jonathanjordan21
commited on
Commit
•
c7ba421
1
Parent(s):
e1dca0d
Update modeling_mos_mamba.py
Browse files- modeling_mos_mamba.py +17 -17
modeling_mos_mamba.py
CHANGED
@@ -32,28 +32,28 @@ from .configuration_mos_mamba import MoSMambaConfig
|
|
32 |
import torch.nn.functional as F
|
33 |
|
34 |
|
35 |
-
if is_mamba_ssm_available():
|
36 |
-
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
|
37 |
-
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
38 |
-
else:
|
39 |
-
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
|
40 |
-
|
41 |
-
if is_causal_conv1d_available():
|
42 |
-
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
43 |
-
else:
|
44 |
-
causal_conv1d_update, causal_conv1d_fn = None, None
|
45 |
-
|
46 |
-
|
47 |
-
# try:
|
48 |
# from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
|
49 |
# from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
50 |
-
#
|
51 |
# selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
|
52 |
|
53 |
-
#
|
54 |
# from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
55 |
-
#
|
56 |
# causal_conv1d_update, causal_conv1d_fn = None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
|
59 |
is_fast_path_available = all(
|
@@ -706,7 +706,7 @@ class MoSMambaPreTrainedModel(PreTrainedModel):
|
|
706 |
if module.bias is not None:
|
707 |
if not getattr(module.bias, "_no_reinit", False):
|
708 |
nn.init.zeros_(module.bias)
|
709 |
-
nn.init.uniform_(module.weight, -0.001, 0.001)
|
710 |
|
711 |
elif isinstance(module, nn.Embedding):
|
712 |
nn.init.normal_(module.weight, std=self.config.initializer_range)
|
|
|
32 |
import torch.nn.functional as F
|
33 |
|
34 |
|
35 |
+
# if is_mamba_ssm_available():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
# from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
|
37 |
# from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
38 |
+
# else:
|
39 |
# selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
|
40 |
|
41 |
+
# if is_causal_conv1d_available():
|
42 |
# from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
43 |
+
# else:
|
44 |
# causal_conv1d_update, causal_conv1d_fn = None, None
|
45 |
+
|
46 |
+
|
47 |
+
try:
|
48 |
+
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
|
49 |
+
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
50 |
+
except:
|
51 |
+
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
|
52 |
+
|
53 |
+
try:
|
54 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
55 |
+
except:
|
56 |
+
causal_conv1d_update, causal_conv1d_fn = None, None
|
57 |
|
58 |
|
59 |
is_fast_path_available = all(
|
|
|
706 |
if module.bias is not None:
|
707 |
if not getattr(module.bias, "_no_reinit", False):
|
708 |
nn.init.zeros_(module.bias)
|
709 |
+
# nn.init.uniform_(module.weight, -0.001, 0.001)
|
710 |
|
711 |
elif isinstance(module, nn.Embedding):
|
712 |
nn.init.normal_(module.weight, std=self.config.initializer_range)
|