mhenrichsen Ubuntu mhenrichsen Mads Henrichsen winglian commited on
Commit
cf66547
1 Parent(s): 06edf17

flash attn pip install (#426)

Browse files

* flash attn pip

* add packaging

* add packaging to apt get

* install flash attn in dockerfile

* remove unused whls

* add wheel

* clean up pr

fix packaging requirement for ci
upgrade pip for ci
skip build isolation for requiremnents to get flash-attn working
install flash-attn seperately

* install wheel for ci

* no flash-attn for basic cicd

* install flash-attn as pip extras

---------

Co-authored-by: Ubuntu <mgh@mgh-vm.wsyvwcia0jxedeyrchqg425tpb.ax.internal.cloudapp.net>
Co-authored-by: mhenrichsen <some_email@hey.com>
Co-authored-by: Mads Henrichsen <mads@BrbartiendeMads.lan>
Co-authored-by: Wing Lian <wing.lian@gmail.com>

.github/workflows/main.yml CHANGED
@@ -13,17 +13,17 @@ jobs:
13
  fail-fast: false
14
  matrix:
15
  include:
16
- - cuda: cu118
17
  cuda_version: 11.8.0
18
  python_version: "3.9"
19
  pytorch: 2.0.1
20
  axolotl_extras:
21
- - cuda: cu118
22
  cuda_version: 11.8.0
23
  python_version: "3.10"
24
  pytorch: 2.0.1
25
  axolotl_extras:
26
- - cuda: cu118
27
  cuda_version: 11.8.0
28
  python_version: "3.9"
29
  pytorch: 2.0.1
@@ -49,10 +49,11 @@ jobs:
49
  with:
50
  context: .
51
  build-args: |
52
- BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}
 
53
  file: ./docker/Dockerfile
54
  push: ${{ github.event_name != 'pull_request' }}
55
- tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
56
  labels: ${{ steps.metadata.outputs.labels }}
57
  build-axolotl-runpod:
58
  needs: build-axolotl
 
13
  fail-fast: false
14
  matrix:
15
  include:
16
+ - cuda: 118
17
  cuda_version: 11.8.0
18
  python_version: "3.9"
19
  pytorch: 2.0.1
20
  axolotl_extras:
21
+ - cuda: 118
22
  cuda_version: 11.8.0
23
  python_version: "3.10"
24
  pytorch: 2.0.1
25
  axolotl_extras:
26
+ - cuda: 118
27
  cuda_version: 11.8.0
28
  python_version: "3.9"
29
  pytorch: 2.0.1
 
49
  with:
50
  context: .
51
  build-args: |
52
+ BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
53
+ CUDA=${{ matrix.cuda }}
54
  file: ./docker/Dockerfile
55
  push: ${{ github.event_name != 'pull_request' }}
56
+ tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
57
  labels: ${{ steps.metadata.outputs.labels }}
58
  build-axolotl-runpod:
59
  needs: build-axolotl
README.md CHANGED
@@ -69,7 +69,7 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo
69
  ```bash
70
  git clone https://github.com/OpenAccess-AI-Collective/axolotl
71
 
72
- pip3 install -e .
73
  pip3 install -U git+https://github.com/huggingface/peft.git
74
 
75
  # finetune lora
 
69
  ```bash
70
  git clone https://github.com/OpenAccess-AI-Collective/axolotl
71
 
72
+ pip3 install -e .[flash-attn]
73
  pip3 install -U git+https://github.com/huggingface/peft.git
74
 
75
  # finetune lora
docker/Dockerfile CHANGED
@@ -16,9 +16,9 @@ RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
16
  # If AXOLOTL_EXTRAS is set, append it in brackets
17
  RUN cd axolotl && \
18
  if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
19
- pip install -e .[$AXOLOTL_EXTRAS]; \
20
  else \
21
- pip install -e .; \
22
  fi
23
 
24
  # fix so that git fetch/pull from remote works
 
16
  # If AXOLOTL_EXTRAS is set, append it in brackets
17
  RUN cd axolotl && \
18
  if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
19
+ pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
20
  else \
21
+ pip install -e .[flash-attn]; \
22
  fi
23
 
24
  # fix so that git fetch/pull from remote works
docker/Dockerfile-base CHANGED
@@ -31,26 +31,6 @@ WORKDIR /workspace
31
  RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
32
  python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
33
 
