Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -35,7 +35,7 @@ nlp_qa_v3 = pipeline("question-answering", model="deepset/roberta-base-squad2")
|
|
35 |
nlp_classification = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english")
|
36 |
nlp_classification_v2 = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest")
|
37 |
nlp_speech_to_text = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h")
|
38 |
-
|
39 |
|
40 |
description = """
|
41 |
## Image-based Document QA
|
@@ -264,25 +264,37 @@ async def test_transcription(file: UploadFile = File(...)):
|
|
264 |
except Exception as e:
|
265 |
raise HTTPException(status_code=500, detail=f"Error during transcription: {str(e)}")
|
266 |
|
267 |
-
|
268 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
try:
|
270 |
-
#
|
271 |
-
|
272 |
-
"Please classify the statement in one of the following classifications: "
|
273 |
-
"Negative, Neutral, Positive"
|
274 |
-
f"Statement: {statement}"
|
275 |
-
)
|
276 |
|
277 |
-
#
|
278 |
-
|
|
|
279 |
|
280 |
-
|
281 |
-
classification = result[0]['generated_text'].strip()
|
282 |
-
|
283 |
-
return {"classification": classification}
|
284 |
except Exception as e:
|
285 |
-
return JSONResponse(content=f"Error in
|
286 |
|
287 |
# Set up CORS middleware
|
288 |
origins = ["*"] # or specify your list of allowed origins
|
|
|
35 |
nlp_classification = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english")
|
36 |
nlp_classification_v2 = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest")
|
37 |
nlp_speech_to_text = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h")
|
38 |
+
nlp_sequence_classification = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
39 |
|
40 |
description = """
|
41 |
## Image-based Document QA
|
|
|
264 |
except Exception as e:
|
265 |
raise HTTPException(status_code=500, detail=f"Error during transcription: {str(e)}")
|
266 |
|
267 |
+
# Predefined classifications
|
268 |
+
labels = [
|
269 |
+
"All Pricing copy quote requested",
|
270 |
+
"Change to quote",
|
271 |
+
"Change to quote & Status Check",
|
272 |
+
"Change to quote (Items missed?)",
|
273 |
+
"Confirmation",
|
274 |
+
"Copy quote requested",
|
275 |
+
"Cost copy quote requested",
|
276 |
+
"MRSP copy quote requested",
|
277 |
+
"MSRP & All Pricing copy quote requested",
|
278 |
+
"MSRP & Cost copy quote requested",
|
279 |
+
"No narrative in email",
|
280 |
+
"Notes not clear",
|
281 |
+
"Retail copy quote requested",
|
282 |
+
"Status Check (possibly)"
|
283 |
+
]
|
284 |
+
|
285 |
+
@app.post("/fast_classify/", description="Quickly classify text into predefined categories.")
|
286 |
+
async def fast_classify_text(statement: str = Form(...)):
|
287 |
try:
|
288 |
+
# Use zero-shot classification to classify statement into one of the provided labels
|
289 |
+
result = nlp_sequence_classification(statement, labels, multi_label=False)
|
|
|
|
|
|
|
|
|
290 |
|
291 |
+
# Extract the best label and score
|
292 |
+
best_label = result["labels"][0]
|
293 |
+
best_score = result["scores"][0]
|
294 |
|
295 |
+
return {"classification": best_label, "confidence": best_score}
|
|
|
|
|
|
|
296 |
except Exception as e:
|
297 |
+
return JSONResponse(content=f"Error in classification: {str(e)}", status_code=500)
|
298 |
|
299 |
# Set up CORS middleware
|
300 |
origins = ["*"] # or specify your list of allowed origins
|