Spaces:
Sleeping
Sleeping
import glob | |
import sys | |
from pathlib import Path | |
import shutil | |
import os | |
from espnet2.tasks.s2t import S2TTask | |
from espnet2.text.sentencepiece_tokenizer import SentencepiecesTokenizer | |
from espnet2.text.token_id_converter import TokenIDConverter | |
from espnet2.s2t.espnet_model import ESPnetS2TModel | |
from espnet2.bin.s2t_inference import Speech2Text | |
import espnetez as ez | |
import torch | |
import numpy as np | |
import logging | |
import gradio as gr | |
import librosa | |
class Logger: | |
def __init__(self, filename): | |
self.terminal = sys.stdout | |
self.log = open(filename, "w") | |
def write(self, message): | |
self.terminal.write(message) | |
self.log.write(message) | |
def flush(self): | |
self.terminal.flush() | |
self.log.flush() | |
def isatty(self): | |
return False | |
sys.stdout = Logger("output.log") | |
def count_parameters(model): | |
return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
def get_dataset(data_path, data_info, test_count=10): | |
# load data | |
data = {} | |
keys = [] | |
with open(f"{data_path}/text", "r", encoding="utf-8") as f: | |
for line in f.readlines(): | |
audio_id, text = line.split(maxsplit=1) | |
data[audio_id.strip()] = {"text": text.strip()} | |
keys.append(audio_id.strip()) | |
# load text_ctc data | |
with open(f"{data_path}/text_ctc", "r", encoding="utf-8") as f: | |
for line in f.readlines(): | |
audio_id, text = line.split(maxsplit=1) | |
data[audio_id.strip()]["text_ctc"] = text.strip() | |
# load audio path | |
for audio_path in glob.glob(f"{data_path}/audio/*"): | |
audio_id = Path(audio_path).stem | |
data[audio_id]["audio_path"] = audio_path | |
# Convert to list | |
data = [{ | |
'id': audio_id, | |
'text': data[audio_id]['text'], | |
'text_ctc': data[audio_id]['text_ctc'], | |
'audio_path': data[audio_id]['audio_path'], | |
} for audio_id in keys] | |
return ez.dataset.ESPnetEZDataset(data[test_count:], data_info), ez.dataset.ESPnetEZDataset(data[:test_count], data_info), data[:test_count] | |
class CustomFinetuneModel(ESPnetS2TModel): | |
def __init__(self, model, log_every=500): | |
super().__init__( | |
vocab_size=model.vocab_size, | |
token_list=model.token_list, | |
frontend=model.frontend, | |
specaug=model.specaug, | |
normalize=model.normalize, | |
preencoder=model.preencoder, | |
encoder=model.encoder, | |
postencoder=model.postencoder, | |
decoder=model.decoder, | |
ctc=model.ctc, | |
ctc_weight=model.ctc_weight, | |
interctc_weight=model.interctc_weight, | |
ignore_id=model.ignore_id, | |
lsm_weight=0.0, | |
length_normalized_loss=False, | |
report_cer=False, | |
report_wer=False, | |
sym_space="<space>", | |
sym_blank="<blank>", | |
sym_sos = "<sos>", | |
sym_eos = "<eos>", | |
sym_sop = "<sop>", # start of prev | |
sym_na = "<na>", # not available | |
extract_feats_in_collect_stats=model.extract_feats_in_collect_stats, | |
) | |
self.iter_count = 0 | |
self.log_every = log_every | |
self.log_stats = { | |
'loss': 0.0, | |
'acc': 0.0 | |
} | |
def forward(self, *args, **kwargs): | |
out = super().forward(*args, **kwargs) | |
self.log_stats['loss'] += out[1]['loss'].item() | |
self.log_stats['acc'] += out[1]['acc'].item() | |
self.iter_count += 1 | |
if self.iter_count % self.log_every == 0: | |
loss = self.log_stats['loss'] / self.log_every | |
acc = self.log_stats['acc'] / self.log_every | |
print(f"[{self.iter_count}] - loss: {loss:.3f} - acc: {acc:.3f}") | |
self.log_stats['loss'] = 0.0 | |
self.log_stats['acc'] = 0.0 | |
return out | |
def finetune_model(lang, task, tempdir_path, log_every, max_epoch, scheduler, warmup_steps, optimizer, learning_rate, weight_decay): | |
"""Main function for finetuning the model.""" | |
print("Start generating baseline...") | |
gr.Info("Start generating baseline...") | |
baseline_model(lang, task, tempdir_path) | |
print("Start loading dataset...") | |
gr.Info("Start Fine-tuning process...") | |
if len(tempdir_path) == 0: | |
raise gr.Error("Please upload a zip file first.") | |
# define tokenizer | |
tokenizer = SentencepiecesTokenizer("assets/owsm_ebf_v3.1_base/bpe.model") | |
converter = TokenIDConverter("assets/owsm_ebf_v3.1_base/tokens.txt") | |
def tokenize(text): | |
return np.array(converter.tokens2ids(tokenizer.text2tokens(text))) | |
data_info = { | |
"speech": lambda d: librosa.load(d["audio_path"], sr=16000)[0], | |
"text": lambda d: tokenize(f"<{lang}><{task}><notimestamps> {d['text']}"), | |
"text_ctc": lambda d: tokenize(d["text_ctc"]), | |
"text_prev": lambda d: tokenize("<na>"), | |
} | |
# load dataset and define data_info | |
train_dataset, test_dataset, test_list = get_dataset(tempdir_path, data_info) | |
print("Loaded dataset.") | |
gr.Info("Loaded dataset.") | |
# load and update configuration | |
print("Setting up the training configuration...") | |
pretrain_config = ez.config.from_yaml( | |
"s2t", | |
"assets/owsm_ebf_v3.1_base/config.yaml", | |
) | |
finetune_config = ez.config.update_finetune_config( | |
"s2t", pretrain_config, "assets/owsm_ebf_v3.1_base/owsm_finetune_base.yaml" | |
) | |
finetune_config['max_epoch'] = max_epoch | |
finetune_config['optim'] = optimizer | |
finetune_config['optim_conf']['lr'] = learning_rate | |
finetune_config['optim_conf']['weight_decay'] = weight_decay | |
finetune_config['scheduler'] = scheduler | |
finetune_config['scheduler_conf']['warmup_steps'] = warmup_steps | |
finetune_config['multiple_iterator'] = False | |
finetune_config['num_iters_per_epoch'] = None | |
def build_model_fn(args): | |
model, _ = S2TTask.build_model_from_file( | |
"assets/owsm_ebf_v3.1_base/config.yaml", | |
"assets/owsm_ebf_v3.1_base/owsm_v3.1_base.trained.pth", | |
device="cuda" if torch.cuda.is_available() else "cpu", | |
) | |
model.train() | |
print(f'Trainable parameters: {count_parameters(model)}') | |
model = CustomFinetuneModel(model, log_every=log_every) | |
return model | |
trainer = ez.Trainer( | |
task='s2t', | |
train_config=finetune_config, | |
train_dataset=train_dataset, | |
valid_dataset=test_dataset, | |
build_model_fn=build_model_fn, # provide the pre-trained model | |
data_info=data_info, | |
output_dir=f"{tempdir_path}/exp/finetune", | |
stats_dir=f"{tempdir_path}/exp/stats", | |
ngpu=1 | |
) | |
gr.Info("start collect stats") | |
print("Start collect stats process...") | |
trainer.collect_stats() | |
gr.Info("Finished collect stats, starting training.") | |
print("Finished collect stats process. Start training.") | |
trainer.train() | |
gr.Info("Finished Fine-tuning! Archiving experiment files...") | |
print("Finished fine-tuning.") | |
print("Start archiving experiment files...") | |
print("Create zip file for the following files into `finetune.zip`:") | |
for f in glob.glob(f"{tempdir_path}/exp/finetune/*"): | |
print(f.replace(tempdir_path, "")) | |
shutil.make_archive(f"{tempdir_path}/finetune", 'zip', f"{tempdir_path}/exp") | |
gr.Info("Finished generating result file in zip!") | |
print("Finished archiving experiment files.") | |
print("Start generating test result...") | |
gr.Info("Start generating output for test set!") | |
del trainer | |
model = Speech2Text( | |
"assets/owsm_ebf_v3.1_base/config.yaml", | |
"assets/owsm_ebf_v3.1_base/owsm_v3.1_base.trained.pth", | |
device="cuda" if torch.cuda.is_available() else "cpu", | |
token_type="bpe", | |
bpemodel="assets/owsm_ebf_v3.1_base/bpe.model", | |
beam_size=5, | |
ctc_weight=0.3, | |
lang_sym=f"<{lang}>", | |
task_sym=f"<{task}>", | |
) | |
model.s2t_model.eval() | |
d = torch.load(f"{tempdir_path}/exp/finetune/valid.acc.ave.pth") | |
model.s2t_model.load_state_dict(d) | |
hyp = "" | |
with open(f"{tempdir_path}/hyp.txt", "w") as f_hyp: | |
for i in range(len(test_list)): | |
data = test_list[i] | |
out = model(librosa.load(data['audio_path'], sr=16000)[0])[0][3] | |
f_hyp.write(out + '\n') | |
hyp += out + '\n' | |
return [f"{tempdir_path}/finetune.zip", f"{tempdir_path}/ref.txt", f"{tempdir_path}/base.txt", f"{tempdir_path}/hyp.txt"], hyp | |
def baseline_model(lang, task, tempdir_path): | |
print("Start loading dataset...") | |
if len(tempdir_path) == 0: | |
raise gr.Error("Please upload a zip file first.") | |
# define tokenizer | |
tokenizer = SentencepiecesTokenizer("assets/owsm_ebf_v3.1_base/bpe.model") | |
converter = TokenIDConverter("assets/owsm_ebf_v3.1_base/tokens.txt") | |
def tokenize(text): | |
return np.array(converter.tokens2ids(tokenizer.text2tokens(text))) | |
data_info = { | |
"speech": lambda d: librosa.load(d["audio_path"], sr=16000)[0], | |
"text": lambda d: tokenize(f"<{lang}><{task}><notimestamps> {d['text']}"), | |
"text_ctc": lambda d: tokenize(d["text_ctc"]), | |
"text_prev": lambda d: tokenize("<na>"), | |
} | |
# load dataset and define data_info | |
train_dataset, test_dataset, test_list = get_dataset(tempdir_path, data_info) | |
print("Loaded dataset.") | |
gr.Info("Loaded dataset.") | |
print("Loading pretrained model...") | |
gr.Info("Loading pretrained model...") | |
model = Speech2Text( | |
"assets/owsm_ebf_v3.1_base/config.yaml", | |
"assets/owsm_ebf_v3.1_base/owsm_v3.1_base.trained.pth", | |
device="cuda" if torch.cuda.is_available() else "cpu", | |
token_type="bpe", | |
bpemodel="assets/owsm_ebf_v3.1_base/bpe.model", | |
beam_size=5, | |
ctc_weight=0.3, | |
lang_sym=f"<{lang}>", | |
task_sym=f"<{task}>", | |
) | |
model.s2t_model.eval() | |
base = "" | |
ref = "" | |
with open(f"{tempdir_path}/base.txt", "w") as f_base, open(f"{tempdir_path}/ref.txt", "w") as f_ref: | |
for i in range(len(test_list)): | |
data = test_list[i] | |
f_ref.write(data['text'] + '\n') | |
out = model(librosa.load(data['audio_path'], sr=16000)[0])[0][3] | |
f_base.write(out + '\n') | |
ref += data['text'] + '\n' | |
base += out + '\n' | |
return ref, base | |