owsm_finetune / finetune.py
ms180's picture
Update finetune.py
a9329d9 verified
raw
history blame
10.4 kB
import glob
import sys
from pathlib import Path
import shutil
import os
from espnet2.tasks.s2t import S2TTask
from espnet2.text.sentencepiece_tokenizer import SentencepiecesTokenizer
from espnet2.text.token_id_converter import TokenIDConverter
from espnet2.s2t.espnet_model import ESPnetS2TModel
from espnet2.bin.s2t_inference import Speech2Text
import espnetez as ez
import torch
import numpy as np
import logging
import gradio as gr
import librosa
class Logger:
def __init__(self, filename):
self.terminal = sys.stdout
self.log = open(filename, "w")
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
self.terminal.flush()
self.log.flush()
def isatty(self):
return False
sys.stdout = Logger("output.log")
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def get_dataset(data_path, data_info, test_count=10):
# load data
data = {}
keys = []
with open(f"{data_path}/text", "r", encoding="utf-8") as f:
for line in f.readlines():
audio_id, text = line.split(maxsplit=1)
data[audio_id.strip()] = {"text": text.strip()}
keys.append(audio_id.strip())
# load text_ctc data
with open(f"{data_path}/text_ctc", "r", encoding="utf-8") as f:
for line in f.readlines():
audio_id, text = line.split(maxsplit=1)
data[audio_id.strip()]["text_ctc"] = text.strip()
# load audio path
for audio_path in glob.glob(f"{data_path}/audio/*"):
audio_id = Path(audio_path).stem
data[audio_id]["audio_path"] = audio_path
# Convert to list
data = [{
'id': audio_id,
'text': data[audio_id]['text'],
'text_ctc': data[audio_id]['text_ctc'],
'audio_path': data[audio_id]['audio_path'],
} for audio_id in keys]
return ez.dataset.ESPnetEZDataset(data[test_count:], data_info), ez.dataset.ESPnetEZDataset(data[:test_count], data_info), data[:test_count]
class CustomFinetuneModel(ESPnetS2TModel):
def __init__(self, model, log_every=500):
super().__init__(
vocab_size=model.vocab_size,
token_list=model.token_list,
frontend=model.frontend,
specaug=model.specaug,
normalize=model.normalize,
preencoder=model.preencoder,
encoder=model.encoder,
postencoder=model.postencoder,
decoder=model.decoder,
ctc=model.ctc,
ctc_weight=model.ctc_weight,
interctc_weight=model.interctc_weight,
ignore_id=model.ignore_id,
lsm_weight=0.0,
length_normalized_loss=False,
report_cer=False,
report_wer=False,
sym_space="<space>",
sym_blank="<blank>",
sym_sos = "<sos>",
sym_eos = "<eos>",
sym_sop = "<sop>", # start of prev
sym_na = "<na>", # not available
extract_feats_in_collect_stats=model.extract_feats_in_collect_stats,
)
self.iter_count = 0
self.log_every = log_every
self.log_stats = {
'loss': 0.0,
'acc': 0.0
}
def forward(self, *args, **kwargs):
out = super().forward(*args, **kwargs)
self.log_stats['loss'] += out[1]['loss'].item()
self.log_stats['acc'] += out[1]['acc'].item()
self.iter_count += 1
if self.iter_count % self.log_every == 0:
loss = self.log_stats['loss'] / self.log_every
acc = self.log_stats['acc'] / self.log_every
print(f"[{self.iter_count}] - loss: {loss:.3f} - acc: {acc:.3f}")
self.log_stats['loss'] = 0.0
self.log_stats['acc'] = 0.0
return out
def finetune_model(lang, task, tempdir_path, log_every, max_epoch, scheduler, warmup_steps, optimizer, learning_rate, weight_decay):
"""Main function for finetuning the model."""
print("Start generating baseline...")
gr.Info("Start generating baseline...")
baseline_model(lang, task, tempdir_path)
print("Start loading dataset...")
gr.Info("Start Fine-tuning process...")
if len(tempdir_path) == 0:
raise gr.Error("Please upload a zip file first.")
# define tokenizer
tokenizer = SentencepiecesTokenizer("assets/owsm_ebf_v3.1_base/bpe.model")
converter = TokenIDConverter("assets/owsm_ebf_v3.1_base/tokens.txt")
def tokenize(text):
return np.array(converter.tokens2ids(tokenizer.text2tokens(text)))
data_info = {
"speech": lambda d: librosa.load(d["audio_path"], sr=16000)[0],
"text": lambda d: tokenize(f"<{lang}><{task}><notimestamps> {d['text']}"),
"text_ctc": lambda d: tokenize(d["text_ctc"]),
"text_prev": lambda d: tokenize("<na>"),
}
# load dataset and define data_info
train_dataset, test_dataset, test_list = get_dataset(tempdir_path, data_info)
print("Loaded dataset.")
gr.Info("Loaded dataset.")
# load and update configuration
print("Setting up the training configuration...")
pretrain_config = ez.config.from_yaml(
"s2t",
"assets/owsm_ebf_v3.1_base/config.yaml",
)
finetune_config = ez.config.update_finetune_config(
"s2t", pretrain_config, "assets/owsm_ebf_v3.1_base/owsm_finetune_base.yaml"
)
finetune_config['max_epoch'] = max_epoch
finetune_config['optim'] = optimizer
finetune_config['optim_conf']['lr'] = learning_rate
finetune_config['optim_conf']['weight_decay'] = weight_decay
finetune_config['scheduler'] = scheduler
finetune_config['scheduler_conf']['warmup_steps'] = warmup_steps
finetune_config['multiple_iterator'] = False
finetune_config['num_iters_per_epoch'] = None
def build_model_fn(args):
model, _ = S2TTask.build_model_from_file(
"assets/owsm_ebf_v3.1_base/config.yaml",
"assets/owsm_ebf_v3.1_base/owsm_v3.1_base.trained.pth",
device="cuda" if torch.cuda.is_available() else "cpu",
)
model.train()
print(f'Trainable parameters: {count_parameters(model)}')
model = CustomFinetuneModel(model, log_every=log_every)
return model
trainer = ez.Trainer(
task='s2t',
train_config=finetune_config,
train_dataset=train_dataset,
valid_dataset=test_dataset,
build_model_fn=build_model_fn, # provide the pre-trained model
data_info=data_info,
output_dir=f"{tempdir_path}/exp/finetune",
stats_dir=f"{tempdir_path}/exp/stats",
ngpu=1
)
gr.Info("start collect stats")
print("Start collect stats process...")
trainer.collect_stats()
gr.Info("Finished collect stats, starting training.")
print("Finished collect stats process. Start training.")
trainer.train()
gr.Info("Finished Fine-tuning! Archiving experiment files...")
print("Finished fine-tuning.")
print("Start archiving experiment files...")
print("Create zip file for the following files into `finetune.zip`:")
for f in glob.glob(f"{tempdir_path}/exp/finetune/*"):
print(f.replace(tempdir_path, ""))
shutil.make_archive(f"{tempdir_path}/finetune", 'zip', f"{tempdir_path}/exp")
gr.Info("Finished generating result file in zip!")
print("Finished archiving experiment files.")
print("Start generating test result...")
gr.Info("Start generating output for test set!")
del trainer
model = Speech2Text(
"assets/owsm_ebf_v3.1_base/config.yaml",
"assets/owsm_ebf_v3.1_base/owsm_v3.1_base.trained.pth",
device="cuda" if torch.cuda.is_available() else "cpu",
token_type="bpe",
bpemodel="assets/owsm_ebf_v3.1_base/bpe.model",
beam_size=5,
ctc_weight=0.3,
lang_sym=f"<{lang}>",
task_sym=f"<{task}>",
)
model.s2t_model.eval()
d = torch.load(f"{tempdir_path}/exp/finetune/valid.acc.ave.pth")
model.s2t_model.load_state_dict(d)
hyp = ""
with open(f"{tempdir_path}/hyp.txt", "w") as f_hyp:
for i in range(len(test_list)):
data = test_list[i]
out = model(librosa.load(data['audio_path'], sr=16000)[0])[0][3]
f_hyp.write(out + '\n')
hyp += out + '\n'
return [f"{tempdir_path}/finetune.zip", f"{tempdir_path}/ref.txt", f"{tempdir_path}/base.txt", f"{tempdir_path}/hyp.txt"], hyp
def baseline_model(lang, task, tempdir_path):
print("Start loading dataset...")
if len(tempdir_path) == 0:
raise gr.Error("Please upload a zip file first.")
# define tokenizer
tokenizer = SentencepiecesTokenizer("assets/owsm_ebf_v3.1_base/bpe.model")
converter = TokenIDConverter("assets/owsm_ebf_v3.1_base/tokens.txt")
def tokenize(text):
return np.array(converter.tokens2ids(tokenizer.text2tokens(text)))
data_info = {
"speech": lambda d: librosa.load(d["audio_path"], sr=16000)[0],
"text": lambda d: tokenize(f"<{lang}><{task}><notimestamps> {d['text']}"),
"text_ctc": lambda d: tokenize(d["text_ctc"]),
"text_prev": lambda d: tokenize("<na>"),
}
# load dataset and define data_info
train_dataset, test_dataset, test_list = get_dataset(tempdir_path, data_info)
print("Loaded dataset.")
gr.Info("Loaded dataset.")
print("Loading pretrained model...")
gr.Info("Loading pretrained model...")
model = Speech2Text(
"assets/owsm_ebf_v3.1_base/config.yaml",
"assets/owsm_ebf_v3.1_base/owsm_v3.1_base.trained.pth",
device="cuda" if torch.cuda.is_available() else "cpu",
token_type="bpe",
bpemodel="assets/owsm_ebf_v3.1_base/bpe.model",
beam_size=5,
ctc_weight=0.3,
lang_sym=f"<{lang}>",
task_sym=f"<{task}>",
)
model.s2t_model.eval()
base = ""
ref = ""
with open(f"{tempdir_path}/base.txt", "w") as f_base, open(f"{tempdir_path}/ref.txt", "w") as f_ref:
for i in range(len(test_list)):
data = test_list[i]
f_ref.write(data['text'] + '\n')
out = model(librosa.load(data['audio_path'], sr=16000)[0])[0][3]
f_base.write(out + '\n')
ref += data['text'] + '\n'
base += out + '\n'
return ref, base