File size: 1,726 Bytes
15216b5
68ba513
 
58b19a1
 
 
 
 
15216b5
58b19a1
 
 
 
 
 
 
 
 
 
b38da9b
58b19a1
 
f3dfeae
58b19a1
 
49684df
f3dfeae
f489399
58b19a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6524bf
58b19a1
15216b5
 
 
58b19a1
 
f3dfeae
 
 
034449b
f3dfeae
034449b
f3dfeae
58b19a1
 
f3dfeae
 
f6524bf
034449b
 
f3dfeae
58b19a1
f3dfeae
 
 
 
 
 
 
15216b5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import gradio as gr
import monai
import torch
from monai.networks.nets import UNet
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np


model = UNet(
    spatial_dims=2,
    in_channels=3,
    out_channels=1,
    channels=[16, 32, 64, 128, 256, 512],
    strides=(2, 2, 2, 2, 2),
    num_res_units=4,
    dropout=0.15,
)
model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu')))
model.eval()

def greet(image):
    

    # image = Image.open(image_path).convert("RGB")
    # image = np.array(image) / 255.0
    image = image / 255.0
    image = image.astype(np.float32)

    inference_transforms = A.Compose([
        A.Resize(height=512, width=512),
        ToTensorV2(),
    ])

    image = inference_transforms(image=image)["image"]

    image = image.unsqueeze(0)


    with torch.no_grad():
        mask_pred = model(image)
        
    return mask_pred[0].numpy()
    

demo = gr.Interface(
    fn=greet,
    title="Histapathology segmentation",
    inputs=[
        gr.Image(
            label="Input image",
            image_mode="RGB",
            # height=400,
            type="numpy",
            # width=400,
        )
    ],
    outputs=[
        gr.Image(
            label="Model Prediction",
            image_mode="L",
            # height=400,
            # width=400,
        )
    ],
    # examples=[
    #     os.path.join(os.path.dirname(__file__), "images/cheetah1.jpg"),
    #     os.path.join(os.path.dirname(__file__), "images/lion.jpg"),
    #     os.path.join(os.path.dirname(__file__), "images/logo.png"),
    #     os.path.join(os.path.dirname(__file__), "images/tower.jpg"),
    # ],

)

demo.launch()