MJobe commited on
Commit
f0e6e2e
1 Parent(s): fd35e4e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +36 -37
main.py CHANGED
@@ -406,54 +406,53 @@ def get_sub_classification(statement: str) -> str:
406
  return sub_label
407
  return "None" # Default to "None" if no keywords match
408
 
409
- @app.post("/classify_with_subcategory/", response_model=ClassificationResponse, description="Classify text into main categories with subcategories.")
410
- async def classify_with_subcategory(statement: str = Form(...)) -> ClassificationResponse:
411
  try:
412
- # Keyword-based classification override
413
- if check_keywords(statement, change_to_quote_keywords):
414
- main_best_label = "Change to quote"
415
- main_best_score = 1.0 # High confidence since it's a direct match
416
- elif check_keywords(statement, copy_quote_requested_keywords):
417
- main_best_label = "Copy quote requested"
418
- main_best_score = 1.0
 
 
 
 
 
 
 
 
 
 
 
 
419
  else:
420
- # If no keywords matched, perform the main classification using the model
421
  loop = asyncio.get_running_loop()
422
- main_classification_result = await loop.run_in_executor(
423
- None,
424
- lambda: nlp_sequence_classification(statement, main_labels, multi_label=False)
425
  )
426
 
427
- # Extract the best main classification label and confidence score
428
- main_best_label = main_classification_result["labels"][0]
429
- main_best_score = main_classification_result["scores"][0]
430
-
431
- # Perform sub-classification only if the main classification is "Copy quote requested"
432
- if main_best_label == "Copy quote requested":
433
- best_sub_label = get_sub_classification(statement)
434
- else:
435
- best_sub_label = "None"
436
-
437
- # Gather the scores for response
438
- scores = {"main": main_best_score}
439
- if best_sub_label != "None":
440
- scores[best_sub_label] = 1.0 # Assign full confidence to sub-classification matches
441
-
442
- return ClassificationResponse(
443
- classification=main_best_label,
444
- sub_classification=best_sub_label,
445
- confidence=main_best_score,
446
- scores=scores
447
- )
448
 
449
  except asyncio.TimeoutError:
450
- # Handle timeout errors
451
  return JSONResponse(content="Classification timed out. Try a shorter input or increase timeout.", status_code=504)
452
  except HTTPException as http_exc:
453
- # Handle HTTP errors
454
  return JSONResponse(content=f"HTTP error: {http_exc.detail}", status_code=http_exc.status_code)
455
  except Exception as e:
456
- # Handle any other errors
457
  return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
458
 
459
  # Set up CORS middleware
 
406
  return sub_label
407
  return "None" # Default to "None" if no keywords match
408
 
409
+ @app.post("/classify_with_subcategory/", description="Quickly classify text into predefined categories.")
410
+ async def fast_classify_text(statement: str = Form(...)):
411
  try:
412
+ # Check for empty or "N/A" statements
413
+ if not statement or statement.strip().lower() == "n/a":
414
+ return {"classification": "Note not clear", "confidence": 1.0, "sub_classification": "None", "scores": {}}
415
+
416
+ # Determine main classification based on keywords
417
+ if any(keyword.lower() in statement.lower() for keyword in change_to_quote_keywords):
418
+ main_classification = "Change to Quote"
419
+ sub_classification = "None"
420
+ elif any(keyword.lower() in statement.lower() for keyword in copy_quote_requested_keywords):
421
+ main_classification = "Copy Quote Requested"
422
+ # Perform sub-classification for Copy Quote Requested
423
+ if "msrp" in statement.lower():
424
+ sub_classification = "MRSP"
425
+ elif "all pricing" in statement.lower():
426
+ sub_classification = "All"
427
+ elif "direct" in statement.lower():
428
+ sub_classification = "Direct"
429
+ else:
430
+ sub_classification = "None" # No sub-classification when keywords don’t match
431
  else:
432
+ # Call the Hugging Face model for cases where keywords don’t match
433
  loop = asyncio.get_running_loop()
434
+ result = await loop.run_in_executor(
435
+ executor,
436
+ lambda: nlp_sequence_classification(statement, labels, multi_label=False)
437
  )
438
 
439
+ main_classification = result["labels"][0]
440
+ main_confidence = result["scores"][0]
441
+ scores = dict(zip(result["labels"], result["scores"]))
442
+ sub_classification = "None" # Set sub-classification to None for non-matching keywords
443
+
444
+ return {
445
+ "classification": main_classification,
446
+ "confidence": main_confidence,
447
+ "sub_classification": sub_classification,
448
+ "scores": scores
449
+ }
 
 
 
 
 
 
 
 
 
 
450
 
451
  except asyncio.TimeoutError:
 
452
  return JSONResponse(content="Classification timed out. Try a shorter input or increase timeout.", status_code=504)
453
  except HTTPException as http_exc:
 
454
  return JSONResponse(content=f"HTTP error: {http_exc.detail}", status_code=http_exc.status_code)
455
  except Exception as e:
 
456
  return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
457
 
458
  # Set up CORS middleware