pirroh commited on
Commit
749980f
1 Parent(s): 266de40

Update README.md (#5)

Browse files

- Update README.md (2c290064b0dc519eee5624dfd3625b729ac5aea2)

Files changed (1) hide show
  1. README.md +7 -2
README.md CHANGED
@@ -95,8 +95,13 @@ from transformers import AutoModelForCausalLM
95
  model = AutoModelForCausalLM.from_pretrained('replit/replit-code-v1-3b', trust_remote_code=True)
96
  ```
97
 
98
- To use the optimized Triton implementation of FlashAttention on GPUs with BF16 precision, move the model to `bfloat16` and use it as follows:
 
 
 
 
99
 
 
100
  ```python
101
  from transformers import AutoModelForCausalLM
102
 
@@ -106,7 +111,7 @@ model.to(device='cuda:0', dtype=torch.bfloat16)
106
 
107
  # forward pass
108
  x = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
109
- x = x.to(device='cuda:0', dtype=torch.bfloat16)
110
  y = model(x)
111
 
112
  ```
 
95
  model = AutoModelForCausalLM.from_pretrained('replit/replit-code-v1-3b', trust_remote_code=True)
96
  ```
97
 
98
+ To use the optimized Triton implementation of FlashAttention on GPUs with BF16 precision, first install the following dependencies:
99
+ ```
100
+ flash-attn==0.2.8
101
+ triton==2.0.0.dev20221202
102
+ ```
103
 
104
+ Then, move the model to `bfloat16` and use it as follows:
105
  ```python
106
  from transformers import AutoModelForCausalLM
107
 
 
111
 
112
  # forward pass
113
  x = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
114
+ x = x.to(device='cuda:0')
115
  y = model(x)
116
 
117
  ```