masanorihirano
commited on
Commit
•
da3b30c
1
Parent(s):
250b795
fix typo
Browse files
app.py
CHANGED
@@ -42,7 +42,7 @@ def load_lora_model(
|
|
42 |
model = LlamaForCausalLM.from_pretrained(
|
43 |
model_path,
|
44 |
load_in_8bit=load_8bit,
|
45 |
-
device_map="auto
|
46 |
max_memory=max_gpu_memory,
|
47 |
torch_dtype=torch.float16,
|
48 |
)
|
|
|
42 |
model = LlamaForCausalLM.from_pretrained(
|
43 |
model_path,
|
44 |
load_in_8bit=load_8bit,
|
45 |
+
device_map="auto" if device == "cuda" else {"": device},
|
46 |
max_memory=max_gpu_memory,
|
47 |
torch_dtype=torch.float16,
|
48 |
)
|