Edit model card

English ASR sequence-to-sequence model. This model supports output normalizing text, labeling timestamps, and segmenting multiple speakers.

# !pip install transformers sentencepiece

from transformers import SpeechEncoderDecoderModel
from transformers import AutoFeatureExtractor, AutoTokenizer, GenerationConfig
import torchaudio
import torch

model_path = 'nguyenvulebinh/wav2vec2-bartpho'
model = SpeechEncoderDecoderModel.from_pretrained(model_path).eval()
feature_extractor = AutoFeatureExtractor.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
if torch.cuda.is_available():
  model = model.cuda()


def decode_tokens(token_ids, skip_special_tokens=True, time_precision=0.02):
    timestamp_begin = tokenizer.vocab_size
    outputs = [[]]
    for token in token_ids:
        if token >= timestamp_begin:
            timestamp = f" |{(token - timestamp_begin) * time_precision:.2f}| "
            outputs.append(timestamp)
            outputs.append([])
        else:
            outputs[-1].append(token)
    outputs = [
        s if isinstance(s, str) else tokenizer.decode(s, skip_special_tokens=skip_special_tokens) for s in outputs
    ]
    return "".join(outputs).replace("< |", "<|").replace("| >", "|>")

def decode_wav(audio_wavs, asr_model, prefix=""):
  device = next(asr_model.parameters()).device
  input_values = feature_extractor.pad(
    [{"input_values": feature} for feature in audio_wavs],
    padding=True,
    max_length=None,
    pad_to_multiple_of=None,
    return_tensors="pt",
  )

  output_beam_ids = asr_model.generate(
    input_values['input_values'].to(device), 
    attention_mask=input_values['attention_mask'].to(device),
    decoder_input_ids=tokenizer.batch_encode_plus([prefix] * len(audio_wavs), return_tensors="pt")['input_ids'][..., :-1].to(device),
    generation_config=GenerationConfig(decoder_start_token_id=tokenizer.bos_token_id),
    max_length=250, 
    num_beams=25, 
    no_repeat_ngram_size=4, 
    num_return_sequences=1, 
    early_stopping=True,
    return_dict_in_generate=True,
    output_scores=True,
  )

  output_text = [decode_tokens(sequence) for sequence in output_beam_ids.sequences]

  return output_text


# https://huggingface.co/nguyenvulebinh/wavlm-bart/resolve/main/sample.wav
print(decode_wav([torchaudio.load('sample.wav')[0].squeeze()], model))

# <|0.06| What are the many parts that make a machine learning system feel like it works so magically cheap? |5.86|>
# <|5.68| Explletability factors important, so they tend to gear towards more simpler models with less parameters, but easier to explain, and on the other spectrum there are |15.86|>

Citation

This repository uses the idea from the following paper. Please cite the paper if this model is used to help produce published results or is incorporated into other software.

@INPROCEEDINGS{10446589,
  author={Nguyen, Thai-Binh and Waibel, Alexander},
  booktitle={ICASSP 2024 - 2024 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 
  title={Synthetic Conversations Improve Multi-Talker ASR}, 
  year={2024},
  volume={},
  number={},
  pages={10461-10465},
  keywords={Systematics;Error analysis;Knowledge based systems;Oral communication;Signal processing;Data models;Acoustics;multi-talker;asr;synthetic conversation},
  doi={10.1109/ICASSP48485.2024.10446589}
}

Contact

nguyenvulebinh@gmail.com

Follow

Downloads last month
5