SamiTechie commited on
Commit
98b6309
1 Parent(s): 6522c57
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +99 -0
  3. requirements.txt +14 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ env/
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import glob
3
+ import os
4
+ import cv2
5
+ import glob
6
+ from models.model import Model
7
+ #from evaluation import evaluation
8
+ datasets = ["HRF","DR-HAGIS","OVRS","ISBI2016_ISIC", "CHASEDB","KEVASIR-SEG"]
9
+ models = ["UNET","SA-UNET", "ATT-UNET","SAM", "SAM-ZERO-SHOT"]
10
+
11
+
12
+
13
+
14
+ def default_mask():
15
+ return None
16
+ def default_metrics():
17
+ return {}
18
+ def on_select_gallery(dataset: gr.SelectData):
19
+ path = glob.glob(os.path.join("datasets",dataset.value[0],"test/images/*"))
20
+ return gr.update(visible=True, value = path)
21
+ def on_select_dataset_path(dataset: gr.SelectData):
22
+ path = glob.glob(os.path.join("datasets",dataset.value[0],"test/images/*"))
23
+ return path
24
+ def on_select_gallery_label(dataset: gr.SelectData):
25
+ path =dataset.value[0]
26
+ return gr.update(visible=True, value = "## "+path+" Dataset")
27
+ def on_select_gallery_set_mask_path(value: gr.SelectData,path,dataset):
28
+ image_path = path[int(value.index)]
29
+ return os.path.join("datasets", dataset,"test","masks",os.path.basename(image_path))
30
+ def set_index(value: gr.SelectData):
31
+ return str(value.index)
32
+ def set_value(value: gr.SelectData):
33
+ return str(value.value)[2:-2]
34
+ def segment(dataset, model, image_index, path):
35
+ return gr.update(visible= True)
36
+ def segment_(dataset, model, image_index,path):
37
+ image_path = path[int(image_index)]
38
+ mask_path = os.path.join("datasets", dataset,"test","masks",os.path.basename(image_path))
39
+ return model.predict(image_path, mask_path)
40
+ def on_select_model(model: gr.SelectData,dataset):
41
+ return Model(model.value, dataset)
42
+
43
+ def update_model_dataset(dataset : gr.SelectData , model):
44
+ return gr.update(value = model.set_dataset(dataset.value[0]))
45
+ def update_model_name(model_name : gr.SelectData, model):
46
+ return gr.update(value = model.set_model(model_name.value))
47
+ def update_dropdown_labels(dataset : gr.SelectData):
48
+ choises = glob.glob(f"models/models_checkpoints/{dataset.value[0]}/*pth")
49
+ choises = list(map(lambda value: value.split("/")[-1].split('.')[0], choises))
50
+ return gr.Dropdown.update(choices = choises, interactive=True, label = "Models")
51
+
52
+ def test(model):
53
+ return gr.update(value=model.evaluation())
54
+
55
+
56
+
57
+ with gr.Blocks() as demo:
58
+ ## States
59
+ model_build = gr.State()
60
+ mask_p_state = gr.State()
61
+ #models_state = gr.State([])
62
+ dataset_state = gr.State("ISBI2016_ISIC")
63
+ path = glob.glob(os.path.join("datasets",dataset_state.value,"test/images/*"))
64
+ path = gr.State(value=path)
65
+ model_state = gr.State(value = Model("UNET", dataset_state.value))
66
+ demo_header = gr.Markdown(value="# A Comparative Analysis of State-of-the-Art Deep learning Models for Medical Image")
67
+ image_index = gr.Markdown(visible = False)
68
+ mask_y_state = gr.State()
69
+
70
+ dataset = gr.Dataset(label="Datasets",components = [gr.Textbox(visible= False)], samples = list(map(lambda item: [item], datasets)),)
71
+ model = gr.Dropdown(interactive=False, label = "Models",value=lambda : "UNET")
72
+ dataset.select(update_dropdown_labels,None,outputs = [model])
73
+ model.select(update_model_name,[model_state] , [model_state])
74
+ dataset_label = gr.Markdown()
75
+ gallery = gr.Gallery(value = path.value,visible=True)
76
+ dataset.select(on_select_gallery,None,outputs = [gallery])
77
+ dataset.select(on_select_gallery_label,None,outputs = [dataset_label])
78
+ dataset.select(set_value,None,outputs = [dataset_state])
79
+ dataset.select(update_model_dataset,[model_state],outputs = [model_state])
80
+ dataset.select(on_select_dataset_path,None,outputs = [path])
81
+ segement_btn = gr.Button(value = "Segment")
82
+ dataset_label = gr.Markdown(value="# Result")
83
+ gallery.select(set_index, None,outputs= [image_index])
84
+ #mask_y_image = gr.Image(scale=2)
85
+ row = gr.Row(visible=False)
86
+ segement_btn.click(segment, [dataset_state, model_state, image_index, path], [row])
87
+ with row:
88
+ mask_p_image = gr.Image(scale=2)
89
+ mask_y_image = gr.Image(scale=2)
90
+ metrics_label = gr.Label(label = "Metrics",inputs=[mask_p_image])
91
+ mask_p_image.change(test, inputs=[model_state],outputs = [metrics_label])
92
+ gallery.select(on_select_gallery_set_mask_path, [path, dataset_state],outputs= [mask_y_state])
93
+ gallery.select(on_select_gallery_set_mask_path, [path, dataset_state],outputs= [mask_y_image])
94
+ segement_btn.click(segment_, [dataset_state, model_state, image_index, path], [mask_p_image])
95
+ dataset.select(default_mask,None, outputs = [mask_p_image])
96
+ dataset.select(default_metrics, None,outputs= [metrics_label])
97
+ gallery.select(default_metrics, None,outputs= [metrics_label])
98
+ gallery.select(default_mask, None,outputs= [mask_p_image])
99
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ albumentations==1.3.1
2
+ gradio==3.44.4
3
+ imageio==2.31.3
4
+ numpy==1.25.2
5
+ opencv_python==4.8.0.76
6
+ opencv_python_headless==4.8.0.76
7
+ Pillow==10.0.1
8
+ pycocotools==2.0.7
9
+ python_box==7.1.1
10
+ scikit_learn==1.3.0
11
+ scipy==1.11.2
12
+ torch==2.0.1+cpu
13
+ torchvision==0.15.2+cpu
14
+ transformers==4.33.0