CatDog / app.py
tshrusd
first commmit
a8fc39e
raw history blame
No virus
1.62 kB
import torch
import torch.nn as nn
from torchvision import transforms
import pytorch_lightning as pl
import torchvision.models as models
import gradio as gr
class Net(pl.LightningModule):
def __init__(self):
super().__init__()
# init a pretrained resnet
backbone = models.resnet50()
num_filters = backbone.fc.in_features
layers = list(backbone.children())[:-1]
self.feature_extractor = nn.Sequential(*layers)
# use the pretrained model to classify cifar-10 (10 image classes)
num_target_classes = 2
self.classifier = nn.Linear(num_filters, num_target_classes)
def forward(self, x):
self.feature_extractor.eval()
with torch.no_grad():
representations = self.feature_extractor(x).flatten(1)
x = self.classifier(representations)
return x
net = Net()
net = net.load_from_checkpoint(
'./resnet50_transfer.ckpt', map_location=torch.device('cpu'))
net.eval()
labels = ['Cat', 'Dog']
val_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Resize(256),
transforms.CenterCrop(224),
]
)
def inference(img):
img = val_transform(img)
img = torch.unsqueeze(img, 0)
with torch.no_grad():
pred = net(img)
pred = torch.softmax(pred, dim=1)
scores = pred.detach().numpy()[0]
confidences = {labels[i]: float(scores[i]) for i in range(len(labels))}
return confidences
iface = gr.Interface(
fn=inference,
inputs=gr.components.Image(type="numpy"),
outputs=gr.Label(num_top_classes=2)
)
iface.launch()