File size: 2,371 Bytes
d863531
 
 
 
 
 
 
 
 
 
b4d55e3
d863531
 
 
 
b6a4ee3
 
d863531
 
 
 
 
 
 
 
 
e708547
 
 
 
 
 
 
 
 
 
 
d863531
 
 
b6a4ee3
d863531
 
 
 
e708547
 
 
 
 
 
d863531
 
 
b6a4ee3
 
 
 
d863531
b4d55e3
d863531
b4d55e3
d863531
 
e708547
 
 
b6a4ee3
e708547
b4d55e3
770d74c
5440d10
b4d55e3
d863531
 
 
 
e708547
b6a4ee3
 
e708547
2eef120
 
d863531
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from PIL import Image, ImageDraw

import torch
from torchvision import transforms
import torch.nn.functional as F

import gradio as gr

# import sys
# sys.path.insert(0, './')
from test import create_letr, get_lines_and_draw
from models.preprocessing import *
from models.misc import nested_tensor_from_tensor_list


model = create_letr('resnet50/checkpoint0024.pth')
model101 = create_letr('resnet101/checkpoint0024.pth')
# PREPARE PREPROCESSING
# transform_test = transforms.Compose([
#     transforms.Resize((test_size)),
#     transforms.ToTensor(),
#     transforms.Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]),
# ])
normalize = Compose([
        ToTensor(),
        Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]),
        Resize([256]),
])
normalize_512 = Compose([
        ToTensor(),
        Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]),
        Resize([512]),
])
normalize_1100 = Compose([
        ToTensor(),
        Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]),
        Resize([1100]),
])


def predict(inp, size, model_name):
    image = Image.fromarray(inp.astype('uint8'), 'RGB')
    h, w = image.height, image.width
    orig_size = torch.as_tensor([int(h), int(w)])

    if size == '1100':
        img = normalize_1100(image)
    elif size == '512':
        img = normalize_512(image)
    else:
        img = normalize(image)
    inputs = nested_tensor_from_tensor_list([img])

    with torch.no_grad():
        if model_name == 'resnet101':
            outputs = model101(inputs)[0]
        else:
            outputs = model(inputs)[0]

    lines = get_lines_and_draw(image, outputs, orig_size)

    return image, str(lines)


inputs = [
    gr.inputs.Image(),
    gr.inputs.Radio(["256", "512", "1100"]),
    gr.inputs.Radio(["resnet50", "resnet101"]),
]
outputs = [
    gr.outputs.Image(label='Image with Lines', type='numpy'),
    gr.outputs.Textbox(label='Lines points List')
]
gr.Interface(
    fn=predict,
    inputs=inputs,
    outputs=outputs,
    examples=[
        ["demo.png", '256', "resnet50"], 
        ["tappeto-per-calibrazione.jpg", '256', "resnet50"]
    ],
    title="LETR: Line Segment Detection Using Transformers without Edges",
    description="It is an end-to-end line segment detection algorithm using Transformers [published on CVPR 2021](https://github.com/mlpc-ucsd/LETR)."
).launch()