Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -19,6 +19,8 @@ import logging
|
|
19 |
import asyncio
|
20 |
from concurrent.futures import ThreadPoolExecutor
|
21 |
import re
|
|
|
|
|
22 |
|
23 |
app = FastAPI()
|
24 |
|
@@ -368,91 +370,86 @@ async def fast_classify_text(statement: str = Form(...)):
|
|
368 |
# Handle general errors
|
369 |
return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
|
370 |
|
371 |
-
# Labels for main
|
372 |
-
|
373 |
"Change to quote",
|
374 |
"Copy quote requested",
|
375 |
"Expired Quote",
|
376 |
"Notes not clear"
|
377 |
]
|
378 |
|
379 |
-
|
380 |
-
|
381 |
-
"
|
382 |
-
"
|
383 |
-
"
|
384 |
-
|
385 |
-
}
|
386 |
-
|
387 |
-
# Function to detect if input is blank or vague
|
388 |
-
def is_blank_or_vague(text):
|
389 |
-
# Checks for empty or only contains general filler words (adjust as needed)
|
390 |
-
return not text.strip() or re.match(r'^\s*(please|send|quote|request|thank you|thanks)\s*$', text, re.IGNORECASE)
|
391 |
-
|
392 |
-
# Function to identify sub-classifications based on keywords
|
393 |
-
def get_sub_classification(text):
|
394 |
-
sub_labels = []
|
395 |
-
for sub_class, keywords in keyword_map.items():
|
396 |
-
if any(keyword.lower() in text.lower() for keyword in keywords):
|
397 |
-
sub_labels.append(sub_class)
|
398 |
-
return sub_labels if sub_labels else ["Uncategorized"]
|
399 |
-
|
400 |
-
@app.post("/classify_text/")
|
401 |
-
async def classify_text(statement: str = Form(...)):
|
402 |
-
try:
|
403 |
-
# Handle blank or vague text as "Notes not clear"
|
404 |
-
if is_blank_or_vague(statement):
|
405 |
-
return {
|
406 |
-
"main_classification": {
|
407 |
-
"label": "Notes not clear",
|
408 |
-
"confidence": 1.0,
|
409 |
-
"scores": {"Notes not clear": 1.0}
|
410 |
-
},
|
411 |
-
"sub_classification": {
|
412 |
-
"labels": ["Uncategorized"],
|
413 |
-
"scores": {"Uncategorized": 1.0}
|
414 |
-
}
|
415 |
-
}
|
416 |
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
|
|
426 |
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
best_main_score = main_class_result["scores"][0]
|
431 |
|
432 |
-
|
433 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
434 |
|
435 |
-
|
436 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
437 |
|
438 |
-
#
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
|
|
450 |
|
451 |
except asyncio.TimeoutError:
|
452 |
-
|
|
|
453 |
except HTTPException as http_exc:
|
|
|
454 |
return JSONResponse(content=f"HTTP error: {http_exc.detail}", status_code=http_exc.status_code)
|
455 |
except Exception as e:
|
|
|
456 |
return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
|
457 |
|
458 |
# Set up CORS middleware
|
|
|
19 |
import asyncio
|
20 |
from concurrent.futures import ThreadPoolExecutor
|
21 |
import re
|
22 |
+
from pydantic import BaseModel
|
23 |
+
from typing import List, Dict, Any
|
24 |
|
25 |
app = FastAPI()
|
26 |
|
|
|
370 |
# Handle general errors
|
371 |
return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
|
372 |
|
373 |
+
# Labels for main and sub classifications
|
374 |
+
main_labels = [
|
375 |
"Change to quote",
|
376 |
"Copy quote requested",
|
377 |
"Expired Quote",
|
378 |
"Notes not clear"
|
379 |
]
|
380 |
|
381 |
+
sub_labels = [
|
382 |
+
"MRSP",
|
383 |
+
"Direct",
|
384 |
+
"All",
|
385 |
+
"MRSP & All"
|
386 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
387 |
|
388 |
+
# Define a model for the response
|
389 |
+
class ClassificationResponse(BaseModel):
|
390 |
+
classification: str
|
391 |
+
sub_classification: str
|
392 |
+
confidence: float
|
393 |
+
scores: Dict[str, float]
|
394 |
+
|
395 |
+
# Keyword dictionaries for overriding classifications
|
396 |
+
change_to_quote_keywords = ["ATP", "Add", "Revised", "Per", "Remove", "Advise"]
|
397 |
+
copy_quote_requested_keywords = ["MSRP", "Quote", "Send", "Copy", "All pricing", "Retail"]
|
398 |
|
399 |
+
# Helper function to check for keywords in a case-insensitive way
|
400 |
+
def check_keywords(statement: str, keywords: List[str]) -> bool:
|
401 |
+
return any(re.search(rf"\b{keyword}\b", statement, re.IGNORECASE) for keyword in keywords)
|
|
|
402 |
|
403 |
+
@app.post("/classify_with_subcategory/", response_model=ClassificationResponse, description="Classify text into main categories with subcategories.")
|
404 |
+
async def classify_with_subcategory(statement: str = Form(...)) -> ClassificationResponse:
|
405 |
+
try:
|
406 |
+
# Check for keyword-based classification
|
407 |
+
if check_keywords(statement, change_to_quote_keywords):
|
408 |
+
main_best_label = "Change to quote"
|
409 |
+
main_best_score = 1.0 # High confidence since it's a direct match
|
410 |
+
elif check_keywords(statement, copy_quote_requested_keywords):
|
411 |
+
main_best_label = "Copy quote requested"
|
412 |
+
main_best_score = 1.0
|
413 |
+
else:
|
414 |
+
# If no keywords matched, perform the main classification using the model
|
415 |
+
loop = asyncio.get_running_loop()
|
416 |
+
main_classification_result = await loop.run_in_executor(
|
417 |
+
None,
|
418 |
+
lambda: nlp_sequence_classification(statement, main_labels, multi_label=False)
|
419 |
+
)
|
420 |
|
421 |
+
# Extract the best main classification label and confidence score
|
422 |
+
main_best_label = main_classification_result["labels"][0]
|
423 |
+
main_best_score = main_classification_result["scores"][0]
|
424 |
+
|
425 |
+
# Perform sub-classification if main classification was successful
|
426 |
+
sub_classification_result = await loop.run_in_executor(
|
427 |
+
None,
|
428 |
+
lambda: nlp_sequence_classification(statement, sub_labels, multi_label=True)
|
429 |
+
)
|
430 |
|
431 |
+
# Extract all sub classification scores
|
432 |
+
sub_scores = dict(zip(sub_classification_result["labels"], sub_classification_result["scores"]))
|
433 |
+
|
434 |
+
# Determine the best sub classification label
|
435 |
+
best_sub_label = sub_classification_result["labels"][0] if sub_classification_result["labels"] else "None"
|
436 |
+
best_sub_score = sub_classification_result["scores"][0] if sub_classification_result["scores"] else 0.0
|
437 |
+
|
438 |
+
return ClassificationResponse(
|
439 |
+
classification=main_best_label,
|
440 |
+
sub_classification=best_sub_label,
|
441 |
+
confidence=main_best_score,
|
442 |
+
scores={"main": main_best_score, **sub_scores}
|
443 |
+
)
|
444 |
|
445 |
except asyncio.TimeoutError:
|
446 |
+
# Handle timeout errors
|
447 |
+
return JSONResponse(content="Classification timed out. Try a shorter input or increase timeout.", status_code=504)
|
448 |
except HTTPException as http_exc:
|
449 |
+
# Handle HTTP errors
|
450 |
return JSONResponse(content=f"HTTP error: {http_exc.detail}", status_code=http_exc.status_code)
|
451 |
except Exception as e:
|
452 |
+
# Handle any other errors
|
453 |
return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
|
454 |
|
455 |
# Set up CORS middleware
|