#!/usr/bin/env python # coding: utf-8 # In[ ]: import albumentations as A from albumentations.pytorch.transforms import ToTensorV2 from timm import create_model import torch import gradio as gr # In[ ]: class TestDataset(torch.utils.data.Dataset): def __init__(self,image,transforms = None): self.image = [image] self.transforms = transforms def __getitem__(self,idx): image = self.image[idx] if self.transforms: augmented = self.transforms(image=image) image = augmented["image"] return {'image':image} def __len__(self): return len(self.image) def get_test_transform(): MEAN = [0.5176, 0.4169, 0.3637] STD = [0.3010, 0.2723, 0.2672] return A.Compose([ #A.resize((256,256)), A.Normalize(MEAN,STD), ToTensorV2(transpose_mask=False,p=1.0) ]) # In[ ]: def predict_image(image): test_dataset = TestDataset(image,transforms = get_test_transform()) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = 1, pin_memory = False, num_workers = 8, shuffle = False) # Loading weights for data in test_loader: for key,value in data.items(): data[key] = value.to('cpu') # Appending Output and Targets: output = torch.sigmoid(model(data['image'])).cpu().detach().numpy() dict_ = {'Down':float(1-output[0][0]),'Upside':float(output[0][0])} return dict_ # In[ ]: model = create_model('resnet18',pretrained = False,num_classes = 1) checkpoint = torch.load('model.pt',map_location = 'cpu') model.load_state_dict(checkpoint,strict = False) # In[ ]: title = "Upside-Down Detector" interpretation='default' enable_queue=True gr.Interface(fn=predict_image,inputs=gr.inputs.Image(shape=(256, 256)),outputs=gr.outputs.Label(num_top_classes=2),title=title,interpretation=interpretation).launch()