add message listing function and add websockets to get notification about messages in chat
parent
5ce8c9026c
commit
6660180be4
|
@ -1,8 +1,11 @@
|
|||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, Depends, Response
|
||||
|
||||
from core.helpers.auth.helpers import get_current_user
|
||||
from core.models.message.db import MPProfile
|
||||
from core.models.message.requests import SendMessageRequest
|
||||
from core.models.message.requests import SendMessageRequest, ListMessagesRequest
|
||||
from core.services import message_service
|
||||
|
||||
router = APIRouter(prefix="/message", tags=["message"])
|
||||
|
@ -14,13 +17,13 @@ async def send_message(response: Response, message: SendMessageRequest, user: MP
|
|||
return (await message_service.send_message(user, message)).model_dump(exclude_none=True, by_alias=True)
|
||||
|
||||
|
||||
async def list_messages():
|
||||
pass
|
||||
@router.get("")
|
||||
async def list_messages(
|
||||
chat_id: int,
|
||||
from_date: datetime | None = None,
|
||||
to_date: datetime | None = None,
|
||||
order_by: Literal['desc'] | Literal['asc'] = 'desc',
|
||||
|
||||
|
||||
async def delete_message():
|
||||
pass
|
||||
|
||||
|
||||
async def edit_message():
|
||||
pass
|
||||
user: MPProfile = Depends(get_current_user)):
|
||||
query = ListMessagesRequest(chat_id=chat_id, from_date=from_date, to_date=to_date, order_by=order_by)
|
||||
return await message_service.list_messages(user, query)
|
||||
|
|
|
@ -10,7 +10,10 @@ from core.errors.errors import not_authenticated_error
|
|||
async def get_current_user(token: str = Header(alias="Authorization")) -> MPProfile:
|
||||
"""Get the current user."""
|
||||
try:
|
||||
token = token.split("Bearer ")[1]
|
||||
token = token.split("Bearer ")
|
||||
if len(token) != 2:
|
||||
raise not_authenticated_error()
|
||||
token = token[1]
|
||||
token_data = jwt.decode(token, Config.secret, algorithms=["HS256"])
|
||||
user = await auth_storage.get_user_by_id(token_data["user_id"])
|
||||
if user:
|
||||
|
|
|
@ -27,7 +27,7 @@ class MPMessage(Base):
|
|||
__tablename__ = 'mp_message'
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
sender_id: Mapped[int] = mapped_column(ForeignKey('mp_profile.id'))
|
||||
chat_id: Mapped[int] = mapped_column(ForeignKey('mp_chat.id'))
|
||||
chat_id: Mapped[int] = mapped_column(ForeignKey('mp_chat.id', ondelete='CASCADE'))
|
||||
content: Mapped[str]
|
||||
chat: Mapped[MPChat] = relationship("MPChat", back_populates="messages")
|
||||
sender: Mapped[MPProfile] = relationship("MPProfile", back_populates="messages")
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
import datetime
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
|
@ -8,3 +11,10 @@ class CreateChatRequest(BaseModel):
|
|||
class SendMessageRequest(BaseModel):
|
||||
chat_id: int
|
||||
content: str
|
||||
|
||||
|
||||
class ListMessagesRequest(BaseModel):
|
||||
chat_id: int
|
||||
from_date: datetime.datetime | None = None
|
||||
to_date: datetime.datetime | None = None
|
||||
order_by: Literal['desc'] | Literal['asc'] = 'desc'
|
||||
|
|
|
@ -23,7 +23,7 @@ class ChatResponse(BaseModel):
|
|||
modified_at: datetime.datetime
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
class MessageDetailResponse(BaseModel):
|
||||
model_config = ConfigDict(alias_generator=AliasGenerator(validation_alias=to_snake,
|
||||
serialization_alias=to_camel))
|
||||
id: int
|
||||
|
@ -32,3 +32,22 @@ class MessageResponse(BaseModel):
|
|||
chat: ChatResponse
|
||||
created_at: datetime.datetime
|
||||
modified_at: datetime.datetime
|
||||
|
||||
|
||||
class MessageShortResponse(BaseModel):
|
||||
model_config = ConfigDict(alias_generator=AliasGenerator(validation_alias=to_snake,
|
||||
serialization_alias=to_camel))
|
||||
id: int
|
||||
sender: ProfileResponse
|
||||
content: str
|
||||
created_at: datetime.datetime
|
||||
modified_at: datetime.datetime
|
||||
|
||||
|
||||
class ListMessageResponse(BaseModel):
|
||||
model_config = ConfigDict(alias_generator=AliasGenerator(validation_alias=to_snake,
|
||||
serialization_alias=to_camel))
|
||||
messages: list[MessageShortResponse]
|
||||
total: int
|
||||
from_date: datetime.datetime
|
||||
to_date: datetime.datetime
|
||||
|
|
|
@ -1,17 +1,18 @@
|
|||
from core.errors.errors import not_a_member_of_chat
|
||||
from core.models.message.db import MPProfile, MPMessage
|
||||
from core.models.message.requests import SendMessageRequest
|
||||
from core.models.message.responses import MessageResponse, ProfileResponse, ChatResponse
|
||||
from core.storage import message_storage, auth_storage, chat_storage
|
||||
from core.models.message.db import MPProfile, MPMessage, MPChat
|
||||
from core.models.message.requests import SendMessageRequest, ListMessagesRequest
|
||||
from core.models.message.responses import MessageDetailResponse, ProfileResponse, ChatResponse
|
||||
from core.storage import message_storage, chat_storage
|
||||
from core.ws.handlers import connection_manager
|
||||
|
||||
|
||||
class Service:
|
||||
|
||||
async def build_message_response(self, msg: MPMessage) -> MessageResponse:
|
||||
async def build_message_response(self, msg: MPMessage) -> MessageDetailResponse:
|
||||
"""
|
||||
Build a message response
|
||||
"""
|
||||
return MessageResponse(
|
||||
return MessageDetailResponse(
|
||||
id=msg.id,
|
||||
sender=ProfileResponse(id=msg.sender.id, external_id=msg.sender.external_id,
|
||||
created_at=msg.sender.created_at, modified_at=msg.sender.modified_at),
|
||||
|
@ -28,7 +29,14 @@ class Service:
|
|||
modified_at=msg.modified_at
|
||||
)
|
||||
|
||||
async def send_message(self, user: MPProfile, message: SendMessageRequest) -> MessageResponse:
|
||||
async def notify_users_in_chat(self, chat: MPChat, message: MPMessage):
|
||||
"""
|
||||
Notify all users in chat
|
||||
"""
|
||||
for user in chat.users:
|
||||
await connection_manager.send_personal_message_by_user_id((await self.build_message_response(message)).model_dump_json(), user.id)
|
||||
|
||||
async def send_message(self, user: MPProfile, message: SendMessageRequest) -> MessageDetailResponse:
|
||||
"""
|
||||
Send message to chat
|
||||
"""
|
||||
|
@ -43,4 +51,22 @@ class Service:
|
|||
# Add message to Database
|
||||
msg = await message_storage.insert_message(user.id, message.chat_id, message.content)
|
||||
|
||||
await self.notify_users_in_chat(chat, msg)
|
||||
return await self.build_message_response(msg)
|
||||
|
||||
async def list_messages(self, user: MPProfile, query: ListMessagesRequest):
|
||||
"""
|
||||
List messages in chat
|
||||
"""
|
||||
|
||||
# Check chat exists
|
||||
chat = await chat_storage.get_chat(query.chat_id)
|
||||
|
||||
# Check user is in chat
|
||||
if user.id not in [x.id for x in chat.users]:
|
||||
not_a_member_of_chat()
|
||||
|
||||
# Get messages from Database
|
||||
messages = await message_storage.list_messages(query.chat_id, query.from_date, query.to_date)
|
||||
|
||||
return [await self.build_message_response(msg) for msg in messages]
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
from sqlalchemy import insert, select
|
||||
import datetime
|
||||
|
||||
from sqlalchemy import insert, select, and_, desc, asc
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload, subqueryload
|
||||
|
||||
|
@ -20,3 +22,23 @@ class Storage(BaseStorage):
|
|||
result = await session.execute(stmt)
|
||||
data = result.scalar_one_or_none()
|
||||
return data
|
||||
|
||||
async def list_messages(self, chat_id: int, from_date: datetime.datetime, to_date: datetime.datetime,
|
||||
order: str = "desc") -> list[MPMessage]:
|
||||
async with self.get_session() as session:
|
||||
session: AsyncSession
|
||||
order_type = desc if order == "desc" else asc
|
||||
selection_query = MPMessage.chat_id == chat_id
|
||||
|
||||
if from_date:
|
||||
selection_query = and_(selection_query, MPMessage.created_at >= from_date)
|
||||
|
||||
if to_date:
|
||||
selection_query = and_(selection_query, MPMessage.created_at <= to_date)
|
||||
|
||||
stmt = select(MPMessage).options(joinedload(MPMessage.sender), joinedload(MPMessage.chat).joinedload(MPChat.admin)).where(
|
||||
selection_query).order_by(order_type(MPMessage.created_at)).limit(100)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
data = result.scalars().all()
|
||||
return data
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
"""
|
||||
This package for handling the websocket connections and messages between the server and the client.
|
||||
Write here the code for routing the messages between the server and the client.
|
||||
"""
|
||||
import json
|
||||
import typing
|
||||
from enum import Enum
|
||||
|
||||
import pydantic
|
||||
from fastapi import WebSocket
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.errors.errors import CustomExceptionHandler
|
||||
from core.helpers.auth.helpers import get_current_user
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
action: 'AvailableActions'
|
||||
data: dict
|
||||
|
||||
|
||||
class AvailableActions(str, Enum):
|
||||
AUTH = 'auth'
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
active_connections: list[WebSocket] = []
|
||||
token_connections: dict[int, WebSocket] = {}
|
||||
|
||||
def __init__(self):
|
||||
# We can use this dictionary to route the messages to the correct method
|
||||
self._action_routing: dict[AvailableActions, typing.Callable[[Message, WebSocket], typing.Coroutine]] = {
|
||||
AvailableActions.AUTH: self.authorize
|
||||
|
||||
}
|
||||
|
||||
async def connect(self, ws: WebSocket):
|
||||
await ws.accept()
|
||||
self.active_connections.append(ws)
|
||||
|
||||
def disconnect(self, ws: WebSocket):
|
||||
self.token_connections = {k: v for k, v in self.token_connections.items() if v != ws}
|
||||
self.active_connections.remove(ws)
|
||||
|
||||
async def send_personal_message(self, message: str, ws: WebSocket):
|
||||
await ws.send_text(message)
|
||||
|
||||
async def broadcast(self, message: str):
|
||||
for connection in self.active_connections:
|
||||
await connection.send_text(message)
|
||||
|
||||
async def authorize(self, message: Message, ws: WebSocket):
|
||||
token = message.data.get('token')
|
||||
if token is None:
|
||||
await self.send_personal_message('{"error": "Token is required"}', ws)
|
||||
return
|
||||
|
||||
try:
|
||||
user = await get_current_user(token)
|
||||
self.token_connections[user.id] = ws
|
||||
msg = {"message": "Successfully authorized", "user_id": user.id}
|
||||
await self.send_personal_message(json.dumps(msg), ws)
|
||||
except CustomExceptionHandler as exc:
|
||||
await self.send_personal_message(exc.response.model_dump_json(), ws)
|
||||
|
||||
async def send_personal_message_by_user_id(self, message: str, user_id: int):
|
||||
ws = self.token_connections.get(user_id)
|
||||
if ws is not None:
|
||||
await ws.send_text(message)
|
||||
|
||||
async def route_message(self, message: str, ws: WebSocket):
|
||||
try:
|
||||
message = Message.parse_raw(message)
|
||||
await self._action_routing[message.action](message, ws)
|
||||
except pydantic.ValidationError:
|
||||
await self.send_personal_message('{"error": "Invalid message format"}', ws)
|
||||
|
||||
|
||||
connection_manager: ConnectionManager = ConnectionManager()
|
||||
|
16
main.py
16
main.py
|
@ -1,7 +1,11 @@
|
|||
from datetime import datetime, timezone
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, WebSocket
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
|
||||
from core.ws.handlers import connection_manager
|
||||
|
||||
from loguru import logger
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
|
@ -39,5 +43,15 @@ async def custom_exception_handler(_, exc):
|
|||
return exc.result()
|
||||
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await connection_manager.connect(websocket)
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
await connection_manager.route_message(data, websocket)
|
||||
except WebSocketDisconnect:
|
||||
connection_manager.disconnect(websocket)
|
||||
|
||||
if __name__ == '__main__':
|
||||
uvicorn.run('main:app', host=str(Config.host), port=Config.port, reload=True)
|
||||
|
|
|
@ -9,4 +9,5 @@ alembic
|
|||
pytest
|
||||
httpx
|
||||
pyjwt
|
||||
trio
|
||||
pydantic==2.7.1
|
||||
Faker
|
|
@ -4,6 +4,8 @@
|
|||
#
|
||||
# pip-compile
|
||||
#
|
||||
pydantic==2.7.1
|
||||
websockets==12.0
|
||||
alembic==1.13.0
|
||||
# via -r requirements.in
|
||||
annotated-types==0.6.0
|
||||
|
@ -15,10 +17,6 @@ anyio==3.7.1
|
|||
# starlette
|
||||
asyncpg==0.29.0
|
||||
# via -r requirements.in
|
||||
attrs==23.2.0
|
||||
# via
|
||||
# outcome
|
||||
# trio
|
||||
certifi==2024.2.2
|
||||
# via
|
||||
# httpcore
|
||||
|
@ -29,6 +27,8 @@ dnspython==2.4.2
|
|||
# via email-validator
|
||||
email-validator==2.1.0.post1
|
||||
# via pydantic
|
||||
faker==25.2.0
|
||||
# via -r requirements.in
|
||||
fastapi==0.104.1
|
||||
# via -r requirements.in
|
||||
greenlet==3.0.2
|
||||
|
@ -46,7 +46,6 @@ idna==3.6
|
|||
# anyio
|
||||
# email-validator
|
||||
# httpx
|
||||
# trio
|
||||
iniconfig==2.0.0
|
||||
# via pytest
|
||||
loguru==0.7.2
|
||||
|
@ -55,18 +54,16 @@ mako==1.3.0
|
|||
# via alembic
|
||||
markupsafe==2.1.3
|
||||
# via mako
|
||||
outcome==1.3.0.post0
|
||||
# via trio
|
||||
packaging==24.0
|
||||
# via pytest
|
||||
pluggy==1.5.0
|
||||
# via pytest
|
||||
pydantic[email]==2.5.2
|
||||
pydantic[email]==2.7.1
|
||||
# via
|
||||
# -r requirements.in
|
||||
# fastapi
|
||||
# pydantic-settings
|
||||
pydantic-core==2.14.5
|
||||
pydantic-core==2.18.2
|
||||
# via pydantic
|
||||
pydantic-settings==2.1.0
|
||||
# via -r requirements.in
|
||||
|
@ -74,23 +71,22 @@ pyjwt==2.8.0
|
|||
# via -r requirements.in
|
||||
pytest==8.2.0
|
||||
# via -r requirements.in
|
||||
python-dateutil==2.9.0.post0
|
||||
# via faker
|
||||
python-dotenv==1.0.0
|
||||
# via pydantic-settings
|
||||
six==1.16.0
|
||||
# via python-dateutil
|
||||
sniffio==1.3.0
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
# trio
|
||||
sortedcontainers==2.4.0
|
||||
# via trio
|
||||
sqlalchemy[asyncio]==2.0.23
|
||||
# via
|
||||
# -r requirements.in
|
||||
# alembic
|
||||
starlette==0.27.0
|
||||
# via fastapi
|
||||
trio==0.25.1
|
||||
# via -r requirements.in
|
||||
typing-extensions==4.9.0
|
||||
# via
|
||||
# alembic
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import uvicorn
|
||||
from faker import Faker
|
||||
from main import app
|
||||
from httpx import Client
|
||||
|
||||
|
@ -42,3 +46,28 @@ def test__message_send_without_chat():
|
|||
"chat_id": -1,
|
||||
"content": "test__message_send_without_chat"}, headers={'Authorization': f'Bearer {token}'})
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Stress test. Should be run separately")
|
||||
def test_stress():
|
||||
# Complete authorization
|
||||
resp = client.post('/api/v1/auth', json={'username': 1})
|
||||
fake = Faker()
|
||||
assert resp.status_code == 200
|
||||
token = resp.json()['access_token']
|
||||
chat_id = client.get('/api/v1/chat', headers={'Authorization': f'Bearer {token}'}).json()[0]["id"]
|
||||
|
||||
for i in range(10000):
|
||||
text = fake.text()
|
||||
start = time.time() * 1000
|
||||
resp = client.post('/api/v1/message', json={
|
||||
"chat_id": chat_id,
|
||||
"content": text}, headers={'Authorization': f'Bearer {token}'})
|
||||
_ = client.get(f'/api/v1/message', headers={'Authorization': f'Bearer {token}'}, params={"chat_id": chat_id})
|
||||
end = time.time() * 1000
|
||||
print(f"Time: {end - start}")
|
||||
|
||||
assert end - start < 100
|
||||
assert resp.status_code == 201
|
||||
assert resp.json()["content"] == text
|
||||
assert resp.json()["id"] > 0
|
||||
|
|
Loading…
Reference in New Issue