| import random | |
| import gradio as gr | |
| import os | |
| base_path = "images/" | |
| def get_main_image_list(): | |
| image_ids = {} | |
| for dir in os.listdir(base_path): | |
| img_id = dir | |
| for img_class in os.listdir(base_path + dir): | |
| image_path = base_path + dir + "/" + img_class + "/image.png" | |
| if img_id not in image_ids: | |
| image_ids[img_id] = {} | |
| image_ids[img_id][img_class] = { | |
| "path": image_path | |
| } | |
| return image_ids | |
| def get_images(k=5): | |
| image_ids = get_main_image_list() | |
| images = [(list(image_ids.keys())[i], f"Image {i}") for i in range(len(image_ids.keys()))] | |
| return (image_ids, images) | |
| current_select = None | |
| dropdown_options = ["1", "2", "3"] | |
| selected_index = None | |
| selected_class = None | |
| image_ids, images = get_images(k=10) | |
| images_to_choose = list(map(lambda x: (image_ids[x[0]][list(image_ids[x[0]].keys())[0]]["path"], x[1]), images)) | |
| def process_image(gallery_data): | |
| if selected_index is None: | |
| return None | |
| key = images[selected_index][0] | |
| prediction = base_path + key + "/" + selected_class + "/pred_overlay.png" | |
| return prediction | |
| def dropdown_change(evt): | |
| if selected_index is None: | |
| return None | |
| global selected_class | |
| selected_class = evt | |
| key = images[selected_index][0] | |
| support_images = get_support_set(selected_index, key + "/" + selected_class) | |
| return support_images | |
| def get_select_value(evt: gr.SelectData): | |
| selected_image_index = evt.index | |
| global selected_index | |
| global selected_class | |
| selected_index = evt.index | |
| img_id = images[selected_image_index][0] | |
| keys = list(image_ids[images[selected_image_index][0]].keys()) | |
| selected_class = keys[0] | |
| sup_images = get_support_set(selected_image_index, img_id + "/" + keys[0]) | |
| return sup_images, gr.update(choices=keys, value=keys[0]) | |
| def get_support_set(selected_image_index, key): | |
| if selected_image_index is None: | |
| return None | |
| parent = base_path + key + "/" | |
| img_list = os.listdir(parent) | |
| img = [parent + x for x in img_list if x.startswith("support_im") and x.endswith("overlay.png")] | |
| return img | |
| def generate_images(): | |
| return images_to_choose, images_to_choose | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| imgs = gr.State() | |
| gallery = gr.Gallery( | |
| label="Generated images", | |
| show_label=False, elem_id="gallery", | |
| columns=[3], rows=[1], object_fit="contain", height="auto", | |
| allow_preview= True, preview= True | |
| ) | |
| demo.load(generate_images, None, [gallery, imgs]) | |
| with gr.Column(scale=1): | |
| support_gallery = gr.Gallery( | |
| label="Support Set Images", | |
| value=[], | |
| columns=3, rows=2, | |
| allow_preview=False, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| dropdown = gr.Dropdown(label="Select a class", scale=0.5, interactive=True) | |
| with gr.Column(scale=1): | |
| process_btn = gr.Button("Process", scale=0) | |
| gallery.select(get_select_value, None, [support_gallery, dropdown]) | |
| dropdown.input(dropdown_change, dropdown, support_gallery) | |
| output_image = gr.Image(label="Processed Image", width=500) | |
| process_btn.click(process_image, gallery, output_image, support_gallery) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |