winglian commited on
Commit
161bcb6
·
unverified ·
1 Parent(s): d25c34c

Dockerfile torch fix (#987)

Browse files

* add torch to requirements.txt at build time to force version to stick

* fix xformers check

* better handling of xformers based on installed torch version

* fix for ci w/o torch

.github/workflows/base.yml CHANGED
@@ -28,7 +28,7 @@ jobs:
28
  - cuda: "118"
29
  cuda_version: 11.8.0
30
  python_version: "3.10"
31
- pytorch: 2.1.0
32
  torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
33
  steps:
34
  - name: Checkout
 
28
  - cuda: "118"
29
  cuda_version: 11.8.0
30
  python_version: "3.10"
31
+ pytorch: 2.1.1
32
  torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
33
  steps:
34
  - name: Checkout
.github/workflows/main.yml CHANGED
@@ -27,7 +27,7 @@ jobs:
27
  - cuda: 118
28
  cuda_version: 11.8.0
29
  python_version: "3.10"
30
- pytorch: 2.1.0
31
  axolotl_extras:
32
  runs-on: [self-hosted, gpu, docker]
33
  steps:
@@ -80,7 +80,7 @@ jobs:
80
  - cuda: 118
81
  cuda_version: 11.8.0
82
  python_version: "3.10"
83
- pytorch: 2.1.0
84
  axolotl_extras:
85
  runs-on: [self-hosted, gpu, docker]
86
  steps:
 
27
  - cuda: 118
28
  cuda_version: 11.8.0
29
  python_version: "3.10"
30
+ pytorch: 2.1.1
31
  axolotl_extras:
32
  runs-on: [self-hosted, gpu, docker]
33
  steps:
 
80
  - cuda: 118
81
  cuda_version: 11.8.0
82
  python_version: "3.10"
83
+ pytorch: 2.1.1
84
  axolotl_extras:
85
  runs-on: [self-hosted, gpu, docker]
86
  steps:
docker/Dockerfile CHANGED
@@ -19,7 +19,6 @@ RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
19
  WORKDIR /workspace/axolotl
20
 
21
  # If AXOLOTL_EXTRAS is set, append it in brackets
22
- RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt
23
  RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
24
  pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \
25
  else \
 
19
  WORKDIR /workspace/axolotl
20
 
21
  # If AXOLOTL_EXTRAS is set, append it in brackets
 
22
  RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
23
  pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \
24
  else \
setup.py CHANGED
@@ -1,5 +1,7 @@
1
  """setup.py for axolotl"""
2
 
 
 
3
  from setuptools import find_packages, setup
4
 
5
 
@@ -22,12 +24,13 @@ def parse_requirements():
22
  # Handle standard packages
23
  _install_requires.append(line)
24
 
25
- # TODO(wing) remove once xformers release supports torch 2.1.0
26
- if "torch==2.1.0" in _install_requires:
27
- _install_requires.pop(_install_requires.index("xformers>=0.0.22"))
28
- _install_requires.append(
29
- "xformers @ git+https://github.com/facebookresearch/xformers.git@main"
30
- )
 
31
 
32
  return _install_requires, _dependency_links
33
 
 
1
  """setup.py for axolotl"""
2
 
3
+ from importlib.metadata import PackageNotFoundError, version
4
+
5
  from setuptools import find_packages, setup
6
 
7
 
 
24
  # Handle standard packages
25
  _install_requires.append(line)
26
 
27
+ try:
28
+ torch_version = version("torch")
29
+ if torch_version.startswith("2.1.1"):
30
+ _install_requires.pop(_install_requires.index("xformers==0.0.22"))
31
+ _install_requires.append("xformers==0.0.23")
32
+ except PackageNotFoundError:
33
+ pass
34
 
35
  return _install_requires, _dependency_links
36