Upload 261 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- BERTopic/.flake8 +2 -0
- BERTopic/.gitattributes +1 -0
- BERTopic/.github/CONTRIBUTING.md +53 -0
- BERTopic/.github/workflows/testing.yml +31 -0
- BERTopic/.gitignore +83 -0
- BERTopic/LICENSE +21 -0
- BERTopic/Makefile +24 -0
- BERTopic/README.md +297 -0
- BERTopic/bertopic/__init__.py +7 -0
- BERTopic/bertopic/_bertopic.py +0 -0
- BERTopic/bertopic/_save_utils.py +492 -0
- BERTopic/bertopic/_utils.py +149 -0
- BERTopic/bertopic/backend/__init__.py +35 -0
- BERTopic/bertopic/backend/_base.py +69 -0
- BERTopic/bertopic/backend/_cohere.py +94 -0
- BERTopic/bertopic/backend/_flair.py +78 -0
- BERTopic/bertopic/backend/_gensim.py +66 -0
- BERTopic/bertopic/backend/_hftransformers.py +96 -0
- BERTopic/bertopic/backend/_multimodal.py +194 -0
- BERTopic/bertopic/backend/_openai.py +88 -0
- BERTopic/bertopic/backend/_sentencetransformers.py +66 -0
- BERTopic/bertopic/backend/_sklearn.py +68 -0
- BERTopic/bertopic/backend/_spacy.py +94 -0
- BERTopic/bertopic/backend/_use.py +58 -0
- BERTopic/bertopic/backend/_utils.py +135 -0
- BERTopic/bertopic/backend/_word_doc.py +49 -0
- BERTopic/bertopic/cluster/__init__.py +5 -0
- BERTopic/bertopic/cluster/_base.py +41 -0
- BERTopic/bertopic/cluster/_utils.py +70 -0
- BERTopic/bertopic/dimensionality/__init__.py +5 -0
- BERTopic/bertopic/dimensionality/_base.py +26 -0
- BERTopic/bertopic/plotting/__init__.py +28 -0
- BERTopic/bertopic/plotting/_approximate_distribution.py +99 -0
- BERTopic/bertopic/plotting/_barchart.py +127 -0
- BERTopic/bertopic/plotting/_datamap.py +152 -0
- BERTopic/bertopic/plotting/_distribution.py +110 -0
- BERTopic/bertopic/plotting/_documents.py +227 -0
- BERTopic/bertopic/plotting/_heatmap.py +138 -0
- BERTopic/bertopic/plotting/_hierarchical_documents.py +336 -0
- BERTopic/bertopic/plotting/_hierarchy.py +308 -0
- BERTopic/bertopic/plotting/_term_rank.py +135 -0
- BERTopic/bertopic/plotting/_topics.py +162 -0
- BERTopic/bertopic/plotting/_topics_over_time.py +123 -0
- BERTopic/bertopic/plotting/_topics_per_class.py +130 -0
- BERTopic/bertopic/representation/__init__.py +68 -0
- BERTopic/bertopic/representation/_base.py +38 -0
- BERTopic/bertopic/representation/_cohere.py +193 -0
- BERTopic/bertopic/representation/_keybert.py +198 -0
- BERTopic/bertopic/representation/_langchain.py +203 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
data/df_en_review_vntourism.csv filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
data/merge_final.csv filter=lfs diff=lfs merge=lfs -text
|
BERTopic/.flake8
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[flake8]
|
| 2 |
+
max-line-length = 160
|
BERTopic/.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*.ipynb linguist-documentation
|
BERTopic/.github/CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributing to BERTopic
|
| 2 |
+
|
| 3 |
+
Hi! Thank you for considering contributing to BERTopic. With the modular nature of BERTopic, many new add-ons, backends, representation models, sub-models, and LLMs, can quickly be added to keep up with the incredibly fast-pacing field.
|
| 4 |
+
|
| 5 |
+
Whether contributions are new features, better documentation, bug fixes, or improvement on the repository itself, anything is appreciated!
|
| 6 |
+
|
| 7 |
+
## 📚 Guidelines
|
| 8 |
+
|
| 9 |
+
### 🤖 Contributing Code
|
| 10 |
+
|
| 11 |
+
To contribute to this project, we follow an `issue -> pull request` approach for main features and bug fixes. This means that any new feature, bug fix, or anything else that touches on code directly needs to start from an issue first. That way, the main discussion about what needs to be added/fixed can be done in the issue before creating a pull request. This makes sure that we are on the same page before you start coding your pull request. If you start working on an issue, please assign it to yourself but do so after there is an agreement with the maintainer, [@MaartenGr](https://github.com/MaartenGr).
|
| 12 |
+
|
| 13 |
+
When there is agreement on the assigned approach, a pull request can be created in which the fix/feature can be added. This follows a ["fork and pull request"](https://docs.github.com/en/get-started/quickstart/contributing-to-projects) workflow.
|
| 14 |
+
Please do not try to push directly to this repo unless you are a maintainer.
|
| 15 |
+
|
| 16 |
+
There are exceptions to the `issue -> pull request` approach that are typically small changes that do not need agreements, such as:
|
| 17 |
+
* Documentation
|
| 18 |
+
* Spelling/grammar issues
|
| 19 |
+
* Docstrings
|
| 20 |
+
* etc.
|
| 21 |
+
|
| 22 |
+
There is a large focus on documentation in this repository, so please make sure to add extensive descriptions of features when creating the pull request.
|
| 23 |
+
|
| 24 |
+
Note that the main focus of pull requests and code should be:
|
| 25 |
+
* Easy readability
|
| 26 |
+
* Clear communication
|
| 27 |
+
* Sufficient documentation
|
| 28 |
+
|
| 29 |
+
## 🚀 Quick Start
|
| 30 |
+
|
| 31 |
+
To start contributing, make sure to first start from a fresh environment. Using an environment manager, such as `conda` or `pyenv` helps in making sure that your code is reproducible and tracks the versions you have in your environment.
|
| 32 |
+
|
| 33 |
+
If you are using conda, you can approach it as follows:
|
| 34 |
+
|
| 35 |
+
1. Create and activate a new conda environment (e.g., `conda create -n bertopic python=3.9`)
|
| 36 |
+
2. Install requirements (e.g., `pip install .[dev]`)
|
| 37 |
+
* This makes sure to also install documentation and testing packages
|
| 38 |
+
3. (Optional) Run `make docs` to build your documentation
|
| 39 |
+
4. (Optional) Run `make test` to run the unit tests and `make coverage` to check the coverage of unit tests
|
| 40 |
+
|
| 41 |
+
❗Note: Unit testing the package can take quite some time since it needs to run several variants of the BERTopic pipeline.
|
| 42 |
+
|
| 43 |
+
## 🤓 Collaborative Efforts
|
| 44 |
+
|
| 45 |
+
When you run into any issue with the above or need help to start with a pull request, feel free to reach out in the issues! As with all repositories, this one has its particularities as a result of the maintainer's view. Each repository is quite different and so will their processes.
|
| 46 |
+
|
| 47 |
+
## 🏆 Recognition
|
| 48 |
+
|
| 49 |
+
If your contribution has made its way into a new release of BERTopic, you will be given credit in the changelog of the new release! Regardless of the size of the contribution, any help is greatly appreciated.
|
| 50 |
+
|
| 51 |
+
## 🎈 Release
|
| 52 |
+
|
| 53 |
+
BERTopic tries to mostly follow [semantic versioning](https://semver.org/) for its new releases. Even though BERTopic has been around for a few years now, it is still pre-1.0 software. With the rapid chances in the field and as a way to keep up, this versioning is on purpose. Backwards-compatibility is taken into account but integrating new features and thereby keeping up with the field takes priority. Especially since BERTopic focuses on modularity, flexibility is necessary.
|
BERTopic/.github/workflows/testing.yml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Code Checks
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches:
|
| 6 |
+
- master
|
| 7 |
+
- dev
|
| 8 |
+
pull_request:
|
| 9 |
+
branches:
|
| 10 |
+
- master
|
| 11 |
+
- dev
|
| 12 |
+
|
| 13 |
+
jobs:
|
| 14 |
+
build:
|
| 15 |
+
runs-on: ubuntu-latest
|
| 16 |
+
strategy:
|
| 17 |
+
matrix:
|
| 18 |
+
python-version: [3.8, 3.9]
|
| 19 |
+
|
| 20 |
+
steps:
|
| 21 |
+
- uses: actions/checkout@v2
|
| 22 |
+
- name: Set up Python ${{ matrix.python-version }}
|
| 23 |
+
uses: actions/setup-python@v1
|
| 24 |
+
with:
|
| 25 |
+
python-version: ${{ matrix.python-version }}
|
| 26 |
+
- name: Install dependencies
|
| 27 |
+
run: |
|
| 28 |
+
python -m pip install --upgrade pip
|
| 29 |
+
pip install -e ".[test]"
|
| 30 |
+
- name: Run Checking Mechanisms
|
| 31 |
+
run: make check
|
BERTopic/.gitignore
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
pip-wheel-metadata/
|
| 24 |
+
share/python-wheels/
|
| 25 |
+
*.egg-info/
|
| 26 |
+
.installed.cfg
|
| 27 |
+
*.egg
|
| 28 |
+
MANIFEST
|
| 29 |
+
model_dir
|
| 30 |
+
model_dir/
|
| 31 |
+
test
|
| 32 |
+
|
| 33 |
+
# PyInstaller
|
| 34 |
+
# Usually these files are written by a python script from a template
|
| 35 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 36 |
+
*.manifest
|
| 37 |
+
*.spec
|
| 38 |
+
|
| 39 |
+
# Installer logs
|
| 40 |
+
pip-log.txt
|
| 41 |
+
pip-delete-this-directory.txt
|
| 42 |
+
|
| 43 |
+
# Unit test / coverage reports
|
| 44 |
+
htmlcov/
|
| 45 |
+
.tox/
|
| 46 |
+
.nox/
|
| 47 |
+
.coverage
|
| 48 |
+
.coverage.*
|
| 49 |
+
.cache
|
| 50 |
+
nosetests.xml
|
| 51 |
+
coverage.xml
|
| 52 |
+
*.cover
|
| 53 |
+
.hypothesis/
|
| 54 |
+
.pytest_cache/
|
| 55 |
+
|
| 56 |
+
# Sphinx documentation
|
| 57 |
+
docs/_build/
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# Jupyter Notebook
|
| 61 |
+
.ipynb_checkpoints
|
| 62 |
+
|
| 63 |
+
# IPython
|
| 64 |
+
profile_default/
|
| 65 |
+
ipython_config.py
|
| 66 |
+
|
| 67 |
+
# pyenv
|
| 68 |
+
.python-version
|
| 69 |
+
|
| 70 |
+
# Environments
|
| 71 |
+
.env
|
| 72 |
+
.venv
|
| 73 |
+
env/
|
| 74 |
+
venv/
|
| 75 |
+
ENV/
|
| 76 |
+
env.bak/
|
| 77 |
+
venv.bak/
|
| 78 |
+
|
| 79 |
+
# Artifacts
|
| 80 |
+
.idea
|
| 81 |
+
.idea/
|
| 82 |
+
.vscode
|
| 83 |
+
.DS_Store
|
BERTopic/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023, Maarten P. Grootendorst
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
BERTopic/Makefile
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
test:
|
| 2 |
+
pytest
|
| 3 |
+
|
| 4 |
+
coverage:
|
| 5 |
+
pytest --cov
|
| 6 |
+
|
| 7 |
+
install:
|
| 8 |
+
python -m pip install -e .
|
| 9 |
+
|
| 10 |
+
install-test:
|
| 11 |
+
python -m pip install -e ".[dev]"
|
| 12 |
+
|
| 13 |
+
docs:
|
| 14 |
+
mkdocs serve
|
| 15 |
+
|
| 16 |
+
pypi:
|
| 17 |
+
python setup.py sdist
|
| 18 |
+
python setup.py bdist_wheel --universal
|
| 19 |
+
twine upload dist/*
|
| 20 |
+
|
| 21 |
+
clean:
|
| 22 |
+
rm -rf **/.ipynb_checkpoints **/.pytest_cache **/__pycache__ **/**/__pycache__ .ipynb_checkpoints .pytest_cache
|
| 23 |
+
|
| 24 |
+
check: test clean
|
BERTopic/README.md
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[](https://pypi.org/project/bertopic/)
|
| 2 |
+
[](https://github.com/MaartenGr/BERTopic/actions)
|
| 3 |
+
[](https://maartengr.github.io/BERTopic/)
|
| 4 |
+
[](https://pypi.org/project/bertopic/)
|
| 5 |
+
[](https://github.com/MaartenGr/VLAC/blob/master/LICENSE)
|
| 6 |
+
[](https://arxiv.org/abs/2203.05794)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# BERTopic
|
| 10 |
+
|
| 11 |
+
<img src="images/logo.png" width="35%" height="35%" align="right" />
|
| 12 |
+
|
| 13 |
+
BERTopic is a topic modeling technique that leverages 🤗 transformers and c-TF-IDF to create dense clusters
|
| 14 |
+
allowing for easily interpretable topics whilst keeping important words in the topic descriptions.
|
| 15 |
+
|
| 16 |
+
BERTopic supports all kinds of topic modeling techniques:
|
| 17 |
+
<table>
|
| 18 |
+
<tr>
|
| 19 |
+
<td><a href="https://maartengr.github.io/BERTopic/getting_started/guided/guided.html">Guided</a></td>
|
| 20 |
+
<td><a href="https://maartengr.github.io/BERTopic/getting_started/supervised/supervised.html">Supervised</a></td>
|
| 21 |
+
<td><a href="https://maartengr.github.io/BERTopic/getting_started/semisupervised/semisupervised.html">Semi-supervised</a></td>
|
| 22 |
+
</tr>
|
| 23 |
+
<tr>
|
| 24 |
+
<td><a href="https://maartengr.github.io/BERTopic/getting_started/manual/manual.html">Manual</a></td>
|
| 25 |
+
<td><a href="https://maartengr.github.io/BERTopic/getting_started/distribution/distribution.html">Multi-topic distributions</a></td>
|
| 26 |
+
<td><a href="https://maartengr.github.io/BERTopic/getting_started/hierarchicaltopics/hierarchicaltopics.html">Hierarchical</a></td>
|
| 27 |
+
</tr>
|
| 28 |
+
<tr>
|
| 29 |
+
<td><a href="https://maartengr.github.io/BERTopic/getting_started/topicsperclass/topicsperclass.html">Class-based</a></td>
|
| 30 |
+
<td><a href="https://maartengr.github.io/BERTopic/getting_started/topicsovertime/topicsovertime.html">Dynamic</a></td>
|
| 31 |
+
<td><a href="https://maartengr.github.io/BERTopic/getting_started/online/online.html">Online/Incremental</a></td>
|
| 32 |
+
</tr>
|
| 33 |
+
<tr>
|
| 34 |
+
<td><a href="https://maartengr.github.io/BERTopic/getting_started/multimodal/multimodal.html">Multimodal</a></td>
|
| 35 |
+
<td><a href="https://maartengr.github.io/BERTopic/getting_started/multiaspect/multiaspect.html">Multi-aspect</a></td>
|
| 36 |
+
<td><a href="https://maartengr.github.io/BERTopic/getting_started/representation/llm.html">Text Generation/LLM</a></td>
|
| 37 |
+
</tr>
|
| 38 |
+
<tr>
|
| 39 |
+
<td><a href="https://maartengr.github.io/BERTopic/getting_started/zeroshot/zeroshot.html">Zero-shot <b>(new!)</b></a></td>
|
| 40 |
+
<td><a href="https://maartengr.github.io/BERTopic/getting_started/merge/merge.html">Merge Models <b>(new!)</b></a></td>
|
| 41 |
+
<td><a href="https://maartengr.github.io/BERTopic/getting_started/seed_words/seed_words.html">Seed Words <b>(new!)</b></a></td>
|
| 42 |
+
</tr>
|
| 43 |
+
</table>
|
| 44 |
+
|
| 45 |
+
Corresponding medium posts can be found [here](https://towardsdatascience.com/topic-modeling-with-bert-779f7db187e6?source=friends_link&sk=0b5a470c006d1842ad4c8a3057063a99), [here](https://towardsdatascience.com/interactive-topic-modeling-with-bertopic-1ea55e7d73d8?sk=03c2168e9e74b6bda2a1f3ed953427e4) and [here](https://towardsdatascience.com/using-whisper-and-bertopic-to-model-kurzgesagts-videos-7d8a63139bdf?sk=b1e0fd46f70cb15e8422b4794a81161d). For a more detailed overview, you can read the [paper](https://arxiv.org/abs/2203.05794) or see a [brief overview](https://maartengr.github.io/BERTopic/algorithm/algorithm.html).
|
| 46 |
+
|
| 47 |
+
## Installation
|
| 48 |
+
|
| 49 |
+
Installation, with sentence-transformers, can be done using [pypi](https://pypi.org/project/bertopic/):
|
| 50 |
+
|
| 51 |
+
```bash
|
| 52 |
+
pip install bertopic
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
If you want to install BERTopic with other embedding models, you can choose one of the following:
|
| 56 |
+
|
| 57 |
+
```bash
|
| 58 |
+
# Choose an embedding backend
|
| 59 |
+
pip install bertopic[flair,gensim,spacy,use]
|
| 60 |
+
|
| 61 |
+
# Topic modeling with images
|
| 62 |
+
pip install bertopic[vision]
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
## Getting Started
|
| 66 |
+
For an in-depth overview of the features of BERTopic
|
| 67 |
+
you can check the [**full documentation**](https://maartengr.github.io/BERTopic/) or you can follow along
|
| 68 |
+
with one of the examples below:
|
| 69 |
+
|
| 70 |
+
| Name | Link |
|
| 71 |
+
|---|---|
|
| 72 |
+
| Start Here - **Best Practices in BERTopic** | [](https://colab.research.google.com/drive/1BoQ_vakEVtojsd2x_U6-_x52OOuqruj2?usp=sharing) |
|
| 73 |
+
| **🆕 New!** - Topic Modeling on Large Data (GPU Acceleration) | [](https://colab.research.google.com/drive/1W7aEdDPxC29jP99GGZphUlqjMFFVKtBC?usp=sharing) |
|
| 74 |
+
| **🆕 New!** - Topic Modeling with Llama 2 🦙 | [](https://colab.research.google.com/drive/1QCERSMUjqGetGGujdrvv_6_EeoIcd_9M?usp=sharing) |
|
| 75 |
+
| **🆕 New!** - Topic Modeling with Quantized LLMs | [](https://colab.research.google.com/drive/1DdSHvVPJA3rmNfBWjCo2P1E9686xfxFx?usp=sharing) |
|
| 76 |
+
| Topic Modeling with BERTopic | [](https://colab.research.google.com/drive/1FieRA9fLdkQEGDIMYl0I3MCjSUKVF8C-?usp=sharing) |
|
| 77 |
+
| (Custom) Embedding Models in BERTopic | [](https://colab.research.google.com/drive/18arPPe50szvcCp_Y6xS56H2tY0m-RLqv?usp=sharing) |
|
| 78 |
+
| Advanced Customization in BERTopic | [](https://colab.research.google.com/drive/1ClTYut039t-LDtlcd-oQAdXWgcsSGTw9?usp=sharing) |
|
| 79 |
+
| (semi-)Supervised Topic Modeling with BERTopic | [](https://colab.research.google.com/drive/1bxizKzv5vfxJEB29sntU__ZC7PBSIPaQ?usp=sharing) |
|
| 80 |
+
| Dynamic Topic Modeling with Trump's Tweets | [](https://colab.research.google.com/drive/1un8ooI-7ZNlRoK0maVkYhmNRl0XGK88f?usp=sharing) |
|
| 81 |
+
| Topic Modeling arXiv Abstracts | [](https://www.kaggle.com/maartengr/topic-modeling-arxiv-abstract-with-bertopic) |
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
## Quick Start
|
| 85 |
+
We start by extracting topics from the well-known 20 newsgroups dataset containing English documents:
|
| 86 |
+
|
| 87 |
+
```python
|
| 88 |
+
from bertopic import BERTopic
|
| 89 |
+
from sklearn.datasets import fetch_20newsgroups
|
| 90 |
+
|
| 91 |
+
docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data']
|
| 92 |
+
|
| 93 |
+
topic_model = BERTopic()
|
| 94 |
+
topics, probs = topic_model.fit_transform(docs)
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
After generating topics and their probabilities, we can access all of the topics together with their topic representations:
|
| 98 |
+
|
| 99 |
+
```python
|
| 100 |
+
>>> topic_model.get_topic_info()
|
| 101 |
+
|
| 102 |
+
Topic Count Name
|
| 103 |
+
-1 4630 -1_can_your_will_any
|
| 104 |
+
0 693 49_windows_drive_dos_file
|
| 105 |
+
1 466 32_jesus_bible_christian_faith
|
| 106 |
+
2 441 2_space_launch_orbit_lunar
|
| 107 |
+
3 381 22_key_encryption_keys_encrypted
|
| 108 |
+
...
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
The `-1` topic refers to all outlier documents and are typically ignored. Each word in a topic describes the underlying theme of that topic and can be used
|
| 112 |
+
for interpreting that topic. Next, let's take a look at the most frequent topic that was generated:
|
| 113 |
+
|
| 114 |
+
```python
|
| 115 |
+
>>> topic_model.get_topic(0)
|
| 116 |
+
|
| 117 |
+
[('windows', 0.006152228076250982),
|
| 118 |
+
('drive', 0.004982897610645755),
|
| 119 |
+
('dos', 0.004845038866360651),
|
| 120 |
+
('file', 0.004140142872194834),
|
| 121 |
+
('disk', 0.004131678774810884),
|
| 122 |
+
('mac', 0.003624848635985097),
|
| 123 |
+
('memory', 0.0034840976976789903),
|
| 124 |
+
('software', 0.0034415334250699077),
|
| 125 |
+
('email', 0.0034239554442333257),
|
| 126 |
+
('pc', 0.003047105930670237)]
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
Using `.get_document_info`, we can also extract information on a document level, such as their corresponding topics, probabilities, whether they are representative documents for a topic, etc.:
|
| 130 |
+
|
| 131 |
+
```python
|
| 132 |
+
>>> topic_model.get_document_info(docs)
|
| 133 |
+
|
| 134 |
+
Document Topic Name Top_n_words Probability ...
|
| 135 |
+
I am sure some bashers of Pens... 0 0_game_team_games_season game - team - games... 0.200010 ...
|
| 136 |
+
My brother is in the market for... -1 -1_can_your_will_any can - your - will... 0.420668 ...
|
| 137 |
+
Finally you said what you dream... -1 -1_can_your_will_any can - your - will... 0.807259 ...
|
| 138 |
+
Think! It's the SCSI card doing... 49 49_windows_drive_dos_file windows - drive - docs... 0.071746 ...
|
| 139 |
+
1) I have an old Jasmine drive... 49 49_windows_drive_dos_file windows - drive - docs... 0.038983 ...
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
**`🔥 Tip`**: Use `BERTopic(language="multilingual")` to select a model that supports 50+ languages.
|
| 143 |
+
|
| 144 |
+
## Fine-tune Topic Representations
|
| 145 |
+
|
| 146 |
+
In BERTopic, there are a number of different [topic representations](https://maartengr.github.io/BERTopic/getting_started/representation/representation.html) that we can choose from. They are all quite different from one another and give interesting perspectives and variations of topic representations. A great start is `KeyBERTInspired`, which for many users increases the coherence and reduces stopwords from the resulting topic representations:
|
| 147 |
+
|
| 148 |
+
```python
|
| 149 |
+
from bertopic.representation import KeyBERTInspired
|
| 150 |
+
|
| 151 |
+
# Fine-tune your topic representations
|
| 152 |
+
representation_model = KeyBERTInspired()
|
| 153 |
+
topic_model = BERTopic(representation_model=representation_model)
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
However, you might want to use something more powerful to describe your clusters. You can even use ChatGPT or other models from OpenAI to generate labels, summaries, phrases, keywords, and more:
|
| 157 |
+
|
| 158 |
+
```python
|
| 159 |
+
import openai
|
| 160 |
+
from bertopic.representation import OpenAI
|
| 161 |
+
|
| 162 |
+
# Fine-tune topic representations with GPT
|
| 163 |
+
client = openai.OpenAI(api_key="sk-...")
|
| 164 |
+
representation_model = OpenAI(client, model="gpt-3.5-turbo", chat=True)
|
| 165 |
+
topic_model = BERTopic(representation_model=representation_model)
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
**`🔥 Tip`**: Instead of iterating over all of these different topic representations, you can model them simultaneously with [multi-aspect topic representations](https://maartengr.github.io/BERTopic/getting_started/multiaspect/multiaspect.html) in BERTopic.
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
## Visualizations
|
| 172 |
+
After having trained our BERTopic model, we can iteratively go through hundreds of topics to get a good
|
| 173 |
+
understanding of the topics that were extracted. However, that takes quite some time and lacks a global representation. Instead, we can use one of the [many visualization options](https://maartengr.github.io/BERTopic/getting_started/visualization/visualization.html) in BERTopic.
|
| 174 |
+
For example, we can visualize the topics that were generated in a way very similar to
|
| 175 |
+
[LDAvis](https://github.com/cpsievert/LDAvis):
|
| 176 |
+
|
| 177 |
+
```python
|
| 178 |
+
topic_model.visualize_topics()
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
<img src="images/topic_visualization.gif" width="60%" height="60%" align="center" />
|
| 182 |
+
|
| 183 |
+
## Modularity
|
| 184 |
+
By default, the [main steps](https://maartengr.github.io/BERTopic/algorithm/algorithm.html) for topic modeling with BERTopic are sentence-transformers, UMAP, HDBSCAN, and c-TF-IDF run in sequence. However, it assumes some independence between these steps which makes BERTopic quite modular. In other words, BERTopic not only allows you to build your own topic model but to explore several topic modeling techniques on top of your customized topic model:
|
| 185 |
+
|
| 186 |
+
https://user-images.githubusercontent.com/25746895/218420473-4b2bb539-9dbe-407a-9674-a8317c7fb3bf.mp4
|
| 187 |
+
|
| 188 |
+
You can swap out any of these models or even remove them entirely. The following steps are completely modular:
|
| 189 |
+
|
| 190 |
+
1. [Embedding](https://maartengr.github.io/BERTopic/getting_started/embeddings/embeddings.html) documents
|
| 191 |
+
2. [Reducing dimensionality](https://maartengr.github.io/BERTopic/getting_started/dim_reduction/dim_reduction.html) of embeddings
|
| 192 |
+
3. [Clustering](https://maartengr.github.io/BERTopic/getting_started/clustering/clustering.html) reduced embeddings into topics
|
| 193 |
+
4. [Tokenization](https://maartengr.github.io/BERTopic/getting_started/vectorizers/vectorizers.html) of topics
|
| 194 |
+
5. [Weight](https://maartengr.github.io/BERTopic/getting_started/ctfidf/ctfidf.html) tokens
|
| 195 |
+
6. [Represent topics](https://maartengr.github.io/BERTopic/getting_started/representation/representation.html) with one or [multiple](https://maartengr.github.io/BERTopic/getting_started/multiaspect/multiaspect.html) representations
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
## Functionality
|
| 199 |
+
BERTopic has many functions that quickly can become overwhelming. To alleviate this issue, you will find an overview
|
| 200 |
+
of all methods and a short description of its purpose.
|
| 201 |
+
|
| 202 |
+
### Common
|
| 203 |
+
Below, you will find an overview of common functions in BERTopic.
|
| 204 |
+
|
| 205 |
+
| Method | Code |
|
| 206 |
+
|-----------------------|---|
|
| 207 |
+
| Fit the model | `.fit(docs)` |
|
| 208 |
+
| Fit the model and predict documents | `.fit_transform(docs)` |
|
| 209 |
+
| Predict new documents | `.transform([new_doc])` |
|
| 210 |
+
| Access single topic | `.get_topic(topic=12)` |
|
| 211 |
+
| Access all topics | `.get_topics()` |
|
| 212 |
+
| Get topic freq | `.get_topic_freq()` |
|
| 213 |
+
| Get all topic information| `.get_topic_info()` |
|
| 214 |
+
| Get all document information| `.get_document_info(docs)` |
|
| 215 |
+
| Get representative docs per topic | `.get_representative_docs()` |
|
| 216 |
+
| Update topic representation | `.update_topics(docs, n_gram_range=(1, 3))` |
|
| 217 |
+
| Generate topic labels | `.generate_topic_labels()` |
|
| 218 |
+
| Set topic labels | `.set_topic_labels(my_custom_labels)` |
|
| 219 |
+
| Merge topics | `.merge_topics(docs, topics_to_merge)` |
|
| 220 |
+
| Reduce nr of topics | `.reduce_topics(docs, nr_topics=30)` |
|
| 221 |
+
| Reduce outliers | `.reduce_outliers(docs, topics)` |
|
| 222 |
+
| Find topics | `.find_topics("vehicle")` |
|
| 223 |
+
| Save model | `.save("my_model", serialization="safetensors")` |
|
| 224 |
+
| Load model | `BERTopic.load("my_model")` |
|
| 225 |
+
| Get parameters | `.get_params()` |
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
### Attributes
|
| 229 |
+
After having trained your BERTopic model, several attributes are saved within your model. These attributes, in part,
|
| 230 |
+
refer to how model information is stored on an estimator during fitting. The attributes that you see below all end in `_` and are
|
| 231 |
+
public attributes that can be used to access model information.
|
| 232 |
+
|
| 233 |
+
| Attribute | Description |
|
| 234 |
+
|------------------------|---------------------------------------------------------------------------------------------|
|
| 235 |
+
| `.topics_` | The topics that are generated for each document after training or updating the topic model. |
|
| 236 |
+
| `.probabilities_` | The probabilities that are generated for each document if HDBSCAN is used. |
|
| 237 |
+
| `.topic_sizes_` | The size of each topic |
|
| 238 |
+
| `.topic_mapper_` | A class for tracking topics and their mappings anytime they are merged/reduced. |
|
| 239 |
+
| `.topic_representations_` | The top *n* terms per topic and their respective c-TF-IDF values. |
|
| 240 |
+
| `.c_tf_idf_` | The topic-term matrix as calculated through c-TF-IDF. |
|
| 241 |
+
| `.topic_aspects_` | The different aspects, or representations, of each topic. |
|
| 242 |
+
| `.topic_labels_` | The default labels for each topic. |
|
| 243 |
+
| `.custom_labels_` | Custom labels for each topic as generated through `.set_topic_labels`. |
|
| 244 |
+
| `.topic_embeddings_` | The embeddings for each topic if `embedding_model` was used. |
|
| 245 |
+
| `.representative_docs_` | The representative documents for each topic if HDBSCAN is used. |
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
### Variations
|
| 249 |
+
There are many different use cases in which topic modeling can be used. As such, several variations of BERTopic have been developed such that one package can be used across many use cases.
|
| 250 |
+
|
| 251 |
+
| Method | Code |
|
| 252 |
+
|-----------------------|---|
|
| 253 |
+
| [Topic Distribution Approximation](https://maartengr.github.io/BERTopic/getting_started/distribution/distribution.html) | `.approximate_distribution(docs)` |
|
| 254 |
+
| [Online Topic Modeling](https://maartengr.github.io/BERTopic/getting_started/online/online.html) | `.partial_fit(doc)` |
|
| 255 |
+
| [Semi-supervised Topic Modeling](https://maartengr.github.io/BERTopic/getting_started/semisupervised/semisupervised.html) | `.fit(docs, y=y)` |
|
| 256 |
+
| [Supervised Topic Modeling](https://maartengr.github.io/BERTopic/getting_started/supervised/supervised.html) | `.fit(docs, y=y)` |
|
| 257 |
+
| [Manual Topic Modeling](https://maartengr.github.io/BERTopic/getting_started/manual/manual.html) | `.fit(docs, y=y)` |
|
| 258 |
+
| [Multimodal Topic Modeling](https://maartengr.github.io/BERTopic/getting_started/multimodal/multimodal.html) | ``.fit(docs, images=images)`` |
|
| 259 |
+
| [Topic Modeling per Class](https://maartengr.github.io/BERTopic/getting_started/topicsperclass/topicsperclass.html) | `.topics_per_class(docs, classes)` |
|
| 260 |
+
| [Dynamic Topic Modeling](https://maartengr.github.io/BERTopic/getting_started/topicsovertime/topicsovertime.html) | `.topics_over_time(docs, timestamps)` |
|
| 261 |
+
| [Hierarchical Topic Modeling](https://maartengr.github.io/BERTopic/getting_started/hierarchicaltopics/hierarchicaltopics.html) | `.hierarchical_topics(docs)` |
|
| 262 |
+
| [Guided Topic Modeling](https://maartengr.github.io/BERTopic/getting_started/guided/guided.html) | `BERTopic(seed_topic_list=seed_topic_list)` |
|
| 263 |
+
| [Zero-shot Topic Modeling](https://maartengr.github.io/BERTopic/getting_started/zeroshot/zeroshot.html) | `BERTopic(zeroshot_topic_list=zeroshot_topic_list)` |
|
| 264 |
+
| [Merge Multiple Models](https://maartengr.github.io/BERTopic/getting_started/merge/merge.html) | `BERTopic.merge_models([topic_model_1, topic_model_2])` |
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
### Visualizations
|
| 268 |
+
Evaluating topic models can be rather difficult due to the somewhat subjective nature of evaluation.
|
| 269 |
+
Visualizing different aspects of the topic model helps in understanding the model and makes it easier
|
| 270 |
+
to tweak the model to your liking.
|
| 271 |
+
|
| 272 |
+
| Method | Code |
|
| 273 |
+
|-----------------------|---|
|
| 274 |
+
| Visualize Topics | `.visualize_topics()` |
|
| 275 |
+
| Visualize Documents | `.visualize_documents()` |
|
| 276 |
+
| Visualize Document Hierarchy | `.visualize_hierarchical_documents()` |
|
| 277 |
+
| Visualize Topic Hierarchy | `.visualize_hierarchy()` |
|
| 278 |
+
| Visualize Topic Tree | `.get_topic_tree(hierarchical_topics)` |
|
| 279 |
+
| Visualize Topic Terms | `.visualize_barchart()` |
|
| 280 |
+
| Visualize Topic Similarity | `.visualize_heatmap()` |
|
| 281 |
+
| Visualize Term Score Decline | `.visualize_term_rank()` |
|
| 282 |
+
| Visualize Topic Probability Distribution | `.visualize_distribution(probs[0])` |
|
| 283 |
+
| Visualize Topics over Time | `.visualize_topics_over_time(topics_over_time)` |
|
| 284 |
+
| Visualize Topics per Class | `.visualize_topics_per_class(topics_per_class)` |
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
## Citation
|
| 288 |
+
To cite the [BERTopic paper](https://arxiv.org/abs/2203.05794), please use the following bibtex reference:
|
| 289 |
+
|
| 290 |
+
```bibtext
|
| 291 |
+
@article{grootendorst2022bertopic,
|
| 292 |
+
title={BERTopic: Neural topic modeling with a class-based TF-IDF procedure},
|
| 293 |
+
author={Grootendorst, Maarten},
|
| 294 |
+
journal={arXiv preprint arXiv:2203.05794},
|
| 295 |
+
year={2022}
|
| 296 |
+
}
|
| 297 |
+
```
|
BERTopic/bertopic/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from bertopic._bertopic import BERTopic
|
| 2 |
+
|
| 3 |
+
__version__ = "0.16.0"
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"BERTopic",
|
| 7 |
+
]
|
BERTopic/bertopic/_bertopic.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
BERTopic/bertopic/_save_utils.py
ADDED
|
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import json
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from tempfile import TemporaryDirectory
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# HuggingFace Hub
|
| 11 |
+
try:
|
| 12 |
+
from huggingface_hub import (
|
| 13 |
+
create_repo, get_hf_file_metadata,
|
| 14 |
+
hf_hub_download, hf_hub_url,
|
| 15 |
+
repo_type_and_id_from_hf_id, upload_folder)
|
| 16 |
+
_has_hf_hub = True
|
| 17 |
+
except ImportError:
|
| 18 |
+
_has_hf_hub = False
|
| 19 |
+
|
| 20 |
+
# Typing
|
| 21 |
+
if sys.version_info >= (3, 8):
|
| 22 |
+
from typing import Literal
|
| 23 |
+
else:
|
| 24 |
+
from typing_extensions import Literal
|
| 25 |
+
from typing import Union, Mapping, Any
|
| 26 |
+
|
| 27 |
+
# Pytorch check
|
| 28 |
+
try:
|
| 29 |
+
import torch
|
| 30 |
+
_has_torch = True
|
| 31 |
+
except ImportError:
|
| 32 |
+
_has_torch = False
|
| 33 |
+
|
| 34 |
+
# Image check
|
| 35 |
+
try:
|
| 36 |
+
from PIL import Image
|
| 37 |
+
_has_vision = True
|
| 38 |
+
except:
|
| 39 |
+
_has_vision = False
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
TOPICS_NAME = "topics.json"
|
| 43 |
+
CONFIG_NAME = "config.json"
|
| 44 |
+
|
| 45 |
+
HF_WEIGHTS_NAME = "topic_embeddings.bin" # default pytorch pkl
|
| 46 |
+
HF_SAFE_WEIGHTS_NAME = "topic_embeddings.safetensors" # safetensors version
|
| 47 |
+
|
| 48 |
+
CTFIDF_WEIGHTS_NAME = "ctfidf.bin" # default pytorch pkl
|
| 49 |
+
CTFIDF_SAFE_WEIGHTS_NAME = "ctfidf.safetensors" # safetensors version
|
| 50 |
+
CTFIDF_CFG_NAME = "ctfidf_config.json"
|
| 51 |
+
|
| 52 |
+
MODEL_CARD_TEMPLATE = """
|
| 53 |
+
---
|
| 54 |
+
tags:
|
| 55 |
+
- bertopic
|
| 56 |
+
library_name: bertopic
|
| 57 |
+
pipeline_tag: {PIPELINE_TAG}
|
| 58 |
+
---
|
| 59 |
+
|
| 60 |
+
# {MODEL_NAME}
|
| 61 |
+
|
| 62 |
+
This is a [BERTopic](https://github.com/MaartenGr/BERTopic) model.
|
| 63 |
+
BERTopic is a flexible and modular topic modeling framework that allows for the generation of easily interpretable topics from large datasets.
|
| 64 |
+
|
| 65 |
+
## Usage
|
| 66 |
+
|
| 67 |
+
To use this model, please install BERTopic:
|
| 68 |
+
|
| 69 |
+
```
|
| 70 |
+
pip install -U bertopic
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
You can use the model as follows:
|
| 74 |
+
|
| 75 |
+
```python
|
| 76 |
+
from bertopic import BERTopic
|
| 77 |
+
topic_model = BERTopic.load("{PATH}")
|
| 78 |
+
|
| 79 |
+
topic_model.get_topic_info()
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
## Topic overview
|
| 83 |
+
|
| 84 |
+
* Number of topics: {NR_TOPICS}
|
| 85 |
+
* Number of training documents: {NR_DOCUMENTS}
|
| 86 |
+
|
| 87 |
+
<details>
|
| 88 |
+
<summary>Click here for an overview of all topics.</summary>
|
| 89 |
+
|
| 90 |
+
{TOPICS}
|
| 91 |
+
|
| 92 |
+
</details>
|
| 93 |
+
|
| 94 |
+
## Training hyperparameters
|
| 95 |
+
|
| 96 |
+
{HYPERPARAMS}
|
| 97 |
+
|
| 98 |
+
## Framework versions
|
| 99 |
+
|
| 100 |
+
{FRAMEWORKS}
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def push_to_hf_hub(
|
| 106 |
+
model,
|
| 107 |
+
repo_id: str,
|
| 108 |
+
commit_message: str = 'Add BERTopic model',
|
| 109 |
+
token: str = None,
|
| 110 |
+
revision: str = None,
|
| 111 |
+
private: bool = False,
|
| 112 |
+
create_pr: bool = False,
|
| 113 |
+
model_card: bool = True,
|
| 114 |
+
serialization: str = "safetensors",
|
| 115 |
+
save_embedding_model: Union[str, bool] = True,
|
| 116 |
+
save_ctfidf: bool = False,
|
| 117 |
+
):
|
| 118 |
+
""" Push your BERTopic model to a HuggingFace Hub
|
| 119 |
+
|
| 120 |
+
Arguments:
|
| 121 |
+
repo_id: The name of your HuggingFace repository
|
| 122 |
+
commit_message: A commit message
|
| 123 |
+
token: Token to add if not already logged in
|
| 124 |
+
revision: Repository revision
|
| 125 |
+
private: Whether to create a private repository
|
| 126 |
+
create_pr: Whether to upload the model as a Pull Request
|
| 127 |
+
model_card: Whether to automatically create a modelcard
|
| 128 |
+
serialization: The type of serialization.
|
| 129 |
+
Either `safetensors` or `pytorch`
|
| 130 |
+
save_embedding_model: A pointer towards a HuggingFace model to be loaded in with
|
| 131 |
+
SentenceTransformers. E.g.,
|
| 132 |
+
`sentence-transformers/all-MiniLM-L6-v2`
|
| 133 |
+
save_ctfidf: Whether to save c-TF-IDF information
|
| 134 |
+
"""
|
| 135 |
+
if not _has_hf_hub:
|
| 136 |
+
raise ValueError("Make sure you have the huggingface hub installed via `pip install --upgrade huggingface_hub`")
|
| 137 |
+
|
| 138 |
+
# Create repo if it doesn't exist yet and infer complete repo_id
|
| 139 |
+
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
|
| 140 |
+
_, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
|
| 141 |
+
repo_id = f"{repo_owner}/{repo_name}"
|
| 142 |
+
|
| 143 |
+
# Temporarily save model and push to HF
|
| 144 |
+
with TemporaryDirectory() as tmpdir:
|
| 145 |
+
|
| 146 |
+
# Save model weights and config.
|
| 147 |
+
model.save(tmpdir, serialization=serialization, save_embedding_model=save_embedding_model, save_ctfidf=save_ctfidf)
|
| 148 |
+
|
| 149 |
+
# Add README if it does not exist
|
| 150 |
+
try:
|
| 151 |
+
get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
|
| 152 |
+
except:
|
| 153 |
+
if model_card:
|
| 154 |
+
readme_text = generate_readme(model, repo_id)
|
| 155 |
+
readme_path = Path(tmpdir) / "README.md"
|
| 156 |
+
readme_path.write_text(readme_text, encoding='utf8')
|
| 157 |
+
|
| 158 |
+
# Upload model
|
| 159 |
+
return upload_folder(repo_id=repo_id, folder_path=tmpdir, revision=revision,
|
| 160 |
+
create_pr=create_pr, commit_message=commit_message)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def load_local_files(path):
|
| 164 |
+
""" Load local BERTopic files """
|
| 165 |
+
# Load json configs
|
| 166 |
+
topics = load_cfg_from_json(path / TOPICS_NAME)
|
| 167 |
+
params = load_cfg_from_json(path / CONFIG_NAME)
|
| 168 |
+
|
| 169 |
+
# Load Topic Embeddings
|
| 170 |
+
safetensor_path = path / HF_SAFE_WEIGHTS_NAME
|
| 171 |
+
if safetensor_path.is_file():
|
| 172 |
+
tensors = load_safetensors(safetensor_path)
|
| 173 |
+
else:
|
| 174 |
+
torch_path = path / HF_WEIGHTS_NAME
|
| 175 |
+
if torch_path.is_file():
|
| 176 |
+
tensors = torch.load(torch_path, map_location="cpu")
|
| 177 |
+
|
| 178 |
+
# c-TF-IDF
|
| 179 |
+
try:
|
| 180 |
+
ctfidf_tensors = None
|
| 181 |
+
safetensor_path = path / CTFIDF_SAFE_WEIGHTS_NAME
|
| 182 |
+
if safetensor_path.is_file():
|
| 183 |
+
ctfidf_tensors = load_safetensors(safetensor_path)
|
| 184 |
+
else:
|
| 185 |
+
torch_path = path / CTFIDF_WEIGHTS_NAME
|
| 186 |
+
if torch_path.is_file():
|
| 187 |
+
ctfidf_tensors = torch.load(torch_path, map_location="cpu")
|
| 188 |
+
ctfidf_config = load_cfg_from_json(path / CTFIDF_CFG_NAME)
|
| 189 |
+
except:
|
| 190 |
+
ctfidf_config, ctfidf_tensors = None, None
|
| 191 |
+
|
| 192 |
+
# Load images
|
| 193 |
+
images = None
|
| 194 |
+
if _has_vision:
|
| 195 |
+
try:
|
| 196 |
+
Image.open(path / "images/0.jpg")
|
| 197 |
+
_has_images = True
|
| 198 |
+
except:
|
| 199 |
+
_has_images = False
|
| 200 |
+
|
| 201 |
+
if _has_images:
|
| 202 |
+
topic_list = list(topics["topic_representations"].keys())
|
| 203 |
+
images = {}
|
| 204 |
+
for topic in topic_list:
|
| 205 |
+
image = Image.open(path / f"images/{topic}.jpg")
|
| 206 |
+
images[int(topic)] = image
|
| 207 |
+
|
| 208 |
+
return topics, params, tensors, ctfidf_tensors, ctfidf_config, images
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def load_files_from_hf(path):
|
| 212 |
+
""" Load files from HuggingFace. """
|
| 213 |
+
path = str(path)
|
| 214 |
+
|
| 215 |
+
# Configs
|
| 216 |
+
topics = load_cfg_from_json(hf_hub_download(path, TOPICS_NAME, revision=None))
|
| 217 |
+
params = load_cfg_from_json(hf_hub_download(path, CONFIG_NAME, revision=None))
|
| 218 |
+
|
| 219 |
+
# Topic Embeddings
|
| 220 |
+
try:
|
| 221 |
+
tensors = hf_hub_download(path, HF_SAFE_WEIGHTS_NAME, revision=None)
|
| 222 |
+
tensors = load_safetensors(tensors)
|
| 223 |
+
except:
|
| 224 |
+
tensors = hf_hub_download(path, HF_WEIGHTS_NAME, revision=None)
|
| 225 |
+
tensors = torch.load(tensors, map_location="cpu")
|
| 226 |
+
|
| 227 |
+
# c-TF-IDF
|
| 228 |
+
try:
|
| 229 |
+
ctfidf_config = load_cfg_from_json(hf_hub_download(path, CTFIDF_CFG_NAME, revision=None))
|
| 230 |
+
try:
|
| 231 |
+
ctfidf_tensors = hf_hub_download(path, CTFIDF_SAFE_WEIGHTS_NAME, revision=None)
|
| 232 |
+
ctfidf_tensors = load_safetensors(ctfidf_tensors)
|
| 233 |
+
except:
|
| 234 |
+
ctfidf_tensors = hf_hub_download(path, CTFIDF_WEIGHTS_NAME, revision=None)
|
| 235 |
+
ctfidf_tensors = torch.load(ctfidf_tensors, map_location="cpu")
|
| 236 |
+
except:
|
| 237 |
+
ctfidf_config, ctfidf_tensors = None, None
|
| 238 |
+
|
| 239 |
+
# Load images if they exist
|
| 240 |
+
images = None
|
| 241 |
+
if _has_vision:
|
| 242 |
+
try:
|
| 243 |
+
hf_hub_download(path, "images/0.jpg", revision=None)
|
| 244 |
+
_has_images = True
|
| 245 |
+
except:
|
| 246 |
+
_has_images = False
|
| 247 |
+
|
| 248 |
+
if _has_images:
|
| 249 |
+
topic_list = list(topics["topic_representations"].keys())
|
| 250 |
+
images = {}
|
| 251 |
+
for topic in topic_list:
|
| 252 |
+
image = Image.open(hf_hub_download(path, f"images/{topic}.jpg", revision=None))
|
| 253 |
+
images[int(topic)] = image
|
| 254 |
+
|
| 255 |
+
return topics, params, tensors, ctfidf_tensors, ctfidf_config, images
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def generate_readme(model, repo_id: str):
|
| 259 |
+
""" Generate README for HuggingFace model card """
|
| 260 |
+
model_card = MODEL_CARD_TEMPLATE
|
| 261 |
+
topic_table_head = "| Topic ID | Topic Keywords | Topic Frequency | Label | \n|----------|----------------|-----------------|-------| \n"
|
| 262 |
+
|
| 263 |
+
# Get Statistics
|
| 264 |
+
model_name = repo_id.split("/")[-1]
|
| 265 |
+
params = {param: value for param, value in model.get_params().items() if "model" not in param}
|
| 266 |
+
params = "\n".join([f"* {param}: {value}" for param, value in params.items()])
|
| 267 |
+
topics = sorted(list(set(model.topics_)))
|
| 268 |
+
nr_topics = str(len(set(model.topics_)))
|
| 269 |
+
|
| 270 |
+
if model.topic_sizes_ is not None:
|
| 271 |
+
nr_documents = str(sum(model.topic_sizes_.values()))
|
| 272 |
+
else:
|
| 273 |
+
nr_documents = ""
|
| 274 |
+
|
| 275 |
+
# Topic information
|
| 276 |
+
topic_keywords = [" - ".join(list(zip(*model.get_topic(topic)))[0][:5]) for topic in topics]
|
| 277 |
+
topic_freq = [model.get_topic_freq(topic) for topic in topics]
|
| 278 |
+
topic_labels = model.custom_labels_ if model.custom_labels_ else [model.topic_labels_[topic] for topic in topics]
|
| 279 |
+
topics = [f"| {topic} | {topic_keywords[index]} | {topic_freq[topic]} | {topic_labels[index]} | \n" for index, topic in enumerate(topics)]
|
| 280 |
+
topics = topic_table_head + "".join(topics)
|
| 281 |
+
frameworks = "\n".join([f"* {param}: {value}" for param, value in get_package_versions().items()])
|
| 282 |
+
|
| 283 |
+
# Fill Statistics into model card
|
| 284 |
+
model_card = model_card.replace("{MODEL_NAME}", model_name)
|
| 285 |
+
model_card = model_card.replace("{PATH}", repo_id)
|
| 286 |
+
model_card = model_card.replace("{NR_TOPICS}", nr_topics)
|
| 287 |
+
model_card = model_card.replace("{TOPICS}", topics.strip())
|
| 288 |
+
model_card = model_card.replace("{NR_DOCUMENTS}", nr_documents)
|
| 289 |
+
model_card = model_card.replace("{HYPERPARAMS}", params)
|
| 290 |
+
model_card = model_card.replace("{FRAMEWORKS}", frameworks)
|
| 291 |
+
|
| 292 |
+
# Fill Pipeline tag
|
| 293 |
+
has_visual_aspect = check_has_visual_aspect(model)
|
| 294 |
+
if not has_visual_aspect:
|
| 295 |
+
model_card = model_card.replace("{PIPELINE_TAG}", "text-classification")
|
| 296 |
+
else:
|
| 297 |
+
model_card = model_card.replace("pipeline_tag: {PIPELINE_TAG}\n","") # TODO add proper tag for this instance
|
| 298 |
+
|
| 299 |
+
return model_card
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def save_hf(model, save_directory, serialization: str):
|
| 303 |
+
""" Save topic embeddings, either safely (using safetensors) or using legacy pytorch """
|
| 304 |
+
tensors = torch.from_numpy(np.array(model.topic_embeddings_, dtype=np.float32))
|
| 305 |
+
tensors = {"topic_embeddings": tensors}
|
| 306 |
+
|
| 307 |
+
if serialization == "safetensors":
|
| 308 |
+
save_safetensors(save_directory / HF_SAFE_WEIGHTS_NAME, tensors)
|
| 309 |
+
if serialization == "pytorch":
|
| 310 |
+
assert _has_torch, "`pip install pytorch` to save as bin"
|
| 311 |
+
torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def save_ctfidf(model,
|
| 315 |
+
save_directory: str,
|
| 316 |
+
serialization: str):
|
| 317 |
+
""" Save c-TF-IDF sparse matrix """
|
| 318 |
+
indptr = torch.from_numpy(model.c_tf_idf_.indptr)
|
| 319 |
+
indices = torch.from_numpy(model.c_tf_idf_.indices)
|
| 320 |
+
data = torch.from_numpy(model.c_tf_idf_.data)
|
| 321 |
+
shape = torch.from_numpy(np.array(model.c_tf_idf_.shape))
|
| 322 |
+
diag = torch.from_numpy(np.array(model.ctfidf_model._idf_diag.data))
|
| 323 |
+
tensors = {
|
| 324 |
+
"indptr": indptr,
|
| 325 |
+
"indices": indices,
|
| 326 |
+
"data": data,
|
| 327 |
+
"shape": shape,
|
| 328 |
+
"diag": diag
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
if serialization == "safetensors":
|
| 332 |
+
save_safetensors(save_directory / CTFIDF_SAFE_WEIGHTS_NAME, tensors)
|
| 333 |
+
if serialization == "pytorch":
|
| 334 |
+
assert _has_torch, "`pip install pytorch` to save as .bin"
|
| 335 |
+
torch.save(tensors, save_directory / CTFIDF_WEIGHTS_NAME)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def save_ctfidf_config(model, path):
|
| 339 |
+
""" Save parameters to recreate CountVectorizer and c-TF-IDF """
|
| 340 |
+
config = {}
|
| 341 |
+
|
| 342 |
+
# Recreate ClassTfidfTransformer
|
| 343 |
+
config["ctfidf_model"] = {
|
| 344 |
+
"bm25_weighting": model.ctfidf_model.bm25_weighting,
|
| 345 |
+
"reduce_frequent_words": model.ctfidf_model.reduce_frequent_words
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
# Recreate CountVectorizer
|
| 349 |
+
cv_params = model.vectorizer_model.get_params()
|
| 350 |
+
del cv_params["tokenizer"], cv_params["preprocessor"], cv_params["dtype"]
|
| 351 |
+
if not isinstance(cv_params["analyzer"], str):
|
| 352 |
+
del cv_params["analyzer"]
|
| 353 |
+
|
| 354 |
+
config["vectorizer_model"] = {
|
| 355 |
+
"params": cv_params,
|
| 356 |
+
"vocab": model.vectorizer_model.vocabulary_
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
with path.open('w') as f:
|
| 360 |
+
json.dump(config, f, indent=2)
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def save_config(model, path: str, embedding_model):
|
| 364 |
+
""" Save BERTopic configuration """
|
| 365 |
+
path = Path(path)
|
| 366 |
+
params = model.get_params()
|
| 367 |
+
config = {param: value for param, value in params.items() if "model" not in param}
|
| 368 |
+
|
| 369 |
+
# Embedding model tag to be used in sentence-transformers
|
| 370 |
+
if isinstance(embedding_model, str):
|
| 371 |
+
config["embedding_model"] = embedding_model
|
| 372 |
+
|
| 373 |
+
with path.open('w') as f:
|
| 374 |
+
json.dump(config, f, indent=2)
|
| 375 |
+
|
| 376 |
+
return config
|
| 377 |
+
|
| 378 |
+
def check_has_visual_aspect(model):
|
| 379 |
+
"""Check if model has visual aspect"""
|
| 380 |
+
if _has_vision:
|
| 381 |
+
for aspect, value in model.topic_aspects_.items():
|
| 382 |
+
if isinstance(value[0], Image.Image):
|
| 383 |
+
visual_aspects = model.topic_aspects_[aspect]
|
| 384 |
+
return True
|
| 385 |
+
|
| 386 |
+
def save_images(model, path: str):
|
| 387 |
+
""" Save topic images """
|
| 388 |
+
if _has_vision:
|
| 389 |
+
visual_aspects = None
|
| 390 |
+
for aspect, value in model.topic_aspects_.items():
|
| 391 |
+
if isinstance(value[0], Image.Image):
|
| 392 |
+
visual_aspects = model.topic_aspects_[aspect]
|
| 393 |
+
break
|
| 394 |
+
|
| 395 |
+
if visual_aspects is not None:
|
| 396 |
+
path.mkdir(exist_ok=True, parents=True)
|
| 397 |
+
for topic, image in visual_aspects.items():
|
| 398 |
+
image.save(path / f"{topic}.jpg")
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def save_topics(model, path: str):
|
| 402 |
+
""" Save Topic-specific information """
|
| 403 |
+
path = Path(path)
|
| 404 |
+
|
| 405 |
+
if _has_vision:
|
| 406 |
+
selected_topic_aspects = {}
|
| 407 |
+
for aspect, value in model.topic_aspects_.items():
|
| 408 |
+
if not isinstance(value[0], Image.Image):
|
| 409 |
+
selected_topic_aspects[aspect] = value
|
| 410 |
+
else:
|
| 411 |
+
selected_topic_aspects["Visual_Aspect"] = True
|
| 412 |
+
else:
|
| 413 |
+
selected_topic_aspects = model.topic_aspects_
|
| 414 |
+
|
| 415 |
+
topics = {
|
| 416 |
+
"topic_representations": model.topic_representations_,
|
| 417 |
+
"topics": [int(topic) for topic in model.topics_],
|
| 418 |
+
"topic_sizes": model.topic_sizes_,
|
| 419 |
+
"topic_mapper": np.array(model.topic_mapper_.mappings_, dtype=int).tolist(),
|
| 420 |
+
"topic_labels": model.topic_labels_,
|
| 421 |
+
"custom_labels": model.custom_labels_,
|
| 422 |
+
"_outliers": int(model._outliers),
|
| 423 |
+
"topic_aspects": selected_topic_aspects
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
with path.open('w') as f:
|
| 427 |
+
json.dump(topics, f, indent=2, cls=NumpyEncoder)
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def load_cfg_from_json(json_file: Union[str, os.PathLike]):
|
| 431 |
+
""" Load configuration from json """
|
| 432 |
+
with open(json_file, "r", encoding="utf-8") as reader:
|
| 433 |
+
text = reader.read()
|
| 434 |
+
return json.loads(text)
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
class NumpyEncoder(json.JSONEncoder):
|
| 438 |
+
def default(self, obj):
|
| 439 |
+
if isinstance(obj, np.integer):
|
| 440 |
+
return int(obj)
|
| 441 |
+
if isinstance(obj, np.floating):
|
| 442 |
+
return float(obj)
|
| 443 |
+
return super(NumpyEncoder, self).default(obj)
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def get_package_versions():
|
| 448 |
+
""" Get versions of main dependencies of BERTopic """
|
| 449 |
+
try:
|
| 450 |
+
import platform
|
| 451 |
+
from numpy import __version__ as np_version
|
| 452 |
+
|
| 453 |
+
try:
|
| 454 |
+
from importlib.metadata import version
|
| 455 |
+
hdbscan_version = version('hdbscan')
|
| 456 |
+
except:
|
| 457 |
+
hdbscan_version = None
|
| 458 |
+
|
| 459 |
+
from umap import __version__ as umap_version
|
| 460 |
+
from pandas import __version__ as pandas_version
|
| 461 |
+
from sklearn import __version__ as sklearn_version
|
| 462 |
+
from sentence_transformers import __version__ as sbert_version
|
| 463 |
+
from numba import __version__ as numba_version
|
| 464 |
+
from transformers import __version__ as transformers_version
|
| 465 |
+
|
| 466 |
+
from plotly import __version__ as plotly_version
|
| 467 |
+
return {"Numpy": np_version, "HDBSCAN": hdbscan_version, "UMAP": umap_version,
|
| 468 |
+
"Pandas": pandas_version, "Scikit-Learn": sklearn_version,
|
| 469 |
+
"Sentence-transformers": sbert_version, "Transformers": transformers_version,
|
| 470 |
+
"Numba": numba_version, "Plotly": plotly_version, "Python": platform.python_version()}
|
| 471 |
+
except Exception as e:
|
| 472 |
+
return e
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def load_safetensors(path):
|
| 476 |
+
""" Load safetensors and check whether it is installed """
|
| 477 |
+
try:
|
| 478 |
+
import safetensors.torch
|
| 479 |
+
import safetensors
|
| 480 |
+
return safetensors.torch.load_file(path, device="cpu")
|
| 481 |
+
except ImportError:
|
| 482 |
+
raise ValueError("`pip install safetensors` to load .safetensors")
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
def save_safetensors(path, tensors):
|
| 486 |
+
""" Save safetensors and check whether it is installed """
|
| 487 |
+
try:
|
| 488 |
+
import safetensors.torch
|
| 489 |
+
import safetensors
|
| 490 |
+
safetensors.torch.save_file(tensors, path)
|
| 491 |
+
except ImportError:
|
| 492 |
+
raise ValueError("`pip install safetensors` to save as .safetensors")
|
BERTopic/bertopic/_utils.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import logging
|
| 4 |
+
from collections.abc import Iterable
|
| 5 |
+
from scipy.sparse import csr_matrix
|
| 6 |
+
from scipy.spatial.distance import squareform
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class MyLogger:
|
| 10 |
+
def __init__(self, level):
|
| 11 |
+
self.logger = logging.getLogger('BERTopic')
|
| 12 |
+
self.set_level(level)
|
| 13 |
+
self._add_handler()
|
| 14 |
+
self.logger.propagate = False
|
| 15 |
+
|
| 16 |
+
def info(self, message):
|
| 17 |
+
self.logger.info(f"{message}")
|
| 18 |
+
|
| 19 |
+
def warning(self, message):
|
| 20 |
+
self.logger.warning(f"WARNING: {message}")
|
| 21 |
+
|
| 22 |
+
def set_level(self, level):
|
| 23 |
+
levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
|
| 24 |
+
if level in levels:
|
| 25 |
+
self.logger.setLevel(level)
|
| 26 |
+
|
| 27 |
+
def _add_handler(self):
|
| 28 |
+
sh = logging.StreamHandler()
|
| 29 |
+
sh.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(message)s'))
|
| 30 |
+
self.logger.addHandler(sh)
|
| 31 |
+
|
| 32 |
+
# Remove duplicate handlers
|
| 33 |
+
if len(self.logger.handlers) > 1:
|
| 34 |
+
self.logger.handlers = [self.logger.handlers[0]]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def check_documents_type(documents):
|
| 38 |
+
""" Check whether the input documents are indeed a list of strings """
|
| 39 |
+
if isinstance(documents, pd.DataFrame):
|
| 40 |
+
raise TypeError("Make sure to supply a list of strings, not a dataframe.")
|
| 41 |
+
elif isinstance(documents, Iterable) and not isinstance(documents, str):
|
| 42 |
+
if not any([isinstance(doc, str) for doc in documents]):
|
| 43 |
+
raise TypeError("Make sure that the iterable only contains strings.")
|
| 44 |
+
else:
|
| 45 |
+
raise TypeError("Make sure that the documents variable is an iterable containing strings only.")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def check_embeddings_shape(embeddings, docs):
|
| 49 |
+
""" Check if the embeddings have the correct shape """
|
| 50 |
+
if embeddings is not None:
|
| 51 |
+
if not any([isinstance(embeddings, np.ndarray), isinstance(embeddings, csr_matrix)]):
|
| 52 |
+
raise ValueError("Make sure to input embeddings as a numpy array or scipy.sparse.csr.csr_matrix. ")
|
| 53 |
+
else:
|
| 54 |
+
if embeddings.shape[0] != len(docs):
|
| 55 |
+
raise ValueError("Make sure that the embeddings are a numpy array with shape: "
|
| 56 |
+
"(len(docs), vector_dim) where vector_dim is the dimensionality "
|
| 57 |
+
"of the vector embeddings. ")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def check_is_fitted(topic_model):
|
| 61 |
+
""" Checks if the model was fitted by verifying the presence of self.matches
|
| 62 |
+
|
| 63 |
+
Arguments:
|
| 64 |
+
model: BERTopic instance for which the check is performed.
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
None
|
| 68 |
+
|
| 69 |
+
Raises:
|
| 70 |
+
ValueError: If the matches were not found.
|
| 71 |
+
"""
|
| 72 |
+
msg = ("This %(name)s instance is not fitted yet. Call 'fit' with "
|
| 73 |
+
"appropriate arguments before using this estimator.")
|
| 74 |
+
|
| 75 |
+
if topic_model.topics_ is None:
|
| 76 |
+
raise ValueError(msg % {'name': type(topic_model).__name__})
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class NotInstalled:
|
| 80 |
+
"""
|
| 81 |
+
This object is used to notify the user that additional dependencies need to be
|
| 82 |
+
installed in order to use the string matching model.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
def __init__(self, tool, dep, custom_msg=None):
|
| 86 |
+
self.tool = tool
|
| 87 |
+
self.dep = dep
|
| 88 |
+
|
| 89 |
+
msg = f"In order to use {self.tool} you will need to install via;\n\n"
|
| 90 |
+
if custom_msg is not None:
|
| 91 |
+
msg += custom_msg
|
| 92 |
+
else:
|
| 93 |
+
msg += f"pip install bertopic[{self.dep}]\n\n"
|
| 94 |
+
self.msg = msg
|
| 95 |
+
|
| 96 |
+
def __getattr__(self, *args, **kwargs):
|
| 97 |
+
raise ModuleNotFoundError(self.msg)
|
| 98 |
+
|
| 99 |
+
def __call__(self, *args, **kwargs):
|
| 100 |
+
raise ModuleNotFoundError(self.msg)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def validate_distance_matrix(X, n_samples):
|
| 104 |
+
""" Validate the distance matrix and convert it to a condensed distance matrix
|
| 105 |
+
if necessary.
|
| 106 |
+
|
| 107 |
+
A valid distance matrix is either a square matrix of shape (n_samples, n_samples)
|
| 108 |
+
with zeros on the diagonal and non-negative values or condensed distance matrix
|
| 109 |
+
of shape (n_samples * (n_samples - 1) / 2,) containing the upper triangular of the
|
| 110 |
+
distance matrix.
|
| 111 |
+
|
| 112 |
+
Arguments:
|
| 113 |
+
X: Distance matrix to validate.
|
| 114 |
+
n_samples: Number of samples in the dataset.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
X: Validated distance matrix.
|
| 118 |
+
|
| 119 |
+
Raises:
|
| 120 |
+
ValueError: If the distance matrix is not valid.
|
| 121 |
+
"""
|
| 122 |
+
# Make sure it is the 1-D condensed distance matrix with zeros on the diagonal
|
| 123 |
+
s = X.shape
|
| 124 |
+
if len(s) == 1:
|
| 125 |
+
# check it has correct size
|
| 126 |
+
n = s[0]
|
| 127 |
+
if n != (n_samples * (n_samples - 1) / 2):
|
| 128 |
+
raise ValueError("The condensed distance matrix must have "
|
| 129 |
+
"shape (n*(n-1)/2,).")
|
| 130 |
+
elif len(s) == 2:
|
| 131 |
+
# check it has correct size
|
| 132 |
+
if (s[0] != n_samples) or (s[1] != n_samples):
|
| 133 |
+
raise ValueError("The distance matrix must be of shape "
|
| 134 |
+
"(n, n) where n is the number of samples.")
|
| 135 |
+
# force zero diagonal and convert to condensed
|
| 136 |
+
np.fill_diagonal(X, 0)
|
| 137 |
+
X = squareform(X)
|
| 138 |
+
else:
|
| 139 |
+
raise ValueError("The distance matrix must be either a 1-D condensed "
|
| 140 |
+
"distance matrix of shape (n*(n-1)/2,) or a "
|
| 141 |
+
"2-D square distance matrix of shape (n, n)."
|
| 142 |
+
"where n is the number of documents."
|
| 143 |
+
"Got a distance matrix of shape %s" % str(s))
|
| 144 |
+
|
| 145 |
+
# Make sure its entries are non-negative
|
| 146 |
+
if np.any(X < 0):
|
| 147 |
+
raise ValueError("Distance matrix cannot contain negative values.")
|
| 148 |
+
|
| 149 |
+
return X
|
BERTopic/bertopic/backend/__init__.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ._base import BaseEmbedder
|
| 2 |
+
from ._word_doc import WordDocEmbedder
|
| 3 |
+
from ._utils import languages
|
| 4 |
+
from bertopic._utils import NotInstalled
|
| 5 |
+
|
| 6 |
+
# OpenAI Embeddings
|
| 7 |
+
try:
|
| 8 |
+
from bertopic.backend._openai import OpenAIBackend
|
| 9 |
+
except ModuleNotFoundError:
|
| 10 |
+
msg = "`pip install openai` \n\n"
|
| 11 |
+
OpenAIBackend = NotInstalled("OpenAI", "OpenAI", custom_msg=msg)
|
| 12 |
+
|
| 13 |
+
# Cohere Embeddings
|
| 14 |
+
try:
|
| 15 |
+
from bertopic.backend._cohere import CohereBackend
|
| 16 |
+
except ModuleNotFoundError:
|
| 17 |
+
msg = "`pip install cohere` \n\n"
|
| 18 |
+
CohereBackend = NotInstalled("Cohere", "Cohere", custom_msg=msg)
|
| 19 |
+
|
| 20 |
+
# Multimodal Embeddings
|
| 21 |
+
try:
|
| 22 |
+
from bertopic.backend._multimodal import MultiModalBackend
|
| 23 |
+
except ModuleNotFoundError:
|
| 24 |
+
msg = "`pip install bertopic[vision]` \n\n"
|
| 25 |
+
MultiModalBackend = NotInstalled("Vision", "Vision", custom_msg=msg)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
__all__ = [
|
| 29 |
+
"BaseEmbedder",
|
| 30 |
+
"WordDocEmbedder",
|
| 31 |
+
"OpenAIBackend",
|
| 32 |
+
"CohereBackend",
|
| 33 |
+
"MultiModalBackend",
|
| 34 |
+
"languages"
|
| 35 |
+
]
|
BERTopic/bertopic/backend/_base.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class BaseEmbedder:
|
| 6 |
+
""" The Base Embedder used for creating embedding models
|
| 7 |
+
|
| 8 |
+
Arguments:
|
| 9 |
+
embedding_model: The main embedding model to be used for extracting
|
| 10 |
+
document and word embedding
|
| 11 |
+
word_embedding_model: The embedding model used for extracting word
|
| 12 |
+
embeddings only. If this model is selected,
|
| 13 |
+
then the `embedding_model` is purely used for
|
| 14 |
+
creating document embeddings.
|
| 15 |
+
"""
|
| 16 |
+
def __init__(self,
|
| 17 |
+
embedding_model=None,
|
| 18 |
+
word_embedding_model=None):
|
| 19 |
+
self.embedding_model = embedding_model
|
| 20 |
+
self.word_embedding_model = word_embedding_model
|
| 21 |
+
|
| 22 |
+
def embed(self,
|
| 23 |
+
documents: List[str],
|
| 24 |
+
verbose: bool = False) -> np.ndarray:
|
| 25 |
+
""" Embed a list of n documents/words into an n-dimensional
|
| 26 |
+
matrix of embeddings
|
| 27 |
+
|
| 28 |
+
Arguments:
|
| 29 |
+
documents: A list of documents or words to be embedded
|
| 30 |
+
verbose: Controls the verbosity of the process
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
Document/words embeddings with shape (n, m) with `n` documents/words
|
| 34 |
+
that each have an embeddings size of `m`
|
| 35 |
+
"""
|
| 36 |
+
pass
|
| 37 |
+
|
| 38 |
+
def embed_words(self,
|
| 39 |
+
words: List[str],
|
| 40 |
+
verbose: bool = False) -> np.ndarray:
|
| 41 |
+
""" Embed a list of n words into an n-dimensional
|
| 42 |
+
matrix of embeddings
|
| 43 |
+
|
| 44 |
+
Arguments:
|
| 45 |
+
words: A list of words to be embedded
|
| 46 |
+
verbose: Controls the verbosity of the process
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
Word embeddings with shape (n, m) with `n` words
|
| 50 |
+
that each have an embeddings size of `m`
|
| 51 |
+
|
| 52 |
+
"""
|
| 53 |
+
return self.embed(words, verbose)
|
| 54 |
+
|
| 55 |
+
def embed_documents(self,
|
| 56 |
+
document: List[str],
|
| 57 |
+
verbose: bool = False) -> np.ndarray:
|
| 58 |
+
""" Embed a list of n words into an n-dimensional
|
| 59 |
+
matrix of embeddings
|
| 60 |
+
|
| 61 |
+
Arguments:
|
| 62 |
+
document: A list of documents to be embedded
|
| 63 |
+
verbose: Controls the verbosity of the process
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
Document embeddings with shape (n, m) with `n` documents
|
| 67 |
+
that each have an embeddings size of `m`
|
| 68 |
+
"""
|
| 69 |
+
return self.embed(document, verbose)
|
BERTopic/bertopic/backend/_cohere.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import numpy as np
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from typing import Any, List, Mapping
|
| 5 |
+
from bertopic.backend import BaseEmbedder
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class CohereBackend(BaseEmbedder):
|
| 9 |
+
""" Cohere Embedding Model
|
| 10 |
+
|
| 11 |
+
Arguments:
|
| 12 |
+
client: A `cohere` client.
|
| 13 |
+
embedding_model: A Cohere model. Default is "large".
|
| 14 |
+
For an overview of models see:
|
| 15 |
+
https://docs.cohere.ai/docs/generation-card
|
| 16 |
+
delay_in_seconds: If a `batch_size` is given, use this set
|
| 17 |
+
the delay in seconds between batches.
|
| 18 |
+
batch_size: The size of each batch.
|
| 19 |
+
embed_kwargs: Kwargs passed to `cohere.Client.embed`.
|
| 20 |
+
Can be used to define additional parameters
|
| 21 |
+
such as `input_type`
|
| 22 |
+
|
| 23 |
+
Examples:
|
| 24 |
+
|
| 25 |
+
```python
|
| 26 |
+
import cohere
|
| 27 |
+
from bertopic.backend import CohereBackend
|
| 28 |
+
|
| 29 |
+
client = cohere.Client("APIKEY")
|
| 30 |
+
cohere_model = CohereBackend(client)
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
If you want to specify `input_type`:
|
| 34 |
+
|
| 35 |
+
```python
|
| 36 |
+
cohere_model = CohereBackend(
|
| 37 |
+
client,
|
| 38 |
+
embedding_model="embed-english-v3.0",
|
| 39 |
+
embed_kwargs={"input_type": "clustering"}
|
| 40 |
+
)
|
| 41 |
+
```
|
| 42 |
+
"""
|
| 43 |
+
def __init__(self,
|
| 44 |
+
client,
|
| 45 |
+
embedding_model: str = "large",
|
| 46 |
+
delay_in_seconds: float = None,
|
| 47 |
+
batch_size: int = None,
|
| 48 |
+
embed_kwargs: Mapping[str, Any] = {}):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.client = client
|
| 51 |
+
self.embedding_model = embedding_model
|
| 52 |
+
self.delay_in_seconds = delay_in_seconds
|
| 53 |
+
self.batch_size = batch_size
|
| 54 |
+
self.embed_kwargs = embed_kwargs
|
| 55 |
+
|
| 56 |
+
if self.embed_kwargs.get("model"):
|
| 57 |
+
self.embedding_model = embed_kwargs.get("model")
|
| 58 |
+
else:
|
| 59 |
+
self.embed_kwargs["model"] = self.embedding_model
|
| 60 |
+
|
| 61 |
+
def embed(self,
|
| 62 |
+
documents: List[str],
|
| 63 |
+
verbose: bool = False) -> np.ndarray:
|
| 64 |
+
""" Embed a list of n documents/words into an n-dimensional
|
| 65 |
+
matrix of embeddings
|
| 66 |
+
|
| 67 |
+
Arguments:
|
| 68 |
+
documents: A list of documents or words to be embedded
|
| 69 |
+
verbose: Controls the verbosity of the process
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
Document/words embeddings with shape (n, m) with `n` documents/words
|
| 73 |
+
that each have an embeddings size of `m`
|
| 74 |
+
"""
|
| 75 |
+
# Batch-wise embedding extraction
|
| 76 |
+
if self.batch_size is not None:
|
| 77 |
+
embeddings = []
|
| 78 |
+
for batch in tqdm(self._chunks(documents), disable=not verbose):
|
| 79 |
+
response = self.client.embed(batch, **self.embed_kwargs)
|
| 80 |
+
embeddings.extend(response.embeddings)
|
| 81 |
+
|
| 82 |
+
# Delay subsequent calls
|
| 83 |
+
if self.delay_in_seconds:
|
| 84 |
+
time.sleep(self.delay_in_seconds)
|
| 85 |
+
|
| 86 |
+
# Extract embeddings all at once
|
| 87 |
+
else:
|
| 88 |
+
response = self.client.embed(documents, **self.embed_kwargs)
|
| 89 |
+
embeddings = response.embeddings
|
| 90 |
+
return np.array(embeddings)
|
| 91 |
+
|
| 92 |
+
def _chunks(self, documents):
|
| 93 |
+
for i in range(0, len(documents), self.batch_size):
|
| 94 |
+
yield documents[i:i + self.batch_size]
|
BERTopic/bertopic/backend/_flair.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
from typing import Union, List
|
| 4 |
+
from flair.data import Sentence
|
| 5 |
+
from flair.embeddings import DocumentEmbeddings, TokenEmbeddings, DocumentPoolEmbeddings
|
| 6 |
+
|
| 7 |
+
from bertopic.backend import BaseEmbedder
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class FlairBackend(BaseEmbedder):
|
| 11 |
+
""" Flair Embedding Model
|
| 12 |
+
|
| 13 |
+
The Flair embedding model used for generating document and
|
| 14 |
+
word embeddings.
|
| 15 |
+
|
| 16 |
+
Arguments:
|
| 17 |
+
embedding_model: A Flair embedding model
|
| 18 |
+
|
| 19 |
+
Examples:
|
| 20 |
+
|
| 21 |
+
```python
|
| 22 |
+
from bertopic.backend import FlairBackend
|
| 23 |
+
from flair.embeddings import WordEmbeddings, DocumentPoolEmbeddings
|
| 24 |
+
|
| 25 |
+
# Create a Flair Embedding model
|
| 26 |
+
glove_embedding = WordEmbeddings('crawl')
|
| 27 |
+
document_glove_embeddings = DocumentPoolEmbeddings([glove_embedding])
|
| 28 |
+
|
| 29 |
+
# Pass the Flair model to create a new backend
|
| 30 |
+
flair_embedder = FlairBackend(document_glove_embeddings)
|
| 31 |
+
```
|
| 32 |
+
"""
|
| 33 |
+
def __init__(self, embedding_model: Union[TokenEmbeddings, DocumentEmbeddings]):
|
| 34 |
+
super().__init__()
|
| 35 |
+
|
| 36 |
+
# Flair word embeddings
|
| 37 |
+
if isinstance(embedding_model, TokenEmbeddings):
|
| 38 |
+
self.embedding_model = DocumentPoolEmbeddings([embedding_model])
|
| 39 |
+
|
| 40 |
+
# Flair document embeddings + disable fine tune to prevent CUDA OOM
|
| 41 |
+
# https://github.com/flairNLP/flair/issues/1719
|
| 42 |
+
elif isinstance(embedding_model, DocumentEmbeddings):
|
| 43 |
+
if "fine_tune" in embedding_model.__dict__:
|
| 44 |
+
embedding_model.fine_tune = False
|
| 45 |
+
self.embedding_model = embedding_model
|
| 46 |
+
|
| 47 |
+
else:
|
| 48 |
+
raise ValueError("Please select a correct Flair model by either using preparing a token or document "
|
| 49 |
+
"embedding model: \n"
|
| 50 |
+
"`from flair.embeddings import TransformerDocumentEmbeddings` \n"
|
| 51 |
+
"`roberta = TransformerDocumentEmbeddings('roberta-base')`")
|
| 52 |
+
|
| 53 |
+
def embed(self,
|
| 54 |
+
documents: List[str],
|
| 55 |
+
verbose: bool = False) -> np.ndarray:
|
| 56 |
+
""" Embed a list of n documents/words into an n-dimensional
|
| 57 |
+
matrix of embeddings
|
| 58 |
+
|
| 59 |
+
Arguments:
|
| 60 |
+
documents: A list of documents or words to be embedded
|
| 61 |
+
verbose: Controls the verbosity of the process
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
Document/words embeddings with shape (n, m) with `n` documents/words
|
| 65 |
+
that each have an embeddings size of `m`
|
| 66 |
+
"""
|
| 67 |
+
embeddings = []
|
| 68 |
+
for document in tqdm(documents, disable=not verbose):
|
| 69 |
+
try:
|
| 70 |
+
sentence = Sentence(document) if document else Sentence("an empty document")
|
| 71 |
+
self.embedding_model.embed(sentence)
|
| 72 |
+
except RuntimeError:
|
| 73 |
+
sentence = Sentence("an empty document")
|
| 74 |
+
self.embedding_model.embed(sentence)
|
| 75 |
+
embedding = sentence.embedding.detach().cpu().numpy()
|
| 76 |
+
embeddings.append(embedding)
|
| 77 |
+
embeddings = np.asarray(embeddings)
|
| 78 |
+
return embeddings
|
BERTopic/bertopic/backend/_gensim.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
from typing import List
|
| 4 |
+
from bertopic.backend import BaseEmbedder
|
| 5 |
+
from gensim.models.keyedvectors import Word2VecKeyedVectors
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class GensimBackend(BaseEmbedder):
|
| 9 |
+
""" Gensim Embedding Model
|
| 10 |
+
|
| 11 |
+
The Gensim embedding model is typically used for word embeddings with
|
| 12 |
+
GloVe, Word2Vec or FastText.
|
| 13 |
+
|
| 14 |
+
Arguments:
|
| 15 |
+
embedding_model: A Gensim embedding model
|
| 16 |
+
|
| 17 |
+
Examples:
|
| 18 |
+
|
| 19 |
+
```python
|
| 20 |
+
from bertopic.backend import GensimBackend
|
| 21 |
+
import gensim.downloader as api
|
| 22 |
+
|
| 23 |
+
ft = api.load('fasttext-wiki-news-subwords-300')
|
| 24 |
+
ft_embedder = GensimBackend(ft)
|
| 25 |
+
```
|
| 26 |
+
"""
|
| 27 |
+
def __init__(self, embedding_model: Word2VecKeyedVectors):
|
| 28 |
+
super().__init__()
|
| 29 |
+
|
| 30 |
+
if isinstance(embedding_model, Word2VecKeyedVectors):
|
| 31 |
+
self.embedding_model = embedding_model
|
| 32 |
+
else:
|
| 33 |
+
raise ValueError("Please select a correct Gensim model: \n"
|
| 34 |
+
"`import gensim.downloader as api` \n"
|
| 35 |
+
"`ft = api.load('fasttext-wiki-news-subwords-300')`")
|
| 36 |
+
|
| 37 |
+
def embed(self,
|
| 38 |
+
documents: List[str],
|
| 39 |
+
verbose: bool = False) -> np.ndarray:
|
| 40 |
+
""" Embed a list of n documents/words into an n-dimensional
|
| 41 |
+
matrix of embeddings
|
| 42 |
+
|
| 43 |
+
Arguments:
|
| 44 |
+
documents: A list of documents or words to be embedded
|
| 45 |
+
verbose: Controls the verbosity of the process
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
Document/words embeddings with shape (n, m) with `n` documents/words
|
| 49 |
+
that each have an embeddings size of `m`
|
| 50 |
+
"""
|
| 51 |
+
vector_shape = self.embedding_model.get_vector(list(self.embedding_model.index_to_key)[0]).shape[0]
|
| 52 |
+
empty_vector = np.zeros(vector_shape)
|
| 53 |
+
|
| 54 |
+
# Extract word embeddings and pool to document-level
|
| 55 |
+
embeddings = []
|
| 56 |
+
for doc in tqdm(documents, disable=not verbose, position=0, leave=True):
|
| 57 |
+
embedding = [self.embedding_model.get_vector(word) for word in doc.split()
|
| 58 |
+
if word in self.embedding_model.key_to_index]
|
| 59 |
+
|
| 60 |
+
if len(embedding) > 0:
|
| 61 |
+
embeddings.append(np.mean(embedding, axis=0))
|
| 62 |
+
else:
|
| 63 |
+
embeddings.append(empty_vector)
|
| 64 |
+
|
| 65 |
+
embeddings = np.array(embeddings)
|
| 66 |
+
return embeddings
|
BERTopic/bertopic/backend/_hftransformers.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from typing import List
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
from sklearn.preprocessing import normalize
|
| 7 |
+
from transformers.pipelines import Pipeline
|
| 8 |
+
|
| 9 |
+
from bertopic.backend import BaseEmbedder
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class HFTransformerBackend(BaseEmbedder):
|
| 13 |
+
""" Hugging Face transformers model
|
| 14 |
+
|
| 15 |
+
This uses the `transformers.pipelines.pipeline` to define and create
|
| 16 |
+
a feature generation pipeline from which embeddings can be extracted.
|
| 17 |
+
|
| 18 |
+
Arguments:
|
| 19 |
+
embedding_model: A Hugging Face feature extraction pipeline
|
| 20 |
+
|
| 21 |
+
Examples:
|
| 22 |
+
|
| 23 |
+
To use a Hugging Face transformers model, load in a pipeline and point
|
| 24 |
+
to any model found on their model hub (https://huggingface.co/models):
|
| 25 |
+
|
| 26 |
+
```python
|
| 27 |
+
from bertopic.backend import HFTransformerBackend
|
| 28 |
+
from transformers.pipelines import pipeline
|
| 29 |
+
|
| 30 |
+
hf_model = pipeline("feature-extraction", model="distilbert-base-cased")
|
| 31 |
+
embedding_model = HFTransformerBackend(hf_model)
|
| 32 |
+
```
|
| 33 |
+
"""
|
| 34 |
+
def __init__(self, embedding_model: Pipeline):
|
| 35 |
+
super().__init__()
|
| 36 |
+
|
| 37 |
+
if isinstance(embedding_model, Pipeline):
|
| 38 |
+
self.embedding_model = embedding_model
|
| 39 |
+
else:
|
| 40 |
+
raise ValueError("Please select a correct transformers pipeline. For example: "
|
| 41 |
+
"pipeline('feature-extraction', model='distilbert-base-cased', device=0)")
|
| 42 |
+
|
| 43 |
+
def embed(self,
|
| 44 |
+
documents: List[str],
|
| 45 |
+
verbose: bool = False) -> np.ndarray:
|
| 46 |
+
""" Embed a list of n documents/words into an n-dimensional
|
| 47 |
+
matrix of embeddings
|
| 48 |
+
|
| 49 |
+
Arguments:
|
| 50 |
+
documents: A list of documents or words to be embedded
|
| 51 |
+
verbose: Controls the verbosity of the process
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
Document/words embeddings with shape (n, m) with `n` documents/words
|
| 55 |
+
that each have an embeddings size of `m`
|
| 56 |
+
"""
|
| 57 |
+
dataset = MyDataset(documents)
|
| 58 |
+
|
| 59 |
+
embeddings = []
|
| 60 |
+
for document, features in tqdm(zip(documents, self.embedding_model(dataset, truncation=True, padding=True)),
|
| 61 |
+
total=len(dataset), disable=not verbose):
|
| 62 |
+
embeddings.append(self._embed(document, features))
|
| 63 |
+
|
| 64 |
+
return np.array(embeddings)
|
| 65 |
+
|
| 66 |
+
def _embed(self,
|
| 67 |
+
document: str,
|
| 68 |
+
features: np.ndarray) -> np.ndarray:
|
| 69 |
+
""" Mean pooling
|
| 70 |
+
|
| 71 |
+
Arguments:
|
| 72 |
+
document: The document for which to extract the attention mask
|
| 73 |
+
features: The embeddings for each token
|
| 74 |
+
|
| 75 |
+
Adopted from:
|
| 76 |
+
https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2#usage-huggingface-transformers
|
| 77 |
+
"""
|
| 78 |
+
token_embeddings = np.array(features)
|
| 79 |
+
attention_mask = self.embedding_model.tokenizer(document, truncation=True, padding=True, return_tensors="np")["attention_mask"]
|
| 80 |
+
input_mask_expanded = np.broadcast_to(np.expand_dims(attention_mask, -1), token_embeddings.shape)
|
| 81 |
+
sum_embeddings = np.sum(token_embeddings * input_mask_expanded, 1)
|
| 82 |
+
sum_mask = np.clip(input_mask_expanded.sum(1), a_min=1e-9, a_max=input_mask_expanded.sum(1).max())
|
| 83 |
+
embedding = normalize(sum_embeddings / sum_mask)[0]
|
| 84 |
+
return embedding
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class MyDataset(Dataset):
|
| 88 |
+
""" Dataset to pass to `transformers.pipelines.pipeline` """
|
| 89 |
+
def __init__(self, docs):
|
| 90 |
+
self.docs = docs
|
| 91 |
+
|
| 92 |
+
def __len__(self):
|
| 93 |
+
return len(self.docs)
|
| 94 |
+
|
| 95 |
+
def __getitem__(self, idx):
|
| 96 |
+
return self.docs[idx]
|
BERTopic/bertopic/backend/_multimodal.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
from typing import List, Union
|
| 6 |
+
from sentence_transformers import SentenceTransformer
|
| 7 |
+
|
| 8 |
+
from bertopic.backend import BaseEmbedder
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class MultiModalBackend(BaseEmbedder):
|
| 12 |
+
""" Multimodal backend using Sentence-transformers
|
| 13 |
+
|
| 14 |
+
The sentence-transformers embedding model used for
|
| 15 |
+
generating word, document, and image embeddings.
|
| 16 |
+
|
| 17 |
+
Arguments:
|
| 18 |
+
embedding_model: A sentence-transformers embedding model that
|
| 19 |
+
can either embed both images and text or only text.
|
| 20 |
+
If it only embeds text, then `image_model` needs
|
| 21 |
+
to be used to embed the images.
|
| 22 |
+
image_model: A sentence-transformers embedding model that is used
|
| 23 |
+
to embed only images.
|
| 24 |
+
batch_size: The sizes of image batches to pass
|
| 25 |
+
|
| 26 |
+
Examples:
|
| 27 |
+
|
| 28 |
+
To create a model, you can load in a string pointing to a
|
| 29 |
+
sentence-transformers model:
|
| 30 |
+
|
| 31 |
+
```python
|
| 32 |
+
from bertopic.backend import MultiModalBackend
|
| 33 |
+
|
| 34 |
+
sentence_model = MultiModalBackend("clip-ViT-B-32")
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
or you can instantiate a model yourself:
|
| 38 |
+
```python
|
| 39 |
+
from bertopic.backend import MultiModalBackend
|
| 40 |
+
from sentence_transformers import SentenceTransformer
|
| 41 |
+
|
| 42 |
+
embedding_model = SentenceTransformer("clip-ViT-B-32")
|
| 43 |
+
sentence_model = MultiModalBackend(embedding_model)
|
| 44 |
+
```
|
| 45 |
+
"""
|
| 46 |
+
def __init__(self,
|
| 47 |
+
embedding_model: Union[str, SentenceTransformer],
|
| 48 |
+
image_model: Union[str, SentenceTransformer] = None,
|
| 49 |
+
batch_size: int = 32):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.batch_size = batch_size
|
| 52 |
+
|
| 53 |
+
# Text or Text+Image model
|
| 54 |
+
if isinstance(embedding_model, SentenceTransformer):
|
| 55 |
+
self.embedding_model = embedding_model
|
| 56 |
+
elif isinstance(embedding_model, str):
|
| 57 |
+
self.embedding_model = SentenceTransformer(embedding_model)
|
| 58 |
+
else:
|
| 59 |
+
raise ValueError("Please select a correct SentenceTransformers model: \n"
|
| 60 |
+
"`from sentence_transformers import SentenceTransformer` \n"
|
| 61 |
+
"`model = SentenceTransformer('clip-ViT-B-32')`")
|
| 62 |
+
|
| 63 |
+
# Image Model
|
| 64 |
+
self.image_model = None
|
| 65 |
+
if image_model is not None:
|
| 66 |
+
if isinstance(image_model, SentenceTransformer):
|
| 67 |
+
self.image_model = image_model
|
| 68 |
+
elif isinstance(image_model, str):
|
| 69 |
+
self.image_model = SentenceTransformer(image_model)
|
| 70 |
+
else:
|
| 71 |
+
raise ValueError("Please select a correct SentenceTransformers model: \n"
|
| 72 |
+
"`from sentence_transformers import SentenceTransformer` \n"
|
| 73 |
+
"`model = SentenceTransformer('clip-ViT-B-32')`")
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
self.tokenizer = self.embedding_model._first_module().processor.tokenizer
|
| 77 |
+
except AttributeError:
|
| 78 |
+
self.tokenizer = self.embedding_model.tokenizer
|
| 79 |
+
except:
|
| 80 |
+
self.tokenizer = None
|
| 81 |
+
|
| 82 |
+
def embed(self,
|
| 83 |
+
documents: List[str],
|
| 84 |
+
images: List[str] = None,
|
| 85 |
+
verbose: bool = False) -> np.ndarray:
|
| 86 |
+
""" Embed a list of n documents/words into an n-dimensional
|
| 87 |
+
matrix of embeddings
|
| 88 |
+
|
| 89 |
+
Arguments:
|
| 90 |
+
documents: A list of documents or words to be embedded
|
| 91 |
+
verbose: Controls the verbosity of the process
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
Document/words embeddings with shape (n, m) with `n` documents/words
|
| 95 |
+
that each have an embeddings size of `m`
|
| 96 |
+
"""
|
| 97 |
+
# Embed documents
|
| 98 |
+
doc_embeddings = None
|
| 99 |
+
if documents[0] is not None:
|
| 100 |
+
doc_embeddings = self.embed_documents(documents)
|
| 101 |
+
|
| 102 |
+
# Embed images
|
| 103 |
+
image_embeddings = None
|
| 104 |
+
if isinstance(images, list):
|
| 105 |
+
image_embeddings = self.embed_images(images, verbose)
|
| 106 |
+
|
| 107 |
+
# Average embeddings
|
| 108 |
+
averaged_embeddings = None
|
| 109 |
+
if doc_embeddings is not None and image_embeddings is not None:
|
| 110 |
+
averaged_embeddings = np.mean([doc_embeddings, image_embeddings], axis=0)
|
| 111 |
+
|
| 112 |
+
if averaged_embeddings is not None:
|
| 113 |
+
return averaged_embeddings
|
| 114 |
+
elif doc_embeddings is not None:
|
| 115 |
+
return doc_embeddings
|
| 116 |
+
elif image_embeddings is not None:
|
| 117 |
+
return image_embeddings
|
| 118 |
+
|
| 119 |
+
def embed_documents(self,
|
| 120 |
+
documents: List[str],
|
| 121 |
+
verbose: bool = False) -> np.ndarray:
|
| 122 |
+
""" Embed a list of n documents/words into an n-dimensional
|
| 123 |
+
matrix of embeddings
|
| 124 |
+
|
| 125 |
+
Arguments:
|
| 126 |
+
documents: A list of documents or words to be embedded
|
| 127 |
+
verbose: Controls the verbosity of the process
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
Document/words embeddings with shape (n, m) with `n` documents/words
|
| 131 |
+
that each have an embeddings size of `m`
|
| 132 |
+
"""
|
| 133 |
+
truncated_docs = [self._truncate_document(doc) for doc in documents]
|
| 134 |
+
embeddings = self.embedding_model.encode(truncated_docs, show_progress_bar=verbose)
|
| 135 |
+
return embeddings
|
| 136 |
+
|
| 137 |
+
def embed_words(self, words: List[str], verbose: bool = False) -> np.ndarray:
|
| 138 |
+
""" Embed a list of n words into an n-dimensional
|
| 139 |
+
matrix of embeddings
|
| 140 |
+
|
| 141 |
+
Arguments:
|
| 142 |
+
words: A list of words to be embedded
|
| 143 |
+
verbose: Controls the verbosity of the process
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
Document/words embeddings with shape (n, m) with `n` documents/words
|
| 147 |
+
that each have an embeddings size of `m`
|
| 148 |
+
"""
|
| 149 |
+
embeddings = self.embedding_model.encode(words, show_progress_bar=verbose)
|
| 150 |
+
return embeddings
|
| 151 |
+
|
| 152 |
+
def embed_images(self, images, verbose):
|
| 153 |
+
if self.batch_size:
|
| 154 |
+
nr_iterations = int(np.ceil(len(images) / self.batch_size))
|
| 155 |
+
|
| 156 |
+
# Embed images per batch
|
| 157 |
+
embeddings = []
|
| 158 |
+
for i in tqdm(range(nr_iterations), disable=not verbose):
|
| 159 |
+
start_index = i * self.batch_size
|
| 160 |
+
end_index = (i * self.batch_size) + self.batch_size
|
| 161 |
+
|
| 162 |
+
images_to_embed = [Image.open(image) if isinstance(image, str) else image for image in images[start_index:end_index]]
|
| 163 |
+
if self.image_model is not None:
|
| 164 |
+
img_emb = self.image_model.encode(images_to_embed)
|
| 165 |
+
else:
|
| 166 |
+
img_emb = self.embedding_model.encode(images_to_embed, show_progress_bar=False)
|
| 167 |
+
embeddings.extend(img_emb.tolist())
|
| 168 |
+
|
| 169 |
+
# Close images
|
| 170 |
+
if isinstance(images[0], str):
|
| 171 |
+
for image in images_to_embed:
|
| 172 |
+
image.close()
|
| 173 |
+
embeddings = np.array(embeddings)
|
| 174 |
+
else:
|
| 175 |
+
images_to_embed = [Image.open(filepath) for filepath in images]
|
| 176 |
+
if self.image_model is not None:
|
| 177 |
+
embeddings = self.image_model.encode(images_to_embed)
|
| 178 |
+
else:
|
| 179 |
+
embeddings = self.embedding_model.encode(images_to_embed, show_progress_bar=False)
|
| 180 |
+
return embeddings
|
| 181 |
+
|
| 182 |
+
def _truncate_document(self, document):
|
| 183 |
+
if self.tokenizer:
|
| 184 |
+
tokens = self.tokenizer.encode(document)
|
| 185 |
+
|
| 186 |
+
if len(tokens) > 77:
|
| 187 |
+
# Skip the starting token, only include 75 tokens
|
| 188 |
+
truncated_tokens = tokens[1:76]
|
| 189 |
+
document = self.tokenizer.decode(truncated_tokens)
|
| 190 |
+
|
| 191 |
+
# Recursive call here, because the encode(decode()) can have different result
|
| 192 |
+
return self._truncate_document(document)
|
| 193 |
+
|
| 194 |
+
return document
|
BERTopic/bertopic/backend/_openai.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import openai
|
| 3 |
+
import numpy as np
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
from typing import List, Mapping, Any
|
| 6 |
+
from bertopic.backend import BaseEmbedder
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class OpenAIBackend(BaseEmbedder):
|
| 10 |
+
""" OpenAI Embedding Model
|
| 11 |
+
|
| 12 |
+
Arguments:
|
| 13 |
+
client: A `openai.OpenAI` client.
|
| 14 |
+
embedding_model: An OpenAI model. Default is
|
| 15 |
+
For an overview of models see:
|
| 16 |
+
https://platform.openai.com/docs/models/embeddings
|
| 17 |
+
delay_in_seconds: If a `batch_size` is given, use this set
|
| 18 |
+
the delay in seconds between batches.
|
| 19 |
+
batch_size: The size of each batch.
|
| 20 |
+
generator_kwargs: Kwargs passed to `openai.Embedding.create`.
|
| 21 |
+
Can be used to define custom engines or
|
| 22 |
+
deployment_ids.
|
| 23 |
+
|
| 24 |
+
Examples:
|
| 25 |
+
|
| 26 |
+
```python
|
| 27 |
+
import openai
|
| 28 |
+
from bertopic.backend import OpenAIBackend
|
| 29 |
+
|
| 30 |
+
client = openai.OpenAI(api_key="sk-...")
|
| 31 |
+
openai_embedder = OpenAIBackend(client, "text-embedding-ada-002")
|
| 32 |
+
```
|
| 33 |
+
"""
|
| 34 |
+
def __init__(self,
|
| 35 |
+
client: openai.OpenAI,
|
| 36 |
+
embedding_model: str = "text-embedding-ada-002",
|
| 37 |
+
delay_in_seconds: float = None,
|
| 38 |
+
batch_size: int = None,
|
| 39 |
+
generator_kwargs: Mapping[str, Any] = {}):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.client = client
|
| 42 |
+
self.embedding_model = embedding_model
|
| 43 |
+
self.delay_in_seconds = delay_in_seconds
|
| 44 |
+
self.batch_size = batch_size
|
| 45 |
+
self.generator_kwargs = generator_kwargs
|
| 46 |
+
|
| 47 |
+
if self.generator_kwargs.get("model"):
|
| 48 |
+
self.embedding_model = generator_kwargs.get("model")
|
| 49 |
+
elif not self.generator_kwargs.get("engine"):
|
| 50 |
+
self.generator_kwargs["model"] = self.embedding_model
|
| 51 |
+
|
| 52 |
+
def embed(self,
|
| 53 |
+
documents: List[str],
|
| 54 |
+
verbose: bool = False) -> np.ndarray:
|
| 55 |
+
""" Embed a list of n documents/words into an n-dimensional
|
| 56 |
+
matrix of embeddings
|
| 57 |
+
|
| 58 |
+
Arguments:
|
| 59 |
+
documents: A list of documents or words to be embedded
|
| 60 |
+
verbose: Controls the verbosity of the process
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
Document/words embeddings with shape (n, m) with `n` documents/words
|
| 64 |
+
that each have an embeddings size of `m`
|
| 65 |
+
"""
|
| 66 |
+
# Prepare documents, replacing empty strings with a single space
|
| 67 |
+
prepared_documents = [" " if doc == "" else doc for doc in documents]
|
| 68 |
+
|
| 69 |
+
# Batch-wise embedding extraction
|
| 70 |
+
if self.batch_size is not None:
|
| 71 |
+
embeddings = []
|
| 72 |
+
for batch in tqdm(self._chunks(prepared_documents), disable=not verbose):
|
| 73 |
+
response = self.client.embeddings.create(input=batch, **self.generator_kwargs)
|
| 74 |
+
embeddings.extend([r.embedding for r in response.data])
|
| 75 |
+
|
| 76 |
+
# Delay subsequent calls
|
| 77 |
+
if self.delay_in_seconds:
|
| 78 |
+
time.sleep(self.delay_in_seconds)
|
| 79 |
+
|
| 80 |
+
# Extract embeddings all at once
|
| 81 |
+
else:
|
| 82 |
+
response = self.client.embeddings.create(input=prepared_documents, **self.generator_kwargs)
|
| 83 |
+
embeddings = [r.embedding for r in response.data]
|
| 84 |
+
return np.array(embeddings)
|
| 85 |
+
|
| 86 |
+
def _chunks(self, documents):
|
| 87 |
+
for i in range(0, len(documents), self.batch_size):
|
| 88 |
+
yield documents[i:i + self.batch_size]
|
BERTopic/bertopic/backend/_sentencetransformers.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from typing import List, Union
|
| 3 |
+
from sentence_transformers import SentenceTransformer
|
| 4 |
+
|
| 5 |
+
from bertopic.backend import BaseEmbedder
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SentenceTransformerBackend(BaseEmbedder):
|
| 9 |
+
""" Sentence-transformers embedding model
|
| 10 |
+
|
| 11 |
+
The sentence-transformers embedding model used for generating document and
|
| 12 |
+
word embeddings.
|
| 13 |
+
|
| 14 |
+
Arguments:
|
| 15 |
+
embedding_model: A sentence-transformers embedding model
|
| 16 |
+
|
| 17 |
+
Examples:
|
| 18 |
+
|
| 19 |
+
To create a model, you can load in a string pointing to a
|
| 20 |
+
sentence-transformers model:
|
| 21 |
+
|
| 22 |
+
```python
|
| 23 |
+
from bertopic.backend import SentenceTransformerBackend
|
| 24 |
+
|
| 25 |
+
sentence_model = SentenceTransformerBackend("all-MiniLM-L6-v2")
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
or you can instantiate a model yourself:
|
| 29 |
+
```python
|
| 30 |
+
from bertopic.backend import SentenceTransformerBackend
|
| 31 |
+
from sentence_transformers import SentenceTransformer
|
| 32 |
+
|
| 33 |
+
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 34 |
+
sentence_model = SentenceTransformerBackend(embedding_model)
|
| 35 |
+
```
|
| 36 |
+
"""
|
| 37 |
+
def __init__(self, embedding_model: Union[str, SentenceTransformer]):
|
| 38 |
+
super().__init__()
|
| 39 |
+
|
| 40 |
+
self._hf_model = None
|
| 41 |
+
if isinstance(embedding_model, SentenceTransformer):
|
| 42 |
+
self.embedding_model = embedding_model
|
| 43 |
+
elif isinstance(embedding_model, str):
|
| 44 |
+
self.embedding_model = SentenceTransformer(embedding_model)
|
| 45 |
+
self._hf_model = embedding_model
|
| 46 |
+
else:
|
| 47 |
+
raise ValueError("Please select a correct SentenceTransformers model: \n"
|
| 48 |
+
"`from sentence_transformers import SentenceTransformer` \n"
|
| 49 |
+
"`model = SentenceTransformer('all-MiniLM-L6-v2')`")
|
| 50 |
+
|
| 51 |
+
def embed(self,
|
| 52 |
+
documents: List[str],
|
| 53 |
+
verbose: bool = False) -> np.ndarray:
|
| 54 |
+
""" Embed a list of n documents/words into an n-dimensional
|
| 55 |
+
matrix of embeddings
|
| 56 |
+
|
| 57 |
+
Arguments:
|
| 58 |
+
documents: A list of documents or words to be embedded
|
| 59 |
+
verbose: Controls the verbosity of the process
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
Document/words embeddings with shape (n, m) with `n` documents/words
|
| 63 |
+
that each have an embeddings size of `m`
|
| 64 |
+
"""
|
| 65 |
+
embeddings = self.embedding_model.encode(documents, show_progress_bar=verbose)
|
| 66 |
+
return embeddings
|
BERTopic/bertopic/backend/_sklearn.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from bertopic.backend import BaseEmbedder
|
| 2 |
+
from sklearn.utils.validation import check_is_fitted, NotFittedError
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class SklearnEmbedder(BaseEmbedder):
|
| 6 |
+
""" Scikit-Learn based embedding model
|
| 7 |
+
|
| 8 |
+
This component allows the usage of scikit-learn pipelines for generating document and
|
| 9 |
+
word embeddings.
|
| 10 |
+
|
| 11 |
+
Arguments:
|
| 12 |
+
pipe: A scikit-learn pipeline that can `.transform()` text.
|
| 13 |
+
|
| 14 |
+
Examples:
|
| 15 |
+
|
| 16 |
+
Scikit-Learn is very flexible and it allows for many representations.
|
| 17 |
+
A relatively simple pipeline is shown below.
|
| 18 |
+
|
| 19 |
+
```python
|
| 20 |
+
from sklearn.pipeline import make_pipeline
|
| 21 |
+
from sklearn.decomposition import TruncatedSVD
|
| 22 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 23 |
+
|
| 24 |
+
from bertopic.backend import SklearnEmbedder
|
| 25 |
+
|
| 26 |
+
pipe = make_pipeline(
|
| 27 |
+
TfidfVectorizer(),
|
| 28 |
+
TruncatedSVD(100)
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
sklearn_embedder = SklearnEmbedder(pipe)
|
| 32 |
+
topic_model = BERTopic(embedding_model=sklearn_embedder)
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
This pipeline first constructs a sparse representation based on TF/idf and then
|
| 36 |
+
makes it dense by applying SVD. Alternatively, you might also construct something
|
| 37 |
+
more elaborate. As long as you construct a scikit-learn compatible pipeline, you
|
| 38 |
+
should be able to pass it to Bertopic.
|
| 39 |
+
|
| 40 |
+
!!! Warning
|
| 41 |
+
One caveat to be aware of is that scikit-learns base `Pipeline` class does not
|
| 42 |
+
support the `.partial_fit()`-API. If you have a pipeline that theoretically should
|
| 43 |
+
be able to support online learning then you might want to explore
|
| 44 |
+
the [scikit-partial](https://github.com/koaning/scikit-partial) project.
|
| 45 |
+
"""
|
| 46 |
+
def __init__(self, pipe):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.pipe = pipe
|
| 49 |
+
|
| 50 |
+
def embed(self, documents, verbose=False):
|
| 51 |
+
""" Embed a list of n documents/words into an n-dimensional
|
| 52 |
+
matrix of embeddings
|
| 53 |
+
|
| 54 |
+
Arguments:
|
| 55 |
+
documents: A list of documents or words to be embedded
|
| 56 |
+
verbose: No-op variable that's kept around to keep the API consistent. If you want to get feedback on training times, you should use the sklearn API.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
Document/words embeddings with shape (n, m) with `n` documents/words
|
| 60 |
+
that each have an embeddings size of `m`
|
| 61 |
+
"""
|
| 62 |
+
try:
|
| 63 |
+
check_is_fitted(self.pipe)
|
| 64 |
+
embeddings = self.pipe.transform(documents)
|
| 65 |
+
except NotFittedError:
|
| 66 |
+
embeddings = self.pipe.fit_transform(documents)
|
| 67 |
+
|
| 68 |
+
return embeddings
|
BERTopic/bertopic/backend/_spacy.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
from typing import List
|
| 4 |
+
from bertopic.backend import BaseEmbedder
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class SpacyBackend(BaseEmbedder):
|
| 8 |
+
""" Spacy embedding model
|
| 9 |
+
|
| 10 |
+
The Spacy embedding model used for generating document and
|
| 11 |
+
word embeddings.
|
| 12 |
+
|
| 13 |
+
Arguments:
|
| 14 |
+
embedding_model: A spacy embedding model
|
| 15 |
+
|
| 16 |
+
Examples:
|
| 17 |
+
|
| 18 |
+
To create a Spacy backend, you need to create an nlp object and
|
| 19 |
+
pass it through this backend:
|
| 20 |
+
|
| 21 |
+
```python
|
| 22 |
+
import spacy
|
| 23 |
+
from bertopic.backend import SpacyBackend
|
| 24 |
+
|
| 25 |
+
nlp = spacy.load("en_core_web_md", exclude=['tagger', 'parser', 'ner', 'attribute_ruler', 'lemmatizer'])
|
| 26 |
+
spacy_model = SpacyBackend(nlp)
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
To load in a transformer model use the following:
|
| 30 |
+
|
| 31 |
+
```python
|
| 32 |
+
import spacy
|
| 33 |
+
from thinc.api import set_gpu_allocator, require_gpu
|
| 34 |
+
from bertopic.backend import SpacyBackend
|
| 35 |
+
|
| 36 |
+
nlp = spacy.load("en_core_web_trf", exclude=['tagger', 'parser', 'ner', 'attribute_ruler', 'lemmatizer'])
|
| 37 |
+
set_gpu_allocator("pytorch")
|
| 38 |
+
require_gpu(0)
|
| 39 |
+
spacy_model = SpacyBackend(nlp)
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
If you run into gpu/memory-issues, please use:
|
| 43 |
+
|
| 44 |
+
```python
|
| 45 |
+
import spacy
|
| 46 |
+
from bertopic.backend import SpacyBackend
|
| 47 |
+
|
| 48 |
+
spacy.prefer_gpu()
|
| 49 |
+
nlp = spacy.load("en_core_web_trf", exclude=['tagger', 'parser', 'ner', 'attribute_ruler', 'lemmatizer'])
|
| 50 |
+
spacy_model = SpacyBackend(nlp)
|
| 51 |
+
```
|
| 52 |
+
"""
|
| 53 |
+
def __init__(self, embedding_model):
|
| 54 |
+
super().__init__()
|
| 55 |
+
|
| 56 |
+
if "spacy" in str(type(embedding_model)):
|
| 57 |
+
self.embedding_model = embedding_model
|
| 58 |
+
else:
|
| 59 |
+
raise ValueError("Please select a correct Spacy model by either using a string such as 'en_core_web_md' "
|
| 60 |
+
"or create a nlp model using: `nlp = spacy.load('en_core_web_md')")
|
| 61 |
+
|
| 62 |
+
def embed(self,
|
| 63 |
+
documents: List[str],
|
| 64 |
+
verbose: bool = False) -> np.ndarray:
|
| 65 |
+
""" Embed a list of n documents/words into an n-dimensional
|
| 66 |
+
matrix of embeddings
|
| 67 |
+
|
| 68 |
+
Arguments:
|
| 69 |
+
documents: A list of documents or words to be embedded
|
| 70 |
+
verbose: Controls the verbosity of the process
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
Document/words embeddings with shape (n, m) with `n` documents/words
|
| 74 |
+
that each have an embeddings size of `m`
|
| 75 |
+
"""
|
| 76 |
+
# Handle empty documents, spaCy models automatically map
|
| 77 |
+
# empty strings to the zero vector
|
| 78 |
+
empty_document = " "
|
| 79 |
+
|
| 80 |
+
# Extract embeddings
|
| 81 |
+
embeddings = []
|
| 82 |
+
for doc in tqdm(documents, position=0, leave=True, disable=not verbose):
|
| 83 |
+
embedding = self.embedding_model(doc or empty_document)
|
| 84 |
+
if embedding.has_vector:
|
| 85 |
+
embedding = embedding.vector
|
| 86 |
+
else:
|
| 87 |
+
embedding = embedding._.trf_data.tensors[-1][0]
|
| 88 |
+
|
| 89 |
+
if not isinstance(embedding, np.ndarray) and hasattr(embedding, 'get'):
|
| 90 |
+
# Convert cupy array to numpy array
|
| 91 |
+
embedding = embedding.get()
|
| 92 |
+
embeddings.append(embedding)
|
| 93 |
+
|
| 94 |
+
return np.array(embeddings)
|
BERTopic/bertopic/backend/_use.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
from bertopic.backend import BaseEmbedder
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class USEBackend(BaseEmbedder):
|
| 9 |
+
""" Universal Sentence Encoder
|
| 10 |
+
|
| 11 |
+
USE encodes text into high-dimensional vectors that
|
| 12 |
+
are used for semantic similarity in BERTopic.
|
| 13 |
+
|
| 14 |
+
Arguments:
|
| 15 |
+
embedding_model: An USE embedding model
|
| 16 |
+
|
| 17 |
+
Examples:
|
| 18 |
+
|
| 19 |
+
```python
|
| 20 |
+
import tensorflow_hub
|
| 21 |
+
from bertopic.backend import USEBackend
|
| 22 |
+
|
| 23 |
+
embedding_model = tensorflow_hub.load("https://tfhub.dev/google/universal-sentence-encoder/4")
|
| 24 |
+
use_embedder = USEBackend(embedding_model)
|
| 25 |
+
```
|
| 26 |
+
"""
|
| 27 |
+
def __init__(self, embedding_model):
|
| 28 |
+
super().__init__()
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
embedding_model(["test sentence"])
|
| 32 |
+
self.embedding_model = embedding_model
|
| 33 |
+
except TypeError:
|
| 34 |
+
raise ValueError("Please select a correct USE model: \n"
|
| 35 |
+
"`import tensorflow_hub` \n"
|
| 36 |
+
"`embedding_model = tensorflow_hub.load(path_to_model)`")
|
| 37 |
+
|
| 38 |
+
def embed(self,
|
| 39 |
+
documents: List[str],
|
| 40 |
+
verbose: bool = False) -> np.ndarray:
|
| 41 |
+
""" Embed a list of n documents/words into an n-dimensional
|
| 42 |
+
matrix of embeddings
|
| 43 |
+
|
| 44 |
+
Arguments:
|
| 45 |
+
documents: A list of documents or words to be embedded
|
| 46 |
+
verbose: Controls the verbosity of the process
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
Document/words embeddings with shape (n, m) with `n` documents/words
|
| 50 |
+
that each have an embeddings size of `m`
|
| 51 |
+
"""
|
| 52 |
+
embeddings = np.array(
|
| 53 |
+
[
|
| 54 |
+
self.embedding_model([doc]).cpu().numpy()[0]
|
| 55 |
+
for doc in tqdm(documents, disable=not verbose)
|
| 56 |
+
]
|
| 57 |
+
)
|
| 58 |
+
return embeddings
|
BERTopic/bertopic/backend/_utils.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ._base import BaseEmbedder
|
| 2 |
+
|
| 3 |
+
# Imports for light-weight variant of BERTopic
|
| 4 |
+
from bertopic.backend._sklearn import SklearnEmbedder
|
| 5 |
+
from sklearn.pipeline import make_pipeline
|
| 6 |
+
from sklearn.decomposition import TruncatedSVD
|
| 7 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 8 |
+
from sklearn.pipeline import Pipeline as ScikitPipeline
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
languages = [
|
| 12 |
+
"arabic",
|
| 13 |
+
"bulgarian",
|
| 14 |
+
"catalan",
|
| 15 |
+
"czech",
|
| 16 |
+
"danish",
|
| 17 |
+
"german",
|
| 18 |
+
"greek",
|
| 19 |
+
"english",
|
| 20 |
+
"spanish",
|
| 21 |
+
"estonian",
|
| 22 |
+
"persian",
|
| 23 |
+
"finnish",
|
| 24 |
+
"french",
|
| 25 |
+
"canadian french",
|
| 26 |
+
"galician",
|
| 27 |
+
"gujarati",
|
| 28 |
+
"hebrew",
|
| 29 |
+
"hindi",
|
| 30 |
+
"croatian",
|
| 31 |
+
"hungarian",
|
| 32 |
+
"armenian",
|
| 33 |
+
"indonesian",
|
| 34 |
+
"italian",
|
| 35 |
+
"japanese",
|
| 36 |
+
"georgian",
|
| 37 |
+
"korean",
|
| 38 |
+
"kurdish",
|
| 39 |
+
"lithuanian",
|
| 40 |
+
"latvian",
|
| 41 |
+
"macedonian",
|
| 42 |
+
"mongolian",
|
| 43 |
+
"marathi",
|
| 44 |
+
"malay",
|
| 45 |
+
"burmese",
|
| 46 |
+
"norwegian bokmal",
|
| 47 |
+
"dutch",
|
| 48 |
+
"polish",
|
| 49 |
+
"portuguese",
|
| 50 |
+
"brazilian portuguese",
|
| 51 |
+
"romanian",
|
| 52 |
+
"russian",
|
| 53 |
+
"slovak",
|
| 54 |
+
"slovenian",
|
| 55 |
+
"albanian",
|
| 56 |
+
"serbian",
|
| 57 |
+
"swedish",
|
| 58 |
+
"thai",
|
| 59 |
+
"turkish",
|
| 60 |
+
"ukrainian",
|
| 61 |
+
"urdu",
|
| 62 |
+
"vietnamese",
|
| 63 |
+
"chinese (simplified)",
|
| 64 |
+
"chinese (traditional)",
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def select_backend(embedding_model,
|
| 69 |
+
language: str = None) -> BaseEmbedder:
|
| 70 |
+
""" Select an embedding model based on language or a specific sentence transformer models.
|
| 71 |
+
When selecting a language, we choose all-MiniLM-L6-v2 for English and
|
| 72 |
+
paraphrase-multilingual-MiniLM-L12-v2 for all other languages as it support 100+ languages.
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
model: Either a Sentence-Transformer or Flair model
|
| 76 |
+
"""
|
| 77 |
+
# BERTopic language backend
|
| 78 |
+
if isinstance(embedding_model, BaseEmbedder):
|
| 79 |
+
return embedding_model
|
| 80 |
+
|
| 81 |
+
# Scikit-learn backend
|
| 82 |
+
if isinstance(embedding_model, ScikitPipeline):
|
| 83 |
+
return SklearnEmbedder(embedding_model)
|
| 84 |
+
|
| 85 |
+
# Flair word embeddings
|
| 86 |
+
if "flair" in str(type(embedding_model)):
|
| 87 |
+
from bertopic.backend._flair import FlairBackend
|
| 88 |
+
return FlairBackend(embedding_model)
|
| 89 |
+
|
| 90 |
+
# Spacy embeddings
|
| 91 |
+
if "spacy" in str(type(embedding_model)):
|
| 92 |
+
from bertopic.backend._spacy import SpacyBackend
|
| 93 |
+
return SpacyBackend(embedding_model)
|
| 94 |
+
|
| 95 |
+
# Gensim embeddings
|
| 96 |
+
if "gensim" in str(type(embedding_model)):
|
| 97 |
+
from bertopic.backend._gensim import GensimBackend
|
| 98 |
+
return GensimBackend(embedding_model)
|
| 99 |
+
|
| 100 |
+
# USE embeddings
|
| 101 |
+
if "tensorflow" and "saved_model" in str(type(embedding_model)):
|
| 102 |
+
from bertopic.backend._use import USEBackend
|
| 103 |
+
return USEBackend(embedding_model)
|
| 104 |
+
|
| 105 |
+
# Sentence Transformer embeddings
|
| 106 |
+
if "sentence_transformers" in str(type(embedding_model)) or isinstance(embedding_model, str):
|
| 107 |
+
from ._sentencetransformers import SentenceTransformerBackend
|
| 108 |
+
return SentenceTransformerBackend(embedding_model)
|
| 109 |
+
|
| 110 |
+
# Hugging Face embeddings
|
| 111 |
+
if "transformers" and "pipeline" in str(type(embedding_model)):
|
| 112 |
+
from ._hftransformers import HFTransformerBackend
|
| 113 |
+
return HFTransformerBackend(embedding_model)
|
| 114 |
+
|
| 115 |
+
# Select embedding model based on language
|
| 116 |
+
if language:
|
| 117 |
+
try:
|
| 118 |
+
from ._sentencetransformers import SentenceTransformerBackend
|
| 119 |
+
if language.lower() in ["English", "english", "en"]:
|
| 120 |
+
return SentenceTransformerBackend("sentence-transformers/all-MiniLM-L6-v2")
|
| 121 |
+
elif language.lower() in languages or language == "multilingual":
|
| 122 |
+
return SentenceTransformerBackend("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
|
| 123 |
+
else:
|
| 124 |
+
raise ValueError(f"{language} is currently not supported. However, you can "
|
| 125 |
+
f"create any embeddings yourself and pass it through fit_transform(docs, embeddings)\n"
|
| 126 |
+
"Else, please select a language from the following list:\n"
|
| 127 |
+
f"{languages}")
|
| 128 |
+
|
| 129 |
+
# Only for light-weight installation
|
| 130 |
+
except ModuleNotFoundError:
|
| 131 |
+
pipe = make_pipeline(TfidfVectorizer(), TruncatedSVD(100))
|
| 132 |
+
return SklearnEmbedder(pipe)
|
| 133 |
+
|
| 134 |
+
from ._sentencetransformers import SentenceTransformerBackend
|
| 135 |
+
return SentenceTransformerBackend("sentence-transformers/all-MiniLM-L6-v2")
|
BERTopic/bertopic/backend/_word_doc.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from typing import List
|
| 3 |
+
from bertopic.backend._base import BaseEmbedder
|
| 4 |
+
from bertopic.backend._utils import select_backend
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class WordDocEmbedder(BaseEmbedder):
|
| 8 |
+
""" Combine a document- and word-level embedder
|
| 9 |
+
"""
|
| 10 |
+
def __init__(self,
|
| 11 |
+
embedding_model,
|
| 12 |
+
word_embedding_model):
|
| 13 |
+
super().__init__()
|
| 14 |
+
|
| 15 |
+
self.embedding_model = select_backend(embedding_model)
|
| 16 |
+
self.word_embedding_model = select_backend(word_embedding_model)
|
| 17 |
+
|
| 18 |
+
def embed_words(self,
|
| 19 |
+
words: List[str],
|
| 20 |
+
verbose: bool = False) -> np.ndarray:
|
| 21 |
+
""" Embed a list of n words into an n-dimensional
|
| 22 |
+
matrix of embeddings
|
| 23 |
+
|
| 24 |
+
Arguments:
|
| 25 |
+
words: A list of words to be embedded
|
| 26 |
+
verbose: Controls the verbosity of the process
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
Word embeddings with shape (n, m) with `n` words
|
| 30 |
+
that each have an embeddings size of `m`
|
| 31 |
+
|
| 32 |
+
"""
|
| 33 |
+
return self.word_embedding_model.embed(words, verbose)
|
| 34 |
+
|
| 35 |
+
def embed_documents(self,
|
| 36 |
+
document: List[str],
|
| 37 |
+
verbose: bool = False) -> np.ndarray:
|
| 38 |
+
""" Embed a list of n words into an n-dimensional
|
| 39 |
+
matrix of embeddings
|
| 40 |
+
|
| 41 |
+
Arguments:
|
| 42 |
+
document: A list of documents to be embedded
|
| 43 |
+
verbose: Controls the verbosity of the process
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
Document embeddings with shape (n, m) with `n` documents
|
| 47 |
+
that each have an embeddings size of `m`
|
| 48 |
+
"""
|
| 49 |
+
return self.embedding_model.embed(document, verbose)
|
BERTopic/bertopic/cluster/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ._base import BaseCluster
|
| 2 |
+
|
| 3 |
+
__all__ = [
|
| 4 |
+
"BaseCluster",
|
| 5 |
+
]
|
BERTopic/bertopic/cluster/_base.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class BaseCluster:
|
| 5 |
+
""" The Base Cluster class
|
| 6 |
+
|
| 7 |
+
Using this class directly in BERTopic will make it skip
|
| 8 |
+
over the cluster step. As a result, topics need to be passed
|
| 9 |
+
to BERTopic in the form of its `y` parameter in order to create
|
| 10 |
+
topic representations.
|
| 11 |
+
|
| 12 |
+
Examples:
|
| 13 |
+
|
| 14 |
+
This will skip over the cluster step in BERTopic:
|
| 15 |
+
|
| 16 |
+
```python
|
| 17 |
+
from bertopic import BERTopic
|
| 18 |
+
from bertopic.dimensionality import BaseCluster
|
| 19 |
+
|
| 20 |
+
empty_cluster_model = BaseCluster()
|
| 21 |
+
|
| 22 |
+
topic_model = BERTopic(hdbscan_model=empty_cluster_model)
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
Then, this class can be used to perform manual topic modeling.
|
| 26 |
+
That is, topic modeling on a topics that were already generated before
|
| 27 |
+
without the need to learn them:
|
| 28 |
+
|
| 29 |
+
```python
|
| 30 |
+
topic_model.fit(docs, y=y)
|
| 31 |
+
```
|
| 32 |
+
"""
|
| 33 |
+
def fit(self, X, y=None):
|
| 34 |
+
if y is not None:
|
| 35 |
+
self.labels_ = y
|
| 36 |
+
else:
|
| 37 |
+
self.labels_ = None
|
| 38 |
+
return self
|
| 39 |
+
|
| 40 |
+
def transform(self, X: np.ndarray) -> np.ndarray:
|
| 41 |
+
return X
|
BERTopic/bertopic/cluster/_utils.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hdbscan
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None):
|
| 6 |
+
""" Function used to select the HDBSCAN-like model for generating
|
| 7 |
+
predictions and probabilities.
|
| 8 |
+
|
| 9 |
+
Arguments:
|
| 10 |
+
model: The cluster model.
|
| 11 |
+
func: The function to use. Options:
|
| 12 |
+
- "approximate_predict"
|
| 13 |
+
- "all_points_membership_vectors"
|
| 14 |
+
- "membership_vector"
|
| 15 |
+
embeddings: Input embeddings for "approximate_predict"
|
| 16 |
+
and "membership_vector"
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
# Approximate predict
|
| 20 |
+
if func == "approximate_predict":
|
| 21 |
+
if isinstance(model, hdbscan.HDBSCAN):
|
| 22 |
+
predictions, probabilities = hdbscan.approximate_predict(model, embeddings)
|
| 23 |
+
return predictions, probabilities
|
| 24 |
+
|
| 25 |
+
str_type_model = str(type(model)).lower()
|
| 26 |
+
if "cuml" in str_type_model and "hdbscan" in str_type_model:
|
| 27 |
+
from cuml.cluster import hdbscan as cuml_hdbscan
|
| 28 |
+
predictions, probabilities = cuml_hdbscan.approximate_predict(model, embeddings)
|
| 29 |
+
return predictions, probabilities
|
| 30 |
+
|
| 31 |
+
predictions = model.predict(embeddings)
|
| 32 |
+
return predictions, None
|
| 33 |
+
|
| 34 |
+
# All points membership
|
| 35 |
+
if func == "all_points_membership_vectors":
|
| 36 |
+
if isinstance(model, hdbscan.HDBSCAN):
|
| 37 |
+
return hdbscan.all_points_membership_vectors(model)
|
| 38 |
+
|
| 39 |
+
str_type_model = str(type(model)).lower()
|
| 40 |
+
if "cuml" in str_type_model and "hdbscan" in str_type_model:
|
| 41 |
+
from cuml.cluster import hdbscan as cuml_hdbscan
|
| 42 |
+
return cuml_hdbscan.all_points_membership_vectors(model)
|
| 43 |
+
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
# membership_vector
|
| 47 |
+
if func == "membership_vector":
|
| 48 |
+
if isinstance(model, hdbscan.HDBSCAN):
|
| 49 |
+
probabilities = hdbscan.membership_vector(model, embeddings)
|
| 50 |
+
return probabilities
|
| 51 |
+
|
| 52 |
+
str_type_model = str(type(model)).lower()
|
| 53 |
+
if "cuml" in str_type_model and "hdbscan" in str_type_model:
|
| 54 |
+
from cuml.cluster.hdbscan.prediction import approximate_predict
|
| 55 |
+
probabilities = approximate_predict(model, embeddings)
|
| 56 |
+
return probabilities
|
| 57 |
+
|
| 58 |
+
return None
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def is_supported_hdbscan(model):
|
| 62 |
+
""" Check whether the input model is a supported HDBSCAN-like model """
|
| 63 |
+
if isinstance(model, hdbscan.HDBSCAN):
|
| 64 |
+
return True
|
| 65 |
+
|
| 66 |
+
str_type_model = str(type(model)).lower()
|
| 67 |
+
if "cuml" in str_type_model and "hdbscan" in str_type_model:
|
| 68 |
+
return True
|
| 69 |
+
|
| 70 |
+
return False
|
BERTopic/bertopic/dimensionality/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ._base import BaseDimensionalityReduction
|
| 2 |
+
|
| 3 |
+
__all__ = [
|
| 4 |
+
"BaseDimensionalityReduction",
|
| 5 |
+
]
|
BERTopic/bertopic/dimensionality/_base.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class BaseDimensionalityReduction:
|
| 5 |
+
""" The Base Dimensionality Reduction class
|
| 6 |
+
|
| 7 |
+
You can use this to skip over the dimensionality reduction step in BERTopic.
|
| 8 |
+
|
| 9 |
+
Examples:
|
| 10 |
+
|
| 11 |
+
This will skip over the reduction step in BERTopic:
|
| 12 |
+
|
| 13 |
+
```python
|
| 14 |
+
from bertopic import BERTopic
|
| 15 |
+
from bertopic.dimensionality import BaseDimensionalityReduction
|
| 16 |
+
|
| 17 |
+
empty_reduction_model = BaseDimensionalityReduction()
|
| 18 |
+
|
| 19 |
+
topic_model = BERTopic(umap_model=empty_reduction_model)
|
| 20 |
+
```
|
| 21 |
+
"""
|
| 22 |
+
def fit(self, X: np.ndarray = None):
|
| 23 |
+
return self
|
| 24 |
+
|
| 25 |
+
def transform(self, X: np.ndarray) -> np.ndarray:
|
| 26 |
+
return X
|
BERTopic/bertopic/plotting/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ._topics import visualize_topics
|
| 2 |
+
from ._heatmap import visualize_heatmap
|
| 3 |
+
from ._barchart import visualize_barchart
|
| 4 |
+
from ._documents import visualize_documents
|
| 5 |
+
from ._term_rank import visualize_term_rank
|
| 6 |
+
from ._hierarchy import visualize_hierarchy
|
| 7 |
+
from ._datamap import visualize_document_datamap
|
| 8 |
+
from ._distribution import visualize_distribution
|
| 9 |
+
from ._topics_over_time import visualize_topics_over_time
|
| 10 |
+
from ._topics_per_class import visualize_topics_per_class
|
| 11 |
+
from ._hierarchical_documents import visualize_hierarchical_documents
|
| 12 |
+
from ._approximate_distribution import visualize_approximate_distribution
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"visualize_topics",
|
| 17 |
+
"visualize_heatmap",
|
| 18 |
+
"visualize_barchart",
|
| 19 |
+
"visualize_documents",
|
| 20 |
+
"visualize_term_rank",
|
| 21 |
+
"visualize_hierarchy",
|
| 22 |
+
"visualize_distribution",
|
| 23 |
+
"visualize_document_datamap",
|
| 24 |
+
"visualize_topics_over_time",
|
| 25 |
+
"visualize_topics_per_class",
|
| 26 |
+
"visualize_hierarchical_documents",
|
| 27 |
+
"visualize_approximate_distribution"
|
| 28 |
+
]
|
BERTopic/bertopic/plotting/_approximate_distribution.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pandas as pd
|
| 3 |
+
|
| 4 |
+
try:
|
| 5 |
+
from pandas.io.formats.style import Styler
|
| 6 |
+
HAS_JINJA = True
|
| 7 |
+
except (ModuleNotFoundError, ImportError):
|
| 8 |
+
HAS_JINJA = False
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def visualize_approximate_distribution(topic_model,
|
| 12 |
+
document: str,
|
| 13 |
+
topic_token_distribution: np.ndarray,
|
| 14 |
+
normalize: bool = False):
|
| 15 |
+
""" Visualize the topic distribution calculated by `.approximate_topic_distribution`
|
| 16 |
+
on a token level. Thereby indicating the extend to which a certain word or phrases belong
|
| 17 |
+
to a specific topic. The assumption here is that a single word can belong to multiple
|
| 18 |
+
similar topics and as such give information about the broader set of topics within
|
| 19 |
+
a single document.
|
| 20 |
+
|
| 21 |
+
NOTE:
|
| 22 |
+
This fuction will return a stylized pandas dataframe if Jinja2 is installed. If not,
|
| 23 |
+
it will only return a pandas dataframe without color highlighting. To install jinja:
|
| 24 |
+
|
| 25 |
+
`pip install jinja2`
|
| 26 |
+
|
| 27 |
+
Arguments:
|
| 28 |
+
topic_model: A fitted BERTopic instance.
|
| 29 |
+
document: The document for which you want to visualize
|
| 30 |
+
the approximated topic distribution.
|
| 31 |
+
topic_token_distribution: The topic-token distribution of the document as
|
| 32 |
+
extracted by `.approximate_topic_distribution`
|
| 33 |
+
normalize: Whether to normalize, between 0 and 1 (summing to 1), the
|
| 34 |
+
topic distribution values.
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
df: A stylized dataframe indicating the best fitting topics
|
| 38 |
+
for each token.
|
| 39 |
+
|
| 40 |
+
Examples:
|
| 41 |
+
|
| 42 |
+
```python
|
| 43 |
+
# Calculate the topic distributions on a token level
|
| 44 |
+
# Note that we need to have `calculate_token_level=True`
|
| 45 |
+
topic_distr, topic_token_distr = topic_model.approximate_distribution(
|
| 46 |
+
docs, calculate_token_level=True
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# Visualize the approximated topic distributions
|
| 50 |
+
df = topic_model.visualize_approximate_distribution(docs[0], topic_token_distr[0])
|
| 51 |
+
df
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
To revert this stylized dataframe back to a regular dataframe,
|
| 55 |
+
you can run the following:
|
| 56 |
+
|
| 57 |
+
```python
|
| 58 |
+
df.data.columns = [column.strip() for column in df.data.columns]
|
| 59 |
+
df = df.data
|
| 60 |
+
```
|
| 61 |
+
"""
|
| 62 |
+
# Tokenize document
|
| 63 |
+
analyzer = topic_model.vectorizer_model.build_tokenizer()
|
| 64 |
+
tokens = analyzer(document)
|
| 65 |
+
|
| 66 |
+
if len(tokens) == 0:
|
| 67 |
+
raise ValueError("Make sure that your document contains at least 1 token.")
|
| 68 |
+
|
| 69 |
+
# Prepare dataframe with results
|
| 70 |
+
if normalize:
|
| 71 |
+
df = pd.DataFrame(topic_token_distribution / topic_token_distribution.sum()).T
|
| 72 |
+
else:
|
| 73 |
+
df = pd.DataFrame(topic_token_distribution).T
|
| 74 |
+
|
| 75 |
+
df.columns = [f"{token}_{i}" for i, token in enumerate(tokens)]
|
| 76 |
+
df.columns = [f"{token}{' '*i}" for i, token in enumerate(tokens)]
|
| 77 |
+
df.index = list(topic_model.topic_labels_.values())[topic_model._outliers:]
|
| 78 |
+
df = df.loc[(df.sum(axis=1) != 0), :]
|
| 79 |
+
|
| 80 |
+
# Style the resulting dataframe
|
| 81 |
+
def text_color(val):
|
| 82 |
+
color = 'white' if val == 0 else 'black'
|
| 83 |
+
return 'color: %s' % color
|
| 84 |
+
|
| 85 |
+
def highligh_color(data, color='white'):
|
| 86 |
+
attr = 'background-color: {}'.format(color)
|
| 87 |
+
return pd.DataFrame(np.where(data == 0, attr, ''), index=data.index, columns=data.columns)
|
| 88 |
+
|
| 89 |
+
if len(df) == 0:
|
| 90 |
+
return df
|
| 91 |
+
elif HAS_JINJA:
|
| 92 |
+
df = (
|
| 93 |
+
df.style
|
| 94 |
+
.format("{:.3f}")
|
| 95 |
+
.background_gradient(cmap='Blues', axis=None)
|
| 96 |
+
.applymap(lambda x: text_color(x))
|
| 97 |
+
.apply(highligh_color, axis=None)
|
| 98 |
+
)
|
| 99 |
+
return df
|
BERTopic/bertopic/plotting/_barchart.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import List, Union
|
| 4 |
+
|
| 5 |
+
import plotly.graph_objects as go
|
| 6 |
+
from plotly.subplots import make_subplots
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def visualize_barchart(topic_model,
|
| 10 |
+
topics: List[int] = None,
|
| 11 |
+
top_n_topics: int = 8,
|
| 12 |
+
n_words: int = 5,
|
| 13 |
+
custom_labels: Union[bool, str] = False,
|
| 14 |
+
title: str = "<b>Topic Word Scores</b>",
|
| 15 |
+
width: int = 250,
|
| 16 |
+
height: int = 250) -> go.Figure:
|
| 17 |
+
""" Visualize a barchart of selected topics
|
| 18 |
+
|
| 19 |
+
Arguments:
|
| 20 |
+
topic_model: A fitted BERTopic instance.
|
| 21 |
+
topics: A selection of topics to visualize.
|
| 22 |
+
top_n_topics: Only select the top n most frequent topics.
|
| 23 |
+
n_words: Number of words to show in a topic
|
| 24 |
+
custom_labels: If bool, whether to use custom topic labels that were defined using
|
| 25 |
+
`topic_model.set_topic_labels`.
|
| 26 |
+
If `str`, it uses labels from other aspects, e.g., "Aspect1".
|
| 27 |
+
title: Title of the plot.
|
| 28 |
+
width: The width of each figure.
|
| 29 |
+
height: The height of each figure.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
fig: A plotly figure
|
| 33 |
+
|
| 34 |
+
Examples:
|
| 35 |
+
|
| 36 |
+
To visualize the barchart of selected topics
|
| 37 |
+
simply run:
|
| 38 |
+
|
| 39 |
+
```python
|
| 40 |
+
topic_model.visualize_barchart()
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
Or if you want to save the resulting figure:
|
| 44 |
+
|
| 45 |
+
```python
|
| 46 |
+
fig = topic_model.visualize_barchart()
|
| 47 |
+
fig.write_html("path/to/file.html")
|
| 48 |
+
```
|
| 49 |
+
<iframe src="../../getting_started/visualization/bar_chart.html"
|
| 50 |
+
style="width:1100px; height: 660px; border: 0px;""></iframe>
|
| 51 |
+
"""
|
| 52 |
+
colors = itertools.cycle(["#D55E00", "#0072B2", "#CC79A7", "#E69F00", "#56B4E9", "#009E73", "#F0E442"])
|
| 53 |
+
|
| 54 |
+
# Select topics based on top_n and topics args
|
| 55 |
+
freq_df = topic_model.get_topic_freq()
|
| 56 |
+
freq_df = freq_df.loc[freq_df.Topic != -1, :]
|
| 57 |
+
if topics is not None:
|
| 58 |
+
topics = list(topics)
|
| 59 |
+
elif top_n_topics is not None:
|
| 60 |
+
topics = sorted(freq_df.Topic.to_list()[:top_n_topics])
|
| 61 |
+
else:
|
| 62 |
+
topics = sorted(freq_df.Topic.to_list()[0:6])
|
| 63 |
+
|
| 64 |
+
# Initialize figure
|
| 65 |
+
if isinstance(custom_labels, str):
|
| 66 |
+
subplot_titles = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in topics]
|
| 67 |
+
subplot_titles = ["_".join([label[0] for label in labels[:4]]) for labels in subplot_titles]
|
| 68 |
+
subplot_titles = [label if len(label) < 30 else label[:27] + "..." for label in subplot_titles]
|
| 69 |
+
elif topic_model.custom_labels_ is not None and custom_labels:
|
| 70 |
+
subplot_titles = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in topics]
|
| 71 |
+
else:
|
| 72 |
+
subplot_titles = [f"Topic {topic}" for topic in topics]
|
| 73 |
+
columns = 4
|
| 74 |
+
rows = int(np.ceil(len(topics) / columns))
|
| 75 |
+
fig = make_subplots(rows=rows,
|
| 76 |
+
cols=columns,
|
| 77 |
+
shared_xaxes=False,
|
| 78 |
+
horizontal_spacing=.1,
|
| 79 |
+
vertical_spacing=.4 / rows if rows > 1 else 0,
|
| 80 |
+
subplot_titles=subplot_titles)
|
| 81 |
+
|
| 82 |
+
# Add barchart for each topic
|
| 83 |
+
row = 1
|
| 84 |
+
column = 1
|
| 85 |
+
for topic in topics:
|
| 86 |
+
words = [word + " " for word, _ in topic_model.get_topic(topic)][:n_words][::-1]
|
| 87 |
+
scores = [score for _, score in topic_model.get_topic(topic)][:n_words][::-1]
|
| 88 |
+
|
| 89 |
+
fig.add_trace(
|
| 90 |
+
go.Bar(x=scores,
|
| 91 |
+
y=words,
|
| 92 |
+
orientation='h',
|
| 93 |
+
marker_color=next(colors)),
|
| 94 |
+
row=row, col=column)
|
| 95 |
+
|
| 96 |
+
if column == columns:
|
| 97 |
+
column = 1
|
| 98 |
+
row += 1
|
| 99 |
+
else:
|
| 100 |
+
column += 1
|
| 101 |
+
|
| 102 |
+
# Stylize graph
|
| 103 |
+
fig.update_layout(
|
| 104 |
+
template="plotly_white",
|
| 105 |
+
showlegend=False,
|
| 106 |
+
title={
|
| 107 |
+
'text': f"{title}",
|
| 108 |
+
'x': .5,
|
| 109 |
+
'xanchor': 'center',
|
| 110 |
+
'yanchor': 'top',
|
| 111 |
+
'font': dict(
|
| 112 |
+
size=22,
|
| 113 |
+
color="Black")
|
| 114 |
+
},
|
| 115 |
+
width=width*4,
|
| 116 |
+
height=height*rows if rows > 1 else height * 1.3,
|
| 117 |
+
hoverlabel=dict(
|
| 118 |
+
bgcolor="white",
|
| 119 |
+
font_size=16,
|
| 120 |
+
font_family="Rockwell"
|
| 121 |
+
),
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
fig.update_xaxes(showgrid=True)
|
| 125 |
+
fig.update_yaxes(showgrid=True)
|
| 126 |
+
|
| 127 |
+
return fig
|
BERTopic/bertopic/plotting/_datamap.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from typing import List, Union
|
| 4 |
+
from umap import UMAP
|
| 5 |
+
from warnings import warn
|
| 6 |
+
try:
|
| 7 |
+
import datamapplot
|
| 8 |
+
from matplotlib.figure import Figure
|
| 9 |
+
except ImportError:
|
| 10 |
+
warn("Data map plotting is unavailable unless datamapplot is installed.")
|
| 11 |
+
# Create a dummy figure type for typing
|
| 12 |
+
class Figure (object):
|
| 13 |
+
pass
|
| 14 |
+
|
| 15 |
+
def visualize_document_datamap(topic_model,
|
| 16 |
+
docs: List[str],
|
| 17 |
+
topics: List[int] = None,
|
| 18 |
+
embeddings: np.ndarray = None,
|
| 19 |
+
reduced_embeddings: np.ndarray = None,
|
| 20 |
+
custom_labels: Union[bool, str] = False,
|
| 21 |
+
title: str = "Documents and Topics",
|
| 22 |
+
sub_title: Union[str, None] = None,
|
| 23 |
+
width: int = 1200,
|
| 24 |
+
height: int = 1200,
|
| 25 |
+
**datamap_kwds) -> Figure:
|
| 26 |
+
""" Visualize documents and their topics in 2D as a static plot for publication using
|
| 27 |
+
DataMapPlot.
|
| 28 |
+
|
| 29 |
+
Arguments:
|
| 30 |
+
topic_model: A fitted BERTopic instance.
|
| 31 |
+
docs: The documents you used when calling either `fit` or `fit_transform`
|
| 32 |
+
topics: A selection of topics to visualize.
|
| 33 |
+
Not to be confused with the topics that you get from `.fit_transform`.
|
| 34 |
+
For example, if you want to visualize only topics 1 through 5:
|
| 35 |
+
`topics = [1, 2, 3, 4, 5]`. Documents not in these topics will be shown
|
| 36 |
+
as noise points.
|
| 37 |
+
embeddings: The embeddings of all documents in `docs`.
|
| 38 |
+
reduced_embeddings: The 2D reduced embeddings of all documents in `docs`.
|
| 39 |
+
custom_labels: If bool, whether to use custom topic labels that were defined using
|
| 40 |
+
`topic_model.set_topic_labels`.
|
| 41 |
+
If `str`, it uses labels from other aspects, e.g., "Aspect1".
|
| 42 |
+
title: Title of the plot.
|
| 43 |
+
sub_title: Sub-title of the plot.
|
| 44 |
+
width: The width of the figure.
|
| 45 |
+
height: The height of the figure.
|
| 46 |
+
**datamap_kwds: All further keyword args will be passed on to DataMapPlot's
|
| 47 |
+
`create_plot` function. See the DataMapPlot documentation
|
| 48 |
+
for more details.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
figure: A Matplotlib Figure object.
|
| 52 |
+
|
| 53 |
+
Examples:
|
| 54 |
+
|
| 55 |
+
To visualize the topics simply run:
|
| 56 |
+
|
| 57 |
+
```python
|
| 58 |
+
topic_model.visualize_document_datamap(docs)
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
Do note that this re-calculates the embeddings and reduces them to 2D.
|
| 62 |
+
The advised and preferred pipeline for using this function is as follows:
|
| 63 |
+
|
| 64 |
+
```python
|
| 65 |
+
from sklearn.datasets import fetch_20newsgroups
|
| 66 |
+
from sentence_transformers import SentenceTransformer
|
| 67 |
+
from bertopic import BERTopic
|
| 68 |
+
from umap import UMAP
|
| 69 |
+
|
| 70 |
+
# Prepare embeddings
|
| 71 |
+
docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data']
|
| 72 |
+
sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 73 |
+
embeddings = sentence_model.encode(docs, show_progress_bar=False)
|
| 74 |
+
|
| 75 |
+
# Train BERTopic
|
| 76 |
+
topic_model = BERTopic().fit(docs, embeddings)
|
| 77 |
+
|
| 78 |
+
# Reduce dimensionality of embeddings, this step is optional
|
| 79 |
+
# reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings)
|
| 80 |
+
|
| 81 |
+
# Run the visualization with the original embeddings
|
| 82 |
+
topic_model.visualize_document_datamap(docs, embeddings=embeddings)
|
| 83 |
+
|
| 84 |
+
# Or, if you have reduced the original embeddings already:
|
| 85 |
+
topic_model.visualize_document_datamap(docs, reduced_embeddings=reduced_embeddings)
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
Or if you want to save the resulting figure:
|
| 89 |
+
|
| 90 |
+
```python
|
| 91 |
+
fig = topic_model.visualize_document_datamap(docs, reduced_embeddings=reduced_embeddings)
|
| 92 |
+
fig.savefig("path/to/file.png", bbox_inches="tight")
|
| 93 |
+
```
|
| 94 |
+
<img src="../../getting_started/visualization/datamapplot.png",
|
| 95 |
+
alt="DataMapPlot of 20-Newsgroups", width=800, height=800></img>
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
topic_per_doc = topic_model.topics_
|
| 99 |
+
|
| 100 |
+
df = pd.DataFrame({"topic": np.array(topic_per_doc)})
|
| 101 |
+
df["doc"] = docs
|
| 102 |
+
df["topic"] = topic_per_doc
|
| 103 |
+
|
| 104 |
+
# Extract embeddings if not already done
|
| 105 |
+
if embeddings is None and reduced_embeddings is None:
|
| 106 |
+
embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document")
|
| 107 |
+
else:
|
| 108 |
+
embeddings_to_reduce = embeddings
|
| 109 |
+
|
| 110 |
+
# Reduce input embeddings
|
| 111 |
+
if reduced_embeddings is None:
|
| 112 |
+
umap_model = UMAP(n_neighbors=15, n_components=2, min_dist=0.15, metric='cosine').fit(embeddings_to_reduce)
|
| 113 |
+
embeddings_2d = umap_model.embedding_
|
| 114 |
+
else:
|
| 115 |
+
embeddings_2d = reduced_embeddings
|
| 116 |
+
|
| 117 |
+
unique_topics = set(topic_per_doc)
|
| 118 |
+
|
| 119 |
+
# Prepare text and names
|
| 120 |
+
if isinstance(custom_labels, str):
|
| 121 |
+
names = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in unique_topics]
|
| 122 |
+
names = [" ".join([label[0] for label in labels[:4]]) for labels in names]
|
| 123 |
+
names = [label if len(label) < 30 else label[:27] + "..." for label in names]
|
| 124 |
+
elif topic_model.custom_labels_ is not None and custom_labels:
|
| 125 |
+
names = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in unique_topics]
|
| 126 |
+
else:
|
| 127 |
+
names = [f"Topic-{topic}: " + " ".join([word for word, value in topic_model.get_topic(topic)][:3]) for topic in unique_topics]
|
| 128 |
+
|
| 129 |
+
topic_name_mapping = {topic_num: topic_name for topic_num, topic_name in zip(unique_topics, names)}
|
| 130 |
+
topic_name_mapping[-1] = "Unlabelled"
|
| 131 |
+
|
| 132 |
+
# If a set of topics is chosen, set everything else to "Unlabelled"
|
| 133 |
+
if topics is not None:
|
| 134 |
+
selected_topics = set(topics)
|
| 135 |
+
for topic_num in topic_name_mapping:
|
| 136 |
+
if topic_num not in selected_topics:
|
| 137 |
+
topic_name_mapping[topic_num] = "Unlabelled"
|
| 138 |
+
|
| 139 |
+
# Map in topic names and plot
|
| 140 |
+
named_topic_per_doc = pd.Series(topic_per_doc).map(topic_name_mapping).values
|
| 141 |
+
|
| 142 |
+
figure, axes = datamapplot.create_plot(
|
| 143 |
+
embeddings_2d,
|
| 144 |
+
named_topic_per_doc,
|
| 145 |
+
figsize=(width/100, height/100),
|
| 146 |
+
dpi=100,
|
| 147 |
+
title=title,
|
| 148 |
+
sub_title=sub_title,
|
| 149 |
+
**datamap_kwds,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
return figure
|
BERTopic/bertopic/plotting/_distribution.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from typing import Union
|
| 3 |
+
import plotly.graph_objects as go
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def visualize_distribution(topic_model,
|
| 7 |
+
probabilities: np.ndarray,
|
| 8 |
+
min_probability: float = 0.015,
|
| 9 |
+
custom_labels: Union[bool, str] = False,
|
| 10 |
+
title: str = "<b>Topic Probability Distribution</b>",
|
| 11 |
+
width: int = 800,
|
| 12 |
+
height: int = 600) -> go.Figure:
|
| 13 |
+
""" Visualize the distribution of topic probabilities
|
| 14 |
+
|
| 15 |
+
Arguments:
|
| 16 |
+
topic_model: A fitted BERTopic instance.
|
| 17 |
+
probabilities: An array of probability scores
|
| 18 |
+
min_probability: The minimum probability score to visualize.
|
| 19 |
+
All others are ignored.
|
| 20 |
+
custom_labels: If bool, whether to use custom topic labels that were defined using
|
| 21 |
+
`topic_model.set_topic_labels`.
|
| 22 |
+
If `str`, it uses labels from other aspects, e.g., "Aspect1".
|
| 23 |
+
title: Title of the plot.
|
| 24 |
+
width: The width of the figure.
|
| 25 |
+
height: The height of the figure.
|
| 26 |
+
|
| 27 |
+
Examples:
|
| 28 |
+
|
| 29 |
+
Make sure to fit the model before and only input the
|
| 30 |
+
probabilities of a single document:
|
| 31 |
+
|
| 32 |
+
```python
|
| 33 |
+
topic_model.visualize_distribution(probabilities[0])
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
Or if you want to save the resulting figure:
|
| 37 |
+
|
| 38 |
+
```python
|
| 39 |
+
fig = topic_model.visualize_distribution(probabilities[0])
|
| 40 |
+
fig.write_html("path/to/file.html")
|
| 41 |
+
```
|
| 42 |
+
<iframe src="../../getting_started/visualization/probabilities.html"
|
| 43 |
+
style="width:1000px; height: 500px; border: 0px;""></iframe>
|
| 44 |
+
"""
|
| 45 |
+
if len(probabilities.shape) != 1:
|
| 46 |
+
raise ValueError("This visualization cannot be used if you have set `calculate_probabilities` to False "
|
| 47 |
+
"as it uses the topic probabilities of all topics. ")
|
| 48 |
+
if len(probabilities[probabilities > min_probability]) == 0:
|
| 49 |
+
raise ValueError("There are no values where `min_probability` is higher than the "
|
| 50 |
+
"probabilities that were supplied. Lower `min_probability` to prevent this error.")
|
| 51 |
+
|
| 52 |
+
# Get values and indices equal or exceed the minimum probability
|
| 53 |
+
labels_idx = np.argwhere(probabilities >= min_probability).flatten()
|
| 54 |
+
vals = probabilities[labels_idx].tolist()
|
| 55 |
+
|
| 56 |
+
# Create labels
|
| 57 |
+
if isinstance(custom_labels, str):
|
| 58 |
+
labels = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in labels_idx]
|
| 59 |
+
labels = ["_".join([label[0] for label in l[:4]]) for l in labels]
|
| 60 |
+
labels = [label if len(label) < 30 else label[:27] + "..." for label in labels]
|
| 61 |
+
elif topic_model.custom_labels_ is not None and custom_labels:
|
| 62 |
+
labels = [topic_model.custom_labels_[idx + topic_model._outliers] for idx in labels_idx]
|
| 63 |
+
else:
|
| 64 |
+
labels = []
|
| 65 |
+
for idx in labels_idx:
|
| 66 |
+
words = topic_model.get_topic(idx)
|
| 67 |
+
if words:
|
| 68 |
+
label = [word[0] for word in words[:5]]
|
| 69 |
+
label = f"<b>Topic {idx}</b>: {'_'.join(label)}"
|
| 70 |
+
label = label[:40] + "..." if len(label) > 40 else label
|
| 71 |
+
labels.append(label)
|
| 72 |
+
else:
|
| 73 |
+
vals.remove(probabilities[idx])
|
| 74 |
+
|
| 75 |
+
# Create Figure
|
| 76 |
+
fig = go.Figure(go.Bar(
|
| 77 |
+
x=vals,
|
| 78 |
+
y=labels,
|
| 79 |
+
marker=dict(
|
| 80 |
+
color='#C8D2D7',
|
| 81 |
+
line=dict(
|
| 82 |
+
color='#6E8484',
|
| 83 |
+
width=1),
|
| 84 |
+
),
|
| 85 |
+
orientation='h')
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
fig.update_layout(
|
| 89 |
+
xaxis_title="Probability",
|
| 90 |
+
title={
|
| 91 |
+
'text': f"{title}",
|
| 92 |
+
'y': .95,
|
| 93 |
+
'x': 0.5,
|
| 94 |
+
'xanchor': 'center',
|
| 95 |
+
'yanchor': 'top',
|
| 96 |
+
'font': dict(
|
| 97 |
+
size=22,
|
| 98 |
+
color="Black")
|
| 99 |
+
},
|
| 100 |
+
template="simple_white",
|
| 101 |
+
width=width,
|
| 102 |
+
height=height,
|
| 103 |
+
hoverlabel=dict(
|
| 104 |
+
bgcolor="white",
|
| 105 |
+
font_size=16,
|
| 106 |
+
font_family="Rockwell"
|
| 107 |
+
),
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
return fig
|
BERTopic/bertopic/plotting/_documents.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import plotly.graph_objects as go
|
| 4 |
+
|
| 5 |
+
from umap import UMAP
|
| 6 |
+
from typing import List, Union
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def visualize_documents(topic_model,
|
| 10 |
+
docs: List[str],
|
| 11 |
+
topics: List[int] = None,
|
| 12 |
+
embeddings: np.ndarray = None,
|
| 13 |
+
reduced_embeddings: np.ndarray = None,
|
| 14 |
+
sample: float = None,
|
| 15 |
+
hide_annotations: bool = False,
|
| 16 |
+
hide_document_hover: bool = False,
|
| 17 |
+
custom_labels: Union[bool, str] = False,
|
| 18 |
+
title: str = "<b>Documents and Topics</b>",
|
| 19 |
+
width: int = 1200,
|
| 20 |
+
height: int = 750):
|
| 21 |
+
""" Visualize documents and their topics in 2D
|
| 22 |
+
|
| 23 |
+
Arguments:
|
| 24 |
+
topic_model: A fitted BERTopic instance.
|
| 25 |
+
docs: The documents you used when calling either `fit` or `fit_transform`
|
| 26 |
+
topics: A selection of topics to visualize.
|
| 27 |
+
Not to be confused with the topics that you get from `.fit_transform`.
|
| 28 |
+
For example, if you want to visualize only topics 1 through 5:
|
| 29 |
+
`topics = [1, 2, 3, 4, 5]`.
|
| 30 |
+
embeddings: The embeddings of all documents in `docs`.
|
| 31 |
+
reduced_embeddings: The 2D reduced embeddings of all documents in `docs`.
|
| 32 |
+
sample: The percentage of documents in each topic that you would like to keep.
|
| 33 |
+
Value can be between 0 and 1. Setting this value to, for example,
|
| 34 |
+
0.1 (10% of documents in each topic) makes it easier to visualize
|
| 35 |
+
millions of documents as a subset is chosen.
|
| 36 |
+
hide_annotations: Hide the names of the traces on top of each cluster.
|
| 37 |
+
hide_document_hover: Hide the content of the documents when hovering over
|
| 38 |
+
specific points. Helps to speed up generation of visualization.
|
| 39 |
+
custom_labels: If bool, whether to use custom topic labels that were defined using
|
| 40 |
+
`topic_model.set_topic_labels`.
|
| 41 |
+
If `str`, it uses labels from other aspects, e.g., "Aspect1".
|
| 42 |
+
title: Title of the plot.
|
| 43 |
+
width: The width of the figure.
|
| 44 |
+
height: The height of the figure.
|
| 45 |
+
|
| 46 |
+
Examples:
|
| 47 |
+
|
| 48 |
+
To visualize the topics simply run:
|
| 49 |
+
|
| 50 |
+
```python
|
| 51 |
+
topic_model.visualize_documents(docs)
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
Do note that this re-calculates the embeddings and reduces them to 2D.
|
| 55 |
+
The advised and preferred pipeline for using this function is as follows:
|
| 56 |
+
|
| 57 |
+
```python
|
| 58 |
+
from sklearn.datasets import fetch_20newsgroups
|
| 59 |
+
from sentence_transformers import SentenceTransformer
|
| 60 |
+
from bertopic import BERTopic
|
| 61 |
+
from umap import UMAP
|
| 62 |
+
|
| 63 |
+
# Prepare embeddings
|
| 64 |
+
docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data']
|
| 65 |
+
sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 66 |
+
embeddings = sentence_model.encode(docs, show_progress_bar=False)
|
| 67 |
+
|
| 68 |
+
# Train BERTopic
|
| 69 |
+
topic_model = BERTopic().fit(docs, embeddings)
|
| 70 |
+
|
| 71 |
+
# Reduce dimensionality of embeddings, this step is optional
|
| 72 |
+
# reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings)
|
| 73 |
+
|
| 74 |
+
# Run the visualization with the original embeddings
|
| 75 |
+
topic_model.visualize_documents(docs, embeddings=embeddings)
|
| 76 |
+
|
| 77 |
+
# Or, if you have reduced the original embeddings already:
|
| 78 |
+
topic_model.visualize_documents(docs, reduced_embeddings=reduced_embeddings)
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
Or if you want to save the resulting figure:
|
| 82 |
+
|
| 83 |
+
```python
|
| 84 |
+
fig = topic_model.visualize_documents(docs, reduced_embeddings=reduced_embeddings)
|
| 85 |
+
fig.write_html("path/to/file.html")
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
<iframe src="../../getting_started/visualization/documents.html"
|
| 89 |
+
style="width:1000px; height: 800px; border: 0px;""></iframe>
|
| 90 |
+
"""
|
| 91 |
+
topic_per_doc = topic_model.topics_
|
| 92 |
+
|
| 93 |
+
# Sample the data to optimize for visualization and dimensionality reduction
|
| 94 |
+
if sample is None or sample > 1:
|
| 95 |
+
sample = 1
|
| 96 |
+
|
| 97 |
+
indices = []
|
| 98 |
+
for topic in set(topic_per_doc):
|
| 99 |
+
s = np.where(np.array(topic_per_doc) == topic)[0]
|
| 100 |
+
size = len(s) if len(s) < 100 else int(len(s) * sample)
|
| 101 |
+
indices.extend(np.random.choice(s, size=size, replace=False))
|
| 102 |
+
indices = np.array(indices)
|
| 103 |
+
|
| 104 |
+
df = pd.DataFrame({"topic": np.array(topic_per_doc)[indices]})
|
| 105 |
+
df["doc"] = [docs[index] for index in indices]
|
| 106 |
+
df["topic"] = [topic_per_doc[index] for index in indices]
|
| 107 |
+
|
| 108 |
+
# Extract embeddings if not already done
|
| 109 |
+
if sample is None:
|
| 110 |
+
if embeddings is None and reduced_embeddings is None:
|
| 111 |
+
embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document")
|
| 112 |
+
else:
|
| 113 |
+
embeddings_to_reduce = embeddings
|
| 114 |
+
else:
|
| 115 |
+
if embeddings is not None:
|
| 116 |
+
embeddings_to_reduce = embeddings[indices]
|
| 117 |
+
elif embeddings is None and reduced_embeddings is None:
|
| 118 |
+
embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document")
|
| 119 |
+
|
| 120 |
+
# Reduce input embeddings
|
| 121 |
+
if reduced_embeddings is None:
|
| 122 |
+
umap_model = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit(embeddings_to_reduce)
|
| 123 |
+
embeddings_2d = umap_model.embedding_
|
| 124 |
+
elif sample is not None and reduced_embeddings is not None:
|
| 125 |
+
embeddings_2d = reduced_embeddings[indices]
|
| 126 |
+
elif sample is None and reduced_embeddings is not None:
|
| 127 |
+
embeddings_2d = reduced_embeddings
|
| 128 |
+
|
| 129 |
+
unique_topics = set(topic_per_doc)
|
| 130 |
+
if topics is None:
|
| 131 |
+
topics = unique_topics
|
| 132 |
+
|
| 133 |
+
# Combine data
|
| 134 |
+
df["x"] = embeddings_2d[:, 0]
|
| 135 |
+
df["y"] = embeddings_2d[:, 1]
|
| 136 |
+
|
| 137 |
+
# Prepare text and names
|
| 138 |
+
if isinstance(custom_labels, str):
|
| 139 |
+
names = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in unique_topics]
|
| 140 |
+
names = ["_".join([label[0] for label in labels[:4]]) for labels in names]
|
| 141 |
+
names = [label if len(label) < 30 else label[:27] + "..." for label in names]
|
| 142 |
+
elif topic_model.custom_labels_ is not None and custom_labels:
|
| 143 |
+
names = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in unique_topics]
|
| 144 |
+
else:
|
| 145 |
+
names = [f"{topic}_" + "_".join([word for word, value in topic_model.get_topic(topic)][:3]) for topic in unique_topics]
|
| 146 |
+
|
| 147 |
+
# Visualize
|
| 148 |
+
fig = go.Figure()
|
| 149 |
+
|
| 150 |
+
# Outliers and non-selected topics
|
| 151 |
+
non_selected_topics = set(unique_topics).difference(topics)
|
| 152 |
+
if len(non_selected_topics) == 0:
|
| 153 |
+
non_selected_topics = [-1]
|
| 154 |
+
|
| 155 |
+
selection = df.loc[df.topic.isin(non_selected_topics), :]
|
| 156 |
+
selection["text"] = ""
|
| 157 |
+
selection.loc[len(selection), :] = [None, None, selection.x.mean(), selection.y.mean(), "Other documents"]
|
| 158 |
+
|
| 159 |
+
fig.add_trace(
|
| 160 |
+
go.Scattergl(
|
| 161 |
+
x=selection.x,
|
| 162 |
+
y=selection.y,
|
| 163 |
+
hovertext=selection.doc if not hide_document_hover else None,
|
| 164 |
+
hoverinfo="text",
|
| 165 |
+
mode='markers+text',
|
| 166 |
+
name="other",
|
| 167 |
+
showlegend=False,
|
| 168 |
+
marker=dict(color='#CFD8DC', size=5, opacity=0.5)
|
| 169 |
+
)
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# Selected topics
|
| 173 |
+
for name, topic in zip(names, unique_topics):
|
| 174 |
+
if topic in topics and topic != -1:
|
| 175 |
+
selection = df.loc[df.topic == topic, :]
|
| 176 |
+
selection["text"] = ""
|
| 177 |
+
|
| 178 |
+
if not hide_annotations:
|
| 179 |
+
selection.loc[len(selection), :] = [None, None, selection.x.mean(), selection.y.mean(), name]
|
| 180 |
+
|
| 181 |
+
fig.add_trace(
|
| 182 |
+
go.Scattergl(
|
| 183 |
+
x=selection.x,
|
| 184 |
+
y=selection.y,
|
| 185 |
+
hovertext=selection.doc if not hide_document_hover else None,
|
| 186 |
+
hoverinfo="text",
|
| 187 |
+
text=selection.text,
|
| 188 |
+
mode='markers+text',
|
| 189 |
+
name=name,
|
| 190 |
+
textfont=dict(
|
| 191 |
+
size=12,
|
| 192 |
+
),
|
| 193 |
+
marker=dict(size=5, opacity=0.5)
|
| 194 |
+
)
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# Add grid in a 'plus' shape
|
| 198 |
+
x_range = (df.x.min() - abs((df.x.min()) * .15), df.x.max() + abs((df.x.max()) * .15))
|
| 199 |
+
y_range = (df.y.min() - abs((df.y.min()) * .15), df.y.max() + abs((df.y.max()) * .15))
|
| 200 |
+
fig.add_shape(type="line",
|
| 201 |
+
x0=sum(x_range) / 2, y0=y_range[0], x1=sum(x_range) / 2, y1=y_range[1],
|
| 202 |
+
line=dict(color="#CFD8DC", width=2))
|
| 203 |
+
fig.add_shape(type="line",
|
| 204 |
+
x0=x_range[0], y0=sum(y_range) / 2, x1=x_range[1], y1=sum(y_range) / 2,
|
| 205 |
+
line=dict(color="#9E9E9E", width=2))
|
| 206 |
+
fig.add_annotation(x=x_range[0], y=sum(y_range) / 2, text="D1", showarrow=False, yshift=10)
|
| 207 |
+
fig.add_annotation(y=y_range[1], x=sum(x_range) / 2, text="D2", showarrow=False, xshift=10)
|
| 208 |
+
|
| 209 |
+
# Stylize layout
|
| 210 |
+
fig.update_layout(
|
| 211 |
+
template="simple_white",
|
| 212 |
+
title={
|
| 213 |
+
'text': f"{title}",
|
| 214 |
+
'x': 0.5,
|
| 215 |
+
'xanchor': 'center',
|
| 216 |
+
'yanchor': 'top',
|
| 217 |
+
'font': dict(
|
| 218 |
+
size=22,
|
| 219 |
+
color="Black")
|
| 220 |
+
},
|
| 221 |
+
width=width,
|
| 222 |
+
height=height
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
fig.update_xaxes(visible=False)
|
| 226 |
+
fig.update_yaxes(visible=False)
|
| 227 |
+
return fig
|
BERTopic/bertopic/plotting/_heatmap.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from typing import List, Union
|
| 3 |
+
from scipy.cluster.hierarchy import fcluster, linkage
|
| 4 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 5 |
+
|
| 6 |
+
import plotly.express as px
|
| 7 |
+
import plotly.graph_objects as go
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def visualize_heatmap(topic_model,
|
| 11 |
+
topics: List[int] = None,
|
| 12 |
+
top_n_topics: int = None,
|
| 13 |
+
n_clusters: int = None,
|
| 14 |
+
custom_labels: Union[bool, str] = False,
|
| 15 |
+
title: str = "<b>Similarity Matrix</b>",
|
| 16 |
+
width: int = 800,
|
| 17 |
+
height: int = 800) -> go.Figure:
|
| 18 |
+
""" Visualize a heatmap of the topic's similarity matrix
|
| 19 |
+
|
| 20 |
+
Based on the cosine similarity matrix between topic embeddings,
|
| 21 |
+
a heatmap is created showing the similarity between topics.
|
| 22 |
+
|
| 23 |
+
Arguments:
|
| 24 |
+
topic_model: A fitted BERTopic instance.
|
| 25 |
+
topics: A selection of topics to visualize.
|
| 26 |
+
top_n_topics: Only select the top n most frequent topics.
|
| 27 |
+
n_clusters: Create n clusters and order the similarity
|
| 28 |
+
matrix by those clusters.
|
| 29 |
+
custom_labels: If bool, whether to use custom topic labels that were defined using
|
| 30 |
+
`topic_model.set_topic_labels`.
|
| 31 |
+
If `str`, it uses labels from other aspects, e.g., "Aspect1".
|
| 32 |
+
title: Title of the plot.
|
| 33 |
+
width: The width of the figure.
|
| 34 |
+
height: The height of the figure.
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
fig: A plotly figure
|
| 38 |
+
|
| 39 |
+
Examples:
|
| 40 |
+
|
| 41 |
+
To visualize the similarity matrix of
|
| 42 |
+
topics simply run:
|
| 43 |
+
|
| 44 |
+
```python
|
| 45 |
+
topic_model.visualize_heatmap()
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
Or if you want to save the resulting figure:
|
| 49 |
+
|
| 50 |
+
```python
|
| 51 |
+
fig = topic_model.visualize_heatmap()
|
| 52 |
+
fig.write_html("path/to/file.html")
|
| 53 |
+
```
|
| 54 |
+
<iframe src="../../getting_started/visualization/heatmap.html"
|
| 55 |
+
style="width:1000px; height: 720px; border: 0px;""></iframe>
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
# Select topic embeddings
|
| 59 |
+
if topic_model.topic_embeddings_ is not None:
|
| 60 |
+
embeddings = np.array(topic_model.topic_embeddings_)[topic_model._outliers:]
|
| 61 |
+
else:
|
| 62 |
+
embeddings = topic_model.c_tf_idf_[topic_model._outliers:]
|
| 63 |
+
|
| 64 |
+
# Select topics based on top_n and topics args
|
| 65 |
+
freq_df = topic_model.get_topic_freq()
|
| 66 |
+
freq_df = freq_df.loc[freq_df.Topic != -1, :]
|
| 67 |
+
if topics is not None:
|
| 68 |
+
topics = list(topics)
|
| 69 |
+
elif top_n_topics is not None:
|
| 70 |
+
topics = sorted(freq_df.Topic.to_list()[:top_n_topics])
|
| 71 |
+
else:
|
| 72 |
+
topics = sorted(freq_df.Topic.to_list())
|
| 73 |
+
|
| 74 |
+
# Order heatmap by similar clusters of topics
|
| 75 |
+
sorted_topics = topics
|
| 76 |
+
if n_clusters:
|
| 77 |
+
if n_clusters >= len(set(topics)):
|
| 78 |
+
raise ValueError("Make sure to set `n_clusters` lower than "
|
| 79 |
+
"the total number of unique topics.")
|
| 80 |
+
|
| 81 |
+
distance_matrix = cosine_similarity(embeddings[topics])
|
| 82 |
+
Z = linkage(distance_matrix, 'ward')
|
| 83 |
+
clusters = fcluster(Z, t=n_clusters, criterion='maxclust')
|
| 84 |
+
|
| 85 |
+
# Extract new order of topics
|
| 86 |
+
mapping = {cluster: [] for cluster in clusters}
|
| 87 |
+
for topic, cluster in zip(topics, clusters):
|
| 88 |
+
mapping[cluster].append(topic)
|
| 89 |
+
mapping = [cluster for cluster in mapping.values()]
|
| 90 |
+
sorted_topics = [topic for cluster in mapping for topic in cluster]
|
| 91 |
+
|
| 92 |
+
# Select embeddings
|
| 93 |
+
indices = np.array([topics.index(topic) for topic in sorted_topics])
|
| 94 |
+
embeddings = embeddings[indices]
|
| 95 |
+
distance_matrix = cosine_similarity(embeddings)
|
| 96 |
+
|
| 97 |
+
# Create labels
|
| 98 |
+
if isinstance(custom_labels, str):
|
| 99 |
+
new_labels = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in sorted_topics]
|
| 100 |
+
new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels]
|
| 101 |
+
new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels]
|
| 102 |
+
elif topic_model.custom_labels_ is not None and custom_labels:
|
| 103 |
+
new_labels = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in sorted_topics]
|
| 104 |
+
else:
|
| 105 |
+
new_labels = [[[str(topic), None]] + topic_model.get_topic(topic) for topic in sorted_topics]
|
| 106 |
+
new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels]
|
| 107 |
+
new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels]
|
| 108 |
+
|
| 109 |
+
fig = px.imshow(distance_matrix,
|
| 110 |
+
labels=dict(color="Similarity Score"),
|
| 111 |
+
x=new_labels,
|
| 112 |
+
y=new_labels,
|
| 113 |
+
color_continuous_scale='GnBu'
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
fig.update_layout(
|
| 117 |
+
title={
|
| 118 |
+
'text': f"{title}",
|
| 119 |
+
'y': .95,
|
| 120 |
+
'x': 0.55,
|
| 121 |
+
'xanchor': 'center',
|
| 122 |
+
'yanchor': 'top',
|
| 123 |
+
'font': dict(
|
| 124 |
+
size=22,
|
| 125 |
+
color="Black")
|
| 126 |
+
},
|
| 127 |
+
width=width,
|
| 128 |
+
height=height,
|
| 129 |
+
hoverlabel=dict(
|
| 130 |
+
bgcolor="white",
|
| 131 |
+
font_size=16,
|
| 132 |
+
font_family="Rockwell"
|
| 133 |
+
),
|
| 134 |
+
)
|
| 135 |
+
fig.update_layout(showlegend=True)
|
| 136 |
+
fig.update_layout(legend_title_text='Trend')
|
| 137 |
+
|
| 138 |
+
return fig
|
BERTopic/bertopic/plotting/_hierarchical_documents.py
ADDED
|
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import plotly.graph_objects as go
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
from umap import UMAP
|
| 7 |
+
from typing import List, Union
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def visualize_hierarchical_documents(topic_model,
|
| 11 |
+
docs: List[str],
|
| 12 |
+
hierarchical_topics: pd.DataFrame,
|
| 13 |
+
topics: List[int] = None,
|
| 14 |
+
embeddings: np.ndarray = None,
|
| 15 |
+
reduced_embeddings: np.ndarray = None,
|
| 16 |
+
sample: Union[float, int] = None,
|
| 17 |
+
hide_annotations: bool = False,
|
| 18 |
+
hide_document_hover: bool = True,
|
| 19 |
+
nr_levels: int = 10,
|
| 20 |
+
level_scale: str = 'linear',
|
| 21 |
+
custom_labels: Union[bool, str] = False,
|
| 22 |
+
title: str = "<b>Hierarchical Documents and Topics</b>",
|
| 23 |
+
width: int = 1200,
|
| 24 |
+
height: int = 750) -> go.Figure:
|
| 25 |
+
""" Visualize documents and their topics in 2D at different levels of hierarchy
|
| 26 |
+
|
| 27 |
+
Arguments:
|
| 28 |
+
docs: The documents you used when calling either `fit` or `fit_transform`
|
| 29 |
+
hierarchical_topics: A dataframe that contains a hierarchy of topics
|
| 30 |
+
represented by their parents and their children
|
| 31 |
+
topics: A selection of topics to visualize.
|
| 32 |
+
Not to be confused with the topics that you get from `.fit_transform`.
|
| 33 |
+
For example, if you want to visualize only topics 1 through 5:
|
| 34 |
+
`topics = [1, 2, 3, 4, 5]`.
|
| 35 |
+
embeddings: The embeddings of all documents in `docs`.
|
| 36 |
+
reduced_embeddings: The 2D reduced embeddings of all documents in `docs`.
|
| 37 |
+
sample: The percentage of documents in each topic that you would like to keep.
|
| 38 |
+
Value can be between 0 and 1. Setting this value to, for example,
|
| 39 |
+
0.1 (10% of documents in each topic) makes it easier to visualize
|
| 40 |
+
millions of documents as a subset is chosen.
|
| 41 |
+
hide_annotations: Hide the names of the traces on top of each cluster.
|
| 42 |
+
hide_document_hover: Hide the content of the documents when hovering over
|
| 43 |
+
specific points. Helps to speed up generation of visualizations.
|
| 44 |
+
nr_levels: The number of levels to be visualized in the hierarchy. First, the distances
|
| 45 |
+
in `hierarchical_topics.Distance` are split in `nr_levels` lists of distances.
|
| 46 |
+
Then, for each list of distances, the merged topics are selected that have a
|
| 47 |
+
distance less or equal to the maximum distance of the selected list of distances.
|
| 48 |
+
NOTE: To get all possible merged steps, make sure that `nr_levels` is equal to
|
| 49 |
+
the length of `hierarchical_topics`.
|
| 50 |
+
level_scale: Whether to apply a linear or logarithmic (log) scale levels of the distance
|
| 51 |
+
vector. Linear scaling will perform an equal number of merges at each level
|
| 52 |
+
while logarithmic scaling will perform more mergers in earlier levels to
|
| 53 |
+
provide more resolution at higher levels (this can be used for when the number
|
| 54 |
+
of topics is large).
|
| 55 |
+
custom_labels: If bool, whether to use custom topic labels that were defined using
|
| 56 |
+
`topic_model.set_topic_labels`.
|
| 57 |
+
If `str`, it uses labels from other aspects, e.g., "Aspect1".
|
| 58 |
+
NOTE: Custom labels are only generated for the original
|
| 59 |
+
un-merged topics.
|
| 60 |
+
title: Title of the plot.
|
| 61 |
+
width: The width of the figure.
|
| 62 |
+
height: The height of the figure.
|
| 63 |
+
|
| 64 |
+
Examples:
|
| 65 |
+
|
| 66 |
+
To visualize the topics simply run:
|
| 67 |
+
|
| 68 |
+
```python
|
| 69 |
+
topic_model.visualize_hierarchical_documents(docs, hierarchical_topics)
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
Do note that this re-calculates the embeddings and reduces them to 2D.
|
| 73 |
+
The advised and preferred pipeline for using this function is as follows:
|
| 74 |
+
|
| 75 |
+
```python
|
| 76 |
+
from sklearn.datasets import fetch_20newsgroups
|
| 77 |
+
from sentence_transformers import SentenceTransformer
|
| 78 |
+
from bertopic import BERTopic
|
| 79 |
+
from umap import UMAP
|
| 80 |
+
|
| 81 |
+
# Prepare embeddings
|
| 82 |
+
docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data']
|
| 83 |
+
sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 84 |
+
embeddings = sentence_model.encode(docs, show_progress_bar=False)
|
| 85 |
+
|
| 86 |
+
# Train BERTopic and extract hierarchical topics
|
| 87 |
+
topic_model = BERTopic().fit(docs, embeddings)
|
| 88 |
+
hierarchical_topics = topic_model.hierarchical_topics(docs)
|
| 89 |
+
|
| 90 |
+
# Reduce dimensionality of embeddings, this step is optional
|
| 91 |
+
# reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings)
|
| 92 |
+
|
| 93 |
+
# Run the visualization with the original embeddings
|
| 94 |
+
topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, embeddings=embeddings)
|
| 95 |
+
|
| 96 |
+
# Or, if you have reduced the original embeddings already:
|
| 97 |
+
topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, reduced_embeddings=reduced_embeddings)
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
Or if you want to save the resulting figure:
|
| 101 |
+
|
| 102 |
+
```python
|
| 103 |
+
fig = topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, reduced_embeddings=reduced_embeddings)
|
| 104 |
+
fig.write_html("path/to/file.html")
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
NOTE:
|
| 108 |
+
This visualization was inspired by the scatter plot representation of Doc2Map:
|
| 109 |
+
https://github.com/louisgeisler/Doc2Map
|
| 110 |
+
|
| 111 |
+
<iframe src="../../getting_started/visualization/hierarchical_documents.html"
|
| 112 |
+
style="width:1000px; height: 770px; border: 0px;""></iframe>
|
| 113 |
+
"""
|
| 114 |
+
topic_per_doc = topic_model.topics_
|
| 115 |
+
|
| 116 |
+
# Sample the data to optimize for visualization and dimensionality reduction
|
| 117 |
+
if sample is None or sample > 1:
|
| 118 |
+
sample = 1
|
| 119 |
+
|
| 120 |
+
indices = []
|
| 121 |
+
for topic in set(topic_per_doc):
|
| 122 |
+
s = np.where(np.array(topic_per_doc) == topic)[0]
|
| 123 |
+
size = len(s) if len(s) < 100 else int(len(s)*sample)
|
| 124 |
+
indices.extend(np.random.choice(s, size=size, replace=False))
|
| 125 |
+
indices = np.array(indices)
|
| 126 |
+
|
| 127 |
+
df = pd.DataFrame({"topic": np.array(topic_per_doc)[indices]})
|
| 128 |
+
df["doc"] = [docs[index] for index in indices]
|
| 129 |
+
df["topic"] = [topic_per_doc[index] for index in indices]
|
| 130 |
+
|
| 131 |
+
# Extract embeddings if not already done
|
| 132 |
+
if sample is None:
|
| 133 |
+
if embeddings is None and reduced_embeddings is None:
|
| 134 |
+
embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document")
|
| 135 |
+
else:
|
| 136 |
+
embeddings_to_reduce = embeddings
|
| 137 |
+
else:
|
| 138 |
+
if embeddings is not None:
|
| 139 |
+
embeddings_to_reduce = embeddings[indices]
|
| 140 |
+
elif embeddings is None and reduced_embeddings is None:
|
| 141 |
+
embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document")
|
| 142 |
+
|
| 143 |
+
# Reduce input embeddings
|
| 144 |
+
if reduced_embeddings is None:
|
| 145 |
+
umap_model = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit(embeddings_to_reduce)
|
| 146 |
+
embeddings_2d = umap_model.embedding_
|
| 147 |
+
elif sample is not None and reduced_embeddings is not None:
|
| 148 |
+
embeddings_2d = reduced_embeddings[indices]
|
| 149 |
+
elif sample is None and reduced_embeddings is not None:
|
| 150 |
+
embeddings_2d = reduced_embeddings
|
| 151 |
+
|
| 152 |
+
# Combine data
|
| 153 |
+
df["x"] = embeddings_2d[:, 0]
|
| 154 |
+
df["y"] = embeddings_2d[:, 1]
|
| 155 |
+
|
| 156 |
+
# Create topic list for each level, levels are created by calculating the distance
|
| 157 |
+
distances = hierarchical_topics.Distance.to_list()
|
| 158 |
+
if level_scale == 'log' or level_scale == 'logarithmic':
|
| 159 |
+
log_indices = np.round(np.logspace(start=math.log(1,10), stop=math.log(len(distances)-1,10), num=nr_levels)).astype(int).tolist()
|
| 160 |
+
log_indices.reverse()
|
| 161 |
+
max_distances = [distances[i] for i in log_indices]
|
| 162 |
+
elif level_scale == 'lin' or level_scale == 'linear':
|
| 163 |
+
max_distances = [distances[indices[-1]] for indices in np.array_split(range(len(hierarchical_topics)), nr_levels)][::-1]
|
| 164 |
+
else:
|
| 165 |
+
raise ValueError("level_scale needs to be one of 'log' or 'linear'")
|
| 166 |
+
|
| 167 |
+
for index, max_distance in enumerate(max_distances):
|
| 168 |
+
|
| 169 |
+
# Get topics below `max_distance`
|
| 170 |
+
mapping = {topic: topic for topic in df.topic.unique()}
|
| 171 |
+
selection = hierarchical_topics.loc[hierarchical_topics.Distance <= max_distance, :]
|
| 172 |
+
selection.Parent_ID = selection.Parent_ID.astype(int)
|
| 173 |
+
selection = selection.sort_values("Parent_ID")
|
| 174 |
+
|
| 175 |
+
for row in selection.iterrows():
|
| 176 |
+
for topic in row[1].Topics:
|
| 177 |
+
mapping[topic] = row[1].Parent_ID
|
| 178 |
+
|
| 179 |
+
# Make sure the mappings are mapped 1:1
|
| 180 |
+
mappings = [True for _ in mapping]
|
| 181 |
+
while any(mappings):
|
| 182 |
+
for i, (key, value) in enumerate(mapping.items()):
|
| 183 |
+
if value in mapping.keys() and key != value:
|
| 184 |
+
mapping[key] = mapping[value]
|
| 185 |
+
else:
|
| 186 |
+
mappings[i] = False
|
| 187 |
+
|
| 188 |
+
# Create new column
|
| 189 |
+
df[f"level_{index+1}"] = df.topic.map(mapping)
|
| 190 |
+
df[f"level_{index+1}"] = df[f"level_{index+1}"].astype(int)
|
| 191 |
+
|
| 192 |
+
# Prepare topic names of original and merged topics
|
| 193 |
+
trace_names = []
|
| 194 |
+
topic_names = {}
|
| 195 |
+
for topic in range(hierarchical_topics.Parent_ID.astype(int).max()):
|
| 196 |
+
if topic < hierarchical_topics.Parent_ID.astype(int).min():
|
| 197 |
+
if topic_model.get_topic(topic):
|
| 198 |
+
if isinstance(custom_labels, str):
|
| 199 |
+
trace_name = f"{topic}_" + "_".join(list(zip(*topic_model.topic_aspects_[custom_labels][topic]))[0][:3])
|
| 200 |
+
elif topic_model.custom_labels_ is not None and custom_labels:
|
| 201 |
+
trace_name = topic_model.custom_labels_[topic + topic_model._outliers]
|
| 202 |
+
else:
|
| 203 |
+
trace_name = f"{topic}_" + "_".join([word[:20] for word, _ in topic_model.get_topic(topic)][:3])
|
| 204 |
+
topic_names[topic] = {"trace_name": trace_name[:40], "plot_text": trace_name[:40]}
|
| 205 |
+
trace_names.append(trace_name)
|
| 206 |
+
else:
|
| 207 |
+
trace_name = f"{topic}_" + hierarchical_topics.loc[hierarchical_topics.Parent_ID == str(topic), "Parent_Name"].values[0]
|
| 208 |
+
plot_text = "_".join([name[:20] for name in trace_name.split("_")[:3]])
|
| 209 |
+
topic_names[topic] = {"trace_name": trace_name[:40], "plot_text": plot_text[:40]}
|
| 210 |
+
trace_names.append(trace_name)
|
| 211 |
+
|
| 212 |
+
# Prepare traces
|
| 213 |
+
all_traces = []
|
| 214 |
+
for level in range(len(max_distances)):
|
| 215 |
+
traces = []
|
| 216 |
+
|
| 217 |
+
# Outliers
|
| 218 |
+
if topic_model._outliers:
|
| 219 |
+
traces.append(
|
| 220 |
+
go.Scattergl(
|
| 221 |
+
x=df.loc[(df[f"level_{level+1}"] == -1), "x"],
|
| 222 |
+
y=df.loc[df[f"level_{level+1}"] == -1, "y"],
|
| 223 |
+
mode='markers+text',
|
| 224 |
+
name="other",
|
| 225 |
+
hoverinfo="text",
|
| 226 |
+
hovertext=df.loc[(df[f"level_{level+1}"] == -1), "doc"] if not hide_document_hover else None,
|
| 227 |
+
showlegend=False,
|
| 228 |
+
marker=dict(color='#CFD8DC', size=5, opacity=0.5)
|
| 229 |
+
)
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# Selected topics
|
| 233 |
+
if topics:
|
| 234 |
+
selection = df.loc[(df.topic.isin(topics)), :]
|
| 235 |
+
unique_topics = sorted([int(topic) for topic in selection[f"level_{level+1}"].unique()])
|
| 236 |
+
else:
|
| 237 |
+
unique_topics = sorted([int(topic) for topic in df[f"level_{level+1}"].unique()])
|
| 238 |
+
|
| 239 |
+
for topic in unique_topics:
|
| 240 |
+
if topic != -1:
|
| 241 |
+
if topics:
|
| 242 |
+
selection = df.loc[(df[f"level_{level+1}"] == topic) &
|
| 243 |
+
(df.topic.isin(topics)), :]
|
| 244 |
+
else:
|
| 245 |
+
selection = df.loc[df[f"level_{level+1}"] == topic, :]
|
| 246 |
+
|
| 247 |
+
if not hide_annotations:
|
| 248 |
+
selection.loc[len(selection), :] = None
|
| 249 |
+
selection["text"] = ""
|
| 250 |
+
selection.loc[len(selection) - 1, "x"] = selection.x.mean()
|
| 251 |
+
selection.loc[len(selection) - 1, "y"] = selection.y.mean()
|
| 252 |
+
selection.loc[len(selection) - 1, "text"] = topic_names[int(topic)]["plot_text"]
|
| 253 |
+
|
| 254 |
+
traces.append(
|
| 255 |
+
go.Scattergl(
|
| 256 |
+
x=selection.x,
|
| 257 |
+
y=selection.y,
|
| 258 |
+
text=selection.text if not hide_annotations else None,
|
| 259 |
+
hovertext=selection.doc if not hide_document_hover else None,
|
| 260 |
+
hoverinfo="text",
|
| 261 |
+
name=topic_names[int(topic)]["trace_name"],
|
| 262 |
+
mode='markers+text',
|
| 263 |
+
marker=dict(size=5, opacity=0.5)
|
| 264 |
+
)
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
all_traces.append(traces)
|
| 268 |
+
|
| 269 |
+
# Track and count traces
|
| 270 |
+
nr_traces_per_set = [len(traces) for traces in all_traces]
|
| 271 |
+
trace_indices = [(0, nr_traces_per_set[0])]
|
| 272 |
+
for index, nr_traces in enumerate(nr_traces_per_set[1:]):
|
| 273 |
+
start = trace_indices[index][1]
|
| 274 |
+
end = nr_traces + start
|
| 275 |
+
trace_indices.append((start, end))
|
| 276 |
+
|
| 277 |
+
# Visualization
|
| 278 |
+
fig = go.Figure()
|
| 279 |
+
for traces in all_traces:
|
| 280 |
+
for trace in traces:
|
| 281 |
+
fig.add_trace(trace)
|
| 282 |
+
|
| 283 |
+
for index in range(len(fig.data)):
|
| 284 |
+
if index >= nr_traces_per_set[0]:
|
| 285 |
+
fig.data[index].visible = False
|
| 286 |
+
|
| 287 |
+
# Create and add slider
|
| 288 |
+
steps = []
|
| 289 |
+
for index, indices in enumerate(trace_indices):
|
| 290 |
+
step = dict(
|
| 291 |
+
method="update",
|
| 292 |
+
label=str(index),
|
| 293 |
+
args=[{"visible": [False] * len(fig.data)}]
|
| 294 |
+
)
|
| 295 |
+
for index in range(indices[1]-indices[0]):
|
| 296 |
+
step["args"][0]["visible"][index+indices[0]] = True
|
| 297 |
+
steps.append(step)
|
| 298 |
+
|
| 299 |
+
sliders = [dict(
|
| 300 |
+
currentvalue={"prefix": "Level: "},
|
| 301 |
+
pad={"t": 20},
|
| 302 |
+
steps=steps
|
| 303 |
+
)]
|
| 304 |
+
|
| 305 |
+
# Add grid in a 'plus' shape
|
| 306 |
+
x_range = (df.x.min() - abs((df.x.min()) * .15), df.x.max() + abs((df.x.max()) * .15))
|
| 307 |
+
y_range = (df.y.min() - abs((df.y.min()) * .15), df.y.max() + abs((df.y.max()) * .15))
|
| 308 |
+
fig.add_shape(type="line",
|
| 309 |
+
x0=sum(x_range) / 2, y0=y_range[0], x1=sum(x_range) / 2, y1=y_range[1],
|
| 310 |
+
line=dict(color="#CFD8DC", width=2))
|
| 311 |
+
fig.add_shape(type="line",
|
| 312 |
+
x0=x_range[0], y0=sum(y_range) / 2, x1=x_range[1], y1=sum(y_range) / 2,
|
| 313 |
+
line=dict(color="#9E9E9E", width=2))
|
| 314 |
+
fig.add_annotation(x=x_range[0], y=sum(y_range) / 2, text="D1", showarrow=False, yshift=10)
|
| 315 |
+
fig.add_annotation(y=y_range[1], x=sum(x_range) / 2, text="D2", showarrow=False, xshift=10)
|
| 316 |
+
|
| 317 |
+
# Stylize layout
|
| 318 |
+
fig.update_layout(
|
| 319 |
+
sliders=sliders,
|
| 320 |
+
template="simple_white",
|
| 321 |
+
title={
|
| 322 |
+
'text': f"{title}",
|
| 323 |
+
'x': 0.5,
|
| 324 |
+
'xanchor': 'center',
|
| 325 |
+
'yanchor': 'top',
|
| 326 |
+
'font': dict(
|
| 327 |
+
size=22,
|
| 328 |
+
color="Black")
|
| 329 |
+
},
|
| 330 |
+
width=width,
|
| 331 |
+
height=height,
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
fig.update_xaxes(visible=False)
|
| 335 |
+
fig.update_yaxes(visible=False)
|
| 336 |
+
return fig
|
BERTopic/bertopic/plotting/_hierarchy.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from typing import Callable, List, Union
|
| 4 |
+
from scipy.sparse import csr_matrix
|
| 5 |
+
from scipy.cluster import hierarchy as sch
|
| 6 |
+
from scipy.spatial.distance import squareform
|
| 7 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 8 |
+
|
| 9 |
+
import plotly.graph_objects as go
|
| 10 |
+
import plotly.figure_factory as ff
|
| 11 |
+
|
| 12 |
+
from bertopic._utils import validate_distance_matrix
|
| 13 |
+
|
| 14 |
+
def visualize_hierarchy(topic_model,
|
| 15 |
+
orientation: str = "left",
|
| 16 |
+
topics: List[int] = None,
|
| 17 |
+
top_n_topics: int = None,
|
| 18 |
+
custom_labels: Union[bool, str] = False,
|
| 19 |
+
title: str = "<b>Hierarchical Clustering</b>",
|
| 20 |
+
width: int = 1000,
|
| 21 |
+
height: int = 600,
|
| 22 |
+
hierarchical_topics: pd.DataFrame = None,
|
| 23 |
+
linkage_function: Callable[[csr_matrix], np.ndarray] = None,
|
| 24 |
+
distance_function: Callable[[csr_matrix], csr_matrix] = None,
|
| 25 |
+
color_threshold: int = 1) -> go.Figure:
|
| 26 |
+
""" Visualize a hierarchical structure of the topics
|
| 27 |
+
|
| 28 |
+
A ward linkage function is used to perform the
|
| 29 |
+
hierarchical clustering based on the cosine distance
|
| 30 |
+
matrix between topic embeddings.
|
| 31 |
+
|
| 32 |
+
Arguments:
|
| 33 |
+
topic_model: A fitted BERTopic instance.
|
| 34 |
+
orientation: The orientation of the figure.
|
| 35 |
+
Either 'left' or 'bottom'
|
| 36 |
+
topics: A selection of topics to visualize
|
| 37 |
+
top_n_topics: Only select the top n most frequent topics
|
| 38 |
+
custom_labels: If bool, whether to use custom topic labels that were defined using
|
| 39 |
+
`topic_model.set_topic_labels`.
|
| 40 |
+
If `str`, it uses labels from other aspects, e.g., "Aspect1".
|
| 41 |
+
NOTE: Custom labels are only generated for the original
|
| 42 |
+
un-merged topics.
|
| 43 |
+
title: Title of the plot.
|
| 44 |
+
width: The width of the figure. Only works if orientation is set to 'left'
|
| 45 |
+
height: The height of the figure. Only works if orientation is set to 'bottom'
|
| 46 |
+
hierarchical_topics: A dataframe that contains a hierarchy of topics
|
| 47 |
+
represented by their parents and their children.
|
| 48 |
+
NOTE: The hierarchical topic names are only visualized
|
| 49 |
+
if both `topics` and `top_n_topics` are not set.
|
| 50 |
+
linkage_function: The linkage function to use. Default is:
|
| 51 |
+
`lambda x: sch.linkage(x, 'ward', optimal_ordering=True)`
|
| 52 |
+
NOTE: Make sure to use the same `linkage_function` as used
|
| 53 |
+
in `topic_model.hierarchical_topics`.
|
| 54 |
+
distance_function: The distance function to use on the c-TF-IDF matrix. Default is:
|
| 55 |
+
`lambda x: 1 - cosine_similarity(x)`.
|
| 56 |
+
You can pass any function that returns either a square matrix of
|
| 57 |
+
shape (n_samples, n_samples) with zeros on the diagonal and
|
| 58 |
+
non-negative values or condensed distance matrix of shape
|
| 59 |
+
(n_samples * (n_samples - 1) / 2,) containing the upper
|
| 60 |
+
triangular of the distance matrix.
|
| 61 |
+
NOTE: Make sure to use the same `distance_function` as used
|
| 62 |
+
in `topic_model.hierarchical_topics`.
|
| 63 |
+
color_threshold: Value at which the separation of clusters will be made which
|
| 64 |
+
will result in different colors for different clusters.
|
| 65 |
+
A higher value will typically lead in less colored clusters.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
fig: A plotly figure
|
| 69 |
+
|
| 70 |
+
Examples:
|
| 71 |
+
|
| 72 |
+
To visualize the hierarchical structure of
|
| 73 |
+
topics simply run:
|
| 74 |
+
|
| 75 |
+
```python
|
| 76 |
+
topic_model.visualize_hierarchy()
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
If you also want the labels visualized of hierarchical topics,
|
| 80 |
+
run the following:
|
| 81 |
+
|
| 82 |
+
```python
|
| 83 |
+
# Extract hierarchical topics and their representations
|
| 84 |
+
hierarchical_topics = topic_model.hierarchical_topics(docs)
|
| 85 |
+
|
| 86 |
+
# Visualize these representations
|
| 87 |
+
topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics)
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
If you want to save the resulting figure:
|
| 91 |
+
|
| 92 |
+
```python
|
| 93 |
+
fig = topic_model.visualize_hierarchy()
|
| 94 |
+
fig.write_html("path/to/file.html")
|
| 95 |
+
```
|
| 96 |
+
<iframe src="../../getting_started/visualization/hierarchy.html"
|
| 97 |
+
style="width:1000px; height: 680px; border: 0px;""></iframe>
|
| 98 |
+
"""
|
| 99 |
+
if distance_function is None:
|
| 100 |
+
distance_function = lambda x: 1 - cosine_similarity(x)
|
| 101 |
+
|
| 102 |
+
if linkage_function is None:
|
| 103 |
+
linkage_function = lambda x: sch.linkage(x, 'ward', optimal_ordering=True)
|
| 104 |
+
|
| 105 |
+
# Select topics based on top_n and topics args
|
| 106 |
+
freq_df = topic_model.get_topic_freq()
|
| 107 |
+
freq_df = freq_df.loc[freq_df.Topic != -1, :]
|
| 108 |
+
if topics is not None:
|
| 109 |
+
topics = list(topics)
|
| 110 |
+
elif top_n_topics is not None:
|
| 111 |
+
topics = sorted(freq_df.Topic.to_list()[:top_n_topics])
|
| 112 |
+
else:
|
| 113 |
+
topics = sorted(freq_df.Topic.to_list())
|
| 114 |
+
|
| 115 |
+
# Select embeddings
|
| 116 |
+
all_topics = sorted(list(topic_model.get_topics().keys()))
|
| 117 |
+
indices = np.array([all_topics.index(topic) for topic in topics])
|
| 118 |
+
|
| 119 |
+
# Select topic embeddings
|
| 120 |
+
if topic_model.c_tf_idf_ is not None:
|
| 121 |
+
embeddings = topic_model.c_tf_idf_[indices]
|
| 122 |
+
else:
|
| 123 |
+
embeddings = np.array(topic_model.topic_embeddings_)[indices]
|
| 124 |
+
|
| 125 |
+
# Annotations
|
| 126 |
+
if hierarchical_topics is not None and len(topics) == len(freq_df.Topic.to_list()):
|
| 127 |
+
annotations = _get_annotations(topic_model=topic_model,
|
| 128 |
+
hierarchical_topics=hierarchical_topics,
|
| 129 |
+
embeddings=embeddings,
|
| 130 |
+
distance_function=distance_function,
|
| 131 |
+
linkage_function=linkage_function,
|
| 132 |
+
orientation=orientation,
|
| 133 |
+
custom_labels=custom_labels)
|
| 134 |
+
else:
|
| 135 |
+
annotations = None
|
| 136 |
+
|
| 137 |
+
# wrap distance function to validate input and return a condensed distance matrix
|
| 138 |
+
distance_function_viz = lambda x: validate_distance_matrix(
|
| 139 |
+
distance_function(x), embeddings.shape[0])
|
| 140 |
+
# Create dendogram
|
| 141 |
+
fig = ff.create_dendrogram(embeddings,
|
| 142 |
+
orientation=orientation,
|
| 143 |
+
distfun=distance_function_viz,
|
| 144 |
+
linkagefun=linkage_function,
|
| 145 |
+
hovertext=annotations,
|
| 146 |
+
color_threshold=color_threshold)
|
| 147 |
+
|
| 148 |
+
# Create nicer labels
|
| 149 |
+
axis = "yaxis" if orientation == "left" else "xaxis"
|
| 150 |
+
if isinstance(custom_labels, str):
|
| 151 |
+
new_labels = [[[str(x), None]] + topic_model.topic_aspects_[custom_labels][x] for x in fig.layout[axis]["ticktext"]]
|
| 152 |
+
new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels]
|
| 153 |
+
new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels]
|
| 154 |
+
elif topic_model.custom_labels_ is not None and custom_labels:
|
| 155 |
+
new_labels = [topic_model.custom_labels_[topics[int(x)] + topic_model._outliers] for x in fig.layout[axis]["ticktext"]]
|
| 156 |
+
else:
|
| 157 |
+
new_labels = [[[str(topics[int(x)]), None]] + topic_model.get_topic(topics[int(x)])
|
| 158 |
+
for x in fig.layout[axis]["ticktext"]]
|
| 159 |
+
new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels]
|
| 160 |
+
new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels]
|
| 161 |
+
|
| 162 |
+
# Stylize layout
|
| 163 |
+
fig.update_layout(
|
| 164 |
+
plot_bgcolor='#ECEFF1',
|
| 165 |
+
template="plotly_white",
|
| 166 |
+
title={
|
| 167 |
+
'text': f"{title}",
|
| 168 |
+
'x': 0.5,
|
| 169 |
+
'xanchor': 'center',
|
| 170 |
+
'yanchor': 'top',
|
| 171 |
+
'font': dict(
|
| 172 |
+
size=22,
|
| 173 |
+
color="Black")
|
| 174 |
+
},
|
| 175 |
+
hoverlabel=dict(
|
| 176 |
+
bgcolor="white",
|
| 177 |
+
font_size=16,
|
| 178 |
+
font_family="Rockwell"
|
| 179 |
+
),
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# Stylize orientation
|
| 183 |
+
if orientation == "left":
|
| 184 |
+
fig.update_layout(height=200 + (15 * len(topics)),
|
| 185 |
+
width=width,
|
| 186 |
+
yaxis=dict(tickmode="array",
|
| 187 |
+
ticktext=new_labels))
|
| 188 |
+
|
| 189 |
+
# Fix empty space on the bottom of the graph
|
| 190 |
+
y_max = max([trace['y'].max() + 5 for trace in fig['data']])
|
| 191 |
+
y_min = min([trace['y'].min() - 5 for trace in fig['data']])
|
| 192 |
+
fig.update_layout(yaxis=dict(range=[y_min, y_max]))
|
| 193 |
+
|
| 194 |
+
else:
|
| 195 |
+
fig.update_layout(width=200 + (15 * len(topics)),
|
| 196 |
+
height=height,
|
| 197 |
+
xaxis=dict(tickmode="array",
|
| 198 |
+
ticktext=new_labels))
|
| 199 |
+
|
| 200 |
+
if hierarchical_topics is not None:
|
| 201 |
+
for index in [0, 3]:
|
| 202 |
+
axis = "x" if orientation == "left" else "y"
|
| 203 |
+
xs = [data["x"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)]
|
| 204 |
+
ys = [data["y"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)]
|
| 205 |
+
hovertext = [data["text"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)]
|
| 206 |
+
|
| 207 |
+
fig.add_trace(go.Scatter(x=xs, y=ys, marker_color='black',
|
| 208 |
+
hovertext=hovertext, hoverinfo="text",
|
| 209 |
+
mode='markers', showlegend=False))
|
| 210 |
+
return fig
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def _get_annotations(topic_model,
|
| 214 |
+
hierarchical_topics: pd.DataFrame,
|
| 215 |
+
embeddings: csr_matrix,
|
| 216 |
+
linkage_function: Callable[[csr_matrix], np.ndarray],
|
| 217 |
+
distance_function: Callable[[csr_matrix], csr_matrix],
|
| 218 |
+
orientation: str,
|
| 219 |
+
custom_labels: bool = False) -> List[List[str]]:
|
| 220 |
+
|
| 221 |
+
""" Get annotations by replicating linkage function calculation in scipy
|
| 222 |
+
|
| 223 |
+
Arguments
|
| 224 |
+
topic_model: A fitted BERTopic instance.
|
| 225 |
+
hierarchical_topics: A dataframe that contains a hierarchy of topics
|
| 226 |
+
represented by their parents and their children.
|
| 227 |
+
NOTE: The hierarchical topic names are only visualized
|
| 228 |
+
if both `topics` and `top_n_topics` are not set.
|
| 229 |
+
embeddings: The c-TF-IDF matrix on which to model the hierarchy
|
| 230 |
+
linkage_function: The linkage function to use. Default is:
|
| 231 |
+
`lambda x: sch.linkage(x, 'ward', optimal_ordering=True)`
|
| 232 |
+
NOTE: Make sure to use the same `linkage_function` as used
|
| 233 |
+
in `topic_model.hierarchical_topics`.
|
| 234 |
+
distance_function: The distance function to use on the c-TF-IDF matrix. Default is:
|
| 235 |
+
`lambda x: 1 - cosine_similarity(x)`.
|
| 236 |
+
You can pass any function that returns either a square matrix of
|
| 237 |
+
shape (n_samples, n_samples) with zeros on the diagonal and
|
| 238 |
+
non-negative values or condensed distance matrix of shape
|
| 239 |
+
(n_samples * (n_samples - 1) / 2,) containing the upper
|
| 240 |
+
triangular of the distance matrix.
|
| 241 |
+
NOTE: Make sure to use the same `distance_function` as used
|
| 242 |
+
in `topic_model.hierarchical_topics`.
|
| 243 |
+
orientation: The orientation of the figure.
|
| 244 |
+
Either 'left' or 'bottom'
|
| 245 |
+
custom_labels: Whether to use custom topic labels that were defined using
|
| 246 |
+
`topic_model.set_topic_labels`.
|
| 247 |
+
NOTE: Custom labels are only generated for the original
|
| 248 |
+
un-merged topics.
|
| 249 |
+
|
| 250 |
+
Returns:
|
| 251 |
+
text_annotations: Annotations to be used within Plotly's `ff.create_dendogram`
|
| 252 |
+
"""
|
| 253 |
+
df = hierarchical_topics.loc[hierarchical_topics.Parent_Name != "Top", :]
|
| 254 |
+
|
| 255 |
+
# Calculate distance
|
| 256 |
+
X = distance_function(embeddings)
|
| 257 |
+
X = validate_distance_matrix(X, embeddings.shape[0])
|
| 258 |
+
|
| 259 |
+
# Calculate linkage and generate dendrogram
|
| 260 |
+
Z = linkage_function(X)
|
| 261 |
+
P = sch.dendrogram(Z, orientation=orientation, no_plot=True)
|
| 262 |
+
|
| 263 |
+
# store topic no.(leaves) corresponding to the x-ticks in dendrogram
|
| 264 |
+
x_ticks = np.arange(5, len(P['leaves']) * 10 + 5, 10)
|
| 265 |
+
x_topic = dict(zip(P['leaves'], x_ticks))
|
| 266 |
+
|
| 267 |
+
topic_vals = dict()
|
| 268 |
+
for key, val in x_topic.items():
|
| 269 |
+
topic_vals[val] = [key]
|
| 270 |
+
|
| 271 |
+
parent_topic = dict(zip(df.Parent_ID, df.Topics))
|
| 272 |
+
|
| 273 |
+
# loop through every trace (scatter plot) in dendrogram
|
| 274 |
+
text_annotations = []
|
| 275 |
+
for index, trace in enumerate(P['icoord']):
|
| 276 |
+
fst_topic = topic_vals[trace[0]]
|
| 277 |
+
scnd_topic = topic_vals[trace[2]]
|
| 278 |
+
|
| 279 |
+
if len(fst_topic) == 1:
|
| 280 |
+
if isinstance(custom_labels, str):
|
| 281 |
+
fst_name = f"{fst_topic[0]}_" + "_".join(list(zip(*topic_model.topic_aspects_[custom_labels][fst_topic[0]]))[0][:3])
|
| 282 |
+
elif topic_model.custom_labels_ is not None and custom_labels:
|
| 283 |
+
fst_name = topic_model.custom_labels_[fst_topic[0] + topic_model._outliers]
|
| 284 |
+
else:
|
| 285 |
+
fst_name = "_".join([word for word, _ in topic_model.get_topic(fst_topic[0])][:5])
|
| 286 |
+
else:
|
| 287 |
+
for key, value in parent_topic.items():
|
| 288 |
+
if set(value) == set(fst_topic):
|
| 289 |
+
fst_name = df.loc[df.Parent_ID == key, "Parent_Name"].values[0]
|
| 290 |
+
|
| 291 |
+
if len(scnd_topic) == 1:
|
| 292 |
+
if isinstance(custom_labels, str):
|
| 293 |
+
scnd_name = f"{scnd_topic[0]}_" + "_".join(list(zip(*topic_model.topic_aspects_[custom_labels][scnd_topic[0]]))[0][:3])
|
| 294 |
+
elif topic_model.custom_labels_ is not None and custom_labels:
|
| 295 |
+
scnd_name = topic_model.custom_labels_[scnd_topic[0] + topic_model._outliers]
|
| 296 |
+
else:
|
| 297 |
+
scnd_name = "_".join([word for word, _ in topic_model.get_topic(scnd_topic[0])][:5])
|
| 298 |
+
else:
|
| 299 |
+
for key, value in parent_topic.items():
|
| 300 |
+
if set(value) == set(scnd_topic):
|
| 301 |
+
scnd_name = df.loc[df.Parent_ID == key, "Parent_Name"].values[0]
|
| 302 |
+
|
| 303 |
+
text_annotations.append([fst_name, "", "", scnd_name])
|
| 304 |
+
|
| 305 |
+
center = (trace[0] + trace[2]) / 2
|
| 306 |
+
topic_vals[center] = fst_topic + scnd_topic
|
| 307 |
+
|
| 308 |
+
return text_annotations
|
BERTopic/bertopic/plotting/_term_rank.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from typing import List, Union
|
| 3 |
+
import plotly.graph_objects as go
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def visualize_term_rank(topic_model,
|
| 7 |
+
topics: List[int] = None,
|
| 8 |
+
log_scale: bool = False,
|
| 9 |
+
custom_labels: Union[bool, str] = False,
|
| 10 |
+
title: str = "<b>Term score decline per Topic</b>",
|
| 11 |
+
width: int = 800,
|
| 12 |
+
height: int = 500) -> go.Figure:
|
| 13 |
+
""" Visualize the ranks of all terms across all topics
|
| 14 |
+
|
| 15 |
+
Each topic is represented by a set of words. These words, however,
|
| 16 |
+
do not all equally represent the topic. This visualization shows
|
| 17 |
+
how many words are needed to represent a topic and at which point
|
| 18 |
+
the beneficial effect of adding words starts to decline.
|
| 19 |
+
|
| 20 |
+
Arguments:
|
| 21 |
+
topic_model: A fitted BERTopic instance.
|
| 22 |
+
topics: A selection of topics to visualize. These will be colored
|
| 23 |
+
red where all others will be colored black.
|
| 24 |
+
log_scale: Whether to represent the ranking on a log scale
|
| 25 |
+
custom_labels: If bool, whether to use custom topic labels that were defined using
|
| 26 |
+
`topic_model.set_topic_labels`.
|
| 27 |
+
If `str`, it uses labels from other aspects, e.g., "Aspect1".
|
| 28 |
+
title: Title of the plot.
|
| 29 |
+
width: The width of the figure.
|
| 30 |
+
height: The height of the figure.
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
fig: A plotly figure
|
| 34 |
+
|
| 35 |
+
Examples:
|
| 36 |
+
|
| 37 |
+
To visualize the ranks of all words across
|
| 38 |
+
all topics simply run:
|
| 39 |
+
|
| 40 |
+
```python
|
| 41 |
+
topic_model.visualize_term_rank()
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
Or if you want to save the resulting figure:
|
| 45 |
+
|
| 46 |
+
```python
|
| 47 |
+
fig = topic_model.visualize_term_rank()
|
| 48 |
+
fig.write_html("path/to/file.html")
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
<iframe src="../../getting_started/visualization/term_rank.html"
|
| 52 |
+
style="width:1000px; height: 530px; border: 0px;""></iframe>
|
| 53 |
+
|
| 54 |
+
<iframe src="../../getting_started/visualization/term_rank_log.html"
|
| 55 |
+
style="width:1000px; height: 530px; border: 0px;""></iframe>
|
| 56 |
+
|
| 57 |
+
Reference:
|
| 58 |
+
|
| 59 |
+
This visualization was heavily inspired by the
|
| 60 |
+
"Term Probability Decline" visualization found in an
|
| 61 |
+
analysis by the amazing [tmtoolkit](https://tmtoolkit.readthedocs.io/).
|
| 62 |
+
Reference to that specific analysis can be found
|
| 63 |
+
[here](https://wzbsocialsciencecenter.github.io/tm_corona/tm_analysis.html).
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
topics = [] if topics is None else topics
|
| 67 |
+
|
| 68 |
+
topic_ids = topic_model.get_topic_info().Topic.unique().tolist()
|
| 69 |
+
topic_words = [topic_model.get_topic(topic) for topic in topic_ids]
|
| 70 |
+
|
| 71 |
+
values = np.array([[value[1] for value in values] for values in topic_words])
|
| 72 |
+
indices = np.array([[value + 1 for value in range(len(values))] for values in topic_words])
|
| 73 |
+
|
| 74 |
+
# Create figure
|
| 75 |
+
lines = []
|
| 76 |
+
for topic, x, y in zip(topic_ids, indices, values):
|
| 77 |
+
if not any(y > 1.5):
|
| 78 |
+
|
| 79 |
+
# labels
|
| 80 |
+
if isinstance(custom_labels, str):
|
| 81 |
+
label = f"{topic}_" + "_".join(list(zip(*topic_model.topic_aspects_[custom_labels][topic]))[0][:3])
|
| 82 |
+
elif topic_model.custom_labels_ is not None and custom_labels:
|
| 83 |
+
label = topic_model.custom_labels_[topic + topic_model._outliers]
|
| 84 |
+
else:
|
| 85 |
+
label = f"<b>Topic {topic}</b>:" + "_".join([word[0] for word in topic_model.get_topic(topic)])
|
| 86 |
+
label = label[:50]
|
| 87 |
+
|
| 88 |
+
# line parameters
|
| 89 |
+
color = "red" if topic in topics else "black"
|
| 90 |
+
opacity = 1 if topic in topics else .1
|
| 91 |
+
if any(y == 0):
|
| 92 |
+
y[y == 0] = min(values[values > 0])
|
| 93 |
+
y = np.log10(y, out=y, where=y > 0) if log_scale else y
|
| 94 |
+
|
| 95 |
+
line = go.Scatter(x=x, y=y,
|
| 96 |
+
name="",
|
| 97 |
+
hovertext=label,
|
| 98 |
+
mode="lines+lines",
|
| 99 |
+
opacity=opacity,
|
| 100 |
+
line=dict(color=color, width=1.5))
|
| 101 |
+
lines.append(line)
|
| 102 |
+
|
| 103 |
+
fig = go.Figure(data=lines)
|
| 104 |
+
|
| 105 |
+
# Stylize layout
|
| 106 |
+
fig.update_xaxes(range=[0, len(indices[0])], tick0=1, dtick=2)
|
| 107 |
+
fig.update_layout(
|
| 108 |
+
showlegend=False,
|
| 109 |
+
template="plotly_white",
|
| 110 |
+
title={
|
| 111 |
+
'text': f"{title}",
|
| 112 |
+
'y': .9,
|
| 113 |
+
'x': 0.5,
|
| 114 |
+
'xanchor': 'center',
|
| 115 |
+
'yanchor': 'top',
|
| 116 |
+
'font': dict(
|
| 117 |
+
size=22,
|
| 118 |
+
color="Black")
|
| 119 |
+
},
|
| 120 |
+
width=width,
|
| 121 |
+
height=height,
|
| 122 |
+
hoverlabel=dict(
|
| 123 |
+
bgcolor="white",
|
| 124 |
+
font_size=16,
|
| 125 |
+
font_family="Rockwell"
|
| 126 |
+
),
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
fig.update_xaxes(title_text='Term Rank')
|
| 130 |
+
if log_scale:
|
| 131 |
+
fig.update_yaxes(title_text='c-TF-IDF score (log scale)')
|
| 132 |
+
else:
|
| 133 |
+
fig.update_yaxes(title_text='c-TF-IDF score')
|
| 134 |
+
|
| 135 |
+
return fig
|
BERTopic/bertopic/plotting/_topics.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from umap import UMAP
|
| 4 |
+
from typing import List, Union
|
| 5 |
+
from sklearn.preprocessing import MinMaxScaler
|
| 6 |
+
|
| 7 |
+
import plotly.express as px
|
| 8 |
+
import plotly.graph_objects as go
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def visualize_topics(topic_model,
|
| 12 |
+
topics: List[int] = None,
|
| 13 |
+
top_n_topics: int = None,
|
| 14 |
+
custom_labels: Union[bool, str] = False,
|
| 15 |
+
title: str = "<b>Intertopic Distance Map</b>",
|
| 16 |
+
width: int = 650,
|
| 17 |
+
height: int = 650) -> go.Figure:
|
| 18 |
+
""" Visualize topics, their sizes, and their corresponding words
|
| 19 |
+
|
| 20 |
+
This visualization is highly inspired by LDAvis, a great visualization
|
| 21 |
+
technique typically reserved for LDA.
|
| 22 |
+
|
| 23 |
+
Arguments:
|
| 24 |
+
topic_model: A fitted BERTopic instance.
|
| 25 |
+
topics: A selection of topics to visualize
|
| 26 |
+
top_n_topics: Only select the top n most frequent topics
|
| 27 |
+
custom_labels: If bool, whether to use custom topic labels that were defined using
|
| 28 |
+
`topic_model.set_topic_labels`.
|
| 29 |
+
If `str`, it uses labels from other aspects, e.g., "Aspect1".
|
| 30 |
+
title: Title of the plot.
|
| 31 |
+
width: The width of the figure.
|
| 32 |
+
height: The height of the figure.
|
| 33 |
+
|
| 34 |
+
Examples:
|
| 35 |
+
|
| 36 |
+
To visualize the topics simply run:
|
| 37 |
+
|
| 38 |
+
```python
|
| 39 |
+
topic_model.visualize_topics()
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
Or if you want to save the resulting figure:
|
| 43 |
+
|
| 44 |
+
```python
|
| 45 |
+
fig = topic_model.visualize_topics()
|
| 46 |
+
fig.write_html("path/to/file.html")
|
| 47 |
+
```
|
| 48 |
+
<iframe src="../../getting_started/visualization/viz.html"
|
| 49 |
+
style="width:1000px; height: 680px; border: 0px;""></iframe>
|
| 50 |
+
"""
|
| 51 |
+
# Select topics based on top_n and topics args
|
| 52 |
+
freq_df = topic_model.get_topic_freq()
|
| 53 |
+
freq_df = freq_df.loc[freq_df.Topic != -1, :]
|
| 54 |
+
if topics is not None:
|
| 55 |
+
topics = list(topics)
|
| 56 |
+
elif top_n_topics is not None:
|
| 57 |
+
topics = sorted(freq_df.Topic.to_list()[:top_n_topics])
|
| 58 |
+
else:
|
| 59 |
+
topics = sorted(freq_df.Topic.to_list())
|
| 60 |
+
|
| 61 |
+
# Extract topic words and their frequencies
|
| 62 |
+
topic_list = sorted(topics)
|
| 63 |
+
frequencies = [topic_model.topic_sizes_[topic] for topic in topic_list]
|
| 64 |
+
if isinstance(custom_labels, str):
|
| 65 |
+
words = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in topic_list]
|
| 66 |
+
words = ["_".join([label[0] for label in labels[:4]]) for labels in words]
|
| 67 |
+
words = [label if len(label) < 30 else label[:27] + "..." for label in words]
|
| 68 |
+
elif custom_labels and topic_model.custom_labels_ is not None:
|
| 69 |
+
words = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in topic_list]
|
| 70 |
+
else:
|
| 71 |
+
words = [" | ".join([word[0] for word in topic_model.get_topic(topic)[:5]]) for topic in topic_list]
|
| 72 |
+
|
| 73 |
+
# Embed c-TF-IDF into 2D
|
| 74 |
+
all_topics = sorted(list(topic_model.get_topics().keys()))
|
| 75 |
+
indices = np.array([all_topics.index(topic) for topic in topics])
|
| 76 |
+
|
| 77 |
+
if topic_model.topic_embeddings_ is not None:
|
| 78 |
+
embeddings = topic_model.topic_embeddings_[indices]
|
| 79 |
+
embeddings = UMAP(n_neighbors=2, n_components=2, metric='cosine', random_state=42).fit_transform(embeddings)
|
| 80 |
+
else:
|
| 81 |
+
embeddings = topic_model.c_tf_idf_.toarray()[indices]
|
| 82 |
+
embeddings = MinMaxScaler().fit_transform(embeddings)
|
| 83 |
+
embeddings = UMAP(n_neighbors=2, n_components=2, metric='hellinger', random_state=42).fit_transform(embeddings)
|
| 84 |
+
|
| 85 |
+
# Visualize with plotly
|
| 86 |
+
df = pd.DataFrame({"x": embeddings[:, 0], "y": embeddings[:, 1],
|
| 87 |
+
"Topic": topic_list, "Words": words, "Size": frequencies})
|
| 88 |
+
return _plotly_topic_visualization(df, topic_list, title, width, height)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _plotly_topic_visualization(df: pd.DataFrame,
|
| 92 |
+
topic_list: List[str],
|
| 93 |
+
title: str,
|
| 94 |
+
width: int,
|
| 95 |
+
height: int):
|
| 96 |
+
""" Create plotly-based visualization of topics with a slider for topic selection """
|
| 97 |
+
|
| 98 |
+
def get_color(topic_selected):
|
| 99 |
+
if topic_selected == -1:
|
| 100 |
+
marker_color = ["#B0BEC5" for _ in topic_list]
|
| 101 |
+
else:
|
| 102 |
+
marker_color = ["red" if topic == topic_selected else "#B0BEC5" for topic in topic_list]
|
| 103 |
+
return [{'marker.color': [marker_color]}]
|
| 104 |
+
|
| 105 |
+
# Prepare figure range
|
| 106 |
+
x_range = (df.x.min() - abs((df.x.min()) * .15), df.x.max() + abs((df.x.max()) * .15))
|
| 107 |
+
y_range = (df.y.min() - abs((df.y.min()) * .15), df.y.max() + abs((df.y.max()) * .15))
|
| 108 |
+
|
| 109 |
+
# Plot topics
|
| 110 |
+
fig = px.scatter(df, x="x", y="y", size="Size", size_max=40, template="simple_white", labels={"x": "", "y": ""},
|
| 111 |
+
hover_data={"Topic": True, "Words": True, "Size": True, "x": False, "y": False})
|
| 112 |
+
fig.update_traces(marker=dict(color="#B0BEC5", line=dict(width=2, color='DarkSlateGrey')))
|
| 113 |
+
|
| 114 |
+
# Update hover order
|
| 115 |
+
fig.update_traces(hovertemplate="<br>".join(["<b>Topic %{customdata[0]}</b>",
|
| 116 |
+
"%{customdata[1]}",
|
| 117 |
+
"Size: %{customdata[2]}"]))
|
| 118 |
+
|
| 119 |
+
# Create a slider for topic selection
|
| 120 |
+
steps = [dict(label=f"Topic {topic}", method="update", args=get_color(topic)) for topic in topic_list]
|
| 121 |
+
sliders = [dict(active=0, pad={"t": 50}, steps=steps)]
|
| 122 |
+
|
| 123 |
+
# Stylize layout
|
| 124 |
+
fig.update_layout(
|
| 125 |
+
title={
|
| 126 |
+
'text': f"{title}",
|
| 127 |
+
'y': .95,
|
| 128 |
+
'x': 0.5,
|
| 129 |
+
'xanchor': 'center',
|
| 130 |
+
'yanchor': 'top',
|
| 131 |
+
'font': dict(
|
| 132 |
+
size=22,
|
| 133 |
+
color="Black")
|
| 134 |
+
},
|
| 135 |
+
width=width,
|
| 136 |
+
height=height,
|
| 137 |
+
hoverlabel=dict(
|
| 138 |
+
bgcolor="white",
|
| 139 |
+
font_size=16,
|
| 140 |
+
font_family="Rockwell"
|
| 141 |
+
),
|
| 142 |
+
xaxis={"visible": False},
|
| 143 |
+
yaxis={"visible": False},
|
| 144 |
+
sliders=sliders
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# Update axes ranges
|
| 148 |
+
fig.update_xaxes(range=x_range)
|
| 149 |
+
fig.update_yaxes(range=y_range)
|
| 150 |
+
|
| 151 |
+
# Add grid in a 'plus' shape
|
| 152 |
+
fig.add_shape(type="line",
|
| 153 |
+
x0=sum(x_range) / 2, y0=y_range[0], x1=sum(x_range) / 2, y1=y_range[1],
|
| 154 |
+
line=dict(color="#CFD8DC", width=2))
|
| 155 |
+
fig.add_shape(type="line",
|
| 156 |
+
x0=x_range[0], y0=sum(y_range) / 2, x1=x_range[1], y1=sum(y_range) / 2,
|
| 157 |
+
line=dict(color="#9E9E9E", width=2))
|
| 158 |
+
fig.add_annotation(x=x_range[0], y=sum(y_range) / 2, text="D1", showarrow=False, yshift=10)
|
| 159 |
+
fig.add_annotation(y=y_range[1], x=sum(x_range) / 2, text="D2", showarrow=False, xshift=10)
|
| 160 |
+
fig.data = fig.data[::-1]
|
| 161 |
+
|
| 162 |
+
return fig
|
BERTopic/bertopic/plotting/_topics_over_time.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from typing import List, Union
|
| 3 |
+
import plotly.graph_objects as go
|
| 4 |
+
from sklearn.preprocessing import normalize
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def visualize_topics_over_time(topic_model,
|
| 8 |
+
topics_over_time: pd.DataFrame,
|
| 9 |
+
top_n_topics: int = None,
|
| 10 |
+
topics: List[int] = None,
|
| 11 |
+
normalize_frequency: bool = False,
|
| 12 |
+
custom_labels: Union[bool, str] = False,
|
| 13 |
+
title: str = "<b>Topics over Time</b>",
|
| 14 |
+
width: int = 1250,
|
| 15 |
+
height: int = 450) -> go.Figure:
|
| 16 |
+
""" Visualize topics over time
|
| 17 |
+
|
| 18 |
+
Arguments:
|
| 19 |
+
topic_model: A fitted BERTopic instance.
|
| 20 |
+
topics_over_time: The topics you would like to be visualized with the
|
| 21 |
+
corresponding topic representation
|
| 22 |
+
top_n_topics: To visualize the most frequent topics instead of all
|
| 23 |
+
topics: Select which topics you would like to be visualized
|
| 24 |
+
normalize_frequency: Whether to normalize each topic's frequency individually
|
| 25 |
+
custom_labels: If bool, whether to use custom topic labels that were defined using
|
| 26 |
+
`topic_model.set_topic_labels`.
|
| 27 |
+
If `str`, it uses labels from other aspects, e.g., "Aspect1".
|
| 28 |
+
title: Title of the plot.
|
| 29 |
+
width: The width of the figure.
|
| 30 |
+
height: The height of the figure.
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
A plotly.graph_objects.Figure including all traces
|
| 34 |
+
|
| 35 |
+
Examples:
|
| 36 |
+
|
| 37 |
+
To visualize the topics over time, simply run:
|
| 38 |
+
|
| 39 |
+
```python
|
| 40 |
+
topics_over_time = topic_model.topics_over_time(docs, timestamps)
|
| 41 |
+
topic_model.visualize_topics_over_time(topics_over_time)
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
Or if you want to save the resulting figure:
|
| 45 |
+
|
| 46 |
+
```python
|
| 47 |
+
fig = topic_model.visualize_topics_over_time(topics_over_time)
|
| 48 |
+
fig.write_html("path/to/file.html")
|
| 49 |
+
```
|
| 50 |
+
<iframe src="../../getting_started/visualization/trump.html"
|
| 51 |
+
style="width:1000px; height: 680px; border: 0px;""></iframe>
|
| 52 |
+
"""
|
| 53 |
+
colors = ["#E69F00", "#56B4E9", "#009E73", "#F0E442", "#D55E00", "#0072B2", "#CC79A7"]
|
| 54 |
+
|
| 55 |
+
# Select topics based on top_n and topics args
|
| 56 |
+
freq_df = topic_model.get_topic_freq()
|
| 57 |
+
freq_df = freq_df.loc[freq_df.Topic != -1, :]
|
| 58 |
+
if topics is not None:
|
| 59 |
+
selected_topics = list(topics)
|
| 60 |
+
elif top_n_topics is not None:
|
| 61 |
+
selected_topics = sorted(freq_df.Topic.to_list()[:top_n_topics])
|
| 62 |
+
else:
|
| 63 |
+
selected_topics = sorted(freq_df.Topic.to_list())
|
| 64 |
+
|
| 65 |
+
# Prepare data
|
| 66 |
+
if isinstance(custom_labels, str):
|
| 67 |
+
topic_names = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in topics]
|
| 68 |
+
topic_names = ["_".join([label[0] for label in labels[:4]]) for labels in topic_names]
|
| 69 |
+
topic_names = [label if len(label) < 30 else label[:27] + "..." for label in topic_names]
|
| 70 |
+
topic_names = {key: topic_names[index] for index, key in enumerate(topic_model.topic_labels_.keys())}
|
| 71 |
+
elif topic_model.custom_labels_ is not None and custom_labels:
|
| 72 |
+
topic_names = {key: topic_model.custom_labels_[key + topic_model._outliers] for key, _ in topic_model.topic_labels_.items()}
|
| 73 |
+
else:
|
| 74 |
+
topic_names = {key: value[:40] + "..." if len(value) > 40 else value
|
| 75 |
+
for key, value in topic_model.topic_labels_.items()}
|
| 76 |
+
topics_over_time["Name"] = topics_over_time.Topic.map(topic_names)
|
| 77 |
+
data = topics_over_time.loc[topics_over_time.Topic.isin(selected_topics), :].sort_values(["Topic", "Timestamp"])
|
| 78 |
+
|
| 79 |
+
# Add traces
|
| 80 |
+
fig = go.Figure()
|
| 81 |
+
for index, topic in enumerate(data.Topic.unique()):
|
| 82 |
+
trace_data = data.loc[data.Topic == topic, :]
|
| 83 |
+
topic_name = trace_data.Name.values[0]
|
| 84 |
+
words = trace_data.Words.values
|
| 85 |
+
if normalize_frequency:
|
| 86 |
+
y = normalize(trace_data.Frequency.values.reshape(1, -1))[0]
|
| 87 |
+
else:
|
| 88 |
+
y = trace_data.Frequency
|
| 89 |
+
fig.add_trace(go.Scatter(x=trace_data.Timestamp, y=y,
|
| 90 |
+
mode='lines',
|
| 91 |
+
marker_color=colors[index % 7],
|
| 92 |
+
hoverinfo="text",
|
| 93 |
+
name=topic_name,
|
| 94 |
+
hovertext=[f'<b>Topic {topic}</b><br>Words: {word}' for word in words]))
|
| 95 |
+
|
| 96 |
+
# Styling of the visualization
|
| 97 |
+
fig.update_xaxes(showgrid=True)
|
| 98 |
+
fig.update_yaxes(showgrid=True)
|
| 99 |
+
fig.update_layout(
|
| 100 |
+
yaxis_title="Normalized Frequency" if normalize_frequency else "Frequency",
|
| 101 |
+
title={
|
| 102 |
+
'text': f"{title}",
|
| 103 |
+
'y': .95,
|
| 104 |
+
'x': 0.40,
|
| 105 |
+
'xanchor': 'center',
|
| 106 |
+
'yanchor': 'top',
|
| 107 |
+
'font': dict(
|
| 108 |
+
size=22,
|
| 109 |
+
color="Black")
|
| 110 |
+
},
|
| 111 |
+
template="simple_white",
|
| 112 |
+
width=width,
|
| 113 |
+
height=height,
|
| 114 |
+
hoverlabel=dict(
|
| 115 |
+
bgcolor="white",
|
| 116 |
+
font_size=16,
|
| 117 |
+
font_family="Rockwell"
|
| 118 |
+
),
|
| 119 |
+
legend=dict(
|
| 120 |
+
title="<b>Global Topic Representation",
|
| 121 |
+
)
|
| 122 |
+
)
|
| 123 |
+
return fig
|
BERTopic/bertopic/plotting/_topics_per_class.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from typing import List, Union
|
| 3 |
+
import plotly.graph_objects as go
|
| 4 |
+
from sklearn.preprocessing import normalize
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def visualize_topics_per_class(topic_model,
|
| 8 |
+
topics_per_class: pd.DataFrame,
|
| 9 |
+
top_n_topics: int = 10,
|
| 10 |
+
topics: List[int] = None,
|
| 11 |
+
normalize_frequency: bool = False,
|
| 12 |
+
custom_labels: Union[bool, str] = False,
|
| 13 |
+
title: str = "<b>Topics per Class</b>",
|
| 14 |
+
width: int = 1250,
|
| 15 |
+
height: int = 900) -> go.Figure:
|
| 16 |
+
""" Visualize topics per class
|
| 17 |
+
|
| 18 |
+
Arguments:
|
| 19 |
+
topic_model: A fitted BERTopic instance.
|
| 20 |
+
topics_per_class: The topics you would like to be visualized with the
|
| 21 |
+
corresponding topic representation
|
| 22 |
+
top_n_topics: To visualize the most frequent topics instead of all
|
| 23 |
+
topics: Select which topics you would like to be visualized
|
| 24 |
+
normalize_frequency: Whether to normalize each topic's frequency individually
|
| 25 |
+
custom_labels: If bool, whether to use custom topic labels that were defined using
|
| 26 |
+
`topic_model.set_topic_labels`.
|
| 27 |
+
If `str`, it uses labels from other aspects, e.g., "Aspect1".
|
| 28 |
+
title: Title of the plot.
|
| 29 |
+
width: The width of the figure.
|
| 30 |
+
height: The height of the figure.
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
A plotly.graph_objects.Figure including all traces
|
| 34 |
+
|
| 35 |
+
Examples:
|
| 36 |
+
|
| 37 |
+
To visualize the topics per class, simply run:
|
| 38 |
+
|
| 39 |
+
```python
|
| 40 |
+
topics_per_class = topic_model.topics_per_class(docs, classes)
|
| 41 |
+
topic_model.visualize_topics_per_class(topics_per_class)
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
Or if you want to save the resulting figure:
|
| 45 |
+
|
| 46 |
+
```python
|
| 47 |
+
fig = topic_model.visualize_topics_per_class(topics_per_class)
|
| 48 |
+
fig.write_html("path/to/file.html")
|
| 49 |
+
```
|
| 50 |
+
<iframe src="../../getting_started/visualization/topics_per_class.html"
|
| 51 |
+
style="width:1400px; height: 1000px; border: 0px;""></iframe>
|
| 52 |
+
"""
|
| 53 |
+
colors = ["#E69F00", "#56B4E9", "#009E73", "#F0E442", "#D55E00", "#0072B2", "#CC79A7"]
|
| 54 |
+
|
| 55 |
+
# Select topics based on top_n and topics args
|
| 56 |
+
freq_df = topic_model.get_topic_freq()
|
| 57 |
+
freq_df = freq_df.loc[freq_df.Topic != -1, :]
|
| 58 |
+
if topics is not None:
|
| 59 |
+
selected_topics = list(topics)
|
| 60 |
+
elif top_n_topics is not None:
|
| 61 |
+
selected_topics = sorted(freq_df.Topic.to_list()[:top_n_topics])
|
| 62 |
+
else:
|
| 63 |
+
selected_topics = sorted(freq_df.Topic.to_list())
|
| 64 |
+
|
| 65 |
+
# Prepare data
|
| 66 |
+
if isinstance(custom_labels, str):
|
| 67 |
+
topic_names = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in topics]
|
| 68 |
+
topic_names = ["_".join([label[0] for label in labels[:4]]) for labels in topic_names]
|
| 69 |
+
topic_names = [label if len(label) < 30 else label[:27] + "..." for label in topic_names]
|
| 70 |
+
topic_names = {key: topic_names[index] for index, key in enumerate(topic_model.topic_labels_.keys())}
|
| 71 |
+
elif topic_model.custom_labels_ is not None and custom_labels:
|
| 72 |
+
topic_names = {key: topic_model.custom_labels_[key + topic_model._outliers] for key, _ in topic_model.topic_labels_.items()}
|
| 73 |
+
else:
|
| 74 |
+
topic_names = {key: value[:40] + "..." if len(value) > 40 else value
|
| 75 |
+
for key, value in topic_model.topic_labels_.items()}
|
| 76 |
+
topics_per_class["Name"] = topics_per_class.Topic.map(topic_names)
|
| 77 |
+
data = topics_per_class.loc[topics_per_class.Topic.isin(selected_topics), :]
|
| 78 |
+
|
| 79 |
+
# Add traces
|
| 80 |
+
fig = go.Figure()
|
| 81 |
+
for index, topic in enumerate(selected_topics):
|
| 82 |
+
if index == 0:
|
| 83 |
+
visible = True
|
| 84 |
+
else:
|
| 85 |
+
visible = "legendonly"
|
| 86 |
+
trace_data = data.loc[data.Topic == topic, :]
|
| 87 |
+
topic_name = trace_data.Name.values[0]
|
| 88 |
+
words = trace_data.Words.values
|
| 89 |
+
if normalize_frequency:
|
| 90 |
+
x = normalize(trace_data.Frequency.values.reshape(1, -1))[0]
|
| 91 |
+
else:
|
| 92 |
+
x = trace_data.Frequency
|
| 93 |
+
fig.add_trace(go.Bar(y=trace_data.Class,
|
| 94 |
+
x=x,
|
| 95 |
+
visible=visible,
|
| 96 |
+
marker_color=colors[index % 7],
|
| 97 |
+
hoverinfo="text",
|
| 98 |
+
name=topic_name,
|
| 99 |
+
orientation="h",
|
| 100 |
+
hovertext=[f'<b>Topic {topic}</b><br>Words: {word}' for word in words]))
|
| 101 |
+
|
| 102 |
+
# Styling of the visualization
|
| 103 |
+
fig.update_xaxes(showgrid=True)
|
| 104 |
+
fig.update_yaxes(showgrid=True)
|
| 105 |
+
fig.update_layout(
|
| 106 |
+
xaxis_title="Normalized Frequency" if normalize_frequency else "Frequency",
|
| 107 |
+
yaxis_title="Class",
|
| 108 |
+
title={
|
| 109 |
+
'text': f"{title}",
|
| 110 |
+
'y': .95,
|
| 111 |
+
'x': 0.40,
|
| 112 |
+
'xanchor': 'center',
|
| 113 |
+
'yanchor': 'top',
|
| 114 |
+
'font': dict(
|
| 115 |
+
size=22,
|
| 116 |
+
color="Black")
|
| 117 |
+
},
|
| 118 |
+
template="simple_white",
|
| 119 |
+
width=width,
|
| 120 |
+
height=height,
|
| 121 |
+
hoverlabel=dict(
|
| 122 |
+
bgcolor="white",
|
| 123 |
+
font_size=16,
|
| 124 |
+
font_family="Rockwell"
|
| 125 |
+
),
|
| 126 |
+
legend=dict(
|
| 127 |
+
title="<b>Global Topic Representation",
|
| 128 |
+
)
|
| 129 |
+
)
|
| 130 |
+
return fig
|
BERTopic/bertopic/representation/__init__.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from bertopic._utils import NotInstalled
|
| 2 |
+
from bertopic.representation._cohere import Cohere
|
| 3 |
+
from bertopic.representation._base import BaseRepresentation
|
| 4 |
+
from bertopic.representation._keybert import KeyBERTInspired
|
| 5 |
+
from bertopic.representation._mmr import MaximalMarginalRelevance
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# Llama CPP Generator
|
| 9 |
+
try:
|
| 10 |
+
from bertopic.representation._llamacpp import LlamaCPP
|
| 11 |
+
except ModuleNotFoundError:
|
| 12 |
+
msg = "`pip install llama-cpp-python` \n\n"
|
| 13 |
+
LlamaCPP = NotInstalled("llama.cpp", "llama-cpp-python", custom_msg=msg)
|
| 14 |
+
|
| 15 |
+
# Text Generation using transformers
|
| 16 |
+
try:
|
| 17 |
+
from bertopic.representation._textgeneration import TextGeneration
|
| 18 |
+
except ModuleNotFoundError:
|
| 19 |
+
msg = "`pip install bertopic` without `--no-deps` \n\n"
|
| 20 |
+
TextGeneration = NotInstalled("TextGeneration", "transformers", custom_msg=msg)
|
| 21 |
+
|
| 22 |
+
# Zero-shot classification using transformers
|
| 23 |
+
try:
|
| 24 |
+
from bertopic.representation._zeroshot import ZeroShotClassification
|
| 25 |
+
except ModuleNotFoundError:
|
| 26 |
+
msg = "`pip install bertopic` without `--no-deps` \n\n"
|
| 27 |
+
ZeroShotClassification = NotInstalled("ZeroShotClassification", "transformers", custom_msg=msg)
|
| 28 |
+
|
| 29 |
+
# OpenAI Generator
|
| 30 |
+
try:
|
| 31 |
+
from bertopic.representation._openai import OpenAI
|
| 32 |
+
except ModuleNotFoundError:
|
| 33 |
+
msg = "`pip install openai` \n\n"
|
| 34 |
+
OpenAI = NotInstalled("OpenAI", "openai", custom_msg=msg)
|
| 35 |
+
|
| 36 |
+
# LangChain Generator
|
| 37 |
+
try:
|
| 38 |
+
from bertopic.representation._langchain import LangChain
|
| 39 |
+
except ModuleNotFoundError:
|
| 40 |
+
msg = "`pip install langchain` \n\n"
|
| 41 |
+
LangChain = NotInstalled("langchain", "langchain", custom_msg=msg)
|
| 42 |
+
|
| 43 |
+
# POS using Spacy
|
| 44 |
+
try:
|
| 45 |
+
from bertopic.representation._pos import PartOfSpeech
|
| 46 |
+
except ModuleNotFoundError:
|
| 47 |
+
PartOfSpeech = NotInstalled("Part of Speech with Spacy", "spacy")
|
| 48 |
+
|
| 49 |
+
# Multimodal
|
| 50 |
+
try:
|
| 51 |
+
from bertopic.representation._visual import VisualRepresentation
|
| 52 |
+
except ModuleNotFoundError:
|
| 53 |
+
VisualRepresentation = NotInstalled("a visual representation model", "vision")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
__all__ = [
|
| 57 |
+
"BaseRepresentation",
|
| 58 |
+
"TextGeneration",
|
| 59 |
+
"ZeroShotClassification",
|
| 60 |
+
"KeyBERTInspired",
|
| 61 |
+
"PartOfSpeech",
|
| 62 |
+
"MaximalMarginalRelevance",
|
| 63 |
+
"Cohere",
|
| 64 |
+
"OpenAI",
|
| 65 |
+
"LangChain",
|
| 66 |
+
"LlamaCPP",
|
| 67 |
+
"VisualRepresentation"
|
| 68 |
+
]
|
BERTopic/bertopic/representation/_base.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from scipy.sparse import csr_matrix
|
| 3 |
+
from sklearn.base import BaseEstimator
|
| 4 |
+
from typing import Mapping, List, Tuple
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class BaseRepresentation(BaseEstimator):
|
| 8 |
+
""" The base representation model for fine-tuning topic representations """
|
| 9 |
+
def extract_topics(self,
|
| 10 |
+
topic_model,
|
| 11 |
+
documents: pd.DataFrame,
|
| 12 |
+
c_tf_idf: csr_matrix,
|
| 13 |
+
topics: Mapping[str, List[Tuple[str, float]]]
|
| 14 |
+
) -> Mapping[str, List[Tuple[str, float]]]:
|
| 15 |
+
""" Extract topics
|
| 16 |
+
|
| 17 |
+
Each representation model that inherits this class will have
|
| 18 |
+
its arguments (topic_model, documents, c_tf_idf, topics)
|
| 19 |
+
automatically passed. Therefore, the representation model
|
| 20 |
+
will only have access to the information about topics related
|
| 21 |
+
to those arguments.
|
| 22 |
+
|
| 23 |
+
Arguments:
|
| 24 |
+
topic_model: The BERTopic model that is fitted until topic
|
| 25 |
+
representations are calculated.
|
| 26 |
+
documents: A dataframe with columns "Document" and "Topic"
|
| 27 |
+
that contains all documents with each corresponding
|
| 28 |
+
topic.
|
| 29 |
+
c_tf_idf: A c-TF-IDF representation that is typically
|
| 30 |
+
identical to `topic_model.c_tf_idf_` except for
|
| 31 |
+
dynamic, class-based, and hierarchical topic modeling
|
| 32 |
+
where it is calculated on a subset of the documents.
|
| 33 |
+
topics: A dictionary with topic (key) and tuple of word and
|
| 34 |
+
weight (value) as calculated by c-TF-IDF. This is the
|
| 35 |
+
default topics that are returned if no representation
|
| 36 |
+
model is used.
|
| 37 |
+
"""
|
| 38 |
+
return topic_model.topic_representations_
|
BERTopic/bertopic/representation/_cohere.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from scipy.sparse import csr_matrix
|
| 5 |
+
from typing import Mapping, List, Tuple, Union, Callable
|
| 6 |
+
from bertopic.representation._base import BaseRepresentation
|
| 7 |
+
from bertopic.representation._utils import truncate_document
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
DEFAULT_PROMPT = """
|
| 11 |
+
This is a list of texts where each collection of texts describe a topic. After each collection of texts, the name of the topic they represent is mentioned as a short-highly-descriptive title
|
| 12 |
+
---
|
| 13 |
+
Topic:
|
| 14 |
+
Sample texts from this topic:
|
| 15 |
+
- Traditional diets in most cultures were primarily plant-based with a little meat on top, but with the rise of industrial style meat production and factory farming, meat has become a staple food.
|
| 16 |
+
- Meat, but especially beef, is the word food in terms of emissions.
|
| 17 |
+
- Eating meat doesn't make you a bad person, not eating meat doesn't make you a good one.
|
| 18 |
+
|
| 19 |
+
Keywords: meat beef eat eating emissions steak food health processed chicken
|
| 20 |
+
Topic name: Environmental impacts of eating meat
|
| 21 |
+
---
|
| 22 |
+
Topic:
|
| 23 |
+
Sample texts from this topic:
|
| 24 |
+
- I have ordered the product weeks ago but it still has not arrived!
|
| 25 |
+
- The website mentions that it only takes a couple of days to deliver but I still have not received mine.
|
| 26 |
+
- I got a message stating that I received the monitor but that is not true!
|
| 27 |
+
- It took a month longer to deliver than was advised...
|
| 28 |
+
|
| 29 |
+
Keywords: deliver weeks product shipping long delivery received arrived arrive week
|
| 30 |
+
Topic name: Shipping and delivery issues
|
| 31 |
+
---
|
| 32 |
+
Topic:
|
| 33 |
+
Sample texts from this topic:
|
| 34 |
+
[DOCUMENTS]
|
| 35 |
+
Keywords: [KEYWORDS]
|
| 36 |
+
Topic name:"""
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Cohere(BaseRepresentation):
|
| 40 |
+
""" Use the Cohere API to generate topic labels based on their
|
| 41 |
+
generative model.
|
| 42 |
+
|
| 43 |
+
Find more about their models here:
|
| 44 |
+
https://docs.cohere.ai/docs
|
| 45 |
+
|
| 46 |
+
Arguments:
|
| 47 |
+
client: A `cohere.Client`
|
| 48 |
+
model: Model to use within Cohere, defaults to `"xlarge"`.
|
| 49 |
+
prompt: The prompt to be used in the model. If no prompt is given,
|
| 50 |
+
`self.default_prompt_` is used instead.
|
| 51 |
+
NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt
|
| 52 |
+
to decide where the keywords and documents need to be
|
| 53 |
+
inserted.
|
| 54 |
+
delay_in_seconds: The delay in seconds between consecutive prompts
|
| 55 |
+
in order to prevent RateLimitErrors.
|
| 56 |
+
nr_docs: The number of documents to pass to OpenAI if a prompt
|
| 57 |
+
with the `["DOCUMENTS"]` tag is used.
|
| 58 |
+
diversity: The diversity of documents to pass to OpenAI.
|
| 59 |
+
Accepts values between 0 and 1. A higher
|
| 60 |
+
values results in passing more diverse documents
|
| 61 |
+
whereas lower values passes more similar documents.
|
| 62 |
+
doc_length: The maximum length of each document. If a document is longer,
|
| 63 |
+
it will be truncated. If None, the entire document is passed.
|
| 64 |
+
tokenizer: The tokenizer used to calculate to split the document into segments
|
| 65 |
+
used to count the length of a document.
|
| 66 |
+
* If tokenizer is 'char', then the document is split up
|
| 67 |
+
into characters which are counted to adhere to `doc_length`
|
| 68 |
+
* If tokenizer is 'whitespace', the document is split up
|
| 69 |
+
into words separated by whitespaces. These words are counted
|
| 70 |
+
and truncated depending on `doc_length`
|
| 71 |
+
* If tokenizer is 'vectorizer', then the internal CountVectorizer
|
| 72 |
+
is used to tokenize the document. These tokens are counted
|
| 73 |
+
and trunctated depending on `doc_length`
|
| 74 |
+
* If tokenizer is a callable, then that callable is used to tokenize
|
| 75 |
+
the document. These tokens are counted and truncated depending
|
| 76 |
+
on `doc_length`
|
| 77 |
+
|
| 78 |
+
Usage:
|
| 79 |
+
|
| 80 |
+
To use this, you will need to install cohere first:
|
| 81 |
+
|
| 82 |
+
`pip install cohere`
|
| 83 |
+
|
| 84 |
+
Then, get yourself an API key and use Cohere's API as follows:
|
| 85 |
+
|
| 86 |
+
```python
|
| 87 |
+
import cohere
|
| 88 |
+
from bertopic.representation import Cohere
|
| 89 |
+
from bertopic import BERTopic
|
| 90 |
+
|
| 91 |
+
# Create your representation model
|
| 92 |
+
co = cohere.Client(my_api_key)
|
| 93 |
+
representation_model = Cohere(co)
|
| 94 |
+
|
| 95 |
+
# Use the representation model in BERTopic on top of the default pipeline
|
| 96 |
+
topic_model = BERTopic(representation_model=representation_model)
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
You can also use a custom prompt:
|
| 100 |
+
|
| 101 |
+
```python
|
| 102 |
+
prompt = "I have the following documents: [DOCUMENTS]. What topic do they contain?"
|
| 103 |
+
representation_model = Cohere(co, prompt=prompt)
|
| 104 |
+
```
|
| 105 |
+
"""
|
| 106 |
+
def __init__(self,
|
| 107 |
+
client,
|
| 108 |
+
model: str = "xlarge",
|
| 109 |
+
prompt: str = None,
|
| 110 |
+
delay_in_seconds: float = None,
|
| 111 |
+
nr_docs: int = 4,
|
| 112 |
+
diversity: float = None,
|
| 113 |
+
doc_length: int = None,
|
| 114 |
+
tokenizer: Union[str, Callable] = None
|
| 115 |
+
):
|
| 116 |
+
self.client = client
|
| 117 |
+
self.model = model
|
| 118 |
+
self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
|
| 119 |
+
self.default_prompt_ = DEFAULT_PROMPT
|
| 120 |
+
self.delay_in_seconds = delay_in_seconds
|
| 121 |
+
self.nr_docs = nr_docs
|
| 122 |
+
self.diversity = diversity
|
| 123 |
+
self.doc_length = doc_length
|
| 124 |
+
self.tokenizer = tokenizer
|
| 125 |
+
self.prompts_ = []
|
| 126 |
+
|
| 127 |
+
def extract_topics(self,
|
| 128 |
+
topic_model,
|
| 129 |
+
documents: pd.DataFrame,
|
| 130 |
+
c_tf_idf: csr_matrix,
|
| 131 |
+
topics: Mapping[str, List[Tuple[str, float]]]
|
| 132 |
+
) -> Mapping[str, List[Tuple[str, float]]]:
|
| 133 |
+
""" Extract topics
|
| 134 |
+
|
| 135 |
+
Arguments:
|
| 136 |
+
topic_model: Not used
|
| 137 |
+
documents: Not used
|
| 138 |
+
c_tf_idf: Not used
|
| 139 |
+
topics: The candidate topics as calculated with c-TF-IDF
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
updated_topics: Updated topic representations
|
| 143 |
+
"""
|
| 144 |
+
# Extract the top 4 representative documents per topic
|
| 145 |
+
repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs(c_tf_idf, documents, topics, 500, self.nr_docs, self.diversity)
|
| 146 |
+
|
| 147 |
+
# Generate using Cohere's Language Model
|
| 148 |
+
updated_topics = {}
|
| 149 |
+
for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose):
|
| 150 |
+
truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs]
|
| 151 |
+
prompt = self._create_prompt(truncated_docs, topic, topics)
|
| 152 |
+
self.prompts_.append(prompt)
|
| 153 |
+
|
| 154 |
+
# Delay
|
| 155 |
+
if self.delay_in_seconds:
|
| 156 |
+
time.sleep(self.delay_in_seconds)
|
| 157 |
+
|
| 158 |
+
request = self.client.generate(model=self.model,
|
| 159 |
+
prompt=prompt,
|
| 160 |
+
max_tokens=50,
|
| 161 |
+
num_generations=1,
|
| 162 |
+
stop_sequences=["\n"])
|
| 163 |
+
label = request.generations[0].text.strip()
|
| 164 |
+
updated_topics[topic] = [(label, 1)] + [("", 0) for _ in range(9)]
|
| 165 |
+
|
| 166 |
+
return updated_topics
|
| 167 |
+
|
| 168 |
+
def _create_prompt(self, docs, topic, topics):
|
| 169 |
+
keywords = list(zip(*topics[topic]))[0]
|
| 170 |
+
|
| 171 |
+
# Use the Default Chat Prompt
|
| 172 |
+
if self.prompt == DEFAULT_PROMPT:
|
| 173 |
+
prompt = self.prompt.replace("[KEYWORDS]", ", ".join(keywords))
|
| 174 |
+
prompt = self._replace_documents(prompt, docs)
|
| 175 |
+
|
| 176 |
+
# Use a custom prompt that leverages keywords, documents or both using
|
| 177 |
+
# custom tags, namely [KEYWORDS] and [DOCUMENTS] respectively
|
| 178 |
+
else:
|
| 179 |
+
prompt = self.prompt
|
| 180 |
+
if "[KEYWORDS]" in prompt:
|
| 181 |
+
prompt = prompt.replace("[KEYWORDS]", ", ".join(keywords))
|
| 182 |
+
if "[DOCUMENTS]" in prompt:
|
| 183 |
+
prompt = self._replace_documents(prompt, docs)
|
| 184 |
+
|
| 185 |
+
return prompt
|
| 186 |
+
|
| 187 |
+
@staticmethod
|
| 188 |
+
def _replace_documents(prompt, docs):
|
| 189 |
+
to_replace = ""
|
| 190 |
+
for doc in docs:
|
| 191 |
+
to_replace += f"- {doc}\n"
|
| 192 |
+
prompt = prompt.replace("[DOCUMENTS]", to_replace)
|
| 193 |
+
return prompt
|
BERTopic/bertopic/representation/_keybert.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pandas as pd
|
| 3 |
+
|
| 4 |
+
from packaging import version
|
| 5 |
+
from scipy.sparse import csr_matrix
|
| 6 |
+
from typing import Mapping, List, Tuple, Union
|
| 7 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 8 |
+
from bertopic.representation._base import BaseRepresentation
|
| 9 |
+
from sklearn import __version__ as sklearn_version
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class KeyBERTInspired(BaseRepresentation):
|
| 13 |
+
def __init__(self,
|
| 14 |
+
top_n_words: int = 10,
|
| 15 |
+
nr_repr_docs: int = 5,
|
| 16 |
+
nr_samples: int = 500,
|
| 17 |
+
nr_candidate_words: int = 100,
|
| 18 |
+
random_state: int = 42):
|
| 19 |
+
""" Use a KeyBERT-like model to fine-tune the topic representations
|
| 20 |
+
|
| 21 |
+
The algorithm follows KeyBERT but does some optimization in
|
| 22 |
+
order to speed up inference.
|
| 23 |
+
|
| 24 |
+
The steps are as follows. First, we extract the top n representative
|
| 25 |
+
documents per topic. To extract the representative documents, we
|
| 26 |
+
randomly sample a number of candidate documents per cluster
|
| 27 |
+
which is controlled by the `nr_samples` parameter. Then,
|
| 28 |
+
the top n representative documents are extracted by calculating
|
| 29 |
+
the c-TF-IDF representation for the candidate documents and finding,
|
| 30 |
+
through cosine similarity, which are closest to the topic c-TF-IDF representation.
|
| 31 |
+
Next, the top n words per topic are extracted based on their
|
| 32 |
+
c-TF-IDF representation, which is controlled by the `nr_repr_docs`
|
| 33 |
+
parameter.
|
| 34 |
+
|
| 35 |
+
Then, we extract the embeddings for words and representative documents
|
| 36 |
+
and create topic embeddings by averaging the representative documents.
|
| 37 |
+
Finally, the most similar words to each topic are extracted by
|
| 38 |
+
calculating the cosine similarity between word and topic embeddings.
|
| 39 |
+
|
| 40 |
+
Arguments:
|
| 41 |
+
top_n_words: The top n words to extract per topic.
|
| 42 |
+
nr_repr_docs: The number of representative documents to extract per cluster.
|
| 43 |
+
nr_samples: The number of candidate documents to extract per cluster.
|
| 44 |
+
nr_candidate_words: The number of candidate words per cluster.
|
| 45 |
+
random_state: The random state for randomly sampling candidate documents.
|
| 46 |
+
|
| 47 |
+
Usage:
|
| 48 |
+
|
| 49 |
+
```python
|
| 50 |
+
from bertopic.representation import KeyBERTInspired
|
| 51 |
+
from bertopic import BERTopic
|
| 52 |
+
|
| 53 |
+
# Create your representation model
|
| 54 |
+
representation_model = KeyBERTInspired()
|
| 55 |
+
|
| 56 |
+
# Use the representation model in BERTopic on top of the default pipeline
|
| 57 |
+
topic_model = BERTopic(representation_model=representation_model)
|
| 58 |
+
```
|
| 59 |
+
"""
|
| 60 |
+
self.top_n_words = top_n_words
|
| 61 |
+
self.nr_repr_docs = nr_repr_docs
|
| 62 |
+
self.nr_samples = nr_samples
|
| 63 |
+
self.nr_candidate_words = nr_candidate_words
|
| 64 |
+
self.random_state = random_state
|
| 65 |
+
|
| 66 |
+
def extract_topics(self,
|
| 67 |
+
topic_model,
|
| 68 |
+
documents: pd.DataFrame,
|
| 69 |
+
c_tf_idf: csr_matrix,
|
| 70 |
+
topics: Mapping[str, List[Tuple[str, float]]]
|
| 71 |
+
) -> Mapping[str, List[Tuple[str, float]]]:
|
| 72 |
+
""" Extract topics
|
| 73 |
+
|
| 74 |
+
Arguments:
|
| 75 |
+
topic_model: A BERTopic model
|
| 76 |
+
documents: All input documents
|
| 77 |
+
c_tf_idf: The topic c-TF-IDF representation
|
| 78 |
+
topics: The candidate topics as calculated with c-TF-IDF
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
updated_topics: Updated topic representations
|
| 82 |
+
"""
|
| 83 |
+
# We extract the top n representative documents per class
|
| 84 |
+
_, representative_docs, repr_doc_indices, _ = topic_model._extract_representative_docs(c_tf_idf, documents, topics, self.nr_samples, self.nr_repr_docs)
|
| 85 |
+
|
| 86 |
+
# We extract the top n words per class
|
| 87 |
+
topics = self._extract_candidate_words(topic_model, c_tf_idf, topics)
|
| 88 |
+
|
| 89 |
+
# We calculate the similarity between word and document embeddings and create
|
| 90 |
+
# topic embeddings from the representative document embeddings
|
| 91 |
+
sim_matrix, words = self._extract_embeddings(topic_model, topics, representative_docs, repr_doc_indices)
|
| 92 |
+
|
| 93 |
+
# Find the best matching words based on the similarity matrix for each topic
|
| 94 |
+
updated_topics = self._extract_top_words(words, topics, sim_matrix)
|
| 95 |
+
|
| 96 |
+
return updated_topics
|
| 97 |
+
|
| 98 |
+
def _extract_candidate_words(self,
|
| 99 |
+
topic_model,
|
| 100 |
+
c_tf_idf: csr_matrix,
|
| 101 |
+
topics: Mapping[str, List[Tuple[str, float]]]
|
| 102 |
+
) -> Mapping[str, List[Tuple[str, float]]]:
|
| 103 |
+
""" For each topic, extract candidate words based on the c-TF-IDF
|
| 104 |
+
representation.
|
| 105 |
+
|
| 106 |
+
Arguments:
|
| 107 |
+
topic_model: A BERTopic model
|
| 108 |
+
c_tf_idf: The topic c-TF-IDF representation
|
| 109 |
+
topics: The top words per topic
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
topics: The `self.top_n_words` per topic
|
| 113 |
+
"""
|
| 114 |
+
labels = [int(label) for label in sorted(list(topics.keys()))]
|
| 115 |
+
|
| 116 |
+
# Scikit-Learn Deprecation: get_feature_names is deprecated in 1.0
|
| 117 |
+
# and will be removed in 1.2. Please use get_feature_names_out instead.
|
| 118 |
+
if version.parse(sklearn_version) >= version.parse("1.0.0"):
|
| 119 |
+
words = topic_model.vectorizer_model.get_feature_names_out()
|
| 120 |
+
else:
|
| 121 |
+
words = topic_model.vectorizer_model.get_feature_names()
|
| 122 |
+
|
| 123 |
+
indices = topic_model._top_n_idx_sparse(c_tf_idf, self.nr_candidate_words)
|
| 124 |
+
scores = topic_model._top_n_values_sparse(c_tf_idf, indices)
|
| 125 |
+
sorted_indices = np.argsort(scores, 1)
|
| 126 |
+
indices = np.take_along_axis(indices, sorted_indices, axis=1)
|
| 127 |
+
scores = np.take_along_axis(scores, sorted_indices, axis=1)
|
| 128 |
+
|
| 129 |
+
# Get top 30 words per topic based on c-TF-IDF score
|
| 130 |
+
topics = {label: [(words[word_index], score)
|
| 131 |
+
if word_index is not None and score > 0
|
| 132 |
+
else ("", 0.00001)
|
| 133 |
+
for word_index, score in zip(indices[index][::-1], scores[index][::-1])
|
| 134 |
+
]
|
| 135 |
+
for index, label in enumerate(labels)}
|
| 136 |
+
topics = {label: list(zip(*values[:self.nr_candidate_words]))[0] for label, values in topics.items()}
|
| 137 |
+
|
| 138 |
+
return topics
|
| 139 |
+
|
| 140 |
+
def _extract_embeddings(self,
|
| 141 |
+
topic_model,
|
| 142 |
+
topics: Mapping[str, List[Tuple[str, float]]],
|
| 143 |
+
representative_docs: List[str],
|
| 144 |
+
repr_doc_indices: List[List[int]]
|
| 145 |
+
) -> Union[np.ndarray, List[str]]:
|
| 146 |
+
""" Extract the representative document embeddings and create topic embeddings.
|
| 147 |
+
Then extract word embeddings and calculate the cosine similarity between topic
|
| 148 |
+
embeddings and the word embeddings. Topic embeddings are the average of
|
| 149 |
+
representative document embeddings.
|
| 150 |
+
|
| 151 |
+
Arguments:
|
| 152 |
+
topic_model: A BERTopic model
|
| 153 |
+
topics: The top words per topic
|
| 154 |
+
representative_docs: A flat list of representative documents
|
| 155 |
+
repr_doc_indices: The indices of representative documents
|
| 156 |
+
that belong to each topic
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
sim: The similarity matrix between word and topic embeddings
|
| 160 |
+
vocab: The complete vocabulary of input documents
|
| 161 |
+
"""
|
| 162 |
+
# Calculate representative docs embeddings and create topic embeddings
|
| 163 |
+
repr_embeddings = topic_model._extract_embeddings(representative_docs, method="document", verbose=False)
|
| 164 |
+
topic_embeddings = [np.mean(repr_embeddings[i[0]:i[-1]+1], axis=0) for i in repr_doc_indices]
|
| 165 |
+
|
| 166 |
+
# Calculate word embeddings and extract best matching with updated topic_embeddings
|
| 167 |
+
vocab = list(set([word for words in topics.values() for word in words]))
|
| 168 |
+
word_embeddings = topic_model._extract_embeddings(vocab, method="document", verbose=False)
|
| 169 |
+
sim = cosine_similarity(topic_embeddings, word_embeddings)
|
| 170 |
+
|
| 171 |
+
return sim, vocab
|
| 172 |
+
|
| 173 |
+
def _extract_top_words(self,
|
| 174 |
+
vocab: List[str],
|
| 175 |
+
topics: Mapping[str, List[Tuple[str, float]]],
|
| 176 |
+
sim: np.ndarray
|
| 177 |
+
) -> Mapping[str, List[Tuple[str, float]]]:
|
| 178 |
+
""" Extract the top n words per topic based on the
|
| 179 |
+
similarity matrix between topics and words.
|
| 180 |
+
|
| 181 |
+
Arguments:
|
| 182 |
+
vocab: The complete vocabulary of input documents
|
| 183 |
+
labels: All topic labels
|
| 184 |
+
topics: The top words per topic
|
| 185 |
+
sim: The similarity matrix between word and topic embeddings
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
updated_topics: The updated topic representations
|
| 189 |
+
"""
|
| 190 |
+
labels = [int(label) for label in sorted(list(topics.keys()))]
|
| 191 |
+
updated_topics = {}
|
| 192 |
+
for i, topic in enumerate(labels):
|
| 193 |
+
indices = [vocab.index(word) for word in topics[topic]]
|
| 194 |
+
values = sim[:, indices][i]
|
| 195 |
+
word_indices = [indices[index] for index in np.argsort(values)[-self.top_n_words:]]
|
| 196 |
+
updated_topics[topic] = [(vocab[index], val) for val, index in zip(np.sort(values)[-self.top_n_words:], word_indices)][::-1]
|
| 197 |
+
|
| 198 |
+
return updated_topics
|
BERTopic/bertopic/representation/_langchain.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from langchain.docstore.document import Document
|
| 3 |
+
from scipy.sparse import csr_matrix
|
| 4 |
+
from typing import Callable, Dict, Mapping, List, Tuple, Union
|
| 5 |
+
|
| 6 |
+
from bertopic.representation._base import BaseRepresentation
|
| 7 |
+
from bertopic.representation._utils import truncate_document
|
| 8 |
+
|
| 9 |
+
DEFAULT_PROMPT = "What are these documents about? Please give a single label."
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class LangChain(BaseRepresentation):
|
| 13 |
+
""" Using chains in langchain to generate topic labels.
|
| 14 |
+
|
| 15 |
+
The classic example uses `langchain.chains.question_answering.load_qa_chain`.
|
| 16 |
+
This returns a chain that takes a list of documents and a question as input.
|
| 17 |
+
|
| 18 |
+
You can also use Runnables such as those composed using the LangChain Expression Language.
|
| 19 |
+
|
| 20 |
+
Arguments:
|
| 21 |
+
chain: The langchain chain or Runnable with a `batch` method.
|
| 22 |
+
Input keys must be `input_documents` and `question`.
|
| 23 |
+
Output key must be `output_text`.
|
| 24 |
+
prompt: The prompt to be used in the model. If no prompt is given,
|
| 25 |
+
`self.default_prompt_` is used instead.
|
| 26 |
+
nr_docs: The number of documents to pass to LangChain if a prompt
|
| 27 |
+
with the `["DOCUMENTS"]` tag is used.
|
| 28 |
+
diversity: The diversity of documents to pass to LangChain.
|
| 29 |
+
Accepts values between 0 and 1. A higher
|
| 30 |
+
values results in passing more diverse documents
|
| 31 |
+
whereas lower values passes more similar documents.
|
| 32 |
+
doc_length: The maximum length of each document. If a document is longer,
|
| 33 |
+
it will be truncated. If None, the entire document is passed.
|
| 34 |
+
tokenizer: The tokenizer used to calculate to split the document into segments
|
| 35 |
+
used to count the length of a document.
|
| 36 |
+
* If tokenizer is 'char', then the document is split up
|
| 37 |
+
into characters which are counted to adhere to `doc_length`
|
| 38 |
+
* If tokenizer is 'whitespace', the document is split up
|
| 39 |
+
into words separated by whitespaces. These words are counted
|
| 40 |
+
and truncated depending on `doc_length`
|
| 41 |
+
* If tokenizer is 'vectorizer', then the internal CountVectorizer
|
| 42 |
+
is used to tokenize the document. These tokens are counted
|
| 43 |
+
and trunctated depending on `doc_length`. They are decoded with
|
| 44 |
+
whitespaces.
|
| 45 |
+
* If tokenizer is a callable, then that callable is used to tokenize
|
| 46 |
+
the document. These tokens are counted and truncated depending
|
| 47 |
+
on `doc_length`
|
| 48 |
+
chain_config: The configuration for the langchain chain. Can be used to set options
|
| 49 |
+
like max_concurrency to avoid rate limiting errors.
|
| 50 |
+
Usage:
|
| 51 |
+
|
| 52 |
+
To use this, you will need to install the langchain package first.
|
| 53 |
+
Additionally, you will need an underlying LLM to support langchain,
|
| 54 |
+
like openai:
|
| 55 |
+
|
| 56 |
+
`pip install langchain`
|
| 57 |
+
`pip install openai`
|
| 58 |
+
|
| 59 |
+
Then, you can create your chain as follows:
|
| 60 |
+
|
| 61 |
+
```python
|
| 62 |
+
from langchain.chains.question_answering import load_qa_chain
|
| 63 |
+
from langchain.llms import OpenAI
|
| 64 |
+
chain = load_qa_chain(OpenAI(temperature=0, openai_api_key=my_openai_api_key), chain_type="stuff")
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
Finally, you can pass the chain to BERTopic as follows:
|
| 68 |
+
|
| 69 |
+
```python
|
| 70 |
+
from bertopic.representation import LangChain
|
| 71 |
+
|
| 72 |
+
# Create your representation model
|
| 73 |
+
representation_model = LangChain(chain)
|
| 74 |
+
|
| 75 |
+
# Use the representation model in BERTopic on top of the default pipeline
|
| 76 |
+
topic_model = BERTopic(representation_model=representation_model)
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
You can also use a custom prompt:
|
| 80 |
+
|
| 81 |
+
```python
|
| 82 |
+
prompt = "What are these documents about? Please give a single label."
|
| 83 |
+
representation_model = LangChain(chain, prompt=prompt)
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
You can also use a Runnable instead of a chain.
|
| 87 |
+
The example below uses the LangChain Expression Language:
|
| 88 |
+
|
| 89 |
+
```python
|
| 90 |
+
from bertopic.representation import LangChain
|
| 91 |
+
from langchain.chains.question_answering import load_qa_chain
|
| 92 |
+
from langchain.chat_models import ChatAnthropic
|
| 93 |
+
from langchain.schema.document import Document
|
| 94 |
+
from langchain.schema.runnable import RunnablePassthrough
|
| 95 |
+
from langchain_experimental.data_anonymizer.presidio import PresidioReversibleAnonymizer
|
| 96 |
+
|
| 97 |
+
prompt = ...
|
| 98 |
+
llm = ...
|
| 99 |
+
|
| 100 |
+
# We will construct a special privacy-preserving chain using Microsoft Presidio
|
| 101 |
+
|
| 102 |
+
pii_handler = PresidioReversibleAnonymizer(analyzed_fields=["PERSON"])
|
| 103 |
+
|
| 104 |
+
chain = (
|
| 105 |
+
{
|
| 106 |
+
"input_documents": (
|
| 107 |
+
lambda inp: [
|
| 108 |
+
Document(
|
| 109 |
+
page_content=pii_handler.anonymize(
|
| 110 |
+
d.page_content,
|
| 111 |
+
language="en",
|
| 112 |
+
),
|
| 113 |
+
)
|
| 114 |
+
for d in inp["input_documents"]
|
| 115 |
+
]
|
| 116 |
+
),
|
| 117 |
+
"question": RunnablePassthrough(),
|
| 118 |
+
}
|
| 119 |
+
| load_qa_chain(representation_llm, chain_type="stuff")
|
| 120 |
+
| (lambda output: {"output_text": pii_handler.deanonymize(output["output_text"])})
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
representation_model = LangChain(chain, prompt=representation_prompt)
|
| 124 |
+
```
|
| 125 |
+
"""
|
| 126 |
+
def __init__(self,
|
| 127 |
+
chain,
|
| 128 |
+
prompt: str = None,
|
| 129 |
+
nr_docs: int = 4,
|
| 130 |
+
diversity: float = None,
|
| 131 |
+
doc_length: int = None,
|
| 132 |
+
tokenizer: Union[str, Callable] = None,
|
| 133 |
+
chain_config = None,
|
| 134 |
+
):
|
| 135 |
+
self.chain = chain
|
| 136 |
+
self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
|
| 137 |
+
self.default_prompt_ = DEFAULT_PROMPT
|
| 138 |
+
self.chain_config = chain_config
|
| 139 |
+
self.nr_docs = nr_docs
|
| 140 |
+
self.diversity = diversity
|
| 141 |
+
self.doc_length = doc_length
|
| 142 |
+
self.tokenizer = tokenizer
|
| 143 |
+
|
| 144 |
+
def extract_topics(self,
|
| 145 |
+
topic_model,
|
| 146 |
+
documents: pd.DataFrame,
|
| 147 |
+
c_tf_idf: csr_matrix,
|
| 148 |
+
topics: Mapping[str, List[Tuple[str, float]]]
|
| 149 |
+
) -> Mapping[str, List[Tuple[str, int]]]:
|
| 150 |
+
""" Extract topics
|
| 151 |
+
|
| 152 |
+
Arguments:
|
| 153 |
+
topic_model: A BERTopic model
|
| 154 |
+
documents: All input documents
|
| 155 |
+
c_tf_idf: The topic c-TF-IDF representation
|
| 156 |
+
topics: The candidate topics as calculated with c-TF-IDF
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
updated_topics: Updated topic representations
|
| 160 |
+
"""
|
| 161 |
+
# Extract the top 4 representative documents per topic
|
| 162 |
+
repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs(
|
| 163 |
+
c_tf_idf=c_tf_idf,
|
| 164 |
+
documents=documents,
|
| 165 |
+
topics=topics,
|
| 166 |
+
nr_samples=500,
|
| 167 |
+
nr_repr_docs=self.nr_docs,
|
| 168 |
+
diversity=self.diversity
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
# Generate label using langchain's batch functionality
|
| 172 |
+
chain_docs: List[List[Document]] = [
|
| 173 |
+
[
|
| 174 |
+
Document(
|
| 175 |
+
page_content=truncate_document(
|
| 176 |
+
topic_model,
|
| 177 |
+
self.doc_length,
|
| 178 |
+
self.tokenizer,
|
| 179 |
+
doc
|
| 180 |
+
)
|
| 181 |
+
)
|
| 182 |
+
for doc in docs
|
| 183 |
+
]
|
| 184 |
+
for docs in repr_docs_mappings.values()
|
| 185 |
+
]
|
| 186 |
+
|
| 187 |
+
# `self.chain` must take `input_documents` and `question` as input keys
|
| 188 |
+
inputs = [
|
| 189 |
+
{"input_documents": docs, "question": self.prompt}
|
| 190 |
+
for docs in chain_docs
|
| 191 |
+
]
|
| 192 |
+
|
| 193 |
+
# `self.chain` must return a dict with an `output_text` key
|
| 194 |
+
# same output key as the `StuffDocumentsChain` returned by `load_qa_chain`
|
| 195 |
+
outputs = self.chain.batch(inputs=inputs, config=self.chain_config)
|
| 196 |
+
labels = [output["output_text"].strip() for output in outputs]
|
| 197 |
+
|
| 198 |
+
updated_topics = {
|
| 199 |
+
topic: [(label, 1)] + [("", 0) for _ in range(9)]
|
| 200 |
+
for topic, label in zip(repr_docs_mappings.keys(), labels)
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
return updated_topics
|