Ll3doke / main.py
Ashrafb's picture
Update main.py
353f81e verified
raw
history blame
2.68 kB
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from huggingface_hub import InferenceClient
import json
app = FastAPI()
client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
SYSTEM_MESSAGE = (
"You are a helpful, respectful and honest assistant. Always answer as helpfully "
"as possible, while being safe. Your answers should not include any harmful, "
"unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure "
"that your responses are socially unbiased and positive in nature.\n\nIf a question "
"does not make any sense, or is not factually coherent, explain why instead of "
"answering something not correct. If you don't know the answer to a question, please "
"don't share false information."
"Always respond in the language of user prompt for each prompt ."
)
MAX_TOKENS = 2000
TEMPERATURE = 0.7
TOP_P = 0.95
def respond(message, history: list[tuple[str, str]]):
messages = [{"role": "system", "content": SYSTEM_MESSAGE}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = client.chat_completion(
messages,
max_tokens=MAX_TOKENS,
stream=True,
temperature=TEMPERATURE,
top_p=TOP_P,
)
for message in response: # Handle regular iteration
yield message.choices[0].delta.content
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=["https://artixiban-ll3.static.hf.space"], # Allow only this origin
allow_credentials=True,
allow_methods=["*"], # Allow all methods (GET, POST, etc.)
allow_headers=["*"], # Allow all headers
)
@app.post("/generate/")
async def generate(request: Request):
allowed_origin = "https://artixiban-ll3.static.hf.space"
origin = request.headers.get("origin")
if origin != allowed_origin:
raise HTTPException(status_code=403, detail="Origin not allowed")
form = await request.form()
prompt = form.get("prompt")
history = json.loads(form.get("history", "[]")) # Default to empty history
if not prompt:
raise HTTPException(status_code=400, detail="Prompt is required")
response_generator = respond(prompt, history)
final_response = ""
for part in response_generator:
final_response += part
return JSONResponse(content={"response": final_response})