add message listing function and add websockets to get notification about messages in chat

master
Ernest Litvinenko 2024-05-20 13:15:56 +03:00
parent 5ce8c9026c
commit 6660180be4
13 changed files with 241 additions and 38 deletions

View File

@ -1,8 +1,11 @@
from datetime import datetime
from typing import Literal
from fastapi import APIRouter, Depends, Response
from core.helpers.auth.helpers import get_current_user
from core.models.message.db import MPProfile
from core.models.message.requests import SendMessageRequest
from core.models.message.requests import SendMessageRequest, ListMessagesRequest
from core.services import message_service
router = APIRouter(prefix="/message", tags=["message"])
@ -14,13 +17,13 @@ async def send_message(response: Response, message: SendMessageRequest, user: MP
return (await message_service.send_message(user, message)).model_dump(exclude_none=True, by_alias=True)
async def list_messages():
pass
@router.get("")
async def list_messages(
chat_id: int,
from_date: datetime | None = None,
to_date: datetime | None = None,
order_by: Literal['desc'] | Literal['asc'] = 'desc',
async def delete_message():
pass
async def edit_message():
pass
user: MPProfile = Depends(get_current_user)):
query = ListMessagesRequest(chat_id=chat_id, from_date=from_date, to_date=to_date, order_by=order_by)
return await message_service.list_messages(user, query)

View File

@ -10,7 +10,10 @@ from core.errors.errors import not_authenticated_error
async def get_current_user(token: str = Header(alias="Authorization")) -> MPProfile:
"""Get the current user."""
try:
token = token.split("Bearer ")[1]
token = token.split("Bearer ")
if len(token) != 2:
raise not_authenticated_error()
token = token[1]
token_data = jwt.decode(token, Config.secret, algorithms=["HS256"])
user = await auth_storage.get_user_by_id(token_data["user_id"])
if user:

View File

@ -27,7 +27,7 @@ class MPMessage(Base):
__tablename__ = 'mp_message'
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
sender_id: Mapped[int] = mapped_column(ForeignKey('mp_profile.id'))
chat_id: Mapped[int] = mapped_column(ForeignKey('mp_chat.id'))
chat_id: Mapped[int] = mapped_column(ForeignKey('mp_chat.id', ondelete='CASCADE'))
content: Mapped[str]
chat: Mapped[MPChat] = relationship("MPChat", back_populates="messages")
sender: Mapped[MPProfile] = relationship("MPProfile", back_populates="messages")

View File

@ -1,3 +1,6 @@
import datetime
from typing import Literal
from pydantic import BaseModel
@ -8,3 +11,10 @@ class CreateChatRequest(BaseModel):
class SendMessageRequest(BaseModel):
chat_id: int
content: str
class ListMessagesRequest(BaseModel):
chat_id: int
from_date: datetime.datetime | None = None
to_date: datetime.datetime | None = None
order_by: Literal['desc'] | Literal['asc'] = 'desc'

View File

@ -23,7 +23,7 @@ class ChatResponse(BaseModel):
modified_at: datetime.datetime
class MessageResponse(BaseModel):
class MessageDetailResponse(BaseModel):
model_config = ConfigDict(alias_generator=AliasGenerator(validation_alias=to_snake,
serialization_alias=to_camel))
id: int
@ -32,3 +32,22 @@ class MessageResponse(BaseModel):
chat: ChatResponse
created_at: datetime.datetime
modified_at: datetime.datetime
class MessageShortResponse(BaseModel):
model_config = ConfigDict(alias_generator=AliasGenerator(validation_alias=to_snake,
serialization_alias=to_camel))
id: int
sender: ProfileResponse
content: str
created_at: datetime.datetime
modified_at: datetime.datetime
class ListMessageResponse(BaseModel):
model_config = ConfigDict(alias_generator=AliasGenerator(validation_alias=to_snake,
serialization_alias=to_camel))
messages: list[MessageShortResponse]
total: int
from_date: datetime.datetime
to_date: datetime.datetime

View File

@ -1,17 +1,18 @@
from core.errors.errors import not_a_member_of_chat
from core.models.message.db import MPProfile, MPMessage
from core.models.message.requests import SendMessageRequest
from core.models.message.responses import MessageResponse, ProfileResponse, ChatResponse
from core.storage import message_storage, auth_storage, chat_storage
from core.models.message.db import MPProfile, MPMessage, MPChat
from core.models.message.requests import SendMessageRequest, ListMessagesRequest
from core.models.message.responses import MessageDetailResponse, ProfileResponse, ChatResponse
from core.storage import message_storage, chat_storage
from core.ws.handlers import connection_manager
class Service:
async def build_message_response(self, msg: MPMessage) -> MessageResponse:
async def build_message_response(self, msg: MPMessage) -> MessageDetailResponse:
"""
Build a message response
"""
return MessageResponse(
return MessageDetailResponse(
id=msg.id,
sender=ProfileResponse(id=msg.sender.id, external_id=msg.sender.external_id,
created_at=msg.sender.created_at, modified_at=msg.sender.modified_at),
@ -28,7 +29,14 @@ class Service:
modified_at=msg.modified_at
)
async def send_message(self, user: MPProfile, message: SendMessageRequest) -> MessageResponse:
async def notify_users_in_chat(self, chat: MPChat, message: MPMessage):
"""
Notify all users in chat
"""
for user in chat.users:
await connection_manager.send_personal_message_by_user_id((await self.build_message_response(message)).model_dump_json(), user.id)
async def send_message(self, user: MPProfile, message: SendMessageRequest) -> MessageDetailResponse:
"""
Send message to chat
"""
@ -43,4 +51,22 @@ class Service:
# Add message to Database
msg = await message_storage.insert_message(user.id, message.chat_id, message.content)
await self.notify_users_in_chat(chat, msg)
return await self.build_message_response(msg)
async def list_messages(self, user: MPProfile, query: ListMessagesRequest):
"""
List messages in chat
"""
# Check chat exists
chat = await chat_storage.get_chat(query.chat_id)
# Check user is in chat
if user.id not in [x.id for x in chat.users]:
not_a_member_of_chat()
# Get messages from Database
messages = await message_storage.list_messages(query.chat_id, query.from_date, query.to_date)
return [await self.build_message_response(msg) for msg in messages]

View File

@ -1,4 +1,6 @@
from sqlalchemy import insert, select
import datetime
from sqlalchemy import insert, select, and_, desc, asc
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload, subqueryload
@ -20,3 +22,23 @@ class Storage(BaseStorage):
result = await session.execute(stmt)
data = result.scalar_one_or_none()
return data
async def list_messages(self, chat_id: int, from_date: datetime.datetime, to_date: datetime.datetime,
order: str = "desc") -> list[MPMessage]:
async with self.get_session() as session:
session: AsyncSession
order_type = desc if order == "desc" else asc
selection_query = MPMessage.chat_id == chat_id
if from_date:
selection_query = and_(selection_query, MPMessage.created_at >= from_date)
if to_date:
selection_query = and_(selection_query, MPMessage.created_at <= to_date)
stmt = select(MPMessage).options(joinedload(MPMessage.sender), joinedload(MPMessage.chat).joinedload(MPChat.admin)).where(
selection_query).order_by(order_type(MPMessage.created_at)).limit(100)
result = await session.execute(stmt)
data = result.scalars().all()
return data

0
core/ws/__init__.py Normal file
View File

80
core/ws/handlers.py Normal file
View File

@ -0,0 +1,80 @@
"""
This package for handling the websocket connections and messages between the server and the client.
Write here the code for routing the messages between the server and the client.
"""
import json
import typing
from enum import Enum
import pydantic
from fastapi import WebSocket
from pydantic import BaseModel
from core.errors.errors import CustomExceptionHandler
from core.helpers.auth.helpers import get_current_user
class Message(BaseModel):
action: 'AvailableActions'
data: dict
class AvailableActions(str, Enum):
AUTH = 'auth'
class ConnectionManager:
active_connections: list[WebSocket] = []
token_connections: dict[int, WebSocket] = {}
def __init__(self):
# We can use this dictionary to route the messages to the correct method
self._action_routing: dict[AvailableActions, typing.Callable[[Message, WebSocket], typing.Coroutine]] = {
AvailableActions.AUTH: self.authorize
}
async def connect(self, ws: WebSocket):
await ws.accept()
self.active_connections.append(ws)
def disconnect(self, ws: WebSocket):
self.token_connections = {k: v for k, v in self.token_connections.items() if v != ws}
self.active_connections.remove(ws)
async def send_personal_message(self, message: str, ws: WebSocket):
await ws.send_text(message)
async def broadcast(self, message: str):
for connection in self.active_connections:
await connection.send_text(message)
async def authorize(self, message: Message, ws: WebSocket):
token = message.data.get('token')
if token is None:
await self.send_personal_message('{"error": "Token is required"}', ws)
return
try:
user = await get_current_user(token)
self.token_connections[user.id] = ws
msg = {"message": "Successfully authorized", "user_id": user.id}
await self.send_personal_message(json.dumps(msg), ws)
except CustomExceptionHandler as exc:
await self.send_personal_message(exc.response.model_dump_json(), ws)
async def send_personal_message_by_user_id(self, message: str, user_id: int):
ws = self.token_connections.get(user_id)
if ws is not None:
await ws.send_text(message)
async def route_message(self, message: str, ws: WebSocket):
try:
message = Message.parse_raw(message)
await self._action_routing[message.action](message, ws)
except pydantic.ValidationError:
await self.send_personal_message('{"error": "Invalid message format"}', ws)
connection_manager: ConnectionManager = ConnectionManager()

16
main.py
View File

@ -1,7 +1,11 @@
from datetime import datetime, timezone
import uvicorn
from fastapi import FastAPI
from fastapi import FastAPI, WebSocket
from starlette.websockets import WebSocketDisconnect
from core.ws.handlers import connection_manager
from loguru import logger
from starlette.middleware.cors import CORSMiddleware
@ -39,5 +43,15 @@ async def custom_exception_handler(_, exc):
return exc.result()
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await connection_manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
await connection_manager.route_message(data, websocket)
except WebSocketDisconnect:
connection_manager.disconnect(websocket)
if __name__ == '__main__':
uvicorn.run('main:app', host=str(Config.host), port=Config.port, reload=True)

View File

@ -9,4 +9,5 @@ alembic
pytest
httpx
pyjwt
trio
pydantic==2.7.1
Faker

View File

@ -4,6 +4,8 @@
#
# pip-compile
#
pydantic==2.7.1
websockets==12.0
alembic==1.13.0
# via -r requirements.in
annotated-types==0.6.0
@ -15,10 +17,6 @@ anyio==3.7.1
# starlette
asyncpg==0.29.0
# via -r requirements.in
attrs==23.2.0
# via
# outcome
# trio
certifi==2024.2.2
# via
# httpcore
@ -29,6 +27,8 @@ dnspython==2.4.2
# via email-validator
email-validator==2.1.0.post1
# via pydantic
faker==25.2.0
# via -r requirements.in
fastapi==0.104.1
# via -r requirements.in
greenlet==3.0.2
@ -46,7 +46,6 @@ idna==3.6
# anyio
# email-validator
# httpx
# trio
iniconfig==2.0.0
# via pytest
loguru==0.7.2
@ -55,18 +54,16 @@ mako==1.3.0
# via alembic
markupsafe==2.1.3
# via mako
outcome==1.3.0.post0
# via trio
packaging==24.0
# via pytest
pluggy==1.5.0
# via pytest
pydantic[email]==2.5.2
pydantic[email]==2.7.1
# via
# -r requirements.in
# fastapi
# pydantic-settings
pydantic-core==2.14.5
pydantic-core==2.18.2
# via pydantic
pydantic-settings==2.1.0
# via -r requirements.in
@ -74,23 +71,22 @@ pyjwt==2.8.0
# via -r requirements.in
pytest==8.2.0
# via -r requirements.in
python-dateutil==2.9.0.post0
# via faker
python-dotenv==1.0.0
# via pydantic-settings
six==1.16.0
# via python-dateutil
sniffio==1.3.0
# via
# anyio
# httpx
# trio
sortedcontainers==2.4.0
# via trio
sqlalchemy[asyncio]==2.0.23
# via
# -r requirements.in
# alembic
starlette==0.27.0
# via fastapi
trio==0.25.1
# via -r requirements.in
typing-extensions==4.9.0
# via
# alembic

View File

@ -1,5 +1,9 @@
import threading
import time
import pytest
import uvicorn
from faker import Faker
from main import app
from httpx import Client
@ -42,3 +46,28 @@ def test__message_send_without_chat():
"chat_id": -1,
"content": "test__message_send_without_chat"}, headers={'Authorization': f'Bearer {token}'})
assert resp.status_code == 404
@pytest.mark.skip(reason="Stress test. Should be run separately")
def test_stress():
# Complete authorization
resp = client.post('/api/v1/auth', json={'username': 1})
fake = Faker()
assert resp.status_code == 200
token = resp.json()['access_token']
chat_id = client.get('/api/v1/chat', headers={'Authorization': f'Bearer {token}'}).json()[0]["id"]
for i in range(10000):
text = fake.text()
start = time.time() * 1000
resp = client.post('/api/v1/message', json={
"chat_id": chat_id,
"content": text}, headers={'Authorization': f'Bearer {token}'})
_ = client.get(f'/api/v1/message', headers={'Authorization': f'Bearer {token}'}, params={"chat_id": chat_id})
end = time.time() * 1000
print(f"Time: {end - start}")
assert end - start < 100
assert resp.status_code == 201
assert resp.json()["content"] == text
assert resp.json()["id"] > 0