Andreas99 commited on
Commit
b093dc1
·
verified ·
1 Parent(s): 6aea200

Update src/train.py

Browse files
Files changed (1) hide show
  1. src/train.py +8 -2
src/train.py CHANGED
@@ -6,7 +6,7 @@ import networkx as nx
6
  from tqdm import tqdm
7
  from peft import (LoraConfig, get_peft_model,
8
  prepare_model_for_kbit_training)
9
- from transformers import AutoModelForCausalLM, AutoTokenizer
10
 
11
 
12
 
@@ -30,11 +30,17 @@ class QloraTrainer_CS:
30
  model_id = self.config['inference']["base_model"]
31
  print(model_id)
32
 
 
 
 
 
 
 
33
  tokenizer = AutoTokenizer.from_pretrained(model_id)
34
  tokenizer.model_max_length = self.config['training']['tokenizer']["max_length"]
35
  if not tokenizer.pad_token:
36
  tokenizer.pad_token = tokenizer.eos_token
37
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
38
  if model.device.type != 'cuda':
39
  model.to('cuda')
40
 
 
6
  from tqdm import tqdm
7
  from peft import (LoraConfig, get_peft_model,
8
  prepare_model_for_kbit_training)
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
10
 
11
 
12
 
 
30
  model_id = self.config['inference']["base_model"]
31
  print(model_id)
32
 
33
+ bnb_config = BitsAndBytesConfig(
34
+ load_in_8bit=True,
35
+ bnb_8bit_use_double_quant=True,
36
+ bnb_8bit_quant_type="nf8",
37
+ bnb_8bit_compute_dtype=torch.bfloat16
38
+ )
39
  tokenizer = AutoTokenizer.from_pretrained(model_id)
40
  tokenizer.model_max_length = self.config['training']['tokenizer']["max_length"]
41
  if not tokenizer.pad_token:
42
  tokenizer.pad_token = tokenizer.eos_token
43
+ model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, torch_dtype=torch.bfloat16)
44
  if model.device.type != 'cuda':
45
  model.to('cuda')
46