MJobe commited on
Commit
e2d5a35
1 Parent(s): 17a7267

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +14 -39
main.py CHANGED
@@ -267,7 +267,10 @@ async def test_transcription(file: UploadFile = File(...)):
267
  except Exception as e:
268
  raise HTTPException(status_code=500, detail=f"Error during transcription: {str(e)}")
269
 
270
- # Predefined classification labels
 
 
 
271
  labels = [
272
  "All Pricing copy quote requested",
273
  "Change to quote",
@@ -285,58 +288,30 @@ labels = [
285
  "Status Check (possibly)"
286
  ]
287
 
288
- # Keyword map for rule-based shortcuts
289
- keyword_map = {
290
- "All Pricing copy quote requested": ["Per ATP", "Not Want", "No Advertising"],
291
- "Change to quote": ["MSRP"],
292
- "Change to quote & Status Check": ["Send clean Quote"],
293
- "Change to quote (Items missed?)": ["Per ATP"],
294
- "Confirmation": ["N/A"],
295
- "Copy quote requested": ["N/A"],
296
- "Cost copy quote requested": ["N/A"],
297
- "MRSP copy quote requested": ["Quote Requested", "Revise Quote", "Include the following items"],
298
- "MSRP & All Pricing copy quote requested": ["Revise and resend Quote"],
299
- "MSRP & Cost copy quote requested": ["Per RTS"],
300
- "No narrative in email": ["Requesting"],
301
- "Notes not clear": ["Add to the quote"],
302
- "Retail copy quote requested": ["Action Required", "Send copy of Quote", "All Pricing"],
303
- "Status Check (possibly)": ["Re-quote", "Current date and Pricing", "Send quote"]
304
- }
305
-
306
- # Thread pool executor for handling requests concurrently
307
- executor = ThreadPoolExecutor(max_workers=10)
308
-
309
- # Helper function to check keywords
310
- def check_keywords(statement):
311
- for label, keywords in keyword_map.items():
312
- for keyword in keywords:
313
- if re.search(rf"\b{keyword}\b", statement, re.IGNORECASE):
314
- return label
315
- return None
316
-
317
- @app.post("/fast_classify/", description="Quickly classify text with a rule-based approach combined with model inference.")
318
  async def fast_classify_text(statement: str = Form(...)):
319
  try:
320
- # Step 1: Check for rule-based keyword match
321
- rule_based_classification = check_keywords(statement)
322
- if rule_based_classification:
323
- return {"classification": rule_based_classification, "confidence": 1.0}
324
-
325
- # Step 2: Run model classification if no rule-based match
326
  loop = asyncio.get_running_loop()
327
  result = await loop.run_in_executor(
328
- executor,
329
  lambda: nlp_sequence_classification(statement, labels, multi_label=False)
330
  )
331
 
 
332
  best_label = result["labels"][0]
333
  best_score = result["scores"][0]
334
 
335
  return {"classification": best_label, "confidence": best_score}
336
  except asyncio.TimeoutError:
 
337
  return JSONResponse(content="Classification timed out. Try a shorter input or increase timeout.", status_code=504)
 
 
 
338
  except Exception as e:
339
- return JSONResponse(content=f"Error in classification: {str(e)}", status_code=500)
 
340
 
341
  # Set up CORS middleware
342
  origins = ["*"] # or specify your list of allowed origins
 
267
  except Exception as e:
268
  raise HTTPException(status_code=500, detail=f"Error during transcription: {str(e)}")
269
 
270
+ # Define the ThreadPoolExecutor globally to manage asynchronous execution
271
+ executor = ThreadPoolExecutor(max_workers=10)
272
+
273
+ # Predefined classifications
274
  labels = [
275
  "All Pricing copy quote requested",
276
  "Change to quote",
 
288
  "Status Check (possibly)"
289
  ]
290
 
291
+ @app.post("/fast_classify/", description="Quickly classify text into predefined categories.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  async def fast_classify_text(statement: str = Form(...)):
293
  try:
294
+ # Use run_in_executor to handle the synchronous model call asynchronously
 
 
 
 
 
295
  loop = asyncio.get_running_loop()
296
  result = await loop.run_in_executor(
297
+ executor,
298
  lambda: nlp_sequence_classification(statement, labels, multi_label=False)
299
  )
300
 
301
+ # Extract the best label and score
302
  best_label = result["labels"][0]
303
  best_score = result["scores"][0]
304
 
305
  return {"classification": best_label, "confidence": best_score}
306
  except asyncio.TimeoutError:
307
+ # Handle timeout
308
  return JSONResponse(content="Classification timed out. Try a shorter input or increase timeout.", status_code=504)
309
+ except HTTPException as http_exc:
310
+ # Handle HTTP errors
311
+ return JSONResponse(content=f"HTTP error: {http_exc.detail}", status_code=http_exc.status_code)
312
  except Exception as e:
313
+ # Handle general errors
314
+ return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
315
 
316
  # Set up CORS middleware
317
  origins = ["*"] # or specify your list of allowed origins