Vivek commited on
Commit
ae69f53
1 Parent(s): a55bc40

delete files

Browse files
results_tensorboard/events.out.tfevents.1626288298.t1v-n-8cb15980-w-0.712426.3.v2 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f41c2fcf7a5ad8f782c33825800d708bad20828d2790507cafaa45daf459f37d
3
- size 588
 
 
 
 
results_tensorboard/events.out.tfevents.1626289169.t1v-n-8cb15980-w-0.716987.3.v2 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:054d9f29e34dc409faa7d550c77e842a61072adc0c91d10b6498b98fe0307085
3
- size 588
 
 
 
 
results_tensorboard/events.out.tfevents.1626327484.t1v-n-8cb15980-w-0.750616.3.v2 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:68aac888ec5f8514ef03c8e2c6d595ae71a355c889643e9423ac98676c00c233
3
- size 588
 
 
 
 
results_tensorboard/events.out.tfevents.1626327887.t1v-n-8cb15980-w-0.752609.3.v2 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:7657f9a1a65b35a4544634bbf3dc9352eec8b111862d47d2015fba24536579f5
3
- size 596
 
 
 
 
results_tensorboard/events.out.tfevents.1626332656.t1v-n-8cb15980-w-0.759220.3.v2 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:781634a9c25358932557a2489094d955a8b955179a710e6f7aebfb2bdd4333dc
3
- size 866
 
 
 
 
results_tensorboard/events.out.tfevents.1626333695.t1v-n-8cb15980-w-0.761814.3.v2 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a6e77accc72524145a973a76b92872232f60143270fb2a5ded477efb5c5d043e
3
- size 176
 
 
 
 
results_tensorboard/events.out.tfevents.1626333832.t1v-n-8cb15980-w-0.763322.3.v2 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:593046ea18ca7b10638c83f187d1f66a534142780b3b663c774efa0a7135c26c
3
- size 456
 
 
 
 
train.py DELETED
@@ -1,240 +0,0 @@
1
- import jax
2
- print(jax.local_device_count())
3
- import jax.numpy as jnp
4
-
5
- import flax
6
- import flax.linen as nn
7
- from flax.training.common_utils import get_metrics,onehot,shard,shard_prng_key
8
- from flax.training import train_state
9
- from flax.metrics.tensorboard import SummaryWriter
10
- from flax.training import checkpoints
11
-
12
-
13
- import logging
14
- import optax
15
- import math
16
- from tqdm import tqdm
17
-
18
- from pathlib import Path
19
- from typing import Callable
20
- from itertools import chain
21
- from flax.metrics import tensorboard
22
-
23
- from datasets import load_dataset,load_metric
24
- from transformers import GPT2Config,GPT2Tokenizer
25
-
26
- from model_file import FlaxGPT2ForMultipleChoice
27
-
28
- logger = logging.getLogger()
29
- logger.setLevel(logging.INFO)
30
- from flax.jax_utils import unreplicate
31
-
32
- def main():
33
-
34
-
35
- tokenizer=GPT2Tokenizer.from_pretrained('gpt2',pad_token='<|endoftext|>')
36
-
37
- dataset=load_dataset('cosmos_qa')
38
-
39
- def preprocess(example):
40
- example['context&question']=example['context']+example['question']
41
- example['first_sentence']=[example['context&question'],example['context&question'],example['context&question'],example['context&question']]
42
- example['second_sentence']=example['answer0'],example['answer1'],example['answer2'],example['answer3']
43
- return example
44
-
45
- train_dataset=dataset['train'].map(preprocess)
46
- validation_dataset=dataset['validation'].map(preprocess)
47
- test_dataset=dataset['test'].map(preprocess)
48
-
49
- #Remove after experiment
50
- len_train_dataset=100
51
- len_validation_dataset=100
52
- len_test_dataset=100
53
-
54
- train_dataset=train_dataset.select(range(len_train_dataset))
55
- test_dataset=test_dataset.select(range(len_validation_dataset))
56
- validation_dataset=validation_dataset.select(range(len_test_dataset))
57
-
58
- #remove_cols=train_dataset.column_names
59
-
60
- def tokenize(examples):
61
- a=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,max_length=256,return_tensors='jax')
62
- a['labels']=examples['label']
63
- return a
64
-
65
- train_dataset=train_dataset.map(tokenize)
66
- validation_dataset=validation_dataset.map(tokenize)
67
- test_dataset=test_dataset.map(tokenize)
68
-
69
- remov_col=['id', 'context', 'question', 'answer0', 'answer1', 'answer2', 'answer3', 'labels', 'context&question', 'first_sentence', 'second_sentence']
70
-
71
- train_dataset=train_dataset.remove_columns(remov_col)
72
- validation_dataset=validation_dataset.remove_columns(remov_col)
73
- test_dataset=test_dataset.remove_columns(remov_col)
74
-
75
- per_device_batch_size=4
76
- seed=0
77
- num_train_epochs=1
78
- learning_rate=2e-5
79
-
80
-
81
- total_batch_size = per_device_batch_size * jax.local_device_count()
82
- print('The overall batch size (both for training and eval) is', total_batch_size)
83
- num_train_steps = len(train_dataset) // total_batch_size * num_train_epochs
84
- num_validation_steps=len(validation_dataset)//total_batch_size*num_train_epochs
85
-
86
- learning_rate_function = optax.linear_schedule(init_value=learning_rate, end_value=0, transition_steps=num_train_steps)
87
-
88
- class TrainState(train_state.TrainState):
89
- logits_function:Callable=flax.struct.field(pytree_node=False)
90
- loss_function:Callable=flax.struct.field(pytree_node=False)
91
-
92
- def adamw(weight_decay):
93
- return optax.adamw(learning_rate=learning_rate_function,b1=0.9,b2=0.99,eps=1e-6,weight_decay=weight_decay)
94
-
95
- decay_path=lambda p:not any(x in p for x in ['bias','LayerNorm.weight'])
96
-
97
- def traverse(function):
98
- def mask(data):
99
- flat=flax.traverse_util.flatten_dict(data)
100
- return flax.traverse_util.unflatten_dict({k:function(k,v) for k,v in flat.items()})
101
- return mask
102
-
103
- gradient_transformation=optax.chain(
104
- optax.masked(adamw(0.0),mask=traverse(lambda path,_:decay_path(path))),
105
- optax.masked(adamw(0.01),mask=traverse(lambda path,_:not decay_path(path))))
106
-
107
- def loss_function(logits,labels):
108
- logits=flax.linen.log_softmax(logits)
109
- xentropy=optax.softmax_cross_entropy(logits,onehot(labels,num_classes=4))
110
- return jnp.mean(xentropy)
111
-
112
- def eval_function(logits):
113
- return logits.argmax(-1)
114
-
115
- model = FlaxGPT2ForMultipleChoice.from_pretrained('gpt2',input_shape=(1,4,1))
116
-
117
- state=TrainState.create(apply_fn=model.__call__,
118
- params=model.params,
119
- tx=gradient_transformation,
120
- logits_function=eval_function,
121
- loss_function=loss_function)
122
-
123
- def train_step(state,batch,dropout_rng):
124
- targets=batch.pop("label")
125
- dropout_rng,new_dropout_rng=jax.random.split(dropout_rng)
126
- def loss_function(params):
127
- logits=state.apply_fn(**batch,params=params,dropout_rng=dropout_rng,train=True)[0]
128
- loss=state.loss_function(logits,targets)
129
- return loss
130
- grad_function=jax.value_and_grad(loss_function)
131
- loss,grad=grad_function(state.params)
132
- grad=jax.lax.pmean(grad,"batch")
133
- new_state=state.apply_gradients(grads=grad)
134
- #Added.
135
- logits=new_state.apply_fn(**batch,params=new_state.params,dropout_rng=dropout_rng,train=True)[0]
136
- accuracy=jnp.equal(jnp.argmax(logits,axis=-1),targets)
137
- metrics=jax.lax.pmean({"loss":loss,"learning_rate":learning_rate_function(state.step),'accuracy':accuracy},axis_name="batch")
138
- return new_state,metrics,new_dropout_rng
139
-
140
- parallel_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))
141
-
142
- def eval_step(state, batch):
143
- targets=batch.pop('label')
144
- logits = state.apply_fn(**batch, params=state.params, train=False)
145
- loss=state.loss_function(logits,targets)
146
- predictions=state.logits_function(logits)
147
- eval_accuracy=jnp.equal(predictions,targets)
148
- #eval_acc=jnp.equal(predictions,targets)
149
- metrics=jax.lax.pmean({"loss":loss,'accuracy':eval_accuracy},axis_name="batch")
150
- #return state.logits_function(logits) #(8,4)
151
- return targets,predictions,metrics
152
-
153
- parallel_eval_step = jax.pmap(eval_step, axis_name="batch")
154
-
155
- def glue_train_data_loader(rng,dataset,batch_size):
156
- steps_per_epoch=len_train_dataset//batch_size
157
- perms=jax.random.permutation(rng,len(dataset))
158
- perms=perms[:steps_per_epoch*batch_size]
159
- perms=perms.reshape((steps_per_epoch,batch_size))
160
- for perm in perms:
161
- batch=dataset[perm]
162
- batch={k:jnp.array(v) for k,v in batch.items()}
163
- batch=shard(batch)
164
- yield batch
165
-
166
- rng=jax.random.PRNGKey(seed)
167
- dropout_rngs=jax.random.split(rng,jax.local_device_count())
168
-
169
- def glue_eval_data_loader(dataset, batch_size):
170
- for i in range(len_validation_dataset // batch_size):
171
- batch = dataset[i * batch_size : (i + 1) * batch_size]
172
- batch = {k: jnp.array(v) for k, v in batch.items()}
173
- batch = shard(batch)
174
-
175
- yield batch
176
-
177
- state = flax.jax_utils.replicate(state)
178
- #metrics_list = list_metrics()
179
-
180
- actual_task = "mnli"
181
- metric = load_metric('glue', "mnli")
182
- actual_taskmetric = load_metric('glue', actual_task)
183
-
184
- workdir='./results_tensorboard'
185
- summary_writer = tensorboard.SummaryWriter(workdir)
186
- #summary_writer.hparams(dict(GPT2Config()))
187
-
188
- logger.info(f"***** Running training *****")
189
- logger.info(f" Num examples = {len_train_dataset}")
190
- logger.info(f" Num Epochs = {1}")
191
- logger.info(f" Instantaneous batch size per device = {per_device_batch_size}")
192
- logger.info(f" Total train batch size = {total_batch_size}")
193
- logger.info(f" Total optimization steps = {num_train_steps}")
194
-
195
- for i, epoch in enumerate(tqdm(range(1, 2), desc=f"Epoch ...", position=0, leave=True)):
196
- rng, input_rng = jax.random.split(rng)
197
- train_acc_metrics=[]
198
- train_loss_metrics=[]
199
- eval_acc_metrics=[]
200
- eval_loss_metrics=[]
201
- # train
202
- with tqdm(total=len_train_dataset // total_batch_size, desc="Training...", leave=False) as progress_bar_train:
203
- for idx,batch in enumerate(glue_train_data_loader(input_rng, train_dataset, total_batch_size)):
204
- state, train_metric, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
205
- train_acc_metrics.append(jax.device_get(train_metric['accuracy']).mean().item())
206
- train_loss_metrics.append(flax.jax_utils.unreplicate(train_metric)['loss'].item())
207
- if idx%1==0:
208
- summary_writer.scalar('train_loss',flax.jax_utils.unreplicate(train_metric)['loss'].item(),idx)
209
- summary_writer.scalar('train_accuracy', jax.device_get(train_metric['accuracy']).mean().item(),idx)
210
- if idx%1==0:
211
- # params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
212
- if jax.process_index() == 0:
213
- params = jax.device_get(unreplicate(state.params))
214
- model.save_pretrained(
215
- '.',
216
- params=state.params,
217
- push_to_hub=True,
218
- commit_message=f"Saving weights of epoch {epoch} at step {idx}",)
219
- progress_bar_train.update(1)
220
-
221
- # evaluate
222
- with tqdm(total=len_validation_dataset // total_batch_size, desc="Evaluating...", leave=False) as progress_bar_eval:
223
- for idx,batch in enumerate(glue_eval_data_loader(validation_dataset, total_batch_size)):
224
- labels,predictions,eval_metric=parallel_eval_step(state, batch)
225
- eval_acc_metrics.append(jax.device_get(eval_metric['accuracy']).mean().item())
226
- eval_loss_metrics.append(flax.jax_utils.unreplicate(eval_metric)['loss'].item())
227
- progress_bar_eval.update(1)
228
- if idx%1==0:
229
- logger.info(f"eval_step_loss{idx}:{flax.jax_utils.unreplicate(eval_metric)['loss'].item()} eval_step_acc{idx}:{jax.device_get(eval_metric['accuracy']).mean().item()}")
230
- summary_writer.scalar('eval_loss',flax.jax_utils.unreplicate(eval_metric)['loss'].item(),idx)
231
- summary_writer.scalar('eval_accuracy', jax.device_get(eval_metric['accuracy']).mean().item(),idx)
232
-
233
- #correct
234
- logger.info(f"Epoch {epoch} done")
235
- logger.info(f"Train loss:{jax.device_get(jnp.array(train_loss_metrics)).mean().item()} Train accuracy:{jax.device_get(jnp.array(train_acc_metrics)).mean().item()}")
236
- logger.info(f"Eval loss:{jax.device_get(jnp.array(eval_loss_metrics)).mean().item()} Eval accuracy:{jax.device_get(jnp.array(eval_acc_metrics)).mean().item()}")
237
- summary_writer.flush()
238
-
239
- if __name__ == "__main__":
240
- main()