madhavatreplit commited on
Commit
1e1a20a
1 Parent(s): b6d9ff2

Update README.md for flash attn

Browse files
Files changed (1) hide show
  1. README.md +8 -2
README.md CHANGED
@@ -105,10 +105,16 @@ triton==2.0.0.dev20221202
105
 
106
  Then, move the model to `bfloat16` and use it as follows:
107
  ```python
108
- from transformers import AutoModelForCausalLM
 
 
 
 
 
 
109
 
110
  # load model
111
- model = AutoModelForCausalLM.from_pretrained('replit/replit-code-v1-3b', trust_remote_code=True, attn_impl='triton')
112
  model.to(device='cuda:0', dtype=torch.bfloat16)
113
 
114
  # forward pass
 
105
 
106
  Then, move the model to `bfloat16` and use it as follows:
107
  ```python
108
+ from transformers import AutoModelForCausalLM, AutoConfig
109
+
110
+ config = AutoConfig.from_pretrained(
111
+ "replit/replit-code-v1-3b",
112
+ trust_remote_code=True
113
+ )
114
+ config.attn_config['attn_impl'] = 'triton'
115
 
116
  # load model
117
+ model = AutoModelForCausalLM.from_pretrained('replit/replit-code-v1-3b', config=config, trust_remote_code=True)
118
  model.to(device='cuda:0', dtype=torch.bfloat16)
119
 
120
  # forward pass