jonathanjordan21 commited on
Commit
3f92511
1 Parent(s): f59f64a

Update modeling_mos_mamba.py

Browse files
Files changed (1) hide show
  1. modeling_mos_mamba.py +18 -5
modeling_mos_mamba.py CHANGED
@@ -26,22 +26,35 @@ from torch.nn import CrossEntropyLoss
26
  from transformers.activations import ACT2FN
27
  from transformers.modeling_utils import PreTrainedModel
28
  from transformers.utils import ModelOutput
29
- from transformers.utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available
30
  from .configuration_mos_mamba import MoSMambaConfig
31
 
32
  import torch.nn.functional as F
33
 
34
 
35
- if is_mamba_ssm_available():
 
 
 
 
 
 
 
 
 
 
 
 
36
  from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
37
  from mamba_ssm.ops.triton.selective_state_update import selective_state_update
38
- else:
39
  selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
40
 
41
- if is_causal_conv1d_available():
42
  from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
43
- else:
44
  causal_conv1d_update, causal_conv1d_fn = None, None
 
45
 
46
  is_fast_path_available = all(
47
  (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
 
26
  from transformers.activations import ACT2FN
27
  from transformers.modeling_utils import PreTrainedModel
28
  from transformers.utils import ModelOutput
29
+ # from transformers.utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available
30
  from .configuration_mos_mamba import MoSMambaConfig
31
 
32
  import torch.nn.functional as F
33
 
34
 
35
+ # if is_mamba_ssm_available():
36
+ # from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
37
+ # from mamba_ssm.ops.triton.selective_state_update import selective_state_update
38
+ # else:
39
+ # selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
40
+
41
+ # if is_causal_conv1d_available():
42
+ # from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
43
+ # else:
44
+ # causal_conv1d_update, causal_conv1d_fn = None, None
45
+
46
+
47
+ try:
48
  from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
49
  from mamba_ssm.ops.triton.selective_state_update import selective_state_update
50
+ except:
51
  selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
52
 
53
+ try:
54
  from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
55
+ except:
56
  causal_conv1d_update, causal_conv1d_fn = None, None
57
+
58
 
59
  is_fast_path_available = all(
60
  (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)