MJobe commited on
Commit
a90d0cf
·
verified ·
1 Parent(s): adb3079

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +9 -0
main.py CHANGED
@@ -24,6 +24,7 @@ nlp_qa = pipeline("document-question-answering", model="jinhybr/OCR-DocVQA-Donut
24
  nlp_qa_v2 = pipeline("document-question-answering", model="faisalraza/layoutlm-invoices")
25
  nlp_qa_v3 = pipeline("question-answering", model="deepset/roberta-base-squad2")
26
  nlp_classification = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english")
 
27
 
28
  description = """
29
  ## Image-based Document QA
@@ -128,6 +129,14 @@ async def classify_text(text: str = Form(...)):
128
  except Exception as e:
129
  return JSONResponse(content=f"Error classifying text: {str(e)}", status_code=500)
130
 
 
 
 
 
 
 
 
 
131
 
132
  # Set up CORS middleware
133
  origins = ["*"] # or specify your list of allowed origins
 
24
  nlp_qa_v2 = pipeline("document-question-answering", model="faisalraza/layoutlm-invoices")
25
  nlp_qa_v3 = pipeline("question-answering", model="deepset/roberta-base-squad2")
26
  nlp_classification = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english")
27
+ nlp_classification_v2 = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest")
28
 
29
  description = """
30
  ## Image-based Document QA
 
129
  except Exception as e:
130
  return JSONResponse(content=f"Error classifying text: {str(e)}", status_code=500)
131
 
132
+ @app.post("/test_classify/", description="Classify the provided text with positive, neutral, or negative sentiment.")
133
+ async def test_classify_text(text: str = Form(...)):
134
+ try:
135
+ # Perform text classification using the updated model that returns positive, neutral, or negative
136
+ result = nlp_classification_v2(text)
137
+ return result
138
+ except Exception as e:
139
+ return JSONResponse(content=f"Error classifying text: {str(e)}", status_code=500)
140
 
141
  # Set up CORS middleware
142
  origins = ["*"] # or specify your list of allowed origins