File size: 2,399 Bytes
1cbfd2b
 
b578661
626bf47
 
15ac7f4
50ba528
1cbfd2b
b368542
be91971
 
 
b368542
 
 
be91971
 
 
 
 
 
 
 
bc3919a
4f2e28e
b368542
4f2e28e
 
 
 
 
 
 
 
 
be91971
4f2e28e
bc3919a
b368542
4f2e28e
 
 
 
b368542
be91971
b368542
 
 
 
 
 
20d1027
626bf47
15ac7f4
 
 
626bf47
1cbfd2b
 
50ba528
626bf47
 
50ba528
 
 
 
 
 
 
 
 
 
 
4f2e28e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import gradio as gr
import kornia as K
from kornia.core import Tensor
from kornia.contrib import ImageStitcher
import kornia.feature as KF
import torch
import numpy as np

def preprocess_image(img):
    print(f"Input image type: {type(img)}")
    print(f"Input image shape: {img.shape if hasattr(img, 'shape') else 'No shape attribute'}")
    
    # Convert numpy array to Tensor and ensure correct shape
    if isinstance(img, np.ndarray):
        img = K.image_to_tensor(img, keepdim=False).float() / 255.0
    elif isinstance(img, torch.Tensor):
        img = img.float()
        if img.max() > 1.0:
            img = img / 255.0
    else:
        raise ValueError(f"Unsupported image type: {type(img)}")
    
    print(f"After conversion to tensor - shape: {img.shape}")
    
    # Ensure 4D tensor (B, C, H, W)
    if img.ndim == 2:
        img = img.unsqueeze(0).unsqueeze(0)
    elif img.ndim == 3:
        if img.shape[0] in [1, 3]:
            img = img.unsqueeze(0)
        else:
            img = img.unsqueeze(1)
    elif img.ndim == 4:
        if img.shape[1] not in [1, 3]:
            img = img.permute(0, 3, 1, 2)
    
    print(f"After ensuring 4D - shape: {img.shape}")
    
    # Ensure 3 channel image
    if img.shape[1] == 1:
        img = img.repeat(1, 3, 1, 1)
    elif img.shape[1] > 3:
        img = img[:, :3]  # Take only the first 3 channels if more than 3
    
    print(f"Final tensor shape: {img.shape}")
    return img

def inference(img_1, img_2):
    # Preprocess images
    img_1 = preprocess_image(img_1)
    img_2 = preprocess_image(img_2)
    
    IS = ImageStitcher(KF.LoFTR(pretrained='outdoor'), estimator='ransac')
    with torch.no_grad():
        result = IS(img_1, img_2)
        
    return K.tensor_to_image(result[0])

examples = [
    ['examples/foto1B.jpg', 'examples/foto1A.jpg'],
]

with gr.Blocks(theme='huggingface') as demo_app:
    gr.Markdown("# Image Stitching using Kornia and LoFTR")
    with gr.Row():
        input_image1 = gr.Image(label="Input Image 1")
        input_image2 = gr.Image(label="Input Image 2")
    output_image = gr.Image(label="Output Image")
    stitch_button = gr.Button("Stitch Images")
    stitch_button.click(fn=inference, inputs=[input_image1, input_image2], outputs=output_image)
    gr.Examples(examples=examples, inputs=[input_image1, input_image2])

if __name__ == "__main__":
    demo_app.launch(share=True)