81 lines
2.6 KiB
Python
81 lines
2.6 KiB
Python
|
"""
|
||
|
This package for handling the websocket connections and messages between the server and the client.
|
||
|
Write here the code for routing the messages between the server and the client.
|
||
|
"""
|
||
|
import json
|
||
|
import typing
|
||
|
from enum import Enum
|
||
|
|
||
|
import pydantic
|
||
|
from fastapi import WebSocket
|
||
|
from pydantic import BaseModel
|
||
|
|
||
|
from core.errors.errors import CustomExceptionHandler
|
||
|
from core.helpers.auth.helpers import get_current_user
|
||
|
|
||
|
|
||
|
class Message(BaseModel):
|
||
|
action: 'AvailableActions'
|
||
|
data: dict
|
||
|
|
||
|
|
||
|
class AvailableActions(str, Enum):
|
||
|
AUTH = 'auth'
|
||
|
|
||
|
|
||
|
class ConnectionManager:
|
||
|
active_connections: list[WebSocket] = []
|
||
|
token_connections: dict[int, WebSocket] = {}
|
||
|
|
||
|
def __init__(self):
|
||
|
# We can use this dictionary to route the messages to the correct method
|
||
|
self._action_routing: dict[AvailableActions, typing.Callable[[Message, WebSocket], typing.Coroutine]] = {
|
||
|
AvailableActions.AUTH: self.authorize
|
||
|
|
||
|
}
|
||
|
|
||
|
async def connect(self, ws: WebSocket):
|
||
|
await ws.accept()
|
||
|
self.active_connections.append(ws)
|
||
|
|
||
|
def disconnect(self, ws: WebSocket):
|
||
|
self.token_connections = {k: v for k, v in self.token_connections.items() if v != ws}
|
||
|
self.active_connections.remove(ws)
|
||
|
|
||
|
async def send_personal_message(self, message: str, ws: WebSocket):
|
||
|
await ws.send_text(message)
|
||
|
|
||
|
async def broadcast(self, message: str):
|
||
|
for connection in self.active_connections:
|
||
|
await connection.send_text(message)
|
||
|
|
||
|
async def authorize(self, message: Message, ws: WebSocket):
|
||
|
token = message.data.get('token')
|
||
|
if token is None:
|
||
|
await self.send_personal_message('{"error": "Token is required"}', ws)
|
||
|
return
|
||
|
|
||
|
try:
|
||
|
user = await get_current_user(token)
|
||
|
self.token_connections[user.id] = ws
|
||
|
msg = {"message": "Successfully authorized", "user_id": user.id}
|
||
|
await self.send_personal_message(json.dumps(msg), ws)
|
||
|
except CustomExceptionHandler as exc:
|
||
|
await self.send_personal_message(exc.response.model_dump_json(), ws)
|
||
|
|
||
|
async def send_personal_message_by_user_id(self, message: str, user_id: int):
|
||
|
ws = self.token_connections.get(user_id)
|
||
|
if ws is not None:
|
||
|
await ws.send_text(message)
|
||
|
|
||
|
async def route_message(self, message: str, ws: WebSocket):
|
||
|
try:
|
||
|
message = Message.parse_raw(message)
|
||
|
await self._action_routing[message.action](message, ws)
|
||
|
except pydantic.ValidationError:
|
||
|
await self.send_personal_message('{"error": "Invalid message format"}', ws)
|
||
|
|
||
|
|
||
|
connection_manager: ConnectionManager = ConnectionManager()
|
||
|
|