File size: 2,278 Bytes
eb42175
 
a0b7d31
71b54be
eb42175
94ca598
 
 
352772b
2ea1f56
fcbd553
 
 
 
 
 
 
 
 
 
 
 
2ea1f56
10c3372
fcbd553
 
 
a0b7d31
 
 
 
fcbd553
 
a0b7d31
10c3372
a0b7d31
fcbd553
a0b7d31
fcbd553
 
089a07b
 
10c3372
 
fcbd553
 
 
 
 
eb42175
a0b7d31
fcbd553
 
a0b7d31
eb42175
fcbd553
10c3372
fcbd553
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from huggingface_hub import InferenceClient
import json

app = FastAPI()

client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")

SYSTEM_MESSAGE = (
    "You are a helpful, respectful and honest assistant. Always answer as helpfully "
    "as possible, while being safe. Your answers should not include any harmful, "
    "unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure "
    "that your responses are socially unbiased and positive in nature.\n\nIf a question "
    "does not make any sense, or is not factually coherent, explain why instead of "
    "answering something not correct. If you don't know the answer to a question, please "
    "don't share false information."
)
MAX_TOKENS = 2000
TEMPERATURE = 0.7
TOP_P = 0.95

def respond(message, history: list[tuple[str, str]]):
    messages = [{"role": "system", "content": SYSTEM_MESSAGE}]
    
    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    response = client.chat_completion(
        messages,
        max_tokens=MAX_TOKENS,
        stream=True,
        temperature=TEMPERATURE,
        top_p=TOP_P,
    )

    for message in response:  # Handle regular iteration
        yield message.choices[0].delta.content

@app.post("/generate/")
async def generate(request: Request):
    form = await request.form()
    prompt = form.get("prompt")
    history = json.loads(form.get("history", "[]"))  # Default to empty history

    if not prompt:
        raise HTTPException(status_code=400, detail="Prompt is required")

    response_generator = respond(prompt, history)
    final_response = ""
    for part in response_generator:
        final_response += part

    return JSONResponse(content={"response": final_response})

app.mount("/", StaticFiles(directory="static", html=True), name="static")

@app.get("/")
def index() -> FileResponse:
    return FileResponse(path="/app/static/index.html", media_type="text/html")