Li Zhaoxu commited on
Commit
c295b91
1 Parent(s): 7a9157d
app.py CHANGED
@@ -58,11 +58,11 @@ cfg = Config(parse_args())
58
 
59
  model_config = cfg.model_cfg
60
  model_cls = registry.get_model_class(model_config.arch)
61
- model = model_cls.from_config(model_config).to('cpu')
62
 
63
  vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
64
  vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
65
- chat = Chat(model, vis_processor,device='cpu')
66
  print('Initialization Finished')
67
 
68
  # ========================================
 
58
 
59
  model_config = cfg.model_cfg
60
  model_cls = registry.get_model_class(model_config.arch)
61
+ model = model_cls.from_config(model_config).to('cuda:0')
62
 
63
  vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
64
  vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
65
+ chat = Chat(model, vis_processor)
66
  print('Initialization Finished')
67
 
68
  # ========================================
minigpt4/configs/models/minigpt4_vicuna0.yaml CHANGED
@@ -5,7 +5,7 @@ model:
5
  image_size: 224
6
  drop_path_rate: 0
7
  use_grad_checkpoint: False
8
- vit_precision: "fp32"
9
  freeze_vit: True
10
  freeze_qformer: True
11
 
 
5
  image_size: 224
6
  drop_path_rate: 0
7
  use_grad_checkpoint: False
8
+ vit_precision: "fp16"
9
  freeze_vit: True
10
  freeze_qformer: True
11
 
minigpt4/models/base_model.py CHANGED
@@ -178,14 +178,14 @@ class BaseModel(nn.Module):
178
  if low_resource:
179
  llama_model = PhiForCausalLM.from_pretrained(
180
  llama_model_path,
181
- torch_dtype=torch.float32,
182
  load_in_8bit=True,
183
  device_map={'': low_res_device}
184
  )
185
  else:
186
  llama_model = PhiForCausalLM.from_pretrained(
187
  llama_model_path,
188
- torch_dtype=torch.float32,
189
  )
190
 
191
  if lora_r > 0:
 
178
  if low_resource:
179
  llama_model = PhiForCausalLM.from_pretrained(
180
  llama_model_path,
181
+ torch_dtype=torch.float16,
182
  load_in_8bit=True,
183
  device_map={'': low_res_device}
184
  )
185
  else:
186
  llama_model = PhiForCausalLM.from_pretrained(
187
  llama_model_path,
188
+ torch_dtype=torch.float16,
189
  )
190
 
191
  if lora_r > 0: