|
import random |
|
|
|
import gradio as gr |
|
import os |
|
|
|
base_path = "images/" |
|
|
|
def get_main_image_list(): |
|
image_ids = {} |
|
for dir in os.listdir(base_path): |
|
img_id = dir |
|
for img_class in os.listdir(base_path + dir): |
|
image_path = base_path + dir + "/" + img_class + "/image.png" |
|
|
|
if img_id not in image_ids: |
|
image_ids[img_id] = {} |
|
image_ids[img_id][img_class] = { |
|
"path": image_path |
|
} |
|
|
|
return image_ids |
|
|
|
def get_images(k=5): |
|
image_ids = get_main_image_list() |
|
images = [(list(image_ids.keys())[i], f"Image {i}") for i in range(len(image_ids.keys()))] |
|
return (image_ids, images) |
|
|
|
current_select = None |
|
dropdown_options = ["1", "2", "3"] |
|
|
|
selected_index = None |
|
selected_class = None |
|
image_ids, images = get_images(k=10) |
|
images_to_choose = list(map(lambda x: (image_ids[x[0]][list(image_ids[x[0]].keys())[0]]["path"], x[1]), images)) |
|
|
|
def process_image(gallery_data): |
|
if selected_index is None: |
|
return None |
|
key = images[selected_index][0] |
|
prediction = base_path + key + "/" + selected_class + "/pred_overlay.png" |
|
|
|
return prediction |
|
|
|
def dropdown_change(evt): |
|
if selected_index is None: |
|
return None |
|
global selected_class |
|
selected_class = evt |
|
key = images[selected_index][0] |
|
support_images = get_support_set(selected_index, key + "/" + selected_class) |
|
return support_images |
|
|
|
def get_select_value(evt: gr.SelectData): |
|
selected_image_index = evt.index |
|
global selected_index |
|
global selected_class |
|
selected_index = evt.index |
|
img_id = images[selected_image_index][0] |
|
keys = list(image_ids[images[selected_image_index][0]].keys()) |
|
selected_class = keys[0] |
|
sup_images = get_support_set(selected_image_index, img_id + "/" + keys[0]) |
|
return sup_images, gr.update(choices=keys, value=keys[0]) |
|
|
|
def get_support_set(selected_image_index, key): |
|
if selected_image_index is None: |
|
return None |
|
|
|
parent = base_path + key + "/" |
|
img_list = os.listdir(parent) |
|
|
|
img = [parent + x for x in img_list if x.startswith("support_im") and x.endswith("overlay.png")] |
|
return img |
|
|
|
def generate_images(): |
|
return images_to_choose, images_to_choose |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
imgs = gr.State() |
|
gallery = gr.Gallery( |
|
label="Generated images", |
|
show_label=False, elem_id="gallery", |
|
columns=[3], rows=[1], object_fit="contain", height="auto", |
|
allow_preview= True, preview= True |
|
) |
|
demo.load(generate_images, None, [gallery, imgs]) |
|
|
|
with gr.Column(scale=1): |
|
support_gallery = gr.Gallery( |
|
label="Support Set Images", |
|
value=[], |
|
columns=3, rows=2, |
|
allow_preview=False, |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
dropdown = gr.Dropdown(label="Select a class", scale=0.5, interactive=True) |
|
with gr.Column(scale=1): |
|
process_btn = gr.Button("Process", scale=0) |
|
|
|
gallery.select(get_select_value, None, [support_gallery, dropdown]) |
|
dropdown.input(dropdown_change, dropdown, support_gallery) |
|
output_image = gr.Image(label="Processed Image", width=500) |
|
|
|
|
|
process_btn.click(process_image, gallery, output_image, support_gallery) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |
|
|
|
|