brightertiger's picture
Upload app.py
be9f2dd
import cv2
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import gradio as gr
import albumentations as A
from torchvision import transforms
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt
class Model(nn.Module):
def __init__(self):
super().__init__()
params = {}
params['encoder_name'] = 'resnet34'
params['encoder_weights'] = 'imagenet'
params['in_channels'] = 3
params['classes'] = 1
params['activation'] = 'identity'
self.model = smp.Unet(**params)
return None
def forward(self, image):
output = self.model(image)
output = output.squeeze()
return output
weights = torch.load('model.pt', map_location='cpu')
weights = weights['model_state_dict']
MODEL = Model()
MODEL.load_state_dict(weights)
MODEL = MODEL.eval()
def process(image):
transform = []
transform.append(A.PadIfNeeded(min_height=736, min_width=736, value=255))
transform = A.Compose(transform)
output = transform(image=image)
image = output['image']
orig = image.copy()
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
transform = []
transform.append(transforms.ToTensor())
transform.append(transforms.Normalize(mean, std))
transform = transforms.Compose(transform)
image = transform(image)
return orig, image
def score(image):
global MODEL
orig, image = process(image)
pred = MODEL(image.unsqueeze(0)).squeeze().data.numpy()
pred = 1/(1 + np.exp(-pred))
return pred
if __name__ == '__main__':
demo = gr.Interface(score, gr.Image(shape=(720,720)), "image", title="BDD Lane Detection Demo", examples=["sample.jpg"])
demo.launch()