alexkueck commited on
Commit
7f937d0
·
1 Parent(s): 4128faf

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +6 -4
utils.py CHANGED
@@ -59,13 +59,14 @@ def load_tokenizer_and_model(base_model, load_8bit=False):
59
  else:
60
  device = "cpu"
61
 
62
- tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast = True)
63
  if device == "cuda":
64
  model = AutoModelForCausalLM.from_pretrained(
65
  base_model,
66
  load_in_8bit=load_8bit,
67
  torch_dtype=torch.float16,
68
- device_map="auto"
 
69
  )
70
  else:
71
  model = AutoModelForCausalLM.from_pretrained(
@@ -90,11 +91,12 @@ def load_model(base_model, load_8bit=False):
90
  base_model,
91
  load_in_8bit=load_8bit,
92
  torch_dtype=torch.float16,
93
- device_map="auto"
 
94
  )
95
  else:
96
  model = AutoModelForCausalLM.from_pretrained(
97
- base_model, device_map={"": device}, low_cpu_mem_usage=True
98
  )
99
 
100
  #if not load_8bit:
 
59
  else:
60
  device = "cpu"
61
 
62
+ tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast = True, use_auth_token=True)
63
  if device == "cuda":
64
  model = AutoModelForCausalLM.from_pretrained(
65
  base_model,
66
  load_in_8bit=load_8bit,
67
  torch_dtype=torch.float16,
68
+ device_map="auto",
69
+ use_auth_token=True
70
  )
71
  else:
72
  model = AutoModelForCausalLM.from_pretrained(
 
91
  base_model,
92
  load_in_8bit=load_8bit,
93
  torch_dtype=torch.float16,
94
+ device_map="auto",
95
+ use_auth_token=True
96
  )
97
  else:
98
  model = AutoModelForCausalLM.from_pretrained(
99
+ base_model, device_map={"": device}, low_cpu_mem_usage=True, use_auth_token=True
100
  )
101
 
102
  #if not load_8bit: