In [None]:
# This notebook is currently designed for a GPU using fp16. Hyperparameters however are barely tuned.

In [None]:
import json
import random
import torch
from pathlib import Path
from accelerate import Accelerator
from datasets import load_dataset, concatenate_datasets
from datasets.features import Audio
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.utils.tensorboard import SummaryWriter
from transformers import AutoTokenizer, Wav2Vec2FeatureExtractor
from wer import calculate_wer # Not what's used in eval.py.
from model import Wav2VecGPT2Model

In [None]:
common_voice = load_dataset('mozilla-foundation/common_voice_7_0', 'de', use_auth_token=True)

In [None]:
EXPERIMENT_NAME = '00'

model_dir = Path('end2end/de') / EXPERIMENT_NAME
log_dir = model_dir / 'logs'
log_dir.mkdir(exist_ok=True, parents=True)

config = {
 'encoder_id': 'jonatasgrosman/wav2vec2-large-xlsr-53-german',
 'decoder_id': 'dbmdz/german-gpt2',
 'decoder_pad_token': '_',
 'decoder_bos_token': '~',
 'num_beams': 1,
 'num_val_examples': 1500,
 'batch_size': 8,
 'base_lr': 3e-4,
 'weight_decay': 0.,
 'accumulate_grad': 4,
 'max_epochs': 10,
 'max_len': 36 # len(max(tokenizer(common_voice['validation']['sentence'] + common_voice['test']['sentence']).input_ids, key=len))

In [None]:
tokenizer = AutoTokenizer.from_pretrained(config['decoder_id'])
tokenizer.add_special_tokens({'pad_token': config['decoder_pad_token'], 'bos_token': config['decoder_bos_token']})

wave2vec_extractor = Wav2Vec2FeatureExtractor.from_pretrained(config['encoder_id'])

model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
 config['encoder_id'], config['decoder_id'], max_length=config['max_len'], num_beams=config['num_beams']
)

model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id

In [None]:
# Load model from decoder-only training.
model.load_state_dict(torch.load('decoder_only/de/00/model.pt'))

In [None]:
class AudioDataset(Dataset):
 
 def __init__(self, ds):
 self.ds = ds
 
 def __len__(self):
 return len(self.ds)
 
 def __getitem__(self, idx):
 eg = self.ds[idx]
 return eg['audio']['array'], eg['sentence']
 
def collate_fn(examples):
 # Remove the longest examples, should be only three and these may lead to OOM- or Index-Errors.
 examples = [eg for eg in examples if len(eg[0]) < 300_000]
 
 audio_features = wave2vec_extractor(
 [eg[0] for eg in examples], sampling_rate=16_000, return_tensors='pt', padding='longest'
 ).input_values
 
 input_ids = tokenizer(
 [eg[1] for eg in examples], return_tensors='pt', padding=True
 ).input_ids
 
 return audio_features, input_ids

In [None]:
train = common_voice['train'].cast_column('audio', Audio(sampling_rate=16_000))
val = common_voice['validation'].cast_column('audio', Audio(sampling_rate=16_000))

In [None]:
random.seed(419)
val_inds = list(range(len(common_voice['validation'])))
random.shuffle(val_inds)

train_ds = AudioDataset(concatenate_datasets([train, val.select(val_inds[config['num_val_examples']:])]))
val_ds = AudioDataset(val.select(val_inds[:config['num_val_examples']]))

train_dl = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True, collate_fn=collate_fn, num_workers=4)
val_dl = DataLoader(val_ds, batch_size=config['batch_size'], shuffle=False, collate_fn=collate_fn, num_workers=4)

In [None]:
accelerator = Accelerator(fp16=True)
print(f'Using {accelerator.device}.')

In [None]:
optimizer = AdamW(model.parameters(), lr=config['base_lr'], weight_decay=config['weight_decay'])

In [None]:
model, optimizer, train_dl, val_dl = accelerator.prepare(model, optimizer, train_dl, val_dl)

In [None]:
with open(log_dir / 'config.json', 'w') as config_file:
 json.dump(config, config_file, indent=4)
 
writer = SummaryWriter(log_dir)
val_golds = common_voice['validation'].select(val_inds[:config['num_val_examples']])['sentence']
best_val_wer = 10.
global_train_step = 0

for epoch in range(config['max_epochs']):
 model.train()
 for batch_step, (audio_features, input_ids) in enumerate(train_dl):
 global_train_step += 1
 
 out = model(labels=input_ids, input_values=audio_features)
 accelerator.backward(out.loss)
 writer.add_scalar('train_loss', out.loss.item(), global_train_step)
 
 if (batch_step + 1) % config['accumulate_grad'] == 0:
 optimizer.step()
 optimizer.zero_grad()
 if batch_step % 300 == 0:
 print(out.loss.item())
 
 model.eval()
 val_preds = []
 for audio_features, input_ids in val_dl:
 with torch.no_grad():
 generated = model.generate(audio_features)
 val_preds += tokenizer.batch_decode(generated)
 val_preds = [pred.lstrip('~').rstrip('_') for pred in val_preds]
 wer = calculate_wer(val_preds, val_golds)
 writer.add_scalar('val_wer', wer, epoch)
 print('WER: ', wer)
 
 if wer < best_val_wer:
 torch.save(model.state_dict(), model_dir / 'model.pt')
 print('Saved Model.')
 best_val_wer = wer