oflakne26 commited on
Commit
57766c5
1 Parent(s): 83be1bc

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +42 -38
main.py CHANGED
@@ -7,7 +7,6 @@ import random
7
  from json_repair import repair_json
8
  import nltk
9
  import sys
10
- sys.setrecursionlimit(10000) # Set a higher recursion limit
11
 
12
  app = FastAPI()
13
 
@@ -22,6 +21,8 @@ FALLBACK_MODELS = [
22
  "mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mistral-7B-Instruct-v0.1"
23
  ]
24
 
 
 
25
  class InputData(BaseModel):
26
  model: str
27
  system_prompt_template: str
@@ -52,45 +53,48 @@ async def generate_response(data: InputData) -> Any:
52
  seed = random.randint(0, 2**32 - 1)
53
 
54
  models_to_try = [data.model] + FALLBACK_MODELS
55
-
56
- for model in models_to_try:
57
- try:
58
- response = client.text_generation(inputs,
59
- temperature=1.0,
60
- max_new_tokens=1000,
61
- seed=seed)
62
-
63
- strict_response = str(response)
64
-
65
- repaired_response = repair_json(strict_response,
66
- return_objects=True)
67
-
68
- if isinstance(repaired_response, str):
69
- raise HTTPException(status_code=500, detail="Invalid response from model")
70
- else:
71
- cleaned_response = {}
72
- for key, value in repaired_response.items():
73
- cleaned_key = key.replace("###", "")
74
- cleaned_response[cleaned_key] = value
75
-
76
- for i, text in enumerate(cleaned_response["New response"]):
77
- if i <= 2:
78
- sentences = tokenizer.tokenize(text)
79
- if sentences:
80
- cleaned_response["New response"][i] = sentences[0]
81
- else:
82
- del cleaned_response["New response"][i]
83
- if cleaned_response.get("Sentence count"):
84
- if cleaned_response["Sentence count"] > 3:
85
- cleaned_response["Sentence count"] = 3
86
  else:
87
- cleaned_response["Sentence count"] = len(cleaned_response["New response"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- data.history += str(cleaned_response)
90
 
91
- return cleaned_response
92
 
93
- except Exception as e:
94
- print(f"Model {model} failed with error: {e}")
 
95
 
96
- raise HTTPException(status_code=500, detail="All models failed to generate response")
 
7
  from json_repair import repair_json
8
  import nltk
9
  import sys
 
10
 
11
  app = FastAPI()
12
 
 
21
  "mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mistral-7B-Instruct-v0.1"
22
  ]
23
 
24
+ MAX_RETRIES = 3 # Maximum number of retries
25
+
26
  class InputData(BaseModel):
27
  model: str
28
  system_prompt_template: str
 
53
  seed = random.randint(0, 2**32 - 1)
54
 
55
  models_to_try = [data.model] + FALLBACK_MODELS
56
+ retries = 0
57
+
58
+ while retries < MAX_RETRIES:
59
+ for model in models_to_try:
60
+ try:
61
+ response = client.text_generation(inputs,
62
+ temperature=1.0,
63
+ max_new_tokens=1000,
64
+ seed=seed)
65
+
66
+ strict_response = str(response)
67
+
68
+ repaired_response = repair_json(strict_response,
69
+ return_objects=True)
70
+
71
+ if isinstance(repaired_response, str):
72
+ raise HTTPException(status_code=500, detail="Invalid response from model")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  else:
74
+ cleaned_response = {}
75
+ for key, value in repaired_response.items():
76
+ cleaned_key = key.replace("###", "")
77
+ cleaned_response[cleaned_key] = value
78
+
79
+ for i, text in enumerate(cleaned_response["New response"]):
80
+ if i <= 2:
81
+ sentences = tokenizer.tokenize(text)
82
+ if sentences:
83
+ cleaned_response["New response"][i] = sentences[0]
84
+ else:
85
+ del cleaned_response["New response"][i]
86
+ if cleaned_response.get("Sentence count"):
87
+ if cleaned_response["Sentence count"] > 3:
88
+ cleaned_response["Sentence count"] = 3
89
+ else:
90
+ cleaned_response["Sentence count"] = len(cleaned_response["New response"])
91
 
92
+ data.history += str(cleaned_response)
93
 
94
+ return cleaned_response
95
 
96
+ except Exception as e:
97
+ print(f"Model {model} failed with error: {e}")
98
+ retries += 1
99
 
100
+ raise HTTPException(status_code=500, detail="All models failed to generate response after maximum retries")