File size: 3,798 Bytes
e53fb7b f319ec8 5dab16f e53fb7b d83f45c 5dab16f 4e06fbd e53fb7b 84a6fe8 db9c88b 84a6fe8 e53fb7b d17ac0e e53fb7b d83f45c 407b03e 344f4fe d17ac0e 84a6fe8 d17ac0e 63c8a4b d17ac0e e53fb7b 63c8a4b d17ac0e 84a6fe8 f319ec8 84a6fe8 d17ac0e 84a6fe8 04dbc8e f319ec8 84a6fe8 f319ec8 04dbc8e e53fb7b 04dbc8e 9db93ec d17ac0e 344f4fe 0abcce5 cdca6a5 d17ac0e 63c8a4b d17ac0e 63c8a4b cdca6a5 63c8a4b cdca6a5 63c8a4b cdca6a5 9db93ec c81e2b1 d83f45c 680bdc2 e44ac21 680bdc2 d83f45c 84a6fe8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
from fastapi import FastAPI, HTTPException
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from os import getenv
from huggingface_hub import InferenceClient
import random
import nltk
import re
from word_forms.word_forms import get_word_forms
app = FastAPI()
nltk.download('punkt')
tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
HF_TOKEN = getenv("HF_TOKEN")
class InputData(BaseModel):
model: str
system_prompt_template: str
prompt_template: str
end_token: str
system_prompts: List[str]
user_inputs: List[str]
history: str = ""
segment: bool = False
max_sentences: Optional[int] = None
class WordCheckData(BaseModel):
string: str
word: str
@app.post("/generate-response/")
async def generate_response(data: InputData) -> Dict[str, Any]:
if data.max_sentences is not None and data.max_sentences != 0:
data.segment = True
elif data.max_sentences == 0:
for user_input in data.user_inputs:
data.history += data.prompt_template.replace("{Prompt}", user_input)
return {
"response": [],
"sentence_count": None,
"history": data.history + data.end_token
}
responses = []
if data.segment:
for user_input in data.user_inputs:
user_sentences = tokenizer.tokenize(user_input)
user_input_str = "\n".join(user_sentences)
data.history += data.prompt_template.replace("{Prompt}", user_input_str) + "\n"
else:
for user_input in data.user_inputs:
data.history += data.prompt_template.replace("{Prompt}", user_input) + "\n"
inputs = ""
for system_prompt in data.system_prompts:
inputs += data.system_prompt_template.replace("{SystemPrompt}", system_prompt) + "\n"
inputs += data.history
seed = random.randint(0, 2**32 - 1)
try:
client = InferenceClient(model=data.model, token=HF_TOKEN)
response = client.text_generation(
inputs,
temperature=1.0,
max_new_tokens=1000,
seed=seed
)
response_str = str(response)
if data.segment:
ai_sentences = tokenizer.tokenize(response_str)
if data.max_sentences is not None:
ai_sentences = ai_sentences[:data.max_sentences]
responses = ai_sentences
sentence_count = len(ai_sentences)
else:
responses = [response_str]
sentence_count = None
data.history += response_str + "\n"
return {
"response": responses,
"sentence_count": sentence_count,
"history": data.history + data.end_token
}
except Exception as e:
print(f"Model {data.model} failed with error: {e}")
raise HTTPException(status_code=500, detail=f"Model {data.model} failed to generate response")
@app.post("/check-word/")
async def check_word(data: WordCheckData) -> Dict[str, Any]:
input_string = data.string.lower()
word = data.word.lower()
forms = get_word_forms(word)
all_forms = set()
for words in forms.values():
all_forms.update(words)
# Initialize found flag
found = False
# Split the input string into words
input_words = input_string.split()
# Loop through each word in the input string
for input_word in input_words:
# Strip the word to contain only alphabetic characters
input_word = ''.join(filter(str.isalpha, input_word))
# Check if the stripped word is equal to any of the forms
if input_word in all_forms:
found = True
break # Exit loop if word is found
result = {
"found": found
}
return result |