abhi-mosaic commited on
Commit
2f88b1b
1 Parent(s): 40e5047

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +7 -5
README.md CHANGED
@@ -39,14 +39,16 @@ It includes options for many training efficiency features such as [FlashAttentio
39
 
40
  ```python
41
  import transformers
42
- model = transformers.AutoModelForCausalLM.from_pretrained('mosaicml/mpt-7b-storywriter', trust_remote_code=True, torch_dtype=torch.bfloat16)
43
  ```
44
 
45
- To use the optimized triton implementation of FlashAttention, you can load with `attn_impl='triton'` and move the model to `bfloat16` like so:
46
-
47
  ```python
48
- model = transformers.AutoModelForCausalLM.from_pretrained('mosaicml/mpt-7b-storywriter', trust_remote_code=True, torch_dtype=torch.bfloat16, attn_impl='triton')
49
- model.to(device='cuda:0', dtype=torch.bfloat16)
 
 
 
50
  ```
51
 
52
  Although the model was trained with a sequence length of 2048 and finetuned with a sequence length of 65536,
 
39
 
40
  ```python
41
  import transformers
42
+ model = transformers.AutoModelForCausalLM.from_pretrained('mosaicml/mpt-7b-storywriter', trust_remote_code=True)
43
  ```
44
 
45
+ To use the optimized [triton implementation](https://github.com/openai/triton) of FlashAttention, you can load the model with `attn_impl='triton'` and move the model to `bfloat16`:
 
46
  ```python
47
+ config = transformers.AutoConfig.from_pretrained('mosaicml/mpt-7b-storywriter', trust_remote_code=True)
48
+ config.attn_config['attn_impl'] = 'triton'
49
+
50
+ model = transformers.AutoModelForCausalLM.from_pretrained('mosaicml/mpt-7b-storywriter', config=config, torch_dtype=torch.bfloat16, trust_remote_code=True)
51
+ model.to(device='cuda:0')
52
  ```
53
 
54
  Although the model was trained with a sequence length of 2048 and finetuned with a sequence length of 65536,