aapot commited on
Commit
067016b
1 Parent(s): 7b87c8d

Add convbert pretrain hyperparams

Browse files
Files changed (1) hide show
  1. configure_pretraining.py +131 -0
configure_pretraining.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+ """Config controlling hyperparameters for pre-training."""
4
+
5
+ from __future__ import absolute_import
6
+ from __future__ import division
7
+ from __future__ import print_function
8
+
9
+ import os
10
+
11
+
12
+ class PretrainingConfig(object):
13
+ """Defines pre-training hyperparameters."""
14
+
15
+ def __init__(self, model_name, data_dir, **kwargs):
16
+ self.model_name = model_name
17
+ self.debug = False # debug mode
18
+ self.do_train = True # pre-train
19
+ self.do_eval = False # evaluate generator/discriminator on unlabeled data
20
+
21
+ # loss functions
22
+ self.electra_objective = True # if False, use the BERT objective instead
23
+ self.gen_weight = 1.0 # masked language modeling / generator loss
24
+ self.disc_weight = 50.0 # discriminator loss
25
+ self.mask_prob = 0.15 # percent of input tokens to mask out / replace
26
+
27
+ # optimization
28
+ self.learning_rate = 1e-4
29
+ self.lr_decay_power = 1.0 # linear weight decay by default
30
+ self.weight_decay_rate = 0.01
31
+ self.num_warmup_steps = 20000
32
+
33
+ # training settings
34
+ self.iterations_per_loop = 200
35
+ self.save_checkpoints_steps = 50000
36
+ self.num_train_steps = 1000000
37
+ self.num_eval_steps = 10000
38
+
39
+ # model settings
40
+ self.model_size = "base" # one of "small", "medium-smal", or "base"
41
+ # override the default transformer hparams for the provided model size; see
42
+ # modeling.BertConfig for the possible hparams and util.training_utils for
43
+ # the defaults
44
+ self.model_hparam_overrides = (
45
+ kwargs["model_hparam_overrides"]
46
+ if "model_hparam_overrides" in kwargs else {})
47
+ self.embedding_size = None # bert hidden size by default
48
+ self.vocab_size = 50265 # number of tokens in the vocabulary
49
+ self.do_lower_case = False # lowercase the input?
50
+
51
+ # ConvBERT additional config
52
+ self.conv_kernel_size=9
53
+ self.linear_groups=2
54
+ self.head_ratio=2
55
+ self.conv_type="sdconv"
56
+ # generator settings
57
+ self.uniform_generator = False # generator is uniform at random
58
+ self.untied_generator_embeddings = False # tie generator/discriminator
59
+ # token embeddings?
60
+ self.untied_generator = True # tie all generator/discriminator weights?
61
+ self.generator_layers = 1.0 # frac of discriminator layers for generator
62
+ self.generator_hidden_size = 0.25 # frac of discrim hidden size for gen
63
+ self.disallow_correct = False # force the generator to sample incorrect
64
+ # tokens (so 15% of tokens are always
65
+ # fake)
66
+ self.temperature = 1.0 # temperature for sampling from generator
67
+
68
+ # batch sizes
69
+ self.max_seq_length = 512
70
+ self.train_batch_size = 128
71
+ self.eval_batch_size = 128
72
+
73
+ # TPU settings
74
+ self.use_tpu = True
75
+ self.tpu_job_name = None
76
+ self.num_tpu_cores = 8
77
+ self.tpu_name = "local" # cloud TPU to use for training
78
+ self.tpu_zone = None # GCE zone where the Cloud TPU is located in
79
+ self.gcp_project = None # project name for the Cloud TPU-enabled project
80
+
81
+ # default locations of data files
82
+ self.pretrain_tfrecords = "/researchdisk/train_tokenized_512/pretrain_data.tfrecord*"
83
+ self.vocab_file = "./vocab.txt"
84
+ self.model_dir = "./"
85
+ results_dir = os.path.join(self.model_dir, "results")
86
+ self.results_txt = os.path.join(results_dir, "unsup_results.txt")
87
+ self.results_pkl = os.path.join(results_dir, "unsup_results.pkl")
88
+
89
+ # update defaults with passed-in hyperparameters
90
+ self.update(kwargs)
91
+
92
+ self.max_predictions_per_seq = int((self.mask_prob + 0.005) *
93
+ self.max_seq_length)
94
+
95
+ # debug-mode settings
96
+ if self.debug:
97
+ self.train_batch_size = 8
98
+ self.num_train_steps = 20
99
+ self.eval_batch_size = 4
100
+ self.iterations_per_loop = 1
101
+ self.num_eval_steps = 2
102
+
103
+ # defaults for different-sized model
104
+ if self.model_size in ["medium-small"]:
105
+ self.embedding_size = 128
106
+ self.conv_kernel_size=9
107
+ self.linear_groups=2
108
+ self.head_ratio=2
109
+ elif self.model_size in ["small"]:
110
+ self.embedding_size = 128
111
+ self.conv_kernel_size=9
112
+ self.linear_groups=1
113
+ self.head_ratio=2
114
+ self.learning_rate = 3e-4
115
+ elif self.model_size in ["base"]:
116
+ self.generator_hidden_size = 1/3
117
+ self.learning_rate = 1e-4
118
+ self.train_batch_size = 256
119
+ self.eval_batch_size = 256
120
+ self.conv_kernel_size=9
121
+ self.linear_groups=1
122
+ self.head_ratio=2
123
+
124
+ # passed-in-arguments override (for example) debug-mode defaults
125
+ self.update(kwargs)
126
+
127
+ def update(self, kwargs):
128
+ for k, v in kwargs.items():
129
+ if k not in self.__dict__:
130
+ raise ValueError("Unknown hparam " + k)
131
+ self.__dict__[k] = v