ahuang11 commited on
Commit
09672b1
1 Parent(s): 586d60d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +453 -131
app.py CHANGED
@@ -1,147 +1,469 @@
1
- import io
2
- import random
3
- from typing import List, Tuple
4
 
5
- import aiohttp
 
 
6
  import panel as pn
7
- from PIL import Image
8
- from transformers import CLIPModel, CLIPProcessor
 
9
 
10
- pn.extension(design="bootstrap", sizing_mode="stretch_width")
 
11
 
12
- ICON_URLS = {
13
- "brand-github": "https://github.com/holoviz/panel",
14
- "brand-twitter": "https://twitter.com/Panel_Org",
15
- "brand-linkedin": "https://www.linkedin.com/company/panel-org",
16
- "message-circle": "https://discourse.holoviz.org/",
17
- "brand-discord": "https://discord.gg/AXRHnJU6sP",
18
- }
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- async def random_url(_):
22
- pet = random.choice(["cat", "dog"])
23
- api_url = f"https://api.the{pet}api.com/v1/images/search"
24
- async with aiohttp.ClientSession() as session:
25
- async with session.get(api_url) as resp:
26
- return (await resp.json())[0]["url"]
 
 
27
 
 
 
 
 
 
 
 
28
 
29
- @pn.cache
30
- def load_processor_model(
31
- processor_name: str, model_name: str
32
- ) -> Tuple[CLIPProcessor, CLIPModel]:
33
- processor = CLIPProcessor.from_pretrained(processor_name)
34
- model = CLIPModel.from_pretrained(model_name)
35
- return processor, model
36
 
37
 
38
- async def open_image_url(image_url: str) -> Image:
39
- async with aiohttp.ClientSession() as session:
40
- async with session.get(image_url) as resp:
41
- return Image.open(io.BytesIO(await resp.read()))
 
42
 
 
 
 
43
 
44
- def get_similarity_scores(class_items: List[str], image: Image) -> List[float]:
45
- processor, model = load_processor_model(
46
- "openai/clip-vit-base-patch32", "openai/clip-vit-base-patch32"
47
- )
48
- inputs = processor(
49
- text=class_items,
50
- images=[image],
51
- return_tensors="pt", # pytorch tensors
52
- )
53
- outputs = model(**inputs)
54
- logits_per_image = outputs.logits_per_image
55
- class_likelihoods = logits_per_image.softmax(dim=1).detach().numpy()
56
- return class_likelihoods[0]
57
-
58
-
59
- async def process_inputs(class_names: List[str], image_url: str):
60
- """
61
- High level function that takes in the user inputs and returns the
62
- classification results as panel objects.
63
- """
64
- try:
65
- main.disabled = True
66
- if not image_url:
67
- yield "##### ⚠️ Provide an image URL"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  return
69
-
70
- yield "##### ⚙ Fetching image and running model..."
71
- try:
72
- pil_img = await open_image_url(image_url)
73
- img = pn.pane.Image(pil_img, height=400, align="center")
74
- except Exception as e:
75
- yield f"##### 😔 Something went wrong, please try a different URL!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  return
77
-
78
- class_items = class_names.split(",")
79
- class_likelihoods = get_similarity_scores(class_items, pil_img)
80
-
81
- # build the results column
82
- results = pn.Column("##### 🎉 Here are the results!", img)
83
-
84
- for class_item, class_likelihood in zip(class_items, class_likelihoods):
85
- row_label = pn.widgets.StaticText(
86
- name=class_item.strip(), value=f"{class_likelihood:.2%}", align="center"
87
  )
