Spaces:
Running
Running
import cv2 | |
import numpy as np | |
import gradio as gr | |
from PIL import Image | |
import tempfile | |
def equalize_exposure(images): | |
equalized_images = [] | |
for img in images: | |
img_lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) | |
l, a, b = cv2.split(img_lab) | |
# Apply CLAHE to L-channel | |
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) | |
l_eq = clahe.apply(l) | |
img_eq = cv2.merge((l_eq, a, b)) | |
img_eq = cv2.cvtColor(img_eq, cv2.COLOR_LAB2BGR) | |
equalized_images.append(img_eq) | |
return equalized_images | |
def stitch_images(image_files): | |
# Load images and convert to BGR format | |
images = [] | |
for file in image_files: | |
img_pil = Image.open(file).convert('RGB') | |
img_bgr = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR) | |
images.append(img_bgr) | |
# Check if there are at least two images | |
if len(images) < 2: | |
print("Need at least two images to stitch.") | |
return None, None | |
# Equalize exposure of images | |
images_eq = equalize_exposure(images) | |
# Create a Stitcher object using the default parameters | |
stitcher = cv2.Stitcher_create(cv2.Stitcher_PANORAMA) | |
# Configure stitcher parameters | |
stitcher.setPanoConfidenceThresh(0.8) | |
stitcher.setWaveCorrection(False) | |
# Perform stitching | |
status, stitched = stitcher.stitch(images_eq) | |
if status != cv2.Stitcher_OK: | |
print(f"Image stitching failed ({status})") | |
return None, None | |
# Perspective correction | |
# Convert to grayscale | |
gray = cv2.cvtColor(stitched, cv2.COLOR_BGR2GRAY) | |
# Find all non-zero points (non-black areas) | |
coords = cv2.findNonZero(gray) | |
x, y, w, h = cv2.boundingRect(coords) | |
# Define source and destination points for perspective transform | |
src_pts = np.float32([ | |
[x, y], | |
[x + w, y], | |
[x + w, y + h], | |
[x, y + h] | |
]) | |
dst_pts = np.float32([ | |
[0, 0], | |
[w, 0], | |
[w, h], | |
[0, h] | |
]) | |
# Compute the perspective transform matrix and apply it | |
M = cv2.getPerspectiveTransform(src_pts, dst_pts) | |
warped = cv2.warpPerspective(stitched, M, (w, h)) | |
# Convert corrected image back to PIL format | |
stitched_rgb = cv2.cvtColor(warped, cv2.COLOR_BGR2RGB) | |
stitched_image = Image.fromarray(stitched_rgb) | |
# Save the stitched image to a temporary file for download | |
temp_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False) | |
stitched_image.save(temp_file.name) | |
return stitched_image, temp_file.name | |
# Gradio Interface | |
with gr.Blocks() as interface: | |
gr.Markdown("<h1 style='color: #2196F3; text-align: center;'>Image Stitcher 🧵</h1>") | |
gr.Markdown("<h3 style='color: #2196F3; text-align: center;'>Upload the images you want to stitch</h3>") | |
image_upload = gr.Files(type="filepath", label="Upload Images") | |
stitch_button = gr.Button("Stitch", variant="primary") | |
stitched_image = gr.Image(type="pil", label="Stitched Image") | |
download_button = gr.File(label="Download Stitched Image") | |
stitch_button.click(stitch_images, inputs=image_upload, outputs=[stitched_image, download_button]) | |
interface.launch() |