Spaces:
Running
Running
import fitz | |
import io | |
from fastapi import FastAPI, File, UploadFile, Form, HTTPException | |
from fastapi.responses import JSONResponse | |
from transformers import pipeline | |
from PIL import Image | |
from io import BytesIO | |
from starlette.middleware import Middleware | |
from starlette.middleware.cors import CORSMiddleware | |
from pdf2image import convert_from_bytes | |
from pydub import AudioSegment | |
import numpy as np | |
import json | |
import torchaudio | |
import torch | |
from pydub import AudioSegment | |
import speech_recognition as sr | |
import logging | |
import asyncio | |
from concurrent.futures import ThreadPoolExecutor | |
import re | |
from pydantic import BaseModel | |
from typing import List, Dict, Any | |
app = FastAPI() | |
# Set up CORS middleware | |
origins = ["*"] # or specify your list of allowed origins | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
nlp_qa = pipeline("document-question-answering", model="jinhybr/OCR-DocVQA-Donut") | |
nlp_qa_v2 = pipeline("document-question-answering", model="faisalraza/layoutlm-invoices", ignore_mismatched_sizes=True) | |
nlp_qa_v3 = pipeline("question-answering", model="deepset/roberta-base-squad2") | |
nlp_classification = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english") | |
nlp_classification_v2 = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest") | |
nlp_speech_to_text = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h") | |
nlp_sequence_classification = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") | |
nlp_main_classification = pipeline("zero-shot-classification", model="roberta-large-mnli") | |
description = """ | |
## Image-based Document QA | |
This API performs document question answering using a LayoutLMv2-based model. | |
### Endpoints: | |
- **POST /uploadfile/:** Upload an image file to extract text and answer provided questions. | |
- **POST /pdfQA/:** Provide a PDF file to extract text and answer provided questions. | |
""" | |
app = FastAPI(docs_url="/", description=description) | |
async def perform_document_qa( | |
file: UploadFile = File(...), | |
questions: str = Form(...), | |
): | |
try: | |
# Read the uploaded file as bytes | |
contents = await file.read() | |
# Open the image using PIL | |
image = Image.open(BytesIO(contents)) | |
# Perform document question answering for each question using LayoutLMv2-based model | |
answers_dict = {} | |
for question in questions.split(','): | |
result = nlp_qa( | |
image, | |
question.strip() | |
) | |
# Access the 'answer' key from the first item in the result list | |
answer = result[0]['answer'] | |
# Format the question as a string without extra characters | |
formatted_question = question.strip("[]") | |
answers_dict[formatted_question] = answer | |
return answers_dict | |
except Exception as e: | |
return JSONResponse(content=f"Error processing file: {str(e)}", status_code=500) | |
async def perform_document_qa( | |
file: UploadFile = File(...), | |
questions: str = Form(...), | |
): | |
try: | |
# Read the uploaded file as bytes | |
contents = await file.read() | |
# Open the image using PIL | |
image = Image.open(BytesIO(contents)) | |
# Perform document question answering for each question using LayoutLMv2-based model | |
answers_dict = {} | |
for question in questions.split(','): | |
result = nlp_qa_v2( | |
image, | |
question.strip() | |
) | |
# Access the 'answer' key from the first item in the result list | |
answer = result[0]['answer'] | |
# Format the question as a string without extra characters | |
formatted_question = question.strip("[]") | |
answers_dict[formatted_question] = answer | |
return answers_dict | |
except Exception as e: | |
return JSONResponse(content=f"Error processing file: {str(e)}", status_code=500) | |
async def perform_document_qa( | |
context: str = Form(...), | |
question: str = Form(...), | |
): | |
try: | |
QA_input = { | |
'question': question, | |
'context': context | |
} | |
res = nlp_qa_v3(QA_input) | |
return res['answer'] | |
except Exception as e: | |
return JSONResponse(content=f"Error processing file: {str(e)}", status_code=500) | |
async def classify_text(text: str = Form(...)): | |
try: | |
# Perform text classification using the pipeline | |
result = nlp_classification(text) | |
# Return the classification result | |
return result | |
except Exception as e: | |
return JSONResponse(content=f"Error classifying text: {str(e)}", status_code=500) | |
async def test_classify_text(text: str = Form(...)): | |
try: | |
# Perform text classification using the updated model that returns positive, neutral, or negative | |
result = nlp_classification_v2(text) | |
# Print the raw label for debugging purposes (can be removed later) | |
raw_label = result[0]['label'] | |
print(f"Raw label from model: {raw_label}") | |
# Map the model labels to human-readable format | |
label_map = { | |
"negative": "Negative", | |
"neutral": "Neutral", | |
"positive": "Positive" | |
} | |
# Get the readable label from the map | |
formatted_label = label_map.get(raw_label, "Unknown") | |
return {"label": formatted_label, "score": result[0]['score']} | |
except Exception as e: | |
return JSONResponse(content=f"Error classifying text: {str(e)}", status_code=500) | |
async def transcribe_and_answer( | |
file: UploadFile = File(...), | |
questions: str = Form(...) | |
): | |
try: | |
# Ensure correct file format | |
if file.content_type not in ["audio/wav", "audio/mpeg", "audio/mp3", "audio/webm"]: | |
raise HTTPException(status_code=400, detail="Unsupported audio format. Please upload a WAV or MP3 file.") | |
logging.info(f"Received file type: {file.content_type}") | |
logging.info(f"Received questions: {questions}") | |
# Convert uploaded file to WAV if needed | |
audio_data = await file.read() | |
audio_file = io.BytesIO(audio_data) | |
if file.content_type in ["audio/mpeg", "audio/mp3"]: | |
audio = AudioSegment.from_file(audio_file, format="mp3") | |
audio_wav = io.BytesIO() | |
audio.export(audio_wav, format="wav") | |
audio_wav.seek(0) | |
elif file.content_type == "audio/webm": | |
audio = AudioSegment.from_file(audio_file, format="webm") | |
audio_wav = io.BytesIO() | |
audio.export(audio_wav, format="wav") | |
audio_wav.seek(0) | |
else: | |
audio_wav = audio_file | |
# Transcription | |
recognizer = sr.Recognizer() | |
with sr.AudioFile(audio_wav) as source: | |
audio = recognizer.record(source) | |
transcription_text = recognizer.recognize_google(audio) | |
# Parse questions JSON | |
try: | |
questions_dict = json.loads(questions) | |
except json.JSONDecodeError as e: | |
raise HTTPException(status_code=400, detail="Invalid JSON format for questions") | |
# Answer each question | |
answers_dict = {} | |
for key, question in questions_dict.items(): | |
QA_input = { | |
'question': question, | |
'context': transcription_text | |
} | |
# Add error handling here for model-based Q&A | |
try: | |
result = nlp_qa_v3(QA_input) # Ensure this is defined or imported correctly | |
answers_dict[key] = result['answer'] | |
except Exception as e: | |
logging.error(f"Error in question answering model: {e}") | |
answers_dict[key] = "Error in answering this question." | |
# Return transcription + answers | |
return { | |
"transcription": transcription_text, | |
"answers": answers_dict | |
} | |
except Exception as e: | |
logging.error(f"General error: {e}") | |
raise HTTPException(status_code=500, detail="Internal Server Error") | |
async def test_transcription(file: UploadFile = File(...)): | |
try: | |
# Check if the file format is supported | |
if file.content_type not in ["audio/wav", "audio/mpeg", "audio/mp3"]: | |
raise HTTPException(status_code=400, detail="Unsupported audio format. Please upload a WAV or MP3 file.") | |
# Convert uploaded file to WAV if necessary for compatibility with SpeechRecognition | |
audio_data = await file.read() | |
audio_file = io.BytesIO(audio_data) | |
if file.content_type in ["audio/mpeg", "audio/mp3"]: | |
# Convert MP3 to WAV | |
audio = AudioSegment.from_file(audio_file, format="mp3") | |
audio_wav = io.BytesIO() | |
audio.export(audio_wav, format="wav") | |
audio_wav.seek(0) | |
else: | |
audio_wav = audio_file | |
# Transcribe audio using speech_recognition | |
recognizer = sr.Recognizer() | |
with sr.AudioFile(audio_wav) as source: | |
audio = recognizer.record(source) | |
transcription = recognizer.recognize_google(audio) | |
# Return the transcription | |
return {"transcription": transcription} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error during transcription: {str(e)}") | |
# Define the ThreadPoolExecutor globally to manage asynchronous execution | |
executor = ThreadPoolExecutor(max_workers=10) | |
# Predefined classifications | |
labels = [ | |
"All Pricing copy quote requested", | |
"Change to quote", | |
"Change to quote & Status Check", | |
"Change to quote (Items missed?)", | |
"Confirmation", | |
"Copy quote requested", | |
"Cost copy quote requested", | |
"MRSP copy quote requested", | |
"MSRP & All Pricing copy quote requested", | |
"MSRP & Cost copy quote requested", | |
"No narrative in email", | |
"Notes not clear", | |
"Retail copy quote requested", | |
"Status Check (possibly)" | |
] | |
async def fast_classify_text(statement: str = Form(...)): | |
try: | |
# Use run_in_executor to handle the synchronous model call asynchronously | |
loop = asyncio.get_running_loop() | |
result = await loop.run_in_executor( | |
executor, | |
lambda: nlp_sequence_classification(statement, labels, multi_label=False) | |
) | |
# Extract the best label and score | |
best_label = result["labels"][0] | |
best_score = result["scores"][0] | |
return {"classification": best_label, "confidence": best_score} | |
except asyncio.TimeoutError: | |
# Handle timeout | |
return JSONResponse(content="Classification timed out. Try a shorter input or increase timeout.", status_code=504) | |
except HTTPException as http_exc: | |
# Handle HTTP errors | |
return JSONResponse(content=f"HTTP error: {http_exc.detail}", status_code=http_exc.status_code) | |
except Exception as e: | |
# Handle general errors | |
return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500) | |
# Predefined classifications | |
labels = [ | |
"All Pricing copy quote requested", | |
"Change to quote", | |
"Change to quote & Status Check", | |
"Change to quote (Items missed?)", | |
"Confirmation", | |
"Copy quote requested", | |
"Cost copy quote requested", | |
"MRSP copy quote requested", | |
"MSRP & All Pricing copy quote requested", | |
"MSRP & Cost copy quote requested", | |
"No narrative in email", | |
"Notes not clear", | |
"Retail copy quote requested", | |
"Status Check (possibly)" | |
] | |
async def fast_classify_text(statement: str = Form(...)): | |
try: | |
# Use run_in_executor to handle the synchronous model call asynchronously | |
loop = asyncio.get_running_loop() | |
result = await loop.run_in_executor( | |
executor, | |
lambda: nlp_sequence_classification(statement, labels, multi_label=False) | |
) | |
# Extract all labels and their scores | |
all_labels = result["labels"] | |
all_scores = result["scores"] | |
# Extract the best label and score | |
best_label = all_labels[0] | |
best_score = all_scores[0] | |
# Prepare the response | |
full_response = { | |
"classification": best_label, | |
"confidence": best_score, | |
"all_labels": {label: score for label, score in zip(all_labels, all_scores)} | |
} | |
return full_response | |
except asyncio.TimeoutError: | |
# Handle timeout | |
return JSONResponse(content="Classification timed out. Try a shorter input or increase timeout.", status_code=504) | |
except HTTPException as http_exc: | |
# Handle HTTP errors | |
return JSONResponse(content=f"HTTP error: {http_exc.detail}", status_code=http_exc.status_code) | |
except Exception as e: | |
# Handle general errors | |
return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500) | |
# Labels for main classifications | |
main_labels = [ | |
"Change to quote", | |
"Copy quote requested", | |
"Expired Quote", | |
"Notes not clear" | |
] | |
# Define a model for the response | |
class ClassificationResponse(BaseModel): | |
classification: str | |
sub_classification: str | |
confidence: float | |
scores: Dict[str, float] | |
# Keyword dictionaries for overriding classifications | |
change_to_quote_keywords = ["Per ATP", "Add", "Revised", "Remove", "Advise"] | |
copy_quote_requested_keywords = ["MSRP", "Send Quote", "Copy", "All pricing", "Retail"] | |
sub_classification_keywords = { | |
"MRSP": ["MSRP"], | |
"Direct": ["Direct"], | |
"All": ["All pricing"], | |
"MRSP & All": ["MSRP", "All pricing"] | |
} | |
# Helper function to check for keywords in a case-insensitive way | |
def check_keywords(statement: str, keywords: List[str]) -> bool: | |
return any(re.search(rf"\b{keyword}\b", statement, re.IGNORECASE) for keyword in keywords) | |
# Function to determine sub-classification based on keywords | |
def get_sub_classification(statement: str) -> str: | |
for sub_label, keywords in sub_classification_keywords.items(): | |
if all(check_keywords(statement, [keyword]) for keyword in keywords): | |
return sub_label | |
return "None" # Default to "None" if no keywords match | |
async def classify_with_subcategory(statement: str = Form(...)) -> ClassificationResponse: | |
try: | |
# Check if the statement is empty or "N/A" | |
if not statement or statement.strip().lower() == "n/a": | |
return ClassificationResponse( | |
classification="Notes not clear", | |
sub_classification="None", | |
confidence=1.0, | |
scores={"main": 1.0} | |
) | |
# Keyword-based classification override | |
if check_keywords(statement, change_to_quote_keywords): | |
main_best_label = "Change to quote" | |
main_best_score = 1.0 # High confidence since it's a direct match | |
elif check_keywords(statement, copy_quote_requested_keywords): | |
main_best_label = "Copy quote requested" | |
main_best_score = 1.0 | |
else: | |
# If no keywords matched, perform the main classification using the model | |
loop = asyncio.get_running_loop() | |
main_classification_result = await loop.run_in_executor( | |
None, | |
lambda: nlp_sequence_classification(statement, main_labels, multi_label=False) | |
) | |
# Extract the best main classification label and confidence score | |
main_best_label = main_classification_result["labels"][0] | |
main_best_score = main_classification_result["scores"][0] | |
# Perform sub-classification only if the main classification is "Copy quote requested" | |
if main_best_label == "Copy quote requested": | |
best_sub_label = get_sub_classification(statement) | |
else: | |
best_sub_label = "None" | |
# Gather the scores for response | |
scores = {"main": main_best_score} | |
if best_sub_label != "None": | |
scores[best_sub_label] = 1.0 # Assign full confidence to sub-classification matches | |
return ClassificationResponse( | |
classification=main_best_label, | |
sub_classification=best_sub_label, | |
confidence=main_best_score, | |
scores=scores | |
) | |
except asyncio.TimeoutError: | |
# Handle timeout errors | |
return JSONResponse(content="Classification timed out. Try a shorter input or increase timeout.", status_code=504) | |
except HTTPException as http_exc: | |
# Handle HTTP errors | |
return JSONResponse(content=f"HTTP error: {http_exc.detail}", status_code=http_exc.status_code) | |
except Exception as e: | |
# Handle any other errors | |
return JSONResponse(content=f"Error in classification pipeline: {str(e)}", status_code=500) | |
# Set up CORS middleware | |
origins = ["*"] # or specify your list of allowed origins | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) |