JerryLiJinyi's picture
Upload 127 files
10b912d verified
raw
history blame
1.63 kB
import argparse
import json
import numpy as np
import tqdm
from pathlib import Path
from pprint import pprint
from collections import defaultdict, Counter
from transformers import AutoTokenizer
import scrl.utils as utils
from scrl.model import load_checkpoint
from scrl.metrics import compute_token_f1, rouge_scorer, ROUGE_TYPES
from nltk import word_tokenize
from scrl.rewards import load_rewards
from scrl.config import load_config
import time
def main(args):
model = load_checkpoint(Path(args.checkpoint), device=args.device)
tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
dataset = list(utils.read_jsonl(args.dataset))
batches = utils.batchify(dataset, args.batch_size)
outputs = []
t1 = time.time()
for items in tqdm.tqdm(batches):
sources = [x["text"] for x in items]
summaries = model.predict(sources, tokenizer, args.device)
for item, summary in zip(items, summaries):
output = {
"id": item["id"],
"pred-summary": summary,
}
outputs.append(output)
t2 = time.time()
print("Seconds:", t2-t1)
if args.output:
utils.write_jsonl(outputs, args.output, "w")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', required=True)
parser.add_argument('--output', required=False)
parser.add_argument('--checkpoint', required=True)
parser.add_argument('--device', default="cpu")
parser.add_argument('--batch-size', type=int, default=4)
return parser.parse_args()
if __name__ == '__main__':
main(parse_args())