diff --git a/core/api/message/handlers.py b/core/api/message/handlers.py index 5bb3902..95bfedb 100644 --- a/core/api/message/handlers.py +++ b/core/api/message/handlers.py @@ -1,8 +1,11 @@ +from datetime import datetime +from typing import Literal + from fastapi import APIRouter, Depends, Response from core.helpers.auth.helpers import get_current_user from core.models.message.db import MPProfile -from core.models.message.requests import SendMessageRequest +from core.models.message.requests import SendMessageRequest, ListMessagesRequest from core.services import message_service router = APIRouter(prefix="/message", tags=["message"]) @@ -14,13 +17,13 @@ async def send_message(response: Response, message: SendMessageRequest, user: MP return (await message_service.send_message(user, message)).model_dump(exclude_none=True, by_alias=True) -async def list_messages(): - pass +@router.get("") +async def list_messages( + chat_id: int, + from_date: datetime | None = None, + to_date: datetime | None = None, + order_by: Literal['desc'] | Literal['asc'] = 'desc', - -async def delete_message(): - pass - - -async def edit_message(): - pass + user: MPProfile = Depends(get_current_user)): + query = ListMessagesRequest(chat_id=chat_id, from_date=from_date, to_date=to_date, order_by=order_by) + return await message_service.list_messages(user, query) diff --git a/core/helpers/auth/helpers.py b/core/helpers/auth/helpers.py index b605d38..d9740b1 100644 --- a/core/helpers/auth/helpers.py +++ b/core/helpers/auth/helpers.py @@ -10,7 +10,10 @@ from core.errors.errors import not_authenticated_error async def get_current_user(token: str = Header(alias="Authorization")) -> MPProfile: """Get the current user.""" try: - token = token.split("Bearer ")[1] + token = token.split("Bearer ") + if len(token) != 2: + raise not_authenticated_error() + token = token[1] token_data = jwt.decode(token, Config.secret, algorithms=["HS256"]) user = await auth_storage.get_user_by_id(token_data["user_id"]) if user: diff --git a/core/models/message/db.py b/core/models/message/db.py index f10cc60..5ae3480 100644 --- a/core/models/message/db.py +++ b/core/models/message/db.py @@ -27,7 +27,7 @@ class MPMessage(Base): __tablename__ = 'mp_message' id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) sender_id: Mapped[int] = mapped_column(ForeignKey('mp_profile.id')) - chat_id: Mapped[int] = mapped_column(ForeignKey('mp_chat.id')) + chat_id: Mapped[int] = mapped_column(ForeignKey('mp_chat.id', ondelete='CASCADE')) content: Mapped[str] chat: Mapped[MPChat] = relationship("MPChat", back_populates="messages") sender: Mapped[MPProfile] = relationship("MPProfile", back_populates="messages") @@ -37,4 +37,4 @@ class MPChatUser(Base): __tablename__ = 'mp_chat_user' id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) user_id: Mapped[int] = mapped_column(ForeignKey('mp_profile.id', ondelete='CASCADE')) - chat_id: Mapped[int] = mapped_column(ForeignKey('mp_chat.id', ondelete='CASCADE')) \ No newline at end of file + chat_id: Mapped[int] = mapped_column(ForeignKey('mp_chat.id', ondelete='CASCADE')) diff --git a/core/models/message/requests.py b/core/models/message/requests.py index 7d8dcb9..5b9ef65 100644 --- a/core/models/message/requests.py +++ b/core/models/message/requests.py @@ -1,3 +1,6 @@ +import datetime +from typing import Literal + from pydantic import BaseModel @@ -8,3 +11,10 @@ class CreateChatRequest(BaseModel): class SendMessageRequest(BaseModel): chat_id: int content: str + + +class ListMessagesRequest(BaseModel): + chat_id: int + from_date: datetime.datetime | None = None + to_date: datetime.datetime | None = None + order_by: Literal['desc'] | Literal['asc'] = 'desc' diff --git a/core/models/message/responses.py b/core/models/message/responses.py index 68657b4..c202981 100644 --- a/core/models/message/responses.py +++ b/core/models/message/responses.py @@ -23,7 +23,7 @@ class ChatResponse(BaseModel): modified_at: datetime.datetime -class MessageResponse(BaseModel): +class MessageDetailResponse(BaseModel): model_config = ConfigDict(alias_generator=AliasGenerator(validation_alias=to_snake, serialization_alias=to_camel)) id: int @@ -32,3 +32,22 @@ class MessageResponse(BaseModel): chat: ChatResponse created_at: datetime.datetime modified_at: datetime.datetime + + +class MessageShortResponse(BaseModel): + model_config = ConfigDict(alias_generator=AliasGenerator(validation_alias=to_snake, + serialization_alias=to_camel)) + id: int + sender: ProfileResponse + content: str + created_at: datetime.datetime + modified_at: datetime.datetime + + +class ListMessageResponse(BaseModel): + model_config = ConfigDict(alias_generator=AliasGenerator(validation_alias=to_snake, + serialization_alias=to_camel)) + messages: list[MessageShortResponse] + total: int + from_date: datetime.datetime + to_date: datetime.datetime diff --git a/core/services/message/service.py b/core/services/message/service.py index 1ad7d88..7f8fbf2 100644 --- a/core/services/message/service.py +++ b/core/services/message/service.py @@ -1,17 +1,18 @@ from core.errors.errors import not_a_member_of_chat -from core.models.message.db import MPProfile, MPMessage -from core.models.message.requests import SendMessageRequest -from core.models.message.responses import MessageResponse, ProfileResponse, ChatResponse -from core.storage import message_storage, auth_storage, chat_storage +from core.models.message.db import MPProfile, MPMessage, MPChat +from core.models.message.requests import SendMessageRequest, ListMessagesRequest +from core.models.message.responses import MessageDetailResponse, ProfileResponse, ChatResponse +from core.storage import message_storage, chat_storage +from core.ws.handlers import connection_manager class Service: - async def build_message_response(self, msg: MPMessage) -> MessageResponse: + async def build_message_response(self, msg: MPMessage) -> MessageDetailResponse: """ Build a message response """ - return MessageResponse( + return MessageDetailResponse( id=msg.id, sender=ProfileResponse(id=msg.sender.id, external_id=msg.sender.external_id, created_at=msg.sender.created_at, modified_at=msg.sender.modified_at), @@ -28,7 +29,14 @@ class Service: modified_at=msg.modified_at ) - async def send_message(self, user: MPProfile, message: SendMessageRequest) -> MessageResponse: + async def notify_users_in_chat(self, chat: MPChat, message: MPMessage): + """ + Notify all users in chat + """ + for user in chat.users: + await connection_manager.send_personal_message_by_user_id((await self.build_message_response(message)).model_dump_json(), user.id) + + async def send_message(self, user: MPProfile, message: SendMessageRequest) -> MessageDetailResponse: """ Send message to chat """ @@ -43,4 +51,22 @@ class Service: # Add message to Database msg = await message_storage.insert_message(user.id, message.chat_id, message.content) + await self.notify_users_in_chat(chat, msg) return await self.build_message_response(msg) + + async def list_messages(self, user: MPProfile, query: ListMessagesRequest): + """ + List messages in chat + """ + + # Check chat exists + chat = await chat_storage.get_chat(query.chat_id) + + # Check user is in chat + if user.id not in [x.id for x in chat.users]: + not_a_member_of_chat() + + # Get messages from Database + messages = await message_storage.list_messages(query.chat_id, query.from_date, query.to_date) + + return [await self.build_message_response(msg) for msg in messages] diff --git a/core/storage/message/storage.py b/core/storage/message/storage.py index 2dc3a4f..2fcc54f 100644 --- a/core/storage/message/storage.py +++ b/core/storage/message/storage.py @@ -1,4 +1,6 @@ -from sqlalchemy import insert, select +import datetime + +from sqlalchemy import insert, select, and_, desc, asc from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload, subqueryload @@ -20,3 +22,23 @@ class Storage(BaseStorage): result = await session.execute(stmt) data = result.scalar_one_or_none() return data + + async def list_messages(self, chat_id: int, from_date: datetime.datetime, to_date: datetime.datetime, + order: str = "desc") -> list[MPMessage]: + async with self.get_session() as session: + session: AsyncSession + order_type = desc if order == "desc" else asc + selection_query = MPMessage.chat_id == chat_id + + if from_date: + selection_query = and_(selection_query, MPMessage.created_at >= from_date) + + if to_date: + selection_query = and_(selection_query, MPMessage.created_at <= to_date) + + stmt = select(MPMessage).options(joinedload(MPMessage.sender), joinedload(MPMessage.chat).joinedload(MPChat.admin)).where( + selection_query).order_by(order_type(MPMessage.created_at)).limit(100) + + result = await session.execute(stmt) + data = result.scalars().all() + return data diff --git a/core/ws/__init__.py b/core/ws/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/core/ws/handlers.py b/core/ws/handlers.py new file mode 100644 index 0000000..b6b630d --- /dev/null +++ b/core/ws/handlers.py @@ -0,0 +1,80 @@ +""" + This package for handling the websocket connections and messages between the server and the client. + Write here the code for routing the messages between the server and the client. +""" +import json +import typing +from enum import Enum + +import pydantic +from fastapi import WebSocket +from pydantic import BaseModel + +from core.errors.errors import CustomExceptionHandler +from core.helpers.auth.helpers import get_current_user + + +class Message(BaseModel): + action: 'AvailableActions' + data: dict + + +class AvailableActions(str, Enum): + AUTH = 'auth' + + +class ConnectionManager: + active_connections: list[WebSocket] = [] + token_connections: dict[int, WebSocket] = {} + + def __init__(self): + # We can use this dictionary to route the messages to the correct method + self._action_routing: dict[AvailableActions, typing.Callable[[Message, WebSocket], typing.Coroutine]] = { + AvailableActions.AUTH: self.authorize + + } + + async def connect(self, ws: WebSocket): + await ws.accept() + self.active_connections.append(ws) + + def disconnect(self, ws: WebSocket): + self.token_connections = {k: v for k, v in self.token_connections.items() if v != ws} + self.active_connections.remove(ws) + + async def send_personal_message(self, message: str, ws: WebSocket): + await ws.send_text(message) + + async def broadcast(self, message: str): + for connection in self.active_connections: + await connection.send_text(message) + + async def authorize(self, message: Message, ws: WebSocket): + token = message.data.get('token') + if token is None: + await self.send_personal_message('{"error": "Token is required"}', ws) + return + + try: + user = await get_current_user(token) + self.token_connections[user.id] = ws + msg = {"message": "Successfully authorized", "user_id": user.id} + await self.send_personal_message(json.dumps(msg), ws) + except CustomExceptionHandler as exc: + await self.send_personal_message(exc.response.model_dump_json(), ws) + + async def send_personal_message_by_user_id(self, message: str, user_id: int): + ws = self.token_connections.get(user_id) + if ws is not None: + await ws.send_text(message) + + async def route_message(self, message: str, ws: WebSocket): + try: + message = Message.parse_raw(message) + await self._action_routing[message.action](message, ws) + except pydantic.ValidationError: + await self.send_personal_message('{"error": "Invalid message format"}', ws) + + +connection_manager: ConnectionManager = ConnectionManager() + diff --git a/main.py b/main.py index 0ca8bc8..06aab45 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,11 @@ from datetime import datetime, timezone import uvicorn -from fastapi import FastAPI +from fastapi import FastAPI, WebSocket +from starlette.websockets import WebSocketDisconnect + +from core.ws.handlers import connection_manager + from loguru import logger from starlette.middleware.cors import CORSMiddleware @@ -39,5 +43,15 @@ async def custom_exception_handler(_, exc): return exc.result() +@app.websocket("/ws") +async def websocket_endpoint(websocket: WebSocket): + await connection_manager.connect(websocket) + try: + while True: + data = await websocket.receive_text() + await connection_manager.route_message(data, websocket) + except WebSocketDisconnect: + connection_manager.disconnect(websocket) + if __name__ == '__main__': uvicorn.run('main:app', host=str(Config.host), port=Config.port, reload=True) diff --git a/requirements.in b/requirements.in index d2663f2..f8327a3 100644 --- a/requirements.in +++ b/requirements.in @@ -9,4 +9,5 @@ alembic pytest httpx pyjwt -trio \ No newline at end of file +pydantic==2.7.1 +Faker \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index af10c86..c168523 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,8 @@ # # pip-compile # +pydantic==2.7.1 +websockets==12.0 alembic==1.13.0 # via -r requirements.in annotated-types==0.6.0 @@ -15,10 +17,6 @@ anyio==3.7.1 # starlette asyncpg==0.29.0 # via -r requirements.in -attrs==23.2.0 - # via - # outcome - # trio certifi==2024.2.2 # via # httpcore @@ -29,6 +27,8 @@ dnspython==2.4.2 # via email-validator email-validator==2.1.0.post1 # via pydantic +faker==25.2.0 + # via -r requirements.in fastapi==0.104.1 # via -r requirements.in greenlet==3.0.2 @@ -46,7 +46,6 @@ idna==3.6 # anyio # email-validator # httpx - # trio iniconfig==2.0.0 # via pytest loguru==0.7.2 @@ -55,18 +54,16 @@ mako==1.3.0 # via alembic markupsafe==2.1.3 # via mako -outcome==1.3.0.post0 - # via trio packaging==24.0 # via pytest pluggy==1.5.0 # via pytest -pydantic[email]==2.5.2 +pydantic[email]==2.7.1 # via # -r requirements.in # fastapi # pydantic-settings -pydantic-core==2.14.5 +pydantic-core==2.18.2 # via pydantic pydantic-settings==2.1.0 # via -r requirements.in @@ -74,23 +71,22 @@ pyjwt==2.8.0 # via -r requirements.in pytest==8.2.0 # via -r requirements.in +python-dateutil==2.9.0.post0 + # via faker python-dotenv==1.0.0 # via pydantic-settings +six==1.16.0 + # via python-dateutil sniffio==1.3.0 # via # anyio # httpx - # trio -sortedcontainers==2.4.0 - # via trio sqlalchemy[asyncio]==2.0.23 # via # -r requirements.in # alembic starlette==0.27.0 # via fastapi -trio==0.25.1 - # via -r requirements.in typing-extensions==4.9.0 # via # alembic diff --git a/tests/test_message.py b/tests/test_message.py index 23d4fb1..19965ac 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -1,5 +1,9 @@ import threading +import time + +import pytest import uvicorn +from faker import Faker from main import app from httpx import Client @@ -42,3 +46,28 @@ def test__message_send_without_chat(): "chat_id": -1, "content": "test__message_send_without_chat"}, headers={'Authorization': f'Bearer {token}'}) assert resp.status_code == 404 + + +@pytest.mark.skip(reason="Stress test. Should be run separately") +def test_stress(): + # Complete authorization + resp = client.post('/api/v1/auth', json={'username': 1}) + fake = Faker() + assert resp.status_code == 200 + token = resp.json()['access_token'] + chat_id = client.get('/api/v1/chat', headers={'Authorization': f'Bearer {token}'}).json()[0]["id"] + + for i in range(10000): + text = fake.text() + start = time.time() * 1000 + resp = client.post('/api/v1/message', json={ + "chat_id": chat_id, + "content": text}, headers={'Authorization': f'Bearer {token}'}) + _ = client.get(f'/api/v1/message', headers={'Authorization': f'Bearer {token}'}, params={"chat_id": chat_id}) + end = time.time() * 1000 + print(f"Time: {end - start}") + + assert end - start < 100 + assert resp.status_code == 201 + assert resp.json()["content"] == text + assert resp.json()["id"] > 0