MJobe commited on
Commit
1a6d882
1 Parent(s): 3d61dca

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +24 -21
main.py CHANGED
@@ -370,7 +370,7 @@ async def fast_classify_text(statement: str = Form(...)):
370
  # Handle general errors
371
  return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
372
 
373
- # Labels for main and sub classifications
374
  main_labels = [
375
  "Change to quote",
376
  "Copy quote requested",
@@ -378,13 +378,6 @@ main_labels = [
378
  "Notes not clear"
379
  ]
380
 
381
- sub_labels = [
382
- "MRSP",
383
- "Direct",
384
- "All",
385
- "MRSP & All"
386
- ]
387
-
388
  # Define a model for the response
389
  class ClassificationResponse(BaseModel):
390
  classification: str
@@ -395,11 +388,24 @@ class ClassificationResponse(BaseModel):
395
  # Keyword dictionaries for overriding classifications
396
  change_to_quote_keywords = ["ATP", "Add", "Revised", "Per", "Remove", "Advise"]
397
  copy_quote_requested_keywords = ["MSRP", "Quote", "Send", "Copy", "All pricing", "Retail"]
 
 
 
 
 
 
398
 
399
  # Helper function to check for keywords in a case-insensitive way
400
  def check_keywords(statement: str, keywords: List[str]) -> bool:
401
  return any(re.search(rf"\b{keyword}\b", statement, re.IGNORECASE) for keyword in keywords)
402
 
 
 
 
 
 
 
 
403
  @app.post("/classify_with_subcategory/", response_model=ClassificationResponse, description="Classify text into main categories with subcategories.")
404
  async def classify_with_subcategory(statement: str = Form(...)) -> ClassificationResponse:
405
  try:
@@ -422,25 +428,22 @@ async def classify_with_subcategory(statement: str = Form(...)) -> Classificatio
422
  main_best_label = main_classification_result["labels"][0]
423
  main_best_score = main_classification_result["scores"][0]
424
 
425
- # Perform sub-classification regardless of how main classification was determined
426
- loop = asyncio.get_running_loop()
427
- sub_classification_result = await loop.run_in_executor(
428
- None,
429
- lambda: nlp_sequence_classification(statement, sub_labels, multi_label=True)
430
- )
431
 
432
- # Extract all sub classification scores
433
- sub_scores = dict(zip(sub_classification_result["labels"], sub_classification_result["scores"]))
434
-
435
- # Determine the best sub classification label
436
- best_sub_label = sub_classification_result["labels"][0] if sub_classification_result["labels"] else "None"
437
- best_sub_score = sub_classification_result["scores"][0] if sub_classification_result["scores"] else 0.0
438
 
439
  return ClassificationResponse(
440
  classification=main_best_label,
441
  sub_classification=best_sub_label,
442
  confidence=main_best_score,
443
- scores={"main": main_best_score, **sub_scores}
444
  )
445
 
446
  except asyncio.TimeoutError:
 
370
  # Handle general errors
371
  return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
372
 
373
+ # Labels for main classifications
374
  main_labels = [
375
  "Change to quote",
376
  "Copy quote requested",
 
378
  "Notes not clear"
379
  ]
380
 
 
 
 
 
 
 
 
381
  # Define a model for the response
382
  class ClassificationResponse(BaseModel):
383
  classification: str
 
388
  # Keyword dictionaries for overriding classifications
389
  change_to_quote_keywords = ["ATP", "Add", "Revised", "Per", "Remove", "Advise"]
390
  copy_quote_requested_keywords = ["MSRP", "Quote", "Send", "Copy", "All pricing", "Retail"]
391
+ sub_classification_keywords = {
392
+ "MRSP": ["MSRP"],
393
+ "Direct": ["Direct"],
394
+ "All": ["All pricing"],
395
+ "MRSP & All": ["MSRP", "All pricing"]
396
+ }
397
 
398
  # Helper function to check for keywords in a case-insensitive way
399
  def check_keywords(statement: str, keywords: List[str]) -> bool:
400
  return any(re.search(rf"\b{keyword}\b", statement, re.IGNORECASE) for keyword in keywords)
401
 
402
+ # Function to determine sub-classification based on keywords
403
+ def get_sub_classification(statement: str) -> str:
404
+ for sub_label, keywords in sub_classification_keywords.items():
405
+ if all(check_keywords(statement, [keyword]) for keyword in keywords):
406
+ return sub_label
407
+ return "None" # Default to "None" if no keywords match
408
+
409
  @app.post("/classify_with_subcategory/", response_model=ClassificationResponse, description="Classify text into main categories with subcategories.")
410
  async def classify_with_subcategory(statement: str = Form(...)) -> ClassificationResponse:
411
  try:
 
428
  main_best_label = main_classification_result["labels"][0]
429
  main_best_score = main_classification_result["scores"][0]
430
 
431
+ # Perform sub-classification only if the main classification is "Copy quote requested"
432
+ if main_best_label == "Copy quote requested":
433
+ best_sub_label = get_sub_classification(statement)
434
+ else:
435
+ best_sub_label = "None"
 
436
 
437
+ # Gather the scores for response
438
+ scores = {"main": main_best_score}
439
+ if best_sub_label != "None":
440
+ scores[best_sub_label] = 1.0 # Assign full confidence to sub-classification matches
 
 
441
 
442
  return ClassificationResponse(
443
  classification=main_best_label,
444
  sub_classification=best_sub_label,
445
  confidence=main_best_score,
446
+ scores=scores
447
  )
448
 
449
  except asyncio.TimeoutError: