lorocksUMD commited on
Commit
a433de4
1 Parent(s): 77ceba1

Update llava/model/builder.py

Browse files
Files changed (1) hide show
  1. llava/model/builder.py +13 -12
llava/model/builder.py CHANGED
@@ -29,18 +29,19 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
29
  # if device != "cuda":
30
  # kwargs['device_map'] = {"": device}
31
 
32
- # if load_8bit:
33
- # kwargs['load_in_8bit'] = True
34
- # elif load_4bit:
35
- # kwargs['load_in_4bit'] = True
36
- # kwargs['quantization_config'] = BitsAndBytesConfig(
37
- # load_in_4bit=True,
38
- # bnb_4bit_compute_dtype=torch.float16,
39
- # bnb_4bit_use_double_quant=True,
40
- # bnb_4bit_quant_type='nf4'
41
- # )
42
- # else:
43
- # kwargs['torch_dtype'] = torch.float16
 
44
 
45
  if use_flash_attn:
46
  kwargs['attn_implementation'] = 'flash_attention_2'
 
29
  # if device != "cuda":
30
  # kwargs['device_map'] = {"": device}
31
 
32
+ load_8bit = True
33
+ if load_8bit:
34
+ kwargs['load_in_8bit'] = True
35
+ elif load_4bit:
36
+ kwargs['load_in_4bit'] = True
37
+ kwargs['quantization_config'] = BitsAndBytesConfig(
38
+ load_in_4bit=True,
39
+ bnb_4bit_compute_dtype=torch.float16,
40
+ bnb_4bit_use_double_quant=True,
41
+ bnb_4bit_quant_type='nf4'
42
+ )
43
+ else:
44
+ kwargs['torch_dtype'] = torch.float16
45
 
46
  if use_flash_attn:
47
  kwargs['attn_implementation'] = 'flash_attention_2'