File size: 4,669 Bytes
98b6309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
import glob
import os
import cv2
import glob
from models.model import Model
#from evaluation import evaluation
datasets = ["HRF","DR-HAGIS","OVRS","ISBI2016_ISIC", "CHASEDB","KEVASIR-SEG"]
models = ["UNET","SA-UNET", "ATT-UNET","SAM", "SAM-ZERO-SHOT"]




def default_mask():
    return None
def default_metrics():
    return {}
def on_select_gallery(dataset: gr.SelectData):
    path = glob.glob(os.path.join("datasets",dataset.value[0],"test/images/*"))
    return gr.update(visible=True, value = path)
def on_select_dataset_path(dataset: gr.SelectData):
    path = glob.glob(os.path.join("datasets",dataset.value[0],"test/images/*"))
    return  path
def on_select_gallery_label(dataset: gr.SelectData):
    path =dataset.value[0]
    return gr.update(visible=True, value = "## "+path+" Dataset")
def on_select_gallery_set_mask_path(value: gr.SelectData,path,dataset):
    image_path = path[int(value.index)]
    return os.path.join("datasets", dataset,"test","masks",os.path.basename(image_path))
def set_index(value: gr.SelectData):
    return str(value.index)
def set_value(value: gr.SelectData):
    return str(value.value)[2:-2]
def segment(dataset, model, image_index, path):
    return gr.update(visible= True)
def segment_(dataset, model, image_index,path):
    image_path = path[int(image_index)]
    mask_path = os.path.join("datasets", dataset,"test","masks",os.path.basename(image_path))
    return model.predict(image_path, mask_path)
def on_select_model(model: gr.SelectData,dataset):
    return Model(model.value, dataset)

def update_model_dataset(dataset : gr.SelectData , model):
    return gr.update(value = model.set_dataset(dataset.value[0]))
def update_model_name(model_name : gr.SelectData, model):
    return gr.update(value = model.set_model(model_name.value))
def update_dropdown_labels(dataset : gr.SelectData):
    choises = glob.glob(f"models/models_checkpoints/{dataset.value[0]}/*pth")
    choises = list(map(lambda value: value.split("/")[-1].split('.')[0], choises))
    return gr.Dropdown.update(choices = choises, interactive=True, label = "Models")

def test(model):
    return gr.update(value=model.evaluation())



with gr.Blocks() as demo:
    ## States
    model_build = gr.State()
    mask_p_state  = gr.State()
    #models_state = gr.State([])
    dataset_state = gr.State("ISBI2016_ISIC")
    path = glob.glob(os.path.join("datasets",dataset_state.value,"test/images/*"))
    path = gr.State(value=path)
    model_state  = gr.State(value = Model("UNET", dataset_state.value))
    demo_header = gr.Markdown(value="# A Comparative Analysis of State-of-the-Art Deep learning Models for Medical Image")
    image_index = gr.Markdown(visible = False)
    mask_y_state = gr.State()

    dataset = gr.Dataset(label="Datasets",components = [gr.Textbox(visible= False)], samples = list(map(lambda item: [item], datasets)),)
    model = gr.Dropdown(interactive=False, label = "Models",value=lambda : "UNET")
    dataset.select(update_dropdown_labels,None,outputs = [model])
    model.select(update_model_name,[model_state] , [model_state])
    dataset_label = gr.Markdown()
    gallery = gr.Gallery(value = path.value,visible=True)
    dataset.select(on_select_gallery,None,outputs = [gallery])
    dataset.select(on_select_gallery_label,None,outputs = [dataset_label])
    dataset.select(set_value,None,outputs = [dataset_state])
    dataset.select(update_model_dataset,[model_state],outputs = [model_state])
    dataset.select(on_select_dataset_path,None,outputs = [path])
    segement_btn = gr.Button(value = "Segment")
    dataset_label = gr.Markdown(value="# Result")
    gallery.select(set_index, None,outputs= [image_index])
    #mask_y_image = gr.Image(scale=2)
    row = gr.Row(visible=False)
    segement_btn.click(segment, [dataset_state, model_state, image_index, path], [row])
    with row:
        mask_p_image = gr.Image(scale=2)
        mask_y_image = gr.Image(scale=2)
        metrics_label = gr.Label(label = "Metrics",inputs=[mask_p_image])
        mask_p_image.change(test, inputs=[model_state],outputs = [metrics_label])
    gallery.select(on_select_gallery_set_mask_path, [path, dataset_state],outputs= [mask_y_state])
    gallery.select(on_select_gallery_set_mask_path, [path, dataset_state],outputs= [mask_y_image])
    segement_btn.click(segment_, [dataset_state, model_state, image_index, path], [mask_p_image])
    dataset.select(default_mask,None, outputs = [mask_p_image])
    dataset.select(default_metrics, None,outputs= [metrics_label])
    gallery.select(default_metrics, None,outputs= [metrics_label])
    gallery.select(default_mask, None,outputs= [mask_p_image])
demo.launch()