badru commited on
Commit
8d9cf6c
·
verified ·
1 Parent(s): 73dd09c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -58
app.py CHANGED
@@ -1,94 +1,90 @@
1
- import streamlit as st
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
  from PIL import Image
4
- from fastapi import FastAPI, UploadFile, File
5
- from fastapi.responses import JSONResponse
6
- import uvicorn
7
- import numpy as np
8
  import cv2
9
- import io
10
-
11
- # Create a FastAPI app instance
12
- app = FastAPI()
13
 
14
  # Load the model and processor
15
- @st.cache_resource
16
- def load_model():
17
- processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
18
- model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
19
- return processor, model
20
-
21
- processor, model = load_model()
22
 
 
23
 
24
- # Function to preprocess image and detect lines (used for multiline OCR)
25
  def detect_lines(image, min_height=20, min_width=100):
 
 
 
 
26
  image_np = np.array(image)
 
 
27
  gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
 
 
28
  _, binary = cv2.threshold(gray, 128, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
 
 
29
  kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
30
  dilated = cv2.dilate(binary, kernel, iterations=1)
31
- contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
32
- bounding_boxes = sorted([cv2.boundingRect(c) for c in contours], key=lambda b: b[1])
33
- line_images = [image_np[y:y+h, x:x+w] for x, y, w, h in bounding_boxes if h >= min_height and w >= min_width]
34
- return line_images
35
-
36
-
37
- # FastAPI endpoint to handle image processing
38
- @app.post("/process_image")
39
- async def process_image(image: UploadFile = File(...)):
40
- try:
41
- # Read the uploaded image
42
- image_data = await image.read()
43
- image = Image.open(io.BytesIO(image_data)).convert("RGB")
44
 
45
- # Detect lines and process each line
46
- line_images = detect_lines(image, min_height=30, min_width=600)
47
- extracted_text = ""
48
- for line_img in line_images:
49
- line_pil = Image.fromarray(line_img)
50
- pixel_values = processor(images=line_pil, return_tensors="pt").pixel_values
51
- generated_ids = model.generate(pixel_values)
52
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
53
- extracted_text += generated_text + "\n"
54
 
55
- # Return extracted text as JSON
56
- return JSONResponse(content={"extracted_text": extracted_text.strip()})
 
57
 
58
- except Exception as e:
59
- return JSONResponse(content={"error": str(e)}, status_code=500)
 
 
 
60
 
 
 
 
 
 
61
 
62
- # Streamlit UI for testing (optional)
63
- st.title("OCR API Service with Multiline Support")
64
 
65
- uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "jpeg", "png"])
66
- if uploaded_file is not None:
 
 
 
67
  try:
68
- image = Image.open(uploaded_file).convert("RGB")
69
- st.image(image, caption="Uploaded Image", use_column_width=True)
 
70
 
71
  # Detect lines in the image
72
- st.write("Detecting lines...")
73
  line_images = detect_lines(image, min_height=30, min_width=100)
74
- st.write(f"Detected {len(line_images)} lines in the image.")
75
 
76
  # Perform OCR on each detected line
77
  extracted_text = ""
78
  for idx, line_img in enumerate(line_images):
 
79
  line_pil = Image.fromarray(line_img)
 
 
80
  pixel_values = processor(images=line_pil, return_tensors="pt").pixel_values
 
 
81
  generated_ids = model.generate(pixel_values)
82
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
83
- extracted_text += f"{generated_text}\n"
 
 
84
 
85
- # Display extracted text
86
- st.subheader("Extracted Text:")
87
- st.text_area("Output Text", extracted_text.strip(), height=300)
88
 
89
  except Exception as e:
90
- st.error(f"An error occurred: {e}")
91
 
92
- # Run the FastAPI app
93
  if __name__ == "__main__":
94
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
  from PIL import Image
 
 
 
 
4
  import cv2
5
+ import numpy as np
6
+ from io import BytesIO
7
+ import uvicorn
 
8
 
9
  # Load the model and processor
10
+ processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
11
+ model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
 
 
 
 
 
12
 
13
+ app = FastAPI()
14
 
15
+ # Helper function to preprocess the image and detect lines
16
  def detect_lines(image, min_height=20, min_width=100):
17
+ """
18
+ Detects lines of text in the given image.
19
+ """
20
+ # Convert the PIL image to a NumPy array
21
  image_np = np.array(image)
22
+
23
+ # Convert to grayscale
24
  gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
25
+
26
+ # Apply binary thresholding
27
  _, binary = cv2.threshold(gray, 128, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
28
+
29
+ # Dilate to merge nearby text
30
  kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
31
  dilated = cv2.dilate(binary, kernel, iterations=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ # Find contours
34
+ contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
 
 
 
 
 
 
 
35
 
36
+ # Sort contours top-to-bottom
37
+ bounding_boxes = [cv2.boundingRect(c) for c in contours]
38
+ bounding_boxes = sorted(bounding_boxes, key=lambda b: b[1]) # Sort by y-coordinate
39
 
40
+ # Filter out small contours and merge nearby ones
41
+ filtered_boxes = []
42
+ for x, y, w, h in bounding_boxes:
43
+ if h >= min_height and w >= min_width: # Filter small boxes
44
+ filtered_boxes.append((x, y, w, h))
45
 
46
+ # Extract individual lines as images
47
+ line_images = []
48
+ for (x, y, w, h) in filtered_boxes:
49
+ line = image_np[y:y + h, x:x + w]
50
+ line_images.append(line)
51
 
52
+ return line_images
 
53
 
54
+ @app.post("/process_image")
55
+ async def process_image(file: UploadFile = File(...)):
56
+ """
57
+ API endpoint to process the uploaded image and extract multiline text.
58
+ """
59
  try:
60
+ # Read the uploaded image
61
+ contents = await file.read()
62
+ image = Image.open(BytesIO(contents)).convert("RGB")
63
 
64
  # Detect lines in the image
 
65
  line_images = detect_lines(image, min_height=30, min_width=100)
 
66
 
67
  # Perform OCR on each detected line
68
  extracted_text = ""
69
  for idx, line_img in enumerate(line_images):
70
+ # Convert the line image to PIL format
71
  line_pil = Image.fromarray(line_img)
72
+
73
+ # Prepare the image for OCR
74
  pixel_values = processor(images=line_pil, return_tensors="pt").pixel_values
75
+
76
+ # Generate text from the line image
77
  generated_ids = model.generate(pixel_values)
78
+ line_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
79
+
80
+ # Append the extracted text
81
+ extracted_text += f"{line_text}\n"
82
 
83
+ # Return the extracted text as a JSON response
84
+ return {"extracted_text": extracted_text}
 
85
 
86
  except Exception as e:
87
+ raise HTTPException(status_code=400, detail=f"Error processing image: {str(e)}")
88
 
 
89
  if __name__ == "__main__":
90
  uvicorn.run(app, host="0.0.0.0", port=8000)