|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Fine-tunes an ELECTRA model on a downstream task.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import argparse |
|
import collections |
|
import json |
|
|
|
import tensorflow as tf |
|
|
|
import configure_finetuning |
|
from finetune import preprocessing |
|
from finetune import task_builder |
|
from model import modeling |
|
from model import optimization |
|
from util import training_utils |
|
from util import utils |
|
|
|
|
|
class FinetuningModel(object): |
|
"""Finetuning model with support for multi-task training.""" |
|
|
|
def __init__(self, config: configure_finetuning.FinetuningConfig, tasks, |
|
is_training, features, num_train_steps): |
|
|
|
bert_config = training_utils.get_bert_config(config) |
|
self.bert_config = bert_config |
|
if config.debug: |
|
bert_config.num_hidden_layers = 3 |
|
bert_config.hidden_size = 144 |
|
bert_config.intermediate_size = 144 * 4 |
|
bert_config.num_attention_heads = 4 |
|
assert config.max_seq_length <= bert_config.max_position_embeddings |
|
bert_model = modeling.BertModel( |
|
bert_config=bert_config, |
|
is_training=is_training, |
|
input_ids=features["input_ids"], |
|
input_mask=features["input_mask"], |
|
token_type_ids=features["segment_ids"], |
|
use_one_hot_embeddings=config.use_tpu, |
|
embedding_size=config.embedding_size) |
|
percent_done = (tf.cast(tf.train.get_or_create_global_step(), tf.float32) / |
|
tf.cast(num_train_steps, tf.float32)) |
|
|
|
|
|
self.outputs = {"task_id": features["task_id"]} |
|
losses = [] |
|
for task in tasks: |
|
with tf.variable_scope("task_specific/" + task.name): |
|
task_losses, task_outputs = task.get_prediction_module( |
|
bert_model, features, is_training, percent_done) |
|
losses.append(task_losses) |
|
self.outputs[task.name] = task_outputs |
|
self.loss = tf.reduce_sum( |
|
tf.stack(losses, -1) * |
|
tf.one_hot(features["task_id"], len(config.task_names))) |
|
|
|
|
|
def model_fn_builder(config: configure_finetuning.FinetuningConfig, tasks, |
|
num_train_steps, pretraining_config=None): |
|
"""Returns `model_fn` closure for TPUEstimator.""" |
|
|
|
def model_fn(features, labels, mode, params): |
|
"""The `model_fn` for TPUEstimator.""" |
|
utils.log("Building model...") |
|
is_training = (mode == tf.estimator.ModeKeys.TRAIN) |
|
model = FinetuningModel( |
|
config, tasks, is_training, features, num_train_steps) |
|
|
|
|
|
init_checkpoint = config.init_checkpoint |
|
if pretraining_config is not None: |
|
init_checkpoint = tf.train.latest_checkpoint(pretraining_config.model_dir) |
|
utils.log("Using checkpoint", init_checkpoint) |
|
tvars = tf.trainable_variables() |
|
scaffold_fn = None |
|
if init_checkpoint: |
|
assignment_map, _ = modeling.get_assignment_map_from_checkpoint( |
|
tvars, init_checkpoint) |
|
if config.use_tpu: |
|
def tpu_scaffold(): |
|
tf.train.init_from_checkpoint(init_checkpoint, assignment_map) |
|
return tf.train.Scaffold() |
|
scaffold_fn = tpu_scaffold |
|
else: |
|
tf.train.init_from_checkpoint(init_checkpoint, assignment_map) |
|
|
|
|
|
if mode == tf.estimator.ModeKeys.TRAIN: |
|
train_op = optimization.create_optimizer( |
|
model.loss, config.learning_rate, num_train_steps, |
|
weight_decay_rate=config.weight_decay_rate, |
|
use_tpu=config.use_tpu, |
|
warmup_proportion=config.warmup_proportion, |
|
layerwise_lr_decay_power=config.layerwise_lr_decay, |
|
n_transformer_layers=model.bert_config.num_hidden_layers |
|
) |
|
output_spec = tf.estimator.tpu.TPUEstimatorSpec( |
|
mode=mode, |
|
loss=model.loss, |
|
train_op=train_op, |
|
scaffold_fn=scaffold_fn, |
|
training_hooks=[training_utils.ETAHook( |
|
{} if config.use_tpu else dict(loss=model.loss), |
|
num_train_steps, config.iterations_per_loop, config.use_tpu, 10)]) |
|
else: |
|
assert mode == tf.estimator.ModeKeys.PREDICT |
|
output_spec = tf.estimator.tpu.TPUEstimatorSpec( |
|
mode=mode, |
|
predictions=utils.flatten_dict(model.outputs), |
|
scaffold_fn=scaffold_fn) |
|
|
|
utils.log("Building complete") |
|
return output_spec |
|
|
|
return model_fn |
|
|
|
|
|
class ModelRunner(object): |
|
"""Fine-tunes a model on a supervised task.""" |
|
|
|
def __init__(self, config: configure_finetuning.FinetuningConfig, tasks, |
|
pretraining_config=None): |
|
self._config = config |
|
self._tasks = tasks |
|
self._preprocessor = preprocessing.Preprocessor(config, self._tasks) |
|
|
|
is_per_host = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V2 |
|
tpu_cluster_resolver = None |
|
if config.use_tpu and config.tpu_name: |
|
tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( |
|
config.tpu_name, zone=config.tpu_zone, project=config.gcp_project) |
|
tpu_config = tf.estimator.tpu.TPUConfig( |
|
iterations_per_loop=config.iterations_per_loop, |
|
num_shards=config.num_tpu_cores, |
|
per_host_input_for_training=is_per_host, |
|
tpu_job_name=config.tpu_job_name) |
|
run_config = tf.estimator.tpu.RunConfig( |
|
cluster=tpu_cluster_resolver, |
|
model_dir=config.model_dir, |
|
save_checkpoints_steps=config.save_checkpoints_steps, |
|
save_checkpoints_secs=None, |
|
tpu_config=tpu_config) |
|
|
|
if self._config.do_train: |
|
(self._train_input_fn, |
|
self.train_steps) = self._preprocessor.prepare_train() |
|
else: |
|
self._train_input_fn, self.train_steps = None, 0 |
|
model_fn = model_fn_builder( |
|
config=config, |
|
tasks=self._tasks, |
|
num_train_steps=self.train_steps, |
|
pretraining_config=pretraining_config) |
|
self._estimator = tf.estimator.tpu.TPUEstimator( |
|
use_tpu=config.use_tpu, |
|
model_fn=model_fn, |
|
config=run_config, |
|
train_batch_size=config.train_batch_size, |
|
eval_batch_size=config.eval_batch_size, |
|
predict_batch_size=config.predict_batch_size) |
|
|
|
def train(self): |
|
utils.log("Training for {:} steps".format(self.train_steps)) |
|
self._estimator.train( |
|
input_fn=self._train_input_fn, max_steps=self.train_steps) |
|
|
|
def evaluate(self): |
|
return {task.name: self.evaluate_task(task) for task in self._tasks} |
|
|
|
def evaluate_task(self, task, split="dev", return_results=True): |
|
"""Evaluate the current model.""" |
|
utils.log("Evaluating", task.name) |
|
eval_input_fn, _ = self._preprocessor.prepare_predict([task], split) |
|
results = self._estimator.predict(input_fn=eval_input_fn, |
|
yield_single_examples=True) |
|
scorer = task.get_scorer() |
|
for r in results: |
|
if r["task_id"] != len(self._tasks): |
|
r = utils.nest_dict(r, self._config.task_names) |
|
scorer.update(r[task.name]) |
|
if return_results: |
|
utils.log(task.name + ": " + scorer.results_str()) |
|
utils.log() |
|
return dict(scorer.get_results()) |
|
else: |
|
return scorer |
|
|
|
def write_classification_outputs(self, tasks, trial, split): |
|
"""Write classification predictions to disk.""" |
|
utils.log("Writing out predictions for", tasks, split) |
|
predict_input_fn, _ = self._preprocessor.prepare_predict(tasks, split) |
|
results = self._estimator.predict(input_fn=predict_input_fn, |
|
yield_single_examples=True) |
|
|
|
logits = collections.defaultdict(dict) |
|
for r in results: |
|
if r["task_id"] != len(self._tasks): |
|
r = utils.nest_dict(r, self._config.task_names) |
|
task_name = self._config.task_names[r["task_id"]] |
|
logits[task_name][r[task_name]["eid"]] = ( |
|
r[task_name]["logits"] if "logits" in r[task_name] |
|
else r[task_name]["predictions"]) |
|
for task_name in logits: |
|
utils.log("Pickling predictions for {:} {:} examples ({:})".format( |
|
len(logits[task_name]), task_name, split)) |
|
if trial <= self._config.n_writes_test: |
|
utils.write_pickle(logits[task_name], self._config.test_predictions( |
|
task_name, split, trial)) |
|
|
|
|
|
def write_results(config: configure_finetuning.FinetuningConfig, results): |
|
"""Write evaluation metrics to disk.""" |
|
utils.log("Writing results to", config.results_txt) |
|
utils.mkdir(config.results_txt.rsplit("/", 1)[0]) |
|
utils.write_pickle(results, config.results_pkl) |
|
with tf.io.gfile.GFile(config.results_txt, "w") as f: |
|
results_str = "" |
|
for trial_results in results: |
|
for task_name, task_results in trial_results.items(): |
|
if task_name == "time" or task_name == "global_step": |
|
continue |
|
results_str += task_name + ": " + " - ".join( |
|
["{:}: {:.2f}".format(k, v) |
|
for k, v in task_results.items()]) + "\n" |
|
f.write(results_str) |
|
utils.write_pickle(results, config.results_pkl) |
|
|
|
|
|
def run_finetuning(config: configure_finetuning.FinetuningConfig): |
|
"""Run finetuning.""" |
|
|
|
|
|
results = [] |
|
trial = 1 |
|
heading_info = "model={:}, trial {:}/{:}".format( |
|
config.model_name, trial, config.num_trials) |
|
heading = lambda msg: utils.heading(msg + ": " + heading_info) |
|
heading("Config") |
|
utils.log_config(config) |
|
generic_model_dir = config.model_dir |
|
tasks = task_builder.get_tasks(config) |
|
|
|
|
|
while config.num_trials < 0 or trial <= config.num_trials: |
|
config.model_dir = generic_model_dir + "_" + str(trial) |
|
if config.do_train: |
|
utils.rmkdir(config.model_dir) |
|
|
|
model_runner = ModelRunner(config, tasks) |
|
if config.do_train: |
|
heading("Start training") |
|
model_runner.train() |
|
utils.log() |
|
|
|
if config.do_eval: |
|
heading("Run dev set evaluation") |
|
results.append(model_runner.evaluate()) |
|
write_results(config, results) |
|
if config.write_test_outputs and trial <= config.n_writes_test: |
|
heading("Running on the test set and writing the predictions") |
|
for task in tasks: |
|
|
|
if task.name in ["cola", "mrpc", "mnli", "sst", "rte", "qnli", "qqp", |
|
"sts"]: |
|
for split in task.get_test_splits(): |
|
model_runner.write_classification_outputs([task], trial, split) |
|
elif task.name == "squad": |
|
scorer = model_runner.evaluate_task(task, "test", False) |
|
scorer.write_predictions() |
|
preds = utils.load_json(config.qa_preds_file("squad")) |
|
null_odds = utils.load_json(config.qa_na_file("squad")) |
|
for q, _ in preds.items(): |
|
if null_odds[q] > config.qa_na_threshold: |
|
preds[q] = "" |
|
utils.write_json(preds, config.test_predictions( |
|
task.name, "test", trial)) |
|
else: |
|
utils.log("Skipping task", task.name, |
|
"- writing predictions is not supported for this task") |
|
|
|
if trial != config.num_trials and (not config.keep_all_models): |
|
utils.rmrf(config.model_dir) |
|
trial += 1 |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description=__doc__) |
|
parser.add_argument("--data-dir", required=True, |
|
help="Location of data files (model weights, etc).") |
|
parser.add_argument("--model-name", required=True, |
|
help="The name of the model being fine-tuned.") |
|
parser.add_argument("--hparams", default="{}", |
|
help="JSON dict of model hyperparameters.") |
|
args = parser.parse_args() |
|
if args.hparams.endswith(".json"): |
|
hparams = utils.load_json(args.hparams) |
|
else: |
|
hparams = json.loads(args.hparams) |
|
tf.logging.set_verbosity(tf.logging.ERROR) |
|
run_finetuning(configure_finetuning.FinetuningConfig( |
|
args.model_name, args.data_dir, **hparams)) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|