MJobe commited on
Commit
ab46adf
1 Parent(s): 31d9e37

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +66 -69
main.py CHANGED
@@ -19,6 +19,8 @@ import logging
19
  import asyncio
20
  from concurrent.futures import ThreadPoolExecutor
21
  import re
 
 
22
 
23
  app = FastAPI()
24
 
@@ -368,91 +370,86 @@ async def fast_classify_text(statement: str = Form(...)):
368
  # Handle general errors
369
  return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
370
 
371
- # Labels for main classification
372
- labels = [
373
  "Change to quote",
374
  "Copy quote requested",
375
  "Expired Quote",
376
  "Notes not clear"
377
  ]
378
 
379
- # Keywords for sub-classifications
380
- keyword_map = {
381
- "MRSP": ["MSRP", "MRSP copy quote", "msrp only"],
382
- "Direct": ["Direct quote", "send directly"],
383
- "All": ["All Pricing", "all pricing"],
384
- "MRSP & All": ["MSRP & All Pricing", "msrp only with all pricing"]
385
- }
386
-
387
- # Function to detect if input is blank or vague
388
- def is_blank_or_vague(text):
389
- # Checks for empty or only contains general filler words (adjust as needed)
390
- return not text.strip() or re.match(r'^\s*(please|send|quote|request|thank you|thanks)\s*$', text, re.IGNORECASE)
391
-
392
- # Function to identify sub-classifications based on keywords
393
- def get_sub_classification(text):
394
- sub_labels = []
395
- for sub_class, keywords in keyword_map.items():
396
- if any(keyword.lower() in text.lower() for keyword in keywords):
397
- sub_labels.append(sub_class)
398
- return sub_labels if sub_labels else ["Uncategorized"]
399
-
400
- @app.post("/classify_text/")
401
- async def classify_text(statement: str = Form(...)):
402
- try:
403
- # Handle blank or vague text as "Notes not clear"
404
- if is_blank_or_vague(statement):
405
- return {
406
- "main_classification": {
407
- "label": "Notes not clear",
408
- "confidence": 1.0,
409
- "scores": {"Notes not clear": 1.0}
410
- },
411
- "sub_classification": {
412
- "labels": ["Uncategorized"],
413
- "scores": {"Uncategorized": 1.0}
414
- }
415
- }
416
 
417
- # Run main classification in executor for async handling
418
- loop = asyncio.get_running_loop()
419
- main_classification_task = loop.run_in_executor(
420
- None,
421
- lambda: nlp_main_classification(statement, labels)
422
- )
423
-
424
- # Await result
425
- main_class_result = await main_classification_task
 
426
 
427
- # Extract main classification label and scores
428
- main_class_scores = {label: score for label, score in zip(main_class_result["labels"], main_class_result["scores"])}
429
- best_main_classification = main_class_result["labels"][0]
430
- best_main_score = main_class_result["scores"][0]
431
 
432
- # Detect sub-classifications using keywords
433
- sub_classification = get_sub_classification(statement)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
 
435
- # Assign default high confidence for keyword-based sub-classification
436
- sub_class_scores = {sub: 1.0 for sub in sub_classification}
 
 
 
 
 
 
 
437
 
438
- # Return results
439
- return {
440
- "main_classification": {
441
- "label": best_main_classification,
442
- "confidence": best_main_score,
443
- "scores": main_class_scores
444
- },
445
- "sub_classification": {
446
- "labels": sub_classification,
447
- "scores": sub_class_scores
448
- }
449
- }
 
450
 
451
  except asyncio.TimeoutError:
452
- return JSONResponse(content="Classification timed out.", status_code=504)
 
453
  except HTTPException as http_exc:
 
454
  return JSONResponse(content=f"HTTP error: {http_exc.detail}", status_code=http_exc.status_code)
455
  except Exception as e:
 
456
  return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
457
 
458
  # Set up CORS middleware
 
19
  import asyncio
20
  from concurrent.futures import ThreadPoolExecutor
