Spaces:
Sleeping
Sleeping
import fitz | |
from fastapi import FastAPI, File, UploadFile, Form | |
from fastapi.responses import JSONResponse | |
from transformers import pipeline | |
from PIL import Image | |
from io import BytesIO | |
from starlette.middleware import Middleware | |
from starlette.middleware.cors import CORSMiddleware | |
from pdf2image import convert_from_bytes | |
from pydub import AudioSegment | |
import numpy as np | |
import json | |
import torchaudio | |
import torch | |
app = FastAPI() | |
# Set up CORS middleware | |
origins = ["*"] # or specify your list of allowed origins | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
nlp_qa = pipeline("document-question-answering", model="jinhybr/OCR-DocVQA-Donut") | |
nlp_qa_v2 = pipeline("document-question-answering", model="faisalraza/layoutlm-invoices") | |
nlp_qa_v3 = pipeline("question-answering", model="deepset/roberta-base-squad2") | |
nlp_classification = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english") | |
nlp_classification_v2 = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest") | |
nlp_speech_to_text = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h") | |
description = """ | |
## Image-based Document QA | |
This API performs document question answering using a LayoutLMv2-based model. | |
### Endpoints: | |
- **POST /uploadfile/:** Upload an image file to extract text and answer provided questions. | |
- **POST /pdfQA/:** Provide a PDF file to extract text and answer provided questions. | |
""" | |
app = FastAPI(docs_url="/", description=description) | |
async def perform_document_qa( | |
file: UploadFile = File(...), | |
questions: str = Form(...), | |
): | |
try: | |
# Read the uploaded file as bytes | |
contents = await file.read() | |
# Open the image using PIL | |
image = Image.open(BytesIO(contents)) | |
# Perform document question answering for each question using LayoutLMv2-based model | |
answers_dict = {} | |
for question in questions.split(','): | |
result = nlp_qa( | |
image, | |
question.strip() | |
) | |
# Access the 'answer' key from the first item in the result list | |
answer = result[0]['answer'] | |
# Format the question as a string without extra characters | |
formatted_question = question.strip("[]") | |
answers_dict[formatted_question] = answer | |
return answers_dict | |
except Exception as e: | |
return JSONResponse(content=f"Error processing file: {str(e)}", status_code=500) | |
async def perform_document_qa( | |
file: UploadFile = File(...), | |
questions: str = Form(...), | |
): | |
try: | |
# Read the uploaded file as bytes | |
contents = await file.read() | |
# Open the image using PIL | |
image = Image.open(BytesIO(contents)) | |
# Perform document question answering for each question using LayoutLMv2-based model | |
answers_dict = {} | |
for question in questions.split(','): | |
result = nlp_qa_v2( | |
image, | |
question.strip() | |
) | |
# Access the 'answer' key from the first item in the result list | |
answer = result[0]['answer'] | |
# Format the question as a string without extra characters | |
formatted_question = question.strip("[]") | |
answers_dict[formatted_question] = answer | |
return answers_dict | |
except Exception as e: | |
return JSONResponse(content=f"Error processing file: {str(e)}", status_code=500) | |
async def perform_document_qa( | |
context: str = Form(...), | |
question: str = Form(...), | |
): | |
try: | |
QA_input = { | |
'question': question, | |
'context': context | |
} | |
res = nlp_qa_v3(QA_input) | |
return res['answer'] | |
except Exception as e: | |
return JSONResponse(content=f"Error processing file: {str(e)}", status_code=500) | |
async def classify_text(text: str = Form(...)): | |
try: | |
# Perform text classification using the pipeline | |
result = nlp_classification(text) | |
# Return the classification result | |
return result | |
except Exception as e: | |
return JSONResponse(content=f"Error classifying text: {str(e)}", status_code=500) | |
async def test_classify_text(text: str = Form(...)): | |
try: | |
# Perform text classification using the updated model that returns positive, neutral, or negative | |
result = nlp_classification_v2(text) | |
# Print the raw label for debugging purposes (can be removed later) | |
raw_label = result[0]['label'] | |
print(f"Raw label from model: {raw_label}") | |
# Map the model labels to human-readable format | |
label_map = { | |
"negative": "Negative", | |
"neutral": "Neutral", | |
"positive": "Positive" | |
} | |
# Get the readable label from the map | |
formatted_label = label_map.get(raw_label, "Unknown") | |
return {"label": formatted_label, "score": result[0]['score']} | |
except Exception as e: | |
return JSONResponse(content=f"Error classifying text: {str(e)}", status_code=500) | |
async def transcribe_and_answer( | |
file: UploadFile = File(...), | |
questions: str = Form(...) | |
): | |
try: | |
# Step 1: Read and convert the audio file | |
contents = await file.read() | |
audio = AudioSegment.from_file(BytesIO(contents)) | |
# Step 2: Ensure the audio is mono and resample if needed | |
audio = audio.set_channels(1) # Convert to mono if it's not already | |
audio = audio.set_frame_rate(16000) # Resample to 16000 Hz, commonly required by ASR models | |
# Step 3: Export to WAV format and load with torchaudio | |
wav_buffer = BytesIO() | |
audio.export(wav_buffer, format="wav") | |
wav_buffer.seek(0) | |
# Load audio using torchaudio | |
waveform, sample_rate = torchaudio.load(wav_buffer) | |
# Convert waveform to float32 and ensure it's a numpy array | |
waveform_np = waveform.numpy().astype(np.float32) | |
# Step 4: Transcribe the audio | |
transcription_result = nlp_speech_to_text(waveform_np) | |
transcription_text = transcription_result['text'] | |
# Step 5: Parse the JSON-formatted questions | |
questions_dict = json.loads(questions) | |
# Step 6: Answer each question using the transcribed text | |
answers_dict = {} | |
for key, question in questions_dict.items(): | |
QA_input = { | |
'question': question, | |
'context': transcription_text | |
} | |
result = nlp_qa_v3(QA_input) | |
answers_dict[key] = result['answer'] | |
# Step 7: Return transcription + answers | |
return { | |
"transcription": transcription_text, | |
"answers": answers_dict | |
} | |
except Exception as e: | |
return JSONResponse(content={"error": f"Error processing audio or answering questions: {str(e)}"}, status_code=500) | |
# Set up CORS middleware | |
origins = ["*"] # or specify your list of allowed origins | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) |