Update generate.py
Browse files- generate.py +2 -0
generate.py
CHANGED
|
@@ -47,6 +47,8 @@ def custom_generate(
|
|
| 47 |
with torch.no_grad():
|
| 48 |
finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=device)
|
| 49 |
|
|
|
|
|
|
|
| 50 |
for cur_token_idx in range(max_new_tokens):
|
| 51 |
# Sample the next token
|
| 52 |
new_ids = self(
|
|
|
|
| 47 |
with torch.no_grad():
|
| 48 |
finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=device)
|
| 49 |
|
| 50 |
+
if max_new_tokens is None:
|
| 51 |
+
max_new_tokens = 50 # Default value if not specified
|
| 52 |
for cur_token_idx in range(max_new_tokens):
|
| 53 |
# Sample the next token
|
| 54 |
new_ids = self(
|