Peijie commited on
Commit
92ef913
1 Parent(s): a410a68

update to support gradio 4+

Browse files
Files changed (4) hide show
  1. app.py +24 -7
  2. requirements.txt +1 -1
  3. utils/load_model.py +9 -2
  4. utils/predict.py +10 -3
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import spaces
2
  import io
3
  import os
4
  debug = False
@@ -29,7 +29,7 @@ PREPROCESS = lambda x: OWLVIT_PRECESSOR(images=x, return_tensors='pt')
29
  IMAGES_FOLDER = "data/images"
30
  # XCLIP_RESULTS = json.load(open("data/jsons/xclip_org.json", "r"))
31
  IMAGE2GT = json.load(open("data/jsons/image2gt.json", 'r'))
32
- CUB_DESC_EMBEDS = torch.load('data/text_embeddings/cub_200_desc.pt')
33
  CUB_IDX2NAME = json.load(open('data/jsons/cub_desc_idx2name.json', 'r'))
34
  CUB_IDX2NAME = {int(k): v for k, v in CUB_IDX2NAME.items()}
35
  # correct_predictions = [k for k, v in XCLIP_RESULTS.items() if v['prediction']]
@@ -269,12 +269,20 @@ def update_selected_image(event: gr.SelectData):
269
  descs = {k: descs[k] for k in ORDERED_PARTS}
270
  custom_text = [custom_class_name] + list(descs.values())
271
  descriptions = ";\n".join(custom_text)
272
- textbox = gr.Textbox.update(value=descriptions, lines=12, visible=True, label="XCLIP descriptions", interactive=True, info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', show_label=False)
 
 
 
 
 
 
 
273
  # modified_exp = gr.HTML().update(value="", visible=True)
274
  return gt_label, img_base64, xclip_pred_markdown, xclip_exp, current_image, textbox
275
 
276
  def on_edit_button_click_xclip():
277
- empty_exp = gr.HTML.update(visible=False)
 
278
 
279
  # Populate the textbox with current descriptions
280
  descs = XCLIP_DESC[current_predicted_class.state]
@@ -282,7 +290,14 @@ def on_edit_button_click_xclip():
282
  descs = {k: descs[k] for k in ORDERED_PARTS}
283
  custom_text = ["class name: custom"] + list(descs.values())
284
  descriptions = ";\n".join(custom_text)
285
- textbox = gr.Textbox.update(value=descriptions, lines=12, visible=True, label="XCLIP descriptions", interactive=True, info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', show_label=False)
 
 
 
 
 
 
 
286
 
287
  return textbox, empty_exp
288
 
@@ -350,10 +365,12 @@ def on_predict_button_click_xclip(textbox_input: str):
350
  custom_pred_markdown = f"""
351
  ### <span style='color:{custom_color}'> {custom_label} &nbsp;&nbsp;&nbsp; {custom_pred_score:.4f}</span>
352
  """
353
- textbox = gr.Textbox.update(visible=False)
 
354
  # return textbox, xclip_pred_markdown, xclip_explanation, custom_pred_markdown, modified_explanation
355
 
356
- modified_exp = gr.HTML().update(value=modified_explanation, visible=True)
 
357
  return textbox, xclip_pred_markdown, xclip_explanation, custom_pred_markdown, modified_exp
358
 
359
 
 
1
+
2
  import io
3
  import os
4
  debug = False
 
29
  IMAGES_FOLDER = "data/images"
30
  # XCLIP_RESULTS = json.load(open("data/jsons/xclip_org.json", "r"))
31
  IMAGE2GT = json.load(open("data/jsons/image2gt.json", 'r'))
32
+ CUB_DESC_EMBEDS = torch.load('data/text_embeddings/cub_200_desc.pt').to(DEVICE)
33
  CUB_IDX2NAME = json.load(open('data/jsons/cub_desc_idx2name.json', 'r'))
34
  CUB_IDX2NAME = {int(k): v for k, v in CUB_IDX2NAME.items()}
35
  # correct_predictions = [k for k, v in XCLIP_RESULTS.items() if v['prediction']]
 
269
  descs = {k: descs[k] for k in ORDERED_PARTS}
270
  custom_text = [custom_class_name] + list(descs.values())
271
  descriptions = ";\n".join(custom_text)
272
+ # textbox = gr.Textbox.update(value=descriptions, lines=12, visible=True, label="XCLIP descriptions", interactive=True, info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', show_label=False)
273
+ textbox = gr.Textbox(value=descriptions,
274
+ lines=12,
275
+ visible=True,
276
+ label="XCLIP descriptions",
277
+ interactive=True,
278
+ info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}',
279
+ show_label=False)
280
  # modified_exp = gr.HTML().update(value="", visible=True)
281
  return gt_label, img_base64, xclip_pred_markdown, xclip_exp, current_image, textbox
282
 
283
  def on_edit_button_click_xclip():
284
+ # empty_exp = gr.HTML.update(visible=False)
285
+ empty_exp = gr.HTML(visible=False)
286
 
287
  # Populate the textbox with current descriptions
288
  descs = XCLIP_DESC[current_predicted_class.state]
 
290
  descs = {k: descs[k] for k in ORDERED_PARTS}
291
  custom_text = ["class name: custom"] + list(descs.values())
292
  descriptions = ";\n".join(custom_text)
293
+ # textbox = gr.Textbox.update(value=descriptions, lines=12, visible=True, label="XCLIP descriptions", interactive=True, info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', show_label=False)
294
+ textbox = gr.Textbox(value=descriptions,
295
+ lines=12,
296
+ visible=True,
297
+ label="XCLIP descriptions",
298
+ interactive=True,
299
+ info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}',
300
+ show_label=False)
301
 
