import random import gradio as gr import os base_path = "images/" def get_main_image_list(): image_ids = {} for dir in os.listdir(base_path): img_id = dir for img_class in os.listdir(base_path + dir): image_path = base_path + dir + "/" + img_class + "/image.png" if img_id not in image_ids: image_ids[img_id] = {} image_ids[img_id][img_class] = { "path": image_path } return image_ids def get_images(k=5): image_ids = get_main_image_list() images = [(list(image_ids.keys())[i], f"Image {i}") for i in range(len(image_ids.keys()))] return (image_ids, images) current_select = None dropdown_options = ["1", "2", "3"] selected_index = None selected_class = None image_ids, images = get_images(k=10) images_to_choose = list(map(lambda x: (image_ids[x[0]][list(image_ids[x[0]].keys())[0]]["path"], x[1]), images)) def process_image(gallery_data): if selected_index is None: return None key = images[selected_index][0] prediction = base_path + key + "/" + selected_class + "/pred_overlay.png" return prediction def dropdown_change(evt): if selected_index is None: return None global selected_class selected_class = evt key = images[selected_index][0] support_images = get_support_set(selected_index, key + "/" + selected_class) return support_images def get_select_value(evt: gr.SelectData): selected_image_index = evt.index global selected_index global selected_class selected_index = evt.index img_id = images[selected_image_index][0] keys = list(image_ids[images[selected_image_index][0]].keys()) selected_class = keys[0] sup_images = get_support_set(selected_image_index, img_id + "/" + keys[0]) return sup_images, gr.update(choices=keys, value=keys[0]) def get_support_set(selected_image_index, key): if selected_image_index is None: return None parent = base_path + key + "/" img_list = os.listdir(parent) img = [parent + x for x in img_list if x.startswith("support_im") and x.endswith("overlay.png")] return img def generate_images(): return images_to_choose, images_to_choose with gr.Blocks() as demo: with gr.Row(): with gr.Column(scale=1): imgs = gr.State() gallery = gr.Gallery( label="Generated images", show_label=False, elem_id="gallery", columns=[3], rows=[1], object_fit="contain", height="auto", allow_preview= True, preview= True ) demo.load(generate_images, None, [gallery, imgs]) with gr.Column(scale=1): support_gallery = gr.Gallery( label="Support Set Images", value=[], columns=3, rows=2, allow_preview=False, ) with gr.Row(): with gr.Column(scale=1): dropdown = gr.Dropdown(label="Select a class", scale=0.5, interactive=True) with gr.Column(scale=1): process_btn = gr.Button("Process", scale=0) gallery.select(get_select_value, None, [support_gallery, dropdown]) dropdown.input(dropdown_change, dropdown, support_gallery) output_image = gr.Image(label="Processed Image", width=500) process_btn.click(process_image, gallery, output_image, support_gallery) if __name__ == "__main__": demo.launch(share=True)