davanstrien HF staff commited on
Commit
b99b870
·
1 Parent(s): beaaa22
Files changed (3) hide show
  1. app.py +52 -0
  2. playground.ipynb +32 -0
  3. requirements.txt +340 -0
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from bertopic import BERTopic
4
+ from datasets import load_dataset
5
+ from functools import lru_cache
6
+
7
+
8
+ def prep_dataset():
9
+ dataset = load_dataset("OpenAssistant/oasst1", split="train")
10
+ assistant_ds = dataset.filter(lambda x: x["role"] == "assistant")
11
+ assistant_ds_en = assistant_ds.filter(lambda x: x["lang"] == "en")
12
+ return assistant_ds_en["text"]
13
+
14
+
15
+ topic_model = BERTopic.load("davanstrien/chat_topics")
16
+
17
+ fig = topic_model.visualize_topics()
18
+
19
+
20
+ def plot_docs():
21
+ docs = prep_dataset()
22
+ return topic_model.visualize_documents(docs)
23
+
24
+
25
+ def search_topic(text):
26
+ similar_topics, _ = topic_model.find_topics(text, top_n=5)
27
+ topic_info = topic_model.get_topic_info()
28
+ return topic_info[topic_info["Topic"].isin(similar_topics)]
29
+
30
+
31
+ def plot_topic_words(num_topics=9, n_words=5):
32
+ return topic_model.visualize_barchart(top_n_topics=num_topics, n_words=n_words)
33
+
34
+
35
+ with gr.Blocks() as demo:
36
+ with gr.Tab("Topic words"):
37
+ topic_number = gr.Slider(
38
+ minimum=3, maximum=20, value=9, step=1, label="Number of topics"
39
+ )
40
+ plot = gr.Plot(plot_topic_words())
41
+ topic_number.change(plot_topic_words, [topic_number], plot)
42
+ with gr.Tab("Topic search"):
43
+ text = gr.Textbox(lines=1, label="Search text")
44
+ df = gr.DataFrame()
45
+ text.change(search_topic, [text], df)
46
+ with gr.Tab("Topic distribution"):
47
+ gr.Plot(fig)
48
+
49
+ # with gr.Tab("Doc visualization"):
50
+ # gr.Plot(plot_docs())
51
+
52
+ demo.launch(debug=True)
playground.ipynb ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": []
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": []
16
+ }
17
+ ],
18
+ "metadata": {
19
+ "kernelspec": {
20
+ "display_name": ".venv",
21
+ "language": "python",
22
+ "name": "python3"
23
+ },
24
+ "language_info": {
25
+ "name": "python",
26
+ "version": "3.11.3"
27
+ },
28
+ "orig_nbformat": 4
29
+ },
30
+ "nbformat": 4,
31
+ "nbformat_minor": 2
32
+ }
requirements.txt ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # This file is autogenerated by pip-compile with Python 3.11
3
+ # by the following command:
4
+ #
5
+ # pip-compile --resolver=backtracking
6
+ #
7
+ aiofiles==23.1.0
8
+ # via gradio
9
+ aiohttp==3.8.4
10
+ # via
11
+ # datasets
12
+ # fsspec
13
+ # gradio
14
+ aiosignal==1.3.1
15
+ # via aiohttp
16
+ altair==5.0.1
17
+ # via gradio
18
+ anyio==3.7.0
19
+ # via
20
+ # httpcore
21
+ # starlette
22
+ async-timeout==4.0.2
23
+ # via aiohttp
24
+ attrs==23.1.0
25
+ # via
26
+ # aiohttp
27
+ # jsonschema
28
+ bertopic==0.15.0
29
+ # via -r requirements.in
30
+ certifi==2023.5.7
31
+ # via
32
+ # httpcore
33
+ # httpx
34
+ # requests
35
+ charset-normalizer==3.1.0
36
+ # via
37
+ # aiohttp
38
+ # requests
39
+ click==8.1.3
40
+ # via
41
+ # nltk
42
+ # uvicorn
43
+ contourpy==1.0.7
44
+ # via matplotlib
45
+ cycler==0.11.0
46
+ # via matplotlib
47
+ cython==0.29.35
48
+ # via hdbscan
49
+ datasets==2.12.0
50
+ # via -r requirements.in
51
+ dill==0.3.6
52
+ # via
53
+ # datasets
54
+ # multiprocess
55
+ fastapi==0.96.0
56
+ # via gradio
57
+ fastjsonschema==2.17.1
58
+ # via nbformat
59
+ ffmpy==0.3.0
60
+ # via gradio
61
+ filelock==3.12.0
62
+ # via
63
+ # huggingface-hub
64
+ # torch
65
+ # transformers
66
+ fonttools==4.39.4
67
+ # via matplotlib
68
+ frozenlist==1.3.3
69
+ # via
70
+ # aiohttp
71
+ # aiosignal
72
+ fsspec[http]==2023.5.0
73
+ # via
74
+ # datasets
75
+ # gradio-client
76
+ # huggingface-hub
77
+ gradio==3.33.1
78
+ # via -r requirements.in
79
+ gradio-client==0.2.5
80
+ # via gradio
81
+ h11==0.14.0
82
+ # via
83
+ # httpcore
84
+ # uvicorn
85
+ hdbscan==0.8.29
86
+ # via bertopic
87
+ httpcore==0.17.2
88
+ # via httpx
89
+ httpx==0.24.1
90
+ # via
91
+ # gradio
92
+ # gradio-client
93
+ huggingface-hub==0.15.1
94
+ # via
95
+ # datasets
96
+ # gradio
97
+ # gradio-client
98
+ # sentence-transformers
99
+ # transformers
100
+ idna==3.4
101
+ # via
102
+ # anyio
103
+ # httpx
104
+ # requests
105
+ # yarl
106
+ jinja2==3.1.2
107
+ # via
108
+ # altair
109
+ # gradio
110
+ # torch
111
+ joblib==1.2.0
112
+ # via
113
+ # hdbscan
114
+ # nltk
115
+ # pynndescent
116
+ # scikit-learn
117
+ jsonschema==4.17.3
118
+ # via
119
+ # altair
120
+ # nbformat
121
+ jupyter-core==5.3.0
122
+ # via nbformat
123
+ kiwisolver==1.4.4
124
+ # via matplotlib
125
+ linkify-it-py==2.0.2
126
+ # via markdown-it-py
127
+ llvmlite==0.40.0
128
+ # via
129
+ # numba
130
+ # pynndescent
131
+ markdown-it-py[linkify]==2.2.0
132
+ # via
133
+ # gradio
134
+ # mdit-py-plugins
135
+ markupsafe==2.1.3
136
+ # via
137
+ # gradio
138
+ # jinja2
139
+ matplotlib==3.7.1
140
+ # via gradio
141
+ mdit-py-plugins==0.3.3
142
+ # via gradio
143
+ mdurl==0.1.2
144
+ # via markdown-it-py
145
+ mpmath==1.3.0
146
+ # via sympy
147
+ multidict==6.0.4
148
+ # via
149
+ # aiohttp
150
+ # yarl
151
+ multiprocess==0.70.14
152
+ # via datasets
153
+ nbformat==5.9.0
154
+ # via -r requirements.in
155
+ networkx==3.1
156
+ # via torch
157
+ nltk==3.8.1
158
+ # via sentence-transformers
159
+ numba==0.57.0
160
+ # via
161
+ # pynndescent
162
+ # umap-learn
163
+ numpy==1.24.3
164
+ # via
165
+ # altair
166
+ # bertopic
167
+ # contourpy
168
+ # datasets
169
+ # gradio
170
+ # hdbscan
171
+ # matplotlib
172
+ # numba
173
+ # pandas
174
+ # pyarrow
175
+ # scikit-learn
176
+ # scipy
177
+ # sentence-transformers
178
+ # torchvision
179
+ # transformers
180
+ # umap-learn
181
+ orjson==3.9.0
182
+ # via gradio
183
+ packaging==23.1
184
+ # via
185
+ # datasets
186
+ # gradio-client
187
+ # huggingface-hub
188
+ # matplotlib
189
+ # plotly
190
+ # transformers
191
+ pandas==2.0.2
192
+ # via
193
+ # altair
194
+ # bertopic
195
+ # datasets
196
+ # gradio
197
+ pillow==9.5.0
198
+ # via
199
+ # gradio
200
+ # matplotlib
201
+ # torchvision
202
+ platformdirs==3.5.1
203
+ # via jupyter-core
204
+ plotly==5.14.1
205
+ # via bertopic
206
+ pyarrow==12.0.0
207
+ # via datasets
208
+ pydantic==1.10.8
209
+ # via
210
+ # fastapi
211
+ # gradio
212
+ pydub==0.25.1
213
+ # via gradio
214
+ pygments==2.15.1
215
+ # via gradio
216
+ pynndescent==0.5.10
217
+ # via umap-learn
218
+ pyparsing==3.0.9
219
+ # via matplotlib
220
+ pyrsistent==0.19.3
221
+ # via jsonschema
222
+ python-dateutil==2.8.2
223
+ # via
224
+ # matplotlib
225
+ # pandas
226
+ python-multipart==0.0.6
227
+ # via gradio
228
+ pytz==2023.3
229
+ # via pandas
230
+ pyyaml==6.0
231
+ # via
232
+ # datasets
233
+ # gradio
234
+ # huggingface-hub
235
+ # transformers
236
+ regex==2023.6.3
237
+ # via
238
+ # nltk
239
+ # transformers
240
+ requests==2.31.0
241
+ # via
242
+ # datasets
243
+ # fsspec
244
+ # gradio
245
+ # gradio-client
246
+ # huggingface-hub
247
+ # responses
248
+ # torchvision
249
+ # transformers
250
+ responses==0.18.0
251
+ # via datasets
252
+ safetensors==0.3.1
253
+ # via -r requirements.in
254
+ scikit-learn==1.2.2
255
+ # via
256
+ # bertopic
257
+ # hdbscan
258
+ # pynndescent
259
+ # sentence-transformers
260
+ # umap-learn
261
+ scipy==1.10.1
262
+ # via
263
+ # hdbscan
264
+ # pynndescent
265
+ # scikit-learn
266
+ # sentence-transformers
267
+ # umap-learn
268
+ semantic-version==2.10.0
269
+ # via gradio
270
+ sentence-transformers==2.2.2
271
+ # via bertopic
272
+ sentencepiece==0.1.99
273
+ # via sentence-transformers
274
+ six==1.16.0
275
+ # via python-dateutil
276
+ sniffio==1.3.0
277
+ # via
278
+ # anyio
279
+ # httpcore
280
+ # httpx
281
+ starlette==0.27.0
282
+ # via fastapi
283
+ sympy==1.12
284
+ # via torch
285
+ tenacity==8.2.2
286
+ # via plotly
287
+ threadpoolctl==3.1.0
288
+ # via scikit-learn
289
+ tokenizers==0.13.3
290
+ # via transformers
291
+ toolz==0.12.0
292
+ # via altair
293
+ torch==2.0.1
294
+ # via
295
+ # sentence-transformers
296
+ # torchvision
297
+ torchvision==0.15.2
298
+ # via sentence-transformers
299
+ tqdm==4.65.0
300
+ # via
301
+ # bertopic
302
+ # datasets
303
+ # huggingface-hub
304
+ # nltk
305
+ # sentence-transformers
306
+ # transformers
307
+ # umap-learn
308
+ traitlets==5.9.0
309
+ # via
310
+ # jupyter-core
311
+ # nbformat
312
+ transformers==4.29.2
313
+ # via sentence-transformers
314
+ typing-extensions==4.6.3
315
+ # via
316
+ # gradio
317
+ # gradio-client
318
+ # huggingface-hub
319
+ # pydantic
320
+ # torch
321
+ tzdata==2023.3
322
+ # via pandas
323
+ uc-micro-py==1.0.2
324
+ # via linkify-it-py
325
+ umap-learn==0.5.3
326
+ # via bertopic
327
+ urllib3==2.0.2
328
+ # via
329
+ # requests
330
+ # responses
331
+ uvicorn==0.22.0
332
+ # via gradio
333
+ websockets==11.0.3
334
+ # via
335
+ # gradio
336
+ # gradio-client
337
+ xxhash==3.2.0
338
+ # via datasets
339
+ yarl==1.9.2
340
+ # via aiohttp