21
  import re
22
+ from pydantic import BaseModel
23
+ from typing import List, Dict, Any
24
 
25
  app = FastAPI()
26
 
 
370
  # Handle general errors
371
  return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
372
 
373
+ # Labels for main and sub classifications
374
+ main_labels = [
375
  "Change to quote",
376
  "Copy quote requested",
377
  "Expired Quote",
378
  "Notes not clear"
379
  ]
380
 
381
+ sub_labels = [
382
+ "MRSP",
383
+ "Direct",
384
+ "All",
385
+ "MRSP & All"
386
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
 
388
+ # Define a model for the response
389
+ class ClassificationResponse(BaseModel):
390
+ classification: str
391
+ sub_classification: str
392
+ confidence: float
393
+ scores: Dict[str, float]
394
+
395
+ # Keyword dictionaries for overriding classifications
396
+ change_to_quote_keywords = ["ATP", "Add", "Revised", "Per", "Remove", "Advise"]
397
+ copy_quote_requested_keywords = ["MSRP", "Quote", "Send", "Copy", "All pricing", "Retail"]
398
 
399
+ # Helper function to check for keywords in a case-insensitive way
400
+ def check_keywords(statement: str, keywords: List[str]) -> bool:
401
+ return any(re.search(rf"\b{keyword}\b", statement, re.IGNORECASE) for keyword in keywords)
 
402
 
403
+ @app.post("/classify_with_subcategory/", response_model=ClassificationResponse, description="Classify text into main categories with subcategories.")
404
+ async def classify_with_subcategory(statement: str = Form(...)) -> ClassificationResponse:
405
+ try:
406
+ # Check for keyword-based classification
407
+ if check_keywords(statement, change_to_quote_keywords):
408
+ main_best_label = "Change to quote"
409
+ main_best_score = 1.0 # High confidence since it's a direct match
410
+ elif check_keywords(statement, copy_quote_requested_keywords):
411
+ main_best_label = "Copy quote requested"
412
+ main_best_score = 1.0
413
+ else:
414
+ # If no keywords matched, perform the main classification using the model
415
+ loop = asyncio.get_running_loop()
416
+ main_classification_result = await loop.run_in_executor(
417
+ None,
418
+ lambda: nlp_sequence_classification(statement, main_labels, multi_label=False)
419
+ )
420
 
421
+ # Extract the best main classification label and confidence score
422
+ main_best_label = main_classification_result["labels"][0]
423
+ main_best_score = main_classification_result["scores"][0]
424
+
425
+ # Perform sub-classification if main classification was successful
426
+ sub_classification_result = await loop.run_in_executor(
427
+ None,
428
+ lambda: nlp_sequence_classification(statement, sub_labels, multi_label=True)
429
+ )
430
 
431
+ # Extract all sub classification scores
432
+ sub_scores = dict(zip(sub_classification_result["labels"], sub_classification_result["scores"]))
433
+
434
+ # Determine the best sub classification label
435
+ best_sub_label = sub_classification_result["labels"][0] if sub_classification_result["labels"] else "None"
436
+ best_sub_score = sub_classification_result["scores"][0] if sub_classification_result["scores"] else 0.0
437
+
438
+ return ClassificationResponse(
439
+ classification=main_best_label,
440
+ sub_classification=best_sub_label,
441
+ confidence=main_best_score,
442
+ scores={"main": main_best_score, **sub_scores}
443
+ )
444
 
445
  except asyncio.TimeoutError:
446
+ # Handle timeout errors
447
+ return JSONResponse(content="Classification timed out. Try a shorter input or increase timeout.", status_code=504)
448
  except HTTPException as http_exc:
449
+ # Handle HTTP errors
450
  return JSONResponse(content=f"HTTP error: {http_exc.detail}", status_code=http_exc.status_code)
451
  except Exception as e:
452
+ # Handle any other errors
453
  return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
454
 
455
  # Set up CORS middleware