Update modeling_jamba.py - LoRA support in Mamba

#6
by tomeras1 - opened
Files changed (1) hide show
  1. modeling_jamba.py +12 -4
modeling_jamba.py CHANGED
@@ -943,14 +943,22 @@ class JambaMambaMixer(nn.Module):
943
  # in order to make quantization work. Quantization code replaces `torch.nn.Linear` layers with quantized
944
  # linear layers, and requires to call the forward pass directly.
945
  # The original code here was: ```discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)```
946
- dt_proj_bias = self.dt_proj.bias
947
- self.dt_proj.bias = None
 
 
 
 
 
948
  discrete_time_step = self.dt_proj(time_step).transpose(1, 2)
949
- self.dt_proj.bias = dt_proj_bias
 
 
 
950
 
951
  A = -torch.exp(self.A_log.float())
952
  # 3.c perform the recurrence y ← SSM(A, B, C)(x)
953
- time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None
954
  if cache_params is not None and cache_params.seqlen_offset > 0:
955
  scan_outputs = selective_state_update(
956
  cache_params.ssm_states[self.layer_idx],
 
943
  # in order to make quantization work. Quantization code replaces `torch.nn.Linear` layers with quantized
944
  # linear layers, and requires to call the forward pass directly.
945
  # The original code here was: ```discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)```
946
+ if hasattr(self.dt_proj, "base_layer"):
947
+ # In case of LoRA, we need to access the base layer to get the weight
948
+ time_proj_bias = self.dt_proj.base_layer.bias
949
+ self.dt_proj.base_layer.bias = None
950
+ else:
951
+ time_proj_bias = self.dt_proj.bias
952
+ self.dt_proj.bias = None
953
  discrete_time_step = self.dt_proj(time_step).transpose(1, 2)
954
+ if hasattr(self.dt_proj, "base_layer"):
955
+ self.dt_proj.base_layer.bias = time_proj_bias
956
+ else:
957
+ self.dt_proj.bias = time_proj_bias
958
 
959
  A = -torch.exp(self.A_log.float())
960
  # 3.c perform the recurrence y ← SSM(A, B, C)(x)
961
+ time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None
962
  if cache_params is not None and cache_params.seqlen_offset > 0:
963
  scan_outputs = selective_state_update(
964
  cache_params.ssm_states[self.layer_idx],