will33am's picture
update
01b0a68
raw
history blame
2.14 kB
#!/usr/bin/env python
# coding: utf-8
# In[ ]:
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
from timm import create_model
import torch
import gradio as gr
# In[ ]:
class TestDataset(torch.utils.data.Dataset):
def __init__(self,image,transforms = None):
self.image = [image]
self.transforms = transforms
def __getitem__(self,idx):
image = self.image[idx]
if self.transforms:
augmented = self.transforms(image=image)
image = augmented["image"]
return {'image':image}
def __len__(self):
return len(self.image)
def get_test_transform():
MEAN = [0.5176, 0.4169, 0.3637]
STD = [0.3010, 0.2723, 0.2672]
return A.Compose([
#A.resize((256,256)),
A.Normalize(MEAN,STD),
ToTensorV2(transpose_mask=False,p=1.0)
])
# In[ ]:
def predict_image(image):
test_dataset = TestDataset(image,transforms = get_test_transform())
test_loader = torch.utils.data.DataLoader(test_dataset,
batch_size = 1,
pin_memory = False,
num_workers = 8,
shuffle = False)
# Loading weights
for data in test_loader:
for key,value in data.items():
data[key] = value.to('cpu')
# Appending Output and Targets:
output = torch.sigmoid(model(data['image'])).cpu().detach().numpy()
dict_ = {'Down':float(1-output[0][0]),'Upside':float(output[0][0])}
return dict_
# In[ ]:
model = create_model('resnet18',pretrained = False,num_classes = 1)
checkpoint = torch.load('model.pt',map_location = 'cpu')
model.load_state_dict(checkpoint,strict = False)
# In[ ]:
title = "Upside-Down Detector"
interpretation='default'
enable_queue=True
gr.Interface(fn=predict_image,inputs=gr.inputs.Image(shape=(256, 256)),outputs=gr.outputs.Label(num_top_classes=2),title=title,interpretation=interpretation).launch(share = True)