88
- row_bar = pn.indicators.Progress(
89
- value=int(class_likelihood * 100),
90
- sizing_mode="stretch_width",
91
- bar_color="secondary",
92
- margin=(0, 10),
93
- design=pn.theme.Material,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  )
95
- results.append(pn.Column(row_label, row_bar))
96
- yield results
97
- finally:
98
- main.disabled = False
99
-
100
-
101
- # create widgets
102
- randomize_url = pn.widgets.Button(name="Randomize URL", align="end")
103
-
104
- image_url = pn.widgets.TextInput(
105
- name="Image URL to classify",
106
- value=pn.bind(random_url, randomize_url),
107
- )
108
- class_names = pn.widgets.TextInput(
109
- name="Comma separated class names",
110
- placeholder="Enter possible class names, e.g. cat, dog",
111
- value="cat, dog, parrot",
112
- )
113
-
114
- input_widgets = pn.Column(
115
- "##### 😊 Click randomize or paste a URL to start classifying!",
116
- pn.Row(image_url, randomize_url),
117
- class_names,
118
- )
119
-
120
- # add interactivity
121
- interactive_result = pn.panel(
122
- pn.bind(process_inputs, image_url=image_url, class_names=class_names),
123
- height=600,
124
- )
125
-
126
- # add footer
127
- footer_row = pn.Row(pn.Spacer(), align="center")
128
- for icon, url in ICON_URLS.items():
129
- href_button = pn.widgets.Button(icon=icon, width=35, height=35)
130
- href_button.js_on_click(code=f"window.open('{url}')")
131
- footer_row.append(href_button)
132
- footer_row.append(pn.Spacer())
133
-
134
- # create dashboard
135
- main = pn.WidgetBox(
136
- input_widgets,
137
- interactive_result,
138
- footer_row,
139
- )
140
-
141
- title = "Panel Demo - Image Classification"
142
- pn.template.BootstrapTemplate(
143
- title=title,
144
- main=main,
145
- main_max_width="min(50%, 698px)",
146
- header_background="#F08080",
147
- ).servable(title=title)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
 
 
2
 
3
+ import duckdb
4
+ import holoviews as hv
5
+ import pandas as pd
6
  import panel as pn
7
+ from bokeh.models import HoverTool
8
+ from langchain.callbacks.base import BaseCallbackHandler
9
+ from langchain.chat_models import ChatOpenAI
10
 
11
+ pn.extension(sizing_mode="stretch_width", notifications=True)
12
+ hv.extension("bokeh")
13
 
 
 
 
 
 
 
 
14
 
15
+ RANDOM_NAME_QUERY = """
16
+ SELECT name, count,
17
+ CASE
18
+ WHEN female_percent >= 0.2 AND female_percent <= 0.8 AND male_percent >= 0.2 AND male_percent <= 0.8 THEN 'unisex'
19
+ WHEN female_percent > 0.6 THEN 'female'
20
+ WHEN male_percent > 0.6 THEN 'male'
21
+ END AS gender
22
+ FROM (
23
+ SELECT
24
+ name,
25
+ MAX(male + female) AS count,
26
+ (SUM(female) / CAST(SUM(male + female) AS REAL)) AS female_percent,
27
+ (SUM(male) / CAST(SUM(male + female) AS REAL)) AS male_percent
28
+ FROM names
29
+ WHERE name LIKE ?
30
+ GROUP BY name
31
+ )
32
+ WHERE count >= ? AND count <= ?
33
+ AND gender = ?
34
+ ORDER BY RANDOM()
35
+ LIMIT 100
36
+ """
37
 
38
+ TOP_NAMES_WILDCARD_QUERY = """
39
+ SELECT name, SUM(male + female) as count
40
+ FROM names
41
+ WHERE lower(name) LIKE ?
42
+ GROUP BY name
43
+ ORDER BY count DESC
44
+ LIMIT 10
45
+ """
46
 
47
+ TOP_NAMES_SELECT_QUERY = """
48
+ SELECT name, SUM(male + female) as count
49
+ FROM names
50
+ WHERE lower(name) = ?
51
+ GROUP BY name
52
+ ORDER BY count DESC
53
+ """
54
 
55
+ DATA_QUERY = """
56
+ SELECT name, year, male, female, SUM(male + female) AS count
57
+ FROM names
58
+ WHERE name in ({placeholders})
59
+ GROUP BY name, year, male, female
60
+ ORDER BY name, year
61
+ """
62
 
