File size: 1,736 Bytes
be9f2dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import cv2
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import gradio as gr
import albumentations as A 
from torchvision import transforms
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt 

class Model(nn.Module):
    
    def __init__(self):
        super().__init__()
        params = {}
        params['encoder_name'] = 'resnet34'
        params['encoder_weights'] = 'imagenet'
        params['in_channels'] = 3
        params['classes'] = 1
        params['activation'] = 'identity'
        self.model = smp.Unet(**params)
        return None
    
    def forward(self, image):
        output = self.model(image)
        output = output.squeeze()
        return output

weights = torch.load('model.pt', map_location='cpu')
weights = weights['model_state_dict']
MODEL = Model()
MODEL.load_state_dict(weights)
MODEL = MODEL.eval()

def process(image):
    transform = []
    transform.append(A.PadIfNeeded(min_height=736, min_width=736, value=255))
    transform = A.Compose(transform)
    output = transform(image=image)
    image = output['image']
    orig = image.copy()
    mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
    transform = []
    transform.append(transforms.ToTensor())
    transform.append(transforms.Normalize(mean, std))
    transform = transforms.Compose(transform)
    image = transform(image)
    return orig, image

def score(image):
	global MODEL
	orig, image = process(image)
	pred = MODEL(image.unsqueeze(0)).squeeze().data.numpy()
	pred =  1/(1 + np.exp(-pred))

	return pred

if __name__ == '__main__':
	demo = gr.Interface(score, gr.Image(shape=(720,720)), "image", title="BDD Lane Detection Demo", examples=["sample.jpg"])
	demo.launch()