MJobe commited on
Commit
ee808b2
1 Parent(s): 8601b67

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +28 -16
main.py CHANGED
@@ -35,7 +35,7 @@ nlp_qa_v3 = pipeline("question-answering", model="deepset/roberta-base-squad2")
35
  nlp_classification = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english")
36
  nlp_classification_v2 = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest")
37
  nlp_speech_to_text = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h")
38
- nlp_prompt_classification = pipeline("text2text-generation", model="google/flan-t5-large")
39
 
40
  description = """
41
  ## Image-based Document QA
@@ -264,25 +264,37 @@ async def test_transcription(file: UploadFile = File(...)):
264
  except Exception as e:
265
  raise HTTPException(status_code=500, detail=f"Error during transcription: {str(e)}")
266
 
267
- @app.post("/prompt_classify/", description="Classify the provided statement into one of the predefined categories.")
268
- async def prompt_classify_text(statement: str = Form(...)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  try:
270
- # Predefined prompt with placeholders
271
- prompt = (
272
- "Please classify the statement in one of the following classifications: "
273
- "Negative, Neutral, Positive"
274
- f"Statement: {statement}"
275
- )
276
 
277
- # Generate the response based on the prompt
278
- result = nlp_prompt_classification(prompt, max_length=50, num_return_sequences=1)
 
279
 
280
- # Extract the generated classification from the response
281
- classification = result[0]['generated_text'].strip()
282
-
283
- return {"classification": classification}
284
  except Exception as e:
285
- return JSONResponse(content=f"Error in prompt classification: {str(e)}", status_code=500)
286
 
287
  # Set up CORS middleware
288
  origins = ["*"] # or specify your list of allowed origins
 
35
  nlp_classification = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english")
36
  nlp_classification_v2 = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest")
37
  nlp_speech_to_text = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h")
38
+ nlp_sequence_classification = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
39
 
40
  description = """
41
  ## Image-based Document QA
 
264
  except Exception as e:
265
  raise HTTPException(status_code=500, detail=f"Error during transcription: {str(e)}")
266
 
267
+ # Predefined classifications
268
+ labels = [
269
+ "All Pricing copy quote requested",
270
+ "Change to quote",
271
+ "Change to quote & Status Check",
272
+ "Change to quote (Items missed?)",
273
+ "Confirmation",
274
+ "Copy quote requested",
275
+ "Cost copy quote requested",
276
+ "MRSP copy quote requested",
277
+ "MSRP & All Pricing copy quote requested",
278
+ "MSRP & Cost copy quote requested",
279
+ "No narrative in email",
280
+ "Notes not clear",
281
+ "Retail copy quote requested",
282
+ "Status Check (possibly)"
283
+ ]
284
+
285
+ @app.post("/fast_classify/", description="Quickly classify text into predefined categories.")
286
+ async def fast_classify_text(statement: str = Form(...)):
287
  try:
288
+ # Use zero-shot classification to classify statement into one of the provided labels
289
+ result = nlp_sequence_classification(statement, labels, multi_label=False)
 
 
 
 
290
 
291
+ # Extract the best label and score
292
+ best_label = result["labels"][0]
293
+ best_score = result["scores"][0]
294
 
295
+ return {"classification": best_label, "confidence": best_score}
 
 
 
296
  except Exception as e:
297
+ return JSONResponse(content=f"Error in classification: {str(e)}", status_code=500)
298
 
299
  # Set up CORS middleware
300
  origins = ["*"] # or specify your list of allowed origins