Peijie commited on
Commit
e1c03cb
1 Parent(s): f80b905

Assuming GPUZero is always available.

Browse files
Files changed (2) hide show
  1. app.py +7 -8
  2. utils/predict.py +1 -1
app.py CHANGED
@@ -10,29 +10,28 @@ from pathlib import Path
10
  from PIL import Image
11
 
12
  from plots import get_pre_define_colors
13
- # from utils.load_model import load_xclip
14
  from utils.predict import xclip_pred
15
 
16
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
17
- # def initialize_model():
18
- # global XCLIP, OWLVIT_PRECESSOR
19
- # if XCLIP is None or OWLVIT_PRECESSOR is None:
20
- # XCLIP, OWLVIT_PRECESSOR = load_xclip(DEVICE)
21
 
22
  #! Huggingface does not allow load model to main process, so we need to load the model when needed, it may not help in improve the speed of the app.
23
  try:
24
  import spaces
25
  XCLIP, OWLVIT_PRECESSOR = None, None
 
26
  except:
 
27
  print(f"Not at Huggingface demo, load model to main process.")
28
  XCLIP, OWLVIT_PRECESSOR = load_xclip(DEVICE)
29
-
 
 
30
  XCLIP_DESC_PATH = "data/jsons/bs_cub_desc.json"
31
  XCLIP_DESC = json.load(open(XCLIP_DESC_PATH, "r"))
32
  IMAGES_FOLDER = "data/images"
33
  # XCLIP_RESULTS = json.load(open("data/jsons/xclip_org.json", "r"))
34
  IMAGE2GT = json.load(open("data/jsons/image2gt.json", 'r'))
35
- CUB_DESC_EMBEDS = torch.load('data/text_embeddings/cub_200_desc.pt').to(DEVICE)
36
  CUB_IDX2NAME = json.load(open('data/jsons/cub_desc_idx2name.json', 'r'))
37
  CUB_IDX2NAME = {int(k): v for k, v in CUB_IDX2NAME.items()}
38
 
 
10
  from PIL import Image
11
 
12
  from plots import get_pre_define_colors
13
+ from utils.load_model import load_xclip
14
  from utils.predict import xclip_pred
15
 
 
 
 
 
 
16
 
17
  #! Huggingface does not allow load model to main process, so we need to load the model when needed, it may not help in improve the speed of the app.
18
  try:
19
  import spaces
20
  XCLIP, OWLVIT_PRECESSOR = None, None
21
+ DEVICE = 'cuda'
22
  except:
23
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
  print(f"Not at Huggingface demo, load model to main process.")
25
  XCLIP, OWLVIT_PRECESSOR = load_xclip(DEVICE)
26
+
27
+ print(f"Device: {DEVICE}")
28
+
29
  XCLIP_DESC_PATH = "data/jsons/bs_cub_desc.json"
30
  XCLIP_DESC = json.load(open(XCLIP_DESC_PATH, "r"))
31
  IMAGES_FOLDER = "data/images"
32
  # XCLIP_RESULTS = json.load(open("data/jsons/xclip_org.json", "r"))
33
  IMAGE2GT = json.load(open("data/jsons/image2gt.json", 'r'))
34
+ CUB_DESC_EMBEDS = torch.load('data/text_embeddings/cub_200_desc.pt')
35
  CUB_IDX2NAME = json.load(open('data/jsons/cub_desc_idx2name.json', 'r'))
36
  CUB_IDX2NAME = {int(k): v for k, v in CUB_IDX2NAME.items()}
37
 
utils/predict.py CHANGED
@@ -91,7 +91,7 @@ def xclip_pred(new_desc: dict,
91
  modified_class_idx = 200
92
  else:
93
  n_classes = 200
94
- query_embeds = cub_embeds
95
  idx2name = cub_idx2name
96
  modified_class_idx = None
97
 
 
91
  modified_class_idx = 200
92
  else:
93
  n_classes = 200
94
+ query_embeds = cub_embeds.to(device)
95
  idx2name = cub_idx2name
96
  modified_class_idx = None
97