Peijie commited on
Commit
401a164
1 Parent(s): e1c03cb

load the model to cpu first

Browse files
Files changed (2) hide show
  1. app.py +3 -8
  2. utils/predict.py +10 -4
app.py CHANGED
@@ -15,14 +15,9 @@ from utils.predict import xclip_pred
15
 
16
 
17
  #! Huggingface does not allow load model to main process, so we need to load the model when needed, it may not help in improve the speed of the app.
18
- try:
19
- import spaces
20
- XCLIP, OWLVIT_PRECESSOR = None, None
21
- DEVICE = 'cuda'
22
- except:
23
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
- print(f"Not at Huggingface demo, load model to main process.")
25
- XCLIP, OWLVIT_PRECESSOR = load_xclip(DEVICE)
26
 
27
  print(f"Device: {DEVICE}")
28
 
 
15
 
16
 
17
  #! Huggingface does not allow load model to main process, so we need to load the model when needed, it may not help in improve the speed of the app.
18
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
+ print(f"Not at Huggingface demo, load model to main process.")
20
+ XCLIP, OWLVIT_PRECESSOR = load_xclip(DEVICE)
 
 
 
 
 
21
 
22
  print(f"Device: {DEVICE}")
23
 
utils/predict.py CHANGED
@@ -53,8 +53,12 @@ def xclip_pred(new_desc: dict,
53
  cub_embeds: torch.Tensor = None,
54
  cub_idx2name: dict = None,
55
  descriptors: dict = None):
56
- if model is None or owlvit_processor is None:
57
- model, owlvit_processor = load_xclip(device=device)
 
 
 
 
58
 
59
  # reorder the new description and the mask
60
  if new_class is not None:
@@ -78,7 +82,9 @@ def xclip_pred(new_desc: dict,
78
  n_classes = len(getprompt.name2idx)
79
  descs, class_idxs, class_mapping, org_desc_mapper, class_list = getprompt('chatgpt-no-template', max_len=12, pad=True)
80
  query_embeds = encode_descs_xclip(owlvit_processor, model, descs, device)
 
81
  else:
 
82
  if new_class is not None:
83
  if new_class in list(cub_idx2name.values()):
84
  new_class = f"{new_class}_custom"
@@ -87,11 +93,11 @@ def xclip_pred(new_desc: dict,
87
  n_classes = 201
88
  query_tokens = owlvit_processor(text=list(new_desc_.values()), padding="max_length", truncation=True, return_tensors="pt").to(device)
89
  new_class_embed = model.owlvit.get_text_features(**query_tokens)
90
- query_embeds = torch.cat([cub_embeds, new_class_embed], dim=0).to(device)
91
  modified_class_idx = 200
92
  else:
93
  n_classes = 200
94
- query_embeds = cub_embeds.to(device)
95
  idx2name = cub_idx2name
96
  modified_class_idx = None
97
 
 
53
  cub_embeds: torch.Tensor = None,
54
  cub_idx2name: dict = None,
55
  descriptors: dict = None):
56
+ # check if in huggingface space
57
+ try:
58
+ model.to('cuda')
59
+ device = 'cuda'
60
+ except:
61
+ device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
62
 
63
  # reorder the new description and the mask
64
  if new_class is not None:
 
82
  n_classes = len(getprompt.name2idx)
83
  descs, class_idxs, class_mapping, org_desc_mapper, class_list = getprompt('chatgpt-no-template', max_len=12, pad=True)
84
  query_embeds = encode_descs_xclip(owlvit_processor, model, descs, device)
85
+
86
  else:
87
+ cub_embeds = cub_embeds.to(device)
88
  if new_class is not None:
89
  if new_class in list(cub_idx2name.values()):
90
  new_class = f"{new_class}_custom"
 
93
  n_classes = 201
94
  query_tokens = owlvit_processor(text=list(new_desc_.values()), padding="max_length", truncation=True, return_tensors="pt").to(device)
95
  new_class_embed = model.owlvit.get_text_features(**query_tokens)
96
+ query_embeds = torch.cat([cub_embeds, new_class_embed], dim=0)
97
  modified_class_idx = 200
98
  else:
99
  n_classes = 200
100
+ query_embeds = cub_embeds
101
  idx2name = cub_idx2name
102
  modified_class_idx = None
103