ahuang11 commited on
Commit
7db9e90
1 Parent(s): dcb23b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -531
app.py CHANGED
@@ -1,532 +1,3 @@
1
- from pathlib import Path
2
 
3
- import duckdb
4
- import holoviews as hv
5
- import pandas as pd
6
- import panel as pn
7
- from bokeh.models import HoverTool
8
- from bokeh.models import NumeralTickFormatter
9
- from pydantic import BaseModel, Field
10
- from langchain.callbacks.base import BaseCallbackHandler
11
- from langchain.chat_models import ChatOpenAI
12
- from langchain.llms.openai import OpenAI
13
- from langchain.output_parsers import PydanticOutputParser
14
- from langchain.pydantic_v1 import BaseModel, Field, validator
15
- from langchain.memory import ConversationBufferMemory
16
- from langchain.chains import ConversationChain
17
- from langchain.prompts import PromptTemplate
18
-
19
- pn.extension(sizing_mode="stretch_width", notifications=True)
20
- hv.extension("bokeh")
21
-
22
- INSTRUCTIONS = """
23
- #### Name Chronicles lets you explore the history of names in the United States.
24
- - Enter a name to add to plot!
25
- - Hover over a line for stats or click for the gender distribution.
26
- - Chat with AI for inspiration or get a random name based on input criteria.
27
- - Have ideas? [Open an issue](https://github.com/ahuang11/name-chronicles/issues).
28
- """
29
-
30
- RANDOM_NAME_QUERY = """
31
- SELECT name, count,
32
- CASE
33
- WHEN female_percent >= 0.2 AND female_percent <= 0.8 AND male_percent >= 0.2 AND male_percent <= 0.8 THEN 'unisex'
34
- WHEN female_percent > 0.5 THEN 'female'
35
- WHEN male_percent > 0.5 THEN 'male'
36
- END AS gender
37
- FROM (
38
- SELECT
39
- name,
40
- MAX(male + female) AS count,
41
- (SUM(female) / CAST(SUM(male + female) AS REAL)) AS female_percent,
42
- (SUM(male) / CAST(SUM(male + female) AS REAL)) AS male_percent
43
- FROM names
44
- WHERE name LIKE ?
45
- GROUP BY name
46
- )
47
- WHERE count >= ? AND count <= ?
48
- AND gender = ?
49
- ORDER BY RANDOM()
50
- LIMIT 100
51
- """
52
-
53
- TOP_NAMES_WILDCARD_QUERY = """
54
- SELECT name, SUM(male + female) as count
55
- FROM names
56
- WHERE lower(name) LIKE ?
57
- GROUP BY name
58
- ORDER BY count DESC
59
- LIMIT 10
60
- """
61
-
62
- TOP_NAMES_SELECT_QUERY = """
63
- SELECT name, SUM(male + female) as count
64
- FROM names
65
- WHERE lower(name) = ?
66
- GROUP BY name
67
- ORDER BY count DESC
68
- """
69
-
70
- DATA_QUERY = """
71
- SELECT name, year, male, female, SUM(male + female) AS count
72
- FROM names
73
- WHERE name in ({placeholders})
74
- GROUP BY name, year, male, female
75
- ORDER BY name, year
76
- """
77
-
78
- MAX_LLM_COUNT = 2000
79
-
80
- class FirstNames(BaseModel):
81
- names: list[str] = Field(description="List of first names")
82
-
83
-
84
- class StreamHandler(BaseCallbackHandler):
85
- def __init__(self, container, initial_text="", target_attr="value"):
86
- self.container = container
87
- self.text = initial_text
88
- self.target_attr = target_attr
89
-
90
- def on_llm_new_token(self, token: str, **kwargs) -> None:
91
- self.text += token
92
- setattr(self.container, self.target_attr, self.text)
93
-
94
-
95
- class NameChronicles:
96
- def __init__(self):
97
- super().__init__()
98
- self.llm_use_counter = 0
99
- self.db_path = Path("data/names.db")
100
-
101
- # Main
102
- self.scatter_cycle = hv.Cycle("Category10")
103
- self.curve_cycle = hv.Cycle("Category10")
104
- self.label_cycle = hv.Cycle("Category10")
105
- self.holoviews_pane = pn.pane.HoloViews(
106
- min_height=675, sizing_mode="stretch_both"
107
- )
108
- self.selection = hv.streams.Selection1D()
109
-
110
- # Sidebar
111
-
112
- # Name Widgets
113
- self.names_input = pn.widgets.TextInput(name="Name Input", placeholder="Andrew")
114
- self.names_input.param.watch(self._add_name, "value")
115
-
116
- self.names_choice = pn.widgets.MultiChoice(
117
- name="Selected Names",
118
- options=["Andrew"],
119
- solid=False,
120
- )
121
- self.names_choice.param.watch(self._update_plot, "value")
122
-
123
- # Reset Widgets
124
- self.clear_button = pn.widgets.Button(
125
- name="Clear Names", button_style="outline", button_type="primary"
126
- )
127
- self.clear_button.on_click(
128
- lambda event: setattr(self.names_choice, "value", [])
129
- )
130
- self.refresh_button = pn.widgets.Button(
131
- name="Refresh Plot", button_style="outline", button_type="primary"
132
- )
133
- self.refresh_button.on_click(self._refresh_plot)
134
-
135
- # Randomize Widgets
136
- self.name_pattern = pn.widgets.TextInput(
137
- name="Name Pattern", placeholder="*na*"
138
- )
139
- self.count_range = pn.widgets.IntRangeSlider(
140
- name="Peak Count Range",
141
- value=(0, 100000),
142
- start=0,
143
- end=100000,
144
- step=1000,
145
- margin=(5, 20),
146
- )
147
- self.gender_select = pn.widgets.RadioButtonGroup(
148
- name="Gender",
149
- options=["Female", "Unisex", "Male"],
150
- button_style="outline",
151
- button_type="primary",
152
- )
153
- randomize_name = pn.widgets.Button(
154
- name="Get Name", button_style="outline", button_type="primary"
155
- )
156
- randomize_name.param.watch(self._randomize_name, "clicks")
157
- self.randomize_pane = pn.Card(
158
- self.name_pattern,
159
- self.count_range,
160
- self.gender_select,
161
- randomize_name,
162
- title="Get Random Name",
163
- collapsed=True,
164
- )
165
-
166
- # AI Widgets
167
- self.chat_interface = pn.chat.ChatInterface(
168
- show_button_name=False,
169
- callback=self._prompt_ai,
170
- height=500,
171
- styles={"background": "white"},
172
- disabled=True,
173
- )
174
- self.chat_interface.send(
175
- value=(
176
- "Ask me about name suggestions or their history! "
177
- "To add suggested names, click the button below!"
178
- ),
179
- user="System",
180
- respond=False,
181
- )
182
- self.parse_ai_button = pn.widgets.Button(
183
- name="Parse and Add Names",
184
- button_style="outline",
185
- button_type="primary",
186
- disabled=True,
187
- )
188
- self.last_ai_output = None
189
- pn.state.onload(self._initialize_database)
190
-
191
- # Database Methods
192
-
193
- def _initialize_database(self):
194
- """
195
- Initialize database with data from the Social Security Administration.
196
- """
197
- self.conn = duckdb.connect(":memory:")
198
- df = pd.concat(
199
- [
200
- pd.read_csv(
201
- path,
202
- header=None,
203
- names=["state", "gender", "year", "name", "count"],
204
- )
205
- for path in Path("data").glob("*.TXT")
206
- ]
207
- )
208
- df_processed = (
209
- df.groupby(["gender", "year", "name"], as_index=False)[["count"]]
210
- .sum()
211
- .pivot(index=["name", "year"], columns="gender", values="count")
212
- .reset_index()
213
- .rename(columns={"F": "female", "M": "male"})
214
- .fillna(0)
215
- )
216
- self.conn.execute("DROP TABLE IF EXISTS names")
217
- self.conn.execute("CREATE TABLE names AS SELECT * FROM df_processed")
218
-
219
- if self.names_choice.value == []:
220
- self.names_choice.value = ["Andrew"]
221
- else:
222
- self.names_choice.param.trigger("value")
223
- self.main.objects = [self.holoviews_pane]
224
-
225
- # Start AI
226
- self.callback_handler = pn.chat.langchain.PanelCallbackHandler(
227
- self.chat_interface
228
- )
229
- self.chat_openai = ChatOpenAI(
230
- max_tokens=75,
231
- streaming=True,
232
- callbacks=[self.callback_handler],
233
- )
234
- self.openai = OpenAI(max_tokens=75)
235
- memory = ConversationBufferMemory()
236
- self.conversation_chain = ConversationChain(
237
- llm=self.chat_openai, memory=memory, callbacks=[self.callback_handler]
238
- )
239
- self.chat_interface.disabled = False
240
- self.parse_ai_button.on_click(self._parse_ai_output)
241
- self.pydantic_parser = PydanticOutputParser(pydantic_object=FirstNames)
242
- self.prompt_template = PromptTemplate(
243
- template="{format_instructions}\n{input}\n",
244
- input_variables=["input"],
245
- partial_variables={"format_instructions": self.pydantic_parser.get_format_instructions()},
246
- )
247
-
248
- def _query_names(self, names):
249
- """
250
- Query the database for the given name.
251
- """
252
- dfs = []
253
- for name in names:
254
- if "*" in name or "%" in name:
255
- name = name.replace("*", "%")
256
- top_names_query = TOP_NAMES_WILDCARD_QUERY
257
- else:
258
- top_names_query = TOP_NAMES_SELECT_QUERY
259
- top_names = (
260
- self.conn.execute(top_names_query, [name.lower()])
261
- .fetch_df()["name"]
262
- .tolist()
263
- )
264
- if len(top_names) == 0:
265
- pn.state.notifications.info(f"No names found matching {name!r}")
266
- continue
267
- data_query = DATA_QUERY.format(
268
- placeholders=", ".join(["?"] * len(top_names))
269
- )
270
- df = self.conn.execute(data_query, top_names).fetch_df()
271
- dfs.append(df)
272
-
273
- if len(dfs) > 0:
274
- self.df = pd.concat(dfs).drop_duplicates(
275
- subset=["name", "year", "male", "female"]
276
- )
277
- else:
278
- self.df = pd.DataFrame(columns=["name", "year", "male", "female"])
279
-
280
- # Widget Methods
281
-
282
- def _randomize_name(self, event):
283
- name_pattern = self.name_pattern.value.lower()
284
- if not name_pattern:
285
- name_pattern = "%"
286
- else:
287
- name_pattern = name_pattern.replace("*", "%")
288
- if not name_pattern.startswith("%"):
289
- name_pattern = name_pattern.title()
290
-
291
- count_range = self.count_range.value
292
- gender_select = self.gender_select.value.lower()
293
- random_names = (
294
- self.conn.execute(
295
- RANDOM_NAME_QUERY, [name_pattern, *count_range, gender_select]
296
- ).fetch_df()["name"]
297
- .tolist()
298
- )
299
- print(len(random_names))
300
- if random_names:
301
- for i in range(len(random_names)):
302
- random_name = random_names[i]
303
- if random_name in self.names_choice.value:
304
- continue
305
- self.names_input.value = random_name
306
- break
307
- else:
308
- pn.state.notifications.info(
309
- "All names matching the criteria are already added!"
310
- )
311
- else:
312
- pn.state.notifications.info("No names found matching the criteria!")
313
-
314
- def _add_only_unique_names(self, names):
315
- value = self.names_choice.value.copy()
316
- options = self.names_choice.options.copy()
317
- for name in names:
318
- if " " in name:
319
- name = name.split(" ", 1)[0]
320
- if name not in options:
321
- options.append(name)
322
- if name not in value:
323
- value.append(name)
324
- self.names_choice.param.update(
325
- options=options,
326
- value=value,
327
- )
328
-
329
- def _add_name(self, event):
330
- name = event.new.strip().title()
331
- self.names_input.value = ""
332
- if not name:
333
- return
334
- elif name in self.names_choice.options and name in self.names_choice.value:
335
- pn.state.notifications.info(f"{name!r} already added!")
336
- return
337
- elif len(self.names_choice.value) > 10:
338
- pn.state.notifications.info(
339
- "Maximum of 10 names allowed; please remove some first!"
340
- )
341
- return
342
- self._add_only_unique_names([name])
343
-
344
- async def _prompt_ai(self, contents, user, instance):
345
- if self.llm_use_counter >= MAX_LLM_COUNT:
346
- pn.state.notifications.info(
347
- "Sorry, all the available AI credits have been used!"
348
- )
349
- return
350
-
351
- prompt = (
352
- f"One sentence reply to {contents!r} or concisely suggest other relevant names; "
353
- f"if no name is provided use {self.names_choice.value[-1]!r}."
354
- )
355
- print(prompt)
356
- self.last_ai_output = await self.conversation_chain.apredict(
357
- input=prompt,
358
- callbacks=[self.callback_handler],
359
- )
360
- self.parse_ai_button.disabled = False
361
- self.llm_use_counter += 1
362
-
363
- async def _parse_ai_output(self, _):
364
- if self.llm_use_counter >= MAX_LLM_COUNT:
365
- pn.state.notifications.info(
366
- "Sorry, all the available AI credits have been used!"
367
- )
368
- return
369
-
370
- if self.last_ai_output is None:
371
- pn.state.notifications.info("No available AI output to parse!")
372
- return
373
-
374
- try:
375
- names_prompt = self.prompt_template.format_prompt(input=self.last_ai_output).to_string()
376
- names_text = await self.openai.apredict(names_prompt)
377
- new_names = (await self.pydantic_parser.aparse(names_text)).names
378
- print(new_names)
379
- self._add_only_unique_names(new_names)
380
- except Exception:
381
- pn.state.notifications.error("Failed to parse AI output.")
382
- finally:
383
- self.last_ai_output = None
384
-
385
- # Plot Methods
386
- def _click_plot(self, index):
387
- gender_nd_overlay = hv.NdOverlay(kdims=["Gender"])
388
- if not index:
389
- return hv.NdOverlay(
390
- {
391
- "curve": self._curve_nd_overlay,
392
- "scatter": self._scatter_nd_overlay,
393
- "label": self._label_nd_overlay,
394
- }
395
- )
396
-
397
- name = self._name_indices[index[0]]
398
- df_name = self.df.loc[self.df["name"] == name].copy()
399
- df_name["female"] += df_name["male"]
400
- gender_nd_overlay["Male"] = hv.Area(
401
- df_name, ["year"], ["male"], label="Male"
402
- ).opts(alpha=0.3, color="#add8e6", line_alpha=0)
403
- gender_nd_overlay["Female"] = hv.Area(
404
- df_name, ["year"], ["male", "female"], label="Female"
405
- ).opts(alpha=0.3, color="#ffb6c1", line_alpha=0)
406
- return hv.NdOverlay(
407
- {
408
- "curve": self._curve_nd_overlay[[index[0]]],
409
- "scatter": self._scatter_nd_overlay,
410
- "label": self._label_nd_overlay[[index[0]]].opts(text_color="black"),
411
- "gender": gender_nd_overlay,
412
- },
413
- kdims=["Gender"],
414
- ).opts(legend_position="top_left")
415
-
416
- def _update_plot(self, event):
417
- names = event.new
418
- print(names)
419
- self._query_names(names)
420
-
421
- self._scatter_nd_overlay = hv.NdOverlay()
422
- self._curve_nd_overlay = hv.NdOverlay(kdims=["Name"]).opts(
423
- gridstyle={"xgrid_line_width": 0},
424
- show_grid=True,
425
- fontscale=1.28,
426
- xlabel="Year",
427
- ylabel="Count",
428
- yformatter=NumeralTickFormatter(format="0.0a"),
429
- legend_limit=0,
430
- padding=(0.2, 0.05),
431
- title="Name Chronicles",
432
- responsive=True,
433
- )
434
- self._label_nd_overlay = hv.NdOverlay(kdims=["Name"])
435
- hover_tool = HoverTool(
436
- tooltips=[("Name", "@name"), ("Year", "@year"), ("Count", "@count")],
437
- )
438
- self._name_indices = {}
439
- for i, (name, df_name) in enumerate(self.df.groupby("name")):
440
- df_name_total = df_name.groupby(
441
- ["name", "year", "male", "female"], as_index=False
442
- )["count"].sum()
443
- df_name_total["male"] = df_name_total["male"] / df_name_total["count"]
444
- df_name_total["female"] = df_name_total["female"] / df_name_total["count"]
445
- df_name_peak = df_name.loc[[df_name["count"].idxmax()]]
446
- df_name_peak[
447
- "label"
448
- ] = f'{df_name_peak["name"].item()} ({df_name_peak["year"].item()})'
449
-
450
- hover_tool = HoverTool(
451
- tooltips=[
452
- ("Name", "@name"),
453
- ("Year", "@year"),
454
- ("Count", "@count{(0a)}"),
455
- ("Male", "@male{(0%)}"),
456
- ("Female", "@female{(0%)}"),
457
- ],
458
- )
459
- self._scatter_nd_overlay[i] = hv.Scatter(
460
- df_name_total, ["year"], ["count", "male", "female", "name"], label=name
461
- ).opts(
462
- color=self.scatter_cycle,
463
- size=4,
464
- alpha=0.15,
465
- marker="y",
466
- tools=["tap", hover_tool],
467
- line_width=3,
468
- show_legend=False,
469
- )
470
- self._curve_nd_overlay[i] = hv.Curve(
471
- df_name_total, ["year"], ["count"], label=name
472
- ).opts(
473
- color=self.curve_cycle,
474
- tools=["tap"],
475
- line_width=3,
476
- )
477
- self._label_nd_overlay[i] = hv.Labels(
478
- df_name_peak, ["year", "count"], ["label"], label=name
479
- ).opts(
480
- text_align="right",
481
- text_baseline="bottom",
482
- text_color=self.label_cycle,
483
- )
484
- self._name_indices[i] = name
485
- self.selection.source = self._curve_nd_overlay
486
- if len(self._name_indices) == 1:
487
- self.selection.update(index=[0])
488
- else:
489
- self.selection.update(index=[])
490
- self.dynamic_map = hv.DynamicMap(
491
- self._click_plot, kdims=[], streams=[self.selection]
492
- ).opts(responsive=True)
493
- self._refresh_plot()
494
-
495
- def _refresh_plot(self, event=None):
496
- self.holoviews_pane.object = self.dynamic_map.clone()
497
-
498
- def view(self):
499
- reset_row = pn.Row(self.clear_button, self.refresh_button)
500
- data_url = pn.pane.Markdown(
501
- "<center>Data from the <a href='https://www.ssa.gov/oact/babynames/limits.html' "
502
- "target='_blank'>U.S. Social Security Administration</a></center>",
503
- align="end",
504
- )
505
- sidebar = pn.Column(
506
- INSTRUCTIONS,
507
- self.names_input,
508
- self.names_choice,
509
- reset_row,
510
- pn.layout.Divider(),
511
- self.chat_interface,
512
- self.parse_ai_button,
513
- self.randomize_pane,
514
- data_url,
515
- )
516
- self.main = pn.Column(
517
- pn.widgets.StaticText(
518
- value="Loading, this may take a few seconds...",
519
- sizing_mode="stretch_both",
520
- ),
521
- )
522
- template = pn.template.FastListTemplate(
523
- sidebar_width=500,
524
- sidebar=[sidebar],
525
- main=[self.main],
526
- title="Name Chronicles",
527
- theme="dark",
528
- )
529
- return template
530
-
531
-
532
- NameChronicles().view().servable()
 
1
+ from tastymap import TastyKitchen
2
 
3
+ TastyKitchen().servable()