OCR_Sample_RGB / app.py
NassimeBejaia's picture
Update app.py
6b8ee3c
raw
history blame
10.3 kB
import streamlit as st
import pandas as pd
import numpy as np
import cv2
from PIL import Image
import os
import zipfile
import gdown
from tempfile import TemporaryDirectory
from Utils import Loadlines, decode_predictions, load_model
import shutil
import tempfile
import subprocess
# Load the OCR model outside the function to prevent reloading it every time
ocr_model = load_model()
def main():
st.title("Arabic Manuscript OCR")
# download_folder()
# Load sample images
sample_images_dir = "sample_images"
sample_images = [os.path.join(sample_images_dir, img) for img in os.listdir(sample_images_dir) if img.endswith(('.png', '.jpg', '.jpeg'))]
# Display images in a grid and let user select
col1, col2, col3 = st.columns(3) # Adjust based on how many images you have; this example assumes 3
# Normalized size for display
display_size = (200, 300)
# Placeholder for selected image
selected_image_pil = None
processed_img_path = None
with col1:
if st.button("Select Image 1", key="img1"):
selected_image_pil = display_image(sample_images[0])
col1.image(resize_image(sample_images[0], display_size), use_column_width=True)
with col2:
if st.button("Select Image 2", key="img2"):
selected_image_pil = display_image(sample_images[1])
col2.image(resize_image(sample_images[1], display_size), use_column_width=True)
with col3:
if st.button("Select Image 3", key="img3"):
selected_image_pil = display_image(sample_images[2])
col3.image(resize_image(sample_images[2], display_size), use_column_width=True)
# Option to upload a new image
uploaded_image = st.file_uploader("Or upload a new image of the Arabic manuscript", type=["jpg", "jpeg", "png"])
if uploaded_image:
selected_image_pil = Image.open(uploaded_image)
st.image(selected_image_pil, caption="Uploaded Image", use_column_width=True)
if selected_image_pil:
thresh_pil = process_image(selected_image_pil)
st.image(thresh_pil, caption="Thresholded Image", use_column_width=True)
processed_img_path = process_with_yolo(selected_image_pil)
if processed_img_path:
st.image(processed_img_path, caption="Processed with YOLO", use_column_width=True)
txt_file_path = os.path.join('yolov3/runs/detect/mainlinedection/labels', os.path.basename(processed_img_path).replace(".jpg", ".txt"))
if os.path.exists(txt_file_path):
if uploaded_image:
original_img_path = uploaded_image # If you have saved it to a location
else:
original_img_path = next((img for img in sample_images if Image.open(img) == selected_image_pil), None) # Match selected image with sample images
display_detected_lines(original_img_path, processed_img_path)
else:
st.error("Annotation file (.txt) not found!")
else:
st.error("Error displaying the processed image.")
# display_files_in_directory('/home/user/app/detected_lines/')
# Function to display files in a directory
def display_files_in_directory(path="."):
if os.path.exists(path):
files = os.listdir(path)
st.write(f"Files in directory: {path}")
for file in files:
st.write(file)
else:
st.write(f"Directory {path} does not exist!")
def resize_image(image_path, size):
"""Function to resize an image"""
with Image.open(image_path) as img:
img = img.resize(size)
return img
def display_image(img):
"""Function to display an image. If img is a path, open it. Otherwise, just display it."""
if isinstance(img, str): # img is a file path
img = Image.open(img)
st.image(img, caption="Selected Image", use_column_width=True)
return img
# --------------------------------------------
def process_image(selected_image):
# Convert PIL image to OpenCV format
opencv_image = np.array(selected_image)
# Check if the image is grayscale or RGB
if len(opencv_image.shape) == 3 and opencv_image.shape[2] == 3: # RGB image
gray_image = cv2.cvtColor(opencv_image, cv2.COLOR_RGB2GRAY)
else: # Image is already grayscale
gray_image = opencv_image
# Ensure the image is 8-bit grayscale
if gray_image.dtype != np.uint8:
gray_image = (gray_image * 255).astype(np.uint8)
# Optionally apply Gaussian Blur
blurred_img = cv2.GaussianBlur(gray_image, (5, 5), 0)
# Apply OTSU's thresholding
_, thresh = cv2.threshold(blurred_img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# Convert thresholded image back to PIL format to display in Streamlit
thresh_pil = Image.fromarray(thresh)
return thresh_pil
# ---------------------------------------------------------------------------------
def download_ultralytics_yolov3_folder():
st.text("Downloading Ultralytics YOLOv3 folder from Google Drive. This may take a while...")
# url = 'https://drive.google.com/file/d/1n6YcqHl5Y2xRpWw7DQPrZ2FoAqfcf12I/view?usp=share_link'
url = 'https://drive.google.com/uc?id=1n6YcqHl5Y2xRpWw7DQPrZ2FoAqfcf12I'
output = 'ultralytics_yolov3.zip'
gdown.download(url, output, quiet=False)
# Extracting the zip file
with zipfile.ZipFile(output, 'r') as zip_ref:
zip_ref.extractall('.')
st.text("Folder extraction complete!")
os.remove(output) # Optional: remove the downloaded zip file after extraction.
st.text("Download and extraction completed!")
def process_with_yolo(img_pil):
with st.spinner('Downloading YOLOv3 folder...'):
download_ultralytics_yolov3_folder()
# display_files_in_directory('/home/user/app/')
# Save the image to a temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg")
img_pil.save(temp_file.name)
if os.path.exists('yolov3/runs/detect/mainlinedection'):
shutil.rmtree('yolov3/runs/detect/mainlinedection')
cmd = [
'python', 'yolov3/detect.py',
'--source', temp_file.name, # use the temp file path here
'--weights', 'yolov3/runs/train/mainline/weights/best.pt',
'--save-txt',
'--save-conf',
'--imgsz', '672',
'--name', 'mainlinedection'
]
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = process.communicate()
if process.returncode != 0:
st.error(f"YOLOv3 command failed with code {process.returncode}.")
# After processing the image, at the end of the function...
output_path = os.path.join('yolov3/runs/detect/mainlinedection', os.path.basename(temp_file.name))
if os.path.exists(output_path):
return output_path
else:
st.error("Processed image not found!")
return None
# Optional: You can print the output to Streamlit, though it might be extensive.
st.write(stdout.decode())
if stderr:
st.error(stderr.decode())
# Optional: Close and delete the temporary file
temp_file.close()
os.unlink(temp_file.name)
def display_detected_lines(original_path, output_path):
# Derive the txt_path from the output_path
txt_path = os.path.join('yolov3/runs/detect/mainlinedection/labels', os.path.basename(output_path).replace(".jpg", ".txt"))
if os.path.exists(txt_path):
original_image = Image.open(original_path)
boxes = get_detected_boxes(txt_path, original_image.width, original_image.height)
if not boxes:
st.warning("No lines detected by YOLOv3.")
return
# Create a temporary directory to store the detected lines
with TemporaryDirectory() as temp_dir:
detected_line_paths = [] # List to store paths of the detected line images
for index, box in enumerate(boxes):
x_center, y_center, width, height = box
x_min = int(x_center - (width / 2))
y_min = int(y_center - (height / 2))
x_max = int(x_center + (width / 2))
y_max = int(y_center + (height / 2))
extracted_line = original_image.crop((x_min, y_min, x_max, y_max))
# Save the detected line image to the temporary directory
detected_line_path = os.path.join(temp_dir, f"detected_line_{index}.jpg")
extracted_line.save(detected_line_path)
detected_line_paths.append(detected_line_path)
# Perform OCR on detected lines
recognized_texts = perform_ocr_on_detected_lines(detected_line_paths)
# print("Decoded OCR Results:", recognized_texts)
# st.text(f"Detected Line: {recognized_texts}")
# Display the results
for img_path, text in zip(detected_line_paths, recognized_texts):
# st.image(img_path, caption=f"Detected Line: {text}", use_column_width=True)
st.image(img_path, use_column_width=True)
st.markdown(
f"<p style='font-size: 18px; font-weight: bold;'>{text}</p>",
unsafe_allow_html=True
)
# Add a small break for better spacing
st.markdown("<br>", unsafe_allow_html=True)
else:
st.error("Annotation file (.txt) not found!")
def perform_ocr_on_detected_lines(detected_line_paths):
"""
Performs OCR on the provided list of detected line image paths.
Args:
- detected_line_paths: List of paths to the detected line images.
Returns:
- A list of recognized text for each image.
"""
# Load the saved detected lines for OCR processing
test_dataset = Loadlines(detected_line_paths)
prediction_texts = []
for batch in test_dataset:
preds = ocr_model(batch)
pred_texts = decode_predictions(preds)
# st.text(f"Decoded OCR Results : {pred_texts}")
prediction_texts.extend(pred_texts)
return prediction_texts
if __name__ == "__main__":
main()