ocrtest / app.py
kopeck's picture
Update app.py
cc68c7f verified
raw
history blame
No virus
9.45 kB
import gradio as gr
from typing import Dict
import logging
import tempfile
import io
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from pdf2image import convert_from_bytes
from PIL import Image
import pytesseract
import docx2txt
from reportlab.lib.pagesizes import letter
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer
from reportlab.lib.styles import getSampleStyleSheet
import time
from concurrent.futures import ThreadPoolExecutor, TimeoutError
import docx
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class AdvancedDocProcessor:
def __init__(self):
# Initialize BART model for text cleaning and summarization
self.bart_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
self.bart_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
# Initialize T5 model for text generation tasks
self.t5_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
self.t5_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
# Initialize pipeline for named entity recognition
self.ner_pipeline = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english")
def extract_text(self, file_content: bytes, file_type: str) -> str:
"""Extract text from various file types."""
try:
if file_type == "application/pdf":
return self.extract_text_from_pdf(file_content)
elif file_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
return self.extract_text_from_docx(file_content)
elif file_type == "text/plain":
return file_content.decode('utf-8')
else:
raise ValueError(f"Unsupported file type: {file_type}")
except Exception as e:
logger.error(f"Error extracting text: {str(e)}")
return ""
def extract_text_from_pdf(self, pdf_content: bytes) -> str:
"""Extract text from PDF using OCR."""
try:
images = convert_from_bytes(pdf_content, timeout=60) # Add timeout
text = ""
for image in images:
text += pytesseract.image_to_string(image)
return text
except Exception as e:
logger.error(f"Error extracting text from PDF: {str(e)}")
return ""
def extract_text_from_docx(self, docx_content: bytes) -> str:
"""Extract text from a DOCX file."""
try:
return docx2txt.process(io.BytesIO(docx_content))
except Exception as e:
logger.error(f"Error extracting text from DOCX: {str(e)}")
return ""
def clean_and_summarize_text(self, text: str) -> str:
"""Clean and summarize the text using BART."""
try:
chunk_size = 1024
chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
summarized_chunks = []
for chunk in chunks:
inputs = self.bart_tokenizer([chunk], max_length=1024, return_tensors="pt", truncation=True)
summary_ids = self.bart_model.generate(inputs["input_ids"], num_beams=4, max_length=150, early_stopping=True)
summarized_chunks.append(self.bart_tokenizer.decode(summary_ids[0], skip_special_tokens=True))
return " ".join(summarized_chunks)
except Exception as e:
logger.error(f"Error cleaning and summarizing text: {str(e)}")
return text
def process_with_t5(self, text: str, prompt: str) -> str:
"""Process the text with T5 based on the given prompt."""
try:
chunk_size = 512
chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
processed_chunks = []
for chunk in chunks:
input_text = f"{prompt} {chunk}"
inputs = self.t5_tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
outputs = self.t5_model.generate(
**inputs,
max_length=150,
num_return_sequences=1,
do_sample=True,
temperature=0.7
)
processed_chunks.append(self.t5_tokenizer.decode(outputs[0], skip_special_tokens=True))
return " ".join(processed_chunks)
except Exception as e:
logger.error(f"Error processing with T5: {str(e)}")
return f"Error processing text: {str(e)}"
def extract_entities(self, text: str) -> str:
"""Extract named entities from the text."""
try:
chunk_size = 10000
chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
all_entities = []
for chunk in chunks:
entities = self.ner_pipeline(chunk)
all_entities.extend(entities)
unique_entities = set((ent['word'], ent['entity']) for ent in all_entities)
return "\n".join([f"{word} ({entity})" for word, entity in unique_entities])
except Exception as e:
logger.error(f"Error extracting entities: {str(e)}")
return "Error extracting entities"
def process_document(self, file_content: bytes, file_type: str, prompt: str) -> Dict[str, str]:
raw_text = self.extract_text(file_content, file_type)
cleaned_text = self.clean_and_summarize_text(raw_text)
processed_text = self.process_with_t5(cleaned_text, prompt)
entities = self.extract_entities(raw_text)
return {
"cleaned": cleaned_text,
"processed": processed_text,
"entities": entities
}
def create_gradio_interface():
processor = AdvancedDocProcessor()
def process_and_display(file, prompt, output_format):
def processing_task():
file_content = file
file_type = infer_file_type(file_content)
results = processor.process_document(file_content, file_type, prompt)
if output_format == "txt":
output_path = save_as_txt(results)
elif output_format == "docx":
output_path = save_as_docx(results)
else: # pdf
output_path = save_as_pdf(results)
return (f"Cleaned and Summarized Text:\n{results['cleaned']}\n\n"
f"Processed Text:\n{results['processed']}\n\n"
f"Extracted Entities:\n{results['entities']}"), output_path
with ThreadPoolExecutor() as executor:
future = executor.submit(processing_task)
try:
return future.result(timeout=300) # 5 minutes timeout
except TimeoutError:
return "Processing timed out after 5 minutes.", None
iface = gr.Interface(
fn=process_and_display,
inputs=[
gr.File(label="Upload Document (PDF, DOCX, or TXT)", type="binary"),
gr.Textbox(label="Enter your prompt for processing", lines=3),
gr.Radio(["txt", "docx", "pdf"], label="Output Format", value="txt")
],
outputs=[
gr.Textbox(label="Processing Results", lines=30),
gr.File(label="Download Processed Document")
],
title="Advanced Document Processing Tool",
description="Upload a document (PDF, DOCX, or TXT) and enter a prompt to process and analyze the text using state-of-the-art NLP models.",
)
return iface
def infer_file_type(file_content: bytes) -> str:
"""Infer the file type from the byte content."""
if file_content.startswith(b'%PDF'):
return "application/pdf"
elif file_content.startswith(b'PK\x03\x04'):
return "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
else:
return "text/plain"
def save_as_txt(results: Dict[str, str]) -> str:
with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as temp_file:
for key, value in results.items():
temp_file.write(f"{key.upper()}:\n{value}\n\n")
return temp_file.name
def save_as_docx(results: Dict[str, str]) -> str:
doc = docx.Document()
for key, value in results.items():
doc.add_heading(key.capitalize(), level=1)
doc.add_paragraph(value)
with tempfile.NamedTemporaryFile(delete=False, suffix='.docx') as tmp:
doc.save(tmp.name)
return tmp.name
def save_as_pdf(results: Dict[str, str]) -> str:
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
doc = SimpleDocTemplate(tmp.name, pagesize=letter)
styles = getSampleStyleSheet()
story = []
for key, value in results.items():
story.append(Paragraph(key.capitalize(), styles['Heading1']))
story.append(Paragraph(value, styles['BodyText']))
story.append(Spacer(1, 12))
doc.build(story)
return tmp.name
# Launch the Gradio app
if __name__ == "__main__":
iface = create_gradio_interface()
iface.launch()
# Launch the Gradio app
if __name__ == "__main__":
iface = create_gradio_interface()
iface.launch()