import streamlit as st import pandas as pd import numpy as np import cv2 from PIL import Image import os import zipfile import gdown from tempfile import TemporaryDirectory from Utils import Loadlines, decode_predictions, load_model import shutil import tempfile import subprocess # Load the OCR model outside the function to prevent reloading it every time ocr_model = load_model() def main(): st.title("Arabic Manuscript OCR") # download_folder() # Load sample images sample_images_dir = "sample_images" sample_images = [os.path.join(sample_images_dir, img) for img in os.listdir(sample_images_dir) if img.endswith(('.png', '.jpg', '.jpeg'))] # Display images in a grid and let user select col1, col2, col3 = st.columns(3) # Adjust based on how many images you have; this example assumes 3 # Normalized size for display display_size = (200, 300) # Placeholder for selected image selected_image_pil = None processed_img_path = None with col1: if st.button("Select Image 1", key="img1"): selected_image_pil = display_image(sample_images[0]) col1.image(resize_image(sample_images[0], display_size), use_column_width=True) with col2: if st.button("Select Image 2", key="img2"): selected_image_pil = display_image(sample_images[1]) col2.image(resize_image(sample_images[1], display_size), use_column_width=True) with col3: if st.button("Select Image 3", key="img3"): selected_image_pil = display_image(sample_images[2]) col3.image(resize_image(sample_images[2], display_size), use_column_width=True) # Option to upload a new image uploaded_image = st.file_uploader("Or upload a new image of the Arabic manuscript", type=["jpg", "jpeg", "png"]) if uploaded_image: selected_image_pil = Image.open(uploaded_image) st.image(selected_image_pil, caption="Uploaded Image", use_column_width=True) if selected_image_pil: thresh_pil = process_image(selected_image_pil) st.image(thresh_pil, caption="Thresholded Image", use_column_width=True) # processed_img_path = process_with_yolo(selected_image_pil) processed_img_path = process_with_yolo(thresh_pil) if processed_img_path: st.image(processed_img_path, caption="Processed with YOLO", use_column_width=True) txt_file_path = os.path.join('yolov3/runs/detect/mainlinedection/labels', os.path.basename(processed_img_path).replace(".jpg", ".txt")) if os.path.exists(txt_file_path): if uploaded_image: original_img_path = uploaded_image # If you have saved it to a location else: original_img_path = next((img for img in sample_images if Image.open(img) == selected_image_pil), None) # Match selected image with sample images display_detected_lines(original_img_path, processed_img_path) else: st.error("Annotation file (.txt) not found!") else: st.error("Error displaying the processed image.") # display_files_in_directory('/home/user/app/detected_lines/') # Function to display files in a directory def display_files_in_directory(path="."): if os.path.exists(path): files = os.listdir(path) st.write(f"Files in directory: {path}") for file in files: st.write(file) else: st.write(f"Directory {path} does not exist!") def resize_image(image_path, size): """Function to resize an image""" with Image.open(image_path) as img: img = img.resize(size) return img def display_image(img): """Function to display an image. If img is a path, open it. Otherwise, just display it.""" if isinstance(img, str): # img is a file path img = Image.open(img) st.image(img, caption="Selected Image", use_column_width=True) return img # -------------------------------------------- def process_image(selected_image): # Convert PIL image to OpenCV format opencv_image = np.array(selected_image) # Check if the image is grayscale or RGB if len(opencv_image.shape) == 3 and opencv_image.shape[2] == 3: # RGB image gray_image = cv2.cvtColor(opencv_image, cv2.COLOR_RGB2GRAY) else: # Image is already grayscale gray_image = opencv_image # Ensure the image is 8-bit grayscale if gray_image.dtype != np.uint8: gray_image = (gray_image * 255).astype(np.uint8) # Optionally apply Gaussian Blur blurred_img = cv2.GaussianBlur(gray_image, (5, 5), 0) # Apply OTSU's thresholding _, thresh = cv2.threshold(blurred_img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) # Convert thresholded image back to PIL format to display in Streamlit thresh_pil = Image.fromarray(thresh) return thresh_pil def get_detected_boxes(txt_path, img_width, img_height): with open(txt_path, 'r') as f: lines = f.readlines() boxes = [] for line in lines: parts = list(map(float, line.strip().split())) # Assuming the format is: class x_center y_center width height confidence x_center, y_center, width, height = parts[1:5] # Convert to pixel values x_center *= img_width y_center *= img_height width *= img_width height *= img_height boxes.append([x_center, y_center, width, height]) return boxes # --------------------------------------------------------------------------------- def download_ultralytics_yolov3_folder(): st.text("Downloading Ultralytics YOLOv3 folder from Google Drive. This may take a while...") # url = 'https://drive.google.com/file/d/1n6YcqHl5Y2xRpWw7DQPrZ2FoAqfcf12I/view?usp=share_link' url = 'https://drive.google.com/uc?id=1n6YcqHl5Y2xRpWw7DQPrZ2FoAqfcf12I' output = 'ultralytics_yolov3.zip' gdown.download(url, output, quiet=False) # Extracting the zip file with zipfile.ZipFile(output, 'r') as zip_ref: zip_ref.extractall('.') st.text("Folder extraction complete!") os.remove(output) # Optional: remove the downloaded zip file after extraction. st.text("Download and extraction completed!") def process_with_yolo(img_pil): with st.spinner('Downloading YOLOv3 folder...'): download_ultralytics_yolov3_folder() # display_files_in_directory('/home/user/app/') # Save the image to a temporary file temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") img_pil.save(temp_file.name) if os.path.exists('yolov3/runs/detect/mainlinedection'): shutil.rmtree('yolov3/runs/detect/mainlinedection') cmd = [ 'python', 'yolov3/detect.py', '--source', temp_file.name, # use the temp file path here '--weights', 'yolov3/runs/train/mainline/weights/best.pt', '--save-txt', '--save-conf', '--imgsz', '672', '--name', 'mainlinedection' ] process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout, stderr = process.communicate() if process.returncode != 0: st.error(f"YOLOv3 command failed with code {process.returncode}.") # After processing the image, at the end of the function... output_path = os.path.join('yolov3/runs/detect/mainlinedection', os.path.basename(temp_file.name)) if os.path.exists(output_path): return output_path else: st.error("Processed image not found!") return None # Optional: You can print the output to Streamlit, though it might be extensive. st.write(stdout.decode()) if stderr: st.error(stderr.decode()) # Optional: Close and delete the temporary file temp_file.close() os.unlink(temp_file.name) def display_detected_lines(original_path, output_path): # Derive the txt_path from the output_path txt_path = os.path.join('yolov3/runs/detect/mainlinedection/labels', os.path.basename(output_path).replace(".jpg", ".txt")) if os.path.exists(txt_path): # Load both original and thresholded images original_image = Image.open(original_path) thresholded_image = process_image(original_image) # This is your function that returns a thresholded PIL image boxes = get_detected_boxes(txt_path, original_image.width, original_image.height) if not boxes: st.warning("No lines detected by YOLOv3.") return # Create a temporary directory to store the detected lines with TemporaryDirectory() as temp_dir: detected_line_paths = [] # List to store paths of the detected line images for index, box in enumerate(boxes): x_center, y_center, width, height = box x_min = int(x_center - (width / 2)) y_min = int(y_center - (height / 2)) x_max = int(x_center + (width / 2)) y_max = int(y_center + (height / 2)) # Crop the thresholded image instead of the original extracted_line = thresholded_image.crop((x_min, y_min, x_max, y_max)) # Save the detected line image to the temporary directory detected_line_path = os.path.join(temp_dir, f"detected_line_{index}.jpg") extracted_line.save(detected_line_path) detected_line_paths.append(detected_line_path) # Perform OCR on detected lines recognized_texts = perform_ocr_on_detected_lines(detected_line_paths) # print("Decoded OCR Results:", recognized_texts) # st.text(f"Detected Line: {recognized_texts}") # Display the results for img_path, text in zip(detected_line_paths, recognized_texts): # st.image(img_path, caption=f"Detected Line: {text}", use_column_width=True) st.image(img_path, use_column_width=True) st.markdown( f"
{text}
", unsafe_allow_html=True ) # Add a small break for better spacing st.markdown("