masanorihirano commited on
Commit
f4d4880
1 Parent(s): da3b30c
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -43,7 +43,7 @@ def load_lora_model(
43
  model_path,
44
  load_in_8bit=load_8bit,
45
  device_map="auto" if device == "cuda" else {"": device},
46
- max_memory=max_gpu_memory,
47
  torch_dtype=torch.float16,
48
  )
49
  if lora_weight is not None:
 
43
  model_path,
44
  load_in_8bit=load_8bit,
45
  device_map="auto" if device == "cuda" else {"": device},
46
+ max_memory={i: max_gpu_memory for i in range(num_gpus)},
47
  torch_dtype=torch.float16,
48
  )
49
  if lora_weight is not None: