xhluca commited on
Commit
528c0e7
·
1 Parent(s): 978ece5

Update app with explorer

Browse files
Files changed (2) hide show
  1. app.py +316 -1
  2. utils.py +356 -0
app.py CHANGED
@@ -1,3 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
 
3
- st.title('Hello World')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ import json
3
+ import os
4
+ import time
5
+ import random
6
+ import string
7
+ import shutil
8
+ import traceback
9
+ import sys
10
+ from pathlib import Path
11
+
12
  import streamlit as st
13
 
14
+ from utils import (
15
+ load_json,
16
+ load_json_no_cache,
17
+ parse_arguments,
18
+ format_chat_message,
19
+ find_screenshot,
20
+ gather_chat_history,
21
+ get_screenshot,
22
+ load_page,
23
+ )
24
+
25
+
26
+ def show_selectbox(demonstration_dir):
27
+ # find all the subdirectories in the current directory
28
+ dirs = [
29
+ d
30
+ for d in os.listdir(demonstration_dir)
31
+ if os.path.isdir(f"{demonstration_dir}/{d}")
32
+ ]
33
+
34
+ if not dirs:
35
+ st.title("No recordings found.")
36
+ return None
37
+
38
+ # sort by date
39
+ dirs.sort(key=lambda x: os.path.getmtime(f"{demonstration_dir}/{x}"), reverse=True)
40
+
41
+ # offer the user a dropdown to select which recording to visualize, set a default
42
+ recording_name = st.sidebar.selectbox("Recording", dirs, index=0)
43
+
44
+ return recording_name
45
+
46
+
47
+ def show_overview(data, recording_name, basedir):
48
+ st.title(f"Recording {recording_name}")
49
+
50
+ screenshot_size = st.session_state.get("screenshot_size_view_mode", "regular")
51
+ show_advanced_info = st.session_state.get("show_advanced_information", False)
52
+
53
+ if screenshot_size == "regular":
54
+ col_layout = [1.5, 1.5, 7, 3.5]
55
+ elif screenshot_size == "small":
56
+ col_layout = [1.5, 1.5, 7, 2]
57
+ else: # screenshot_size == 'large'
58
+ col_layout = [1.5, 1.5, 11]
59
+
60
+ # col_i, col_time, col_act, col_actvis = st.columns(col_layout)
61
+ # screenshots = load_screenshots(data, basedir)
62
+
63
+ for i, d in enumerate(data):
64
+ if i > 0 and show_advanced_info:
65
+ # Use html to add a horizontal line with minimal gap
66
+ st.markdown(
67
+ "<hr style='margin-top: 0.1rem; margin-bottom: 0.1rem;'/>",
68
+ unsafe_allow_html=True,
69
+ )
70
+ if screenshot_size == "large":
71
+ col_time, col_i, col_act = st.columns(col_layout)
72
+ col_actvis = col_act
73
+ else:
74
+ col_time, col_i, col_act, col_actvis = st.columns(col_layout)
75
+ secs_from_start = d["timestamp"] - data[0]["timestamp"]
76
+ # `secs_from_start` is a float including ms, display in MM:SS.mm format
77
+ col_time.markdown(
78
+ f"**{datetime.utcfromtimestamp(secs_from_start).strftime('%M:%S')}**"
79
+ )
80
+
81
+ if not st.session_state.get("enable_html_download", True):
82
+ col_i.markdown(f"**#{i}**")
83
+
84
+ elif d["type"] == "browser" and (page_filename := d["state"]["page"]):
85
+ page_path = f"{basedir}/pages/{page_filename}"
86
+
87
+ col_i.download_button(
88
+ label="#" + str(i),
89
+ data=load_page(page_path),
90
+ file_name=recording_name + "-" + page_filename,
91
+ mime="multipart/related",
92
+ key=f"page{i}",
93
+ )
94
+ else:
95
+ col_i.button(f"#{i}", type='secondary')
96
+
97
+ if d["type"] == "chat":
98
+ col_act.markdown(format_chat_message(d), unsafe_allow_html=True)
99
+ continue
100
+
101
+ # screenshot_filename = d["state"]["screenshot"]
102
+ img = get_screenshot(d, basedir)
103
+ arguments = parse_arguments(d["action"])
104
+ event_type = d["action"]["intent"]
105
+
106
+ action_str = f"**{event_type}**({arguments})"
107
+
108
+ if img:
109
+ col_actvis.image(img)
110
+
111
+ col_act.markdown(action_str)
112
+
113
+ if show_advanced_info:
114
+ colors = {
115
+ "good": "green",
116
+ "broken": "red",
117
+ "delayed": "orange"
118
+ }
119
+ status = d["state"].get("screenshot_status", "unknown")
120
+ color = colors.get(status, "blue")
121
+
122
+ text = ""
123
+ text += f'Status: **:{color}[{status.upper()}]**\n\n'
124
+ text += f'Screenshot: `{d["state"]["screenshot"]}`\\\n'
125
+ text += f'Page: `{d["state"]["page"]}`\n'
126
+
127
+ col_act.markdown(text)
128
+
129
+
130
+ def show_slide(data, example_index, recording_name, basedir):
131
+ columns = st.columns([1, 5], gap="large")
132
+
133
+ chat = gather_chat_history(data, example_index)
134
+ example = data[example_index]
135
+
136
+ # display current message
137
+ if example["type"] == "chat":
138
+ columns[0].markdown(format_chat_message(example), unsafe_allow_html=True)
139
+
140
+ # display chat history
141
+ with columns[0]:
142
+ for d in chat:
143
+ columns[0].markdown(format_chat_message(d), unsafe_allow_html=True)
144
+
145
+ img = find_screenshot(data, example_index, basedir)
146
+
147
+ if example["type"] == "browser":
148
+ event_type = example["action"]["intent"]
149
+ arguments = parse_arguments(example["action"])
150
+ event_type = example["action"]["intent"]
151
+
152
+ action = f"**{event_type}**({arguments})"
153
+
154
+ with columns[1]:
155
+ st.write(action)
156
+
157
+ with columns[1]:
158
+ st.image(img)
159
+
160
+
161
+ def show_presentation(data, recording_name, basedir):
162
+ example_index = st.slider(
163
+ "example",
164
+ min_value=0,
165
+ max_value=len(data) - 1,
166
+ step=1,
167
+ label_visibility="hidden",
168
+ )
169
+
170
+ show_slide(data, example_index, recording_name=recording_name, basedir=basedir)
171
+
172
+
173
+ def show_video(basedir):
174
+ # find the mp4 file in the basedir
175
+ video_filename = None
176
+ for filename in os.listdir(f"{basedir}"):
177
+ if filename.endswith(".mp4"):
178
+ video_filename = filename
179
+ video_path = f"{basedir}/{video_filename}"
180
+ st.video(video_path)
181
+
182
+ if not video_filename:
183
+ st.error("No video found")
184
+
185
+
186
+ def load_recording(basedir):
187
+ # Before loading replay, we need a dropdown that allows us to select replay.json or replay_orig.json
188
+ # Find all files in basedir starting with "replay" and ending with ".json"
189
+ replay_files = sorted(
190
+ [
191
+ f
192
+ for f in os.listdir(basedir)
193
+ if f.startswith("replay") and f.endswith(".json")
194
+ ]
195
+ )
196
+ replay_file = st.sidebar.selectbox("Select replay", replay_files, index=0)
197
+ st.sidebar.checkbox(
198
+ "Advanced Screenshot Info", False, key="show_advanced_information"
199
+ )
200
+ st.sidebar.checkbox(
201
+ "Enable HTML download", True, key="enable_html_download"
202
+ )
203
+ replay_file = replay_file.replace(".json", "")
204
+
205
+ metadata = load_json(basedir, "metadata")
206
+
207
+ extension_version = metadata["version"]
208
+ st.sidebar.markdown(f"**extension version**: {extension_version}")
209
+
210
+ # convert timestamp to readable date string
211
+ recording_start_timestamp = metadata["recordingStart"]
212
+ recording_start_date = datetime.fromtimestamp(
213
+ int(recording_start_timestamp) / 1000
214
+ ).strftime("%Y-%m-%d %H:%M:%S")
215
+ st.sidebar.markdown(f"**started**: {recording_start_date}")
216
+
217
+ # recording_end_timestamp = k["recordingEnd"]
218
+ # calculate duration
219
+ # duration = int(recording_end_timestamp) - int(recording_start_timestamp)
220
+ # duration = time.strftime("%M:%S", time.gmtime(duration / 1000))
221
+
222
+ # Read in the JSON data
223
+ replay_dict = load_json_no_cache(basedir, replay_file)
224
+ form = load_json(basedir, "form")
225
+
226
+ duration = replay_dict["data"][-1]["timestamp"] - replay_dict["data"][0]["timestamp"]
227
+ duration = time.strftime("%M:%S", time.gmtime(duration))
228
+ st.sidebar.markdown(f"**duration**: {duration}")
229
+
230
+ if not replay_dict:
231
+ return None
232
+
233
+ for key in [
234
+ "annotator",
235
+ "description",
236
+ "tasks",
237
+ "upload_date",
238
+ "instructor_sees_screen",
239
+ "uses_ai_generated_output",
240
+ ]:
241
+ if form and key in form:
242
+ if type(form[key]) == list:
243
+ st.sidebar.markdown(f"**{key}**: {', '.join(form[key])}")
244
+ else:
245
+ st.sidebar.markdown(f"**{key}**: {form[key]}")
246
+
247
+ st.sidebar.markdown("---")
248
+ if replay_dict and "status" in replay_dict:
249
+ st.sidebar.markdown(f"**Validation status**: {replay_dict['status']}")
250
+
251
+ processed_meta_path = Path(basedir).joinpath('processed_metadata.json')
252
+ start_frame = 'file not found'
253
+
254
+ if processed_meta_path.exists():
255
+ with open(processed_meta_path) as f:
256
+ processed_meta = json.load(f)
257
+ start_frame = processed_meta.get('start_frame', 'info not in file')
258
+
259
+ st.sidebar.markdown(f"**Recording start frame**: {start_frame}")
260
+
261
+
262
+ # st.sidebar.button("Delete recording", type="primary", on_click=delete_recording, args=[basedir])
263
+
264
+ data = replay_dict["data"]
265
+ return data
266
+
267
+
268
+ def run():
269
+ mode = st.sidebar.radio("Mode", ["Overview"])
270
+ demonstration_dir = "./demonstrations"
271
+
272
+ params = st.experimental_get_query_params()
273
+ # list demonstrations/
274
+ demo_names = os.listdir(demonstration_dir)
275
+
276
+ if params.get("recording"):
277
+ recording_name = params["recording"][0]
278
+
279
+ else:
280
+ recording_name = demo_names[0]
281
+
282
+ recording_name = st.sidebar.selectbox(
283
+ "Recordings",
284
+ demo_names,
285
+ index=demo_names.index(recording_name),
286
+ )
287
+
288
+ if recording_name != params.get("recording", [None])[0]:
289
+ st.experimental_set_query_params(recording=recording_name)
290
+
291
+ if mode == "Overview":
292
+ with st.sidebar:
293
+ # Want a dropdown
294
+ st.selectbox(
295
+ "Screenshot size",
296
+ ["small", "regular", "large"],
297
+ index=1,
298
+ key="screenshot_size_view_mode",
299
+ )
300
+
301
+ if recording_name is not None:
302
+ basedir = f"{demonstration_dir}/{recording_name}"
303
+ data = load_recording(basedir=basedir)
304
+
305
+ if not data:
306
+ st.stop()
307
+
308
+ if mode == "Overview":
309
+ show_overview(data, recording_name=recording_name, basedir=basedir)
310
+ elif mode == "Presentation":
311
+ show_presentation(data, recording_name=recording_name, basedir=basedir)
312
+ elif mode == "Video":
313
+ show_video(basedir=basedir)
314
+
315
+
316
+ if __name__ == "__main__":
317
+ st.set_page_config(layout="wide")
318
+ run()
utils.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ from datetime import datetime
4
+ import os
5
+ import json
6
+ from pathlib import Path
7
+ import sys
8
+ import shutil
9
+ import time
10
+ import traceback
11
+
12
+ import pandas as pd
13
+ import streamlit as st
14
+
15
+ import json
16
+ from PIL import Image, ImageDraw
17
+
18
+
19
+ CACHE_TTL = 60 * 60 * 24 * 14
20
+
21
+ """
22
+ Streamlit app utilities
23
+ """
24
+
25
+
26
+ @st.cache_data(ttl=CACHE_TTL)
27
+ def load_json(basedir, name):
28
+ if not os.path.exists(f"{basedir}/{name}.json"):
29
+ return None
30
+
31
+ with open(f"{basedir}/{name}.json", "r") as f:
32
+ j = json.load(f)
33
+
34
+ return j
35
+
36
+
37
+ def load_json_no_cache(basedir, name):
38
+ if not os.path.exists(f"{basedir}/{name}.json"):
39
+ return None
40
+
41
+ with open(f"{basedir}/{name}.json", "r") as f:
42
+ j = json.load(f)
43
+
44
+ return j
45
+
46
+
47
+ def save_json(basedir, name, data):
48
+ with open(f"{basedir}/{name}.json", "w") as f:
49
+ json.dump(data, f, indent=4)
50
+
51
+
52
+ @st.cache_data
53
+ def load_image(image_file):
54
+ img = Image.open(image_file)
55
+ return img
56
+
57
+
58
+ @st.cache_resource
59
+ def load_page(page_path):
60
+ return open(page_path, "rb")
61
+
62
+
63
+ def shorten(s):
64
+ # shorten to 100 characters
65
+ if len(s) > 100:
66
+ s = s[:100] + "..."
67
+
68
+ return s
69
+
70
+
71
+ @st.cache_data
72
+ def parse_arguments(action):
73
+ s = []
74
+ event_type = action["intent"]
75
+ args = action["arguments"]
76
+
77
+ if event_type == "textInput":
78
+ txt = args["text"]
79
+
80
+ txt = txt.strip()
81
+
82
+ # escape markdown characters
83
+ txt = txt.replace("_", "\\_")
84
+ txt = txt.replace("*", "\\*")
85
+ txt = txt.replace("`", "\\`")
86
+ txt = txt.replace("$", "\\$")
87
+
88
+ txt = shorten(txt)
89
+
90
+ s.append(f'"{txt}"')
91
+ elif event_type == "change":
92
+ s.append(f'{args["value"]}')
93
+ elif event_type == "load":
94
+ url = args["properties"].get("url") or args.get("url")
95
+ short_url = shorten(url)
96
+ s.append(f'"[{short_url}]({url})"')
97
+
98
+ if args["properties"].get("transitionType"):
99
+ s.append(f'*{args["properties"]["transitionType"]}*')
100
+ s.append(f'*{" ".join(args["properties"]["transitionQualifiers"])}*')
101
+ elif event_type == "scroll":
102
+ s.append(f'{args["scrollX"]}, {args["scrollY"]}')
103
+ elif event_type == "say":
104
+ s.append(f'"{args["text"]}"')
105
+ elif event_type == "copy":
106
+ selected = shorten(args["selected"])
107
+ s.append(f'"{selected}"')
108
+ elif event_type == "paste":
109
+ pasted = shorten(args["pasted"])
110
+ s.append(f'"{pasted}"')
111
+ elif event_type == "tabcreate":
112
+ s.append(f'{args["properties"]["tabId"]}')
113
+ elif event_type == "tabremove":
114
+ s.append(f'{args["properties"]["tabId"]}')
115
+ elif event_type == "tabswitch":
116
+ s.append(
117
+ f'{args["properties"]["tabIdOrigin"]} -> {args["properties"]["tabId"]}'
118
+ )
119
+
120
+ if args.get("element"):
121
+
122
+ if event_type == 'click':
123
+ x = round(args['metadata']['mouseX'], 1)
124
+ y = round(args['metadata']['mouseY'], 1)
125
+ uid = args.get('element', {}).get('attributes', {}).get("data-webtasks-id")
126
+ s.append(f"*x =* {x}, *y =* {y}, *uid =* {uid}")
127
+ else:
128
+ top = round(args["element"]["bbox"]["top"], 1)
129
+ left = round(args["element"]["bbox"]["left"], 1)
130
+ right = round(args["element"]["bbox"]["right"], 1)
131
+ bottom = round(args["element"]["bbox"]["bottom"], 1)
132
+
133
+ s.append(f"*top =* {top}, *left =* {left}, *right =* {right}, *bottom =* {bottom}")
134
+
135
+ return ", ".join(s)
136
+
137
+
138
+ @st.cache_resource(max_entries=50_000, ttl=CACHE_TTL)
139
+ def create_visualization(_img, event_type, bbox, x, y, screenshot_path):
140
+ # screenshot_path is not used, but we need it for caching since we can't cache
141
+ # PIL images (hence the leading underscore in the variable name to indicate
142
+ # that it's not hashed)
143
+ _img = _img.convert("RGBA")
144
+ draw = ImageDraw.Draw(_img)
145
+
146
+ # draw a bounding box around the element
147
+ color = {
148
+ "click": "red",
149
+ "hover": "orange",
150
+ "textInput": "blue",
151
+ "change": "green",
152
+ }[event_type]
153
+
154
+ left = bbox["left"]
155
+ top = bbox["top"]
156
+ w = bbox["width"]
157
+ h = bbox["height"]
158
+ draw.rectangle((left, top, left + w, top + h), outline=color, width=2)
159
+
160
+ if event_type in ["click", "hover"]:
161
+ r = 15
162
+ for i in range(1, 5):
163
+ rx = r * i
164
+ draw.ellipse((x - rx, y - rx, x + rx, y + rx), outline=color, width=3)
165
+ draw.ellipse((x - r, y - r, x + r, y + r), fill=color)
166
+
167
+ return _img
168
+
169
+
170
+ @st.cache_data(max_entries=50_000, ttl=CACHE_TTL)
171
+ def get_screenshot_minimal(screenshot_path, event_type, bbox, x, y, new_width=None):
172
+ img = load_image(screenshot_path)
173
+ # vis = None
174
+
175
+ if event_type in ["click", "textInput", "change", "hover"]:
176
+ img = create_visualization(img, event_type, bbox, x, y, screenshot_path)
177
+
178
+ if new_width is not None:
179
+ # Resize to 800px wide
180
+ w, h = img.size
181
+ new_w = new_width
182
+ new_h = int(new_w * h / w)
183
+ img = img.resize((new_w, new_h))
184
+ print(f"Resized '{screenshot_path}' to", new_w, new_h)
185
+
186
+ return img
187
+
188
+
189
+ def get_event_info(d):
190
+ event_type = d["action"]["intent"]
191
+
192
+ try:
193
+ bbox = d["action"]["arguments"]["element"]["bbox"]
194
+ except KeyError:
195
+ bbox = None
196
+
197
+ try:
198
+ x = d["action"]["arguments"]["properties"]["x"]
199
+ y = d["action"]["arguments"]["properties"]["y"]
200
+ except KeyError:
201
+ x = None
202
+ y = None
203
+
204
+ return event_type, bbox, x, y
205
+
206
+
207
+ def get_screenshot(d, basedir, new_width=None):
208
+ screenshot_filename = d["state"]["screenshot"]
209
+
210
+ if not screenshot_filename:
211
+ return None
212
+
213
+ event_type, bbox, x, y = get_event_info(d)
214
+ screenshot_path = f"{basedir}/screenshots/{screenshot_filename}"
215
+
216
+ return get_screenshot_minimal(
217
+ screenshot_path, event_type, bbox, x, y, new_width=new_width
218
+ )
219
+
220
+
221
+ def text_bubble(text, color):
222
+ text = text.replace("\n", "<br>").replace("\t", "&nbsp;" * 8)
223
+ return f'<div style="background-color:{color}; padding: 8px; margin: 6px; border-radius:10px; display:inline-block;">{text}</div>'
224
+
225
+
226
+ def gather_chat_history(data, example_index):
227
+ chat = []
228
+ for i, d in enumerate(data):
229
+ if d["type"] == "chat":
230
+ if i >= example_index:
231
+ break
232
+ chat.append(d)
233
+
234
+ # # leave out just 5 last messages
235
+ # if len(chat) > 5:
236
+ # chat = chat[-5:]
237
+
238
+ return reversed(chat)
239
+
240
+
241
+ def format_chat_message(d):
242
+ if d["speaker"] == "instructor":
243
+ return text_bubble("🧑 " + d["utterance"], "rgba(63, 111, 255, 0.35)")
244
+ else:
245
+ return text_bubble("🤖 " + d["utterance"], "rgba(185,185,185,0.35)")
246
+
247
+
248
+ def find_screenshot(data, example_index, basedir):
249
+ # keep looking at previous screenshots until we find one
250
+ # if there is none, return None
251
+
252
+ for i in range(example_index, -1, -1):
253
+ d = data[i]
254
+ if d["type"] == "chat":
255
+ continue
256
+
257
+ screenshot = get_screenshot(d, basedir)
258
+ if screenshot:
259
+ return screenshot
260
+
261
+ return None
262
+
263
+
264
+ def create_visualization_2(_img, bbox, color, width, x, y):
265
+ _img = _img.convert("RGBA")
266
+ draw = ImageDraw.Draw(_img)
267
+
268
+ if bbox:
269
+ left = bbox["left"]
270
+ top = bbox["top"]
271
+ w = bbox["width"]
272
+ h = bbox["height"]
273
+ draw.rectangle((left, top, left + w, top + h), outline=color, width=width)
274
+
275
+ if x and y:
276
+ r = 8
277
+ for i in range(1, 4):
278
+ rx = r * i
279
+ draw.ellipse((x - rx, y - rx, x + rx, y + rx), outline=color, width=2)
280
+ draw.ellipse((x - r, y - r, x + r, y + r), fill=color)
281
+
282
+ return _img
283
+
284
+
285
+ def rescale_bbox(bbox, scaling_factor):
286
+ return {
287
+ k: bbox[k] * scaling_factor
288
+ for k in ["top", "left", "width", "height", "right", "bottom"]
289
+ if k in bbox
290
+ }
291
+
292
+
293
+ def show_overlay(
294
+ _img,
295
+ pred,
296
+ ref,
297
+ turn_args,
298
+ turn_metadata,
299
+ scale_pred=True,
300
+ show=("pred_coords", "ref", "pred_elem"),
301
+ ):
302
+ scaling_factor = turn_metadata.get("zoomLevel", 1.0)
303
+
304
+ if "pred_elem" in show:
305
+ # First, draw red box around predicted element
306
+ if pred.get("element") and pred["element"].get("bbox"):
307
+ # rescale the bbox by scaling_factor
308
+ bbox = rescale_bbox(pred["element"]["bbox"], scaling_factor)
309
+ _img = create_visualization_2(
310
+ _img, bbox, color="red", width=9, x=None, y=None
311
+ )
312
+
313
+ if "ref" in show:
314
+ # Finally, draw a blue box around the reference element (if it exists)
315
+ if ref.get("element") and ref["element"].get("bbox"):
316
+ # rescale the bbox
317
+ bbox = rescale_bbox(ref["element"]["bbox"], scaling_factor)
318
+ x = turn_args.get("properties", {}).get("x")
319
+ y = turn_args.get("properties", {}).get("y")
320
+ _img = create_visualization_2(_img, bbox, color="blue", width=6, x=x, y=y)
321
+
322
+ if "pred_coords" in show:
323
+ # Second draw a green box and x/y coordinate based on predicted coordinates
324
+ # The predicted coordinates are the raw output of the model,
325
+ # Whereas the predicted element is the inferred element from the predicted coordinates
326
+ if pred["args"].get("x") and pred["args"].get("y"):
327
+ x = pred["args"]["x"]
328
+ y = pred["args"]["y"]
329
+
330
+ if scale_pred:
331
+ x = x * scaling_factor
332
+ y = y * scaling_factor
333
+ else:
334
+ x = None
335
+ y = None
336
+
337
+ # If the predicted element is a bounding box, draw a green box around it
338
+ if all(c in pred["args"] for c in ["top", "left", "right", "bottom"]):
339
+ bbox = {
340
+ "top": pred["args"]["top"],
341
+ "left": pred["args"]["left"],
342
+ "width": (pred["args"]["right"] - pred["args"]["left"]),
343
+ "height": (pred["args"]["bottom"] - pred["args"]["top"]),
344
+ "right": pred["args"]["right"],
345
+ "bottom": pred["args"]["bottom"],
346
+ }
347
+
348
+ if scale_pred:
349
+ bbox = rescale_bbox(bbox, scaling_factor)
350
+ else:
351
+ # Otherwise, do nothing
352
+ bbox = None
353
+
354
+ _img = create_visualization_2(_img, bbox=bbox, color="green", width=3, x=x, y=y)
355
+
356
+ return _img