63
 
64
+ class StreamHandler(BaseCallbackHandler):
65
+ def __init__(self, container, initial_text="", target_attr="value"):
66
+ self.container = container
67
+ self.text = initial_text
68
+ self.target_attr = target_attr
69
 
70
+ def on_llm_new_token(self, token: str, **kwargs) -> None:
71
+ self.text += token
72
+ setattr(self.container, self.target_attr, self.text)
73
 
74
+
75
+ class NameChronicles:
76
+ def __init__(self, refresh=False):
77
+ super().__init__()
78
+ self.db_path = Path("names.db")
79
+ self._initialize_database(refresh=refresh)
80
+
81
+ # Main
82
+ self.holoviews_pane = pn.pane.HoloViews(sizing_mode="stretch_both")
83
+ self.selection = hv.streams.Selection1D()
84
+
85
+ # Sidebar
86
+
87
+ # Name Widgets
88
+ self.names_input = pn.widgets.TextInput(name="Name Input", placeholder="Andrew")
89
+ self.names_input.param.watch(self._add_name, "value")
90
+
91
+ self.names_choice = pn.widgets.MultiChoice(
92
+ name="Selected Names",
93
+ options=["Andrew"],
94
+ solid=False,
95
+ )
96
+ self.names_choice.param.watch(self._update_plot, "value")
97
+ self.names_choice.value = ["Andrew"]
98
+
99
+ # Reset Widgets
100
+ self.clear_button = pn.widgets.Button(
101
+ name="Clear Names", button_style="outline", button_type="primary"
102
+ )
103
+ self.clear_button.on_click(
104
+ lambda event: setattr(self.names_choice, "value", [])
105
+ )
106
+ self.refresh_button = pn.widgets.Button(
107
+ name="Refresh Plot", button_style="outline", button_type="primary"
108
+ )
109
+ self.refresh_button.on_click(self._refresh_plot)
110
+
111
+ # Randomize Widgets
112
+ self.name_pattern = pn.widgets.TextInput(
113
+ name="Name Pattern", placeholder="*na*"
114
+ )
115
+ self.count_range = pn.widgets.IntRangeSlider(
116
+ name="Peak Count Range",
117
+ value=(10000, 50000),
118
+ start=0,
119
+ end=100000,
120
+ step=1000,
121
+ margin=(5, 20),
122
+ )
123
+ self.gender_select = pn.widgets.RadioButtonGroup(
124
+ name="Gender",
125
+ options=["Female", "Unisex", "Male"],
126
+ button_style="outline",
127
+ button_type="primary",
128
+ )
129
+ randomize_name = pn.widgets.Button(
130
+ name="Get Name", button_style="outline", button_type="primary"
131
+ )
132
+ randomize_name.param.watch(self._randomize_name, "clicks")
133
+ self.randomize_pane = pn.Card(
134
+ self.name_pattern,
135
+ self.count_range,
136
+ self.gender_select,
137
+ randomize_name,
138
+ title="Get Random Name",
139
+ collapsed=True,
140
+ )
141
+
142
+ # AI Widgets
143
+ self.ai_key = pn.widgets.PasswordInput(
144
+ name="OpenAI Key",
145
+ placeholder="",
146
+ )
147
+ self.ai_prompt = pn.widgets.TextInput(
148
+ name="AI Prompt",
149
+ value="Share a little history about the name:",
150
+ )
151
+ ai_button = pn.widgets.Button(
152
+ name="Get Response",
153
+ button_style="outline",
154
+ button_type="primary",
155
+ )
156
+ ai_button.on_click(self._prompt_ai)
157
+ self.ai_response = pn.widgets.TextAreaInput(
158
+ placeholder="",
159
+ disabled=True,
160
+ height=350,
161
+ )
162
+ self.ai_pane = pn.Card(
163
+ self.ai_key,
164
+ self.ai_prompt,
165
+ ai_button,
166
+ self.ai_response,
167
+ collapsed=True,
168
+ title="Ask AI",
169
+ )
170
+
171
+ # Database Methods
172
+
173
+ def _connect_database(self):
174
+ """
175
+ Connect to the database.
176
+ """
177
+ return duckdb.connect(database=str(self.db_path))
178
+
179
+ def _initialize_database(self, refresh):
180
+ """
181
+ Initialize database with data from the Social Security Administration.
182
+ """
183
+ if not refresh and self.db_path.exists():
184
  return
