fix int/str for conv_dim indexing
#5
by
winglian
- opened
- modeling_hymba.py +3 -3
modeling_hymba.py
CHANGED
@@ -396,7 +396,7 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
|
|
396 |
|
397 |
if has_mamba_state:
|
398 |
if hasattr(config, 'conv_dim'):
|
399 |
-
conv_dim = config.conv_dim[i]
|
400 |
else:
|
401 |
conv_dim = intermediate_size
|
402 |
self.conv_states += [
|
@@ -1523,7 +1523,7 @@ class HymbaBlock(nn.Module):
|
|
1523 |
num_ssm_param = 1
|
1524 |
|
1525 |
if not hasattr(config, 'conv_dim'):
|
1526 |
-
config.conv_dim = {i:0 for i in range(config.num_hidden_layers)}
|
1527 |
|
1528 |
self.conv1d = nn.Conv1d(
|
1529 |
in_channels=self.intermediate_size,
|
@@ -1534,7 +1534,7 @@ class HymbaBlock(nn.Module):
|
|
1534 |
padding=self.conv_kernel_size - 1
|
1535 |
)
|
1536 |
|
1537 |
-
config.conv_dim[self.layer_idx] = self.intermediate_size
|
1538 |
|
1539 |
self.x_proj = nn.ModuleList([nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False) for _ in range(num_ssm_param)])
|
1540 |
self.dt_proj = nn.ModuleList([nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) for _ in range(num_ssm_param)])
|
|
|
396 |
|
397 |
if has_mamba_state:
|
398 |
if hasattr(config, 'conv_dim'):
|
399 |
+
conv_dim = config.conv_dim[str(i)]
|
400 |
else:
|
401 |
conv_dim = intermediate_size
|
402 |
self.conv_states += [
|
|
|
1523 |
num_ssm_param = 1
|
1524 |
|
1525 |
if not hasattr(config, 'conv_dim'):
|
1526 |
+
config.conv_dim = {str(i):0 for i in range(config.num_hidden_layers)}
|
1527 |
|
1528 |
self.conv1d = nn.Conv1d(
|
1529 |
in_channels=self.intermediate_size,
|
|
|
1534 |
padding=self.conv_kernel_size - 1
|
1535 |
)
|
1536 |
|
1537 |
+
config.conv_dim[str(self.layer_idx)] = self.intermediate_size
|
1538 |
|
1539 |
self.x_proj = nn.ModuleList([nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False) for _ in range(num_ssm_param)])
|
1540 |
self.dt_proj = nn.ModuleList([nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) for _ in range(num_ssm_param)])
|