302
  return textbox, empty_exp
303
 
 
365
  custom_pred_markdown = f"""
366
  ### <span style='color:{custom_color}'> {custom_label} &nbsp;&nbsp;&nbsp; {custom_pred_score:.4f}</span>
367
  """
368
+ # textbox = gr.Textbox.update(visible=False)
369
+ textbox = gr.Textbox(visible=False)
370
  # return textbox, xclip_pred_markdown, xclip_explanation, custom_pred_markdown, modified_explanation
371
 
372
+ # modified_exp = gr.HTML().update(value=modified_explanation, visible=True)
373
+ modified_exp = gr.HTML(value=modified_explanation, visible=True)
374
  return textbox, xclip_pred_markdown, xclip_explanation, custom_pred_markdown, modified_exp
375
 
376
 
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
  torch
2
  torchvision
3
- gradio==3.41.0
4
  numpy
5
  Pillow
6
  transformers
 
1
  torch
2
  torchvision
3
+ gradio
4
  numpy
5
  Pillow
6
  transformers
utils/load_model.py CHANGED
@@ -1,12 +1,19 @@
1
 
2
 
3
- import spaces
 
 
 
 
 
 
 
4
  import torch
5
  from transformers import OwlViTProcessor, OwlViTForObjectDetection
6
 
7
  from .model import OwlViTForClassification
8
 
9
- @spaces.GPU
10
  def load_xclip(device: str = "cuda:0",
11
  n_classes: int = 183,
12
  use_teacher_logits: bool = False,
 
1
 
2
 
3
+ try:
4
+ import spaces
5
+ gpu_decorator = spaces.GPU
6
+ except ImportError:
7
+ # Define a no-operation decorator as fallback
8
+ def gpu_decorator(func):
9
+ return func
10
+
11
  import torch
12
  from transformers import OwlViTProcessor, OwlViTForObjectDetection
13
 
14
  from .model import OwlViTForClassification
15
 
16
+ @gpu_decorator
17
  def load_xclip(device: str = "cuda:0",
18
  n_classes: int = 183,
19
  use_teacher_logits: bool = False,
utils/predict.py CHANGED
@@ -1,4 +1,11 @@
1
- import spaces
 
 
 
 
 
 
 
2
  import PIL
3
  import torch
4
 
@@ -30,7 +37,7 @@ def encode_descs_xclip(owlvit_det_processor: callable, model: callable, descs: l
30
  # text_embeds = torch.cat(text_embeds, dim=0)
31
  # text_embeds = torch.nn.functional.normalize(text_embeds, dim=-1)
32
  # return text_embeds.to(device)
33
- @spaces.GPU
34
  def xclip_pred(new_desc: dict,
35
  new_part_mask: dict,
36
  new_class: str,
@@ -76,7 +83,7 @@ def xclip_pred(new_desc: dict,
76
  n_classes = 201
77
  query_tokens = owlvit_processor(text=list(new_desc_.values()), padding="max_length", truncation=True, return_tensors="pt").to(device)
78
  new_class_embed = model.owlvit.get_text_features(**query_tokens)
79
- query_embeds = torch.cat([cub_embeds, new_class_embed], dim=0)
80
  modified_class_idx = 200
81
  else:
82
  n_classes = 200
 
1
+ try:
2
+ import spaces
3
+ gpu_decorator = spaces.GPU
4
+ except ImportError:
5
+ # Define a no-operation decorator as fallback
6
+ def gpu_decorator(func):
7
+ return func
8
+
9
  import PIL
10
  import torch
11
 
 
37
  # text_embeds = torch.cat(text_embeds, dim=0)
38
  # text_embeds = torch.nn.functional.normalize(text_embeds, dim=-1)
39
  # return text_embeds.to(device)
40
+ @gpu_decorator
41
  def xclip_pred(new_desc: dict,
42
  new_part_mask: dict,
43
  new_class: str,
 
83
  n_classes = 201
84
  query_tokens = owlvit_processor(text=list(new_desc_.values()), padding="max_length", truncation=True, return_tensors="pt").to(device)
85
  new_class_embed = model.owlvit.get_text_features(**query_tokens)
86
+ query_embeds = torch.cat([cub_embeds, new_class_embed], dim=0).to(device)
87
  modified_class_idx = 200
88
  else:
89
  n_classes = 200