Update main.py
Browse files
main.py
CHANGED
@@ -30,7 +30,7 @@ class InputData(BaseModel):
|
|
30 |
json_prompt: str
|
31 |
history: str = ""
|
32 |
|
33 |
-
|
34 |
async def generate_response(data: InputData) -> Any:
|
35 |
client = InferenceClient(model=data.model, token=HF_TOKEN)
|
36 |
|
@@ -50,10 +50,49 @@ async def generate_response(data: InputData) -> Any:
|
|
50 |
|
51 |
seed = random.randint(0, 2**32 - 1)
|
52 |
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
-
|
|
|
56 |
try:
|
|
|
57 |
response = client.text_generation(inputs,
|
58 |
temperature=1.0,
|
59 |
max_new_tokens=1000,
|
|
|
30 |
json_prompt: str
|
31 |
history: str = ""
|
32 |
|
33 |
+
@@app.post("/generate-response/")
|
34 |
async def generate_response(data: InputData) -> Any:
|
35 |
client = InferenceClient(model=data.model, token=HF_TOKEN)
|
36 |
|
|
|
50 |
|
51 |
seed = random.randint(0, 2**32 - 1)
|
52 |
|
53 |
+
try:
|
54 |
+
response = client.text_generation(inputs,
|
55 |
+
temperature=1.0,
|
56 |
+
max_new_tokens=1000,
|
57 |
+
seed=seed)
|
58 |
+
|
59 |
+
strict_response = str(response)
|
60 |
+
|
61 |
+
repaired_response = repair_json(strict_response,
|
62 |
+
return_objects=True)
|
63 |
+
|
64 |
+
if isinstance(repaired_response, str):
|
65 |
+
raise HTTPException(status_code=500, detail="Invalid response from model")
|
66 |
+
else:
|
67 |
+
cleaned_response = {}
|
68 |
+
for key, value in repaired_response.items():
|
69 |
+
cleaned_key = key.replace("###", "")
|
70 |
+
cleaned_response[cleaned_key] = value
|
71 |
+
|
72 |
+
for i, text in enumerate(cleaned_response["New response"]):
|
73 |
+
if i <= 2:
|
74 |
+
sentences = tokenizer.tokenize(text)
|
75 |
+
if sentences:
|
76 |
+
cleaned_response["New response"][i] = sentences[0]
|
77 |
+
else:
|
78 |
+
del cleaned_response["New response"][i]
|
79 |
+
if cleaned_response.get("Sentence count"):
|
80 |
+
if cleaned_response["Sentence count"] > 3:
|
81 |
+
cleaned_response["Sentence count"] = 3
|
82 |
+
else:
|
83 |
+
cleaned_response["Sentence count"] = len(cleaned_response["New response"])
|
84 |
+
|
85 |
+
data.history += str(cleaned_response)
|
86 |
+
|
87 |
+
return cleaned_response
|
88 |
+
|
89 |
+
except Exception as e:
|
90 |
+
print(f"Primary model {data.model} failed with error: {e}")
|
91 |
|
92 |
+
# If the primary model fails, try fallback models
|
93 |
+
for model in FALLBACK_MODELS:
|
94 |
try:
|
95 |
+
client = InferenceClient(model=model, token=HF_TOKEN)
|
96 |
response = client.text_generation(inputs,
|
97 |
temperature=1.0,
|
98 |
max_new_tokens=1000,
|