|
|
|
|
|
"""Config controlling hyperparameters for pre-training.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import os |
|
|
|
|
|
class PretrainingConfig(object): |
|
"""Defines pre-training hyperparameters.""" |
|
|
|
def __init__(self, model_name, data_dir, **kwargs): |
|
self.model_name = model_name |
|
self.debug = False |
|
self.do_train = True |
|
self.do_eval = False |
|
|
|
|
|
self.electra_objective = True |
|
self.gen_weight = 1.0 |
|
self.disc_weight = 50.0 |
|
self.mask_prob = 0.15 |
|
|
|
|
|
self.learning_rate = 1e-4 |
|
self.lr_decay_power = 1.0 |
|
self.weight_decay_rate = 0.01 |
|
self.num_warmup_steps = 20000 |
|
|
|
|
|
self.iterations_per_loop = 200 |
|
self.save_checkpoints_steps = 50000 |
|
self.num_train_steps = 1000000 |
|
self.num_eval_steps = 10000 |
|
|
|
|
|
self.model_size = "base" |
|
|
|
|
|
|
|
self.model_hparam_overrides = ( |
|
kwargs["model_hparam_overrides"] |
|
if "model_hparam_overrides" in kwargs else {}) |
|
self.embedding_size = None |
|
self.vocab_size = 50265 |
|
self.do_lower_case = False |
|
|
|
|
|
self.conv_kernel_size=9 |
|
self.linear_groups=2 |
|
self.head_ratio=2 |
|
self.conv_type="sdconv" |
|
|
|
self.uniform_generator = False |
|
self.untied_generator_embeddings = False |
|
|
|
self.untied_generator = True |
|
self.generator_layers = 1.0 |
|
self.generator_hidden_size = 0.25 |
|
self.disallow_correct = False |
|
|
|
|
|
self.temperature = 1.0 |
|
|
|
|
|
self.max_seq_length = 512 |
|
self.train_batch_size = 128 |
|
self.eval_batch_size = 128 |
|
|
|
|
|
self.use_tpu = True |
|
self.tpu_job_name = None |
|
self.num_tpu_cores = 8 |
|
self.tpu_name = "local" |
|
self.tpu_zone = None |
|
self.gcp_project = None |
|
|
|
|
|
self.pretrain_tfrecords = "/researchdisk/train_tokenized_512/pretrain_data.tfrecord*" |
|
self.vocab_file = "./vocab.txt" |
|
self.model_dir = "./" |
|
results_dir = os.path.join(self.model_dir, "results") |
|
self.results_txt = os.path.join(results_dir, "unsup_results.txt") |
|
self.results_pkl = os.path.join(results_dir, "unsup_results.pkl") |
|
|
|
|
|
self.update(kwargs) |
|
|
|
self.max_predictions_per_seq = int((self.mask_prob + 0.005) * |
|
self.max_seq_length) |
|
|
|
|
|
if self.debug: |
|
self.train_batch_size = 8 |
|
self.num_train_steps = 20 |
|
self.eval_batch_size = 4 |
|
self.iterations_per_loop = 1 |
|
self.num_eval_steps = 2 |
|
|
|
|
|
if self.model_size in ["medium-small"]: |
|
self.embedding_size = 128 |
|
self.conv_kernel_size=9 |
|
self.linear_groups=2 |
|
self.head_ratio=2 |
|
elif self.model_size in ["small"]: |
|
self.embedding_size = 128 |
|
self.conv_kernel_size=9 |
|
self.linear_groups=1 |
|
self.head_ratio=2 |
|
self.learning_rate = 3e-4 |
|
elif self.model_size in ["base"]: |
|
self.generator_hidden_size = 1/3 |
|
self.learning_rate = 1e-4 |
|
self.train_batch_size = 256 |
|
self.eval_batch_size = 256 |
|
self.conv_kernel_size=9 |
|
self.linear_groups=1 |
|
self.head_ratio=2 |
|
|
|
|
|
self.update(kwargs) |
|
|
|
def update(self, kwargs): |
|
for k, v in kwargs.items(): |
|
if k not in self.__dict__: |
|
raise ValueError("Unknown hparam " + k) |
|
self.__dict__[k] = v |
|
|