Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -39,6 +39,8 @@ nlp_classification = pipeline("text-classification", model="distilbert/distilber
|
|
39 |
nlp_classification_v2 = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest")
|
40 |
nlp_speech_to_text = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h")
|
41 |
nlp_sequence_classification = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
|
|
|
|
42 |
description = """
|
43 |
## Image-based Document QA
|
44 |
This API performs document question answering using a LayoutLMv2-based model.
|
@@ -365,6 +367,93 @@ async def fast_classify_text(statement: str = Form(...)):
|
|
365 |
except Exception as e:
|
366 |
# Handle general errors
|
367 |
return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
368 |
|
369 |
# Set up CORS middleware
|
370 |
origins = ["*"] # or specify your list of allowed origins
|
|
|
39 |
nlp_classification_v2 = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest")
|
40 |
nlp_speech_to_text = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h")
|
41 |
nlp_sequence_classification = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
42 |
+
nlp_main_classification = pipeline("zero-shot-classification", model="roberta-large-mnli")
|
43 |
+
|
44 |
description = """
|
45 |
## Image-based Document QA
|
46 |
This API performs document question answering using a LayoutLMv2-based model.
|
|
|
367 |
except Exception as e:
|
368 |
# Handle general errors
|
369 |
return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
|
370 |
+
|
371 |
+
# Labels for main classification
|
372 |
+
labels = [
|
373 |
+
"Change to quote",
|
374 |
+
"Copy quote requested",
|
375 |
+
"Expired Quote",
|
376 |
+
"Notes not clear"
|
377 |
+
]
|
378 |
+
|
379 |
+
# Keywords for sub-classifications
|
380 |
+
keyword_map = {
|
381 |
+
"MRSP": ["MSRP", "MRSP copy quote", "msrp only"],
|
382 |
+
"Direct": ["Direct quote", "send directly"],
|
383 |
+
"All": ["All Pricing", "all pricing"],
|
384 |
+
"MRSP & All": ["MSRP & All Pricing", "msrp only with all pricing"]
|
385 |
+
}
|
386 |
+
|
387 |
+
# Function to detect if input is blank or vague
|
388 |
+
def is_blank_or_vague(text):
|
389 |
+
# Checks for empty or only contains general filler words (adjust as needed)
|
390 |
+
return not text.strip() or re.match(r'^\s*(please|send|quote|request|thank you|thanks)\s*$', text, re.IGNORECASE)
|
391 |
+
|
392 |
+
# Function to identify sub-classifications based on keywords
|
393 |
+
def get_sub_classification(text):
|
394 |
+
sub_labels = []
|
395 |
+
for sub_class, keywords in keyword_map.items():
|
396 |
+
if any(keyword.lower() in text.lower() for keyword in keywords):
|
397 |
+
sub_labels.append(sub_class)
|
398 |
+
return sub_labels if sub_labels else ["Uncategorized"]
|
399 |
+
|
400 |
+
@app.post("/classify_text/")
|
401 |
+
async def classify_text(statement: str = Form(...)):
|
402 |
+
try:
|
403 |
+
# Handle blank or vague text as "Notes not clear"
|
404 |
+
if is_blank_or_vague(statement):
|
405 |
+
return {
|
406 |
+
"main_classification": {
|
407 |
+
"label": "Notes not clear",
|
408 |
+
"confidence": 1.0,
|
409 |
+
"scores": {"Notes not clear": 1.0}
|
410 |
+
},
|
411 |
+
"sub_classification": {
|
412 |
+
"labels": ["Uncategorized"],
|
413 |
+
"scores": {"Uncategorized": 1.0}
|
414 |
+
}
|
415 |
+
}
|
416 |
+
|
417 |
+
# Run main classification in executor for async handling
|
418 |
+
loop = asyncio.get_running_loop()
|
419 |
+
main_classification_task = loop.run_in_executor(
|
420 |
+
None,
|
421 |
+
lambda: nlp_main_classification(statement, labels)
|
422 |
+
)
|
423 |
+
|
424 |
+
# Await result
|
425 |
+
main_class_result = await main_classification_task
|
426 |
+
|
427 |
+
# Extract main classification label and scores
|
428 |
+
main_class_scores = {label: score for label, score in zip(main_class_result["labels"], main_class_result["scores"])}
|
429 |
+
best_main_classification = main_class_result["labels"][0]
|
430 |
+
best_main_score = main_class_result["scores"][0]
|
431 |
+
|
432 |
+
# Detect sub-classifications using keywords
|
433 |
+
sub_classification = get_sub_classification(statement)
|
434 |
+
|
435 |
+
# Assign default high confidence for keyword-based sub-classification
|
436 |
+
sub_class_scores = {sub: 1.0 for sub in sub_classification}
|
437 |
+
|
438 |
+
# Return results
|
439 |
+
return {
|
440 |
+
"main_classification": {
|
441 |
+
"label": best_main_classification,
|
442 |
+
"confidence": best_main_score,
|
443 |
+
"scores": main_class_scores
|
444 |
+
},
|
445 |
+
"sub_classification": {
|
446 |
+
"labels": sub_classification,
|
447 |
+
"scores": sub_class_scores
|
448 |
+
}
|
449 |
+
}
|
450 |
+
|
451 |
+
except asyncio.TimeoutError:
|
452 |
+
return JSONResponse(content="Classification timed out.", status_code=504)
|
453 |
+
except HTTPException as http_exc:
|
454 |
+
return JSONResponse(content=f"HTTP error: {http_exc.detail}", status_code=http_exc.status_code)
|
455 |
+
except Exception as e:
|
456 |
+
return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
|
457 |
|
458 |
# Set up CORS middleware
|
459 |
origins = ["*"] # or specify your list of allowed origins
|