Peijie commited on
Commit
395d6df
·
1 Parent(s): a209c09

add text embeddings and change prediction logits.

Browse files
app.py CHANGED
@@ -1,8 +1,9 @@
1
  import io
2
  import os
3
- os.system("pip uninstall -y gradio")
4
- os.system("pip install gradio==3.41.0")
5
 
 
6
  import json
7
  import base64
8
  import random
@@ -24,6 +25,9 @@ XCLIP_DESC = json.load(open(XCLIP_DESC_PATH, "r"))
24
  PREPROCESS = lambda x: OWLVIT_PRECESSOR(images=x, return_tensors='pt')
25
  IMAGES_FOLDER = "data/images"
26
  XCLIP_RESULTS = json.load(open("data/jsons/xclip_org.json", "r"))
 
 
 
27
  # correct_predictions = [k for k, v in XCLIP_RESULTS.items() if v['prediction']]
28
 
29
  # get the intersection of sachit and xclip (revised)
@@ -225,7 +229,18 @@ def update_selected_image(event: gr.SelectData):
225
  gt_class.state = gt_label
226
 
227
  # --- for initial value only ---
228
- out_dict = xclip_pred(new_desc=None, new_part_mask=None, new_class=None, org_desc=XCLIP_DESC_PATH, image=Image.open(os.path.join(IMAGES_FOLDER, 'org', current_image.state)).convert('RGB'), model=XCLIP, owlvit_processor=OWLVIT_PRECESSOR, device=DEVICE, image_name=current_image.state)
 
 
 
 
 
 
 
 
 
 
 
229
  xclip_label = out_dict['pred_class']
230
  clip_pred_scores = out_dict['pred_score']
231
  xclip_part_scores = out_dict['pred_desc_scores']
@@ -298,7 +313,18 @@ def on_predict_button_click_xclip(textbox_input: str):
298
  descriptions_dict, part_mask, new_class_name = convert_input_text_to_xclip_format(textbox_input)
299
 
300
  # Get the new predictions and explanations
301
- out_dict = xclip_pred(new_desc=descriptions_dict, new_part_mask=part_mask, new_class=new_class_name, org_desc=XCLIP_DESC_PATH, image=Image.open(os.path.join(IMAGES_FOLDER, 'org', current_image.state)).convert('RGB'), model=XCLIP, owlvit_processor=OWLVIT_PRECESSOR, device=DEVICE, image_name=current_image.state)
 
 
 
 
 
 
 
 
 
 
 
302
  xclip_label = out_dict['pred_class']
303
  xclip_pred_score = out_dict['pred_score']
304
  xclip_part_scores = out_dict['pred_desc_scores']
@@ -403,5 +429,5 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="PEEB") as demo:
403
  xclip_edit_button.click(on_edit_button_click_xclip, inputs=[], outputs=[xclip_textbox, custom_explanation])
404
  xclip_predict_button.click(on_predict_button_click_xclip, inputs=[xclip_textbox], outputs=[xclip_textbox, xclip_pred_label, xclip_explanation, custom_pred_label, custom_explanation])
405
 
406
- # demo.launch(server_port=5000, share=True)
407
- demo.launch()
 
1
  import io
2
  import os
3
+ # os.system("pip uninstall -y gradio")
4
+ # os.system("pip install gradio==3.41.0")
5
 
6
+ import torch
7
  import json
8
  import base64
9
  import random
 
25
  PREPROCESS = lambda x: OWLVIT_PRECESSOR(images=x, return_tensors='pt')
26
  IMAGES_FOLDER = "data/images"
27
  XCLIP_RESULTS = json.load(open("data/jsons/xclip_org.json", "r"))
28
+ CUB_DESC_EMBEDS = torch.load('data/text_embeddings/cub_200_desc.pt')
29
+ CUB_IDX2NAME = json.load(open('data/jsons/cub_desc_idx2name.json', 'r'))
30
+ CUB_IDX2NAME = {int(k): v for k, v in CUB_IDX2NAME.items()}
31
  # correct_predictions = [k for k, v in XCLIP_RESULTS.items() if v['prediction']]
32
 
33
  # get the intersection of sachit and xclip (revised)
 
229
  gt_class.state = gt_label
230
 
231
  # --- for initial value only ---
232
+ out_dict = xclip_pred(new_desc=None,
233
+ new_part_mask=None,
234
+ new_class=None,
235
+ org_desc=XCLIP_DESC_PATH,
236
+ image=Image.open(os.path.join(IMAGES_FOLDER, 'org', current_image.state)).convert('RGB'),
237
+ model=XCLIP,
238
+ owlvit_processor=OWLVIT_PRECESSOR,
239
+ device=DEVICE,
240
+ image_name=current_image.state,
241
+ cub_embeds=CUB_DESC_EMBEDS,
242
+ cub_idx2name=CUB_IDX2NAME,
243
+ descriptors=XCLIP_DESC)
244
  xclip_label = out_dict['pred_class']
245
  clip_pred_scores = out_dict['pred_score']
246
  xclip_part_scores = out_dict['pred_desc_scores']
 
313
  descriptions_dict, part_mask, new_class_name = convert_input_text_to_xclip_format(textbox_input)
314
 
315
  # Get the new predictions and explanations
316
+ out_dict = xclip_pred(new_desc=descriptions_dict,
317
+ new_part_mask=part_mask,
318
+ new_class=new_class_name,
319
+ org_desc=XCLIP_DESC_PATH,
320
+ image=Image.open(os.path.join(IMAGES_FOLDER, 'org', current_image.state)).convert('RGB'),
321
+ model=XCLIP,
322
+ owlvit_processor=OWLVIT_PRECESSOR,
323
+ device=DEVICE,
324
+ image_name=current_image.state,
325
+ cub_embeds=CUB_DESC_EMBEDS,
326
+ cub_idx2name=CUB_IDX2NAME,
327
+ descriptors=XCLIP_DESC)
328
  xclip_label = out_dict['pred_class']
329
  xclip_pred_score = out_dict['pred_score']
330
  xclip_part_scores = out_dict['pred_desc_scores']
 
429
  xclip_edit_button.click(on_edit_button_click_xclip, inputs=[], outputs=[xclip_textbox, custom_explanation])
430
  xclip_predict_button.click(on_predict_button_click_xclip, inputs=[xclip_textbox], outputs=[xclip_textbox, xclip_pred_label, xclip_explanation, custom_pred_label, custom_explanation])
431
 
