answerer-api / main.py
DaniilAlpha's picture
Update main.py
9fe8e1f
from asyncio import sleep
from typing import Optional
from fastapi import FastAPI
from fastapi.encoders import jsonable_encoder
from fastapi.websockets import WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse, JSONResponse
from websockets import ConnectionClosed
from accelerator import Accelerator
from answerer import Answerer
from mapper import Mapper
try: mapper = Mapper("sentence-transformers/multi-qa-distilbert-cos-v1")
except Exception as e: print(f"ERROR! cannot load Mapper model!\n{e}")
answerer = Answerer(
model="RWKV-5-World-3B-v2-20231118-ctx16k",
vocab="rwkv_vocab_v20230424",
strategy="cpu bf16",
ctx_limit=16*1024,
)
accelerator = Accelerator()
app = FastAPI()
HTML = """
<!DOCTYPE HTML>
<html>
<body>
<form action="" onsubmit="ask(event)">
<textarea id="prompt"></textarea>
<br>
<input type="submit" value="SEND" />
</form>
<p id="output"></p>
<script>
const prompt = document.getElementById("prompt");
const output = document.getElementById("output");
const ws = new WebSocket("wss://daniilalpha-answerer-api.hf.space/answer");
ws.onmessage = (e) => answer(e.data);
function ask(event) {
if(ws.readyState != 1) {
answer("websocket is not connected!");
return;
}
ws.send(prompt.value);
event.preventDefault();
}
function answer(value) {
output.innerHTML = value;
}
</script>
</body>
</html>
"""
@app.get("/")
def index():
return HTMLResponse(HTML)
@app.websocket("/accelerate")
async def answer(ws: WebSocket):
await accelerator.connect(ws)
while accelerator.connected():
await sleep(10)
@app.post("/map")
def map(query: Optional[str], items: Optional[list[str]]):
scores = mapper(query, items)
return JSONResponse(jsonable_encoder(scores))
async def handle_answerer_local(ws: WebSocket, input: str):
output = answerer(input, 128)
el: str
async for el in output: pass
await ws.send_text(el)
async def handle_answerer_accelerated(ws: WebSocket, input: str):
output = await accelerator.accelerate(input)
if output: await ws.send_text(output)
else: await handle_answerer_local(ws, input)
@app.websocket("/answer")
async def answer(ws: WebSocket):
await ws.accept()
try:
input = await ws.receive_text()
if accelerator.connected(): await handle_answerer_accelerated(ws, input)
else: await handle_answerer_local(ws, input)
except ConnectionClosed: return
except WebSocketDisconnect: return
await ws.close()