Nanobit commited on
Commit
d69ba2b
1 Parent(s): 9e3f0cb

fix: warn user to install mamba_ssm package (#1019)

Browse files
docker/Dockerfile CHANGED
@@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl
20
 
21
  # If AXOLOTL_EXTRAS is set, append it in brackets
22
  RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
23
- pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \
24
  else \
25
- pip install -e .[deepspeed,flash-attn]; \
26
  fi
27
 
28
  # So we can test the Docker image
 
20
 
21
  # If AXOLOTL_EXTRAS is set, append it in brackets
22
  RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
23
+ pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS]; \
24
  else \
25
+ pip install -e .[deepspeed,flash-attn,mamba-ssm]; \
26
  fi
27
 
28
  # So we can test the Docker image
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
2
- packaging
3
  peft==0.7.0
4
  transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0
5
  tokenizers==0.15.0
@@ -34,6 +34,8 @@ fschat==0.2.34
34
  gradio==3.50.2
35
  tensorboard
36
 
 
 
37
  # remote filesystems
38
  s3fs
39
  gcsfs
 
1
  --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
2
+ packaging==23.2
3
  peft==0.7.0
4
  transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0
5
  tokenizers==0.15.0
 
34
  gradio==3.50.2
35
  tensorboard
36
 
37
+ mamba-ssm==1.1.1
38
+
39
  # remote filesystems
40
  s3fs
41
  gcsfs
setup.py CHANGED
@@ -11,17 +11,17 @@ def parse_requirements():
11
  with open("./requirements.txt", encoding="utf-8") as requirements_file:
12
  lines = [r.strip() for r in requirements_file.readlines()]
13
  for line in lines:
 
 
 
 
 
 
14
  if line.startswith("--extra-index-url"):
15
  # Handle custom index URLs
16
  _, url = line.split()
17
  _dependency_links.append(url)
18
- elif (
19
- "flash-attn" not in line
20
- and "flash-attention" not in line
21
- and "deepspeed" not in line
22
- and line
23
- and line[0] != "#"
24
- ):
25
  # Handle standard packages
26
  _install_requires.append(line)
27
 
 
11
  with open("./requirements.txt", encoding="utf-8") as requirements_file:
12
  lines = [r.strip() for r in requirements_file.readlines()]
13
  for line in lines:
14
+ is_extras = (
15
+ "flash-attn" in line
16
+ or "flash-attention" in line
17
+ or "deepspeed" in line
18
+ or "mamba-ssm" in line
19
+ )
20
  if line.startswith("--extra-index-url"):
21
  # Handle custom index URLs
22
  _, url = line.split()
23
  _dependency_links.append(url)
24
+ elif not is_extras and line and line[0] != "#":
 
 
 
 
 
 
25
  # Handle standard packages
26
  _install_requires.append(line)
27
 
src/axolotl/models/mamba/__init__.py CHANGED
@@ -2,8 +2,20 @@
2
  Modeling module for Mamba models
3
  """
4
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  def fix_mamba_attn_for_loss():
 
 
7
  from mamba_ssm.models import mixer_seq_simple
8
 
9
  from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed
 
2
  Modeling module for Mamba models
3
  """
4
 
5
+ import importlib
6
+
7
+
8
+ def check_mamba_ssm_installed():
9
+ mamba_ssm_spec = importlib.util.find_spec("mamba_ssm")
10
+ if mamba_ssm_spec is None:
11
+ raise ImportError(
12
+ "MambaLMHeadModel requires mamba_ssm. Please install it with `pip install -e .[mamba-ssm]`"
13
+ )
14
+
15
 
16
  def fix_mamba_attn_for_loss():
17
+ check_mamba_ssm_installed()
18
+
19
  from mamba_ssm.models import mixer_seq_simple
20
 
21
  from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed