wondervictor commited on
Commit
35a87cf
·
verified ·
1 Parent(s): 9c44f4a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -85,20 +85,20 @@ def initialize_models(sam_path, adapter_pth, model_cfg, cfg):
85
 
86
  # SAM2初始化
87
  if sam2_model is None:
88
- sam2_model = build_sam2(model_cfg, sam_path, device="cpu", apply_postprocessing=False)
89
  print("SAM2 model initialized.")
90
 
91
  # CLIP模型初始化
92
  if clip_model is None:
93
  clip_model, _, _ = open_clip.create_model_and_transforms("convnext_large_d_320", pretrained="laion2b_s29b_b131k_ft_soup")
94
- clip_model = clip_model.to("cpu")
95
  print("CLIP model initialized.")
96
 
97
  # Mask Adapter模型初始化
98
  if mask_adapter is None:
99
- mask_adapter = build_mask_adapter(cfg, "MASKAdapterHead").cpu()
100
  # 加载Adapter状态字典
101
- adapter_state_dict = torch.load(adapter_pth, map_location=torch.device('cpu'))
102
  mask_adapter.load_state_dict(adapter_state_dict)
103
  print("Mask Adapter model initialized.")
104
 
 
85
 
86
  # SAM2初始化
87
  if sam2_model is None:
88
+ sam2_model = build_sam2(model_cfg, sam_path, device="cuda", apply_postprocessing=False)
89
  print("SAM2 model initialized.")
90
 
91
  # CLIP模型初始化
92
  if clip_model is None:
93
  clip_model, _, _ = open_clip.create_model_and_transforms("convnext_large_d_320", pretrained="laion2b_s29b_b131k_ft_soup")
94
+ clip_model = clip_model.to("cuda")
95
  print("CLIP model initialized.")
96
 
97
  # Mask Adapter模型初始化
98
  if mask_adapter is None:
99
+ mask_adapter = build_mask_adapter(cfg, "MASKAdapterHead").cuda()
100
  # 加载Adapter状态字典
101
+ adapter_state_dict = torch.load(adapter_pth)
102
  mask_adapter.load_state_dict(adapter_state_dict)
103
  print("Mask Adapter model initialized.")
104