File size: 4,749 Bytes
067016b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
# coding=utf-8
"""Config controlling hyperparameters for pre-training."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
class PretrainingConfig(object):
"""Defines pre-training hyperparameters."""
def __init__(self, model_name, data_dir, **kwargs):
self.model_name = model_name
self.debug = False # debug mode
self.do_train = True # pre-train
self.do_eval = False # evaluate generator/discriminator on unlabeled data
# loss functions
self.electra_objective = True # if False, use the BERT objective instead
self.gen_weight = 1.0 # masked language modeling / generator loss
self.disc_weight = 50.0 # discriminator loss
self.mask_prob = 0.15 # percent of input tokens to mask out / replace
# optimization
self.learning_rate = 1e-4
self.lr_decay_power = 1.0 # linear weight decay by default
self.weight_decay_rate = 0.01
self.num_warmup_steps = 20000
# training settings
self.iterations_per_loop = 200
self.save_checkpoints_steps = 50000
self.num_train_steps = 1000000
self.num_eval_steps = 10000
# model settings
self.model_size = "base" # one of "small", "medium-smal", or "base"
# override the default transformer hparams for the provided model size; see
# modeling.BertConfig for the possible hparams and util.training_utils for
# the defaults
self.model_hparam_overrides = (
kwargs["model_hparam_overrides"]
if "model_hparam_overrides" in kwargs else {})
self.embedding_size = None # bert hidden size by default
self.vocab_size = 50265 # number of tokens in the vocabulary
self.do_lower_case = False # lowercase the input?
# ConvBERT additional config
self.conv_kernel_size=9
self.linear_groups=2
self.head_ratio=2
self.conv_type="sdconv"
# generator settings
self.uniform_generator = False # generator is uniform at random
self.untied_generator_embeddings = False # tie generator/discriminator
# token embeddings?
self.untied_generator = True # tie all generator/discriminator weights?
self.generator_layers = 1.0 # frac of discriminator layers for generator
self.generator_hidden_size = 0.25 # frac of discrim hidden size for gen
self.disallow_correct = False # force the generator to sample incorrect
# tokens (so 15% of tokens are always
# fake)
self.temperature = 1.0 # temperature for sampling from generator
# batch sizes
self.max_seq_length = 512
self.train_batch_size = 128
self.eval_batch_size = 128
# TPU settings
self.use_tpu = True
self.tpu_job_name = None
self.num_tpu_cores = 8
self.tpu_name = "local" # cloud TPU to use for training
self.tpu_zone = None # GCE zone where the Cloud TPU is located in
self.gcp_project = None # project name for the Cloud TPU-enabled project
# default locations of data files
self.pretrain_tfrecords = "/researchdisk/train_tokenized_512/pretrain_data.tfrecord*"
self.vocab_file = "./vocab.txt"
self.model_dir = "./"
results_dir = os.path.join(self.model_dir, "results")
self.results_txt = os.path.join(results_dir, "unsup_results.txt")
self.results_pkl = os.path.join(results_dir, "unsup_results.pkl")
# update defaults with passed-in hyperparameters
self.update(kwargs)
self.max_predictions_per_seq = int((self.mask_prob + 0.005) *
self.max_seq_length)
# debug-mode settings
if self.debug:
self.train_batch_size = 8
self.num_train_steps = 20
self.eval_batch_size = 4
self.iterations_per_loop = 1
self.num_eval_steps = 2
# defaults for different-sized model
if self.model_size in ["medium-small"]:
self.embedding_size = 128
self.conv_kernel_size=9
self.linear_groups=2
self.head_ratio=2
elif self.model_size in ["small"]:
self.embedding_size = 128
self.conv_kernel_size=9
self.linear_groups=1
self.head_ratio=2
self.learning_rate = 3e-4
elif self.model_size in ["base"]:
self.generator_hidden_size = 1/3
self.learning_rate = 1e-4
self.train_batch_size = 256
self.eval_batch_size = 256
self.conv_kernel_size=9
self.linear_groups=1
self.head_ratio=2
# passed-in-arguments override (for example) debug-mode defaults
self.update(kwargs)
def update(self, kwargs):
for k, v in kwargs.items():
if k not in self.__dict__:
raise ValueError("Unknown hparam " + k)
self.__dict__[k] = v
|