import gradio as gr import random import time import os from glob import glob from PIL import Image import torchvision.transforms as transforms image_prefix = "/deep/u/eprakash/AngioSeg/diffusion/cxr_synthetic_data_25_no_transform/synth/" image_ids = glob(os.path.join(image_prefix, '*' + '.png')) image_ids = list(set([os.path.splitext(os.path.basename(p))[0].split("_")[0] for p in image_ids])) save_path = "cxr_ranks" def load_img(img_path, size=1024): img = Image.open(img_path).convert('RGB') transform_list = [transforms.Resize((size, size))] transform = transforms.Compose(transform_list) img = transform(img) return img def find_completed_idxs(save_path=save_path): files = os.listdir(save_path) if len(files) == 0: return [-1] else: file_list = [] for f in files: f = int(f.split(".")[0]) file_list.append(f) file_list = sorted(file_list) return file_list def load_next(rank, img_1, mask_1, img_2, mask_2, img_3, mask_3, img_4, mask_4, example, ids=image_ids, image_prefix=image_prefix, save_path=save_path): if int(example) == len(image_ids) - 1: return [None, None, None, None, None, None, None] else: file_list = find_completed_idxs() if (int(example) not in file_list): r = str(image_ids[int(example)]) + "," + rank r_fp = open(save_path + "/" + str(int(example)) +".txt", "w") r_fp.write(r + "\n") r_fp.close() file_list = find_completed_idxs() example = file_list[-1] + 1 rank = "" img_1 = gr.Image(label="Sample #1", value=load_img(image_prefix + str(image_ids[int(example)]) + "_synthetic_0.png"), interactive=False) mask_1 = gr.Image(label="Mask", value=load_img(image_prefix + str(image_ids[int(example)]) + "_synthetic_mask_1.png"), interactive=False) img_2 = gr.Image(label="Sample #2", value=load_img(image_prefix+ str(image_ids[int(example)]) + "_synthetic_1.png"), interactive=False) mask_2 = gr.Image(label="Mask", value=load_img(image_prefix + str(image_ids[int(example)]) + "_synthetic_mask_2.png"), interactive=False) img_3 = gr.Image(label="Sample #3", value=load_img(image_prefix + str(image_ids[int(example)]) + "_synthetic_2.png"), interactive=False) mask_3 = gr.Image(label="Mask", value=load_img(image_prefix + str(image_ids[int(example)]) + "_synthetic_mask_3.png"), interactive=False) img_4 = gr.Image(label="Sample #4", value=load_img(image_prefix + str(image_ids[int(example)]) + "_synthetic_3.png"), interactive=False) mask_4 = gr.Image(label="Mask", value=load_img(image_prefix + str(image_ids[int(example)]) + "_synthetic_mask_4.png"), interactive=False) return [rank, img_1, mask_1, img_2, mask_2, img_3, mask_3, img_4, mask_4, example] with gr.Blocks() as demo: last_idx = -1 example = gr.Number(label="Example #. Click next for #-1 (blank starting page).", value=last_idx, interactive=False) rank = gr.Textbox(label="Rankings (Best to worst, comma-separated, no spaces).") with gr.Column(scale=1): with gr.Row(): mask_1 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False) img_1 = gr.Image(label="Sample #1", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False) with gr.Row(): mask_2 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False) img_2 = gr.Image(label="Sample #2", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False) with gr.Row(): mask_3 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False) img_3 = gr.Image(label="Sample #3", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False) with gr.Row(): mask_4 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False) img_4 = gr.Image(label="Sample #4", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False) next_btn = gr.Button(value="Next") next_btn.click(fn=load_next, inputs=[rank, img_1, mask_1, img_2, mask_2, img_3, mask_3, img_4, mask_4, example], outputs=[rank, img_1, mask_1, img_2, mask_2, img_3, mask_3, img_4, mask_4, example], queue=False) demo.queue() demo.launch(share=True)