Spaces:
Running
Running
import argparse | |
import functools | |
import gc | |
import os | |
import evaluate | |
import numpy as np | |
import torch | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
from transformers import WhisperForConditionalGeneration, WhisperProcessor | |
from utils.data_utils import DataCollatorSpeechSeq2SeqWithPadding, remove_punctuation, to_simple | |
from utils.reader import CustomDataset | |
from utils.utils import print_arguments, add_arguments | |
parser = argparse.ArgumentParser(description=__doc__) | |
add_arg = functools.partial(add_arguments, argparser=parser) | |
add_arg("test_data", type=str, default="dataset/test.json", help="测试集的路径") | |
add_arg("model_path", type=str, default="models/whisper-tiny-finetune", help="合并模型的路径,或者是huggingface上模型的名称") | |
add_arg("batch_size", type=int, default=16, help="评估的batch size") | |
add_arg("num_workers", type=int, default=8, help="读取数据的线程数量") | |
add_arg("language", type=str, default="Chinese", help="设置语言,可全称也可简写,如果为None则评估的是多语言") | |
add_arg("remove_pun", type=bool, default=True, help="是否移除标点符号") | |
add_arg("to_simple", type=bool, default=True, help="是否转为简体中文") | |
add_arg("timestamps", type=bool, default=False, help="评估时是否使用时间戳数据") | |
add_arg("min_audio_len", type=float, default=0.5, help="最小的音频长度,单位秒") | |
add_arg("max_audio_len", type=float, default=30, help="最大的音频长度,单位秒") | |
add_arg("local_files_only", type=bool, default=True, help="是否只在本地加载模型,不尝试下载") | |
add_arg("task", type=str, default="transcribe", choices=['transcribe', 'translate'], help="模型的任务") | |
add_arg("metric", type=str, default="cer", choices=['cer', 'wer'], help="评估方式") | |
args = parser.parse_args() | |
print_arguments(args) | |
# 判断模型路径是否合法 | |
assert 'openai' == os.path.dirname(args.model_path) or os.path.exists(args.model_path), \ | |
f"模型文件{args.model_path}不存在,请检查是否已经成功合并模型,或者是否为huggingface存在模型" | |
# 获取Whisper的数据处理器,这个包含了特征提取器、tokenizer | |
processor = WhisperProcessor.from_pretrained(args.model_path, | |
language=args.language, | |
task=args.task, | |
no_timestamps=not args.timestamps, | |
local_files_only=args.local_files_only) | |
forced_decoder_ids = processor.get_decoder_prompt_ids() | |
# 获取模型 | |
model = WhisperForConditionalGeneration.from_pretrained(args.model_path, | |
device_map="auto", | |
local_files_only=args.local_files_only) | |
model.eval() | |
# 获取测试数据 | |
test_dataset = CustomDataset(data_list_path=args.test_data, | |
processor=processor, | |
timestamps=args.timestamps, | |
min_duration=args.min_audio_len, | |
max_duration=args.max_audio_len) | |
print(f"测试数据:{len(test_dataset)}") | |
# 数据padding器 | |
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor) | |
eval_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, | |
num_workers=args.num_workers, collate_fn=data_collator) | |
# 获取评估方法 | |
metric = evaluate.load(args.metric) | |
# 开始评估 | |
for step, batch in enumerate(tqdm(eval_dataloader)): | |
with torch.cuda.amp.autocast(): | |
with torch.no_grad(): | |
generated_tokens = ( | |
model.generate( | |
input_features=batch["input_features"].cuda(), | |
decoder_input_ids=batch["labels"][:, :4].cuda(), | |
forced_decoder_ids=forced_decoder_ids, | |
max_new_tokens=255).cpu().numpy()) | |
labels = batch["labels"].cpu().numpy() | |
labels = np.where(labels != -100, labels, processor.tokenizer.pad_token_id) | |
# 将预测和实际的token转换为文本 | |
decoded_preds = processor.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) | |
decoded_labels = processor.tokenizer.batch_decode(labels, skip_special_tokens=True) | |
# 删除标点符号 | |
if args.remove_pun: | |
decoded_preds = remove_punctuation(decoded_preds) | |
decoded_labels = remove_punctuation(decoded_labels) | |
# 将繁体中文总成简体中文 | |
if args.to_simple: | |
decoded_preds = to_simple(decoded_preds) | |
decoded_labels = to_simple(decoded_labels) | |
metric.add_batch(predictions=decoded_preds, references=decoded_labels) | |
# 删除计算的记录 | |
del generated_tokens, labels, batch | |
gc.collect() | |
# 计算评估结果 | |
m = metric.compute() | |
print(f"评估结果:{args.metric}={round(m, 5)}") | |