|
|
|
|
|
|
|
|
|
|
|
|
|
from io import BytesIO |
|
|
|
import logging |
|
import warnings |
|
import string |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from PIL import Image, ImageFile |
|
|
|
from data import data_utils |
|
from data.ofa_dataset import OFADataset |
|
|
|
|
|
import os |
|
import random |
|
|
|
import soundfile as sf |
|
|
|
import torchaudio |
|
|
|
|
|
from data.audio_utils import get_audio_features, int16_to_float32, float32_to_int16, AUDIO_CFG |
|
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
ImageFile.MAX_IMAGE_PIXELS = None |
|
Image.MAX_IMAGE_PIXELS = None |
|
|
|
logger = logging.getLogger(__name__) |
|
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) |
|
|
|
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) |
|
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) |
|
|
|
|
|
|
|
|
|
|
|
def collate(samples, pad_idx, eos_idx): |
|
if len(samples) == 0: |
|
return {} |
|
|
|
def merge(key): |
|
return data_utils.collate_tokens( |
|
[s[key] for s in samples], |
|
pad_idx, |
|
eos_idx=eos_idx, |
|
) |
|
|
|
id = np.array([s["id"] for s in samples]) |
|
src_tokens = merge("source") |
|
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples]) |
|
|
|
patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0) |
|
patch_masks = torch.cat([sample['patch_mask'] for sample in samples]) |
|
|
|
patch_videos = torch.stack([sample['patch_video'] for sample in samples], dim=0) |
|
patch_types = torch.cat([sample['patch_type'] for sample in samples]) |
|
patch_audios = torch.stack([sample['patch_audio'] for sample in samples], dim=0) |
|
|
|
|
|
|
|
prev_output_tokens = None |
|
target = None |
|
if samples[0].get("target", None) is not None: |
|
target = merge("target") |
|
tgt_lengths = torch.LongTensor([s["target"].ne(pad_idx).long().sum() for s in samples]) |
|
ntokens = tgt_lengths.sum().item() |
|
|
|
if samples[0].get("prev_output_tokens", None) is not None: |
|
prev_output_tokens = merge("prev_output_tokens") |
|
else: |
|
ntokens = src_lengths.sum().item() |
|
|
|
batch = { |
|
"id": id, |
|
"nsentences": len(samples), |
|
"ntokens": ntokens, |
|
"net_input": { |
|
"src_tokens": src_tokens, |
|
"src_lengths": src_lengths, |
|
"patch_images": patch_images, |
|
"patch_masks": patch_masks, |
|
"prev_output_tokens": prev_output_tokens, |
|
"patch_videos": patch_videos, |
|
"patch_types": patch_types, |
|
"patch_audios": patch_audios, |
|
}, |
|
"target": target, |
|
} |
|
|
|
return batch |
|
|
|
|
|
class CaptionDataset(OFADataset): |
|
def __init__( |
|
self, |
|
split, |
|
dataset, |
|
bpe, |
|
src_dict, |
|
tgt_dict=None, |
|
max_src_length=128, |
|
max_tgt_length=30, |
|
patch_image_size=224, |
|
imagenet_default_mean_and_std=False, |
|
scst=False, |
|
image_dir='/gpfsscratch/rech/dyf/ugz83ue/data', |
|
audio_cfg=AUDIO_CFG, |
|
max_audio_len = 480000, |
|
num_frames=4, |
|
sample_rate = 48000, |
|
audio_sample_rate=False, |
|
ast=False, |
|
mode='train', |
|
mel_bins=64, |
|
|
|
): |
|
super().__init__(split, dataset, bpe, src_dict, tgt_dict) |
|
self.max_src_length = max_src_length |
|
self.max_tgt_length = max_tgt_length |
|
self.patch_image_size = patch_image_size |
|
self.scst = scst |
|
|
|
self.image_dir = image_dir |
|
|
|
self.sample_rate = sample_rate |
|
|
|
self.transtab = str.maketrans({key: None for key in string.punctuation}) |
|
|
|
|
|
|
|
self.num_frames = num_frames |
|
|
|
|
|
self.audio_cfg = audio_cfg |
|
self.max_audio_len = max_audio_len |
|
|
|
self.audio_sample_rate = audio_sample_rate |
|
|
|
|
|
|
|
if type(bpe).__name__ == 'GPT2BPE': |
|
self.prompt = " what does the video describe?" |
|
else: |
|
raise NotImplemented |
|
|
|
|
|
self.ast = ast |
|
self.target_length = 1024 |
|
self.mode = split |
|
self.freqm_p = 24 |
|
self.timem_p = 96 |
|
self.skip_norm = False |
|
self.noise = False |
|
self.norm_mean = -4.2677393 |
|
self.norm_std = 4.5689974 |
|
self.freqm = torchaudio.transforms.FrequencyMasking(self.freqm_p) |
|
self.timem = torchaudio.transforms.TimeMasking(self.timem_p) |
|
self.mel_bins = mel_bins |
|
|
|
def __getitem__(self, index): |
|
uniq_id, image, caption = self.dataset[index] |
|
|
|
|
|
|
|
image_path = os.path.join(self.image_dir, image) |
|
data_path = image_path |
|
|
|
|
|
try: |
|
|
|
|
|
if not self.audio_sample_rate: |
|
audio_data, orig_sr = sf.read(data_path) |
|
if audio_data.ndim>1: |
|
audio_data = np.mean(audio_data,axis=1) |
|
audio_data = int16_to_float32(float32_to_int16(audio_data)) |
|
audio_data = torch.tensor(audio_data).float() |
|
else: |
|
audio_data, orig_sr = torchaudio.load(data_path) |
|
audio_data = torchaudio.transforms.Resample(orig_sr, self.sample_rate)(audio_data[0]) |
|
|
|
sample = {} |
|
|
|
sample = get_audio_features( |
|
sample, audio_data, self.max_audio_len, |
|
data_truncating='rand_trunc', |
|
data_filling='repeatpad', |
|
audio_cfg=self.audio_cfg |
|
) |
|
except Exception as e: |
|
new_index = random.randint(0, len(self) - 1) |
|
logger.warning( |
|
f"Caught exception {e} when loading video {data_path}, " |
|
f"randomly sample a new video as replacement" |
|
) |
|
return self.__getitem__(new_index) |
|
|
|
waveform = sample['waveform'] |
|
patch_audio = waveform |
|
|
|
|
|
patch_type = torch.tensor([2]) |
|
|
|
patch_image = torch.zeros((3, self.patch_image_size, self.patch_image_size)) |
|
patch_video = torch.zeros((3, self.num_frames, self.patch_image_size, self.patch_image_size)) |
|
|
|
|
|
patch_mask = torch.tensor([True]) |
|
|
|
if self.split == 'train' and not self.scst: |
|
caption = caption.translate(self.transtab).strip() |
|
caption_token_list = caption.strip().split() |
|
tgt_caption = ' '.join(caption_token_list[:self.max_tgt_length]) |
|
else: |
|
caption = ' '.join(caption.strip().split()) |
|
caption_list = [cap.translate(self.transtab).strip() for cap in caption.strip().split('&&')] |
|
tgt_caption = '&&'.join(caption_list) |
|
src_item = self.encode_text(self.prompt) |
|
tgt_item = self.encode_text(" {}".format(tgt_caption)) |
|
|
|
src_item = torch.cat([self.bos_item, src_item, self.eos_item]) |
|
target_item = torch.cat([tgt_item, self.eos_item]) |
|
prev_output_item = torch.cat([self.bos_item, tgt_item]) |
|
|
|
example = { |
|
"id": uniq_id, |
|
"source": src_item, |
|
"patch_image": patch_image, |
|
"patch_mask": patch_mask, |
|
"target": target_item, |
|
"prev_output_tokens": prev_output_item, |
|
"patch_type": patch_type, |
|
"patch_video": patch_video, |
|
"patch_audio": patch_audio, |
|
} |
|
return example |
|
|
|
def collater(self, samples, pad_to_length=None): |
|
"""Merge a list of samples to form a mini-batch. |
|
Args: |
|
samples (List[dict]): samples to collate |
|
Returns: |
|
dict: a mini-batch containing the data of the task |
|
""" |
|
return collate(samples, pad_idx=self.pad, eos_idx=self.eos) |
|
|