import numpy as np from datasets import load_metric from transformers import logging import random wer_metric = load_metric("./model-bin/metrics/wer") # print(wer_metric) def compute_metrics_fn(processor): def compute(pred): pred_logits = pred.predictions pred_ids = np.argmax(pred_logits, axis=-1) pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id pred_str = processor.batch_decode(pred_ids) # we do not want to group tokens when computing the metrics label_str = processor.batch_decode(pred.label_ids, group_tokens=False) random_idx = random.randint(0, len(label_str)) logging.get_logger().info( '\n\n\nRandom sample predict:\nTruth: {}\nPredict: {}'.format(label_str[random_idx], pred_str[random_idx])) wer = wer_metric.compute(predictions=pred_str, references=label_str) return {"wer": wer} return compute