|
import logging |
|
|
|
import datasets |
|
from datasets import DatasetDict, load_dataset, concatenate_datasets |
|
from tqdm import tqdm |
|
from transformers import ( |
|
AutoConfig, |
|
AutoFeatureExtractor, |
|
AutoModelForSpeechSeq2Seq, |
|
AutoTokenizer, |
|
set_seed, |
|
) |
|
from transformers.utils.versions import require_version |
|
from transformers.utils import check_min_version |
|
from tqdm import tqdm |
|
|
|
from audiomentations import ( |
|
AddBackgroundNoise, |
|
AddGaussianNoise, |
|
Compose, |
|
Gain, |
|
OneOf, |
|
PitchShift, |
|
PolarityInversion, |
|
TimeStretch, |
|
) |
|
|
|
|
|
check_min_version("4.27.0.dev0") |
|
|
|
require_version( |
|
"datasets>=1.18.0", |
|
"To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt", |
|
) |
|
|
|
logger = logging.getLogger(__name__) |
|
from datasets import Dataset, DatasetDict |
|
import torchaudio |
|
from torchaudio import transforms as at |
|
import pandas as pd |
|
import torch |
|
from pathlib import Path |
|
import random |
|
def main(): |
|
|
|
set_seed(42) |
|
|
|
|
|
|
|
|
|
|
|
config = AutoConfig.from_pretrained( |
|
"openai/whisper-medium", revision="main", use_auth_token=True |
|
) |
|
|
|
config.update({"forced_decoder_ids": None, "suppress_tokens": None}) |
|
|
|
|
|
|
|
config.update({"apply_spec_augment": True}) |
|
|
|
feature_extractor = AutoFeatureExtractor.from_pretrained( |
|
"openai/whisper-medium", |
|
revision="main", |
|
use_auth_token=True, |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
"openai/whisper-medium", |
|
use_fast=True, |
|
revision="main", |
|
use_auth_token=True, |
|
) |
|
|
|
tokenizer.set_prefix_tokens(language="vi", task="transcribe") |
|
|
|
|
|
|
|
max_input_length = 30.0 * 16000 |
|
min_input_length = 0.0 * 16000 |
|
audio_column_name = "audio" |
|
num_workers = 16 |
|
text_column_name = "text" |
|
model_input_name = feature_extractor.model_input_names[0] |
|
|
|
|
|
forward_attention_mask = True |
|
|
|
|
|
|
|
augmentation = Compose( |
|
[ |
|
TimeStretch(min_rate=0.9, max_rate=1.1, p=0.2, leave_length_unchanged=True), |
|
Gain(min_gain_in_db=-6, max_gain_in_db=6, p=0.1), |
|
PitchShift(min_semitones=-4, max_semitones=4, p=0.2), |
|
] |
|
) |
|
|
|
def augment_dataset(batch): |
|
|
|
sample = batch["audio"] |
|
|
|
|
|
augmented_waveform = augmentation( |
|
sample, sample_rate=16000 |
|
) |
|
batch["audio"]["array"] = augmented_waveform |
|
return batch |
|
|
|
def prepare_dataset(batch): |
|
|
|
sample = batch[audio_column_name] |
|
inputs = feature_extractor( |
|
sample, |
|
sampling_rate= 16000, |
|
return_attention_mask=forward_attention_mask, |
|
) |
|
|
|
batch[model_input_name] = inputs.get(model_input_name)[0] |
|
batch["input_length"] = len(sample) |
|
if forward_attention_mask: |
|
batch["attention_mask"] = inputs.get("attention_mask")[0] |
|
|
|
|
|
input_str = batch[text_column_name] |
|
batch["labels"] = tokenizer(input_str).input_ids |
|
return batch |
|
|
|
|
|
def load_wave(wave_path, sample_rate:int=16000) -> torch.Tensor: |
|
waveform, sr = torchaudio.load(wave_path, normalize=True) |
|
if sample_rate != sr: |
|
waveform = at.Resample(sr, sample_rate)(waveform) |
|
return waveform |
|
|
|
|
|
def get_list_files_MITI(phase, sample_rate=16000, audio_max_sample_length=480000, fraction=0.15): |
|
audio_list = [] |
|
text_list = [] |
|
if phase == 'train': |
|
csv_file = 'vin_train.csv' |
|
else: |
|
csv_file = 'vin_test.csv' |
|
df = pd.read_csv(csv_file) |
|
|
|
|
|
num_samples = int(len(df) * fraction) |
|
|
|
|
|
selected_indices = random.sample(range(len(df)), num_samples) |
|
|
|
for index, row in tqdm(df.iterrows()): |
|
if index not in selected_indices: |
|
continue |
|
|
|
new_path = Path(row['path']) |
|
audio_id = index |
|
text = row['sentence'] |
|
|
|
if new_path.exists(): |
|
audio = load_wave(new_path, sample_rate=sample_rate)[0] |
|
if len(audio) > audio_max_sample_length or len(audio) < 0: |
|
print('skip file:', new_path, 'with len audio', len(audio)) |
|
continue |
|
audio_list.append(audio) |
|
text_list.append(text) |
|
|
|
return audio_list, text_list |
|
|
|
|
|
|
|
|
|
train_audio, train_text = get_list_files_MITI(phase='train') |
|
|
|
|
|
test_audio, test_text = get_list_files_MITI(phase='test') |
|
|
|
|
|
train_dataset = Dataset.from_dict({"audio": train_audio, "text": train_text}) |
|
test_dataset = Dataset.from_dict({"audio": test_audio, "text": test_text}) |
|
|
|
|
|
vin_100h = DatasetDict({"train": train_dataset, "test": test_dataset}) |
|
|
|
|
|
|
|
|
|
|
|
print(vin_100h) |
|
|
|
|
|
|
|
|
|
|
|
vectorized_datasets = vin_100h.map( |
|
prepare_dataset, |
|
remove_columns=["audio", "text"], |
|
num_proc=1, |
|
desc="preprocess train dataset", |
|
) |
|
|
|
|
|
print(vectorized_datasets) |
|
|
|
vectorized_datasets.save_to_disk( |
|
"./vin_10h", num_proc=1 |
|
) |
|
|
|
return |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |