aatu18 commited on
Commit
5974cda
·
verified ·
1 Parent(s): cfd99e7

test device_map auto

Browse files
Files changed (1) hide show
  1. methods/utilities.py +5 -4
methods/utilities.py CHANGED
@@ -30,10 +30,11 @@ def load_llama_llm(AUTH_TOKEN):
30
  token=AUTH_TOKEN
31
  )
32
  model = AutoModelForCausalLM.from_pretrained(
33
- model_id,
34
- #torch_dtype=torch.float16,
35
- trust_remote_code=True,
36
- token=AUTH_TOKEN
 
37
  )
38
  model = model.to('cuda')
39
  model = model.eval()
 
30
  token=AUTH_TOKEN
31
  )
32
  model = AutoModelForCausalLM.from_pretrained(
33
+ model_id,
34
+ torch_dtype=torch.float16,
35
+ device_map='auto',
36
+ trust_remote_code=True,
37
+ token=AUTH_TOKEN
38
  )
39
  model = model.to('cuda')
40
  model = model.eval()