kisejin commited on
Commit
19b102a
·
verified ·
1 Parent(s): 71464bf

Upload 261 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. BERTopic/.flake8 +2 -0
  3. BERTopic/.gitattributes +1 -0
  4. BERTopic/.github/CONTRIBUTING.md +53 -0
  5. BERTopic/.github/workflows/testing.yml +31 -0
  6. BERTopic/.gitignore +83 -0
  7. BERTopic/LICENSE +21 -0
  8. BERTopic/Makefile +24 -0
  9. BERTopic/README.md +297 -0
  10. BERTopic/bertopic/__init__.py +7 -0
  11. BERTopic/bertopic/_bertopic.py +0 -0
  12. BERTopic/bertopic/_save_utils.py +492 -0
  13. BERTopic/bertopic/_utils.py +149 -0
  14. BERTopic/bertopic/backend/__init__.py +35 -0
  15. BERTopic/bertopic/backend/_base.py +69 -0
  16. BERTopic/bertopic/backend/_cohere.py +94 -0
  17. BERTopic/bertopic/backend/_flair.py +78 -0
  18. BERTopic/bertopic/backend/_gensim.py +66 -0
  19. BERTopic/bertopic/backend/_hftransformers.py +96 -0
  20. BERTopic/bertopic/backend/_multimodal.py +194 -0
  21. BERTopic/bertopic/backend/_openai.py +88 -0
  22. BERTopic/bertopic/backend/_sentencetransformers.py +66 -0
  23. BERTopic/bertopic/backend/_sklearn.py +68 -0
  24. BERTopic/bertopic/backend/_spacy.py +94 -0
  25. BERTopic/bertopic/backend/_use.py +58 -0
  26. BERTopic/bertopic/backend/_utils.py +135 -0
  27. BERTopic/bertopic/backend/_word_doc.py +49 -0
  28. BERTopic/bertopic/cluster/__init__.py +5 -0
  29. BERTopic/bertopic/cluster/_base.py +41 -0
  30. BERTopic/bertopic/cluster/_utils.py +70 -0
  31. BERTopic/bertopic/dimensionality/__init__.py +5 -0
  32. BERTopic/bertopic/dimensionality/_base.py +26 -0
  33. BERTopic/bertopic/plotting/__init__.py +28 -0
  34. BERTopic/bertopic/plotting/_approximate_distribution.py +99 -0
  35. BERTopic/bertopic/plotting/_barchart.py +127 -0
  36. BERTopic/bertopic/plotting/_datamap.py +152 -0
  37. BERTopic/bertopic/plotting/_distribution.py +110 -0
  38. BERTopic/bertopic/plotting/_documents.py +227 -0
  39. BERTopic/bertopic/plotting/_heatmap.py +138 -0
  40. BERTopic/bertopic/plotting/_hierarchical_documents.py +336 -0
  41. BERTopic/bertopic/plotting/_hierarchy.py +308 -0
  42. BERTopic/bertopic/plotting/_term_rank.py +135 -0
  43. BERTopic/bertopic/plotting/_topics.py +162 -0
  44. BERTopic/bertopic/plotting/_topics_over_time.py +123 -0
  45. BERTopic/bertopic/plotting/_topics_per_class.py +130 -0
  46. BERTopic/bertopic/representation/__init__.py +68 -0
  47. BERTopic/bertopic/representation/_base.py +38 -0
  48. BERTopic/bertopic/representation/_cohere.py +193 -0
  49. BERTopic/bertopic/representation/_keybert.py +198 -0
  50. BERTopic/bertopic/representation/_langchain.py +203 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/df_en_review_vntourism.csv filter=lfs diff=lfs merge=lfs -text
