Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -370,7 +370,7 @@ async def fast_classify_text(statement: str = Form(...)):
|
|
370 |
# Handle general errors
|
371 |
return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
|
372 |
|
373 |
-
# Labels for main
|
374 |
main_labels = [
|
375 |
"Change to quote",
|
376 |
"Copy quote requested",
|
@@ -378,13 +378,6 @@ main_labels = [
|
|
378 |
"Notes not clear"
|
379 |
]
|
380 |
|
381 |
-
sub_labels = [
|
382 |
-
"MRSP",
|
383 |
-
"Direct",
|
384 |
-
"All",
|
385 |
-
"MRSP & All"
|
386 |
-
]
|
387 |
-
|
388 |
# Define a model for the response
|
389 |
class ClassificationResponse(BaseModel):
|
390 |
classification: str
|
@@ -395,11 +388,24 @@ class ClassificationResponse(BaseModel):
|
|
395 |
# Keyword dictionaries for overriding classifications
|
396 |
change_to_quote_keywords = ["ATP", "Add", "Revised", "Per", "Remove", "Advise"]
|
397 |
copy_quote_requested_keywords = ["MSRP", "Quote", "Send", "Copy", "All pricing", "Retail"]
|
|
|
|
|
|
|
|
|
|
|
|
|
398 |
|
399 |
# Helper function to check for keywords in a case-insensitive way
|
400 |
def check_keywords(statement: str, keywords: List[str]) -> bool:
|
401 |
return any(re.search(rf"\b{keyword}\b", statement, re.IGNORECASE) for keyword in keywords)
|
402 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
403 |
@app.post("/classify_with_subcategory/", response_model=ClassificationResponse, description="Classify text into main categories with subcategories.")
|
404 |
async def classify_with_subcategory(statement: str = Form(...)) -> ClassificationResponse:
|
405 |
try:
|
@@ -422,25 +428,22 @@ async def classify_with_subcategory(statement: str = Form(...)) -> Classificatio
|
|
422 |
main_best_label = main_classification_result["labels"][0]
|
423 |
main_best_score = main_classification_result["scores"][0]
|
424 |
|
425 |
-
# Perform sub-classification
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
)
|
431 |
|
432 |
-
#
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
best_sub_label = sub_classification_result["labels"][0] if sub_classification_result["labels"] else "None"
|
437 |
-
best_sub_score = sub_classification_result["scores"][0] if sub_classification_result["scores"] else 0.0
|
438 |
|
439 |
return ClassificationResponse(
|
440 |
classification=main_best_label,
|
441 |
sub_classification=best_sub_label,
|
442 |
confidence=main_best_score,
|
443 |
-
scores=
|
444 |
)
|
445 |
|
446 |
except asyncio.TimeoutError:
|
|
|
370 |
# Handle general errors
|
371 |
return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
|
372 |
|
373 |
+
# Labels for main classifications
|
374 |
main_labels = [
|
375 |
"Change to quote",
|
376 |
"Copy quote requested",
|
|
|
378 |
"Notes not clear"
|
379 |
]
|
380 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
381 |
# Define a model for the response
|
382 |
class ClassificationResponse(BaseModel):
|
383 |
classification: str
|
|
|
388 |
# Keyword dictionaries for overriding classifications
|
389 |
change_to_quote_keywords = ["ATP", "Add", "Revised", "Per", "Remove", "Advise"]
|
390 |
copy_quote_requested_keywords = ["MSRP", "Quote", "Send", "Copy", "All pricing", "Retail"]
|
391 |
+
sub_classification_keywords = {
|
392 |
+
"MRSP": ["MSRP"],
|
393 |
+
"Direct": ["Direct"],
|
394 |
+
"All": ["All pricing"],
|
395 |
+
"MRSP & All": ["MSRP", "All pricing"]
|
396 |
+
}
|
397 |
|
398 |
# Helper function to check for keywords in a case-insensitive way
|
399 |
def check_keywords(statement: str, keywords: List[str]) -> bool:
|
400 |
return any(re.search(rf"\b{keyword}\b", statement, re.IGNORECASE) for keyword in keywords)
|
401 |
|
402 |
+
# Function to determine sub-classification based on keywords
|
403 |
+
def get_sub_classification(statement: str) -> str:
|
404 |
+
for sub_label, keywords in sub_classification_keywords.items():
|
405 |
+
if all(check_keywords(statement, [keyword]) for keyword in keywords):
|
406 |
+
return sub_label
|
407 |
+
return "None" # Default to "None" if no keywords match
|
408 |
+
|
409 |
@app.post("/classify_with_subcategory/", response_model=ClassificationResponse, description="Classify text into main categories with subcategories.")
|
410 |
async def classify_with_subcategory(statement: str = Form(...)) -> ClassificationResponse:
|
411 |
try:
|
|
|
428 |
main_best_label = main_classification_result["labels"][0]
|
429 |
main_best_score = main_classification_result["scores"][0]
|
430 |
|
431 |
+
# Perform sub-classification only if the main classification is "Copy quote requested"
|
432 |
+
if main_best_label == "Copy quote requested":
|
433 |
+
best_sub_label = get_sub_classification(statement)
|
434 |
+
else:
|
435 |
+
best_sub_label = "None"
|
|
|
436 |
|
437 |
+
# Gather the scores for response
|
438 |
+
scores = {"main": main_best_score}
|
439 |
+
if best_sub_label != "None":
|
440 |
+
scores[best_sub_label] = 1.0 # Assign full confidence to sub-classification matches
|
|
|
|
|
441 |
|
442 |
return ClassificationResponse(
|
443 |
classification=main_best_label,
|
444 |
sub_classification=best_sub_label,
|
445 |
confidence=main_best_score,
|
446 |
+
scores=scores
|
447 |
)
|
448 |
|
449 |
except asyncio.TimeoutError:
|