Ll3doke / main.py
Ashrafb's picture
Update main.py
8c3bfa5 verified
raw
history blame
No virus
2.74 kB
from fastapi import FastAPI, File, UploadFile, Form, Request
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from huggingface_hub import InferenceClient
import random
API_URL = "https://api-inference.huggingface.co/models/"
client = InferenceClient(
"mistralai/Mistral-7B-Instruct-v0.1"
)
app = FastAPI()
def format_prompt(message, history):
prompt = "<s>"
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
prompt += f"[INST] {message} [/INST]"
return prompt
def generate(prompt, history, temperature=0.9, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=random.randint(0, 10**7),
)
formatted_prompt = format_prompt(prompt, history)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
word = ""
for response in stream:
token_text = response.token.text.strip()
if token_text != "":
# Decode the token text to handle encoded characters
decoded_text = token_text.encode("utf-8", "backslashreplace").decode("utf-8")
# Add the decoded letter to the current word
word += decoded_text
# If the token is a space or the end of the stream, add the word to the output and reset the word
if token_text == " " or response.is_end_of_stream:
output += word + " "
word = ""
return output
@app.post("/generate/")
async def generate_chat(request: Request, prompt: str = Form(...), history: str = Form(...), temperature: float = Form(0.9), max_new_tokens: int = Form(512), top_p: float = Form(0.95), repetition_penalty: float = Form(1.0)):
history = eval(history) # Convert history string back to list
response = generate(prompt, history, temperature, max_new_tokens, top_p, repetition_penalty)
# Concatenate the generated response strings into a single coherent response
coherent_response = " ".join(response)
return {"response": coherent_response}
app.mount("/", StaticFiles(directory="static", html=True), name="static")
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="/app/static/index.html", media_type="text/html")