Files changed (1) hide show
  1. README.md +27 -0
README.md CHANGED
@@ -191,11 +191,14 @@ print(tokenizer.decode(outputs[0]))
191
 
192
  #### FP16
193
 
 
 
194
  <details>
195
  <summary> Click to expand </summary>
196
 
197
  ```python
198
  # pip install accelerate
 
199
  import torch
200
  from transformers import T5Tokenizer, T5ForConditionalGeneration
201
 
@@ -211,8 +214,32 @@ print(tokenizer.decode(outputs[0]))
211
 
212
  </details>
213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  #### INT8
215
 
 
 
216
  <details>
217
  <summary> Click to expand </summary>
218
 
 
191
 
192
  #### FP16
193
 
194
+ The original model has been trained in `bfloat16`, therefore running such a large model in `float16` can lead to drastically reduced performance. We advise users to run this model in `bfloat16` or `float32` if they have enough compute resources. Check the next section on how to run the model in `bfloat16`.
195
+
196
  <details>
197
  <summary> Click to expand </summary>
198
 
199
  ```python
200
  # pip install accelerate
201
+ # Not recommended - we advise users to run their model in `bfloat16`
202
  import torch
203
  from transformers import T5Tokenizer, T5ForConditionalGeneration
204
 
 
214
 
215
  </details>
216
 
217
+ #### BFLOAT16
218
+
219
+ <details>
220
+ <summary> Click to expand </summary>
221
+
222
+ ```python
223
+ # pip install accelerate
224
+ import torch
225
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
226
+
227
+ tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl")
228
+ model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xxl", device_map="auto", torch_dtype=torch.bfloat16)
229
+
230
+ input_text = "translate English to German: How old are you?"
231
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
232
+
233
+ outputs = model.generate(input_ids)
234
+ print(tokenizer.decode(outputs[0]))
235
+ ```
236
+
237
+ </details>
238
+
239
  #### INT8
240
 
241
+ The original model has been trained in `bfloat16`, therefore running such a large model in `int8` (the underlying technique behind `int8` quantization is to first cast the weights in `float16`) can lead to drastically reduced performance. We advise users to run this model in `bfloat16` or `float32` if they have enough compute resources. Check the next section on how to run the model in `bfloat16`.
242
+
243
  <details>
244
  <summary> Click to expand </summary>
245