|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import shutil |
|
import time |
|
from json import JSONDecodeError |
|
from logging import getLogger |
|
from pathlib import Path |
|
from typing import Dict, List |
|
|
|
import torch |
|
from torch.utils.data import DataLoader |
|
from tqdm import tqdm |
|
|
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
from utils import ( |
|
Seq2SeqDataset, |
|
calculate_bleu, |
|
calculate_rouge, |
|
chunks, |
|
lmap, |
|
load_json, |
|
parse_numeric_n_bool_cl_kwargs, |
|
save_json, |
|
use_task_specific_params, |
|
write_txt_file, |
|
) |
|
|
|
|
|
logger = getLogger(__name__) |
|
|
|
|
|
def eval_data_dir( |
|
data_dir, |
|
save_dir: str, |
|
model_name: str, |
|
bs: int = 8, |
|
max_source_length: int = 1024, |
|
type_path="val", |
|
n_obs=None, |
|
fp16=False, |
|
task="summarization", |
|
local_rank=None, |
|
num_return_sequences=1, |
|
dataset_kwargs: Dict = None, |
|
prefix="", |
|
**generate_kwargs, |
|
) -> Dict: |
|
"""Run evaluation on part of the data for one gpu and save to {save_dir}/rank_{rank}_output.json""" |
|
model_name = str(model_name) |
|
assert local_rank is not None |
|
torch.distributed.init_process_group(backend="nccl", rank=local_rank) |
|
|
|
save_dir = Path(save_dir) |
|
save_path = save_dir.joinpath(f"rank_{local_rank}_output.json") |
|
torch.cuda.set_device(local_rank) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).cuda() |
|
if fp16: |
|
model = model.half() |
|
|
|
use_task_specific_params(model, task) |
|
num_beams = generate_kwargs.pop("num_beams", model.config.num_beams) |
|
if num_return_sequences > num_beams: |
|
num_beams = num_return_sequences |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
logger.info(f"Inferred tokenizer type: {tokenizer.__class__}") |
|
|
|
if max_source_length is None: |
|
max_source_length = tokenizer.model_max_length |
|
if prefix is None: |
|
prefix = prefix or getattr(model.config, "prefix", "") or "" |
|
ds = Seq2SeqDataset( |
|
tokenizer, |
|
data_dir, |
|
max_source_length, |
|
max_target_length=1024, |
|
type_path=type_path, |
|
n_obs=n_obs, |
|
prefix=prefix, |
|
**dataset_kwargs, |
|
) |
|
|
|
|
|
sampler = ds.make_sortish_sampler(bs, distributed=True, add_extra_examples=False, shuffle=True) |
|
data_loader = DataLoader(ds, sampler=sampler, batch_size=bs, collate_fn=ds.collate_fn) |
|
results = [] |
|
for batch in tqdm(data_loader): |
|
summaries = model.generate( |
|
input_ids=batch["input_ids"].to(model.device), |
|
attention_mask=batch["attention_mask"].to(model.device), |
|
num_return_sequences=num_return_sequences, |
|
num_beams=num_beams, |
|
**generate_kwargs, |
|
) |
|
preds = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False) |
|
ids = batch["ids"] |
|
if num_return_sequences > 1: |
|
preds = chunks(preds, num_return_sequences) |
|
for i, pred in enumerate(preds): |
|
results.append({"pred": pred, "id": ids[i].item()}) |
|
save_json(results, save_path) |
|
return results, sampler.num_replicas |
|
|
|
|
|
def run_generate(): |
|
parser = argparse.ArgumentParser( |
|
epilog="Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate" |
|
) |
|
parser.add_argument("--data_dir", type=str, help="like cnn_dm/test.source") |
|
parser.add_argument( |
|
"--model_name", |
|
type=str, |
|
help="like facebook/bart-large-cnn,t5-base, etc.", |
|
default="sshleifer/distilbart-xsum-12-3", |
|
) |
|
parser.add_argument("--save_dir", type=str, help="where to save", default="tmp_gen") |
|
parser.add_argument("--max_source_length", type=int, default=None) |
|
parser.add_argument( |
|
"--type_path", type=str, default="test", help="which subset to evaluate typically train/val/test" |
|
) |
|
parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics") |
|
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size") |
|
parser.add_argument( |
|
"--local_rank", type=int, default=-1, required=False, help="should be passed by distributed.launch" |
|
) |
|
|
|
parser.add_argument( |
|
"--n_obs", type=int, default=None, required=False, help="How many observations. Defaults to all." |
|
) |
|
parser.add_argument( |
|
"--num_return_sequences", type=int, default=1, required=False, help="How many sequences to return" |
|
) |
|
parser.add_argument( |
|
"--sync_timeout", |
|
type=int, |
|
default=600, |
|
required=False, |
|
help="How long should master process wait for other processes to finish.", |
|
) |
|
parser.add_argument("--src_lang", type=str, default=None, required=False) |
|
parser.add_argument("--tgt_lang", type=str, default=None, required=False) |
|
parser.add_argument( |
|
"--prefix", type=str, required=False, default=None, help="will be added to the begininng of src examples" |
|
) |
|
parser.add_argument("--fp16", action="store_true") |
|
parser.add_argument("--debug", action="store_true") |
|
start_time = time.time() |
|
args, rest = parser.parse_known_args() |
|
generate_kwargs = parse_numeric_n_bool_cl_kwargs(rest) |
|
if generate_kwargs and args.local_rank <= 0: |
|
print(f"parsed the following generate kwargs: {generate_kwargs}") |
|
json_save_dir = Path(args.save_dir + "_tmp") |
|
Path(json_save_dir).mkdir(exist_ok=True) |
|
intermediate_files = list(json_save_dir.glob("rank_*.json")) |
|
if intermediate_files: |
|
raise ValueError(f"Found files at {json_save_dir} please move or remove them.") |
|
|
|
dataset_kwargs = {} |
|
if args.src_lang is not None: |
|
dataset_kwargs["src_lang"] = args.src_lang |
|
if args.tgt_lang is not None: |
|
dataset_kwargs["tgt_lang"] = args.tgt_lang |
|
|
|
Path(args.save_dir).mkdir(exist_ok=True) |
|
results, num_replicas = eval_data_dir( |
|
args.data_dir, |
|
json_save_dir, |
|
args.model_name, |
|
type_path=args.type_path, |
|
bs=args.bs, |
|
fp16=args.fp16, |
|
task=args.task, |
|
local_rank=args.local_rank, |
|
n_obs=args.n_obs, |
|
max_source_length=args.max_source_length, |
|
num_return_sequences=args.num_return_sequences, |
|
prefix=args.prefix, |
|
dataset_kwargs=dataset_kwargs, |
|
**generate_kwargs, |
|
) |
|
|
|
if args.local_rank <= 0: |
|
save_dir = Path(args.save_dir) |
|
save_dir.mkdir(exist_ok=True) |
|
partial_results = gather_results_from_each_node(num_replicas, json_save_dir, args.sync_timeout) |
|
preds = combine_partial_results(partial_results) |
|
if args.num_return_sequences > 1: |
|
save_path = save_dir.joinpath("pseudolabel_results.json") |
|
print(f"Saving aggregated results at {save_path}, intermediate in {json_save_dir}/") |
|
save_json(preds, save_path) |
|
return |
|
tgt_file = Path(args.data_dir).joinpath(args.type_path + ".target") |
|
with open(tgt_file) as f: |
|
labels = [x.rstrip() for x in f.readlines()][: len(preds)] |
|
|
|
|
|
calc_bleu = "translation" in args.task |
|
score_fn = calculate_bleu if calc_bleu else calculate_rouge |
|
metric_name = "bleu" if calc_bleu else "rouge" |
|
metrics: Dict = score_fn(preds, labels) |
|
metrics["n_obs"] = len(preds) |
|
runtime = time.time() - start_time |
|
metrics["seconds_per_sample"] = round(runtime / metrics["n_obs"], 4) |
|
metrics["n_gpus"] = num_replicas |
|
|
|
metrics_save_path = save_dir.joinpath(f"{args.type_path}_{metric_name}.json") |
|
save_json(metrics, metrics_save_path, indent=None) |
|
print(metrics) |
|
write_txt_file(preds, save_dir.joinpath(f"{args.type_path}_generations.txt")) |
|
if args.debug: |
|
write_txt_file(labels, save_dir.joinpath(f"{args.type_path}.target")) |
|
else: |
|
shutil.rmtree(json_save_dir) |
|
|
|
|
|
def combine_partial_results(partial_results) -> List: |
|
"""Concatenate partial results into one file, then sort it by id.""" |
|
records = [] |
|
for partial_result in partial_results: |
|
records.extend(partial_result) |
|
records = sorted(records, key=lambda x: x["id"]) |
|
preds = [x["pred"] for x in records] |
|
return preds |
|
|
|
|
|
def gather_results_from_each_node(num_replicas, save_dir, timeout) -> List[Dict[str, List]]: |
|
|
|
start_wait = time.time() |
|
logger.info("waiting for all nodes to finish") |
|
json_data = None |
|
while (time.time() - start_wait) < timeout: |
|
json_files = list(save_dir.glob("rank_*.json")) |
|
if len(json_files) < num_replicas: |
|
continue |
|
try: |
|
|
|
json_data = lmap(load_json, json_files) |
|
return json_data |
|
except JSONDecodeError: |
|
continue |
|
else: |
|
raise TimeoutError("Rank 0 gave up on waiting for other processes") |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
run_generate() |
|
|