MMSAPI / app.py
badru's picture
Update app.py
8d9cf6c verified
from fastapi import FastAPI, UploadFile, File, HTTPException
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import cv2
import numpy as np
from io import BytesIO
import uvicorn
# Load the model and processor
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
app = FastAPI()
# Helper function to preprocess the image and detect lines
def detect_lines(image, min_height=20, min_width=100):
"""
Detects lines of text in the given image.
"""
# Convert the PIL image to a NumPy array
image_np = np.array(image)
# Convert to grayscale
gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
# Apply binary thresholding
_, binary = cv2.threshold(gray, 128, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
# Dilate to merge nearby text
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
dilated = cv2.dilate(binary, kernel, iterations=1)
# Find contours
contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Sort contours top-to-bottom
bounding_boxes = [cv2.boundingRect(c) for c in contours]
bounding_boxes = sorted(bounding_boxes, key=lambda b: b[1]) # Sort by y-coordinate
# Filter out small contours and merge nearby ones
filtered_boxes = []
for x, y, w, h in bounding_boxes:
if h >= min_height and w >= min_width: # Filter small boxes
filtered_boxes.append((x, y, w, h))
# Extract individual lines as images
line_images = []
for (x, y, w, h) in filtered_boxes:
line = image_np[y:y + h, x:x + w]
line_images.append(line)
return line_images
@app.post("/process_image")
async def process_image(file: UploadFile = File(...)):
"""
API endpoint to process the uploaded image and extract multiline text.
"""
try:
# Read the uploaded image
contents = await file.read()
image = Image.open(BytesIO(contents)).convert("RGB")
# Detect lines in the image
line_images = detect_lines(image, min_height=30, min_width=100)
# Perform OCR on each detected line
extracted_text = ""
for idx, line_img in enumerate(line_images):
# Convert the line image to PIL format
line_pil = Image.fromarray(line_img)
# Prepare the image for OCR
pixel_values = processor(images=line_pil, return_tensors="pt").pixel_values
# Generate text from the line image
generated_ids = model.generate(pixel_values)
line_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# Append the extracted text
extracted_text += f"{line_text}\n"
# Return the extracted text as a JSON response
return {"extracted_text": extracted_text}
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error processing image: {str(e)}")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)