New-Place / main.py
oflakne26's picture
Update main.py
bf0e60d verified
raw
history blame contribute delete
No virus
3.8 kB
from fastapi import FastAPI, HTTPException
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from os import getenv
from huggingface_hub import InferenceClient
import random
import nltk
import re
from word_forms.word_forms import get_word_forms
app = FastAPI()
nltk.download('punkt')
tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
HF_TOKEN = getenv("HF_TOKEN")
class InputData(BaseModel):
model: str
system_prompt_template: str
prompt_template: str
end_token: str
system_prompts: List[str]
user_inputs: List[str]
history: str = ""
segment: bool = False
max_sentences: Optional[int] = None
class WordCheckData(BaseModel):
string: str
word: str
@app.post("/generate-response/")
async def generate_response(data: InputData) -> Dict[str, Any]:
if data.max_sentences is not None and data.max_sentences != 0:
data.segment = True
elif data.max_sentences == 0:
for user_input in data.user_inputs:
data.history += data.prompt_template.replace("{Prompt}", user_input)
return {
"response": [],
"sentence_count": None,
"history": data.history + data.end_token
}
responses = []
if data.segment:
for user_input in data.user_inputs:
user_sentences = tokenizer.tokenize(user_input)
user_input_str = "\n".join(user_sentences)
data.history += data.prompt_template.replace("{Prompt}", user_input_str) + "\n"
else:
for user_input in data.user_inputs:
data.history += data.prompt_template.replace("{Prompt}", user_input) + "\n"
inputs = ""
for system_prompt in data.system_prompts:
inputs += data.system_prompt_template.replace("{SystemPrompt}", system_prompt) + "\n"
inputs += data.history
seed = random.randint(0, 2**32 - 1)
try:
client = InferenceClient(model=data.model, token=HF_TOKEN)
response = client.text_generation(
inputs,
temperature=1.0,
max_new_tokens=1000,
seed=seed
)
response_str = str(response)
if data.segment:
ai_sentences = tokenizer.tokenize(response_str)
if data.max_sentences is not None:
ai_sentences = ai_sentences[:data.max_sentences]
responses = ai_sentences
sentence_count = len(ai_sentences)
else:
responses = [response_str]
sentence_count = None
data.history += response_str + "\n"
return {
"response": responses,
"sentence_count": sentence_count,
"history": data.history + data.end_token
}
except Exception as e:
print(f"Model {data.model} failed with error: {e}")
raise HTTPException(status_code=500, detail=f"Model {data.model} failed to generate response")
@app.post("/check-word/")
async def check_word(data: WordCheckData) -> Dict[str, Any]:
input_string = data.string.lower()
word = data.word.lower()
forms = get_word_forms(word)
all_forms = set()
for words in forms.values():
all_forms.update(words)
# Initialize found flag
found = False
# Split the input string into words
input_words = input_string.split()
# Loop through each word in the input string
for input_word in input_words:
# Strip the word to contain only alphabetic characters
input_word = ''.join(filter(str.isalpha, input_word))
# Check if the stripped word is equal to any of the forms
if input_word in all_forms:
found = True
break # Exit loop if word is found
result = {
"found": found
}
return result