fix: warn user to install mamba_ssm package (#1019)
Browse files- docker/Dockerfile +2 -2
- requirements.txt +3 -1
- setup.py +7 -7
- src/axolotl/models/mamba/__init__.py +12 -0
docker/Dockerfile
CHANGED
@@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl
|
|
20 |
|
21 |
# If AXOLOTL_EXTRAS is set, append it in brackets
|
22 |
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
23 |
-
pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \
|
24 |
else \
|
25 |
-
pip install -e .[deepspeed,flash-attn]; \
|
26 |
fi
|
27 |
|
28 |
# So we can test the Docker image
|
|
|
20 |
|
21 |
# If AXOLOTL_EXTRAS is set, append it in brackets
|
22 |
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
23 |
+
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS]; \
|
24 |
else \
|
25 |
+
pip install -e .[deepspeed,flash-attn,mamba-ssm]; \
|
26 |
fi
|
27 |
|
28 |
# So we can test the Docker image
|
requirements.txt
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
2 |
-
packaging
|
3 |
peft==0.7.0
|
4 |
transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0
|
5 |
tokenizers==0.15.0
|
@@ -34,6 +34,8 @@ fschat==0.2.34
|
|
34 |
gradio==3.50.2
|
35 |
tensorboard
|
36 |
|
|
|
|
|
37 |
# remote filesystems
|
38 |
s3fs
|
39 |
gcsfs
|
|
|
1 |
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
2 |
+
packaging==23.2
|
3 |
peft==0.7.0
|
4 |
transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0
|
5 |
tokenizers==0.15.0
|
|
|
34 |
gradio==3.50.2
|
35 |
tensorboard
|
36 |
|
37 |
+
mamba-ssm==1.1.1
|
38 |
+
|
39 |
# remote filesystems
|
40 |
s3fs
|
41 |
gcsfs
|
setup.py
CHANGED
@@ -11,17 +11,17 @@ def parse_requirements():
|
|
11 |
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
12 |
lines = [r.strip() for r in requirements_file.readlines()]
|
13 |
for line in lines:
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
if line.startswith("--extra-index-url"):
|
15 |
# Handle custom index URLs
|
16 |
_, url = line.split()
|
17 |
_dependency_links.append(url)
|
18 |
-
elif
|
19 |
-
"flash-attn" not in line
|
20 |
-
and "flash-attention" not in line
|
21 |
-
and "deepspeed" not in line
|
22 |
-
and line
|
23 |
-
and line[0] != "#"
|
24 |
-
):
|
25 |
# Handle standard packages
|
26 |
_install_requires.append(line)
|
27 |
|
|
|
11 |
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
12 |
lines = [r.strip() for r in requirements_file.readlines()]
|
13 |
for line in lines:
|
14 |
+
is_extras = (
|
15 |
+
"flash-attn" in line
|
16 |
+
or "flash-attention" in line
|
17 |
+
or "deepspeed" in line
|
18 |
+
or "mamba-ssm" in line
|
19 |
+
)
|
20 |
if line.startswith("--extra-index-url"):
|
21 |
# Handle custom index URLs
|
22 |
_, url = line.split()
|
23 |
_dependency_links.append(url)
|
24 |
+
elif not is_extras and line and line[0] != "#":
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
# Handle standard packages
|
26 |
_install_requires.append(line)
|
27 |
|
src/axolotl/models/mamba/__init__.py
CHANGED
@@ -2,8 +2,20 @@
|
|
2 |
Modeling module for Mamba models
|
3 |
"""
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
def fix_mamba_attn_for_loss():
|
|
|
|
|
7 |
from mamba_ssm.models import mixer_seq_simple
|
8 |
|
9 |
from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed
|
|
|
2 |
Modeling module for Mamba models
|
3 |
"""
|
4 |
|
5 |
+
import importlib
|
6 |
+
|
7 |
+
|
8 |
+
def check_mamba_ssm_installed():
|
9 |
+
mamba_ssm_spec = importlib.util.find_spec("mamba_ssm")
|
10 |
+
if mamba_ssm_spec is None:
|
11 |
+
raise ImportError(
|
12 |
+
"MambaLMHeadModel requires mamba_ssm. Please install it with `pip install -e .[mamba-ssm]`"
|
13 |
+
)
|
14 |
+
|
15 |
|
16 |
def fix_mamba_attn_for_loss():
|
17 |
+
check_mamba_ssm_installed()
|
18 |
+
|
19 |
from mamba_ssm.models import mixer_seq_simple
|
20 |
|
21 |
from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed
|