Uhhy commited on
Commit
90624da
·
verified ·
1 Parent(s): ebc22be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -9
app.py CHANGED
@@ -121,20 +121,29 @@ def remove_repetitive_responses(responses):
121
  return unique_responses
122
 
123
  @app.post("/generate/")
 
124
  async def generate(request: ChatRequest):
125
  try:
126
- normalized_message = normalize_input(request.message)
127
- with ThreadPoolExecutor() as executor:
128
- futures = [executor.submit(model.generate, f"<s>[INST]{normalized_message} [/INST]",
129
- top_k=request.top_k, top_p=request.top_p, temperature=request.temperature)
130
- for model in global_data['models'].values()]
131
- responses = [{'model': model, 'response': future.result()}
132
- for model, future in zip(global_data['models'].keys(), as_completed(futures))]
133
-
134
  unique_responses = remove_repetitive_responses(responses)
135
  return unique_responses
136
  except Exception as e:
137
- raise HTTPException(status_code=500, detail=f"Error generating responses: {e}")
 
 
 
 
 
 
 
 
 
 
138
 
139
  if __name__ == "__main__":
140
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
121
  return unique_responses
122
 
123
  @app.post("/generate/")
124
+ @GPU(duration=0)
125
  async def generate(request: ChatRequest):
126
  try:
127
+ inputs = normalize_input(request.message)
128
+ futures = [
129
+ executor.submit(model.generate, inputs, top_k=request.top_k, top_p=request.top_p, temperature=request.temperature)
130
+ for model in global_data['models'].values()
131
+ ]
132
+ responses = [{'model': model, 'response': future.result()} for model, future in zip(global_data['models'].keys(), as_completed(futures))]
 
 
133
  unique_responses = remove_repetitive_responses(responses)
134
  return unique_responses
135
  except Exception as e:
136
+ print(f"Error generating responses: {e}")
137
+ raise HTTPException(status_code=500, detail="Error generating responses")
138
+
139
+ @app.middleware("http")
140
+ async def process_request(request: Request, call_next):
141
+ try:
142
+ response = await call_next(request)
143
+ return response
144
+ except Exception as e:
145
+ print(f"Request error: {e}")
146
+ raise HTTPException(status_code=500, detail="Internal Server Error")
147
 
148
  if __name__ == "__main__":
149
  uvicorn.run(app, host="0.0.0.0", port=8000)