tanlei0
"v1"
8222fd4
raw
history blame contribute delete
No virus
4.26 kB
import json
import random
class FeatureCombiner():
def __init__(self, json_file, dataset=None) -> None:
with open(json_file, "r") as f:
data = json.load(f)
self.data = data
self.seg_data = [entry for entry in self.data if "annotations" in entry["image"]]
self.img_data = [entry for entry in self.data if "images" in entry["image"]]
self.img_data = [entry for entry in self.img_data if dataset in entry["dataset"]]
self.seg_data = [entry for entry in self.seg_data if dataset in entry["dataset"]]
self.dataset_field = dataset
if self.dataset_field is None:
self.feature_name = ["dataset", "location", "size", "trend_shape"]
else:
self.feature_name = ["location", "size", "trend_shape"]
def combine(self, is_seg_only=True, debug=False):
features_seg = []
features_img = []
# 选择seg_data的overview,这里假设每个entry中都有"overview"键
seg_overview_entry = random.choice(self.seg_data)
overview_seg = seg_overview_entry["overview"]
overview_img = self._find_img_entry_by_seg_entry(seg_overview_entry)["overview"]
features_seg.append(overview_seg)
features_img.append(overview_img)
if self.dataset_field != None:
features_seg.append(self.dataset_field)
features_img.append(self.dataset_field)
else:
raise "Not implement"
if not debug:
tmp_f_seg = []
tmp_f_img = []
# 根据is_seg_only选择数据集
data = self.seg_data if is_seg_only else self.data
for feature in self.feature_name:
# 随机选择特征,这里假设每个entry中都有这些feature_name中定义的键
entry_seg = random.choice(data)
entry_img = self._find_img_entry_by_seg_entry(entry_seg)
tmp_f_seg.append(entry_seg[feature])
tmp_f_img.append(entry_img[feature])
features_seg = features_seg + tmp_f_seg
features_img = features_img + tmp_f_img
caption_seg = ""
for sentence in features_seg:
# 整理sentence: 删除末尾的标点符号,去掉结尾多余的空格,然后在结尾添加分号,最后拼接到caption中
sentence = sentence.rstrip(';.!?') # 删除末尾的标点符号
sentence = sentence.strip() # 去掉结尾多余的空格
sentence += ";" # 在结尾添加分号
caption_seg += sentence + " " # 拼接到caption中,并添加一个空格以分隔句子
# 删除caption最后的分号和空格
caption_seg = caption_seg.rstrip('; ')
caption_img = ""
for sentence in features_img:
# 整理sentence: 删除末尾的标点符号,去掉结尾多余的空格,然后在结尾添加分号,最后拼接到caption中
sentence = sentence.rstrip(';.!?') # 删除末尾的标点符号
sentence = sentence.strip() # 去掉结尾多余的空格
sentence += ";" # 在结尾添加分号
caption_img += sentence + " " # 拼接到caption中,并添加一个空格以分隔句子
# 删除caption最后的分号和空格
caption_img = caption_img.rstrip('; ')
return caption_seg, caption_img
def _find_img_entry_by_seg_entry(self, seg_entry):
seg_path = seg_entry["image"]
file_name = seg_path.split('/')[-1].split(".")[0]
if "gt" in file_name:
file_name = file_name.replace("gt", "img")
for entry in self.img_data:
if file_name+"." in entry["image"]:
return entry
raise "No entry find"
if __name__ == "__main__":
json_file = r"/data/leiqin/dataset_curvilinear/dataset_curvilinear/crack/crack.json"
fc = FeatureCombiner(json_file,"Crack500 dataset")
for _ in range(9999):
caption1, caption2 = fc.combine()
print(caption1, caption2)