Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
import os | |
import zipfile | |
import gdown | |
from tempfile import TemporaryDirectory | |
from Utils import Loadlines, decode_predictions, load_model | |
import shutil | |
import tempfile | |
import subprocess | |
# Load the OCR model outside the function to prevent reloading it every time | |
ocr_model = load_model() | |
def main(): | |
st.title("Arabic Manuscript OCR") | |
# download_folder() | |
# Load sample images | |
sample_images_dir = "sample_images" | |
sample_images = [os.path.join(sample_images_dir, img) for img in os.listdir(sample_images_dir) if img.endswith(('.png', '.jpg', '.jpeg'))] | |
# Display images in a grid and let user select | |
col1, col2, col3 = st.columns(3) # Adjust based on how many images you have; this example assumes 3 | |
# Normalized size for display | |
display_size = (200, 300) | |
# Placeholder for selected image | |
selected_image_pil = None | |
processed_img_path = None | |
with col1: | |
if st.button("Select Image 1", key="img1"): | |
selected_image_pil = display_image(sample_images[0]) | |
col1.image(resize_image(sample_images[0], display_size), use_column_width=True) | |
with col2: | |
if st.button("Select Image 2", key="img2"): | |
selected_image_pil = display_image(sample_images[1]) | |
col2.image(resize_image(sample_images[1], display_size), use_column_width=True) | |
with col3: | |
if st.button("Select Image 3", key="img3"): | |
selected_image_pil = display_image(sample_images[2]) | |
col3.image(resize_image(sample_images[2], display_size), use_column_width=True) | |
# Option to upload a new image | |
uploaded_image = st.file_uploader("Or upload a new image of the Arabic manuscript", type=["jpg", "jpeg", "png"]) | |
if uploaded_image: | |
selected_image_pil = Image.open(uploaded_image) | |
st.image(selected_image_pil, caption="Uploaded Image", use_column_width=True) | |
if selected_image_pil: | |
thresh_pil = process_image(selected_image_pil) | |
st.image(thresh_pil, caption="Thresholded Image", use_column_width=True) | |
processed_img_path = process_with_yolo(selected_image_pil) | |
if processed_img_path: | |
st.image(processed_img_path, caption="Processed with YOLO", use_column_width=True) | |
txt_file_path = os.path.join('yolov3/runs/detect/mainlinedection/labels', os.path.basename(processed_img_path).replace(".jpg", ".txt")) | |
if os.path.exists(txt_file_path): | |
if uploaded_image: | |
original_img_path = uploaded_image # If you have saved it to a location | |
else: | |
original_img_path = next((img for img in sample_images if Image.open(img) == selected_image_pil), None) # Match selected image with sample images | |
display_detected_lines(original_img_path, processed_img_path) | |
else: | |
st.error("Annotation file (.txt) not found!") | |
else: | |
st.error("Error displaying the processed image.") | |
# display_files_in_directory('/home/user/app/detected_lines/') | |
# Function to display files in a directory | |
def display_files_in_directory(path="."): | |
if os.path.exists(path): | |
files = os.listdir(path) | |
st.write(f"Files in directory: {path}") | |
for file in files: | |
st.write(file) | |
else: | |
st.write(f"Directory {path} does not exist!") | |
def resize_image(image_path, size): | |
"""Function to resize an image""" | |
with Image.open(image_path) as img: | |
img = img.resize(size) | |
return img | |
def display_image(img): | |
"""Function to display an image. If img is a path, open it. Otherwise, just display it.""" | |
if isinstance(img, str): # img is a file path | |
img = Image.open(img) | |
st.image(img, caption="Selected Image", use_column_width=True) | |
return img | |
# -------------------------------------------- | |
def process_image(selected_image): | |
# Convert PIL image to OpenCV format | |
opencv_image = np.array(selected_image) | |
# Check if the image is grayscale or RGB | |
if len(opencv_image.shape) == 3 and opencv_image.shape[2] == 3: # RGB image | |
gray_image = cv2.cvtColor(opencv_image, cv2.COLOR_RGB2GRAY) | |
else: # Image is already grayscale | |
gray_image = opencv_image | |
# Ensure the image is 8-bit grayscale | |
if gray_image.dtype != np.uint8: | |
gray_image = (gray_image * 255).astype(np.uint8) | |
# Optionally apply Gaussian Blur | |
blurred_img = cv2.GaussianBlur(gray_image, (5, 5), 0) | |
# Apply OTSU's thresholding | |
_, thresh = cv2.threshold(blurred_img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
# Convert thresholded image back to PIL format to display in Streamlit | |
thresh_pil = Image.fromarray(thresh) | |
return thresh_pil | |
# --------------------------------------------------------------------------------- | |
def download_ultralytics_yolov3_folder(): | |
st.text("Downloading Ultralytics YOLOv3 folder from Google Drive. This may take a while...") | |
# url = 'https://drive.google.com/file/d/1n6YcqHl5Y2xRpWw7DQPrZ2FoAqfcf12I/view?usp=share_link' | |
url = 'https://drive.google.com/uc?id=1n6YcqHl5Y2xRpWw7DQPrZ2FoAqfcf12I' | |
output = 'ultralytics_yolov3.zip' | |
gdown.download(url, output, quiet=False) | |
# Extracting the zip file | |
with zipfile.ZipFile(output, 'r') as zip_ref: | |
zip_ref.extractall('.') | |
st.text("Folder extraction complete!") | |
os.remove(output) # Optional: remove the downloaded zip file after extraction. | |
st.text("Download and extraction completed!") | |
def process_with_yolo(img_pil): | |
with st.spinner('Downloading YOLOv3 folder...'): | |
download_ultralytics_yolov3_folder() | |
# display_files_in_directory('/home/user/app/') | |
# Save the image to a temporary file | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") | |
img_pil.save(temp_file.name) | |
if os.path.exists('yolov3/runs/detect/mainlinedection'): | |
shutil.rmtree('yolov3/runs/detect/mainlinedection') | |
cmd = [ | |
'python', 'yolov3/detect.py', | |
'--source', temp_file.name, # use the temp file path here | |
'--weights', 'yolov3/runs/train/mainline/weights/best.pt', | |
'--save-txt', | |
'--save-conf', | |
'--imgsz', '672', | |
'--name', 'mainlinedection' | |
] | |
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
stdout, stderr = process.communicate() | |
if process.returncode != 0: | |
st.error(f"YOLOv3 command failed with code {process.returncode}.") | |
# After processing the image, at the end of the function... | |
output_path = os.path.join('yolov3/runs/detect/mainlinedection', os.path.basename(temp_file.name)) | |
if os.path.exists(output_path): | |
return output_path | |
else: | |
st.error("Processed image not found!") | |
return None | |
# Optional: You can print the output to Streamlit, though it might be extensive. | |
st.write(stdout.decode()) | |
if stderr: | |
st.error(stderr.decode()) | |
# Optional: Close and delete the temporary file | |
temp_file.close() | |
os.unlink(temp_file.name) | |
def display_detected_lines(original_path, output_path): | |
# Derive the txt_path from the output_path | |
txt_path = os.path.join('yolov3/runs/detect/mainlinedection/labels', os.path.basename(output_path).replace(".jpg", ".txt")) | |
if os.path.exists(txt_path): | |
original_image = Image.open(original_path) | |
boxes = get_detected_boxes(txt_path, original_image.width, original_image.height) | |
if not boxes: | |
st.warning("No lines detected by YOLOv3.") | |
return | |
# Create a temporary directory to store the detected lines | |
with TemporaryDirectory() as temp_dir: | |
detected_line_paths = [] # List to store paths of the detected line images | |
for index, box in enumerate(boxes): | |
x_center, y_center, width, height = box | |
x_min = int(x_center - (width / 2)) | |
y_min = int(y_center - (height / 2)) | |
x_max = int(x_center + (width / 2)) | |
y_max = int(y_center + (height / 2)) | |
extracted_line = original_image.crop((x_min, y_min, x_max, y_max)) | |
# Save the detected line image to the temporary directory | |
detected_line_path = os.path.join(temp_dir, f"detected_line_{index}.jpg") | |
extracted_line.save(detected_line_path) | |
detected_line_paths.append(detected_line_path) | |
# Perform OCR on detected lines | |
recognized_texts = perform_ocr_on_detected_lines(detected_line_paths) | |
# print("Decoded OCR Results:", recognized_texts) | |
# st.text(f"Detected Line: {recognized_texts}") | |
# Display the results | |
for img_path, text in zip(detected_line_paths, recognized_texts): | |
# st.image(img_path, caption=f"Detected Line: {text}", use_column_width=True) | |
st.image(img_path, use_column_width=True) | |
st.markdown( | |
f"<p style='font-size: 18px; font-weight: bold;'>{text}</p>", | |
unsafe_allow_html=True | |
) | |
# Add a small break for better spacing | |
st.markdown("<br>", unsafe_allow_html=True) | |
else: | |
st.error("Annotation file (.txt) not found!") | |
def perform_ocr_on_detected_lines(detected_line_paths): | |
""" | |
Performs OCR on the provided list of detected line image paths. | |
Args: | |
- detected_line_paths: List of paths to the detected line images. | |
Returns: | |
- A list of recognized text for each image. | |
""" | |
# Load the saved detected lines for OCR processing | |
test_dataset = Loadlines(detected_line_paths) | |
prediction_texts = [] | |
for batch in test_dataset: | |
preds = ocr_model(batch) | |
pred_texts = decode_predictions(preds) | |
# st.text(f"Decoded OCR Results : {pred_texts}") | |
prediction_texts.extend(pred_texts) | |
return prediction_texts | |
if __name__ == "__main__": | |
main() |