SamiTechie
commited on
Commit
•
98b6309
1
Parent(s):
6522c57
project
Browse files- .gitignore +1 -0
- app.py +99 -0
- requirements.txt +14 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
env/
|
app.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import glob
|
3 |
+
import os
|
4 |
+
import cv2
|
5 |
+
import glob
|
6 |
+
from models.model import Model
|
7 |
+
#from evaluation import evaluation
|
8 |
+
datasets = ["HRF","DR-HAGIS","OVRS","ISBI2016_ISIC", "CHASEDB","KEVASIR-SEG"]
|
9 |
+
models = ["UNET","SA-UNET", "ATT-UNET","SAM", "SAM-ZERO-SHOT"]
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
def default_mask():
|
15 |
+
return None
|
16 |
+
def default_metrics():
|
17 |
+
return {}
|
18 |
+
def on_select_gallery(dataset: gr.SelectData):
|
19 |
+
path = glob.glob(os.path.join("datasets",dataset.value[0],"test/images/*"))
|
20 |
+
return gr.update(visible=True, value = path)
|
21 |
+
def on_select_dataset_path(dataset: gr.SelectData):
|
22 |
+
path = glob.glob(os.path.join("datasets",dataset.value[0],"test/images/*"))
|
23 |
+
return path
|
24 |
+
def on_select_gallery_label(dataset: gr.SelectData):
|
25 |
+
path =dataset.value[0]
|
26 |
+
return gr.update(visible=True, value = "## "+path+" Dataset")
|
27 |
+
def on_select_gallery_set_mask_path(value: gr.SelectData,path,dataset):
|
28 |
+
image_path = path[int(value.index)]
|
29 |
+
return os.path.join("datasets", dataset,"test","masks",os.path.basename(image_path))
|
30 |
+
def set_index(value: gr.SelectData):
|
31 |
+
return str(value.index)
|
32 |
+
def set_value(value: gr.SelectData):
|
33 |
+
return str(value.value)[2:-2]
|
34 |
+
def segment(dataset, model, image_index, path):
|
35 |
+
return gr.update(visible= True)
|
36 |
+
def segment_(dataset, model, image_index,path):
|
37 |
+
image_path = path[int(image_index)]
|
38 |
+
mask_path = os.path.join("datasets", dataset,"test","masks",os.path.basename(image_path))
|
39 |
+
return model.predict(image_path, mask_path)
|
40 |
+
def on_select_model(model: gr.SelectData,dataset):
|
41 |
+
return Model(model.value, dataset)
|
42 |
+
|
43 |
+
def update_model_dataset(dataset : gr.SelectData , model):
|
44 |
+
return gr.update(value = model.set_dataset(dataset.value[0]))
|
45 |
+
def update_model_name(model_name : gr.SelectData, model):
|
46 |
+
return gr.update(value = model.set_model(model_name.value))
|
47 |
+
def update_dropdown_labels(dataset : gr.SelectData):
|
48 |
+
choises = glob.glob(f"models/models_checkpoints/{dataset.value[0]}/*pth")
|
49 |
+
choises = list(map(lambda value: value.split("/")[-1].split('.')[0], choises))
|
50 |
+
return gr.Dropdown.update(choices = choises, interactive=True, label = "Models")
|
51 |
+
|
52 |
+
def test(model):
|
53 |
+
return gr.update(value=model.evaluation())
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
with gr.Blocks() as demo:
|
58 |
+
## States
|
59 |
+
model_build = gr.State()
|
60 |
+
mask_p_state = gr.State()
|
61 |
+
#models_state = gr.State([])
|
62 |
+
dataset_state = gr.State("ISBI2016_ISIC")
|
63 |
+
path = glob.glob(os.path.join("datasets",dataset_state.value,"test/images/*"))
|
64 |
+
path = gr.State(value=path)
|
65 |
+
model_state = gr.State(value = Model("UNET", dataset_state.value))
|
66 |
+
demo_header = gr.Markdown(value="# A Comparative Analysis of State-of-the-Art Deep learning Models for Medical Image")
|
67 |
+
image_index = gr.Markdown(visible = False)
|
68 |
+
mask_y_state = gr.State()
|
69 |
+
|
70 |
+
dataset = gr.Dataset(label="Datasets",components = [gr.Textbox(visible= False)], samples = list(map(lambda item: [item], datasets)),)
|
71 |
+
model = gr.Dropdown(interactive=False, label = "Models",value=lambda : "UNET")
|
72 |
+
dataset.select(update_dropdown_labels,None,outputs = [model])
|
73 |
+
model.select(update_model_name,[model_state] , [model_state])
|
74 |
+
dataset_label = gr.Markdown()
|
75 |
+
gallery = gr.Gallery(value = path.value,visible=True)
|
76 |
+
dataset.select(on_select_gallery,None,outputs = [gallery])
|
77 |
+
dataset.select(on_select_gallery_label,None,outputs = [dataset_label])
|
78 |
+
dataset.select(set_value,None,outputs = [dataset_state])
|
79 |
+
dataset.select(update_model_dataset,[model_state],outputs = [model_state])
|
80 |
+
dataset.select(on_select_dataset_path,None,outputs = [path])
|
81 |
+
segement_btn = gr.Button(value = "Segment")
|
82 |
+
dataset_label = gr.Markdown(value="# Result")
|
83 |
+
gallery.select(set_index, None,outputs= [image_index])
|
84 |
+
#mask_y_image = gr.Image(scale=2)
|
85 |
+
row = gr.Row(visible=False)
|
86 |
+
segement_btn.click(segment, [dataset_state, model_state, image_index, path], [row])
|
87 |
+
with row:
|
88 |
+
mask_p_image = gr.Image(scale=2)
|
89 |
+
mask_y_image = gr.Image(scale=2)
|
90 |
+
metrics_label = gr.Label(label = "Metrics",inputs=[mask_p_image])
|
91 |
+
mask_p_image.change(test, inputs=[model_state],outputs = [metrics_label])
|
92 |
+
gallery.select(on_select_gallery_set_mask_path, [path, dataset_state],outputs= [mask_y_state])
|
93 |
+
gallery.select(on_select_gallery_set_mask_path, [path, dataset_state],outputs= [mask_y_image])
|
94 |
+
segement_btn.click(segment_, [dataset_state, model_state, image_index, path], [mask_p_image])
|
95 |
+
dataset.select(default_mask,None, outputs = [mask_p_image])
|
96 |
+
dataset.select(default_metrics, None,outputs= [metrics_label])
|
97 |
+
gallery.select(default_metrics, None,outputs= [metrics_label])
|
98 |
+
gallery.select(default_mask, None,outputs= [mask_p_image])
|
99 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
albumentations==1.3.1
|
2 |
+
gradio==3.44.4
|
3 |
+
imageio==2.31.3
|
4 |
+
numpy==1.25.2
|
5 |
+
opencv_python==4.8.0.76
|
6 |
+
opencv_python_headless==4.8.0.76
|
7 |
+
Pillow==10.0.1
|
8 |
+
pycocotools==2.0.7
|
9 |
+
python_box==7.1.1
|
10 |
+
scikit_learn==1.3.0
|
11 |
+
scipy==1.11.2
|
12 |
+
torch==2.0.1+cpu
|
13 |
+
torchvision==0.15.2+cpu
|
14 |
+
transformers==4.33.0
|