|
from fastapi import FastAPI, Request, HTTPException |
|
from fastapi.responses import JSONResponse, FileResponse |
|
from fastapi.staticfiles import StaticFiles |
|
from huggingface_hub import InferenceClient |
|
import json |
|
|
|
app = FastAPI() |
|
|
|
client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct") |
|
|
|
SYSTEM_MESSAGE = ( |
|
"You are a helpful, respectful and honest assistant. Always answer as helpfully " |
|
"as possible, while being safe. Your answers should not include any harmful, " |
|
"unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure " |
|
"that your responses are socially unbiased and positive in nature.\n\nIf a question " |
|
"does not make any sense, or is not factually coherent, explain why instead of " |
|
"answering something not correct. If you don't know the answer to a question, please " |
|
"don't share false information." |
|
"Always respond in the language of user prompt for each prompt ." |
|
) |
|
MAX_TOKENS = 2000 |
|
TEMPERATURE = 0.7 |
|
TOP_P = 0.95 |
|
|
|
def respond(message, history: list[tuple[str, str]]): |
|
messages = [{"role": "system", "content": SYSTEM_MESSAGE}] |
|
|
|
for val in history: |
|
if val[0]: |
|
messages.append({"role": "user", "content": val[0]}) |
|
if val[1]: |
|
messages.append({"role": "assistant", "content": val[1]}) |
|
|
|
messages.append({"role": "user", "content": message}) |
|
|
|
response = client.chat_completion( |
|
messages, |
|
max_tokens=MAX_TOKENS, |
|
stream=True, |
|
temperature=TEMPERATURE, |
|
top_p=TOP_P, |
|
) |
|
|
|
for message in response: |
|
yield message.choices[0].delta.content |
|
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["https://artixiban-ll3.static.hf.space"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
@app.post("/generate/") |
|
async def generate(request: Request): |
|
allowed_origin = "https://artixiban-ll3.static.hf.space" |
|
origin = request.headers.get("origin") |
|
if origin != allowed_origin: |
|
raise HTTPException(status_code=403, detail="Origin not allowed") |
|
form = await request.form() |
|
prompt = form.get("prompt") |
|
history = json.loads(form.get("history", "[]")) |
|
|
|
if not prompt: |
|
raise HTTPException(status_code=400, detail="Prompt is required") |
|
|
|
response_generator = respond(prompt, history) |
|
final_response = "" |
|
for part in response_generator: |
|
final_response += part |
|
|
|
return JSONResponse(content={"response": final_response}) |
|
|
|
|