Spaces:
Runtime error
Runtime error
fix: cur_len mismatch
Browse files- modeling_asteroid.py +1 -0
modeling_asteroid.py
CHANGED
|
@@ -85,6 +85,7 @@ class CustomMixin(GenerationMixin):
|
|
| 85 |
needs_additional_steps = -1 * torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
| 86 |
tf_inputs = input_ids[:]
|
| 87 |
input_ids = input_ids[:, :-(channels - 1)]
|
|
|
|
| 88 |
model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, :-(channels - 1)]
|
| 89 |
base_length = input_ids.shape[1]
|
| 90 |
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
|
|
|
|
| 85 |
needs_additional_steps = -1 * torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
| 86 |
tf_inputs = input_ids[:]
|
| 87 |
input_ids = input_ids[:, :-(channels - 1)]
|
| 88 |
+
cur_len = input_ids.shape[1]
|
| 89 |
model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, :-(channels - 1)]
|
| 90 |
base_length = input_ids.shape[1]
|
| 91 |
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
|