jsnfly
increase beam size
9024e1e
import argparse
import re
from typing import Dict
import torch
from datasets import Audio, Dataset, load_dataset, load_metric
from transformers import AutoFeatureExtractor, AutoTokenizer, SpeechEncoderDecoderModel, pipeline
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.models.encoder_decoder.modeling_encoder_decoder import shift_tokens_right
from transformers.modeling_outputs import Seq2SeqLMOutput
def log_results(result: Dataset, args: Dict[str, str]):
"""DO NOT CHANGE. This function computes and logs the result metrics."""
log_outputs = args.log_outputs
dataset_id = "_".join(args.dataset.split("/") + [args.config, args.split])
# load metric
wer = load_metric("wer")
cer = load_metric("cer")
# compute metrics
wer_result = wer.compute(references=result["target"], predictions=result["prediction"])
cer_result = cer.compute(references=result["target"], predictions=result["prediction"])
# print & log results
result_str = f"WER: {wer_result}\n" f"CER: {cer_result}"
print(result_str)
with open(f"{dataset_id}_eval_results.txt", "w") as f:
f.write(result_str)
# log all results in text file. Possibly interesting for analysis
if log_outputs is not None:
pred_file = f"log_{dataset_id}_predictions.txt"
target_file = f"log_{dataset_id}_targets.txt"
with open(pred_file, "w") as p, open(target_file, "w") as t:
# mapping function to write output
def write_to_file(batch, i):
p.write(f"{i}" + "\n")
p.write(batch["prediction"] + "\n")
t.write(f"{i}" + "\n")
t.write(batch["target"] + "\n")
result.map(write_to_file, with_indices=True)
def normalize_text(text: str) -> str:
"""DO ADAPT FOR YOUR USE CASE. this function normalizes the target text."""
# From https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-german.
CHARS_TO_IGNORE = [",", "?", "¿", ".", "!", "¡", ";", ";", ":", '""', "%", '"', "�", "ʿ", "·", "჻", "~", "՞",
"؟", "،", "।", "॥", "«", "»", "„", "“", "”", "「", "」", "‘", "’", "《", "》", "(", ")", "[", "]",
"{", "}", "=", "`", "_", "+", "<", ">", "…", "–", "°", "´", "ʾ", "‹", "›", "©", "®", "—", "→", "。",
"、", "﹂", "﹁", "‧", "~", "﹏", ",", "{", "}", "(", ")", "[", "]", "【", "】", "‥", "〽",
"『", "』", "〝", "〟", "⟨", "⟩", "〜", ":", "!", "?", "♪", "؛", "/", "\\", "º", "−", "^", "ʻ", "ˆ"]
chars_to_ignore_regex = f"[{re.escape(''.join(CHARS_TO_IGNORE))}]"
text = re.sub(chars_to_ignore_regex, "", text.lower())
return text
def main(args):
# load dataset
dataset = load_dataset(args.dataset, args.config, split=args.split, use_auth_token=True)
# # for testing: only process the first two examples as a test
# dataset = dataset.select(range(10))
# load processor
feature_extractor = AutoFeatureExtractor.from_pretrained(args.model_id)
sampling_rate = feature_extractor.sampling_rate
# resample audio
dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))
# load tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_id)
# load model
model = Wav2VecGPT2Model.from_pretrained(args.model_id)
model.config.num_beams = 4
# load eval pipeline
if args.device is None:
args.device = 0 if torch.cuda.is_available() else -1
asr = pipeline("automatic-speech-recognition", model=model, device=args.device,
feature_extractor=feature_extractor, tokenizer=tokenizer)
# map function to decode audio
def map_to_pred(batch):
prediction = asr(
batch["audio"]["array"], chunk_length_s=args.chunk_length_s, stride_length_s=args.stride_length_s
)
batch["prediction"] = normalize_text(prediction["text"])
batch["target"] = normalize_text(batch["sentence"])
return batch
# run inference on all examples
result = dataset.map(map_to_pred, remove_columns=dataset.column_names)
# compute and log_results
# do not change function below
log_results(result, args)
class Wav2VecGPT2Model(SpeechEncoderDecoderModel):
"""
Basically the same as `SpeechEncoderDecoderModel` but position embeddings (initialized with GPT2's position
embeddings) are added to encoder output
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.encoder_outputs_pos_emb = nn.Embedding(1024, self.decoder.config.hidden_size)
with torch.no_grad():
self.encoder_outputs_pos_emb.weight.copy_(self.decoder.transformer.wpe.weight)
self.enc_to_dec_proj_ln = nn.LayerNorm(self.decoder.config.hidden_size,
eps=self.decoder.config.layer_norm_epsilon)
def __getattribute__(self, name):
# Fake class so it is recognized as seq2seq model.
if name == '__class__':
return SpeechEncoderDecoderModel
return SpeechEncoderDecoderModel.__getattribute__(self, name)
def forward(
self,
inputs=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs=None,
past_key_values=None,
decoder_inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
input_values=None,
input_features=None,
return_dict=None,
**kwargs,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
kwargs_decoder = {
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
}
if encoder_outputs is None and inputs is None:
if input_values is not None and input_features is not None:
raise ValueError("You cannot specify both input_values and input_features at the same time")
elif input_values is not None:
inputs = input_values
elif input_features is not None:
inputs = input_features
else:
raise ValueError("You have to specify either input_values or input_features")
encoder_outputs = self.encoder(
inputs,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs_encoder,
)
encoder_hidden_states = encoder_outputs[0]
# optionally project encoder_hidden_states
if (
self.encoder_output_dim != self.decoder.config.hidden_size
and self.decoder.config.cross_attention_hidden_size is None
):
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
encoder_hidden_states += self.encoder_outputs_pos_emb(
torch.arange(0, encoder_hidden_states.shape[1], device=encoder_hidden_states.device)
)
encoder_hidden_states = self.enc_to_dec_proj_ln(encoder_hidden_states)
# compute correct encoder attention mask
if attention_mask is not None:
encoder_attention_mask = self.encoder._get_feature_vector_attention_mask(
encoder_hidden_states.shape[1], attention_mask
)
else:
encoder_attention_mask = None
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
past_key_values=past_key_values,
return_dict=return_dict,
**kwargs_decoder,
)
# Compute loss independent from decoder (as some shift the logits inside them)
loss = None
if labels is not None:
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))
if not return_dict:
if loss is not None:
return (loss,) + decoder_outputs + encoder_outputs
else:
return decoder_outputs + encoder_outputs
return Seq2SeqLMOutput(
loss=loss,
logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_id", type=str, required=True, help="Model identifier. Should be loadable with 🤗 Transformers"
)
parser.add_argument(
"--dataset",
type=str,
required=True,
help="Dataset name to evaluate the `model_id`. Should be loadable with 🤗 Datasets",
)
parser.add_argument(
"--config", type=str, required=True, help="Config of the dataset. *E.g.* `'en'` for Common Voice"
)
parser.add_argument("--split", type=str, required=True, help="Split of the dataset. *E.g.* `'test'`")
parser.add_argument(
"--chunk_length_s", type=float, default=None, help="Chunk length in seconds. Defaults to 5 seconds."
)
parser.add_argument(
"--stride_length_s", type=float, default=None, help="Stride of the audio chunks. Defaults to 1 second."
)
parser.add_argument(
"--log_outputs", action="store_true", help="If defined, write outputs to log file for analysis."
)
parser.add_argument(
"--device",
type=int,
default=None,
help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
)
args = parser.parse_args()
main(args)