Image-Stitching / app.py
mischeiwiller's picture
Update app.py
be91971 verified
raw
history blame
2.34 kB
import gradio as gr
import kornia as K
from kornia.core import Tensor
from kornia.contrib import ImageStitcher
import kornia.feature as KF
import torch
import numpy as np
def preprocess_image(img):
print(f"Input image type: {type(img)}")
print(f"Input image shape: {img.shape if hasattr(img, 'shape') else 'No shape attribute'}")
# Convert numpy array to Tensor and ensure correct shape
if isinstance(img, np.ndarray):
img = K.image_to_tensor(img, keepdim=False).float() / 255.0
elif isinstance(img, torch.Tensor):
img = img.float()
if img.max() > 1.0:
img = img / 255.0
else:
raise ValueError(f"Unsupported image type: {type(img)}")
print(f"After conversion to tensor - shape: {img.shape}")
# Ensure 3D tensor (C, H, W)
if img.ndim == 2:
img = img.unsqueeze(0)
elif img.ndim == 3 and img.shape[0] not in [1, 3]:
img = img.permute(2, 0, 1)
print(f"After ensuring 3D - shape: {img.shape}")
# Ensure 3 channel image
if img.shape[0] == 1:
img = img.expand(3, -1, -1)
elif img.shape[0] > 3:
img = img[:3] # Take only the first 3 channels if more than 3
print(f"After ensuring 3 channels - shape: {img.shape}")
# Add batch dimension
img = img.unsqueeze(0)
print(f"Final tensor shape: {img.shape}")
return img
def inference(img_1, img_2):
# Preprocess images
img_1 = preprocess_image(img_1)
img_2 = preprocess_image(img_2)
IS = ImageStitcher(KF.LoFTR(pretrained='outdoor'), estimator='ransac')
with torch.no_grad():
result = IS(img_1, img_2)
return K.tensor_to_image(result[0])
examples = [
['examples/foto1B.jpg', 'examples/foto1A.jpg'],
]
with gr.Blocks(theme='huggingface') as demo_app:
gr.Markdown("# Image Stitching using Kornia and LoFTR")
with gr.Row():
input_image1 = gr.Image(label="Input Image 1")
input_image2 = gr.Image(label="Input Image 2")
output_image = gr.Image(label="Output Image")
stitch_button = gr.Button("Stitch Images")
stitch_button.click(fn=inference, inputs=[input_image1, input_image2], outputs=output_image)
gr.Examples(examples=examples, inputs=[input_image1, input_image2])
if __name__ == "__main__":
demo_app.launch()