Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -267,7 +267,10 @@ async def test_transcription(file: UploadFile = File(...)):
|
|
267 |
except Exception as e:
|
268 |
raise HTTPException(status_code=500, detail=f"Error during transcription: {str(e)}")
|
269 |
|
270 |
-
#
|
|
|
|
|
|
|
271 |
labels = [
|
272 |
"All Pricing copy quote requested",
|
273 |
"Change to quote",
|
@@ -285,58 +288,30 @@ labels = [
|
|
285 |
"Status Check (possibly)"
|
286 |
]
|
287 |
|
288 |
-
|
289 |
-
keyword_map = {
|
290 |
-
"All Pricing copy quote requested": ["Per ATP", "Not Want", "No Advertising"],
|
291 |
-
"Change to quote": ["MSRP"],
|
292 |
-
"Change to quote & Status Check": ["Send clean Quote"],
|
293 |
-
"Change to quote (Items missed?)": ["Per ATP"],
|
294 |
-
"Confirmation": ["N/A"],
|
295 |
-
"Copy quote requested": ["N/A"],
|
296 |
-
"Cost copy quote requested": ["N/A"],
|
297 |
-
"MRSP copy quote requested": ["Quote Requested", "Revise Quote", "Include the following items"],
|
298 |
-
"MSRP & All Pricing copy quote requested": ["Revise and resend Quote"],
|
299 |
-
"MSRP & Cost copy quote requested": ["Per RTS"],
|
300 |
-
"No narrative in email": ["Requesting"],
|
301 |
-
"Notes not clear": ["Add to the quote"],
|
302 |
-
"Retail copy quote requested": ["Action Required", "Send copy of Quote", "All Pricing"],
|
303 |
-
"Status Check (possibly)": ["Re-quote", "Current date and Pricing", "Send quote"]
|
304 |
-
}
|
305 |
-
|
306 |
-
# Thread pool executor for handling requests concurrently
|
307 |
-
executor = ThreadPoolExecutor(max_workers=10)
|
308 |
-
|
309 |
-
# Helper function to check keywords
|
310 |
-
def check_keywords(statement):
|
311 |
-
for label, keywords in keyword_map.items():
|
312 |
-
for keyword in keywords:
|
313 |
-
if re.search(rf"\b{keyword}\b", statement, re.IGNORECASE):
|
314 |
-
return label
|
315 |
-
return None
|
316 |
-
|
317 |
-
@app.post("/fast_classify/", description="Quickly classify text with a rule-based approach combined with model inference.")
|
318 |
async def fast_classify_text(statement: str = Form(...)):
|
319 |
try:
|
320 |
-
#
|
321 |
-
rule_based_classification = check_keywords(statement)
|
322 |
-
if rule_based_classification:
|
323 |
-
return {"classification": rule_based_classification, "confidence": 1.0}
|
324 |
-
|
325 |
-
# Step 2: Run model classification if no rule-based match
|
326 |
loop = asyncio.get_running_loop()
|
327 |
result = await loop.run_in_executor(
|
328 |
-
executor,
|
329 |
lambda: nlp_sequence_classification(statement, labels, multi_label=False)
|
330 |
)
|
331 |
|
|
|
332 |
best_label = result["labels"][0]
|
333 |
best_score = result["scores"][0]
|
334 |
|
335 |
return {"classification": best_label, "confidence": best_score}
|
336 |
except asyncio.TimeoutError:
|
|
|
337 |
return JSONResponse(content="Classification timed out. Try a shorter input or increase timeout.", status_code=504)
|
|
|
|
|
|
|
338 |
except Exception as e:
|
339 |
-
|
|
|
340 |
|
341 |
# Set up CORS middleware
|
342 |
origins = ["*"] # or specify your list of allowed origins
|
|
|
267 |
except Exception as e:
|
268 |
raise HTTPException(status_code=500, detail=f"Error during transcription: {str(e)}")
|
269 |
|
270 |
+
# Define the ThreadPoolExecutor globally to manage asynchronous execution
|
271 |
+
executor = ThreadPoolExecutor(max_workers=10)
|
272 |
+
|
273 |
+
# Predefined classifications
|
274 |
labels = [
|
275 |
"All Pricing copy quote requested",
|
276 |
"Change to quote",
|
|
|
288 |
"Status Check (possibly)"
|
289 |
]
|
290 |
|
291 |
+
@app.post("/fast_classify/", description="Quickly classify text into predefined categories.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
async def fast_classify_text(statement: str = Form(...)):
|
293 |
try:
|
294 |
+
# Use run_in_executor to handle the synchronous model call asynchronously
|
|
|
|
|
|
|
|
|
|
|
295 |
loop = asyncio.get_running_loop()
|
296 |
result = await loop.run_in_executor(
|
297 |
+
executor,
|
298 |
lambda: nlp_sequence_classification(statement, labels, multi_label=False)
|
299 |
)
|
300 |
|
301 |
+
# Extract the best label and score
|
302 |
best_label = result["labels"][0]
|
303 |
best_score = result["scores"][0]
|
304 |
|
305 |
return {"classification": best_label, "confidence": best_score}
|
306 |
except asyncio.TimeoutError:
|
307 |
+
# Handle timeout
|
308 |
return JSONResponse(content="Classification timed out. Try a shorter input or increase timeout.", status_code=504)
|
309 |
+
except HTTPException as http_exc:
|
310 |
+
# Handle HTTP errors
|
311 |
+
return JSONResponse(content=f"HTTP error: {http_exc.detail}", status_code=http_exc.status_code)
|
312 |
except Exception as e:
|
313 |
+
# Handle general errors
|
314 |
+
return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500)
|
315 |
|
316 |
# Set up CORS middleware
|
317 |
origins = ["*"] # or specify your list of allowed origins
|