from asyncio import sleep from typing import Optional from fastapi import FastAPI from fastapi.encoders import jsonable_encoder from fastapi.websockets import WebSocket, WebSocketDisconnect from fastapi.responses import HTMLResponse, JSONResponse from websockets import ConnectionClosed from accelerator import Accelerator from answerer import Answerer from mapper import Mapper try: mapper = Mapper("sentence-transformers/multi-qa-distilbert-cos-v1") except Exception as e: print(f"ERROR! cannot load Mapper model!\n{e}") answerer = Answerer( model="RWKV-5-World-3B-v2-20231118-ctx16k", vocab="rwkv_vocab_v20230424", strategy="cpu bf16", ctx_limit=16*1024, ) accelerator = Accelerator() app = FastAPI() HTML = """
""" @app.get("/") def index(): return HTMLResponse(HTML) @app.websocket("/accelerate") async def answer(ws: WebSocket): await accelerator.connect(ws) while accelerator.connected(): await sleep(10) @app.post("/map") def map(query: Optional[str], items: Optional[list[str]]): scores = mapper(query, items) return JSONResponse(jsonable_encoder(scores)) async def handle_answerer_local(ws: WebSocket, input: str): output = answerer(input, 128) el: str async for el in output: pass await ws.send_text(el) async def handle_answerer_accelerated(ws: WebSocket, input: str): output = await accelerator.accelerate(input) if output: await ws.send_text(output) else: await handle_answerer_local(ws, input) @app.websocket("/answer") async def answer(ws: WebSocket): await ws.accept() try: input = await ws.receive_text() if accelerator.connected(): await handle_answerer_accelerated(ws, input) else: await handle_answerer_local(ws, input) except ConnectionClosed: return except WebSocketDisconnect: return await ws.close()