MJobe commited on
Commit
cfd8768
1 Parent(s): 11d5e31

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +43 -0
main.py CHANGED
@@ -25,6 +25,7 @@ nlp_qa_v2 = pipeline("document-question-answering", model="faisalraza/layoutlm-i
25
  nlp_qa_v3 = pipeline("question-answering", model="deepset/roberta-base-squad2")
26
  nlp_classification = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english")
27
  nlp_classification_v2 = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest")
 
28
 
29
  description = """
30
  ## Image-based Document QA
@@ -153,6 +154,48 @@ async def test_classify_text(text: str = Form(...)):
153
  except Exception as e:
154
  return JSONResponse(content=f"Error classifying text: {str(e)}", status_code=500)
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  # Set up CORS middleware
157
  origins = ["*"] # or specify your list of allowed origins
158
  app.add_middleware(
 
25
  nlp_qa_v3 = pipeline("question-answering", model="deepset/roberta-base-squad2")
26
  nlp_classification = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english")
27
  nlp_classification_v2 = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest")
28
+ nlp_speech_to_text = pipeline("automatic-speech-recognition", model="openai/whisper-base")
29
 
30
  description = """
31
  ## Image-based Document QA
 
154
  except Exception as e:
155
  return JSONResponse(content=f"Error classifying text: {str(e)}", status_code=500)
156
 
157
+
158
+ @app.post("/transcribe_and_match/", description="Transcribe audio and match responses to form fields.")
159
+ async def transcribe_and_match(
160
+ file: UploadFile = File(...),
161
+ field_data: str = Form(...)
162
+ ):
163
+ """
164
+ Transcribe audio and match it to form fields.
165
+ :param file: The uploaded audio file.
166
+ :param field_data: A JSON string that contains form field information (field names and IDs).
167
+ """
168
+ try:
169
+ # Step 1: Read and transcribe the audio file
170
+ contents = await file.read()
171
+ transcription_result = nlp_speech_to_text(contents)
172
+ transcription_text = transcription_result['text']
173
+
174
+ # Step 2: Parse the field_data (which contains field names/IDs)
175
+ # Example: [{"field_id": "name_field", "field_label": "Name"}, {"field_id": "email_field", "field_label": "Email"}]
176
+ import json
177
+ fields = json.loads(field_data)
178
+
179
+ # Step 3: Find the matching field for the transcription
180
+ field_matches = {}
181
+
182
+ for field in fields:
183
+ field_label = field.get("field_label", "").lower()
184
+ field_id = field.get("field_id", "")
185
+
186
+ # Simple matching: if the transcribed text contains the field label (or something close)
187
+ if field_label in transcription_text.lower():
188
+ field_matches[field_id] = transcription_text
189
+
190
+ # Step 4: Return transcription + matched fields
191
+ return {
192
+ "transcription": transcription_text,
193
+ "matched_fields": field_matches
194
+ }
195
+
196
+ except Exception as e:
197
+ return JSONResponse(content=f"Error processing audio or matching fields: {str(e)}", status_code=500)
198
+
199
  # Set up CORS middleware
200
  origins = ["*"] # or specify your list of allowed origins
201
  app.add_middleware(