from PIL import Image, ImageDraw import torch from torchvision import transforms import torch.nn.functional as F import gradio as gr # import sys # sys.path.insert(0, './') from test import create_letr, get_lines_and_draw from models.preprocessing import * from models.misc import nested_tensor_from_tensor_list model = create_letr('resnet50/checkpoint0024.pth') model101 = create_letr('resnet101/checkpoint0024.pth') # PREPARE PREPROCESSING # transform_test = transforms.Compose([ # transforms.Resize((test_size)), # transforms.ToTensor(), # transforms.Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]), # ]) normalize = Compose([ ToTensor(), Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]), Resize([256]), ]) normalize_512 = Compose([ ToTensor(), Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]), Resize([512]), ]) normalize_1100 = Compose([ ToTensor(), Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]), Resize([1100]), ]) def predict(inp, size, model_name): image = Image.fromarray(inp.astype('uint8'), 'RGB') h, w = image.height, image.width orig_size = torch.as_tensor([int(h), int(w)]) if size == '1100': img = normalize_1100(image) elif size == '512': img = normalize_512(image) else: img = normalize(image) inputs = nested_tensor_from_tensor_list([img]) with torch.no_grad(): if model_name == 'resnet101': outputs = model101(inputs)[0] else: outputs = model(inputs)[0] lines = get_lines_and_draw(image, outputs, orig_size) return image, str(lines) inputs = [ gr.inputs.Image(), gr.inputs.Radio(["256", "512", "1100"]), gr.inputs.Radio(["resnet50", "resnet101"]), ] outputs = [ gr.outputs.Image(label='Image with Lines', type='numpy'), gr.outputs.Textbox(label='Lines points List') ] gr.Interface( fn=predict, inputs=inputs, outputs=outputs, examples=[ ["demo.png", '256', "resnet50"], ["tappeto-per-calibrazione.jpg", '256', "resnet50"] ], title="LETR: Line Segment Detection Using Transformers without Edges", description="It is an end-to-end line segment detection algorithm using Transformers [published on CVPR 2021](https://github.com/mlpc-ucsd/LETR)." ).launch()