34
-
35
- FROM base-builder AS flash-attn-builder
36
-
37
- WORKDIR /workspace
38
-
39
- ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
40
-
41
- RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
42
- cd flash-attention && \
43
- git checkout v2.0.4 && \
44
- python3 setup.py bdist_wheel && \
45
- cd csrc/fused_dense_lib && \
46
- python3 setup.py bdist_wheel && \
47
- cd ../xentropy && \
48
- python3 setup.py bdist_wheel && \
49
- cd ../rotary && \
50
- python3 setup.py bdist_wheel && \
51
- cd ../layer_norm && \
52
- python3 setup.py bdist_wheel
53
-
54
  FROM base-builder AS deepspeed-builder
55
 
56
  ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
@@ -90,13 +70,8 @@ RUN mkdir -p /workspace/wheels/bitsandbytes
90
  COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
91
  COPY --from=bnb-builder /workspace/bitsandbytes/dist/bitsandbytes-*.whl wheels
92
  COPY --from=bnb-builder /workspace/bitsandbytes/bitsandbytes/libbitsandbytes*.so wheels/bitsandbytes
93
- COPY --from=flash-attn-builder /workspace/flash-attention/dist/flash_attn-*.whl wheels
94
- COPY --from=flash-attn-builder /workspace/flash-attention/csrc/fused_dense_lib/dist/fused_dense_lib-*.whl wheels
95
- COPY --from=flash-attn-builder /workspace/flash-attention/csrc/xentropy/dist/xentropy_cuda_lib-*.whl wheels
96
- COPY --from=flash-attn-builder /workspace/flash-attention/csrc/rotary/dist/rotary_emb-*.whl wheels
97
- COPY --from=flash-attn-builder /workspace/flash-attention/csrc/layer_norm/dist/dropout_layer_norm-*.whl wheels
98
 
99
- RUN pip3 install wheels/deepspeed-*.whl wheels/flash_attn-*.whl wheels/fused_dense_lib-*.whl wheels/xentropy_cuda_lib-*.whl wheels/rotary_emb-*.whl wheels/dropout_layer_norm-*.whl
100
  RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
101
  RUN git lfs install --skip-repo
102
  RUN pip3 install awscli && \
 
31
  RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
32
  python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  FROM base-builder AS deepspeed-builder
35
 
36
  ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
 
70
  COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
71
  COPY --from=bnb-builder /workspace/bitsandbytes/dist/bitsandbytes-*.whl wheels
72
  COPY --from=bnb-builder /workspace/bitsandbytes/bitsandbytes/libbitsandbytes*.so wheels/bitsandbytes
 
 
 
 
 
73
 
74
+ RUN pip3 install wheels/deepspeed-*.whl
75
  RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
76
  RUN git lfs install --skip-repo
77
  RUN pip3 install awscli && \
requirements.txt CHANGED
@@ -6,6 +6,7 @@ addict
6
  fire
7
  PyYAML==6.0
8
  datasets
 
9
  sentencepiece
10
  wandb
11
  einops
 
6
  fire
7
  PyYAML==6.0
8
  datasets
9
+ flash-attn==2.0.8
10
  sentencepiece
11
  wandb
12
  einops
setup.py CHANGED
@@ -7,6 +7,7 @@ with open("./requirements.txt", encoding="utf-8") as requirements_file:
7
  # don't include peft yet until we check the int4
8
  # need to manually install peft for now...
9
  reqs = [r.strip() for r in requirements_file.readlines() if "peft" not in r]
 
10
  reqs = [r for r in reqs if r and r[0] != "#"]
11
  for r in reqs:
12
  install_requires.append(r)
@@ -25,8 +26,10 @@ setup(
25
  "gptq_triton": [
26
  "alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
27
  ],
 
 
 
28
  "extras": [
29
- "flash-attn",
30
  "deepspeed",
31
  ],
32
  },
 
7
  # don't include peft yet until we check the int4
8
  # need to manually install peft for now...
9
  reqs = [r.strip() for r in requirements_file.readlines() if "peft" not in r]
10
+ reqs = [r for r in reqs if "flash-attn" not in r]
11
  reqs = [r for r in reqs if r and r[0] != "#"]
12
  for r in reqs:
13
  install_requires.append(r)
 
26
  "gptq_triton": [
27
  "alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
28
  ],
29
+ "flash-attn": [
30
+ "flash-attn==2.0.8",
31
+ ],
32
  "extras": [
 
33
  "deepspeed",
34
  ],
35
  },