185
+
186
+ df = pd.concat(
187
+ [
188
+ pd.read_csv(
189
+ path,
190
+ header=None,
191
+ names=["state", "gender", "year", "name", "count"],
192
+ )
193
+ for path in Path("data").glob("*.TXT")
194
+ ]
195
+ )
196
+ df_processed = (
197
+ df.groupby(["gender", "year", "name"], as_index=False)[["count"]]
198
+ .sum()
199
+ .pivot(index=["name", "year"], columns="gender", values="count")
200
+ .reset_index()
201
+ .rename(columns={"F": "female", "M": "male"})
202
+ .fillna(0)
203
+ )
204
+ with self._connect_database() as conn:
205
+ conn.execute("DROP TABLE IF EXISTS names")
206
+ conn.execute("CREATE TABLE names AS SELECT * FROM df_processed")
207
+
208
+ def _query_names(self, names):
209
+ """
210
+ Query the database for the given name.
211
+ """
212
+ dfs = []
213
+ for name in names:
214
+ if "*" in name or "%" in name:
215
+ name = name.replace("*", "%")
216
+ top_names_query = TOP_NAMES_WILDCARD_QUERY
217
+ else:
218
+ top_names_query = TOP_NAMES_SELECT_QUERY
219
+ with self._connect_database() as conn:
220
+ top_names = (
221
+ conn.execute(top_names_query, [name.lower()])
222
+ .fetch_df()["name"]
223
+ .tolist()
224
+ )
225
+ if len(top_names) == 0:
226
+ pn.state.notifications.info(f"No names found matching {name!r}")
227
+ continue
228
+ data_query = DATA_QUERY.format(
229
+ placeholders=", ".join(["?"] * len(top_names))
230
+ )
231
+ df = conn.execute(data_query, top_names).fetch_df()
232
+ dfs.append(df)
233
+
234
+ if len(dfs) > 0:
235
+ self.df = pd.concat(dfs).drop_duplicates(
236
+ subset=["name", "year", "male", "female"]
237
+ )
238
+ else:
239
+ self.df = pd.DataFrame(columns=["name", "year", "male", "female"])
240
+
241
+ # Widget Methods
242
+
243
+ def _randomize_name(self, event):
244
+ with self._connect_database() as conn:
245
+ name_pattern = self.name_pattern.value.lower()
246
+ if not name_pattern:
247
+ name_pattern = "%"
248
+ else:
249
+ name_pattern = name_pattern.replace("*", "%")
250
+ count_range = self.count_range.value
251
+ gender_select = self.gender_select.value.lower()
252
+ random_names = (
253
+ conn.execute(
254
+ RANDOM_NAME_QUERY, [name_pattern, *count_range, gender_select]
255
+ )
256
+ .fetch_df()["name"]
257
+ .tolist()
258
+ )
259
+ if random_names:
260
+ for i in range(len(random_names)):
261
+ random_name = random_names[i]
262
+ if random_name in self.names_choice.value:
263
+ continue
264
+ self.names_input.value = random_name
265
+ break
266
+ else:
267
+ pn.state.notifications.info(
268
+ "All names matching the criteria are already added!"
269
+ )
270
+ else:
271
+ pn.state.notifications.info("No names found matching the criteria!")
272
+
273
+ def _add_name(self, event):
274
+ name = event.new.strip().title()
275
+ self.names_input.value = ""
276
+ if not name:
277
  return
