Vivek commited on
Commit
7b062dd
1 Parent(s): bce676a

slight update

Browse files
Files changed (4) hide show
  1. .DS_Store +0 -0
  2. src/gptneo_piqa.py +234 -0
  3. src/model_file.py +209 -0
  4. src/requirements.txt +7 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
src/gptneo_piqa.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ print(jax.local_device_count())
3
+ import jax.numpy as jnp
4
+
5
+ import flax
6
+ import flax.linen as nn
7
+ from flax.training.common_utils import get_metrics,onehot,shard,shard_prng_key
8
+ from flax.training import train_state
9
+ from flax.metrics.tensorboard import SummaryWriter
10
+ from flax.training import checkpoints
11
+
12
+ from datasets import load_dataset,load_metric
13
+ from transformers import GPT2Tokenizer
14
+
15
+ from tqdm import tqdm
16
+
17
+ import logging
18
+ import optax
19
+ import math
20
+ from pathlib import Path
21
+ from typing import Callable
22
+ from itertools import chain
23
+ from flax.metrics import tensorboard
24
+ from datasets import load_dataset,load_metric
25
+
26
+ from transformers import GPTNeoConfig,GPT2Tokenizer
27
+
28
+ from model_file import FlaxGPTNeoForMultipleChoice
29
+
30
+ logger = logging.getLogger()
31
+ logger.setLevel(logging.INFO)
32
+
33
+
34
+ tokenizer=GPT2Tokenizer.from_pretrained('EleutherAI/gpt-neo-1.3B',pad_token='<|endoftext|>')
35
+
36
+ dataset=load_dataset('piqa')
37
+ num_choices=2
38
+
39
+ def preprocess(example):
40
+ example['first_sentence']=[example['goal']]*num_choices
41
+ example['second_sentence']=[example[f'sol{i}'] for i in [1,2]]
42
+ return example
43
+
44
+ train_dataset=dataset['train'].map(preprocess)
45
+ validation_dataset=dataset['validation'].map(preprocess)
46
+ test_dataset=dataset['test'].map(preprocess)
47
+
48
+ len_train_dataset=16113
49
+ len_validation_dataset=1838
50
+ len_test_dataset=3084
51
+
52
+ train_dataset=train_dataset.select(range(len_train_dataset))
53
+ test_dataset=test_dataset.select(range(len_test_dataset))
54
+ validation_dataset=validation_dataset.select(range(len_validation_dataset))
55
+
56
+ remove_col=train_dataset.column_names
57
+
58
+ def tokenize(examples):
59
+ tokenized_examples=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,max_length=256,return_tensors='jax')
60
+ tokenized_examples['labels']=int(examples['label'])
61
+ return tokenized_examples
62
+
63
+ train_dataset=train_dataset.map(tokenize)
64
+ validation_dataset=validation_dataset.map(tokenize)
65
+
66
+ train_dataset=train_dataset.remove_columns(remove_col)
67
+ validation_dataset=validation_dataset.remove_columns(remove_col)
68
+ test_dataset=test_dataset.remove_columns(remove_col)
69
+
70
+ per_device_batch_size=2
71
+ seed=0
72
+ num_train_epochs=3
73
+ learning_rate=2e-5
74
+
75
+ model = FlaxGPTNeoForMultipleChoice.from_pretrained('EleutherAI/gpt-neo-1.3B',input_shape=(1,num_choices,1))
76
+
77
+ total_batch_size = per_device_batch_size * jax.local_device_count()
78
+ print('The overall batch size (both for training and eval) is', total_batch_size)
79
+ num_train_steps = len(train_dataset) // total_batch_size * num_train_epochs
80
+ num_validation_steps=len(validation_dataset)//total_batch_size*num_train_epochs
81
+
82
+ learning_rate_function = optax.linear_schedule(init_value=learning_rate, end_value=3e-7, transition_steps=num_train_steps)
83
+
84
+ class TrainState(train_state.TrainState):
85
+ logits_function:Callable=flax.struct.field(pytree_node=False)
86
+ loss_function:Callable=flax.struct.field(pytree_node=False)
87
+
88
+ def adamw(weight_decay):
89
+ return optax.adafactor(learning_rate=learning_rate_function)
90
+
91
+ decay_path=lambda p:not any(x in p for x in ['bias','LayerNorm.weight'])
92
+
93
+ def traverse(function):
94
+ def mask(data):
95
+ flat=flax.traverse_util.flatten_dict(data)
96
+ return flax.traverse_util.unflatten_dict({k:function(k,v) for k,v in flat.items()})
97
+ return mask
98
+ gradient_transformation=optax.chain(
99
+ optax.masked(adamw(0.0),mask=traverse(lambda path,_:decay_path(path))),
100
+ optax.masked(adamw(0.01),mask=traverse(lambda path,_:not decay_path(path))))
101
+
102
+ def loss_function(logits,labels):
103
+ logits=flax.linen.log_softmax(logits)
104
+ xentropy=optax.softmax_cross_entropy(logits,onehot(labels,num_classes=num_choices))
105
+ return jnp.mean(xentropy)
106
+
107
+ def eval_function(logits):
108
+ return logits.argmax(-1)
109
+
110
+ state=TrainState.create(apply_fn=model.__call__,
111
+ params=model.params,
112
+ tx=gradient_transformation,
113
+ logits_function=eval_function,
114
+ loss_function=loss_function)
115
+
116
+ def train_step(state,batch,dropout_rng):
117
+ targets=batch.pop("labels")
118
+ dropout_rng,new_dropout_rng=jax.random.split(dropout_rng)
119
+ def loss_function(params):
120
+ logits=state.apply_fn(**batch,params=params,dropout_rng=dropout_rng,train=True)[0]
121
+ loss=state.loss_function(logits,targets)
122
+ return loss
123
+ grad_function=jax.value_and_grad(loss_function)
124
+ loss,grad=grad_function(state.params)
125
+ grad=jax.lax.pmean(grad,"batch")
126
+ new_state=state.apply_gradients(grads=grad)
127
+ #Added.
128
+ logits=new_state.apply_fn(**batch,params=new_state.params,dropout_rng=dropout_rng,train=True)[0]
129
+ accuracy=jnp.equal(jnp.argmax(logits,axis=-1),targets)
130
+ #metrics=jax.lax.pmean({"loss":loss,"learning_rate":learning_rate_function(state.step),'accuracy':accuracy},axis_name="batch")
131
+ metrics=jax.lax.pmean({"loss":jax.device_get(loss),"learning_rate":jax.device_get(learning_rate_function(state.step)),'accuracy':jax.device_get(accuracy)},axis_name="batch")
132
+ return new_state,metrics,new_dropout_rng
133
+
134
+ parallel_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))
135
+
136
+ def eval_step(state, batch):
137
+ targets=batch.pop('labels')
138
+ logits = state.apply_fn(**batch, params=state.params, train=False)
139
+ loss=state.loss_function(logits,targets)
140
+ predictions=state.logits_function(logits)
141
+ eval_accuracy=jnp.equal(predictions,targets)
142
+ #eval_acc=jnp.equal(predictions,targets)
143
+ #metrics=jax.lax.pmean({"loss":loss,'accuracy':eval_accuracy},axis_name="batch")
144
+ metrics=jax.lax.pmean({"loss":jax.device_get(loss),'accuracy':jax.device_get(eval_accuracy)},axis_name="batch")
145
+ #return state.logits_function(logits) #(8,4)
146
+ return targets,predictions,metrics
147
+
148
+ parallel_eval_step = jax.pmap(eval_step, axis_name="batch")
149
+
150
+ def glue_train_data_loader(rng,dataset,batch_size):
151
+ steps_per_epoch=len_train_dataset//batch_size
152
+ perms=jax.random.permutation(rng,len_train_dataset)
153
+ perms=perms[:steps_per_epoch*batch_size]
154
+ perms=perms.reshape((steps_per_epoch,batch_size))
155
+ for perm in perms:
156
+ batch=dataset[perm]
157
+ #print(jnp.array(batch['label']))
158
+ batch={k:jnp.array(v) for k,v in batch.items()}
159
+ batch=shard(batch)
160
+ yield batch
161
+
162
+ rng=jax.random.PRNGKey(seed)
163
+ dropout_rngs=jax.random.split(rng,jax.local_device_count())
164
+
165
+ def glue_eval_data_loader(dataset, batch_size):
166
+ for i in range(len_validation_dataset // batch_size):
167
+ batch = dataset[i * batch_size : (i + 1) * batch_size]
168
+ batch = {k: jnp.array(v) for k, v in batch.items()}
169
+ batch = shard(batch)
170
+ yield batch
171
+
172
+ state = flax.jax_utils.replicate(state)
173
+
174
+ actual_task = "mnli"
175
+ metric = load_metric('glue', "mnli")
176
+ actual_taskmetric = load_metric('glue', actual_task)
177
+
178
+ workdir='../results_tensorboard'
179
+ summary_writer = tensorboard.SummaryWriter(workdir)
180
+
181
+ logger.info(f"***** Running training *****")
182
+ logger.info(f" Num examples = {len_train_dataset}")
183
+ logger.info(f" Num Epochs = {num_train_epochs}")
184
+ logger.info(f" Instantaneous batch size per device = {per_device_batch_size}")
185
+ logger.info(f" Total train batch size = {total_batch_size}")
186
+ logger.info(f" Total optimization steps = {num_train_steps}")
187
+
188
+ for i, epoch in enumerate(tqdm(range(1, num_train_epochs+1), desc=f"Epoch ...", position=0, leave=True)):
189
+ rng, input_rng = jax.random.split(rng)
190
+ train_acc_metrics=[]
191
+ train_loss_metrics=[]
192
+ eval_acc_metrics=[]
193
+ eval_loss_metrics=[]
194
+ # train
195
+ with tqdm(total=len_train_dataset // total_batch_size, desc="Training...", leave=False) as progress_bar_train:
196
+ for idx,batch in enumerate(glue_train_data_loader(input_rng, train_dataset, total_batch_size)):
197
+ state, train_metric, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
198
+ train_acc_metrics.append(jax.device_get(train_metric['accuracy']).mean().item())
199
+ train_loss_metrics.append(flax.jax_utils.unreplicate(train_metric)['loss'].item())
200
+ if idx%5==0:
201
+ summary_writer.scalar('train_loss',flax.jax_utils.unreplicate(train_metric)['loss'].item(),idx)
202
+ summary_writer.scalar('train_accuracy', jax.device_get(train_metric['accuracy']).mean().item(),idx)
203
+ if idx%20==0:
204
+ logger.info(f"train_step_loss{idx}: {flax.jax_utils.unreplicate(train_metric)['loss'].item()} train_step_acc{idx}: {jax.device_get(train_metric['accuracy']).mean().item()} ")
205
+
206
+ progress_bar_train.update(1)
207
+
208
+ # evaluate
209
+ with tqdm(total=len_validation_dataset // total_batch_size, desc="Evaluating...", leave=False) as progress_bar_eval:
210
+ for idx,batch in enumerate(glue_eval_data_loader(validation_dataset, total_batch_size)):
211
+ labels,predictions,eval_metric=parallel_eval_step(state, batch)
212
+ eval_acc_metrics.append(jax.device_get(eval_metric['accuracy']).mean().item())
213
+ eval_loss_metrics.append(flax.jax_utils.unreplicate(eval_metric)['loss'].item())
214
+ progress_bar_eval.update(1)
215
+ if idx%5==0:
216
+ logger.info(f"eval_step_loss {idx} : {flax.jax_utils.unreplicate(eval_metric)['loss'].item()} eval_step_acc {idx} : {jax.device_get(eval_metric['accuracy']).mean().item()}")
217
+ summary_writer.scalar('eval_loss : ', flax.jax_utils.unreplicate(eval_metric)['loss'].item(),idx)
218
+ summary_writer.scalar('eval_accuracy : ', jax.device_get(eval_metric['accuracy']).mean().item(),idx)
219
+
220
+ logger.info(f"---------------------Epoch {epoch} done-----------------")
221
+ logger.info(f"Train loss: {jax.device_get(jnp.array(train_loss_metrics)).mean().item()} Train accuracy: {jax.device_get(jnp.array(train_acc_metrics)).mean().item()}")
222
+ logger.info(f"Eval loss: {jax.device_get(jnp.array(eval_loss_metrics)).mean().item()} Eval accuracy: {jax.device_get(jnp.array(eval_acc_metrics)).mean().item()}")
223
+
224
+ if jax.process_index() == 0:
225
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
226
+
227
+ model.save_pretrained(
228
+ '../',
229
+ params=params,
230
+ push_to_hub=True,
231
+ commit_message=f"Piqa:Saving weights of epoch {epoch} at step {idx}",)
232
+
233
+ summary_writer.flush()
234
+
src/model_file.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import jax.numpy as jnp
3
+
4
+ import flax
5
+ import flax.linen as nn
6
+ from flax.core.frozen_dict import FrozenDict, unfreeze
7
+
8
+ from pathlib import Path
9
+ from typing import Callable
10
+ from itertools import chain
11
+ from typing import Any, Optional, Tuple
12
+ from flax.core.frozen_dict import FrozenDict, unfreeze
13
+
14
+ from transformers.modeling_flax_utils import FlaxPreTrainedModel
15
+ from transformers import GPTNeoConfig,GPT2Tokenizer,file_utils
16
+ from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
17
+ from transformers.models.gpt_neo.modeling_flax_gpt_neo import FlaxGPTNeoBlockCollection
18
+ from transformers.modeling_flax_outputs import FlaxBaseModelOutput
19
+ from transformers.models.gpt_neo.modeling_flax_gpt_neo import FlaxGPTNeoModule
20
+ from transformers.models.gpt_neo.modeling_flax_gpt_neo import FlaxGPTNeoPreTrainedModel
21
+
22
+ num_choice=2
23
+
24
+ tokenizer=GPT2Tokenizer.from_pretrained('EleutherAI/gpt-neo-1.3B',pad_token='<|endoftext|>')
25
+
26
+ GPT_NEO_START_DOCSTRING = r"""
27
+ This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
28
+ generic methods the library implements for all its model (such as downloading or saving, resizing the input
29
+ embeddings, pruning heads etc.)
30
+ This model is also a Flax Linen `flax.nn.Module
31
+ <https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax
32
+ Module and refer to the Flax documentation for all matter related to general usage and behavior.
33
+ Finally, this model supports inherent JAX features such as:
34
+ - `Just-In-Time (JIT) compilation <https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit>`__
35
+ - `Automatic Differentiation <https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation>`__
36
+ - `Vectorization <https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap>`__
37
+ - `Parallelization <https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap>`__
38
+ Parameters:
39
+ config (:class:`~transformers.GPTNeoConfig`): Model configuration class with all the parameters of the model.
40
+ Initializing with a config file does not load the weights associated with the model, only the
41
+ configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
42
+ model weights.
43
+ """
44
+
45
+ GPT_NEO_INPUTS_DOCSTRING = r"""
46
+ Args:
47
+ input_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, input_ids_length)`):
48
+ :obj:`input_ids_length` = ``sequence_length``. Indices of input sequence tokens in the vocabulary.
49
+ Indices can be obtained using :class:`~transformers.GPTNeoTokenizer`. See
50
+ :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
51
+ details.
52
+ `What are input IDs? <../glossary.html#input-ids>`__
53
+ attention_mask (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
54
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
55
+ - 1 for tokens that are **not masked**,
56
+ - 0 for tokens that are **masked**.
57
+ `What are attention masks? <../glossary.html#attention-mask>`__
58
+ position_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
59
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
60
+ config.max_position_embeddings - 1]``.
61
+ past_key_values (:obj:`Dict[str, np.ndarray]`, `optional`, returned by ``init_cache`` or when passing previous ``past_key_values``):
62
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
63
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape `[batch_size, max_length]`.
64
+ output_attentions (:obj:`bool`, `optional`):
65
+ Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
66
+ tensors for more detail.
67
+ output_hidden_states (:obj:`bool`, `optional`):
68
+ Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
69
+ more detail.
70
+ return_dict (:obj:`bool`, `optional`):
71
+ Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
72
+ """
73
+
74
+ class FlaxGPTNeoPreTrainedModel(FlaxPreTrainedModel): #modify
75
+ """
76
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
77
+ models.
78
+ """
79
+ config_class = GPTNeoConfig
80
+ base_model_prefix = "transformer"
81
+ module_class: nn.Module = None
82
+ def __init__(
83
+ self,
84
+ config: GPTNeoConfig,
85
+ input_shape: Tuple = (1, 1),
86
+ seed: int = 0,
87
+ dtype: jnp.dtype = jnp.float32,
88
+ **kwargs,
89
+ ):
90
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
91
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
92
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
93
+ # init input tensors
94
+ input_ids = jnp.zeros(input_shape, dtype="i4")
95
+ attention_mask = jnp.ones_like(input_ids)
96
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
97
+ params_rng, dropout_rng = jax.random.split(rng)
98
+ rngs = {"params": params_rng, "dropout": dropout_rng}
99
+ return self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"]
100
+ def init_cache(self, batch_size, max_length):
101
+ r"""
102
+ Args:
103
+ batch_size (:obj:`int`):
104
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
105
+ max_length (:obj:`int`):
106
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
107
+ cache.
108
+ """
109
+ # init input variables to retrieve cache
110
+ input_ids = jnp.ones((batch_size, max_length))
111
+ attention_mask = jnp.ones_like(input_ids)
112
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
113
+ init_variables = self.module.init(
114
+ jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
115
+ )
116
+ return init_variables["cache"]
117
+ @add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING)
118
+ def __call__(
119
+ self,
120
+ input_ids,
121
+ attention_mask=None,
122
+ position_ids=None,
123
+ params: dict = None,
124
+ past_key_values: dict = None,
125
+ dropout_rng: jax.random.PRNGKey = None,
126
+ train: bool = False,
127
+ output_attentions: Optional[bool] = None,
128
+ output_hidden_states: Optional[bool] = None,
129
+ return_dict: Optional[bool] = None,
130
+ ):
131
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
132
+ output_hidden_states = (
133
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
134
+ )
135
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
136
+
137
+ if position_ids is None:
138
+ if past_key_values is not None:
139
+ raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
140
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
141
+ if attention_mask is None:
142
+ attention_mask = jnp.ones_like(input_ids)
143
+ # Handle any PRNG if needed
144
+ rngs = {}
145
+ if dropout_rng is not None:
146
+ rngs["dropout"] = dropout_rng
147
+ inputs = {"params": params or self.params}
148
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPT2Attention module
149
+ if past_key_values:
150
+ inputs["cache"] = past_key_values
151
+ mutable = ["cache"]
152
+ else:
153
+ mutable = False
154
+ outputs = self.module.apply(
155
+ inputs,
156
+ jnp.array(input_ids, dtype="i4"),
157
+ jnp.array(attention_mask, dtype="i4"),
158
+ jnp.array(position_ids, dtype="i4"),
159
+ not train,
160
+ False,
161
+ output_attentions,
162
+ output_hidden_states,
163
+ return_dict,
164
+ rngs=rngs,
165
+ mutable=mutable,
166
+ )
167
+ # add updated cache to model output
168
+ if past_key_values is not None and return_dict:
169
+ outputs, past_key_values = outputs
170
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
171
+ return outputs
172
+ elif past_key_values is not None and not return_dict:
173
+ outputs, past_key_values = outputs
174
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
175
+ return outputs
176
+
177
+ class FlaxGPTNeoForMultipleChoiceModule(nn.Module):
178
+ config:GPTNeoConfig
179
+ dtype: jnp.dtype = jnp.float32
180
+ def setup(self):
181
+ self.transformer = FlaxGPTNeoModule(config=self.config, dtype=self.dtype)
182
+ self.dropout = nn.Dropout(rate=0.2)
183
+ self.classifier = nn.Dense(num_choice, dtype=self.dtype)
184
+ def __call__(self,input_ids,attention_mask,position_ids,return_dict=True,deterministic=True,*args):
185
+ batch_size = input_ids.shape[0]
186
+ rng=jax.random.PRNGKey(0)
187
+ _, dropout_rng = jax.random.split(rng)
188
+ input_ids=input_ids.reshape(num_choice*batch_size,-1)
189
+ position_ids=position_ids.reshape(num_choice*batch_size,-1)
190
+ attention_mask=attention_mask.reshape(num_choice*batch_size,-1)
191
+ outputs=self.transformer(input_ids, attention_mask,position_ids,return_dict=return_dict)
192
+
193
+ hidden_states = outputs[0]
194
+ hidden_states= jnp.mean(hidden_states, axis=1)
195
+
196
+
197
+ hidden_states=hidden_states.reshape(batch_size,-1) #(32,8,768)->(32,8*768)
198
+ dropout_output = self.dropout(hidden_states,deterministic=deterministic,rng=dropout_rng)
199
+
200
+ logits = self.classifier(dropout_output)
201
+ reshaped_logits = logits.reshape(-1, num_choice)
202
+ #(32,4)
203
+ if not return_dict:
204
+ return (reshaped_logits,) + outputs[2:]
205
+ return reshaped_logits
206
+
207
+ class FlaxGPTNeoForMultipleChoice(FlaxGPTNeoPreTrainedModel):
208
+ module_class = FlaxGPTNeoForMultipleChoiceModule
209
+
src/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
1
+ jax
2
+ flax
3
+ transformers
4
+ Datasets
5
+ tqdm
6
+ tensorflow
7
+ sklearn