432
+ demo.launch(server_port=5000, share=True)
433
+ # demo.launch()
data/jsons/cub_desc_idx2name.json ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "0": "Black-footed Albatross",
3
+ "1": "Laysan Albatross",
4
+ "2": "Sooty Albatross",
5
+ "3": "Groove-billed Ani",
6
+ "4": "Crested Auklet",
7
+ "5": "Least Auklet",
8
+ "6": "Parakeet Auklet",
9
+ "7": "Rhinoceros Auklet",
10
+ "8": "Brewer Blackbird",
11
+ "9": "Red-winged Blackbird",
12
+ "10": "Rusty Blackbird",
13
+ "11": "Yellow-headed Blackbird",
14
+ "12": "Bobolink",
15
+ "13": "Indigo Bunting",
16
+ "14": "Lazuli Bunting",
17
+ "15": "Painted Bunting",
18
+ "16": "Cardinal",
19
+ "17": "Spotted Catbird",
20
+ "18": "Gray Catbird",
21
+ "19": "Yellow-breasted Chat",
22
+ "20": "Eastern Towhee",
23
+ "21": "Chuck-will Widow",
24
+ "22": "Brandt Cormorant",
25
+ "23": "Red-faced Cormorant",
26
+ "24": "Pelagic Cormorant",
27
+ "25": "Bronzed Cowbird",
28
+ "26": "Shiny Cowbird",
29
+ "27": "Brown Creeper",
30
+ "28": "American Crow",
31
+ "29": "Fish Crow",
32
+ "30": "Black-billed Cuckoo",
33
+ "31": "Mangrove Cuckoo",
34
+ "32": "Yellow-billed Cuckoo",
35
+ "33": "Gray-crowned-Rosy Finch",
36
+ "34": "Purple Finch",
37
+ "35": "Northern Flicker",
38
+ "36": "Acadian Flycatcher",
39
+ "37": "Great-Crested Flycatcher",
40
+ "38": "Least Flycatcher",
41
+ "39": "Olive-sided Flycatcher",
42
+ "40": "Scissor-tailed Flycatcher",
43
+ "41": "Vermilion Flycatcher",
44
+ "42": "Yellow-bellied Flycatcher",
45
+ "43": "Frigatebird",
46
+ "44": "Northern Fulmar",
47
+ "45": "Gadwall",
48
+ "46": "American Goldfinch",
49
+ "47": "European Goldfinch",
50
+ "48": "Boat-tailed Grackle",
51
+ "49": "Eared Grebe",
52
+ "50": "Horned Grebe",
53
+ "51": "Pied-billed Grebe",
54
+ "52": "Western Grebe",
55
+ "53": "Blue Grosbeak",
56
+ "54": "Evening Grosbeak",
57
+ "55": "Pine Grosbeak",
58
+ "56": "Rose-breasted Grosbeak",
59
+ "57": "Pigeon Guillemot",
60
+ "58": "California Gull",
61
+ "59": "Glaucous-winged Gull",
62
+ "60": "Heermann Gull",
63
+ "61": "Herring Gull",
64
+ "62": "Ivory Gull",
65
+ "63": "Ring-billed Gull",
66
+ "64": "Slaty-backed Gull",
67
+ "65": "Western Gull",
68
+ "66": "Anna Hummingbird",
69
+ "67": "Ruby-throated Hummingbird",
70
+ "68": "Rufous Hummingbird",
71
+ "69": "Green Violetear",
72
+ "70": "Long-tailed Jaeger",
73
+ "71": "Pomarine Jaeger",
74
+ "72": "Blue Jay",
75
+ "73": "Florida Jay",
76
+ "74": "Green Jay",
77
+ "75": "Dark-eyed Junco",
78
+ "76": "Tropical Kingbird",
79
+ "77": "Gray Kingbird",
80
+ "78": "Belted Kingfisher",
81
+ "79": "Green Kingfisher",
82
+ "80": "Pied Kingfisher",
83
+ "81": "Ringed Kingfisher",
84
+ "82": "White-breasted Kingfisher",
85
+ "83": "Red-legged Kittiwake",
86
+ "84": "Horned Lark",
87
+ "85": "Pacific Loon",
88
+ "86": "Mallard",
89
+ "87": "Western Meadowlark",
90
+ "88": "Hooded Merganser",
91
+ "89": "Red-breasted Merganser",
92
+ "90": "Mockingbird",
93
+ "91": "Nighthawk",
94
+ "92": "Clark Nutcracker",
95
+ "93": "White-breasted Nuthatch",
96
+ "94": "Baltimore Oriole",
97
+ "95": "Hooded Oriole",
98
+ "96": "Orchard Oriole",
99
+ "97": "Scott Oriole",
100
+ "98": "Ovenbird",
101
+ "99": "Brown Pelican",
102
+ "100": "White Pelican",
103
+ "101": "Western-Wood Pewee",
104
+ "102": "Sayornis",
105
+ "103": "American Pipit",
106
+ "104": "Whip-poor Will",
107
+ "105": "Horned Puffin",
108
+ "106": "Common Raven",
109
+ "107": "White-necked Raven",
110
+ "108": "American Redstart",
111
+ "109": "Geococcyx",
112
+ "110": "Loggerhead Shrike",
113
+ "111": "Great-Grey Shrike",
114
+ "112": "Baird Sparrow",
115
+ "113": "Black-throated Sparrow",
116
+ "114": "Brewer Sparrow",
117
+ "115": "Chipping Sparrow",
118
+ "116": "Clay-colored Sparrow",
119
+ "117": "House Sparrow",
120
+ "118": "Field Sparrow",
121
+ "119": "Fox Sparrow",
122
+ "120": "Grasshopper Sparrow",
123
+ "121": "Harris Sparrow",
124
+ "122": "Henslow Sparrow",
125
+ "123": "Le-Conte Sparrow",
126
+ "124": "Lincoln Sparrow",
127
+ "125": "Nelson-Sharp-tailed Sparrow",
128
+ "126": "Savannah Sparrow",
129
+ "127": "Seaside Sparrow",
130
+ "128": "Song Sparrow",
131
+ "129": "Tree Sparrow",
132
+ "130": "Vesper Sparrow",
133
+ "131": "White-crowned Sparrow",
134
+ "132": "White-throated Sparrow",
135
+ "133": "Cape-Glossy Starling",
136
+ "134": "Bank Swallow",
137
+ "135": "Barn Swallow",
138
+ "136": "Cliff Swallow",
139
+ "137": "Tree Swallow",
140
+ "138": "Scarlet Tanager",
141
+ "139": "Summer Tanager",
142
+ "140": "Artic Tern",
143
+ "141": "Black Tern",
144
+ "142": "Caspian Tern",
145
+ "143": "Common Tern",
146
+ "144": "Elegant Tern",
147
+ "145": "Least Tern",
148
+ "146": "Green-tailed Towhee",
149
+ "147": "Brown Thrasher",
150
+ "148": "Sage Thrasher",
151
+ "149": "Black-capped Vireo",
152
+ "150": "Blue-headed Vireo",
153
+ "151": "Philadelphia Vireo",
154
+ "152": "Red-eyed Vireo",
155
+ "153": "Warbling Vireo",
156
+ "154": "White-eyed Vireo",
157
+ "155": "Yellow-throated Vireo",
158
+ "156": "Bay-breasted Warbler",
159
+ "157": "Black-and-white Warbler",
160
+ "158": "Black-throated-Blue Warbler",
161
+ "159": "Blue-winged Warbler",
162
+ "160": "Canada Warbler",
163
+ "161": "Cape-May Warbler",
164
+ "162": "Cerulean Warbler",
165
+ "163": "Chestnut-sided Warbler",
166
+ "164": "Golden-winged Warbler",
167
+ "165": "Hooded Warbler",
168
+ "166": "Kentucky Warbler",
169
+ "167": "Magnolia Warbler",
170
+ "168": "Mourning Warbler",
171
+ "169": "Myrtle Warbler",
172
+ "170": "Nashville Warbler",
173
+ "171": "Orange-crowned Warbler",
174
+ "172": "Palm Warbler",
175
+ "173": "Pine Warbler",
176
+ "174": "Prairie Warbler",
177
+ "175": "Prothonotary Warbler",
178
+ "176": "Swainson Warbler",
179
+ "177": "Tennessee Warbler",
180
+ "178": "Wilson Warbler",
181
+ "179": "Worm-eating Warbler",
182
+ "180": "Yellow Warbler",
183
+ "181": "Northern Waterthrush",
184
+ "182": "Louisiana Waterthrush",
185
+ "183": "Bohemian Waxwing",
186
+ "184": "Cedar Waxwing",
187
+ "185": "American-Three-toed Woodpecker",
188
+ "186": "Pileated Woodpecker",
189
+ "187": "Red-bellied Woodpecker",
190
+ "188": "Red-cockaded Woodpecker",
191
+ "189": "Red-headed Woodpecker",
192
+ "190": "Downy Woodpecker",
193
+ "191": "Bewick Wren",
194
+ "192": "Cactus Wren",
195
+ "193": "Carolina Wren",
196
+ "194": "House Wren",
197
+ "195": "Marsh Wren",
198
+ "196": "Rock Wren",
199
+ "197": "Winter Wren",
200
+ "198": "Common Yellowthroat",
201
+ "199": "Forsters Tern"
202
+ }
data/text_embeddings/cub_200_desc.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:059f1ec588c01a202417a09136c17f8026fec533213536b7f70b711ec40b575d
3
+ size 4916405
utils/predict.py CHANGED
@@ -40,7 +40,10 @@ def xclip_pred(new_desc: dict,
40
  device: str,
41
  return_img_embeds: bool = False,
42
  use_precompute_embeddings = True,
43
- image_name: str = None,):
 
 
 
