|
import json |
|
import random |
|
from tqdm import tqdm |
|
import torch |
|
import decord |
|
|
|
decord.bridge.set_bridge("torch") |
|
import torchaudio |
|
from math import ceil |
|
from torch.utils.data import Dataset, DataLoader |
|
import pandas as pd |
|
import numpy as np |
|
|
|
|
|
class AudioVisualDataset(Dataset): |
|
"""Can sample data from audio-visual databases |
|
Params: |
|
min_video_frames: used to drop short video clips |
|
video_resize: resize for CLIP processing |
|
sampling_rate: audio sampling rate |
|
max_clip_len: max length (seconds) of audiovisual clip to be sampled |
|
num_sample_frames: number of image frames to be uniformly sampled from video |
|
""" |
|
|
|
def __init__( |
|
self, |
|
datafiles=[ |
|
"/mnt/bn/data-xubo/dataset/audioset_videos/datafiles/audioset_balanced_train.json", |
|
], |
|
min_video_frames=30, |
|
video_resize=[224, 224], |
|
sampling_rate=16000, |
|
sample_av_clip=True, |
|
max_clip_len=10, |
|
num_sample_frames=10, |
|
|
|
freqm=48, |
|
timem=192, |
|
return_label=False, |
|
): |
|
all_data_json = [] |
|
for datafile in datafiles: |
|
with open(datafile, "r") as fp: |
|
data_json = json.load(fp)["data"] |
|
all_data_json.extend(data_json) |
|
|
|
|
|
self.all_data_json = [ |
|
data |
|
for data in all_data_json |
|
if int(data["video_shape"][0]) >= min_video_frames |
|
] |
|
|
|
self.max_clip_len = max_clip_len |
|
self.video_resize = video_resize |
|
self.sampling_rate = sampling_rate |
|
self.sample_av_clip = sample_av_clip |
|
self.num_sample_frames = num_sample_frames |
|
self.corresponding_audio_len = self.sampling_rate * self.max_clip_len |
|
|
|
|
|
self.freqm = freqm |
|
self.timem = timem |
|
self.norm_mean = -4.2677393 |
|
self.norm_std = 4.5689974 |
|
self.melbins = 128 |
|
self.TARGET_LEN = 1024 |
|
|
|
self.return_label = return_label |
|
if self.return_label: |
|
self.audioset_label2idx = self._prepare_audioset() |
|
|
|
def __len__(self): |
|
return len(self.all_data_json) |
|
|
|
def _read_audio_video(self, index): |
|
try: |
|
video_path = self.all_data_json[index]["mp4"] |
|
|
|
ar = decord.AudioReader( |
|
video_path, sample_rate=self.sampling_rate, mono=True |
|
) |
|
|
|
vr = decord.VideoReader( |
|
video_path, |
|
height=self.video_resize[0], |
|
width=self.video_resize[1], |
|
) |
|
|
|
labels = self.all_data_json[index]["labels"] |
|
return vr, ar, labels |
|
|
|
except Exception as e: |
|
print(f"error: {e} occurs, when loading {video_path}") |
|
random_index = random.randint(0, len(self.all_data_json) - 1) |
|
return self._read_audio_video(index=random_index) |
|
|
|
def _prepare_audioset(self): |
|
df1 = pd.read_csv( |
|
"/mnt/bn/lqhaoheliu/datasets/audioset/metadata/class_labels_indices.csv", |
|
delimiter=",", |
|
skiprows=0, |
|
) |
|
label_set = df1.to_numpy() |
|
code2id = {} |
|
for i in range(len(label_set)): |
|
code2id[label_set[i][1]] = label_set[i][0] |
|
return code2id |
|
|
|
def __getitem__(self, index): |
|
|
|
vr, ar, labels = self._read_audio_video(index) |
|
|
|
|
|
audio_data = ar[:] |
|
audio_len = audio_data.shape[1] / self.sampling_rate |
|
audio_data = audio_data.squeeze(0) |
|
|
|
|
|
full_vid_length = len(vr) |
|
video_rate = ceil(vr.get_avg_fps()) |
|
samples_per_frame = float(self.sampling_rate) / video_rate |
|
start_frame = 0 |
|
|
|
|
|
if audio_len > self.max_clip_len and self.sample_av_clip: |
|
start_frame = random.randint( |
|
0, max(full_vid_length - video_rate * self.max_clip_len, 0) |
|
) |
|
end_frame = min(start_frame + video_rate * self.max_clip_len, full_vid_length) |
|
video_data = vr.get_batch(range(start_frame, end_frame)) |
|
|
|
|
|
if audio_len > self.max_clip_len and self.sample_av_clip: |
|
|
|
corresponding_audio_start = int(start_frame * samples_per_frame) |
|
audio_data = audio_data[corresponding_audio_start:] |
|
|
|
|
|
if audio_data.shape[0] < self.corresponding_audio_len: |
|
zero_data = torch.zeros(self.corresponding_audio_len) |
|
zero_data[: audio_data.shape[0]] = audio_data |
|
audio_data = zero_data |
|
elif audio_data.shape[0] > self.corresponding_audio_len: |
|
audio_data = audio_data[: self.corresponding_audio_len] |
|
|
|
|
|
interval = video_data.shape[0] // self.num_sample_frames |
|
video_data = video_data[::interval][: self.num_sample_frames] |
|
|
|
assert ( |
|
video_data.shape[0] == self.num_sample_frames |
|
), f"number of sampled image frames is {video_data.shape[0]}" |
|
|
|
assert ( |
|
audio_data.shape[0] == self.corresponding_audio_len |
|
), f"number of audio samples is {audio_data.shape[0]}" |
|
|
|
|
|
video_data = video_data / 255.0 |
|
video_data = video_data.permute(0, 3, 1, 2) |
|
|
|
|
|
audio_data = audio_data.unsqueeze(0) |
|
audio_data = audio_data - audio_data.mean() |
|
fbank = torchaudio.compliance.kaldi.fbank( |
|
audio_data, |
|
htk_compat=True, |
|
sample_frequency=self.sampling_rate, |
|
use_energy=False, |
|
window_type="hanning", |
|
num_mel_bins=self.melbins, |
|
dither=0.0, |
|
frame_shift=10, |
|
) |
|
|
|
n_frames = fbank.shape[0] |
|
p = self.TARGET_LEN - n_frames |
|
if p > 0: |
|
m = torch.nn.ZeroPad2d((0, 0, 0, p)) |
|
fbank = m(fbank) |
|
elif p < 0: |
|
fbank = fbank[0 : self.TARGET_LEN, :] |
|
|
|
|
|
freqm = torchaudio.transforms.FrequencyMasking(self.freqm) |
|
timem = torchaudio.transforms.TimeMasking(self.timem) |
|
fbank = fbank.transpose(0, 1).unsqueeze(0) |
|
if self.freqm != 0: |
|
fbank = freqm(fbank) |
|
if self.timem != 0: |
|
fbank = timem(fbank) |
|
fbank = torch.transpose(fbank.squeeze(), 0, 1) |
|
fbank = (fbank - self.norm_mean) / (self.norm_std * 2) |
|
fbank = fbank.unsqueeze(0) |
|
|
|
if self.return_label: |
|
|
|
label_indices = np.zeros(527) |
|
|
|
for label_str in labels.split(","): |
|
label_indices[int(self.audioset_label2idx[label_str])] = 1.0 |
|
|
|
label_indices = torch.FloatTensor(label_indices) |
|
|
|
data_dict = { |
|
"labels": label_indices, |
|
"images": video_data, |
|
"fbank": fbank, |
|
|
|
} |
|
|
|
else: |
|
data_dict = { |
|
"images": video_data, |
|
"fbank": fbank, |
|
|
|
} |
|
|
|
return data_dict |
|
|
|
|
|
def collate_fn(list_data_dict): |
|
r"""Collate mini-batch data to inputs and targets for training. |
|
|
|
Args: |
|
list_data_dict: e.g., [ |
|
{'vocals': (channels_num, segment_samples), |
|
'accompaniment': (channels_num, segment_samples), |
|
'mixture': (channels_num, segment_samples) |
|
}, |
|
{'vocals': (channels_num, segment_samples), |
|
'accompaniment': (channels_num, segment_samples), |
|
'mixture': (channels_num, segment_samples) |
|
}, |
|
...] |
|
|
|
Returns: |
|
data_dict: e.g. { |
|
'vocals': (batch_size, channels_num, segment_samples), |
|
'accompaniment': (batch_size, channels_num, segment_samples), |
|
'mixture': (batch_size, channels_num, segment_samples) |
|
} |
|
""" |
|
|
|
data_dict = {} |
|
for key in list_data_dict[0].keys(): |
|
|
|
|
|
data_dict[key] = [data_dict[key] for data_dict in list_data_dict] |
|
|
|
|
|
|
|
data_dict[key] = torch.stack(data_dict[key]) |
|
|
|
return data_dict |
|
|