File size: 3,950 Bytes
4096277
 
 
750073f
 
4096277
 
 
 
 
 
 
 
750073f
 
 
c30d351
4096277
 
c30d351
 
750073f
4096277
 
750073f
4096277
 
 
 
 
 
 
 
 
750073f
 
 
 
 
 
 
 
 
 
 
c30d351
 
4096277
c30d351
 
 
4096277
 
 
 
 
 
 
 
 
c30d351
4096277
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
750073f
 
4096277
 
 
 
 
 
 
 
 
 
 
 
750073f
 
4096277
 
 
 
 
 
 
 
 
750073f
4096277
 
 
 
750073f
c30d351
 
 
750073f
c30d351
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import asyncio
import logging
from fastapi.concurrency import asynccontextmanager
import uvicorn
import os
from dotenv import load_dotenv
from fastapi import FastAPI, Response, WebSocket, WebSocketDisconnect, status, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse

from models.connection_manager import ConnectionManager
from models.request_payload import RequestPayload
from utils.package_manager import PackageManager

# Load environment variables from .env file
load_dotenv()
IS_DEV = os.environ.get('ENV', 'DEV') != 'PROD'
WEBSOCKET_SECURE_TOKEN = os.getenv("SECURE_TOKEN")
WHITELIST_CHANNEL_IDS = os.getenv('WHITELIST_CHANNEL_IDS')
X_REQUEST_USER = os.environ.get('X_REQUEST_USER')
X_API_KEY = os.environ.get('X_API_KEY')

WHITELIST_CHANNEL_IDS = WHITELIST_CHANNEL_IDS.split(',') if WHITELIST_CHANNEL_IDS is not None else []

app = FastAPI()
# Initialize the connection manager
manager = ConnectionManager()
package = PackageManager()

logging.basicConfig(
    level=logging.WARNING,
    format='%(asctime)s %(name)s %(levelname)-8s  %(message)s',
    datefmt='(%H:%M:%S)'
)

# CORS Middleware: restrict access to only trusted origins
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    #allow_origins=["https://your-frontend-domain.com"],
    #allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.get("/")
def root():
    return Response(status_code=status.HTTP_200_OK, data='ok')
        
@app.get("/health")
def healthcheck():
    return Response(status_code=status.HTTP_200_OK, data='ok')

@app.post("/hi_mlhub")
async def hi_mlhub(payload: RequestPayload):
    if manager.available is not None:
        request_id, compressed_data = package.gzip(payload)
        
        # Send binary data to all connected WebSocket clients
        await manager.send_bytes(manager.available, compressed_data)
        
        try:
            # Wait for the response with a timeout (e.g., 10 seconds)
            data = await asyncio.wait_for(manager.listen(manager.available, request_id), timeout=10.0)
            return JSONResponse(status_code=status.HTTP_200_OK, content=data)
        except Exception:
            return JSONResponse(status_code=status.HTTP_504_GATEWAY_TIMEOUT, content={ "error": "Timeout" })
    else:
        return JSONResponse(status_code=status.HTTP_502_BAD_GATEWAY, content={ "error": "MLaaS is not available." })

# Simple token-based authentication dependency
def is_valid_token(token: str):
    return token == WEBSOCKET_SECURE_TOKEN

def is_valid_apikey(channel_id: str):
    return channel_id is not None and channel_id in WHITELIST_CHANNEL_IDS

# WebSocket endpoint
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    headers = websocket.headers
    token = headers.get("x-token")
    channel_id = headers.get("x-channel-id")

    if not is_valid_token(token):
        return HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
    if not is_valid_apikey(channel_id):
        return HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="No permission")

    await manager.connect(channel_id, websocket)
    
    try:
        while True:
            # Common receiver
            data = await manager.receive_text(channel_id)
            print(f"Message from MLaaS: {data}")

            # Notify the manager that a message was received
            await manager.notify(channel_id, data)

            # Broadcast the message to all clients
            #await manager.broadcast(f"Client {channel_id} says: {data}")
    except WebSocketDisconnect:
        manager.disconnect(channel_id)
        await manager.broadcast(f"A client has disconnected with ID: {channel_id}")

    return None

def is_valid(u, p):
    return u == X_REQUEST_USER and p == X_API_KEY

if __name__ == "__main__":
    uvicorn.run('app:app', host='0.0.0.0', port=7860, reload=True)