File size: 2,335 Bytes
1cbfd2b
 
b578661
626bf47
 
15ac7f4
50ba528
1cbfd2b
b368542
be91971
 
 
b368542
 
 
be91971
 
 
 
 
 
 
 
bc3919a
b368542
 
 
be91971
 
 
 
bc3919a
b368542
 
be91971
b368542
 
bc3919a
be91971
 
bc3919a
b368542
 
be91971
b368542
 
 
 
 
 
20d1027
626bf47
15ac7f4
 
 
626bf47
1cbfd2b
 
50ba528
626bf47
 
50ba528
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import gradio as gr
import kornia as K
from kornia.core import Tensor
from kornia.contrib import ImageStitcher
import kornia.feature as KF
import torch
import numpy as np

def preprocess_image(img):
    print(f"Input image type: {type(img)}")
    print(f"Input image shape: {img.shape if hasattr(img, 'shape') else 'No shape attribute'}")
    
    # Convert numpy array to Tensor and ensure correct shape
    if isinstance(img, np.ndarray):
        img = K.image_to_tensor(img, keepdim=False).float() / 255.0
    elif isinstance(img, torch.Tensor):
        img = img.float()
        if img.max() > 1.0:
            img = img / 255.0
    else:
        raise ValueError(f"Unsupported image type: {type(img)}")
    
    print(f"After conversion to tensor - shape: {img.shape}")
    
    # Ensure 3D tensor (C, H, W)
    if img.ndim == 2:
        img = img.unsqueeze(0)
    elif img.ndim == 3 and img.shape[0] not in [1, 3]:
        img = img.permute(2, 0, 1)
    
    print(f"After ensuring 3D - shape: {img.shape}")
    
    # Ensure 3 channel image
    if img.shape[0] == 1:
        img = img.expand(3, -1, -1)
    elif img.shape[0] > 3:
        img = img[:3]  # Take only the first 3 channels if more than 3
    
    print(f"After ensuring 3 channels - shape: {img.shape}")
    
    # Add batch dimension
    img = img.unsqueeze(0)
    
    print(f"Final tensor shape: {img.shape}")
    return img

def inference(img_1, img_2):
    # Preprocess images
    img_1 = preprocess_image(img_1)
    img_2 = preprocess_image(img_2)
    
    IS = ImageStitcher(KF.LoFTR(pretrained='outdoor'), estimator='ransac')
    with torch.no_grad():
        result = IS(img_1, img_2)
        
    return K.tensor_to_image(result[0])

examples = [
    ['examples/foto1B.jpg', 'examples/foto1A.jpg'],
]

with gr.Blocks(theme='huggingface') as demo_app:
    gr.Markdown("# Image Stitching using Kornia and LoFTR")
    with gr.Row():
        input_image1 = gr.Image(label="Input Image 1")
        input_image2 = gr.Image(label="Input Image 2")
    output_image = gr.Image(label="Output Image")
    stitch_button = gr.Button("Stitch Images")
    stitch_button.click(fn=inference, inputs=[input_image1, input_image2], outputs=output_image)
    gr.Examples(examples=examples, inputs=[input_image1, input_image2])

if __name__ == "__main__":
    demo_app.launch()