jugarte00 commited on
Commit
04595e7
·
verified ·
1 Parent(s): 898ca37

Upload folder using huggingface_hub

Browse files
Files changed (45) hide show
  1. .devcontainer/.DS_Store +0 -0
  2. .devcontainer/dev/Dockerfile +9 -0
  3. .devcontainer/dev/bin/svc +23 -0
  4. .devcontainer/devcontainer.json +51 -0
  5. .devcontainer/docker-compose.yml +19 -0
  6. .devcontainer/scripts/post-create.sh +25 -0
  7. .devcontainer/scripts/post-start.sh +2 -0
  8. .gitattributes +5 -35
  9. .github/workflows/update_space.yml +32 -0
  10. .gitignore +7 -0
  11. .pre-commit-config.yaml +19 -0
  12. .vscode/launch.json +59 -0
  13. .vscode/settings.json +26 -0
  14. README.md +141 -6
  15. article_embedding/__init__.py +0 -0
  16. article_embedding/__main__.py +4 -0
  17. article_embedding/app/__init__.py +3 -0
  18. article_embedding/app/gather.py +71 -0
  19. article_embedding/app/lookup.py +11 -0
  20. article_embedding/app/model.py +71 -0
  21. article_embedding/app/rag.py +134 -0
  22. article_embedding/app/search.py +70 -0
  23. article_embedding/app/ui.py +31 -0
  24. article_embedding/benchmark.py +76 -0
  25. article_embedding/chunk.py +25 -0
  26. article_embedding/cli.py +24 -0
  27. article_embedding/constants.py +2 -0
  28. article_embedding/couchdb.py +203 -0
  29. article_embedding/embed.py +86 -0
  30. article_embedding/loader.py +146 -0
  31. article_embedding/modal_app.py +153 -0
  32. article_embedding/query.py +117 -0
  33. article_embedding/replicate.py +21 -0
  34. article_embedding/retrieval.py +330 -0
  35. article_embedding/sheets.py +7 -0
  36. article_embedding/utils.py +76 -0
  37. notebooks/couchdb.http +39 -0
  38. notebooks/models.ipynb +44 -0
  39. notebooks/query.ipynb +168 -0
  40. notebooks/stella.ipynb +67 -0
  41. poetry.lock +0 -0
  42. poetry.toml +2 -0
  43. pyproject.toml +86 -0
  44. requirements.txt +0 -0
  45. tests/__init__.py +0 -0
.devcontainer/.DS_Store ADDED
Binary file (6.15 kB). View file
 
.devcontainer/dev/Dockerfile ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ FROM mcr.microsoft.com/devcontainers/python:1-3.12-bookworm
2
+
3
+ RUN <<EOF
4
+ apt update
5
+ apt install -y \
6
+ silversearcher-ag \
7
+ iputils-ping \
8
+
9
+ EOF
.devcontainer/dev/bin/svc ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Usage: svc service [command]
3
+ # Example: svc app ps aux
4
+
5
+ # Validate arguments
6
+ if [[ $# -lt 1 ]]; then
7
+ echo "Usage: svc service [command]"
8
+ exit 1
9
+ fi
10
+
11
+ # Service
12
+ case $1 in
13
+ *)
14
+ SVC=
15
+ CMD=
16
+ ;;
17
+ esac
18
+
19
+ CMD_ARGS="${@:2}"
20
+ if [[ -z "$CMD_ARGS" ]]; then
21
+ CMD_ARGS="$CMD"
22
+ fi
23
+ docker compose exec $SVC $CMD_ARGS
.devcontainer/devcontainer.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // For format details, see https://aka.ms/devcontainer.json. For config options, see the
2
+ // README at: https://github.com/devcontainers/templates/tree/main/src/python
3
+ {
4
+ "name": "Article Embedding",
5
+ "dockerComposeFile": "docker-compose.yml",
6
+ "service": "dev",
7
+ "runServices": [
8
+ "dev",
9
+ "qdrant"
10
+ ],
11
+ "workspaceFolder": "/workspaces/article-embedding",
12
+ "features": {
13
+ "ghcr.io/devcontainers/features/github-cli:1": {},
14
+ "ghcr.io/nikobockerman/devcontainer-features/poetry-persistent-cache:1": {},
15
+ "ghcr.io/devcontainers-extra/features/poetry:2": {},
16
+ "ghcr.io/devcontainers/features/docker-outside-of-docker:1": {}
17
+ },
18
+ "containerEnv": {
19
+ "DOCKER_CLI_HINTS": "false",
20
+ "POETRY_VIRTUALENVS_IN_PROJECT": "false",
21
+ "VIRTUAL_ENV": "/mnt/poetry-persistent-cache/virtualenvs/article-embedding-p7H3l83p-py3.12",
22
+ "COMPOSE_PROJECT_NAME": "article-embedding_devcontainer",
23
+ "COMPOSE_FILE": "${containerWorkspaceFolder}/.devcontainer/docker-compose.yml"
24
+ },
25
+ "remoteEnv": {
26
+ "PATH": "${containerEnv:VIRTUAL_ENV}/bin:${containerWorkspaceFolder}/.devcontainer/dev/bin:${containerEnv:PATH}"
27
+ },
28
+ "mounts": [
29
+ "type=bind,source=${localEnv:HOME}${localEnv:USERPROFILE}/.gitconfig,target=/home/vscode/.gitconfig,readonly",
30
+ "type=bind,source=${localEnv:HOME}${localEnv:USERPROFILE}/.ssh,target=/home/vscode/.ssh,readonly"
31
+ ],
32
+ "postCreateCommand": ".devcontainer/scripts/post-create.sh ${containerWorkspaceFolder}",
33
+ "postStartCommand": ".devcontainer/scripts/post-start.sh ${containerWorkspaceFolder}",
34
+ "customizations": {
35
+ "vscode": {
36
+ "extensions": [
37
+ "ms-python.debugpy",
38
+ "github.vscode-github-actions",
39
+ "matangover.mypy",
40
+ "tamasfe.even-better-toml",
41
+ "humao.rest-client",
42
+ "charliermarsh.ruff",
43
+ "ms-toolsai.jupyter",
44
+ "redhat.vscode-yaml"
45
+ ],
46
+ "settings": {
47
+ "python.defaultInterpreterPath": "${containerEnv:VIRTUAL_ENV}/bin/python"
48
+ }
49
+ }
50
+ }
51
+ }
.devcontainer/docker-compose.yml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ dev:
3
+ profiles:
4
+ - devcontainer
5
+ build: dev
6
+ volumes:
7
+ - ../..:/workspaces:cached
8
+ - ..:/workspaces/article-embedding:cached
9
+ command: sleep infinity
10
+
11
+ qdrant:
12
+ image: qdrant/qdrant
13
+ ports:
14
+ - "6333:6333"
15
+ volumes:
16
+ - qdrant:/qdrant/storage:z
17
+
18
+ volumes:
19
+ qdrant:
.devcontainer/scripts/post-create.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Fail on any error
3
+ set -e
4
+
5
+ # Ensure the lock file is up to date, otherwise the install step may fail later
6
+ poetry lock --no-update
7
+
8
+ # Remove the old-style .venv symlink if it exists
9
+ if [[ -L .venv ]]; then
10
+ rm .venv
11
+ fi
12
+
13
+ # Force the virtual environment to the path set in the .devcontainer.json file
14
+ python -mvenv $VIRTUAL_ENV
15
+ PATH="$VIRTUAL_ENV/bin:$PATH"
16
+ poetry install
17
+
18
+ # Activate pre commit hooks
19
+ pre-commit install
20
+
21
+ # Configuration files
22
+ cat >> ~/.inputrc <<EOF
23
+ set completion-ignore-case on
24
+ set editing-mode vi
25
+ EOF
.devcontainer/scripts/post-start.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #!/bin/bash
2
+ set -e
.gitattributes CHANGED
@@ -1,35 +1,5 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ poetry.lock binary
2
+
3
+ *.ipynb filter=nbstripout
4
+ *.zpln filter=nbstripout
5
+ *.ipynb diff=ipynb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.github/workflows/update_space.yml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Run Python script
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+
8
+ jobs:
9
+ build:
10
+ runs-on: ubuntu-latest
11
+
12
+ steps:
13
+ - name: Checkout
14
+ uses: actions/checkout@v4
15
+
16
+ - name: Install Poetry
17
+ run: pipx install poetry
18
+
19
+ - name: Set up Python
20
+ uses: actions/setup-python@v5
21
+ with:
22
+ python-version: "3.11"
23
+ cache: "poetry"
24
+
25
+ - name: Deploy
26
+ run: |-
27
+ pipx inject poetry poetry-plugin-export
28
+ poetry export --only=main >requirements.txt
29
+ poetry install --only=ci --no-root -q
30
+ PATH=$(poetry env info --path)/bin:$PATH
31
+ huggingface-cli login --token ${{ secrets.HF_TOKEN }}
32
+ gradio deploy
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ .*checkpoint
2
+ .keys/
3
+ __pycache__/
4
+ .gradio/
5
+ .env
6
+ data/
7
+ .venv/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # See https://pre-commit.com for more information
2
+ # See https://pre-commit.com/hooks.html for more hooks
3
+ repos:
4
+ - repo: https://github.com/pre-commit/pre-commit-hooks
5
+ rev: v3.2.0
6
+ hooks:
7
+ - id: trailing-whitespace
8
+ - id: end-of-file-fixer
9
+ - id: check-yaml
10
+ - id: check-added-large-files
11
+ - repo: https://github.com/astral-sh/ruff-pre-commit
12
+ # Ruff version.
13
+ rev: v0.7.4
14
+ hooks:
15
+ # Run the linter.
16
+ - id: ruff
17
+ args: [--fix]
18
+ # Run the formatter.
19
+ - id: ruff-format
.vscode/launch.json ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ // Use IntelliSense to learn about possible attributes.
3
+ // Hover to view descriptions of existing attributes.
4
+ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5
+ "version": "0.2.0",
6
+ "configurations": [
7
+ {
8
+ "name": "App",
9
+ "type": "debugpy",
10
+ "request": "launch",
11
+ "module": "article_embedding",
12
+ "args": [
13
+ "app"
14
+ ],
15
+ "console": "integratedTerminal",
16
+ "justMyCode": false,
17
+ },
18
+ {
19
+ "name": "Scanner",
20
+ "type": "debugpy",
21
+ "request": "launch",
22
+ "program": "article_embedding/couchdb.py",
23
+ "console": "integratedTerminal",
24
+ "justMyCode": false,
25
+ },
26
+ {
27
+ "name": "Loader",
28
+ "type": "debugpy",
29
+ "request": "launch",
30
+ "program": "article_embedding/loader.py",
31
+ "console": "integratedTerminal",
32
+ "justMyCode": false,
33
+ },
34
+ {
35
+ "name": "Embedder",
36
+ "type": "debugpy",
37
+ "request": "launch",
38
+ "program": "article_embedding/embed.py",
39
+ "console": "integratedTerminal",
40
+ "justMyCode": false,
41
+ },
42
+ {
43
+ "name": "Query",
44
+ "type": "debugpy",
45
+ "request": "launch",
46
+ "program": "article_embedding/query.py",
47
+ "console": "integratedTerminal",
48
+ "justMyCode": false,
49
+ },
50
+ {
51
+ "name": "Benchmark",
52
+ "type": "debugpy",
53
+ "request": "launch",
54
+ "program": "article_embedding/benchmark.py",
55
+ "console": "integratedTerminal",
56
+ "justMyCode": false,
57
+ },
58
+ ]
59
+ }
.vscode/settings.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "editor.codeActionsOnSave": {
3
+ "source.organizeImports": "always",
4
+ "source.unusedImports": "always"
5
+ },
6
+ "editor.formatOnPaste": true,
7
+ "editor.formatOnSave": true,
8
+ "editor.formatOnSaveMode": "file",
9
+ "editor.formatOnType": true,
10
+ "editor.renderWhitespace": "all",
11
+ "editor.rulers": [
12
+ 132
13
+ ],
14
+ "files.trimTrailingWhitespace": true,
15
+ "mypy.runUsingActiveInterpreter": true,
16
+ "python.analysis.importFormat": "absolute",
17
+ "python.analysis.autoFormatStrings": true,
18
+ "python.analysis.autoImportCompletions": true,
19
+ "ruff.organizeImports": true,
20
+ "[jsonc]": {
21
+ "editor.defaultFormatter": "vscode.json-language-features",
22
+ },
23
+ "[python]": {
24
+ "editor.defaultFormatter": "charliermarsh.ruff",
25
+ }
26
+ }
README.md CHANGED
@@ -1,13 +1,148 @@
1
  ---
2
- title: Wsws Chatbot
3
- emoji: 💻
4
- colorFrom: red
5
- colorTo: green
6
  sdk: gradio
7
  sdk_version: 5.12.0
8
- app_file: app.py
9
  pinned: false
