Crystalcareai commited on
Commit
9d87e98
·
verified ·
1 Parent(s): ba01a71

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +2 -0
generate.py CHANGED
@@ -47,6 +47,8 @@ def custom_generate(
47
  with torch.no_grad():
48
  finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=device)
49
 
 
 
50
  for cur_token_idx in range(max_new_tokens):
51
  # Sample the next token
52
  new_ids = self(
 
47
  with torch.no_grad():
48
  finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=device)
49
 
50
+ if max_new_tokens is None:
51
+ max_new_tokens = 50 # Default value if not specified
52
  for cur_token_idx in range(max_new_tokens):
53
  # Sample the next token
54
  new_ids = self(