fix int/str for conv_dim indexing

#5
by winglian - opened
Files changed (1) hide show
  1. modeling_hymba.py +3 -3
modeling_hymba.py CHANGED
@@ -396,7 +396,7 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
396
 
397
  if has_mamba_state:
398
  if hasattr(config, 'conv_dim'):
399
- conv_dim = config.conv_dim[i]
400
  else:
401
  conv_dim = intermediate_size
402
  self.conv_states += [
@@ -1523,7 +1523,7 @@ class HymbaBlock(nn.Module):
1523
  num_ssm_param = 1
1524
 
1525
  if not hasattr(config, 'conv_dim'):
1526
- config.conv_dim = {i:0 for i in range(config.num_hidden_layers)}
1527
 
1528
  self.conv1d = nn.Conv1d(
1529
  in_channels=self.intermediate_size,
@@ -1534,7 +1534,7 @@ class HymbaBlock(nn.Module):
1534
  padding=self.conv_kernel_size - 1
1535
  )
1536
 
1537
- config.conv_dim[self.layer_idx] = self.intermediate_size
1538
 
1539
  self.x_proj = nn.ModuleList([nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False) for _ in range(num_ssm_param)])
1540
  self.dt_proj = nn.ModuleList([nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) for _ in range(num_ssm_param)])
 
396
 
397
  if has_mamba_state:
398
  if hasattr(config, 'conv_dim'):
399
+ conv_dim = config.conv_dim[str(i)]
400
  else:
401
  conv_dim = intermediate_size
402
  self.conv_states += [
 
1523
  num_ssm_param = 1
1524
 
1525
  if not hasattr(config, 'conv_dim'):
1526
+ config.conv_dim = {str(i):0 for i in range(config.num_hidden_layers)}
1527
 
1528
  self.conv1d = nn.Conv1d(
1529
  in_channels=self.intermediate_size,
 
1534
  padding=self.conv_kernel_size - 1
1535
  )
1536
 
1537
+ config.conv_dim[str(self.layer_idx)] = self.intermediate_size
1538
 
1539
  self.x_proj = nn.ModuleList([nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False) for _ in range(num_ssm_param)])
1540
  self.dt_proj = nn.ModuleList([nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) for _ in range(num_ssm_param)])