from torch.utils.data import DataLoader from bleu import calculate_bleu_score from load_dataset import load_local_bleu_dataset from dataset import BilingualDataset from config import load_config from load_and_save_model import load_model_tokenizer def get_bleu_of_model(config) -> float: model, src_tokenizer, tgt_tokenizer = load_model_tokenizer(config) bleu_ds_raw = load_local_bleu_dataset( src_dataset_filename='datasets/'+config['dataset']['bleu_dataset']+'.'+config['dataset']['src_lang'], tgt_dataset_filename='datasets/'+config['dataset']['bleu_dataset']+'.'+config['dataset']['tgt_lang'], src_lang=config['dataset']['src_lang'], tgt_lang=config['dataset']['tgt_lang'], ) bleu_ds = BilingualDataset( ds=bleu_ds_raw, src_tokenizer=src_tokenizer, tgt_tokenizer=tgt_tokenizer, src_lang=config['dataset']['src_lang'], tgt_lang=config['dataset']['tgt_lang'], src_max_seq_len=config['dataset']['src_max_seq_len'], tgt_max_seq_len=config['dataset']['tgt_max_seq_len'], ) bleu_dataloader = DataLoader(bleu_ds, batch_size=1, shuffle=True) return calculate_bleu_score( model, bleu_dataloader, src_tokenizer, tgt_tokenizer, ) if __name__ == '__main__': for file_name in {'config_final.yaml', 'config_huge.yaml', 'config_big.yaml', 'config_small.yaml'}: config = load_config(file_name) print(get_bleu_of_model(config), f" is the BLEU of {file_name}", sep='')