10
  short_description: Ask the WSWS chatbot
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: wsws-chatbot
3
+ emoji: 📊
4
+ colorFrom: green
5
+ colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 5.12.0
8
+ app_file: article_embedding/app/ui.py
9
  pinned: false
10
  short_description: Ask the WSWS chatbot
11
  ---
12
 
13
+ ## Filter
14
+
15
+ Identify the top 60 most important articles of the year
16
+
17
+ | Variables | Description |
18
+ | -------------------- | ------------------------------------- |
19
+ | articlesPerYear | Number of articles per year |
20
+ | perspectivePercent | Percentage of perspectives |
21
+ | usPercent | Percentage of US articles |
22
+ | maxArticlesPerAuthor | Maximum number of articles per author |
23
+
24
+ ### Important topics
25
+
26
+ | Topic ID | Topic Name |
27
+ | ------------------------------------ | ------------------------ |
28
+ | 3f1db5b4-7915-4cef-a132-3b2e51f0728d | US Politics |
29
+ | 105b9ec2-0cdd-480d-9c55-6f806fa7b9bc | Perspectives |
30
+ | 04d2a008-8cef-4928-9a38-86b205a7590f | The coronavirus pandemic |
31
+ | ea69254c-a976-4e67-a79c-843202a10e55 | United States |
32
+ | b0e983c5-6e85-4829-a9fe-3c25f02c211f | North America |
33
+
34
+ ## Analysis
35
+
36
+ 1. Document metadata structure
37
+
38
+ - The metadata is stored as JSON
39
+ - Yes, there are structured attributes in the metadata
40
+ - There are many other structured attributes but none that are relevant to the system design
41
+
42
+ 2. Vector database setup
43
+
44
+ - Qdrant
45
+ - Documents are embedded using the Stella 400M v5 embedding
46
+
47
+ 3. Integration of filters and constraints
48
+
49
+ - Excellent question. The constraints need not be strict. They should be approximately met, since relevance to the provided queries is also important
50
+ - The documents should be selected first based on relevance and then adjust the final selection to meet the constraints. It's ok to select less relevant articles to get closer to the constraints, but up to a certain point.
51
+
52
+ 4. Date range filtering
53
+
54
+ - The date range is strict
55
+
56
+ 5. Result volume and ranking logic
57
+
58
+ - Great question too. The articles should be ranked based on their relevance, with all queries adding documents to the result set, but choosing the most relevant ones from this pool. For example, it would be ok if all articles end up being derived from only one query.
59
+
60
+ 6. Combining multiple queries
61
+
62
+ - The former. Each query is run separately and then the results are combined based on the document relevance score. Since the same document may be returned for different queries, in each case with a different relevance score, the highest relevance score should be used.
63
+
64
+ 7. User preferences and constraint handling
65
+
66
+ - Exactly, the result set should be post processed to meet constraints
67
+
68
+ 8. Performance considerations
69
+
70
+ - The document database has about 100,000 documents
71
+ - No latency requirements. The accuracy of the system is the only requirement
72
+
73
+ 9. Evaluation and Benchmarks
74
+
75
+ - The system success will be measured based on relevance judgments and coverage of constraints, in that order
76
+ - A feedback loop to improve the retrieval strategies or embedding approach is definitely on the table, if needed.
77
+
78
+ self.embedding_model = SentenceTransformerModel("dunzhang/stella_en_400M_v5")
79
+
80
+ # Performance benchmarks
81
+
82
+ The script at `article_embedding/modal_app.py` was used to measure the performance of the embedding model. The results are as follows:
83
+
84
+ | GPU | Speed | Cost |
85
+ | :----: | ------: | ----: |
86
+ | T4 | 130.0ms | 2.13¢ |
87
+ | L4 | 82.7ms | 1.84¢ |
88
+ | M3 Max | 60.0ms | N/A |
89
+ | A10G | 64.2ms | 1.96¢ |
90
+ | A100 | 42.2ms | 3.26¢ |
91
+ | H100 | 24.9ms | 3.15¢ |
92
+
93
+ These are benchmarks on the M3 Max 128GB:
94
+
95
+ | Model | f32 | f16 | Context | Size |
96
+ | -------------- | ----: | ----: | ------: | ----: |
97
+ | jasper sdpa | 373ms | 314ms | 1,024 | 1,024 |
98
+ | jasper | 354ms | 330ms | 1,024 | 1,024 |
99
+ | stella-en-400M | 65ms | 62ms | 131,072 | 1,024 |
100
+
101
+ ## Models
102
+
103
+ [Jasper](https://huggingface.co/infgrad/jasper_en_vision_language_v1)
104
+
105
+ ## Issues
106
+
107
+ ```sh
108
+ $ cd /Users/jlopez/git-repos/icfi/article-embedding ; /usr/bin/env /Users/jlopez/git-repos/icfi/article-embedding/.venv/bin/python /Users/jlopez/.vscode/extensions/ms-python.debugpy-2024.10.0-darwin-arm64/bundled/libs/debugpy/adapter/../../debugpy/launcher 63791 -- article_embedding/loader.py
109
+ 0%| | 0/314855 [00:00<?, ?it/s]E+00379.644: /handling disconnect from Adapter/
110
+ Failed to kill Debuggee[PID=93387]
111
+
112
+ Traceback (most recent call last):
113
+ File "/Users/jlopez/.vscode/extensions/ms-python.debugpy-2024.10.0-darwin-arm64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/launcher/debuggee.py", line 190, in kill
114
+ os.killpg(process.pid, signal.SIGKILL)
115
+ PermissionError: [Errno 1] Operation not permitted
116
+
117
+ Stack where logged:
118
+ File "/Users/jlopez/.pyenv/versions/3.12.5/lib/python3.12/threading.py", line 1032, in _bootstrap
119
+ self._bootstrap_inner()
120
+ File "/Users/jlopez/.pyenv/versions/3.12.5/lib/python3.12/threading.py", line 1075, in _bootstrap_inner
121
+ self.run()
122
+ File "/Users/jlopez/.pyenv/versions/3.12.5/lib/python3.12/threading.py", line 1012, in run
123
+ self._target(*self._args, **self._kwargs)
124
+ File "/Users/jlopez/.vscode/extensions/ms-python.debugpy-2024.10.0-darwin-arm64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/common/messaging.py", line 1458, in _run_handlers
125
+ handler()
126
+ File "/Users/jlopez/.vscode/extensions/ms-python.debugpy-2024.10.0-darwin-arm64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/common/messaging.py", line 1488, in _handle_disconnect
127
+ handler()
128
+ File "/Users/jlopez/.vscode/extensions/ms-python.debugpy-2024.10.0-darwin-arm64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/launcher/handlers.py", line 152, in disconnect
129
+ debuggee.kill()
130
+ File "/Users/jlopez/.vscode/extensions/ms-python.debugpy-2024.10.0-darwin-arm64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/launcher/debuggee.py", line 192, in kill
131
+ log.swallow_exception("Failed to kill {0}", describe())
132
+ File "/Users/jlopez/.vscode/extensions/ms-python.debugpy-2024.10.0-darwin-arm64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/common/log.py", line 215, in swallow_exception
133
+ _exception(format_string, *args, **kwargs)
134
+
135
+
136
+ E+00379.648: Failed to kill Debuggee[PID=93387]
137
+
138
+ Traceback (most recent call last):
139
+ File "/Users/jlopez/.vscode/extensions/ms-python.debugpy-2024.10.0-darwin-arm64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/launcher/debuggee.py", line 190, in kill
140
+ os.killpg(process.pid, signal.SIGKILL)
141
+ PermissionError: [Errno 1] Operation not permitted
142
+
143
+ Stack where logged:
144
+ File "/Users/jlopez/.vscode/extensions/ms-python.debugpy-2024.10.0-darwin-arm64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/launcher/debuggee.py", line 192, in kill
145
+ log.swallow_exception("Failed to kill {0}", describe())
146
+ File "/Users/jlopez/.vscode/extensions/ms-python.debugpy-2024.10.0-darwin-arm64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/common/log.py", line 215, in swallow_exception
147
+ _exception(format_string, *args, **kwargs)
148
+ ```
article_embedding/__init__.py ADDED
File without changes
article_embedding/__main__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from article_embedding.cli import cli
2
+
3
+ if __name__ == "__main__":
4
+ cli()
article_embedding/app/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .ui import run_app
2
+
3
+ __all__ = ["run_app"]
article_embedding/app/gather.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import gradio as gr
4
+
5
+ from article_embedding.retrieval import retrieve_documents
6
+
7
+ DEFAULT_QUERY = """
8
+ The development of the COVID-19 pandemic globally.
9
+ The “herd immunity” and “mitigationist” response of capitalist governments globally to the COVID-19 pandemic.
10
+ The only viable solution of global elimination of the COVID-19 pandemic.
11
+ The science of the COVID-19 pandemic: Long COVID, airborne transmission, viral evolution.
12
+ """.strip()
13
+
14
+
15
+ def gather_articles(
16
+ query: str,
17
+ year: int,
18
+ articles_per_year: int,
19
+ pct_perspectives: int,
20
+ pct_us: int,
21
+ max_per_author: int,
22
+ ) -> list[list[Any]]:
23
+ queries = [q.strip() for q in query.split("\n") if q.strip()]
24
+ docs = retrieve_documents(
25
+ queries=queries,
26
+ start_date=f"{year}-01-01",
27
+ end_date=f"{year + 1}-01-01",
28
+ number_of_articles=articles_per_year,
29
+ op_ed_ratio=pct_perspectives / 100.0,
30
+ us_ratio=pct_us / 100.0,
31
+ max_per_author=max_per_author,
32
+ )
33
+ return [
34
+ [
35
+ ix + 1,
36
+ d.query_id,
37
+ f"{d.score * 100:.1f}%",
38
+ d.meta.published.strftime("%Y/%m/%d") if d.meta.published else "N/A",
39
+ ", ".join(d.meta.authors)[:32] if d.meta.authors else "N/A",
40
+ f'[{d.meta.title or "N/A"}]({d.meta.url})',
41
+ ]
42
+ for ix, d in enumerate(docs)
43
+ ]
44
+
45
+
46
+ def gather_tab() -> None:
47
+ with gr.Tab("Gather"):
48
+ with gr.Row():
49
+ with gr.Column(scale=2):
50
+ query = gr.Textbox(label="Query", placeholder="Enter multiple queries, one per line", value=DEFAULT_QUERY)
51
+ gather = gr.Button(value="Gather articles")
52
+ with gr.Column():
53
+ year = gr.Slider(minimum=2021, maximum=2025, value=2021, label="Year", step=1)
54
+ articles_per_year = gr.Slider(label="Articles per year", minimum=1, maximum=100, value=60, step=1)
55
+ pct_perspectives = gr.Slider(label="Perspectives (%)", minimum=0, maximum=100, value=50)
56
+ pct_us = gr.Slider(label="US (%)", minimum=0, maximum=100, value=30)
57
+ max_per_author = gr.Slider(label="Max articles per author", minimum=1, maximum=20, value=6, step=1)
58
+
59
+ results = gr.DataFrame(
60
+ label="Results",
61
+ type="array",
62
+ headers=["#", "Q", "%", "Date", "Authors", "Title"],
63
+ datatype=["number", "number", "number", "date", "str", "markdown"],
64
+ column_widths=["3.5%", "3.5%", "4%", "7%", "20%", "62%"],
65
+ )
66
+
67
+ gather.click(
68
+ fn=gather_articles,
69
+ inputs=[query, year, articles_per_year, pct_perspectives, pct_us, max_per_author],
70
+ outputs=[results],
71
+ )
article_embedding/app/lookup.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from article_embedding.couchdb import CouchDB
4
+
5
+
6
+ def lookup_tab() -> None:
7
+ with gr.Tab("Lookup"):
8
+ doc_id = gr.Textbox(label="Document ID", placeholder="Enter a document ID")
9
+ button = gr.Button("Lookup")
10
+ doc_output = gr.JSON()
11
+ button.click(fn=CouchDB().get_doc, inputs=[doc_id], outputs=[doc_output])
article_embedding/app/model.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ from typing import Any
4
+
5
+ from pydantic import BaseModel
6
+
7
+ log = logging.getLogger(__name__)
8
+
9
+
10
+ class WorkflowData(BaseModel):
11
+ status: str | None = None
12
+ writers: list[str] | None = None
13
+ editors: list[str] | None = None
14
+ proofers: list[str] | None = None
15
+ reviewers: list[str] | None = None
16
+ proofingDeadline: str | None = None
17
+
18
+
19
+ class Document(BaseModel):
20
+ _id: str
21
+ _rev: str
22
+ type: str | None = None
23
+ mimetype: str | None = None
24
+ title: str | None = None
25
+ language: str | None = None
26
+ workflowData: WorkflowData | None = None
27
+ path: str | None = None
28
+ name: str | None = None
29
+ created: int | None = None
30
+ creator: str | None = None
31
+ lastPublished: int | None = None
32
+ firstPublished: int | None = None
33
+ modified: datetime.datetime | None = None
34
+ modifier: str | None = None
35
+ published: datetime.datetime | None = None
36
+ authors: list[str] | None = None
37
+ content: str | None = None
38
+ contentAssets: list[str] | None = None
39
+ featuredImages: list[str] | None = None
40
+ keywords: list[str] | None = None
41
+ topics: list[str] | None = None
42
+ relatedAssets: list[str] | None = None
43
+ comments: bool | None = None
44
+ campaignConfigs: list[Any] | None = None
45
+ order: int | None = None
46
+ overline: str | None = None
47
+ translatedFrom: str | None = None
48
+ socialTitles: list[Any] | None = None
49
+ socialDescriptions: list[Any] | None = None
50
+ socialFeaturedImages: list[Any] | None = None
51
+ underline: str | None = None
52
+ template: str | None = None
53
+ description: str | None = None
54
+ suggestedImages: list[str] | None = None
55
+ publisher: str | None = None
56
+
57
+ @property
58
+ def url(self) -> str:
59
+ if not self.path:
60
+ log.warning(f"URL requested for pathless document: {self}.")
61
+ return public_url(self.path)
62
+
63
+ def has_topic(self, topic_id: str) -> bool:
64
+ return topic_id in self.topics if self.topics else False
65
+
66
+ def __str__(self) -> str:
67
+ return f"Document[id={self._id!r}, path={self.path!r}]"
68
+
69
+
70
+ def public_url(path: str | None) -> str:
71
+ return f"https://www.wsws.org{path or ''}"
article_embedding/app/rag.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from typing import AsyncGenerator, Literal
3
+
4
+ import gradio as gr
5
+ from gradio import MessageDict
6
+ from openai import AsyncOpenAI, BaseModel
7
+ from openai.types.chat import ChatCompletionMessageParam
8
+
9
+ from article_embedding.chunk import html_to_markdown
10
+ from article_embedding.couchdb import CouchDB
11
+ from article_embedding.query import Query
12
+ from article_embedding.utils import env_str
13
+
14
+ BackendId = Literal["OpenAI", "Ollama"]
15
+
16
+
17
+ class BackendConfig(BaseModel):
18
+ base_url: str | None = None
19
+ api_key: str | None = None
20
+
21
+ @property
22
+ def async_client(self) -> AsyncOpenAI:
23
+ return AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
24
+
25
+
26
+ BACKEND_CONFIG: dict[BackendId, BackendConfig] = {
27
+ "OpenAI": BackendConfig(),
28
+ "Ollama": BackendConfig(base_url=env_str("OLLAMA_URL"), api_key="Ollama"),
29
+ }
30
+
31
+
32
+ class RagModel(BaseModel):
33
+ backend_id: BackendId
34
+ model: str
35
+
36
+ @property
37
+ def backend(self) -> BackendConfig:
38
+ return BACKEND_CONFIG[self.backend_id]
39
+
40
+
41
+ RAG_MODELS: dict[str, RagModel] = {
42
+ m.model: m
43
+ for m in [
44
+ RagModel(backend_id="OpenAI", model="gpt-4o"),
45
+ RagModel(backend_id="OpenAI", model="gpt-4o-mini"),
46
+ RagModel(backend_id="Ollama", model="qwen2.5:72b"),
47
+ RagModel(backend_id="Ollama", model="llama3.3:70b"),
48
+ RagModel(backend_id="Ollama", model="phi4"),
49
+ ]
50
+ }
51
+
52
+ QUERY_PROMPT = """
53
+ Generate a short paragraph whose vector embedding will be used to find related articles using a vector database indexing the WSWS archive.
54
+ Only output the paragraph without any additional text.
55
+
56
+ {question}
57
+ """.strip() # noqa: E501
58
+
59
+ CONTEXT_PROMPT = """
60
+ Using the following sources, answer the user's question as precisely as possible.
61
+ Only use the information provided in the sources.
62
+ If not enough information is available, say you don't know.
63
+ Include references to the sources in the answer by using just the article index and url, e.g. [3](URL).
64
+ The answers should be exhaustive and cover as much as possible from the sources.
65
+
66
+ {docs}
67
+ """.strip()
68
+
69
+ _couchdb = CouchDB()
70
+
71
+
72
+ # From Gradio chat_interface.py:ChatInterface:_stream_fn
73
+ async def chat_function(
74
+ message: str, # | MultimodalPostprocess,
75
+ history_with_input: list[MessageDict],
76
+ model: str,
77
+ ) -> AsyncGenerator[str | list[MessageDict], None]:
78
+ _openai = RAG_MODELS[model].backend.async_client
79
+ _chat = _openai.chat.completions.create
80
+ messages: list[ChatCompletionMessageParam] = [{"role": h["role"], "content": h["content"]} for h in history_with_input]
81
+
82
+ async def generate_system_prompt(question: str) -> str:
83
+ msgs: list[ChatCompletionMessageParam] = messages + [{"role": "system", "content": QUERY_PROMPT.format(question=question)}]
84
+ response = await _chat(messages=msgs, model=model, stream=False)
85
+ query = response.choices[0].message.content or ""
86
+
87
+ result = Query(index="wsws-2").query(query, limit=40)
88
+ paths = []
89
+ for point in result.points:
90
+ doc = point.payload
91
+ if doc is None:
92
+ continue
93
+ path = doc["path"]
94
+ if path not in paths:
95
+ paths.append(path)
96
+ if len(paths) >= 8:
97
+ break
98
+ tasks = [_couchdb.get_doc(path) for path in paths]
99
+ docs: list[str] = []
100
+ for rix, task_result in enumerate(await asyncio.gather(*tasks, return_exceptions=True), start=1):
101
+ if isinstance(task_result, dict):
102
+ data = [
103
+ f"Index: {rix}",
104
+ f"URL: https://www.wsws.org{task_result["path"]}",
105
+ task_result["title"] or "",
106
+ task_result["description"] or "",
107
+ html_to_markdown(task_result["content"]),
108
+ "---",
109
+ ]
110
+ docs.append("\n".join(data))
111
+ return CONTEXT_PROMPT.format(docs="\n\n".join(docs))
112
+
113
+ system_prompt = await generate_system_prompt(message)
114
+ messages.append({"role": "system", "content": system_prompt})
115
+ messages.append({"role": "user", "content": message})
116
+ generator = await _chat(messages=messages, model=model, stream=True)
117
+ result: MessageDict = {"role": "assistant", "content": ""}
118
+ async for chunk in generator:
119
+ delta = chunk.choices[0].delta
120
+ if delta.content:
121
+ result["content"] += delta.content
122
+ yield result
123
+
124
+
125
+ def rag_tab() -> None:
126
+ with gr.Tab("RAG"):
127
+ model = gr.Dropdown(value="gpt-4o", choices=list(RAG_MODELS), label="Model")
128
+ gr.ChatInterface(
129
+ fn=chat_function,
130
+ multimodal=False,
131
+ type="messages",
132
+ additional_inputs=[model],
133
+ chatbot=gr.Chatbot(height=800),
134
+ )
article_embedding/app/search.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from typing import Any
3
+
4
+ import gradio as gr
5
+ from qdrant_client.models import Condition, Filter
6
+
7
+ from article_embedding.query import Query, make_date_condition, make_topic_condition
8
+
9
+
10
+ def search(
11
+ text: str,
12
+ date_from: datetime | None,
13
+ date_to: datetime | None,
14
+ topics: str,
15
+ limit: int,
16
+ ) -> tuple[list[list[str]], list[Any]]:
17
+ rv: list[list[str]] = []
18
+
19
+ topic_lines = [t for t in topics.strip().split("\n") if t]
20
+ should: list[Condition] = [make_topic_condition(t) for t in topic_lines if t[0].isalnum()]
21
+ must: list[Condition] = [make_topic_condition(t[1:]) for t in topic_lines if t.startswith("+")]
22
+ must_not: list[Condition] = [make_topic_condition(t[1:]) for t in topic_lines if t.startswith("!")]
23
+ if date_from or date_to:
24
+ must.append(make_date_condition(date_from=date_from, date_to=date_to)) # type: ignore
25
+ result = Query().query(
26
+ text,
27
+ query_filter=Filter(should=should, must=must, must_not=must_not),
28
+ limit=limit,
29
+ )
30
+ docs = []
31
+ for point in result.points:
32
+ doc = point.payload
33
+ assert doc is not None
34
+ docs.append(doc)
35
+ rv.append(
36
+ [
37
+ f"{point.score * 100:.1f}%",
38
+ datetime.fromtimestamp(doc["published"]).strftime("%Y/%m/%d"),
39
+ ", ".join(doc["authors"]),
40
+ f'[{doc["title"] or "N/A"}](https://www.wsws.org{doc["path"]})',
41
+ ]
42
+ )
43
+ return rv, docs
44
+
45
+
46
+ def search_tab() -> None:
47
+ default_query = "The COVID winter wave, the emergence of the Delta variant and the January 6th coup"
48
+ with gr.Tab("Search"):
49
+ with gr.Row():
50
+ with gr.Column(scale=2):
51
+ query = gr.Textbox(
52
+ label="Query",
53
+ placeholder="Enter a query",
54
+ value=default_query,
55
+ )
56
+ with gr.Column():
57
+ from_date = gr.DateTime(include_time=False, type="datetime", label="From", value="2021-01-01")
58
+ to_date = gr.DateTime(include_time=False, type="datetime", label="To", value="2021-05-01")
59
+ topics = gr.Textbox(label="Topics", placeholder="Enter topics")
60
+ limit = gr.Number(label="Limit", minimum=1, value=10)
61
+ button = gr.Button("Search")
62
+ results = gr.DataFrame(
63
+ label="Results",
64
+ type="array",
65
+ headers=["Score", "Published", "Authors", "Title"],
66
+ datatype=["number", "date", "str", "markdown"],
67
+ )
68
+ with gr.Accordion("Debug"):
69
+ json_output = gr.JSON()
70
+ button.click(fn=search, inputs=[query, from_date, to_date, topics, limit], outputs=[results, json_output])
article_embedding/app/ui.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+
4
+ import gradio as gr
5
+ from dotenv import load_dotenv
6
+
7
+ from article_embedding.app.gather import gather_tab
8
+ from article_embedding.app.lookup import lookup_tab
9
+ from article_embedding.app.rag import rag_tab
10
+ from article_embedding.app.search import search_tab
11
+
12
+ load_dotenv()
13
+
14
+
15
+ async def amain(share: bool = False) -> None:
16
+ with gr.Blocks(fill_height=True) as app: # noqa: SIM117
17
+ rag_tab()
18
+ gather_tab()
19
+ search_tab()
20
+ lookup_tab()
21
+
22
+ auth = (os.environ["GRADIO_USERNAME"], os.environ["GRADIO_PASSWORD"]) if share else None
23
+ app.queue().launch(inbrowser=True, share=share, auth=auth)
24
+
25
+
26
+ def run_app(share: bool = False) -> None:
27
+ asyncio.run(amain(share))
28
+
29
+
30
+ if __name__ == "__main__":
31
+ run_app()
article_embedding/benchmark.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import time
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+ import numpy.typing as npt
9
+ import torch
10
+
11
+ from article_embedding.embed import SentenceTransformerModel, StellaEmbedder
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+
16
+ async def fetch_documents() -> list[str]:
17
+ from article_embedding.couchdb import CouchDB
18
+
19
+ async for sorted_batch in CouchDB().changes(batch_size=128):
20
+ sorted_batch = sorted(sorted_batch, key=lambda x: x.get("_id"))
21
+ return [a["content"] for a in sorted_batch if a.get("type") == "article" and a.get("language") == "en"]
22
+ return []
23
+
24
+
25
+ def process2(model: SentenceTransformerModel, documents: list[str], name: str) -> None:
26
+ model.embed(["Hello, world!"]) # Warmup
27
+ ts0 = time.time()
28
+ embeddings = model.embed(documents)
29
+ benchmark = time.time() - ts0
30
+ output_path = Path("data/embeddings.json")
31
+ save_embeddings(embeddings, output_path)
32
+ golden_path = Path(f"data/embeddings.{name}-golden.json")
33
+ if golden_path.exists():
34
+ golden_embeddings: Any = load_embeddings(golden_path)
35
+ similarities = model.model.similarity_pairwise(embeddings, golden_embeddings)
36
+ rms = torch.sqrt(torch.mean(similarities**2)).item()
37
+ else:
38
+ save_embeddings(embeddings, golden_path)
39
+ rms = 0.0
40
+ log.info(
41
+ "%s - RMS: %.2f. Latency: %.2f ms. Size: %d",
42
+ name,
43
+ rms,
44
+ benchmark / len(embeddings) * 1000,
45
+ len(embeddings[0]),
46
+ )
47
+
48
+
49
+ def load_embeddings(path: Path) -> list[npt.NDArray[np.float64]]:
50
+ with path.open() as f:
51
+ return [np.array(json.loads(line)) for line in f.readlines()]
52
+
53
+
54
+ def save_embeddings(embeddings: list[npt.NDArray[np.float64]], path: Path) -> None:
55
+ with path.open("w") as f:
56
+ for e in embeddings:
57
+ f.write(json.dumps(e.tolist()) + "\n")
58
+
59
+
60
+ async def amain() -> None:
61
+ model = StellaEmbedder()
62
+ # model = NvEmbedder()
63
+ # model = JasperEmbedder()
64
+ # model.model.half()
65
+ documents = await fetch_documents()
66
+ process2(model, documents, "stella")
67
+ # documents = await fetch_documents()
68
+ # process2(model, documents, "stella")
69
+
70
+
71
+ if __name__ == "__main__":
72
+ import asyncio
73
+
74
+ logging.basicConfig(level=logging.WARN)
75
+ log.setLevel(logging.DEBUG)
76
+ asyncio.run(amain())
article_embedding/chunk.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import cast
3
+
4
+ import pandoc
5
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
6
+
7
+
8
+ class Chunker(ABC):
9
+ @abstractmethod
10
+ def chunk(self, text: str) -> list[str]: ...
11
+
12
+
13
+ class LangchainChunker(Chunker):
14
+ def __init__(self, chunk_size: int = 576, chunk_overlap: int = 0) -> None:
15
+ self.splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
16
+
17
+ def chunk(self, text: str) -> list[str]:
18
+ md = html_to_markdown(text)
19
+ return self.splitter.split_text(md)
20
+
21
+
22
+ def html_to_markdown(html: str) -> str:
23
+ doc = pandoc.read(html, format="html")
24
+ md = pandoc.write(doc, format="markdown-smart", options=["--wrap=none"])
25
+ return cast(str, md)
article_embedding/cli.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import click
4
+ from click_help_colors import HelpColorsGroup
5
+
6
+ from article_embedding.app import run_app
7
+
8
+
9
+ @click.group(
10
+ cls=HelpColorsGroup,
11
+ help_headers_color="yellow",
12
+ help_options_color="green",
13
+ )
14
+ def cli() -> None:
15
+ pass
16
+
17
+
18
+ @cli.command()
19
+ @click.option("--share", is_flag=True, help="Share the app at Gradio.")
20
+ def app(share: bool) -> None:
21
+ """Run the Gradio app."""
22
+ logging.basicConfig(level=logging.WARN, format="%(asctime)s %(levelname)s %(name)s - %(message)s")
23
+ logging.getLogger("article_embedding").setLevel(logging.INFO)
24
+ run_app(share)
article_embedding/constants.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ PERSPECTIVE_TOPIC_ID = "105b9ec2-0cdd-480d-9c55-6f806fa7b9bc"
2
+ US_TOPIC_ID = "ea69254c-a976-4e67-a79c-843202a10e55"
article_embedding/couchdb.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import json
3
+ import logging
4
+ import os
5
+ import re
6
+ from abc import ABC, abstractmethod
7
+ from typing import Any, AsyncGenerator
8
+
9
+ import backoff
10
+ import httpx
11
+ import tqdm
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+
16
+ class Checkpoint(ABC):
17
+ @abstractmethod
18
+ def get(self) -> str | None: ...
19
+
20
+ @abstractmethod
21
+ def set(self, value: str) -> None: ...
22
+
23
+ @abstractmethod
24
+ def reset(self) -> None: ...
25
+
26
+
27
+ class NullCheckpoint(Checkpoint):
28
+ def get(self) -> str | None:
29
+ return None
30
+
31
+ def set(self, value: str) -> None:
32
+ pass
33
+
34
+ def reset(self) -> None:
35
+ pass
36
+
37
+
38
+ _NULL_CHECKPOINT = NullCheckpoint()
39
+
40
+
41
+ class FileCheckpoint(Checkpoint):
42
+ def __init__(self, path: str) -> None:
43
+ self.path = path
44
+
45
+ def get(self) -> str | None:
46
+ try:
47
+ with open(self.path) as file:
48
+ return file.read().strip()
49
+ except FileNotFoundError:
50
+ return None
51
+
52
+ def set(self, value: str) -> None:
53
+ with open(self.path, "w") as file:
54
+ file.write(value)
55
+
56
+ def reset(self) -> None:
57
+ with contextlib.suppress(FileNotFoundError):
58
+ os.remove(self.path)
59
+
60
+
61
+ class CouchDB:
62
+ def __init__(self) -> None:
63
+ self.client = self.make_client()
64
+ self.database = os.environ["COUCHDB_DB"]
65
+ self.path_view = f"/{self.database}/{os.environ["DOCS_PATH_VIEW"]}"
66
+
67
+ def __new__(cls) -> "CouchDB":
68
+ if not hasattr(cls, "_instance"):
69
+ cls._instance = super().__new__(cls)
70
+ return cls._instance
71
+
72
+ def make_client(self) -> httpx.AsyncClient:
73
+ url = os.environ["COUCHDB_URL"]
74
+ user = os.environ["COUCHDB_USER"]
75
+ password = os.environ["COUCHDB_PASSWORD"]
76
+ auth = {"name": user, "password": password}
77
+
78
+ async def on_backoff(details: Any) -> None:
79
+ response = await self.client.post("/_session", json=auth)
80
+ response.raise_for_status()
81
+
82
+ client = httpx.AsyncClient(base_url=url)
83
+ decorator = backoff.on_predicate(
84
+ backoff.expo,
85
+ predicate=lambda r: r.status_code == 401,
86
+ on_backoff=on_backoff,
87
+ max_tries=2,
88
+ factor=0,
89
+ )
90
+ client.get = decorator(client.get) # type: ignore[method-assign]
91
+ return client
92
+
93
+ async def changes(self, *, batch_size: int, checkpoint: Checkpoint = _NULL_CHECKPOINT) -> AsyncGenerator[list[Any], None]:
94
+ since = checkpoint.get() or 0
95
+ params = {"since": since, "limit": batch_size, "include_docs": True}
96
+ while True:
97
+ response = await self.client.get(f"/{self.database}/_changes", params=params)
98
+ response.raise_for_status()
99
+ data = response.json()
100
+ yield [change["doc"] for change in data["results"]]
101
+ since = data["last_seq"]
102
+ assert isinstance(since, str)
103
+ params["since"] = since
104
+ checkpoint.set(since)
105
+ if data["pending"] == 0:
106
+ break
107
+
108
+ async def estimate_total_changes(self, *, checkpoint: Checkpoint = _NULL_CHECKPOINT) -> int:
109
+ since = checkpoint.get() or 0
110
+ params = {"since": since, "limit": 0}
111
+ response = await self.client.get(f"/{self.database}/_changes", params=params)
112
+ response.raise_for_status()
113
+ data = response.json()
114
+ return int(data["pending"]) + 1
115
+
116
+ async def get_doc_by_id(self, doc_id: str) -> Any:
117
+ try:
118
+ response = await self.client.get(f"/{self.database}/{doc_id}")
119
+ if response.status_code == 404:
120
+ return None
121
+ response.raise_for_status()
122
+ return response.json()
123
+ except Exception as e:
124
+ log.error("Error fetching document by ID", exc_info=e)
125
+ return None
126
+
127
+ async def get_doc_by_path(self, path: str) -> Any:
128
+ try:
129
+ params = {
130
+ "limit": "1",
131
+ "key": json.dumps(path),
132
+ "include_docs": "true",
133
+ }
134
+ response = await self.client.get(self.path_view, params=params)
135
+ response.raise_for_status()
136
+ data = response.json()
137
+ rows = data["rows"]
138
+ if not rows:
139
+ return None
140
+ return rows[0]["doc"]
141
+ except Exception as e:
142
+ logging.error("Error fetching document by path", exc_info=e)
143
+ return None
144
+
145
+ async def get_doc(self, id_or_path: str) -> Any:
146
+ uuids = extract_doc_ids(id_or_path)
147
+ for uuid in uuids:
148
+ doc = await self.get_doc_by_id(uuid)
149
+ if doc:
150
+ return doc
151
+
152
+ path = extract_doc_path(id_or_path)
153
+ if path:
154
+ return await self.get_doc_by_path(path)
155
+
156
+ return None
157
+
158
+
159
+ UUID_PATTERN = re.compile(r"[0-9a-f]{8}-[0-9a-f]{4}-[1-5][0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}")
160
+
161
+
162
+ def extract_doc_ids(s: str) -> list[str]:
163
+ return UUID_PATTERN.findall(s)
164
+
165
+
166
+ def extract_doc_path(s: str) -> str | None:
167
+ if not s.endswith(".html"):
168
+ return None
169
+ if s.startswith("/"):
170
+ return s
171
+ if "://" in s:
172
+ s = s.split("://", 1)[1]
173
+ if "/" in s:
174
+ return "/" + s.split("/", 1)[1]
175
+ return None
176
+
177
+
178
+ if __name__ == "__main__":
179
+
180
+ async def main() -> None:
181
+ db = CouchDB()
182
+ checkpoint = FileCheckpoint(".checkpoint")
183
+ total = await db.estimate_total_changes(checkpoint=checkpoint)
184
+ with tqdm.tqdm(total=total) as pbar:
185
+ async for docs in db.changes(batch_size=40, checkpoint=checkpoint):
186
+ for doc in docs:
187
+ kind = doc.get("type")
188
+ if kind == "article":
189
+ _id = doc["_id"]
190
+ language = doc["language"]
191
+ path = doc["path"]
192
+ path = os.path.basename(path)
193
+ pbar.desc = f"{_id}: {kind} {language} {path}"
194
+ else:
195
+ pbar.desc = f"{kind}"
196
+ pbar.update(1)
197
+
198
+ import asyncio
199
+
200
+ from dotenv import load_dotenv
201
+
202
+ load_dotenv()
203
+ asyncio.run(main())
article_embedding/embed.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any
3
+
4
+ from fastembed import TextEmbedding # type: ignore
5
+ from qdrant_client import QdrantClient
6
+ from qdrant_client.models import Distance, VectorParams
7
+ from sentence_transformers import SentenceTransformer
8
+
9
+
10
+ class EmbeddingModel(ABC):
11
+ @abstractmethod
12
+ def embed(self, documents: list[str]) -> Any: ...
13
+ @abstractmethod
14
+ def get_vectors_config(self) -> VectorParams: ...
15
+
16
+
17
+ class SentenceTransformerModel(EmbeddingModel):
18
+ model: SentenceTransformer
19
+
20
+ def embed(self, documents: list[str]) -> Any:
21
+ return self.model.encode(documents, normalize_embeddings=True)
22
+
23
+ def get_vectors_config(self) -> VectorParams:
24
+ dimensions = self.model.get_sentence_embedding_dimension()
25
+ assert dimensions is not None
26
+ return VectorParams(
27
+ size=dimensions,
28
+ distance=Distance.COSINE,
29
+ )
30
+
31
+
32
+ class StellaEmbedder(SentenceTransformerModel):
33
+ def __init__(self) -> None:
34
+ self.model = SentenceTransformer(
35
+ "dunzhang/stella_en_400M_v5",
36
+ trust_remote_code=True,
37
+ config_kwargs={"use_memory_efficient_attention": False, "unpad_inputs": False},
38
+ ).to("mps")
39
+
40
+
41
+ class NvEmbedder(SentenceTransformerModel):
42
+ def __init__(self) -> None:
43
+ self.model = SentenceTransformer(
44
+ "nvidia/NV-Embed-V2",
45
+ trust_remote_code=True,
46
+ ).to("mps")
47
+ self.model.max_seq_length = 32768
48
+ self.model.tokenizer.padding_side = "right"
49
+
50
+ def embed(self, documents: list[str]) -> Any:
51
+ processed = [d + self.model.tokenizer.eos_token for d in documents]
52
+ return self.model.encode(processed, normalize_embeddings=True)
53
+
54
+
55
+ class JasperEmbedder(SentenceTransformerModel):
56
+ def __init__(self, use_sdpa: bool = True) -> None:
57
+ model_kwargs = {"attn_implementation": "sdpa"} if use_sdpa else {}
58
+ self.model = SentenceTransformer(
59
+ "infgrad/jasper_en_vision_language_v1",
60
+ trust_remote_code=True,
61
+ model_kwargs=model_kwargs,
62
+ config_kwargs={
63
+ "is_text_encoder": True,
64
+ "vector_dim": 1024,
65
+ },
66
+ ).to("mps")
67
+ self.model.max_seq_length = 1024
68
+
69
+ def embed(self, documents: list[str]) -> Any:
70
+ processed = [d + self.model.tokenizer.eos_token for d in documents]
71
+ return self.model.encode(processed, normalize_embeddings=True)
72
+
73
+
74
+ class FastEmbedModel(EmbeddingModel):
75
+ def __init__(self, model_name: str = "BAAI/bge-small-en-v1.5") -> None:
76
+ self.model = TextEmbedding(model_name=model_name)
77
+
78
+ def embed(self, documents: list[str]) -> Any:
79
+ return self.model.embed(documents)
80
+
81
+ def get_vectors_config(self) -> VectorParams:
82
+ size, distance = QdrantClient._get_model_params(model_name=self.model.model_name)
83
+ return VectorParams(
84
+ size=size,
85
+ distance=distance,
86
+ )
article_embedding/loader.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import os
3
+ import uuid
4
+ from dataclasses import dataclass
5
+ from typing import Any, Generator
6
+
7
+ import tqdm
8
+ from qdrant_client import QdrantClient
9
+ from qdrant_client.http.exceptions import UnexpectedResponse
10
+ from qdrant_client.models import Batch
11
+
12
+ from article_embedding.chunk import Chunker, LangchainChunker
13
+ from article_embedding.couchdb import Checkpoint, CouchDB, FileCheckpoint
14
+ from article_embedding.embed import EmbeddingModel, StellaEmbedder
15
+
16
+
17
+ class Loader:
18
+ def __init__(
19
+ self,
20
+ *,
21
+ collection_name: str,
22
+ embedding_model: EmbeddingModel,
23
+ checkpoint: Checkpoint,
24
+ chunker: Chunker | None = None,
25
+ ) -> None:
26
+ self.collection_name = collection_name
27
+ self.embedding_model = embedding_model
28
+ self.checkpoint = checkpoint
29
+ self.chunker = chunker
30
+ self.qdrant = QdrantClient(os.getenv("QDRANT_URL"))
31
+ self.couchdb = CouchDB()
32
+
33
+ @property
34
+ def starting_from_scratch(self) -> bool:
35
+ return self.checkpoint.get() is None
36
+
37
+ def recreate_collection(self) -> None:
38
+ with contextlib.suppress(UnexpectedResponse):
39
+ self.qdrant.delete_collection(self.collection_name)
40
+ self.qdrant.create_collection(
41
+ collection_name=self.collection_name,
42
+ vectors_config=self.embedding_model.get_vectors_config(),
43
+ )
44
+
45
+ async def load(self, batch_size: int = 64) -> None:
46
+ if self.starting_from_scratch:
47
+ self.recreate_collection()
48
+ total = await self.couchdb.estimate_total_changes(checkpoint=self.checkpoint)
49
+ with tqdm.tqdm(total=total) as pbar:
50
+ async for docs in self.couchdb.changes(batch_size=batch_size, checkpoint=self.checkpoint):
51
+ points = self.embed(docs)
52
+ if points:
53
+ self.qdrant.upsert(
54
+ collection_name=self.collection_name,
55
+ points=points,
56
+ wait=False,
57
+ )
58
+ pbar.update(len(docs))
59
+
60
+ def embed(self, docs: list[Any]) -> Batch | None:
61
+ def article_filter(doc: Any) -> bool:
62
+ return doc.get("type") == "article" and doc.get("language") == "en" and doc.get("content") is not None
63
+
64
+ def to_payload(doc: Any) -> Any:
65
+ doc.pop("content")
66
+ return doc
67
+
68
+ def chunk(doc: dict[str, Any]) -> Generator[dict[str, Any], None, None]:
69
+ assert self.chunker is not None
70
+ content = doc["content"]
71
+ for ix, chunk in enumerate(self.chunker.chunk(content)):
72
+ copy = doc.copy()
73
+ copy["docId"] = doc["_id"]
74
+ copy["_id"] = str(uuid.uuid3(_NS_DOCUMENT, chunk))
75
+ copy["content"] = chunk
76
+ copy["chunkIndex"] = ix
77
+ copy["chunk"] = chunk
78
+ yield copy
79
+
80
+ filtered = [d for d in docs if article_filter(d)]
81
+ chunked = [c for d in filtered for c in chunk(d)] if self.chunker else filtered
82
+ if not chunked:
83
+ return None
84
+ ids = [d["_id"] for d in chunked]
85
+ docs = [d["content"] for d in chunked]
86
+ payloads = [to_payload(d) for d in chunked]
87
+ return Batch(
88
+ ids=ids,
89
+ vectors=self.embedding_model.embed(docs),
90
+ payloads=payloads,
91
+ )
92
+
93
+
94
+ _NS_DOCUMENT = uuid.UUID("a07b4024-c09b-11ef-9e49-36d763a413b8")
95
+
96
+
97
+ def uuid_for_content(*content: Any) -> str:
98
+ import hashlib
99
+ import uuid
100
+
101
+ hash = hashlib.md5("".join(map(str, content)).encode()).hexdigest()
102
+ return str(uuid.UUID(hash[:32]))
103
+
104
+
105
+ if __name__ == "__main__":
106
+ stella = StellaEmbedder()
107
+
108
+ @dataclass
109
+ class Store:
110
+ collection_name: str
111
+ embedding_model: EmbeddingModel = stella
112
+ chunker: Chunker | None = None
113
+ batch_size: int = 64
114
+
115
+ @property
116
+ def checkpoint(self) -> str:
117
+ return f".{self.collection_name}.checkpoint"
118
+
119
+ class StoreManager:
120
+ def __init__(self) -> None:
121
+ self.stores: dict[str, Store] = {}
122
+
123
+ def add(self, store: Store) -> None:
124
+ self.stores[store.collection_name] = store
125
+
126
+ async def load(self, name: str) -> None:
127
+ store = self.stores[name]
128
+ checkpoint = FileCheckpoint(store.checkpoint)
129
+ loader = Loader(
130
+ collection_name=store.collection_name,
131
+ embedding_model=store.embedding_model,
132
+ checkpoint=checkpoint,
133
+ chunker=store.chunker,
134
+ )
135
+ await loader.load(batch_size=store.batch_size)
136
+
137
+ async def load_all(self) -> None:
138
+ for name in self.stores:
139
+ await self.load(name)
140
+
141
+ import asyncio
142
+
143
+ store_manager = StoreManager()
144
+ store_manager.add(Store(collection_name="wsws"))
145
+ store_manager.add(Store(collection_name="wsws-2", chunker=LangchainChunker(), batch_size=8))
146
+ asyncio.run(store_manager.load_all())
article_embedding/modal_app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import time
4
+ from pathlib import Path
5
+ from typing import Any, Callable
6
+
7
+ import modal
8
+ import numpy as np
9
+ import numpy.typing as npt
10
+ import torch
11
+ from sentence_transformers import SentenceTransformer
12
+
13
+ from article_embedding.embed import SentenceTransformerModel, StellaEmbedder
14
+
15
+ log = logging.getLogger(__name__)
16
+
17
+
18
+ def load_model() -> SentenceTransformer:
19
+ return SentenceTransformer(
20
+ "dunzhang/stella_en_400M_v5",
21
+ trust_remote_code=True,
22
+ config_kwargs={"use_memory_efficient_attention": False, "unpad_inputs": False},
23
+ )
24
+
25
+
26
+ image = modal.Image.debian_slim(python_version="3.12").pip_install(["sentence-transformers"]).run_function(load_model)
27
+
28
+ app = modal.App("embedding", image=image)
29
+
30
+
31
+ @app.cls(gpu="A10G")
32
+ class ModalEmbedder:
33
+ @modal.enter()
34
+ def setup(self) -> None:
35
+ logging.basicConfig(level=logging.WARN)
36
+ log.setLevel(logging.DEBUG)
37
+ if torch.cuda.is_available():
38
+ device = "cuda"
39
+ elif torch.backends.mps.is_available():
40
+ device = "mps"
41
+ else:
42
+ device = "cpu"
43
+ self.model = load_model().to(device)
44
+ log.info("Model loaded on %s", device)
45
+
46
+ @modal.method()
47
+ def embed(self, documents: list[str]) -> Any:
48
+ return self.model.encode(documents)
49
+
50
+
51
+ async def fetch_documents() -> list[str]:
52
+ from article_embedding.couchdb import CouchDB
53
+
54
+ async for sorted_batch in CouchDB().changes(batch_size=256):
55
+ sorted_batch = sorted(sorted_batch, key=lambda x: x.get("_id"))
56
+ return [a["content"] for a in sorted_batch if a.get("type") == "article" and a.get("language") == "en"]
57
+ return []
58
+
59
+
60
+ def process(func: Callable[[list[str]], Any], documents: list[str], name: str) -> None:
61
+ func(["Hello, world!"]) # Warmup
62
+ ts0 = time.time()
63
+ embeddings = func(documents)
64
+ benchmark = time.time() - ts0
65
+ output_path = Path("data/embeddings.json")
66
+ golden_path = Path(f"data/embeddings.{name}-golden.json")
67
+ save_embeddings(embeddings, output_path)
68
+ cosine_distance, rms = compare_embeddings(embeddings, golden_path)
69
+ log.info(
70
+ "%s - MCS: %.2f. RMS: %.2f. Latency: %.2f ms. Size: %d",
71
+ name,
72
+ cosine_distance,
73
+ rms,
74
+ benchmark / len(embeddings) * 1000,
75
+ len(embeddings[0]),
76
+ )
77
+
78
+
79
+ def process2(model: SentenceTransformerModel, documents: list[str], name: str) -> None:
80
+ model.embed(["Hello, world!"]) # Warmup
81
+ ts0 = time.time()
82
+ embeddings = model.embed(documents)
83
+ benchmark = time.time() - ts0
84
+ output_path = Path("data/embeddings.json")
85
+ save_embeddings(embeddings, output_path)
86
+ golden_path = Path(f"data/embeddings.{name}-golden.json")
87
+ if golden_path.exists():
88
+ golden_embeddings: Any = load_embeddings(golden_path)
89
+ similarities = model.model.similarity_pairwise(embeddings, golden_embeddings)
90
+ rms = torch.sqrt(torch.mean(similarities**2))
91
+ else:
92
+ save_embeddings(embeddings, golden_path)
93
+ rms = torch.zeros([0])
94
+ log.info(
95
+ "%s - RMS: %.2f. Latency: %.2f ms. Size: %d",
96
+ name,
97
+ rms,
98
+ benchmark / len(embeddings) * 1000,
99
+ len(embeddings[0]),
100
+ )
101
+
102
+
103
+ def load_embeddings(path: Path) -> list[npt.NDArray[np.float64]]:
104
+ with path.open() as f:
105
+ return [np.array(json.loads(line)) for line in f.readlines()]
106
+
107
+
108
+ def save_embeddings(embeddings: list[npt.NDArray[np.float64]], path: Path) -> None:
109
+ with path.open("w") as f:
110
+ for e in embeddings:
111
+ f.write(json.dumps(e.tolist()) + "\n")
112
+
113
+
114
+ def compare_embeddings(embeddings: list[npt.NDArray[np.float64]], golden_path: Path) -> tuple[float, float]:
115
+ if not golden_path.exists():
116
+ save_embeddings(embeddings, golden_path)
117
+ return 0.0, 0.0
118
+ with golden_path.open() as f:
119
+ golden_embeddings = [np.array(json.loads(line)) for line in f.readlines()]
120
+ np_embeddings = np.array(embeddings)
121
+ np_golden_embeddings = np.array(golden_embeddings)
122
+ rms = np.sqrt(np.mean((np_embeddings - np_golden_embeddings) ** 2))
123
+ dot_products = np.einsum("ij,ij->i", np_embeddings, np_golden_embeddings)
124
+ norms = np.linalg.norm(np_embeddings, axis=1) * np.linalg.norm(np_golden_embeddings, axis=1)
125
+ cosine_similarities = dot_products / norms
126
+ return np.mean(cosine_similarities), np.mean(rms)
127
+
128
+
129
+ @app.local_entrypoint()
130
+ async def modal_amain() -> None:
131
+ logging.basicConfig(level=logging.WARN)
132
+ log.setLevel(logging.DEBUG)
133
+
134
+ embedder = ModalEmbedder()
135
+ documents = await fetch_documents()
136
+ process(embedder.embed.remote, documents, "modal")
137
+
138
+
139
+ async def amain() -> None:
140
+ model = StellaEmbedder()
141
+ # model = NvEmbedder()
142
+ # model = JasperEmbedder()
143
+ # model.model.half()
144
+ documents = await fetch_documents()
145
+ process2(model, documents, "stella")
146
+
147
+
148
+ if __name__ == "__main__":
149
+ import asyncio
150
+
151
+ logging.basicConfig(level=logging.WARN)
152
+ log.setLevel(logging.DEBUG)
153
+ asyncio.run(amain())
article_embedding/query.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ from datetime import datetime
4
+ from typing import Any, Optional
5
+
6
+ from qdrant_client import QdrantClient
7
+ from qdrant_client.http.models import QueryResponse
8
+ from qdrant_client.models import FieldCondition, Filter, MatchValue, Range
9
+
10
+ from article_embedding.embed import StellaEmbedder
11
+
12
+ warnings.simplefilter(action="ignore", category=FutureWarning)
13
+
14
+
15
+ def as_timestamp(date: datetime | str) -> float:
16
+ if isinstance(date, datetime):
17
+ return date.timestamp()
18
+ return datetime.strptime(date, "%Y-%m-%d").timestamp()
19
+
20
+
21
+ def make_date_condition(
22
+ *, field: str = "published", date_from: datetime | str | None = None, date_to: datetime | str | None = None
23
+ ) -> FieldCondition | None:
24
+ kwargs = {}
25
+ if date_from:
26
+ kwargs["gte"] = as_timestamp(date_from)
27
+ if date_to:
28
+ kwargs["lt"] = as_timestamp(date_to)
29
+ if kwargs:
30
+ return FieldCondition(key=field, range=Range(**kwargs))
31
+ return None
32
+
33
+
34
+ def make_topic_condition(topic_id: str) -> FieldCondition:
35
+ return FieldCondition(key="topics[]", match=MatchValue(value=topic_id))
36
+
37
+
38
+ class Query:
39
+ _instance: Optional["Query"] = None
40
+ _embedding_model_instance: Optional[StellaEmbedder] = None
41
+
42
+ def __init__(self, index: str = "wsws") -> None:
43
+ self.embedding_model = Query.embedding_model_singleton()
44
+ self.qdrant = QdrantClient(os.environ["QDRANT_URL"])
45
+ self.index = index
46
+
47
+ @staticmethod
48
+ def embedding_model_singleton() -> StellaEmbedder:
49
+ if Query._embedding_model_instance is None:
50
+ Query._embedding_model_instance = StellaEmbedder()
51
+ return Query._embedding_model_instance
52
+
53
+ @staticmethod
54
+ def singleton() -> "Query":
55
+ if Query._instance is None:
56
+ Query._instance = Query()
57
+ return Query._instance
58
+
59
+ def embed(self, query: str) -> Any:
60
+ return self.embedding_model.embed([query])[0]
61
+
62
+ def query(
63
+ self,
64
+ query: str,
65
+ query_filter: Filter | None = None,
66
+ limit: int = 10,
67
+ ) -> QueryResponse:
68
+ vector = self.embedding_model.embed([query])[0]
69
+ return self.qdrant.query_points(self.index, query=vector, query_filter=query_filter, limit=limit)
70
+
71
+
72
+ if __name__ == "__main__":
73
+ import gspread
74
+ from dotenv import load_dotenv
75
+ from gspread.utils import ValueInputOption
76
+
77
+ data = [
78
+ ("2021-01-01", "2021-05-01", "The COVID winter wave, the emergence of the Delta variant and the January 6th coup"),
79
+ (
80
+ "2021-05-01",
81
+ "2021-09-01",
82
+ "The COVID vaccine rollout, Biden declaring independence from COVID while the Delta wave continues",
83
+ ),
84
+ (
85
+ "2021-09-01",
86
+ "2022-01-01",
87
+ "The emergence of the COVID Omicron variant and the embrace of herd immunity by the ruling class",
88
+ ),
89
+ ]
90
+
91
+ load_dotenv()
92
+ query = Query()
93
+ rows: list[list[str]] = []
94
+ for date_from, date_to, sentence in data:
95
+ result = query.query(
96
+ sentence,
97
+ query_filter=Filter(should=make_date_condition(date_from=date_from, date_to=date_to)),
98
+ )
99
+ rows.append([sentence])
100
+ for point in result.points:
101
+ doc = point.payload
102
+ assert doc is not None
103
+ print(f'{point.score * 100:.1f}% https://www.wsws.org{doc["path"]} - {doc["title"]}')
104
+ rows.append(
105
+ [
106
+ f"{point.score * 100:.1f}%",
107
+ datetime.fromtimestamp(doc["published"]).strftime("%Y/%m/%d"),
108
+ ", ".join(doc["authors"]),
109
+ f'=hyperlink("https://www.wsws.org{doc["path"]}", "{doc["title"]}")',
110
+ ]
111
+ )
112
+ rows.append([])
113
+
114
+ gc = gspread.auth.oauth(credentials_filename=os.environ["GOOGLE_CREDENTIALS"])
115
+ sh = gc.open("COVID-19 Compilation")
116
+ ws = sh.get_worksheet(0)
117
+ ws.append_rows(rows, value_input_option=ValueInputOption.user_entered)
article_embedding/replicate.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from qdrant_client import QdrantClient
2
+
3
+ from article_embedding.utils import env_str
4
+
5
+ remote_url = env_str("CLUSTER_1_ENDPOINT")
6
+ remote_key = env_str("CLUSTER_1_API_KEY")
7
+ collection_name = "wsws-2"
8
+ local_url = env_str("QDRANT_URL")
9
+
10
+ # model = StellaEmbedder()
11
+
12
+ cloud_client = QdrantClient(url=remote_url, api_key=remote_key)
13
+
14
+ # cloud_client.recreate_collection(
15
+ # collection_name=collection_name,
16
+ # vectors_config=model.get_vectors_config(),
17
+ # )
18
+
19
+ local_client = QdrantClient(local_url)
20
+
21
+ local_client.migrate(cloud_client, collection_names=[collection_name])
article_embedding/retrieval.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from typing import Generator, Literal
4
+
5
+ from pydantic import BaseModel
6
+ from qdrant_client.models import Filter, ScoredPoint
7
+
8
+ from article_embedding.app.model import Document
9
+ from article_embedding.constants import PERSPECTIVE_TOPIC_ID, US_TOPIC_ID
10
+ from article_embedding.query import Query, make_date_condition
11
+
12
+ log = logging.getLogger(__name__)
13
+
14
+
15
+ class Candidate(BaseModel):
16
+ doc_id: str
17
+ score: float
18
+ query_id: int
19
+ meta: Document
20
+
21
+ @classmethod
22
+ def from_point(cls, query_id: int, point: ScoredPoint) -> "Candidate":
23
+ assert isinstance(point.id, str)
24
+ return cls(doc_id=point.id, score=point.score, query_id=query_id, meta=Document.model_validate(point.payload))
25
+
26
+
27
+ CandidatePool = dict[str, Candidate]
28
+ Candidates = list[Candidate]
29
+ Fractions = tuple[float, float]
30
+ AuthorCounts = dict[str, int]
31
+
32
+
33
+ # Qdrant retrieval function (assume exists)
34
+ def qdrant_search(query_id: int, query: str, start_date: str, end_date: str, top_k: int) -> list[Candidate]:
35
+ # returns a list of candidates (doc_id, relevance_score, metadata)
36
+ # metadata includes: date, author, topics, etc.
37
+
38
+ query_filter = Filter(must=make_date_condition(date_from=start_date, date_to=end_date))
39
+ result = Query.singleton().query(query, query_filter=query_filter, limit=top_k)
40
+ return [Candidate.from_point(query_id, point) for point in result.points]
41
+
42
+
43
+ def compute_count(docs: Candidates, *topic_ids: str) -> int:
44
+ return sum(1 if any(d.meta.has_topic(t) for t in topic_ids) else 0 for d in docs)
45
+
46
+
47
+ def compute_fractions(docs: Candidates) -> Fractions:
48
+ oped_count = sum(d.meta.has_topic(PERSPECTIVE_TOPIC_ID) for d in docs)
49
+ us_count = sum(d.meta.has_topic(US_TOPIC_ID) for d in docs)
50
+ total = len(docs)
51
+ return (oped_count / total, us_count / total)
52
+
53
+
54
+ def compute_author_counts(docs: Candidates) -> AuthorCounts:
55
+ counts: AuthorCounts = {}
56
+ for d in docs:
57
+ for author in d.meta.authors or []:
58
+ counts[author] = counts.get(author, 0) + 1
59
+ return counts
60
+
61
+
62
+ Category = tuple[bool, bool]
63
+ SwapIndices = tuple[int, int]
64
+ SwapCandidates = tuple[Candidate, Candidate]
65
+ Sign = Literal[-1, 0, 1]
66
+ Evaluation = Literal["rejected", "accepted", "desired"]
67
+
68
+
69
+ def sign(n: int) -> Sign:
70
+ return -1 if n < 0 else 0 if n == 0 else 1
71
+
72
+
73
+ class Categories(BaseModel):
74
+ accepted: list[Sign] = []
75
+ desired: list[Sign] = []
76
+
77
+ def evaluate_delta(self, delta: int) -> Evaluation:
78
+ if sign(delta) in self.accepted:
79
+ return "accepted"
80
+ if sign(delta) in self.desired:
81
+ return "desired"
82
+ return "rejected"
83
+
84
+
85
+ def seek_constraints(
86
+ selected: Candidates,
87
+ available: Candidates,
88
+ oped_target: int,
89
+ us_target: int,
90
+ max_per_author: int,
91
+ tolerance: int,
92
+ ) -> bool:
93
+ # Swaps candidates from `selected` with candidates from `available` with the goal of
94
+ # meeting the target counts for OpEd and US topics, while respecting the max per author constraint.
95
+ # and favoring a more balanced author count distribution (i.e. closer to the mean).
96
+
97
+ # It is a precondition that the author counts in `selected` do not already exceed the limit.
98
+ # The function will attempt to adjust the selection to meet the target counts approximately.
99
+
100
+ # The return value indicates if the algorithm was able to meet the constraints.
101
+
102
+ # This is a heuristic approach, i.e. it may not always find the optimal solution.
103
+ # It swaps documents from the tail of `selected` with appropriate replacements from
104
+ # the head of `available`.
105
+
106
+ # The max per author constraint is a hard constraint, i.e. it may not be exceeded.
107
+ # The algorithm also attempts to move the author count means closer to the mean.
108
+
109
+ # The count constraints are soft constraints, i.e. they allow some tolerance.
110
+ # The algorithm finishes when the counts meet the target within the tolerance.
111
+
112
+ # A band of ±2 is considered a hard constraint, ±1 is a soft constraint, and 0 is on target.
113
+ def compute_band(current: int, target: int, tolerance: int) -> int:
114
+ if current < target - tolerance:
115
+ return -2
116
+ if current == target - tolerance:
117
+ return -1
118
+ if current < target + tolerance:
119
+ return 0
120
+ if current == target + tolerance:
121
+ return +1
122
+ return +2
123
+
124
+ # Determine current counts and bands
125
+ oped_current = compute_count(selected, PERSPECTIVE_TOPIC_ID)
126
+ us_current = compute_count(selected, US_TOPIC_ID)
127
+ author_counts = compute_author_counts(selected)
128
+ author_mean_numerator = sum(author_counts.values())
129
+ author_mean_denominator = len(author_counts) or 1
130
+
131
+ # It is a precondition that the author counts in `selected` do not exceed the limit
132
+ if sum(1 for count in author_counts.values() if count > max_per_author):
133
+ raise ValueError("Precondition violated: author counts exceed the limit")
134
+
135
+ def iterate_swap_indices() -> Generator[SwapIndices, None, None]:
136
+ for dropped_index in range(len(selected) - 1, -1, -1):
137
+ for picked_index in range(len(available)):
138
+ yield (picked_index, dropped_index)
139
+
140
+ def compute_topic_delta(swap_candidates: SwapCandidates, topic_id: str) -> int:
141
+ return sum(d if c.meta.has_topic(topic_id) else 0 for d, c in zip([1, -1], swap_candidates))
142
+
143
+ def evaluate_author_change(author: str, delta: int, mean: float) -> int:
144
+ count = author_counts.get(author, 0)
145
+ old_distance = abs(count - mean)
146
+ new_distance = abs(count + delta - mean)
147
+ return int(math.copysign(1, old_distance - new_distance))
148
+
149
+ def evaluate_author_exchange(swap_candidates: SwapCandidates, mean: float) -> Evaluation:
150
+ points = 0
151
+ # An exchange gets points for moving author counts closer to the mean
152
+ picked, dropped = swap_candidates
153
+ for author_picked in picked.meta.authors or []:
154
+ if author_counts.get(author_picked, 0) >= max_per_author:
155
+ return "rejected"
156
+ points += evaluate_author_change(author_picked, +1, mean)
157
+ for author_dropped in dropped.meta.authors or []:
158
+ points += evaluate_author_change(author_dropped, -1, mean)
159
+ return "desired" if points > 0 else "accepted"
160
+
161
+ def update_author_counts(swap_candidates: SwapCandidates) -> int:
162
+ author_mean_numerator_delta = 0
163
+
164
+ picked, dropped = swap_candidates
165
+ for author in picked.meta.authors or []:
166
+ author_mean_numerator_delta += 1
167
+ author_counts[author] = author_counts.get(author, 0) + 1
168
+ for author in dropped.meta.authors or []:
169
+ author_mean_numerator_delta -= 1
170
+ new_count = author_counts.pop(author) - 1
171
+ if new_count:
172
+ author_counts[author] = new_count
173
+
174
+ return author_mean_numerator_delta
175
+
176
+ # Bands
177
+ # -2 - pick +, drop -
178
+ # -1 - pick +, drop -, pick +, drop +, pick -, drop -
179
+ # 0 - all
180
+ # +1 - pick -, drop +, pick +, drop +, pick -, drop -
181
+ # +2 - pick -, drop +
182
+ CATEGORIES = [
183
+ Categories(desired=[1]),
184
+ Categories(desired=[1], accepted=[0]),
185
+ Categories(accepted=[1, 0, -1]),
186
+ Categories(desired=[-1], accepted=[0]),
187
+ Categories(desired=[-1]),
188
+ ]
189
+ # We do multiple passes, in each successive pass we try to improve fewer constraints
190
+ swap_attempt = 0
191
+ swap_count = 0
192
+ for desired_target in range(3, 0, -1):
193
+ current_pass = 4 - desired_target
194
+ # Iterate all possible swaps
195
+ for picked_index, dropped_index in iterate_swap_indices():
196
+ swap_attempt += 1
197
+ swap_candidates = (available[picked_index], selected[dropped_index])
198
+ oped_band = compute_band(oped_current, oped_target, tolerance)
199
+ us_band = compute_band(us_current, us_target, tolerance)
200
+ author_mean_denominator = len(author_counts) or 1
201
+ author_mean = author_mean_numerator / author_mean_denominator
202
+
203
+ if log.isEnabledFor(logging.INFO):
204
+ author_variance = sum((count - author_mean) ** 2 for count in author_counts.values()) / author_mean_denominator
205
+
206
+ if log.isEnabledFor(logging.DEBUG):
207
+ # fmt: off
208
+ log.debug("%d/%d - %d Eval %d <-> %d: oped=%d (%d %d), us=%d (%d %d), author µ=%.2f σ=%.2f",
209
+ swap_count, swap_attempt, current_pass, picked_index, dropped_index,
210
+ oped_current, oped_target, oped_band, us_current, us_target, us_band, author_mean, author_variance,
211
+ ) # fmt: on
212
+
213
+ # Check if constraints are met
214
+ if abs(oped_band) <= 1 and abs(us_band) <= 1:
215
+ # fmt: off
216
+ log.info("%d/%d - %d Constraints met: oped=%d (%d %d), us=%d (%d %d), author µ=%.2f σ=%.2f",
217
+ swap_count, swap_attempt, current_pass,
218
+ oped_current, oped_target, oped_band, us_target, us_target, us_band, author_mean, author_variance,
219
+ ) # fmt: on
220
+ return True
221
+
222
+ oped_delta = compute_topic_delta(swap_candidates, PERSPECTIVE_TOPIC_ID)
223
+ category = CATEGORIES[oped_band + 2]
224
+ oped_evaluation = category.evaluate_delta(oped_delta)
225
+ if oped_evaluation == "rejected":
226
+ log.debug("%d/%d - %d Swap rejected by OpEd constraints", swap_count, swap_attempt, current_pass)
227
+ continue
228
+
229
+ us_delta = compute_topic_delta(swap_candidates, US_TOPIC_ID)
230
+ category = CATEGORIES[us_band + 2]
231
+ us_evaluation = category.evaluate_delta(us_delta)
232
+ if us_evaluation == "rejected":
233
+ log.debug("%d/%d - %d Swap rejected by US constraints", swap_count, swap_attempt, current_pass)
234
+ continue
235
+
236
+ author_evaluation = evaluate_author_exchange(swap_candidates, author_mean)
237
+ if author_evaluation == "rejected":
238
+ log.debug("%d/%d - %d Swap rejected by author constraints", swap_count, swap_attempt, current_pass)
239
+ continue
240
+
241
+ desired_count = sum(1 for f in [oped_evaluation, us_evaluation, author_evaluation] if f == "desired")
242
+ if desired_count < desired_target:
243
+ log.debug("%d/%d - %d Swap rejected by desired count", swap_count, swap_attempt, current_pass)
244
+ continue
245
+
246
+ swap_count += 1
247
+ author_mean_numerator_delta = update_author_counts(swap_candidates)
248
+ oped_current += oped_delta
249
+ us_current += us_delta
250
+ author_mean_numerator += author_mean_numerator_delta
251
+ picked, dropped = swap_candidates
252
+ selected.pop(dropped_index)
253
+ available.pop(picked_index)
254
+ selected.append(picked)
255
+ available.append(dropped)
256
+
257
+ if log.isEnabledFor(logging.INFO):
258
+ author_variance = sum((count - author_mean) ** 2 for count in author_counts.values()) / author_mean_denominator
259
+ # fmt: off
260
+ log.info("%d/%d - %d Swapped %d <-> %d: oped=%d (%d %d), us=%d (%d %d), author µ=%.2f σ=%.2f",
261
+ swap_count, swap_attempt, current_pass, picked_index, dropped_index,
262
+ oped_current, oped_target, oped_band, us_current, us_target, us_band, author_mean, author_variance,
263
+ ) # fmt: on
264
+
265
+ log.warning("No more improvements possible")
266
+ return False
267
+
268
+
269
+ def retrieve_documents(
270
+ *,
271
+ queries: list[str],
272
+ start_date: str,
273
+ end_date: str,
274
+ number_of_articles: int,
275
+ op_ed_ratio: float,
276
+ us_ratio: float,
277
+ max_per_author: int,
278
+ ) -> Candidates:
279
+ K = 5 * number_of_articles # retrieve 5x the needed documents from each query
280
+
281
+ # 1. Retrieve candidates for each query
282
+ candidate_pool: CandidatePool = {}
283
+ for query_id, q in enumerate(queries):
284
+ results = qdrant_search(query_id, q, start_date, end_date, K)
285
+ for candidate in results:
286
+ if candidate.doc_id not in candidate_pool:
287
+ candidate_pool[candidate.doc_id] = candidate
288
+ else:
289
+ # If the doc appears multiple times, keep the higher score
290
+ if candidate.score > candidate_pool[candidate.doc_id].score:
291
+ candidate_pool[candidate.doc_id] = candidate
292
+
293
+ # 2. Sort candidates by relevance
294
+ sorted_candidates = sorted(candidate_pool.values(), key=lambda c: c.score, reverse=True)
295
+
296
+ # 3. Initial selection: take top M documents by relevance respecting author limits
297
+ available: Candidates = []
298
+ skipped: Candidates = []
299
+ author_counts: dict[str, int] = {}
300
+ for c in sorted_candidates:
301
+ skip = False
302
+ for author in c.meta.authors or []:
303
+ author_counts[author] = author_counts.get(author, 0) + 1
304
+ if author_counts[author] > max_per_author:
305
+ skip = True
306
+ break
307
+ if skip:
308
+ skipped.append(c)
309
+ else:
310
+ available.append(c)
311
+ if len(available) < number_of_articles:
312
+ raise ValueError("Not enough documents to meet the max_per_author constraint")
313
+ available += skipped
314
+
315
+ # 4. Attempt adjustments to meet constraints approximately
316
+ selected = available[:number_of_articles]
317
+ available = available[number_of_articles:]
318
+ seek_constraints(
319
+ selected=selected,
320
+ available=available,
321
+ oped_target=int(op_ed_ratio * number_of_articles),
322
+ us_target=int(us_ratio * number_of_articles),
323
+ max_per_author=max_per_author,
324
+ tolerance=int(0.05 * number_of_articles),
325
+ )
326
+
327
+ # selected now contains the chosen documents meeting date range, maximizing relevance,
328
+ # and approximately meeting the OpEd/Non-OpEd, US/International fractions, as well as
329
+ # respecting the max_per_author constraint.
330
+ return selected
article_embedding/sheets.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gspread
4
+
5
+ gc = gspread.auth.oauth(credentials_filename=os.environ["GOOGLE_CREDENTIALS"])
6
+ sh = gc.open("COVID-19 Compilation")
7
+ print(sh.sheet1.get("A1"))
article_embedding/utils.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Assorted utils."""
2
+
3
+ import json
4
+ import os
5
+ from types import EllipsisType
6
+ from typing import Any
7
+
8
+ import yaml
9
+
10
+ _FALSEY_STRINGS = ["", "0", "false", "no", "off", "n"]
11
+
12
+
13
+ def is_truthy(value: Any) -> bool:
14
+ """Check if a value is truthy."""
15
+ if isinstance(value, str):
16
+ return value.lower() not in _FALSEY_STRINGS
17
+ return bool(value)
18
+
19
+
20
+ def env_data(env_name: str, default: Any = ...) -> Any:
21
+ """Load a YAML/JSON from an environment variable."""
22
+ value = os.environ.get(env_name)
23
+ if value is None:
24
+ if default is ...:
25
+ raise ValueError(f"Missing environment variable {env_name}")
26
+ return default
27
+ extension = os.path.splitext(value)[1]
28
+ if extension not in [".yaml", ".yml", ".json"]:
29
+ raise ValueError(f"Environment variable {env_name} must be a YAML/JSON file")
30
+ try:
31
+ with open(value, encoding="utf-8") as fd:
32
+ if extension == ".json":
33
+ return json.load(fd)
34
+ return yaml.load(fd, Loader=yaml.CLoader)
35
+ except Exception as e:
36
+ raise ValueError(f"Environment variable {env_name}") from e
37
+
38
+
39
+ def env_str(env_name: str, default: str | EllipsisType = ...) -> str:
40
+ """Get a string value from an environment variable."""
41
+ value = os.environ.get(env_name)
42
+ if value is None:
43
+ if default is ...:
44
+ raise ValueError(f"Missing environment variable {env_name}")
45
+ return default
46
+ return value
47
+
48
+
49
+ def env_bool(env_name: str, default: bool | EllipsisType = ...) -> bool:
50
+ """Get a boolean value from an environment variable."""
51
+ value = os.environ.get(env_name)
52
+ if value is None:
53
+ if default is ...:
54
+ raise ValueError(f"Missing environment variable {env_name}")
55
+ return default
56
+ return value.lower() not in ["", "0", "false", "no", "off"]
57
+
58
+
59
+ def env_int(env_name: str, default: int | EllipsisType = ...) -> int:
60
+ """Get an int value from an environment variable."""
61
+ value = os.environ.get(env_name)
62
+ if value is None:
63
+ if default is ...:
64
+ raise ValueError(f"Missing environment variable {env_name}")
65
+ return default
66
+ return int(value)
67
+
68
+
69
+ def env_float(env_name: str, default: float | EllipsisType = ...) -> float:
70
+ """Get a boolean value from an environment variable."""
71
+ value = os.environ.get(env_name)
72
+ if value is None:
73
+ if default is ...:
74
+ raise ValueError(f"Missing environment variable {env_name}")
75
+ return default
76
+ return float(value)
notebooks/couchdb.http ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @url = {{$dotenv COUCHDB_URL}}
2
+ @user = {{$dotenv COUCHDB_USER}}
3
+ @password = {{$dotenv COUCHDB_PASSWORD}}
4
+ @db = {{$dotenv COUCHDB_DB}}
5
+ @origin = 0
6
+ @origin = 316723-g1AAAACReJzLYWBgYMpgTmHgzcvPy09JdcjLz8gvLskBCScyJNX___8_K4M5iYEpWyEXKMaekpyabGBuiK4ehwl5LECSoQFI_YcblCUMNsgsMS010dQSXVsWAD8FLCA
7
+ @batchSize = 10
8
+ @includeDocs = true
9
+
10
+ ###
11
+ # @name session
12
+ POST {{url}}/_session
13
+ Content-Type: application/json
14
+
15
+ {
16
+ "name": "{{user}}",
17
+ "password": "{{password}}"
18
+ }
19
+
20
+ ###
21
+ # @name changes
22
+ GET {{url}}/{{db}}/_changes
23
+ ?limit={{batchSize}}
24
+ &include_docs={{includeDocs}}
25
+ &since={{origin}}
26
+
27
+ ###
28
+ # @name changes
29
+ GET {{url}}/{{db}}/_changes
30
+ ?limit={{batchSize}}
31
+ &include_docs={{includeDocs}}
32
+ &since={{changes.response.body.last_seq}}
33
+
34
+ ###
35
+ # @name broken
36
+ GET {{url}}/{{db}}/_changes
37
+ ?limit=10
38
+ &include_docs=true
39
+ &since=316731-g1AAAACReJzLYWBgYMpgTmHgzcvPy09JdcjLz8gvLskBCScyJNX___8_K4M5iYEpWyEXKMaekpyabGBuiK4ehwl5LECSoQFI_YcblCUNNsgsMS010dQSXVsWAD_1LCg
notebooks/models.ipynb ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import pandas as pd\n",
10
+ "from fastembed import TextEmbedding\n",
11
+ "\n",
12
+ "# pd.set_option('display.max_colwidth', 100)\n",
13
+ "supported_models = (\n",
14
+ " pd.DataFrame(TextEmbedding.list_supported_models())\n",
15
+ " .sort_values(\"size_in_GB\")\n",
16
+ " .drop(columns=[\"sources\", \"model_file\", \"additional_files\"])\n",
17
+ " .reset_index(drop=True)\n",
18
+ ")\n",
19
+ "supported_models.style.set_properties(**{\"width\": \"300px\"}, subset=[\"description\"])"
20
+ ]
21
+ }
22
+ ],
23
+ "metadata": {
24
+ "kernelspec": {
25
+ "display_name": "article-embedding-p7H3l83p-py3.12",
26
+ "language": "python",
27
+ "name": "python3"
28
+ },
29
+ "language_info": {
30
+ "codemirror_mode": {
31
+ "name": "ipython",
32
+ "version": 3
33
+ },
34
+ "file_extension": ".py",
35
+ "mimetype": "text/x-python",
36
+ "name": "python",
37
+ "nbconvert_exporter": "python",
38
+ "pygments_lexer": "ipython3",
39
+ "version": "3.12.7"
40
+ }
41
+ },
42
+ "nbformat": 4,
43
+ "nbformat_minor": 2
44
+ }
notebooks/query.ipynb ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import warnings\n",
10
+ "\n",
11
+ "from qdrant_client import QdrantClient\n",
12
+ "\n",
13
+ "from article_embedding.embed import SentenceTransformerModel\n",
14
+ "\n",
15
+ "warnings.simplefilter(action=\"ignore\", category=FutureWarning)\n",
16
+ "embedding_model = SentenceTransformerModel(\"dunzhang/stella_en_400M_v5\")\n",
17
+ "qdrant = QdrantClient()"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": null,
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "from qdrant_client import QdrantClient\n",
27
+ "\n",
28
+ "from article_embedding.embed import SentenceTransformerModel\n",
29
+ "\n",
30
+ "warnings.simplefilter(action=\"ignore\", category=FutureWarning)\n",
31
+ "embedding_model = SentenceTransformerModel(\"dunzhang/stella_en_400M_v5\")\n",
32
+ "qdrant = QdrantClient()"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": null,
38
+ "metadata": {},
39
+ "outputs": [],
40
+ "source": [
41
+ "vector = embedding_model.embed(\n",
42
+ " [\n",
43
+ " \"\"\"\n",
44
+ "The COVID-19 pandemic is the worst public health crisis since the 1918 influenza pandemic over a century ago,\n",
45
+ "when the science of virology was still in its infancy. By mid-July 2022, the official death toll since the\n",
46
+ "start of the pandemic stood at over 1 million in the US and nearly 6.4 million globally. Estimates of excess\n",
47
+ "deaths place the number of people killed directly or indirectly by the pandemic at over 1.2 million in the US\n",
48
+ "and 21.5 million internationally. In just two-and-a-half years, total deaths from the pandemic have reached the\n",
49
+ "estimated 20 million military and civilian deaths during the four years of World War I (1914-1918).\n",
50
+ "\"\"\"\n",
51
+ " ]\n",
52
+ ")"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": null,
58
+ "metadata": {},
59
+ "outputs": [],
60
+ "source": [
61
+ "vector = embedding_model.embed(\n",
62
+ " [\n",
63
+ " \"\"\"\n",
64
+ "Let us borrow Trotsky’s brilliant use of metaphor for the purpose of explaining the essential significance\n",
65
+ "of the petty-bourgeois pseudo-left’s call for a “Party of the 99 percent.” On the cruise ship America, one\n",
66
+ "percent of the passengers occupy state rooms that look out on the ocean. They dine with captain Trump in an\n",
67
+ "exclusive five-star restaurant, where they wash down their succulent meals with wine that costs $10,000 per\n",
68
+ "bottle. The \"next nine percent\" of the passengers, depending on what they can afford, make do with cabins\n",
69
+ "whose quality reflects their lower price. The cheapest of the rooms in the top ten percent category lack a\n",
70
+ "view out onto the ocean and have shabby rugs and uncomfortable mattresses. And, unless they are able and\n",
71
+ "willing to pay a substantial surcharge, the occupants of these rooms are not permitted to use the pool and\n",
72
+ "spa reserved for the richest passengers.\n",
73
+ "\"\"\"\n",
74
+ " ]\n",
75
+ ")"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": null,
81
+ "metadata": {},
82
+ "outputs": [],
83
+ "source": [
84
+ "vector = embedding_model.embed(\n",
85
+ " [\n",
86
+ " \"\"\"\n",
87
+ "Israel’s slaughter of tens if not hundreds of thousands of men, women and children, the destruction of entire\n",
88
+ "cities, hospitals, schools, mosques, and the imposition of mass starvation on Gaza’s 2.3 million residents—all\n",
89
+ "this has produced profound shock and anger towards the Netanyahu regime and its imperialist backers in the US,\n",
90
+ "Britain, Germany and elsewhere. It has fueled the political radicalisation of millions of workers and young people,\n",
91
+ "for whom Zionism is now synonymous with some of the worst atrocities committed anywhere since the Second World War.\n",
92
+ "\"\"\"\n",
93
+ " ]\n",
94
+ ")"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "execution_count": null,
100
+ "metadata": {},
101
+ "outputs": [],
102
+ "source": [
103
+ "vector = embedding_model.embed([\"The COVID winter wave, the emergence of the Delta variant and the January 6th coup\"])[0]"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": null,
109
+ "metadata": {},
110
+ "outputs": [],
111
+ "source": [
112
+ "vector"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": null,
118
+ "metadata": {},
119
+ "outputs": [],
120
+ "source": [
121
+ "# calculate epoch seconds for 2023-01-01\n",
122
+ "query_filter = None\n",
123
+ "result = qdrant.query_points(\"wsws\", query=vector[0], query_filter=query_filter)"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": null,
129
+ "metadata": {},
130
+ "outputs": [],
131
+ "source": [
132
+ "for point in result.points:\n",
133
+ " doc = point.payload\n",
134
+ " print(f\"{point.score * 100:.1f}% https://www.wsws.org/{doc[\"path\"]} - {doc[\"title\"]}\")"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": null,
140
+ "metadata": {},
141
+ "outputs": [],
142
+ "source": [
143
+ "print(result.model_dump_json(indent=2))"
144
+ ]
145
+ }
146
+ ],
147
+ "metadata": {
148
+ "kernelspec": {
149
+ "display_name": ".venv",
150
+ "language": "python",
151
+ "name": "python3"
152
+ },
153
+ "language_info": {
154
+ "codemirror_mode": {
155
+ "name": "ipython",
156
+ "version": 3
157
+ },
158
+ "file_extension": ".py",
159
+ "mimetype": "text/x-python",
160
+ "name": "python",
161
+ "nbconvert_exporter": "python",
162
+ "pygments_lexer": "ipython3",
163
+ "version": "3.12.5"
164
+ }
165
+ },
166
+ "nbformat": 4,
167
+ "nbformat_minor": 2
168
+ }
notebooks/stella.ipynb ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from sentence_transformers import SentenceTransformer\n",
10
+ "\n",
11
+ "# This model supports two prompts: \"s2p_query\" and \"s2s_query\" for sentence-to-passage and sentence-to-sentence tasks, respectively.\n",
12
+ "# They are defined in `config_sentence_transformers.json`\n",
13
+ "query_prompt_name = \"s2p_query\"\n",
14
+ "queries = [\n",
15
+ " \"What are some ways to reduce stress?\",\n",
16
+ " \"What are the benefits of drinking green tea?\",\n",
17
+ "]\n",
18
+ "# docs do not need any prompts\n",
19
+ "docs = [\n",
20
+ " \"There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent stress from building up.\", # noqa: E501\n",
21
+ " \"Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.\", # noqa: E501\n",
22
+ "]\n",
23
+ "\n",
24
+ "# !The default dimension is 1024, if you need other dimensions, please clone the model and modify `modules.json` to replace `2_Dense_1024` with another dimension, e.g. `2_Dense_256` or `2_Dense_8192` ! # noqa: E501\n",
25
+ "# on gpu\n",
26
+ "# model = SentenceTransformer(\"dunzhang/stella_en_400M_v5\", trust_remote_code=True).cuda()\n",
27
+ "# you can also use this model without the features of `use_memory_efficient_attention` and `unpad_inputs`. It can be worked in CPU.\n",
28
+ "model = SentenceTransformer(\n",
29
+ " \"dunzhang/stella_en_400M_v5\",\n",
30
+ " trust_remote_code=True,\n",
31
+ " device=\"mps\",\n",
32
+ " config_kwargs={\"use_memory_efficient_attention\": False, \"unpad_inputs\": False},\n",
33
+ ")\n",
34
+ "query_embeddings = model.encode(queries, prompt_name=query_prompt_name)\n",
35
+ "doc_embeddings = model.encode(docs)\n",
36
+ "print(query_embeddings.shape, doc_embeddings.shape)\n",
37
+ "# (2, 1024) (2, 1024)\n",
38
+ "\n",
39
+ "similarities = model.similarity(query_embeddings, doc_embeddings)\n",
40
+ "print(similarities)\n",
41
+ "# tensor([[0.8398, 0.2990],\n",
42
+ "# [0.3282, 0.8095]])"
43
+ ]
44
+ }
45
+ ],
46
+ "metadata": {
47
+ "kernelspec": {
48
+ "display_name": ".venv",
49
+ "language": "python",
50
+ "name": "python3"
51
+ },
52
+ "language_info": {
53
+ "codemirror_mode": {
54
+ "name": "ipython",
55
+ "version": 3
56
+ },
57
+ "file_extension": ".py",
58
+ "mimetype": "text/x-python",
59
+ "name": "python",
60
+ "nbconvert_exporter": "python",
61
+ "pygments_lexer": "ipython3",
62
+ "version": "3.12.5"
63
+ }
64
+ },
65
+ "nbformat": 4,
66
+ "nbformat_minor": 2
67
+ }
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
poetry.toml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [virtualenvs]
2
+ in-project = true
pyproject.toml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "article-embedding"
3
+ version = "0.1.0"
4
+ description = ""
5
+ authors = ["Jesus Lopez <jesus@jesusla.com>"]
6
+ readme = "README.md"
7
+
8
+ [tool.poetry.scripts]
9
+ article-embedding = "article_embedding.cli:cli"
10
+
11
+ [tool.poetry.dependencies]
12
+ python = "~3.12"
13
+ httpx = "^0.27.2"
14
+ backoff = "^2.2.1"
15
+ python-dotenv = "^1.0.1"
16
+ tqdm = "^4.67.0"
17
+ qdrant-client = { extras = ["fastembed"], version = "^1.12.1" }
18
+ sentence-transformers = "^3.3.1"
19
+ gspread = "^6.1.4"
20
+ gradio = "^5.8.0"
21
+ langchain-text-splitters = "^0.3.2"
22
+ click = "^8.1.7"
23
+ click-help-colors = "^0.9.4"
24
+ modal = "^0.68.23"
25
+ datasets = "^3.2.0"
26
+ einops = "^0.8.0"
27
+ torch = "^2.5.1"
28
+ pandoc = "^2.4"
29
+ openai = "^1.58.1"
30
+ pyyaml = "^6.0.2"
31
+ pydantic = "^2.10.4"
32
+
33
+
34
+ [tool.poetry.group.dev.dependencies]
35
+ pre-commit = "^4.0.1"
36
+ ipykernel = "^6.29.5"
37
+ ruff = "^0.7.4"
38
+ nbstripout = "^0.8.1"
39
+
40
+
41
+ [tool.poetry.group.typing.dependencies]
42
+ mypy = "^1.13.0"
43
+ types-tqdm = "^4.66.0.20240417"
44
+ types-pyyaml = "^6.0.12.20241221"
45
+
46
+
47
+ [tool.poetry.group.jupyter.dependencies]
48
+ pandas = "^2.2.3"
49
+ ipywidgets = "^8.1.5"
50
+ jinja2 = "^3.1.4"
51
+
52
+
53
+ [tool.poetry.group.ci.dependencies]
54
+ huggingface-hub = { extras = ["cli"], version = "^0.27.1" }
55
+ gradio = "^5.12.0"
56
+
57
+ [tool.ruff]
58
+ line-length = 132
59
+
60
+ [tool.ruff.lint]
61
+ select = [
62
+ # pycodestyle
63
+ "E",
64
+ # Pyflakes
65
+ "F",
66
+ # pyupgrade
67
+ "UP",
68
+ # flake8-bugbear
69
+ "B",
70
+ # flake8-simplify
71
+ "SIM",
72
+ # isort
73
+ "I",
74
+ ]
75
+
76
+ [tool.mypy]
77
+ strict = true
78
+
79
+ [[tool.mypy.overrides]]
80
+ module = ["gradio", "pandoc"]
81
+ ignore_missing_imports = true
82
+
83
+
84
+ [build-system]
85
+ requires = ["poetry-core"]
86
+ build-backend = "poetry.core.masonry.api"
requirements.txt ADDED
The diff for this file is too large to render. See raw diff
 
tests/__init__.py ADDED
File without changes