Spaces:
Runtime error
Runtime error
import os | |
import random | |
import pickle | |
import torch | |
from tencentpretrain.utils.constants import * | |
from tencentpretrain.utils.tokenizers import * | |
from tencentpretrain.utils.mask import mask_seq | |
from tencentpretrain.utils.augment import SpecAugment | |
class Dataloader(object): | |
def __init__(self, args, dataset_path, batch_size, rank, world_size, gpu_id, shuffle=False, model_for_dataloader=None): | |
self.tokenizer = args.tokenizer | |
self.batch_size = batch_size | |
self.instances_buffer_size = args.instances_buffer_size | |
self.rank = rank | |
self.world_size = world_size | |
self.gpu_id = gpu_id | |
self.shuffle = shuffle | |
self.model_for_dataloader = model_for_dataloader | |
self.dataset_reader = open(dataset_path, "rb") | |
self.read_count = 0 | |
self.start = 0 | |
self.end = 0 | |
self.buffer = [] | |
self.vocab = args.vocab | |
self.whole_word_masking = args.whole_word_masking | |
self.span_masking = args.span_masking | |
self.span_geo_prob = args.span_geo_prob | |
self.span_max_length = args.span_max_length | |
def _fill_buf(self): | |
try: | |
self.buffer = [] | |
while True: | |
instance = pickle.load(self.dataset_reader) | |
self.read_count += 1 | |
if (self.read_count - 1) % self.world_size == self.rank: | |
self.buffer.append(instance) | |
if len(self.buffer) >= self.instances_buffer_size: | |
break | |
except EOFError: | |
# Reach file end. | |
self.dataset_reader.seek(0) | |
if self.shuffle: | |
random.shuffle(self.buffer) | |
self.start = 0 | |
self.end = len(self.buffer) | |
def _empty(self): | |
return self.start >= self.end | |
def __del__(self): | |
self.dataset_reader.close() | |
class BertDataloader(Dataloader): | |
def __iter__(self): | |
while True: | |
while self._empty(): | |
self._fill_buf() | |
if self.start + self.batch_size >= self.end: | |
instances = self.buffer[self.start:] | |
else: | |
instances = self.buffer[self.start: self.start + self.batch_size] | |
self.start += self.batch_size | |
src = [] | |
tgt_mlm = [] | |
is_next = [] | |
seg = [] | |
masked_words_num = 0 | |
for ins in instances: | |
src_single, pad_num = ins[0] | |
for _ in range(pad_num): | |
src_single.append(self.vocab.get(PAD_TOKEN)) | |
if len(ins) == 4: | |
src.append(src_single) | |
masked_words_num += len(ins[1]) | |
tgt_mlm.append([0] * len(src_single)) | |
for mask in ins[1]: | |
tgt_mlm[-1][mask[0]] = mask[1] | |
is_next.append(ins[2]) | |
seg.append([1] * ins[3][0] + [2] * (ins[3][1] - ins[3][0]) + [0] * pad_num) | |
else: | |
src_single, tgt_mlm_single = mask_seq(src_single, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) | |
masked_words_num += len(tgt_mlm_single) | |
src.append(src_single) | |
tgt_mlm.append([0] * len(src_single)) | |
for mask in tgt_mlm_single: | |
tgt_mlm[-1][mask[0]] = mask[1] | |
is_next.append(ins[1]) | |
seg.append([1] * ins[2][0] + [2] * (ins[2][1] - ins[2][0]) + [0] * pad_num) | |
if masked_words_num == 0: | |
continue | |
yield torch.LongTensor(src), \ | |
torch.LongTensor(tgt_mlm), \ | |
torch.LongTensor(is_next), \ | |
torch.LongTensor(seg) | |
class MlmDataloader(Dataloader): | |
def __iter__(self): | |
while True: | |
while self._empty(): | |
self._fill_buf() | |
if self.start + self.batch_size >= self.end: | |
instances = self.buffer[self.start:] | |
else: | |
instances = self.buffer[self.start: self.start + self.batch_size] | |
self.start += self.batch_size | |
src = [] | |
tgt = [] | |
seg = [] | |
masked_words_num = 0 | |
for ins in instances: | |
src_single, pad_num = ins[0] | |
for _ in range(pad_num): | |
src_single.append(self.vocab.get(PAD_TOKEN)) | |
if len(ins) == 3: | |
src.append(src_single) | |
masked_words_num += len(ins[1]) | |
tgt.append([0] * len(src_single)) | |
for mask in ins[1]: | |
tgt[-1][mask[0]] = mask[1] | |
seg.append([1] * ins[2][0] + [0] * pad_num) | |
else: | |
src_single, tgt_single = mask_seq(src_single, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) | |
masked_words_num += len(tgt_single) | |
src.append(src_single) | |
tgt.append([0] * len(src_single)) | |
for mask in tgt_single: | |
tgt[-1][mask[0]] = mask[1] | |
seg.append([1] * ins[1][0] + [0] * pad_num) | |
if masked_words_num == 0: | |
continue | |
yield torch.LongTensor(src), \ | |
torch.LongTensor(tgt), \ | |
torch.LongTensor(seg) | |
class AlbertDataloader(BertDataloader): | |
''' | |
AlbertDataloader can reuse the code of BertDataloader. | |
''' | |
pass | |
class LmDataloader(Dataloader): | |
def __iter__(self): | |
while True: | |
while self._empty(): | |
self._fill_buf() | |
if self.start + self.batch_size >= self.end: | |
instances = self.buffer[self.start:] | |
else: | |
instances = self.buffer[self.start: self.start + self.batch_size] | |
self.start += self.batch_size | |
src = [] | |
tgt = [] | |
seg = [] | |
for ins in instances: | |
src_single, pad_num = ins[0] | |
for _ in range(pad_num): | |
src_single.append(self.vocab.get(PAD_TOKEN)) | |
src.append(src_single[:-1]) | |
tgt.append(src_single[1:]) | |
seg.append([1] * ins[1][0] + [0] * (len(src_single) - 1 - ins[1][0])) | |
yield torch.LongTensor(src), \ | |
torch.LongTensor(tgt), \ | |
torch.LongTensor(seg) | |
class BilmDataloader(Dataloader): | |
def __iter__(self): | |
while True: | |
while self._empty(): | |
self._fill_buf() | |
if self.start + self.batch_size >= self.end: | |
instances = self.buffer[self.start:] | |
else: | |
instances = self.buffer[self.start: self.start + self.batch_size] | |
self.start += self.batch_size | |
src = [] | |
tgt_forward = [] | |
tgt_backward = [] | |
seg = [] | |
for ins in instances: | |
src_single, pad_num = ins[0] | |
tgt_forward_single, tgt_backward_single = ins[1], ins[2] | |
for _ in range(pad_num): | |
src_single.append(self.vocab.get(PAD_TOKEN)) | |
tgt_forward_single.append(self.vocab.get(PAD_TOKEN)) | |
tgt_backward_single.append(self.vocab.get(PAD_TOKEN)) | |
src.append(src_single) | |
tgt_forward.append(tgt_forward_single) | |
tgt_backward.append(tgt_backward_single) | |
seg.append([1] * ins[3][0] + [0] * (len(src_single) - ins[3][0])) | |
yield torch.LongTensor(src), \ | |
torch.LongTensor(tgt_forward), \ | |
torch.LongTensor(tgt_backward), \ | |
torch.LongTensor(seg) | |
class MtDataloader(Dataloader): | |
def __iter__(self): | |
while True: | |
while self._empty(): | |
self._fill_buf() | |
if self.start + self.batch_size >= self.end: | |
instances = self.buffer[self.start:] | |
else: | |
instances = self.buffer[self.start: self.start + self.batch_size] | |
self.start += self.batch_size | |
src = [] | |
tgt_in = [] | |
tgt_out = [] | |
seg = [] | |
tgt_seg = [] | |
for ins in instances: | |
src_single, pad_num = ins[0] | |
for _ in range(pad_num): | |
src_single.append(self.vocab.get(PAD_TOKEN)) | |
tgt_single, pad_num = ins[1] | |
for _ in range(pad_num): | |
tgt_single.append(self.vocab.get(PAD_TOKEN)) | |
src.append(src_single) | |
tgt_in.append(tgt_single[:-1]) | |
tgt_out.append(tgt_single[1:]) | |
seg.append([1] * ins[2][0] + [0] * (len(src_single) - ins[2][0])) | |
pad_num = max(ins[1][1] - 1, 0) # left shifted, pad_num >= 0 | |
tgt_seg.append([1] * (len(tgt_in[-1]) - pad_num) + [0] * pad_num) | |
yield torch.LongTensor(src), \ | |
torch.LongTensor(tgt_out), \ | |
torch.LongTensor(seg), \ | |
torch.LongTensor(tgt_in), \ | |
torch.LongTensor(tgt_seg) | |
class T5Dataloader(Dataloader): | |
def __iter__(self): | |
while True: | |
while self._empty(): | |
self._fill_buf() | |
if self.start + self.batch_size >= self.end: | |
instances = self.buffer[self.start:] | |
else: | |
instances = self.buffer[self.start: self.start + self.batch_size] | |
self.start += self.batch_size | |
src = [] | |
tgt_in = [] | |
tgt_out = [] | |
seg = [] | |
tgt_seg = [] | |
tgt_seq_length = 0 | |
for _, ins in enumerate(instances): | |
src_single, pad_num = ins[0] | |
for _ in range(pad_num): | |
src_single.append(self.vocab.get(PAD_TOKEN)) | |
if len(ins) == 3: | |
tgt_single = ins[1] | |
seg.append([1] * ins[2][0] + [0] * pad_num) | |
else: | |
src_single, tgt_single = mask_seq(src_single, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) | |
seg.append([1] * ins[1][0] + [0] * pad_num) | |
MASK_ID = self.vocab.get(MASK_TOKEN) | |
SENTINEL_ID = self.vocab.get(SENTINEL_TOKEN) | |
PAD_ID = self.vocab.get(PAD_TOKEN) | |
for src_index, _ in tgt_single: | |
if src_single[src_index] != MASK_ID: | |
src_single[src_index] = MASK_ID | |
tgt_in_single = [self.vocab.get(CLS_TOKEN)] | |
mask_index = 0 | |
src_with_sentinel = [] | |
for token_id in src_single: | |
if token_id == MASK_ID: | |
if len(src_with_sentinel) > 0 and src_with_sentinel[-1] == (SENTINEL_ID - 1): | |
pass | |
else: | |
src_with_sentinel.append(SENTINEL_ID) | |
tgt_in_single.append(SENTINEL_ID) | |
if SENTINEL_ID < len(self.vocab) - 1: | |
SENTINEL_ID += 1 | |
tgt_in_single.append(tgt_single[mask_index][1]) | |
mask_index += 1 | |
else: | |
src_with_sentinel.append(token_id) | |
tgt_in_single.append(SENTINEL_ID) | |
tgt_in_single.append(self.vocab.get(SEP_TOKEN)) | |
tgt_seg_single = [1] * len(tgt_in_single) | |
while len(src_with_sentinel) < len(src_single): | |
src_with_sentinel.append(PAD_ID) | |
if len(tgt_in_single) > tgt_seq_length: | |
tgt_seq_length = len(tgt_in_single) | |
src.append(src_with_sentinel) | |
tgt_in.append(tgt_in_single) | |
tgt_seg.append(tgt_seg_single) | |
tgt_out.append(tgt_in[-1][1:] + [PAD_ID]) | |
for i in range(len(tgt_in)): | |
while len(tgt_in[i]) != tgt_seq_length: | |
tgt_in[i].append(PAD_ID) | |
tgt_out[i].append(PAD_ID) | |
tgt_seg[i].append(0) | |
yield torch.LongTensor(src), \ | |
torch.LongTensor(tgt_out), \ | |
torch.LongTensor(seg), \ | |
torch.LongTensor(tgt_in), \ | |
torch.LongTensor(tgt_seg) | |
class GsgDataloader(MtDataloader): | |
pass | |
class BartDataloader(Dataloader): | |
def __iter__(self): | |
while True: | |
while self._empty(): | |
self._fill_buf() | |
if self.start + self.batch_size >= self.end: | |
instances = self.buffer[self.start:] | |
else: | |
instances = self.buffer[self.start: self.start + self.batch_size] | |
self.start += self.batch_size | |
src = [] | |
tgt_in = [] | |
tgt_out = [] | |
seg = [] | |
tgt_seg = [] | |
for _, ins in enumerate(instances): | |
src_single, pad_num = ins[0] | |
for _ in range(pad_num): | |
src_single.append(self.vocab.get(PAD_TOKEN)) | |
tgt_single, pad_num = ins[1] | |
for _ in range(pad_num): | |
tgt_single.append(self.vocab.get(PAD_TOKEN)) | |
src_single, _ = mask_seq(src_single, self.tokenizer, self.whole_word_masking, self.span_masking, | |
self.span_geo_prob, self.span_max_length) | |
seg_pos = ins[2][0] | |
tgt_in.append(tgt_single[:-1]) | |
tgt_out.append(tgt_single[1:]) | |
pad_num = max(ins[1][1] - 1, 0) # left shifted, pad_num >= 0 | |
tgt_seg.append([1] * (len(tgt_in[-1]) - pad_num) + [0] * pad_num) | |
MASK_ID = self.vocab.get(MASK_TOKEN) | |
src_with_span_mask = [] | |
for token_id in src_single: | |
if token_id == MASK_ID: | |
if len(src_with_span_mask) > 0 and src_with_span_mask[-1] == MASK_ID: | |
seg_pos -= 1 | |
else: | |
src_with_span_mask.append(MASK_ID) | |
else: | |
src_with_span_mask.append(token_id) | |
while len(src_with_span_mask) < len(src_single): | |
src_with_span_mask.append(self.vocab.get(PAD_TOKEN)) | |
seg.append([1] * seg_pos + [0] * (len(src_single) - seg_pos)) | |
src.append(src_with_span_mask) | |
yield torch.LongTensor(src), \ | |
torch.LongTensor(tgt_out), \ | |
torch.LongTensor(seg), \ | |
torch.LongTensor(tgt_in), \ | |
torch.LongTensor(tgt_seg) | |
class ClsDataloader(Dataloader): | |
def __iter__(self): | |
while True: | |
while self._empty(): | |
self._fill_buf() | |
if self.start + self.batch_size >= self.end: | |
instances = self.buffer[self.start:] | |
else: | |
instances = self.buffer[self.start: self.start + self.batch_size] | |
self.start += self.batch_size | |
src = [] | |
tgt = [] | |
seg = [] | |
for ins in instances: | |
src_single, pad_num = ins[0] | |
seg_pos_single = ins[2] | |
if len(seg_pos_single) == 1: | |
seg_single = [1] * seg_pos_single[0] | |
elif len(seg_pos_single) == 2: | |
seg_single = [1] * seg_pos_single[0] + [2] * seg_pos_single[1] | |
for _ in range(pad_num): | |
src_single.append(self.vocab.get(PAD_TOKEN)) | |
seg_single.append(0) | |
src.append(src_single) | |
tgt.append(ins[1]) | |
seg.append(seg_single) | |
yield torch.LongTensor(src), \ | |
torch.LongTensor(tgt), \ | |
torch.LongTensor(seg) | |
class PrefixlmDataloader(Dataloader): | |
def __iter__(self): | |
while True: | |
while self._empty(): | |
self._fill_buf() | |
if self.start + self.batch_size >= self.end: | |
instances = self.buffer[self.start:] | |
else: | |
instances = self.buffer[self.start: self.start + self.batch_size] | |
self.start += self.batch_size | |
src = [] | |
tgt = [] | |
seg = [] | |
for ins in instances: | |
src_single, pad_num = ins[0] | |
tgt_single = ins[1] | |
for _ in range(pad_num): | |
src_single.append(self.vocab.get(PAD_TOKEN)) | |
tgt_single.append(self.vocab.get(PAD_TOKEN)) | |
src.append(src_single) | |
tgt.append(tgt_single) | |
seg.append([1] * ins[2][0] + [2] * (ins[2][1] - ins[2][0]) + [0] * (len(src_single) - ins[2][1])) | |
yield torch.LongTensor(src), \ | |
torch.LongTensor(tgt), \ | |
torch.LongTensor(seg) | |
class ClsMlmDataloader(Dataloader): | |
def __iter__(self): | |
while True: | |
while self._empty(): | |
self._fill_buf() | |
if self.start + self.batch_size >= self.end: | |
instances = self.buffer[self.start:] | |
else: | |
instances = self.buffer[self.start: self.start + self.batch_size] | |
self.start += self.batch_size | |
src = [] | |
tgt_mlm = [] | |
tgt_cls = [] | |
seg = [] | |
masked_words_num = 0 | |
for ins in instances: | |
src_single, pad_num = ins[0] | |
seg_pos_single = ins[-1] | |
tgt_cls.append(ins[-2]) | |
if len(seg_pos_single) == 1: | |
seg_single = [1] * seg_pos_single[0] | |
elif len(seg_pos_single) == 2: | |
seg_single = [1] * seg_pos_single[0] + [2] * seg_pos_single[1] | |
for _ in range(pad_num): | |
src_single.append(self.vocab.get(PAD_TOKEN)) | |
seg_single.append(0) | |
seg.append(seg_single) | |
if len(ins) == 4 : | |
src.append(src_single) | |
masked_words_num += len(ins[1]) | |
tgt_mlm.append([0] * len(src_single)) | |
for mask in ins[1]: | |
tgt_mlm[-1][mask[0]] = mask[1] | |
else: | |
src_single, tgt_single = mask_seq(src_single, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) | |
src.append(src_single) | |
masked_words_num += len(tgt_single) | |
tgt_mlm.append([0] * len(src_single)) | |
for mask in tgt_single: | |
tgt_mlm[-1][mask[0]] = mask[1] | |
if masked_words_num == 0: | |
continue | |
yield torch.LongTensor(src), \ | |
torch.LongTensor(tgt_mlm), \ | |
torch.LongTensor(tgt_cls), \ | |
torch.LongTensor(seg) | |
class VisionDataloader(Dataloader): | |
def __init__(self, args, dataset_path, batch_size, rank, world_size, gpu_id, shuffle=False, model_for_dataloader=None): | |
super(VisionDataloader, self).__init__(args, dataset_path, batch_size, rank, world_size, gpu_id, shuffle, model_for_dataloader) | |
self.patch_size = args.patch_size | |
self.image_height = args.image_height | |
self.image_width = args.image_width | |
from torchvision import transforms | |
from tencentpretrain.utils.misc import ZeroOneNormalize | |
preprocess_pipeline = [] | |
if "corp" in args.image_preprocess: | |
preprocess_pipeline.append(transforms.RandomResizedCrop(max(self.image_height, self.image_width))) | |
if "horizontal_flip" in args.image_preprocess: | |
preprocess_pipeline.append(transforms.RandomHorizontalFlip()) | |
preprocess_pipeline.append(transforms.Resize((self.image_height, self.image_width))) | |
preprocess_pipeline.append(ZeroOneNormalize()) | |
if "normalize" in args.image_preprocess: | |
preprocess_pipeline.append(transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))) | |
self.transform = transforms.Compose(preprocess_pipeline) | |
class VitDataloader(VisionDataloader): | |
def __iter__(self): | |
""" | |
instances: (tgt, image_path) | |
tgt: The category the image belongs to | |
image_path: Path of the image sample | |
Returns: | |
src_image: [batch_size x channel_size x width x hight] | |
seg: [batch_size x (patch_num + 1)] | |
tgt: [batch_size] | |
""" | |
from torchvision.io import read_image | |
from torchvision.io.image import ImageReadMode | |
while True: | |
while self._empty(): | |
self._fill_buf() | |
if self.start + self.batch_size >= self.end: | |
instances = self.buffer[self.start:] | |
else: | |
instances = self.buffer[self.start: self.start + self.batch_size] | |
self.start += self.batch_size | |
src = [] | |
tgt = [] | |
seg = [] | |
for ins in instances: | |
image = read_image(ins[1], ImageReadMode.RGB) | |
image = image.cuda(self.gpu_id) | |
src.append(self.transform(image)) | |
tgt.append(ins[0]) | |
seg.append([1] * ((self.image_height // self.patch_size) * (self.image_width // self.patch_size) + 1)) | |
yield torch.stack(src, 0), \ | |
torch.LongTensor(tgt), \ | |
torch.LongTensor(seg) | |
class ViltDataloader(VisionDataloader): | |
def __iter__(self): | |
""" | |
instances: (src_text, seg_text, image_path) | |
src_text: Tokens of the text sample | |
seg_text: Segment input of text sample | |
src_image: Path of the image sample | |
Returns: | |
src_text: [batch_size x seq_length] | |
src_image: [batch_size x channel_size x width x hight] | |
tgt_mlm: [batch_size x (seq_length + patch_num + 1)] | |
tgt_match: [batch_size] | |
seg: [batch_size x (seq_length + patch_num + 1)] | |
""" | |
from torchvision.io import read_image | |
from torchvision.io.image import ImageReadMode | |
while True: | |
while self._empty(): | |
self._fill_buf() | |
if self.start + self.batch_size >= self.end: | |
instances = self.buffer[self.start:] | |
else: | |
instances = self.buffer[self.start: self.start + self.batch_size] | |
self.start += self.batch_size | |
src_text = [] | |
src_image = [] | |
tgt_mlm = [] | |
tgt_match = [] | |
seg = [] | |
masked_words_num = 0 | |
for ins in instances: | |
src_text_single, pad_num = ins[0] | |
for _ in range(pad_num): | |
src_text_single.append(self.vocab.get(PAD_TOKEN)) | |
src_text_single, tgt_mlm_single = mask_seq(src_text_single, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) | |
src_text.append(src_text_single) | |
masked_words_num += len(tgt_mlm_single) | |
tgt_mlm.append([0] * len(src_text_single)) | |
for mask in tgt_mlm_single: | |
tgt_mlm[-1][mask[0]] = mask[1] | |
if random.random() < 0.5: | |
image = read_image(ins[2], ImageReadMode.RGB) | |
tgt_match.append(1) | |
else: | |
image = read_image(random.choice(self.buffer)[2], ImageReadMode.RGB) | |
tgt_match.append(0) | |
seg_image = [2] * ((self.image_height // self.patch_size) * (self.image_width // self.patch_size) + 1) | |
tgt_mlm[-1].extend([0] * len(seg_image)) | |
image = image.cuda(self.gpu_id) | |
src_image_single = self.transform(image) | |
src_image.append(src_image_single) | |
seg.append([1] * ins[1][0] + [0] * pad_num + seg_image) | |
if masked_words_num == 0: | |
continue | |
yield torch.LongTensor(src_text), \ | |
torch.stack(src_image, 0), \ | |
torch.LongTensor(tgt_mlm), \ | |
torch.LongTensor(tgt_match), \ | |
torch.LongTensor(seg) | |
class ClipDataloader(VisionDataloader): | |
def __iter__(self): | |
""" | |
instances: (src_text, src_image, seg_text) | |
src_text: Tokens of the text sample | |
src_image: Path of the image sample | |
seg_text: Segment input of text sample | |
Returns: | |
src_text: [batch_size x seq_length] | |
src_image: [batch_size x channel_size x width x hight] | |
seg_text: [batch_size x seq_length] | |
seg_image: [batch_size x (patch_num + 1)] | |
""" | |
from torchvision.io import read_image | |
from torchvision.io.image import ImageReadMode | |
while True: | |
while self._empty(): | |
self._fill_buf() | |
if self.start + self.batch_size >= self.end: | |
instances = self.buffer[self.start:] | |
else: | |
instances = self.buffer[self.start: self.start + self.batch_size] | |
self.start += self.batch_size | |
src_text = [] | |
src_image = [] | |
seg_text = [] | |
seg_image = [] | |
for ins in instances: | |
src_text_single, pad_num = ins[0] | |
for _ in range(pad_num): | |
src_text_single.append(self.vocab.get(PAD_TOKEN)) | |
src_text.append(src_text_single) | |
seg_text.append([1] * ins[1][0] + [0] * pad_num) | |
image = read_image(ins[2], ImageReadMode.RGB) | |
image = image.cuda(self.gpu_id) | |
src_image.append(self.transform(image)) | |
seg_image.append([1] * ((self.image_height // self.patch_size) * (self.image_width // self.patch_size) + 1)) | |
yield torch.LongTensor(src_text), \ | |
torch.stack(src_image, 0), \ | |
torch.LongTensor(seg_text), \ | |
torch.LongTensor(seg_image) | |
class AudioDataloader(Dataloader): | |
def __init__(self, args, dataset_path, batch_size, rank, world_size, gpu_id, shuffle=False, model_for_dataloader=None): | |
super(AudioDataloader, self).__init__(args, dataset_path, batch_size, rank, world_size, gpu_id, shuffle, model_for_dataloader) | |
self.dataset_folder = os.path.dirname(dataset_path) | |
self.sampling_rate = args.sampling_rate | |
self.normalize_means, self.normalize_vars, self.ceptral_normalize = True, True, True | |
self.padding_value = 0.0 | |
self.audio_feature_size = args.audio_feature_size | |
self.conv_layers_num = args.conv_layers_num | |
self.max_audio_frames = args.max_audio_frames | |
self.specaugment = None | |
if "normalize_means" not in args.audio_preprocess: | |
self.normalize_means = False | |
if "normalize_vars" not in args.audio_preprocess: | |
self.normalize_vars = False | |
if "ceptral_normalize" not in args.audio_preprocess: | |
self.ceptral_normalize = False | |
if "sepcaugment" in args: | |
self.specaugment = SpecAugment(args) | |
def utterance_cmvn(x, normalize_means=True, normalize_vars=True, gpu_id=None): | |
mean = x.mean(axis=0) | |
square_sums = (x ** 2).sum(axis=0) | |
if normalize_means: | |
x = torch.sub(x, mean) | |
if normalize_vars: | |
var = square_sums / x.size(0) - mean ** 2 | |
if gpu_id is not None: | |
std = torch.sqrt(torch.maximum(var, torch.full(var.size(), 1e-10).cuda(gpu_id))) | |
else: | |
std = torch.sqrt(torch.maximum(var, torch.full(var.size(), 1e-10))) | |
x = torch.div(x, std) | |
return x | |
class S2tDataloader(AudioDataloader): | |
def __iter__(self): | |
import torchaudio | |
import torchaudio.compliance.kaldi as ta_kaldi | |
padding_vector = torch.FloatTensor(self.audio_feature_size * [self.padding_value] if self.audio_feature_size > 1 else self.padding_value).unsqueeze(0).cuda(self.gpu_id) | |
while True: | |
while self._empty(): | |
self._fill_buf() | |
if self.start + self.batch_size >= self.end: | |
instances = self.buffer[self.start:] | |
else: | |
instances = self.buffer[self.start: self.start + self.batch_size] | |
self.start += self.batch_size | |
tgt_in = [] | |
tgt_out = [] | |
src_audio = [] | |
seg_audio = [] | |
tgt_seg = [] | |
for ins in instances: | |
text_single, pad_num = ins[0] | |
for _ in range(pad_num): | |
text_single.append(self.vocab.get(PAD_TOKEN)) | |
waveform, _ = torchaudio.load(ins[2]) # waveform, sample_rate | |
waveform = waveform * (2 ** 15) # Kaldi compliance: 16-bit signed integers | |
waveform = waveform.cuda(self.gpu_id) | |
feature = ta_kaldi.fbank(waveform, num_mel_bins=self.audio_feature_size, | |
sample_frequency=self.sampling_rate) | |
if self.ceptral_normalize: | |
feature = utterance_cmvn(feature, self.normalize_means, self.normalize_vars, self.gpu_id) | |
difference = self.max_audio_frames - feature.size(0) | |
if difference < 0: | |
continue | |
else: | |
src_audio.append(torch.cat([feature] + [padding_vector] * difference)) | |
src_pad_num = int(self.max_audio_frames / self.conv_layers_num / 2) - int(feature.size(0) / self.conv_layers_num / 2) | |
seg_audio.append([1] * int(feature.size(0) / self.conv_layers_num / 2) + [0] * src_pad_num) | |
tgt_out.append(text_single[1:]) | |
text_single[-pad_num-1] = self.vocab.get(PAD_TOKEN) | |
tgt_in.append(text_single[:-1]) | |
pad_num = max(pad_num - 1, 0) # left shifted, pad_num >= 0 | |
tgt_seg.append([1] * (len(tgt_in[-1]) - pad_num) + [0] * pad_num) | |
if len(src_audio) == 0: | |
continue | |
if self.specaugment: | |
src_audio = self.specaugment(src_audio) | |
yield torch.stack(src_audio, 0), \ | |
torch.LongTensor(tgt_out), \ | |
torch.LongTensor(seg_audio), \ | |
torch.LongTensor(tgt_in), \ | |
torch.LongTensor(tgt_seg) | |
class BeitDataloader(VisionDataloader): | |
def __init__(self, args, dataset_path, batch_size, rank, world_size, gpu_id, shuffle=False, model_for_dataloader=None): | |
super(BeitDataloader, self).__init__(args, dataset_path, batch_size, rank, world_size, gpu_id, shuffle, model_for_dataloader) | |
from tencentpretrain.utils.image_tokenizer import build_vqgan_model | |
self.vqgan = self.model_for_dataloader | |
def mask(self, image_tokens, mask_rate = 0.15): | |
mask_num = int(len(image_tokens) * mask_rate) | |
mask_index = random.sample(range(1, len(image_tokens)), mask_num) | |
tgt = [0] * len(image_tokens) | |
for idx in mask_index: | |
tgt[idx] = image_tokens[idx] | |
return tgt, mask_index | |
def __iter__(self): | |
""" | |
instances: (tgt, image_path) | |
tgt: The category the image belongs to | |
image_path: Path of the image sample | |
Returns: | |
src_image: [batch_size x channel_size x width x hight] | |
seg: [batch_size x (patch_num + 1)] | |
tgt: [batch_size] | |
""" | |
from torchvision.io import read_image | |
from torchvision.io.image import ImageReadMode | |
from tencentpretrain.utils.image_tokenizer import image_tokenize | |
while True: | |
while self._empty(): | |
self._fill_buf() | |
if self.start + self.batch_size >= self.end: | |
instances = self.buffer[self.start:] | |
else: | |
instances = self.buffer[self.start: self.start + self.batch_size] | |
self.start += self.batch_size | |
src = [] | |
tgt = [] | |
seg = [] | |
mask = [] | |
for ins in instances: | |
image = read_image(ins, ImageReadMode.RGB) | |
image = image.cuda(self.gpu_id) | |
image = self.transform(image) | |
src.append(image) | |
image_tokens = [0] + image_tokenize(self.vqgan, image) | |
tgt_single, mask_index = self.mask(image_tokens) | |
tgt.append(tgt_single) | |
mask.append(mask_index) | |
seg.append([1] * ((self.image_height // self.patch_size) * (self.image_width // self.patch_size) + 1)) | |
yield torch.stack(src, 0), \ | |
torch.LongTensor(tgt), \ | |
torch.LongTensor(seg), \ | |
mask | |
class DalleDataloader(VisionDataloader): | |
def __init__(self, args, dataset_path, batch_size, rank, world_size, gpu_id, shuffle=False, model_for_dataloader=None): | |
super(DalleDataloader, self).__init__(args, dataset_path, batch_size, rank, world_size, gpu_id, shuffle, model_for_dataloader) | |
from tencentpretrain.utils.image_tokenizer import build_vqgan_model | |
self.vqgan = self.model_for_dataloader | |
self.vocab_bias = args.tokenizer.vocab_bias | |
def __iter__(self): | |
from torchvision.io import read_image | |
from torchvision.io.image import ImageReadMode | |
from tencentpretrain.utils.image_tokenizer import image_tokenize | |
while True: | |
while self._empty(): | |
self._fill_buf() | |
if self.start + self.batch_size >= self.end: | |
instances = self.buffer[self.start:] | |
else: | |
instances = self.buffer[self.start: self.start + self.batch_size] | |
self.start += self.batch_size | |
src = [] | |
tgt = [] | |
seg = [] | |
for ins in instances: | |
src_single, pad_num = ins[0] | |
image = read_image(ins[2], ImageReadMode.RGB) | |
image = image.cuda(self.gpu_id) | |
image = self.transform(image) | |
image_tokens = [i + self.vocab_bias for i in image_tokenize(self.vqgan, image)] | |
src_single.extend(image_tokens) | |
for _ in range(pad_num): | |
src_single.append(self.vocab.get(PAD_TOKEN)) | |
seg_single = [1] * ins[1][0] + [2] * len(image_tokens) + [0] * pad_num | |
src.append(src_single) | |
tgt.append(src_single[1:] + [self.vocab.get(SEP_TOKEN)]) | |
seg.append(seg_single) | |
yield torch.LongTensor(src), \ | |
torch.LongTensor(tgt), \ | |
torch.LongTensor(seg) | |