278
+ elif name in self.names_choice.options and name in self.names_choice.value:
279
+ pn.state.notifications.info(f"{name!r} already added!")
280
+ return
281
+ elif len(self.names_choice.value) > 10:
282
+ pn.state.notifications.info(
283
+ "Maximum of 10 names allowed; please remove some first!"
 
 
 
 
284
  )
285
+ return
286
+ value = self.names_choice.value.copy()
287
+ options = self.names_choice.options.copy()
288
+ if name not in options:
289
+ options.append(name)
290
+ if name not in value:
291
+ value.append(name)
292
+ self.names_choice.param.update(
293
+ options=options,
294
+ value=value,
295
+ )
296
+
297
+ def _prompt_ai(self, event):
298
+ if not self.ai_key.value:
299
+ pn.state.notifications.info("Please enter an API key!")
300
+ return
301
+
302
+ if not self.ai_prompt.value:
303
+ pn.state.notifications.info("Please enter a prompt!")
304
+ return
305
+
306
+ stream_handler = StreamHandler(self.ai_response)
307
+ chat = ChatOpenAI(
308
+ max_tokens=500,
309
+ openai_api_key=self.ai_key.value,
310
+ streaming=True,
311
+ callbacks=[stream_handler],
312
+ )
313
+ self.ai_response.loading = True
314
+ try:
315
+ if self.selection.index:
316
+ names = [self._name_indices[self.selection.index[0]]]
317
+ else:
318
+ names = self.names_choice.value[:3]
319
+ chat.predict(f"{self.ai_prompt.value} {names}")
320
+ finally:
321
+ self.ai_response.loading = False
322
+
323
+ # Plot Methods
324
+
325
+ def _click_plot(self, index):
326
+ gender_nd_overlay = hv.NdOverlay(kdims=["Gender"])
327
+ if not index:
328
+ return hv.NdOverlay(
329
+ {
330
+ "curve": self._curve_nd_overlay,
331
+ "scatter": self._scatter_nd_overlay,
332
+ "label": self._label_nd_overlay,
333
+ }
334
  )
