| import torch | |
| import gradio as gr | |
| import json | |
| from torchvision import transforms | |
| import torch.nn.functional as F | |
| TORCHSCRIPT_PATH = "res/screenclassification-resnet-noisystudent+web350k.torchscript" | |
| LABELS_PATH = "res/class_map_enrico.json" | |
| IMG_SIZE = 128 | |
| model = torch.jit.load(TORCHSCRIPT_PATH) | |
| with open(LABELS_PATH, "r") as f: | |
| label2Idx = json.load(f)["label2Idx"] | |
| img_transforms = transforms.Compose([ | |
| transforms.Resize(IMG_SIZE), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| def predict(img): | |
| img_input = img_transforms(img).unsqueeze(0) | |
| predictions = F.softmax(model(img_input), dim=-1)[0] | |
| confidences = {} | |
| for label in label2Idx: | |
| confidences[label] = float(predictions[int(label2Idx[label])]) | |
| return confidences | |
| example_imgs = [ | |
| "res/example.jpg", | |
| "res/screenlane-snapchat-profile.jpg", | |
| "res/screenlane-snapchat-settings.jpg", | |
| "res/example_pair1.jpg", | |
| "res/example_pair2.jpg" | |
| ] | |
| interface = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=5), examples=example_imgs) | |
| interface.launch() | |