Spaces:
Paused
Paused
Update model.py
Browse files
model.py
CHANGED
@@ -372,7 +372,8 @@ class GPT(nn.Module):
|
|
372 |
idx = torch.cat((idx, idx_next), dim=1)
|
373 |
|
374 |
return idx
|
375 |
-
|
|
|
376 |
def generate_streaming(self, idx, max_new_tokens, temperature=1.0, top_k=None):
|
377 |
"""
|
378 |
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
|
@@ -399,3 +400,48 @@ class GPT(nn.Module):
|
|
399 |
# append sampled index to the running sequence and continue
|
400 |
idx = torch.cat((idx, idx_next), dim=1)
|
401 |
yield idx_next.item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
372 |
idx = torch.cat((idx, idx_next), dim=1)
|
373 |
|
374 |
return idx
|
375 |
+
|
376 |
+
@torch.no_grad()
|
377 |
def generate_streaming(self, idx, max_new_tokens, temperature=1.0, top_k=None):
|
378 |
"""
|
379 |
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
|
|
|
400 |
# append sampled index to the running sequence and continue
|
401 |
idx = torch.cat((idx, idx_next), dim=1)
|
402 |
yield idx_next.item()
|
403 |
+
|
404 |
+
@torch.no_grad()
|
405 |
+
def generate_instructed_streaming(self, idx, idi, max_new_tokens, temperature=1.0, top_k=None):
|
406 |
+
"""
|
407 |
+
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
|
408 |
+
the sequence max_new_tokens times, feeding the predictions back into the model each time.
|
409 |
+
Yield the generated indices one at a time rather than concatenating them into a single tensor.
|
410 |
+
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
|
411 |
+
"""
|
412 |
+
idi_length = idi.size(1)
|
413 |
+
max_idx_length = self.config.block_size - idi_length
|
414 |
+
|
415 |
+
# Precompute the minimum top_k value for logits.size(-1)
|
416 |
+
min_top_k = None
|
417 |
+
if top_k is not None:
|
418 |
+
min_top_k = min(top_k, self.config.vocab_size)
|
419 |
+
|
420 |
+
for _ in range(max_new_tokens):
|
421 |
+
# if the sequence context is growing too long we must crop it at block_size
|
422 |
+
idx_cond = idx if idx.size(1) <= max_idx_length else idx[:, -max_idx_length:]
|
423 |
+
|
424 |
+
# concatenate idi with the cropped idx
|
425 |
+
idx_cond = torch.cat((idi, idx_cond), dim=1)
|
426 |
+
|
427 |
+
# forward the model to get the logits for the index in the sequence
|
428 |
+
logits, _ = self(idx_cond)
|
429 |
+
|
430 |
+
# pluck the logits at the final step and scale by desired temperature
|
431 |
+
logits = logits[:, -1, :] / temperature
|
432 |
+
|
433 |
+
# optionally crop the logits to only the top k options
|
434 |
+
if min_top_k is not None:
|
435 |
+
v, _ = torch.topk(logits, min_top_k)
|
436 |
+
logits[logits < v[:, [-1]]] = -float('Inf')
|
437 |
+
|
438 |
+
# apply softmax to convert logits to (normalized) probabilities
|
439 |
+
probs = F.softmax(logits, dim=-1)
|
440 |
+
|
441 |
+
# sample from the distribution
|
442 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
443 |
+
|
444 |
+
# yield the next index
|
445 |
+
# append sampled index to the running sequence and continue
|
446 |
+
idx = torch.cat((idx, idx_next), dim=1)
|
447 |
+
yield idx_next.item()
|