Vietnamese_ASR / src /prepare_data.py
DuyTa's picture
Source )
c6b1960
import logging
import datasets
from datasets import DatasetDict, load_dataset, concatenate_datasets
from tqdm import tqdm
from transformers import (
AutoConfig,
AutoFeatureExtractor,
AutoModelForSpeechSeq2Seq,
AutoTokenizer,
set_seed,
)
from transformers.utils.versions import require_version
from transformers.utils import check_min_version
from tqdm import tqdm
from audiomentations import (
AddBackgroundNoise,
AddGaussianNoise,
Compose,
Gain,
OneOf,
PitchShift,
PolarityInversion,
TimeStretch,
)
check_min_version("4.27.0.dev0")
require_version(
"datasets>=1.18.0",
"To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt",
)
logger = logging.getLogger(__name__)
from datasets import Dataset, DatasetDict
import torchaudio
from torchaudio import transforms as at
import pandas as pd
import torch
from pathlib import Path
import random
def main():
# Set seed before initializing model.
set_seed(42)
# 5. Load pretrained model, tokenizer, and feature extractor
#
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
config = AutoConfig.from_pretrained(
"openai/whisper-medium", revision="main", use_auth_token=True
)
config.update({"forced_decoder_ids": None, "suppress_tokens": None})
# *****************************SpecAugment for whisper models
# if getattr(config, "model_type", None) == "whisper":
config.update({"apply_spec_augment": True})
feature_extractor = AutoFeatureExtractor.from_pretrained(
"openai/whisper-medium",
revision="main",
use_auth_token=True,
)
tokenizer = AutoTokenizer.from_pretrained(
"openai/whisper-medium",
use_fast=True,
revision="main",
use_auth_token=True,
)
tokenizer.set_prefix_tokens(language="vi", task="transcribe")
# 7. Preprocessing the datasets.
# We need to read the audio files as arrays and tokenize the targets.
max_input_length = 30.0 * 16000
min_input_length = 0.0 * 16000
audio_column_name = "audio"
num_workers = 16
text_column_name = "text"
model_input_name = feature_extractor.model_input_names[0]
# if SpecAugment is used for whisper models, return attention_mask to guide the mask along time axis
forward_attention_mask = True
# noise_dir = "../noise/ESC-50-master/audio/"
# define augmentation
augmentation = Compose(
[
TimeStretch(min_rate=0.9, max_rate=1.1, p=0.2, leave_length_unchanged=True),
Gain(min_gain_in_db=-6, max_gain_in_db=6, p=0.1),
PitchShift(min_semitones=-4, max_semitones=4, p=0.2),
]
)
def augment_dataset(batch):
# load and (possibly) resample audio data to 16kHz
sample = batch["audio"]
# apply augmentation
augmented_waveform = augmentation(
sample, sample_rate=16000
)
batch["audio"]["array"] = augmented_waveform
return batch
def prepare_dataset(batch):
# process audio
sample = batch[audio_column_name]
inputs = feature_extractor(
sample,
sampling_rate= 16000,
return_attention_mask=forward_attention_mask,
)
# process audio length
batch[model_input_name] = inputs.get(model_input_name)[0]
batch["input_length"] = len(sample)
if forward_attention_mask:
batch["attention_mask"] = inputs.get("attention_mask")[0]
# process targets
input_str = batch[text_column_name]
batch["labels"] = tokenizer(input_str).input_ids
return batch
def load_wave(wave_path, sample_rate:int=16000) -> torch.Tensor:
waveform, sr = torchaudio.load(wave_path, normalize=True)
if sample_rate != sr:
waveform = at.Resample(sr, sample_rate)(waveform)
return waveform
def get_list_files_MITI(phase, sample_rate=16000, audio_max_sample_length=480000, fraction=0.15):
audio_list = []
text_list = []
if phase == 'train':
csv_file = 'vin_train.csv'
else:
csv_file = 'vin_test.csv'
df = pd.read_csv(csv_file)
# Calculate the number of samples to select based on the fraction
num_samples = int(len(df) * fraction)
# Randomly select the indices of samples
selected_indices = random.sample(range(len(df)), num_samples)
for index, row in tqdm(df.iterrows()):
if index not in selected_indices:
continue
new_path = Path(row['path'])
audio_id = index
text = row['sentence']
if new_path.exists():
audio = load_wave(new_path, sample_rate=sample_rate)[0]
if len(audio) > audio_max_sample_length or len(audio) < 0:
print('skip file:', new_path, 'with len audio', len(audio))
continue
audio_list.append(audio)
text_list.append(text)
return audio_list, text_list
# Assuming you have two CSV files, 'vin_train.csv' and 'vin_test.csv', in the same directory
# Get the training dataset
train_audio, train_text = get_list_files_MITI(phase='train')
# Get the testing dataset
test_audio, test_text = get_list_files_MITI(phase='test')
# Create the Dataset objects
train_dataset = Dataset.from_dict({"audio": train_audio, "text": train_text})
test_dataset = Dataset.from_dict({"audio": test_audio, "text": test_text})
# Create the DatasetDict
vin_100h = DatasetDict({"train": train_dataset, "test": test_dataset})
print(vin_100h)
vectorized_datasets = vin_100h.map(
prepare_dataset,
remove_columns=["audio", "text"],
num_proc=1,
desc="preprocess train dataset",
)
print(vectorized_datasets)
vectorized_datasets.save_to_disk(
"./vin_10h", num_proc=1
)
return
if __name__ == "__main__":
main()