jonathanjordan21 commited on
Commit
c7ba421
1 Parent(s): e1dca0d

Update modeling_mos_mamba.py

Browse files
Files changed (1) hide show
  1. modeling_mos_mamba.py +17 -17
modeling_mos_mamba.py CHANGED
@@ -32,28 +32,28 @@ from .configuration_mos_mamba import MoSMambaConfig
32
  import torch.nn.functional as F
33
 
34
 
35
- if is_mamba_ssm_available():
36
- from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
37
- from mamba_ssm.ops.triton.selective_state_update import selective_state_update
38
- else:
39
- selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
40
-
41
- if is_causal_conv1d_available():
42
- from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
43
- else:
44
- causal_conv1d_update, causal_conv1d_fn = None, None
45
-
46
-
47
- # try:
48
  # from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
49
  # from mamba_ssm.ops.triton.selective_state_update import selective_state_update
50
- # except:
51
  # selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
52
 
53
- # try:
54
  # from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
55
- # except:
56
  # causal_conv1d_update, causal_conv1d_fn = None, None
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
 
59
  is_fast_path_available = all(
@@ -706,7 +706,7 @@ class MoSMambaPreTrainedModel(PreTrainedModel):
706
  if module.bias is not None:
707
  if not getattr(module.bias, "_no_reinit", False):
708
  nn.init.zeros_(module.bias)
709
- nn.init.uniform_(module.weight, -0.001, 0.001)
710
 
711
  elif isinstance(module, nn.Embedding):
712
  nn.init.normal_(module.weight, std=self.config.initializer_range)
 
32
  import torch.nn.functional as F
33
 
34
 
35
+ # if is_mamba_ssm_available():
 
 
 
 
 
 
 
 
 
 
 
 
36
  # from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
37
  # from mamba_ssm.ops.triton.selective_state_update import selective_state_update
38
+ # else:
39
  # selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
40
 
41
+ # if is_causal_conv1d_available():
42
  # from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
43
+ # else:
44
  # causal_conv1d_update, causal_conv1d_fn = None, None
45
+
46
+
47
+ try:
48
+ from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
49
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
50
+ except:
51
+ selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
52
+
53
+ try:
54
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
55
+ except:
56
+ causal_conv1d_update, causal_conv1d_fn = None, None
57
 
58
 
59
  is_fast_path_available = all(
 
706
  if module.bias is not None:
707
  if not getattr(module.bias, "_no_reinit", False):
708
  nn.init.zeros_(module.bias)
709
+ # nn.init.uniform_(module.weight, -0.001, 0.001)
710
 
711
  elif isinstance(module, nn.Embedding):
712
  nn.init.normal_(module.weight, std=self.config.initializer_range)