Ll3doke / main1.py
Ashrafb's picture
Rename main.py to main1.py
ada3958 verified
raw
history blame
No virus
2.28 kB
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from huggingface_hub import InferenceClient
import json
app = FastAPI()
client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
SYSTEM_MESSAGE = (
"You are a helpful, respectful and honest assistant. Always answer as helpfully "
"as possible, while being safe. Your answers should not include any harmful, "
"unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure "
"that your responses are socially unbiased and positive in nature.\n\nIf a question "
"does not make any sense, or is not factually coherent, explain why instead of "
"answering something not correct. If you don't know the answer to a question, please "
"don't share false information."
)
MAX_TOKENS = 2000
TEMPERATURE = 0.7
TOP_P = 0.95
def respond(message, history: list[tuple[str, str]]):
messages = [{"role": "system", "content": SYSTEM_MESSAGE}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = client.chat_completion(
messages,
max_tokens=MAX_TOKENS,
stream=True,
temperature=TEMPERATURE,
top_p=TOP_P,
)
for message in response: # Handle regular iteration
yield message.choices[0].delta.content
@app.post("/generate/")
async def generate(request: Request):
form = await request.form()
prompt = form.get("prompt")
history = json.loads(form.get("history", "[]")) # Default to empty history
if not prompt:
raise HTTPException(status_code=400, detail="Prompt is required")
response_generator = respond(prompt, history)
final_response = ""
for part in response_generator:
final_response += part
return JSONResponse(content={"response": final_response})
app.mount("/", StaticFiles(directory="static", html=True), name="static")
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="/app/static/index.html", media_type="text/html")