Spaces:
Sleeping
Sleeping
initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +7 -0
- README.md +3 -0
- app.py +400 -0
- data/image_embeddings/American_Goldfinch_0123_32505.jpg.pt +3 -0
- data/image_embeddings/Black_Tern_0101_144331.jpg.pt +3 -0
- data/image_embeddings/Brandt_Cormorant_0040_23144.jpg.pt +3 -0
- data/image_embeddings/Brown_Thrasher_0014_155421.jpg.pt +3 -0
- data/image_embeddings/Carolina_Wren_0060_186296.jpg.pt +3 -0
- data/image_embeddings/Cedar_Waxwing_0075_179114.jpg.pt +3 -0
- data/image_embeddings/Clark_Nutcracker_0126_85134.jpg.pt +3 -0
- data/image_embeddings/Gray_Catbird_0071_20974.jpg.pt +3 -0
- data/image_embeddings/Heermann_Gull_0097_45783.jpg.pt +3 -0
- data/image_embeddings/House_Wren_0137_187273.jpg.pt +3 -0
- data/image_embeddings/Ivory_Gull_0004_49019.jpg.pt +3 -0
- data/image_embeddings/Northern_Waterthrush_0038_177027.jpg.pt +3 -0
- data/image_embeddings/Pine_Warbler_0113_172456.jpg.pt +3 -0
- data/image_embeddings/Red_Headed_Woodpecker_0032_182815.jpg.pt +3 -0
- data/image_embeddings/Rufous_Hummingbird_0076_59563.jpg.pt +3 -0
- data/image_embeddings/Sage_Thrasher_0062_796462.jpg.pt +3 -0
- data/image_embeddings/Vesper_Sparrow_0030_125663.jpg.pt +3 -0
- data/image_embeddings/Western_Grebe_0064_36613.jpg.pt +3 -0
- data/image_embeddings/White_Eyed_Vireo_0046_158849.jpg.pt +3 -0
- data/image_embeddings/Winter_Wren_0048_189683.jpg.pt +3 -0
- data/images/boxes/American_Goldfinch_0123_32505_all.jpg +0 -0
- data/images/boxes/American_Goldfinch_0123_32505_back.jpg +0 -0
- data/images/boxes/American_Goldfinch_0123_32505_beak.jpg +0 -0
- data/images/boxes/American_Goldfinch_0123_32505_belly.jpg +0 -0
- data/images/boxes/American_Goldfinch_0123_32505_breast.jpg +0 -0
- data/images/boxes/American_Goldfinch_0123_32505_crown.jpg +0 -0
- data/images/boxes/American_Goldfinch_0123_32505_eyes.jpg +0 -0
- data/images/boxes/American_Goldfinch_0123_32505_forehead.jpg +0 -0
- data/images/boxes/American_Goldfinch_0123_32505_legs.jpg +0 -0
- data/images/boxes/American_Goldfinch_0123_32505_nape.jpg +0 -0
- data/images/boxes/American_Goldfinch_0123_32505_tail.jpg +0 -0
- data/images/boxes/American_Goldfinch_0123_32505_throat.jpg +0 -0
- data/images/boxes/American_Goldfinch_0123_32505_visible.jpg +0 -0
- data/images/boxes/American_Goldfinch_0123_32505_wings.jpg +0 -0
- data/images/boxes/Black_Tern_0101_144331_all.jpg +0 -0
- data/images/boxes/Black_Tern_0101_144331_back.jpg +0 -0
- data/images/boxes/Black_Tern_0101_144331_beak.jpg +0 -0
- data/images/boxes/Black_Tern_0101_144331_belly.jpg +0 -0
- data/images/boxes/Black_Tern_0101_144331_breast.jpg +0 -0
- data/images/boxes/Black_Tern_0101_144331_crown.jpg +0 -0
- data/images/boxes/Black_Tern_0101_144331_eyes.jpg +0 -0
- data/images/boxes/Black_Tern_0101_144331_forehead.jpg +0 -0
- data/images/boxes/Black_Tern_0101_144331_legs.jpg +0 -0
- data/images/boxes/Black_Tern_0101_144331_nape.jpg +0 -0
- data/images/boxes/Black_Tern_0101_144331_tail.jpg +0 -0
- data/images/boxes/Black_Tern_0101_144331_throat.jpg +0 -0
- data/images/boxes/Black_Tern_0101_144331_visible.jpg +0 -0
.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
temp*
|
2 |
+
|
3 |
+
|
4 |
+
# python temp files
|
5 |
+
__pycache__
|
6 |
+
*.pyc
|
7 |
+
.vscode
|
README.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
---
|
app.py
ADDED
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
|
4 |
+
import json
|
5 |
+
import base64
|
6 |
+
import random
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
import gradio as gr
|
10 |
+
from pathlib import Path
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
from plots import get_pre_define_colors
|
14 |
+
from utils.load_model import load_xclip
|
15 |
+
from utils.predict import xclip_pred
|
16 |
+
|
17 |
+
|
18 |
+
DEVICE = "cpu"
|
19 |
+
XCLIP, OWLVIT_PRECESSOR = load_xclip(DEVICE)
|
20 |
+
XCLIP_DESC_PATH = "data/jsons/bs_cub_desc.json"
|
21 |
+
XCLIP_DESC = json.load(open(XCLIP_DESC_PATH, "r"))
|
22 |
+
PREPROCESS = lambda x: OWLVIT_PRECESSOR(images=x, return_tensors='pt')
|
23 |
+
IMAGES_FOLDER = "data/images"
|
24 |
+
# XCLIP_RESULTS = json.load(open("data/jsons/xclip_org.json", "r"))
|
25 |
+
# correct_predictions = [k for k, v in XCLIP_RESULTS.items() if v['prediction']]
|
26 |
+
|
27 |
+
# get the intersection of sachit and xclip (revised)
|
28 |
+
# INTERSECTION = []
|
29 |
+
# IMAGE_RES = 400 * 400 # minimum resolution
|
30 |
+
# TOTAL_SAMPLES = 20
|
31 |
+
# for file_name in XCLIP_RESULTS:
|
32 |
+
# image = Image.open(os.path.join(IMAGES_FOLDER, 'org', file_name)).convert('RGB')
|
33 |
+
# w, h = image.size
|
34 |
+
# if w * h < IMAGE_RES:
|
35 |
+
# continue
|
36 |
+
# else:
|
37 |
+
# INTERSECTION.append(file_name)
|
38 |
+
|
39 |
+
# IMAGE_FILE_LIST = random.sample(INTERSECTION, TOTAL_SAMPLES)
|
40 |
+
IMAGE_FILE_LIST = json.load(open("data/jsons/file_list.json", "r"))
|
41 |
+
# IMAGE_FILE_LIST = IMAGE_FILE_LIST[:19]
|
42 |
+
# IMAGE_FILE_LIST.append('Eastern_Bluebird.jpg')
|
43 |
+
IMAGE_GALLERY = [Image.open(os.path.join(IMAGES_FOLDER, 'org', file_name)).convert('RGB') for file_name in IMAGE_FILE_LIST]
|
44 |
+
|
45 |
+
ORG_PART_ORDER = ['back', 'beak', 'belly', 'breast', 'crown', 'forehead', 'eyes', 'legs', 'wings', 'nape', 'tail', 'throat']
|
46 |
+
ORDERED_PARTS = ['crown', 'forehead', 'nape', 'eyes', 'beak', 'throat', 'breast', 'belly', 'back', 'wings', 'legs', 'tail']
|
47 |
+
COLORS = get_pre_define_colors(12, cmap_set=['Set2', 'tab10'])
|
48 |
+
SACHIT_COLOR = "#ADD8E6"
|
49 |
+
# CUB_BOXES = json.load(open("data/jsons/cub_boxes_owlvit_large.json", "r"))
|
50 |
+
VISIBILITY_DICT = json.load(open("data/jsons/cub_vis_dict_binary.json", 'r'))
|
51 |
+
VISIBILITY_DICT['Eastern_Bluebird.jpg'] = dict(zip(ORDERED_PARTS, [True]*12))
|
52 |
+
|
53 |
+
# --- Image related functions ---
|
54 |
+
def img_to_base64(img):
|
55 |
+
img_pil = Image.fromarray(img) if isinstance(img, np.ndarray) else img
|
56 |
+
buffered = io.BytesIO()
|
57 |
+
img_pil.save(buffered, format="JPEG")
|
58 |
+
img_str = base64.b64encode(buffered.getvalue())
|
59 |
+
return img_str.decode()
|
60 |
+
|
61 |
+
def create_blank_image(width=500, height=500, color=(255, 255, 255)):
|
62 |
+
"""Create a blank image of the given size and color."""
|
63 |
+
return np.array(Image.new("RGB", (width, height), color))
|
64 |
+
|
65 |
+
# Convert RGB colors to hex
|
66 |
+
def rgb_to_hex(rgb):
|
67 |
+
return f"#{''.join(f'{x:02x}' for x in rgb)}"
|
68 |
+
|
69 |
+
def load_part_images(file_name: str) -> dict:
|
70 |
+
part_images = {}
|
71 |
+
# start_time = time.time()
|
72 |
+
for part_name in ORDERED_PARTS:
|
73 |
+
base_name = Path(file_name).stem
|
74 |
+
part_image_path = os.path.join(IMAGES_FOLDER, "boxes", f"{base_name}_{part_name}.jpg")
|
75 |
+
if not Path(part_image_path).exists():
|
76 |
+
continue
|
77 |
+
image = np.array(Image.open(part_image_path))
|
78 |
+
part_images[part_name] = img_to_base64(image)
|
79 |
+
# print(f"Time cost to load 12 images: {time.time() - start_time}")
|
80 |
+
# This takes less than 0.01 seconds. So the loading time is not the bottleneck.
|
81 |
+
return part_images
|
82 |
+
|
83 |
+
def generate_xclip_explanations(result_dict:dict, visibility: dict, part_mask: dict = dict(zip(ORDERED_PARTS, [1]*12))):
|
84 |
+
"""
|
85 |
+
The result_dict needs three keys: 'descriptions', 'pred_scores', 'file_name'
|
86 |
+
descriptions: {part_name1: desc_1, part_name2: desc_2, ...}
|
87 |
+
pred_scores: {part_name1: score_1, part_name2: score_2, ...}
|
88 |
+
file_name: str
|
89 |
+
"""
|
90 |
+
|
91 |
+
descriptions = result_dict['descriptions']
|
92 |
+
image_name = result_dict['file_name']
|
93 |
+
part_images = PART_IMAGES_DICT[image_name]
|
94 |
+
MAX_LENGTH = 50
|
95 |
+
exp_length = 400
|
96 |
+
fontsize = 15
|
97 |
+
|
98 |
+
# Start the SVG inside a div
|
99 |
+
svg_parts = [f'<div style="width: {exp_length}px; height: 450px; background-color: white;">',
|
100 |
+
"<svg width=\"100%\" height=\"100%\">"]
|
101 |
+
|
102 |
+
# Add a row for each visible bird part
|
103 |
+
y_offset = 0
|
104 |
+
for part in ORDERED_PARTS:
|
105 |
+
if visibility[part] and part_mask[part]:
|
106 |
+
# Calculate the length of the bar (scaled to fit within the SVG)
|
107 |
+
part_score = max(result_dict['pred_scores'][part], 0)
|
108 |
+
bar_length = part_score * exp_length
|
109 |
+
|
110 |
+
# Modify the overlay image's opacity on mouseover and mouseout
|
111 |
+
mouseover_action1 = f"document.getElementById('overlayImage').src = 'data:image/jpeg;base64,{part_images[part]}'; document.getElementById('overlayImage').style.opacity = 1;"
|
112 |
+
mouseout_action1 = "document.getElementById('overlayImage').style.opacity = 0;"
|
113 |
+
|
114 |
+
combined_mouseover = f"javascript: {mouseover_action1};"
|
115 |
+
combined_mouseout = f"javascript: {mouseout_action1};"
|
116 |
+
|
117 |
+
# Add the description
|
118 |
+
num_lines = len(descriptions[part]) // MAX_LENGTH + 1
|
119 |
+
for line in range(num_lines):
|
120 |
+
desc_line = descriptions[part][line*MAX_LENGTH:(line+1)*MAX_LENGTH]
|
121 |
+
y_offset += fontsize
|
122 |
+
svg_parts.append(f"""
|
123 |
+
<text x="0" y="{y_offset}" font-size="{fontsize}"
|
124 |
+
onmouseover="{combined_mouseover}"
|
125 |
+
onmouseout="{combined_mouseout}">
|
126 |
+
{desc_line}
|
127 |
+
</text>
|
128 |
+
""")
|
129 |
+
|
130 |
+
# Add the bars
|
131 |
+
svg_parts.append(f"""
|
132 |
+
<rect x="0" y="{y_offset +3}" width="{bar_length}" height="{fontsize*0.7}" fill="{PART_COLORS[part]}"
|
133 |
+
onmouseover="{combined_mouseover}"
|
134 |
+
onmouseout="{combined_mouseout}">
|
135 |
+
</rect>
|
136 |
+
""")
|
137 |
+
# Add the scores
|
138 |
+
svg_parts.append(f'<text x="{exp_length - 50}" y="{y_offset+fontsize+3}" font-size="{fontsize}" fill="{PART_COLORS[part]}">{part_score:.2f}</text>')
|
139 |
+
|
140 |
+
y_offset += fontsize + 3
|
141 |
+
svg_parts.extend(("</svg>", "</div>"))
|
142 |
+
# Join everything into a single string
|
143 |
+
html = "".join(svg_parts)
|
144 |
+
|
145 |
+
|
146 |
+
return html
|
147 |
+
|
148 |
+
|
149 |
+
|
150 |
+
def generate_sachit_explanations(result_dict:dict):
|
151 |
+
descriptions = result_dict['descriptions']
|
152 |
+
scores = result_dict['scores']
|
153 |
+
MAX_LENGTH = 50
|
154 |
+
exp_length = 400
|
155 |
+
fontsize = 15
|
156 |
+
|
157 |
+
descriptions = zip(scores, descriptions)
|
158 |
+
descriptions = sorted(descriptions, key=lambda x: x[0], reverse=True)
|
159 |
+
|
160 |
+
# Start the SVG inside a div
|
161 |
+
svg_parts = [f'<div style="width: {exp_length}px; height: 450px; background-color: white;">',
|
162 |
+
"<svg width=\"100%\" height=\"100%\">"]
|
163 |
+
|
164 |
+
# Add a row for each visible bird part
|
165 |
+
y_offset = 0
|
166 |
+
for score, desc in descriptions:
|
167 |
+
|
168 |
+
# Calculate the length of the bar (scaled to fit within the SVG)
|
169 |
+
part_score = max(score, 0)
|
170 |
+
bar_length = part_score * exp_length
|
171 |
+
|
172 |
+
# Split the description into two lines if it's too long
|
173 |
+
num_lines = len(desc) // MAX_LENGTH + 1
|
174 |
+
for line in range(num_lines):
|
175 |
+
desc_line = desc[line*MAX_LENGTH:(line+1)*MAX_LENGTH]
|
176 |
+
y_offset += fontsize
|
177 |
+
svg_parts.append(f"""
|
178 |
+
<text x="0" y="{y_offset}" font-size="{fontsize}" fill="black">
|
179 |
+
{desc_line}
|
180 |
+
</text>
|
181 |
+
""")
|
182 |
+
|
183 |
+
# Add the bar
|
184 |
+
svg_parts.append(f"""
|
185 |
+
<rect x="0" y="{y_offset+3}" width="{bar_length}" height="{fontsize*0.7}" fill="{SACHIT_COLOR}">
|
186 |
+
</rect>
|
187 |
+
""")
|
188 |
+
|
189 |
+
# Add the score
|
190 |
+
svg_parts.append(f'<text x="{exp_length - 50}" y="{y_offset+fontsize+3}" font-size="fontsize" fill="{SACHIT_COLOR}">{part_score:.2f}</text>') # Added fill color
|
191 |
+
|
192 |
+
y_offset += fontsize + 3
|
193 |
+
|
194 |
+
|
195 |
+
svg_parts.extend(("</svg>", "</div>"))
|
196 |
+
# Join everything into a single string
|
197 |
+
html = "".join(svg_parts)
|
198 |
+
|
199 |
+
|
200 |
+
return html
|
201 |
+
|
202 |
+
# --- Constants created by the functions above ---
|
203 |
+
BLANK_OVERLAY = img_to_base64(create_blank_image())
|
204 |
+
PART_COLORS = {part: rgb_to_hex(COLORS[i]) for i, part in enumerate(ORDERED_PARTS)}
|
205 |
+
blank_image = np.array(Image.open('data/images/final.png').convert('RGB'))
|
206 |
+
PART_IMAGES_DICT = {file_name: load_part_images(file_name) for file_name in IMAGE_FILE_LIST}
|
207 |
+
|
208 |
+
# --- Gradio Functions ---
|
209 |
+
def update_selected_image(event: gr.SelectData):
|
210 |
+
image_height = 400
|
211 |
+
index = event.index
|
212 |
+
|
213 |
+
image_name = IMAGE_FILE_LIST[index]
|
214 |
+
current_image.state = image_name
|
215 |
+
org_image = Image.open(os.path.join(IMAGES_FOLDER, 'org', image_name)).convert('RGB')
|
216 |
+
img_base64 = f"""
|
217 |
+
<div style="position: relative; height: {image_height}px; display: inline-block;">
|
218 |
+
<img id="birdImage" src="data:image/jpeg;base64,{img_to_base64(org_image)}" style="height: {image_height}px; width: auto;">
|
219 |
+
<img id="overlayImage" src="data:image/jpeg;base64,{BLANK_OVERLAY}" style="position:absolute; top:0; left:0; width:auto; height: {image_height}px; opacity: 0;">
|
220 |
+
</div>
|
221 |
+
"""
|
222 |
+
gt_label = XCLIP_RESULTS[image_name]['ground_truth']
|
223 |
+
gt_class.state = gt_label
|
224 |
+
|
225 |
+
# --- for initial value only ---
|
226 |
+
out_dict = xclip_pred(new_desc=None, new_part_mask=None, new_class=None, org_desc=XCLIP_DESC_PATH, image=Image.open(os.path.join(IMAGES_FOLDER, 'org', current_image.state)).convert('RGB'), model=XCLIP, owlvit_processor=OWLVIT_PRECESSOR, device=DEVICE, image_name=current_image.state)
|
227 |
+
xclip_label = out_dict['pred_class']
|
228 |
+
clip_pred_scores = out_dict['pred_score']
|
229 |
+
xclip_part_scores = out_dict['pred_desc_scores']
|
230 |
+
result_dict = {'descriptions': dict(zip(ORG_PART_ORDER, out_dict["descriptions"])), 'pred_scores': xclip_part_scores, 'file_name': current_image.state}
|
231 |
+
xclip_exp = generate_xclip_explanations(result_dict, VISIBILITY_DICT[current_image.state], part_mask=dict(zip(ORDERED_PARTS, [1]*12)))
|
232 |
+
# --- end of intial value ---
|
233 |
+
|
234 |
+
xclip_color = "green" if xclip_label.strip() == gt_label.strip() else "red"
|
235 |
+
xclip_pred_markdown = f"""
|
236 |
+
### <span style='color:{xclip_color}'>XCLIP: {xclip_label} {clip_pred_scores:.4f}</span>
|
237 |
+
"""
|
238 |
+
|
239 |
+
gt_label = f"""
|
240 |
+
## {gt_label}
|
241 |
+
"""
|
242 |
+
current_predicted_class.state = xclip_label
|
243 |
+
|
244 |
+
# Populate the textbox with current descriptions
|
245 |
+
custom_class_name = "class name: custom"
|
246 |
+
descs = XCLIP_DESC[xclip_label]
|
247 |
+
descs = {k: descs[i] for i, k in enumerate(ORG_PART_ORDER)}
|
248 |
+
descs = {k: descs[k] for k in ORDERED_PARTS}
|
249 |
+
custom_text = [custom_class_name] + list(descs.values())
|
250 |
+
descriptions = ";\n".join(custom_text)
|
251 |
+
textbox = gr.Textbox.update(value=descriptions, lines=12, visible=True, label="XCLIP descriptions", interactive=True, info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', show_label=False)
|
252 |
+
# modified_exp = gr.HTML().update(value="", visible=True)
|
253 |
+
return gt_label, img_base64, xclip_pred_markdown, xclip_exp, current_image, textbox
|
254 |
+
|
255 |
+
def on_edit_button_click_xclip():
|
256 |
+
empty_exp = gr.HTML.update(visible=False)
|
257 |
+
|
258 |
+
# Populate the textbox with current descriptions
|
259 |
+
descs = XCLIP_DESC[current_predicted_class.state]
|
260 |
+
descs = {k: descs[i] for i, k in enumerate(ORG_PART_ORDER)}
|
261 |
+
descs = {k: descs[k] for k in ORDERED_PARTS}
|
262 |
+
custom_text = ["class name: custom"] + list(descs.values())
|
263 |
+
descriptions = ";\n".join(custom_text)
|
264 |
+
textbox = gr.Textbox.update(value=descriptions, lines=12, visible=True, label="XCLIP descriptions", interactive=True, info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', show_label=False)
|
265 |
+
|
266 |
+
return textbox, empty_exp
|
267 |
+
|
268 |
+
def convert_input_text_to_xclip_format(textbox_input: str):
|
269 |
+
|
270 |
+
# Split the descriptions by newline to get individual descriptions for each part
|
271 |
+
descriptions_list = textbox_input.split(";\n")
|
272 |
+
# the first line should be "class name: xxx"
|
273 |
+
class_name_line = descriptions_list[0]
|
274 |
+
new_class_name = class_name_line.split(":")[1].strip()
|
275 |
+
|
276 |
+
descriptions_list = descriptions_list[1:]
|
277 |
+
|
278 |
+
# construct descripion dict with part name as key
|
279 |
+
descriptions_dict = {}
|
280 |
+
for desc in descriptions_list:
|
281 |
+
if desc.strip() == "":
|
282 |
+
continue
|
283 |
+
part_name, _ = desc.split(":")
|
284 |
+
descriptions_dict[part_name.strip()] = desc
|
285 |
+
# fill with empty string if the part is not in the descriptions
|
286 |
+
part_mask = {}
|
287 |
+
for part in ORDERED_PARTS:
|
288 |
+
if part not in descriptions_dict:
|
289 |
+
descriptions_dict[part] = ""
|
290 |
+
part_mask[part] = 0
|
291 |
+
else:
|
292 |
+
part_mask[part] = 1
|
293 |
+
return descriptions_dict, part_mask, new_class_name
|
294 |
+
|
295 |
+
def on_predict_button_click_xclip(textbox_input: str):
|
296 |
+
descriptions_dict, part_mask, new_class_name = convert_input_text_to_xclip_format(textbox_input)
|
297 |
+
|
298 |
+
# Get the new predictions and explanations
|
299 |
+
out_dict = xclip_pred(new_desc=descriptions_dict, new_part_mask=part_mask, new_class=new_class_name, org_desc=XCLIP_DESC_PATH, image=Image.open(os.path.join(IMAGES_FOLDER, 'org', current_image.state)).convert('RGB'), model=XCLIP, owlvit_processor=OWLVIT_PRECESSOR, device=DEVICE, image_name=current_image.state)
|
300 |
+
xclip_label = out_dict['pred_class']
|
301 |
+
xclip_pred_score = out_dict['pred_score']
|
302 |
+
xclip_part_scores = out_dict['pred_desc_scores']
|
303 |
+
custom_label = out_dict['modified_class']
|
304 |
+
custom_pred_score = out_dict['modified_score']
|
305 |
+
custom_part_scores = out_dict['modified_desc_scores']
|
306 |
+
|
307 |
+
# construct a result dict to generate xclip explanations
|
308 |
+
result_dict = {'descriptions': dict(zip(ORG_PART_ORDER, out_dict["descriptions"])), 'pred_scores': xclip_part_scores, 'file_name': current_image.state}
|
309 |
+
xclip_explanation = generate_xclip_explanations(result_dict, VISIBILITY_DICT[current_image.state], part_mask)
|
310 |
+
modified_result_dict = {'descriptions': dict(zip(ORG_PART_ORDER, out_dict["modified_descriptions"])), 'pred_scores': custom_part_scores, 'file_name': current_image.state}
|
311 |
+
modified_explanation = generate_xclip_explanations(modified_result_dict, VISIBILITY_DICT[current_image.state], part_mask)
|
312 |
+
|
313 |
+
xclip_color = "green" if xclip_label.strip() == gt_class.state.strip() else "red"
|
314 |
+
xclip_pred_markdown = f"""
|
315 |
+
### <span style='color:{xclip_color}'>XCLIP: {xclip_label} {xclip_pred_score:.4f}</span>
|
316 |
+
"""
|
317 |
+
custom_color = "green" if custom_label.strip() == gt_class.state.strip() else "red"
|
318 |
+
custom_pred_markdown = f"""
|
319 |
+
### <span style='color:{custom_color}'>XCLIP: {custom_label} {custom_pred_score:.4f}</span>
|
320 |
+
"""
|
321 |
+
textbox = gr.Textbox.update(visible=False)
|
322 |
+
# return textbox, xclip_pred_markdown, xclip_explanation, custom_pred_markdown, modified_explanation
|
323 |
+
|
324 |
+
modified_exp = gr.HTML().update(value=modified_explanation, visible=True)
|
325 |
+
return textbox, xclip_pred_markdown, xclip_explanation, custom_pred_markdown, modified_exp
|
326 |
+
|
327 |
+
|
328 |
+
custom_css = """
|
329 |
+
html, body {
|
330 |
+
margin: 0;
|
331 |
+
padding: 0;
|
332 |
+
}
|
333 |
+
|
334 |
+
#container {
|
335 |
+
position: relative;
|
336 |
+
width: 400px;
|
337 |
+
height: 400px;
|
338 |
+
border: 1px solid #000;
|
339 |
+
margin: 0 auto; /* This will center the container horizontally */
|
340 |
+
}
|
341 |
+
|
342 |
+
#canvas {
|
343 |
+
position: absolute;
|
344 |
+
top: 0;
|
345 |
+
left: 0;
|
346 |
+
width: 100%;
|
347 |
+
height: 100%;
|
348 |
+
object-fit: cover;
|
349 |
+
}
|
350 |
+
|
351 |
+
"""
|
352 |
+
|
353 |
+
# Define the Gradio interface
|
354 |
+
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="PEEB") as demo:
|
355 |
+
current_image = gr.State("")
|
356 |
+
current_predicted_class = gr.State("")
|
357 |
+
gt_class = gr.State("")
|
358 |
+
|
359 |
+
with gr.Column():
|
360 |
+
title_text = gr.Markdown("# PEEB - demo")
|
361 |
+
gr.Markdown(
|
362 |
+
"- In this demo, you can edit the descriptions of a class and see how to model react to it."
|
363 |
+
)
|
364 |
+
|
365 |
+
# display the gallery of images
|
366 |
+
with gr.Column():
|
367 |
+
|
368 |
+
gr.Markdown("## Select an image to start!")
|
369 |
+
image_gallery = gr.Gallery(value=IMAGE_GALLERY, label=None, preview=False, allow_preview=False, columns=10, height=250)
|
370 |
+
gr.Markdown("### Custom descritions: \n The first row should be **class name: {some name};**, where you can name your descriptions. \n For the remianing descriptions, please use **;** to separate the descriptions for each part, and use the format **{part name}: {descriptions}**. \n Note that you can delete a part completely, in such cases, all descriptions will remove the corresponding part.")
|
371 |
+
|
372 |
+
with gr.Row():
|
373 |
+
with gr.Column():
|
374 |
+
image_label = gr.Markdown("### Class Name")
|
375 |
+
org_image = gr.HTML()
|
376 |
+
|
377 |
+
with gr.Column():
|
378 |
+
with gr.Row():
|
379 |
+
# xclip_predict_button = gr.Button(label="Predict", value="Predict")
|
380 |
+
xclip_predict_button = gr.Button(value="Predict")
|
381 |
+
xclip_pred_label = gr.Markdown("### XCLIP:")
|
382 |
+
xclip_explanation = gr.HTML()
|
383 |
+
|
384 |
+
with gr.Column():
|
385 |
+
# xclip_edit_button = gr.Button(label="Edit", value="Reset Descriptions")
|
386 |
+
xclip_edit_button = gr.Button(value="Reset Descriptions")
|
387 |
+
custom_pred_label = gr.Markdown(
|
388 |
+
"### Custom Descritpions:"
|
389 |
+
)
|
390 |
+
xclip_textbox = gr.Textbox(lines=12, placeholder="Edit the descriptions here", visible=False)
|
391 |
+
# ai_explanation = gr.Image(type="numpy", visible=True, show_label=False, height=500)
|
392 |
+
custom_explanation = gr.HTML()
|
393 |
+
|
394 |
+
gr.HTML("<br>")
|
395 |
+
|
396 |
+
image_gallery.select(update_selected_image, inputs=None, outputs=[image_label, org_image, xclip_pred_label, xclip_explanation, current_image, xclip_textbox])
|
397 |
+
xclip_edit_button.click(on_edit_button_click_xclip, inputs=[], outputs=[xclip_textbox, custom_explanation])
|
398 |
+
xclip_predict_button.click(on_predict_button_click_xclip, inputs=[xclip_textbox], outputs=[xclip_textbox, xclip_pred_label, xclip_explanation, custom_pred_label, custom_explanation])
|
399 |
+
|
400 |
+
demo.launch(server_port=5000, share=True)
|
data/image_embeddings/American_Goldfinch_0123_32505.jpg.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4405b6dfc87741cf87aa4887f77308aee46209877a7dcf29caacb4dae12459d5
|
3 |
+
size 1770910
|
data/image_embeddings/Black_Tern_0101_144331.jpg.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:218995c5e9d3256313ead069ff11c89a52ce616221880070d722f27c4227ffe2
|
3 |
+
size 1770875
|
data/image_embeddings/Brandt_Cormorant_0040_23144.jpg.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c493ed75f6dad68a1336ae3142deea98acb2eec30fbb5345aa1c545660eef4bb
|
3 |
+
size 1770900
|
data/image_embeddings/Brown_Thrasher_0014_155421.jpg.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4c051c80027beeebfabab679b596f5a2b7536c016c2c966a5736b03a980b96a5
|
3 |
+
size 1770895
|
data/image_embeddings/Carolina_Wren_0060_186296.jpg.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b0c34e05f759b6244ad50ca5529002e26a9370c9db07d22df91e476f827b7724
|
3 |
+
size 1770890
|
data/image_embeddings/Cedar_Waxwing_0075_179114.jpg.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9d91e1fd22664d4dbad771f214ae943b60c26a0e52aeefc156eddbddde8cb0fb
|
3 |
+
size 1770890
|
data/image_embeddings/Clark_Nutcracker_0126_85134.jpg.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:99e85d16d9b4b0d62e92926a7cefce6fbd5298daa1632df02d1d2bc1c812ccf4
|
3 |
+
size 1770900
|
data/image_embeddings/Gray_Catbird_0071_20974.jpg.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e02ea920306d2a41b2f0a46c3205691e1373d3a443714ba31c67bd46fa0baae8
|
3 |
+
size 1770880
|
data/image_embeddings/Heermann_Gull_0097_45783.jpg.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:51ecf397a13ffc0ef481b029c7c54498dd9c0dda7db709f9335dba01faebdc65
|
3 |
+
size 1770885
|
data/image_embeddings/House_Wren_0137_187273.jpg.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3fab5144fff8e0ff975f9064337dc032d39918bf777d149e02e4952a6ed10d8b
|
3 |
+
size 1770875
|
data/image_embeddings/Ivory_Gull_0004_49019.jpg.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:129b38324da3899caa7182fa0a251c81eba2a8ba8e71995139e269d479456e75
|
3 |
+
size 1770870
|
data/image_embeddings/Northern_Waterthrush_0038_177027.jpg.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0bd735f0756b810b8c74628ca2285311411cb6fb14639277728a60260e64cda9
|
3 |
+
size 1770925
|
data/image_embeddings/Pine_Warbler_0113_172456.jpg.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9c48503ff01eb8af79b86315ab9b6abe7d215c32ab37eb5acc54dd99b9877574
|
3 |
+
size 1770885
|
data/image_embeddings/Red_Headed_Woodpecker_0032_182815.jpg.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:54ec8a9edf3bc0e5e21a989596469efec44815f9ac30a0cdbde4f5d1f1952619
|
3 |
+
size 1770930
|
data/image_embeddings/Rufous_Hummingbird_0076_59563.jpg.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d02487c6d3b10c2bc193547a3ad863b6b02710071e93b2a99e9be17931c9e785
|
3 |
+
size 1770910
|
data/image_embeddings/Sage_Thrasher_0062_796462.jpg.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:294ae723107b6cc26f467ef19018f7d0c27befe0ddbf46ea1432a4440cf538c7
|
3 |
+
size 1770890
|
data/image_embeddings/Vesper_Sparrow_0030_125663.jpg.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0c3b58049302546a0f19e1a0da37d85ee3841d1f34674a6263b4972229539806
|
3 |
+
size 1770895
|
data/image_embeddings/Western_Grebe_0064_36613.jpg.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:66a4a4c3d9e8c61c729eef180dca7c06dc19748be507798548bb629fb8283645
|
3 |
+
size 1770885
|
data/image_embeddings/White_Eyed_Vireo_0046_158849.jpg.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:31f5601dd90778785d90da4b079faa4e8082da814b0edb75c46c27f7a59bb0c3
|
3 |
+
size 1770905
|
data/image_embeddings/Winter_Wren_0048_189683.jpg.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aa44fb0827d907160d964837908b8d313bce096d02062be2ea7192e6c2903543
|
3 |
+
size 1770880
|
data/images/boxes/American_Goldfinch_0123_32505_all.jpg
ADDED
data/images/boxes/American_Goldfinch_0123_32505_back.jpg
ADDED
data/images/boxes/American_Goldfinch_0123_32505_beak.jpg
ADDED
data/images/boxes/American_Goldfinch_0123_32505_belly.jpg
ADDED
data/images/boxes/American_Goldfinch_0123_32505_breast.jpg
ADDED
data/images/boxes/American_Goldfinch_0123_32505_crown.jpg
ADDED
data/images/boxes/American_Goldfinch_0123_32505_eyes.jpg
ADDED
data/images/boxes/American_Goldfinch_0123_32505_forehead.jpg
ADDED
data/images/boxes/American_Goldfinch_0123_32505_legs.jpg
ADDED
data/images/boxes/American_Goldfinch_0123_32505_nape.jpg
ADDED
data/images/boxes/American_Goldfinch_0123_32505_tail.jpg
ADDED
data/images/boxes/American_Goldfinch_0123_32505_throat.jpg
ADDED
data/images/boxes/American_Goldfinch_0123_32505_visible.jpg
ADDED
data/images/boxes/American_Goldfinch_0123_32505_wings.jpg
ADDED
data/images/boxes/Black_Tern_0101_144331_all.jpg
ADDED
data/images/boxes/Black_Tern_0101_144331_back.jpg
ADDED
data/images/boxes/Black_Tern_0101_144331_beak.jpg
ADDED
data/images/boxes/Black_Tern_0101_144331_belly.jpg
ADDED
data/images/boxes/Black_Tern_0101_144331_breast.jpg
ADDED
data/images/boxes/Black_Tern_0101_144331_crown.jpg
ADDED
data/images/boxes/Black_Tern_0101_144331_eyes.jpg
ADDED
data/images/boxes/Black_Tern_0101_144331_forehead.jpg
ADDED
data/images/boxes/Black_Tern_0101_144331_legs.jpg
ADDED
data/images/boxes/Black_Tern_0101_144331_nape.jpg
ADDED
data/images/boxes/Black_Tern_0101_144331_tail.jpg
ADDED
data/images/boxes/Black_Tern_0101_144331_throat.jpg
ADDED
data/images/boxes/Black_Tern_0101_144331_visible.jpg
ADDED