Spaces:
Sleeping
Sleeping
add text embeddings and change prediction logits.
Browse files- app.py +32 -6
- data/jsons/cub_desc_idx2name.json +202 -0
- data/text_embeddings/cub_200_desc.pt +3 -0
- utils/predict.py +42 -25
app.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
import io
|
2 |
import os
|
3 |
-
os.system("pip uninstall -y gradio")
|
4 |
-
os.system("pip install gradio==3.41.0")
|
5 |
|
|
|
6 |
import json
|
7 |
import base64
|
8 |
import random
|
@@ -24,6 +25,9 @@ XCLIP_DESC = json.load(open(XCLIP_DESC_PATH, "r"))
|
|
24 |
PREPROCESS = lambda x: OWLVIT_PRECESSOR(images=x, return_tensors='pt')
|
25 |
IMAGES_FOLDER = "data/images"
|
26 |
XCLIP_RESULTS = json.load(open("data/jsons/xclip_org.json", "r"))
|
|
|
|
|
|
|
27 |
# correct_predictions = [k for k, v in XCLIP_RESULTS.items() if v['prediction']]
|
28 |
|
29 |
# get the intersection of sachit and xclip (revised)
|
@@ -225,7 +229,18 @@ def update_selected_image(event: gr.SelectData):
|
|
225 |
gt_class.state = gt_label
|
226 |
|
227 |
# --- for initial value only ---
|
228 |
-
out_dict = xclip_pred(new_desc=None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
xclip_label = out_dict['pred_class']
|
230 |
clip_pred_scores = out_dict['pred_score']
|
231 |
xclip_part_scores = out_dict['pred_desc_scores']
|
@@ -298,7 +313,18 @@ def on_predict_button_click_xclip(textbox_input: str):
|
|
298 |
descriptions_dict, part_mask, new_class_name = convert_input_text_to_xclip_format(textbox_input)
|
299 |
|
300 |
# Get the new predictions and explanations
|
301 |
-
out_dict = xclip_pred(new_desc=descriptions_dict,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
302 |
xclip_label = out_dict['pred_class']
|
303 |
xclip_pred_score = out_dict['pred_score']
|
304 |
xclip_part_scores = out_dict['pred_desc_scores']
|
@@ -403,5 +429,5 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="PEEB") as demo:
|
|
403 |
xclip_edit_button.click(on_edit_button_click_xclip, inputs=[], outputs=[xclip_textbox, custom_explanation])
|
404 |
xclip_predict_button.click(on_predict_button_click_xclip, inputs=[xclip_textbox], outputs=[xclip_textbox, xclip_pred_label, xclip_explanation, custom_pred_label, custom_explanation])
|
405 |
|
406 |
-
|
407 |
-
demo.launch()
|
|
|
1 |
import io
|
2 |
import os
|
3 |
+
# os.system("pip uninstall -y gradio")
|
4 |
+
# os.system("pip install gradio==3.41.0")
|
5 |
|
6 |
+
import torch
|
7 |
import json
|
8 |
import base64
|
9 |
import random
|
|
|
25 |
PREPROCESS = lambda x: OWLVIT_PRECESSOR(images=x, return_tensors='pt')
|
26 |
IMAGES_FOLDER = "data/images"
|
27 |
XCLIP_RESULTS = json.load(open("data/jsons/xclip_org.json", "r"))
|
28 |
+
CUB_DESC_EMBEDS = torch.load('data/text_embeddings/cub_200_desc.pt')
|
29 |
+
CUB_IDX2NAME = json.load(open('data/jsons/cub_desc_idx2name.json', 'r'))
|
30 |
+
CUB_IDX2NAME = {int(k): v for k, v in CUB_IDX2NAME.items()}
|
31 |
# correct_predictions = [k for k, v in XCLIP_RESULTS.items() if v['prediction']]
|
32 |
|
33 |
# get the intersection of sachit and xclip (revised)
|
|
|
229 |
gt_class.state = gt_label
|
230 |
|
231 |
# --- for initial value only ---
|
232 |
+
out_dict = xclip_pred(new_desc=None,
|
233 |
+
new_part_mask=None,
|
234 |
+
new_class=None,
|
235 |
+
org_desc=XCLIP_DESC_PATH,
|
236 |
+
image=Image.open(os.path.join(IMAGES_FOLDER, 'org', current_image.state)).convert('RGB'),
|
237 |
+
model=XCLIP,
|
238 |
+
owlvit_processor=OWLVIT_PRECESSOR,
|
239 |
+
device=DEVICE,
|
240 |
+
image_name=current_image.state,
|
241 |
+
cub_embeds=CUB_DESC_EMBEDS,
|
242 |
+
cub_idx2name=CUB_IDX2NAME,
|
243 |
+
descriptors=XCLIP_DESC)
|
244 |
xclip_label = out_dict['pred_class']
|
245 |
clip_pred_scores = out_dict['pred_score']
|
246 |
xclip_part_scores = out_dict['pred_desc_scores']
|
|
|
313 |
descriptions_dict, part_mask, new_class_name = convert_input_text_to_xclip_format(textbox_input)
|
314 |
|
315 |
# Get the new predictions and explanations
|
316 |
+
out_dict = xclip_pred(new_desc=descriptions_dict,
|
317 |
+
new_part_mask=part_mask,
|
318 |
+
new_class=new_class_name,
|
319 |
+
org_desc=XCLIP_DESC_PATH,
|
320 |
+
image=Image.open(os.path.join(IMAGES_FOLDER, 'org', current_image.state)).convert('RGB'),
|
321 |
+
model=XCLIP,
|
322 |
+
owlvit_processor=OWLVIT_PRECESSOR,
|
323 |
+
device=DEVICE,
|
324 |
+
image_name=current_image.state,
|
325 |
+
cub_embeds=CUB_DESC_EMBEDS,
|
326 |
+
cub_idx2name=CUB_IDX2NAME,
|
327 |
+
descriptors=XCLIP_DESC)
|
328 |
xclip_label = out_dict['pred_class']
|
329 |
xclip_pred_score = out_dict['pred_score']
|
330 |
xclip_part_scores = out_dict['pred_desc_scores']
|
|
|
429 |
xclip_edit_button.click(on_edit_button_click_xclip, inputs=[], outputs=[xclip_textbox, custom_explanation])
|
430 |
xclip_predict_button.click(on_predict_button_click_xclip, inputs=[xclip_textbox], outputs=[xclip_textbox, xclip_pred_label, xclip_explanation, custom_pred_label, custom_explanation])
|
431 |
|
432 |
+
demo.launch(server_port=5000, share=True)
|
433 |
+
# demo.launch()
|
data/jsons/cub_desc_idx2name.json
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"0": "Black-footed Albatross",
|
3 |
+
"1": "Laysan Albatross",
|
4 |
+
"2": "Sooty Albatross",
|
5 |
+
"3": "Groove-billed Ani",
|
6 |
+
"4": "Crested Auklet",
|
7 |
+
"5": "Least Auklet",
|
8 |
+
"6": "Parakeet Auklet",
|
9 |
+
"7": "Rhinoceros Auklet",
|
10 |
+
"8": "Brewer Blackbird",
|
11 |
+
"9": "Red-winged Blackbird",
|
12 |
+
"10": "Rusty Blackbird",
|
13 |
+
"11": "Yellow-headed Blackbird",
|
14 |
+
"12": "Bobolink",
|
15 |
+
"13": "Indigo Bunting",
|
16 |
+
"14": "Lazuli Bunting",
|
17 |
+
"15": "Painted Bunting",
|
18 |
+
"16": "Cardinal",
|
19 |
+
"17": "Spotted Catbird",
|
20 |
+
"18": "Gray Catbird",
|
21 |
+
"19": "Yellow-breasted Chat",
|
22 |
+
"20": "Eastern Towhee",
|
23 |
+
"21": "Chuck-will Widow",
|
24 |
+
"22": "Brandt Cormorant",
|
25 |
+
"23": "Red-faced Cormorant",
|
26 |
+
"24": "Pelagic Cormorant",
|
27 |
+
"25": "Bronzed Cowbird",
|
28 |
+
"26": "Shiny Cowbird",
|
29 |
+
"27": "Brown Creeper",
|
30 |
+
"28": "American Crow",
|
31 |
+
"29": "Fish Crow",
|
32 |
+
"30": "Black-billed Cuckoo",
|
33 |
+
"31": "Mangrove Cuckoo",
|
34 |
+
"32": "Yellow-billed Cuckoo",
|
35 |
+
"33": "Gray-crowned-Rosy Finch",
|
36 |
+
"34": "Purple Finch",
|
37 |
+
"35": "Northern Flicker",
|
38 |
+
"36": "Acadian Flycatcher",
|
39 |
+
"37": "Great-Crested Flycatcher",
|
40 |
+
"38": "Least Flycatcher",
|
41 |
+
"39": "Olive-sided Flycatcher",
|
42 |
+
"40": "Scissor-tailed Flycatcher",
|
43 |
+
"41": "Vermilion Flycatcher",
|
44 |
+
"42": "Yellow-bellied Flycatcher",
|
45 |
+
"43": "Frigatebird",
|
46 |
+
"44": "Northern Fulmar",
|
47 |
+
"45": "Gadwall",
|
48 |
+
"46": "American Goldfinch",
|
49 |
+
"47": "European Goldfinch",
|
50 |
+
"48": "Boat-tailed Grackle",
|
51 |
+
"49": "Eared Grebe",
|
52 |
+
"50": "Horned Grebe",
|
53 |
+
"51": "Pied-billed Grebe",
|
54 |
+
"52": "Western Grebe",
|
55 |
+
"53": "Blue Grosbeak",
|
56 |
+
"54": "Evening Grosbeak",
|
57 |
+
"55": "Pine Grosbeak",
|
58 |
+
"56": "Rose-breasted Grosbeak",
|
59 |
+
"57": "Pigeon Guillemot",
|
60 |
+
"58": "California Gull",
|
61 |
+
"59": "Glaucous-winged Gull",
|
62 |
+
"60": "Heermann Gull",
|
63 |
+
"61": "Herring Gull",
|
64 |
+
"62": "Ivory Gull",
|
65 |
+
"63": "Ring-billed Gull",
|
66 |
+
"64": "Slaty-backed Gull",
|
67 |
+
"65": "Western Gull",
|
68 |
+
"66": "Anna Hummingbird",
|
69 |
+
"67": "Ruby-throated Hummingbird",
|
70 |
+
"68": "Rufous Hummingbird",
|
71 |
+
"69": "Green Violetear",
|
72 |
+
"70": "Long-tailed Jaeger",
|
73 |
+
"71": "Pomarine Jaeger",
|
74 |
+
"72": "Blue Jay",
|
75 |
+
"73": "Florida Jay",
|
76 |
+
"74": "Green Jay",
|
77 |
+
"75": "Dark-eyed Junco",
|
78 |
+
"76": "Tropical Kingbird",
|
79 |
+
"77": "Gray Kingbird",
|
80 |
+
"78": "Belted Kingfisher",
|
81 |
+
"79": "Green Kingfisher",
|
82 |
+
"80": "Pied Kingfisher",
|
83 |
+
"81": "Ringed Kingfisher",
|
84 |
+
"82": "White-breasted Kingfisher",
|
85 |
+
"83": "Red-legged Kittiwake",
|
86 |
+
"84": "Horned Lark",
|
87 |
+
"85": "Pacific Loon",
|
88 |
+
"86": "Mallard",
|
89 |
+
"87": "Western Meadowlark",
|
90 |
+
"88": "Hooded Merganser",
|
91 |
+
"89": "Red-breasted Merganser",
|
92 |
+
"90": "Mockingbird",
|
93 |
+
"91": "Nighthawk",
|
94 |
+
"92": "Clark Nutcracker",
|
95 |
+
"93": "White-breasted Nuthatch",
|
96 |
+
"94": "Baltimore Oriole",
|
97 |
+
"95": "Hooded Oriole",
|
98 |
+
"96": "Orchard Oriole",
|
99 |
+
"97": "Scott Oriole",
|
100 |
+
"98": "Ovenbird",
|
101 |
+
"99": "Brown Pelican",
|
102 |
+
"100": "White Pelican",
|
103 |
+
"101": "Western-Wood Pewee",
|
104 |
+
"102": "Sayornis",
|
105 |
+
"103": "American Pipit",
|
106 |
+
"104": "Whip-poor Will",
|
107 |
+
"105": "Horned Puffin",
|
108 |
+
"106": "Common Raven",
|
109 |
+
"107": "White-necked Raven",
|
110 |
+
"108": "American Redstart",
|
111 |
+
"109": "Geococcyx",
|
112 |
+
"110": "Loggerhead Shrike",
|
113 |
+
"111": "Great-Grey Shrike",
|
114 |
+
"112": "Baird Sparrow",
|
115 |
+
"113": "Black-throated Sparrow",
|
116 |
+
"114": "Brewer Sparrow",
|
117 |
+
"115": "Chipping Sparrow",
|
118 |
+
"116": "Clay-colored Sparrow",
|
119 |
+
"117": "House Sparrow",
|
120 |
+
"118": "Field Sparrow",
|
121 |
+
"119": "Fox Sparrow",
|
122 |
+
"120": "Grasshopper Sparrow",
|
123 |
+
"121": "Harris Sparrow",
|
124 |
+
"122": "Henslow Sparrow",
|
125 |
+
"123": "Le-Conte Sparrow",
|
126 |
+
"124": "Lincoln Sparrow",
|
127 |
+
"125": "Nelson-Sharp-tailed Sparrow",
|
128 |
+
"126": "Savannah Sparrow",
|
129 |
+
"127": "Seaside Sparrow",
|
130 |
+
"128": "Song Sparrow",
|
131 |
+
"129": "Tree Sparrow",
|
132 |
+
"130": "Vesper Sparrow",
|
133 |
+
"131": "White-crowned Sparrow",
|
134 |
+
"132": "White-throated Sparrow",
|
135 |
+
"133": "Cape-Glossy Starling",
|
136 |
+
"134": "Bank Swallow",
|
137 |
+
"135": "Barn Swallow",
|
138 |
+
"136": "Cliff Swallow",
|
139 |
+
"137": "Tree Swallow",
|
140 |
+
"138": "Scarlet Tanager",
|
141 |
+
"139": "Summer Tanager",
|
142 |
+
"140": "Artic Tern",
|
143 |
+
"141": "Black Tern",
|
144 |
+
"142": "Caspian Tern",
|
145 |
+
"143": "Common Tern",
|
146 |
+
"144": "Elegant Tern",
|
147 |
+
"145": "Least Tern",
|
148 |
+
"146": "Green-tailed Towhee",
|
149 |
+
"147": "Brown Thrasher",
|
150 |
+
"148": "Sage Thrasher",
|
151 |
+
"149": "Black-capped Vireo",
|
152 |
+
"150": "Blue-headed Vireo",
|
153 |
+
"151": "Philadelphia Vireo",
|
154 |
+
"152": "Red-eyed Vireo",
|
155 |
+
"153": "Warbling Vireo",
|
156 |
+
"154": "White-eyed Vireo",
|
157 |
+
"155": "Yellow-throated Vireo",
|
158 |
+
"156": "Bay-breasted Warbler",
|
159 |
+
"157": "Black-and-white Warbler",
|
160 |
+
"158": "Black-throated-Blue Warbler",
|
161 |
+
"159": "Blue-winged Warbler",
|
162 |
+
"160": "Canada Warbler",
|
163 |
+
"161": "Cape-May Warbler",
|
164 |
+
"162": "Cerulean Warbler",
|
165 |
+
"163": "Chestnut-sided Warbler",
|
166 |
+
"164": "Golden-winged Warbler",
|
167 |
+
"165": "Hooded Warbler",
|
168 |
+
"166": "Kentucky Warbler",
|
169 |
+
"167": "Magnolia Warbler",
|
170 |
+
"168": "Mourning Warbler",
|
171 |
+
"169": "Myrtle Warbler",
|
172 |
+
"170": "Nashville Warbler",
|
173 |
+
"171": "Orange-crowned Warbler",
|
174 |
+
"172": "Palm Warbler",
|
175 |
+
"173": "Pine Warbler",
|
176 |
+
"174": "Prairie Warbler",
|
177 |
+
"175": "Prothonotary Warbler",
|
178 |
+
"176": "Swainson Warbler",
|
179 |
+
"177": "Tennessee Warbler",
|
180 |
+
"178": "Wilson Warbler",
|
181 |
+
"179": "Worm-eating Warbler",
|
182 |
+
"180": "Yellow Warbler",
|
183 |
+
"181": "Northern Waterthrush",
|
184 |
+
"182": "Louisiana Waterthrush",
|
185 |
+
"183": "Bohemian Waxwing",
|
186 |
+
"184": "Cedar Waxwing",
|
187 |
+
"185": "American-Three-toed Woodpecker",
|
188 |
+
"186": "Pileated Woodpecker",
|
189 |
+
"187": "Red-bellied Woodpecker",
|
190 |
+
"188": "Red-cockaded Woodpecker",
|
191 |
+
"189": "Red-headed Woodpecker",
|
192 |
+
"190": "Downy Woodpecker",
|
193 |
+
"191": "Bewick Wren",
|
194 |
+
"192": "Cactus Wren",
|
195 |
+
"193": "Carolina Wren",
|
196 |
+
"194": "House Wren",
|
197 |
+
"195": "Marsh Wren",
|
198 |
+
"196": "Rock Wren",
|
199 |
+
"197": "Winter Wren",
|
200 |
+
"198": "Common Yellowthroat",
|
201 |
+
"199": "Forsters Tern"
|
202 |
+
}
|
data/text_embeddings/cub_200_desc.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:059f1ec588c01a202417a09136c17f8026fec533213536b7f70b711ec40b575d
|
3 |
+
size 4916405
|
utils/predict.py
CHANGED
@@ -40,7 +40,10 @@ def xclip_pred(new_desc: dict,
|
|
40 |
device: str,
|
41 |
return_img_embeds: bool = False,
|
42 |
use_precompute_embeddings = True,
|
43 |
-
image_name: str = None,
|
|
|
|
|
|
|
44 |
# reorder the new description and the mask
|
45 |
if new_class is not None:
|
46 |
new_desc_ = {k: new_desc[k] for k in ORG_PART_ORDER}
|
@@ -49,34 +52,49 @@ def xclip_pred(new_desc: dict,
|
|
49 |
else:
|
50 |
desc_mask = [1] * 12
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
getprompt.
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
model.cls_head.num_classes = n_classes
|
64 |
|
65 |
-
descs, class_idxs, class_mapping, org_desc_mapper, class_list = getprompt('chatgpt-no-template', max_len=12, pad=True)
|
66 |
-
query_embeds = encode_descs_xclip(owlvit_processor, model, descs, device)
|
67 |
-
|
68 |
with torch.no_grad():
|
69 |
-
image_input = owlvit_processor(images=image, return_tensors='pt').to(device)
|
70 |
-
# image_input['pixel_values'] = image_input['pixel_values'].squeeze(1)
|
71 |
|
72 |
part_embeds = owlvit_processor(text=[ORG_PART_ORDER], return_tensors="pt").to(device)
|
73 |
-
if return_img_embeds:
|
74 |
-
feature_map, _ = model.image_embedder(pixel_values = image_input['pixel_values'])
|
75 |
if use_precompute_embeddings:
|
76 |
image_embeds = torch.load(f'data/image_embeddings/{image_name}.pt').to(device)
|
77 |
-
|
78 |
else:
|
79 |
-
|
|
|
|
|
80 |
|
81 |
b, c, n = part_logits.shape
|
82 |
mask = torch.tensor(desc_mask, dtype=float).unsqueeze(0).unsqueeze(0).repeat(b, c, 1).to(device)
|
@@ -100,18 +118,17 @@ def xclip_pred(new_desc: dict,
|
|
100 |
else:
|
101 |
modified_score = None
|
102 |
modified_part_scores_dict = None
|
103 |
-
modified_part_scores_dict = None
|
104 |
|
105 |
output_dict = {"pred_class": pred_class_name,
|
106 |
"pred_score": softmax_score_top1,
|
107 |
"pred_desc_scores": part_scores_dict,
|
108 |
-
"descriptions":
|
109 |
"modified_class": new_class,
|
110 |
"modified_score": modified_score,
|
111 |
"modified_desc_scores": modified_part_scores_dict,
|
112 |
-
"modified_descriptions":
|
113 |
}
|
114 |
-
return output_dict if
|
115 |
|
116 |
|
117 |
# def sachit_pred(new_desc: list,
|
|
|
40 |
device: str,
|
41 |
return_img_embeds: bool = False,
|
42 |
use_precompute_embeddings = True,
|
43 |
+
image_name: str = None,
|
44 |
+
cub_embeds: torch.Tensor = None,
|
45 |
+
cub_idx2name: dict = None,
|
46 |
+
descriptors: dict = None):
|
47 |
# reorder the new description and the mask
|
48 |
if new_class is not None:
|
49 |
new_desc_ = {k: new_desc[k] for k in ORG_PART_ORDER}
|
|
|
52 |
else:
|
53 |
desc_mask = [1] * 12
|
54 |
|
55 |
+
if cub_embeds is None:
|
56 |
+
# replace the description if the new class is in the description, otherwise add a new class
|
57 |
+
getprompt = GetPromptList(org_desc)
|
58 |
+
if new_class not in getprompt.desc and new_class is not None:
|
59 |
+
getprompt.name2idx[new_class] = len(getprompt.name2idx)
|
60 |
+
if new_class is not None:
|
61 |
+
getprompt.desc[new_class] = list(new_desc_.values())
|
62 |
+
|
63 |
+
idx2name = dict(zip(getprompt.name2idx.values(), getprompt.name2idx.keys()))
|
64 |
+
modified_class_idx = getprompt.name2idx[new_class] if new_class is not None else None
|
65 |
+
|
66 |
+
n_classes = len(getprompt.name2idx)
|
67 |
+
descs, class_idxs, class_mapping, org_desc_mapper, class_list = getprompt('chatgpt-no-template', max_len=12, pad=True)
|
68 |
+
query_embeds = encode_descs_xclip(owlvit_processor, model, descs, device)
|
69 |
+
else:
|
70 |
+
if new_class is not None:
|
71 |
+
if new_class in list(cub_idx2name.values()):
|
72 |
+
new_class = f"{new_class}_custom"
|
73 |
+
idx2name = cub_idx2name | {200: new_class}
|
74 |
+
descriptors |= {new_class: list(new_desc_.values())}
|
75 |
+
n_classes = 201
|
76 |
+
query_tokens = owlvit_processor(text=list(new_desc_.values()), padding="max_length", truncation=True, return_tensors="pt").to(device)
|
77 |
+
new_class_embed = model.owlvit.get_text_features(**query_tokens)
|
78 |
+
query_embeds = torch.cat([cub_embeds, new_class_embed], dim=0)
|
79 |
+
modified_class_idx = 200
|
80 |
+
else:
|
81 |
+
n_classes = 200
|
82 |
+
query_embeds = cub_embeds
|
83 |
+
idx2name = cub_idx2name
|
84 |
+
modified_class_idx = None
|
85 |
+
|
86 |
model.cls_head.num_classes = n_classes
|
87 |
|
|
|
|
|
|
|
88 |
with torch.no_grad():
|
|
|
|
|
89 |
|
90 |
part_embeds = owlvit_processor(text=[ORG_PART_ORDER], return_tensors="pt").to(device)
|
|
|
|
|
91 |
if use_precompute_embeddings:
|
92 |
image_embeds = torch.load(f'data/image_embeddings/{image_name}.pt').to(device)
|
93 |
+
image_input = owlvit_processor(images=image, return_tensors='pt').to(device)
|
94 |
else:
|
95 |
+
image_embeds, _ = model.image_embedder(pixel_values = image_input['pixel_values'])
|
96 |
+
|
97 |
+
pred_logits, part_logits, output_dict = model(image_embeds, part_embeds, query_embeds, None)
|
98 |
|
99 |
b, c, n = part_logits.shape
|
100 |
mask = torch.tensor(desc_mask, dtype=float).unsqueeze(0).unsqueeze(0).repeat(b, c, 1).to(device)
|
|
|
118 |
else:
|
119 |
modified_score = None
|
120 |
modified_part_scores_dict = None
|
|
|
121 |
|
122 |
output_dict = {"pred_class": pred_class_name,
|
123 |
"pred_score": softmax_score_top1,
|
124 |
"pred_desc_scores": part_scores_dict,
|
125 |
+
"descriptions": descriptors[pred_class_name],
|
126 |
"modified_class": new_class,
|
127 |
"modified_score": modified_score,
|
128 |
"modified_desc_scores": modified_part_scores_dict,
|
129 |
+
"modified_descriptions": descriptors.get(new_class),
|
130 |
}
|
131 |
+
return (output_dict, image_embeds) if return_img_embeds else output_dict
|
132 |
|
133 |
|
134 |
# def sachit_pred(new_desc: list,
|