HarryLee commited on
Commit
22ee7e5
1 Parent(s): 3b57708

Upload evaluate.py

Browse files
Files changed (1) hide show
  1. evaluate.py +152 -0
evaluate.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import os
9
+ import sys
10
+ import json
11
+ from itertools import chain
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torch.distributed as dist
16
+ from fairseq import distributed_utils, options, tasks, utils
17
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
18
+ from fairseq.logging import progress_bar
19
+ from fairseq.utils import reset_logging
20
+ from omegaconf import DictConfig
21
+
22
+ from utils import checkpoint_utils
23
+ from utils.eval_utils import eval_step
24
+
25
+ logging.basicConfig(
26
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
27
+ datefmt="%Y-%m-%d %H:%M:%S",
28
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
29
+ stream=sys.stdout,
30
+ )
31
+ logger = logging.getLogger("ofa.evaluate")
32
+
33
+
34
+ def apply_half(t):
35
+ if t.dtype is torch.float32:
36
+ return t.to(dtype=torch.half)
37
+ return t
38
+
39
+
40
+ def main(cfg: DictConfig):
41
+ utils.import_user_module(cfg.common)
42
+
43
+ reset_logging()
44
+ logger.info(cfg)
45
+
46
+ assert (
47
+ cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
48
+ ), "Must specify batch size either with --max-tokens or --batch-size"
49
+
50
+ # Fix seed for stochastic decoding
51
+ if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
52
+ np.random.seed(cfg.common.seed)
53
+ utils.set_torch_seed(cfg.common.seed)
54
+
55
+ use_fp16 = cfg.common.fp16
56
+ use_cuda = torch.cuda.is_available() and not cfg.common.cpu
57
+
58
+ if use_cuda:
59
+ torch.cuda.set_device(cfg.distributed_training.device_id)
60
+
61
+ # Load ensemble
62
+ overrides = eval(cfg.common_eval.model_overrides)
63
+ logger.info("loading model(s) from {}".format(cfg.common_eval.path))
64
+ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
65
+ utils.split_paths(cfg.common_eval.path),
66
+ arg_overrides=overrides,
67
+ suffix=cfg.checkpoint.checkpoint_suffix,
68
+ strict=(cfg.checkpoint.checkpoint_shard_count == 1),
69
+ num_shards=cfg.checkpoint.checkpoint_shard_count,
70
+ )
71
+
72
+ # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
73
+ task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)
74
+
75
+ # Move models to GPU
76
+ for model in models:
77
+ model.eval()
78
+ if use_fp16:
79
+ model.half()
80
+ if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
81
+ model.cuda()
82
+ model.prepare_for_inference_(cfg)
83
+
84
+ # Load dataset (possibly sharded)
85
+ itr = task.get_batch_iterator(
86
+ dataset=task.dataset(cfg.dataset.gen_subset),
87
+ max_tokens=cfg.dataset.max_tokens,
88
+ max_sentences=cfg.dataset.batch_size,
89
+ max_positions=utils.resolve_max_positions(
90
+ task.max_positions(), *[m.max_positions() for m in models]
91
+ ),
92
+ ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
93
+ required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
94
+ seed=cfg.common.seed,
95
+ num_shards=cfg.distributed_training.distributed_world_size,
96
+ shard_id=cfg.distributed_training.distributed_rank,
97
+ num_workers=cfg.dataset.num_workers,
98
+ data_buffer_size=cfg.dataset.data_buffer_size,
99
+ ).next_epoch_itr(shuffle=False)
100
+ progress = progress_bar.progress_bar(
101
+ itr,
102
+ log_format=cfg.common.log_format,
103
+ log_interval=cfg.common.log_interval,
104
+ default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
105
+ )
106
+
107
+ # Initialize generator
108
+ generator = task.build_generator(models, cfg.generation)
109
+
110
+ results = []
111
+ score_sum = torch.FloatTensor([0]).cuda()
112
+ score_cnt = torch.FloatTensor([0]).cuda()
113
+ for sample in progress:
114
+ if "net_input" not in sample:
115
+ continue
116
+ sample = utils.move_to_cuda(sample) if use_cuda else sample
117
+ sample = utils.apply_to_sample(apply_half, sample) if cfg.common.fp16 else sample
118
+ with torch.no_grad():
119
+ result, scores = eval_step(task, generator, models, sample)
120
+ results += result
121
+ score_sum += sum(scores) if scores is not None else 0
122
+ score_cnt += len(scores) if scores is not None else 0
123
+ progress.log({"sentences": sample["nsentences"]})
124
+
125
+ gather_results = None
126
+ if cfg.distributed_training.distributed_world_size > 1:
127
+ gather_results = [None for _ in range(dist.get_world_size())]
128
+ dist.all_gather_object(gather_results, results)
129
+ dist.all_reduce(score_sum.data)
130
+ dist.all_reduce(score_cnt.data)
131
+ if score_cnt.item() > 0:
132
+ logger.info("score_sum: {}, score_cnt: {}, score: {}".format(
133
+ score_sum, score_cnt, round(score_sum.item() / score_cnt.item(), 4)
134
+ ))
135
+
136
+ if cfg.distributed_training.distributed_world_size == 1 or dist.get_rank() == 0:
137
+ os.makedirs(cfg.common_eval.results_path, exist_ok=True)
138
+ output_path = os.path.join(cfg.common_eval.results_path, "{}_predict.json".format(cfg.dataset.gen_subset))
139
+ gather_results = list(chain(*gather_results)) if gather_results is not None else results
140
+ with open(output_path, 'w') as fw:
141
+ json.dump(gather_results, fw)
142
+
143
+
144
+ def cli_main():
145
+ parser = options.get_generation_parser()
146
+ args = options.parse_args_and_arch(parser)
147
+ cfg = convert_namespace_to_omegaconf(args)
148
+ distributed_utils.call_main(cfg, main)
149
+
150
+
151
+ if __name__ == "__main__":
152
+ cli_main()