Chen YiJia
first epoch
5cedadf
raw
history blame
No virus
1.03 kB
from datasets.load import load_dataset
import logging
import sacrebleu
import pandas as pd
from simpletransformers.t5 import T5Model, T5Args
raw_datasets = load_dataset('iwslt2017', 'iwslt2017-zh-en')
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)
model_args = T5Args()
model_args.max_length = 512
model_args.length_penalty = 1
model_args.num_beams = 10
model = T5Model("mt5", "outputs", args=model_args)
en_zh_test = pd.DataFrame(raw_datasets['test']['translation'])
zh_truth = en_zh_test['zh'].tolist()
en_input = en_zh_test['en'].tolist()
zh_preds = model.predict(en_input)
en_zh_bleu = sacrebleu.corpus_bleu(zh_preds, zh_truth)
print("----------------------------------------------")
print("English to Chinese: ", en_zh_bleu.score)
en_preds = model.predict(zh_truth)
zh_en_bleu = sacrebleu.corpus_bleu(en_preds, en_input)
print("----------------------------------------------")
print("Chinese to English: ", zh_en_bleu.score)