37
+ data/merge_final.csv filter=lfs diff=lfs merge=lfs -text
BERTopic/.flake8 ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [flake8]
2
+ max-line-length = 160
BERTopic/.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.ipynb linguist-documentation
BERTopic/.github/CONTRIBUTING.md ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to BERTopic
2
+
3
+ Hi! Thank you for considering contributing to BERTopic. With the modular nature of BERTopic, many new add-ons, backends, representation models, sub-models, and LLMs, can quickly be added to keep up with the incredibly fast-pacing field.
4
+
5
+ Whether contributions are new features, better documentation, bug fixes, or improvement on the repository itself, anything is appreciated!
6
+
7
+ ## 📚 Guidelines
8
+
9
+ ### 🤖 Contributing Code
10
+
11
+ To contribute to this project, we follow an `issue -> pull request` approach for main features and bug fixes. This means that any new feature, bug fix, or anything else that touches on code directly needs to start from an issue first. That way, the main discussion about what needs to be added/fixed can be done in the issue before creating a pull request. This makes sure that we are on the same page before you start coding your pull request. If you start working on an issue, please assign it to yourself but do so after there is an agreement with the maintainer, [@MaartenGr](https://github.com/MaartenGr).
12
+
13
+ When there is agreement on the assigned approach, a pull request can be created in which the fix/feature can be added. This follows a ["fork and pull request"](https://docs.github.com/en/get-started/quickstart/contributing-to-projects) workflow.
14
+ Please do not try to push directly to this repo unless you are a maintainer.
15
+
16
+ There are exceptions to the `issue -> pull request` approach that are typically small changes that do not need agreements, such as:
17
+ * Documentation
18
+ * Spelling/grammar issues
19
+ * Docstrings
20
+ * etc.
21
+
22
+ There is a large focus on documentation in this repository, so please make sure to add extensive descriptions of features when creating the pull request.
23
+
24
+ Note that the main focus of pull requests and code should be:
25
+ * Easy readability
26
+ * Clear communication
27
+ * Sufficient documentation
28
+
29
+ ## 🚀 Quick Start
30
+
31
+ To start contributing, make sure to first start from a fresh environment. Using an environment manager, such as `conda` or `pyenv` helps in making sure that your code is reproducible and tracks the versions you have in your environment.
32
+
33
+ If you are using conda, you can approach it as follows:
34
+
35
+ 1. Create and activate a new conda environment (e.g., `conda create -n bertopic python=3.9`)
36
+ 2. Install requirements (e.g., `pip install .[dev]`)
37
+ * This makes sure to also install documentation and testing packages
38
+ 3. (Optional) Run `make docs` to build your documentation
39
+ 4. (Optional) Run `make test` to run the unit tests and `make coverage` to check the coverage of unit tests
40
+
41
+ ❗Note: Unit testing the package can take quite some time since it needs to run several variants of the BERTopic pipeline.
42
+
43
+ ## 🤓 Collaborative Efforts
44
+
45
+ When you run into any issue with the above or need help to start with a pull request, feel free to reach out in the issues! As with all repositories, this one has its particularities as a result of the maintainer's view. Each repository is quite different and so will their processes.
46
+
47
+ ## 🏆 Recognition
48
+
49
+ If your contribution has made its way into a new release of BERTopic, you will be given credit in the changelog of the new release! Regardless of the size of the contribution, any help is greatly appreciated.
50
+
51
+ ## 🎈 Release
52
+
53
+ BERTopic tries to mostly follow [semantic versioning](https://semver.org/) for its new releases. Even though BERTopic has been around for a few years now, it is still pre-1.0 software. With the rapid chances in the field and as a way to keep up, this versioning is on purpose. Backwards-compatibility is taken into account but integrating new features and thereby keeping up with the field takes priority. Especially since BERTopic focuses on modularity, flexibility is necessary.
BERTopic/.github/workflows/testing.yml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Code Checks
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - master
7
+ - dev
8
+ pull_request:
9
+ branches:
10
+ - master
11
+ - dev
12
+
13
+ jobs:
14
+ build:
15
+ runs-on: ubuntu-latest
16
+ strategy:
17
+ matrix:
18
+ python-version: [3.8, 3.9]
19
+
20
+ steps:
21
+ - uses: actions/checkout@v2
22
+ - name: Set up Python ${{ matrix.python-version }}
23
+ uses: actions/setup-python@v1
24
+ with:
25
+ python-version: ${{ matrix.python-version }}
26
+ - name: Install dependencies
27
+ run: |
28
+ python -m pip install --upgrade pip
29
+ pip install -e ".[test]"
30
+ - name: Run Checking Mechanisms
31
+ run: make check
BERTopic/.gitignore ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+ model_dir
30
+ model_dir/
31
+ test
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+
56
+ # Sphinx documentation
57
+ docs/_build/
58
+
59
+
60
+ # Jupyter Notebook
61
+ .ipynb_checkpoints
62
+
63
+ # IPython
64
+ profile_default/
65
+ ipython_config.py
66
+
67
+ # pyenv
68
+ .python-version
69
+
70
+ # Environments
71
+ .env
72
+ .venv
73
+ env/
74
+ venv/
75
+ ENV/
76
+ env.bak/
77
+ venv.bak/
78
+
79
+ # Artifacts
80
+ .idea
81
+ .idea/
82
+ .vscode
83
+ .DS_Store
BERTopic/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023, Maarten P. Grootendorst
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
BERTopic/Makefile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ test:
2
+ pytest
3
+
4
+ coverage:
5
+ pytest --cov
6
+
7
+ install:
8
+ python -m pip install -e .
9
+
10
+ install-test:
11
+ python -m pip install -e ".[dev]"
12
+
13
+ docs:
14
+ mkdocs serve
15
+
16
+ pypi:
17
+ python setup.py sdist
18
+ python setup.py bdist_wheel --universal
19
+ twine upload dist/*
20
+
21
+ clean:
22
+ rm -rf **/.ipynb_checkpoints **/.pytest_cache **/__pycache__ **/**/__pycache__ .ipynb_checkpoints .pytest_cache
23
+
24
+ check: test clean
BERTopic/README.md ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [![PyPI - Python](https://img.shields.io/badge/python-v3.7+-blue.svg)](https://pypi.org/project/bertopic/)
2
+ [![Build](https://img.shields.io/github/actions/workflow/status/MaartenGr/BERTopic/testing.yml?branch=master)](https://github.com/MaartenGr/BERTopic/actions)
3
+ [![docs](https://img.shields.io/badge/docs-Passing-green.svg)](https://maartengr.github.io/BERTopic/)
4
+ [![PyPI - PyPi](https://img.shields.io/pypi/v/BERTopic)](https://pypi.org/project/bertopic/)
5
+ [![PyPI - License](https://img.shields.io/badge/license-MIT-green.svg)](https://github.com/MaartenGr/VLAC/blob/master/LICENSE)
6
+ [![arXiv](https://img.shields.io/badge/arXiv-2203.05794-<COLOR>.svg)](https://arxiv.org/abs/2203.05794)
7
+
8
+
9
+ # BERTopic
10
+
11
+ <img src="images/logo.png" width="35%" height="35%" align="right" />
12
+
13
+ BERTopic is a topic modeling technique that leverages 🤗 transformers and c-TF-IDF to create dense clusters
14
+ allowing for easily interpretable topics whilst keeping important words in the topic descriptions.
15
+
16
+ BERTopic supports all kinds of topic modeling techniques:
17
+ <table>
18
+ <tr>
19
+ <td><a href="https://maartengr.github.io/BERTopic/getting_started/guided/guided.html">Guided</a></td>
20
+ <td><a href="https://maartengr.github.io/BERTopic/getting_started/supervised/supervised.html">Supervised</a></td>
21
+ <td><a href="https://maartengr.github.io/BERTopic/getting_started/semisupervised/semisupervised.html">Semi-supervised</a></td>
22
+ </tr>
23
+ <tr>
24
+ <td><a href="https://maartengr.github.io/BERTopic/getting_started/manual/manual.html">Manual</a></td>
25
+ <td><a href="https://maartengr.github.io/BERTopic/getting_started/distribution/distribution.html">Multi-topic distributions</a></td>
26
+ <td><a href="https://maartengr.github.io/BERTopic/getting_started/hierarchicaltopics/hierarchicaltopics.html">Hierarchical</a></td>
27
+ </tr>
28
+ <tr>
29
+ <td><a href="https://maartengr.github.io/BERTopic/getting_started/topicsperclass/topicsperclass.html">Class-based</a></td>
30
+ <td><a href="https://maartengr.github.io/BERTopic/getting_started/topicsovertime/topicsovertime.html">Dynamic</a></td>
31
+ <td><a href="https://maartengr.github.io/BERTopic/getting_started/online/online.html">Online/Incremental</a></td>
32
+ </tr>
33
+ <tr>
34
+ <td><a href="https://maartengr.github.io/BERTopic/getting_started/multimodal/multimodal.html">Multimodal</a></td>
35
+ <td><a href="https://maartengr.github.io/BERTopic/getting_started/multiaspect/multiaspect.html">Multi-aspect</a></td>
36
+ <td><a href="https://maartengr.github.io/BERTopic/getting_started/representation/llm.html">Text Generation/LLM</a></td>
37
+ </tr>
38
+ <tr>
39
+ <td><a href="https://maartengr.github.io/BERTopic/getting_started/zeroshot/zeroshot.html">Zero-shot <b>(new!)</b></a></td>
40
+ <td><a href="https://maartengr.github.io/BERTopic/getting_started/merge/merge.html">Merge Models <b>(new!)</b></a></td>
41
+ <td><a href="https://maartengr.github.io/BERTopic/getting_started/seed_words/seed_words.html">Seed Words <b>(new!)</b></a></td>
42
+ </tr>
43
+ </table>
44
+
45
+ Corresponding medium posts can be found [here](https://towardsdatascience.com/topic-modeling-with-bert-779f7db187e6?source=friends_link&sk=0b5a470c006d1842ad4c8a3057063a99), [here](https://towardsdatascience.com/interactive-topic-modeling-with-bertopic-1ea55e7d73d8?sk=03c2168e9e74b6bda2a1f3ed953427e4) and [here](https://towardsdatascience.com/using-whisper-and-bertopic-to-model-kurzgesagts-videos-7d8a63139bdf?sk=b1e0fd46f70cb15e8422b4794a81161d). For a more detailed overview, you can read the [paper](https://arxiv.org/abs/2203.05794) or see a [brief overview](https://maartengr.github.io/BERTopic/algorithm/algorithm.html).
46
+
47
+ ## Installation
48
+
49
+ Installation, with sentence-transformers, can be done using [pypi](https://pypi.org/project/bertopic/):
50
+
51
+ ```bash
52
+ pip install bertopic
53
+ ```
54
+
55
+ If you want to install BERTopic with other embedding models, you can choose one of the following:
56
+
57
+ ```bash
58
+ # Choose an embedding backend
59
+ pip install bertopic[flair,gensim,spacy,use]
60
+
61
+ # Topic modeling with images
62
+ pip install bertopic[vision]
63
+ ```
64
+
65
+ ## Getting Started
66
+ For an in-depth overview of the features of BERTopic
67
+ you can check the [**full documentation**](https://maartengr.github.io/BERTopic/) or you can follow along
68
+ with one of the examples below:
69
+
70
+ | Name | Link |
71
+ |---|---|
72
+ | Start Here - **Best Practices in BERTopic** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1BoQ_vakEVtojsd2x_U6-_x52OOuqruj2?usp=sharing) |
73
+ | **🆕 New!** - Topic Modeling on Large Data (GPU Acceleration) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1W7aEdDPxC29jP99GGZphUlqjMFFVKtBC?usp=sharing) |
74
+ | **🆕 New!** - Topic Modeling with Llama 2 🦙 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1QCERSMUjqGetGGujdrvv_6_EeoIcd_9M?usp=sharing) |
75
+ | **🆕 New!** - Topic Modeling with Quantized LLMs | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1DdSHvVPJA3rmNfBWjCo2P1E9686xfxFx?usp=sharing) |
76
+ | Topic Modeling with BERTopic | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1FieRA9fLdkQEGDIMYl0I3MCjSUKVF8C-?usp=sharing) |
77
+ | (Custom) Embedding Models in BERTopic | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/18arPPe50szvcCp_Y6xS56H2tY0m-RLqv?usp=sharing) |
78
+ | Advanced Customization in BERTopic | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1ClTYut039t-LDtlcd-oQAdXWgcsSGTw9?usp=sharing) |
79
+ | (semi-)Supervised Topic Modeling with BERTopic | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1bxizKzv5vfxJEB29sntU__ZC7PBSIPaQ?usp=sharing) |
80
+ | Dynamic Topic Modeling with Trump's Tweets | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1un8ooI-7ZNlRoK0maVkYhmNRl0XGK88f?usp=sharing) |
81
+ | Topic Modeling arXiv Abstracts | [![Kaggle](https://img.shields.io/static/v1?style=for-the-badge&message=Kaggle&color=222222&logo=Kaggle&logoColor=20BEFF&label=)](https://www.kaggle.com/maartengr/topic-modeling-arxiv-abstract-with-bertopic) |
82
+
83
+
84
+ ## Quick Start
85
+ We start by extracting topics from the well-known 20 newsgroups dataset containing English documents:
86
+
87
+ ```python
88
+ from bertopic import BERTopic
89
+ from sklearn.datasets import fetch_20newsgroups
90
+
91
+ docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data']
92
+
93
+ topic_model = BERTopic()
94
+ topics, probs = topic_model.fit_transform(docs)
95
+ ```
96
+
97
+ After generating topics and their probabilities, we can access all of the topics together with their topic representations:
98
+
99
+ ```python
100
+ >>> topic_model.get_topic_info()
101
+
102
+ Topic Count Name
103
+ -1 4630 -1_can_your_will_any
104
+ 0 693 49_windows_drive_dos_file
105
+ 1 466 32_jesus_bible_christian_faith
106
+ 2 441 2_space_launch_orbit_lunar
107
+ 3 381 22_key_encryption_keys_encrypted
108
+ ...
109
+ ```
110
+
111
+ The `-1` topic refers to all outlier documents and are typically ignored. Each word in a topic describes the underlying theme of that topic and can be used
112
+ for interpreting that topic. Next, let's take a look at the most frequent topic that was generated:
113
+
114
+ ```python
115
+ >>> topic_model.get_topic(0)
116
+
117
+ [('windows', 0.006152228076250982),
118
+ ('drive', 0.004982897610645755),
119
+ ('dos', 0.004845038866360651),
120
+ ('file', 0.004140142872194834),
121
+ ('disk', 0.004131678774810884),
122
+ ('mac', 0.003624848635985097),
123
+ ('memory', 0.0034840976976789903),
124
+ ('software', 0.0034415334250699077),
125
+ ('email', 0.0034239554442333257),
126
+ ('pc', 0.003047105930670237)]
127
+ ```
128
+
129
+ Using `.get_document_info`, we can also extract information on a document level, such as their corresponding topics, probabilities, whether they are representative documents for a topic, etc.:
130
+
131
+ ```python
132
+ >>> topic_model.get_document_info(docs)
133
+
134
+ Document Topic Name Top_n_words Probability ...
135
+ I am sure some bashers of Pens... 0 0_game_team_games_season game - team - games... 0.200010 ...
136
+ My brother is in the market for... -1 -1_can_your_will_any can - your - will... 0.420668 ...
137
+ Finally you said what you dream... -1 -1_can_your_will_any can - your - will... 0.807259 ...
138
+ Think! It's the SCSI card doing... 49 49_windows_drive_dos_file windows - drive - docs... 0.071746 ...
139
+ 1) I have an old Jasmine drive... 49 49_windows_drive_dos_file windows - drive - docs... 0.038983 ...
140
+ ```
141
+
142
+ **`🔥 Tip`**: Use `BERTopic(language="multilingual")` to select a model that supports 50+ languages.
143
+
144
+ ## Fine-tune Topic Representations
145
+
146
+ In BERTopic, there are a number of different [topic representations](https://maartengr.github.io/BERTopic/getting_started/representation/representation.html) that we can choose from. They are all quite different from one another and give interesting perspectives and variations of topic representations. A great start is `KeyBERTInspired`, which for many users increases the coherence and reduces stopwords from the resulting topic representations:
147
+
148
+ ```python
149
+ from bertopic.representation import KeyBERTInspired
150
+
151
+ # Fine-tune your topic representations
152
+ representation_model = KeyBERTInspired()
153
+ topic_model = BERTopic(representation_model=representation_model)
154
+ ```
155
+
156
+ However, you might want to use something more powerful to describe your clusters. You can even use ChatGPT or other models from OpenAI to generate labels, summaries, phrases, keywords, and more:
157
+
158
+ ```python
159
+ import openai
160
+ from bertopic.representation import OpenAI
161
+
162
+ # Fine-tune topic representations with GPT
163
+ client = openai.OpenAI(api_key="sk-...")
164
+ representation_model = OpenAI(client, model="gpt-3.5-turbo", chat=True)
165
+ topic_model = BERTopic(representation_model=representation_model)
166
+ ```
167
+
168
+ **`🔥 Tip`**: Instead of iterating over all of these different topic representations, you can model them simultaneously with [multi-aspect topic representations](https://maartengr.github.io/BERTopic/getting_started/multiaspect/multiaspect.html) in BERTopic.
169
+
170
+
171
+ ## Visualizations
172
+ After having trained our BERTopic model, we can iteratively go through hundreds of topics to get a good
173
+ understanding of the topics that were extracted. However, that takes quite some time and lacks a global representation. Instead, we can use one of the [many visualization options](https://maartengr.github.io/BERTopic/getting_started/visualization/visualization.html) in BERTopic.
174
+ For example, we can visualize the topics that were generated in a way very similar to
175
+ [LDAvis](https://github.com/cpsievert/LDAvis):
176
+
177
+ ```python
178
+ topic_model.visualize_topics()
179
+ ```
180
+
181
+ <img src="images/topic_visualization.gif" width="60%" height="60%" align="center" />
182
+
183
+ ## Modularity
184
+ By default, the [main steps](https://maartengr.github.io/BERTopic/algorithm/algorithm.html) for topic modeling with BERTopic are sentence-transformers, UMAP, HDBSCAN, and c-TF-IDF run in sequence. However, it assumes some independence between these steps which makes BERTopic quite modular. In other words, BERTopic not only allows you to build your own topic model but to explore several topic modeling techniques on top of your customized topic model:
185
+
186
+ https://user-images.githubusercontent.com/25746895/218420473-4b2bb539-9dbe-407a-9674-a8317c7fb3bf.mp4
187
+
188
+ You can swap out any of these models or even remove them entirely. The following steps are completely modular:
189
+
190
+ 1. [Embedding](https://maartengr.github.io/BERTopic/getting_started/embeddings/embeddings.html) documents
191
+ 2. [Reducing dimensionality](https://maartengr.github.io/BERTopic/getting_started/dim_reduction/dim_reduction.html) of embeddings
192
+ 3. [Clustering](https://maartengr.github.io/BERTopic/getting_started/clustering/clustering.html) reduced embeddings into topics
193
+ 4. [Tokenization](https://maartengr.github.io/BERTopic/getting_started/vectorizers/vectorizers.html) of topics
194
+ 5. [Weight](https://maartengr.github.io/BERTopic/getting_started/ctfidf/ctfidf.html) tokens
195
+ 6. [Represent topics](https://maartengr.github.io/BERTopic/getting_started/representation/representation.html) with one or [multiple](https://maartengr.github.io/BERTopic/getting_started/multiaspect/multiaspect.html) representations
196
+
197
+
198
+ ## Functionality
199
+ BERTopic has many functions that quickly can become overwhelming. To alleviate this issue, you will find an overview
200
+ of all methods and a short description of its purpose.
201
+
202
+ ### Common
203
+ Below, you will find an overview of common functions in BERTopic.
204
+
205
+ | Method | Code |
206
+ |-----------------------|---|
207
+ | Fit the model | `.fit(docs)` |
208
+ | Fit the model and predict documents | `.fit_transform(docs)` |
209
+ | Predict new documents | `.transform([new_doc])` |
210
+ | Access single topic | `.get_topic(topic=12)` |
211
+ | Access all topics | `.get_topics()` |
212
+ | Get topic freq | `.get_topic_freq()` |
213
+ | Get all topic information| `.get_topic_info()` |
214
+ | Get all document information| `.get_document_info(docs)` |
215
+ | Get representative docs per topic | `.get_representative_docs()` |
216
+ | Update topic representation | `.update_topics(docs, n_gram_range=(1, 3))` |
217
+ | Generate topic labels | `.generate_topic_labels()` |
218
+ | Set topic labels | `.set_topic_labels(my_custom_labels)` |
219
+ | Merge topics | `.merge_topics(docs, topics_to_merge)` |
220
+ | Reduce nr of topics | `.reduce_topics(docs, nr_topics=30)` |
221
+ | Reduce outliers | `.reduce_outliers(docs, topics)` |
222
+ | Find topics | `.find_topics("vehicle")` |
223
+ | Save model | `.save("my_model", serialization="safetensors")` |
224
+ | Load model | `BERTopic.load("my_model")` |
225
+ | Get parameters | `.get_params()` |
226
+
227
+
228
+ ### Attributes
229
+ After having trained your BERTopic model, several attributes are saved within your model. These attributes, in part,
230
+ refer to how model information is stored on an estimator during fitting. The attributes that you see below all end in `_` and are
231
+ public attributes that can be used to access model information.
232
+
233
+ | Attribute | Description |
234
+ |------------------------|---------------------------------------------------------------------------------------------|
235
+ | `.topics_` | The topics that are generated for each document after training or updating the topic model. |
236
+ | `.probabilities_` | The probabilities that are generated for each document if HDBSCAN is used. |
237
+ | `.topic_sizes_` | The size of each topic |
238
+ | `.topic_mapper_` | A class for tracking topics and their mappings anytime they are merged/reduced. |
239
+ | `.topic_representations_` | The top *n* terms per topic and their respective c-TF-IDF values. |
240
+ | `.c_tf_idf_` | The topic-term matrix as calculated through c-TF-IDF. |
241
+ | `.topic_aspects_` | The different aspects, or representations, of each topic. |
242
+ | `.topic_labels_` | The default labels for each topic. |
243
+ | `.custom_labels_` | Custom labels for each topic as generated through `.set_topic_labels`. |
244
+ | `.topic_embeddings_` | The embeddings for each topic if `embedding_model` was used. |
245
+ | `.representative_docs_` | The representative documents for each topic if HDBSCAN is used. |
246
+
247
+
248
+ ### Variations
249
+ There are many different use cases in which topic modeling can be used. As such, several variations of BERTopic have been developed such that one package can be used across many use cases.
250
+
251
+ | Method | Code |
252
+ |-----------------------|---|
253
+ | [Topic Distribution Approximation](https://maartengr.github.io/BERTopic/getting_started/distribution/distribution.html) | `.approximate_distribution(docs)` |
254
+ | [Online Topic Modeling](https://maartengr.github.io/BERTopic/getting_started/online/online.html) | `.partial_fit(doc)` |
255
+ | [Semi-supervised Topic Modeling](https://maartengr.github.io/BERTopic/getting_started/semisupervised/semisupervised.html) | `.fit(docs, y=y)` |
256
+ | [Supervised Topic Modeling](https://maartengr.github.io/BERTopic/getting_started/supervised/supervised.html) | `.fit(docs, y=y)` |
257
+ | [Manual Topic Modeling](https://maartengr.github.io/BERTopic/getting_started/manual/manual.html) | `.fit(docs, y=y)` |
258
+ | [Multimodal Topic Modeling](https://maartengr.github.io/BERTopic/getting_started/multimodal/multimodal.html) | ``.fit(docs, images=images)`` |
259
+ | [Topic Modeling per Class](https://maartengr.github.io/BERTopic/getting_started/topicsperclass/topicsperclass.html) | `.topics_per_class(docs, classes)` |
260
+ | [Dynamic Topic Modeling](https://maartengr.github.io/BERTopic/getting_started/topicsovertime/topicsovertime.html) | `.topics_over_time(docs, timestamps)` |
261
+ | [Hierarchical Topic Modeling](https://maartengr.github.io/BERTopic/getting_started/hierarchicaltopics/hierarchicaltopics.html) | `.hierarchical_topics(docs)` |
262
+ | [Guided Topic Modeling](https://maartengr.github.io/BERTopic/getting_started/guided/guided.html) | `BERTopic(seed_topic_list=seed_topic_list)` |
263
+ | [Zero-shot Topic Modeling](https://maartengr.github.io/BERTopic/getting_started/zeroshot/zeroshot.html) | `BERTopic(zeroshot_topic_list=zeroshot_topic_list)` |
264
+ | [Merge Multiple Models](https://maartengr.github.io/BERTopic/getting_started/merge/merge.html) | `BERTopic.merge_models([topic_model_1, topic_model_2])` |
265
+
266
+
267
+ ### Visualizations
268
+ Evaluating topic models can be rather difficult due to the somewhat subjective nature of evaluation.
269
+ Visualizing different aspects of the topic model helps in understanding the model and makes it easier
270
+ to tweak the model to your liking.
271
+
272
+ | Method | Code |
273
+ |-----------------------|---|
274
+ | Visualize Topics | `.visualize_topics()` |
275
+ | Visualize Documents | `.visualize_documents()` |
276
+ | Visualize Document Hierarchy | `.visualize_hierarchical_documents()` |
277
+ | Visualize Topic Hierarchy | `.visualize_hierarchy()` |
278
+ | Visualize Topic Tree | `.get_topic_tree(hierarchical_topics)` |
279
+ | Visualize Topic Terms | `.visualize_barchart()` |
280
+ | Visualize Topic Similarity | `.visualize_heatmap()` |
281
+ | Visualize Term Score Decline | `.visualize_term_rank()` |
282
+ | Visualize Topic Probability Distribution | `.visualize_distribution(probs[0])` |
283
+ | Visualize Topics over Time | `.visualize_topics_over_time(topics_over_time)` |
284
+ | Visualize Topics per Class | `.visualize_topics_per_class(topics_per_class)` |
285
+
286
+
287
+ ## Citation
288
+ To cite the [BERTopic paper](https://arxiv.org/abs/2203.05794), please use the following bibtex reference:
289
+
290
+ ```bibtext
291
+ @article{grootendorst2022bertopic,
292
+ title={BERTopic: Neural topic modeling with a class-based TF-IDF procedure},
293
+ author={Grootendorst, Maarten},
294
+ journal={arXiv preprint arXiv:2203.05794},
295
+ year={2022}
296
+ }
297
+ ```
BERTopic/bertopic/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from bertopic._bertopic import BERTopic
2
+
3
+ __version__ = "0.16.0"
4
+
5
+ __all__ = [
6
+ "BERTopic",
7
+ ]
BERTopic/bertopic/_bertopic.py ADDED
The diff for this file is too large to render. See raw diff
 
BERTopic/bertopic/_save_utils.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import numpy as np
5
+
6
+ from pathlib import Path
7
+ from tempfile import TemporaryDirectory
8
+
9
+
10
+ # HuggingFace Hub
11
+ try:
12
+ from huggingface_hub import (
13
+ create_repo, get_hf_file_metadata,
14
+ hf_hub_download, hf_hub_url,
15
+ repo_type_and_id_from_hf_id, upload_folder)
16
+ _has_hf_hub = True
17
+ except ImportError:
18
+ _has_hf_hub = False
19
+
20
+ # Typing
21
+ if sys.version_info >= (3, 8):
22
+ from typing import Literal
23
+ else:
24
+ from typing_extensions import Literal
25
+ from typing import Union, Mapping, Any
26
+
27
+ # Pytorch check
28
+ try:
29
+ import torch
30
+ _has_torch = True
31
+ except ImportError:
32
+ _has_torch = False
33
+
34
+ # Image check
35
+ try:
36
+ from PIL import Image
37
+ _has_vision = True
38
+ except:
39
+ _has_vision = False
40
+
41
+
42
+ TOPICS_NAME = "topics.json"
43
+ CONFIG_NAME = "config.json"
44
+
45
+ HF_WEIGHTS_NAME = "topic_embeddings.bin" # default pytorch pkl
46
+ HF_SAFE_WEIGHTS_NAME = "topic_embeddings.safetensors" # safetensors version
47
+
48
+ CTFIDF_WEIGHTS_NAME = "ctfidf.bin" # default pytorch pkl
49
+ CTFIDF_SAFE_WEIGHTS_NAME = "ctfidf.safetensors" # safetensors version
50
+ CTFIDF_CFG_NAME = "ctfidf_config.json"
51
+
52
+ MODEL_CARD_TEMPLATE = """
53
+ ---
54
+ tags:
55
+ - bertopic
56
+ library_name: bertopic
57
+ pipeline_tag: {PIPELINE_TAG}
58
+ ---
59
+
60
+ # {MODEL_NAME}
61
+
62
+ This is a [BERTopic](https://github.com/MaartenGr/BERTopic) model.
63
+ BERTopic is a flexible and modular topic modeling framework that allows for the generation of easily interpretable topics from large datasets.
64
+
65
+ ## Usage
66
+
67
+ To use this model, please install BERTopic:
68
+
69
+ ```
70
+ pip install -U bertopic
71
+ ```
72
+
73
+ You can use the model as follows:
74
+
75
+ ```python
76
+ from bertopic import BERTopic
77
+ topic_model = BERTopic.load("{PATH}")
78
+
79
+ topic_model.get_topic_info()
80
+ ```
81
+
82
+ ## Topic overview
83
+
84
+ * Number of topics: {NR_TOPICS}
85
+ * Number of training documents: {NR_DOCUMENTS}
86
+
87
+ <details>
88
+ <summary>Click here for an overview of all topics.</summary>
89
+
90
+ {TOPICS}
91
+
92
+ </details>
93
+
94
+ ## Training hyperparameters
95
+
96
+ {HYPERPARAMS}
97
+
98
+ ## Framework versions
99
+
100
+ {FRAMEWORKS}
101
+ """
102
+
103
+
104
+
105
+ def push_to_hf_hub(
106
+ model,
107
+ repo_id: str,
108
+ commit_message: str = 'Add BERTopic model',
109
+ token: str = None,
110
+ revision: str = None,
111
+ private: bool = False,
112
+ create_pr: bool = False,
113
+ model_card: bool = True,
114
+ serialization: str = "safetensors",
115
+ save_embedding_model: Union[str, bool] = True,
116
+ save_ctfidf: bool = False,
117
+ ):
118
+ """ Push your BERTopic model to a HuggingFace Hub
119
+
120
+ Arguments:
121
+ repo_id: The name of your HuggingFace repository
122
+ commit_message: A commit message
123
+ token: Token to add if not already logged in
124
+ revision: Repository revision
125
+ private: Whether to create a private repository
126
+ create_pr: Whether to upload the model as a Pull Request
127
+ model_card: Whether to automatically create a modelcard
128
+ serialization: The type of serialization.
129
+ Either `safetensors` or `pytorch`
130
+ save_embedding_model: A pointer towards a HuggingFace model to be loaded in with
131
+ SentenceTransformers. E.g.,
132
+ `sentence-transformers/all-MiniLM-L6-v2`
133
+ save_ctfidf: Whether to save c-TF-IDF information
134
+ """
135
+ if not _has_hf_hub:
136
+ raise ValueError("Make sure you have the huggingface hub installed via `pip install --upgrade huggingface_hub`")
137
+
138
+ # Create repo if it doesn't exist yet and infer complete repo_id
139
+ repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
140
+ _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
141
+ repo_id = f"{repo_owner}/{repo_name}"
142
+
143
+ # Temporarily save model and push to HF
144
+ with TemporaryDirectory() as tmpdir:
145
+
146
+ # Save model weights and config.
147
+ model.save(tmpdir, serialization=serialization, save_embedding_model=save_embedding_model, save_ctfidf=save_ctfidf)
148
+
149
+ # Add README if it does not exist
150
+ try:
151
+ get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
152
+ except:
153
+ if model_card:
154
+ readme_text = generate_readme(model, repo_id)
155
+ readme_path = Path(tmpdir) / "README.md"
156
+ readme_path.write_text(readme_text, encoding='utf8')
157
+
158
+ # Upload model
159
+ return upload_folder(repo_id=repo_id, folder_path=tmpdir, revision=revision,
160
+ create_pr=create_pr, commit_message=commit_message)
161
+
162
+
163
+ def load_local_files(path):
164
+ """ Load local BERTopic files """
165
+ # Load json configs
166
+ topics = load_cfg_from_json(path / TOPICS_NAME)
167
+ params = load_cfg_from_json(path / CONFIG_NAME)
168
+
169
+ # Load Topic Embeddings
170
+ safetensor_path = path / HF_SAFE_WEIGHTS_NAME
171
+ if safetensor_path.is_file():
172
+ tensors = load_safetensors(safetensor_path)
173
+ else:
174
+ torch_path = path / HF_WEIGHTS_NAME
175
+ if torch_path.is_file():
176
+ tensors = torch.load(torch_path, map_location="cpu")
177
+
178
+ # c-TF-IDF
179
+ try:
180
+ ctfidf_tensors = None
181
+ safetensor_path = path / CTFIDF_SAFE_WEIGHTS_NAME
182
+ if safetensor_path.is_file():
183
+ ctfidf_tensors = load_safetensors(safetensor_path)
184
+ else:
185
+ torch_path = path / CTFIDF_WEIGHTS_NAME
186
+ if torch_path.is_file():
187
+ ctfidf_tensors = torch.load(torch_path, map_location="cpu")
188
+ ctfidf_config = load_cfg_from_json(path / CTFIDF_CFG_NAME)
189
+ except:
190
+ ctfidf_config, ctfidf_tensors = None, None
191
+
192
+ # Load images
193
+ images = None
194
+ if _has_vision:
195
+ try:
196
+ Image.open(path / "images/0.jpg")
197
+ _has_images = True
198
+ except:
199
+ _has_images = False
200
+
201
+ if _has_images:
202
+ topic_list = list(topics["topic_representations"].keys())
203
+ images = {}
204
+ for topic in topic_list:
205
+ image = Image.open(path / f"images/{topic}.jpg")
206
+ images[int(topic)] = image
207
+
208
+ return topics, params, tensors, ctfidf_tensors, ctfidf_config, images
209
+
210
+
211
+ def load_files_from_hf(path):
212
+ """ Load files from HuggingFace. """
213
+ path = str(path)
214
+
215
+ # Configs
216
+ topics = load_cfg_from_json(hf_hub_download(path, TOPICS_NAME, revision=None))
217
+ params = load_cfg_from_json(hf_hub_download(path, CONFIG_NAME, revision=None))
218
+
219
+ # Topic Embeddings
220
+ try:
221
+ tensors = hf_hub_download(path, HF_SAFE_WEIGHTS_NAME, revision=None)
222
+ tensors = load_safetensors(tensors)
223
+ except:
224
+ tensors = hf_hub_download(path, HF_WEIGHTS_NAME, revision=None)
225
+ tensors = torch.load(tensors, map_location="cpu")
226
+
227
+ # c-TF-IDF
228
+ try:
229
+ ctfidf_config = load_cfg_from_json(hf_hub_download(path, CTFIDF_CFG_NAME, revision=None))
230
+ try:
231
+ ctfidf_tensors = hf_hub_download(path, CTFIDF_SAFE_WEIGHTS_NAME, revision=None)
232
+ ctfidf_tensors = load_safetensors(ctfidf_tensors)
233
+ except:
234
+ ctfidf_tensors = hf_hub_download(path, CTFIDF_WEIGHTS_NAME, revision=None)
235
+ ctfidf_tensors = torch.load(ctfidf_tensors, map_location="cpu")
236
+ except:
237
+ ctfidf_config, ctfidf_tensors = None, None
238
+
239
+ # Load images if they exist
240
+ images = None
241
+ if _has_vision:
242
+ try:
243
+ hf_hub_download(path, "images/0.jpg", revision=None)
244
+ _has_images = True
245
+ except:
246
+ _has_images = False
247
+
248
+ if _has_images:
249
+ topic_list = list(topics["topic_representations"].keys())
250
+ images = {}
251
+ for topic in topic_list:
252
+ image = Image.open(hf_hub_download(path, f"images/{topic}.jpg", revision=None))
253
+ images[int(topic)] = image
254
+
255
+ return topics, params, tensors, ctfidf_tensors, ctfidf_config, images
256
+
257
+
258
+ def generate_readme(model, repo_id: str):
259
+ """ Generate README for HuggingFace model card """
260
+ model_card = MODEL_CARD_TEMPLATE
261
+ topic_table_head = "| Topic ID | Topic Keywords | Topic Frequency | Label | \n|----------|----------------|-----------------|-------| \n"
262
+
263
+ # Get Statistics
264
+ model_name = repo_id.split("/")[-1]
265
+ params = {param: value for param, value in model.get_params().items() if "model" not in param}
266
+ params = "\n".join([f"* {param}: {value}" for param, value in params.items()])
267
+ topics = sorted(list(set(model.topics_)))
268
+ nr_topics = str(len(set(model.topics_)))
269
+
270
+ if model.topic_sizes_ is not None:
271
+ nr_documents = str(sum(model.topic_sizes_.values()))
272
+ else:
273
+ nr_documents = ""
274
+
275
+ # Topic information
276
+ topic_keywords = [" - ".join(list(zip(*model.get_topic(topic)))[0][:5]) for topic in topics]
277
+ topic_freq = [model.get_topic_freq(topic) for topic in topics]
278
+ topic_labels = model.custom_labels_ if model.custom_labels_ else [model.topic_labels_[topic] for topic in topics]
279
+ topics = [f"| {topic} | {topic_keywords[index]} | {topic_freq[topic]} | {topic_labels[index]} | \n" for index, topic in enumerate(topics)]
280
+ topics = topic_table_head + "".join(topics)
281
+ frameworks = "\n".join([f"* {param}: {value}" for param, value in get_package_versions().items()])
282
+
283
+ # Fill Statistics into model card
284
+ model_card = model_card.replace("{MODEL_NAME}", model_name)
285
+ model_card = model_card.replace("{PATH}", repo_id)
286
+ model_card = model_card.replace("{NR_TOPICS}", nr_topics)
287
+ model_card = model_card.replace("{TOPICS}", topics.strip())
288
+ model_card = model_card.replace("{NR_DOCUMENTS}", nr_documents)
289
+ model_card = model_card.replace("{HYPERPARAMS}", params)
290
+ model_card = model_card.replace("{FRAMEWORKS}", frameworks)
291
+
292
+ # Fill Pipeline tag
293
+ has_visual_aspect = check_has_visual_aspect(model)
294
+ if not has_visual_aspect:
295
+ model_card = model_card.replace("{PIPELINE_TAG}", "text-classification")
296
+ else:
297
+ model_card = model_card.replace("pipeline_tag: {PIPELINE_TAG}\n","") # TODO add proper tag for this instance
298
+
299
+ return model_card
300
+
301
+
302
+ def save_hf(model, save_directory, serialization: str):
303
+ """ Save topic embeddings, either safely (using safetensors) or using legacy pytorch """
304
+ tensors = torch.from_numpy(np.array(model.topic_embeddings_, dtype=np.float32))
305
+ tensors = {"topic_embeddings": tensors}
306
+
307
+ if serialization == "safetensors":
308
+ save_safetensors(save_directory / HF_SAFE_WEIGHTS_NAME, tensors)
309
+ if serialization == "pytorch":
310
+ assert _has_torch, "`pip install pytorch` to save as bin"
311
+ torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
312
+
313
+
314
+ def save_ctfidf(model,
315
+ save_directory: str,
316
+ serialization: str):
317
+ """ Save c-TF-IDF sparse matrix """
318
+ indptr = torch.from_numpy(model.c_tf_idf_.indptr)
319
+ indices = torch.from_numpy(model.c_tf_idf_.indices)
320
+ data = torch.from_numpy(model.c_tf_idf_.data)
321
+ shape = torch.from_numpy(np.array(model.c_tf_idf_.shape))
322
+ diag = torch.from_numpy(np.array(model.ctfidf_model._idf_diag.data))
323
+ tensors = {
324
+ "indptr": indptr,
325
+ "indices": indices,
326
+ "data": data,
327
+ "shape": shape,
328
+ "diag": diag
329
+ }
330
+
331
+ if serialization == "safetensors":
332
+ save_safetensors(save_directory / CTFIDF_SAFE_WEIGHTS_NAME, tensors)
333
+ if serialization == "pytorch":
334
+ assert _has_torch, "`pip install pytorch` to save as .bin"
335
+ torch.save(tensors, save_directory / CTFIDF_WEIGHTS_NAME)
336
+
337
+
338
+ def save_ctfidf_config(model, path):
339
+ """ Save parameters to recreate CountVectorizer and c-TF-IDF """
340
+ config = {}
341
+
342
+ # Recreate ClassTfidfTransformer
343
+ config["ctfidf_model"] = {
344
+ "bm25_weighting": model.ctfidf_model.bm25_weighting,
345
+ "reduce_frequent_words": model.ctfidf_model.reduce_frequent_words
346
+ }
347
+
348
+ # Recreate CountVectorizer
349
+ cv_params = model.vectorizer_model.get_params()
350
+ del cv_params["tokenizer"], cv_params["preprocessor"], cv_params["dtype"]
351
+ if not isinstance(cv_params["analyzer"], str):
352
+ del cv_params["analyzer"]
353
+
354
+ config["vectorizer_model"] = {
355
+ "params": cv_params,
356
+ "vocab": model.vectorizer_model.vocabulary_
357
+ }
358
+
359
+ with path.open('w') as f:
360
+ json.dump(config, f, indent=2)
361
+
362
+
363
+ def save_config(model, path: str, embedding_model):
364
+ """ Save BERTopic configuration """
365
+ path = Path(path)
366
+ params = model.get_params()
367
+ config = {param: value for param, value in params.items() if "model" not in param}
368
+
369
+ # Embedding model tag to be used in sentence-transformers
370
+ if isinstance(embedding_model, str):
371
+ config["embedding_model"] = embedding_model
372
+
373
+ with path.open('w') as f:
374
+ json.dump(config, f, indent=2)
375
+
376
+ return config
377
+
378
+ def check_has_visual_aspect(model):
379
+ """Check if model has visual aspect"""
380
+ if _has_vision:
381
+ for aspect, value in model.topic_aspects_.items():
382
+ if isinstance(value[0], Image.Image):
383
+ visual_aspects = model.topic_aspects_[aspect]
384
+ return True
385
+
386
+ def save_images(model, path: str):
387
+ """ Save topic images """
388
+ if _has_vision:
389
+ visual_aspects = None
390
+ for aspect, value in model.topic_aspects_.items():
391
+ if isinstance(value[0], Image.Image):
392
+ visual_aspects = model.topic_aspects_[aspect]
393
+ break
394
+
395
+ if visual_aspects is not None:
396
+ path.mkdir(exist_ok=True, parents=True)
397
+ for topic, image in visual_aspects.items():
398
+ image.save(path / f"{topic}.jpg")
399
+
400
+
401
+ def save_topics(model, path: str):
402
+ """ Save Topic-specific information """
403
+ path = Path(path)
404
+
405
+ if _has_vision:
406
+ selected_topic_aspects = {}
407
+ for aspect, value in model.topic_aspects_.items():
408
+ if not isinstance(value[0], Image.Image):
409
+ selected_topic_aspects[aspect] = value
410
+ else:
411
+ selected_topic_aspects["Visual_Aspect"] = True
412
+ else:
413
+ selected_topic_aspects = model.topic_aspects_
414
+
415
+ topics = {
416
+ "topic_representations": model.topic_representations_,
417
+ "topics": [int(topic) for topic in model.topics_],
418
+ "topic_sizes": model.topic_sizes_,
419
+ "topic_mapper": np.array(model.topic_mapper_.mappings_, dtype=int).tolist(),
420
+ "topic_labels": model.topic_labels_,
421
+ "custom_labels": model.custom_labels_,
422
+ "_outliers": int(model._outliers),
423
+ "topic_aspects": selected_topic_aspects
424
+ }
425
+
426
+ with path.open('w') as f:
427
+ json.dump(topics, f, indent=2, cls=NumpyEncoder)
428
+
429
+
430
+ def load_cfg_from_json(json_file: Union[str, os.PathLike]):
431
+ """ Load configuration from json """
432
+ with open(json_file, "r", encoding="utf-8") as reader:
433
+ text = reader.read()
434
+ return json.loads(text)
435
+
436
+
437
+ class NumpyEncoder(json.JSONEncoder):
438
+ def default(self, obj):
439
+ if isinstance(obj, np.integer):
440
+ return int(obj)
441
+ if isinstance(obj, np.floating):
442
+ return float(obj)
443
+ return super(NumpyEncoder, self).default(obj)
444
+
445
+
446
+
447
+ def get_package_versions():
448
+ """ Get versions of main dependencies of BERTopic """
449
+ try:
450
+ import platform
451
+ from numpy import __version__ as np_version
452
+
453
+ try:
454
+ from importlib.metadata import version
455
+ hdbscan_version = version('hdbscan')
456
+ except:
457
+ hdbscan_version = None
458
+
459
+ from umap import __version__ as umap_version
460
+ from pandas import __version__ as pandas_version
461
+ from sklearn import __version__ as sklearn_version
462
+ from sentence_transformers import __version__ as sbert_version
463
+ from numba import __version__ as numba_version
464
+ from transformers import __version__ as transformers_version
465
+
466
+ from plotly import __version__ as plotly_version
467
+ return {"Numpy": np_version, "HDBSCAN": hdbscan_version, "UMAP": umap_version,
468
+ "Pandas": pandas_version, "Scikit-Learn": sklearn_version,
469
+ "Sentence-transformers": sbert_version, "Transformers": transformers_version,
470
+ "Numba": numba_version, "Plotly": plotly_version, "Python": platform.python_version()}
471
+ except Exception as e:
472
+ return e
473
+
474
+
475
+ def load_safetensors(path):
476
+ """ Load safetensors and check whether it is installed """
477
+ try:
478
+ import safetensors.torch
479
+ import safetensors
480
+ return safetensors.torch.load_file(path, device="cpu")
481
+ except ImportError:
482
+ raise ValueError("`pip install safetensors` to load .safetensors")
483
+
484
+
485
+ def save_safetensors(path, tensors):
486
+ """ Save safetensors and check whether it is installed """
487
+ try:
488
+ import safetensors.torch
489
+ import safetensors
490
+ safetensors.torch.save_file(tensors, path)
491
+ except ImportError:
492
+ raise ValueError("`pip install safetensors` to save as .safetensors")
BERTopic/bertopic/_utils.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import logging
4
+ from collections.abc import Iterable
5
+ from scipy.sparse import csr_matrix
6
+ from scipy.spatial.distance import squareform
7
+
8
+
9
+ class MyLogger:
10
+ def __init__(self, level):
11
+ self.logger = logging.getLogger('BERTopic')
12
+ self.set_level(level)
13
+ self._add_handler()
14
+ self.logger.propagate = False
15
+
16
+ def info(self, message):
17
+ self.logger.info(f"{message}")
18
+
19
+ def warning(self, message):
20
+ self.logger.warning(f"WARNING: {message}")
21
+
22
+ def set_level(self, level):
23
+ levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
24
+ if level in levels:
25
+ self.logger.setLevel(level)
26
+
27
+ def _add_handler(self):
28
+ sh = logging.StreamHandler()
29
+ sh.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(message)s'))
30
+ self.logger.addHandler(sh)
31
+
32
+ # Remove duplicate handlers
33
+ if len(self.logger.handlers) > 1:
34
+ self.logger.handlers = [self.logger.handlers[0]]
35
+
36
+
37
+ def check_documents_type(documents):
38
+ """ Check whether the input documents are indeed a list of strings """
39
+ if isinstance(documents, pd.DataFrame):
40
+ raise TypeError("Make sure to supply a list of strings, not a dataframe.")
41
+ elif isinstance(documents, Iterable) and not isinstance(documents, str):
42
+ if not any([isinstance(doc, str) for doc in documents]):
43
+ raise TypeError("Make sure that the iterable only contains strings.")
44
+ else:
45
+ raise TypeError("Make sure that the documents variable is an iterable containing strings only.")
46
+
47
+
48
+ def check_embeddings_shape(embeddings, docs):
49
+ """ Check if the embeddings have the correct shape """
50
+ if embeddings is not None:
51
+ if not any([isinstance(embeddings, np.ndarray), isinstance(embeddings, csr_matrix)]):
52
+ raise ValueError("Make sure to input embeddings as a numpy array or scipy.sparse.csr.csr_matrix. ")
53
+ else:
54
+ if embeddings.shape[0] != len(docs):
55
+ raise ValueError("Make sure that the embeddings are a numpy array with shape: "
56
+ "(len(docs), vector_dim) where vector_dim is the dimensionality "
57
+ "of the vector embeddings. ")
58
+
59
+
60
+ def check_is_fitted(topic_model):
61
+ """ Checks if the model was fitted by verifying the presence of self.matches
62
+
63
+ Arguments:
64
+ model: BERTopic instance for which the check is performed.
65
+
66
+ Returns:
67
+ None
68
+
69
+ Raises:
70
+ ValueError: If the matches were not found.
71
+ """
72
+ msg = ("This %(name)s instance is not fitted yet. Call 'fit' with "
73
+ "appropriate arguments before using this estimator.")
74
+
75
+ if topic_model.topics_ is None:
76
+ raise ValueError(msg % {'name': type(topic_model).__name__})
77
+
78
+
79
+ class NotInstalled:
80
+ """
81
+ This object is used to notify the user that additional dependencies need to be
82
+ installed in order to use the string matching model.
83
+ """
84
+
85
+ def __init__(self, tool, dep, custom_msg=None):
86
+ self.tool = tool
87
+ self.dep = dep
88
+
89
+ msg = f"In order to use {self.tool} you will need to install via;\n\n"
90
+ if custom_msg is not None:
91
+ msg += custom_msg
92
+ else:
93
+ msg += f"pip install bertopic[{self.dep}]\n\n"
94
+ self.msg = msg
95
+
96
+ def __getattr__(self, *args, **kwargs):
97
+ raise ModuleNotFoundError(self.msg)
98
+
99
+ def __call__(self, *args, **kwargs):
100
+ raise ModuleNotFoundError(self.msg)
101
+
102
+
103
+ def validate_distance_matrix(X, n_samples):
104
+ """ Validate the distance matrix and convert it to a condensed distance matrix
105
+ if necessary.
106
+
107
+ A valid distance matrix is either a square matrix of shape (n_samples, n_samples)
108
+ with zeros on the diagonal and non-negative values or condensed distance matrix
109
+ of shape (n_samples * (n_samples - 1) / 2,) containing the upper triangular of the
110
+ distance matrix.
111
+
112
+ Arguments:
113
+ X: Distance matrix to validate.
114
+ n_samples: Number of samples in the dataset.
115
+
116
+ Returns:
117
+ X: Validated distance matrix.
118
+
119
+ Raises:
120
+ ValueError: If the distance matrix is not valid.
121
+ """
122
+ # Make sure it is the 1-D condensed distance matrix with zeros on the diagonal
123
+ s = X.shape
124
+ if len(s) == 1:
125
+ # check it has correct size
126
+ n = s[0]
127
+ if n != (n_samples * (n_samples - 1) / 2):
128
+ raise ValueError("The condensed distance matrix must have "
129
+ "shape (n*(n-1)/2,).")
130
+ elif len(s) == 2:
131
+ # check it has correct size
132
+ if (s[0] != n_samples) or (s[1] != n_samples):
133
+ raise ValueError("The distance matrix must be of shape "
134
+ "(n, n) where n is the number of samples.")
135
+ # force zero diagonal and convert to condensed
136
+ np.fill_diagonal(X, 0)
137
+ X = squareform(X)
138
+ else:
139
+ raise ValueError("The distance matrix must be either a 1-D condensed "
140
+ "distance matrix of shape (n*(n-1)/2,) or a "
141
+ "2-D square distance matrix of shape (n, n)."
142
+ "where n is the number of documents."
143
+ "Got a distance matrix of shape %s" % str(s))
144
+
145
+ # Make sure its entries are non-negative
146
+ if np.any(X < 0):
147
+ raise ValueError("Distance matrix cannot contain negative values.")
148
+
149
+ return X
BERTopic/bertopic/backend/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._base import BaseEmbedder
2
+ from ._word_doc import WordDocEmbedder
3
+ from ._utils import languages
4
+ from bertopic._utils import NotInstalled
5
+
6
+ # OpenAI Embeddings
7
+ try:
8
+ from bertopic.backend._openai import OpenAIBackend
9
+ except ModuleNotFoundError:
10
+ msg = "`pip install openai` \n\n"
11
+ OpenAIBackend = NotInstalled("OpenAI", "OpenAI", custom_msg=msg)
12
+
13
+ # Cohere Embeddings
14
+ try:
15
+ from bertopic.backend._cohere import CohereBackend
16
+ except ModuleNotFoundError:
17
+ msg = "`pip install cohere` \n\n"
18
+ CohereBackend = NotInstalled("Cohere", "Cohere", custom_msg=msg)
19
+
20
+ # Multimodal Embeddings
21
+ try:
22
+ from bertopic.backend._multimodal import MultiModalBackend
23
+ except ModuleNotFoundError:
24
+ msg = "`pip install bertopic[vision]` \n\n"
25
+ MultiModalBackend = NotInstalled("Vision", "Vision", custom_msg=msg)
26
+
27
+
28
+ __all__ = [
29
+ "BaseEmbedder",
30
+ "WordDocEmbedder",
31
+ "OpenAIBackend",
32
+ "CohereBackend",
33
+ "MultiModalBackend",
34
+ "languages"
35
+ ]
BERTopic/bertopic/backend/_base.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List
3
+
4
+
5
+ class BaseEmbedder:
6
+ """ The Base Embedder used for creating embedding models
7
+
8
+ Arguments:
9
+ embedding_model: The main embedding model to be used for extracting
10
+ document and word embedding
11
+ word_embedding_model: The embedding model used for extracting word
12
+ embeddings only. If this model is selected,
13
+ then the `embedding_model` is purely used for
14
+ creating document embeddings.
15
+ """
16
+ def __init__(self,
17
+ embedding_model=None,
18
+ word_embedding_model=None):
19
+ self.embedding_model = embedding_model
20
+ self.word_embedding_model = word_embedding_model
21
+
22
+ def embed(self,
23
+ documents: List[str],
24
+ verbose: bool = False) -> np.ndarray:
25
+ """ Embed a list of n documents/words into an n-dimensional
26
+ matrix of embeddings
27
+
28
+ Arguments:
29
+ documents: A list of documents or words to be embedded
30
+ verbose: Controls the verbosity of the process
31
+
32
+ Returns:
33
+ Document/words embeddings with shape (n, m) with `n` documents/words
34
+ that each have an embeddings size of `m`
35
+ """
36
+ pass
37
+
38
+ def embed_words(self,
39
+ words: List[str],
40
+ verbose: bool = False) -> np.ndarray:
41
+ """ Embed a list of n words into an n-dimensional
42
+ matrix of embeddings
43
+
44
+ Arguments:
45
+ words: A list of words to be embedded
46
+ verbose: Controls the verbosity of the process
47
+
48
+ Returns:
49
+ Word embeddings with shape (n, m) with `n` words
50
+ that each have an embeddings size of `m`
51
+
52
+ """
53
+ return self.embed(words, verbose)
54
+
55
+ def embed_documents(self,
56
+ document: List[str],
57
+ verbose: bool = False) -> np.ndarray:
58
+ """ Embed a list of n words into an n-dimensional
59
+ matrix of embeddings
60
+
61
+ Arguments:
62
+ document: A list of documents to be embedded
63
+ verbose: Controls the verbosity of the process
64
+
65
+ Returns:
66
+ Document embeddings with shape (n, m) with `n` documents
67
+ that each have an embeddings size of `m`
68
+ """
69
+ return self.embed(document, verbose)
BERTopic/bertopic/backend/_cohere.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+ from typing import Any, List, Mapping
5
+ from bertopic.backend import BaseEmbedder
6
+
7
+
8
+ class CohereBackend(BaseEmbedder):
9
+ """ Cohere Embedding Model
10
+
11
+ Arguments:
12
+ client: A `cohere` client.
13
+ embedding_model: A Cohere model. Default is "large".
14
+ For an overview of models see:
15
+ https://docs.cohere.ai/docs/generation-card
16
+ delay_in_seconds: If a `batch_size` is given, use this set
17
+ the delay in seconds between batches.
18
+ batch_size: The size of each batch.
19
+ embed_kwargs: Kwargs passed to `cohere.Client.embed`.
20
+ Can be used to define additional parameters
21
+ such as `input_type`
22
+
23
+ Examples:
24
+
25
+ ```python
26
+ import cohere
27
+ from bertopic.backend import CohereBackend
28
+
29
+ client = cohere.Client("APIKEY")
30
+ cohere_model = CohereBackend(client)
31
+ ```
32
+
33
+ If you want to specify `input_type`:
34
+
35
+ ```python
36
+ cohere_model = CohereBackend(
37
+ client,
38
+ embedding_model="embed-english-v3.0",
39
+ embed_kwargs={"input_type": "clustering"}
40
+ )
41
+ ```
42
+ """
43
+ def __init__(self,
44
+ client,
45
+ embedding_model: str = "large",
46
+ delay_in_seconds: float = None,
47
+ batch_size: int = None,
48
+ embed_kwargs: Mapping[str, Any] = {}):
49
+ super().__init__()
50
+ self.client = client
51
+ self.embedding_model = embedding_model
52
+ self.delay_in_seconds = delay_in_seconds
53
+ self.batch_size = batch_size
54
+ self.embed_kwargs = embed_kwargs
55
+
56
+ if self.embed_kwargs.get("model"):
57
+ self.embedding_model = embed_kwargs.get("model")
58
+ else:
59
+ self.embed_kwargs["model"] = self.embedding_model
60
+
61
+ def embed(self,
62
+ documents: List[str],
63
+ verbose: bool = False) -> np.ndarray:
64
+ """ Embed a list of n documents/words into an n-dimensional
65
+ matrix of embeddings
66
+
67
+ Arguments:
68
+ documents: A list of documents or words to be embedded
69
+ verbose: Controls the verbosity of the process
70
+
71
+ Returns:
72
+ Document/words embeddings with shape (n, m) with `n` documents/words
73
+ that each have an embeddings size of `m`
74
+ """
75
+ # Batch-wise embedding extraction
76
+ if self.batch_size is not None:
77
+ embeddings = []
78
+ for batch in tqdm(self._chunks(documents), disable=not verbose):
79
+ response = self.client.embed(batch, **self.embed_kwargs)
80
+ embeddings.extend(response.embeddings)
81
+
82
+ # Delay subsequent calls
83
+ if self.delay_in_seconds:
84
+ time.sleep(self.delay_in_seconds)
85
+
86
+ # Extract embeddings all at once
87
+ else:
88
+ response = self.client.embed(documents, **self.embed_kwargs)
89
+ embeddings = response.embeddings
90
+ return np.array(embeddings)
91
+
92
+ def _chunks(self, documents):
93
+ for i in range(0, len(documents), self.batch_size):
94
+ yield documents[i:i + self.batch_size]
BERTopic/bertopic/backend/_flair.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from tqdm import tqdm
3
+ from typing import Union, List
4
+ from flair.data import Sentence
5
+ from flair.embeddings import DocumentEmbeddings, TokenEmbeddings, DocumentPoolEmbeddings
6
+
7
+ from bertopic.backend import BaseEmbedder
8
+
9
+
10
+ class FlairBackend(BaseEmbedder):
11
+ """ Flair Embedding Model
12
+
13
+ The Flair embedding model used for generating document and
14
+ word embeddings.
15
+
16
+ Arguments:
17
+ embedding_model: A Flair embedding model
18
+
19
+ Examples:
20
+
21
+ ```python
22
+ from bertopic.backend import FlairBackend
23
+ from flair.embeddings import WordEmbeddings, DocumentPoolEmbeddings
24
+
25
+ # Create a Flair Embedding model
26
+ glove_embedding = WordEmbeddings('crawl')
27
+ document_glove_embeddings = DocumentPoolEmbeddings([glove_embedding])
28
+
29
+ # Pass the Flair model to create a new backend
30
+ flair_embedder = FlairBackend(document_glove_embeddings)
31
+ ```
32
+ """
33
+ def __init__(self, embedding_model: Union[TokenEmbeddings, DocumentEmbeddings]):
34
+ super().__init__()
35
+
36
+ # Flair word embeddings
37
+ if isinstance(embedding_model, TokenEmbeddings):
38
+ self.embedding_model = DocumentPoolEmbeddings([embedding_model])
39
+
40
+ # Flair document embeddings + disable fine tune to prevent CUDA OOM
41
+ # https://github.com/flairNLP/flair/issues/1719
42
+ elif isinstance(embedding_model, DocumentEmbeddings):
43
+ if "fine_tune" in embedding_model.__dict__:
44
+ embedding_model.fine_tune = False
45
+ self.embedding_model = embedding_model
46
+
47
+ else:
48
+ raise ValueError("Please select a correct Flair model by either using preparing a token or document "
49
+ "embedding model: \n"
50
+ "`from flair.embeddings import TransformerDocumentEmbeddings` \n"
51
+ "`roberta = TransformerDocumentEmbeddings('roberta-base')`")
52
+
53
+ def embed(self,
54
+ documents: List[str],
55
+ verbose: bool = False) -> np.ndarray:
56
+ """ Embed a list of n documents/words into an n-dimensional
57
+ matrix of embeddings
58
+
59
+ Arguments:
60
+ documents: A list of documents or words to be embedded
61
+ verbose: Controls the verbosity of the process
62
+
63
+ Returns:
64
+ Document/words embeddings with shape (n, m) with `n` documents/words
65
+ that each have an embeddings size of `m`
66
+ """
67
+ embeddings = []
68
+ for document in tqdm(documents, disable=not verbose):
69
+ try:
70
+ sentence = Sentence(document) if document else Sentence("an empty document")
71
+ self.embedding_model.embed(sentence)
72
+ except RuntimeError:
73
+ sentence = Sentence("an empty document")
74
+ self.embedding_model.embed(sentence)
75
+ embedding = sentence.embedding.detach().cpu().numpy()
76
+ embeddings.append(embedding)
77
+ embeddings = np.asarray(embeddings)
78
+ return embeddings
BERTopic/bertopic/backend/_gensim.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from tqdm import tqdm
3
+ from typing import List
4
+ from bertopic.backend import BaseEmbedder
5
+ from gensim.models.keyedvectors import Word2VecKeyedVectors
6
+
7
+
8
+ class GensimBackend(BaseEmbedder):
9
+ """ Gensim Embedding Model
10
+
11
+ The Gensim embedding model is typically used for word embeddings with
12
+ GloVe, Word2Vec or FastText.
13
+
14
+ Arguments:
15
+ embedding_model: A Gensim embedding model
16
+
17
+ Examples:
18
+
19
+ ```python
20
+ from bertopic.backend import GensimBackend
21
+ import gensim.downloader as api
22
+
23
+ ft = api.load('fasttext-wiki-news-subwords-300')
24
+ ft_embedder = GensimBackend(ft)
25
+ ```
26
+ """
27
+ def __init__(self, embedding_model: Word2VecKeyedVectors):
28
+ super().__init__()
29
+
30
+ if isinstance(embedding_model, Word2VecKeyedVectors):
31
+ self.embedding_model = embedding_model
32
+ else:
33
+ raise ValueError("Please select a correct Gensim model: \n"
34
+ "`import gensim.downloader as api` \n"
35
+ "`ft = api.load('fasttext-wiki-news-subwords-300')`")
36
+
37
+ def embed(self,
38
+ documents: List[str],
39
+ verbose: bool = False) -> np.ndarray:
40
+ """ Embed a list of n documents/words into an n-dimensional
41
+ matrix of embeddings
42
+
43
+ Arguments:
44
+ documents: A list of documents or words to be embedded
45
+ verbose: Controls the verbosity of the process
46
+
47
+ Returns:
48
+ Document/words embeddings with shape (n, m) with `n` documents/words
49
+ that each have an embeddings size of `m`
50
+ """
51
+ vector_shape = self.embedding_model.get_vector(list(self.embedding_model.index_to_key)[0]).shape[0]
52
+ empty_vector = np.zeros(vector_shape)
53
+
54
+ # Extract word embeddings and pool to document-level
55
+ embeddings = []
56
+ for doc in tqdm(documents, disable=not verbose, position=0, leave=True):
57
+ embedding = [self.embedding_model.get_vector(word) for word in doc.split()
58
+ if word in self.embedding_model.key_to_index]
59
+
60
+ if len(embedding) > 0:
61
+ embeddings.append(np.mean(embedding, axis=0))
62
+ else:
63
+ embeddings.append(empty_vector)
64
+
65
+ embeddings = np.array(embeddings)
66
+ return embeddings
BERTopic/bertopic/backend/_hftransformers.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from tqdm import tqdm
4
+ from typing import List
5
+ from torch.utils.data import Dataset
6
+ from sklearn.preprocessing import normalize
7
+ from transformers.pipelines import Pipeline
8
+
9
+ from bertopic.backend import BaseEmbedder
10
+
11
+
12
+ class HFTransformerBackend(BaseEmbedder):
13
+ """ Hugging Face transformers model
14
+
15
+ This uses the `transformers.pipelines.pipeline` to define and create
16
+ a feature generation pipeline from which embeddings can be extracted.
17
+
18
+ Arguments:
19
+ embedding_model: A Hugging Face feature extraction pipeline
20
+
21
+ Examples:
22
+
23
+ To use a Hugging Face transformers model, load in a pipeline and point
24
+ to any model found on their model hub (https://huggingface.co/models):
25
+
26
+ ```python
27
+ from bertopic.backend import HFTransformerBackend
28
+ from transformers.pipelines import pipeline
29
+
30
+ hf_model = pipeline("feature-extraction", model="distilbert-base-cased")
31
+ embedding_model = HFTransformerBackend(hf_model)
32
+ ```
33
+ """
34
+ def __init__(self, embedding_model: Pipeline):
35
+ super().__init__()
36
+
37
+ if isinstance(embedding_model, Pipeline):
38
+ self.embedding_model = embedding_model
39
+ else:
40
+ raise ValueError("Please select a correct transformers pipeline. For example: "
41
+ "pipeline('feature-extraction', model='distilbert-base-cased', device=0)")
42
+
43
+ def embed(self,
44
+ documents: List[str],
45
+ verbose: bool = False) -> np.ndarray:
46
+ """ Embed a list of n documents/words into an n-dimensional
47
+ matrix of embeddings
48
+
49
+ Arguments:
50
+ documents: A list of documents or words to be embedded
51
+ verbose: Controls the verbosity of the process
52
+
53
+ Returns:
54
+ Document/words embeddings with shape (n, m) with `n` documents/words
55
+ that each have an embeddings size of `m`
56
+ """
57
+ dataset = MyDataset(documents)
58
+
59
+ embeddings = []
60
+ for document, features in tqdm(zip(documents, self.embedding_model(dataset, truncation=True, padding=True)),
61
+ total=len(dataset), disable=not verbose):
62
+ embeddings.append(self._embed(document, features))
63
+
64
+ return np.array(embeddings)
65
+
66
+ def _embed(self,
67
+ document: str,
68
+ features: np.ndarray) -> np.ndarray:
69
+ """ Mean pooling
70
+
71
+ Arguments:
72
+ document: The document for which to extract the attention mask
73
+ features: The embeddings for each token
74
+
75
+ Adopted from:
76
+ https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2#usage-huggingface-transformers
77
+ """
78
+ token_embeddings = np.array(features)
79
+ attention_mask = self.embedding_model.tokenizer(document, truncation=True, padding=True, return_tensors="np")["attention_mask"]
80
+ input_mask_expanded = np.broadcast_to(np.expand_dims(attention_mask, -1), token_embeddings.shape)
81
+ sum_embeddings = np.sum(token_embeddings * input_mask_expanded, 1)
82
+ sum_mask = np.clip(input_mask_expanded.sum(1), a_min=1e-9, a_max=input_mask_expanded.sum(1).max())
83
+ embedding = normalize(sum_embeddings / sum_mask)[0]
84
+ return embedding
85
+
86
+
87
+ class MyDataset(Dataset):
88
+ """ Dataset to pass to `transformers.pipelines.pipeline` """
89
+ def __init__(self, docs):
90
+ self.docs = docs
91
+
92
+ def __len__(self):
93
+ return len(self.docs)
94
+
95
+ def __getitem__(self, idx):
96
+ return self.docs[idx]
BERTopic/bertopic/backend/_multimodal.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ from PIL import Image
4
+ from tqdm import tqdm
5
+ from typing import List, Union
6
+ from sentence_transformers import SentenceTransformer
7
+
8
+ from bertopic.backend import BaseEmbedder
9
+
10
+
11
+ class MultiModalBackend(BaseEmbedder):
12
+ """ Multimodal backend using Sentence-transformers
13
+
14
+ The sentence-transformers embedding model used for
15
+ generating word, document, and image embeddings.
16
+
17
+ Arguments:
18
+ embedding_model: A sentence-transformers embedding model that
19
+ can either embed both images and text or only text.
20
+ If it only embeds text, then `image_model` needs
21
+ to be used to embed the images.
22
+ image_model: A sentence-transformers embedding model that is used
23
+ to embed only images.
24
+ batch_size: The sizes of image batches to pass
25
+
26
+ Examples:
27
+
28
+ To create a model, you can load in a string pointing to a
29
+ sentence-transformers model:
30
+
31
+ ```python
32
+ from bertopic.backend import MultiModalBackend
33
+
34
+ sentence_model = MultiModalBackend("clip-ViT-B-32")
35
+ ```
36
+
37
+ or you can instantiate a model yourself:
38
+ ```python
39
+ from bertopic.backend import MultiModalBackend
40
+ from sentence_transformers import SentenceTransformer
41
+
42
+ embedding_model = SentenceTransformer("clip-ViT-B-32")
43
+ sentence_model = MultiModalBackend(embedding_model)
44
+ ```
45
+ """
46
+ def __init__(self,
47
+ embedding_model: Union[str, SentenceTransformer],
48
+ image_model: Union[str, SentenceTransformer] = None,
49
+ batch_size: int = 32):
50
+ super().__init__()
51
+ self.batch_size = batch_size
52
+
53
+ # Text or Text+Image model
54
+ if isinstance(embedding_model, SentenceTransformer):
55
+ self.embedding_model = embedding_model
56
+ elif isinstance(embedding_model, str):
57
+ self.embedding_model = SentenceTransformer(embedding_model)
58
+ else:
59
+ raise ValueError("Please select a correct SentenceTransformers model: \n"
60
+ "`from sentence_transformers import SentenceTransformer` \n"
61
+ "`model = SentenceTransformer('clip-ViT-B-32')`")
62
+
63
+ # Image Model
64
+ self.image_model = None
65
+ if image_model is not None:
66
+ if isinstance(image_model, SentenceTransformer):
67
+ self.image_model = image_model
68
+ elif isinstance(image_model, str):
69
+ self.image_model = SentenceTransformer(image_model)
70
+ else:
71
+ raise ValueError("Please select a correct SentenceTransformers model: \n"
72
+ "`from sentence_transformers import SentenceTransformer` \n"
73
+ "`model = SentenceTransformer('clip-ViT-B-32')`")
74
+
75
+ try:
76
+ self.tokenizer = self.embedding_model._first_module().processor.tokenizer
77
+ except AttributeError:
78
+ self.tokenizer = self.embedding_model.tokenizer
79
+ except:
80
+ self.tokenizer = None
81
+
82
+ def embed(self,
83
+ documents: List[str],
84
+ images: List[str] = None,
85
+ verbose: bool = False) -> np.ndarray:
86
+ """ Embed a list of n documents/words into an n-dimensional
87
+ matrix of embeddings
88
+
89
+ Arguments:
90
+ documents: A list of documents or words to be embedded
91
+ verbose: Controls the verbosity of the process
92
+
93
+ Returns:
94
+ Document/words embeddings with shape (n, m) with `n` documents/words
95
+ that each have an embeddings size of `m`
96
+ """
97
+ # Embed documents
98
+ doc_embeddings = None
99
+ if documents[0] is not None:
100
+ doc_embeddings = self.embed_documents(documents)
101
+
102
+ # Embed images
103
+ image_embeddings = None
104
+ if isinstance(images, list):
105
+ image_embeddings = self.embed_images(images, verbose)
106
+
107
+ # Average embeddings
108
+ averaged_embeddings = None
109
+ if doc_embeddings is not None and image_embeddings is not None:
110
+ averaged_embeddings = np.mean([doc_embeddings, image_embeddings], axis=0)
111
+
112
+ if averaged_embeddings is not None:
113
+ return averaged_embeddings
114
+ elif doc_embeddings is not None:
115
+ return doc_embeddings
116
+ elif image_embeddings is not None:
117
+ return image_embeddings
118
+
119
+ def embed_documents(self,
120
+ documents: List[str],
121
+ verbose: bool = False) -> np.ndarray:
122
+ """ Embed a list of n documents/words into an n-dimensional
123
+ matrix of embeddings
124
+
125
+ Arguments:
126
+ documents: A list of documents or words to be embedded
127
+ verbose: Controls the verbosity of the process
128
+
129
+ Returns:
130
+ Document/words embeddings with shape (n, m) with `n` documents/words
131
+ that each have an embeddings size of `m`
132
+ """
133
+ truncated_docs = [self._truncate_document(doc) for doc in documents]
134
+ embeddings = self.embedding_model.encode(truncated_docs, show_progress_bar=verbose)
135
+ return embeddings
136
+
137
+ def embed_words(self, words: List[str], verbose: bool = False) -> np.ndarray:
138
+ """ Embed a list of n words into an n-dimensional
139
+ matrix of embeddings
140
+
141
+ Arguments:
142
+ words: A list of words to be embedded
143
+ verbose: Controls the verbosity of the process
144
+
145
+ Returns:
146
+ Document/words embeddings with shape (n, m) with `n` documents/words
147
+ that each have an embeddings size of `m`
148
+ """
149
+ embeddings = self.embedding_model.encode(words, show_progress_bar=verbose)
150
+ return embeddings
151
+
152
+ def embed_images(self, images, verbose):
153
+ if self.batch_size:
154
+ nr_iterations = int(np.ceil(len(images) / self.batch_size))
155
+
156
+ # Embed images per batch
157
+ embeddings = []
158
+ for i in tqdm(range(nr_iterations), disable=not verbose):
159
+ start_index = i * self.batch_size
160
+ end_index = (i * self.batch_size) + self.batch_size
161
+
162
+ images_to_embed = [Image.open(image) if isinstance(image, str) else image for image in images[start_index:end_index]]
163
+ if self.image_model is not None:
164
+ img_emb = self.image_model.encode(images_to_embed)
165
+ else:
166
+ img_emb = self.embedding_model.encode(images_to_embed, show_progress_bar=False)
167
+ embeddings.extend(img_emb.tolist())
168
+
169
+ # Close images
170
+ if isinstance(images[0], str):
171
+ for image in images_to_embed:
172
+ image.close()
173
+ embeddings = np.array(embeddings)
174
+ else:
175
+ images_to_embed = [Image.open(filepath) for filepath in images]
176
+ if self.image_model is not None:
177
+ embeddings = self.image_model.encode(images_to_embed)
178
+ else:
179
+ embeddings = self.embedding_model.encode(images_to_embed, show_progress_bar=False)
180
+ return embeddings
181
+
182
+ def _truncate_document(self, document):
183
+ if self.tokenizer:
184
+ tokens = self.tokenizer.encode(document)
185
+
186
+ if len(tokens) > 77:
187
+ # Skip the starting token, only include 75 tokens
188
+ truncated_tokens = tokens[1:76]
189
+ document = self.tokenizer.decode(truncated_tokens)
190
+
191
+ # Recursive call here, because the encode(decode()) can have different result
192
+ return self._truncate_document(document)
193
+
194
+ return document
BERTopic/bertopic/backend/_openai.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import openai
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ from typing import List, Mapping, Any
6
+ from bertopic.backend import BaseEmbedder
7
+
8
+
9
+ class OpenAIBackend(BaseEmbedder):
10
+ """ OpenAI Embedding Model
11
+
12
+ Arguments:
13
+ client: A `openai.OpenAI` client.
14
+ embedding_model: An OpenAI model. Default is
15
+ For an overview of models see:
16
+ https://platform.openai.com/docs/models/embeddings
17
+ delay_in_seconds: If a `batch_size` is given, use this set
18
+ the delay in seconds between batches.
19
+ batch_size: The size of each batch.
20
+ generator_kwargs: Kwargs passed to `openai.Embedding.create`.
21
+ Can be used to define custom engines or
22
+ deployment_ids.
23
+
24
+ Examples:
25
+
26
+ ```python
27
+ import openai
28
+ from bertopic.backend import OpenAIBackend
29
+
30
+ client = openai.OpenAI(api_key="sk-...")
31
+ openai_embedder = OpenAIBackend(client, "text-embedding-ada-002")
32
+ ```
33
+ """
34
+ def __init__(self,
35
+ client: openai.OpenAI,
36
+ embedding_model: str = "text-embedding-ada-002",
37
+ delay_in_seconds: float = None,
38
+ batch_size: int = None,
39
+ generator_kwargs: Mapping[str, Any] = {}):
40
+ super().__init__()
41
+ self.client = client
42
+ self.embedding_model = embedding_model
43
+ self.delay_in_seconds = delay_in_seconds
44
+ self.batch_size = batch_size
45
+ self.generator_kwargs = generator_kwargs
46
+
47
+ if self.generator_kwargs.get("model"):
48
+ self.embedding_model = generator_kwargs.get("model")
49
+ elif not self.generator_kwargs.get("engine"):
50
+ self.generator_kwargs["model"] = self.embedding_model
51
+
52
+ def embed(self,
53
+ documents: List[str],
54
+ verbose: bool = False) -> np.ndarray:
55
+ """ Embed a list of n documents/words into an n-dimensional
56
+ matrix of embeddings
57
+
58
+ Arguments:
59
+ documents: A list of documents or words to be embedded
60
+ verbose: Controls the verbosity of the process
61
+
62
+ Returns:
63
+ Document/words embeddings with shape (n, m) with `n` documents/words
64
+ that each have an embeddings size of `m`
65
+ """
66
+ # Prepare documents, replacing empty strings with a single space
67
+ prepared_documents = [" " if doc == "" else doc for doc in documents]
68
+
69
+ # Batch-wise embedding extraction
70
+ if self.batch_size is not None:
71
+ embeddings = []
72
+ for batch in tqdm(self._chunks(prepared_documents), disable=not verbose):
73
+ response = self.client.embeddings.create(input=batch, **self.generator_kwargs)
74
+ embeddings.extend([r.embedding for r in response.data])
75
+
76
+ # Delay subsequent calls
77
+ if self.delay_in_seconds:
78
+ time.sleep(self.delay_in_seconds)
79
+
80
+ # Extract embeddings all at once
81
+ else:
82
+ response = self.client.embeddings.create(input=prepared_documents, **self.generator_kwargs)
83
+ embeddings = [r.embedding for r in response.data]
84
+ return np.array(embeddings)
85
+
86
+ def _chunks(self, documents):
87
+ for i in range(0, len(documents), self.batch_size):
88
+ yield documents[i:i + self.batch_size]
BERTopic/bertopic/backend/_sentencetransformers.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List, Union
3
+ from sentence_transformers import SentenceTransformer
4
+
5
+ from bertopic.backend import BaseEmbedder
6
+
7
+
8
+ class SentenceTransformerBackend(BaseEmbedder):
9
+ """ Sentence-transformers embedding model
10
+
11
+ The sentence-transformers embedding model used for generating document and
12
+ word embeddings.
13
+
14
+ Arguments:
15
+ embedding_model: A sentence-transformers embedding model
16
+
17
+ Examples:
18
+
19
+ To create a model, you can load in a string pointing to a
20
+ sentence-transformers model:
21
+
22
+ ```python
23
+ from bertopic.backend import SentenceTransformerBackend
24
+
25
+ sentence_model = SentenceTransformerBackend("all-MiniLM-L6-v2")
26
+ ```
27
+
28
+ or you can instantiate a model yourself:
29
+ ```python
30
+ from bertopic.backend import SentenceTransformerBackend
31
+ from sentence_transformers import SentenceTransformer
32
+
33
+ embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
34
+ sentence_model = SentenceTransformerBackend(embedding_model)
35
+ ```
36
+ """
37
+ def __init__(self, embedding_model: Union[str, SentenceTransformer]):
38
+ super().__init__()
39
+
40
+ self._hf_model = None
41
+ if isinstance(embedding_model, SentenceTransformer):
42
+ self.embedding_model = embedding_model
43
+ elif isinstance(embedding_model, str):
44
+ self.embedding_model = SentenceTransformer(embedding_model)
45
+ self._hf_model = embedding_model
46
+ else:
47
+ raise ValueError("Please select a correct SentenceTransformers model: \n"
48
+ "`from sentence_transformers import SentenceTransformer` \n"
49
+ "`model = SentenceTransformer('all-MiniLM-L6-v2')`")
50
+
51
+ def embed(self,
52
+ documents: List[str],
53
+ verbose: bool = False) -> np.ndarray:
54
+ """ Embed a list of n documents/words into an n-dimensional
55
+ matrix of embeddings
56
+
57
+ Arguments:
58
+ documents: A list of documents or words to be embedded
59
+ verbose: Controls the verbosity of the process
60
+
61
+ Returns:
62
+ Document/words embeddings with shape (n, m) with `n` documents/words
63
+ that each have an embeddings size of `m`
64
+ """
65
+ embeddings = self.embedding_model.encode(documents, show_progress_bar=verbose)
66
+ return embeddings
BERTopic/bertopic/backend/_sklearn.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from bertopic.backend import BaseEmbedder
2
+ from sklearn.utils.validation import check_is_fitted, NotFittedError
3
+
4
+
5
+ class SklearnEmbedder(BaseEmbedder):
6
+ """ Scikit-Learn based embedding model
7
+
8
+ This component allows the usage of scikit-learn pipelines for generating document and
9
+ word embeddings.
10
+
11
+ Arguments:
12
+ pipe: A scikit-learn pipeline that can `.transform()` text.
13
+
14
+ Examples:
15
+
16
+ Scikit-Learn is very flexible and it allows for many representations.
17
+ A relatively simple pipeline is shown below.
18
+
19
+ ```python
20
+ from sklearn.pipeline import make_pipeline
21
+ from sklearn.decomposition import TruncatedSVD
22
+ from sklearn.feature_extraction.text import TfidfVectorizer
23
+
24
+ from bertopic.backend import SklearnEmbedder
25
+
26
+ pipe = make_pipeline(
27
+ TfidfVectorizer(),
28
+ TruncatedSVD(100)
29
+ )
30
+
31
+ sklearn_embedder = SklearnEmbedder(pipe)
32
+ topic_model = BERTopic(embedding_model=sklearn_embedder)
33
+ ```
34
+
35
+ This pipeline first constructs a sparse representation based on TF/idf and then
36
+ makes it dense by applying SVD. Alternatively, you might also construct something
37
+ more elaborate. As long as you construct a scikit-learn compatible pipeline, you
38
+ should be able to pass it to Bertopic.
39
+
40
+ !!! Warning
41
+ One caveat to be aware of is that scikit-learns base `Pipeline` class does not
42
+ support the `.partial_fit()`-API. If you have a pipeline that theoretically should
43
+ be able to support online learning then you might want to explore
44
+ the [scikit-partial](https://github.com/koaning/scikit-partial) project.
45
+ """
46
+ def __init__(self, pipe):
47
+ super().__init__()
48
+ self.pipe = pipe
49
+
50
+ def embed(self, documents, verbose=False):
51
+ """ Embed a list of n documents/words into an n-dimensional
52
+ matrix of embeddings
53
+
54
+ Arguments:
55
+ documents: A list of documents or words to be embedded
56
+ verbose: No-op variable that's kept around to keep the API consistent. If you want to get feedback on training times, you should use the sklearn API.
57
+
58
+ Returns:
59
+ Document/words embeddings with shape (n, m) with `n` documents/words
60
+ that each have an embeddings size of `m`
61
+ """
62
+ try:
63
+ check_is_fitted(self.pipe)
64
+ embeddings = self.pipe.transform(documents)
65
+ except NotFittedError:
66
+ embeddings = self.pipe.fit_transform(documents)
67
+
68
+ return embeddings
BERTopic/bertopic/backend/_spacy.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from tqdm import tqdm
3
+ from typing import List
4
+ from bertopic.backend import BaseEmbedder
5
+
6
+
7
+ class SpacyBackend(BaseEmbedder):
8
+ """ Spacy embedding model
9
+
10
+ The Spacy embedding model used for generating document and
11
+ word embeddings.
12
+
13
+ Arguments:
14
+ embedding_model: A spacy embedding model
15
+
16
+ Examples:
17
+
18
+ To create a Spacy backend, you need to create an nlp object and
19
+ pass it through this backend:
20
+
21
+ ```python
22
+ import spacy
23
+ from bertopic.backend import SpacyBackend
24
+
25
+ nlp = spacy.load("en_core_web_md", exclude=['tagger', 'parser', 'ner', 'attribute_ruler', 'lemmatizer'])
26
+ spacy_model = SpacyBackend(nlp)
27
+ ```
28
+
29
+ To load in a transformer model use the following:
30
+
31
+ ```python
32
+ import spacy
33
+ from thinc.api import set_gpu_allocator, require_gpu
34
+ from bertopic.backend import SpacyBackend
35
+
36
+ nlp = spacy.load("en_core_web_trf", exclude=['tagger', 'parser', 'ner', 'attribute_ruler', 'lemmatizer'])
37
+ set_gpu_allocator("pytorch")
38
+ require_gpu(0)
39
+ spacy_model = SpacyBackend(nlp)
40
+ ```
41
+
42
+ If you run into gpu/memory-issues, please use:
43
+
44
+ ```python
45
+ import spacy
46
+ from bertopic.backend import SpacyBackend
47
+
48
+ spacy.prefer_gpu()
49
+ nlp = spacy.load("en_core_web_trf", exclude=['tagger', 'parser', 'ner', 'attribute_ruler', 'lemmatizer'])
50
+ spacy_model = SpacyBackend(nlp)
51
+ ```
52
+ """
53
+ def __init__(self, embedding_model):
54
+ super().__init__()
55
+
56
+ if "spacy" in str(type(embedding_model)):
57
+ self.embedding_model = embedding_model
58
+ else:
59
+ raise ValueError("Please select a correct Spacy model by either using a string such as 'en_core_web_md' "
60
+ "or create a nlp model using: `nlp = spacy.load('en_core_web_md')")
61
+
62
+ def embed(self,
63
+ documents: List[str],
64
+ verbose: bool = False) -> np.ndarray:
65
+ """ Embed a list of n documents/words into an n-dimensional
66
+ matrix of embeddings
67
+
68
+ Arguments:
69
+ documents: A list of documents or words to be embedded
70
+ verbose: Controls the verbosity of the process
71
+
72
+ Returns:
73
+ Document/words embeddings with shape (n, m) with `n` documents/words
74
+ that each have an embeddings size of `m`
75
+ """
76
+ # Handle empty documents, spaCy models automatically map
77
+ # empty strings to the zero vector
78
+ empty_document = " "
79
+
80
+ # Extract embeddings
81
+ embeddings = []
82
+ for doc in tqdm(documents, position=0, leave=True, disable=not verbose):
83
+ embedding = self.embedding_model(doc or empty_document)
84
+ if embedding.has_vector:
85
+ embedding = embedding.vector
86
+ else:
87
+ embedding = embedding._.trf_data.tensors[-1][0]
88
+
89
+ if not isinstance(embedding, np.ndarray) and hasattr(embedding, 'get'):
90
+ # Convert cupy array to numpy array
91
+ embedding = embedding.get()
92
+ embeddings.append(embedding)
93
+
94
+ return np.array(embeddings)
BERTopic/bertopic/backend/_use.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from tqdm import tqdm
3
+ from typing import List
4
+
5
+ from bertopic.backend import BaseEmbedder
6
+
7
+
8
+ class USEBackend(BaseEmbedder):
9
+ """ Universal Sentence Encoder
10
+
11
+ USE encodes text into high-dimensional vectors that
12
+ are used for semantic similarity in BERTopic.
13
+
14
+ Arguments:
15
+ embedding_model: An USE embedding model
16
+
17
+ Examples:
18
+
19
+ ```python
20
+ import tensorflow_hub
21
+ from bertopic.backend import USEBackend
22
+
23
+ embedding_model = tensorflow_hub.load("https://tfhub.dev/google/universal-sentence-encoder/4")
24
+ use_embedder = USEBackend(embedding_model)
25
+ ```
26
+ """
27
+ def __init__(self, embedding_model):
28
+ super().__init__()
29
+
30
+ try:
31
+ embedding_model(["test sentence"])
32
+ self.embedding_model = embedding_model
33
+ except TypeError:
34
+ raise ValueError("Please select a correct USE model: \n"
35
+ "`import tensorflow_hub` \n"
36
+ "`embedding_model = tensorflow_hub.load(path_to_model)`")
37
+
38
+ def embed(self,
39
+ documents: List[str],
40
+ verbose: bool = False) -> np.ndarray:
41
+ """ Embed a list of n documents/words into an n-dimensional
42
+ matrix of embeddings
43
+
44
+ Arguments:
45
+ documents: A list of documents or words to be embedded
46
+ verbose: Controls the verbosity of the process
47
+
48
+ Returns:
49
+ Document/words embeddings with shape (n, m) with `n` documents/words
50
+ that each have an embeddings size of `m`
51
+ """
52
+ embeddings = np.array(
53
+ [
54
+ self.embedding_model([doc]).cpu().numpy()[0]
55
+ for doc in tqdm(documents, disable=not verbose)
56
+ ]
57
+ )
58
+ return embeddings
BERTopic/bertopic/backend/_utils.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._base import BaseEmbedder
2
+
3
+ # Imports for light-weight variant of BERTopic
4
+ from bertopic.backend._sklearn import SklearnEmbedder
5
+ from sklearn.pipeline import make_pipeline
6
+ from sklearn.decomposition import TruncatedSVD
7
+ from sklearn.feature_extraction.text import TfidfVectorizer
8
+ from sklearn.pipeline import Pipeline as ScikitPipeline
9
+
10
+
11
+ languages = [
12
+ "arabic",
13
+ "bulgarian",
14
+ "catalan",
15
+ "czech",
16
+ "danish",
17
+ "german",
18
+ "greek",
19
+ "english",
20
+ "spanish",
21
+ "estonian",
22
+ "persian",
23
+ "finnish",
24
+ "french",
25
+ "canadian french",
26
+ "galician",
27
+ "gujarati",
28
+ "hebrew",
29
+ "hindi",
30
+ "croatian",
31
+ "hungarian",
32
+ "armenian",
33
+ "indonesian",
34
+ "italian",
35
+ "japanese",
36
+ "georgian",
37
+ "korean",
38
+ "kurdish",
39
+ "lithuanian",
40
+ "latvian",
41
+ "macedonian",
42
+ "mongolian",
43
+ "marathi",
44
+ "malay",
45
+ "burmese",
46
+ "norwegian bokmal",
47
+ "dutch",
48
+ "polish",
49
+ "portuguese",
50
+ "brazilian portuguese",
51
+ "romanian",
52
+ "russian",
53
+ "slovak",
54
+ "slovenian",
55
+ "albanian",
56
+ "serbian",
57
+ "swedish",
58
+ "thai",
59
+ "turkish",
60
+ "ukrainian",
61
+ "urdu",
62
+ "vietnamese",
63
+ "chinese (simplified)",
64
+ "chinese (traditional)",
65
+ ]
66
+
67
+
68
+ def select_backend(embedding_model,
69
+ language: str = None) -> BaseEmbedder:
70
+ """ Select an embedding model based on language or a specific sentence transformer models.
71
+ When selecting a language, we choose all-MiniLM-L6-v2 for English and
72
+ paraphrase-multilingual-MiniLM-L12-v2 for all other languages as it support 100+ languages.
73
+
74
+ Returns:
75
+ model: Either a Sentence-Transformer or Flair model
76
+ """
77
+ # BERTopic language backend
78
+ if isinstance(embedding_model, BaseEmbedder):
79
+ return embedding_model
80
+
81
+ # Scikit-learn backend
82
+ if isinstance(embedding_model, ScikitPipeline):
83
+ return SklearnEmbedder(embedding_model)
84
+
85
+ # Flair word embeddings
86
+ if "flair" in str(type(embedding_model)):
87
+ from bertopic.backend._flair import FlairBackend
88
+ return FlairBackend(embedding_model)
89
+
90
+ # Spacy embeddings
91
+ if "spacy" in str(type(embedding_model)):
92
+ from bertopic.backend._spacy import SpacyBackend
93
+ return SpacyBackend(embedding_model)
94
+
95
+ # Gensim embeddings
96
+ if "gensim" in str(type(embedding_model)):
97
+ from bertopic.backend._gensim import GensimBackend
98
+ return GensimBackend(embedding_model)
99
+
100
+ # USE embeddings
101
+ if "tensorflow" and "saved_model" in str(type(embedding_model)):
102
+ from bertopic.backend._use import USEBackend
103
+ return USEBackend(embedding_model)
104
+
105
+ # Sentence Transformer embeddings
106
+ if "sentence_transformers" in str(type(embedding_model)) or isinstance(embedding_model, str):
107
+ from ._sentencetransformers import SentenceTransformerBackend
108
+ return SentenceTransformerBackend(embedding_model)
109
+
110
+ # Hugging Face embeddings
111
+ if "transformers" and "pipeline" in str(type(embedding_model)):
112
+ from ._hftransformers import HFTransformerBackend
113
+ return HFTransformerBackend(embedding_model)
114
+
115
+ # Select embedding model based on language
116
+ if language:
117
+ try:
118
+ from ._sentencetransformers import SentenceTransformerBackend
119
+ if language.lower() in ["English", "english", "en"]:
120
+ return SentenceTransformerBackend("sentence-transformers/all-MiniLM-L6-v2")
121
+ elif language.lower() in languages or language == "multilingual":
122
+ return SentenceTransformerBackend("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
123
+ else:
124
+ raise ValueError(f"{language} is currently not supported. However, you can "
125
+ f"create any embeddings yourself and pass it through fit_transform(docs, embeddings)\n"
126
+ "Else, please select a language from the following list:\n"
127
+ f"{languages}")
128
+
129
+ # Only for light-weight installation
130
+ except ModuleNotFoundError:
131
+ pipe = make_pipeline(TfidfVectorizer(), TruncatedSVD(100))
132
+ return SklearnEmbedder(pipe)
133
+
134
+ from ._sentencetransformers import SentenceTransformerBackend
135
+ return SentenceTransformerBackend("sentence-transformers/all-MiniLM-L6-v2")
BERTopic/bertopic/backend/_word_doc.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List
3
+ from bertopic.backend._base import BaseEmbedder
4
+ from bertopic.backend._utils import select_backend
5
+
6
+
7
+ class WordDocEmbedder(BaseEmbedder):
8
+ """ Combine a document- and word-level embedder
9
+ """
10
+ def __init__(self,
11
+ embedding_model,
12
+ word_embedding_model):
13
+ super().__init__()
14
+
15
+ self.embedding_model = select_backend(embedding_model)
16
+ self.word_embedding_model = select_backend(word_embedding_model)
17
+
18
+ def embed_words(self,
19
+ words: List[str],
20
+ verbose: bool = False) -> np.ndarray:
21
+ """ Embed a list of n words into an n-dimensional
22
+ matrix of embeddings
23
+
24
+ Arguments:
25
+ words: A list of words to be embedded
26
+ verbose: Controls the verbosity of the process
27
+
28
+ Returns:
29
+ Word embeddings with shape (n, m) with `n` words
30
+ that each have an embeddings size of `m`
31
+
32
+ """
33
+ return self.word_embedding_model.embed(words, verbose)
34
+
35
+ def embed_documents(self,
36
+ document: List[str],
37
+ verbose: bool = False) -> np.ndarray:
38
+ """ Embed a list of n words into an n-dimensional
39
+ matrix of embeddings
40
+
41
+ Arguments:
42
+ document: A list of documents to be embedded
43
+ verbose: Controls the verbosity of the process
44
+
45
+ Returns:
46
+ Document embeddings with shape (n, m) with `n` documents
47
+ that each have an embeddings size of `m`
48
+ """
49
+ return self.embedding_model.embed(document, verbose)
BERTopic/bertopic/cluster/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from ._base import BaseCluster
2
+
3
+ __all__ = [
4
+ "BaseCluster",
5
+ ]
BERTopic/bertopic/cluster/_base.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class BaseCluster:
5
+ """ The Base Cluster class
6
+
7
+ Using this class directly in BERTopic will make it skip
8
+ over the cluster step. As a result, topics need to be passed
9
+ to BERTopic in the form of its `y` parameter in order to create
10
+ topic representations.
11
+
12
+ Examples:
13
+
14
+ This will skip over the cluster step in BERTopic:
15
+
16
+ ```python
17
+ from bertopic import BERTopic
18
+ from bertopic.dimensionality import BaseCluster
19
+
20
+ empty_cluster_model = BaseCluster()
21
+
22
+ topic_model = BERTopic(hdbscan_model=empty_cluster_model)
23
+ ```
24
+
25
+ Then, this class can be used to perform manual topic modeling.
26
+ That is, topic modeling on a topics that were already generated before
27
+ without the need to learn them:
28
+
29
+ ```python
30
+ topic_model.fit(docs, y=y)
31
+ ```
32
+ """
33
+ def fit(self, X, y=None):
34
+ if y is not None:
35
+ self.labels_ = y
36
+ else:
37
+ self.labels_ = None
38
+ return self
39
+
40
+ def transform(self, X: np.ndarray) -> np.ndarray:
41
+ return X
BERTopic/bertopic/cluster/_utils.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hdbscan
2
+ import numpy as np
3
+
4
+
5
+ def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None):
6
+ """ Function used to select the HDBSCAN-like model for generating
7
+ predictions and probabilities.
8
+
9
+ Arguments:
10
+ model: The cluster model.
11
+ func: The function to use. Options:
12
+ - "approximate_predict"
13
+ - "all_points_membership_vectors"
14
+ - "membership_vector"
15
+ embeddings: Input embeddings for "approximate_predict"
16
+ and "membership_vector"
17
+ """
18
+
19
+ # Approximate predict
20
+ if func == "approximate_predict":
21
+ if isinstance(model, hdbscan.HDBSCAN):
22
+ predictions, probabilities = hdbscan.approximate_predict(model, embeddings)
23
+ return predictions, probabilities
24
+
25
+ str_type_model = str(type(model)).lower()
26
+ if "cuml" in str_type_model and "hdbscan" in str_type_model:
27
+ from cuml.cluster import hdbscan as cuml_hdbscan
28
+ predictions, probabilities = cuml_hdbscan.approximate_predict(model, embeddings)
29
+ return predictions, probabilities
30
+
31
+ predictions = model.predict(embeddings)
32
+ return predictions, None
33
+
34
+ # All points membership
35
+ if func == "all_points_membership_vectors":
36
+ if isinstance(model, hdbscan.HDBSCAN):
37
+ return hdbscan.all_points_membership_vectors(model)
38
+
39
+ str_type_model = str(type(model)).lower()
40
+ if "cuml" in str_type_model and "hdbscan" in str_type_model:
41
+ from cuml.cluster import hdbscan as cuml_hdbscan
42
+ return cuml_hdbscan.all_points_membership_vectors(model)
43
+
44
+ return None
45
+
46
+ # membership_vector
47
+ if func == "membership_vector":
48
+ if isinstance(model, hdbscan.HDBSCAN):
49
+ probabilities = hdbscan.membership_vector(model, embeddings)
50
+ return probabilities
51
+
52
+ str_type_model = str(type(model)).lower()
53
+ if "cuml" in str_type_model and "hdbscan" in str_type_model:
54
+ from cuml.cluster.hdbscan.prediction import approximate_predict
55
+ probabilities = approximate_predict(model, embeddings)
56
+ return probabilities
57
+
58
+ return None
59
+
60
+
61
+ def is_supported_hdbscan(model):
62
+ """ Check whether the input model is a supported HDBSCAN-like model """
63
+ if isinstance(model, hdbscan.HDBSCAN):
64
+ return True
65
+
66
+ str_type_model = str(type(model)).lower()
67
+ if "cuml" in str_type_model and "hdbscan" in str_type_model:
68
+ return True
69
+
70
+ return False
BERTopic/bertopic/dimensionality/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from ._base import BaseDimensionalityReduction
2
+
3
+ __all__ = [
4
+ "BaseDimensionalityReduction",
5
+ ]
BERTopic/bertopic/dimensionality/_base.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class BaseDimensionalityReduction:
5
+ """ The Base Dimensionality Reduction class
6
+
7
+ You can use this to skip over the dimensionality reduction step in BERTopic.
8
+
9
+ Examples:
10
+
11
+ This will skip over the reduction step in BERTopic:
12
+
13
+ ```python
14
+ from bertopic import BERTopic
15
+ from bertopic.dimensionality import BaseDimensionalityReduction
16
+
17
+ empty_reduction_model = BaseDimensionalityReduction()
18
+
19
+ topic_model = BERTopic(umap_model=empty_reduction_model)
20
+ ```
21
+ """
22
+ def fit(self, X: np.ndarray = None):
23
+ return self
24
+
25
+ def transform(self, X: np.ndarray) -> np.ndarray:
26
+ return X
BERTopic/bertopic/plotting/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._topics import visualize_topics
2
+ from ._heatmap import visualize_heatmap
3
+ from ._barchart import visualize_barchart
4
+ from ._documents import visualize_documents
5
+ from ._term_rank import visualize_term_rank
6
+ from ._hierarchy import visualize_hierarchy
7
+ from ._datamap import visualize_document_datamap
8
+ from ._distribution import visualize_distribution
9
+ from ._topics_over_time import visualize_topics_over_time
10
+ from ._topics_per_class import visualize_topics_per_class
11
+ from ._hierarchical_documents import visualize_hierarchical_documents
12
+ from ._approximate_distribution import visualize_approximate_distribution
13
+
14
+
15
+ __all__ = [
16
+ "visualize_topics",
17
+ "visualize_heatmap",
18
+ "visualize_barchart",
19
+ "visualize_documents",
20
+ "visualize_term_rank",
21
+ "visualize_hierarchy",
22
+ "visualize_distribution",
23
+ "visualize_document_datamap",
24
+ "visualize_topics_over_time",
25
+ "visualize_topics_per_class",
26
+ "visualize_hierarchical_documents",
27
+ "visualize_approximate_distribution"
28
+ ]
BERTopic/bertopic/plotting/_approximate_distribution.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+ try:
5
+ from pandas.io.formats.style import Styler
6
+ HAS_JINJA = True
7
+ except (ModuleNotFoundError, ImportError):
8
+ HAS_JINJA = False
9
+
10
+
11
+ def visualize_approximate_distribution(topic_model,
12
+ document: str,
13
+ topic_token_distribution: np.ndarray,
14
+ normalize: bool = False):
15
+ """ Visualize the topic distribution calculated by `.approximate_topic_distribution`
16
+ on a token level. Thereby indicating the extend to which a certain word or phrases belong
17
+ to a specific topic. The assumption here is that a single word can belong to multiple
18
+ similar topics and as such give information about the broader set of topics within
19
+ a single document.
20
+
21
+ NOTE:
22
+ This fuction will return a stylized pandas dataframe if Jinja2 is installed. If not,
23
+ it will only return a pandas dataframe without color highlighting. To install jinja:
24
+
25
+ `pip install jinja2`
26
+
27
+ Arguments:
28
+ topic_model: A fitted BERTopic instance.
29
+ document: The document for which you want to visualize
30
+ the approximated topic distribution.
31
+ topic_token_distribution: The topic-token distribution of the document as
32
+ extracted by `.approximate_topic_distribution`
33
+ normalize: Whether to normalize, between 0 and 1 (summing to 1), the
34
+ topic distribution values.
35
+
36
+ Returns:
37
+ df: A stylized dataframe indicating the best fitting topics
38
+ for each token.
39
+
40
+ Examples:
41
+
42
+ ```python
43
+ # Calculate the topic distributions on a token level
44
+ # Note that we need to have `calculate_token_level=True`
45
+ topic_distr, topic_token_distr = topic_model.approximate_distribution(
46
+ docs, calculate_token_level=True
47
+ )
48
+
49
+ # Visualize the approximated topic distributions
50
+ df = topic_model.visualize_approximate_distribution(docs[0], topic_token_distr[0])
51
+ df
52
+ ```
53
+
54
+ To revert this stylized dataframe back to a regular dataframe,
55
+ you can run the following:
56
+
57
+ ```python
58
+ df.data.columns = [column.strip() for column in df.data.columns]
59
+ df = df.data
60
+ ```
61
+ """
62
+ # Tokenize document
63
+ analyzer = topic_model.vectorizer_model.build_tokenizer()
64
+ tokens = analyzer(document)
65
+
66
+ if len(tokens) == 0:
67
+ raise ValueError("Make sure that your document contains at least 1 token.")
68
+
69
+ # Prepare dataframe with results
70
+ if normalize:
71
+ df = pd.DataFrame(topic_token_distribution / topic_token_distribution.sum()).T
72
+ else:
73
+ df = pd.DataFrame(topic_token_distribution).T
74
+
75
+ df.columns = [f"{token}_{i}" for i, token in enumerate(tokens)]
76
+ df.columns = [f"{token}{' '*i}" for i, token in enumerate(tokens)]
77
+ df.index = list(topic_model.topic_labels_.values())[topic_model._outliers:]
78
+ df = df.loc[(df.sum(axis=1) != 0), :]
79
+
80
+ # Style the resulting dataframe
81
+ def text_color(val):
82
+ color = 'white' if val == 0 else 'black'
83
+ return 'color: %s' % color
84
+
85
+ def highligh_color(data, color='white'):
86
+ attr = 'background-color: {}'.format(color)
87
+ return pd.DataFrame(np.where(data == 0, attr, ''), index=data.index, columns=data.columns)
88
+
89
+ if len(df) == 0:
90
+ return df
91
+ elif HAS_JINJA:
92
+ df = (
93
+ df.style
94
+ .format("{:.3f}")
95
+ .background_gradient(cmap='Blues', axis=None)
96
+ .applymap(lambda x: text_color(x))
97
+ .apply(highligh_color, axis=None)
98
+ )
99
+ return df
BERTopic/bertopic/plotting/_barchart.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import numpy as np
3
+ from typing import List, Union
4
+
5
+ import plotly.graph_objects as go
6
+ from plotly.subplots import make_subplots
7
+
8
+
9
+ def visualize_barchart(topic_model,
10
+ topics: List[int] = None,
11
+ top_n_topics: int = 8,
12
+ n_words: int = 5,
13
+ custom_labels: Union[bool, str] = False,
14
+ title: str = "<b>Topic Word Scores</b>",
15
+ width: int = 250,
16
+ height: int = 250) -> go.Figure:
17
+ """ Visualize a barchart of selected topics
18
+
19
+ Arguments:
20
+ topic_model: A fitted BERTopic instance.
21
+ topics: A selection of topics to visualize.
22
+ top_n_topics: Only select the top n most frequent topics.
23
+ n_words: Number of words to show in a topic
24
+ custom_labels: If bool, whether to use custom topic labels that were defined using
25
+ `topic_model.set_topic_labels`.
26
+ If `str`, it uses labels from other aspects, e.g., "Aspect1".
27
+ title: Title of the plot.
28
+ width: The width of each figure.
29
+ height: The height of each figure.
30
+
31
+ Returns:
32
+ fig: A plotly figure
33
+
34
+ Examples:
35
+
36
+ To visualize the barchart of selected topics
37
+ simply run:
38
+
39
+ ```python
40
+ topic_model.visualize_barchart()
41
+ ```
42
+
43
+ Or if you want to save the resulting figure:
44
+
45
+ ```python
46
+ fig = topic_model.visualize_barchart()
47
+ fig.write_html("path/to/file.html")
48
+ ```
49
+ <iframe src="../../getting_started/visualization/bar_chart.html"
50
+ style="width:1100px; height: 660px; border: 0px;""></iframe>
51
+ """
52
+ colors = itertools.cycle(["#D55E00", "#0072B2", "#CC79A7", "#E69F00", "#56B4E9", "#009E73", "#F0E442"])
53
+
54
+ # Select topics based on top_n and topics args
55
+ freq_df = topic_model.get_topic_freq()
56
+ freq_df = freq_df.loc[freq_df.Topic != -1, :]
57
+ if topics is not None:
58
+ topics = list(topics)
59
+ elif top_n_topics is not None:
60
+ topics = sorted(freq_df.Topic.to_list()[:top_n_topics])
61
+ else:
62
+ topics = sorted(freq_df.Topic.to_list()[0:6])
63
+
64
+ # Initialize figure
65
+ if isinstance(custom_labels, str):
66
+ subplot_titles = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in topics]
67
+ subplot_titles = ["_".join([label[0] for label in labels[:4]]) for labels in subplot_titles]
68
+ subplot_titles = [label if len(label) < 30 else label[:27] + "..." for label in subplot_titles]
69
+ elif topic_model.custom_labels_ is not None and custom_labels:
70
+ subplot_titles = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in topics]
71
+ else:
72
+ subplot_titles = [f"Topic {topic}" for topic in topics]
73
+ columns = 4
74
+ rows = int(np.ceil(len(topics) / columns))
75
+ fig = make_subplots(rows=rows,
76
+ cols=columns,
77
+ shared_xaxes=False,
78
+ horizontal_spacing=.1,
79
+ vertical_spacing=.4 / rows if rows > 1 else 0,
80
+ subplot_titles=subplot_titles)
81
+
82
+ # Add barchart for each topic
83
+ row = 1
84
+ column = 1
85
+ for topic in topics:
86
+ words = [word + " " for word, _ in topic_model.get_topic(topic)][:n_words][::-1]
87
+ scores = [score for _, score in topic_model.get_topic(topic)][:n_words][::-1]
88
+
89
+ fig.add_trace(
90
+ go.Bar(x=scores,
91
+ y=words,
92
+ orientation='h',
93
+ marker_color=next(colors)),
94
+ row=row, col=column)
95
+
96
+ if column == columns:
97
+ column = 1
98
+ row += 1
99
+ else:
100
+ column += 1
101
+
102
+ # Stylize graph
103
+ fig.update_layout(
104
+ template="plotly_white",
105
+ showlegend=False,
106
+ title={
107
+ 'text': f"{title}",
108
+ 'x': .5,
109
+ 'xanchor': 'center',
110
+ 'yanchor': 'top',
111
+ 'font': dict(
112
+ size=22,
113
+ color="Black")
114
+ },
115
+ width=width*4,
116
+ height=height*rows if rows > 1 else height * 1.3,
117
+ hoverlabel=dict(
118
+ bgcolor="white",
119
+ font_size=16,
120
+ font_family="Rockwell"
121
+ ),
122
+ )
123
+
124
+ fig.update_xaxes(showgrid=True)
125
+ fig.update_yaxes(showgrid=True)
126
+
127
+ return fig
BERTopic/bertopic/plotting/_datamap.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from typing import List, Union
4
+ from umap import UMAP
5
+ from warnings import warn
6
+ try:
7
+ import datamapplot
8
+ from matplotlib.figure import Figure
9
+ except ImportError:
10
+ warn("Data map plotting is unavailable unless datamapplot is installed.")
11
+ # Create a dummy figure type for typing
12
+ class Figure (object):
13
+ pass
14
+
15
+ def visualize_document_datamap(topic_model,
16
+ docs: List[str],
17
+ topics: List[int] = None,
18
+ embeddings: np.ndarray = None,
19
+ reduced_embeddings: np.ndarray = None,
20
+ custom_labels: Union[bool, str] = False,
21
+ title: str = "Documents and Topics",
22
+ sub_title: Union[str, None] = None,
23
+ width: int = 1200,
24
+ height: int = 1200,
25
+ **datamap_kwds) -> Figure:
26
+ """ Visualize documents and their topics in 2D as a static plot for publication using
27
+ DataMapPlot.
28
+
29
+ Arguments:
30
+ topic_model: A fitted BERTopic instance.
31
+ docs: The documents you used when calling either `fit` or `fit_transform`
32
+ topics: A selection of topics to visualize.
33
+ Not to be confused with the topics that you get from `.fit_transform`.
34
+ For example, if you want to visualize only topics 1 through 5:
35
+ `topics = [1, 2, 3, 4, 5]`. Documents not in these topics will be shown
36
+ as noise points.
37
+ embeddings: The embeddings of all documents in `docs`.
38
+ reduced_embeddings: The 2D reduced embeddings of all documents in `docs`.
39
+ custom_labels: If bool, whether to use custom topic labels that were defined using
40
+ `topic_model.set_topic_labels`.
41
+ If `str`, it uses labels from other aspects, e.g., "Aspect1".
42
+ title: Title of the plot.
43
+ sub_title: Sub-title of the plot.
44
+ width: The width of the figure.
45
+ height: The height of the figure.
46
+ **datamap_kwds: All further keyword args will be passed on to DataMapPlot's
47
+ `create_plot` function. See the DataMapPlot documentation
48
+ for more details.
49
+
50
+ Returns:
51
+ figure: A Matplotlib Figure object.
52
+
53
+ Examples:
54
+
55
+ To visualize the topics simply run:
56
+
57
+ ```python
58
+ topic_model.visualize_document_datamap(docs)
59
+ ```
60
+
61
+ Do note that this re-calculates the embeddings and reduces them to 2D.
62
+ The advised and preferred pipeline for using this function is as follows:
63
+
64
+ ```python
65
+ from sklearn.datasets import fetch_20newsgroups
66
+ from sentence_transformers import SentenceTransformer
67
+ from bertopic import BERTopic
68
+ from umap import UMAP
69
+
70
+ # Prepare embeddings
71
+ docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data']
72
+ sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
73
+ embeddings = sentence_model.encode(docs, show_progress_bar=False)
74
+
75
+ # Train BERTopic
76
+ topic_model = BERTopic().fit(docs, embeddings)
77
+
78
+ # Reduce dimensionality of embeddings, this step is optional
79
+ # reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings)
80
+
81
+ # Run the visualization with the original embeddings
82
+ topic_model.visualize_document_datamap(docs, embeddings=embeddings)
83
+
84
+ # Or, if you have reduced the original embeddings already:
85
+ topic_model.visualize_document_datamap(docs, reduced_embeddings=reduced_embeddings)
86
+ ```
87
+
88
+ Or if you want to save the resulting figure:
89
+
90
+ ```python
91
+ fig = topic_model.visualize_document_datamap(docs, reduced_embeddings=reduced_embeddings)
92
+ fig.savefig("path/to/file.png", bbox_inches="tight")
93
+ ```
94
+ <img src="../../getting_started/visualization/datamapplot.png",
95
+ alt="DataMapPlot of 20-Newsgroups", width=800, height=800></img>
96
+ """
97
+
98
+ topic_per_doc = topic_model.topics_
99
+
100
+ df = pd.DataFrame({"topic": np.array(topic_per_doc)})
101
+ df["doc"] = docs
102
+ df["topic"] = topic_per_doc
103
+
104
+ # Extract embeddings if not already done
105
+ if embeddings is None and reduced_embeddings is None:
106
+ embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document")
107
+ else:
108
+ embeddings_to_reduce = embeddings
109
+
110
+ # Reduce input embeddings
111
+ if reduced_embeddings is None:
112
+ umap_model = UMAP(n_neighbors=15, n_components=2, min_dist=0.15, metric='cosine').fit(embeddings_to_reduce)
113
+ embeddings_2d = umap_model.embedding_
114
+ else:
115
+ embeddings_2d = reduced_embeddings
116
+
117
+ unique_topics = set(topic_per_doc)
118
+
119
+ # Prepare text and names
120
+ if isinstance(custom_labels, str):
121
+ names = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in unique_topics]
122
+ names = [" ".join([label[0] for label in labels[:4]]) for labels in names]
123
+ names = [label if len(label) < 30 else label[:27] + "..." for label in names]
124
+ elif topic_model.custom_labels_ is not None and custom_labels:
125
+ names = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in unique_topics]
126
+ else:
127
+ names = [f"Topic-{topic}: " + " ".join([word for word, value in topic_model.get_topic(topic)][:3]) for topic in unique_topics]
128
+
129
+ topic_name_mapping = {topic_num: topic_name for topic_num, topic_name in zip(unique_topics, names)}
130
+ topic_name_mapping[-1] = "Unlabelled"
131
+
132
+ # If a set of topics is chosen, set everything else to "Unlabelled"
133
+ if topics is not None:
134
+ selected_topics = set(topics)
135
+ for topic_num in topic_name_mapping:
136
+ if topic_num not in selected_topics:
137
+ topic_name_mapping[topic_num] = "Unlabelled"
138
+
139
+ # Map in topic names and plot
140
+ named_topic_per_doc = pd.Series(topic_per_doc).map(topic_name_mapping).values
141
+
142
+ figure, axes = datamapplot.create_plot(
143
+ embeddings_2d,
144
+ named_topic_per_doc,
145
+ figsize=(width/100, height/100),
146
+ dpi=100,
147
+ title=title,
148
+ sub_title=sub_title,
149
+ **datamap_kwds,
150
+ )
151
+
152
+ return figure
BERTopic/bertopic/plotting/_distribution.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import Union
3
+ import plotly.graph_objects as go
4
+
5
+
6
+ def visualize_distribution(topic_model,
7
+ probabilities: np.ndarray,
8
+ min_probability: float = 0.015,
9
+ custom_labels: Union[bool, str] = False,
10
+ title: str = "<b>Topic Probability Distribution</b>",
11
+ width: int = 800,
12
+ height: int = 600) -> go.Figure:
13
+ """ Visualize the distribution of topic probabilities
14
+
15
+ Arguments:
16
+ topic_model: A fitted BERTopic instance.
17
+ probabilities: An array of probability scores
18
+ min_probability: The minimum probability score to visualize.
19
+ All others are ignored.
20
+ custom_labels: If bool, whether to use custom topic labels that were defined using
21
+ `topic_model.set_topic_labels`.
22
+ If `str`, it uses labels from other aspects, e.g., "Aspect1".
23
+ title: Title of the plot.
24
+ width: The width of the figure.
25
+ height: The height of the figure.
26
+
27
+ Examples:
28
+
29
+ Make sure to fit the model before and only input the
30
+ probabilities of a single document:
31
+
32
+ ```python
33
+ topic_model.visualize_distribution(probabilities[0])
34
+ ```
35
+
36
+ Or if you want to save the resulting figure:
37
+
38
+ ```python
39
+ fig = topic_model.visualize_distribution(probabilities[0])
40
+ fig.write_html("path/to/file.html")
41
+ ```
42
+ <iframe src="../../getting_started/visualization/probabilities.html"
43
+ style="width:1000px; height: 500px; border: 0px;""></iframe>
44
+ """
45
+ if len(probabilities.shape) != 1:
46
+ raise ValueError("This visualization cannot be used if you have set `calculate_probabilities` to False "
47
+ "as it uses the topic probabilities of all topics. ")
48
+ if len(probabilities[probabilities > min_probability]) == 0:
49
+ raise ValueError("There are no values where `min_probability` is higher than the "
50
+ "probabilities that were supplied. Lower `min_probability` to prevent this error.")
51
+
52
+ # Get values and indices equal or exceed the minimum probability
53
+ labels_idx = np.argwhere(probabilities >= min_probability).flatten()
54
+ vals = probabilities[labels_idx].tolist()
55
+
56
+ # Create labels
57
+ if isinstance(custom_labels, str):
58
+ labels = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in labels_idx]
59
+ labels = ["_".join([label[0] for label in l[:4]]) for l in labels]
60
+ labels = [label if len(label) < 30 else label[:27] + "..." for label in labels]
61
+ elif topic_model.custom_labels_ is not None and custom_labels:
62
+ labels = [topic_model.custom_labels_[idx + topic_model._outliers] for idx in labels_idx]
63
+ else:
64
+ labels = []
65
+ for idx in labels_idx:
66
+ words = topic_model.get_topic(idx)
67
+ if words:
68
+ label = [word[0] for word in words[:5]]
69
+ label = f"<b>Topic {idx}</b>: {'_'.join(label)}"
70
+ label = label[:40] + "..." if len(label) > 40 else label
71
+ labels.append(label)
72
+ else:
73
+ vals.remove(probabilities[idx])
74
+
75
+ # Create Figure
76
+ fig = go.Figure(go.Bar(
77
+ x=vals,
78
+ y=labels,
79
+ marker=dict(
80
+ color='#C8D2D7',
81
+ line=dict(
82
+ color='#6E8484',
83
+ width=1),
84
+ ),
85
+ orientation='h')
86
+ )
87
+
88
+ fig.update_layout(
89
+ xaxis_title="Probability",
90
+ title={
91
+ 'text': f"{title}",
92
+ 'y': .95,
93
+ 'x': 0.5,
94
+ 'xanchor': 'center',
95
+ 'yanchor': 'top',
96
+ 'font': dict(
97
+ size=22,
98
+ color="Black")
99
+ },
100
+ template="simple_white",
101
+ width=width,
102
+ height=height,
103
+ hoverlabel=dict(
104
+ bgcolor="white",
105
+ font_size=16,
106
+ font_family="Rockwell"
107
+ ),
108
+ )
109
+
110
+ return fig
BERTopic/bertopic/plotting/_documents.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import plotly.graph_objects as go
4
+
5
+ from umap import UMAP
6
+ from typing import List, Union
7
+
8
+
9
+ def visualize_documents(topic_model,
10
+ docs: List[str],
11
+ topics: List[int] = None,
12
+ embeddings: np.ndarray = None,
13
+ reduced_embeddings: np.ndarray = None,
14
+ sample: float = None,
15
+ hide_annotations: bool = False,
16
+ hide_document_hover: bool = False,
17
+ custom_labels: Union[bool, str] = False,
18
+ title: str = "<b>Documents and Topics</b>",
19
+ width: int = 1200,
20
+ height: int = 750):
21
+ """ Visualize documents and their topics in 2D
22
+
23
+ Arguments:
24
+ topic_model: A fitted BERTopic instance.
25
+ docs: The documents you used when calling either `fit` or `fit_transform`
26
+ topics: A selection of topics to visualize.
27
+ Not to be confused with the topics that you get from `.fit_transform`.
28
+ For example, if you want to visualize only topics 1 through 5:
29
+ `topics = [1, 2, 3, 4, 5]`.
30
+ embeddings: The embeddings of all documents in `docs`.
31
+ reduced_embeddings: The 2D reduced embeddings of all documents in `docs`.
32
+ sample: The percentage of documents in each topic that you would like to keep.
33
+ Value can be between 0 and 1. Setting this value to, for example,
34
+ 0.1 (10% of documents in each topic) makes it easier to visualize
35
+ millions of documents as a subset is chosen.
36
+ hide_annotations: Hide the names of the traces on top of each cluster.
37
+ hide_document_hover: Hide the content of the documents when hovering over
38
+ specific points. Helps to speed up generation of visualization.
39
+ custom_labels: If bool, whether to use custom topic labels that were defined using
40
+ `topic_model.set_topic_labels`.
41
+ If `str`, it uses labels from other aspects, e.g., "Aspect1".
42
+ title: Title of the plot.
43
+ width: The width of the figure.
44
+ height: The height of the figure.
45
+
46
+ Examples:
47
+
48
+ To visualize the topics simply run:
49
+
50
+ ```python
51
+ topic_model.visualize_documents(docs)
52
+ ```
53
+
54
+ Do note that this re-calculates the embeddings and reduces them to 2D.
55
+ The advised and preferred pipeline for using this function is as follows:
56
+
57
+ ```python
58
+ from sklearn.datasets import fetch_20newsgroups
59
+ from sentence_transformers import SentenceTransformer
60
+ from bertopic import BERTopic
61
+ from umap import UMAP
62
+
63
+ # Prepare embeddings
64
+ docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data']
65
+ sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
66
+ embeddings = sentence_model.encode(docs, show_progress_bar=False)
67
+
68
+ # Train BERTopic
69
+ topic_model = BERTopic().fit(docs, embeddings)
70
+
71
+ # Reduce dimensionality of embeddings, this step is optional
72
+ # reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings)
73
+
74
+ # Run the visualization with the original embeddings
75
+ topic_model.visualize_documents(docs, embeddings=embeddings)
76
+
77
+ # Or, if you have reduced the original embeddings already:
78
+ topic_model.visualize_documents(docs, reduced_embeddings=reduced_embeddings)
79
+ ```
80
+
81
+ Or if you want to save the resulting figure:
82
+
83
+ ```python
84
+ fig = topic_model.visualize_documents(docs, reduced_embeddings=reduced_embeddings)
85
+ fig.write_html("path/to/file.html")
86
+ ```
87
+
88
+ <iframe src="../../getting_started/visualization/documents.html"
89
+ style="width:1000px; height: 800px; border: 0px;""></iframe>
90
+ """
91
+ topic_per_doc = topic_model.topics_
92
+
93
+ # Sample the data to optimize for visualization and dimensionality reduction
94
+ if sample is None or sample > 1:
95
+ sample = 1
96
+
97
+ indices = []
98
+ for topic in set(topic_per_doc):
99
+ s = np.where(np.array(topic_per_doc) == topic)[0]
100
+ size = len(s) if len(s) < 100 else int(len(s) * sample)
101
+ indices.extend(np.random.choice(s, size=size, replace=False))
102
+ indices = np.array(indices)
103
+
104
+ df = pd.DataFrame({"topic": np.array(topic_per_doc)[indices]})
105
+ df["doc"] = [docs[index] for index in indices]
106
+ df["topic"] = [topic_per_doc[index] for index in indices]
107
+
108
+ # Extract embeddings if not already done
109
+ if sample is None:
110
+ if embeddings is None and reduced_embeddings is None:
111
+ embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document")
112
+ else:
113
+ embeddings_to_reduce = embeddings
114
+ else:
115
+ if embeddings is not None:
116
+ embeddings_to_reduce = embeddings[indices]
117
+ elif embeddings is None and reduced_embeddings is None:
118
+ embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document")
119
+
120
+ # Reduce input embeddings
121
+ if reduced_embeddings is None:
122
+ umap_model = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit(embeddings_to_reduce)
123
+ embeddings_2d = umap_model.embedding_
124
+ elif sample is not None and reduced_embeddings is not None:
125
+ embeddings_2d = reduced_embeddings[indices]
126
+ elif sample is None and reduced_embeddings is not None:
127
+ embeddings_2d = reduced_embeddings
128
+
129
+ unique_topics = set(topic_per_doc)
130
+ if topics is None:
131
+ topics = unique_topics
132
+
133
+ # Combine data
134
+ df["x"] = embeddings_2d[:, 0]
135
+ df["y"] = embeddings_2d[:, 1]
136
+
137
+ # Prepare text and names
138
+ if isinstance(custom_labels, str):
139
+ names = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in unique_topics]
140
+ names = ["_".join([label[0] for label in labels[:4]]) for labels in names]
141
+ names = [label if len(label) < 30 else label[:27] + "..." for label in names]
142
+ elif topic_model.custom_labels_ is not None and custom_labels:
143
+ names = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in unique_topics]
144
+ else:
145
+ names = [f"{topic}_" + "_".join([word for word, value in topic_model.get_topic(topic)][:3]) for topic in unique_topics]
146
+
147
+ # Visualize
148
+ fig = go.Figure()
149
+
150
+ # Outliers and non-selected topics
151
+ non_selected_topics = set(unique_topics).difference(topics)
152
+ if len(non_selected_topics) == 0:
153
+ non_selected_topics = [-1]
154
+
155
+ selection = df.loc[df.topic.isin(non_selected_topics), :]
156
+ selection["text"] = ""
157
+ selection.loc[len(selection), :] = [None, None, selection.x.mean(), selection.y.mean(), "Other documents"]
158
+
159
+ fig.add_trace(
160
+ go.Scattergl(
161
+ x=selection.x,
162
+ y=selection.y,
163
+ hovertext=selection.doc if not hide_document_hover else None,
164
+ hoverinfo="text",
165
+ mode='markers+text',
166
+ name="other",
167
+ showlegend=False,
168
+ marker=dict(color='#CFD8DC', size=5, opacity=0.5)
169
+ )
170
+ )
171
+
172
+ # Selected topics
173
+ for name, topic in zip(names, unique_topics):
174
+ if topic in topics and topic != -1:
175
+ selection = df.loc[df.topic == topic, :]
176
+ selection["text"] = ""
177
+
178
+ if not hide_annotations:
179
+ selection.loc[len(selection), :] = [None, None, selection.x.mean(), selection.y.mean(), name]
180
+
181
+ fig.add_trace(
182
+ go.Scattergl(
183
+ x=selection.x,
184
+ y=selection.y,
185
+ hovertext=selection.doc if not hide_document_hover else None,
186
+ hoverinfo="text",
187
+ text=selection.text,
188
+ mode='markers+text',
189
+ name=name,
190
+ textfont=dict(
191
+ size=12,
192
+ ),
193
+ marker=dict(size=5, opacity=0.5)
194
+ )
195
+ )
196
+
197
+ # Add grid in a 'plus' shape
198
+ x_range = (df.x.min() - abs((df.x.min()) * .15), df.x.max() + abs((df.x.max()) * .15))
199
+ y_range = (df.y.min() - abs((df.y.min()) * .15), df.y.max() + abs((df.y.max()) * .15))
200
+ fig.add_shape(type="line",
201
+ x0=sum(x_range) / 2, y0=y_range[0], x1=sum(x_range) / 2, y1=y_range[1],
202
+ line=dict(color="#CFD8DC", width=2))
203
+ fig.add_shape(type="line",
204
+ x0=x_range[0], y0=sum(y_range) / 2, x1=x_range[1], y1=sum(y_range) / 2,
205
+ line=dict(color="#9E9E9E", width=2))
206
+ fig.add_annotation(x=x_range[0], y=sum(y_range) / 2, text="D1", showarrow=False, yshift=10)
207
+ fig.add_annotation(y=y_range[1], x=sum(x_range) / 2, text="D2", showarrow=False, xshift=10)
208
+
209
+ # Stylize layout
210
+ fig.update_layout(
211
+ template="simple_white",
212
+ title={
213
+ 'text': f"{title}",
214
+ 'x': 0.5,
215
+ 'xanchor': 'center',
216
+ 'yanchor': 'top',
217
+ 'font': dict(
218
+ size=22,
219
+ color="Black")
220
+ },
221
+ width=width,
222
+ height=height
223
+ )
224
+
225
+ fig.update_xaxes(visible=False)
226
+ fig.update_yaxes(visible=False)
227
+ return fig
BERTopic/bertopic/plotting/_heatmap.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List, Union
3
+ from scipy.cluster.hierarchy import fcluster, linkage
4
+ from sklearn.metrics.pairwise import cosine_similarity
5
+
6
+ import plotly.express as px
7
+ import plotly.graph_objects as go
8
+
9
+
10
+ def visualize_heatmap(topic_model,
11
+ topics: List[int] = None,
12
+ top_n_topics: int = None,
13
+ n_clusters: int = None,
14
+ custom_labels: Union[bool, str] = False,
15
+ title: str = "<b>Similarity Matrix</b>",
16
+ width: int = 800,
17
+ height: int = 800) -> go.Figure:
18
+ """ Visualize a heatmap of the topic's similarity matrix
19
+
20
+ Based on the cosine similarity matrix between topic embeddings,
21
+ a heatmap is created showing the similarity between topics.
22
+
23
+ Arguments:
24
+ topic_model: A fitted BERTopic instance.
25
+ topics: A selection of topics to visualize.
26
+ top_n_topics: Only select the top n most frequent topics.
27
+ n_clusters: Create n clusters and order the similarity
28
+ matrix by those clusters.
29
+ custom_labels: If bool, whether to use custom topic labels that were defined using
30
+ `topic_model.set_topic_labels`.
31
+ If `str`, it uses labels from other aspects, e.g., "Aspect1".
32
+ title: Title of the plot.
33
+ width: The width of the figure.
34
+ height: The height of the figure.
35
+
36
+ Returns:
37
+ fig: A plotly figure
38
+
39
+ Examples:
40
+
41
+ To visualize the similarity matrix of
42
+ topics simply run:
43
+
44
+ ```python
45
+ topic_model.visualize_heatmap()
46
+ ```
47
+
48
+ Or if you want to save the resulting figure:
49
+
50
+ ```python
51
+ fig = topic_model.visualize_heatmap()
52
+ fig.write_html("path/to/file.html")
53
+ ```
54
+ <iframe src="../../getting_started/visualization/heatmap.html"
55
+ style="width:1000px; height: 720px; border: 0px;""></iframe>
56
+ """
57
+
58
+ # Select topic embeddings
59
+ if topic_model.topic_embeddings_ is not None:
60
+ embeddings = np.array(topic_model.topic_embeddings_)[topic_model._outliers:]
61
+ else:
62
+ embeddings = topic_model.c_tf_idf_[topic_model._outliers:]
63
+
64
+ # Select topics based on top_n and topics args
65
+ freq_df = topic_model.get_topic_freq()
66
+ freq_df = freq_df.loc[freq_df.Topic != -1, :]
67
+ if topics is not None:
68
+ topics = list(topics)
69
+ elif top_n_topics is not None:
70
+ topics = sorted(freq_df.Topic.to_list()[:top_n_topics])
71
+ else:
72
+ topics = sorted(freq_df.Topic.to_list())
73
+
74
+ # Order heatmap by similar clusters of topics
75
+ sorted_topics = topics
76
+ if n_clusters:
77
+ if n_clusters >= len(set(topics)):
78
+ raise ValueError("Make sure to set `n_clusters` lower than "
79
+ "the total number of unique topics.")
80
+
81
+ distance_matrix = cosine_similarity(embeddings[topics])
82
+ Z = linkage(distance_matrix, 'ward')
83
+ clusters = fcluster(Z, t=n_clusters, criterion='maxclust')
84
+
85
+ # Extract new order of topics
86
+ mapping = {cluster: [] for cluster in clusters}
87
+ for topic, cluster in zip(topics, clusters):
88
+ mapping[cluster].append(topic)
89
+ mapping = [cluster for cluster in mapping.values()]
90
+ sorted_topics = [topic for cluster in mapping for topic in cluster]
91
+
92
+ # Select embeddings
93
+ indices = np.array([topics.index(topic) for topic in sorted_topics])
94
+ embeddings = embeddings[indices]
95
+ distance_matrix = cosine_similarity(embeddings)
96
+
97
+ # Create labels
98
+ if isinstance(custom_labels, str):
99
+ new_labels = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in sorted_topics]
100
+ new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels]
101
+ new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels]
102
+ elif topic_model.custom_labels_ is not None and custom_labels:
103
+ new_labels = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in sorted_topics]
104
+ else:
105
+ new_labels = [[[str(topic), None]] + topic_model.get_topic(topic) for topic in sorted_topics]
106
+ new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels]
107
+ new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels]
108
+
109
+ fig = px.imshow(distance_matrix,
110
+ labels=dict(color="Similarity Score"),
111
+ x=new_labels,
112
+ y=new_labels,
113
+ color_continuous_scale='GnBu'
114
+ )
115
+
116
+ fig.update_layout(
117
+ title={
118
+ 'text': f"{title}",
119
+ 'y': .95,
120
+ 'x': 0.55,
121
+ 'xanchor': 'center',
122
+ 'yanchor': 'top',
123
+ 'font': dict(
124
+ size=22,
125
+ color="Black")
126
+ },
127
+ width=width,
128
+ height=height,
129
+ hoverlabel=dict(
130
+ bgcolor="white",
131
+ font_size=16,
132
+ font_family="Rockwell"
133
+ ),
134
+ )
135
+ fig.update_layout(showlegend=True)
136
+ fig.update_layout(legend_title_text='Trend')
137
+
138
+ return fig
BERTopic/bertopic/plotting/_hierarchical_documents.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import plotly.graph_objects as go
4
+ import math
5
+
6
+ from umap import UMAP
7
+ from typing import List, Union
8
+
9
+
10
+ def visualize_hierarchical_documents(topic_model,
11
+ docs: List[str],
12
+ hierarchical_topics: pd.DataFrame,
13
+ topics: List[int] = None,
14
+ embeddings: np.ndarray = None,
15
+ reduced_embeddings: np.ndarray = None,
16
+ sample: Union[float, int] = None,
17
+ hide_annotations: bool = False,
18
+ hide_document_hover: bool = True,
19
+ nr_levels: int = 10,
20
+ level_scale: str = 'linear',
21
+ custom_labels: Union[bool, str] = False,
22
+ title: str = "<b>Hierarchical Documents and Topics</b>",
23
+ width: int = 1200,
24
+ height: int = 750) -> go.Figure:
25
+ """ Visualize documents and their topics in 2D at different levels of hierarchy
26
+
27
+ Arguments:
28
+ docs: The documents you used when calling either `fit` or `fit_transform`
29
+ hierarchical_topics: A dataframe that contains a hierarchy of topics
30
+ represented by their parents and their children
31
+ topics: A selection of topics to visualize.
32
+ Not to be confused with the topics that you get from `.fit_transform`.
33
+ For example, if you want to visualize only topics 1 through 5:
34
+ `topics = [1, 2, 3, 4, 5]`.
35
+ embeddings: The embeddings of all documents in `docs`.
36
+ reduced_embeddings: The 2D reduced embeddings of all documents in `docs`.
37
+ sample: The percentage of documents in each topic that you would like to keep.
38
+ Value can be between 0 and 1. Setting this value to, for example,
39
+ 0.1 (10% of documents in each topic) makes it easier to visualize
40
+ millions of documents as a subset is chosen.
41
+ hide_annotations: Hide the names of the traces on top of each cluster.
42
+ hide_document_hover: Hide the content of the documents when hovering over
43
+ specific points. Helps to speed up generation of visualizations.
44
+ nr_levels: The number of levels to be visualized in the hierarchy. First, the distances
45
+ in `hierarchical_topics.Distance` are split in `nr_levels` lists of distances.
46
+ Then, for each list of distances, the merged topics are selected that have a
47
+ distance less or equal to the maximum distance of the selected list of distances.
48
+ NOTE: To get all possible merged steps, make sure that `nr_levels` is equal to
49
+ the length of `hierarchical_topics`.
50
+ level_scale: Whether to apply a linear or logarithmic (log) scale levels of the distance
51
+ vector. Linear scaling will perform an equal number of merges at each level
52
+ while logarithmic scaling will perform more mergers in earlier levels to
53
+ provide more resolution at higher levels (this can be used for when the number
54
+ of topics is large).
55
+ custom_labels: If bool, whether to use custom topic labels that were defined using
56
+ `topic_model.set_topic_labels`.
57
+ If `str`, it uses labels from other aspects, e.g., "Aspect1".
58
+ NOTE: Custom labels are only generated for the original
59
+ un-merged topics.
60
+ title: Title of the plot.
61
+ width: The width of the figure.
62
+ height: The height of the figure.
63
+
64
+ Examples:
65
+
66
+ To visualize the topics simply run:
67
+
68
+ ```python
69
+ topic_model.visualize_hierarchical_documents(docs, hierarchical_topics)
70
+ ```
71
+
72
+ Do note that this re-calculates the embeddings and reduces them to 2D.
73
+ The advised and preferred pipeline for using this function is as follows:
74
+
75
+ ```python
76
+ from sklearn.datasets import fetch_20newsgroups
77
+ from sentence_transformers import SentenceTransformer
78
+ from bertopic import BERTopic
79
+ from umap import UMAP
80
+
81
+ # Prepare embeddings
82
+ docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data']
83
+ sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
84
+ embeddings = sentence_model.encode(docs, show_progress_bar=False)
85
+
86
+ # Train BERTopic and extract hierarchical topics
87
+ topic_model = BERTopic().fit(docs, embeddings)
88
+ hierarchical_topics = topic_model.hierarchical_topics(docs)
89
+
90
+ # Reduce dimensionality of embeddings, this step is optional
91
+ # reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings)
92
+
93
+ # Run the visualization with the original embeddings
94
+ topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, embeddings=embeddings)
95
+
96
+ # Or, if you have reduced the original embeddings already:
97
+ topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, reduced_embeddings=reduced_embeddings)
98
+ ```
99
+
100
+ Or if you want to save the resulting figure:
101
+
102
+ ```python
103
+ fig = topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, reduced_embeddings=reduced_embeddings)
104
+ fig.write_html("path/to/file.html")
105
+ ```
106
+
107
+ NOTE:
108
+ This visualization was inspired by the scatter plot representation of Doc2Map:
109
+ https://github.com/louisgeisler/Doc2Map
110
+
111
+ <iframe src="../../getting_started/visualization/hierarchical_documents.html"
112
+ style="width:1000px; height: 770px; border: 0px;""></iframe>
113
+ """
114
+ topic_per_doc = topic_model.topics_
115
+
116
+ # Sample the data to optimize for visualization and dimensionality reduction
117
+ if sample is None or sample > 1:
118
+ sample = 1
119
+
120
+ indices = []
121
+ for topic in set(topic_per_doc):
122
+ s = np.where(np.array(topic_per_doc) == topic)[0]
123
+ size = len(s) if len(s) < 100 else int(len(s)*sample)
124
+ indices.extend(np.random.choice(s, size=size, replace=False))
125
+ indices = np.array(indices)
126
+
127
+ df = pd.DataFrame({"topic": np.array(topic_per_doc)[indices]})
128
+ df["doc"] = [docs[index] for index in indices]
129
+ df["topic"] = [topic_per_doc[index] for index in indices]
130
+
131
+ # Extract embeddings if not already done
132
+ if sample is None:
133
+ if embeddings is None and reduced_embeddings is None:
134
+ embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document")
135
+ else:
136
+ embeddings_to_reduce = embeddings
137
+ else:
138
+ if embeddings is not None:
139
+ embeddings_to_reduce = embeddings[indices]
140
+ elif embeddings is None and reduced_embeddings is None:
141
+ embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document")
142
+
143
+ # Reduce input embeddings
144
+ if reduced_embeddings is None:
145
+ umap_model = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit(embeddings_to_reduce)
146
+ embeddings_2d = umap_model.embedding_
147
+ elif sample is not None and reduced_embeddings is not None:
148
+ embeddings_2d = reduced_embeddings[indices]
149
+ elif sample is None and reduced_embeddings is not None:
150
+ embeddings_2d = reduced_embeddings
151
+
152
+ # Combine data
153
+ df["x"] = embeddings_2d[:, 0]
154
+ df["y"] = embeddings_2d[:, 1]
155
+
156
+ # Create topic list for each level, levels are created by calculating the distance
157
+ distances = hierarchical_topics.Distance.to_list()
158
+ if level_scale == 'log' or level_scale == 'logarithmic':
159
+ log_indices = np.round(np.logspace(start=math.log(1,10), stop=math.log(len(distances)-1,10), num=nr_levels)).astype(int).tolist()
160
+ log_indices.reverse()
161
+ max_distances = [distances[i] for i in log_indices]
162
+ elif level_scale == 'lin' or level_scale == 'linear':
163
+ max_distances = [distances[indices[-1]] for indices in np.array_split(range(len(hierarchical_topics)), nr_levels)][::-1]
164
+ else:
165
+ raise ValueError("level_scale needs to be one of 'log' or 'linear'")
166
+
167
+ for index, max_distance in enumerate(max_distances):
168
+
169
+ # Get topics below `max_distance`
170
+ mapping = {topic: topic for topic in df.topic.unique()}
171
+ selection = hierarchical_topics.loc[hierarchical_topics.Distance <= max_distance, :]
172
+ selection.Parent_ID = selection.Parent_ID.astype(int)
173
+ selection = selection.sort_values("Parent_ID")
174
+
175
+ for row in selection.iterrows():
176
+ for topic in row[1].Topics:
177
+ mapping[topic] = row[1].Parent_ID
178
+
179
+ # Make sure the mappings are mapped 1:1
180
+ mappings = [True for _ in mapping]
181
+ while any(mappings):
182
+ for i, (key, value) in enumerate(mapping.items()):
183
+ if value in mapping.keys() and key != value:
184
+ mapping[key] = mapping[value]
185
+ else:
186
+ mappings[i] = False
187
+
188
+ # Create new column
189
+ df[f"level_{index+1}"] = df.topic.map(mapping)
190
+ df[f"level_{index+1}"] = df[f"level_{index+1}"].astype(int)
191
+
192
+ # Prepare topic names of original and merged topics
193
+ trace_names = []
194
+ topic_names = {}
195
+ for topic in range(hierarchical_topics.Parent_ID.astype(int).max()):
196
+ if topic < hierarchical_topics.Parent_ID.astype(int).min():
197
+ if topic_model.get_topic(topic):
198
+ if isinstance(custom_labels, str):
199
+ trace_name = f"{topic}_" + "_".join(list(zip(*topic_model.topic_aspects_[custom_labels][topic]))[0][:3])
200
+ elif topic_model.custom_labels_ is not None and custom_labels:
201
+ trace_name = topic_model.custom_labels_[topic + topic_model._outliers]
202
+ else:
203
+ trace_name = f"{topic}_" + "_".join([word[:20] for word, _ in topic_model.get_topic(topic)][:3])
204
+ topic_names[topic] = {"trace_name": trace_name[:40], "plot_text": trace_name[:40]}
205
+ trace_names.append(trace_name)
206
+ else:
207
+ trace_name = f"{topic}_" + hierarchical_topics.loc[hierarchical_topics.Parent_ID == str(topic), "Parent_Name"].values[0]
208
+ plot_text = "_".join([name[:20] for name in trace_name.split("_")[:3]])
209
+ topic_names[topic] = {"trace_name": trace_name[:40], "plot_text": plot_text[:40]}
210
+ trace_names.append(trace_name)
211
+
212
+ # Prepare traces
213
+ all_traces = []
214
+ for level in range(len(max_distances)):
215
+ traces = []
216
+
217
+ # Outliers
218
+ if topic_model._outliers:
219
+ traces.append(
220
+ go.Scattergl(
221
+ x=df.loc[(df[f"level_{level+1}"] == -1), "x"],
222
+ y=df.loc[df[f"level_{level+1}"] == -1, "y"],
223
+ mode='markers+text',
224
+ name="other",
225
+ hoverinfo="text",
226
+ hovertext=df.loc[(df[f"level_{level+1}"] == -1), "doc"] if not hide_document_hover else None,
227
+ showlegend=False,
228
+ marker=dict(color='#CFD8DC', size=5, opacity=0.5)
229
+ )
230
+ )
231
+
232
+ # Selected topics
233
+ if topics:
234
+ selection = df.loc[(df.topic.isin(topics)), :]
235
+ unique_topics = sorted([int(topic) for topic in selection[f"level_{level+1}"].unique()])
236
+ else:
237
+ unique_topics = sorted([int(topic) for topic in df[f"level_{level+1}"].unique()])
238
+
239
+ for topic in unique_topics:
240
+ if topic != -1:
241
+ if topics:
242
+ selection = df.loc[(df[f"level_{level+1}"] == topic) &
243
+ (df.topic.isin(topics)), :]
244
+ else:
245
+ selection = df.loc[df[f"level_{level+1}"] == topic, :]
246
+
247
+ if not hide_annotations:
248
+ selection.loc[len(selection), :] = None
249
+ selection["text"] = ""
250
+ selection.loc[len(selection) - 1, "x"] = selection.x.mean()
251
+ selection.loc[len(selection) - 1, "y"] = selection.y.mean()
252
+ selection.loc[len(selection) - 1, "text"] = topic_names[int(topic)]["plot_text"]
253
+
254
+ traces.append(
255
+ go.Scattergl(
256
+ x=selection.x,
257
+ y=selection.y,
258
+ text=selection.text if not hide_annotations else None,
259
+ hovertext=selection.doc if not hide_document_hover else None,
260
+ hoverinfo="text",
261
+ name=topic_names[int(topic)]["trace_name"],
262
+ mode='markers+text',
263
+ marker=dict(size=5, opacity=0.5)
264
+ )
265
+ )
266
+
267
+ all_traces.append(traces)
268
+
269
+ # Track and count traces
270
+ nr_traces_per_set = [len(traces) for traces in all_traces]
271
+ trace_indices = [(0, nr_traces_per_set[0])]
272
+ for index, nr_traces in enumerate(nr_traces_per_set[1:]):
273
+ start = trace_indices[index][1]
274
+ end = nr_traces + start
275
+ trace_indices.append((start, end))
276
+
277
+ # Visualization
278
+ fig = go.Figure()
279
+ for traces in all_traces:
280
+ for trace in traces:
281
+ fig.add_trace(trace)
282
+
283
+ for index in range(len(fig.data)):
284
+ if index >= nr_traces_per_set[0]:
285
+ fig.data[index].visible = False
286
+
287
+ # Create and add slider
288
+ steps = []
289
+ for index, indices in enumerate(trace_indices):
290
+ step = dict(
291
+ method="update",
292
+ label=str(index),
293
+ args=[{"visible": [False] * len(fig.data)}]
294
+ )
295
+ for index in range(indices[1]-indices[0]):
296
+ step["args"][0]["visible"][index+indices[0]] = True
297
+ steps.append(step)
298
+
299
+ sliders = [dict(
300
+ currentvalue={"prefix": "Level: "},
301
+ pad={"t": 20},
302
+ steps=steps
303
+ )]
304
+
305
+ # Add grid in a 'plus' shape
306
+ x_range = (df.x.min() - abs((df.x.min()) * .15), df.x.max() + abs((df.x.max()) * .15))
307
+ y_range = (df.y.min() - abs((df.y.min()) * .15), df.y.max() + abs((df.y.max()) * .15))
308
+ fig.add_shape(type="line",
309
+ x0=sum(x_range) / 2, y0=y_range[0], x1=sum(x_range) / 2, y1=y_range[1],
310
+ line=dict(color="#CFD8DC", width=2))
311
+ fig.add_shape(type="line",
312
+ x0=x_range[0], y0=sum(y_range) / 2, x1=x_range[1], y1=sum(y_range) / 2,
313
+ line=dict(color="#9E9E9E", width=2))
314
+ fig.add_annotation(x=x_range[0], y=sum(y_range) / 2, text="D1", showarrow=False, yshift=10)
315
+ fig.add_annotation(y=y_range[1], x=sum(x_range) / 2, text="D2", showarrow=False, xshift=10)
316
+
317
+ # Stylize layout
318
+ fig.update_layout(
319
+ sliders=sliders,
320
+ template="simple_white",
321
+ title={
322
+ 'text': f"{title}",
323
+ 'x': 0.5,
324
+ 'xanchor': 'center',
325
+ 'yanchor': 'top',
326
+ 'font': dict(
327
+ size=22,
328
+ color="Black")
329
+ },
330
+ width=width,
331
+ height=height,
332
+ )
333
+
334
+ fig.update_xaxes(visible=False)
335
+ fig.update_yaxes(visible=False)
336
+ return fig
BERTopic/bertopic/plotting/_hierarchy.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from typing import Callable, List, Union
4
+ from scipy.sparse import csr_matrix
5
+ from scipy.cluster import hierarchy as sch
6
+ from scipy.spatial.distance import squareform
7
+ from sklearn.metrics.pairwise import cosine_similarity
8
+
9
+ import plotly.graph_objects as go
10
+ import plotly.figure_factory as ff
11
+
12
+ from bertopic._utils import validate_distance_matrix
13
+
14
+ def visualize_hierarchy(topic_model,
15
+ orientation: str = "left",
16
+ topics: List[int] = None,
17
+ top_n_topics: int = None,
18
+ custom_labels: Union[bool, str] = False,
19
+ title: str = "<b>Hierarchical Clustering</b>",
20
+ width: int = 1000,
21
+ height: int = 600,
22
+ hierarchical_topics: pd.DataFrame = None,
23
+ linkage_function: Callable[[csr_matrix], np.ndarray] = None,
24
+ distance_function: Callable[[csr_matrix], csr_matrix] = None,
25
+ color_threshold: int = 1) -> go.Figure:
26
+ """ Visualize a hierarchical structure of the topics
27
+
28
+ A ward linkage function is used to perform the
29
+ hierarchical clustering based on the cosine distance
30
+ matrix between topic embeddings.
31
+
32
+ Arguments:
33
+ topic_model: A fitted BERTopic instance.
34
+ orientation: The orientation of the figure.
35
+ Either 'left' or 'bottom'
36
+ topics: A selection of topics to visualize
37
+ top_n_topics: Only select the top n most frequent topics
38
+ custom_labels: If bool, whether to use custom topic labels that were defined using
39
+ `topic_model.set_topic_labels`.
40
+ If `str`, it uses labels from other aspects, e.g., "Aspect1".
41
+ NOTE: Custom labels are only generated for the original
42
+ un-merged topics.
43
+ title: Title of the plot.
44
+ width: The width of the figure. Only works if orientation is set to 'left'
45
+ height: The height of the figure. Only works if orientation is set to 'bottom'
46
+ hierarchical_topics: A dataframe that contains a hierarchy of topics
47
+ represented by their parents and their children.
48
+ NOTE: The hierarchical topic names are only visualized
49
+ if both `topics` and `top_n_topics` are not set.
50
+ linkage_function: The linkage function to use. Default is:
51
+ `lambda x: sch.linkage(x, 'ward', optimal_ordering=True)`
52
+ NOTE: Make sure to use the same `linkage_function` as used
53
+ in `topic_model.hierarchical_topics`.
54
+ distance_function: The distance function to use on the c-TF-IDF matrix. Default is:
55
+ `lambda x: 1 - cosine_similarity(x)`.
56
+ You can pass any function that returns either a square matrix of
57
+ shape (n_samples, n_samples) with zeros on the diagonal and
58
+ non-negative values or condensed distance matrix of shape
59
+ (n_samples * (n_samples - 1) / 2,) containing the upper
60
+ triangular of the distance matrix.
61
+ NOTE: Make sure to use the same `distance_function` as used
62
+ in `topic_model.hierarchical_topics`.
63
+ color_threshold: Value at which the separation of clusters will be made which
64
+ will result in different colors for different clusters.
65
+ A higher value will typically lead in less colored clusters.
66
+
67
+ Returns:
68
+ fig: A plotly figure
69
+
70
+ Examples:
71
+
72
+ To visualize the hierarchical structure of
73
+ topics simply run:
74
+
75
+ ```python
76
+ topic_model.visualize_hierarchy()
77
+ ```
78
+
79
+ If you also want the labels visualized of hierarchical topics,
80
+ run the following:
81
+
82
+ ```python
83
+ # Extract hierarchical topics and their representations
84
+ hierarchical_topics = topic_model.hierarchical_topics(docs)
85
+
86
+ # Visualize these representations
87
+ topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics)
88
+ ```
89
+
90
+ If you want to save the resulting figure:
91
+
92
+ ```python
93
+ fig = topic_model.visualize_hierarchy()
94
+ fig.write_html("path/to/file.html")
95
+ ```
96
+ <iframe src="../../getting_started/visualization/hierarchy.html"
97
+ style="width:1000px; height: 680px; border: 0px;""></iframe>
98
+ """
99
+ if distance_function is None:
100
+ distance_function = lambda x: 1 - cosine_similarity(x)
101
+
102
+ if linkage_function is None:
103
+ linkage_function = lambda x: sch.linkage(x, 'ward', optimal_ordering=True)
104
+
105
+ # Select topics based on top_n and topics args
106
+ freq_df = topic_model.get_topic_freq()
107
+ freq_df = freq_df.loc[freq_df.Topic != -1, :]
108
+ if topics is not None:
109
+ topics = list(topics)
110
+ elif top_n_topics is not None:
111
+ topics = sorted(freq_df.Topic.to_list()[:top_n_topics])
112
+ else:
113
+ topics = sorted(freq_df.Topic.to_list())
114
+
115
+ # Select embeddings
116
+ all_topics = sorted(list(topic_model.get_topics().keys()))
117
+ indices = np.array([all_topics.index(topic) for topic in topics])
118
+
119
+ # Select topic embeddings
120
+ if topic_model.c_tf_idf_ is not None:
121
+ embeddings = topic_model.c_tf_idf_[indices]
122
+ else:
123
+ embeddings = np.array(topic_model.topic_embeddings_)[indices]
124
+
125
+ # Annotations
126
+ if hierarchical_topics is not None and len(topics) == len(freq_df.Topic.to_list()):
127
+ annotations = _get_annotations(topic_model=topic_model,
128
+ hierarchical_topics=hierarchical_topics,
129
+ embeddings=embeddings,
130
+ distance_function=distance_function,
131
+ linkage_function=linkage_function,
132
+ orientation=orientation,
133
+ custom_labels=custom_labels)
134
+ else:
135
+ annotations = None
136
+
137
+ # wrap distance function to validate input and return a condensed distance matrix
138
+ distance_function_viz = lambda x: validate_distance_matrix(
139
+ distance_function(x), embeddings.shape[0])
140
+ # Create dendogram
141
+ fig = ff.create_dendrogram(embeddings,
142
+ orientation=orientation,
143
+ distfun=distance_function_viz,
144
+ linkagefun=linkage_function,
145
+ hovertext=annotations,
146
+ color_threshold=color_threshold)
147
+
148
+ # Create nicer labels
149
+ axis = "yaxis" if orientation == "left" else "xaxis"
150
+ if isinstance(custom_labels, str):
151
+ new_labels = [[[str(x), None]] + topic_model.topic_aspects_[custom_labels][x] for x in fig.layout[axis]["ticktext"]]
152
+ new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels]
153
+ new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels]
154
+ elif topic_model.custom_labels_ is not None and custom_labels:
155
+ new_labels = [topic_model.custom_labels_[topics[int(x)] + topic_model._outliers] for x in fig.layout[axis]["ticktext"]]
156
+ else:
157
+ new_labels = [[[str(topics[int(x)]), None]] + topic_model.get_topic(topics[int(x)])
158
+ for x in fig.layout[axis]["ticktext"]]
159
+ new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels]
160
+ new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels]
161
+
162
+ # Stylize layout
163
+ fig.update_layout(
164
+ plot_bgcolor='#ECEFF1',
165
+ template="plotly_white",
166
+ title={
167
+ 'text': f"{title}",
168
+ 'x': 0.5,
169
+ 'xanchor': 'center',
170
+ 'yanchor': 'top',
171
+ 'font': dict(
172
+ size=22,
173
+ color="Black")
174
+ },
175
+ hoverlabel=dict(
176
+ bgcolor="white",
177
+ font_size=16,
178
+ font_family="Rockwell"
179
+ ),
180
+ )
181
+
182
+ # Stylize orientation
183
+ if orientation == "left":
184
+ fig.update_layout(height=200 + (15 * len(topics)),
185
+ width=width,
186
+ yaxis=dict(tickmode="array",
187
+ ticktext=new_labels))
188
+
189
+ # Fix empty space on the bottom of the graph
190
+ y_max = max([trace['y'].max() + 5 for trace in fig['data']])
191
+ y_min = min([trace['y'].min() - 5 for trace in fig['data']])
192
+ fig.update_layout(yaxis=dict(range=[y_min, y_max]))
193
+
194
+ else:
195
+ fig.update_layout(width=200 + (15 * len(topics)),
196
+ height=height,
197
+ xaxis=dict(tickmode="array",
198
+ ticktext=new_labels))
199
+
200
+ if hierarchical_topics is not None:
201
+ for index in [0, 3]:
202
+ axis = "x" if orientation == "left" else "y"
203
+ xs = [data["x"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)]
204
+ ys = [data["y"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)]
205
+ hovertext = [data["text"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)]
206
+
207
+ fig.add_trace(go.Scatter(x=xs, y=ys, marker_color='black',
208
+ hovertext=hovertext, hoverinfo="text",
209
+ mode='markers', showlegend=False))
210
+ return fig
211
+
212
+
213
+ def _get_annotations(topic_model,
214
+ hierarchical_topics: pd.DataFrame,
215
+ embeddings: csr_matrix,
216
+ linkage_function: Callable[[csr_matrix], np.ndarray],
217
+ distance_function: Callable[[csr_matrix], csr_matrix],
218
+ orientation: str,
219
+ custom_labels: bool = False) -> List[List[str]]:
220
+
221
+ """ Get annotations by replicating linkage function calculation in scipy
222
+
223
+ Arguments
224
+ topic_model: A fitted BERTopic instance.
225
+ hierarchical_topics: A dataframe that contains a hierarchy of topics
226
+ represented by their parents and their children.
227
+ NOTE: The hierarchical topic names are only visualized
228
+ if both `topics` and `top_n_topics` are not set.
229
+ embeddings: The c-TF-IDF matrix on which to model the hierarchy
230
+ linkage_function: The linkage function to use. Default is:
231
+ `lambda x: sch.linkage(x, 'ward', optimal_ordering=True)`
232
+ NOTE: Make sure to use the same `linkage_function` as used
233
+ in `topic_model.hierarchical_topics`.
234
+ distance_function: The distance function to use on the c-TF-IDF matrix. Default is:
235
+ `lambda x: 1 - cosine_similarity(x)`.
236
+ You can pass any function that returns either a square matrix of
237
+ shape (n_samples, n_samples) with zeros on the diagonal and
238
+ non-negative values or condensed distance matrix of shape
239
+ (n_samples * (n_samples - 1) / 2,) containing the upper
240
+ triangular of the distance matrix.
241
+ NOTE: Make sure to use the same `distance_function` as used
242
+ in `topic_model.hierarchical_topics`.
243
+ orientation: The orientation of the figure.
244
+ Either 'left' or 'bottom'
245
+ custom_labels: Whether to use custom topic labels that were defined using
246
+ `topic_model.set_topic_labels`.
247
+ NOTE: Custom labels are only generated for the original
248
+ un-merged topics.
249
+
250
+ Returns:
251
+ text_annotations: Annotations to be used within Plotly's `ff.create_dendogram`
252
+ """
253
+ df = hierarchical_topics.loc[hierarchical_topics.Parent_Name != "Top", :]
254
+
255
+ # Calculate distance
256
+ X = distance_function(embeddings)
257
+ X = validate_distance_matrix(X, embeddings.shape[0])
258
+
259
+ # Calculate linkage and generate dendrogram
260
+ Z = linkage_function(X)
261
+ P = sch.dendrogram(Z, orientation=orientation, no_plot=True)
262
+
263
+ # store topic no.(leaves) corresponding to the x-ticks in dendrogram
264
+ x_ticks = np.arange(5, len(P['leaves']) * 10 + 5, 10)
265
+ x_topic = dict(zip(P['leaves'], x_ticks))
266
+
267
+ topic_vals = dict()
268
+ for key, val in x_topic.items():
269
+ topic_vals[val] = [key]
270
+
271
+ parent_topic = dict(zip(df.Parent_ID, df.Topics))
272
+
273
+ # loop through every trace (scatter plot) in dendrogram
274
+ text_annotations = []
275
+ for index, trace in enumerate(P['icoord']):
276
+ fst_topic = topic_vals[trace[0]]
277
+ scnd_topic = topic_vals[trace[2]]
278
+
279
+ if len(fst_topic) == 1:
280
+ if isinstance(custom_labels, str):
281
+ fst_name = f"{fst_topic[0]}_" + "_".join(list(zip(*topic_model.topic_aspects_[custom_labels][fst_topic[0]]))[0][:3])
282
+ elif topic_model.custom_labels_ is not None and custom_labels:
283
+ fst_name = topic_model.custom_labels_[fst_topic[0] + topic_model._outliers]
284
+ else:
285
+ fst_name = "_".join([word for word, _ in topic_model.get_topic(fst_topic[0])][:5])
286
+ else:
287
+ for key, value in parent_topic.items():
288
+ if set(value) == set(fst_topic):
289
+ fst_name = df.loc[df.Parent_ID == key, "Parent_Name"].values[0]
290
+
291
+ if len(scnd_topic) == 1:
292
+ if isinstance(custom_labels, str):
293
+ scnd_name = f"{scnd_topic[0]}_" + "_".join(list(zip(*topic_model.topic_aspects_[custom_labels][scnd_topic[0]]))[0][:3])
294
+ elif topic_model.custom_labels_ is not None and custom_labels:
295
+ scnd_name = topic_model.custom_labels_[scnd_topic[0] + topic_model._outliers]
296
+ else:
297
+ scnd_name = "_".join([word for word, _ in topic_model.get_topic(scnd_topic[0])][:5])
298
+ else:
299
+ for key, value in parent_topic.items():
300
+ if set(value) == set(scnd_topic):
301
+ scnd_name = df.loc[df.Parent_ID == key, "Parent_Name"].values[0]
302
+
303
+ text_annotations.append([fst_name, "", "", scnd_name])
304
+
305
+ center = (trace[0] + trace[2]) / 2
306
+ topic_vals[center] = fst_topic + scnd_topic
307
+
308
+ return text_annotations
BERTopic/bertopic/plotting/_term_rank.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List, Union
3
+ import plotly.graph_objects as go
4
+
5
+
6
+ def visualize_term_rank(topic_model,
7
+ topics: List[int] = None,
8
+ log_scale: bool = False,
9
+ custom_labels: Union[bool, str] = False,
10
+ title: str = "<b>Term score decline per Topic</b>",
11
+ width: int = 800,
12
+ height: int = 500) -> go.Figure:
13
+ """ Visualize the ranks of all terms across all topics
14
+
15
+ Each topic is represented by a set of words. These words, however,
16
+ do not all equally represent the topic. This visualization shows
17
+ how many words are needed to represent a topic and at which point
18
+ the beneficial effect of adding words starts to decline.
19
+
20
+ Arguments:
21
+ topic_model: A fitted BERTopic instance.
22
+ topics: A selection of topics to visualize. These will be colored
23
+ red where all others will be colored black.
24
+ log_scale: Whether to represent the ranking on a log scale
25
+ custom_labels: If bool, whether to use custom topic labels that were defined using
26
+ `topic_model.set_topic_labels`.
27
+ If `str`, it uses labels from other aspects, e.g., "Aspect1".
28
+ title: Title of the plot.
29
+ width: The width of the figure.
30
+ height: The height of the figure.
31
+
32
+ Returns:
33
+ fig: A plotly figure
34
+
35
+ Examples:
36
+
37
+ To visualize the ranks of all words across
38
+ all topics simply run:
39
+
40
+ ```python
41
+ topic_model.visualize_term_rank()
42
+ ```
43
+
44
+ Or if you want to save the resulting figure:
45
+
46
+ ```python
47
+ fig = topic_model.visualize_term_rank()
48
+ fig.write_html("path/to/file.html")
49
+ ```
50
+
51
+ <iframe src="../../getting_started/visualization/term_rank.html"
52
+ style="width:1000px; height: 530px; border: 0px;""></iframe>
53
+
54
+ <iframe src="../../getting_started/visualization/term_rank_log.html"
55
+ style="width:1000px; height: 530px; border: 0px;""></iframe>
56
+
57
+ Reference:
58
+
59
+ This visualization was heavily inspired by the
60
+ "Term Probability Decline" visualization found in an
61
+ analysis by the amazing [tmtoolkit](https://tmtoolkit.readthedocs.io/).
62
+ Reference to that specific analysis can be found
63
+ [here](https://wzbsocialsciencecenter.github.io/tm_corona/tm_analysis.html).
64
+ """
65
+
66
+ topics = [] if topics is None else topics
67
+
68
+ topic_ids = topic_model.get_topic_info().Topic.unique().tolist()
69
+ topic_words = [topic_model.get_topic(topic) for topic in topic_ids]
70
+
71
+ values = np.array([[value[1] for value in values] for values in topic_words])
72
+ indices = np.array([[value + 1 for value in range(len(values))] for values in topic_words])
73
+
74
+ # Create figure
75
+ lines = []
76
+ for topic, x, y in zip(topic_ids, indices, values):
77
+ if not any(y > 1.5):
78
+
79
+ # labels
80
+ if isinstance(custom_labels, str):
81
+ label = f"{topic}_" + "_".join(list(zip(*topic_model.topic_aspects_[custom_labels][topic]))[0][:3])
82
+ elif topic_model.custom_labels_ is not None and custom_labels:
83
+ label = topic_model.custom_labels_[topic + topic_model._outliers]
84
+ else:
85
+ label = f"<b>Topic {topic}</b>:" + "_".join([word[0] for word in topic_model.get_topic(topic)])
86
+ label = label[:50]
87
+
88
+ # line parameters
89
+ color = "red" if topic in topics else "black"
90
+ opacity = 1 if topic in topics else .1
91
+ if any(y == 0):
92
+ y[y == 0] = min(values[values > 0])
93
+ y = np.log10(y, out=y, where=y > 0) if log_scale else y
94
+
95
+ line = go.Scatter(x=x, y=y,
96
+ name="",
97
+ hovertext=label,
98
+ mode="lines+lines",
99
+ opacity=opacity,
100
+ line=dict(color=color, width=1.5))
101
+ lines.append(line)
102
+
103
+ fig = go.Figure(data=lines)
104
+
105
+ # Stylize layout
106
+ fig.update_xaxes(range=[0, len(indices[0])], tick0=1, dtick=2)
107
+ fig.update_layout(
108
+ showlegend=False,
109
+ template="plotly_white",
110
+ title={
111
+ 'text': f"{title}",
112
+ 'y': .9,
113
+ 'x': 0.5,
114
+ 'xanchor': 'center',
115
+ 'yanchor': 'top',
116
+ 'font': dict(
117
+ size=22,
118
+ color="Black")
119
+ },
120
+ width=width,
121
+ height=height,
122
+ hoverlabel=dict(
123
+ bgcolor="white",
124
+ font_size=16,
125
+ font_family="Rockwell"
126
+ ),
127
+ )
128
+
129
+ fig.update_xaxes(title_text='Term Rank')
130
+ if log_scale:
131
+ fig.update_yaxes(title_text='c-TF-IDF score (log scale)')
132
+ else:
133
+ fig.update_yaxes(title_text='c-TF-IDF score')
134
+
135
+ return fig
BERTopic/bertopic/plotting/_topics.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from umap import UMAP
4
+ from typing import List, Union
5
+ from sklearn.preprocessing import MinMaxScaler
6
+
7
+ import plotly.express as px
8
+ import plotly.graph_objects as go
9
+
10
+
11
+ def visualize_topics(topic_model,
12
+ topics: List[int] = None,
13
+ top_n_topics: int = None,
14
+ custom_labels: Union[bool, str] = False,
15
+ title: str = "<b>Intertopic Distance Map</b>",
16
+ width: int = 650,
17
+ height: int = 650) -> go.Figure:
18
+ """ Visualize topics, their sizes, and their corresponding words
19
+
20
+ This visualization is highly inspired by LDAvis, a great visualization
21
+ technique typically reserved for LDA.
22
+
23
+ Arguments:
24
+ topic_model: A fitted BERTopic instance.
25
+ topics: A selection of topics to visualize
26
+ top_n_topics: Only select the top n most frequent topics
27
+ custom_labels: If bool, whether to use custom topic labels that were defined using
28
+ `topic_model.set_topic_labels`.
29
+ If `str`, it uses labels from other aspects, e.g., "Aspect1".
30
+ title: Title of the plot.
31
+ width: The width of the figure.
32
+ height: The height of the figure.
33
+
34
+ Examples:
35
+
36
+ To visualize the topics simply run:
37
+
38
+ ```python
39
+ topic_model.visualize_topics()
40
+ ```
41
+
42
+ Or if you want to save the resulting figure:
43
+
44
+ ```python
45
+ fig = topic_model.visualize_topics()
46
+ fig.write_html("path/to/file.html")
47
+ ```
48
+ <iframe src="../../getting_started/visualization/viz.html"
49
+ style="width:1000px; height: 680px; border: 0px;""></iframe>
50
+ """
51
+ # Select topics based on top_n and topics args
52
+ freq_df = topic_model.get_topic_freq()
53
+ freq_df = freq_df.loc[freq_df.Topic != -1, :]
54
+ if topics is not None:
55
+ topics = list(topics)
56
+ elif top_n_topics is not None:
57
+ topics = sorted(freq_df.Topic.to_list()[:top_n_topics])
58
+ else:
59
+ topics = sorted(freq_df.Topic.to_list())
60
+
61
+ # Extract topic words and their frequencies
62
+ topic_list = sorted(topics)
63
+ frequencies = [topic_model.topic_sizes_[topic] for topic in topic_list]
64
+ if isinstance(custom_labels, str):
65
+ words = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in topic_list]
66
+ words = ["_".join([label[0] for label in labels[:4]]) for labels in words]
67
+ words = [label if len(label) < 30 else label[:27] + "..." for label in words]
68
+ elif custom_labels and topic_model.custom_labels_ is not None:
69
+ words = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in topic_list]
70
+ else:
71
+ words = [" | ".join([word[0] for word in topic_model.get_topic(topic)[:5]]) for topic in topic_list]
72
+
73
+ # Embed c-TF-IDF into 2D
74
+ all_topics = sorted(list(topic_model.get_topics().keys()))
75
+ indices = np.array([all_topics.index(topic) for topic in topics])
76
+
77
+ if topic_model.topic_embeddings_ is not None:
78
+ embeddings = topic_model.topic_embeddings_[indices]
79
+ embeddings = UMAP(n_neighbors=2, n_components=2, metric='cosine', random_state=42).fit_transform(embeddings)
80
+ else:
81
+ embeddings = topic_model.c_tf_idf_.toarray()[indices]
82
+ embeddings = MinMaxScaler().fit_transform(embeddings)
83
+ embeddings = UMAP(n_neighbors=2, n_components=2, metric='hellinger', random_state=42).fit_transform(embeddings)
84
+
85
+ # Visualize with plotly
86
+ df = pd.DataFrame({"x": embeddings[:, 0], "y": embeddings[:, 1],
87
+ "Topic": topic_list, "Words": words, "Size": frequencies})
88
+ return _plotly_topic_visualization(df, topic_list, title, width, height)
89
+
90
+
91
+ def _plotly_topic_visualization(df: pd.DataFrame,
92
+ topic_list: List[str],
93
+ title: str,
94
+ width: int,
95
+ height: int):
96
+ """ Create plotly-based visualization of topics with a slider for topic selection """
97
+
98
+ def get_color(topic_selected):
99
+ if topic_selected == -1:
100
+ marker_color = ["#B0BEC5" for _ in topic_list]
101
+ else:
102
+ marker_color = ["red" if topic == topic_selected else "#B0BEC5" for topic in topic_list]
103
+ return [{'marker.color': [marker_color]}]
104
+
105
+ # Prepare figure range
106
+ x_range = (df.x.min() - abs((df.x.min()) * .15), df.x.max() + abs((df.x.max()) * .15))
107
+ y_range = (df.y.min() - abs((df.y.min()) * .15), df.y.max() + abs((df.y.max()) * .15))
108
+
109
+ # Plot topics
110
+ fig = px.scatter(df, x="x", y="y", size="Size", size_max=40, template="simple_white", labels={"x": "", "y": ""},
111
+ hover_data={"Topic": True, "Words": True, "Size": True, "x": False, "y": False})
112
+ fig.update_traces(marker=dict(color="#B0BEC5", line=dict(width=2, color='DarkSlateGrey')))
113
+
114
+ # Update hover order
115
+ fig.update_traces(hovertemplate="<br>".join(["<b>Topic %{customdata[0]}</b>",
116
+ "%{customdata[1]}",
117
+ "Size: %{customdata[2]}"]))
118
+
119
+ # Create a slider for topic selection
120
+ steps = [dict(label=f"Topic {topic}", method="update", args=get_color(topic)) for topic in topic_list]
121
+ sliders = [dict(active=0, pad={"t": 50}, steps=steps)]
122
+
123
+ # Stylize layout
124
+ fig.update_layout(
125
+ title={
126
+ 'text': f"{title}",
127
+ 'y': .95,
128
+ 'x': 0.5,
129
+ 'xanchor': 'center',
130
+ 'yanchor': 'top',
131
+ 'font': dict(
132
+ size=22,
133
+ color="Black")
134
+ },
135
+ width=width,
136
+ height=height,
137
+ hoverlabel=dict(
138
+ bgcolor="white",
139
+ font_size=16,
140
+ font_family="Rockwell"
141
+ ),
142
+ xaxis={"visible": False},
143
+ yaxis={"visible": False},
144
+ sliders=sliders
145
+ )
146
+
147
+ # Update axes ranges
148
+ fig.update_xaxes(range=x_range)
149
+ fig.update_yaxes(range=y_range)
150
+
151
+ # Add grid in a 'plus' shape
152
+ fig.add_shape(type="line",
153
+ x0=sum(x_range) / 2, y0=y_range[0], x1=sum(x_range) / 2, y1=y_range[1],
154
+ line=dict(color="#CFD8DC", width=2))
155
+ fig.add_shape(type="line",
156
+ x0=x_range[0], y0=sum(y_range) / 2, x1=x_range[1], y1=sum(y_range) / 2,
157
+ line=dict(color="#9E9E9E", width=2))
158
+ fig.add_annotation(x=x_range[0], y=sum(y_range) / 2, text="D1", showarrow=False, yshift=10)
159
+ fig.add_annotation(y=y_range[1], x=sum(x_range) / 2, text="D2", showarrow=False, xshift=10)
160
+ fig.data = fig.data[::-1]
161
+
162
+ return fig
BERTopic/bertopic/plotting/_topics_over_time.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from typing import List, Union
3
+ import plotly.graph_objects as go
4
+ from sklearn.preprocessing import normalize
5
+
6
+
7
+ def visualize_topics_over_time(topic_model,
8
+ topics_over_time: pd.DataFrame,
9
+ top_n_topics: int = None,
10
+ topics: List[int] = None,
11
+ normalize_frequency: bool = False,
12
+ custom_labels: Union[bool, str] = False,
13
+ title: str = "<b>Topics over Time</b>",
14
+ width: int = 1250,
15
+ height: int = 450) -> go.Figure:
16
+ """ Visualize topics over time
17
+
18
+ Arguments:
19
+ topic_model: A fitted BERTopic instance.
20
+ topics_over_time: The topics you would like to be visualized with the
21
+ corresponding topic representation
22
+ top_n_topics: To visualize the most frequent topics instead of all
23
+ topics: Select which topics you would like to be visualized
24
+ normalize_frequency: Whether to normalize each topic's frequency individually
25
+ custom_labels: If bool, whether to use custom topic labels that were defined using
26
+ `topic_model.set_topic_labels`.
27
+ If `str`, it uses labels from other aspects, e.g., "Aspect1".
28
+ title: Title of the plot.
29
+ width: The width of the figure.
30
+ height: The height of the figure.
31
+
32
+ Returns:
33
+ A plotly.graph_objects.Figure including all traces
34
+
35
+ Examples:
36
+
37
+ To visualize the topics over time, simply run:
38
+
39
+ ```python
40
+ topics_over_time = topic_model.topics_over_time(docs, timestamps)
41
+ topic_model.visualize_topics_over_time(topics_over_time)
42
+ ```
43
+
44
+ Or if you want to save the resulting figure:
45
+
46
+ ```python
47
+ fig = topic_model.visualize_topics_over_time(topics_over_time)
48
+ fig.write_html("path/to/file.html")
49
+ ```
50
+ <iframe src="../../getting_started/visualization/trump.html"
51
+ style="width:1000px; height: 680px; border: 0px;""></iframe>
52
+ """
53
+ colors = ["#E69F00", "#56B4E9", "#009E73", "#F0E442", "#D55E00", "#0072B2", "#CC79A7"]
54
+
55
+ # Select topics based on top_n and topics args
56
+ freq_df = topic_model.get_topic_freq()
57
+ freq_df = freq_df.loc[freq_df.Topic != -1, :]
58
+ if topics is not None:
59
+ selected_topics = list(topics)
60
+ elif top_n_topics is not None:
61
+ selected_topics = sorted(freq_df.Topic.to_list()[:top_n_topics])
62
+ else:
63
+ selected_topics = sorted(freq_df.Topic.to_list())
64
+
65
+ # Prepare data
66
+ if isinstance(custom_labels, str):
67
+ topic_names = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in topics]
68
+ topic_names = ["_".join([label[0] for label in labels[:4]]) for labels in topic_names]
69
+ topic_names = [label if len(label) < 30 else label[:27] + "..." for label in topic_names]
70
+ topic_names = {key: topic_names[index] for index, key in enumerate(topic_model.topic_labels_.keys())}
71
+ elif topic_model.custom_labels_ is not None and custom_labels:
72
+ topic_names = {key: topic_model.custom_labels_[key + topic_model._outliers] for key, _ in topic_model.topic_labels_.items()}
73
+ else:
74
+ topic_names = {key: value[:40] + "..." if len(value) > 40 else value
75
+ for key, value in topic_model.topic_labels_.items()}
76
+ topics_over_time["Name"] = topics_over_time.Topic.map(topic_names)
77
+ data = topics_over_time.loc[topics_over_time.Topic.isin(selected_topics), :].sort_values(["Topic", "Timestamp"])
78
+
79
+ # Add traces
80
+ fig = go.Figure()
81
+ for index, topic in enumerate(data.Topic.unique()):
82
+ trace_data = data.loc[data.Topic == topic, :]
83
+ topic_name = trace_data.Name.values[0]
84
+ words = trace_data.Words.values
85
+ if normalize_frequency:
86
+ y = normalize(trace_data.Frequency.values.reshape(1, -1))[0]
87
+ else:
88
+ y = trace_data.Frequency
89
+ fig.add_trace(go.Scatter(x=trace_data.Timestamp, y=y,
90
+ mode='lines',
91
+ marker_color=colors[index % 7],
92
+ hoverinfo="text",
93
+ name=topic_name,
94
+ hovertext=[f'<b>Topic {topic}</b><br>Words: {word}' for word in words]))
95
+
96
+ # Styling of the visualization
97
+ fig.update_xaxes(showgrid=True)
98
+ fig.update_yaxes(showgrid=True)
99
+ fig.update_layout(
100
+ yaxis_title="Normalized Frequency" if normalize_frequency else "Frequency",
101
+ title={
102
+ 'text': f"{title}",
103
+ 'y': .95,
104
+ 'x': 0.40,
105
+ 'xanchor': 'center',
106
+ 'yanchor': 'top',
107
+ 'font': dict(
108
+ size=22,
109
+ color="Black")
110
+ },
111
+ template="simple_white",
112
+ width=width,
113
+ height=height,
114
+ hoverlabel=dict(
115
+ bgcolor="white",
116
+ font_size=16,
117
+ font_family="Rockwell"
118
+ ),
119
+ legend=dict(
120
+ title="<b>Global Topic Representation",
121
+ )
122
+ )
123
+ return fig
BERTopic/bertopic/plotting/_topics_per_class.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from typing import List, Union
3
+ import plotly.graph_objects as go
4
+ from sklearn.preprocessing import normalize
5
+
6
+
7
+ def visualize_topics_per_class(topic_model,
8
+ topics_per_class: pd.DataFrame,
9
+ top_n_topics: int = 10,
10
+ topics: List[int] = None,
11
+ normalize_frequency: bool = False,
12
+ custom_labels: Union[bool, str] = False,
13
+ title: str = "<b>Topics per Class</b>",
14
+ width: int = 1250,
15
+ height: int = 900) -> go.Figure:
16
+ """ Visualize topics per class
17
+
18
+ Arguments:
19
+ topic_model: A fitted BERTopic instance.
20
+ topics_per_class: The topics you would like to be visualized with the
21
+ corresponding topic representation
22
+ top_n_topics: To visualize the most frequent topics instead of all
23
+ topics: Select which topics you would like to be visualized
24
+ normalize_frequency: Whether to normalize each topic's frequency individually
25
+ custom_labels: If bool, whether to use custom topic labels that were defined using
26
+ `topic_model.set_topic_labels`.
27
+ If `str`, it uses labels from other aspects, e.g., "Aspect1".
28
+ title: Title of the plot.
29
+ width: The width of the figure.
30
+ height: The height of the figure.
31
+
32
+ Returns:
33
+ A plotly.graph_objects.Figure including all traces
34
+
35
+ Examples:
36
+
37
+ To visualize the topics per class, simply run:
38
+
39
+ ```python
40
+ topics_per_class = topic_model.topics_per_class(docs, classes)
41
+ topic_model.visualize_topics_per_class(topics_per_class)
42
+ ```
43
+
44
+ Or if you want to save the resulting figure:
45
+
46
+ ```python
47
+ fig = topic_model.visualize_topics_per_class(topics_per_class)
48
+ fig.write_html("path/to/file.html")
49
+ ```
50
+ <iframe src="../../getting_started/visualization/topics_per_class.html"
51
+ style="width:1400px; height: 1000px; border: 0px;""></iframe>
52
+ """
53
+ colors = ["#E69F00", "#56B4E9", "#009E73", "#F0E442", "#D55E00", "#0072B2", "#CC79A7"]
54
+
55
+ # Select topics based on top_n and topics args
56
+ freq_df = topic_model.get_topic_freq()
57
+ freq_df = freq_df.loc[freq_df.Topic != -1, :]
58
+ if topics is not None:
59
+ selected_topics = list(topics)
60
+ elif top_n_topics is not None:
61
+ selected_topics = sorted(freq_df.Topic.to_list()[:top_n_topics])
62
+ else:
63
+ selected_topics = sorted(freq_df.Topic.to_list())
64
+
65
+ # Prepare data
66
+ if isinstance(custom_labels, str):
67
+ topic_names = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in topics]
68
+ topic_names = ["_".join([label[0] for label in labels[:4]]) for labels in topic_names]
69
+ topic_names = [label if len(label) < 30 else label[:27] + "..." for label in topic_names]
70
+ topic_names = {key: topic_names[index] for index, key in enumerate(topic_model.topic_labels_.keys())}
71
+ elif topic_model.custom_labels_ is not None and custom_labels:
72
+ topic_names = {key: topic_model.custom_labels_[key + topic_model._outliers] for key, _ in topic_model.topic_labels_.items()}
73
+ else:
74
+ topic_names = {key: value[:40] + "..." if len(value) > 40 else value
75
+ for key, value in topic_model.topic_labels_.items()}
76
+ topics_per_class["Name"] = topics_per_class.Topic.map(topic_names)
77
+ data = topics_per_class.loc[topics_per_class.Topic.isin(selected_topics), :]
78
+
79
+ # Add traces
80
+ fig = go.Figure()
81
+ for index, topic in enumerate(selected_topics):
82
+ if index == 0:
83
+ visible = True
84
+ else:
85
+ visible = "legendonly"
86
+ trace_data = data.loc[data.Topic == topic, :]
87
+ topic_name = trace_data.Name.values[0]
88
+ words = trace_data.Words.values
89
+ if normalize_frequency:
90
+ x = normalize(trace_data.Frequency.values.reshape(1, -1))[0]
91
+ else:
92
+ x = trace_data.Frequency
93
+ fig.add_trace(go.Bar(y=trace_data.Class,
94
+ x=x,
95
+ visible=visible,
96
+ marker_color=colors[index % 7],
97
+ hoverinfo="text",
98
+ name=topic_name,
99
+ orientation="h",
100
+ hovertext=[f'<b>Topic {topic}</b><br>Words: {word}' for word in words]))
101
+
102
+ # Styling of the visualization
103
+ fig.update_xaxes(showgrid=True)
104
+ fig.update_yaxes(showgrid=True)
105
+ fig.update_layout(
106
+ xaxis_title="Normalized Frequency" if normalize_frequency else "Frequency",
107
+ yaxis_title="Class",
108
+ title={
109
+ 'text': f"{title}",
110
+ 'y': .95,
111
+ 'x': 0.40,
112
+ 'xanchor': 'center',
113
+ 'yanchor': 'top',
114
+ 'font': dict(
115
+ size=22,
116
+ color="Black")
117
+ },
118
+ template="simple_white",
119
+ width=width,
120
+ height=height,
121
+ hoverlabel=dict(
122
+ bgcolor="white",
123
+ font_size=16,
124
+ font_family="Rockwell"
125
+ ),
126
+ legend=dict(
127
+ title="<b>Global Topic Representation",
128
+ )
129
+ )
130
+ return fig
BERTopic/bertopic/representation/__init__.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from bertopic._utils import NotInstalled
2
+ from bertopic.representation._cohere import Cohere
3
+ from bertopic.representation._base import BaseRepresentation
4
+ from bertopic.representation._keybert import KeyBERTInspired
5
+ from bertopic.representation._mmr import MaximalMarginalRelevance
6
+
7
+
8
+ # Llama CPP Generator
9
+ try:
10
+ from bertopic.representation._llamacpp import LlamaCPP
11
+ except ModuleNotFoundError:
12
+ msg = "`pip install llama-cpp-python` \n\n"
13
+ LlamaCPP = NotInstalled("llama.cpp", "llama-cpp-python", custom_msg=msg)
14
+
15
+ # Text Generation using transformers
16
+ try:
17
+ from bertopic.representation._textgeneration import TextGeneration
18
+ except ModuleNotFoundError:
19
+ msg = "`pip install bertopic` without `--no-deps` \n\n"
20
+ TextGeneration = NotInstalled("TextGeneration", "transformers", custom_msg=msg)
21
+
22
+ # Zero-shot classification using transformers
23
+ try:
24
+ from bertopic.representation._zeroshot import ZeroShotClassification
25
+ except ModuleNotFoundError:
26
+ msg = "`pip install bertopic` without `--no-deps` \n\n"
27
+ ZeroShotClassification = NotInstalled("ZeroShotClassification", "transformers", custom_msg=msg)
28
+
29
+ # OpenAI Generator
30
+ try:
31
+ from bertopic.representation._openai import OpenAI
32
+ except ModuleNotFoundError:
33
+ msg = "`pip install openai` \n\n"
34
+ OpenAI = NotInstalled("OpenAI", "openai", custom_msg=msg)
35
+
36
+ # LangChain Generator
37
+ try:
38
+ from bertopic.representation._langchain import LangChain
39
+ except ModuleNotFoundError:
40
+ msg = "`pip install langchain` \n\n"
41
+ LangChain = NotInstalled("langchain", "langchain", custom_msg=msg)
42
+
43
+ # POS using Spacy
44
+ try:
45
+ from bertopic.representation._pos import PartOfSpeech
46
+ except ModuleNotFoundError:
47
+ PartOfSpeech = NotInstalled("Part of Speech with Spacy", "spacy")
48
+
49
+ # Multimodal
50
+ try:
51
+ from bertopic.representation._visual import VisualRepresentation
52
+ except ModuleNotFoundError:
53
+ VisualRepresentation = NotInstalled("a visual representation model", "vision")
54
+
55
+
56
+ __all__ = [
57
+ "BaseRepresentation",
58
+ "TextGeneration",
59
+ "ZeroShotClassification",
60
+ "KeyBERTInspired",
61
+ "PartOfSpeech",
62
+ "MaximalMarginalRelevance",
63
+ "Cohere",
64
+ "OpenAI",
65
+ "LangChain",
66
+ "LlamaCPP",
67
+ "VisualRepresentation"
68
+ ]
BERTopic/bertopic/representation/_base.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from scipy.sparse import csr_matrix
3
+ from sklearn.base import BaseEstimator
4
+ from typing import Mapping, List, Tuple
5
+
6
+
7
+ class BaseRepresentation(BaseEstimator):
8
+ """ The base representation model for fine-tuning topic representations """
9
+ def extract_topics(self,
10
+ topic_model,
11
+ documents: pd.DataFrame,
12
+ c_tf_idf: csr_matrix,
13
+ topics: Mapping[str, List[Tuple[str, float]]]
14
+ ) -> Mapping[str, List[Tuple[str, float]]]:
15
+ """ Extract topics
16
+
17
+ Each representation model that inherits this class will have
18
+ its arguments (topic_model, documents, c_tf_idf, topics)
19
+ automatically passed. Therefore, the representation model
20
+ will only have access to the information about topics related
21
+ to those arguments.
22
+
23
+ Arguments:
24
+ topic_model: The BERTopic model that is fitted until topic
25
+ representations are calculated.
26
+ documents: A dataframe with columns "Document" and "Topic"
27
+ that contains all documents with each corresponding
28
+ topic.
29
+ c_tf_idf: A c-TF-IDF representation that is typically
30
+ identical to `topic_model.c_tf_idf_` except for
31
+ dynamic, class-based, and hierarchical topic modeling
32
+ where it is calculated on a subset of the documents.
33
+ topics: A dictionary with topic (key) and tuple of word and
34
+ weight (value) as calculated by c-TF-IDF. This is the
35
+ default topics that are returned if no representation
36
+ model is used.
37
+ """
38
+ return topic_model.topic_representations_
BERTopic/bertopic/representation/_cohere.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import pandas as pd
3
+ from tqdm import tqdm
4
+ from scipy.sparse import csr_matrix
5
+ from typing import Mapping, List, Tuple, Union, Callable
6
+ from bertopic.representation._base import BaseRepresentation
7
+ from bertopic.representation._utils import truncate_document
8
+
9
+
10
+ DEFAULT_PROMPT = """
11
+ This is a list of texts where each collection of texts describe a topic. After each collection of texts, the name of the topic they represent is mentioned as a short-highly-descriptive title
12
+ ---
13
+ Topic:
14
+ Sample texts from this topic:
15
+ - Traditional diets in most cultures were primarily plant-based with a little meat on top, but with the rise of industrial style meat production and factory farming, meat has become a staple food.
16
+ - Meat, but especially beef, is the word food in terms of emissions.
17
+ - Eating meat doesn't make you a bad person, not eating meat doesn't make you a good one.
18
+
19
+ Keywords: meat beef eat eating emissions steak food health processed chicken
20
+ Topic name: Environmental impacts of eating meat
21
+ ---
22
+ Topic:
23
+ Sample texts from this topic:
24
+ - I have ordered the product weeks ago but it still has not arrived!
25
+ - The website mentions that it only takes a couple of days to deliver but I still have not received mine.
26
+ - I got a message stating that I received the monitor but that is not true!
27
+ - It took a month longer to deliver than was advised...
28
+
29
+ Keywords: deliver weeks product shipping long delivery received arrived arrive week
30
+ Topic name: Shipping and delivery issues
31
+ ---
32
+ Topic:
33
+ Sample texts from this topic:
34
+ [DOCUMENTS]
35
+ Keywords: [KEYWORDS]
36
+ Topic name:"""
37
+
38
+
39
+ class Cohere(BaseRepresentation):
40
+ """ Use the Cohere API to generate topic labels based on their
41
+ generative model.
42
+
43
+ Find more about their models here:
44
+ https://docs.cohere.ai/docs
45
+
46
+ Arguments:
47
+ client: A `cohere.Client`
48
+ model: Model to use within Cohere, defaults to `"xlarge"`.
49
+ prompt: The prompt to be used in the model. If no prompt is given,
50
+ `self.default_prompt_` is used instead.
51
+ NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt
52
+ to decide where the keywords and documents need to be
53
+ inserted.
54
+ delay_in_seconds: The delay in seconds between consecutive prompts
55
+ in order to prevent RateLimitErrors.
56
+ nr_docs: The number of documents to pass to OpenAI if a prompt
57
+ with the `["DOCUMENTS"]` tag is used.
58
+ diversity: The diversity of documents to pass to OpenAI.
59
+ Accepts values between 0 and 1. A higher
60
+ values results in passing more diverse documents
61
+ whereas lower values passes more similar documents.
62
+ doc_length: The maximum length of each document. If a document is longer,
63
+ it will be truncated. If None, the entire document is passed.
64
+ tokenizer: The tokenizer used to calculate to split the document into segments
65
+ used to count the length of a document.
66
+ * If tokenizer is 'char', then the document is split up
67
+ into characters which are counted to adhere to `doc_length`
68
+ * If tokenizer is 'whitespace', the document is split up
69
+ into words separated by whitespaces. These words are counted
70
+ and truncated depending on `doc_length`
71
+ * If tokenizer is 'vectorizer', then the internal CountVectorizer
72
+ is used to tokenize the document. These tokens are counted
73
+ and trunctated depending on `doc_length`
74
+ * If tokenizer is a callable, then that callable is used to tokenize
75
+ the document. These tokens are counted and truncated depending
76
+ on `doc_length`
77
+
78
+ Usage:
79
+
80
+ To use this, you will need to install cohere first:
81
+
82
+ `pip install cohere`
83
+
84
+ Then, get yourself an API key and use Cohere's API as follows:
85
+
86
+ ```python
87
+ import cohere
88
+ from bertopic.representation import Cohere
89
+ from bertopic import BERTopic
90
+
91
+ # Create your representation model
92
+ co = cohere.Client(my_api_key)
93
+ representation_model = Cohere(co)
94
+
95
+ # Use the representation model in BERTopic on top of the default pipeline
96
+ topic_model = BERTopic(representation_model=representation_model)
97
+ ```
98
+
99
+ You can also use a custom prompt:
100
+
101
+ ```python
102
+ prompt = "I have the following documents: [DOCUMENTS]. What topic do they contain?"
103
+ representation_model = Cohere(co, prompt=prompt)
104
+ ```
105
+ """
106
+ def __init__(self,
107
+ client,
108
+ model: str = "xlarge",
109
+ prompt: str = None,
110
+ delay_in_seconds: float = None,
111
+ nr_docs: int = 4,
112
+ diversity: float = None,
113
+ doc_length: int = None,
114
+ tokenizer: Union[str, Callable] = None
115
+ ):
116
+ self.client = client
117
+ self.model = model
118
+ self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
119
+ self.default_prompt_ = DEFAULT_PROMPT
120
+ self.delay_in_seconds = delay_in_seconds
121
+ self.nr_docs = nr_docs
122
+ self.diversity = diversity
123
+ self.doc_length = doc_length
124
+ self.tokenizer = tokenizer
125
+ self.prompts_ = []
126
+
127
+ def extract_topics(self,
128
+ topic_model,
129
+ documents: pd.DataFrame,
130
+ c_tf_idf: csr_matrix,
131
+ topics: Mapping[str, List[Tuple[str, float]]]
132
+ ) -> Mapping[str, List[Tuple[str, float]]]:
133
+ """ Extract topics
134
+
135
+ Arguments:
136
+ topic_model: Not used
137
+ documents: Not used
138
+ c_tf_idf: Not used
139
+ topics: The candidate topics as calculated with c-TF-IDF
140
+
141
+ Returns:
142
+ updated_topics: Updated topic representations
143
+ """
144
+ # Extract the top 4 representative documents per topic
145
+ repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs(c_tf_idf, documents, topics, 500, self.nr_docs, self.diversity)
146
+
147
+ # Generate using Cohere's Language Model
148
+ updated_topics = {}
149
+ for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose):
150
+ truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs]
151
+ prompt = self._create_prompt(truncated_docs, topic, topics)
152
+ self.prompts_.append(prompt)
153
+
154
+ # Delay
155
+ if self.delay_in_seconds:
156
+ time.sleep(self.delay_in_seconds)
157
+
158
+ request = self.client.generate(model=self.model,
159
+ prompt=prompt,
160
+ max_tokens=50,
161
+ num_generations=1,
162
+ stop_sequences=["\n"])
163
+ label = request.generations[0].text.strip()
164
+ updated_topics[topic] = [(label, 1)] + [("", 0) for _ in range(9)]
165
+
166
+ return updated_topics
167
+
168
+ def _create_prompt(self, docs, topic, topics):
169
+ keywords = list(zip(*topics[topic]))[0]
170
+
171
+ # Use the Default Chat Prompt
172
+ if self.prompt == DEFAULT_PROMPT:
173
+ prompt = self.prompt.replace("[KEYWORDS]", ", ".join(keywords))
174
+ prompt = self._replace_documents(prompt, docs)
175
+
176
+ # Use a custom prompt that leverages keywords, documents or both using
177
+ # custom tags, namely [KEYWORDS] and [DOCUMENTS] respectively
178
+ else:
179
+ prompt = self.prompt
180
+ if "[KEYWORDS]" in prompt:
181
+ prompt = prompt.replace("[KEYWORDS]", ", ".join(keywords))
182
+ if "[DOCUMENTS]" in prompt:
183
+ prompt = self._replace_documents(prompt, docs)
184
+
185
+ return prompt
186
+
187
+ @staticmethod
188
+ def _replace_documents(prompt, docs):
189
+ to_replace = ""
190
+ for doc in docs:
191
+ to_replace += f"- {doc}\n"
192
+ prompt = prompt.replace("[DOCUMENTS]", to_replace)
193
+ return prompt
BERTopic/bertopic/representation/_keybert.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+ from packaging import version
5
+ from scipy.sparse import csr_matrix
6
+ from typing import Mapping, List, Tuple, Union
7
+ from sklearn.metrics.pairwise import cosine_similarity
8
+ from bertopic.representation._base import BaseRepresentation
9
+ from sklearn import __version__ as sklearn_version
10
+
11
+
12
+ class KeyBERTInspired(BaseRepresentation):
13
+ def __init__(self,
14
+ top_n_words: int = 10,
15
+ nr_repr_docs: int = 5,
16
+ nr_samples: int = 500,
17
+ nr_candidate_words: int = 100,
18
+ random_state: int = 42):
19
+ """ Use a KeyBERT-like model to fine-tune the topic representations
20
+
21
+ The algorithm follows KeyBERT but does some optimization in
22
+ order to speed up inference.
23
+
24
+ The steps are as follows. First, we extract the top n representative
25
+ documents per topic. To extract the representative documents, we
26
+ randomly sample a number of candidate documents per cluster
27
+ which is controlled by the `nr_samples` parameter. Then,
28
+ the top n representative documents are extracted by calculating
29
+ the c-TF-IDF representation for the candidate documents and finding,
30
+ through cosine similarity, which are closest to the topic c-TF-IDF representation.
31
+ Next, the top n words per topic are extracted based on their
32
+ c-TF-IDF representation, which is controlled by the `nr_repr_docs`
33
+ parameter.
34
+
35
+ Then, we extract the embeddings for words and representative documents
36
+ and create topic embeddings by averaging the representative documents.
37
+ Finally, the most similar words to each topic are extracted by
38
+ calculating the cosine similarity between word and topic embeddings.
39
+
40
+ Arguments:
41
+ top_n_words: The top n words to extract per topic.
42
+ nr_repr_docs: The number of representative documents to extract per cluster.
43
+ nr_samples: The number of candidate documents to extract per cluster.
44
+ nr_candidate_words: The number of candidate words per cluster.
45
+ random_state: The random state for randomly sampling candidate documents.
46
+
47
+ Usage:
48
+
49
+ ```python
50
+ from bertopic.representation import KeyBERTInspired
51
+ from bertopic import BERTopic
52
+
53
+ # Create your representation model
54
+ representation_model = KeyBERTInspired()
55
+
56
+ # Use the representation model in BERTopic on top of the default pipeline
57
+ topic_model = BERTopic(representation_model=representation_model)
58
+ ```
59
+ """
60
+ self.top_n_words = top_n_words
61
+ self.nr_repr_docs = nr_repr_docs
62
+ self.nr_samples = nr_samples
63
+ self.nr_candidate_words = nr_candidate_words
64
+ self.random_state = random_state
65
+
66
+ def extract_topics(self,
67
+ topic_model,
68
+ documents: pd.DataFrame,
69
+ c_tf_idf: csr_matrix,
70
+ topics: Mapping[str, List[Tuple[str, float]]]
71
+ ) -> Mapping[str, List[Tuple[str, float]]]:
72
+ """ Extract topics
73
+
74
+ Arguments:
75
+ topic_model: A BERTopic model
76
+ documents: All input documents
77
+ c_tf_idf: The topic c-TF-IDF representation
78
+ topics: The candidate topics as calculated with c-TF-IDF
79
+
80
+ Returns:
81
+ updated_topics: Updated topic representations
82
+ """
83
+ # We extract the top n representative documents per class
84
+ _, representative_docs, repr_doc_indices, _ = topic_model._extract_representative_docs(c_tf_idf, documents, topics, self.nr_samples, self.nr_repr_docs)
85
+
86
+ # We extract the top n words per class
87
+ topics = self._extract_candidate_words(topic_model, c_tf_idf, topics)
88
+
89
+ # We calculate the similarity between word and document embeddings and create
90
+ # topic embeddings from the representative document embeddings
91
+ sim_matrix, words = self._extract_embeddings(topic_model, topics, representative_docs, repr_doc_indices)
92
+
93
+ # Find the best matching words based on the similarity matrix for each topic
94
+ updated_topics = self._extract_top_words(words, topics, sim_matrix)
95
+
96
+ return updated_topics
97
+
98
+ def _extract_candidate_words(self,
99
+ topic_model,
100
+ c_tf_idf: csr_matrix,
101
+ topics: Mapping[str, List[Tuple[str, float]]]
102
+ ) -> Mapping[str, List[Tuple[str, float]]]:
103
+ """ For each topic, extract candidate words based on the c-TF-IDF
104
+ representation.
105
+
106
+ Arguments:
107
+ topic_model: A BERTopic model
108
+ c_tf_idf: The topic c-TF-IDF representation
109
+ topics: The top words per topic
110
+
111
+ Returns:
112
+ topics: The `self.top_n_words` per topic
113
+ """
114
+ labels = [int(label) for label in sorted(list(topics.keys()))]
115
+
116
+ # Scikit-Learn Deprecation: get_feature_names is deprecated in 1.0
117
+ # and will be removed in 1.2. Please use get_feature_names_out instead.
118
+ if version.parse(sklearn_version) >= version.parse("1.0.0"):
119
+ words = topic_model.vectorizer_model.get_feature_names_out()
120
+ else:
121
+ words = topic_model.vectorizer_model.get_feature_names()
122
+
123
+ indices = topic_model._top_n_idx_sparse(c_tf_idf, self.nr_candidate_words)
124
+ scores = topic_model._top_n_values_sparse(c_tf_idf, indices)
125
+ sorted_indices = np.argsort(scores, 1)
126
+ indices = np.take_along_axis(indices, sorted_indices, axis=1)
127
+ scores = np.take_along_axis(scores, sorted_indices, axis=1)
128
+
129
+ # Get top 30 words per topic based on c-TF-IDF score
130
+ topics = {label: [(words[word_index], score)
131
+ if word_index is not None and score > 0
132
+ else ("", 0.00001)
133
+ for word_index, score in zip(indices[index][::-1], scores[index][::-1])
134
+ ]
135
+ for index, label in enumerate(labels)}
136
+ topics = {label: list(zip(*values[:self.nr_candidate_words]))[0] for label, values in topics.items()}
137
+
138
+ return topics
139
+
140
+ def _extract_embeddings(self,
141
+ topic_model,
142
+ topics: Mapping[str, List[Tuple[str, float]]],
143
+ representative_docs: List[str],
144
+ repr_doc_indices: List[List[int]]
145
+ ) -> Union[np.ndarray, List[str]]:
146
+ """ Extract the representative document embeddings and create topic embeddings.
147
+ Then extract word embeddings and calculate the cosine similarity between topic
148
+ embeddings and the word embeddings. Topic embeddings are the average of
149
+ representative document embeddings.
150
+
151
+ Arguments:
152
+ topic_model: A BERTopic model
153
+ topics: The top words per topic
154
+ representative_docs: A flat list of representative documents
155
+ repr_doc_indices: The indices of representative documents
156
+ that belong to each topic
157
+
158
+ Returns:
159
+ sim: The similarity matrix between word and topic embeddings
160
+ vocab: The complete vocabulary of input documents
161
+ """
162
+ # Calculate representative docs embeddings and create topic embeddings
163
+ repr_embeddings = topic_model._extract_embeddings(representative_docs, method="document", verbose=False)
164
+ topic_embeddings = [np.mean(repr_embeddings[i[0]:i[-1]+1], axis=0) for i in repr_doc_indices]
165
+
166
+ # Calculate word embeddings and extract best matching with updated topic_embeddings
167
+ vocab = list(set([word for words in topics.values() for word in words]))
168
+ word_embeddings = topic_model._extract_embeddings(vocab, method="document", verbose=False)
169
+ sim = cosine_similarity(topic_embeddings, word_embeddings)
170
+
171
+ return sim, vocab
172
+
173
+ def _extract_top_words(self,
174
+ vocab: List[str],
175
+ topics: Mapping[str, List[Tuple[str, float]]],
176
+ sim: np.ndarray
177
+ ) -> Mapping[str, List[Tuple[str, float]]]:
178
+ """ Extract the top n words per topic based on the
179
+ similarity matrix between topics and words.
180
+
181
+ Arguments:
182
+ vocab: The complete vocabulary of input documents
183
+ labels: All topic labels
184
+ topics: The top words per topic
185
+ sim: The similarity matrix between word and topic embeddings
186
+
187
+ Returns:
188
+ updated_topics: The updated topic representations
189
+ """
190
+ labels = [int(label) for label in sorted(list(topics.keys()))]
191
+ updated_topics = {}
192
+ for i, topic in enumerate(labels):
193
+ indices = [vocab.index(word) for word in topics[topic]]
194
+ values = sim[:, indices][i]
195
+ word_indices = [indices[index] for index in np.argsort(values)[-self.top_n_words:]]
196
+ updated_topics[topic] = [(vocab[index], val) for val, index in zip(np.sort(values)[-self.top_n_words:], word_indices)][::-1]
197
+
198
+ return updated_topics
BERTopic/bertopic/representation/_langchain.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from langchain.docstore.document import Document
3
+ from scipy.sparse import csr_matrix
4
+ from typing import Callable, Dict, Mapping, List, Tuple, Union
5
+
6
+ from bertopic.representation._base import BaseRepresentation
7
+ from bertopic.representation._utils import truncate_document
8
+
9
+ DEFAULT_PROMPT = "What are these documents about? Please give a single label."
10
+
11
+
12
+ class LangChain(BaseRepresentation):
13
+ """ Using chains in langchain to generate topic labels.
14
+
15
+ The classic example uses `langchain.chains.question_answering.load_qa_chain`.
16
+ This returns a chain that takes a list of documents and a question as input.
17
+
18
+ You can also use Runnables such as those composed using the LangChain Expression Language.
19
+
20
+ Arguments:
21
+ chain: The langchain chain or Runnable with a `batch` method.
22
+ Input keys must be `input_documents` and `question`.
23
+ Output key must be `output_text`.
24
+ prompt: The prompt to be used in the model. If no prompt is given,
25
+ `self.default_prompt_` is used instead.
26
+ nr_docs: The number of documents to pass to LangChain if a prompt
27
+ with the `["DOCUMENTS"]` tag is used.
28
+ diversity: The diversity of documents to pass to LangChain.
29
+ Accepts values between 0 and 1. A higher
30
+ values results in passing more diverse documents
31
+ whereas lower values passes more similar documents.
32
+ doc_length: The maximum length of each document. If a document is longer,
33
+ it will be truncated. If None, the entire document is passed.
34
+ tokenizer: The tokenizer used to calculate to split the document into segments
35
+ used to count the length of a document.
36
+ * If tokenizer is 'char', then the document is split up
37
+ into characters which are counted to adhere to `doc_length`
38
+ * If tokenizer is 'whitespace', the document is split up
39
+ into words separated by whitespaces. These words are counted
40
+ and truncated depending on `doc_length`
41
+ * If tokenizer is 'vectorizer', then the internal CountVectorizer
42
+ is used to tokenize the document. These tokens are counted
43
+ and trunctated depending on `doc_length`. They are decoded with
44
+ whitespaces.
45
+ * If tokenizer is a callable, then that callable is used to tokenize
46
+ the document. These tokens are counted and truncated depending
47
+ on `doc_length`
48
+ chain_config: The configuration for the langchain chain. Can be used to set options
49
+ like max_concurrency to avoid rate limiting errors.
50
+ Usage:
51
+
52
+ To use this, you will need to install the langchain package first.
53
+ Additionally, you will need an underlying LLM to support langchain,
54
+ like openai:
55
+
56
+ `pip install langchain`
57
+ `pip install openai`
58
+
59
+ Then, you can create your chain as follows:
60
+
61
+ ```python
62
+ from langchain.chains.question_answering import load_qa_chain
63
+ from langchain.llms import OpenAI
64
+ chain = load_qa_chain(OpenAI(temperature=0, openai_api_key=my_openai_api_key), chain_type="stuff")
65
+ ```
66
+
67
+ Finally, you can pass the chain to BERTopic as follows:
68
+
69
+ ```python
70
+ from bertopic.representation import LangChain
71
+
72
+ # Create your representation model
73
+ representation_model = LangChain(chain)
74
+
75
+ # Use the representation model in BERTopic on top of the default pipeline
76
+ topic_model = BERTopic(representation_model=representation_model)
77
+ ```
78
+
79
+ You can also use a custom prompt:
80
+
81
+ ```python
82
+ prompt = "What are these documents about? Please give a single label."
83
+ representation_model = LangChain(chain, prompt=prompt)
84
+ ```
85
+
86
+ You can also use a Runnable instead of a chain.
87
+ The example below uses the LangChain Expression Language:
88
+
89
+ ```python
90
+ from bertopic.representation import LangChain
91
+ from langchain.chains.question_answering import load_qa_chain
92
+ from langchain.chat_models import ChatAnthropic
93
+ from langchain.schema.document import Document
94
+ from langchain.schema.runnable import RunnablePassthrough
95
+ from langchain_experimental.data_anonymizer.presidio import PresidioReversibleAnonymizer
96
+
97
+ prompt = ...
98
+ llm = ...
99
+
100
+ # We will construct a special privacy-preserving chain using Microsoft Presidio
101
+
102
+ pii_handler = PresidioReversibleAnonymizer(analyzed_fields=["PERSON"])
103
+
104
+ chain = (
105
+ {
106
+ "input_documents": (
107
+ lambda inp: [
108
+ Document(
109
+ page_content=pii_handler.anonymize(
110
+ d.page_content,
111
+ language="en",
112
+ ),
113
+ )
114
+ for d in inp["input_documents"]
115
+ ]
116
+ ),
117
+ "question": RunnablePassthrough(),
118
+ }
119
+ | load_qa_chain(representation_llm, chain_type="stuff")
120
+ | (lambda output: {"output_text": pii_handler.deanonymize(output["output_text"])})
121
+ )
122
+
123
+ representation_model = LangChain(chain, prompt=representation_prompt)
124
+ ```
125
+ """
126
+ def __init__(self,
127
+ chain,
128
+ prompt: str = None,
129
+ nr_docs: int = 4,
130
+ diversity: float = None,
131
+ doc_length: int = None,
132
+ tokenizer: Union[str, Callable] = None,
133
+ chain_config = None,
134
+ ):
135
+ self.chain = chain
136
+ self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
137
+ self.default_prompt_ = DEFAULT_PROMPT
138
+ self.chain_config = chain_config
139
+ self.nr_docs = nr_docs
140
+ self.diversity = diversity
141
+ self.doc_length = doc_length
142
+ self.tokenizer = tokenizer
143
+
144
+ def extract_topics(self,
145
+ topic_model,
146
+ documents: pd.DataFrame,
147
+ c_tf_idf: csr_matrix,
148
+ topics: Mapping[str, List[Tuple[str, float]]]
149
+ ) -> Mapping[str, List[Tuple[str, int]]]:
150
+ """ Extract topics
151
+
152
+ Arguments:
153
+ topic_model: A BERTopic model
154
+ documents: All input documents
155
+ c_tf_idf: The topic c-TF-IDF representation
156
+ topics: The candidate topics as calculated with c-TF-IDF
157
+
158
+ Returns:
159
+ updated_topics: Updated topic representations
160
+ """
161
+ # Extract the top 4 representative documents per topic
162
+ repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs(
163
+ c_tf_idf=c_tf_idf,
164
+ documents=documents,
165
+ topics=topics,
166
+ nr_samples=500,
167
+ nr_repr_docs=self.nr_docs,
168
+ diversity=self.diversity
169
+ )
170
+
171
+ # Generate label using langchain's batch functionality
172
+ chain_docs: List[List[Document]] = [
173
+ [
174
+ Document(
175
+ page_content=truncate_document(
176
+ topic_model,
177
+ self.doc_length,
178
+ self.tokenizer,
179
+ doc
180
+ )
181
+ )
182
+ for doc in docs
183
+ ]
184
+ for docs in repr_docs_mappings.values()
185
+ ]
186
+
187
+ # `self.chain` must take `input_documents` and `question` as input keys
188
+ inputs = [
189
+ {"input_documents": docs, "question": self.prompt}
190
+ for docs in chain_docs
191
+ ]
192
+
193
+ # `self.chain` must return a dict with an `output_text` key
194
+ # same output key as the `StuffDocumentsChain` returned by `load_qa_chain`
195
+ outputs = self.chain.batch(inputs=inputs, config=self.chain_config)
196
+ labels = [output["output_text"].strip() for output in outputs]
197
+
198
+ updated_topics = {
199
+ topic: [(label, 1)] + [("", 0) for _ in range(9)]
200
+ for topic, label in zip(repr_docs_mappings.keys(), labels)
201
+ }
202
+
203
+ return updated_topics