import torch import gradio as gr import json from torchvision import transforms import torch.nn.functional as F TORCHSCRIPT_PATH = "res/screenclassification-resnet-noisystudent+web350k.torchscript" LABELS_PATH = "res/class_map_enrico.json" IMG_SIZE = 128 model = torch.jit.load(TORCHSCRIPT_PATH) with open(LABELS_PATH, "r") as f: label2Idx = json.load(f)["label2Idx"] img_transforms = transforms.Compose([ transforms.Resize(IMG_SIZE), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) def predict(img): img_input = img_transforms(img).unsqueeze(0) predictions = F.softmax(model(img_input), dim=-1)[0] confidences = {} for label in label2Idx: confidences[label] = float(predictions[int(label2Idx[label])]) return confidences example_imgs = [ "res/example.jpg", "res/screenlane-snapchat-profile.jpg", "res/screenlane-snapchat-settings.jpg", "res/example_pair1.jpg", "res/example_pair2.jpg" ] interface = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=5), examples=example_imgs) interface.launch()