Markus28 commited on
Commit
284316d
1 Parent(s): 6546b2c

feat: added README

Browse files
Files changed (1) hide show
  1. README.md +33 -0
README.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BERT with Flash-Attention
2
+ ### Installing dependencies
3
+ To run the model on GPU, you need to install Flash Attention.
4
+ You may either install from pypi (which may not work with fused-dense), or from source.
5
+ To install from source, clone the GitHub repository:
6
+ ```console
7
+ git clone git@github.com:Dao-AILab/flash-attention.git
8
+ ```
9
+ The code provided here should work with commit `43950dd`.
10
+ Change to the cloned repo and install:
11
+ ```console
12
+ cd flash-attention && python setup.py install
13
+ ```
14
+ This will compile the flash-attention kernel, which will take some time.
15
+
16
+ If you would like to use fused MLPs (e.g. to use activation checkpointing),
17
+ you may install fused-dense also from source:
18
+ ```console
19
+ cd csrc/fused_dense_lib && python setup.py install
20
+ ```
21
+
22
+
23
+ ### Configuration
24
+ The config adds some new parameters:
25
+ - `use_flash_attn`: If `True`, always use flash attention. If `None`, use flash attention when GPU is available. If `False`, never use flash attention (works on CPU).
26
+ - `window_size`: Size (left and right) of the local attention window. If `(-1, -1)`, use global attention
27
+ - `dense_seq_output`: If true, we only need to pass the hidden states for the masked out token (around 15%) to the classifier heads. I set this to true for pretraining.
28
+ - `fused_mlp`: Whether to use fused-dense. Useful to reduce VRAM in combination with activation checkpointing
29
+ - `mlp_checkpoint_lvl`: One of `{0, 1, 2}`. Increasing this increases the amount of activation checkpointing within the MLP. Keep this at 0 for pretraining and use gradient accumulation instead. For embedding training, increase this as much as needed.
30
+ - `last_layer_subset`: If true, we only need the compute the last layer for a subset of tokens. I left this to false.
31
+ - `use_qk_norm`: Whether or not to use QK-normalization
32
+ - `num_loras`: Number of LoRAs to use when initializing a `BertLoRA` model. Has no effect on other models.
33
+