# test from dump file import argparse import time from pathlib import Path import numpy as np import torch from AR.data.dataset import Text2SemanticDataset from AR.models.t2s_lightning_module import Text2SemanticLightningModule from AR.utils.io import load_yaml_config from torch.utils.data import DataLoader def parse_args(): # parse args and config parser = argparse.ArgumentParser( description="Run SoundStorm AR S1 model for test set.") parser.add_argument( '--config_file', type=str, default='conf/default.yaml', help='path of config file') # args for dataset parser.add_argument( '--test_semantic_path', type=str, default='dump/test/semantic_token.tsv') parser.add_argument( '--test_phoneme_path', type=str, default='dump/test/phonemes.npy') parser.add_argument( '--ckpt_path', type=str, default='exp/default/ckpt/epoch=99-step=49000.ckpt', help='Checkpoint file of SoundStorm AR S1 model.') parser.add_argument("--output_dir", type=str, help="output dir.") args = parser.parse_args() return args def main(): args = parse_args() config = load_yaml_config(args.config_file) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) batch_size = 1 hz = 50 max_sec = config['data']['max_sec'] # get dataset test_dataset = Text2SemanticDataset( phoneme_path=args.test_phoneme_path, semantic_path=args.test_semantic_path, # max_sec 需要与训练时保持一致,不然可能会效果不好,重复漏字等 # 但是这里设置太短又会直接过滤掉太长的样本,为了防止被过滤掉,可以在 infer 的时候截断 max_sec=100, max_sample=8, pad_val=config['data']['pad_val']) # get model t2s_model = Text2SemanticLightningModule.load_from_checkpoint( checkpoint_path=args.ckpt_path, config=config) t2s_model.cuda() t2s_model.eval() # 获取 batch_size 条 # 创建 DataLoader,并指定 collate_fn 函数 dataloader = DataLoader( test_dataset, batch_size=batch_size, shuffle=False, collate_fn=test_dataset.collate) item_names = test_dataset.__get_item_names__() # 逐批次读取数据, bs=1、shuffle=False 时可以用 __get_item_names__ 对应 semantic_data = [['item_name', 'semantic_audio']] for i, batch in enumerate(dataloader): # 要保证 bs = 1 utt_id = item_names[i] if i == 0: print("utt_id:", utt_id) # bs > 1 时会补零 # 与 validation_step() 保持一致 semantic_len = batch['semantic_ids'].size(1) # 以 batch['semantic_ids'] 的前 150 个为 prompt # 多次合成,前 prompt_len 个是一样的,而且和 prompt 一样 prompt_len = min(int(semantic_len * 0.5), 150) # 输入纯文本时 prompt 该输入什么?=> see t2s.py prompt = batch['semantic_ids'][:, :prompt_len] # # zero prompt => 也可以输出文本内容正确的 semantic token, 但是音色是乱的 # 证明 semantic token 中还是包含了音色信息 # prompt = torch.ones( # batch['semantic_ids'].size(0), 1, dtype=torch.int32) * 0 # print("prompt:", prompt) # print("prompt.shape:", prompt.shape) np.save(output_dir / 'prompt.npy', prompt.detach().cpu().numpy()) st = time.time() with torch.no_grad(): # calculate acc for test loss, acc = t2s_model.model.forward( batch['phoneme_ids'].cuda(), batch['phoneme_ids_len'].cuda(), batch['semantic_ids'].cuda(), batch['semantic_ids_len'].cuda()) print("top_3_acc of this batch:", acc) pred_semantic = t2s_model.model.infer( batch['phoneme_ids'].cuda(), batch['phoneme_ids_len'].cuda(), prompt.cuda(), top_k=config['inference']['top_k'], # hz * max_sec in train dataloader # 生成的长度是 1002 应该是有一些 pad early_stop_num=hz * max_sec) # bs = 1 pred_semantic = pred_semantic[0] print(f'{time.time() - st} sec used in T2S') semantic_token = pred_semantic.detach().cpu().numpy().tolist() semantic_token_str = ' '.join(str(x) for x in semantic_token) semantic_data.append([utt_id, semantic_token_str]) else: break delimiter = '\t' filename = output_dir / "semantic_token.tsv" with open(filename, 'w', encoding='utf-8') as writer: for row in semantic_data: line = delimiter.join(row) writer.write(line + '\n') if __name__ == "__main__": main()