Anyou commited on
Commit
b2b0303
1 Parent(s): 2473491

Upload 8 files

Browse files
data_script/flintstones_hdf5.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import pickle
5
+
6
+ import cv2
7
+ import h5py
8
+ import numpy as np
9
+ from tqdm import tqdm
10
+
11
+
12
+ def main(args):
13
+ splits = json.load(open(os.path.join(args.data_dir, 'train-val-test_split.json'), 'r'))
14
+ train_ids, val_ids, test_ids = splits["train"], splits["val"], splits["test"]
15
+ followings = pickle.load(open(os.path.join(args.data_dir, 'following_cache4.pkl'), 'rb'))
16
+ annotations = json.load(open(os.path.join(args.data_dir, 'flintstones_annotations_v1-0.json')))
17
+ descriptions = dict()
18
+ for sample in annotations:
19
+ descriptions[sample["globalID"]] = sample["description"]
20
+
21
+ f = h5py.File(args.save_path, "w")
22
+ for subset, ids in {'train': train_ids, 'val': val_ids, 'test': test_ids}.items():
23
+ ids = [i for i in ids if i in followings and len(followings[i]) == 4]
24
+ length = len(ids)
25
+
26
+ group = f.create_group(subset)
27
+ images = list()
28
+ for i in range(5):
29
+ images.append(
30
+ group.create_dataset('image{}'.format(i), (length,), dtype=h5py.vlen_dtype(np.dtype('uint8'))))
31
+ text = group.create_dataset('text', (length,), dtype=h5py.string_dtype(encoding='utf-8'))
32
+ for i, item in enumerate(tqdm(ids, leave=True, desc="saveh5")):
33
+ globalIDs = [item] + followings[item]
34
+ txt = list()
35
+ for j, globalID in enumerate(globalIDs):
36
+ img = np.load(os.path.join(args.data_dir, 'video_frames_sampled', '{}.npy'.format(globalID)))
37
+ img = np.concatenate(img, axis=0).astype(np.uint8)
38
+ img = cv2.imencode('.png', img)[1].tobytes()
39
+ img = np.frombuffer(img, np.uint8)
40
+ images[j][i] = img
41
+ txt.append(descriptions[globalID])
42
+ text[i] = '|'.join([t.replace('\n', '').replace('\t', '').strip() for t in txt])
43
+ f.close()
44
+
45
+
46
+ if __name__ == '__main__':
47
+ parser = argparse.ArgumentParser(description='arguments for flintstones hdf5 file saving')
48
+ parser.add_argument('--data_dir', type=str, required=True, help='flintstones data directory')
49
+ parser.add_argument('--save_path', type=str, required=True, help='path to save hdf5')
50
+ args = parser.parse_args()
51
+ main(args)
data_script/pororo_hdf5.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import cv2
5
+ import h5py
6
+ import numpy as np
7
+ from PIL import Image
8
+ from tqdm import tqdm
9
+
10
+
11
+ def main(args):
12
+ # 使用numpy库的load函数来加载名为descriptions.npy的文件。该文件是一个Python字典对象,因此我们使用item()方法将其转换为字典对象。
13
+ # ——os.path.join函数用于连接文件路径
14
+ # ——args.data_dir作为基础目录,将'descriptions.npy'添加到该目录中
15
+ # ——指定allow_pickle=True,表示允许加载包含Python对象的文件
16
+ # ——指定encoding='latin1',表示使用拉丁字符编码加载该文件
17
+ descriptions = np.load(os.path.join(args.data_dir, 'descriptions.npy'), allow_pickle=True, encoding='latin1').item()
18
+ # imgs_list包含一组图像文件的路径,
19
+ # followings_list包含每个图像的一些附加信息
20
+ imgs_list = np.load(os.path.join(args.data_dir, 'img_cache4.npy'), encoding='latin1')
21
+ followings_list = np.load(os.path.join(args.data_dir, 'following_cache4.npy'))
22
+ # 使用numpy库的load函数来加载名为train_seen_unseen_ids.npy的文件
23
+ # 该文件包含三个numpy数组:train_ids、val_ids和test_ids,分别代表训练集、验证集和测试集的ID列表。
24
+ # 使用元组来一次性加载这三个数组,并将它们赋值给相应的变量。
25
+ train_ids, val_ids, test_ids = np.load(os.path.join(args.data_dir, 'train_seen_unseen_ids.npy'), allow_pickle=True)
26
+ # 按照ID的顺序逐一排序
27
+ train_ids = np.sort(train_ids)
28
+ val_ids = np.sort(val_ids)
29
+ test_ids = np.sort(test_ids)
30
+
31
+ # 创建一个新的HDF5文件,并指定文件名为args.save_path。
32
+ # 使用h5py库的File函数来创建文件对象,指定打开方式为写模式("w")。
33
+ # 在这个文件中存储处理后的图像和文本数据。
34
+ f = h5py.File(args.save_path, "w")
35
+ for subset, ids in {'train': train_ids, 'val': val_ids, 'test': test_ids}.items():
36
+ length = len(ids)
37
+
38
+ # 为每个数据集(train、val和test)创建一个组
39
+ # 针对每个数据集都创建了5个数据集,名为'image0'、'image1'、'image2'、'image3'、'image4',分别对应于当前图像及其相关联的4个图像。
40
+ # 目的:将每个图像及其相关联的图像数据保存到同一个HDF5文件中,并按照一定的组织方式存储,方便后续的数据读取和处理。
41
+ group = f.create_group(subset)
42
+ # 创建一个长度为ids列表长度的空列表images,按照image0-4顺序添加了5个HDF5数据集对象
43
+ images = list()
44
+ # 为当前数据集中的每个图像创建了五个数据集。
45
+ # 每个数据集都使用vlen_dtype(np.dtype('uint8'))作为数据类型,并将其添加到当前组group中。
46
+ # ——vlen_dtype(np.dtype('uint8'))表示可变长度的无符号8位整数数组。
47
+ for i in range(5):
48
+ images.append(
49
+ group.create_dataset('image{}'.format(i), (length,), dtype=h5py.vlen_dtype(np.dtype('uint8'))))
50
+ # 创建一个数据集text,用于存储与当前数据集中图像相关的文本描述。该数据集的数据类型为字符串,编码方式为utf-8,并将其添加到当前组group中。
51
+ text = group.create_dataset('text', (length,), dtype=h5py.string_dtype(encoding='utf-8'))
52
+ # 遍历当前数据集中的每个图像,并将相关数据保存到HDF5文件中
53
+ for i, item in enumerate(tqdm(ids, leave=True, desc="saveh5")):
54
+ # 获取与当前图像相关的所有图像的路径,存储到列表img_paths中。
55
+ # ——imgs_list是一个字典,存储了所有图像的路径
56
+ # ——followings_list是一个字典,存储了与每个图像相关的四张图像的路径
57
+ img_paths = [str(imgs_list[item])[2:-1]] + [str(followings_list[item][i])[2:-1] for i in range(4)]
58
+ # 打开img_paths列表中的每个图像,并将其转换为RGB格式的PIL图像对象。
59
+ imgs = [Image.open(os.path.join(args.data_dir, img_path)).convert('RGB') for img_path in img_paths]
60
+ # 将每个PIL图像对象转换为numpy数组
61
+ for j, img in enumerate(imgs):
62
+ img = np.array(img).astype(np.uint8)
63
+ # 使用OpenCV将其编码为png格式的二进制数据
64
+ img = cv2.imencode('.png', img)[1].tobytes()
65
+ # 将该二进制数据转换为numpy数组
66
+ img = np.frombuffer(img, np.uint8)
67
+ # 将其存储到images列表中与当前图像相关的数据集中
68
+ images[j][i] = img
69
+ # 获取与当前图像相关的所有图像的文件名,并将其存储到列表tgt_img_ids中
70
+ tgt_img_ids = [str(img_path).replace('.png', '') for img_path in img_paths]
71
+ # 根据目标图像的文件名,获取其对应的文本描述,并将其存储到列表txt中。
72
+ txt = [descriptions[tgt_img_id][0] for tgt_img_id in tgt_img_ids]
73
+ # 将txt列表中的所有文本描述合并为一个字符串,并将其中的"\n"、"\t"等无关字符替换为空格。然后,将该字符串存储到数据集text中
74
+ text[i] = '|'.join([t.replace('\n', '').replace('\t', '').strip() for t in txt])
75
+ f.close()
76
+
77
+
78
+ if __name__ == '__main__':
79
+ parser = argparse.ArgumentParser(description='arguments for flintstones pororo file saving')
80
+ parser.add_argument('--data_dir', type=str, required=True, help='pororo data directory')
81
+ parser.add_argument('--save_path', type=str, required=True, help='path to save hdf5')
82
+ args = parser.parse_args()
83
+ main(args)
data_script/vist_hdf5.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+
5
+ import cv2
6
+ import h5py
7
+ import numpy as np
8
+ from PIL import Image
9
+ from tqdm import tqdm
10
+
11
+
12
+ def main(args):
13
+ train_data = json.load(open(os.path.join(args.sis_json_dir, 'train.story-in-sequence.json')))
14
+ val_data = json.load(open(os.path.join(args.sis_json_dir, 'val.story-in-sequence.json')))
15
+ test_data = json.load(open(os.path.join(args.sis_json_dir, 'test.story-in-sequence.json')))
16
+
17
+ prefix = ["train", "val", "test"]
18
+ whole_album = {}
19
+ for i, data in enumerate([train_data, val_data, test_data]):
20
+ album_mapping = {}
21
+ for annot_new in data["annotations"]:
22
+ annot = annot_new[0]
23
+ assert len(annot_new) == 1
24
+ if annot['story_id'] not in album_mapping:
25
+ album_mapping[annot['story_id']] = {"flickr_id": [annot['photo_flickr_id']],
26
+ "sis": [annot['original_text']],
27
+ "length": 1}
28
+ else:
29
+ album_mapping[annot['story_id']]["flickr_id"].append(annot['photo_flickr_id'])
30
+ album_mapping[annot['story_id']]["sis"].append(
31
+ annot['original_text'])
32
+ album_mapping[annot['story_id']]["length"] += 1
33
+ whole_album[prefix[i]] = album_mapping
34
+
35
+ for p in prefix:
36
+ deletables = []
37
+ for story_id, story in whole_album[p].items():
38
+ if story['length'] != 5:
39
+ print("deleting {}".format(story_id))
40
+ deletables.append(story_id)
41
+ continue
42
+ d = [os.path.exists(os.path.join(args.img_dir, "{}.jpg".format(_))) for _ in story["flickr_id"]]
43
+ if sum(d) < 5:
44
+ print("deleting {}".format(story_id))
45
+ deletables.append(story_id)
46
+ else:
47
+ pass
48
+ for i in deletables:
49
+ del whole_album[p][i]
50
+
51
+ train_data = json.load(open(os.path.join(args.sis_json_dir, 'train.description-in-isolation.json')))
52
+ val_data = json.load(open(os.path.join(args.sis_json_dir, 'val.description-in-isolation.json')))
53
+ test_data = json.load(open(os.path.join(args.sis_json_dir, 'test.description-in-isolation.json')))
54
+
55
+ flickr_id2text = {}
56
+ for i, data in enumerate([train_data, val_data, test_data]):
57
+ for l in data['annotations']:
58
+ assert len(l) == 1
59
+ if l[0]['photo_flickr_id'] in flickr_id2text:
60
+ flickr_id2text[l[0]['photo_flickr_id']] = \
61
+ max([flickr_id2text[l[0]['photo_flickr_id']], l[0]['original_text']], key=len)
62
+ else:
63
+ flickr_id2text[l[0]['photo_flickr_id']] = l[0]['original_text']
64
+
65
+ for p in prefix:
66
+ deletables = []
67
+ for story_id, story in whole_album[p].items():
68
+ story['dii'] = []
69
+ for i, flickr_id in enumerate(story['flickr_id']):
70
+ if flickr_id not in flickr_id2text:
71
+ print("{} not found in story {}".format(flickr_id, story_id))
72
+ deletables.append(story_id)
73
+ break
74
+ story['dii'].append(flickr_id2text[flickr_id])
75
+ for i in deletables:
76
+ del whole_album[p][i]
77
+
78
+ f = h5py.File(args.save_path, "w")
79
+ for p in prefix:
80
+ group = f.create_group(p)
81
+ story_dict = whole_album[p]
82
+ length = len(story_dict)
83
+ images = list()
84
+ for i in range(5):
85
+ images.append(
86
+ group.create_dataset('image{}'.format(i), (length,), dtype=h5py.vlen_dtype(np.dtype('uint8'))))
87
+ sis = group.create_dataset('sis', (length,), dtype=h5py.string_dtype(encoding='utf-8'))
88
+ dii = group.create_dataset('dii', (length,), dtype=h5py.string_dtype(encoding='utf-8'))
89
+ for i, (story_id, story) in enumerate(tqdm(story_dict.items(), leave=True, desc="saveh5")):
90
+ imgs = [Image.open('{}/{}.jpg'.format(args.img_dir, flickr_id)).convert('RGB') for flickr_id in
91
+ story['flickr_id']]
92
+ for j, img in enumerate(imgs):
93
+ img = np.array(img).astype(np.uint8)
94
+ img = cv2.imencode('.png', img)[1].tobytes()
95
+ img = np.frombuffer(img, np.uint8)
96
+ images[j][i] = img
97
+ sis[i] = '|'.join([t.replace('\n', '').replace('\t', '').strip() for t in story['sis']])
98
+ txt_dii = [t.replace('\n', '').replace('\t', '').strip() for t in story['dii']]
99
+ txt_dii = sorted(set(txt_dii), key=txt_dii.index)
100
+ dii[i] = '|'.join(txt_dii)
101
+ f.close()
102
+
103
+
104
+ if __name__ == '__main__':
105
+ parser = argparse.ArgumentParser(description='arguments for vist hdf5 file saving')
106
+ parser.add_argument('--sis_json_dir', type=str, required=True, help='sis json file directory')
107
+ parser.add_argument('--dii_json_dir', type=str, required=True, help='dii json file directory')
108
+ parser.add_argument('--img_dir', type=str, required=True, help='json file directory')
109
+ parser.add_argument('--save_path', type=str, required=True, help='path to save hdf5')
110
+ args = parser.parse_args()
111
+ main(args)
data_script/vist_img_download.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import requests
3
+ from io import BytesIO
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+ from multiprocessing import Process
7
+ import os
8
+ import argparse
9
+
10
+
11
+ def download_subprocess(dii, save_dir):
12
+ for image in tqdm(dii):
13
+ key, value = image.popitem()
14
+ try:
15
+ img_data = requests.get(value).content
16
+ img = Image.open(BytesIO(img_data)).convert('RGB')
17
+ h = img.size[0]
18
+ w = img.size[1]
19
+ if min(h, w) > 512:
20
+ img = img.resize((int(h / (w / 512)), 512) if h > w else (512, int(w / (h / 512))))
21
+ img.save('{}/{}.jpg'.format(save_dir, key))
22
+ except:
23
+ print(key, value)
24
+
25
+
26
+ def main(args):
27
+ train_data = json.load(open(os.path.join(args.json_dir, 'train.description-in-isolation.json')))
28
+ val_data = json.load(open(os.path.join(args.json_dir, 'val.description-in-isolation.json')))
29
+ test_data = json.load(open(os.path.join(args.json_dir, 'test.description-in-isolation.json')))
30
+ dii = []
31
+ for subset in [train_data, val_data, test_data]:
32
+ for image in subset["images"]:
33
+ try:
34
+ dii.append({image['id']: image['url_o']})
35
+ except:
36
+ dii.append({image['id']: image['url_m']})
37
+
38
+ dii = [image for image in dii if not os.path.exists('{}/{}.jpg'.format(args.save_dir, list(image)[0]))]
39
+ print('total images: {}'.format(len(dii)))
40
+
41
+ def splitlist(inlist, chunksize):
42
+ return [inlist[x:x + chunksize] for x in range(0, len(inlist), chunksize)]
43
+
44
+ dii_splitted = splitlist(dii, int((len(dii) / args.num_process)))
45
+ process_list = []
46
+ for dii_sub_list in dii_splitted:
47
+ p = Process(target=download_subprocess, args=(dii_sub_list,))
48
+ process_list.append(p)
49
+ p.Daemon = True
50
+ p.start()
51
+ for p in process_list:
52
+ p.join()
53
+
54
+
55
+ if __name__ == "__main__":
56
+ parser = argparse.ArgumentParser(description='arguments for vist images downloading')
57
+ parser.add_argument('--json_dir', type=str, required=True, help='dii json file directory')
58
+ parser.add_argument('--img_dir', type=str, required=True, help='images saving directory')
59
+ parser.add_argument('--num_process', type=int, default=32)
60
+ args = parser.parse_args()
61
+ main(args)
datasets/flintstones.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import cv2
4
+ import h5py
5
+ import numpy as np
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+ from torchvision import transforms
9
+ from transformers import CLIPTokenizer
10
+
11
+ from models.blip_override.blip import init_tokenizer
12
+
13
+
14
+ class StoryDataset(Dataset):
15
+ """
16
+ A custom subset class for the LRW (includes train, val, test) subset
17
+ """
18
+
19
+ def __init__(self, subset, args):
20
+ super(StoryDataset, self).__init__()
21
+ self.args = args
22
+
23
+ self.h5_file = args.get(args.dataset).hdf5_file
24
+ self.subset = subset
25
+
26
+ self.augment = transforms.Compose([
27
+ transforms.ToPILImage(),
28
+ transforms.Resize([512, 512]),
29
+ transforms.ToTensor(),
30
+ transforms.Normalize([0.5], [0.5])
31
+ ])
32
+ self.dataset = args.dataset
33
+ self.max_length = args.get(args.dataset).max_length
34
+ self.clip_tokenizer = CLIPTokenizer.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="tokenizer")
35
+ self.blip_tokenizer = init_tokenizer()
36
+ msg = self.clip_tokenizer.add_tokens(list(args.get(args.dataset).new_tokens))
37
+ print("clip {} new tokens added".format(msg))
38
+ msg = self.blip_tokenizer.add_tokens(list(args.get(args.dataset).new_tokens))
39
+ print("blip {} new tokens added".format(msg))
40
+
41
+ self.blip_image_processor = transforms.Compose([
42
+ transforms.ToPILImage(),
43
+ transforms.Resize([224, 224]),
44
+ transforms.ToTensor(),
45
+ transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711])
46
+ ])
47
+
48
+ def open_h5(self):
49
+ h5 = h5py.File(self.h5_file, "r")
50
+ self.h5 = h5[self.subset]
51
+
52
+ def __getitem__(self, index):
53
+ if not hasattr(self, 'h5'):
54
+ self.open_h5()
55
+
56
+ images = list()
57
+ for i in range(5):
58
+ im = self.h5['image{}'.format(i)][index]
59
+ im = cv2.imdecode(im, cv2.IMREAD_COLOR)
60
+ idx = random.randint(0, 4)
61
+ images.append(im[idx * 128: (idx + 1) * 128])
62
+
63
+ source_images = torch.stack([self.blip_image_processor(im) for im in images])
64
+ images = images[1:] if self.args.task == 'continuation' else images
65
+ images = torch.stack([self.augment(im) for im in images]) \
66
+ if self.subset in ['train', 'val'] else torch.from_numpy(np.array(images)).permute(0, 3, 1, 2)
67
+
68
+ texts = self.h5['text'][index].decode('utf-8').split('|')
69
+
70
+ # tokenize caption using default tokenizer
71
+ tokenized = self.clip_tokenizer(
72
+ texts[1:] if self.args.task == 'continuation' else texts,
73
+ padding="max_length",
74
+ max_length=self.max_length,
75
+ truncation=False,
76
+ return_tensors="pt",
77
+ )
78
+ captions, attention_mask = tokenized['input_ids'], tokenized['attention_mask']
79
+
80
+ tokenized = self.blip_tokenizer(
81
+ texts,
82
+ padding="max_length",
83
+ max_length=self.max_length,
84
+ truncation=False,
85
+ return_tensors="pt",
86
+ )
87
+ source_caption, source_attention_mask = tokenized['input_ids'], tokenized['attention_mask']
88
+ return images, captions, attention_mask, source_images, source_caption, source_attention_mask
89
+
90
+ def __len__(self):
91
+ if not hasattr(self, 'h5'):
92
+ self.open_h5()
93
+ return len(self.h5['text'])
datasets/pororo.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ import random
4
+ from PIL import Image
5
+ import cv2
6
+ import h5py
7
+ import numpy as np
8
+ import torch
9
+ from torch.utils.data import Dataset
10
+ from torchvision import transforms
11
+ from transformers import CLIPTokenizer
12
+
13
+ from models.blip_override.blip import init_tokenizer
14
+
15
+
16
+ class StoryDataset(Dataset):
17
+ """
18
+ A custom subset class for the LRW (includes train, val, test) subset
19
+ """
20
+ # StoryDataset 类的构造函数
21
+ def __init__(self, subset, args):
22
+ # 用来调用父类 Dataset 的初始化函数,确保该类能够继承 Dataset 类的所有方法和属性。
23
+ super(StoryDataset, self).__init__()
24
+ # args 则是该类的其他参数,是一个命名空间(namespace)对象
25
+ self.args = args
26
+ # 一个 HDF5 文件的路径,存储了训练、验证和测试集的图像和文本数据。
27
+ # ——args.get(args.dataset)表示从命名空间对象args中获取指定数据集(训练集、验证集或测试集)的参数。
28
+ self.h5_file = args.get(args.dataset).hdf5_file
29
+ # 初始化函数中 subset 表示要读取的子集的类型(如训练集、验证集、测试集)
30
+ self.subset = subset
31
+
32
+ # 一个图像变换函数序列(transform),用来对图像进行预处理,包括将图像转化为 PIL 格式,调整图像大小,将图像转换为 Tensor,并进行归一化。
33
+ self.augment = transforms.Compose([
34
+ transforms.ToPILImage(),
35
+ # transforms.Resize([256, 256]),
36
+ transforms.Resize([512, 512]),
37
+ transforms.ToTensor(),
38
+ transforms.Normalize([0.5], [0.5])
39
+ ])
40
+ # 表示当前数据集的类型(训练集、验证集或测试集)
41
+ self.dataset = args.dataset
42
+ # 最大的 caption 长度,在进行tokenize操作时,caption中的单词数量将被填充到该长度。
43
+ self.max_length = args.get(args.dataset).max_length
44
+ # 一个使用CLIP模型进行tokenize的tokenizer
45
+ self.clip_tokenizer = CLIPTokenizer.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="tokenizer")
46
+ # 一个自定义的tokenizer,用于处理文本输入
47
+ self.blip_tokenizer = init_tokenizer()
48
+ msg = self.clip_tokenizer.add_tokens(list(args.get(args.dataset).new_tokens))
49
+ print("clip {} new tokens added".format(msg))
50
+ msg = self.blip_tokenizer.add_tokens(list(args.get(args.dataset).new_tokens))
51
+ print("blip {} new tokens added".format(msg))
52
+
53
+ # 一个用于对输入的图像进行处理的函数序列,包括转换为PIL图像、重置图像大小、转换为tensor、归一化等。
54
+ self.blip_image_processor = transforms.Compose([
55
+ transforms.ToPILImage(),
56
+ transforms.Resize([224, 224]),
57
+ transforms.ToTensor(),
58
+ transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711])
59
+ ])
60
+
61
+ # 打开与数据集对应的h5文件
62
+ def open_h5(self):
63
+ h5 = h5py.File(self.h5_file, "r")
64
+ self.h5 = h5[self.subset]
65
+
66
+ # 用于按索引获取数据。
67
+
68
+ # 对于每个图像,都进行数据增强操作,以进行数据增强。
69
+ # 然后,将文本输入的caption进行tokenize操作,
70
+ # 使用CLIP tokenizer和自定义tokenizer分别进行tokenize。
71
+ # 最后,将处理好的图像、caption和attention mask返回
72
+ def __getitem__(self, index):
73
+ # 首先调用open_h5()打开数据集的h5文件
74
+ if not hasattr(self, 'h5'):
75
+ self.open_h5()
76
+ #index = 1
77
+ images = list()
78
+ for i in range(5):
79
+ # 从h5文件中读取一组图像和对应的文本。
80
+ im = self.h5['image{}'.format(i)][index]
81
+ # print(im)
82
+ # pil_img = Image.fromarray(im)
83
+ # # 保存图像
84
+ # pil_img.save(os.path.join('/root/lihui/StoryVisualization/ori_test_images', '{:04d}.png'.format(i)))
85
+ # 对每个图像解码
86
+ im = cv2.imdecode(im, cv2.IMREAD_COLOR)
87
+ # 随机选择一个128像素的图像切片
88
+ idx = random.randint(0, im.shape[0] / 128 - 1)
89
+ # 将切片后的图像加到images列表中
90
+ images.append(im[idx * 128: (idx + 1) * 128])
91
+ # 深拷贝,后续不随images变化
92
+ ori_images = copy.deepcopy(images)
93
+ # 保存test原始图像
94
+
95
+ # for i, im in enumerate(images):
96
+ # file_path = '/root/lihui/StoryVisualization/ori_test_images/group{:02d}_image{:02d}.png'.format(index + 1,
97
+ # i + 1)
98
+ # cv2.imwrite(file_path, im)
99
+ # 将图像转换为张量
100
+ source_images = torch.stack([self.blip_image_processor(im) for im in images])
101
+ # 如果为continuation任务,将列表中的第一个图像从images中移除
102
+ images = images[1:] if self.args.task == 'continuation' else images
103
+ # 如果subset的值为train/val,则使用augment方法对images列表中的所有图像进行数据增强,并将其转换为张量
104
+ # 否则使用numpy.array方法将images列表转换为张量,并进行转置操作
105
+ images = torch.stack([self.augment(im) for im in images]) \
106
+ if self.subset in ['train', 'val'] else torch.from_numpy(np.array(images)).permute(0, 3, 1, 2)
107
+ ######################
108
+ # 读取当前索引处的文本,并使用decode方法将其解码为UTF-8
109
+ texts = self.h5['text'][index].decode('utf-8').split('|')
110
+ # print(f"index: {index}")
111
+ # for text in texts:
112
+ # print(f"texts: {text}")
113
+
114
+ # tokenize caption using default tokenizer
115
+ tokenized = self.clip_tokenizer(
116
+ texts[1:] if self.args.task == 'continuation' else texts,
117
+ padding="max_length",
118
+ max_length=self.max_length,
119
+ truncation=False,
120
+ return_tensors="pt",
121
+ )
122
+ captions, attention_mask = tokenized['input_ids'], tokenized['attention_mask']
123
+
124
+ tokenized = self.blip_tokenizer(
125
+ texts,
126
+ padding="max_length",
127
+ max_length=self.max_length,
128
+ truncation=False,
129
+ return_tensors="pt",
130
+ )
131
+ source_caption, source_attention_mask = tokenized['input_ids'], tokenized['attention_mask']
132
+ return images, captions, attention_mask, source_images, source_caption, source_attention_mask, texts, ori_images
133
+
134
+ # 返回数据集中样本的数量
135
+ # 如果是测试集,则返回100,否则返回对应的数据集中的样本数量
136
+ def __len__(self):
137
+ if not hasattr(self, 'h5'):
138
+ self.open_h5()
139
+ if self.subset == 'test':
140
+ #print('')
141
+ return 1
142
+ # if self.subset == 'test':
143
+ # return 100
144
+ return len(self.h5['text'])
datasets/vistdii.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import h5py
3
+ import numpy as np
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ from torchvision import transforms
7
+ from transformers import CLIPTokenizer
8
+
9
+ from models.blip_override.blip import init_tokenizer
10
+
11
+
12
+ class StoryDataset(Dataset):
13
+ """
14
+ A custom subset class for the LRW (includes train, val, test) subset
15
+ """
16
+
17
+ def __init__(self, subset, args):
18
+ super(StoryDataset, self).__init__()
19
+ self.args = args
20
+
21
+ self.h5_file = args.get(args.dataset).hdf5_file
22
+ self.subset = subset
23
+
24
+ self.augment = transforms.Compose([
25
+ transforms.ToPILImage(),
26
+ transforms.Resize(512),
27
+ transforms.RandomCrop(512) if self.subset == 'train' else transforms.CenterCrop(512),
28
+ transforms.ToTensor(),
29
+ transforms.Normalize([0.5], [0.5])
30
+ ]) if self.subset in ['train', 'val'] else transforms.Compose([
31
+ transforms.ToPILImage(),
32
+ transforms.Resize(64),
33
+ transforms.CenterCrop(64)
34
+ ])
35
+
36
+ self.dataset = args.dataset
37
+ self.max_length = args.get(args.dataset).max_length
38
+ self.clip_tokenizer = CLIPTokenizer.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="tokenizer")
39
+ self.blip_tokenizer = init_tokenizer()
40
+
41
+ self.blip_image_processor = transforms.Compose([
42
+ transforms.ToPILImage(),
43
+ transforms.Resize(224),
44
+ transforms.RandomCrop(224) if self.subset == 'train' else transforms.CenterCrop(224),
45
+ transforms.ToTensor(),
46
+ transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711])
47
+ ])
48
+
49
+ def open_h5(self):
50
+ h5 = h5py.File(self.h5_file, "r")
51
+ self.h5 = h5[self.subset]
52
+
53
+ def __getitem__(self, index):
54
+ if not hasattr(self, 'h5'):
55
+ self.open_h5()
56
+
57
+ images = list()
58
+ for i in range(5):
59
+ im = self.h5['image{}'.format(i)][index]
60
+ im = cv2.imdecode(im, cv2.IMREAD_COLOR)
61
+ images.append(im)
62
+
63
+ source_images = torch.stack([self.blip_image_processor(im) for im in images])
64
+ images = images[1:] if self.args.task == 'continuation' else images
65
+ images = [self.augment(im) for im in images]
66
+ images = torch.stack(images) if self.subset in ['train', 'val'] \
67
+ else torch.from_numpy(np.array([np.array(im) for im in images])).permute(0, 3, 1, 2)
68
+
69
+ texts = self.h5['dii'][index].decode('utf-8').split('|')
70
+
71
+ # tokenize caption using default tokenizer
72
+ tokenized = self.clip_tokenizer(
73
+ texts[1:] if self.args.task == 'continuation' else texts,
74
+ padding="max_length",
75
+ max_length=self.max_length,
76
+ truncation=False,
77
+ return_tensors="pt",
78
+ )
79
+ captions, attention_mask = tokenized['input_ids'], tokenized['attention_mask']
80
+
81
+ tokenized = self.blip_tokenizer(
82
+ texts,
83
+ padding="max_length",
84
+ max_length=self.max_length,
85
+ truncation=False,
86
+ return_tensors="pt",
87
+ )
88
+ source_caption, source_attention_mask = tokenized['input_ids'], tokenized['attention_mask']
89
+ return images, captions, attention_mask, source_images, source_caption, source_attention_mask
90
+
91
+ def __len__(self):
92
+ if not hasattr(self, 'h5'):
93
+ self.open_h5()
94
+ return len(self.h5['dii'])
datasets/vistsis.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import h5py
3
+ import numpy as np
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ from torchvision import transforms
7
+ from transformers import CLIPTokenizer
8
+
9
+ from models.blip_override.blip import init_tokenizer
10
+
11
+
12
+ class StoryDataset(Dataset):
13
+ """
14
+ A custom subset class for the LRW (includes train, val, test) subset
15
+ """
16
+
17
+ def __init__(self, subset, args):
18
+ super(StoryDataset, self).__init__()
19
+ self.args = args
20
+
21
+ self.h5_file = args.get(args.dataset).hdf5_file
22
+ self.subset = subset
23
+
24
+ self.augment = transforms.Compose([
25
+ transforms.ToPILImage(),
26
+ transforms.Resize(512),
27
+ transforms.RandomCrop(512) if self.subset == 'train' else transforms.CenterCrop(512),
28
+ transforms.ToTensor(),
29
+ transforms.Normalize([0.5], [0.5])
30
+ ]) if self.subset in ['train', 'val'] else transforms.Compose([
31
+ transforms.ToPILImage(),
32
+ transforms.Resize(64),
33
+ transforms.CenterCrop(64)
34
+ ])
35
+
36
+ self.dataset = args.dataset
37
+ self.max_length = args.get(args.dataset).max_length
38
+ self.clip_tokenizer = CLIPTokenizer.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="tokenizer")
39
+ self.blip_tokenizer = init_tokenizer()
40
+
41
+ self.blip_image_processor = transforms.Compose([
42
+ transforms.ToPILImage(),
43
+ transforms.Resize(224),
44
+ transforms.RandomCrop(224) if self.subset == 'train' else transforms.CenterCrop(224),
45
+ transforms.ToTensor(),
46
+ transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711])
47
+ ])
48
+
49
+ def open_h5(self):
50
+ h5 = h5py.File(self.h5_file, "r")
51
+ self.h5 = h5[self.subset]
52
+
53
+ def __getitem__(self, index):
54
+ if not hasattr(self, 'h5'):
55
+ self.open_h5()
56
+
57
+ images = list()
58
+ for i in range(5):
59
+ im = self.h5['image{}'.format(i)][index]
60
+ im = cv2.imdecode(im, cv2.IMREAD_COLOR)
61
+ images.append(im)
62
+
63
+ source_images = torch.stack([self.blip_image_processor(im) for im in images])
64
+ images = images[1:] if self.args.task == 'continuation' else images
65
+ images = [self.augment(im) for im in images]
66
+ images = torch.stack(images) if self.subset in ['train', 'val'] \
67
+ else torch.from_numpy(np.array([np.array(im) for im in images])).permute(0, 3, 1, 2)
68
+
69
+ texts = self.h5['sis'][index].decode('utf-8').split('|')
70
+
71
+ # tokenize caption using default tokenizer
72
+ tokenized = self.clip_tokenizer(
73
+ texts[1:] if self.args.task == 'continuation' else texts,
74
+ padding="max_length",
75
+ max_length=self.max_length,
76
+ truncation=False,
77
+ return_tensors="pt",
78
+ )
79
+ captions, attention_mask = tokenized['input_ids'], tokenized['attention_mask']
80
+
81
+ tokenized = self.blip_tokenizer(
82
+ texts,
83
+ padding="max_length",
84
+ max_length=self.max_length,
85
+ truncation=False,
86
+ return_tensors="pt",
87
+ )
88
+ source_caption, source_attention_mask = tokenized['input_ids'], tokenized['attention_mask']
89
+ return images, captions, attention_mask, source_images, source_caption, source_attention_mask
90
+
91
+ def __len__(self):
92
+ if not hasattr(self, 'h5'):
93
+ self.open_h5()
94
+ return len(self.h5['sis'])