import os # when loading file paths import pandas as pd # for lookup in annotation file import spacy # for tokenizer import torch import torchvision.transforms as transforms from PIL import Image # Load img from torch.nn.utils.rnn import pad_sequence # pad batch from torch.utils.data import DataLoader, Dataset # We want to convert text -> numerical values # 1. We need a Vocabulary mapping each word to a index # 2. We need to setup a Pytorch dataset to load the data # 3. Setup padding of every batch (all examples should be # of same seq_len and setup dataloader) # Download with: python -m spacy download en spacy_eng = spacy.load("en_core_web_sm") class Vocabulary: def __init__(self, freq_threshold): self.itos = {0: "", 1: "", 2: "", 3: ""} self.stoi = {"": 0, "": 1, "": 2, "": 3} self.freq_threshold = freq_threshold def __len__(self): return len(self.itos) @staticmethod def tokenizer_eng(text): return [tok.text.lower() for tok in spacy_eng.tokenizer(text)] def build_vocabulary(self, sentence_list): frequencies = {} idx = 4 for sentence in sentence_list: for word in self.tokenizer_eng(sentence): if word not in frequencies: frequencies[word] = 1 else: frequencies[word] += 1 if frequencies[word] == self.freq_threshold: self.stoi[word] = idx self.itos[idx] = word idx += 1 def numericalize(self, text): tokenized_text = self.tokenizer_eng(text) return [ self.stoi[token] if token in self.stoi else self.stoi[""] for token in tokenized_text ] class FlickrDataset(Dataset): def __init__(self, root_dir, captions_file, transform=None, freq_threshold=5): self.root_dir = root_dir self.df = pd.read_csv(captions_file) self.transform = transform # Get img, caption columns self.imgs = self.df["image"] self.captions = self.df["caption"] # Initialize vocabulary and build vocab self.vocab = Vocabulary(freq_threshold) self.vocab.build_vocabulary(self.captions.tolist()) def __len__(self): return len(self.df) def __getitem__(self, index): caption = self.captions[index] img_id = self.imgs[index] img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB") if self.transform is not None: img = self.transform(img) numericalized_caption = [self.vocab.stoi[""]] numericalized_caption += self.vocab.numericalize(caption) numericalized_caption.append(self.vocab.stoi[""]) return img, torch.tensor(numericalized_caption) class MyCollate: def __init__(self, pad_idx): self.pad_idx = pad_idx def __call__(self, batch): imgs = [item[0].unsqueeze(0) for item in batch] imgs = torch.cat(imgs, dim=0) # [BCHW] targets = [item[1] for item in batch] # [BL] L长度不同 targets = pad_sequence(targets, batch_first=False, # [LB] L长度相同 padding_value=self.pad_idx) return imgs, targets def get_loader( root_folder, annotation_file, transform, batch_size=32, num_workers=8, shuffle=True, pin_memory=True, ): dataset = FlickrDataset(root_folder, annotation_file, transform=transform) pad_idx = dataset.vocab.stoi[""] loader = DataLoader( dataset=dataset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle, pin_memory=pin_memory, collate_fn=MyCollate(pad_idx=pad_idx), ) return loader, dataset if __name__ == "__main__": transform = transforms.Compose( [transforms.Resize((224, 224)), transforms.ToTensor(),] ) loader, dataset = get_loader( "flickr8k/images/", "flickr8k/captions.txt", transform=transform ) for idx, (imgs, captions) in enumerate(loader): print(imgs.shape) print(captions.shape)