|
import gradio as gr |
|
import random |
|
import time |
|
import os |
|
from glob import glob |
|
from PIL import Image |
|
import torchvision.transforms as transforms |
|
|
|
num_rank = 200 |
|
image_prefix = "/deep/u/eprakash/AngioSeg/diffusion/lung_seg_synthetic_60/synth/" |
|
mask_prefix = "/deep/u/eprakash/AngioSeg/diffusion/lung_seg_synthetic_60/orig/" |
|
image_ids = [] |
|
img_list = "/deep/u/eprakash/lung_seg/train_60.csv" |
|
with open(img_list) as fp: |
|
for line in fp: |
|
image_ids.append("('" + line.strip().split(",")[0] + "',)") |
|
image_ids = image_ids[301:501] |
|
save_path = "lung_seg_ranks" |
|
|
|
def is_int(s): |
|
try: |
|
int(s) |
|
return True |
|
except ValueError: |
|
return False |
|
|
|
def load_img(img_path, size=512): |
|
img = Image.open(img_path).convert('RGB') |
|
transform_list = [transforms.Resize((size, size))] |
|
transform = transforms.Compose(transform_list) |
|
img = transform(img) |
|
return img |
|
|
|
def find_completed_idxs(save_path=save_path): |
|
files = os.listdir(save_path) |
|
incorrect_files = [] |
|
if len(files) == 0: |
|
return [-1], [] |
|
else: |
|
file_list = [] |
|
for f in files: |
|
f_name = int(f.split(".")[0]) |
|
with open(save_path + "/" + f) as fp: |
|
for line in fp: |
|
items = line.strip().split(",") |
|
if (len(items) != 5 and f_name != -1): |
|
incorrect_files.append(f_name) |
|
else: |
|
if ((not is_int(items[1].strip()) or not is_int(items[2].strip()) or not is_int(items[3].strip()) or not is_int(items[4].strip())) and f_name != -1): |
|
incorrect_files.append(f_name) |
|
file_list.append(f_name) |
|
file_list = sorted(file_list) |
|
incorrect_files = sorted(incorrect_files) |
|
return file_list, incorrect_files |
|
|
|
def load_next(rank, img_1, mask_1, img_2, mask_2, img_3, mask_3, img_4, mask_4, example, ids=image_ids, image_prefix=image_prefix, save_path=save_path): |
|
file_list, incorrect_files = find_completed_idxs() |
|
print(str(file_list) + " " + str(incorrect_files)) |
|
if (int(example) not in file_list or int(example) in incorrect_files): |
|
r = str(image_ids[int(example)]).split(",")[0].split("(")[1] + "," + rank |
|
r_fp = open(save_path + "/" + str(int(example)) +".txt", "w") |
|
r_fp.write(r + "\n") |
|
r_fp.close() |
|
file_list, incorrect_files = find_completed_idxs() |
|
if (len(incorrect_files) != 0): |
|
example = incorrect_files[-1] |
|
else: |
|
example = file_list[-1] + 1 |
|
if int(example) == num_rank: |
|
rank = "DONE!" |
|
example = -1 |
|
mask_1 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False) |
|
img_1 = gr.Image(label="Sample #1", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False) |
|
mask_2 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False) |
|
img_2 = gr.Image(label="Sample #2", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False) |
|
mask_3 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False) |
|
img_3 = gr.Image(label="Sample #3", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False) |
|
mask_4 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False) |
|
img_4 = gr.Image(label="Sample #4", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False) |
|
else: |
|
rank = "" |
|
img_1 = gr.Image(label="Sample #1", value=load_img(image_prefix + str(image_ids[int(example)]) + "_synthetic_0.png"), interactive=False) |
|
mask_1 = gr.Image(label="Mask", value=load_img(mask_prefix + str(image_ids[int(example)]) + "_mask.png"), interactive=False) |
|
img_2 = gr.Image(label="Sample #2", value=load_img(image_prefix+ str(image_ids[int(example)]) + "_synthetic_1.png"), interactive=False) |
|
mask_2 = gr.Image(label="Mask", value=load_img(mask_prefix + str(image_ids[int(example)]) + "_mask.png"), interactive=False) |
|
img_3 = gr.Image(label="Sample #3", value=load_img(image_prefix + str(image_ids[int(example)]) + "_synthetic_2.png"), interactive=False) |
|
mask_3 = gr.Image(label="Mask", value=load_img(mask_prefix + str(image_ids[int(example)]) + "_mask.png"), interactive=False) |
|
img_4 = gr.Image(label="Sample #4", value=load_img(image_prefix + str(image_ids[int(example)]) + "_synthetic_3.png"), interactive=False) |
|
mask_4 = gr.Image(label="Mask", value=load_img(mask_prefix + str(image_ids[int(example)]) + "_mask.png"), interactive=False) |
|
return [rank, img_1, mask_1, img_2, mask_2, img_3, mask_3, img_4, mask_4, example] |
|
|
|
with gr.Blocks() as demo: |
|
last_idx = -1 |
|
example = gr.Number(label="Example #. Click next for #-1 (blank starting page).", value=last_idx, interactive=False) |
|
rank = gr.Textbox(label="Rankings (Best to worst, comma-separated, no spaces).") |
|
with gr.Column(scale=1): |
|
with gr.Row(): |
|
mask_1 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False) |
|
img_1 = gr.Image(label="Sample #1", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False) |
|
with gr.Row(): |
|
mask_2 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False) |
|
img_2 = gr.Image(label="Sample #2", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False) |
|
with gr.Row(): |
|
mask_3 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False) |
|
img_3 = gr.Image(label="Sample #3", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False) |
|
with gr.Row(): |
|
mask_4 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False) |
|
img_4 = gr.Image(label="Sample #4", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False) |
|
next_btn = gr.Button(value="Next") |
|
next_btn.click(fn=load_next, inputs=[rank, img_1, mask_1, img_2, mask_2, img_3, mask_3, img_4, mask_4, example], outputs=[rank, img_1, mask_1, img_2, mask_2, img_3, mask_3, img_4, mask_4, example], queue=False) |
|
demo.queue() |
|
demo.launch(share=True) |
|
|
|
|