from PIL import Image, ImageDraw import torch from torchvision import transforms import torch.nn.functional as F import gradio as gr # import sys # sys.path.insert(0, './') from test import create_letr, draw_fig from models.preprocessing import * from models.misc import nested_tensor_from_tensor_list model = create_letr('resnet50/checkpoint0024.pth') model101 = create_letr('resnet101/checkpoint0024.pth') # PREPARE PREPROCESSING # transform_test = transforms.Compose([ # transforms.Resize((test_size)), # transforms.ToTensor(), # transforms.Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]), # ]) normalize = Compose([ ToTensor(), Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]), Resize([256]), ]) normalize_512 = Compose([ ToTensor(), Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]), Resize([512]), ]) normalize_1100 = Compose([ ToTensor(), Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]), Resize([1100]), ]) def predict(inp, size, model_name): image = Image.fromarray(inp.astype('uint8'), 'RGB') h, w = image.height, image.width orig_size = torch.as_tensor([int(h), int(w)]) if size == '1100': img = normalize_1100(image) elif size == '512': img = normalize_512(image) else: img = normalize(image) inputs = nested_tensor_from_tensor_list([img]) with torch.no_grad(): if model_name == 'resnet101': outputs = model101(inputs)[0] else: outputs = model(inputs)[0] draw_fig(image, outputs, orig_size) return image inputs = [ gr.inputs.Image(), gr.inputs.Radio(["256", "512", "1100"]), gr.inputs.Radio(["resnet50", "resnet101"]), ] outputs = gr.outputs.Image() gr.Interface( fn=predict, inputs=inputs, outputs=outputs, examples=[ ["demo.png", '256', "resnet50"], ["tappeto-per-calibrazione.jpg", '256', "resnet50"] ], title="LETR", description="Model for line detection..." ).launch()