Corianas commited on
Commit
6814bfe
1 Parent(s): b731873

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +47 -1
model.py CHANGED
@@ -372,7 +372,8 @@ class GPT(nn.Module):
372
  idx = torch.cat((idx, idx_next), dim=1)
373
 
374
  return idx
375
-
 
376
  def generate_streaming(self, idx, max_new_tokens, temperature=1.0, top_k=None):
377
  """
378
  Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
@@ -399,3 +400,48 @@ class GPT(nn.Module):
399
  # append sampled index to the running sequence and continue
400
  idx = torch.cat((idx, idx_next), dim=1)
401
  yield idx_next.item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  idx = torch.cat((idx, idx_next), dim=1)
373
 
374
  return idx
375
+
376
+ @torch.no_grad()
377
  def generate_streaming(self, idx, max_new_tokens, temperature=1.0, top_k=None):
378
  """
379
  Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
 
400
  # append sampled index to the running sequence and continue
401
  idx = torch.cat((idx, idx_next), dim=1)
402
  yield idx_next.item()
403
+
404
+ @torch.no_grad()
405
+ def generate_instructed_streaming(self, idx, idi, max_new_tokens, temperature=1.0, top_k=None):
406
+ """
407
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
408
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
409
+ Yield the generated indices one at a time rather than concatenating them into a single tensor.
410
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
411
+ """
412
+ idi_length = idi.size(1)
413
+ max_idx_length = self.config.block_size - idi_length
414
+
415
+ # Precompute the minimum top_k value for logits.size(-1)
416
+ min_top_k = None
417
+ if top_k is not None:
418
+ min_top_k = min(top_k, self.config.vocab_size)
419
+
420
+ for _ in range(max_new_tokens):
421
+ # if the sequence context is growing too long we must crop it at block_size
422
+ idx_cond = idx if idx.size(1) <= max_idx_length else idx[:, -max_idx_length:]
423
+
424
+ # concatenate idi with the cropped idx
425
+ idx_cond = torch.cat((idi, idx_cond), dim=1)
426
+
427
+ # forward the model to get the logits for the index in the sequence
428
+ logits, _ = self(idx_cond)
429
+
430
+ # pluck the logits at the final step and scale by desired temperature
431
+ logits = logits[:, -1, :] / temperature
432
+
433
+ # optionally crop the logits to only the top k options
434
+ if min_top_k is not None:
435
+ v, _ = torch.topk(logits, min_top_k)
436
+ logits[logits < v[:, [-1]]] = -float('Inf')
437
+
438
+ # apply softmax to convert logits to (normalized) probabilities
439
+ probs = F.softmax(logits, dim=-1)
440
+
441
+ # sample from the distribution
442
+ idx_next = torch.multinomial(probs, num_samples=1)
443
+
444
+ # yield the next index
445
+ # append sampled index to the running sequence and continue
446
+ idx = torch.cat((idx, idx_next), dim=1)
447
+ yield idx_next.item()