OpenNLPLab
commited on
Add bf16
Browse files
README.md
CHANGED
@@ -112,12 +112,12 @@ export use_triton=False
|
|
112 |
```python
|
113 |
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
114 |
>>> tokenizer = AutoTokenizer.from_pretrained("OpenNLPLab/TransNormerLLM2-7B-300B", trust_remote_code=True)
|
115 |
-
>>> model = AutoModelForCausalLM.from_pretrained("TransNormerLLM2-7B-300B", device_map="auto", trust_remote_code=True)
|
116 |
>>> inputs = tokenizer('今天是美好的一天', return_tensors='pt')
|
117 |
>>> pred = model.generate(**inputs, max_new_tokens=8192, repetition_penalty=1.0)
|
118 |
>>> print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True))
|
119 |
```
|
120 |
-
|
121 |
|
122 |
# Fine-tuning the Model
|
123 |
|
|
|
112 |
```python
|
113 |
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
114 |
>>> tokenizer = AutoTokenizer.from_pretrained("OpenNLPLab/TransNormerLLM2-7B-300B", trust_remote_code=True)
|
115 |
+
>>> model = AutoModelForCausalLM.from_pretrained("TransNormerLLM2-7B-300B", torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True)
|
116 |
>>> inputs = tokenizer('今天是美好的一天', return_tensors='pt')
|
117 |
>>> pred = model.generate(**inputs, max_new_tokens=8192, repetition_penalty=1.0)
|
118 |
>>> print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True))
|
119 |
```
|
120 |
+
* **Note**: we recommend to use `bfloat16` in `TransNormerLLM`, `float16` might lead `nan` error, please check your divce compatibility!
|
121 |
|
122 |
# Fine-tuning the Model
|
123 |
|