mp_message/core/ws/handlers.py

81 lines
2.6 KiB
Python
Raw Permalink Normal View History

"""
This package for handling the websocket connections and messages between the server and the client.
Write here the code for routing the messages between the server and the client.
"""
import json
import typing
from enum import Enum
import pydantic
from fastapi import WebSocket
from pydantic import BaseModel
from core.errors.errors import CustomExceptionHandler
from core.helpers.auth.helpers import get_current_user
class Message(BaseModel):
action: 'AvailableActions'
data: dict
class AvailableActions(str, Enum):
AUTH = 'auth'
class ConnectionManager:
active_connections: list[WebSocket] = []
token_connections: dict[int, WebSocket] = {}
def __init__(self):
# We can use this dictionary to route the messages to the correct method
self._action_routing: dict[AvailableActions, typing.Callable[[Message, WebSocket], typing.Coroutine]] = {
AvailableActions.AUTH: self.authorize
}
async def connect(self, ws: WebSocket):
await ws.accept()
self.active_connections.append(ws)
def disconnect(self, ws: WebSocket):
self.token_connections = {k: v for k, v in self.token_connections.items() if v != ws}
self.active_connections.remove(ws)
async def send_personal_message(self, message: str, ws: WebSocket):
await ws.send_text(message)
async def broadcast(self, message: str):
for connection in self.active_connections:
await connection.send_text(message)
async def authorize(self, message: Message, ws: WebSocket):
token = message.data.get('token')
if token is None:
await self.send_personal_message('{"error": "Token is required"}', ws)
return
try:
user = await get_current_user(token)
self.token_connections[user.id] = ws
msg = {"message": "Successfully authorized", "user_id": user.id}
await self.send_personal_message(json.dumps(msg), ws)
except CustomExceptionHandler as exc:
await self.send_personal_message(exc.response.model_dump_json(), ws)
async def send_personal_message_by_user_id(self, message: str, user_id: int):
ws = self.token_connections.get(user_id)
if ws is not None:
await ws.send_text(message)
async def route_message(self, message: str, ws: WebSocket):
try:
message = Message.parse_raw(message)
await self._action_routing[message.action](message, ws)
except pydantic.ValidationError:
await self.send_personal_message('{"error": "Invalid message format"}', ws)
connection_manager: ConnectionManager = ConnectionManager()