import torch import torch.nn as nn from torchvision import transforms import pytorch_lightning as pl import torchvision.models as models import gradio as gr class Net(pl.LightningModule): def __init__(self): super().__init__() # init a pretrained resnet backbone = models.resnet50() num_filters = backbone.fc.in_features layers = list(backbone.children())[:-1] self.feature_extractor = nn.Sequential(*layers) # use the pretrained model to classify cifar-10 (10 image classes) num_target_classes = 2 self.classifier = nn.Linear(num_filters, num_target_classes) def forward(self, x): self.feature_extractor.eval() with torch.no_grad(): representations = self.feature_extractor(x).flatten(1) x = self.classifier(representations) return x net = Net() net = net.load_from_checkpoint( './resnet50_transfer.ckpt', map_location=torch.device('cpu')) net.eval() labels = ['Cat', 'Dog'] val_transform = transforms.Compose( [ transforms.ToTensor(), transforms.Resize(256), transforms.CenterCrop(224), ] ) def inference(img): img = val_transform(img) img = torch.unsqueeze(img, 0) with torch.no_grad(): pred = net(img) pred = torch.softmax(pred, dim=1) scores = pred.detach().numpy()[0] confidences = {labels[i]: float(scores[i]) for i in range(len(labels))} return confidences iface = gr.Interface( fn=inference, inputs=gr.components.Image(type="numpy"), outputs=gr.Label(num_top_classes=2) ) iface.launch()