Vivek commited on
Commit
d006373
1 Parent(s): c34dab9

added all the files

Browse files
Files changed (3) hide show
  1. app.py +67 -0
  2. model_file.py +226 -0
  3. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import transformers
3
+ from transformers import (
4
+ GPT2Config,
5
+ GPT2Tokenizer)
6
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2",pad_token='<|endoftext|>')
7
+ from model_file import FlaxGPT2ForMultipleChoice
8
+ import jax.numpy as jnp
9
+
10
+ st.title('GPT2 for common sense reasoning')
11
+ st.write('Multiple Choice Question Answering using CosmosQA Dataset')
12
+
13
+ context=st.text_area('Context',height=25)
14
+ st.write(context)
15
+ #context = st.text_input('Context :')
16
+
17
+
18
+
19
+
20
+
21
+ question=st.text_input('Question')
22
+
23
+
24
+ buff, col, buff2 = st.beta_columns([5,1,2])
25
+ choice_a=buff.text_input('choice 0:')
26
+ choice_b=buff.text_input('choice 1:')
27
+ choice_c=buff.text_input('choice 2:')
28
+ choice_d=buff.text_input('choice 3:')
29
+
30
+ a={}
31
+ def preprocess(context,question,choice_a,choice_b,choice_c,choice_d):
32
+ a['context&question']=context+question
33
+ a['first_sentence']=[a['context&question'],a['context&question'],a['context&question'],a['context&question']]
34
+ a['second_sentence']=choice_a,choice_b,choice_c,choice_d
35
+ return a
36
+
37
+ preprocessed_data=preprocess(context,question,choice_a,choice_b,choice_c,choice_d)
38
+
39
+ def tokenize(examples):
40
+ b=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,max_length=256,return_tensors='jax')
41
+ return b
42
+
43
+ tokenized_data=tokenize(preprocessed_data)
44
+
45
+
46
+ model = FlaxGPT2ForMultipleChoice.from_pretrained("flax-community/gpt2-Cosmos",input_shape=(1,4,1))
47
+
48
+ input_id=jnp.array(tokenized_data['input_ids'])
49
+ att_mask=jnp.array(tokenized_data['attention_mask'])
50
+
51
+
52
+ if st.button("Run"):
53
+ with st.spinner(text="Getting results..."):
54
+ outputs=model(input_id,att_mask)
55
+ final_output=jnp.argmax(outputs,axis=-1)
56
+ if final_output==0:
57
+ result='0'
58
+ elif final_output==1:
59
+ result='1'
60
+ elif final_output==2:
61
+ result='2'
62
+ elif final_output==3:
63
+ result='3'
64
+ st.success(f"The answer is choice {result1}")
65
+
66
+
67
+
model_file.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import jax.numpy as jnp
3
+
4
+ import flax
5
+ import flax.linen as nn
6
+ from flax.core.frozen_dict import FrozenDict, unfreeze
7
+
8
+ from typing import Any, Optional, Tuple
9
+
10
+ from transformers import (
11
+ GPT2Config)
12
+
13
+ from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
14
+ from transformers.models.gpt2.modeling_flax_gpt2 import FlaxGPT2BlockCollection
15
+ from transformers.modeling_flax_outputs import FlaxBaseModelOutput
16
+ from transformers.modeling_flax_utils import FlaxPreTrainedModel
17
+ from transformers.models.gpt2.modeling_flax_gpt2 import FlaxGPT2Module
18
+
19
+ from transformers import GPT2Tokenizer
20
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2",pad_token='<|endoftext|>')
21
+
22
+ GPT2_START_DOCSTRING = r"""
23
+ This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
24
+ generic methods the library implements for all its model (such as downloading or saving, resizing the input
25
+ embeddings, pruning heads etc.)
26
+ This model is also a Flax Linen `flax.nn.Module
27
+ <https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax
28
+ Module and refer to the Flax documentation for all matter related to general usage and behavior.
29
+ Finally, this model supports inherent JAX features such as:
30
+ - `Just-In-Time (JIT) compilation <https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit>`__
31
+ - `Automatic Differentiation <https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation>`__
32
+ - `Vectorization <https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap>`__
33
+ - `Parallelization <https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap>`__
34
+ Parameters:
35
+ config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model.
36
+ Initializing with a config file does not load the weights associated with the model, only the
37
+ configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
38
+ model weights.
39
+ """
40
+ GPT2_INPUTS_DOCSTRING = r"""
41
+ Args:
42
+ input_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size,input_ids_length)`):
43
+ :obj:`input_ids_length` = ``sequence_length``. Indices of input sequence tokens in the vocabulary.
44
+ Indices can be obtained using :class:`~transformers.GPT2Tokenizer`. See
45
+ :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
46
+ details.
47
+ `What are input IDs? <../glossary.html#input-ids>`__
48
+ attention_mask (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
49
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
50
+ - 1 for tokens that are **not masked**,
51
+ - 0 for tokens that are **masked**.
52
+ `What are attention masks? <../glossary.html#attention-mask>`__
53
+ position_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
54
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
55
+ config.max_position_embeddings - 1]``.
56
+ past_key_values (:obj:`Dict[str, np.ndarray]`, `optional`, returned by ``init_cache`` or when passing previous ``past_key_values``):
57
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
58
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape `[batch_size, max_length]`.
59
+ output_attentions (:obj:`bool`, `optional`):
60
+ Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
61
+ tensors for more detail.
62
+ output_hidden_states (:obj:`bool`, `optional`):
63
+ Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
64
+ more detail.
65
+ return_dict (:obj:`bool`, `optional`):
66
+ Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
67
+ """
68
+
69
+ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel): #modify
70
+ """
71
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
72
+ models.
73
+ """
74
+
75
+ config_class = GPT2Config
76
+ base_model_prefix = "transformer"
77
+ module_class: nn.Module = None
78
+
79
+ def __init__(
80
+ self,
81
+ config: GPT2Config,
82
+ input_shape: Tuple = (1, 1),
83
+ seed: int = 0,
84
+ dtype: jnp.dtype = jnp.float32,
85
+ **kwargs,
86
+ ):
87
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
88
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
89
+
90
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
91
+ # init input tensors
92
+ input_ids = jnp.zeros(input_shape, dtype="i4")
93
+ attention_mask = jnp.ones_like(input_ids)
94
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
95
+ params_rng, dropout_rng = jax.random.split(rng)
96
+ rngs = {"params": params_rng, "dropout": dropout_rng}
97
+
98
+ return self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"]
99
+
100
+ def init_cache(self, batch_size, max_length):
101
+ r"""
102
+ Args:
103
+ batch_size (:obj:`int`):
104
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
105
+ max_length (:obj:`int`):
106
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
107
+ cache.
108
+ """
109
+ # init input variables to retrieve cache
110
+ input_ids = jnp.ones((batch_size, max_length))
111
+ attention_mask = jnp.ones_like(input_ids)
112
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
113
+
114
+ init_variables = self.module.init(
115
+ jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
116
+ )
117
+ return init_variables["cache"]
118
+
119
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
120
+ def __call__(
121
+ self,
122
+ input_ids,
123
+ attention_mask=None,
124
+ position_ids=None,
125
+ params: dict = None,
126
+ past_key_values: dict = None,
127
+ dropout_rng: jax.random.PRNGKey = None,
128
+ train: bool = False,
129
+ output_attentions: Optional[bool] = None,
130
+ output_hidden_states: Optional[bool] = None,
131
+ return_dict: Optional[bool] = None,
132
+ ):
133
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
134
+ output_hidden_states = (
135
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
136
+ )
137
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
138
+
139
+
140
+ if position_ids is None:
141
+ if past_key_values is not None:
142
+ raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
143
+
144
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
145
+
146
+ if attention_mask is None:
147
+ attention_mask = jnp.ones_like(input_ids)
148
+
149
+ # Handle any PRNG if needed
150
+ rngs = {}
151
+ if dropout_rng is not None:
152
+ rngs["dropout"] = dropout_rng
153
+
154
+ inputs = {"params": params or self.params}
155
+
156
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPT2Attention module
157
+ if past_key_values:
158
+ inputs["cache"] = past_key_values
159
+ mutable = ["cache"]
160
+ else:
161
+ mutable = False
162
+
163
+ outputs = self.module.apply(
164
+ inputs,
165
+ jnp.array(input_ids, dtype="i4"),
166
+ jnp.array(attention_mask, dtype="i4"),
167
+ jnp.array(position_ids, dtype="i4"),
168
+ not train,
169
+ False,
170
+ output_attentions,
171
+ output_hidden_states,
172
+ return_dict,
173
+ rngs=rngs,
174
+ mutable=mutable,
175
+ )
176
+
177
+ # add updated cache to model output
178
+ if past_key_values is not None and return_dict:
179
+ outputs, past_key_values = outputs
180
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
181
+ return outputs
182
+ elif past_key_values is not None and not return_dict:
183
+ outputs, past_key_values = outputs
184
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
185
+
186
+ return outputs
187
+
188
+ class FlaxGPT2ForMultipleChoiceModule(nn.Module):
189
+ config:GPT2Config
190
+ dtype: jnp.dtype = jnp.float32
191
+ def setup(self):
192
+ self.transformer = FlaxGPT2Module(config=self.config, dtype=self.dtype)
193
+ self.dropout = nn.Dropout(rate=0.2)
194
+ self.classifier = nn.Dense(4, dtype=self.dtype)
195
+
196
+ def __call__(self,input_ids,attention_mask,position_ids,return_dict=True,deterministic=True,*args):
197
+ batch_size = input_ids.shape[0]
198
+ rng=jax.random.PRNGKey(0)
199
+ _, dropout_rng = jax.random.split(rng)
200
+ input_ids=input_ids.reshape(4*batch_size,-1)
201
+ position_ids=position_ids.reshape(4*batch_size,-1)
202
+ attention_mask=attention_mask.reshape(4*batch_size,-1)
203
+
204
+ outputs=self.transformer(input_ids, attention_mask,position_ids,return_dict=return_dict)
205
+
206
+
207
+ hidden_states = outputs[0]
208
+ hidden_states= jnp.mean(hidden_states, axis=1)
209
+
210
+
211
+
212
+ hidden_states=hidden_states.reshape(batch_size,-1) #(32,8,768)->(32,8*768)
213
+
214
+ dropout_output = self.dropout(hidden_states,deterministic=deterministic,rng=dropout_rng)
215
+
216
+
217
+
218
+ logits = self.classifier(dropout_output)
219
+ reshaped_logits = logits.reshape(-1, 4)
220
+ #(32,4)
221
+ if not return_dict:
222
+ return (reshaped_logits,) + outputs[2:]
223
+ return reshaped_logits
224
+
225
+ class FlaxGPT2ForMultipleChoice(FlaxGPT2PreTrainedModel):
226
+ module_class = FlaxGPT2ForMultipleChoiceModule
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ jax
2
+ flax
3
+ transformers
4
+ Datasets
5
+ tqdm