zhouyik's picture
Upload folder using huggingface_hub
032e687 verified
import json
import random
def RegionCaptionDataset_load_fn(data_path, repeat_time):
with open(data_path, 'r') as f:
json_file = json.load(f)
ret, hf_ret = [], []
for item in json_file:
item.update({'image': item['file_name']})
if len(item["description"]) != len(item["annotation"]):
print("The number of description is not equal to seg !!!")
else:
ret.append(item)
if repeat_time < 1:
ret = random.sample(ret, int(len(ret) * repeat_time))
elif repeat_time > 1:
int_repeat_time = int(repeat_time)
remaining_repeat_time = repeat_time - int_repeat_time
if remaining_repeat_time > 0:
remaining_ret = random.sample(
ret, int(len(ret) * remaining_repeat_time))
ret = ret * int_repeat_time
ret.extend(remaining_ret)
else:
ret = ret * int_repeat_time
for item in ret:
image = item["file_name"]
description = item["description"]
hf_required_info = {"image": image, "description": description}
hf_ret.append(hf_required_info)
return ret, hf_ret
def RegionConversationDataset_load_fn(data_path, repeat_time):
with open(data_path, 'r') as f:
json_file = json.load(f)
ret, hf_ret = [], []
for dataset_info in json_file:
if 'annotation' not in dataset_info or len(dataset_info['annotation']) == 0:
print("The annotation is not valid, filter out!!!")
continue
dataset_info.update({'image': dataset_info['file_name']})
ret.append(dataset_info)
if repeat_time < 1:
ret = random.sample(ret, int(len(ret) * repeat_time))
elif repeat_time > 1:
int_repeat_time = int(repeat_time)
remaining_repeat_time = repeat_time - int_repeat_time
if remaining_repeat_time > 0:
remaining_ret = random.sample(
ret, int(len(ret) * remaining_repeat_time))
ret = ret * int_repeat_time
ret.extend(remaining_ret)
else:
ret = ret * int_repeat_time
for dataset_info in ret:
conversations = dataset_info["conversations"]
image = dataset_info["file_name"]
num_regions = len(dataset_info['annotation'])
required_info = {'image': image, 'conversations': conversations,
'num_regions': num_regions}
hf_ret.append(required_info)
return ret, hf_ret