Alexandru Gherghescu commited on
Commit
04fbb43
1 Parent(s): 15c1815

Fix inference code

Browse files
Files changed (2) hide show
  1. inference.py +1 -1
  2. modeling_gpt1.py +10 -35
inference.py CHANGED
@@ -5,7 +5,7 @@ model = AutoModelForCausalLM.from_pretrained(checkpoint, trust_remote_code=True)
5
  tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
6
 
7
  prompt = 'The mastermind behind the plan was, all along, '
8
- inputs = tokenizer(prompt, return_tensors='pt')
9
 
10
  generate_ids = model.generate(inputs.input_ids,
11
  max_new_tokens=40,
 
5
  tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
6
 
7
  prompt = 'The mastermind behind the plan was, all along, '
8
+ inputs = tokenizer(prompt, return_tensors='pt', add_special_tokens=True)
9
 
10
  generate_ids = model.generate(inputs.input_ids,
11
  max_new_tokens=40,
modeling_gpt1.py CHANGED
@@ -149,13 +149,6 @@ class GPT1Model(GPT1PreTrainedModel):
149
  [GPT1DecoderLayer(config) for _ in range(config.num_hidden_layers)]
150
  )
151
 
152
- causal_mask = torch.full((1, config.max_position_embeddings, config.max_position_embeddings),
153
- fill_value=float('-inf'))
154
- self.register_buffer('causal_mask',
155
- torch.triu(causal_mask, diagonal=1),
156
- persistent=False)
157
- self.mask_cache_len = config.max_position_embeddings
158
-
159
  self.post_init()
160
 
161
  def get_input_embeddings(self):
@@ -164,7 +157,7 @@ class GPT1Model(GPT1PreTrainedModel):
164
  def set_input_embeddings(self, value):
165
  self.embs = value
166
 
167
- def forward(self, input_ids, attention_mask=None, *args, **kwargs):
168
  position_ids = torch.arange(input_ids.size(-1),
169
  dtype=torch.long,
170
  device=input_ids.device).unsqueeze_(0)
@@ -173,18 +166,12 @@ class GPT1Model(GPT1PreTrainedModel):
173
  position_embeds = self.pos_emb(position_ids)
174
  hidden_state = self.embs_dropout(input_embeds) + position_embeds
175
 
176
- if attention_mask is not None and attention_mask.size(1) > self.mask_cache_len:
177
- seq_len = attention_mask.size(1)
178
- self.mask_cache_len = seq_len
179
-
180
- causal_mask = torch.full((seq_len, seq_len),
181
- fill_value=float('-inf'))
182
- self.register_buffer('causal_mask',
183
- torch.triu(causal_mask, diagonal=1),
184
- persistent=False)
185
 
186
- causal_mask = self.causal_mask.to(dtype=input_embeds.dtype,
187
- device=input_embeds.device)
188
 
189
  for layer in self.layers:
190
  hidden_state = layer(hidden_state, attn_mask=causal_mask)
@@ -225,9 +212,8 @@ class GPT1ForCausalLM(GPT1PreTrainedModel):
225
  def set_decoder(self, decoder):
226
  self.model = decoder
227
 
228
- def forward(self, input_ids, labels=None, attention_mask=None,
229
- *args, **kwargs):
230
- output = self.model(input_ids, attention_mask)
231
 
232
  hidden_state = output[0]
233
  logits = self.lm_head(hidden_state).float()
@@ -247,16 +233,5 @@ class GPT1ForCausalLM(GPT1PreTrainedModel):
247
  logits=logits
248
  )
249
 
250
- def prepare_inputs_for_generation(self, input_ids, attention_mask,
251
- *args, **kwargs):
252
- assert attention_mask.size(1) == input_ids.size(1)
253
-
254
- seq_len = attention_mask.size(1)
255
-
256
- attn_mask = torch.full((seq_len, seq_len), fill_value=float('-inf'))
257
- attn_mask = torch.triu(attn_mask, diagonal=1)
258
-
259
- return {
260
- 'input_ids': input_ids,
261
- 'attention_mask': attn_mask
262
- }
 
149
  [GPT1DecoderLayer(config) for _ in range(config.num_hidden_layers)]
150
  )
151
 
 
 
 
 
 
 
 
152
  self.post_init()
153
 
154
  def get_input_embeddings(self):
 
157
  def set_input_embeddings(self, value):
158
  self.embs = value
159
 
160
+ def forward(self, input_ids, *args, **kwargs):
161
  position_ids = torch.arange(input_ids.size(-1),
162
  dtype=torch.long,
163
  device=input_ids.device).unsqueeze_(0)
 
166
  position_embeds = self.pos_emb(position_ids)
167
  hidden_state = self.embs_dropout(input_embeds) + position_embeds
168
 
169
+ seq_len = input_ids.size(-1)
170
+ attn_mask = torch.full((seq_len, seq_len), fill_value=float('-inf'))
171
+ attn_mask = torch.triu(attn_mask, diagonal=1)
 
 
 
 
 
 
172
 
173
+ causal_mask = attn_mask.to(dtype=input_embeds.dtype,
174
+ device=input_embeds.device)
175
 
176
  for layer in self.layers:
177
  hidden_state = layer(hidden_state, attn_mask=causal_mask)
 
212
  def set_decoder(self, decoder):
213
  self.model = decoder
214
 
215
+ def forward(self, input_ids, labels=None, *args, **kwargs):
216
+ output = self.model(input_ids)
 
217
 
218
  hidden_state = output[0]
219
  logits = self.lm_head(hidden_state).float()
 
233
  logits=logits
234
  )
235
 
236
+ def prepare_inputs_for_generation(self, input_ids, *args, **kwargs):
237
+ return { 'input_ids': input_ids }