Alexandru Gherghescu
commited on
Commit
•
04fbb43
1
Parent(s):
15c1815
Fix inference code
Browse files- inference.py +1 -1
- modeling_gpt1.py +10 -35
inference.py
CHANGED
@@ -5,7 +5,7 @@ model = AutoModelForCausalLM.from_pretrained(checkpoint, trust_remote_code=True)
|
|
5 |
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
|
6 |
|
7 |
prompt = 'The mastermind behind the plan was, all along, '
|
8 |
-
inputs = tokenizer(prompt, return_tensors='pt')
|
9 |
|
10 |
generate_ids = model.generate(inputs.input_ids,
|
11 |
max_new_tokens=40,
|
|
|
5 |
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
|
6 |
|
7 |
prompt = 'The mastermind behind the plan was, all along, '
|
8 |
+
inputs = tokenizer(prompt, return_tensors='pt', add_special_tokens=True)
|
9 |
|
10 |
generate_ids = model.generate(inputs.input_ids,
|
11 |
max_new_tokens=40,
|
modeling_gpt1.py
CHANGED
@@ -149,13 +149,6 @@ class GPT1Model(GPT1PreTrainedModel):
|
|
149 |
[GPT1DecoderLayer(config) for _ in range(config.num_hidden_layers)]
|
150 |
)
|
151 |
|
152 |
-
causal_mask = torch.full((1, config.max_position_embeddings, config.max_position_embeddings),
|
153 |
-
fill_value=float('-inf'))
|
154 |
-
self.register_buffer('causal_mask',
|
155 |
-
torch.triu(causal_mask, diagonal=1),
|
156 |
-
persistent=False)
|
157 |
-
self.mask_cache_len = config.max_position_embeddings
|
158 |
-
|
159 |
self.post_init()
|
160 |
|
161 |
def get_input_embeddings(self):
|
@@ -164,7 +157,7 @@ class GPT1Model(GPT1PreTrainedModel):
|
|
164 |
def set_input_embeddings(self, value):
|
165 |
self.embs = value
|
166 |
|
167 |
-
def forward(self, input_ids,
|
168 |
position_ids = torch.arange(input_ids.size(-1),
|
169 |
dtype=torch.long,
|
170 |
device=input_ids.device).unsqueeze_(0)
|
@@ -173,18 +166,12 @@ class GPT1Model(GPT1PreTrainedModel):
|
|
173 |
position_embeds = self.pos_emb(position_ids)
|
174 |
hidden_state = self.embs_dropout(input_embeds) + position_embeds
|
175 |
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
causal_mask = torch.full((seq_len, seq_len),
|
181 |
-
fill_value=float('-inf'))
|
182 |
-
self.register_buffer('causal_mask',
|
183 |
-
torch.triu(causal_mask, diagonal=1),
|
184 |
-
persistent=False)
|
185 |
|
186 |
-
causal_mask =
|
187 |
-
|
188 |
|
189 |
for layer in self.layers:
|
190 |
hidden_state = layer(hidden_state, attn_mask=causal_mask)
|
@@ -225,9 +212,8 @@ class GPT1ForCausalLM(GPT1PreTrainedModel):
|
|
225 |
def set_decoder(self, decoder):
|
226 |
self.model = decoder
|
227 |
|
228 |
-
def forward(self, input_ids, labels=None,
|
229 |
-
|
230 |
-
output = self.model(input_ids, attention_mask)
|
231 |
|
232 |
hidden_state = output[0]
|
233 |
logits = self.lm_head(hidden_state).float()
|
@@ -247,16 +233,5 @@ class GPT1ForCausalLM(GPT1PreTrainedModel):
|
|
247 |
logits=logits
|
248 |
)
|
249 |
|
250 |
-
def prepare_inputs_for_generation(self, input_ids,
|
251 |
-
|
252 |
-
assert attention_mask.size(1) == input_ids.size(1)
|
253 |
-
|
254 |
-
seq_len = attention_mask.size(1)
|
255 |
-
|
256 |
-
attn_mask = torch.full((seq_len, seq_len), fill_value=float('-inf'))
|
257 |
-
attn_mask = torch.triu(attn_mask, diagonal=1)
|
258 |
-
|
259 |
-
return {
|
260 |
-
'input_ids': input_ids,
|
261 |
-
'attention_mask': attn_mask
|
262 |
-
}
|
|
|
149 |
[GPT1DecoderLayer(config) for _ in range(config.num_hidden_layers)]
|
150 |
)
|
151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
self.post_init()
|
153 |
|
154 |
def get_input_embeddings(self):
|
|
|
157 |
def set_input_embeddings(self, value):
|
158 |
self.embs = value
|
159 |
|
160 |
+
def forward(self, input_ids, *args, **kwargs):
|
161 |
position_ids = torch.arange(input_ids.size(-1),
|
162 |
dtype=torch.long,
|
163 |
device=input_ids.device).unsqueeze_(0)
|
|
|
166 |
position_embeds = self.pos_emb(position_ids)
|
167 |
hidden_state = self.embs_dropout(input_embeds) + position_embeds
|
168 |
|
169 |
+
seq_len = input_ids.size(-1)
|
170 |
+
attn_mask = torch.full((seq_len, seq_len), fill_value=float('-inf'))
|
171 |
+
attn_mask = torch.triu(attn_mask, diagonal=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
+
causal_mask = attn_mask.to(dtype=input_embeds.dtype,
|
174 |
+
device=input_embeds.device)
|
175 |
|
176 |
for layer in self.layers:
|
177 |
hidden_state = layer(hidden_state, attn_mask=causal_mask)
|
|
|
212 |
def set_decoder(self, decoder):
|
213 |
self.model = decoder
|
214 |
|
215 |
+
def forward(self, input_ids, labels=None, *args, **kwargs):
|
216 |
+
output = self.model(input_ids)
|
|
|
217 |
|
218 |
hidden_state = output[0]
|
219 |
logits = self.lm_head(hidden_state).float()
|
|
|
233 |
logits=logits
|
234 |
)
|
235 |
|
236 |
+
def prepare_inputs_for_generation(self, input_ids, *args, **kwargs):
|
237 |
+
return { 'input_ids': input_ids }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|