Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -406,54 +406,53 @@ def get_sub_classification(statement: str) -> str:
|
|
406 |
return sub_label
|
407 |
return "None" # Default to "None" if no keywords match
|
408 |
|
409 |
-
@app.post("/classify_with_subcategory/",
|
410 |
-
async def
|
411 |
try:
|
412 |
-
#
|
413 |
-
if
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
419 |
else:
|
420 |
-
#
|
421 |
loop = asyncio.get_running_loop()
|
422 |
-
|
423 |
-
|
424 |
-
lambda: nlp_sequence_classification(statement,
|
425 |
)
|
426 |
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
scores = {"main": main_best_score}
|
439 |
-
if best_sub_label != "None":
|
440 |
-
scores[best_sub_label] = 1.0 # Assign full confidence to sub-classification matches
|
441 |
-
|
442 |
-
return ClassificationResponse(
|
443 |
-
classification=main_best_label,
|
444 |
-
sub_classification=best_sub_label,
|
445 |
-
confidence=main_best_score,
|
446 |
-
scores=scores
|
447 |
-
)
|
448 |
|
449 |
except asyncio.TimeoutError:
|
450 |
-
# Handle timeout errors
|
451 |
return JSONResponse(content="Classification timed out. Try a shorter input or increase timeout.", status_code=504)
|
452 |
except HTTPException as http_exc:
|
453 |
-
# Handle HTTP errors
|
454 |
return JSONResponse(content=f"HTTP error: {http_exc.detail}", status_code=http_exc.status_code)
|
455 |
except Exception as e:
|
456 |
-
# Handle any other errors
|
457 |
return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
|
458 |
|
459 |
# Set up CORS middleware
|
|
|
406 |
return sub_label
|
407 |
return "None" # Default to "None" if no keywords match
|
408 |
|
409 |
+
@app.post("/classify_with_subcategory/", description="Quickly classify text into predefined categories.")
|
410 |
+
async def fast_classify_text(statement: str = Form(...)):
|
411 |
try:
|
412 |
+
# Check for empty or "N/A" statements
|
413 |
+
if not statement or statement.strip().lower() == "n/a":
|
414 |
+
return {"classification": "Note not clear", "confidence": 1.0, "sub_classification": "None", "scores": {}}
|
415 |
+
|
416 |
+
# Determine main classification based on keywords
|
417 |
+
if any(keyword.lower() in statement.lower() for keyword in change_to_quote_keywords):
|
418 |
+
main_classification = "Change to Quote"
|
419 |
+
sub_classification = "None"
|
420 |
+
elif any(keyword.lower() in statement.lower() for keyword in copy_quote_requested_keywords):
|
421 |
+
main_classification = "Copy Quote Requested"
|
422 |
+
# Perform sub-classification for Copy Quote Requested
|
423 |
+
if "msrp" in statement.lower():
|
424 |
+
sub_classification = "MRSP"
|
425 |
+
elif "all pricing" in statement.lower():
|
426 |
+
sub_classification = "All"
|
427 |
+
elif "direct" in statement.lower():
|
428 |
+
sub_classification = "Direct"
|
429 |
+
else:
|
430 |
+
sub_classification = "None" # No sub-classification when keywords don’t match
|
431 |
else:
|
432 |
+
# Call the Hugging Face model for cases where keywords don’t match
|
433 |
loop = asyncio.get_running_loop()
|
434 |
+
result = await loop.run_in_executor(
|
435 |
+
executor,
|
436 |
+
lambda: nlp_sequence_classification(statement, labels, multi_label=False)
|
437 |
)
|
438 |
|
439 |
+
main_classification = result["labels"][0]
|
440 |
+
main_confidence = result["scores"][0]
|
441 |
+
scores = dict(zip(result["labels"], result["scores"]))
|
442 |
+
sub_classification = "None" # Set sub-classification to None for non-matching keywords
|
443 |
+
|
444 |
+
return {
|
445 |
+
"classification": main_classification,
|
446 |
+
"confidence": main_confidence,
|
447 |
+
"sub_classification": sub_classification,
|
448 |
+
"scores": scores
|
449 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
450 |
|
451 |
except asyncio.TimeoutError:
|
|
|
452 |
return JSONResponse(content="Classification timed out. Try a shorter input or increase timeout.", status_code=504)
|
453 |
except HTTPException as http_exc:
|
|
|
454 |
return JSONResponse(content=f"HTTP error: {http_exc.detail}", status_code=http_exc.status_code)
|
455 |
except Exception as e:
|
|
|
456 |
return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
|
457 |
|
458 |
# Set up CORS middleware
|