alexkueck commited on
Commit
6f8b2c2
1 Parent(s): afbdee4

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +4 -4
utils.py CHANGED
@@ -10,6 +10,7 @@ import requests
10
  import re
11
  import html
12
  import torch
 
13
  import sys
14
  import gc
15
  from pygments.lexers import guess_lexer, ClassNotFound
@@ -18,7 +19,7 @@ from pygments import highlight
18
  from pygments.lexers import guess_lexer,get_lexer_by_name
19
  from pygments.formatters import HtmlFormatter
20
  import transformers
21
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig,
22
  import datasets
23
  from datasets import load_dataset
24
  import evaluate
@@ -335,12 +336,11 @@ def daten_laden(name):
335
 
336
 
337
  #Quantisation - tzo speed up training
338
- def bnb_config (load4Bit, double_quant)
339
- compute_dtype = getattr(torch, "float16")
340
  bnb_config = BitsAndBytesConfig(
341
  load_in_4bit= load4Bit,
342
  bnb_4bit_quant_type="nf4",
343
- bnb_4bit_compute_dtype=compute_dtype,
344
  bnb_4bit_use_double_quant=double_quant,
345
  )
346
  return bnb_config
 
10
  import re
11
  import html
12
  import torch
13
+ from torch import cuda, bfloat16
14
  import sys
15
  import gc
16
  from pygments.lexers import guess_lexer, ClassNotFound
 
19
  from pygments.lexers import guess_lexer,get_lexer_by_name
20
  from pygments.formatters import HtmlFormatter
21
  import transformers
22
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
23
  import datasets
24
  from datasets import load_dataset
25
  import evaluate
 
336
 
337
 
338
  #Quantisation - tzo speed up training
339
+ def bnb_config (load4Bit, double_quant):
 
340
  bnb_config = BitsAndBytesConfig(
341
  load_in_4bit= load4Bit,
342
  bnb_4bit_quant_type="nf4",
343
+ bnb_4bit_compute_dtype=bfloat16,
344
  bnb_4bit_use_double_quant=double_quant,
345
  )
346
  return bnb_config