|
from fastapi import FastAPI, HTTPException |
|
from typing import Any, Dict |
|
from pydantic import BaseModel |
|
import os |
|
from os import getenv |
|
from huggingface_hub import InferenceClient |
|
import random |
|
from json_repair import repair_json |
|
import nltk |
|
|
|
app = FastAPI() |
|
|
|
nltk.download('punkt') |
|
|
|
tokenizer = nltk.data.load('tokenizers/punkt/english.pickle') |
|
|
|
HF_TOKEN = getenv("HF_TOKEN") |
|
|
|
class InputData(BaseModel): |
|
model: str |
|
system_prompt_template: str |
|
prompt_template: str |
|
end_token: str |
|
system_prompt: str |
|
user_input: str |
|
json_prompt: str |
|
history: str = "" |
|
|
|
@app.post("/generate-response/") |
|
async def generate_response(data: InputData) -> Dict[str, Any]: |
|
client = InferenceClient(model=data.model, token=HF_TOKEN) |
|
|
|
sentences = tokenizer.tokenize(data.user_input) |
|
data_dict = {'###New response###': [], '###Sentence count###': 0} |
|
for i, sentence in enumerate(sentences): |
|
data_dict["###New response###"].append(sentence) |
|
data_dict["###Sentence count###"] = i + 1 |
|
|
|
data.history += data.prompt_template.replace("{Prompt}", str(data_dict)) |
|
|
|
inputs = ( |
|
data.system_prompt_template.replace("{SystemPrompt}", data.system_prompt) + |
|
data.system_prompt_template.replace("{SystemPrompt}", data.json_prompt) + |
|
data.history |
|
) |
|
|
|
seed = random.randint(0, 2**32 - 1) |
|
|
|
try: |
|
response = client.text_generation( |
|
inputs, |
|
temperature=1.0, |
|
max_new_tokens=1000, |
|
seed=seed |
|
) |
|
|
|
strict_response = str(response) |
|
|
|
repaired_response = repair_json(strict_response, return_objects=True) |
|
|
|
if isinstance(repaired_response, str): |
|
raise HTTPException(status_code=500, detail="Invalid response from model") |
|
else: |
|
cleaned_response = {} |
|
for key, value in repaired_response.items(): |
|
cleaned_key = key.replace("###", "") |
|
cleaned_response[cleaned_key] = value |
|
|
|
for i, text in enumerate(cleaned_response["New response"]): |
|
if i <= 2: |
|
sentences = tokenizer.tokenize(text) |
|
if sentences: |
|
cleaned_response["New response"][i] = sentences[0] |
|
else: |
|
del cleaned_response["New response"][i] |
|
|
|
if cleaned_response.get("Sentence count"): |
|
if cleaned_response["Sentence count"] > 3: |
|
cleaned_response["Sentence count"] = 3 |
|
else: |
|
cleaned_response["Sentence count"] = len(cleaned_response["New response"]) |
|
|
|
data.history += str(cleaned_response) |
|
|
|
return { |
|
"response": cleaned_response, |
|
"history": data.history + data.end_token |
|
} |
|
|
|
except Exception as e: |
|
print(f"Model {data.model} failed with error: {e}") |
|
raise HTTPException(status_code=500, detail=f"Model {data.model} failed to generate response") |
|
|
|
@app.post("/get-medieval-name/") |
|
async def get_medieval_name() -> Dict[str, str]: |
|
try: |
|
file_path = "medieval_names.txt" |
|
if not os.path.exists(file_path): |
|
raise HTTPException(status_code=404, detail="File not found") |
|
|
|
with open(file_path, "r") as file: |
|
names = file.read().splitlines() |
|
|
|
if not names: |
|
raise HTTPException(status_code=404, detail="No names found in the file") |
|
|
|
random_name = random.choice(names) |
|
|
|
return {"name": random_name} |
|
|
|
except Exception as e: |
|
print(f"Error: {e}") |
|
raise HTTPException(status_code=500, detail="An error occurred while processing the request") |