Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- .devcontainer/.DS_Store +0 -0
- .devcontainer/dev/Dockerfile +9 -0
- .devcontainer/dev/bin/svc +23 -0
- .devcontainer/devcontainer.json +51 -0
- .devcontainer/docker-compose.yml +19 -0
- .devcontainer/scripts/post-create.sh +25 -0
- .devcontainer/scripts/post-start.sh +2 -0
- .gitattributes +5 -35
- .github/workflows/update_space.yml +32 -0
- .gitignore +7 -0
- .pre-commit-config.yaml +19 -0
- .vscode/launch.json +59 -0
- .vscode/settings.json +26 -0
- README.md +141 -6
- article_embedding/__init__.py +0 -0
- article_embedding/__main__.py +4 -0
- article_embedding/app/__init__.py +3 -0
- article_embedding/app/gather.py +71 -0
- article_embedding/app/lookup.py +11 -0
- article_embedding/app/model.py +71 -0
- article_embedding/app/rag.py +134 -0
- article_embedding/app/search.py +70 -0
- article_embedding/app/ui.py +31 -0
- article_embedding/benchmark.py +76 -0
- article_embedding/chunk.py +25 -0
- article_embedding/cli.py +24 -0
- article_embedding/constants.py +2 -0
- article_embedding/couchdb.py +203 -0
- article_embedding/embed.py +86 -0
- article_embedding/loader.py +146 -0
- article_embedding/modal_app.py +153 -0
- article_embedding/query.py +117 -0
- article_embedding/replicate.py +21 -0
- article_embedding/retrieval.py +330 -0
- article_embedding/sheets.py +7 -0
- article_embedding/utils.py +76 -0
- notebooks/couchdb.http +39 -0
- notebooks/models.ipynb +44 -0
- notebooks/query.ipynb +168 -0
- notebooks/stella.ipynb +67 -0
- poetry.lock +0 -0
- poetry.toml +2 -0
- pyproject.toml +86 -0
- requirements.txt +0 -0
- tests/__init__.py +0 -0
.devcontainer/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
.devcontainer/dev/Dockerfile
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM mcr.microsoft.com/devcontainers/python:1-3.12-bookworm
|
2 |
+
|
3 |
+
RUN <<EOF
|
4 |
+
apt update
|
5 |
+
apt install -y \
|
6 |
+
silversearcher-ag \
|
7 |
+
iputils-ping \
|
8 |
+
|
9 |
+
EOF
|
.devcontainer/dev/bin/svc
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# Usage: svc service [command]
|
3 |
+
# Example: svc app ps aux
|
4 |
+
|
5 |
+
# Validate arguments
|
6 |
+
if [[ $# -lt 1 ]]; then
|
7 |
+
echo "Usage: svc service [command]"
|
8 |
+
exit 1
|
9 |
+
fi
|
10 |
+
|
11 |
+
# Service
|
12 |
+
case $1 in
|
13 |
+
*)
|
14 |
+
SVC=
|
15 |
+
CMD=
|
16 |
+
;;
|
17 |
+
esac
|
18 |
+
|
19 |
+
CMD_ARGS="${@:2}"
|
20 |
+
if [[ -z "$CMD_ARGS" ]]; then
|
21 |
+
CMD_ARGS="$CMD"
|
22 |
+
fi
|
23 |
+
docker compose exec $SVC $CMD_ARGS
|
.devcontainer/devcontainer.json
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
|
2 |
+
// README at: https://github.com/devcontainers/templates/tree/main/src/python
|
3 |
+
{
|
4 |
+
"name": "Article Embedding",
|
5 |
+
"dockerComposeFile": "docker-compose.yml",
|
6 |
+
"service": "dev",
|
7 |
+
"runServices": [
|
8 |
+
"dev",
|
9 |
+
"qdrant"
|
10 |
+
],
|
11 |
+
"workspaceFolder": "/workspaces/article-embedding",
|
12 |
+
"features": {
|
13 |
+
"ghcr.io/devcontainers/features/github-cli:1": {},
|
14 |
+
"ghcr.io/nikobockerman/devcontainer-features/poetry-persistent-cache:1": {},
|
15 |
+
"ghcr.io/devcontainers-extra/features/poetry:2": {},
|
16 |
+
"ghcr.io/devcontainers/features/docker-outside-of-docker:1": {}
|
17 |
+
},
|
18 |
+
"containerEnv": {
|
19 |
+
"DOCKER_CLI_HINTS": "false",
|
20 |
+
"POETRY_VIRTUALENVS_IN_PROJECT": "false",
|
21 |
+
"VIRTUAL_ENV": "/mnt/poetry-persistent-cache/virtualenvs/article-embedding-p7H3l83p-py3.12",
|
22 |
+
"COMPOSE_PROJECT_NAME": "article-embedding_devcontainer",
|
23 |
+
"COMPOSE_FILE": "${containerWorkspaceFolder}/.devcontainer/docker-compose.yml"
|
24 |
+
},
|
25 |
+
"remoteEnv": {
|
26 |
+
"PATH": "${containerEnv:VIRTUAL_ENV}/bin:${containerWorkspaceFolder}/.devcontainer/dev/bin:${containerEnv:PATH}"
|
27 |
+
},
|
28 |
+
"mounts": [
|
29 |
+
"type=bind,source=${localEnv:HOME}${localEnv:USERPROFILE}/.gitconfig,target=/home/vscode/.gitconfig,readonly",
|
30 |
+
"type=bind,source=${localEnv:HOME}${localEnv:USERPROFILE}/.ssh,target=/home/vscode/.ssh,readonly"
|
31 |
+
],
|
32 |
+
"postCreateCommand": ".devcontainer/scripts/post-create.sh ${containerWorkspaceFolder}",
|
33 |
+
"postStartCommand": ".devcontainer/scripts/post-start.sh ${containerWorkspaceFolder}",
|
34 |
+
"customizations": {
|
35 |
+
"vscode": {
|
36 |
+
"extensions": [
|
37 |
+
"ms-python.debugpy",
|
38 |
+
"github.vscode-github-actions",
|
39 |
+
"matangover.mypy",
|
40 |
+
"tamasfe.even-better-toml",
|
41 |
+
"humao.rest-client",
|
42 |
+
"charliermarsh.ruff",
|
43 |
+
"ms-toolsai.jupyter",
|
44 |
+
"redhat.vscode-yaml"
|
45 |
+
],
|
46 |
+
"settings": {
|
47 |
+
"python.defaultInterpreterPath": "${containerEnv:VIRTUAL_ENV}/bin/python"
|
48 |
+
}
|
49 |
+
}
|
50 |
+
}
|
51 |
+
}
|
.devcontainer/docker-compose.yml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
services:
|
2 |
+
dev:
|
3 |
+
profiles:
|
4 |
+
- devcontainer
|
5 |
+
build: dev
|
6 |
+
volumes:
|
7 |
+
- ../..:/workspaces:cached
|
8 |
+
- ..:/workspaces/article-embedding:cached
|
9 |
+
command: sleep infinity
|
10 |
+
|
11 |
+
qdrant:
|
12 |
+
image: qdrant/qdrant
|
13 |
+
ports:
|
14 |
+
- "6333:6333"
|
15 |
+
volumes:
|
16 |
+
- qdrant:/qdrant/storage:z
|
17 |
+
|
18 |
+
volumes:
|
19 |
+
qdrant:
|
.devcontainer/scripts/post-create.sh
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# Fail on any error
|
3 |
+
set -e
|
4 |
+
|
5 |
+
# Ensure the lock file is up to date, otherwise the install step may fail later
|
6 |
+
poetry lock --no-update
|
7 |
+
|
8 |
+
# Remove the old-style .venv symlink if it exists
|
9 |
+
if [[ -L .venv ]]; then
|
10 |
+
rm .venv
|
11 |
+
fi
|
12 |
+
|
13 |
+
# Force the virtual environment to the path set in the .devcontainer.json file
|
14 |
+
python -mvenv $VIRTUAL_ENV
|
15 |
+
PATH="$VIRTUAL_ENV/bin:$PATH"
|
16 |
+
poetry install
|
17 |
+
|
18 |
+
# Activate pre commit hooks
|
19 |
+
pre-commit install
|
20 |
+
|
21 |
+
# Configuration files
|
22 |
+
cat >> ~/.inputrc <<EOF
|
23 |
+
set completion-ignore-case on
|
24 |
+
set editing-mode vi
|
25 |
+
EOF
|
.devcontainer/scripts/post-start.sh
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
set -e
|
.gitattributes
CHANGED
@@ -1,35 +1,5 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
*.
|
4 |
-
*.
|
5 |
-
*.
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
poetry.lock binary
|
2 |
+
|
3 |
+
*.ipynb filter=nbstripout
|
4 |
+
*.zpln filter=nbstripout
|
5 |
+
*.ipynb diff=ipynb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.github/workflows/update_space.yml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Run Python script
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
|
8 |
+
jobs:
|
9 |
+
build:
|
10 |
+
runs-on: ubuntu-latest
|
11 |
+
|
12 |
+
steps:
|
13 |
+
- name: Checkout
|
14 |
+
uses: actions/checkout@v4
|
15 |
+
|
16 |
+
- name: Install Poetry
|
17 |
+
run: pipx install poetry
|
18 |
+
|
19 |
+
- name: Set up Python
|
20 |
+
uses: actions/setup-python@v5
|
21 |
+
with:
|
22 |
+
python-version: "3.11"
|
23 |
+
cache: "poetry"
|
24 |
+
|
25 |
+
- name: Deploy
|
26 |
+
run: |-
|
27 |
+
pipx inject poetry poetry-plugin-export
|
28 |
+
poetry export --only=main >requirements.txt
|
29 |
+
poetry install --only=ci --no-root -q
|
30 |
+
PATH=$(poetry env info --path)/bin:$PATH
|
31 |
+
huggingface-cli login --token ${{ secrets.HF_TOKEN }}
|
32 |
+
gradio deploy
|
.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.*checkpoint
|
2 |
+
.keys/
|
3 |
+
__pycache__/
|
4 |
+
.gradio/
|
5 |
+
.env
|
6 |
+
data/
|
7 |
+
.venv/
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# See https://pre-commit.com for more information
|
2 |
+
# See https://pre-commit.com/hooks.html for more hooks
|
3 |
+
repos:
|
4 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
5 |
+
rev: v3.2.0
|
6 |
+
hooks:
|
7 |
+
- id: trailing-whitespace
|
8 |
+
- id: end-of-file-fixer
|
9 |
+
- id: check-yaml
|
10 |
+
- id: check-added-large-files
|
11 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
12 |
+
# Ruff version.
|
13 |
+
rev: v0.7.4
|
14 |
+
hooks:
|
15 |
+
# Run the linter.
|
16 |
+
- id: ruff
|
17 |
+
args: [--fix]
|
18 |
+
# Run the formatter.
|
19 |
+
- id: ruff-format
|
.vscode/launch.json
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
// Use IntelliSense to learn about possible attributes.
|
3 |
+
// Hover to view descriptions of existing attributes.
|
4 |
+
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
5 |
+
"version": "0.2.0",
|
6 |
+
"configurations": [
|
7 |
+
{
|
8 |
+
"name": "App",
|
9 |
+
"type": "debugpy",
|
10 |
+
"request": "launch",
|
11 |
+
"module": "article_embedding",
|
12 |
+
"args": [
|
13 |
+
"app"
|
14 |
+
],
|
15 |
+
"console": "integratedTerminal",
|
16 |
+
"justMyCode": false,
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"name": "Scanner",
|
20 |
+
"type": "debugpy",
|
21 |
+
"request": "launch",
|
22 |
+
"program": "article_embedding/couchdb.py",
|
23 |
+
"console": "integratedTerminal",
|
24 |
+
"justMyCode": false,
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"name": "Loader",
|
28 |
+
"type": "debugpy",
|
29 |
+
"request": "launch",
|
30 |
+
"program": "article_embedding/loader.py",
|
31 |
+
"console": "integratedTerminal",
|
32 |
+
"justMyCode": false,
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"name": "Embedder",
|
36 |
+
"type": "debugpy",
|
37 |
+
"request": "launch",
|
38 |
+
"program": "article_embedding/embed.py",
|
39 |
+
"console": "integratedTerminal",
|
40 |
+
"justMyCode": false,
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"name": "Query",
|
44 |
+
"type": "debugpy",
|
45 |
+
"request": "launch",
|
46 |
+
"program": "article_embedding/query.py",
|
47 |
+
"console": "integratedTerminal",
|
48 |
+
"justMyCode": false,
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"name": "Benchmark",
|
52 |
+
"type": "debugpy",
|
53 |
+
"request": "launch",
|
54 |
+
"program": "article_embedding/benchmark.py",
|
55 |
+
"console": "integratedTerminal",
|
56 |
+
"justMyCode": false,
|
57 |
+
},
|
58 |
+
]
|
59 |
+
}
|
.vscode/settings.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"editor.codeActionsOnSave": {
|
3 |
+
"source.organizeImports": "always",
|
4 |
+
"source.unusedImports": "always"
|
5 |
+
},
|
6 |
+
"editor.formatOnPaste": true,
|
7 |
+
"editor.formatOnSave": true,
|
8 |
+
"editor.formatOnSaveMode": "file",
|
9 |
+
"editor.formatOnType": true,
|
10 |
+
"editor.renderWhitespace": "all",
|
11 |
+
"editor.rulers": [
|
12 |
+
132
|
13 |
+
],
|
14 |
+
"files.trimTrailingWhitespace": true,
|
15 |
+
"mypy.runUsingActiveInterpreter": true,
|
16 |
+
"python.analysis.importFormat": "absolute",
|
17 |
+
"python.analysis.autoFormatStrings": true,
|
18 |
+
"python.analysis.autoImportCompletions": true,
|
19 |
+
"ruff.organizeImports": true,
|
20 |
+
"[jsonc]": {
|
21 |
+
"editor.defaultFormatter": "vscode.json-language-features",
|
22 |
+
},
|
23 |
+
"[python]": {
|
24 |
+
"editor.defaultFormatter": "charliermarsh.ruff",
|
25 |
+
}
|
26 |
+
}
|
README.md
CHANGED
@@ -1,13 +1,148 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.12.0
|
8 |
-
app_file: app.py
|
9 |
pinned: false
|
10 |
short_description: Ask the WSWS chatbot
|
11 |
---
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: wsws-chatbot
|
3 |
+
emoji: 📊
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.12.0
|
8 |
+
app_file: article_embedding/app/ui.py
|
9 |
pinned: false
|
10 |
short_description: Ask the WSWS chatbot
|
11 |
---
|
12 |
|
13 |
+
## Filter
|
14 |
+
|
15 |
+
Identify the top 60 most important articles of the year
|
16 |
+
|
17 |
+
| Variables | Description |
|
18 |
+
| -------------------- | ------------------------------------- |
|
19 |
+
| articlesPerYear | Number of articles per year |
|
20 |
+
| perspectivePercent | Percentage of perspectives |
|
21 |
+
| usPercent | Percentage of US articles |
|
22 |
+
| maxArticlesPerAuthor | Maximum number of articles per author |
|
23 |
+
|
24 |
+
### Important topics
|
25 |
+
|
26 |
+
| Topic ID | Topic Name |
|
27 |
+
| ------------------------------------ | ------------------------ |
|
28 |
+
| 3f1db5b4-7915-4cef-a132-3b2e51f0728d | US Politics |
|
29 |
+
| 105b9ec2-0cdd-480d-9c55-6f806fa7b9bc | Perspectives |
|
30 |
+
| 04d2a008-8cef-4928-9a38-86b205a7590f | The coronavirus pandemic |
|
31 |
+
| ea69254c-a976-4e67-a79c-843202a10e55 | United States |
|
32 |
+
| b0e983c5-6e85-4829-a9fe-3c25f02c211f | North America |
|
33 |
+
|
34 |
+
## Analysis
|
35 |
+
|
36 |
+
1. Document metadata structure
|
37 |
+
|
38 |
+
- The metadata is stored as JSON
|
39 |
+
- Yes, there are structured attributes in the metadata
|
40 |
+
- There are many other structured attributes but none that are relevant to the system design
|
41 |
+
|
42 |
+
2. Vector database setup
|
43 |
+
|
44 |
+
- Qdrant
|
45 |
+
- Documents are embedded using the Stella 400M v5 embedding
|
46 |
+
|
47 |
+
3. Integration of filters and constraints
|
48 |
+
|
49 |
+
- Excellent question. The constraints need not be strict. They should be approximately met, since relevance to the provided queries is also important
|
50 |
+
- The documents should be selected first based on relevance and then adjust the final selection to meet the constraints. It's ok to select less relevant articles to get closer to the constraints, but up to a certain point.
|
51 |
+
|
52 |
+
4. Date range filtering
|
53 |
+
|
54 |
+
- The date range is strict
|
55 |
+
|
56 |
+
5. Result volume and ranking logic
|
57 |
+
|
58 |
+
- Great question too. The articles should be ranked based on their relevance, with all queries adding documents to the result set, but choosing the most relevant ones from this pool. For example, it would be ok if all articles end up being derived from only one query.
|
59 |
+
|
60 |
+
6. Combining multiple queries
|
61 |
+
|
62 |
+
- The former. Each query is run separately and then the results are combined based on the document relevance score. Since the same document may be returned for different queries, in each case with a different relevance score, the highest relevance score should be used.
|
63 |
+
|
64 |
+
7. User preferences and constraint handling
|
65 |
+
|
66 |
+
- Exactly, the result set should be post processed to meet constraints
|
67 |
+
|
68 |
+
8. Performance considerations
|
69 |
+
|
70 |
+
- The document database has about 100,000 documents
|
71 |
+
- No latency requirements. The accuracy of the system is the only requirement
|
72 |
+
|
73 |
+
9. Evaluation and Benchmarks
|
74 |
+
|
75 |
+
- The system success will be measured based on relevance judgments and coverage of constraints, in that order
|
76 |
+
- A feedback loop to improve the retrieval strategies or embedding approach is definitely on the table, if needed.
|
77 |
+
|
78 |
+
self.embedding_model = SentenceTransformerModel("dunzhang/stella_en_400M_v5")
|
79 |
+
|
80 |
+
# Performance benchmarks
|
81 |
+
|
82 |
+
The script at `article_embedding/modal_app.py` was used to measure the performance of the embedding model. The results are as follows:
|
83 |
+
|
84 |
+
| GPU | Speed | Cost |
|
85 |
+
| :----: | ------: | ----: |
|
86 |
+
| T4 | 130.0ms | 2.13¢ |
|
87 |
+
| L4 | 82.7ms | 1.84¢ |
|
88 |
+
| M3 Max | 60.0ms | N/A |
|
89 |
+
| A10G | 64.2ms | 1.96¢ |
|
90 |
+
| A100 | 42.2ms | 3.26¢ |
|
91 |
+
| H100 | 24.9ms | 3.15¢ |
|
92 |
+
|
93 |
+
These are benchmarks on the M3 Max 128GB:
|
94 |
+
|
95 |
+
| Model | f32 | f16 | Context | Size |
|
96 |
+
| -------------- | ----: | ----: | ------: | ----: |
|
97 |
+
| jasper sdpa | 373ms | 314ms | 1,024 | 1,024 |
|
98 |
+
| jasper | 354ms | 330ms | 1,024 | 1,024 |
|
99 |
+
| stella-en-400M | 65ms | 62ms | 131,072 | 1,024 |
|
100 |
+
|
101 |
+
## Models
|
102 |
+
|
103 |
+
[Jasper](https://huggingface.co/infgrad/jasper_en_vision_language_v1)
|
104 |
+
|
105 |
+
## Issues
|
106 |
+
|
107 |
+
```sh
|
108 |
+
$ cd /Users/jlopez/git-repos/icfi/article-embedding ; /usr/bin/env /Users/jlopez/git-repos/icfi/article-embedding/.venv/bin/python /Users/jlopez/.vscode/extensions/ms-python.debugpy-2024.10.0-darwin-arm64/bundled/libs/debugpy/adapter/../../debugpy/launcher 63791 -- article_embedding/loader.py
|
109 |
+
0%| | 0/314855 [00:00<?, ?it/s]E+00379.644: /handling disconnect from Adapter/
|
110 |
+
Failed to kill Debuggee[PID=93387]
|
111 |
+
|
112 |
+
Traceback (most recent call last):
|
113 |
+
File "/Users/jlopez/.vscode/extensions/ms-python.debugpy-2024.10.0-darwin-arm64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/launcher/debuggee.py", line 190, in kill
|
114 |
+
os.killpg(process.pid, signal.SIGKILL)
|
115 |
+
PermissionError: [Errno 1] Operation not permitted
|
116 |
+
|
117 |
+
Stack where logged:
|
118 |
+
File "/Users/jlopez/.pyenv/versions/3.12.5/lib/python3.12/threading.py", line 1032, in _bootstrap
|
119 |
+
self._bootstrap_inner()
|
120 |
+
File "/Users/jlopez/.pyenv/versions/3.12.5/lib/python3.12/threading.py", line 1075, in _bootstrap_inner
|
121 |
+
self.run()
|
122 |
+
File "/Users/jlopez/.pyenv/versions/3.12.5/lib/python3.12/threading.py", line 1012, in run
|
123 |
+
self._target(*self._args, **self._kwargs)
|
124 |
+
File "/Users/jlopez/.vscode/extensions/ms-python.debugpy-2024.10.0-darwin-arm64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/common/messaging.py", line 1458, in _run_handlers
|
125 |
+
handler()
|
126 |
+
File "/Users/jlopez/.vscode/extensions/ms-python.debugpy-2024.10.0-darwin-arm64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/common/messaging.py", line 1488, in _handle_disconnect
|
127 |
+
handler()
|
128 |
+
File "/Users/jlopez/.vscode/extensions/ms-python.debugpy-2024.10.0-darwin-arm64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/launcher/handlers.py", line 152, in disconnect
|
129 |
+
debuggee.kill()
|
130 |
+
File "/Users/jlopez/.vscode/extensions/ms-python.debugpy-2024.10.0-darwin-arm64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/launcher/debuggee.py", line 192, in kill
|
131 |
+
log.swallow_exception("Failed to kill {0}", describe())
|
132 |
+
File "/Users/jlopez/.vscode/extensions/ms-python.debugpy-2024.10.0-darwin-arm64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/common/log.py", line 215, in swallow_exception
|
133 |
+
_exception(format_string, *args, **kwargs)
|
134 |
+
|
135 |
+
|
136 |
+
E+00379.648: Failed to kill Debuggee[PID=93387]
|
137 |
+
|
138 |
+
Traceback (most recent call last):
|
139 |
+
File "/Users/jlopez/.vscode/extensions/ms-python.debugpy-2024.10.0-darwin-arm64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/launcher/debuggee.py", line 190, in kill
|
140 |
+
os.killpg(process.pid, signal.SIGKILL)
|
141 |
+
PermissionError: [Errno 1] Operation not permitted
|
142 |
+
|
143 |
+
Stack where logged:
|
144 |
+
File "/Users/jlopez/.vscode/extensions/ms-python.debugpy-2024.10.0-darwin-arm64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/launcher/debuggee.py", line 192, in kill
|
145 |
+
log.swallow_exception("Failed to kill {0}", describe())
|
146 |
+
File "/Users/jlopez/.vscode/extensions/ms-python.debugpy-2024.10.0-darwin-arm64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/common/log.py", line 215, in swallow_exception
|
147 |
+
_exception(format_string, *args, **kwargs)
|
148 |
+
```
|
article_embedding/__init__.py
ADDED
File without changes
|
article_embedding/__main__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from article_embedding.cli import cli
|
2 |
+
|
3 |
+
if __name__ == "__main__":
|
4 |
+
cli()
|
article_embedding/app/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .ui import run_app
|
2 |
+
|
3 |
+
__all__ = ["run_app"]
|
article_embedding/app/gather.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
from article_embedding.retrieval import retrieve_documents
|
6 |
+
|
7 |
+
DEFAULT_QUERY = """
|
8 |
+
The development of the COVID-19 pandemic globally.
|
9 |
+
The “herd immunity” and “mitigationist” response of capitalist governments globally to the COVID-19 pandemic.
|
10 |
+
The only viable solution of global elimination of the COVID-19 pandemic.
|
11 |
+
The science of the COVID-19 pandemic: Long COVID, airborne transmission, viral evolution.
|
12 |
+
""".strip()
|
13 |
+
|
14 |
+
|
15 |
+
def gather_articles(
|
16 |
+
query: str,
|
17 |
+
year: int,
|
18 |
+
articles_per_year: int,
|
19 |
+
pct_perspectives: int,
|
20 |
+
pct_us: int,
|
21 |
+
max_per_author: int,
|
22 |
+
) -> list[list[Any]]:
|
23 |
+
queries = [q.strip() for q in query.split("\n") if q.strip()]
|
24 |
+
docs = retrieve_documents(
|
25 |
+
queries=queries,
|
26 |
+
start_date=f"{year}-01-01",
|
27 |
+
end_date=f"{year + 1}-01-01",
|
28 |
+
number_of_articles=articles_per_year,
|
29 |
+
op_ed_ratio=pct_perspectives / 100.0,
|
30 |
+
us_ratio=pct_us / 100.0,
|
31 |
+
max_per_author=max_per_author,
|
32 |
+
)
|
33 |
+
return [
|
34 |
+
[
|
35 |
+
ix + 1,
|
36 |
+
d.query_id,
|
37 |
+
f"{d.score * 100:.1f}%",
|
38 |
+
d.meta.published.strftime("%Y/%m/%d") if d.meta.published else "N/A",
|
39 |
+
", ".join(d.meta.authors)[:32] if d.meta.authors else "N/A",
|
40 |
+
f'[{d.meta.title or "N/A"}]({d.meta.url})',
|
41 |
+
]
|
42 |
+
for ix, d in enumerate(docs)
|
43 |
+
]
|
44 |
+
|
45 |
+
|
46 |
+
def gather_tab() -> None:
|
47 |
+
with gr.Tab("Gather"):
|
48 |
+
with gr.Row():
|
49 |
+
with gr.Column(scale=2):
|
50 |
+
query = gr.Textbox(label="Query", placeholder="Enter multiple queries, one per line", value=DEFAULT_QUERY)
|
51 |
+
gather = gr.Button(value="Gather articles")
|
52 |
+
with gr.Column():
|
53 |
+
year = gr.Slider(minimum=2021, maximum=2025, value=2021, label="Year", step=1)
|
54 |
+
articles_per_year = gr.Slider(label="Articles per year", minimum=1, maximum=100, value=60, step=1)
|
55 |
+
pct_perspectives = gr.Slider(label="Perspectives (%)", minimum=0, maximum=100, value=50)
|
56 |
+
pct_us = gr.Slider(label="US (%)", minimum=0, maximum=100, value=30)
|
57 |
+
max_per_author = gr.Slider(label="Max articles per author", minimum=1, maximum=20, value=6, step=1)
|
58 |
+
|
59 |
+
results = gr.DataFrame(
|
60 |
+
label="Results",
|
61 |
+
type="array",
|
62 |
+
headers=["#", "Q", "%", "Date", "Authors", "Title"],
|
63 |
+
datatype=["number", "number", "number", "date", "str", "markdown"],
|
64 |
+
column_widths=["3.5%", "3.5%", "4%", "7%", "20%", "62%"],
|
65 |
+
)
|
66 |
+
|
67 |
+
gather.click(
|
68 |
+
fn=gather_articles,
|
69 |
+
inputs=[query, year, articles_per_year, pct_perspectives, pct_us, max_per_author],
|
70 |
+
outputs=[results],
|
71 |
+
)
|
article_embedding/app/lookup.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
from article_embedding.couchdb import CouchDB
|
4 |
+
|
5 |
+
|
6 |
+
def lookup_tab() -> None:
|
7 |
+
with gr.Tab("Lookup"):
|
8 |
+
doc_id = gr.Textbox(label="Document ID", placeholder="Enter a document ID")
|
9 |
+
button = gr.Button("Lookup")
|
10 |
+
doc_output = gr.JSON()
|
11 |
+
button.click(fn=CouchDB().get_doc, inputs=[doc_id], outputs=[doc_output])
|
article_embedding/app/model.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import logging
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
from pydantic import BaseModel
|
6 |
+
|
7 |
+
log = logging.getLogger(__name__)
|
8 |
+
|
9 |
+
|
10 |
+
class WorkflowData(BaseModel):
|
11 |
+
status: str | None = None
|
12 |
+
writers: list[str] | None = None
|
13 |
+
editors: list[str] | None = None
|
14 |
+
proofers: list[str] | None = None
|
15 |
+
reviewers: list[str] | None = None
|
16 |
+
proofingDeadline: str | None = None
|
17 |
+
|
18 |
+
|
19 |
+
class Document(BaseModel):
|
20 |
+
_id: str
|
21 |
+
_rev: str
|
22 |
+
type: str | None = None
|
23 |
+
mimetype: str | None = None
|
24 |
+
title: str | None = None
|
25 |
+
language: str | None = None
|
26 |
+
workflowData: WorkflowData | None = None
|
27 |
+
path: str | None = None
|
28 |
+
name: str | None = None
|
29 |
+
created: int | None = None
|
30 |
+
creator: str | None = None
|
31 |
+
lastPublished: int | None = None
|
32 |
+
firstPublished: int | None = None
|
33 |
+
modified: datetime.datetime | None = None
|
34 |
+
modifier: str | None = None
|
35 |
+
published: datetime.datetime | None = None
|
36 |
+
authors: list[str] | None = None
|
37 |
+
content: str | None = None
|
38 |
+
contentAssets: list[str] | None = None
|
39 |
+
featuredImages: list[str] | None = None
|
40 |
+
keywords: list[str] | None = None
|
41 |
+
topics: list[str] | None = None
|
42 |
+
relatedAssets: list[str] | None = None
|
43 |
+
comments: bool | None = None
|
44 |
+
campaignConfigs: list[Any] | None = None
|
45 |
+
order: int | None = None
|
46 |
+
overline: str | None = None
|
47 |
+
translatedFrom: str | None = None
|
48 |
+
socialTitles: list[Any] | None = None
|
49 |
+
socialDescriptions: list[Any] | None = None
|
50 |
+
socialFeaturedImages: list[Any] | None = None
|
51 |
+
underline: str | None = None
|
52 |
+
template: str | None = None
|
53 |
+
description: str | None = None
|
54 |
+
suggestedImages: list[str] | None = None
|
55 |
+
publisher: str | None = None
|
56 |
+
|
57 |
+
@property
|
58 |
+
def url(self) -> str:
|
59 |
+
if not self.path:
|
60 |
+
log.warning(f"URL requested for pathless document: {self}.")
|
61 |
+
return public_url(self.path)
|
62 |
+
|
63 |
+
def has_topic(self, topic_id: str) -> bool:
|
64 |
+
return topic_id in self.topics if self.topics else False
|
65 |
+
|
66 |
+
def __str__(self) -> str:
|
67 |
+
return f"Document[id={self._id!r}, path={self.path!r}]"
|
68 |
+
|
69 |
+
|
70 |
+
def public_url(path: str | None) -> str:
|
71 |
+
return f"https://www.wsws.org{path or ''}"
|
article_embedding/app/rag.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
from typing import AsyncGenerator, Literal
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
from gradio import MessageDict
|
6 |
+
from openai import AsyncOpenAI, BaseModel
|
7 |
+
from openai.types.chat import ChatCompletionMessageParam
|
8 |
+
|
9 |
+
from article_embedding.chunk import html_to_markdown
|
10 |
+
from article_embedding.couchdb import CouchDB
|
11 |
+
from article_embedding.query import Query
|
12 |
+
from article_embedding.utils import env_str
|
13 |
+
|
14 |
+
BackendId = Literal["OpenAI", "Ollama"]
|
15 |
+
|
16 |
+
|
17 |
+
class BackendConfig(BaseModel):
|
18 |
+
base_url: str | None = None
|
19 |
+
api_key: str | None = None
|
20 |
+
|
21 |
+
@property
|
22 |
+
def async_client(self) -> AsyncOpenAI:
|
23 |
+
return AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
|
24 |
+
|
25 |
+
|
26 |
+
BACKEND_CONFIG: dict[BackendId, BackendConfig] = {
|
27 |
+
"OpenAI": BackendConfig(),
|
28 |
+
"Ollama": BackendConfig(base_url=env_str("OLLAMA_URL"), api_key="Ollama"),
|
29 |
+
}
|
30 |
+
|
31 |
+
|
32 |
+
class RagModel(BaseModel):
|
33 |
+
backend_id: BackendId
|
34 |
+
model: str
|
35 |
+
|
36 |
+
@property
|
37 |
+
def backend(self) -> BackendConfig:
|
38 |
+
return BACKEND_CONFIG[self.backend_id]
|
39 |
+
|
40 |
+
|
41 |
+
RAG_MODELS: dict[str, RagModel] = {
|
42 |
+
m.model: m
|
43 |
+
for m in [
|
44 |
+
RagModel(backend_id="OpenAI", model="gpt-4o"),
|
45 |
+
RagModel(backend_id="OpenAI", model="gpt-4o-mini"),
|
46 |
+
RagModel(backend_id="Ollama", model="qwen2.5:72b"),
|
47 |
+
RagModel(backend_id="Ollama", model="llama3.3:70b"),
|
48 |
+
RagModel(backend_id="Ollama", model="phi4"),
|
49 |
+
]
|
50 |
+
}
|
51 |
+
|
52 |
+
QUERY_PROMPT = """
|
53 |
+
Generate a short paragraph whose vector embedding will be used to find related articles using a vector database indexing the WSWS archive.
|
54 |
+
Only output the paragraph without any additional text.
|
55 |
+
|
56 |
+
{question}
|
57 |
+
""".strip() # noqa: E501
|
58 |
+
|
59 |
+
CONTEXT_PROMPT = """
|
60 |
+
Using the following sources, answer the user's question as precisely as possible.
|
61 |
+
Only use the information provided in the sources.
|
62 |
+
If not enough information is available, say you don't know.
|
63 |
+
Include references to the sources in the answer by using just the article index and url, e.g. [3](URL).
|
64 |
+
The answers should be exhaustive and cover as much as possible from the sources.
|
65 |
+
|
66 |
+
{docs}
|
67 |
+
""".strip()
|
68 |
+
|
69 |
+
_couchdb = CouchDB()
|
70 |
+
|
71 |
+
|
72 |
+
# From Gradio chat_interface.py:ChatInterface:_stream_fn
|
73 |
+
async def chat_function(
|
74 |
+
message: str, # | MultimodalPostprocess,
|
75 |
+
history_with_input: list[MessageDict],
|
76 |
+
model: str,
|
77 |
+
) -> AsyncGenerator[str | list[MessageDict], None]:
|
78 |
+
_openai = RAG_MODELS[model].backend.async_client
|
79 |
+
_chat = _openai.chat.completions.create
|
80 |
+
messages: list[ChatCompletionMessageParam] = [{"role": h["role"], "content": h["content"]} for h in history_with_input]
|
81 |
+
|
82 |
+
async def generate_system_prompt(question: str) -> str:
|
83 |
+
msgs: list[ChatCompletionMessageParam] = messages + [{"role": "system", "content": QUERY_PROMPT.format(question=question)}]
|
84 |
+
response = await _chat(messages=msgs, model=model, stream=False)
|
85 |
+
query = response.choices[0].message.content or ""
|
86 |
+
|
87 |
+
result = Query(index="wsws-2").query(query, limit=40)
|
88 |
+
paths = []
|
89 |
+
for point in result.points:
|
90 |
+
doc = point.payload
|
91 |
+
if doc is None:
|
92 |
+
continue
|
93 |
+
path = doc["path"]
|
94 |
+
if path not in paths:
|
95 |
+
paths.append(path)
|
96 |
+
if len(paths) >= 8:
|
97 |
+
break
|
98 |
+
tasks = [_couchdb.get_doc(path) for path in paths]
|
99 |
+
docs: list[str] = []
|
100 |
+
for rix, task_result in enumerate(await asyncio.gather(*tasks, return_exceptions=True), start=1):
|
101 |
+
if isinstance(task_result, dict):
|
102 |
+
data = [
|
103 |
+
f"Index: {rix}",
|
104 |
+
f"URL: https://www.wsws.org{task_result["path"]}",
|
105 |
+
task_result["title"] or "",
|
106 |
+
task_result["description"] or "",
|
107 |
+
html_to_markdown(task_result["content"]),
|
108 |
+
"---",
|
109 |
+
]
|
110 |
+
docs.append("\n".join(data))
|
111 |
+
return CONTEXT_PROMPT.format(docs="\n\n".join(docs))
|
112 |
+
|
113 |
+
system_prompt = await generate_system_prompt(message)
|
114 |
+
messages.append({"role": "system", "content": system_prompt})
|
115 |
+
messages.append({"role": "user", "content": message})
|
116 |
+
generator = await _chat(messages=messages, model=model, stream=True)
|
117 |
+
result: MessageDict = {"role": "assistant", "content": ""}
|
118 |
+
async for chunk in generator:
|
119 |
+
delta = chunk.choices[0].delta
|
120 |
+
if delta.content:
|
121 |
+
result["content"] += delta.content
|
122 |
+
yield result
|
123 |
+
|
124 |
+
|
125 |
+
def rag_tab() -> None:
|
126 |
+
with gr.Tab("RAG"):
|
127 |
+
model = gr.Dropdown(value="gpt-4o", choices=list(RAG_MODELS), label="Model")
|
128 |
+
gr.ChatInterface(
|
129 |
+
fn=chat_function,
|
130 |
+
multimodal=False,
|
131 |
+
type="messages",
|
132 |
+
additional_inputs=[model],
|
133 |
+
chatbot=gr.Chatbot(height=800),
|
134 |
+
)
|
article_embedding/app/search.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
from typing import Any
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
from qdrant_client.models import Condition, Filter
|
6 |
+
|
7 |
+
from article_embedding.query import Query, make_date_condition, make_topic_condition
|
8 |
+
|
9 |
+
|
10 |
+
def search(
|
11 |
+
text: str,
|
12 |
+
date_from: datetime | None,
|
13 |
+
date_to: datetime | None,
|
14 |
+
topics: str,
|
15 |
+
limit: int,
|
16 |
+
) -> tuple[list[list[str]], list[Any]]:
|
17 |
+
rv: list[list[str]] = []
|
18 |
+
|
19 |
+
topic_lines = [t for t in topics.strip().split("\n") if t]
|
20 |
+
should: list[Condition] = [make_topic_condition(t) for t in topic_lines if t[0].isalnum()]
|
21 |
+
must: list[Condition] = [make_topic_condition(t[1:]) for t in topic_lines if t.startswith("+")]
|
22 |
+
must_not: list[Condition] = [make_topic_condition(t[1:]) for t in topic_lines if t.startswith("!")]
|
23 |
+
if date_from or date_to:
|
24 |
+
must.append(make_date_condition(date_from=date_from, date_to=date_to)) # type: ignore
|
25 |
+
result = Query().query(
|
26 |
+
text,
|
27 |
+
query_filter=Filter(should=should, must=must, must_not=must_not),
|
28 |
+
limit=limit,
|
29 |
+
)
|
30 |
+
docs = []
|
31 |
+
for point in result.points:
|
32 |
+
doc = point.payload
|
33 |
+
assert doc is not None
|
34 |
+
docs.append(doc)
|
35 |
+
rv.append(
|
36 |
+
[
|
37 |
+
f"{point.score * 100:.1f}%",
|
38 |
+
datetime.fromtimestamp(doc["published"]).strftime("%Y/%m/%d"),
|
39 |
+
", ".join(doc["authors"]),
|
40 |
+
f'[{doc["title"] or "N/A"}](https://www.wsws.org{doc["path"]})',
|
41 |
+
]
|
42 |
+
)
|
43 |
+
return rv, docs
|
44 |
+
|
45 |
+
|
46 |
+
def search_tab() -> None:
|
47 |
+
default_query = "The COVID winter wave, the emergence of the Delta variant and the January 6th coup"
|
48 |
+
with gr.Tab("Search"):
|
49 |
+
with gr.Row():
|
50 |
+
with gr.Column(scale=2):
|
51 |
+
query = gr.Textbox(
|
52 |
+
label="Query",
|
53 |
+
placeholder="Enter a query",
|
54 |
+
value=default_query,
|
55 |
+
)
|
56 |
+
with gr.Column():
|
57 |
+
from_date = gr.DateTime(include_time=False, type="datetime", label="From", value="2021-01-01")
|
58 |
+
to_date = gr.DateTime(include_time=False, type="datetime", label="To", value="2021-05-01")
|
59 |
+
topics = gr.Textbox(label="Topics", placeholder="Enter topics")
|
60 |
+
limit = gr.Number(label="Limit", minimum=1, value=10)
|
61 |
+
button = gr.Button("Search")
|
62 |
+
results = gr.DataFrame(
|
63 |
+
label="Results",
|
64 |
+
type="array",
|
65 |
+
headers=["Score", "Published", "Authors", "Title"],
|
66 |
+
datatype=["number", "date", "str", "markdown"],
|
67 |
+
)
|
68 |
+
with gr.Accordion("Debug"):
|
69 |
+
json_output = gr.JSON()
|
70 |
+
button.click(fn=search, inputs=[query, from_date, to_date, topics, limit], outputs=[results, json_output])
|
article_embedding/app/ui.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import os
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
|
7 |
+
from article_embedding.app.gather import gather_tab
|
8 |
+
from article_embedding.app.lookup import lookup_tab
|
9 |
+
from article_embedding.app.rag import rag_tab
|
10 |
+
from article_embedding.app.search import search_tab
|
11 |
+
|
12 |
+
load_dotenv()
|
13 |
+
|
14 |
+
|
15 |
+
async def amain(share: bool = False) -> None:
|
16 |
+
with gr.Blocks(fill_height=True) as app: # noqa: SIM117
|
17 |
+
rag_tab()
|
18 |
+
gather_tab()
|
19 |
+
search_tab()
|
20 |
+
lookup_tab()
|
21 |
+
|
22 |
+
auth = (os.environ["GRADIO_USERNAME"], os.environ["GRADIO_PASSWORD"]) if share else None
|
23 |
+
app.queue().launch(inbrowser=True, share=share, auth=auth)
|
24 |
+
|
25 |
+
|
26 |
+
def run_app(share: bool = False) -> None:
|
27 |
+
asyncio.run(amain(share))
|
28 |
+
|
29 |
+
|
30 |
+
if __name__ == "__main__":
|
31 |
+
run_app()
|
article_embedding/benchmark.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import time
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Any
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import numpy.typing as npt
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from article_embedding.embed import SentenceTransformerModel, StellaEmbedder
|
12 |
+
|
13 |
+
log = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
|
16 |
+
async def fetch_documents() -> list[str]:
|
17 |
+
from article_embedding.couchdb import CouchDB
|
18 |
+
|
19 |
+
async for sorted_batch in CouchDB().changes(batch_size=128):
|
20 |
+
sorted_batch = sorted(sorted_batch, key=lambda x: x.get("_id"))
|
21 |
+
return [a["content"] for a in sorted_batch if a.get("type") == "article" and a.get("language") == "en"]
|
22 |
+
return []
|
23 |
+
|
24 |
+
|
25 |
+
def process2(model: SentenceTransformerModel, documents: list[str], name: str) -> None:
|
26 |
+
model.embed(["Hello, world!"]) # Warmup
|
27 |
+
ts0 = time.time()
|
28 |
+
embeddings = model.embed(documents)
|
29 |
+
benchmark = time.time() - ts0
|
30 |
+
output_path = Path("data/embeddings.json")
|
31 |
+
save_embeddings(embeddings, output_path)
|
32 |
+
golden_path = Path(f"data/embeddings.{name}-golden.json")
|
33 |
+
if golden_path.exists():
|
34 |
+
golden_embeddings: Any = load_embeddings(golden_path)
|
35 |
+
similarities = model.model.similarity_pairwise(embeddings, golden_embeddings)
|
36 |
+
rms = torch.sqrt(torch.mean(similarities**2)).item()
|
37 |
+
else:
|
38 |
+
save_embeddings(embeddings, golden_path)
|
39 |
+
rms = 0.0
|
40 |
+
log.info(
|
41 |
+
"%s - RMS: %.2f. Latency: %.2f ms. Size: %d",
|
42 |
+
name,
|
43 |
+
rms,
|
44 |
+
benchmark / len(embeddings) * 1000,
|
45 |
+
len(embeddings[0]),
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
def load_embeddings(path: Path) -> list[npt.NDArray[np.float64]]:
|
50 |
+
with path.open() as f:
|
51 |
+
return [np.array(json.loads(line)) for line in f.readlines()]
|
52 |
+
|
53 |
+
|
54 |
+
def save_embeddings(embeddings: list[npt.NDArray[np.float64]], path: Path) -> None:
|
55 |
+
with path.open("w") as f:
|
56 |
+
for e in embeddings:
|
57 |
+
f.write(json.dumps(e.tolist()) + "\n")
|
58 |
+
|
59 |
+
|
60 |
+
async def amain() -> None:
|
61 |
+
model = StellaEmbedder()
|
62 |
+
# model = NvEmbedder()
|
63 |
+
# model = JasperEmbedder()
|
64 |
+
# model.model.half()
|
65 |
+
documents = await fetch_documents()
|
66 |
+
process2(model, documents, "stella")
|
67 |
+
# documents = await fetch_documents()
|
68 |
+
# process2(model, documents, "stella")
|
69 |
+
|
70 |
+
|
71 |
+
if __name__ == "__main__":
|
72 |
+
import asyncio
|
73 |
+
|
74 |
+
logging.basicConfig(level=logging.WARN)
|
75 |
+
log.setLevel(logging.DEBUG)
|
76 |
+
asyncio.run(amain())
|
article_embedding/chunk.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import cast
|
3 |
+
|
4 |
+
import pandoc
|
5 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
6 |
+
|
7 |
+
|
8 |
+
class Chunker(ABC):
|
9 |
+
@abstractmethod
|
10 |
+
def chunk(self, text: str) -> list[str]: ...
|
11 |
+
|
12 |
+
|
13 |
+
class LangchainChunker(Chunker):
|
14 |
+
def __init__(self, chunk_size: int = 576, chunk_overlap: int = 0) -> None:
|
15 |
+
self.splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
16 |
+
|
17 |
+
def chunk(self, text: str) -> list[str]:
|
18 |
+
md = html_to_markdown(text)
|
19 |
+
return self.splitter.split_text(md)
|
20 |
+
|
21 |
+
|
22 |
+
def html_to_markdown(html: str) -> str:
|
23 |
+
doc = pandoc.read(html, format="html")
|
24 |
+
md = pandoc.write(doc, format="markdown-smart", options=["--wrap=none"])
|
25 |
+
return cast(str, md)
|
article_embedding/cli.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
import click
|
4 |
+
from click_help_colors import HelpColorsGroup
|
5 |
+
|
6 |
+
from article_embedding.app import run_app
|
7 |
+
|
8 |
+
|
9 |
+
@click.group(
|
10 |
+
cls=HelpColorsGroup,
|
11 |
+
help_headers_color="yellow",
|
12 |
+
help_options_color="green",
|
13 |
+
)
|
14 |
+
def cli() -> None:
|
15 |
+
pass
|
16 |
+
|
17 |
+
|
18 |
+
@cli.command()
|
19 |
+
@click.option("--share", is_flag=True, help="Share the app at Gradio.")
|
20 |
+
def app(share: bool) -> None:
|
21 |
+
"""Run the Gradio app."""
|
22 |
+
logging.basicConfig(level=logging.WARN, format="%(asctime)s %(levelname)s %(name)s - %(message)s")
|
23 |
+
logging.getLogger("article_embedding").setLevel(logging.INFO)
|
24 |
+
run_app(share)
|
article_embedding/constants.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
PERSPECTIVE_TOPIC_ID = "105b9ec2-0cdd-480d-9c55-6f806fa7b9bc"
|
2 |
+
US_TOPIC_ID = "ea69254c-a976-4e67-a79c-843202a10e55"
|
article_embedding/couchdb.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
from abc import ABC, abstractmethod
|
7 |
+
from typing import Any, AsyncGenerator
|
8 |
+
|
9 |
+
import backoff
|
10 |
+
import httpx
|
11 |
+
import tqdm
|
12 |
+
|
13 |
+
log = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
|
16 |
+
class Checkpoint(ABC):
|
17 |
+
@abstractmethod
|
18 |
+
def get(self) -> str | None: ...
|
19 |
+
|
20 |
+
@abstractmethod
|
21 |
+
def set(self, value: str) -> None: ...
|
22 |
+
|
23 |
+
@abstractmethod
|
24 |
+
def reset(self) -> None: ...
|
25 |
+
|
26 |
+
|
27 |
+
class NullCheckpoint(Checkpoint):
|
28 |
+
def get(self) -> str | None:
|
29 |
+
return None
|
30 |
+
|
31 |
+
def set(self, value: str) -> None:
|
32 |
+
pass
|
33 |
+
|
34 |
+
def reset(self) -> None:
|
35 |
+
pass
|
36 |
+
|
37 |
+
|
38 |
+
_NULL_CHECKPOINT = NullCheckpoint()
|
39 |
+
|
40 |
+
|
41 |
+
class FileCheckpoint(Checkpoint):
|
42 |
+
def __init__(self, path: str) -> None:
|
43 |
+
self.path = path
|
44 |
+
|
45 |
+
def get(self) -> str | None:
|
46 |
+
try:
|
47 |
+
with open(self.path) as file:
|
48 |
+
return file.read().strip()
|
49 |
+
except FileNotFoundError:
|
50 |
+
return None
|
51 |
+
|
52 |
+
def set(self, value: str) -> None:
|
53 |
+
with open(self.path, "w") as file:
|
54 |
+
file.write(value)
|
55 |
+
|
56 |
+
def reset(self) -> None:
|
57 |
+
with contextlib.suppress(FileNotFoundError):
|
58 |
+
os.remove(self.path)
|
59 |
+
|
60 |
+
|
61 |
+
class CouchDB:
|
62 |
+
def __init__(self) -> None:
|
63 |
+
self.client = self.make_client()
|
64 |
+
self.database = os.environ["COUCHDB_DB"]
|
65 |
+
self.path_view = f"/{self.database}/{os.environ["DOCS_PATH_VIEW"]}"
|
66 |
+
|
67 |
+
def __new__(cls) -> "CouchDB":
|
68 |
+
if not hasattr(cls, "_instance"):
|
69 |
+
cls._instance = super().__new__(cls)
|
70 |
+
return cls._instance
|
71 |
+
|
72 |
+
def make_client(self) -> httpx.AsyncClient:
|
73 |
+
url = os.environ["COUCHDB_URL"]
|
74 |
+
user = os.environ["COUCHDB_USER"]
|
75 |
+
password = os.environ["COUCHDB_PASSWORD"]
|
76 |
+
auth = {"name": user, "password": password}
|
77 |
+
|
78 |
+
async def on_backoff(details: Any) -> None:
|
79 |
+
response = await self.client.post("/_session", json=auth)
|
80 |
+
response.raise_for_status()
|
81 |
+
|
82 |
+
client = httpx.AsyncClient(base_url=url)
|
83 |
+
decorator = backoff.on_predicate(
|
84 |
+
backoff.expo,
|
85 |
+
predicate=lambda r: r.status_code == 401,
|
86 |
+
on_backoff=on_backoff,
|
87 |
+
max_tries=2,
|
88 |
+
factor=0,
|
89 |
+
)
|
90 |
+
client.get = decorator(client.get) # type: ignore[method-assign]
|
91 |
+
return client
|
92 |
+
|
93 |
+
async def changes(self, *, batch_size: int, checkpoint: Checkpoint = _NULL_CHECKPOINT) -> AsyncGenerator[list[Any], None]:
|
94 |
+
since = checkpoint.get() or 0
|
95 |
+
params = {"since": since, "limit": batch_size, "include_docs": True}
|
96 |
+
while True:
|
97 |
+
response = await self.client.get(f"/{self.database}/_changes", params=params)
|
98 |
+
response.raise_for_status()
|
99 |
+
data = response.json()
|
100 |
+
yield [change["doc"] for change in data["results"]]
|
101 |
+
since = data["last_seq"]
|
102 |
+
assert isinstance(since, str)
|
103 |
+
params["since"] = since
|
104 |
+
checkpoint.set(since)
|
105 |
+
if data["pending"] == 0:
|
106 |
+
break
|
107 |
+
|
108 |
+
async def estimate_total_changes(self, *, checkpoint: Checkpoint = _NULL_CHECKPOINT) -> int:
|
109 |
+
since = checkpoint.get() or 0
|
110 |
+
params = {"since": since, "limit": 0}
|
111 |
+
response = await self.client.get(f"/{self.database}/_changes", params=params)
|
112 |
+
response.raise_for_status()
|
113 |
+
data = response.json()
|
114 |
+
return int(data["pending"]) + 1
|
115 |
+
|
116 |
+
async def get_doc_by_id(self, doc_id: str) -> Any:
|
117 |
+
try:
|
118 |
+
response = await self.client.get(f"/{self.database}/{doc_id}")
|
119 |
+
if response.status_code == 404:
|
120 |
+
return None
|
121 |
+
response.raise_for_status()
|
122 |
+
return response.json()
|
123 |
+
except Exception as e:
|
124 |
+
log.error("Error fetching document by ID", exc_info=e)
|
125 |
+
return None
|
126 |
+
|
127 |
+
async def get_doc_by_path(self, path: str) -> Any:
|
128 |
+
try:
|
129 |
+
params = {
|
130 |
+
"limit": "1",
|
131 |
+
"key": json.dumps(path),
|
132 |
+
"include_docs": "true",
|
133 |
+
}
|
134 |
+
response = await self.client.get(self.path_view, params=params)
|
135 |
+
response.raise_for_status()
|
136 |
+
data = response.json()
|
137 |
+
rows = data["rows"]
|
138 |
+
if not rows:
|
139 |
+
return None
|
140 |
+
return rows[0]["doc"]
|
141 |
+
except Exception as e:
|
142 |
+
logging.error("Error fetching document by path", exc_info=e)
|
143 |
+
return None
|
144 |
+
|
145 |
+
async def get_doc(self, id_or_path: str) -> Any:
|
146 |
+
uuids = extract_doc_ids(id_or_path)
|
147 |
+
for uuid in uuids:
|
148 |
+
doc = await self.get_doc_by_id(uuid)
|
149 |
+
if doc:
|
150 |
+
return doc
|
151 |
+
|
152 |
+
path = extract_doc_path(id_or_path)
|
153 |
+
if path:
|
154 |
+
return await self.get_doc_by_path(path)
|
155 |
+
|
156 |
+
return None
|
157 |
+
|
158 |
+
|
159 |
+
UUID_PATTERN = re.compile(r"[0-9a-f]{8}-[0-9a-f]{4}-[1-5][0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}")
|
160 |
+
|
161 |
+
|
162 |
+
def extract_doc_ids(s: str) -> list[str]:
|
163 |
+
return UUID_PATTERN.findall(s)
|
164 |
+
|
165 |
+
|
166 |
+
def extract_doc_path(s: str) -> str | None:
|
167 |
+
if not s.endswith(".html"):
|
168 |
+
return None
|
169 |
+
if s.startswith("/"):
|
170 |
+
return s
|
171 |
+
if "://" in s:
|
172 |
+
s = s.split("://", 1)[1]
|
173 |
+
if "/" in s:
|
174 |
+
return "/" + s.split("/", 1)[1]
|
175 |
+
return None
|
176 |
+
|
177 |
+
|
178 |
+
if __name__ == "__main__":
|
179 |
+
|
180 |
+
async def main() -> None:
|
181 |
+
db = CouchDB()
|
182 |
+
checkpoint = FileCheckpoint(".checkpoint")
|
183 |
+
total = await db.estimate_total_changes(checkpoint=checkpoint)
|
184 |
+
with tqdm.tqdm(total=total) as pbar:
|
185 |
+
async for docs in db.changes(batch_size=40, checkpoint=checkpoint):
|
186 |
+
for doc in docs:
|
187 |
+
kind = doc.get("type")
|
188 |
+
if kind == "article":
|
189 |
+
_id = doc["_id"]
|
190 |
+
language = doc["language"]
|
191 |
+
path = doc["path"]
|
192 |
+
path = os.path.basename(path)
|
193 |
+
pbar.desc = f"{_id}: {kind} {language} {path}"
|
194 |
+
else:
|
195 |
+
pbar.desc = f"{kind}"
|
196 |
+
pbar.update(1)
|
197 |
+
|
198 |
+
import asyncio
|
199 |
+
|
200 |
+
from dotenv import load_dotenv
|
201 |
+
|
202 |
+
load_dotenv()
|
203 |
+
asyncio.run(main())
|
article_embedding/embed.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import Any
|
3 |
+
|
4 |
+
from fastembed import TextEmbedding # type: ignore
|
5 |
+
from qdrant_client import QdrantClient
|
6 |
+
from qdrant_client.models import Distance, VectorParams
|
7 |
+
from sentence_transformers import SentenceTransformer
|
8 |
+
|
9 |
+
|
10 |
+
class EmbeddingModel(ABC):
|
11 |
+
@abstractmethod
|
12 |
+
def embed(self, documents: list[str]) -> Any: ...
|
13 |
+
@abstractmethod
|
14 |
+
def get_vectors_config(self) -> VectorParams: ...
|
15 |
+
|
16 |
+
|
17 |
+
class SentenceTransformerModel(EmbeddingModel):
|
18 |
+
model: SentenceTransformer
|
19 |
+
|
20 |
+
def embed(self, documents: list[str]) -> Any:
|
21 |
+
return self.model.encode(documents, normalize_embeddings=True)
|
22 |
+
|
23 |
+
def get_vectors_config(self) -> VectorParams:
|
24 |
+
dimensions = self.model.get_sentence_embedding_dimension()
|
25 |
+
assert dimensions is not None
|
26 |
+
return VectorParams(
|
27 |
+
size=dimensions,
|
28 |
+
distance=Distance.COSINE,
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
class StellaEmbedder(SentenceTransformerModel):
|
33 |
+
def __init__(self) -> None:
|
34 |
+
self.model = SentenceTransformer(
|
35 |
+
"dunzhang/stella_en_400M_v5",
|
36 |
+
trust_remote_code=True,
|
37 |
+
config_kwargs={"use_memory_efficient_attention": False, "unpad_inputs": False},
|
38 |
+
).to("mps")
|
39 |
+
|
40 |
+
|
41 |
+
class NvEmbedder(SentenceTransformerModel):
|
42 |
+
def __init__(self) -> None:
|
43 |
+
self.model = SentenceTransformer(
|
44 |
+
"nvidia/NV-Embed-V2",
|
45 |
+
trust_remote_code=True,
|
46 |
+
).to("mps")
|
47 |
+
self.model.max_seq_length = 32768
|
48 |
+
self.model.tokenizer.padding_side = "right"
|
49 |
+
|
50 |
+
def embed(self, documents: list[str]) -> Any:
|
51 |
+
processed = [d + self.model.tokenizer.eos_token for d in documents]
|
52 |
+
return self.model.encode(processed, normalize_embeddings=True)
|
53 |
+
|
54 |
+
|
55 |
+
class JasperEmbedder(SentenceTransformerModel):
|
56 |
+
def __init__(self, use_sdpa: bool = True) -> None:
|
57 |
+
model_kwargs = {"attn_implementation": "sdpa"} if use_sdpa else {}
|
58 |
+
self.model = SentenceTransformer(
|
59 |
+
"infgrad/jasper_en_vision_language_v1",
|
60 |
+
trust_remote_code=True,
|
61 |
+
model_kwargs=model_kwargs,
|
62 |
+
config_kwargs={
|
63 |
+
"is_text_encoder": True,
|
64 |
+
"vector_dim": 1024,
|
65 |
+
},
|
66 |
+
).to("mps")
|
67 |
+
self.model.max_seq_length = 1024
|
68 |
+
|
69 |
+
def embed(self, documents: list[str]) -> Any:
|
70 |
+
processed = [d + self.model.tokenizer.eos_token for d in documents]
|
71 |
+
return self.model.encode(processed, normalize_embeddings=True)
|
72 |
+
|
73 |
+
|
74 |
+
class FastEmbedModel(EmbeddingModel):
|
75 |
+
def __init__(self, model_name: str = "BAAI/bge-small-en-v1.5") -> None:
|
76 |
+
self.model = TextEmbedding(model_name=model_name)
|
77 |
+
|
78 |
+
def embed(self, documents: list[str]) -> Any:
|
79 |
+
return self.model.embed(documents)
|
80 |
+
|
81 |
+
def get_vectors_config(self) -> VectorParams:
|
82 |
+
size, distance = QdrantClient._get_model_params(model_name=self.model.model_name)
|
83 |
+
return VectorParams(
|
84 |
+
size=size,
|
85 |
+
distance=distance,
|
86 |
+
)
|
article_embedding/loader.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import os
|
3 |
+
import uuid
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from typing import Any, Generator
|
6 |
+
|
7 |
+
import tqdm
|
8 |
+
from qdrant_client import QdrantClient
|
9 |
+
from qdrant_client.http.exceptions import UnexpectedResponse
|
10 |
+
from qdrant_client.models import Batch
|
11 |
+
|
12 |
+
from article_embedding.chunk import Chunker, LangchainChunker
|
13 |
+
from article_embedding.couchdb import Checkpoint, CouchDB, FileCheckpoint
|
14 |
+
from article_embedding.embed import EmbeddingModel, StellaEmbedder
|
15 |
+
|
16 |
+
|
17 |
+
class Loader:
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
*,
|
21 |
+
collection_name: str,
|
22 |
+
embedding_model: EmbeddingModel,
|
23 |
+
checkpoint: Checkpoint,
|
24 |
+
chunker: Chunker | None = None,
|
25 |
+
) -> None:
|
26 |
+
self.collection_name = collection_name
|
27 |
+
self.embedding_model = embedding_model
|
28 |
+
self.checkpoint = checkpoint
|
29 |
+
self.chunker = chunker
|
30 |
+
self.qdrant = QdrantClient(os.getenv("QDRANT_URL"))
|
31 |
+
self.couchdb = CouchDB()
|
32 |
+
|
33 |
+
@property
|
34 |
+
def starting_from_scratch(self) -> bool:
|
35 |
+
return self.checkpoint.get() is None
|
36 |
+
|
37 |
+
def recreate_collection(self) -> None:
|
38 |
+
with contextlib.suppress(UnexpectedResponse):
|
39 |
+
self.qdrant.delete_collection(self.collection_name)
|
40 |
+
self.qdrant.create_collection(
|
41 |
+
collection_name=self.collection_name,
|
42 |
+
vectors_config=self.embedding_model.get_vectors_config(),
|
43 |
+
)
|
44 |
+
|
45 |
+
async def load(self, batch_size: int = 64) -> None:
|
46 |
+
if self.starting_from_scratch:
|
47 |
+
self.recreate_collection()
|
48 |
+
total = await self.couchdb.estimate_total_changes(checkpoint=self.checkpoint)
|
49 |
+
with tqdm.tqdm(total=total) as pbar:
|
50 |
+
async for docs in self.couchdb.changes(batch_size=batch_size, checkpoint=self.checkpoint):
|
51 |
+
points = self.embed(docs)
|
52 |
+
if points:
|
53 |
+
self.qdrant.upsert(
|
54 |
+
collection_name=self.collection_name,
|
55 |
+
points=points,
|
56 |
+
wait=False,
|
57 |
+
)
|
58 |
+
pbar.update(len(docs))
|
59 |
+
|
60 |
+
def embed(self, docs: list[Any]) -> Batch | None:
|
61 |
+
def article_filter(doc: Any) -> bool:
|
62 |
+
return doc.get("type") == "article" and doc.get("language") == "en" and doc.get("content") is not None
|
63 |
+
|
64 |
+
def to_payload(doc: Any) -> Any:
|
65 |
+
doc.pop("content")
|
66 |
+
return doc
|
67 |
+
|
68 |
+
def chunk(doc: dict[str, Any]) -> Generator[dict[str, Any], None, None]:
|
69 |
+
assert self.chunker is not None
|
70 |
+
content = doc["content"]
|
71 |
+
for ix, chunk in enumerate(self.chunker.chunk(content)):
|
72 |
+
copy = doc.copy()
|
73 |
+
copy["docId"] = doc["_id"]
|
74 |
+
copy["_id"] = str(uuid.uuid3(_NS_DOCUMENT, chunk))
|
75 |
+
copy["content"] = chunk
|
76 |
+
copy["chunkIndex"] = ix
|
77 |
+
copy["chunk"] = chunk
|
78 |
+
yield copy
|
79 |
+
|
80 |
+
filtered = [d for d in docs if article_filter(d)]
|
81 |
+
chunked = [c for d in filtered for c in chunk(d)] if self.chunker else filtered
|
82 |
+
if not chunked:
|
83 |
+
return None
|
84 |
+
ids = [d["_id"] for d in chunked]
|
85 |
+
docs = [d["content"] for d in chunked]
|
86 |
+
payloads = [to_payload(d) for d in chunked]
|
87 |
+
return Batch(
|
88 |
+
ids=ids,
|
89 |
+
vectors=self.embedding_model.embed(docs),
|
90 |
+
payloads=payloads,
|
91 |
+
)
|
92 |
+
|
93 |
+
|
94 |
+
_NS_DOCUMENT = uuid.UUID("a07b4024-c09b-11ef-9e49-36d763a413b8")
|
95 |
+
|
96 |
+
|
97 |
+
def uuid_for_content(*content: Any) -> str:
|
98 |
+
import hashlib
|
99 |
+
import uuid
|
100 |
+
|
101 |
+
hash = hashlib.md5("".join(map(str, content)).encode()).hexdigest()
|
102 |
+
return str(uuid.UUID(hash[:32]))
|
103 |
+
|
104 |
+
|
105 |
+
if __name__ == "__main__":
|
106 |
+
stella = StellaEmbedder()
|
107 |
+
|
108 |
+
@dataclass
|
109 |
+
class Store:
|
110 |
+
collection_name: str
|
111 |
+
embedding_model: EmbeddingModel = stella
|
112 |
+
chunker: Chunker | None = None
|
113 |
+
batch_size: int = 64
|
114 |
+
|
115 |
+
@property
|
116 |
+
def checkpoint(self) -> str:
|
117 |
+
return f".{self.collection_name}.checkpoint"
|
118 |
+
|
119 |
+
class StoreManager:
|
120 |
+
def __init__(self) -> None:
|
121 |
+
self.stores: dict[str, Store] = {}
|
122 |
+
|
123 |
+
def add(self, store: Store) -> None:
|
124 |
+
self.stores[store.collection_name] = store
|
125 |
+
|
126 |
+
async def load(self, name: str) -> None:
|
127 |
+
store = self.stores[name]
|
128 |
+
checkpoint = FileCheckpoint(store.checkpoint)
|
129 |
+
loader = Loader(
|
130 |
+
collection_name=store.collection_name,
|
131 |
+
embedding_model=store.embedding_model,
|
132 |
+
checkpoint=checkpoint,
|
133 |
+
chunker=store.chunker,
|
134 |
+
)
|
135 |
+
await loader.load(batch_size=store.batch_size)
|
136 |
+
|
137 |
+
async def load_all(self) -> None:
|
138 |
+
for name in self.stores:
|
139 |
+
await self.load(name)
|
140 |
+
|
141 |
+
import asyncio
|
142 |
+
|
143 |
+
store_manager = StoreManager()
|
144 |
+
store_manager.add(Store(collection_name="wsws"))
|
145 |
+
store_manager.add(Store(collection_name="wsws-2", chunker=LangchainChunker(), batch_size=8))
|
146 |
+
asyncio.run(store_manager.load_all())
|
article_embedding/modal_app.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import time
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Any, Callable
|
6 |
+
|
7 |
+
import modal
|
8 |
+
import numpy as np
|
9 |
+
import numpy.typing as npt
|
10 |
+
import torch
|
11 |
+
from sentence_transformers import SentenceTransformer
|
12 |
+
|
13 |
+
from article_embedding.embed import SentenceTransformerModel, StellaEmbedder
|
14 |
+
|
15 |
+
log = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
def load_model() -> SentenceTransformer:
|
19 |
+
return SentenceTransformer(
|
20 |
+
"dunzhang/stella_en_400M_v5",
|
21 |
+
trust_remote_code=True,
|
22 |
+
config_kwargs={"use_memory_efficient_attention": False, "unpad_inputs": False},
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
image = modal.Image.debian_slim(python_version="3.12").pip_install(["sentence-transformers"]).run_function(load_model)
|
27 |
+
|
28 |
+
app = modal.App("embedding", image=image)
|
29 |
+
|
30 |
+
|
31 |
+
@app.cls(gpu="A10G")
|
32 |
+
class ModalEmbedder:
|
33 |
+
@modal.enter()
|
34 |
+
def setup(self) -> None:
|
35 |
+
logging.basicConfig(level=logging.WARN)
|
36 |
+
log.setLevel(logging.DEBUG)
|
37 |
+
if torch.cuda.is_available():
|
38 |
+
device = "cuda"
|
39 |
+
elif torch.backends.mps.is_available():
|
40 |
+
device = "mps"
|
41 |
+
else:
|
42 |
+
device = "cpu"
|
43 |
+
self.model = load_model().to(device)
|
44 |
+
log.info("Model loaded on %s", device)
|
45 |
+
|
46 |
+
@modal.method()
|
47 |
+
def embed(self, documents: list[str]) -> Any:
|
48 |
+
return self.model.encode(documents)
|
49 |
+
|
50 |
+
|
51 |
+
async def fetch_documents() -> list[str]:
|
52 |
+
from article_embedding.couchdb import CouchDB
|
53 |
+
|
54 |
+
async for sorted_batch in CouchDB().changes(batch_size=256):
|
55 |
+
sorted_batch = sorted(sorted_batch, key=lambda x: x.get("_id"))
|
56 |
+
return [a["content"] for a in sorted_batch if a.get("type") == "article" and a.get("language") == "en"]
|
57 |
+
return []
|
58 |
+
|
59 |
+
|
60 |
+
def process(func: Callable[[list[str]], Any], documents: list[str], name: str) -> None:
|
61 |
+
func(["Hello, world!"]) # Warmup
|
62 |
+
ts0 = time.time()
|
63 |
+
embeddings = func(documents)
|
64 |
+
benchmark = time.time() - ts0
|
65 |
+
output_path = Path("data/embeddings.json")
|
66 |
+
golden_path = Path(f"data/embeddings.{name}-golden.json")
|
67 |
+
save_embeddings(embeddings, output_path)
|
68 |
+
cosine_distance, rms = compare_embeddings(embeddings, golden_path)
|
69 |
+
log.info(
|
70 |
+
"%s - MCS: %.2f. RMS: %.2f. Latency: %.2f ms. Size: %d",
|
71 |
+
name,
|
72 |
+
cosine_distance,
|
73 |
+
rms,
|
74 |
+
benchmark / len(embeddings) * 1000,
|
75 |
+
len(embeddings[0]),
|
76 |
+
)
|
77 |
+
|
78 |
+
|
79 |
+
def process2(model: SentenceTransformerModel, documents: list[str], name: str) -> None:
|
80 |
+
model.embed(["Hello, world!"]) # Warmup
|
81 |
+
ts0 = time.time()
|
82 |
+
embeddings = model.embed(documents)
|
83 |
+
benchmark = time.time() - ts0
|
84 |
+
output_path = Path("data/embeddings.json")
|
85 |
+
save_embeddings(embeddings, output_path)
|
86 |
+
golden_path = Path(f"data/embeddings.{name}-golden.json")
|
87 |
+
if golden_path.exists():
|
88 |
+
golden_embeddings: Any = load_embeddings(golden_path)
|
89 |
+
similarities = model.model.similarity_pairwise(embeddings, golden_embeddings)
|
90 |
+
rms = torch.sqrt(torch.mean(similarities**2))
|
91 |
+
else:
|
92 |
+
save_embeddings(embeddings, golden_path)
|
93 |
+
rms = torch.zeros([0])
|
94 |
+
log.info(
|
95 |
+
"%s - RMS: %.2f. Latency: %.2f ms. Size: %d",
|
96 |
+
name,
|
97 |
+
rms,
|
98 |
+
benchmark / len(embeddings) * 1000,
|
99 |
+
len(embeddings[0]),
|
100 |
+
)
|
101 |
+
|
102 |
+
|
103 |
+
def load_embeddings(path: Path) -> list[npt.NDArray[np.float64]]:
|
104 |
+
with path.open() as f:
|
105 |
+
return [np.array(json.loads(line)) for line in f.readlines()]
|
106 |
+
|
107 |
+
|
108 |
+
def save_embeddings(embeddings: list[npt.NDArray[np.float64]], path: Path) -> None:
|
109 |
+
with path.open("w") as f:
|
110 |
+
for e in embeddings:
|
111 |
+
f.write(json.dumps(e.tolist()) + "\n")
|
112 |
+
|
113 |
+
|
114 |
+
def compare_embeddings(embeddings: list[npt.NDArray[np.float64]], golden_path: Path) -> tuple[float, float]:
|
115 |
+
if not golden_path.exists():
|
116 |
+
save_embeddings(embeddings, golden_path)
|
117 |
+
return 0.0, 0.0
|
118 |
+
with golden_path.open() as f:
|
119 |
+
golden_embeddings = [np.array(json.loads(line)) for line in f.readlines()]
|
120 |
+
np_embeddings = np.array(embeddings)
|
121 |
+
np_golden_embeddings = np.array(golden_embeddings)
|
122 |
+
rms = np.sqrt(np.mean((np_embeddings - np_golden_embeddings) ** 2))
|
123 |
+
dot_products = np.einsum("ij,ij->i", np_embeddings, np_golden_embeddings)
|
124 |
+
norms = np.linalg.norm(np_embeddings, axis=1) * np.linalg.norm(np_golden_embeddings, axis=1)
|
125 |
+
cosine_similarities = dot_products / norms
|
126 |
+
return np.mean(cosine_similarities), np.mean(rms)
|
127 |
+
|
128 |
+
|
129 |
+
@app.local_entrypoint()
|
130 |
+
async def modal_amain() -> None:
|
131 |
+
logging.basicConfig(level=logging.WARN)
|
132 |
+
log.setLevel(logging.DEBUG)
|
133 |
+
|
134 |
+
embedder = ModalEmbedder()
|
135 |
+
documents = await fetch_documents()
|
136 |
+
process(embedder.embed.remote, documents, "modal")
|
137 |
+
|
138 |
+
|
139 |
+
async def amain() -> None:
|
140 |
+
model = StellaEmbedder()
|
141 |
+
# model = NvEmbedder()
|
142 |
+
# model = JasperEmbedder()
|
143 |
+
# model.model.half()
|
144 |
+
documents = await fetch_documents()
|
145 |
+
process2(model, documents, "stella")
|
146 |
+
|
147 |
+
|
148 |
+
if __name__ == "__main__":
|
149 |
+
import asyncio
|
150 |
+
|
151 |
+
logging.basicConfig(level=logging.WARN)
|
152 |
+
log.setLevel(logging.DEBUG)
|
153 |
+
asyncio.run(amain())
|
article_embedding/query.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import warnings
|
3 |
+
from datetime import datetime
|
4 |
+
from typing import Any, Optional
|
5 |
+
|
6 |
+
from qdrant_client import QdrantClient
|
7 |
+
from qdrant_client.http.models import QueryResponse
|
8 |
+
from qdrant_client.models import FieldCondition, Filter, MatchValue, Range
|
9 |
+
|
10 |
+
from article_embedding.embed import StellaEmbedder
|
11 |
+
|
12 |
+
warnings.simplefilter(action="ignore", category=FutureWarning)
|
13 |
+
|
14 |
+
|
15 |
+
def as_timestamp(date: datetime | str) -> float:
|
16 |
+
if isinstance(date, datetime):
|
17 |
+
return date.timestamp()
|
18 |
+
return datetime.strptime(date, "%Y-%m-%d").timestamp()
|
19 |
+
|
20 |
+
|
21 |
+
def make_date_condition(
|
22 |
+
*, field: str = "published", date_from: datetime | str | None = None, date_to: datetime | str | None = None
|
23 |
+
) -> FieldCondition | None:
|
24 |
+
kwargs = {}
|
25 |
+
if date_from:
|
26 |
+
kwargs["gte"] = as_timestamp(date_from)
|
27 |
+
if date_to:
|
28 |
+
kwargs["lt"] = as_timestamp(date_to)
|
29 |
+
if kwargs:
|
30 |
+
return FieldCondition(key=field, range=Range(**kwargs))
|
31 |
+
return None
|
32 |
+
|
33 |
+
|
34 |
+
def make_topic_condition(topic_id: str) -> FieldCondition:
|
35 |
+
return FieldCondition(key="topics[]", match=MatchValue(value=topic_id))
|
36 |
+
|
37 |
+
|
38 |
+
class Query:
|
39 |
+
_instance: Optional["Query"] = None
|
40 |
+
_embedding_model_instance: Optional[StellaEmbedder] = None
|
41 |
+
|
42 |
+
def __init__(self, index: str = "wsws") -> None:
|
43 |
+
self.embedding_model = Query.embedding_model_singleton()
|
44 |
+
self.qdrant = QdrantClient(os.environ["QDRANT_URL"])
|
45 |
+
self.index = index
|
46 |
+
|
47 |
+
@staticmethod
|
48 |
+
def embedding_model_singleton() -> StellaEmbedder:
|
49 |
+
if Query._embedding_model_instance is None:
|
50 |
+
Query._embedding_model_instance = StellaEmbedder()
|
51 |
+
return Query._embedding_model_instance
|
52 |
+
|
53 |
+
@staticmethod
|
54 |
+
def singleton() -> "Query":
|
55 |
+
if Query._instance is None:
|
56 |
+
Query._instance = Query()
|
57 |
+
return Query._instance
|
58 |
+
|
59 |
+
def embed(self, query: str) -> Any:
|
60 |
+
return self.embedding_model.embed([query])[0]
|
61 |
+
|
62 |
+
def query(
|
63 |
+
self,
|
64 |
+
query: str,
|
65 |
+
query_filter: Filter | None = None,
|
66 |
+
limit: int = 10,
|
67 |
+
) -> QueryResponse:
|
68 |
+
vector = self.embedding_model.embed([query])[0]
|
69 |
+
return self.qdrant.query_points(self.index, query=vector, query_filter=query_filter, limit=limit)
|
70 |
+
|
71 |
+
|
72 |
+
if __name__ == "__main__":
|
73 |
+
import gspread
|
74 |
+
from dotenv import load_dotenv
|
75 |
+
from gspread.utils import ValueInputOption
|
76 |
+
|
77 |
+
data = [
|
78 |
+
("2021-01-01", "2021-05-01", "The COVID winter wave, the emergence of the Delta variant and the January 6th coup"),
|
79 |
+
(
|
80 |
+
"2021-05-01",
|
81 |
+
"2021-09-01",
|
82 |
+
"The COVID vaccine rollout, Biden declaring independence from COVID while the Delta wave continues",
|
83 |
+
),
|
84 |
+
(
|
85 |
+
"2021-09-01",
|
86 |
+
"2022-01-01",
|
87 |
+
"The emergence of the COVID Omicron variant and the embrace of herd immunity by the ruling class",
|
88 |
+
),
|
89 |
+
]
|
90 |
+
|
91 |
+
load_dotenv()
|
92 |
+
query = Query()
|
93 |
+
rows: list[list[str]] = []
|
94 |
+
for date_from, date_to, sentence in data:
|
95 |
+
result = query.query(
|
96 |
+
sentence,
|
97 |
+
query_filter=Filter(should=make_date_condition(date_from=date_from, date_to=date_to)),
|
98 |
+
)
|
99 |
+
rows.append([sentence])
|
100 |
+
for point in result.points:
|
101 |
+
doc = point.payload
|
102 |
+
assert doc is not None
|
103 |
+
print(f'{point.score * 100:.1f}% https://www.wsws.org{doc["path"]} - {doc["title"]}')
|
104 |
+
rows.append(
|
105 |
+
[
|
106 |
+
f"{point.score * 100:.1f}%",
|
107 |
+
datetime.fromtimestamp(doc["published"]).strftime("%Y/%m/%d"),
|
108 |
+
", ".join(doc["authors"]),
|
109 |
+
f'=hyperlink("https://www.wsws.org{doc["path"]}", "{doc["title"]}")',
|
110 |
+
]
|
111 |
+
)
|
112 |
+
rows.append([])
|
113 |
+
|
114 |
+
gc = gspread.auth.oauth(credentials_filename=os.environ["GOOGLE_CREDENTIALS"])
|
115 |
+
sh = gc.open("COVID-19 Compilation")
|
116 |
+
ws = sh.get_worksheet(0)
|
117 |
+
ws.append_rows(rows, value_input_option=ValueInputOption.user_entered)
|
article_embedding/replicate.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from qdrant_client import QdrantClient
|
2 |
+
|
3 |
+
from article_embedding.utils import env_str
|
4 |
+
|
5 |
+
remote_url = env_str("CLUSTER_1_ENDPOINT")
|
6 |
+
remote_key = env_str("CLUSTER_1_API_KEY")
|
7 |
+
collection_name = "wsws-2"
|
8 |
+
local_url = env_str("QDRANT_URL")
|
9 |
+
|
10 |
+
# model = StellaEmbedder()
|
11 |
+
|
12 |
+
cloud_client = QdrantClient(url=remote_url, api_key=remote_key)
|
13 |
+
|
14 |
+
# cloud_client.recreate_collection(
|
15 |
+
# collection_name=collection_name,
|
16 |
+
# vectors_config=model.get_vectors_config(),
|
17 |
+
# )
|
18 |
+
|
19 |
+
local_client = QdrantClient(local_url)
|
20 |
+
|
21 |
+
local_client.migrate(cloud_client, collection_names=[collection_name])
|
article_embedding/retrieval.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
from typing import Generator, Literal
|
4 |
+
|
5 |
+
from pydantic import BaseModel
|
6 |
+
from qdrant_client.models import Filter, ScoredPoint
|
7 |
+
|
8 |
+
from article_embedding.app.model import Document
|
9 |
+
from article_embedding.constants import PERSPECTIVE_TOPIC_ID, US_TOPIC_ID
|
10 |
+
from article_embedding.query import Query, make_date_condition
|
11 |
+
|
12 |
+
log = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
class Candidate(BaseModel):
|
16 |
+
doc_id: str
|
17 |
+
score: float
|
18 |
+
query_id: int
|
19 |
+
meta: Document
|
20 |
+
|
21 |
+
@classmethod
|
22 |
+
def from_point(cls, query_id: int, point: ScoredPoint) -> "Candidate":
|
23 |
+
assert isinstance(point.id, str)
|
24 |
+
return cls(doc_id=point.id, score=point.score, query_id=query_id, meta=Document.model_validate(point.payload))
|
25 |
+
|
26 |
+
|
27 |
+
CandidatePool = dict[str, Candidate]
|
28 |
+
Candidates = list[Candidate]
|
29 |
+
Fractions = tuple[float, float]
|
30 |
+
AuthorCounts = dict[str, int]
|
31 |
+
|
32 |
+
|
33 |
+
# Qdrant retrieval function (assume exists)
|
34 |
+
def qdrant_search(query_id: int, query: str, start_date: str, end_date: str, top_k: int) -> list[Candidate]:
|
35 |
+
# returns a list of candidates (doc_id, relevance_score, metadata)
|
36 |
+
# metadata includes: date, author, topics, etc.
|
37 |
+
|
38 |
+
query_filter = Filter(must=make_date_condition(date_from=start_date, date_to=end_date))
|
39 |
+
result = Query.singleton().query(query, query_filter=query_filter, limit=top_k)
|
40 |
+
return [Candidate.from_point(query_id, point) for point in result.points]
|
41 |
+
|
42 |
+
|
43 |
+
def compute_count(docs: Candidates, *topic_ids: str) -> int:
|
44 |
+
return sum(1 if any(d.meta.has_topic(t) for t in topic_ids) else 0 for d in docs)
|
45 |
+
|
46 |
+
|
47 |
+
def compute_fractions(docs: Candidates) -> Fractions:
|
48 |
+
oped_count = sum(d.meta.has_topic(PERSPECTIVE_TOPIC_ID) for d in docs)
|
49 |
+
us_count = sum(d.meta.has_topic(US_TOPIC_ID) for d in docs)
|
50 |
+
total = len(docs)
|
51 |
+
return (oped_count / total, us_count / total)
|
52 |
+
|
53 |
+
|
54 |
+
def compute_author_counts(docs: Candidates) -> AuthorCounts:
|
55 |
+
counts: AuthorCounts = {}
|
56 |
+
for d in docs:
|
57 |
+
for author in d.meta.authors or []:
|
58 |
+
counts[author] = counts.get(author, 0) + 1
|
59 |
+
return counts
|
60 |
+
|
61 |
+
|
62 |
+
Category = tuple[bool, bool]
|
63 |
+
SwapIndices = tuple[int, int]
|
64 |
+
SwapCandidates = tuple[Candidate, Candidate]
|
65 |
+
Sign = Literal[-1, 0, 1]
|
66 |
+
Evaluation = Literal["rejected", "accepted", "desired"]
|
67 |
+
|
68 |
+
|
69 |
+
def sign(n: int) -> Sign:
|
70 |
+
return -1 if n < 0 else 0 if n == 0 else 1
|
71 |
+
|
72 |
+
|
73 |
+
class Categories(BaseModel):
|
74 |
+
accepted: list[Sign] = []
|
75 |
+
desired: list[Sign] = []
|
76 |
+
|
77 |
+
def evaluate_delta(self, delta: int) -> Evaluation:
|
78 |
+
if sign(delta) in self.accepted:
|
79 |
+
return "accepted"
|
80 |
+
if sign(delta) in self.desired:
|
81 |
+
return "desired"
|
82 |
+
return "rejected"
|
83 |
+
|
84 |
+
|
85 |
+
def seek_constraints(
|
86 |
+
selected: Candidates,
|
87 |
+
available: Candidates,
|
88 |
+
oped_target: int,
|
89 |
+
us_target: int,
|
90 |
+
max_per_author: int,
|
91 |
+
tolerance: int,
|
92 |
+
) -> bool:
|
93 |
+
# Swaps candidates from `selected` with candidates from `available` with the goal of
|
94 |
+
# meeting the target counts for OpEd and US topics, while respecting the max per author constraint.
|
95 |
+
# and favoring a more balanced author count distribution (i.e. closer to the mean).
|
96 |
+
|
97 |
+
# It is a precondition that the author counts in `selected` do not already exceed the limit.
|
98 |
+
# The function will attempt to adjust the selection to meet the target counts approximately.
|
99 |
+
|
100 |
+
# The return value indicates if the algorithm was able to meet the constraints.
|
101 |
+
|
102 |
+
# This is a heuristic approach, i.e. it may not always find the optimal solution.
|
103 |
+
# It swaps documents from the tail of `selected` with appropriate replacements from
|
104 |
+
# the head of `available`.
|
105 |
+
|
106 |
+
# The max per author constraint is a hard constraint, i.e. it may not be exceeded.
|
107 |
+
# The algorithm also attempts to move the author count means closer to the mean.
|
108 |
+
|
109 |
+
# The count constraints are soft constraints, i.e. they allow some tolerance.
|
110 |
+
# The algorithm finishes when the counts meet the target within the tolerance.
|
111 |
+
|
112 |
+
# A band of ±2 is considered a hard constraint, ±1 is a soft constraint, and 0 is on target.
|
113 |
+
def compute_band(current: int, target: int, tolerance: int) -> int:
|
114 |
+
if current < target - tolerance:
|
115 |
+
return -2
|
116 |
+
if current == target - tolerance:
|
117 |
+
return -1
|
118 |
+
if current < target + tolerance:
|
119 |
+
return 0
|
120 |
+
if current == target + tolerance:
|
121 |
+
return +1
|
122 |
+
return +2
|
123 |
+
|
124 |
+
# Determine current counts and bands
|
125 |
+
oped_current = compute_count(selected, PERSPECTIVE_TOPIC_ID)
|
126 |
+
us_current = compute_count(selected, US_TOPIC_ID)
|
127 |
+
author_counts = compute_author_counts(selected)
|
128 |
+
author_mean_numerator = sum(author_counts.values())
|
129 |
+
author_mean_denominator = len(author_counts) or 1
|
130 |
+
|
131 |
+
# It is a precondition that the author counts in `selected` do not exceed the limit
|
132 |
+
if sum(1 for count in author_counts.values() if count > max_per_author):
|
133 |
+
raise ValueError("Precondition violated: author counts exceed the limit")
|
134 |
+
|
135 |
+
def iterate_swap_indices() -> Generator[SwapIndices, None, None]:
|
136 |
+
for dropped_index in range(len(selected) - 1, -1, -1):
|
137 |
+
for picked_index in range(len(available)):
|
138 |
+
yield (picked_index, dropped_index)
|
139 |
+
|
140 |
+
def compute_topic_delta(swap_candidates: SwapCandidates, topic_id: str) -> int:
|
141 |
+
return sum(d if c.meta.has_topic(topic_id) else 0 for d, c in zip([1, -1], swap_candidates))
|
142 |
+
|
143 |
+
def evaluate_author_change(author: str, delta: int, mean: float) -> int:
|
144 |
+
count = author_counts.get(author, 0)
|
145 |
+
old_distance = abs(count - mean)
|
146 |
+
new_distance = abs(count + delta - mean)
|
147 |
+
return int(math.copysign(1, old_distance - new_distance))
|
148 |
+
|
149 |
+
def evaluate_author_exchange(swap_candidates: SwapCandidates, mean: float) -> Evaluation:
|
150 |
+
points = 0
|
151 |
+
# An exchange gets points for moving author counts closer to the mean
|
152 |
+
picked, dropped = swap_candidates
|
153 |
+
for author_picked in picked.meta.authors or []:
|
154 |
+
if author_counts.get(author_picked, 0) >= max_per_author:
|
155 |
+
return "rejected"
|
156 |
+
points += evaluate_author_change(author_picked, +1, mean)
|
157 |
+
for author_dropped in dropped.meta.authors or []:
|
158 |
+
points += evaluate_author_change(author_dropped, -1, mean)
|
159 |
+
return "desired" if points > 0 else "accepted"
|
160 |
+
|
161 |
+
def update_author_counts(swap_candidates: SwapCandidates) -> int:
|
162 |
+
author_mean_numerator_delta = 0
|
163 |
+
|
164 |
+
picked, dropped = swap_candidates
|
165 |
+
for author in picked.meta.authors or []:
|
166 |
+
author_mean_numerator_delta += 1
|
167 |
+
author_counts[author] = author_counts.get(author, 0) + 1
|
168 |
+
for author in dropped.meta.authors or []:
|
169 |
+
author_mean_numerator_delta -= 1
|
170 |
+
new_count = author_counts.pop(author) - 1
|
171 |
+
if new_count:
|
172 |
+
author_counts[author] = new_count
|
173 |
+
|
174 |
+
return author_mean_numerator_delta
|
175 |
+
|
176 |
+
# Bands
|
177 |
+
# -2 - pick +, drop -
|
178 |
+
# -1 - pick +, drop -, pick +, drop +, pick -, drop -
|
179 |
+
# 0 - all
|
180 |
+
# +1 - pick -, drop +, pick +, drop +, pick -, drop -
|
181 |
+
# +2 - pick -, drop +
|
182 |
+
CATEGORIES = [
|
183 |
+
Categories(desired=[1]),
|
184 |
+
Categories(desired=[1], accepted=[0]),
|
185 |
+
Categories(accepted=[1, 0, -1]),
|
186 |
+
Categories(desired=[-1], accepted=[0]),
|
187 |
+
Categories(desired=[-1]),
|
188 |
+
]
|
189 |
+
# We do multiple passes, in each successive pass we try to improve fewer constraints
|
190 |
+
swap_attempt = 0
|
191 |
+
swap_count = 0
|
192 |
+
for desired_target in range(3, 0, -1):
|
193 |
+
current_pass = 4 - desired_target
|
194 |
+
# Iterate all possible swaps
|
195 |
+
for picked_index, dropped_index in iterate_swap_indices():
|
196 |
+
swap_attempt += 1
|
197 |
+
swap_candidates = (available[picked_index], selected[dropped_index])
|
198 |
+
oped_band = compute_band(oped_current, oped_target, tolerance)
|
199 |
+
us_band = compute_band(us_current, us_target, tolerance)
|
200 |
+
author_mean_denominator = len(author_counts) or 1
|
201 |
+
author_mean = author_mean_numerator / author_mean_denominator
|
202 |
+
|
203 |
+
if log.isEnabledFor(logging.INFO):
|
204 |
+
author_variance = sum((count - author_mean) ** 2 for count in author_counts.values()) / author_mean_denominator
|
205 |
+
|
206 |
+
if log.isEnabledFor(logging.DEBUG):
|
207 |
+
# fmt: off
|
208 |
+
log.debug("%d/%d - %d Eval %d <-> %d: oped=%d (%d %d), us=%d (%d %d), author µ=%.2f σ=%.2f",
|
209 |
+
swap_count, swap_attempt, current_pass, picked_index, dropped_index,
|
210 |
+
oped_current, oped_target, oped_band, us_current, us_target, us_band, author_mean, author_variance,
|
211 |
+
) # fmt: on
|
212 |
+
|
213 |
+
# Check if constraints are met
|
214 |
+
if abs(oped_band) <= 1 and abs(us_band) <= 1:
|
215 |
+
# fmt: off
|
216 |
+
log.info("%d/%d - %d Constraints met: oped=%d (%d %d), us=%d (%d %d), author µ=%.2f σ=%.2f",
|
217 |
+
swap_count, swap_attempt, current_pass,
|
218 |
+
oped_current, oped_target, oped_band, us_target, us_target, us_band, author_mean, author_variance,
|
219 |
+
) # fmt: on
|
220 |
+
return True
|
221 |
+
|
222 |
+
oped_delta = compute_topic_delta(swap_candidates, PERSPECTIVE_TOPIC_ID)
|
223 |
+
category = CATEGORIES[oped_band + 2]
|
224 |
+
oped_evaluation = category.evaluate_delta(oped_delta)
|
225 |
+
if oped_evaluation == "rejected":
|
226 |
+
log.debug("%d/%d - %d Swap rejected by OpEd constraints", swap_count, swap_attempt, current_pass)
|
227 |
+
continue
|
228 |
+
|
229 |
+
us_delta = compute_topic_delta(swap_candidates, US_TOPIC_ID)
|
230 |
+
category = CATEGORIES[us_band + 2]
|
231 |
+
us_evaluation = category.evaluate_delta(us_delta)
|
232 |
+
if us_evaluation == "rejected":
|
233 |
+
log.debug("%d/%d - %d Swap rejected by US constraints", swap_count, swap_attempt, current_pass)
|
234 |
+
continue
|
235 |
+
|
236 |
+
author_evaluation = evaluate_author_exchange(swap_candidates, author_mean)
|
237 |
+
if author_evaluation == "rejected":
|
238 |
+
log.debug("%d/%d - %d Swap rejected by author constraints", swap_count, swap_attempt, current_pass)
|
239 |
+
continue
|
240 |
+
|
241 |
+
desired_count = sum(1 for f in [oped_evaluation, us_evaluation, author_evaluation] if f == "desired")
|
242 |
+
if desired_count < desired_target:
|
243 |
+
log.debug("%d/%d - %d Swap rejected by desired count", swap_count, swap_attempt, current_pass)
|
244 |
+
continue
|
245 |
+
|
246 |
+
swap_count += 1
|
247 |
+
author_mean_numerator_delta = update_author_counts(swap_candidates)
|
248 |
+
oped_current += oped_delta
|
249 |
+
us_current += us_delta
|
250 |
+
author_mean_numerator += author_mean_numerator_delta
|
251 |
+
picked, dropped = swap_candidates
|
252 |
+
selected.pop(dropped_index)
|
253 |
+
available.pop(picked_index)
|
254 |
+
selected.append(picked)
|
255 |
+
available.append(dropped)
|
256 |
+
|
257 |
+
if log.isEnabledFor(logging.INFO):
|
258 |
+
author_variance = sum((count - author_mean) ** 2 for count in author_counts.values()) / author_mean_denominator
|
259 |
+
# fmt: off
|
260 |
+
log.info("%d/%d - %d Swapped %d <-> %d: oped=%d (%d %d), us=%d (%d %d), author µ=%.2f σ=%.2f",
|
261 |
+
swap_count, swap_attempt, current_pass, picked_index, dropped_index,
|
262 |
+
oped_current, oped_target, oped_band, us_current, us_target, us_band, author_mean, author_variance,
|
263 |
+
) # fmt: on
|
264 |
+
|
265 |
+
log.warning("No more improvements possible")
|
266 |
+
return False
|
267 |
+
|
268 |
+
|
269 |
+
def retrieve_documents(
|
270 |
+
*,
|
271 |
+
queries: list[str],
|
272 |
+
start_date: str,
|
273 |
+
end_date: str,
|
274 |
+
number_of_articles: int,
|
275 |
+
op_ed_ratio: float,
|
276 |
+
us_ratio: float,
|
277 |
+
max_per_author: int,
|
278 |
+
) -> Candidates:
|
279 |
+
K = 5 * number_of_articles # retrieve 5x the needed documents from each query
|
280 |
+
|
281 |
+
# 1. Retrieve candidates for each query
|
282 |
+
candidate_pool: CandidatePool = {}
|
283 |
+
for query_id, q in enumerate(queries):
|
284 |
+
results = qdrant_search(query_id, q, start_date, end_date, K)
|
285 |
+
for candidate in results:
|
286 |
+
if candidate.doc_id not in candidate_pool:
|
287 |
+
candidate_pool[candidate.doc_id] = candidate
|
288 |
+
else:
|
289 |
+
# If the doc appears multiple times, keep the higher score
|
290 |
+
if candidate.score > candidate_pool[candidate.doc_id].score:
|
291 |
+
candidate_pool[candidate.doc_id] = candidate
|
292 |
+
|
293 |
+
# 2. Sort candidates by relevance
|
294 |
+
sorted_candidates = sorted(candidate_pool.values(), key=lambda c: c.score, reverse=True)
|
295 |
+
|
296 |
+
# 3. Initial selection: take top M documents by relevance respecting author limits
|
297 |
+
available: Candidates = []
|
298 |
+
skipped: Candidates = []
|
299 |
+
author_counts: dict[str, int] = {}
|
300 |
+
for c in sorted_candidates:
|
301 |
+
skip = False
|
302 |
+
for author in c.meta.authors or []:
|
303 |
+
author_counts[author] = author_counts.get(author, 0) + 1
|
304 |
+
if author_counts[author] > max_per_author:
|
305 |
+
skip = True
|
306 |
+
break
|
307 |
+
if skip:
|
308 |
+
skipped.append(c)
|
309 |
+
else:
|
310 |
+
available.append(c)
|
311 |
+
if len(available) < number_of_articles:
|
312 |
+
raise ValueError("Not enough documents to meet the max_per_author constraint")
|
313 |
+
available += skipped
|
314 |
+
|
315 |
+
# 4. Attempt adjustments to meet constraints approximately
|
316 |
+
selected = available[:number_of_articles]
|
317 |
+
available = available[number_of_articles:]
|
318 |
+
seek_constraints(
|
319 |
+
selected=selected,
|
320 |
+
available=available,
|
321 |
+
oped_target=int(op_ed_ratio * number_of_articles),
|
322 |
+
us_target=int(us_ratio * number_of_articles),
|
323 |
+
max_per_author=max_per_author,
|
324 |
+
tolerance=int(0.05 * number_of_articles),
|
325 |
+
)
|
326 |
+
|
327 |
+
# selected now contains the chosen documents meeting date range, maximizing relevance,
|
328 |
+
# and approximately meeting the OpEd/Non-OpEd, US/International fractions, as well as
|
329 |
+
# respecting the max_per_author constraint.
|
330 |
+
return selected
|
article_embedding/sheets.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import gspread
|
4 |
+
|
5 |
+
gc = gspread.auth.oauth(credentials_filename=os.environ["GOOGLE_CREDENTIALS"])
|
6 |
+
sh = gc.open("COVID-19 Compilation")
|
7 |
+
print(sh.sheet1.get("A1"))
|
article_embedding/utils.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Assorted utils."""
|
2 |
+
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
from types import EllipsisType
|
6 |
+
from typing import Any
|
7 |
+
|
8 |
+
import yaml
|
9 |
+
|
10 |
+
_FALSEY_STRINGS = ["", "0", "false", "no", "off", "n"]
|
11 |
+
|
12 |
+
|
13 |
+
def is_truthy(value: Any) -> bool:
|
14 |
+
"""Check if a value is truthy."""
|
15 |
+
if isinstance(value, str):
|
16 |
+
return value.lower() not in _FALSEY_STRINGS
|
17 |
+
return bool(value)
|
18 |
+
|
19 |
+
|
20 |
+
def env_data(env_name: str, default: Any = ...) -> Any:
|
21 |
+
"""Load a YAML/JSON from an environment variable."""
|
22 |
+
value = os.environ.get(env_name)
|
23 |
+
if value is None:
|
24 |
+
if default is ...:
|
25 |
+
raise ValueError(f"Missing environment variable {env_name}")
|
26 |
+
return default
|
27 |
+
extension = os.path.splitext(value)[1]
|
28 |
+
if extension not in [".yaml", ".yml", ".json"]:
|
29 |
+
raise ValueError(f"Environment variable {env_name} must be a YAML/JSON file")
|
30 |
+
try:
|
31 |
+
with open(value, encoding="utf-8") as fd:
|
32 |
+
if extension == ".json":
|
33 |
+
return json.load(fd)
|
34 |
+
return yaml.load(fd, Loader=yaml.CLoader)
|
35 |
+
except Exception as e:
|
36 |
+
raise ValueError(f"Environment variable {env_name}") from e
|
37 |
+
|
38 |
+
|
39 |
+
def env_str(env_name: str, default: str | EllipsisType = ...) -> str:
|
40 |
+
"""Get a string value from an environment variable."""
|
41 |
+
value = os.environ.get(env_name)
|
42 |
+
if value is None:
|
43 |
+
if default is ...:
|
44 |
+
raise ValueError(f"Missing environment variable {env_name}")
|
45 |
+
return default
|
46 |
+
return value
|
47 |
+
|
48 |
+
|
49 |
+
def env_bool(env_name: str, default: bool | EllipsisType = ...) -> bool:
|
50 |
+
"""Get a boolean value from an environment variable."""
|
51 |
+
value = os.environ.get(env_name)
|
52 |
+
if value is None:
|
53 |
+
if default is ...:
|
54 |
+
raise ValueError(f"Missing environment variable {env_name}")
|
55 |
+
return default
|
56 |
+
return value.lower() not in ["", "0", "false", "no", "off"]
|
57 |
+
|
58 |
+
|
59 |
+
def env_int(env_name: str, default: int | EllipsisType = ...) -> int:
|
60 |
+
"""Get an int value from an environment variable."""
|
61 |
+
value = os.environ.get(env_name)
|
62 |
+
if value is None:
|
63 |
+
if default is ...:
|
64 |
+
raise ValueError(f"Missing environment variable {env_name}")
|
65 |
+
return default
|
66 |
+
return int(value)
|
67 |
+
|
68 |
+
|
69 |
+
def env_float(env_name: str, default: float | EllipsisType = ...) -> float:
|
70 |
+
"""Get a boolean value from an environment variable."""
|
71 |
+
value = os.environ.get(env_name)
|
72 |
+
if value is None:
|
73 |
+
if default is ...:
|
74 |
+
raise ValueError(f"Missing environment variable {env_name}")
|
75 |
+
return default
|
76 |
+
return float(value)
|
notebooks/couchdb.http
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@url = {{$dotenv COUCHDB_URL}}
|
2 |
+
@user = {{$dotenv COUCHDB_USER}}
|
3 |
+
@password = {{$dotenv COUCHDB_PASSWORD}}
|
4 |
+
@db = {{$dotenv COUCHDB_DB}}
|
5 |
+
@origin = 0
|
6 |
+
@origin = 316723-g1AAAACReJzLYWBgYMpgTmHgzcvPy09JdcjLz8gvLskBCScyJNX___8_K4M5iYEpWyEXKMaekpyabGBuiK4ehwl5LECSoQFI_YcblCUMNsgsMS010dQSXVsWAD8FLCA
|
7 |
+
@batchSize = 10
|
8 |
+
@includeDocs = true
|
9 |
+
|
10 |
+
###
|
11 |
+
# @name session
|
12 |
+
POST {{url}}/_session
|
13 |
+
Content-Type: application/json
|
14 |
+
|
15 |
+
{
|
16 |
+
"name": "{{user}}",
|
17 |
+
"password": "{{password}}"
|
18 |
+
}
|
19 |
+
|
20 |
+
###
|
21 |
+
# @name changes
|
22 |
+
GET {{url}}/{{db}}/_changes
|
23 |
+
?limit={{batchSize}}
|
24 |
+
&include_docs={{includeDocs}}
|
25 |
+
&since={{origin}}
|
26 |
+
|
27 |
+
###
|
28 |
+
# @name changes
|
29 |
+
GET {{url}}/{{db}}/_changes
|
30 |
+
?limit={{batchSize}}
|
31 |
+
&include_docs={{includeDocs}}
|
32 |
+
&since={{changes.response.body.last_seq}}
|
33 |
+
|
34 |
+
###
|
35 |
+
# @name broken
|
36 |
+
GET {{url}}/{{db}}/_changes
|
37 |
+
?limit=10
|
38 |
+
&include_docs=true
|
39 |
+
&since=316731-g1AAAACReJzLYWBgYMpgTmHgzcvPy09JdcjLz8gvLskBCScyJNX___8_K4M5iYEpWyEXKMaekpyabGBuiK4ehwl5LECSoQFI_YcblCUNNsgsMS010dQSXVsWAD_1LCg
|
notebooks/models.ipynb
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import pandas as pd\n",
|
10 |
+
"from fastembed import TextEmbedding\n",
|
11 |
+
"\n",
|
12 |
+
"# pd.set_option('display.max_colwidth', 100)\n",
|
13 |
+
"supported_models = (\n",
|
14 |
+
" pd.DataFrame(TextEmbedding.list_supported_models())\n",
|
15 |
+
" .sort_values(\"size_in_GB\")\n",
|
16 |
+
" .drop(columns=[\"sources\", \"model_file\", \"additional_files\"])\n",
|
17 |
+
" .reset_index(drop=True)\n",
|
18 |
+
")\n",
|
19 |
+
"supported_models.style.set_properties(**{\"width\": \"300px\"}, subset=[\"description\"])"
|
20 |
+
]
|
21 |
+
}
|
22 |
+
],
|
23 |
+
"metadata": {
|
24 |
+
"kernelspec": {
|
25 |
+
"display_name": "article-embedding-p7H3l83p-py3.12",
|
26 |
+
"language": "python",
|
27 |
+
"name": "python3"
|
28 |
+
},
|
29 |
+
"language_info": {
|
30 |
+
"codemirror_mode": {
|
31 |
+
"name": "ipython",
|
32 |
+
"version": 3
|
33 |
+
},
|
34 |
+
"file_extension": ".py",
|
35 |
+
"mimetype": "text/x-python",
|
36 |
+
"name": "python",
|
37 |
+
"nbconvert_exporter": "python",
|
38 |
+
"pygments_lexer": "ipython3",
|
39 |
+
"version": "3.12.7"
|
40 |
+
}
|
41 |
+
},
|
42 |
+
"nbformat": 4,
|
43 |
+
"nbformat_minor": 2
|
44 |
+
}
|
notebooks/query.ipynb
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import warnings\n",
|
10 |
+
"\n",
|
11 |
+
"from qdrant_client import QdrantClient\n",
|
12 |
+
"\n",
|
13 |
+
"from article_embedding.embed import SentenceTransformerModel\n",
|
14 |
+
"\n",
|
15 |
+
"warnings.simplefilter(action=\"ignore\", category=FutureWarning)\n",
|
16 |
+
"embedding_model = SentenceTransformerModel(\"dunzhang/stella_en_400M_v5\")\n",
|
17 |
+
"qdrant = QdrantClient()"
|
18 |
+
]
|
19 |
+
},
|
20 |
+
{
|
21 |
+
"cell_type": "code",
|
22 |
+
"execution_count": null,
|
23 |
+
"metadata": {},
|
24 |
+
"outputs": [],
|
25 |
+
"source": [
|
26 |
+
"from qdrant_client import QdrantClient\n",
|
27 |
+
"\n",
|
28 |
+
"from article_embedding.embed import SentenceTransformerModel\n",
|
29 |
+
"\n",
|
30 |
+
"warnings.simplefilter(action=\"ignore\", category=FutureWarning)\n",
|
31 |
+
"embedding_model = SentenceTransformerModel(\"dunzhang/stella_en_400M_v5\")\n",
|
32 |
+
"qdrant = QdrantClient()"
|
33 |
+
]
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"cell_type": "code",
|
37 |
+
"execution_count": null,
|
38 |
+
"metadata": {},
|
39 |
+
"outputs": [],
|
40 |
+
"source": [
|
41 |
+
"vector = embedding_model.embed(\n",
|
42 |
+
" [\n",
|
43 |
+
" \"\"\"\n",
|
44 |
+
"The COVID-19 pandemic is the worst public health crisis since the 1918 influenza pandemic over a century ago,\n",
|
45 |
+
"when the science of virology was still in its infancy. By mid-July 2022, the official death toll since the\n",
|
46 |
+
"start of the pandemic stood at over 1 million in the US and nearly 6.4 million globally. Estimates of excess\n",
|
47 |
+
"deaths place the number of people killed directly or indirectly by the pandemic at over 1.2 million in the US\n",
|
48 |
+
"and 21.5 million internationally. In just two-and-a-half years, total deaths from the pandemic have reached the\n",
|
49 |
+
"estimated 20 million military and civilian deaths during the four years of World War I (1914-1918).\n",
|
50 |
+
"\"\"\"\n",
|
51 |
+
" ]\n",
|
52 |
+
")"
|
53 |
+
]
|
54 |
+
},
|
55 |
+
{
|
56 |
+
"cell_type": "code",
|
57 |
+
"execution_count": null,
|
58 |
+
"metadata": {},
|
59 |
+
"outputs": [],
|
60 |
+
"source": [
|
61 |
+
"vector = embedding_model.embed(\n",
|
62 |
+
" [\n",
|
63 |
+
" \"\"\"\n",
|
64 |
+
"Let us borrow Trotsky’s brilliant use of metaphor for the purpose of explaining the essential significance\n",
|
65 |
+
"of the petty-bourgeois pseudo-left’s call for a “Party of the 99 percent.” On the cruise ship America, one\n",
|
66 |
+
"percent of the passengers occupy state rooms that look out on the ocean. They dine with captain Trump in an\n",
|
67 |
+
"exclusive five-star restaurant, where they wash down their succulent meals with wine that costs $10,000 per\n",
|
68 |
+
"bottle. The \"next nine percent\" of the passengers, depending on what they can afford, make do with cabins\n",
|
69 |
+
"whose quality reflects their lower price. The cheapest of the rooms in the top ten percent category lack a\n",
|
70 |
+
"view out onto the ocean and have shabby rugs and uncomfortable mattresses. And, unless they are able and\n",
|
71 |
+
"willing to pay a substantial surcharge, the occupants of these rooms are not permitted to use the pool and\n",
|
72 |
+
"spa reserved for the richest passengers.\n",
|
73 |
+
"\"\"\"\n",
|
74 |
+
" ]\n",
|
75 |
+
")"
|
76 |
+
]
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"cell_type": "code",
|
80 |
+
"execution_count": null,
|
81 |
+
"metadata": {},
|
82 |
+
"outputs": [],
|
83 |
+
"source": [
|
84 |
+
"vector = embedding_model.embed(\n",
|
85 |
+
" [\n",
|
86 |
+
" \"\"\"\n",
|
87 |
+
"Israel’s slaughter of tens if not hundreds of thousands of men, women and children, the destruction of entire\n",
|
88 |
+
"cities, hospitals, schools, mosques, and the imposition of mass starvation on Gaza’s 2.3 million residents—all\n",
|
89 |
+
"this has produced profound shock and anger towards the Netanyahu regime and its imperialist backers in the US,\n",
|
90 |
+
"Britain, Germany and elsewhere. It has fueled the political radicalisation of millions of workers and young people,\n",
|
91 |
+
"for whom Zionism is now synonymous with some of the worst atrocities committed anywhere since the Second World War.\n",
|
92 |
+
"\"\"\"\n",
|
93 |
+
" ]\n",
|
94 |
+
")"
|
95 |
+
]
|
96 |
+
},
|
97 |
+
{
|
98 |
+
"cell_type": "code",
|
99 |
+
"execution_count": null,
|
100 |
+
"metadata": {},
|
101 |
+
"outputs": [],
|
102 |
+
"source": [
|
103 |
+
"vector = embedding_model.embed([\"The COVID winter wave, the emergence of the Delta variant and the January 6th coup\"])[0]"
|
104 |
+
]
|
105 |
+
},
|
106 |
+
{
|
107 |
+
"cell_type": "code",
|
108 |
+
"execution_count": null,
|
109 |
+
"metadata": {},
|
110 |
+
"outputs": [],
|
111 |
+
"source": [
|
112 |
+
"vector"
|
113 |
+
]
|
114 |
+
},
|
115 |
+
{
|
116 |
+
"cell_type": "code",
|
117 |
+
"execution_count": null,
|
118 |
+
"metadata": {},
|
119 |
+
"outputs": [],
|
120 |
+
"source": [
|
121 |
+
"# calculate epoch seconds for 2023-01-01\n",
|
122 |
+
"query_filter = None\n",
|
123 |
+
"result = qdrant.query_points(\"wsws\", query=vector[0], query_filter=query_filter)"
|
124 |
+
]
|
125 |
+
},
|
126 |
+
{
|
127 |
+
"cell_type": "code",
|
128 |
+
"execution_count": null,
|
129 |
+
"metadata": {},
|
130 |
+
"outputs": [],
|
131 |
+
"source": [
|
132 |
+
"for point in result.points:\n",
|
133 |
+
" doc = point.payload\n",
|
134 |
+
" print(f\"{point.score * 100:.1f}% https://www.wsws.org/{doc[\"path\"]} - {doc[\"title\"]}\")"
|
135 |
+
]
|
136 |
+
},
|
137 |
+
{
|
138 |
+
"cell_type": "code",
|
139 |
+
"execution_count": null,
|
140 |
+
"metadata": {},
|
141 |
+
"outputs": [],
|
142 |
+
"source": [
|
143 |
+
"print(result.model_dump_json(indent=2))"
|
144 |
+
]
|
145 |
+
}
|
146 |
+
],
|
147 |
+
"metadata": {
|
148 |
+
"kernelspec": {
|
149 |
+
"display_name": ".venv",
|
150 |
+
"language": "python",
|
151 |
+
"name": "python3"
|
152 |
+
},
|
153 |
+
"language_info": {
|
154 |
+
"codemirror_mode": {
|
155 |
+
"name": "ipython",
|
156 |
+
"version": 3
|
157 |
+
},
|
158 |
+
"file_extension": ".py",
|
159 |
+
"mimetype": "text/x-python",
|
160 |
+
"name": "python",
|
161 |
+
"nbconvert_exporter": "python",
|
162 |
+
"pygments_lexer": "ipython3",
|
163 |
+
"version": "3.12.5"
|
164 |
+
}
|
165 |
+
},
|
166 |
+
"nbformat": 4,
|
167 |
+
"nbformat_minor": 2
|
168 |
+
}
|
notebooks/stella.ipynb
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"from sentence_transformers import SentenceTransformer\n",
|
10 |
+
"\n",
|
11 |
+
"# This model supports two prompts: \"s2p_query\" and \"s2s_query\" for sentence-to-passage and sentence-to-sentence tasks, respectively.\n",
|
12 |
+
"# They are defined in `config_sentence_transformers.json`\n",
|
13 |
+
"query_prompt_name = \"s2p_query\"\n",
|
14 |
+
"queries = [\n",
|
15 |
+
" \"What are some ways to reduce stress?\",\n",
|
16 |
+
" \"What are the benefits of drinking green tea?\",\n",
|
17 |
+
"]\n",
|
18 |
+
"# docs do not need any prompts\n",
|
19 |
+
"docs = [\n",
|
20 |
+
" \"There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent stress from building up.\", # noqa: E501\n",
|
21 |
+
" \"Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.\", # noqa: E501\n",
|
22 |
+
"]\n",
|
23 |
+
"\n",
|
24 |
+
"# !The default dimension is 1024, if you need other dimensions, please clone the model and modify `modules.json` to replace `2_Dense_1024` with another dimension, e.g. `2_Dense_256` or `2_Dense_8192` ! # noqa: E501\n",
|
25 |
+
"# on gpu\n",
|
26 |
+
"# model = SentenceTransformer(\"dunzhang/stella_en_400M_v5\", trust_remote_code=True).cuda()\n",
|
27 |
+
"# you can also use this model without the features of `use_memory_efficient_attention` and `unpad_inputs`. It can be worked in CPU.\n",
|
28 |
+
"model = SentenceTransformer(\n",
|
29 |
+
" \"dunzhang/stella_en_400M_v5\",\n",
|
30 |
+
" trust_remote_code=True,\n",
|
31 |
+
" device=\"mps\",\n",
|
32 |
+
" config_kwargs={\"use_memory_efficient_attention\": False, \"unpad_inputs\": False},\n",
|
33 |
+
")\n",
|
34 |
+
"query_embeddings = model.encode(queries, prompt_name=query_prompt_name)\n",
|
35 |
+
"doc_embeddings = model.encode(docs)\n",
|
36 |
+
"print(query_embeddings.shape, doc_embeddings.shape)\n",
|
37 |
+
"# (2, 1024) (2, 1024)\n",
|
38 |
+
"\n",
|
39 |
+
"similarities = model.similarity(query_embeddings, doc_embeddings)\n",
|
40 |
+
"print(similarities)\n",
|
41 |
+
"# tensor([[0.8398, 0.2990],\n",
|
42 |
+
"# [0.3282, 0.8095]])"
|
43 |
+
]
|
44 |
+
}
|
45 |
+
],
|
46 |
+
"metadata": {
|
47 |
+
"kernelspec": {
|
48 |
+
"display_name": ".venv",
|
49 |
+
"language": "python",
|
50 |
+
"name": "python3"
|
51 |
+
},
|
52 |
+
"language_info": {
|
53 |
+
"codemirror_mode": {
|
54 |
+
"name": "ipython",
|
55 |
+
"version": 3
|
56 |
+
},
|
57 |
+
"file_extension": ".py",
|
58 |
+
"mimetype": "text/x-python",
|
59 |
+
"name": "python",
|
60 |
+
"nbconvert_exporter": "python",
|
61 |
+
"pygments_lexer": "ipython3",
|
62 |
+
"version": "3.12.5"
|
63 |
+
}
|
64 |
+
},
|
65 |
+
"nbformat": 4,
|
66 |
+
"nbformat_minor": 2
|
67 |
+
}
|
poetry.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|
poetry.toml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
[virtualenvs]
|
2 |
+
in-project = true
|
pyproject.toml
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.poetry]
|
2 |
+
name = "article-embedding"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = ""
|
5 |
+
authors = ["Jesus Lopez <jesus@jesusla.com>"]
|
6 |
+
readme = "README.md"
|
7 |
+
|
8 |
+
[tool.poetry.scripts]
|
9 |
+
article-embedding = "article_embedding.cli:cli"
|
10 |
+
|
11 |
+
[tool.poetry.dependencies]
|
12 |
+
python = "~3.12"
|
13 |
+
httpx = "^0.27.2"
|
14 |
+
backoff = "^2.2.1"
|
15 |
+
python-dotenv = "^1.0.1"
|
16 |
+
tqdm = "^4.67.0"
|
17 |
+
qdrant-client = { extras = ["fastembed"], version = "^1.12.1" }
|
18 |
+
sentence-transformers = "^3.3.1"
|
19 |
+
gspread = "^6.1.4"
|
20 |
+
gradio = "^5.8.0"
|
21 |
+
langchain-text-splitters = "^0.3.2"
|
22 |
+
click = "^8.1.7"
|
23 |
+
click-help-colors = "^0.9.4"
|
24 |
+
modal = "^0.68.23"
|
25 |
+
datasets = "^3.2.0"
|
26 |
+
einops = "^0.8.0"
|
27 |
+
torch = "^2.5.1"
|
28 |
+
pandoc = "^2.4"
|
29 |
+
openai = "^1.58.1"
|
30 |
+
pyyaml = "^6.0.2"
|
31 |
+
pydantic = "^2.10.4"
|
32 |
+
|
33 |
+
|
34 |
+
[tool.poetry.group.dev.dependencies]
|
35 |
+
pre-commit = "^4.0.1"
|
36 |
+
ipykernel = "^6.29.5"
|
37 |
+
ruff = "^0.7.4"
|
38 |
+
nbstripout = "^0.8.1"
|
39 |
+
|
40 |
+
|
41 |
+
[tool.poetry.group.typing.dependencies]
|
42 |
+
mypy = "^1.13.0"
|
43 |
+
types-tqdm = "^4.66.0.20240417"
|
44 |
+
types-pyyaml = "^6.0.12.20241221"
|
45 |
+
|
46 |
+
|
47 |
+
[tool.poetry.group.jupyter.dependencies]
|
48 |
+
pandas = "^2.2.3"
|
49 |
+
ipywidgets = "^8.1.5"
|
50 |
+
jinja2 = "^3.1.4"
|
51 |
+
|
52 |
+
|
53 |
+
[tool.poetry.group.ci.dependencies]
|
54 |
+
huggingface-hub = { extras = ["cli"], version = "^0.27.1" }
|
55 |
+
gradio = "^5.12.0"
|
56 |
+
|
57 |
+
[tool.ruff]
|
58 |
+
line-length = 132
|
59 |
+
|
60 |
+
[tool.ruff.lint]
|
61 |
+
select = [
|
62 |
+
# pycodestyle
|
63 |
+
"E",
|
64 |
+
# Pyflakes
|
65 |
+
"F",
|
66 |
+
# pyupgrade
|
67 |
+
"UP",
|
68 |
+
# flake8-bugbear
|
69 |
+
"B",
|
70 |
+
# flake8-simplify
|
71 |
+
"SIM",
|
72 |
+
# isort
|
73 |
+
"I",
|
74 |
+
]
|
75 |
+
|
76 |
+
[tool.mypy]
|
77 |
+
strict = true
|
78 |
+
|
79 |
+
[[tool.mypy.overrides]]
|
80 |
+
module = ["gradio", "pandoc"]
|
81 |
+
ignore_missing_imports = true
|
82 |
+
|
83 |
+
|
84 |
+
[build-system]
|
85 |
+
requires = ["poetry-core"]
|
86 |
+
build-backend = "poetry.core.masonry.api"
|
requirements.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tests/__init__.py
ADDED
File without changes
|