# -------------------------------------------------------- # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) # Github source: https://github.com/mbzuai-nlp/ArTST # Based on speecht5, fairseq and espnet code bases # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet # -------------------------------------------------------- import logging import os from typing import Any, List, Optional import numpy as np import torch import torch.nn.functional as F from fairseq.data import data_utils, Dictionary from fairseq.data.fairseq_dataset import FairseqDataset logger = logging.getLogger(__name__) def load_audio(manifest_path, max_keep, min_keep): """manifest tsv: wav_path, wav_nframe, wav_class Args manifest_path: str max_keep: int min_keep: int Return root, names, inds, tot, sizes, classes """ n_long, n_short = 0, 0 names, inds, sizes, classes = [], [], [], [] with open(manifest_path) as f: root = f.readline().strip() for ind, line in enumerate(f): items = line.strip().split("\t") assert len(items) >= 2, line sz = int(items[1]) if min_keep is not None and sz < min_keep: n_short += 1 elif max_keep is not None and sz > max_keep: n_long += 1 else: names.append(items[0]) if len(items) > 2: classes.append(items[2]) inds.append(ind) sizes.append(sz) tot = ind + 1 logger.info( ( f"max_keep={max_keep}, min_keep={min_keep}, " f"loaded {len(names)}, skipped {n_short} short and {n_long} long, " f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}" ) ) if len(classes) == 0: logger.warn("no classes loaded only if inference") return root, names, inds, tot, sizes, classes def sample_from_feature(x: np.ndarray, max_segment_length: int = 300): """Load a segment within 300-400/51200-76800 frames or the corresponding samples from a utterance. Args: x (np.ndarray): feature or waveform (frames[, features]), e.g., log mel filter bank or waveform max_segment_length (int, optional): maximum segment length. Defaults to 400. Returns: np.ndarray: segmented features """ if len(x) <= max_segment_length: return x start = np.random.randint(0, x.shape[0] - max_segment_length) return x[start: start + max_segment_length] class SpeechToClassDataset(FairseqDataset): def __init__( self, manifest_path: str, sample_rate: float, label_processors: Optional[List[Any]] = None, max_keep_sample_size: Optional[int] = None, min_keep_sample_size: Optional[int] = None, shuffle: bool = True, normalize: bool = False, tgt_dict: Optional[Dictionary] = None, max_length: Optional[int] = None ): self.audio_root, self.audio_names, inds, tot, self.wav_sizes, self.wav_classes = load_audio( manifest_path, max_keep_sample_size, min_keep_sample_size ) self.sample_rate = sample_rate self.shuffle = shuffle self.label_processors = label_processors self.normalize = normalize self.tgt_dict = tgt_dict self.max_length = max_length logger.info( f"max_length={max_length}, normalize={normalize}" ) def get_audio(self, index): import soundfile as sf wav_path = os.path.join(self.audio_root, self.audio_names[index]) wav, cur_sample_rate = sf.read(wav_path) if self.max_length is not None: wav = sample_from_feature(wav, self.max_length) wav = torch.from_numpy(wav).float() wav = self.postprocess(wav, cur_sample_rate) return wav def get_label(self, index): label = self.wav_classes[index] if self.label_processors is not None: label = self.label_processors(label) return label def __getitem__(self, index): wav = self.get_audio(index) label = None if len(self.wav_classes) == len(self.audio_names): label = self.get_label(index) return {"id": index, "source": wav, "label": label} def __len__(self): return len(self.wav_sizes) def collater(self, samples): samples = [s for s in samples if s["source"] is not None] if len(samples) == 0: return {} audios = [s["source"] for s in samples] audio_sizes = [len(s) for s in audios] audio_size = max(audio_sizes) collated_audios, padding_mask = self.collater_audio( audios, audio_size ) decoder_label = None decoder_target = None decoder_target_lengths = None if samples[0]["label"] is not None: targets_by_label = [ [s["label"] for s in samples] ] targets_list, lengths_list, ntokens_list = self.collater_label(targets_by_label) decoder_label = [ (targets_list[0][i, :lengths_list[0][i]]).long() for i in range(targets_list[0].size(0)) ] decoder_target = data_utils.collate_tokens( decoder_label, self.tgt_dict.pad(), self.tgt_dict.eos(), left_pad=False, move_eos_to_beginning=False, ) decoder_target_lengths = torch.tensor( [x.size(0) for x in decoder_label], dtype=torch.long ) prev_output_tokens = data_utils.collate_tokens( [torch.LongTensor([-1]) for _ in samples], self.tgt_dict.pad(), self.tgt_dict.eos(), left_pad=False, move_eos_to_beginning=True, ) net_input = { "source": collated_audios, "padding_mask": padding_mask, "prev_output_tokens": prev_output_tokens, "task_name": "s2c", } batch = { "id": torch.LongTensor([s["id"] for s in samples]), "net_input": net_input, "target": decoder_target, "target_lengths": decoder_target_lengths, "task_name": "s2c", "ntokens": len(samples), } return batch def collater_audio(self, audios, audio_size): collated_audios = audios[0].new_zeros(len(audios), audio_size) padding_mask = ( torch.BoolTensor(collated_audios.shape).fill_(False) ) for i, audio in enumerate(audios): diff = len(audio) - audio_size if diff == 0: collated_audios[i] = audio elif diff < 0: collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)]) padding_mask[i, diff:] = True else: raise Exception("Diff should not be larger than 0") return collated_audios, padding_mask def collater_seq_label(self, targets, pad): lengths = torch.LongTensor([len(t) for t in targets]) ntokens = lengths.sum().item() targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False) return targets, lengths, ntokens def collater_label(self, targets_by_label): targets_list, lengths_list, ntokens_list = [], [], [] itr = zip(targets_by_label, [self.tgt_dict.pad()]) for targets, pad in itr: targets, lengths, ntokens = self.collater_seq_label(targets, pad) targets_list.append(targets) lengths_list.append(lengths) ntokens_list.append(ntokens) return targets_list, lengths_list, ntokens_list def num_tokens(self, index): return self.size(index) def size(self, index): return self.wav_sizes[index] @property def sizes(self): return np.array(self.wav_sizes) def ordered_indices(self): if self.shuffle: order = [np.random.permutation(len(self))] else: order = [np.arange(len(self))] order.append(self.wav_sizes) return np.lexsort(order)[::-1] def postprocess(self, wav, cur_sample_rate): if wav.dim() == 2: wav = wav.mean(-1) assert wav.dim() == 1, wav.dim() if cur_sample_rate != self.sample_rate: raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}") if self.normalize: with torch.no_grad(): wav = F.layer_norm(wav, wav.shape) return wav