samin commited on
Commit
bf1d26a
Β·
verified Β·
1 Parent(s): 3951a6d

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +25 -61
src/streamlit_app.py CHANGED
@@ -31,6 +31,8 @@ TASK_DESCRIPTIONS = {
31
  }
32
 
33
  DEFAULT_CHUNK = 10
 
 
34
  if "show_counts" not in st.session_state:
35
  st.session_state.show_counts = {task: DEFAULT_CHUNK for task in DATA_DIRS}
36
 
@@ -41,13 +43,7 @@ _num_re = re.compile(r"(\d+)")
41
 
42
  def _natural_key(s: str):
43
  parts = _num_re.split(s)
44
- out = []
45
- for p in parts:
46
- if p.isdigit():
47
- out.append(int(p))
48
- else:
49
- out.append(p.lower())
50
- return out
51
 
52
  def _basename_from_url(url: str) -> str:
53
  try:
@@ -59,7 +55,7 @@ def _sorted_urls(urls):
59
  return sorted(urls, key=lambda u: _natural_key(_basename_from_url(u)))
60
 
61
  # ──────────────────────────────────────────────────────────────────────────────
62
- # Load manifest (local file or baked into Space)
63
  # ──────────────────────────────────────────────────────────────────────────────
64
  @st.cache_data(show_spinner=False)
65
  def load_manifest():
@@ -70,7 +66,6 @@ def load_manifest():
70
  topics = set()
71
 
72
  for it in items:
73
- # sort the image url lists inside each item for determinism
74
  if it.get("cond_image_urls"):
75
  it["cond_image_urls"] = _sorted_urls(it["cond_image_urls"])
76
  if it.get("model_output_urls"):
@@ -80,16 +75,14 @@ def load_manifest():
80
  if it.get("topic"):
81
  topics.add(it["topic"])
82
 
83
- # sort items within each task by item_id (natural order)
84
  for t, lst in per_task.items():
85
  lst.sort(key=lambda it: _natural_key(str(it.get("item_id", ""))))
86
-
87
  return per_task, sorted(list(topics))
88
 
89
  # ──────────────────────────────────────────────────────────────────────────────
90
  # Stable image grid
91
  # ──────────────────────────────────────────────────────────────────────────────
92
- def _display_images(urls, caption_prefix="", max_per_row=3, fixed_height_px=None):
93
  if not urls:
94
  st.write("No images found.")
95
  return
