delete files
Browse files- results_tensorboard/events.out.tfevents.1626288298.t1v-n-8cb15980-w-0.712426.3.v2 +0 -3
- results_tensorboard/events.out.tfevents.1626289169.t1v-n-8cb15980-w-0.716987.3.v2 +0 -3
- results_tensorboard/events.out.tfevents.1626327484.t1v-n-8cb15980-w-0.750616.3.v2 +0 -3
- results_tensorboard/events.out.tfevents.1626327887.t1v-n-8cb15980-w-0.752609.3.v2 +0 -3
- results_tensorboard/events.out.tfevents.1626332656.t1v-n-8cb15980-w-0.759220.3.v2 +0 -3
- results_tensorboard/events.out.tfevents.1626333695.t1v-n-8cb15980-w-0.761814.3.v2 +0 -3
- results_tensorboard/events.out.tfevents.1626333832.t1v-n-8cb15980-w-0.763322.3.v2 +0 -3
- train.py +0 -240
results_tensorboard/events.out.tfevents.1626288298.t1v-n-8cb15980-w-0.712426.3.v2
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:f41c2fcf7a5ad8f782c33825800d708bad20828d2790507cafaa45daf459f37d
|
3 |
-
size 588
|
|
|
|
|
|
|
|
results_tensorboard/events.out.tfevents.1626289169.t1v-n-8cb15980-w-0.716987.3.v2
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:054d9f29e34dc409faa7d550c77e842a61072adc0c91d10b6498b98fe0307085
|
3 |
-
size 588
|
|
|
|
|
|
|
|
results_tensorboard/events.out.tfevents.1626327484.t1v-n-8cb15980-w-0.750616.3.v2
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:68aac888ec5f8514ef03c8e2c6d595ae71a355c889643e9423ac98676c00c233
|
3 |
-
size 588
|
|
|
|
|
|
|
|
results_tensorboard/events.out.tfevents.1626327887.t1v-n-8cb15980-w-0.752609.3.v2
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:7657f9a1a65b35a4544634bbf3dc9352eec8b111862d47d2015fba24536579f5
|
3 |
-
size 596
|
|
|
|
|
|
|
|
results_tensorboard/events.out.tfevents.1626332656.t1v-n-8cb15980-w-0.759220.3.v2
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:781634a9c25358932557a2489094d955a8b955179a710e6f7aebfb2bdd4333dc
|
3 |
-
size 866
|
|
|
|
|
|
|
|
results_tensorboard/events.out.tfevents.1626333695.t1v-n-8cb15980-w-0.761814.3.v2
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:a6e77accc72524145a973a76b92872232f60143270fb2a5ded477efb5c5d043e
|
3 |
-
size 176
|
|
|
|
|
|
|
|
results_tensorboard/events.out.tfevents.1626333832.t1v-n-8cb15980-w-0.763322.3.v2
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:593046ea18ca7b10638c83f187d1f66a534142780b3b663c774efa0a7135c26c
|
3 |
-
size 456
|
|
|
|
|
|
|
|
train.py
DELETED
@@ -1,240 +0,0 @@
|
|
1 |
-
import jax
|
2 |
-
print(jax.local_device_count())
|
3 |
-
import jax.numpy as jnp
|
4 |
-
|
5 |
-
import flax
|
6 |
-
import flax.linen as nn
|
7 |
-
from flax.training.common_utils import get_metrics,onehot,shard,shard_prng_key
|
8 |
-
from flax.training import train_state
|
9 |
-
from flax.metrics.tensorboard import SummaryWriter
|
10 |
-
from flax.training import checkpoints
|
11 |
-
|
12 |
-
|
13 |
-
import logging
|
14 |
-
import optax
|
15 |
-
import math
|
16 |
-
from tqdm import tqdm
|
17 |
-
|
18 |
-
from pathlib import Path
|
19 |
-
from typing import Callable
|
20 |
-
from itertools import chain
|
21 |
-
from flax.metrics import tensorboard
|
22 |
-
|
23 |
-
from datasets import load_dataset,load_metric
|
24 |
-
from transformers import GPT2Config,GPT2Tokenizer
|
25 |
-
|
26 |
-
from model_file import FlaxGPT2ForMultipleChoice
|
27 |
-
|
28 |
-
logger = logging.getLogger()
|
29 |
-
logger.setLevel(logging.INFO)
|
30 |
-
from flax.jax_utils import unreplicate
|
31 |
-
|
32 |
-
def main():
|
33 |
-
|
34 |
-
|
35 |
-
tokenizer=GPT2Tokenizer.from_pretrained('gpt2',pad_token='<|endoftext|>')
|
36 |
-
|
37 |
-
dataset=load_dataset('cosmos_qa')
|
38 |
-
|
39 |
-
def preprocess(example):
|
40 |
-
example['context&question']=example['context']+example['question']
|
41 |
-
example['first_sentence']=[example['context&question'],example['context&question'],example['context&question'],example['context&question']]
|
42 |
-
example['second_sentence']=example['answer0'],example['answer1'],example['answer2'],example['answer3']
|
43 |
-
return example
|
44 |
-
|
45 |
-
train_dataset=dataset['train'].map(preprocess)
|
46 |
-
validation_dataset=dataset['validation'].map(preprocess)
|
47 |
-
test_dataset=dataset['test'].map(preprocess)
|
48 |
-
|
49 |
-
#Remove after experiment
|
50 |
-
len_train_dataset=100
|
51 |
-
len_validation_dataset=100
|
52 |
-
len_test_dataset=100
|
53 |
-
|
54 |
-
train_dataset=train_dataset.select(range(len_train_dataset))
|
55 |
-
test_dataset=test_dataset.select(range(len_validation_dataset))
|
56 |
-
validation_dataset=validation_dataset.select(range(len_test_dataset))
|
57 |
-
|
58 |
-
#remove_cols=train_dataset.column_names
|
59 |
-
|
60 |
-
def tokenize(examples):
|
61 |
-
a=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,max_length=256,return_tensors='jax')
|
62 |
-
a['labels']=examples['label']
|
63 |
-
return a
|
64 |
-
|
65 |
-
train_dataset=train_dataset.map(tokenize)
|
66 |
-
validation_dataset=validation_dataset.map(tokenize)
|
67 |
-
test_dataset=test_dataset.map(tokenize)
|
68 |
-
|
69 |
-
remov_col=['id', 'context', 'question', 'answer0', 'answer1', 'answer2', 'answer3', 'labels', 'context&question', 'first_sentence', 'second_sentence']
|
70 |
-
|
71 |
-
train_dataset=train_dataset.remove_columns(remov_col)
|
72 |
-
validation_dataset=validation_dataset.remove_columns(remov_col)
|
73 |
-
test_dataset=test_dataset.remove_columns(remov_col)
|
74 |
-
|
75 |
-
per_device_batch_size=4
|
76 |
-
seed=0
|
77 |
-
num_train_epochs=1
|
78 |
-
learning_rate=2e-5
|
79 |
-
|
80 |
-
|
81 |
-
total_batch_size = per_device_batch_size * jax.local_device_count()
|
82 |
-
print('The overall batch size (both for training and eval) is', total_batch_size)
|
83 |
-
num_train_steps = len(train_dataset) // total_batch_size * num_train_epochs
|
84 |
-
num_validation_steps=len(validation_dataset)//total_batch_size*num_train_epochs
|
85 |
-
|
86 |
-
learning_rate_function = optax.linear_schedule(init_value=learning_rate, end_value=0, transition_steps=num_train_steps)
|
87 |
-
|
88 |
-
class TrainState(train_state.TrainState):
|
89 |
-
logits_function:Callable=flax.struct.field(pytree_node=False)
|
90 |
-
loss_function:Callable=flax.struct.field(pytree_node=False)
|
91 |
-
|
92 |
-
def adamw(weight_decay):
|
93 |
-
return optax.adamw(learning_rate=learning_rate_function,b1=0.9,b2=0.99,eps=1e-6,weight_decay=weight_decay)
|
94 |
-
|
95 |
-
decay_path=lambda p:not any(x in p for x in ['bias','LayerNorm.weight'])
|
96 |
-
|
97 |
-
def traverse(function):
|
98 |
-
def mask(data):
|
99 |
-
flat=flax.traverse_util.flatten_dict(data)
|
100 |
-
return flax.traverse_util.unflatten_dict({k:function(k,v) for k,v in flat.items()})
|
101 |
-
return mask
|
102 |
-
|
103 |
-
gradient_transformation=optax.chain(
|
104 |
-
optax.masked(adamw(0.0),mask=traverse(lambda path,_:decay_path(path))),
|
105 |
-
optax.masked(adamw(0.01),mask=traverse(lambda path,_:not decay_path(path))))
|
106 |
-
|
107 |
-
def loss_function(logits,labels):
|
108 |
-
logits=flax.linen.log_softmax(logits)
|
109 |
-
xentropy=optax.softmax_cross_entropy(logits,onehot(labels,num_classes=4))
|
110 |
-
return jnp.mean(xentropy)
|
111 |
-
|
112 |
-
def eval_function(logits):
|
113 |
-
return logits.argmax(-1)
|
114 |
-
|
115 |
-
model = FlaxGPT2ForMultipleChoice.from_pretrained('gpt2',input_shape=(1,4,1))
|
116 |
-
|
117 |
-
state=TrainState.create(apply_fn=model.__call__,
|
118 |
-
params=model.params,
|
119 |
-
tx=gradient_transformation,
|
120 |
-
logits_function=eval_function,
|
121 |
-
loss_function=loss_function)
|
122 |
-
|
123 |
-
def train_step(state,batch,dropout_rng):
|
124 |
-
targets=batch.pop("label")
|
125 |
-
dropout_rng,new_dropout_rng=jax.random.split(dropout_rng)
|
126 |
-
def loss_function(params):
|
127 |
-
logits=state.apply_fn(**batch,params=params,dropout_rng=dropout_rng,train=True)[0]
|
128 |
-
loss=state.loss_function(logits,targets)
|
129 |
-
return loss
|
130 |
-
grad_function=jax.value_and_grad(loss_function)
|
131 |
-
loss,grad=grad_function(state.params)
|
132 |
-
grad=jax.lax.pmean(grad,"batch")
|
133 |
-
new_state=state.apply_gradients(grads=grad)
|
134 |
-
#Added.
|
135 |
-
logits=new_state.apply_fn(**batch,params=new_state.params,dropout_rng=dropout_rng,train=True)[0]
|
136 |
-
accuracy=jnp.equal(jnp.argmax(logits,axis=-1),targets)
|
137 |
-
metrics=jax.lax.pmean({"loss":loss,"learning_rate":learning_rate_function(state.step),'accuracy':accuracy},axis_name="batch")
|
138 |
-
return new_state,metrics,new_dropout_rng
|
139 |
-
|
140 |
-
parallel_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))
|
141 |
-
|
142 |
-
def eval_step(state, batch):
|
143 |
-
targets=batch.pop('label')
|
144 |
-
logits = state.apply_fn(**batch, params=state.params, train=False)
|
145 |
-
loss=state.loss_function(logits,targets)
|
146 |
-
predictions=state.logits_function(logits)
|
147 |
-
eval_accuracy=jnp.equal(predictions,targets)
|
148 |
-
#eval_acc=jnp.equal(predictions,targets)
|
149 |
-
metrics=jax.lax.pmean({"loss":loss,'accuracy':eval_accuracy},axis_name="batch")
|
150 |
-
#return state.logits_function(logits) #(8,4)
|
151 |
-
return targets,predictions,metrics
|
152 |
-
|
153 |
-
parallel_eval_step = jax.pmap(eval_step, axis_name="batch")
|
154 |
-
|
155 |
-
def glue_train_data_loader(rng,dataset,batch_size):
|
156 |
-
steps_per_epoch=len_train_dataset//batch_size
|
157 |
-
perms=jax.random.permutation(rng,len(dataset))
|
158 |
-
perms=perms[:steps_per_epoch*batch_size]
|
159 |
-
perms=perms.reshape((steps_per_epoch,batch_size))
|
160 |
-
for perm in perms:
|
161 |
-
batch=dataset[perm]
|
162 |
-
batch={k:jnp.array(v) for k,v in batch.items()}
|
163 |
-
batch=shard(batch)
|
164 |
-
yield batch
|
165 |
-
|
166 |
-
rng=jax.random.PRNGKey(seed)
|
167 |
-
dropout_rngs=jax.random.split(rng,jax.local_device_count())
|
168 |
-
|
169 |
-
def glue_eval_data_loader(dataset, batch_size):
|
170 |
-
for i in range(len_validation_dataset // batch_size):
|
171 |
-
batch = dataset[i * batch_size : (i + 1) * batch_size]
|
172 |
-
batch = {k: jnp.array(v) for k, v in batch.items()}
|
173 |
-
batch = shard(batch)
|
174 |
-
|
175 |
-
yield batch
|
176 |
-
|
177 |
-
state = flax.jax_utils.replicate(state)
|
178 |
-
#metrics_list = list_metrics()
|
179 |
-
|
180 |
-
actual_task = "mnli"
|
181 |
-
metric = load_metric('glue', "mnli")
|
182 |
-
actual_taskmetric = load_metric('glue', actual_task)
|
183 |
-
|
184 |
-
workdir='./results_tensorboard'
|
185 |
-
summary_writer = tensorboard.SummaryWriter(workdir)
|
186 |
-
#summary_writer.hparams(dict(GPT2Config()))
|
187 |
-
|
188 |
-
logger.info(f"***** Running training *****")
|
189 |
-
logger.info(f" Num examples = {len_train_dataset}")
|
190 |
-
logger.info(f" Num Epochs = {1}")
|
191 |
-
logger.info(f" Instantaneous batch size per device = {per_device_batch_size}")
|
192 |
-
logger.info(f" Total train batch size = {total_batch_size}")
|
193 |
-
logger.info(f" Total optimization steps = {num_train_steps}")
|
194 |
-
|
195 |
-
for i, epoch in enumerate(tqdm(range(1, 2), desc=f"Epoch ...", position=0, leave=True)):
|
196 |
-
rng, input_rng = jax.random.split(rng)
|
197 |
-
train_acc_metrics=[]
|
198 |
-
train_loss_metrics=[]
|
199 |
-
eval_acc_metrics=[]
|
200 |
-
eval_loss_metrics=[]
|
201 |
-
# train
|
202 |
-
with tqdm(total=len_train_dataset // total_batch_size, desc="Training...", leave=False) as progress_bar_train:
|
203 |
-
for idx,batch in enumerate(glue_train_data_loader(input_rng, train_dataset, total_batch_size)):
|
204 |
-
state, train_metric, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
|
205 |
-
train_acc_metrics.append(jax.device_get(train_metric['accuracy']).mean().item())
|
206 |
-
train_loss_metrics.append(flax.jax_utils.unreplicate(train_metric)['loss'].item())
|
207 |
-
if idx%1==0:
|
208 |
-
summary_writer.scalar('train_loss',flax.jax_utils.unreplicate(train_metric)['loss'].item(),idx)
|
209 |
-
summary_writer.scalar('train_accuracy', jax.device_get(train_metric['accuracy']).mean().item(),idx)
|
210 |
-
if idx%1==0:
|
211 |
-
# params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
212 |
-
if jax.process_index() == 0:
|
213 |
-
params = jax.device_get(unreplicate(state.params))
|
214 |
-
model.save_pretrained(
|
215 |
-
'.',
|
216 |
-
params=state.params,
|
217 |
-
push_to_hub=True,
|
218 |
-
commit_message=f"Saving weights of epoch {epoch} at step {idx}",)
|
219 |
-
progress_bar_train.update(1)
|
220 |
-
|
221 |
-
# evaluate
|
222 |
-
with tqdm(total=len_validation_dataset // total_batch_size, desc="Evaluating...", leave=False) as progress_bar_eval:
|
223 |
-
for idx,batch in enumerate(glue_eval_data_loader(validation_dataset, total_batch_size)):
|
224 |
-
labels,predictions,eval_metric=parallel_eval_step(state, batch)
|
225 |
-
eval_acc_metrics.append(jax.device_get(eval_metric['accuracy']).mean().item())
|
226 |
-
eval_loss_metrics.append(flax.jax_utils.unreplicate(eval_metric)['loss'].item())
|
227 |
-
progress_bar_eval.update(1)
|
228 |
-
if idx%1==0:
|
229 |
-
logger.info(f"eval_step_loss{idx}:{flax.jax_utils.unreplicate(eval_metric)['loss'].item()} eval_step_acc{idx}:{jax.device_get(eval_metric['accuracy']).mean().item()}")
|
230 |
-
summary_writer.scalar('eval_loss',flax.jax_utils.unreplicate(eval_metric)['loss'].item(),idx)
|
231 |
-
summary_writer.scalar('eval_accuracy', jax.device_get(eval_metric['accuracy']).mean().item(),idx)
|
232 |
-
|
233 |
-
#correct
|
234 |
-
logger.info(f"Epoch {epoch} done")
|
235 |
-
logger.info(f"Train loss:{jax.device_get(jnp.array(train_loss_metrics)).mean().item()} Train accuracy:{jax.device_get(jnp.array(train_acc_metrics)).mean().item()}")
|
236 |
-
logger.info(f"Eval loss:{jax.device_get(jnp.array(eval_loss_metrics)).mean().item()} Eval accuracy:{jax.device_get(jnp.array(eval_acc_metrics)).mean().item()}")
|
237 |
-
summary_writer.flush()
|
238 |
-
|
239 |
-
if __name__ == "__main__":
|
240 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|