icl_xmem / app.py
Ioana
app and images
fdce364
import random
import gradio as gr
import os
base_path = "images/"
def get_main_image_list():
image_ids = {}
for dir in os.listdir(base_path):
img_id = dir
for img_class in os.listdir(base_path + dir):
image_path = base_path + dir + "/" + img_class + "/image.png"
if img_id not in image_ids:
image_ids[img_id] = {}
image_ids[img_id][img_class] = {
"path": image_path
}
return image_ids
def get_images(k=5):
image_ids = get_main_image_list()
images = [(list(image_ids.keys())[i], f"Image {i}") for i in range(len(image_ids.keys()))]
return (image_ids, images)
current_select = None
dropdown_options = ["1", "2", "3"]
selected_index = None
selected_class = None
image_ids, images = get_images(k=10)
images_to_choose = list(map(lambda x: (image_ids[x[0]][list(image_ids[x[0]].keys())[0]]["path"], x[1]), images))
def process_image(gallery_data):
if selected_index is None:
return None
key = images[selected_index][0]
prediction = base_path + key + "/" + selected_class + "/pred_overlay.png"
return prediction
def dropdown_change(evt):
if selected_index is None:
return None
global selected_class
selected_class = evt
key = images[selected_index][0]
support_images = get_support_set(selected_index, key + "/" + selected_class)
return support_images
def get_select_value(evt: gr.SelectData):
selected_image_index = evt.index
global selected_index
global selected_class
selected_index = evt.index
img_id = images[selected_image_index][0]
keys = list(image_ids[images[selected_image_index][0]].keys())
selected_class = keys[0]
sup_images = get_support_set(selected_image_index, img_id + "/" + keys[0])
return sup_images, gr.update(choices=keys, value=keys[0])
def get_support_set(selected_image_index, key):
if selected_image_index is None:
return None
parent = base_path + key + "/"
img_list = os.listdir(parent)
img = [parent + x for x in img_list if x.startswith("support_im") and x.endswith("overlay.png")]
return img
def generate_images():
return images_to_choose, images_to_choose
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=1):
imgs = gr.State()
gallery = gr.Gallery(
label="Generated images",
show_label=False, elem_id="gallery",
columns=[3], rows=[1], object_fit="contain", height="auto",
allow_preview= True, preview= True
)
demo.load(generate_images, None, [gallery, imgs])
with gr.Column(scale=1):
support_gallery = gr.Gallery(
label="Support Set Images",
value=[],
columns=3, rows=2,
allow_preview=False,
)
with gr.Row():
with gr.Column(scale=1):
dropdown = gr.Dropdown(label="Select a class", scale=0.5, interactive=True)
with gr.Column(scale=1):
process_btn = gr.Button("Process", scale=0)
gallery.select(get_select_value, None, [support_gallery, dropdown])
dropdown.input(dropdown_change, dropdown, support_gallery)
output_image = gr.Image(label="Processed Image", width=500)
process_btn.click(process_image, gallery, output_image, support_gallery)
if __name__ == "__main__":
demo.launch(share=True)