from torch.utils.data import DataLoader, Dataset, Sampler from pathlib import Path import json from multiprocessing import Pool from tqdm import tqdm from PIL import Image import random import numpy as np import torch import torchvision import torchvision.transforms as T from torch.utils.data.distributed import DistributedSampler from transformers import T5Tokenizer, BertTokenizer, BertTokenizerFast, CLIPTokenizer import text_utils project_dir = Path(__file__).parent.resolve() workspace_dir = project_dir.parent.parent dataset_dir = workspace_dir.joinpath('datasets/').resolve() # coco_dir = dataset_dir.joinpath('COCO') # vg_dir = dataset_dir.joinpath('VG') coco_img_dir = dataset_dir.joinpath('COCO/images/') coco_data_dir = project_dir.parent.joinpath('CLIP-ViL/CLIP-ViL-Direct/caption/data/') # coco_feature_dir = coco_dir.joinpath('features') class COCORetrievalDataset(Dataset): def __init__(self, split='karpathy_train', rank=-1, topk=-1, verbose=True, args=None, mode='train'): super().__init__() self.topk = topk self.verbose = verbose self.args = args self.rank = rank self.mode = mode # Loading datasets to data self.source = split if self.verbose: print('Data source: ', self.source) # if self.args.tokenizer is None: # self.args.tokenizer = self.args.decoder_backbone # if 'bert' in self.args.tokenizer: # self.tokenizer = BertTokenizerFast.from_pretrained( # self.args.tokenizer, # # max_length=self.args.max_text_length, # # do_lower_case=self.args.do_lower_case # ) # elif 'clip' in self.args.tokenizer: # self.tokenizer = CLIPTokenizer.from_pretrained( # self.args.tokenizer, # # max_length=self.args.max_text_length, # # do_lower_case=self.args.do_lower_case # ) self.tokenizer = CLIPTokenizer.from_pretrained( self.args.tokenizer, # max_length=self.args.max_text_length, # do_lower_case=self.args.do_lower_case ) with open(coco_data_dir.joinpath('cocotalk.json')) as f: self.vocab = list(json.load(f)['ix_to_word'].values()) popped = self.vocab.pop(-1) assert popped == 'UNK' if self.verbose: print('vocab size: ', len(self.vocab)) data_info_path = coco_data_dir.joinpath('dataset_coco.json') with open(data_info_path) as f: karpathy_data = json.load(f) split_rename = { 'train': 'train', 'restval': 'train', 'val': 'val', 'test': 'test' } n_images = 0 data = [] # self.vocab = set() for datum in karpathy_data['images']: re_split = split_rename[datum['split']] # if re_split == 'train': # for d in datum['sentences']: # self.vocab = self.vocab.union(set(d['tokens'])) if re_split != self.source.split('_')[-1]: continue if re_split == 'train': # for d in datum['sentences']: # img_id = datum['filename'].split('.')[0] # new_datum = { # 'filename': datum['filename'], # 'img_id': img_id, # 'sent': d['raw'].strip(), # 'targets': [d['raw'].strip() for d in datum['sentences']], # 'is_train': True, # 'cocoid': datum['cocoid'] # } # data.append(new_datum) img_id = datum['filename'].split('.')[0] new_datum = { 'filename': datum['filename'], 'img_id': img_id, # 'sent': d['raw'], # 'targets': [d['raw'].strip() for d in datum['sentences']], 'targets': [" ".join(d['tokens']) for d in datum['sentences']], 'is_train': True, 'cocoid': datum['cocoid'] } data.append(new_datum) else: img_id = datum['filename'].split('.')[0] new_datum = { 'filename': datum['filename'], 'img_id': img_id, # 'sent': d['raw'], # 'targets': [d['raw'].strip() for d in datum['sentences']], 'targets': [" ".join(d['tokens']) for d in datum['sentences']], 'is_train': False, 'cocoid': datum['cocoid'] } data.append(new_datum) n_images += 1 if self.verbose: print(f"{self.source} has {n_images} images") # print(f"Loaded {len(data)} data from", split) self.n_gpus = torch.cuda.device_count() if self.topk > 0: data = data[:self.topk] if self.verbose: print(f"Use only {self.topk} data") self.data = data # if self.verbose: # print("# all sentences:", len(self.data)) if self.args.load_feat: # feat_dir = coco_dir.joinpath('' # self.feat_loader = HybridLoader('/scratch-space/CLIP-ViL/CLIP-ViL-Direct/caption/data/cocotalk_clipscore_vis', ext='.npy', in_memory=False) self.feat_loader = HybridLoader( coco_data_dir.joinpath('cocotalk_clipscore_vis'), ext='.npy', in_memory=False) else: if 'openai/clip' in self.args.encoder_backbone: # from transformers import CLIPProcessor # self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", # size=args.image_size, # do_resize=True, # do_center_crop=False, # ) # self.img_transform = lambda image: self.processor.feature_extractor( # image, # return_tensors='pt')['pixel_values'][0] self.image_mean = [0.48145466, 0.4578275, 0.40821073] self.image_std = [0.26862954, 0.26130258, 0.27577711] # captioning # self.img_transform = T.Compose([ # T.Resize((self.args.image_size, self.args.image_size)) # ]) # retrieval self.img_transform = T.Compose([ T.Resize(self.args.image_size, interpolation=T.functional.InterpolationMode.BICUBIC), T.CenterCrop(self.args.image_size) ]) self.img_tensor_transform = T.Compose([ # T.RandomCrop(224), # T.RandomHorizontalFlip(p=0.3), T.ConvertImageDtype(torch.float), T.Normalize(self.image_mean, self.image_std) ] ) # elif 'google/vit' in self.args.encoder_backbone: # self.image_mean = [0.5, 0.5, 0.5] # self.image_std = [0.5, 0.5, 0.5] # self.img_transform = T.Compose([ # # T.PILToTensor(), # T.Resize((self.args.image_size, self.args.image_size)) # ]) # self.img_tensor_transform = T.Compose([ # # T.RandomCrop(224), # # T.RandomHorizontalFlip(p=0.3), # T.ConvertImageDtype(torch.float), # T.Normalize(self.image_mean, self.image_std) # ] # ) def get_negative_text(self, text): neg_type = random.choice(['repeat', 'remove', 'insert', 'swap', 'shuffle']) if neg_type == 'repeat': text = text_utils.repeat(text) elif neg_type == 'remove': text = text_utils.remove(text) elif neg_type == 'insert': text = text_utils.insert(text, self.vocab) elif neg_type == 'swap': text = text_utils.swap(text, self.vocab) elif neg_type == 'shuffle': text = text_utils.shuffle(text) return text, neg_type def __len__(self): return len(self.data) def __getitem__(self, idx): datum = self.data[idx] return self.process_datum(datum) def process_datum(self, datum): out_dict = {} ###### Image ###### if self.args.load_feat: cocoid = datum['cocoid'] out_dict['cocoid'] = str(cocoid) img_feat = self.feat_loader.get(str(cocoid)) out_dict['img_feat'] = torch.from_numpy(img_feat) else: img_id = datum['img_id'] out_dict['img_id'] = img_id if 'train' in datum['filename']: img_split = 'train2014' elif 'val' in datum['filename']: img_split = 'val2014' img_path = coco_img_dir.joinpath(img_split).joinpath(datum['filename']).with_suffix('.jpg') assert img_path.exists() img_path = str(img_path) out_dict['img_path'] = img_path img_tensor = torchvision.io.read_image(img_path) # out_dict['img_tensor'] = img # img = Image.open(img_path).convert('RGB') # img_tensor = torch.as_tensor(np.asarray(img)) out_dict['img_tensor'] = self.img_transform(img_tensor) # self.img_transform(img_tensor) # out_dict['img_tensor'] = self.img_transform(img) ###### Text ##### # if datum['is_train']: # sent = datum['sent'].strip() sent = random.choice(datum['targets']) # target_ids = self.tokenizer.encode( # sent, max_length=self.args.gen_max_length, truncation=True) # assert len(target_ids) <= self.args.gen_max_length, len(target_ids) out_dict['sent'] = sent # out_dict['target_ids'] = torch.LongTensor(target_ids) # out_dict['target_length'] = len(target_ids) # negative sample neg_sent, neg_type = self.get_negative_text(sent) # neg_target_ids = self.tokenizer.encode( # neg_sent, max_length=self.args.gen_max_length, truncation=True) # assert len(neg_target_ids) <= self.args.gen_max_length, len(neg_target_ids) out_dict['neg_sent'] = neg_sent out_dict['neg_type'] = neg_type # out_dict['neg_target_ids'] = torch.LongTensor(neg_target_ids) # out_dict['neg_target_length'] = len(neg_target_ids) if 'targets' in datum: out_dict['targets'] = datum['targets'] return out_dict def collate_fn(self, batch): batch_entry = {} B = len(batch) # if 'target_ids' in batch[0]: # T_W_L = max(entry['target_length'] for entry in batch) # target_ids = torch.ones( # B, T_W_L, dtype=torch.long) * self.tokenizer.pad_token_id # if 'target_ids' in batch[0]: # T_W_L = max(entry['target_length'] for entry in batch) # target_ids = torch.ones( # B, T_W_L, dtype=torch.long) * self.tokenizer.pad_token_id targets = [] img_ids = [] img_paths = [] coco_ids = [] if self.args.load_feat: img_feats = torch.zeros(B, 512, dtype=torch.float) else: # imgs = [] img_tensor = torch.zeros(B, 3, self.args.image_size, self.args.image_size, dtype=torch.uint8) for i, entry in enumerate(batch): if self.args.load_feat: coco_ids.append(entry['cocoid']) img_feats[i] = entry['img_feat'] else: img_ids.append(entry['img_id']) img_paths.append(entry['img_path']) img_tensor[i] = entry['img_tensor'] # if 'target_ids' in entry: # target_ids[i, :entry['target_length']] = entry['target_ids'] if 'targets' in entry: targets.append(entry['targets']) if 'sent' in batch[0]: # word_mask = target_ids != self.tokenizer.pad_token_id # target_ids[~word_mask] = -100 # batch_entry['target_ids'] = target_ids tokenized = self.tokenizer([entry['sent'] for entry in batch], truncation=True, padding=True, return_tensors='pt') neg_tokenized = self.tokenizer([entry['neg_sent'] for entry in batch], truncation=True, padding=True, return_tensors='pt') # sent, max_length=self.args.gen_max_length, truncation=True) batch_entry['text'] = (tokenized.input_ids, tokenized.attention_mask) batch_entry['neg_text'] = (neg_tokenized.input_ids, neg_tokenized.attention_mask) if self.args.load_feat: batch_entry['coco_ids'] = coco_ids batch_entry['img_feats'] = img_feats else: img_tensor = self.img_tensor_transform(img_tensor) batch_entry['img_id'] = img_ids batch_entry['img_paths'] = img_paths batch_entry['img_tensor'] = img_tensor batch_entry['targets'] = targets # print('batch created') # batch_entry['task'] = 'caption' return batch_entry # def get_loader(args, split='karpathy_train', mode='train', # batch_size=32, workers=4, distributed=False, gpu=0, # topk=-1): # verbose = (gpu == 0) # dataset = COCORetrievalDataset( # split, # rank=gpu, # topk=topk, # verbose=verbose, # args=args, # mode=mode) # # if distributed: # # sampler = DistributedSampler(dataset) # # else: # # sampler = None # if mode == 'train': # loader = DataLoader( # dataset, batch_size=batch_size, shuffle=(sampler is None), # num_workers=workers, pin_memory=True, sampler=sampler, # collate_fn=dataset.collate_fn) # else: # loader = DataLoader( # dataset, # batch_size=batch_size, shuffle=False, # num_workers=workers, pin_memory=True, # sampler=sampler, # collate_fn=dataset.collate_fn, # drop_last=False) # # if verbose: # # loader.evaluator = COCOCaptionEvaluator() # # loader.task = 'caption' # return loader # class COCOCaptionEvaluator: # def __init__(self): # import language_evaluation # self.evaluator = language_evaluation.CocoEvaluator(verbose=False) # def evaluate(self, predicts, answers): # results = self.evaluator.run_evaluation(predicts, answers) # return results import six import os import h5py class HybridLoader: """ If db_path is a director, then use normal file loading If lmdb, then load from lmdb The loading method depend on extention. in_memory: if in_memory is True, we save all the features in memory For individual np(y|z)s, we don't need to do that because the system will do this for us. Should be useful for lmdb or h5. (Copied this idea from vilbert) """ def __init__(self, db_path, ext='.npy', in_memory=False): self.db_path = db_path self.ext = ext if self.ext == '.npy': self.loader = lambda x: np.load(six.BytesIO(x)) else: self.loader = lambda x: np.load(six.BytesIO(x))['feat'] # if db_path.endswith('.lmdb'): # self.db_type = 'lmdb' # self.lmdb = lmdbdict(db_path, unsafe=True) # self.lmdb._key_dumps = DUMPS_FUNC['ascii'] # self.lmdb._value_loads = LOADS_FUNC['identity'] # elif db_path.endswith('.pth'): # Assume a key,value dictionary # self.db_type = 'pth' # self.feat_file = torch.load(db_path) # self.loader = lambda x: x # print('HybridLoader: ext is ignored') # elif db_path.endswith('h5'): # self.db_type = 'h5' # self.loader = lambda x: np.array(x).astype('float32') # else: # self.db_type = 'dir' self.in_memory = in_memory if self.in_memory: self.features = {} def get(self, key): # if self.in_memory and key in self.features: # # We save f_input because we want to save the # # compressed bytes to save memory # f_input = self.features[key] # elif self.db_type == 'lmdb': # f_input = self.lmdb[key] # elif self.db_type == 'pth': # f_input = self.feat_file[key] # elif self.db_type == 'h5': # f_input = h5py.File(self.db_path, 'r')[key] # else: # f_input = open(os.path.join( # self.db_path, key + self.ext), 'rb').read() f_input = open(os.path.join( self.db_path, key + self.ext), 'rb').read() if self.in_memory and key not in self.features: self.features[key] = f_input # load image feat = self.loader(f_input) return feat