ser / app.py
a9's picture
Update app.py
f24dc63 verified
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from typing import Dict, List
import time
import json
import asyncio
app = FastAPI()
# --- DATABASE / STATE ---
# Fixed users in the system (0 is the only hardcoded user for now, assuming another is 1)
VALID_USER_IDS = [0, 1]
# User status tracking (only tracks last time seen, actual online status is handled by ConnectionManager)
userLastOnlineTime: Dict[int, int] = {id: 0 for id in VALID_USER_IDS}
# In-memory "database" using a list of dictionaries
messages: List[Dict] = []
def timeMiliSec(sec_time):
return int(sec_time * 1000)
class ConnectionManager:
"""Manages active WebSocket connections, mapping user_id to WebSocket."""
def __init__(self):
# Maps {user_id: WebSocket} for easy personal messaging and status tracking
self.active_connections: Dict[int, WebSocket] = {}
async def connect(self, user_id: int, websocket: WebSocket):
await websocket.accept()
self.active_connections[user_id] = websocket
def disconnect(self, user_id: int):
if user_id in self.active_connections:
del self.active_connections[user_id]
def isOnline(self, user_id: int):
return user_id in self.active_connections
async def send_personal_message(self, message: Dict, user_id: int):
"""Sends a JSON message to a single connected user."""
if user_id in self.active_connections:
try:
await self.active_connections[user_id].send_json(message)
return True
except RuntimeError as e:
print(f"Error sending message to user {user_id}: {e}")
# Clean up broken connection if necessary
self.disconnect(user_id)
return False
return False
async def broadcast(self, message: Dict):
"""Sends a JSON message to all connected users."""
disconnects = []
for user_id, connection in self.active_connections.items():
try:
# Use send_json for reliable serialization
await connection.send_json(message)
except RuntimeError:
# If connection is already closed, mark for removal
disconnects.append(user_id)
except Exception as e:
print(f"Broadcast error to {user_id}: {e}")
disconnects.append(user_id)
for user_id in disconnects:
self.disconnect(user_id)
print(f"User {user_id} disconnected during broadcast cleanup.")
manager = ConnectionManager()
def addMessage(sender_id: int, data: Dict) -> Dict:
"""
Creates a new message dictionary and adds it to the global messages list.
Expected keys in data: "content", "timestamp" (client time), "receiver_id".
"""
new_message_id = len(messages)
# Message Status Tracking:
# Sender status is implicitly 'sent' by the server.
# Receiver status starts as 'delivered' (meaning the server has logged it and is attempting delivery).
receiver_id = data.get("receiver_id")
status_tracker = {
sender_id: "sent",
receiver_id: "delivered" # Initial status for the receiver (server logged and attempting delivery)
}
new_message = {
"id": new_message_id,
"content": data.get("content", ""),
"time_client": data.get("timestamp", timeMiliSec(time.time())),
"time_server": timeMiliSec(time.time()),
"sender_id": sender_id,
"receiver_id": receiver_id, # Target receiver ID
"status": status_tracker
}
messages.append(new_message)
return new_message
# --- NEW FUNCTION FOR OFFLINE DELIVERY ---
async def send_undelivered_messages(user_id: int):
"""
Checks the message history and attempts to send any messages
that were logged while the user was offline.
"""
print(f"Checking for undelivered messages for user {user_id}...")
# We use a copy of the list for safe iteration if another async task modifies it
# For a simple chat like this, iterating directly is fine, but in a real-world scenario,
# you'd need locks for thread safety on the global state.
undelivered_count = 0
# Iterate through the history to find messages meant for this user
for msg in messages:
# Check 1: Is this user the receiver?
if msg.get("receiver_id") == user_id:
# Check 2: Is the message status still 'delivered'?
# This is the key: 'delivered' means the server logged it but never confirmed client receipt.
receiver_status = msg["status"].get(user_id)
if receiver_status == "delivered":
delivery_payload = {
"type": "new_message",
"message": msg
}
# Attempt to send the message to the now-connected user
sent_successfully = await manager.send_personal_message(delivery_payload, user_id)
if sent_successfully:
# Update the status in the global list to prevent re-delivery
msg["status"][user_id] = "received"
undelivered_count += 1
# Optional: Notify the original sender that the message has now been received
sender_id = msg["sender_id"]
receipt_payload = {
"type": "read_receipt",
"message_id": msg["id"],
"status": "received", # The new status is 'received' upon successful delivery
"updated_by_user": user_id
}
await manager.send_personal_message(receipt_payload, sender_id)
if undelivered_count > 0:
print(f"Delivered {undelivered_count} historical messages to user {user_id}.")
@app.websocket("/ws/{user_id}")
async def websocket_endpoint(websocket: WebSocket, user_id: int):
# 1. Validation and Connection
if user_id not in VALID_USER_IDS:
await websocket.close(code=1008, reason="Invalid User ID")
return
await manager.connect(user_id, websocket)
userLastOnlineTime[user_id] = timeMiliSec(time.time()) # Now officially "online"
print(f"User {user_id} connected.")
# 2. Initial Status Broadcast
online_status = {"type": "status_update", "user_id": user_id, "is_online": True}
await manager.broadcast(online_status)
# Send the other user's status to the newly connected user
for other_user_id in VALID_USER_IDS:
if other_user_id != user_id:
if not manager.isOnline(other_user_id):
offline_status = {"type": "status_update", "user_id": other_user_id, "is_online": False, "last_seen": userLastOnlineTime.get(other_user_id, 0.0)}
await manager.send_personal_message(offline_status, user_id)
else:
online_status = {"type": "status_update", "user_id": other_user_id, "is_online": True}
await manager.send_personal_message(online_status, user_id)
# 3. <--- NEW IMPLEMENTATION HERE --->
# Attempt to send any messages that accumulated while the user was offline.
await send_undelivered_messages(user_id)
try:
while True:
data = await websocket.receive_json()
# --- Message Routing based on client message type ---
message_type = data.get("type", "chat_message")
if message_type == "chat_message":
# Client is sending a new message
# Check for required fields
if not all(k in data for k in ["content", "timestamp", "receiver_id"]):
await manager.send_personal_message({"type": "error", "message": "Missing required fields (content, timestamp, receiver_id)."}, user_id)
continue
# 1. Log the message
new_message = addMessage(user_id, data)
# 2. Prepare the payload for delivery
delivery_payload = {
"type": "new_message",
"message": new_message
}
# 3. Send message to the target receiver
sent_to_receiver = await manager.send_personal_message(delivery_payload, new_message["receiver_id"])
if sent_to_receiver:
# Optional: Confirm to sender that message was delivered to server/receiver
await manager.send_personal_message({"type": "delivery_receipt", "message_id": new_message["id"], "status": "delivered"}, user_id)
else:
# If receiver is not connected, the status remains 'delivered' in the DB.
# Notify the sender that the receiver is offline.
await manager.send_personal_message({"type": "delivery_receipt", "message_id": new_message["id"], "status": "pending"}, user_id)
elif message_type == "status_update":
# Client is sending a 'received' or 'read' receipt
required_keys = ["message_id", "new_status"]
if not all(k in data for k in required_keys):
await manager.send_personal_message({"type": "error", "message": "Missing required fields for status update."}, user_id)
continue
message_id = data["message_id"]
new_status = data["new_status"] # Should be "received" or "read"
if 0 <= message_id < len(messages):
message_to_update = messages[message_id]
sender_id = message_to_update["sender_id"]
# Update status for the current user (who is the receiver)
message_to_update["status"][user_id] = new_status
# Notify the original sender about the status change
receipt_payload = {
"type": "read_receipt",
"message_id": message_id,
"status": new_status
}
await manager.send_personal_message(receipt_payload, sender_id)
else:
await manager.send_personal_message({"type": "error", "message": f"Message ID {message_id} not found."}, user_id)
else:
await manager.send_personal_message({"type": "error", "message": "Unknown message type."}, user_id)
except WebSocketDisconnect:
# 4. Disconnection Handling
manager.disconnect(user_id)
userLastOnlineTime[user_id] = timeMiliSec(time.time())
print(f"User {user_id} disconnected at {userLastOnlineTime[user_id]}.")
# 5. Status Broadcast on Disconnect
offline_status = {
"type": "status_update",
"user_id": user_id,
"is_online": False,
"last_seen": userLastOnlineTime[user_id]
}
await manager.broadcast(offline_status)
except json.JSONDecodeError:
print(f"Received invalid JSON from user {user_id}")
await manager.send_personal_message({"type": "error", "message": "Invalid JSON format."}, user_id)
except Exception as e:
print(f"Unexpected error in WS loop for user {user_id}: {e}")
manager.disconnect(user_id)
finally:
# Ensure cleanup even if exceptions occur
if user_id in manager.active_connections:
manager.disconnect(user_id)
print(f"Cleanup disconnect for user {user_id}")