Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
import os | |
import zipfile | |
import gdown | |
from tempfile import TemporaryDirectory | |
from Utils import Loadlines, decode_predictions, load_model | |
import shutil | |
import tempfile | |
import subprocess | |
import io | |
from skimage import filters | |
# Load the OCR model outside the function to prevent reloading it every time | |
ocr_model = load_model() | |
def main(): | |
st.title("منصة تجريبية لمكتب أضاميم") | |
st.title("Arabic Manuscript OCR") | |
st.warning("يرجى تحميل صورة ذو دقة عالية للحصول على نتيجة أفضل (حوالي 5000*3000)") | |
# download_folder() | |
# Load sample images | |
sample_images_dir = "sample_images" | |
sample_images = [os.path.join(sample_images_dir, img) for img in os.listdir(sample_images_dir) if img.endswith(('.png', '.jpg', '.jpeg'))] | |
# Display images in a grid and let user select | |
col1, col2, col3 = st.columns(3) # Adjust based on how many images you have; this example assumes 3 | |
# Normalized size for display | |
display_size = (200, 300) | |
# Placeholder for selected image | |
selected_image_pil = None | |
processed_img_path = None | |
with col1: | |
if st.button("Select Image 1", key="img1"): | |
selected_image_pil = display_image(sample_images[0]) | |
col1.image(resize_image(sample_images[0], display_size), use_column_width=True) | |
with col2: | |
if st.button("Select Image 2", key="img2"): | |
selected_image_pil = display_image(sample_images[1]) | |
col2.image(resize_image(sample_images[1], display_size), use_column_width=True) | |
with col3: | |
if st.button("Select Image 3", key="img3"): | |
selected_image_pil = display_image(sample_images[2]) | |
col3.image(resize_image(sample_images[2], display_size), use_column_width=True) | |
# Option to upload a new image | |
uploaded_image = st.file_uploader("Or upload a new image of the Arabic manuscript", type=["jpg", "jpeg", "png"]) | |
if uploaded_image: | |
selected_image_pil = Image.open(uploaded_image) | |
st.image(selected_image_pil, caption="Uploaded Image", use_column_width=True) | |
if selected_image_pil: | |
thresh_pil = process_image(selected_image_pil) | |
# st.image(thresh_pil, caption="Thresholded Image", use_column_width=True) | |
# processed_img_path = process_with_yolo(selected_image_pil) | |
processed_img_path = process_with_yolo(thresh_pil) | |
if processed_img_path: | |
st.image(processed_img_path, caption="Processed with YOLO", use_column_width=True) | |
txt_file_path = os.path.join('yolov3/runs/detect/mainlinedection/labels', os.path.basename(processed_img_path).replace(".jpg", ".txt")) | |
if os.path.exists(txt_file_path): | |
if uploaded_image: | |
original_img_path = uploaded_image # If you have saved it to a location | |
else: | |
original_img_path = next((img for img in sample_images if Image.open(img) == selected_image_pil), None) # Match selected image with sample images | |
display_detected_lines(original_img_path, processed_img_path) | |
else: | |
st.error("Annotation file (.txt) not found!") | |
else: | |
st.error("Error displaying the processed image.") | |
# display_files_in_directory('/home/user/app/detected_lines/') | |
# Function to display files in a directory | |
# def display_files_in_directory(path="."): | |
# if os.path.exists(path): | |
# files = os.listdir(path) | |
# st.write(f"Files in directory: {path}") | |
# for file in files: | |
# st.write(file) | |
# else: | |
# st.write(f"Directory {path} does not exist!") | |
def resize_image(image_path, size): | |
"""Function to resize an image""" | |
with Image.open(image_path) as img: | |
img = img.resize(size) | |
return img | |
def display_image(img): | |
"""Function to display an image. If img is a path, open it. Otherwise, just display it.""" | |
if isinstance(img, str): # img is a file path | |
img = Image.open(img) | |
st.image(img, caption="Selected Image", use_column_width=True) | |
return img | |
# -------------------------------------------- | |
def get_dynamic_kernel(img_height): | |
# Set the kernel size to be 1% of the image height | |
kernel_height = int(0.001 * img_height) | |
# Ensure the kernel height is odd | |
if kernel_height % 2 == 0: | |
kernel_height += 1 | |
# Set minimum and maximum limits | |
kernel_height = max(1, kernel_height) # Minimum limit | |
kernel_height = min(11, kernel_height) # Maximum limit | |
st.text(f"img_height : {img_height}") | |
st.text(f"kernel_height : {kernel_height}") | |
return np.ones((kernel_height, 1), np.uint8) | |
def process_image(selected_image): | |
# Convert PIL image to OpenCV format | |
opencv_image = np.array(selected_image) | |
# Check if the image is grayscale or RGB | |
if len(opencv_image.shape) == 3 and opencv_image.shape[2] == 3: # RGB image | |
gray_image = cv2.cvtColor(opencv_image, cv2.COLOR_RGB2GRAY) | |
else: # Image is already grayscale | |
gray_image = opencv_image | |
# Ensure the image is 8-bit grayscale | |
if gray_image.dtype != np.uint8: | |
gray_image = (gray_image * 255).astype(np.uint8) | |
# Apply adaptive thresholding | |
# Let's assume `gray_image` is your grayscale image array | |
height, width = gray_image.shape | |
# Example: setting block size to 1/30th of the average image dimension, making sure it's odd. | |
block_size = ((height + width) // 2) // 30 | |
block_size = block_size + 1 if block_size % 2 == 0 else block_size | |
# Example: setting offset to a small fraction of the global mean intensity. | |
offset = np.mean(gray_image) * 0.05 | |
adaptive_threshold = filters.threshold_local(gray_image, block_size, offset=offset, method='mean') | |
binary_adaptive = gray_image > adaptive_threshold | |
thresh = binary_adaptive | |
# Convert thresholded image back to PIL format to display in Streamlit | |
thresh_pil = Image.fromarray(thresh) | |
return thresh_pil | |
def get_detected_boxes(txt_path, img_width, img_height): | |
with open(txt_path, 'r') as f: | |
lines = f.readlines() | |
boxes = [] | |
for line in lines: | |
parts = list(map(float, line.strip().split())) | |
# Assuming the format is: class x_center y_center width height confidence | |
x_center, y_center, width, height = parts[1:5] | |
# Convert to pixel values | |
x_center *= img_width | |
y_center *= img_height | |
width *= img_width | |
height *= img_height | |
boxes.append([x_center, y_center, width, height]) | |
return boxes | |
# --------------------------------------------------------------------------------- | |
def download_ultralytics_yolov3_folder(): | |
st.text("Downloading Ultralytics YOLOv3 folder from Google Drive. This may take a while...") | |
# url = 'https://drive.google.com/file/d/1n6YcqHl5Y2xRpWw7DQPrZ2FoAqfcf12I/view?usp=share_link' | |
url = 'https://drive.google.com/uc?id=1n6YcqHl5Y2xRpWw7DQPrZ2FoAqfcf12I' | |
output = 'ultralytics_yolov3.zip' | |
gdown.download(url, output, quiet=False) | |
# Extracting the zip file | |
with zipfile.ZipFile(output, 'r') as zip_ref: | |
zip_ref.extractall('.') | |
st.text("Folder extraction complete!") | |
os.remove(output) # Optional: remove the downloaded zip file after extraction. | |
st.text("Download and extraction completed!") | |
def process_with_yolo(img_pil): | |
with st.spinner('Downloading YOLOv3 folder...'): | |
download_ultralytics_yolov3_folder() | |
# display_files_in_directory('/home/user/app/') | |
# Save the image to a temporary file | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") | |
img_pil.save(temp_file.name) | |
if os.path.exists('yolov3/runs/detect/mainlinedection'): | |
shutil.rmtree('yolov3/runs/detect/mainlinedection') | |
cmd = [ | |
'python', 'yolov3/detect.py', | |
'--source', temp_file.name, # use the temp file path here | |
'--weights', 'yolov3/runs/train/mainline/weights/best.pt', | |
'--save-txt', | |
'--save-conf', | |
'--imgsz', '672', | |
'--name', 'mainlinedection' | |
] | |
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
stdout, stderr = process.communicate() | |
if process.returncode != 0: | |
st.error(f"YOLOv3 command failed with code {process.returncode}.") | |
# After processing the image, at the end of the function... | |
output_path = os.path.join('yolov3/runs/detect/mainlinedection', os.path.basename(temp_file.name)) | |
if os.path.exists(output_path): | |
return output_path | |
else: | |
st.error("Processed image not found!") | |
return None | |
# Optional: You can print the output to Streamlit, though it might be extensive. | |
st.write(stdout.decode()) | |
if stderr: | |
st.error(stderr.decode()) | |
# Optional: Close and delete the temporary file | |
temp_file.close() | |
os.unlink(temp_file.name) | |
def display_detected_lines(original_path, output_path): | |
# Derive the txt_path from the output_path | |
txt_path = os.path.join('yolov3/runs/detect/mainlinedection/labels', os.path.basename(output_path).replace(".jpg", ".txt")) | |
if os.path.exists(txt_path): | |
# Load both original and thresholded images | |
original_image = Image.open(original_path) | |
thresholded_image = process_image(original_image) | |
boxes = get_detected_boxes(txt_path, original_image.width, original_image.height) | |
if not boxes: | |
st.warning("No lines detected by YOLOv3.") | |
return | |
# Create a temporary directory to store the detected lines | |
with TemporaryDirectory() as temp_dir: | |
# detected_line_paths = [] # For storing paths of the thresholded line images for OCR | |
original_line_paths = [] # For storing paths of the original line images for display | |
for index, box in enumerate(boxes): | |
x_center, y_center, width, height = box | |
x_min = int(x_center - (width / 2)) | |
y_min = int(y_center - (height / 2)) | |
x_max = int(x_center + (width / 2)) | |
y_max = int(y_center + (height / 2)) | |
# Crop the ORIGINAL image and save | |
original_line = original_image.crop((x_min, y_min, x_max, y_max)) | |
original_line_path = os.path.join(temp_dir, f"original_line_{index}.jpg") | |
original_line.save(original_line_path) | |
original_line_paths.append(original_line_path) | |
# # Crop the THRESHOLDED image and save for OCR | |
# extracted_line = thresholded_image.crop((x_min, y_min, x_max, y_max)) | |
# detected_line_path = os.path.join(temp_dir, f"detected_line_{index}.jpg") | |
# extracted_line.save(detected_line_path) | |
# detected_line_paths.append(detected_line_path) | |
# Perform OCR on thresholded lines | |
recognized_texts = perform_ocr_on_detected_lines(original_line_paths) | |
# Display the results | |
for original_img_path, text in zip(original_line_paths, recognized_texts): | |
st.image(original_img_path, use_column_width=True) | |
st.markdown( | |
f"<p style='font-size: 18px; font-weight: bold;'>{text}</p>", | |
unsafe_allow_html=True | |
) | |
# Add a small break for better spacing | |
st.markdown("<br>", unsafe_allow_html=True) | |
else: | |
st.error("Annotation file (.txt) not found!") | |
def perform_ocr_on_detected_lines(detected_line_paths): | |
""" | |
Performs OCR on the provided list of detected line image paths. | |
Args: | |
- detected_line_paths: List of paths to the detected line images. | |
Returns: | |
- A list of recognized text for each image. | |
""" | |
# Load the saved detected lines for OCR processing | |
test_dataset = Loadlines(detected_line_paths) | |
prediction_texts = [] | |
for batch in test_dataset: | |
preds = ocr_model(batch) | |
pred_texts = decode_predictions(preds) | |
# st.text(f"Decoded OCR Results : {pred_texts}") | |
prediction_texts.extend(pred_texts) | |
return prediction_texts | |
if __name__ == "__main__": | |
main() |