Spaces:
Sleeping
Sleeping
import gradio as gr | |
import random | |
import os | |
import pickle | |
from PIL import Image | |
class Process: | |
gt_image= "" | |
gt_image_idx = (0, 0) | |
raw_image_path = "" | |
candidate_image1_path = "" | |
candidate_image2_path = "" | |
candidata_image1_idx = (0, 0) | |
candidate_image2_idx = (0, 0) | |
candidate_image1_group = "negative" | |
candidate_image2_group = "negative" | |
concept_choices = None | |
pkl_data = None | |
positive_cand = [] | |
negative_cand = [] | |
positive1_cand = [] | |
positive2_cand = [] | |
positive_common_cand = [] | |
schedule = 0 | |
idx_to_chain = {} | |
global process | |
process = Process() | |
def load_data_and_produce_list(dataset,exp_mode, concept_choices): | |
if dataset == "ocl_attribute": | |
#TODO | |
attr_name = ['wooden', 'metal', 'flying', 'ripe', 'fresh', 'natural', 'cooked', 'painted', 'rusty', 'furry'] | |
attr2idx = {item:idx for idx,item in enumerate(attr_name)} | |
idx_2_attr = {value:key for key,value in attr2idx.items()} | |
pkl_path = "Data/OCL_data/OCL_selected_test_attribute_refined.pkl" | |
image_dir = "Data/OCL_data/data" | |
with open(pkl_path,"rb") as f: | |
data = pickle.load(f) | |
with open('Data/OCL_data/OCL_annot_test.pkl', "rb") as f: | |
process.pkl_data = pickle.load(f) | |
if exp_mode == "One concept": | |
process.positive_cand = data['selected_individual_pkl'][process.idx_to_chain[concept_choices]] | |
process.negative_cand = data['negative_pkl'] | |
else: | |
selected_concept_group = process.idx_to_chain[concept_choices].split("-") | |
selected_paired_pkl = data['selected_paired_pkl'][process.idx_to_chain[concept_choices]] | |
process.positive1_cand = selected_paired_pkl[selected_concept_group[0]] | |
process.positive2_cand = selected_paired_pkl[selected_concept_group[1]] | |
process.positive_common_cand = selected_paired_pkl[process.idx_to_chain[concept_choices]] | |
process.negative_cand = data['negative_pkl'] | |
elif dataset == "ocl_affordance": | |
aff_name = ['break', 'carry', 'clean','cut','push','sit','write'] | |
aff2idx = {item:idx for idx,item in enumerate(aff_name)} | |
idx_2_attr = {value:key for key,value in aff2idx.items()} | |
pkl_path = "Data/OCL_data/OCL_selected_test_affordance_refined.pkl" | |
image_dir = "Data/OCL_data/data" | |
with open(pkl_path,"rb") as f: | |
data = pickle.load(f) | |
with open('Data/OCL_data/OCL_annot_test.pkl', "rb") as f: | |
process.pkl_data = pickle.load(f) | |
if exp_mode == "One concept": | |
process.positive_cand = data['selected_individual_pkl'][process.idx_to_chain[concept_choices]] | |
process.negative_cand = data['negative_pkl'] | |
else: | |
selected_concept_group = process.idx_to_chain[concept_choices].split("-") | |
selected_paired_pkl = data['selected_paired_pkl'][process.idx_to_chain[concept_choices]] | |
process.positive1_cand = selected_paired_pkl[selected_concept_group[0]] | |
process.positive2_cand = selected_paired_pkl[selected_concept_group[1]] | |
process.positive_common_cand = selected_paired_pkl[process.idx_to_chain[concept_choices]] | |
process.negative_cand = data['negative_pkl'] | |
elif dataset == "Pangea": | |
attr_name = ["hit-18.1","run-51.3.2","dress-41.1.1-1-1","drive-11.5","cooking-45.3","build-26.1","shake-22.3-2","cut-21.1-1"] | |
attr2idx = {item:idx for idx,item in enumerate(attr_name)} | |
idx_2_attr = {value:key for key,value in attr2idx.items()} | |
pkl_path = "Data/pangea/pangea_test_refined.pkl" | |
image_dir = "Data/pangea/pangea" | |
with open(pkl_path,"rb") as f: | |
data = pickle.load(f) | |
with open("Data/pangea/B123_test_KIN-FULL_with_node.pkl", "rb") as f: | |
process.pkl_data = pickle.load(f) | |
if exp_mode == "One concept": | |
process.positive_cand = data['selected_pkl'][process.idx_to_chain[concept_choices]] | |
process.negative_cand = data['negative_pkl'] | |
else: | |
selected_concept_group = process.idx_to_chain[concept_choices].split("_") | |
selected_paired_pkl = data['selected_paired_pkl'][process.idx_to_chain[concept_choices]] | |
process.positive1_cand = selected_paired_pkl[selected_concept_group[0]] | |
process.positive2_cand = selected_paired_pkl[selected_concept_group[1]] | |
process.positive_common_cand = selected_paired_pkl[process.idx_to_chain[concept_choices]] | |
process.negative_cand = data['negative_pkl'] | |
elif dataset == "hmdb": | |
attr_name = ['brush_hair','clap', 'dive', 'shake_hands','hug' ,'sit','smoke','eat'] | |
attr2idx = {key:item for key,item in enumerate(attr_name)} | |
image_dir = "Data/refined_HMDB" | |
pkl_path = "Data/refined_HMDB.pkl" | |
with open(pkl_path,"rb") as f: | |
data = pickle.load(f) | |
if exp_mode == "One concept": | |
positive_cand = [] | |
negative_cand = [] | |
for each_data in data: | |
each_data['name'] = os.path.join(image_dir,each_data['name']) | |
if process.idx_to_chain[concept_choices] in each_data["label"]: | |
positive_cand.append(each_data) | |
else: | |
negative_cand.append(each_data) | |
if len(positive_cand) > 30 and len(negative_cand) > 100: | |
break | |
process.positive_cand = positive_cand | |
process.negative_cand = negative_cand | |
else: | |
negative_cand = [] | |
positive1_cand = [] | |
positive2_cand = [] | |
positive_common_cand = [] | |
for each_data in data: | |
each_data['name'] = os.path.join(image_dir,each_data['name']) | |
selected_concept_group = process.idx_to_chain[concept_choices].split("-") | |
if selected_concept_group[0] in each_data["name"] and selected_concept_group[1] in each_data["name"]: | |
positive_common_cand.append(each_data) | |
elif selected_concept_group[0] in each_data["name"]: | |
positive1_cand.append(each_data) | |
elif selected_concept_group[1] in each_data["name"]: | |
positive2_cand.append(each_data) | |
else: | |
if len(negative_cand) <= 100: | |
negative_cand.append(each_data) | |
process.positive1_cand = positive1_cand | |
process.positive2_cand = positive2_cand | |
process.positive_common_cand = positive_common_cand | |
process.negative_cand = negative_cand | |
TARGET_SIZE = (200,200) | |
def load_images(dataset, raw_image_path, candidate_image1_path, candidate_image2_path): | |
if dataset == "ocl_attribute" or dataset == "ocl_affordance": | |
image_dir = "Data/OCL_data/data" | |
raw_data = process.pkl_data[raw_image_path[0]] | |
img_path = os.path.join(image_dir,raw_data["name"]) | |
raw_image = Image.open(img_path).crop(raw_data['objects'][raw_image_path[1]]['box']).resize(TARGET_SIZE) | |
candidate_data1 = process.pkl_data[candidate_image1_path[0]] | |
cand1_img_path = os.path.join(image_dir,candidate_data1["name"]) | |
candidate_image1 = Image.open(cand1_img_path).crop(candidate_data1['objects'][candidate_image1_path[1]]['box']).resize(TARGET_SIZE) | |
candidate_data2 = process.pkl_data[candidate_image2_path[0]] | |
cand2_img_path = os.path.join(image_dir,candidate_data2["name"]) | |
candidate_image2 = Image.open(cand2_img_path).crop(candidate_data2['objects'][candidate_image2_path[1]]['box']).resize(TARGET_SIZE) | |
elif dataset == "Pangea": | |
mapping_dataset_directory = {'ActvityNet_hico_style_batch1':'ActivityNet_hico_batch1','charadesEgo_hico_style':'charadesego_frame', 'HAG_hico_style_new':'hag_frame','HACS_hico_style':'hacs_frame','kinetics_hico_style':'kinetics_dataset/k700-2020/train'} | |
image_dir = "Data/pangea/pangea" | |
raw_data = process.pkl_data[raw_image_path] | |
img_path = os.path.join(image_dir,mapping_dataset_directory[raw_data[0]], raw_data[1]) | |
raw_image = Image.open(img_path).resize(TARGET_SIZE) | |
candidate_data1 = process.pkl_data[candidate_image1_path] | |
cand1_img_path = os.path.join(image_dir,mapping_dataset_directory[candidate_data1[0]], candidate_data1[1]) | |
candidate_image1 = Image.open(cand1_img_path).resize(TARGET_SIZE) | |
candidate_data2 = process.pkl_data[candidate_image2_path] | |
cand2_img_path = os.path.join(image_dir,mapping_dataset_directory[candidate_data2[0]], candidate_data2[1]) | |
candidate_image2 = Image.open(cand2_img_path).resize(TARGET_SIZE) | |
else: | |
raw_image = Image.open(raw_image_path['name']).resize(TARGET_SIZE) | |
candidate_image1 = Image.open(candidate_image1_path['name']).resize(TARGET_SIZE) | |
candidate_image2 = Image.open(candidate_image2_path['name']).resize(TARGET_SIZE) | |
return raw_image, candidate_image1, candidate_image2 | |
def load_candidate_images(dataset, cand_image,candidate_image1_path,candidate_image2_path): | |
raw_image = cand_image | |
if dataset == "ocl_attribute" or dataset == "ocl_affordance": | |
image_dir = "Data/OCL_data/data" | |
candidate_data1 = process.pkl_data[candidate_image1_path[0]] | |
cand1_img_path = os.path.join(image_dir, candidate_data1["name"]) | |
candidate_image1 = Image.open(cand1_img_path).crop(candidate_data1['objects'][candidate_image1_path[1]]['box']).resize(TARGET_SIZE) | |
candidate_data2 = process.pkl_data[candidate_image2_path[0]] | |
cand2_img_path = os.path.join(image_dir, candidate_data2["name"]) | |
candidate_image2 = Image.open(cand2_img_path).crop(candidate_data2['objects'][candidate_image2_path[1]]['box']).resize(TARGET_SIZE) | |
elif dataset == "Pangea": | |
mapping_dataset_directory = {'ActvityNet_hico_style_batch1':'ActivityNet_hico_batch1','charadesEgo_hico_style':'charadesego_frame', 'HAG_hico_style_new':'hag_frame','HACS_hico_style':'hacs_frame','kinetics_hico_style':'kinetics_dataset/k700-2020/train'} | |
image_dir = "Data/pangea/pangea" | |
candidate_data1 = process.pkl_data[candidate_image1_path] | |
cand1_img_path = os.path.join(image_dir,mapping_dataset_directory[candidate_data1[0]],candidate_data1[1]) | |
candidate_image1 = Image.open(cand1_img_path).resize(TARGET_SIZE) | |
candidate_data2 = process.pkl_data[candidate_image2_path] | |
cand2_img_path = os.path.join(image_dir,mapping_dataset_directory[candidate_data2[0]],candidate_data2[1]) | |
candidate_image2 = Image.open(cand2_img_path).resize(TARGET_SIZE) | |
else: | |
candidate_image1 = Image.open(candidate_image1_path['name']).resize(TARGET_SIZE) | |
candidate_image2 = Image.open(candidate_image2_path['name']).resize(TARGET_SIZE) | |
return raw_image,candidate_image1,candidate_image2 | |
class InferenceDemo(object): | |
def __init__(self,args,dataset,exp_mode,concept_choices): | |
print("init success") | |
def get_concept_choices(dataset,exp_mode): | |
# if dataset == "ocl": | |
if dataset == "ocl_affordance": | |
if exp_mode == "One concept": | |
choices = [f"Chain_{i}" for i in range(8)] | |
else: | |
choices = [f"Chain_{i}" for i in range(4)] | |
elif dataset == "Pangea": | |
if exp_mode == "One concept": | |
choices = [f"Chain_{i}" for i in range(8)] | |
else: | |
choices = [f"Chain_{i}" for i in range(4)] | |
else: | |
if exp_mode == "One concept": | |
choices = [f"Chain_{i}" for i in range(8)] | |
else: | |
choices = [f"Chain_{i}" for i in range(4)] | |
return gr.update(choices=choices) | |
def load_images_and_concepts(dataset,exp_mode,concept_choices): | |
process.concept_choices = concept_choices | |
idx_2_chain = {} | |
if dataset == "ocl_attribute": | |
if exp_mode == "One concept": | |
concept = ["furry","metal","fresh","cooked","natural","ripe","painted","rusty"] | |
for idx in range(8): | |
idx_2_chain[f"Chain_{idx}"] = concept[idx] | |
else: | |
concept = ["furry-metal","fresh-cooked","natural-ripe","painted-rusty"] | |
for idx in range(4): | |
idx_2_chain[f"Chain_{idx}"] = concept[idx] | |
elif dataset == "ocl_affordance": | |
if exp_mode == "One concept": | |
concept = ['break', 'carry', 'clean','cut','open','push','sit','write'] | |
for idx in range(8): | |
idx_2_chain[f"Chain_{idx}"] = concept[idx] | |
else: | |
concept = ['sit-write','push-carry','cut-clean','open-break'] | |
for idx in range(4): | |
idx_2_chain[f"Chain_{idx}"] = concept[idx] | |
elif dataset == "Pangea": | |
if exp_mode == "One concept": | |
concept = ["hit-18.1","run-51.3.2","dress-41.1.1-1-1","drive-11.5","cooking-45.3","build-26.1","shake-22.3-2","cut-21.1-1"] | |
for idx in range(8): | |
idx_2_chain[f"Chain_{idx}"] = concept[idx] | |
else: | |
concept = ['run-51.3.2_hit-18.1', 'drive-11.5_dress-41.1.1-1-1', 'cooking-45.3_build-26.1','shake-22.3-2_cut-21.1-1'] | |
for idx in range(4): | |
idx_2_chain[f"Chain_{idx}"] = concept[idx] | |
else: | |
if exp_mode == "One concept": | |
concept = ["brush_hair","dive","clap","hug","shake_hands","sit","smoke","eat"] | |
for idx in range(8): | |
idx_2_chain[f"Chain_{idx}"] = concept[idx] | |
else: | |
concept = ["brush_hair-dive","clap-hug","shake_hands-sit","smoke-eat"] | |
for idx in range(4): | |
idx_2_chain[f"Chain_{idx}"] = concept[idx] | |
process.idx_to_chain = idx_2_chain | |
load_data_and_produce_list(dataset,exp_mode,concept_choices) | |
if exp_mode == "One concept": | |
if random.random() < 0.5: | |
process.raw_image_path = random.choice(process.positive_cand) | |
process.candidate_image1_idx = process.candidate_image1_path = random.choice(process.positive_cand) | |
process.candidate_image2_idx = process.candidate_image2_path = random.choice(process.negative_cand) | |
process.candidate_image1_group, process.candidate_image2_group = "positive", "negative" | |
process.gt_image = "Image1" | |
process.gt_image_idx = process.candidata_image1_idx | |
else: | |
process.raw_image_path = random.choice(process.positive_cand) | |
process.candidate_image1_idx = process.candidate_image1_path = random.choice(process.negative_cand) | |
process.candidate_image2_idx = process.candidate_image2_path = random.choice(process.positive_cand) | |
process.candidate_image1_group, process.candidate_image2_group = "negative", "positive" | |
process.gt_image = "Image2" | |
process.gt_image_idx = process.candidate_image2_idx | |
else: | |
if random.random() < 0.5: | |
process.raw_image_path = random.choice(process.positive1_cand) | |
process.candidate_image1_idx = process.candidate_image1_path = random.choice(process.positive1_cand) | |
process.candidate_image2_idx = process.candidate_image2_path = random.choice(process.negative_cand) | |
process.candidate_image1_group, process.candidate_image2_group = "positive1", "negative" | |
process.gt_image = "Image1" | |
process.gt_image_idx = process.candidata_image1_idx | |
else: | |
process.raw_image_path = random.choice(process.positive1_cand) | |
process.candidate_image1_idx = process.candidate_image1_path = random.choice(process.negative_cand) | |
process.candidate_image2_idx = process.candidate_image2_path = random.choice(process.positive1_cand) | |
process.candidate_image1_group, process.candidate_image2_group = "negative", "positive1" | |
process.gt_image = "Image2" | |
process.gt_image_idx = process.candidate_image2_idx | |
raw_image,candidate_image1,candidate_image2 = load_images(dataset, process.raw_image_path,process.candidate_image1_path,process.candidate_image2_path) | |
if dataset == "Pangea": | |
concept = ["hit", "run", "dress", "drive", "cooking", "build", "shake", "cut"] | |
elif dataset == "ocl_attribute": | |
concept = ["furry","metal","fresh","cooked","natural","ripe","painted","rusty"] | |
elif dataset == "ocl_affordance": | |
concept = ['break', 'carry', 'clean','cut','open','push','sit','write'] | |
return raw_image,candidate_image1,candidate_image2, str(concept) | |
def count_and_reload_images(dataset,exp_mode, select_input,show_result, steps,raw_image,candidate_image1,candidate_image2): | |
if select_input != None: | |
if select_input == process.gt_image or int(steps) < 6 or select_input == 'Uncertain': | |
if select_input == 'Uncertain': | |
if process.gt_image == 'Image1': | |
negative_sample = 'Image2' | |
else: | |
negative_sample = 'Image1' | |
filter_images(dataset, exp_mode, process.concept_choices, negative_sample) | |
if select_input == process.gt_image: | |
show_result = "Success!" | |
elif select_input == 'Uncertain': | |
show_result = 'Skip' | |
else: | |
show_result = "Error!" | |
if exp_mode == "One concept": | |
if process.gt_image == "Image1": | |
candidate_image = candidate_image1 | |
else: | |
candidate_image = candidate_image2 | |
if random.random() < 0.5: | |
process.candidate_image1_idx = process.candidate_image1_path = random.choice([x for x in process.positive_cand if x!=process.gt_image_idx]) | |
process.candidate_image2_idx = process.candidate_image2_path = random.choice(process.negative_cand) | |
process.candidate_image1_group, process.candidate_image2_group = "positive", "negative" | |
process.gt_image = "Image1" | |
process.gt_image_idx = process.candidate_image1_idx | |
else: | |
process.candidate_image1_idx = process.candidate_image1_path = random.choice(process.negative_cand) | |
process.candidate_image2_idx = process.candidate_image2_path = random.choice([x for x in process.positive_cand if x!=process.gt_image_idx]) | |
process.candidate_image1_group, process.candidate_image2_group = "negative", "positive" | |
process.gt_image = "Image2" | |
process.gt_image_idx = process.candidate_image2_idx | |
raw_image,candidate_image1,candidate_image2 = load_candidate_images(dataset,candidate_image,process.candidate_image1_path,process.candidate_image2_path) | |
else: | |
if process.gt_image == "Image1": | |
candidate_image = candidate_image1 | |
else: | |
candidate_image = candidate_image2 | |
if random.random() < 0.5: | |
if process.schedule < 3: | |
process.candidate_image1_idx = process.candidate_image1_path = random.choice([x for x in process.positive1_cand if x!=process.gt_image_idx]) | |
process.candidate_image2_idx = process.candidate_image2_path = random.choice(process.negative_cand) | |
process.candidate_image1_group, process.candidate_image2_group = "positive1", "negative" | |
raw_image,candidate_image1,candidate_image2 = load_candidate_images(dataset,candidate_image,process.candidate_image1_path,process.candidate_image2_path) | |
process.schedule += 1 | |
elif process.schedule == 3: | |
if len(process.positive_common_cand) != 0: | |
process.candidate_image1_idx = process.candidate_image1_path = random.choice([x for x in process.positive_common_cand if x!=process.gt_image_idx]) | |
process.candidate_image2_idx = process.candidate_image2_path = random.choice(process.negative_cand) | |
process.candidate_image1_group, process.candidate_image2_group = "positive_com", "negative" | |
raw_image,candidate_image1,candidate_image2 = load_candidate_images(dataset,candidate_image,process.candidate_image1_path,process.candidate_image2_path) | |
else: | |
process.raw_image_path = random.choice(process.positive2_cand) | |
process.candidate_image1_idx = process.candidate_image1_path = random.choice([x for x in process.positive2_cand if x!=process.gt_image_idx]) | |
process.candidate_image2_idx = process.candidate_image2_path = random.choice(process.negative_cand) | |
process.candidate_image1_group, process.candidate_image2_group = "positive2", "negative" | |
raw_image,candidate_image1,candidate_image2 = load_images(dataset,process.raw_image_path,process.candidate_image1_path,process.candidate_image2_path) | |
process.schedule += 1 | |
elif process.schedule < 7: | |
process.candidate_image1_idx = process.candidate_image1_path = random.choice([x for x in process.positive2_cand if x!=process.gt_image_idx]) | |
process.candidate_image2_idx = process.candidate_image2_path = random.choice(process.negative_cand) | |
process.candidate_image1_group, process.candidate_image2_group = "positive2", "negative" | |
raw_image,candidate_image1,candidate_image2 = load_candidate_images(dataset,candidate_image,process.candidate_image1_path,process.candidate_image2_path) | |
process.schedule += 1 | |
elif process.schedule == 7: | |
if len(process.positive_common_cand) != 0: | |
process.candidate_image1_path = random.choice([x for x in process.positive_common_cand if x!=process.gt_image_idx]) | |
process.candidate_image2_path = random.choice(process.negative_cand) | |
raw_image,candidate_image1,candidate_image2 = load_candidate_images(dataset,candidate_image,process.candidate_image1_path,process.candidate_image2_path) | |
else: | |
process.raw_image_path = random.choice(process.positive1_cand) | |
process.candidate_image1_path = random.choice([x for x in process.positive1_cand if x!=process.gt_image_idx]) | |
process.candidate_image2_path = random.choice(process.negative_cand) | |
raw_image,candidate_image1,candidate_image2 = load_images(dataset,process.raw_image_path,process.candidate_image1_path,process.candidate_image2_path) | |
process.schedule = 0 | |
process.gt_image = "Image1" | |
process.gt_image_idx = process.candidate_image1_idx | |
else: | |
if process.schedule < 3: | |
process.candidate_image2_idx = process.candidate_image2_path = random.choice([x for x in process.positive1_cand if x!=process.gt_image_idx]) | |
process.candidate_image1_idx = process.candidate_image1_path = random.choice(process.negative_cand) | |
process.candidate_image1_group, process.candidate_image2_group = "negative", "positive1" | |
raw_image,candidate_image1,candidate_image2 = load_candidate_images(dataset,candidate_image,process.candidate_image1_path,process.candidate_image2_path) | |
process.schedule += 1 | |
elif process.schedule == 3: | |
if len(process.positive_common_cand) != 0: | |
process.candidate_image2_idx = process.candidate_image2_path = random.choice([x for x in process.positive_common_cand if x!=process.gt_image_idx]) | |
process.candidate_image1_idx = process.candidate_image1_path = random.choice(process.negative_cand) | |
process.candidate_image1_group, process.candidate_image2_group = "negative", "positive_com" | |
raw_image,candidate_image1,candidate_image2 = load_candidate_images(dataset,candidate_image,process.candidate_image1_path,process.candidate_image2_path) | |
else: | |
process.raw_image_path = random.choice(process.positive2_cand) | |
process.candidate_image2_idx = process.candidate_image2_path = random.choice([x for x in process.positive2_cand if x!=process.gt_image_idx]) | |
process.candidate_image1_idx = process.candidate_image1_path = random.choice(process.negative_cand) | |
process.candidate_image1_group, process.candidate_image2_group = "negative", "positive2" | |
raw_image,candidate_image1,candidate_image2 = load_images(dataset,process.raw_image_path,process.candidate_image1_path,process.candidate_image2_path) | |
process.schedule += 1 | |
elif process.schedule < 7: | |
process.candidate_image2_idx = process.candidate_image2_path = random.choice([x for x in process.positive2_cand if x!=process.gt_image_idx]) | |
process.candidate_image1_idx = process.candidate_image1_path = random.choice(process.negative_cand) | |
process.candidate_image1_group, process.candidate_image2_group = "negative", "positive2" | |
raw_image,candidate_image1,candidate_image2 = load_candidate_images(dataset,candidate_image,process.candidate_image1_path,process.candidate_image2_path) | |
process.schedule += 1 | |
elif process.schedule == 7: | |
if len(process.positive_common_cand) != 0: | |
process.candidate_image2_idx = process.candidate_image2_path = random.choice([x for x in process.positive_common_cand if x!=process.gt_image_idx]) | |
process.candidate_image1_idx = process.candidate_image1_path = random.choice(process.negative_cand) | |
process.candidate_image1_group, process.candidate_image2_group = "negative", "positive_com" | |
raw_image,candidate_image1,candidate_image2 = load_candidate_images(dataset,candidate_image,process.candidate_image1_path,process.candidate_image2_path) | |
else: | |
process.raw_image_path = random.choice(process.positive1_cand) | |
process.candidate_image2_idx = process.candidate_image2_path = random.choice([x for x in process.positive1_cand if x!=process.gt_image_idx]) | |
process.candidate_image1_idx = process.candidate_image1_path = random.choice(process.negative_cand) | |
process.candidate_image1_group, process.candidate_image2_group = "negative", "positive1" | |
raw_image,candidate_image1,candidate_image2 = load_images(dataset,process.raw_image_path,process.candidate_image1_path,process.candidate_image2_path) | |
process.schedule = 0 | |
process.gt_image = "Image2" | |
process.gt_image_idx = process.candidate_image2_idx | |
if select_input != 'Uncertain': | |
steps = int(steps) + 1 | |
select_input = None | |
else: | |
show_result = "Error, Please reset!" | |
process.gt_image = None | |
return select_input,show_result, steps,raw_image,candidate_image1,candidate_image2 | |
def filter_images(dataset, exp_mode, concept_choices, image_filtered): | |
if image_filtered == None: | |
return None | |
if dataset == "ocl_attribute" or dataset == "ocl_affordance": | |
if dataset == "ocl_attribute": | |
pkl_path = "Data/OCL_data/OCL_selected_test_attribute_refined.pkl" | |
else: | |
pkl_path = "Data/OCL_data/OCL_selected_test_affordance_refined.pkl" | |
with open(pkl_path,"rb") as f: | |
data = pickle.load(f) | |
if exp_mode == "One concept": | |
if image_filtered == "Image1": | |
print(process.candidate_image1_idx) | |
if process.candidate_image1_group == "positive": | |
if process.candidate_image1_idx in data['selected_individual_pkl'][process.idx_to_chain[concept_choices]]: | |
data['selected_individual_pkl'][process.idx_to_chain[concept_choices]].remove(process.candidate_image1_idx) | |
elif process.candidate_image1_group == "negative": | |
if process.candidate_image1_idx in data["negative_pkl"]: | |
data["negative_pkl"].remove(process.candidate_image1_idx) | |
else: | |
print('Error') | |
else: | |
print(process.candidate_image2_idx) | |
if process.candidate_image2_group == "positive": | |
if process.candidate_image2_idx in data['selected_individual_pkl'][process.idx_to_chain[concept_choices]]: | |
data['selected_individual_pkl'][process.idx_to_chain[concept_choices]].remove(process.candidate_image2_idx) | |
elif process.candidate_image2_group == "negative": | |
if process.candidate_image2_idx in data["negative_pkl"]: | |
data["negative_pkl"].remove(process.candidate_image2_idx) | |
else: | |
print('Error') | |
else: | |
selected_concept_group = process.idx_to_chain[concept_choices].split("_") | |
selected_paired_pkl = data['selected_paired_pkl'][process.idx_to_chain[concept_choices]] | |
if image_filtered == "Image1": | |
print(process.candidate_image1_idx) | |
if process.candidate_image1_group == "positive1": | |
if process.candidate_image1_idx in selected_paired_pkl[selected_concept_group[0]]: | |
selected_paired_pkl[selected_concept_group[0]].remove(process.candidate_image1_idx) | |
elif process.candidate_image1_group == "positive2": | |
if process.candidate_image1_idx in selected_paired_pkl[selected_concept_group[1]]: | |
selected_paired_pkl[selected_concept_group[1]].remove(process.candidate_image1_idx) | |
elif process.candidate_image1_group == "positive_com": | |
if process.candidate_image1_idx in selected_paired_pkl[process.idx_to_chain[concept_choices]]: | |
selected_paired_pkl[process.idx_to_chain[concept_choices]].remove(process.candidate_image1_idx) | |
elif process.candidate_image1_group == "negative": | |
if process.candidate_image1_idx in data["negative_pkl"]: | |
data["negative_pkl"].remove(process.candidate_image1_idx) | |
else: | |
print('Error') | |
else: | |
print(process.candidate_image2_idx) | |
if process.candidate_image2_group == "positive1": | |
if process.candidate_image2_idx in selected_paired_pkl[selected_concept_group[0]]: | |
selected_paired_pkl[selected_concept_group[0]].remove(process.candidate_image2_idx) | |
elif process.candidate_image2_group == "positive2": | |
if process.candidate_image2_idx in selected_paired_pkl[selected_concept_group[1]]: | |
selected_paired_pkl[selected_concept_group[1]].remove(process.candidate_image2_idx) | |
elif process.candidate_image2_group == "positive_com": | |
if process.candidate_image2_idx in selected_paired_pkl[process.idx_to_chain[concept_choices]]: | |
selected_paired_pkl[process.idx_to_chain[concept_choices]].remove(process.candidate_image2_idx) | |
elif process.candidate_image2_group == "negative": | |
if process.candidate_image2_idx in data["negative_pkl"]: | |
data["negative_pkl"].remove(process.candidate_image2_idx) | |
else: | |
print('Error') | |
with open(pkl_path, "wb") as f: | |
pickle.dump(data, f) | |
elif dataset == "Pangea": | |
pkl_path = "Data/pangea/pangea_test_refined.pkl" | |
with open(pkl_path,"rb") as f: | |
data = pickle.load(f) | |
if exp_mode == "One concept": | |
if image_filtered == "Image1": | |
print(process.candidate_image1_idx) | |
if process.candidate_image1_group == "positive": | |
if process.candidate_image1_idx in data['selected_pkl'][process.idx_to_chain[concept_choices]]: | |
data['selected_pkl'][process.idx_to_chain[concept_choices]].remove(process.candidate_image1_idx) | |
elif process.candidate_image1_group == "negative": | |
if process.candidate_image1_idx in data["negative_pkl"]: | |
data["negative_pkl"].remove(process.candidate_image1_idx) | |
else: | |
print('Error') | |
else: | |
print(process.candidate_image2_idx) | |
if process.candidate_image2_group == "positive": | |
if process.candidate_image2_idx in data['selected_pkl'][process.idx_to_chain[concept_choices]]: | |
data['selected_pkl'][process.idx_to_chain[concept_choices]].remove(process.candidate_image2_idx) | |
elif process.candidate_image2_group == "negative": | |
if process.candidate_image2_idx in data["negative_pkl"]: | |
data["negative_pkl"].remove(process.candidate_image2_idx) | |
else: | |
print('Error') | |
else: | |
selected_concept_group = process.idx_to_chain[concept_choices].split("-") | |
selected_paired_pkl = data['selected_paired_pkl'][process.idx_to_chain[concept_choices]] | |
if image_filtered == "Image1": | |
print(process.candidate_image1_idx) | |
if process.candidate_image1_group == "positive1": | |
if process.candidate_image1_idx in selected_paired_pkl[selected_concept_group[0]]: | |
selected_paired_pkl[selected_concept_group[0]].remove(process.candidate_image1_idx) | |
elif process.candidate_image1_group == "positive2": | |
if process.candidate_image1_idx in selected_paired_pkl[selected_concept_group[1]]: | |
selected_paired_pkl[selected_concept_group[1]].remove(process.candidate_image1_idx) | |
elif process.candidate_image1_group == "positive_com": | |
if process.candidate_image1_idx in selected_paired_pkl[process.idx_to_chain[concept_choices]]: | |
selected_paired_pkl[process.idx_to_chain[concept_choices]].remove(process.candidate_image1_idx) | |
elif process.candidate_image1_group == "negative": | |
if process.candidate_image1_idx in data["negative_pkl"]: | |
data["negative_pkl"].remove(process.candidate_image1_idx) | |
else: | |
print('Error') | |
else: | |
print(process.candidate_image2_idx) | |
if process.candidate_image2_group == "positive1": | |
if process.candidate_image2_idx in selected_paired_pkl[selected_concept_group[0]]: | |
selected_paired_pkl[selected_concept_group[0]].remove(process.candidate_image2_idx) | |
elif process.candidate_image2_group == "positive2": | |
if process.candidate_image2_idx in selected_paired_pkl[selected_concept_group[1]]: | |
selected_paired_pkl[selected_concept_group[1]].remove(process.candidate_image2_idx) | |
elif process.candidate_image2_group == "positive_com": | |
if process.candidate_image2_idx in selected_paired_pkl[process.idx_to_chain[concept_choices]]: | |
selected_paired_pkl[process.idx_to_chain[concept_choices]].remove(process.candidate_image2_idx) | |
elif process.candidate_image2_group == "negative": | |
if process.candidate_image2_idx in data["negative_pkl"]: | |
data["negative_pkl"].remove(process.candidate_image2_idx) | |
else: | |
print('Error') | |
with open(pkl_path, "wb") as f: | |
pickle.dump(data, f) | |
else: | |
print("Error") | |
return None | |
with gr.Blocks() as demo: | |
title_markdown = (""" | |
# MLLM Associstion | |
[[Paper]](https://mvig-rhos.com) [[Code]](https://github.com/lihong2303/MLLMs_Association) | |
""") | |
# ![RHOS]("images/android-chrome-192x192.png") | |
cur_dir = os.path.dirname(os.path.abspath(__file__)) | |
gr.Markdown(title_markdown) | |
with gr.Row(): | |
with gr.Column(): | |
raw_image = gr.Image(label="Raw Image",interactive=False) | |
with gr.Column(): | |
candidate_image1 = gr.Image(label="Candidate Image 1",interactive=False) | |
with gr.Column(): | |
candidate_image2 = gr.Image(label="Candidate Image 2",interactive=False) | |
with gr.Row(): | |
candidate_concepts = gr.Label(value="", label="Candidate Concepts") | |
filter_Images = gr.Radio(choices=["Image1", "Image2"],label="Filter low quality image") | |
with gr.Row(): | |
dataset = gr.Dropdown(choices=["ocl_attribute","ocl_affordance","hmdb", "Pangea"],label="Select a dataset",interactive=True) | |
exp_mode = gr.Dropdown(choices=["One concept","Two concepts"],label="Select a test mode",interactive=True) | |
concept_choices = gr.Dropdown(choices=[],label = "Select the chain",interactive=True) | |
with gr.Row(): | |
select_input = gr.Radio(choices=["Image1","Image2","Uncertain"],label="Select candidate image") | |
steps = gr.Label(value="0",label="Steps") | |
show_result = gr.Label(value="",label="Selected Result") | |
# reset_button = gr.Button(text="Reset") | |
exp_mode.change(fn=get_concept_choices,inputs=[dataset,exp_mode],outputs=concept_choices) | |
concept_choices.change(fn=load_images_and_concepts, | |
inputs=[dataset,exp_mode,concept_choices], | |
outputs=[raw_image,candidate_image1,candidate_image2, candidate_concepts]) | |
filter_Images.change(fn=filter_images, inputs=[dataset, exp_mode, concept_choices, filter_Images], outputs=[filter_Images]) | |
select_input.change(fn=count_and_reload_images,inputs=[dataset,exp_mode,select_input,show_result,steps,raw_image,candidate_image1,candidate_image2],outputs=[select_input,show_result,steps,raw_image,candidate_image1,candidate_image2]) | |
demo.queue() | |
if __name__ == "__main__": | |
demo.launch() | |
# demo.launch(server_port=6126) | |
# import argparse | |
# argparser = argparse.ArgumentParser() | |
# argparser.add_argument("--server_name", default="0.0.0.0", type=str) | |
# argparser.add_argument("--port", default="6123", type=str) | |
# args = argparser.parse_args() | |
# try: | |
# demo.launch(server_name=args.server_name, server_port=int(args.port),share=False) | |
# except Exception as e: | |
# args.port=int(args.port)+1 | |
# print(f"Port {args.port} is occupied, try port {args.port}") | |
# demo.launch(server_name=args.server_name, server_port=int(args.port),share=False) | |