Files changed (1) hide show
  1. modeling_backpack_gpt2.py +42 -2
modeling_backpack_gpt2.py CHANGED
@@ -153,7 +153,7 @@ class BackpackGPT2Model(BackpackGPT2PreTrainedModel):
153
  def get_sense_network(self):
154
  return self.sense_network
155
 
156
- def forward(self, input_ids, position_ids):
157
  # Compute senses
158
  sense_input_embeds = self.word_embeddings(input_ids)
159
  senses = self.sense_network(sense_input_embeds) # (bs, nv, s, d)
@@ -205,8 +205,48 @@ class BackpackGPT2LMHeadModel(BackpackGPT2PreTrainedModel):
205
 
206
  def get_lm_head(self):
207
  return self.lm_head
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
- def forward(self, input_ids, position_ids=None):
210
  outputs = self.backpack(input_ids, position_ids=position_ids)
211
  hidden_states, contextualization = outputs.hidden_states, outputs.contextualization
212
  lm_logits = self.lm_head(hidden_states) # (bs, s, V)
 
153
  def get_sense_network(self):
154
  return self.sense_network
155
 
156
+ def forward(self, input_ids, position_ids, **kwargs):
157
  # Compute senses
158
  sense_input_embeds = self.word_embeddings(input_ids)
159
  senses = self.sense_network(sense_input_embeds) # (bs, nv, s, d)
 
205
 
206
  def get_lm_head(self):
207
  return self.lm_head
208
+
209
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, attention_mask=None, **kwargs):
210
+ # prepare_inputs_for_generation needs to be overwritten to support generation
211
+ # this is inspired from the one in GPT2LMHeadModel: https://github.com/huggingface/transformers/blob/d533465150532b0c5de167b574e59f64c68b1154/src/transformers/models/gpt2/modeling_gpt2.py#L1007C4-L1007C4
212
+
213
+ token_type_ids = kwargs.get("token_type_ids", None)
214
+ # only last token for inputs_ids if past is defined in kwargs
215
+ if past_key_values:
216
+ input_ids = input_ids[:, -1].unsqueeze(-1)
217
+ if token_type_ids is not None:
218
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
219
+
220
+ attention_mask = kwargs.get("attention_mask", None)
221
+ position_ids = kwargs.get("position_ids", None)
222
+
223
+ if attention_mask is not None and position_ids is None:
224
+ # create position_ids on the fly for batch generation
225
+ position_ids = attention_mask.long().cumsum(-1) - 1
226
+ position_ids.masked_fill_(attention_mask == 0, 1)
227
+ if past_key_values:
228
+ position_ids = position_ids[:, -1].unsqueeze(-1)
229
+ else:
230
+ position_ids = None
231
+
232
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
233
+ if inputs_embeds is not None and past_key_values is None:
234
+ model_inputs = {"inputs_embeds": inputs_embeds}
235
+ else:
236
+ model_inputs = {"input_ids": input_ids}
237
+
238
+ model_inputs.update(
239
+ {
240
+ "past_key_values": past_key_values,
241
+ "use_cache": kwargs.get("use_cache"),
242
+ "position_ids": position_ids,
243
+ "attention_mask": attention_mask,
244
+ "token_type_ids": token_type_ids,
245
+ }
246
+ )
247
+ return model_inputs
248
 
249
+ def forward(self, input_ids, position_ids=None, **kwargs):
250
  outputs = self.backpack(input_ids, position_ids=position_ids)
251
  hidden_states, contextualization = outputs.hidden_states, outputs.contextualization
252
  lm_logits = self.lm_head(hidden_states) # (bs, s, V)