44
  # reorder the new description and the mask
45
  if new_class is not None:
46
  new_desc_ = {k: new_desc[k] for k in ORG_PART_ORDER}
@@ -49,34 +52,49 @@ def xclip_pred(new_desc: dict,
49
  else:
50
  desc_mask = [1] * 12
51
 
52
- # replace the description if the new class is in the description, otherwise add a new class
53
- getprompt = GetPromptList(org_desc)
54
- if new_class not in getprompt.desc and new_class is not None:
55
- getprompt.name2idx[new_class] = len(getprompt.name2idx)
56
- if new_class is not None:
57
- getprompt.desc[new_class] = list(new_desc_.values())
58
-
59
- idx2name = dict(zip(getprompt.name2idx.values(), getprompt.name2idx.keys()))
60
- modified_class_idx = getprompt.name2idx[new_class] if new_class is not None else None
61
-
62
- n_classes = len(getprompt.name2idx)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  model.cls_head.num_classes = n_classes
64
 
65
- descs, class_idxs, class_mapping, org_desc_mapper, class_list = getprompt('chatgpt-no-template', max_len=12, pad=True)
66
- query_embeds = encode_descs_xclip(owlvit_processor, model, descs, device)
67
-
68
  with torch.no_grad():
69
- image_input = owlvit_processor(images=image, return_tensors='pt').to(device)
70
- # image_input['pixel_values'] = image_input['pixel_values'].squeeze(1)
71
 
72
  part_embeds = owlvit_processor(text=[ORG_PART_ORDER], return_tensors="pt").to(device)
73
- if return_img_embeds:
74
- feature_map, _ = model.image_embedder(pixel_values = image_input['pixel_values'])
75
  if use_precompute_embeddings:
76
  image_embeds = torch.load(f'data/image_embeddings/{image_name}.pt').to(device)
77
- pred_logits, part_logits, output_dict = model(image_embeds, part_embeds, query_embeds, None)
78
  else:
79
- pred_logits, part_logits, output_dict = model(image_input, part_embeds, query_embeds, None)
 
 
80
 
81
  b, c, n = part_logits.shape
82
  mask = torch.tensor(desc_mask, dtype=float).unsqueeze(0).unsqueeze(0).repeat(b, c, 1).to(device)
@@ -100,18 +118,17 @@ def xclip_pred(new_desc: dict,
100
  else:
101
  modified_score = None
102
  modified_part_scores_dict = None
103
- modified_part_scores_dict = None
104
 
105
  output_dict = {"pred_class": pred_class_name,
106
  "pred_score": softmax_score_top1,
107
  "pred_desc_scores": part_scores_dict,
108
- "descriptions": getprompt.desc[pred_class_name],
109
  "modified_class": new_class,
110
  "modified_score": modified_score,
111
  "modified_desc_scores": modified_part_scores_dict,
112
- "modified_descriptions": getprompt.desc[new_class] if new_class is not None else None,
113
  }
114
- return output_dict if not return_img_embeds else (output_dict, feature_map)
115
 
116
 
117
  # def sachit_pred(new_desc: list,
 
40
  device: str,
41
  return_img_embeds: bool = False,
42
  use_precompute_embeddings = True,
43
+ image_name: str = None,
44
+ cub_embeds: torch.Tensor = None,
45
+ cub_idx2name: dict = None,
46
+ descriptors: dict = None):
47
  # reorder the new description and the mask
48
  if new_class is not None:
49
  new_desc_ = {k: new_desc[k] for k in ORG_PART_ORDER}
 
52
  else:
53
  desc_mask = [1] * 12
54
 
55
+ if cub_embeds is None:
56
+ # replace the description if the new class is in the description, otherwise add a new class
57
+ getprompt = GetPromptList(org_desc)
58
+ if new_class not in getprompt.desc and new_class is not None:
59
+ getprompt.name2idx[new_class] = len(getprompt.name2idx)
60
+ if new_class is not None:
61
+ getprompt.desc[new_class] = list(new_desc_.values())
62
+
63
+ idx2name = dict(zip(getprompt.name2idx.values(), getprompt.name2idx.keys()))
64
+ modified_class_idx = getprompt.name2idx[new_class] if new_class is not None else None
65
+
66
+ n_classes = len(getprompt.name2idx)
67
+ descs, class_idxs, class_mapping, org_desc_mapper, class_list = getprompt('chatgpt-no-template', max_len=12, pad=True)
68
+ query_embeds = encode_descs_xclip(owlvit_processor, model, descs, device)
69
+ else:
70
+ if new_class is not None:
71
+ if new_class in list(cub_idx2name.values()):
72
+ new_class = f"{new_class}_custom"
73
+ idx2name = cub_idx2name | {200: new_class}
74
+ descriptors |= {new_class: list(new_desc_.values())}
75
+ n_classes = 201
76
+ query_tokens = owlvit_processor(text=list(new_desc_.values()), padding="max_length", truncation=True, return_tensors="pt").to(device)
77
+ new_class_embed = model.owlvit.get_text_features(**query_tokens)
78
+ query_embeds = torch.cat([cub_embeds, new_class_embed], dim=0)
79
+ modified_class_idx = 200
80
+ else:
81
+ n_classes = 200
82
+ query_embeds = cub_embeds
83
+ idx2name = cub_idx2name
84
+ modified_class_idx = None
85
+
86
  model.cls_head.num_classes = n_classes
87
 
 
 
 
88
  with torch.no_grad():
 
 
89
 
90
  part_embeds = owlvit_processor(text=[ORG_PART_ORDER], return_tensors="pt").to(device)
 
 
91
  if use_precompute_embeddings:
92
  image_embeds = torch.load(f'data/image_embeddings/{image_name}.pt').to(device)
93
+ image_input = owlvit_processor(images=image, return_tensors='pt').to(device)
94
  else:
95
+ image_embeds, _ = model.image_embedder(pixel_values = image_input['pixel_values'])
96
+
97
+ pred_logits, part_logits, output_dict = model(image_embeds, part_embeds, query_embeds, None)
98
 
99
  b, c, n = part_logits.shape
100
  mask = torch.tensor(desc_mask, dtype=float).unsqueeze(0).unsqueeze(0).repeat(b, c, 1).to(device)
 
118
  else:
119
  modified_score = None
120
  modified_part_scores_dict = None
 
121
 
122
  output_dict = {"pred_class": pred_class_name,
123
  "pred_score": softmax_score_top1,
124
  "pred_desc_scores": part_scores_dict,
125
+ "descriptions": descriptors[pred_class_name],
126
  "modified_class": new_class,
127
  "modified_score": modified_score,
128
  "modified_desc_scores": modified_part_scores_dict,
129
+ "modified_descriptions": descriptors.get(new_class),
130
  }
131
+ return (output_dict, image_embeds) if return_img_embeds else output_dict
132
 
133
 
134
  # def sachit_pred(new_desc: list,