File size: 3,388 Bytes
e53fb7b 5dab16f e53fb7b 83be1bc 5dab16f 4e06fbd e53fb7b 04dbc8e e53fb7b 04dbc8e e53fb7b 04dbc8e e53fb7b 0abcce5 57766c5 0abcce5 e53fb7b 0abcce5 e53fb7b 0abcce5 e53fb7b 0abcce5 e53fb7b 0abcce5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
from fastapi import FastAPI, HTTPException
from typing import Any
from pydantic import BaseModel
from os import getenv
from huggingface_hub import InferenceClient
import random
from json_repair import repair_json
import nltk
import sys
app = FastAPI()
nltk.download('punkt')
tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
HF_TOKEN = getenv("HF_TOKEN")
MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
FALLBACK_MODELS = [
"mistralai/Mixtral-8x7B-Instruct-v0.1",
"mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mistral-7B-Instruct-v0.1"
]
class InputData(BaseModel):
model: str
system_prompt_template: str
prompt_template: str
system_prompt: str
user_input: str
json_prompt: str
history: str = ""
@app.post("/generate-response/")
async def generate_response(data: InputData) -> Any:
client = InferenceClient(model=data.model, token=HF_TOKEN)
sentences = tokenizer.tokenize(data.user_input)
data_dict = {'###New response###': [], '###Sentence count###': 0}
for i, sentence in enumerate(sentences):
data_dict["###New response###"].append(sentence)
data_dict["###Sentence count###"] = i + 1
data.history += data.prompt_template.replace("{Prompt}", str(data_dict))
inputs = (
data.system_prompt_template.replace("{SystemPrompt}",
data.system_prompt) +
data.system_prompt_template.replace("{SystemPrompt}", data.json_prompt) +
data.history)
seed = random.randint(0, 2**32 - 1)
models_to_try = [data.model] + FALLBACK_MODELS
for model in models_to_try:
try:
response = client.text_generation(inputs,
temperature=1.0,
max_new_tokens=1000,
seed=seed)
strict_response = str(response)
repaired_response = repair_json(strict_response,
return_objects=True)
if isinstance(repaired_response, str):
raise HTTPException(status_code=500, detail="Invalid response from model")
else:
cleaned_response = {}
for key, value in repaired_response.items():
cleaned_key = key.replace("###", "")
cleaned_response[cleaned_key] = value
for i, text in enumerate(cleaned_response["New response"]):
if i <= 2:
sentences = tokenizer.tokenize(text)
if sentences:
cleaned_response["New response"][i] = sentences[0]
else:
del cleaned_response["New response"][i]
if cleaned_response.get("Sentence count"):
if cleaned_response["Sentence count"] > 3:
cleaned_response["Sentence count"] = 3
else:
cleaned_response["Sentence count"] = len(cleaned_response["New response"])
data.history += str(cleaned_response)
return cleaned_response
except Exception as e:
print(f"Model {model} failed with error: {e}")
raise HTTPException(status_code=500, detail="All models failed to generate response") |