|
import json
|
|
import random
|
|
|
|
class FeatureCombiner():
|
|
def __init__(self, json_file, dataset=None) -> None:
|
|
with open(json_file, "r") as f:
|
|
data = json.load(f)
|
|
|
|
self.data = data
|
|
self.seg_data = [entry for entry in self.data if "annotations" in entry["image"]]
|
|
self.img_data = [entry for entry in self.data if "images" in entry["image"]]
|
|
|
|
self.img_data = [entry for entry in self.img_data if dataset in entry["dataset"]]
|
|
self.seg_data = [entry for entry in self.seg_data if dataset in entry["dataset"]]
|
|
|
|
self.dataset_field = dataset
|
|
if self.dataset_field is None:
|
|
self.feature_name = ["dataset", "location", "size", "trend_shape"]
|
|
else:
|
|
self.feature_name = ["location", "size", "trend_shape"]
|
|
|
|
def combine(self, is_seg_only=True, debug=False):
|
|
features_seg = []
|
|
features_img = []
|
|
|
|
seg_overview_entry = random.choice(self.seg_data)
|
|
overview_seg = seg_overview_entry["overview"]
|
|
overview_img = self._find_img_entry_by_seg_entry(seg_overview_entry)["overview"]
|
|
|
|
features_seg.append(overview_seg)
|
|
features_img.append(overview_img)
|
|
if self.dataset_field != None:
|
|
features_seg.append(self.dataset_field)
|
|
features_img.append(self.dataset_field)
|
|
else:
|
|
raise "Not implement"
|
|
|
|
if not debug:
|
|
tmp_f_seg = []
|
|
tmp_f_img = []
|
|
|
|
data = self.seg_data if is_seg_only else self.data
|
|
for feature in self.feature_name:
|
|
|
|
entry_seg = random.choice(data)
|
|
entry_img = self._find_img_entry_by_seg_entry(entry_seg)
|
|
tmp_f_seg.append(entry_seg[feature])
|
|
tmp_f_img.append(entry_img[feature])
|
|
features_seg = features_seg + tmp_f_seg
|
|
features_img = features_img + tmp_f_img
|
|
|
|
caption_seg = ""
|
|
for sentence in features_seg:
|
|
|
|
sentence = sentence.rstrip(';.!?')
|
|
sentence = sentence.strip()
|
|
sentence += ";"
|
|
caption_seg += sentence + " "
|
|
|
|
|
|
caption_seg = caption_seg.rstrip('; ')
|
|
|
|
caption_img = ""
|
|
for sentence in features_img:
|
|
|
|
sentence = sentence.rstrip(';.!?')
|
|
sentence = sentence.strip()
|
|
sentence += ";"
|
|
caption_img += sentence + " "
|
|
|
|
|
|
caption_img = caption_img.rstrip('; ')
|
|
|
|
return caption_seg, caption_img
|
|
|
|
def _find_img_entry_by_seg_entry(self, seg_entry):
|
|
seg_path = seg_entry["image"]
|
|
file_name = seg_path.split('/')[-1].split(".")[0]
|
|
if "gt" in file_name:
|
|
file_name = file_name.replace("gt", "img")
|
|
|
|
for entry in self.img_data:
|
|
if file_name+"." in entry["image"]:
|
|
return entry
|
|
|
|
raise "No entry find"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
json_file = r"/data/leiqin/dataset_curvilinear/dataset_curvilinear/crack/crack.json"
|
|
fc = FeatureCombiner(json_file,"Crack500 dataset")
|
|
for _ in range(9999):
|
|
caption1, caption2 = fc.combine()
|
|
print(caption1, caption2)
|
|
|