@@ -98,26 +91,22 @@ def _display_images(urls, caption_prefix="", max_per_row=3, fixed_height_px=None
98
  for i, url in enumerate(urls):
99
  col = cols[i % max_per_row]
100
  with col:
101
- if fixed_height_px:
102
- # Reserve space and avoid reflow while image loads
103
- st.markdown(
104
- f"""
105
- <div class="img-frame" style="height:{fixed_height_px}px; display:flex; align-items:center; justify-content:center; overflow:hidden; border-radius:12px;">
106
- <img src="{url}" alt="{_basename_from_url(url)}" style="max-height:100%; width:100%; object-fit:contain;" />
107
- </div>
108
- <div class="img-cap" style="font-size:0.85rem; opacity:0.8; margin-top:4px;">
109
- {caption_prefix} {_basename_from_url(url)}
110
- </div>
111
- """,
112
- unsafe_allow_html=True,
113
- )
114
- else:
115
- st.image(url, caption=f"{caption_prefix} {_basename_from_url(url)}", use_container_width=True)
116
 
117
  # ──────────────────────────────────────────────────────────────────────────────
118
- # Global CSS to reduce β€œvibrating” / layout reflow
119
  # ──────────────────────────────────────────────────────────────────────────────
120
- def _inject_css(_fixed_height_px: int | None):
121
  css = """
122
  <style>
123
  .block-container { padding-top: 0.75rem; }
@@ -143,41 +132,23 @@ def _inject_css(_fixed_height_px: int | None):
143
  def main():
144
  st.title("πŸ–ΌοΈ ImagenHub2 Data Visualization")
145
  st.markdown("Each task starts with **10** items β€” click **Show more** to load **+10**.")
 
146
 
147
- # Load manifest first (to get topic list)
148
  with st.spinner("Loading manifest…"):
149
  per_task, _all_topics = load_manifest()
150
 
151
- # Sidebar
152
  st.sidebar.header("Filters")
153
- fixed_height_on = st.sidebar.toggle(
154
- "Stabilize grid with fixed image height",
155
- value=True,
156
- help="Pre-allocate space for images to prevent page β€˜vibrating’."
157
- )
158
- fixed_height_px = st.sidebar.number_input(
159
- "Fixed image height (px)",
160
- min_value=120, max_value=1200, value=320, step=20,
161
- disabled=not fixed_height_on
162
- )
163
- _inject_css(fixed_height_px if fixed_height_on else None)
164
-
165
  selected_tasks = st.sidebar.multiselect("Select Tasks", DATA_DIRS, default=DATA_DIRS)
166
- search_query = st.sidebar.text_input("πŸ” Search in prompts", "")
167
- # Topic filter behaves like task filter (multiselect)
168
- topic_filter = st.sidebar.multiselect(
169
- "Select Topics",
170
- _all_topics,
171
- default=[],
172
- help="Filter items by one or more topic IDs."
173
- )
174
  subtopic_filter = st.sidebar.text_input("Filter by subtopic (optional)", "")
175
 
176
  st.sidebar.header("Task Descriptions")
177
  for t in selected_tasks:
178
  st.sidebar.write(f"**{t}**: {TASK_DESCRIPTIONS.get(t, '')}")
179
 
180
- # Tabs per selected task
181
  tabs = st.tabs(selected_tasks) if selected_tasks else []
182
  for task, tab in zip(selected_tasks, tabs):
183
  with tab:
@@ -185,7 +156,6 @@ def main():
185
  limit = st.session_state.show_counts.get(task, DEFAULT_CHUNK)
186
  all_items = per_task.get(task, [])
187
 
188
- # Apply filters
189
  def _match(it):
190
  sq = search_query.strip().lower()
191
  if sq and (sq not in it.get("prompt", "").lower()
@@ -226,16 +196,10 @@ def main():
226
 
227
  if cond_urls:
228
  st.write("**Condition Images:**")
229
- _display_images(
230
- cond_urls, "Condition", max_per_row=3,
231
- fixed_height_px=(fixed_height_px if fixed_height_on else None)
232
- )
233
  if model_urls:
234
  st.write("**Model Output:**")
235
- _display_images(
236
- model_urls, "Model", max_per_row=3,
237
- fixed_height_px=(fixed_height_px if fixed_height_on else None)
238
- )
239
  st.divider()
240
 
241
  # Pagination controls
 
31
  }
32
 
33
  DEFAULT_CHUNK = 10
34
+ FIXED_HEIGHT_PX = 320 # always stabilized
35
+
36
  if "show_counts" not in st.session_state:
37
  st.session_state.show_counts = {task: DEFAULT_CHUNK for task in DATA_DIRS}
38
 
 
43
 
44
  def _natural_key(s: str):
45
  parts = _num_re.split(s)
46
+ return [int(p) if p.isdigit() else p.lower() for p in parts]
 
 
 
 
 
 
47
 
48
  def _basename_from_url(url: str) -> str:
49
  try:
 
55
  return sorted(urls, key=lambda u: _natural_key(_basename_from_url(u)))
56
 
57
  # ──────────────────────────────────────────────────────────────────────────────
58
+ # Load manifest
59
  # ──────────────────────────────────────────────────────────────────────────────
60
  @st.cache_data(show_spinner=False)
61
  def load_manifest():
 
66
  topics = set()
67
 
68
  for it in items:
 
69
  if it.get("cond_image_urls"):
70
  it["cond_image_urls"] = _sorted_urls(it["cond_image_urls"])
71
  if it.get("model_output_urls"):
 
75
  if it.get("topic"):
76
  topics.add(it["topic"])
77
 
 
78
  for t, lst in per_task.items():
79
  lst.sort(key=lambda it: _natural_key(str(it.get("item_id", ""))))
 
80
  return per_task, sorted(list(topics))
81
 
82
  # ──────────────────────────────────────────────────────────────────────────────
83
  # Stable image grid
84
  # ──────────────────────────────────────────────────────────────────────────────
85
+ def _display_images(urls, caption_prefix="", max_per_row=3):
86
  if not urls:
87
  st.write("No images found.")
88
  return
 
91
  for i, url in enumerate(urls):
92
  col = cols[i % max_per_row]
93
  with col:
94
+ st.markdown(
95
+ f"""
96
+ <div class="img-frame" style="height:{FIXED_HEIGHT_PX}px; display:flex; align-items:center; justify-content:center; overflow:hidden; border-radius:12px;">
97
+ <img src="{url}" alt="{_basename_from_url(url)}" style="max-height:100%; width:100%; object-fit:contain;" />
98
+ </div>
99
+ <div class="img-cap" style="text-align:center; font-size:0.9rem; opacity:0.8; margin-top:6px;">
100
+ {caption_prefix} {_basename_from_url(url)}
101
+ </div>
102
+ """,
103
+ unsafe_allow_html=True,
104
+ )
 
 
 
 
105
 
106
  # ──────────────────────────────────────────────────────────────────────────────
107
+ # Global CSS
108
  # ──────────────────────────────────────────────────────────────────────────────
109
+ def _inject_css():
110
  css = """
111
  <style>
112
  .block-container { padding-top: 0.75rem; }
 
132
  def main():
133
  st.title("πŸ–ΌοΈ ImagenHub2 Data Visualization")
134
  st.markdown("Each task starts with **10** items β€” click **Show more** to load **+10**.")
135
+ _inject_css()
136
 
 
137
  with st.spinner("Loading manifest…"):
138
  per_task, _all_topics = load_manifest()
139
 
140
+ # Sidebar filters
141
  st.sidebar.header("Filters")
 
 
 
 
 
 
 
 
 
 
 
 
142
  selected_tasks = st.sidebar.multiselect("Select Tasks", DATA_DIRS, default=DATA_DIRS)
143
+ search_query = st.sidebar.text_input("πŸ” Search in prompts", "")
144
+ topic_filter = st.sidebar.multiselect("Select Topics", _all_topics, default=[])
 
 
 
 
 
 
145
  subtopic_filter = st.sidebar.text_input("Filter by subtopic (optional)", "")
146
 
147
  st.sidebar.header("Task Descriptions")
148
  for t in selected_tasks:
149
  st.sidebar.write(f"**{t}**: {TASK_DESCRIPTIONS.get(t, '')}")
150
 
151
+ # Tabs per task
152
  tabs = st.tabs(selected_tasks) if selected_tasks else []
153
  for task, tab in zip(selected_tasks, tabs):
154
  with tab:
 
156
  limit = st.session_state.show_counts.get(task, DEFAULT_CHUNK)
157
  all_items = per_task.get(task, [])
158
 
 
159
  def _match(it):
160
  sq = search_query.strip().lower()
161
  if sq and (sq not in it.get("prompt", "").lower()
 
196
 
197
  if cond_urls:
198
  st.write("**Condition Images:**")
199
+ _display_images(cond_urls, "Condition")
 
 
 
200
  if model_urls:
201
  st.write("**Model Output:**")
202
+ _display_images(model_urls, "Model")
 
 
 
203
  st.divider()
204
 
205
  # Pagination controls