Image_Stitcher / app.py
ammariii08's picture
Update app.py
26b3f96 verified
import cv2
import numpy as np
import gradio as gr
from PIL import Image
import tempfile
def equalize_exposure(images):
equalized_images = []
for img in images:
img_lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
l, a, b = cv2.split(img_lab)
# Apply CLAHE to L-channel
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
l_eq = clahe.apply(l)
img_eq = cv2.merge((l_eq, a, b))
img_eq = cv2.cvtColor(img_eq, cv2.COLOR_LAB2BGR)
equalized_images.append(img_eq)
return equalized_images
def stitch_images(image_files):
# Load images and convert to BGR format
images = []
for file in image_files:
img_pil = Image.open(file).convert('RGB')
img_bgr = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
images.append(img_bgr)
# Check if there are at least two images
if len(images) < 2:
print("Need at least two images to stitch.")
return None, None
# Equalize exposure of images
images_eq = equalize_exposure(images)
# Create a Stitcher object using the default parameters
stitcher = cv2.Stitcher_create(cv2.Stitcher_PANORAMA)
# Configure stitcher parameters
stitcher.setPanoConfidenceThresh(0.8)
stitcher.setWaveCorrection(False)
# Perform stitching
status, stitched = stitcher.stitch(images_eq)
if status != cv2.Stitcher_OK:
print(f"Image stitching failed ({status})")
return None, None
# Perspective correction
# Convert to grayscale
gray = cv2.cvtColor(stitched, cv2.COLOR_BGR2GRAY)
# Find all non-zero points (non-black areas)
coords = cv2.findNonZero(gray)
x, y, w, h = cv2.boundingRect(coords)
# Define source and destination points for perspective transform
src_pts = np.float32([
[x, y],
[x + w, y],
[x + w, y + h],
[x, y + h]
])
dst_pts = np.float32([
[0, 0],
[w, 0],
[w, h],
[0, h]
])
# Compute the perspective transform matrix and apply it
M = cv2.getPerspectiveTransform(src_pts, dst_pts)
warped = cv2.warpPerspective(stitched, M, (w, h))
# Convert corrected image back to PIL format
stitched_rgb = cv2.cvtColor(warped, cv2.COLOR_BGR2RGB)
stitched_image = Image.fromarray(stitched_rgb)
# Save the stitched image to a temporary file for download
temp_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
stitched_image.save(temp_file.name)
return stitched_image, temp_file.name
# Gradio Interface
with gr.Blocks() as interface:
gr.Markdown("<h1 style='color: #2196F3; text-align: center;'>Image Stitcher 🧵</h1>")
gr.Markdown("<h3 style='color: #2196F3; text-align: center;'>Upload the images you want to stitch</h3>")
image_upload = gr.Files(type="filepath", label="Upload Images")
stitch_button = gr.Button("Stitch", variant="primary")
stitched_image = gr.Image(type="pil", label="Stitched Image")
download_button = gr.File(label="Download Stitched Image")
stitch_button.click(stitch_images, inputs=image_upload, outputs=[stitched_image, download_button])
interface.launch()