| import os
|
| import re
|
| import random
|
| import torch
|
| import torchaudio
|
|
|
| MATPLOTLIB_FLAG = False
|
|
|
|
|
| def load_audio(audiopath, sampling_rate):
|
| audio, sr = torchaudio.load(audiopath)
|
|
|
|
|
| if audio.size(0) > 1:
|
| audio = audio[0].unsqueeze(0)
|
|
|
| if sr != sampling_rate:
|
| try:
|
| audio = torchaudio.functional.resample(audio, sr, sampling_rate)
|
| except Exception as e:
|
| print(f"Warning: {audiopath}, wave shape: {audio.shape}, sample_rate: {sr}")
|
| return None
|
|
|
| audio.clip_(-1, 1)
|
| return audio
|
|
|
|
|
| def tokenize_by_CJK_char(line: str) -> str:
|
| """
|
| Tokenize a line of text with CJK char.
|
|
|
| Note: All return charaters will be upper case.
|
|
|
| Example:
|
| input = "你好世界是 hello world 的中文"
|
| output = "你 好 世 界 是 HELLO WORLD 的 中 文"
|
|
|
| Args:
|
| line:
|
| The input text.
|
|
|
| Return:
|
| A new string tokenize by CJK char.
|
| """
|
|
|
| pattern = re.compile(
|
| r"([\u1100-\u11ff\u2e80-\ua4cf\ua840-\uD7AF\uF900-\uFAFF\uFE30-\uFE4F\uFF65-\uFFDC\U00020000-\U0002FFFF])"
|
| )
|
| chars = pattern.split(line.strip().upper())
|
| return " ".join([w.strip() for w in chars if w.strip()])
|
|
|
|
|
| def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
| """Make mask tensor containing indices of padded part.
|
|
|
| See description of make_non_pad_mask.
|
|
|
| Args:
|
| lengths (torch.Tensor): Batch of lengths (B,).
|
| Returns:
|
| torch.Tensor: Mask tensor containing indices of padded part.
|
|
|
| Examples:
|
| >>> lengths = [5, 3, 2]
|
| >>> make_pad_mask(lengths)
|
| masks = [[0, 0, 0, 0 ,0],
|
| [0, 0, 0, 1, 1],
|
| [0, 0, 1, 1, 1]]
|
| """
|
| batch_size = lengths.size(0)
|
| max_len = max_len if max_len > 0 else lengths.max().item()
|
| seq_range = torch.arange(0,
|
| max_len,
|
| dtype=torch.int64,
|
| device=lengths.device)
|
| seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
| seq_length_expand = lengths.unsqueeze(-1)
|
| mask = seq_range_expand >= seq_length_expand
|
| return mask
|
|
|
|
|
| def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
|
| """
|
| Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
|
|
|
| Args:
|
| x (Tensor): Input tensor.
|
| clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
|
|
|
| Returns:
|
| Tensor: Element-wise logarithm of the input tensor with clipping applied.
|
| """
|
| return torch.log(torch.clip(x, min=clip_val))
|
|
|