335
+
336
+ name = self._name_indices[index[0]]
337
+ df_name = self.df.loc[self.df["name"] == name].copy()
338
+ df_name["female"] += df_name["male"]
339
+ gender_nd_overlay["Male"] = hv.Area(
340
+ df_name, ["year"], ["male"], label="Male"
341
+ ).opts(alpha=0.3, color="#add8e6", line_alpha=0)
342
+ gender_nd_overlay["Female"] = hv.Area(
343
+ df_name, ["year"], ["male", "female"], label="Female"
344
+ ).opts(alpha=0.3, color="#ffb6c1", line_alpha=0)
345
+ return hv.NdOverlay(
346
+ {
347
+ "curve": self._curve_nd_overlay[[index[0]]],
348
+ "scatter": self._scatter_nd_overlay,
349
+ "label": self._label_nd_overlay[[index[0]]].opts(text_color="black"),
350
+ "gender": gender_nd_overlay,
351
+ },
352
+ kdims=["Gender"],
353
+ ).opts(legend_position="top_left")
354
+
355
+ @staticmethod
356
+ def _format_y(value):
357
+ return f"{value / 1000}k"
358
+
359
+ def _update_plot(self, event):
360
+ names = event.new
361
+ print(names)
362
+ self._query_names(names)
363
+
364
+ self._scatter_nd_overlay = hv.NdOverlay()
365
+ self._curve_nd_overlay = hv.NdOverlay(kdims=["Name"]).opts(
366
+ gridstyle={"xgrid_line_width": 0},
367
+ show_grid=True,
368
+ fontscale=1.28,
369
+ xlabel="Year",
370
+ ylabel="Count",
371
+ yformatter=self._format_y,
372
+ legend_limit=0,
373
+ padding=(0.2, 0.05),
374
+ title="Name Chronicles",
375
+ responsive=True,
376
+ )
377
+ self._label_nd_overlay = hv.NdOverlay(kdims=["Name"])
378
+ hover_tool = HoverTool(
379
+ tooltips=[("Name", "@name"), ("Year", "@year"), ("Count", "@count")],
380
+ )
381
+ self._name_indices = {}
382
+ scatter_cycle = hv.Cycle("Category10")
383
+ curve_cycle = hv.Cycle("Category10")
384
+ label_cycle = hv.Cycle("Category10")
385
+ for i, (name, df_name) in enumerate(self.df.groupby("name")):
386
+ df_name_total = df_name.groupby(
387
+ ["name", "year", "male", "female"], as_index=False
388
+ )["count"].sum()
389
+ df_name_total["male"] = df_name_total["male"] / df_name_total["count"]
390
+ df_name_total["female"] = df_name_total["female"] / df_name_total["count"]
391
+ df_name_peak = df_name.loc[[df_name["count"].idxmax()]]
392
+ df_name_peak[
393
+ "label"
394
+ ] = f'{df_name_peak["name"].item()} ({df_name_peak["year"].item()})'
395
+
396
+ hover_tool = HoverTool(
397
+ tooltips=[
398
+ ("Name", "@name"),
399
+ ("Year", "@year"),
400
+ ("Count", "@count{(0a)}"),
401
+ ("Male", "@male{(0%)}"),
402
+ ("Female", "@female{(0%)}"),
403
+ ],
404
+ )
405
+ self._scatter_nd_overlay[i] = hv.Scatter(
406
+ df_name_total, ["year"], ["count", "male", "female", "name"], label=name
407
+ ).opts(
408
+ color=scatter_cycle,
409
+ size=4,
410
+ alpha=0.15,
411
+ marker="y",
412
+ tools=["tap", hover_tool],
413
+ line_width=3,
414
+ show_legend=False,
415
+ )
416
+ self._curve_nd_overlay[i] = hv.Curve(
417
+ df_name_total, ["year"], ["count"], label=name
418
+ ).opts(
419
+ color=curve_cycle,
420
+ tools=["tap"],
421
+ line_width=3,
422
+ )
423
+ self._label_nd_overlay[i] = hv.Labels(
424
+ df_name_peak, ["year", "count"], ["label"], label=name
425
+ ).opts(
426
+ text_align="right",
427
+ text_baseline="bottom",
428
+ text_color=label_cycle,
429
+ )
430
+ self._name_indices[i] = name
431
+ self.selection.source = self._curve_nd_overlay
432
+ if len(self._name_indices) == 1:
433
+ self.selection.update(index=[0])
434
+ else:
435
+ self.selection.update(index=[])
436
+ self.dynamic_map = hv.DynamicMap(
437
+ self._click_plot, kdims=[], streams=[self.selection]
438
+ ).opts(responsive=True)
439
+ self._refresh_plot()
440
+
441
+ def _refresh_plot(self, event=None):
442
+ self.holoviews_pane.object = self.dynamic_map.clone()
443
+
444
+ def view(self):
445
+ reset_row = pn.Row(self.clear_button, self.refresh_button)
446
+ data_url = pn.pane.Markdown(
447
+ "<center>Data from the <a href='https://www.ssa.gov/oact/babynames/limits.html' "
448
+ "target='_blank'>U.S. Social Security Administration</a></center>",
449
+ align="end",
450
+ )
451
+ sidebar = pn.Column(
452
+ self.names_input,
453
+ self.names_choice,
454
+ reset_row,
455
+ pn.layout.Divider(),
456
+ self.randomize_pane,
457
+ self.ai_pane,
458
+ data_url,
459
+ )
460
+ template = pn.template.FastListTemplate(
461
+ sidebar=[sidebar],
462
+ main=[self.holoviews_pane],
463
+ title="Name Chronicles",
464
+ theme="dark",
465
+ )
466
+ return template
467
+
468
+
469
+ NameChronicles().view().servable()