Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- .flake8 +6 -0
- .gitignore +180 -0
- .isort.cfg +2 -0
- .pre-commit-config.yaml +14 -0
- README.md +57 -1
- app/README.md +6 -0
- app/__init__.py +9 -0
- app/__main__.py +7 -0
- app/_config.py +62 -0
- app/components/__init__.py +0 -0
- app/components/embedding/__init__.py +0 -0
- app/components/embedding/component.py +38 -0
- app/components/ingest/__init__.py +0 -0
- app/components/ingest/component.py +143 -0
- app/components/ingest/helpers.py +61 -0
- app/components/llm/__init__.py +0 -0
- app/components/llm/component.py +50 -0
- app/components/node_store/__init__.py +0 -0
- app/components/node_store/component.py +31 -0
- app/components/vector_store/__init__.py +0 -0
- app/components/vector_store/component.py +51 -0
- app/enums.py +39 -0
- app/main.py +38 -0
- app/paths.py +15 -0
- app/server/__init__.py +0 -0
- app/server/chat/__init__.py +0 -0
- app/server/chat/router.py +70 -0
- app/server/chat/schemas.py +45 -0
- app/server/chat/service.py +122 -0
- app/server/chat/utils.py +68 -0
- app/server/embedding/__init__.py +0 -0
- app/server/embedding/router.py +18 -0
- app/server/embedding/schemas.py +19 -0
- app/server/embedding/service.py +18 -0
- app/server/ingest/__init__.py +0 -0
- app/server/ingest/schemas.py +32 -0
- app/server/ingest/service.py +123 -0
- app/ui/__init__.py +1 -0
- app/ui/dodge_ava.jpg +0 -0
- app/ui/schemas.py +27 -0
- app/ui/ui.py +228 -0
- docker-compose.yml +26 -0
- prestart.sh +5 -0
- pyproject.toml +31 -0
- setup.py +31 -0
.flake8
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[flake8]
|
2 |
+
ignore = E203, E501, W503
|
3 |
+
max-line-length = 88
|
4 |
+
max-complexity = 10
|
5 |
+
per-file-ignores =
|
6 |
+
*/__init__.py: F401, F403
|
.gitignore
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Created by https://www.toptal.com/developers/gitignore/api/python,pycharm
|
2 |
+
# Edit at https://www.toptal.com/developers/gitignore?templates=python,pycharm
|
3 |
+
|
4 |
+
### PyCharm ###
|
5 |
+
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
|
6 |
+
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
7 |
+
|
8 |
+
.idea/
|
9 |
+
|
10 |
+
# AWS User-
|
11 |
+
# CMake
|
12 |
+
cmake-build-*/
|
13 |
+
|
14 |
+
# File-based project format
|
15 |
+
*.iws
|
16 |
+
|
17 |
+
# IntelliJ
|
18 |
+
out/
|
19 |
+
|
20 |
+
# mpeltonen/sbt-idea plugin
|
21 |
+
.idea_modules/
|
22 |
+
|
23 |
+
# JIRA plugin
|
24 |
+
atlassian-ide-plugin.xml
|
25 |
+
|
26 |
+
# Crashlytics plugin (for Android Studio and IntelliJ)
|
27 |
+
com_crashlytics_export_strings.xml
|
28 |
+
crashlytics.properties
|
29 |
+
crashlytics-build.properties
|
30 |
+
fabric.properties
|
31 |
+
|
32 |
+
### Python ###
|
33 |
+
# Byte-compiled / optimized / DLL files
|
34 |
+
__pycache__/
|
35 |
+
*.py[cod]
|
36 |
+
*$py.class
|
37 |
+
|
38 |
+
# C extensions
|
39 |
+
*.so
|
40 |
+
|
41 |
+
# Distribution / packaging
|
42 |
+
.Python
|
43 |
+
build/
|
44 |
+
develop-eggs/
|
45 |
+
dist/
|
46 |
+
downloads/
|
47 |
+
eggs/
|
48 |
+
.eggs/
|
49 |
+
lib/
|
50 |
+
lib64/
|
51 |
+
parts/
|
52 |
+
sdist/
|
53 |
+
var/
|
54 |
+
wheels/
|
55 |
+
share/python-wheels/
|
56 |
+
*.egg-info/
|
57 |
+
.installed.cfg
|
58 |
+
*.egg
|
59 |
+
MANIFEST
|
60 |
+
|
61 |
+
# PyInstaller
|
62 |
+
# Usually these files are written by a python script from a template
|
63 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
64 |
+
*.manifest
|
65 |
+
*.spec
|
66 |
+
|
67 |
+
# Installer logs
|
68 |
+
pip-log.txt
|
69 |
+
pip-delete-this-directory.txt
|
70 |
+
|
71 |
+
# Unit test / coverage reports
|
72 |
+
htmlcov/
|
73 |
+
.tox/
|
74 |
+
.nox/
|
75 |
+
.coverage
|
76 |
+
.coverage.*
|
77 |
+
.cache
|
78 |
+
nosetests.xml
|
79 |
+
coverage.xml
|
80 |
+
*.cover
|
81 |
+
*.py,cover
|
82 |
+
.hypothesis/
|
83 |
+
.pytest_cache/
|
84 |
+
cover/
|
85 |
+
|
86 |
+
# Translations
|
87 |
+
*.mo
|
88 |
+
*.pot
|
89 |
+
|
90 |
+
# Django stuff:
|
91 |
+
*.log
|
92 |
+
local_settings.py
|
93 |
+
db.sqlite3
|
94 |
+
db.sqlite3-journal
|
95 |
+
|
96 |
+
# Flask stuff:
|
97 |
+
instance/
|
98 |
+
.webassets-cache
|
99 |
+
|
100 |
+
# Scrapy stuff:
|
101 |
+
.scrapy
|
102 |
+
|
103 |
+
# Sphinx documentation
|
104 |
+
docs/_build/
|
105 |
+
|
106 |
+
# PyBuilder
|
107 |
+
.pybuilder/
|
108 |
+
target/
|
109 |
+
|
110 |
+
# Jupyter Notebook
|
111 |
+
.ipynb_checkpoints
|
112 |
+
|
113 |
+
# IPython
|
114 |
+
profile_default/
|
115 |
+
ipython_config.py
|
116 |
+
|
117 |
+
# pyenv
|
118 |
+
# For a library or package, you might want to ignore these files since the code is
|
119 |
+
# intended to run in multiple environments; otherwise, check them in:
|
120 |
+
# .python-version
|
121 |
+
|
122 |
+
# pipenv
|
123 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
124 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
125 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
126 |
+
# install all needed dependencies.
|
127 |
+
#Pipfile.lock
|
128 |
+
|
129 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
130 |
+
__pypackages__/
|
131 |
+
|
132 |
+
# Celery stuff
|
133 |
+
celerybeat-schedule
|
134 |
+
celerybeat.pid
|
135 |
+
|
136 |
+
# SageMath parsed files
|
137 |
+
*.sage.py
|
138 |
+
|
139 |
+
# Environments
|
140 |
+
.env
|
141 |
+
.venv
|
142 |
+
env/
|
143 |
+
venv/
|
144 |
+
ENV/
|
145 |
+
env.bak/
|
146 |
+
venv.bak/
|
147 |
+
|
148 |
+
poetry.lock
|
149 |
+
|
150 |
+
# Spyder project settings
|
151 |
+
.spyderproject
|
152 |
+
.spyproject
|
153 |
+
|
154 |
+
# Rope project settings
|
155 |
+
.ropeproject
|
156 |
+
|
157 |
+
# mkdocs documentation
|
158 |
+
/site
|
159 |
+
|
160 |
+
# mypy
|
161 |
+
.mypy_cache/
|
162 |
+
.dmypy.json
|
163 |
+
dmypy.json
|
164 |
+
|
165 |
+
# Pyre type checker
|
166 |
+
.pyre/
|
167 |
+
|
168 |
+
# pytype static type analyzer
|
169 |
+
.pytype/
|
170 |
+
|
171 |
+
# Cython debug symbols
|
172 |
+
cython_debug/
|
173 |
+
|
174 |
+
# End of https://www.toptal.com/developers/gitignore/api/python,pycharm
|
175 |
+
|
176 |
+
coverage_report/
|
177 |
+
|
178 |
+
local_data/
|
179 |
+
models/
|
180 |
+
.DS_Store
|
.isort.cfg
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
[settings]
|
2 |
+
profile = black
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
repos:
|
2 |
+
- repo: https://github.com/psf/black
|
3 |
+
rev: 23.3.0
|
4 |
+
hooks:
|
5 |
+
- id: black
|
6 |
+
- repo: https://github.com/PyCQA/flake8
|
7 |
+
rev: 6.0.0
|
8 |
+
hooks:
|
9 |
+
- id: flake8
|
10 |
+
- repo: https://github.com/PyCQA/isort
|
11 |
+
rev: 5.12.0
|
12 |
+
hooks:
|
13 |
+
- id: isort
|
14 |
+
|
README.md
CHANGED
@@ -1,6 +1,62 @@
|
|
1 |
---
|
2 |
title: discord-bot
|
3 |
-
app_file: __main__.py
|
4 |
sdk: gradio
|
5 |
sdk_version: 4.33.0
|
6 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
title: discord-bot
|
3 |
+
app_file: app/__main__.py
|
4 |
sdk: gradio
|
5 |
sdk_version: 4.33.0
|
6 |
---
|
7 |
+
# Capstone Project
|
8 |
+
|
9 |
+
## Requirements
|
10 |
+
|
11 |
+
* [Docker](https://www.docker.com/).
|
12 |
+
* [Docker Compose](https://docs.docker.com/compose/install/).
|
13 |
+
* [Poetry](https://python-poetry.org/) for Python package and environment management.
|
14 |
+
|
15 |
+
|
16 |
+
## Installation
|
17 |
+
|
18 |
+
### Set up virtual environment
|
19 |
+
|
20 |
+
```shell
|
21 |
+
python -m venv venv
|
22 |
+
. venv/bin/activate
|
23 |
+
```
|
24 |
+
|
25 |
+
### Install dependencies
|
26 |
+
|
27 |
+
- Install poetry: https://python-poetry.org/
|
28 |
+
- Install dependencies
|
29 |
+
|
30 |
+
```shell
|
31 |
+
poetry install
|
32 |
+
```
|
33 |
+
|
34 |
+
#### Local Requirements
|
35 |
+
```shell
|
36 |
+
poetry install --with local
|
37 |
+
```
|
38 |
+
|
39 |
+
Download embedding and(or) LLM models
|
40 |
+
```shell
|
41 |
+
bash prestart.sh
|
42 |
+
```
|
43 |
+
|
44 |
+
### Install `pre-commit` hooks
|
45 |
+
|
46 |
+
```shell
|
47 |
+
pre-commit install
|
48 |
+
```
|
49 |
+
|
50 |
+
## Running
|
51 |
+
```shell
|
52 |
+
docker compose up -d
|
53 |
+
poetry run python -m app
|
54 |
+
```
|
55 |
+
|
56 |
+
## URLs
|
57 |
+
### Development URLs
|
58 |
+
#### Gradio UI
|
59 |
+
http://localhost:8000
|
60 |
+
|
61 |
+
#### API Documentation
|
62 |
+
http://localhost:8000/docs
|
app/README.md
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: discord-bot
|
3 |
+
app_file: __main__.py
|
4 |
+
sdk: gradio
|
5 |
+
sdk_version: 4.33.0
|
6 |
+
---
|
app/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
ROOT_LOG_LEVEL = "INFO"
|
4 |
+
|
5 |
+
PRETTY_LOG_FORMAT = (
|
6 |
+
"%(asctime)s.%(msecs)03d [%(levelname)-8s] %(name)+25s - %(message)s"
|
7 |
+
)
|
8 |
+
logging.basicConfig(level=ROOT_LOG_LEVEL, format=PRETTY_LOG_FORMAT, datefmt="%H:%M:%S")
|
9 |
+
logging.captureWarnings(True)
|
app/__main__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import uvicorn
|
2 |
+
|
3 |
+
from app._config import settings
|
4 |
+
from app.main import app
|
5 |
+
|
6 |
+
if __name__ == "__main__":
|
7 |
+
uvicorn.run(app, host="0.0.0.0", port=settings.PORT)
|
app/_config.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Literal, Optional
|
3 |
+
|
4 |
+
from pydantic import Field
|
5 |
+
from pydantic_settings import BaseSettings
|
6 |
+
|
7 |
+
|
8 |
+
class Settings(BaseSettings):
|
9 |
+
ENVIRONMENT: str
|
10 |
+
PORT: int = 8000
|
11 |
+
VECTOR_DATABASE: Literal["weaviate"] = "weaviate"
|
12 |
+
|
13 |
+
OPENAI_API_KEY: Optional[str] = None
|
14 |
+
OPENAI_MODEL: str = "gpt-3.5-turbo"
|
15 |
+
|
16 |
+
WEAVIATE_CLIENT_URL: str = "http://localhost:8080"
|
17 |
+
|
18 |
+
LLM_MODE: Literal["openai", "mock", "local"] = "mock"
|
19 |
+
EMBEDDING_MODE: Literal["openai", "mock", "local"] = "mock"
|
20 |
+
|
21 |
+
LOCAL_DATA_FOLDER: str = "local_data/test"
|
22 |
+
|
23 |
+
DEFAULT_QUERY_SYSTEM_PROMPT: str = "You can only answer questions about the provided context. If you know the answer but it is not based in the provided context, don't provide the answer, just state the answer is not in the context provided."
|
24 |
+
|
25 |
+
LOCAL_HF_EMBEDDING_MODEL_NAME: str = "BAAI/bge-small-en-v1.5"
|
26 |
+
|
27 |
+
LOCAL_HF_LLM_REPO_ID: str = "TheBloke/Llama-2-7B-Chat-GGUF"
|
28 |
+
LOCAL_HF_LLM_MODEL_FILE: str = "llama-2-7b-chat.Q4_K_M.gguf"
|
29 |
+
|
30 |
+
# LLM config
|
31 |
+
LLM_TEMPERATURE: float = Field(
|
32 |
+
default=0.1, description="The temperature to use for sampling."
|
33 |
+
)
|
34 |
+
LLM_MAX_NEW_TOKENS: int = Field(
|
35 |
+
default=256,
|
36 |
+
description="The maximum number of tokens to generate.",
|
37 |
+
)
|
38 |
+
LLM_CONTEXT_WINDOW: int = Field(
|
39 |
+
default=3900,
|
40 |
+
description="The maximum number of context tokens for the model.",
|
41 |
+
)
|
42 |
+
|
43 |
+
# UI
|
44 |
+
IS_UI_ENABLED: bool = True
|
45 |
+
UI_PATH: str = "/"
|
46 |
+
|
47 |
+
# Rerank
|
48 |
+
IS_RERANK_ENABLED: bool = True
|
49 |
+
RERANK_TOP_N: int = 3
|
50 |
+
RERANK_MODEL_NAME: str = "cross-encoder/ms-marco-MiniLM-L-2-v2"
|
51 |
+
|
52 |
+
class Config:
|
53 |
+
case_sensitive = True
|
54 |
+
env_file_encoding = "utf-8"
|
55 |
+
|
56 |
+
|
57 |
+
environment = os.environ.get("ENVIRONMENT", "local")
|
58 |
+
settings = Settings(
|
59 |
+
ENVIRONMENT=environment,
|
60 |
+
# ".env.{environment}" takes priority over ".env"
|
61 |
+
_env_file=[".env", f".env.{environment}"],
|
62 |
+
)
|
app/components/__init__.py
ADDED
File without changes
|
app/components/embedding/__init__.py
ADDED
File without changes
|
app/components/embedding/component.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
from llama_index import MockEmbedding
|
4 |
+
from llama_index.embeddings.base import BaseEmbedding
|
5 |
+
|
6 |
+
from app._config import settings
|
7 |
+
from app.enums import EmbeddingMode
|
8 |
+
from app.paths import models_cache_path
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
MOCK_EMBEDDING_DIM = 1536
|
13 |
+
|
14 |
+
|
15 |
+
class EmbeddingComponent:
|
16 |
+
embedding_model: BaseEmbedding
|
17 |
+
|
18 |
+
def __init__(self) -> None:
|
19 |
+
embedding_mode = settings.EMBEDDING_MODE
|
20 |
+
logger.info("Initializing the embedding model in mode=%s", embedding_mode)
|
21 |
+
match embedding_mode:
|
22 |
+
case EmbeddingMode.OPENAI:
|
23 |
+
from llama_index import OpenAIEmbedding
|
24 |
+
|
25 |
+
self.embedding_model = OpenAIEmbedding(api_key=settings.OPENAI_API_KEY)
|
26 |
+
|
27 |
+
case EmbeddingMode.MOCK:
|
28 |
+
# Not a random number, is the dimensionality used by
|
29 |
+
# the default embedding model
|
30 |
+
self.embedding_model = MockEmbedding(MOCK_EMBEDDING_DIM)
|
31 |
+
|
32 |
+
case EmbeddingMode.LOCAL:
|
33 |
+
from llama_index.embeddings import HuggingFaceEmbedding
|
34 |
+
|
35 |
+
self.embedding_model = HuggingFaceEmbedding(
|
36 |
+
model_name=settings.LOCAL_HF_EMBEDDING_MODEL_NAME,
|
37 |
+
cache_folder=str(models_cache_path),
|
38 |
+
)
|
app/components/ingest/__init__.py
ADDED
File without changes
|
app/components/ingest/component.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
import logging
|
3 |
+
import threading
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Any
|
6 |
+
|
7 |
+
from llama_index import (
|
8 |
+
Document,
|
9 |
+
ServiceContext,
|
10 |
+
StorageContext,
|
11 |
+
VectorStoreIndex,
|
12 |
+
load_index_from_storage,
|
13 |
+
)
|
14 |
+
from llama_index.data_structs import IndexDict
|
15 |
+
from llama_index.indices.base import BaseIndex
|
16 |
+
|
17 |
+
from app.components.ingest.helpers import IngestionHelper
|
18 |
+
from app.paths import local_data_path
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
|
23 |
+
class BaseIngestComponent(abc.ABC):
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
storage_context: StorageContext,
|
27 |
+
service_context: ServiceContext,
|
28 |
+
*args: Any,
|
29 |
+
**kwargs: Any,
|
30 |
+
) -> None:
|
31 |
+
logger.debug(f"Initializing base ingest component type={type(self).__name__}")
|
32 |
+
self.storage_context = storage_context
|
33 |
+
self.service_context = service_context
|
34 |
+
|
35 |
+
@abc.abstractmethod
|
36 |
+
def ingest(self, file_name: str, file_data: Path) -> list[Document]:
|
37 |
+
pass
|
38 |
+
|
39 |
+
@abc.abstractmethod
|
40 |
+
def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]:
|
41 |
+
pass
|
42 |
+
|
43 |
+
@abc.abstractmethod
|
44 |
+
def delete(self, doc_id: str) -> None:
|
45 |
+
pass
|
46 |
+
|
47 |
+
|
48 |
+
class BaseIngestComponentWithIndex(BaseIngestComponent, abc.ABC):
|
49 |
+
def __init__(
|
50 |
+
self,
|
51 |
+
storage_context: StorageContext,
|
52 |
+
service_context: ServiceContext,
|
53 |
+
*args: Any,
|
54 |
+
**kwargs: Any,
|
55 |
+
) -> None:
|
56 |
+
super().__init__(storage_context, service_context, *args, **kwargs)
|
57 |
+
|
58 |
+
self.show_progress = True
|
59 |
+
self._index_thread_lock = (
|
60 |
+
threading.Lock()
|
61 |
+
) # Thread lock! Not Multiprocessing lock
|
62 |
+
self._index = self._initialize_index()
|
63 |
+
|
64 |
+
def _initialize_index(self) -> BaseIndex[IndexDict]:
|
65 |
+
"""Initialize the index from the storage context."""
|
66 |
+
try:
|
67 |
+
# Load the index with store_nodes_override=True to be able to delete them
|
68 |
+
index = load_index_from_storage(
|
69 |
+
storage_context=self.storage_context,
|
70 |
+
service_context=self.service_context,
|
71 |
+
store_nodes_override=True, # Force store nodes in index and document stores
|
72 |
+
show_progress=self.show_progress,
|
73 |
+
)
|
74 |
+
except ValueError:
|
75 |
+
# There are no index in the storage context, creating a new one
|
76 |
+
logger.info("Creating a new vector store index")
|
77 |
+
index = VectorStoreIndex.from_documents(
|
78 |
+
[],
|
79 |
+
storage_context=self.storage_context,
|
80 |
+
service_context=self.service_context,
|
81 |
+
store_nodes_override=True, # Force store nodes in index and document stores
|
82 |
+
show_progress=self.show_progress,
|
83 |
+
)
|
84 |
+
index.storage_context.persist(persist_dir=local_data_path)
|
85 |
+
return index
|
86 |
+
|
87 |
+
def _save_index(self) -> None:
|
88 |
+
self._index.storage_context.persist(persist_dir=local_data_path)
|
89 |
+
|
90 |
+
def delete(self, doc_id: str) -> None:
|
91 |
+
with self._index_thread_lock:
|
92 |
+
# Delete the document from the index
|
93 |
+
self._index.delete_ref_doc(doc_id, delete_from_docstore=True)
|
94 |
+
|
95 |
+
# Save the index
|
96 |
+
self._save_index()
|
97 |
+
|
98 |
+
|
99 |
+
class SimpleIngestComponent(BaseIngestComponentWithIndex):
|
100 |
+
def __init__(
|
101 |
+
self,
|
102 |
+
storage_context: StorageContext,
|
103 |
+
service_context: ServiceContext,
|
104 |
+
*args: Any,
|
105 |
+
**kwargs: Any,
|
106 |
+
) -> None:
|
107 |
+
super().__init__(storage_context, service_context, *args, **kwargs)
|
108 |
+
|
109 |
+
def ingest(self, file_name: str, file_data: Path) -> list[Document]:
|
110 |
+
logger.info("Ingesting file_name=%s", file_name)
|
111 |
+
documents = IngestionHelper.transform_file_into_documents(file_name, file_data)
|
112 |
+
logger.info(
|
113 |
+
"Transformed file=%s into count=%s documents", file_name, len(documents)
|
114 |
+
)
|
115 |
+
logger.debug("Saving the documents in the index and doc store")
|
116 |
+
return self._save_docs(documents)
|
117 |
+
|
118 |
+
def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]:
|
119 |
+
saved_documents = []
|
120 |
+
for file_name, file_data in files:
|
121 |
+
documents = IngestionHelper.transform_file_into_documents(
|
122 |
+
file_name, file_data
|
123 |
+
)
|
124 |
+
saved_documents.extend(self._save_docs(documents))
|
125 |
+
return saved_documents
|
126 |
+
|
127 |
+
def _save_docs(self, documents: list[Document]) -> list[Document]:
|
128 |
+
logger.debug("Transforming count=%s documents into nodes", len(documents))
|
129 |
+
with self._index_thread_lock:
|
130 |
+
for document in documents:
|
131 |
+
self._index.insert(document, show_progress=True)
|
132 |
+
logger.debug("Persisting the index and nodes")
|
133 |
+
# persist the index and nodes
|
134 |
+
self._save_index()
|
135 |
+
logger.debug("Persisted the index and nodes")
|
136 |
+
return documents
|
137 |
+
|
138 |
+
|
139 |
+
def get_ingestion_component(
|
140 |
+
storage_context: StorageContext,
|
141 |
+
service_context: ServiceContext,
|
142 |
+
) -> BaseIngestComponent:
|
143 |
+
return SimpleIngestComponent(storage_context, service_context)
|
app/components/ingest/helpers.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
from llama_index import Document
|
5 |
+
from llama_index.readers import JSONReader, StringIterableReader
|
6 |
+
from llama_index.readers.file.base import DEFAULT_FILE_READER_CLS
|
7 |
+
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
# Patching the default file reader to support other file types
|
11 |
+
FILE_READER_CLS = DEFAULT_FILE_READER_CLS.copy()
|
12 |
+
FILE_READER_CLS.update(
|
13 |
+
{
|
14 |
+
".json": JSONReader,
|
15 |
+
}
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
class IngestionHelper:
|
20 |
+
"""Helper class to transform a file into a list of documents.
|
21 |
+
|
22 |
+
This class should be used to transform a file into a list of documents.
|
23 |
+
These methods are thread-safe (and multiprocessing-safe).
|
24 |
+
"""
|
25 |
+
|
26 |
+
@staticmethod
|
27 |
+
def transform_file_into_documents(
|
28 |
+
file_name: str, file_data: Path
|
29 |
+
) -> list[Document]:
|
30 |
+
documents = IngestionHelper._load_file_to_documents(file_name, file_data)
|
31 |
+
for document in documents:
|
32 |
+
document.metadata["file_name"] = file_name
|
33 |
+
IngestionHelper._exclude_metadata(documents)
|
34 |
+
return documents
|
35 |
+
|
36 |
+
@staticmethod
|
37 |
+
def _load_file_to_documents(file_name: str, file_data: Path) -> list[Document]:
|
38 |
+
logger.debug("Transforming file_name=%s into documents", file_name)
|
39 |
+
extension = Path(file_name).suffix
|
40 |
+
reader_cls = FILE_READER_CLS.get(extension)
|
41 |
+
if reader_cls is None:
|
42 |
+
logger.debug(
|
43 |
+
"No reader found for extension=%s, using default string reader",
|
44 |
+
extension,
|
45 |
+
)
|
46 |
+
# Read as a plain text
|
47 |
+
string_reader = StringIterableReader()
|
48 |
+
return string_reader.load_data([file_data.read_text()])
|
49 |
+
|
50 |
+
logger.debug("Specific reader found for extension=%s", extension)
|
51 |
+
return reader_cls().load_data(file_data)
|
52 |
+
|
53 |
+
@staticmethod
|
54 |
+
def _exclude_metadata(documents: list[Document]) -> None:
|
55 |
+
logger.debug("Excluding metadata from count=%s documents", len(documents))
|
56 |
+
for document in documents:
|
57 |
+
document.metadata["doc_id"] = document.doc_id
|
58 |
+
# We don't want the Embeddings search to receive this metadata
|
59 |
+
document.excluded_embed_metadata_keys = ["doc_id"]
|
60 |
+
# We don't want the LLM to receive these metadata in the context
|
61 |
+
document.excluded_llm_metadata_keys = ["file_name", "doc_id", "page_label"]
|
app/components/llm/__init__.py
ADDED
File without changes
|
app/components/llm/component.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
from llama_index.llms import LLM, MockLLM
|
4 |
+
|
5 |
+
from app._config import settings
|
6 |
+
from app.enums import LLMMode
|
7 |
+
from app.paths import models_path
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
class LLMComponent:
|
13 |
+
llm: LLM
|
14 |
+
|
15 |
+
def __init__(self) -> None:
|
16 |
+
llm_mode = settings.LLM_MODE
|
17 |
+
logger.info(f"Initializing the LLM in mode={llm_mode}")
|
18 |
+
match settings.LLM_MODE:
|
19 |
+
case LLMMode.OPENAI:
|
20 |
+
from llama_index.llms import OpenAI
|
21 |
+
|
22 |
+
self.llm = OpenAI(
|
23 |
+
api_key=settings.OPENAI_API_KEY,
|
24 |
+
model=settings.OPENAI_MODEL,
|
25 |
+
)
|
26 |
+
case LLMMode.MOCK:
|
27 |
+
self.llm = MockLLM()
|
28 |
+
|
29 |
+
case LLMMode.LOCAL:
|
30 |
+
from llama_index.llms import LlamaCPP
|
31 |
+
from llama_index.llms.llama_utils import (
|
32 |
+
completion_to_prompt,
|
33 |
+
messages_to_prompt,
|
34 |
+
)
|
35 |
+
|
36 |
+
self.llm = LlamaCPP(
|
37 |
+
model_path=str(models_path / settings.LOCAL_HF_LLM_MODEL_FILE),
|
38 |
+
temperature=settings.LLM_TEMPERATURE,
|
39 |
+
max_new_tokens=settings.LLM_MAX_NEW_TOKENS,
|
40 |
+
context_window=settings.LLM_CONTEXT_WINDOW,
|
41 |
+
generate_kwargs={},
|
42 |
+
# set to at least 1 to use GPU
|
43 |
+
# set to -1 for all gpu
|
44 |
+
# set to 0 for cpu
|
45 |
+
model_kwargs={"n_gpu_layers": 0},
|
46 |
+
# transform inputs into Llama2 format
|
47 |
+
messages_to_prompt=messages_to_prompt,
|
48 |
+
completion_to_prompt=completion_to_prompt,
|
49 |
+
verbose=True,
|
50 |
+
)
|
app/components/node_store/__init__.py
ADDED
File without changes
|
app/components/node_store/component.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
from llama_index.storage.docstore import BaseDocumentStore, SimpleDocumentStore
|
4 |
+
from llama_index.storage.index_store import SimpleIndexStore
|
5 |
+
from llama_index.storage.index_store.types import BaseIndexStore
|
6 |
+
|
7 |
+
from app.paths import local_data_path
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
class NodeStoreComponent:
|
13 |
+
index_store: BaseIndexStore
|
14 |
+
doc_store: BaseDocumentStore
|
15 |
+
|
16 |
+
def __init__(self) -> None:
|
17 |
+
try:
|
18 |
+
self.index_store = SimpleIndexStore.from_persist_dir(
|
19 |
+
persist_dir=str(local_data_path)
|
20 |
+
)
|
21 |
+
except FileNotFoundError:
|
22 |
+
logger.debug("Local index store not found, creating a new one")
|
23 |
+
self.index_store = SimpleIndexStore()
|
24 |
+
|
25 |
+
try:
|
26 |
+
self.doc_store = SimpleDocumentStore.from_persist_dir(
|
27 |
+
persist_dir=str(local_data_path)
|
28 |
+
)
|
29 |
+
except FileNotFoundError:
|
30 |
+
logger.debug("Local document store not found, creating a new one")
|
31 |
+
self.doc_store = SimpleDocumentStore()
|
app/components/vector_store/__init__.py
ADDED
File without changes
|
app/components/vector_store/component.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import typing
|
3 |
+
|
4 |
+
from llama_index import VectorStoreIndex
|
5 |
+
from llama_index.indices.vector_store import VectorIndexRetriever
|
6 |
+
from llama_index.vector_stores.types import VectorStore
|
7 |
+
|
8 |
+
from app._config import settings
|
9 |
+
from app.enums import WEAVIATE_INDEX_NAME, VectorDatabase
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
class VectorStoreComponent:
|
15 |
+
vector_store: VectorStore
|
16 |
+
|
17 |
+
def __init__(self) -> None:
|
18 |
+
match settings.VECTOR_DATABASE:
|
19 |
+
case VectorDatabase.WEAVIATE:
|
20 |
+
import weaviate
|
21 |
+
from llama_index.vector_stores import WeaviateVectorStore
|
22 |
+
|
23 |
+
client = weaviate.Client(settings.WEAVIATE_CLIENT_URL)
|
24 |
+
self.vector_store = typing.cast(
|
25 |
+
VectorStore,
|
26 |
+
WeaviateVectorStore(
|
27 |
+
weaviate_client=client, index_name=WEAVIATE_INDEX_NAME
|
28 |
+
),
|
29 |
+
)
|
30 |
+
case _:
|
31 |
+
# Should be unreachable
|
32 |
+
# The settings validator should have caught this
|
33 |
+
raise ValueError(
|
34 |
+
f"Vectorstore database {settings.VECTOR_DATABASE} not supported"
|
35 |
+
)
|
36 |
+
|
37 |
+
@staticmethod
|
38 |
+
def get_retriever(
|
39 |
+
index: VectorStoreIndex,
|
40 |
+
doc_ids: list[str] | None = None,
|
41 |
+
similarity_top_k: int = 2,
|
42 |
+
) -> VectorIndexRetriever:
|
43 |
+
return VectorIndexRetriever(
|
44 |
+
index=index,
|
45 |
+
similarity_top_k=similarity_top_k,
|
46 |
+
doc_ids=doc_ids,
|
47 |
+
)
|
48 |
+
|
49 |
+
def close(self) -> None:
|
50 |
+
if hasattr(self.vector_store.client, "close"):
|
51 |
+
self.vector_store.client.close()
|
app/enums.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum, auto, unique
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
PROJECT_ROOT_PATH: Path = Path(__file__).parents[1]
|
5 |
+
|
6 |
+
|
7 |
+
@unique
|
8 |
+
class BaseEnum(str, Enum):
|
9 |
+
@staticmethod
|
10 |
+
def _generate_next_value_(name: str, *_):
|
11 |
+
"""
|
12 |
+
Automatically generate values for enum.
|
13 |
+
Enum values are lower-cased enum member names.
|
14 |
+
"""
|
15 |
+
return name.lower()
|
16 |
+
|
17 |
+
@classmethod
|
18 |
+
def get_values(cls) -> list[str]:
|
19 |
+
# noinspection PyUnresolvedReferences
|
20 |
+
return [m.value for m in cls]
|
21 |
+
|
22 |
+
|
23 |
+
class LLMMode(BaseEnum):
|
24 |
+
MOCK = auto()
|
25 |
+
OPENAI = auto()
|
26 |
+
LOCAL = auto()
|
27 |
+
|
28 |
+
|
29 |
+
class EmbeddingMode(BaseEnum):
|
30 |
+
MOCK = auto()
|
31 |
+
OPENAI = auto()
|
32 |
+
LOCAL = auto()
|
33 |
+
|
34 |
+
|
35 |
+
class VectorDatabase(BaseEnum):
|
36 |
+
WEAVIATE = auto()
|
37 |
+
|
38 |
+
|
39 |
+
WEAVIATE_INDEX_NAME = "LlamaIndex"
|
app/main.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
from fastapi import FastAPI
|
4 |
+
|
5 |
+
from app._config import settings
|
6 |
+
from app.components.embedding.component import EmbeddingComponent
|
7 |
+
from app.components.llm.component import LLMComponent
|
8 |
+
from app.components.node_store.component import NodeStoreComponent
|
9 |
+
from app.components.vector_store.component import VectorStoreComponent
|
10 |
+
from app.server.chat.router import chat_router
|
11 |
+
from app.server.chat.service import ChatService
|
12 |
+
from app.server.embedding.router import embedding_router
|
13 |
+
from app.server.ingest.service import IngestService
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
app = FastAPI()
|
18 |
+
app.include_router(chat_router)
|
19 |
+
app.include_router(embedding_router)
|
20 |
+
|
21 |
+
if settings.IS_UI_ENABLED:
|
22 |
+
logger.debug("Importing the UI module")
|
23 |
+
from app.ui.ui import PrivateGptUi
|
24 |
+
|
25 |
+
llm_component = LLMComponent()
|
26 |
+
vector_store_component = VectorStoreComponent()
|
27 |
+
embedding_component = EmbeddingComponent()
|
28 |
+
node_store_component = NodeStoreComponent()
|
29 |
+
|
30 |
+
ingest_service = IngestService(
|
31 |
+
llm_component, vector_store_component, embedding_component, node_store_component
|
32 |
+
)
|
33 |
+
chat_service = ChatService(
|
34 |
+
llm_component, vector_store_component, embedding_component, node_store_component
|
35 |
+
)
|
36 |
+
|
37 |
+
ui = PrivateGptUi(ingest_service, chat_service)
|
38 |
+
ui.mount_in_app(app, settings.UI_PATH)
|
app/paths.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
from app._config import settings
|
4 |
+
from app.enums import PROJECT_ROOT_PATH
|
5 |
+
|
6 |
+
|
7 |
+
def _absolute_or_from_project_root(path: str) -> Path:
|
8 |
+
if path.startswith("/"):
|
9 |
+
return Path(path)
|
10 |
+
return PROJECT_ROOT_PATH / path
|
11 |
+
|
12 |
+
|
13 |
+
local_data_path: Path = _absolute_or_from_project_root(settings.LOCAL_DATA_FOLDER)
|
14 |
+
models_path: Path = PROJECT_ROOT_PATH / "models"
|
15 |
+
models_cache_path: Path = models_path / "cache"
|
app/server/__init__.py
ADDED
File without changes
|
app/server/chat/__init__.py
ADDED
File without changes
|
app/server/chat/router.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter
|
2 |
+
from llama_index.llms import ChatMessage, MessageRole
|
3 |
+
from pydantic import BaseModel
|
4 |
+
|
5 |
+
from app.components.embedding.component import EmbeddingComponent
|
6 |
+
from app.components.llm.component import LLMComponent
|
7 |
+
from app.components.node_store.component import NodeStoreComponent
|
8 |
+
from app.components.vector_store.component import VectorStoreComponent
|
9 |
+
from app.server.chat.service import ChatService
|
10 |
+
from app.server.chat.utils import OpenAICompletion, OpenAIMessage, to_openai_response
|
11 |
+
|
12 |
+
chat_router = APIRouter()
|
13 |
+
|
14 |
+
|
15 |
+
class ChatBody(BaseModel):
|
16 |
+
messages: list[OpenAIMessage]
|
17 |
+
include_sources: bool = True
|
18 |
+
|
19 |
+
model_config = {
|
20 |
+
"json_schema_extra": {
|
21 |
+
"examples": [
|
22 |
+
{
|
23 |
+
"messages": [
|
24 |
+
{
|
25 |
+
"role": "system",
|
26 |
+
"content": "You are a rapper. Always answer with a rap.",
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"role": "user",
|
30 |
+
"content": "How do you fry an egg?",
|
31 |
+
},
|
32 |
+
],
|
33 |
+
"include_sources": True,
|
34 |
+
}
|
35 |
+
]
|
36 |
+
}
|
37 |
+
}
|
38 |
+
|
39 |
+
|
40 |
+
@chat_router.post(
|
41 |
+
"/chat",
|
42 |
+
response_model=None,
|
43 |
+
responses={200: {"model": OpenAICompletion}},
|
44 |
+
tags=["Contextual Completions"],
|
45 |
+
)
|
46 |
+
def chat_completion(body: ChatBody) -> OpenAICompletion:
|
47 |
+
"""Given a list of messages comprising a conversation, return a response.
|
48 |
+
|
49 |
+
Optionally include an initial `role: system` message to influence the way
|
50 |
+
the LLM answers.
|
51 |
+
|
52 |
+
When using `'include_sources': true`, the API will return the source Chunks used
|
53 |
+
to create the response, which come from the context provided.
|
54 |
+
"""
|
55 |
+
llm_component = LLMComponent()
|
56 |
+
vector_store_component = VectorStoreComponent()
|
57 |
+
embedding_component = EmbeddingComponent()
|
58 |
+
node_store_component = NodeStoreComponent()
|
59 |
+
|
60 |
+
chat_service = ChatService(
|
61 |
+
llm_component, vector_store_component, embedding_component, node_store_component
|
62 |
+
)
|
63 |
+
all_messages = [
|
64 |
+
ChatMessage(content=m.content, role=MessageRole(m.role)) for m in body.messages
|
65 |
+
]
|
66 |
+
|
67 |
+
completion = chat_service.chat(messages=all_messages)
|
68 |
+
return to_openai_response(
|
69 |
+
completion.response, completion.sources if body.include_sources else None
|
70 |
+
)
|
app/server/chat/schemas.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal
|
2 |
+
|
3 |
+
from llama_index.schema import NodeWithScore
|
4 |
+
from pydantic import BaseModel, Field
|
5 |
+
|
6 |
+
from app.server.ingest.schemas import IngestedDoc
|
7 |
+
|
8 |
+
|
9 |
+
class Chunk(BaseModel):
|
10 |
+
object: Literal["context.chunk"]
|
11 |
+
score: float = Field(examples=[0.023])
|
12 |
+
document: IngestedDoc
|
13 |
+
text: str = Field(examples=["Outbound sales increased 20%, driven by new leads."])
|
14 |
+
previous_texts: list[str] | None = Field(
|
15 |
+
default=None,
|
16 |
+
examples=[["SALES REPORT 2023", "Inbound didn't show major changes."]],
|
17 |
+
)
|
18 |
+
next_texts: list[str] | None = Field(
|
19 |
+
default=None,
|
20 |
+
examples=[
|
21 |
+
[
|
22 |
+
"New leads came from Google Ads campaign.",
|
23 |
+
"The campaign was run by the Marketing Department",
|
24 |
+
]
|
25 |
+
],
|
26 |
+
)
|
27 |
+
|
28 |
+
@classmethod
|
29 |
+
def from_node(cls: type["Chunk"], node: NodeWithScore) -> "Chunk":
|
30 |
+
doc_id = node.node.ref_doc_id if node.node.ref_doc_id is not None else "-"
|
31 |
+
return cls(
|
32 |
+
object="context.chunk",
|
33 |
+
score=node.score or 0.0,
|
34 |
+
document=IngestedDoc(
|
35 |
+
object="ingest.document",
|
36 |
+
doc_id=doc_id,
|
37 |
+
doc_metadata=node.metadata,
|
38 |
+
),
|
39 |
+
text=node.get_content(),
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
class Completion(BaseModel):
|
44 |
+
response: str
|
45 |
+
sources: list[Chunk] | None = None
|
app/server/chat/service.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
from llama_index import ServiceContext, StorageContext, VectorStoreIndex
|
4 |
+
from llama_index.chat_engine import ContextChatEngine
|
5 |
+
from llama_index.chat_engine.types import BaseChatEngine
|
6 |
+
from llama_index.core.postprocessor import SentenceTransformerRerank
|
7 |
+
from llama_index.indices.postprocessor import MetadataReplacementPostProcessor
|
8 |
+
from llama_index.llms import ChatMessage, MessageRole
|
9 |
+
|
10 |
+
from app._config import settings
|
11 |
+
from app.components.embedding.component import EmbeddingComponent
|
12 |
+
from app.components.llm.component import LLMComponent
|
13 |
+
from app.components.node_store.component import NodeStoreComponent
|
14 |
+
from app.components.vector_store.component import VectorStoreComponent
|
15 |
+
from app.server.chat.schemas import Chunk, Completion
|
16 |
+
|
17 |
+
|
18 |
+
@dataclass
|
19 |
+
class ChatEngineInput:
|
20 |
+
system_message: ChatMessage | None = None
|
21 |
+
last_message: ChatMessage | None = None
|
22 |
+
chat_history: list[ChatMessage] | None = None
|
23 |
+
|
24 |
+
@classmethod
|
25 |
+
def from_messages(cls, messages: list[ChatMessage]) -> "ChatEngineInput":
|
26 |
+
# Detect if there is a system message, extract the last message and chat history
|
27 |
+
system_message = (
|
28 |
+
messages[0]
|
29 |
+
if len(messages) > 0 and messages[0].role == MessageRole.SYSTEM
|
30 |
+
else None
|
31 |
+
)
|
32 |
+
last_message = (
|
33 |
+
messages[-1]
|
34 |
+
if len(messages) > 0 and messages[-1].role == MessageRole.USER
|
35 |
+
else None
|
36 |
+
)
|
37 |
+
# Remove from messages list the system message and last message,
|
38 |
+
# if they exist. The rest is the chat history.
|
39 |
+
if system_message:
|
40 |
+
messages.pop(0)
|
41 |
+
if last_message:
|
42 |
+
messages.pop(-1)
|
43 |
+
chat_history = messages if len(messages) > 0 else None
|
44 |
+
|
45 |
+
return cls(
|
46 |
+
system_message=system_message,
|
47 |
+
last_message=last_message,
|
48 |
+
chat_history=chat_history,
|
49 |
+
)
|
50 |
+
|
51 |
+
|
52 |
+
class ChatService:
|
53 |
+
def __init__(
|
54 |
+
self,
|
55 |
+
llm_component: LLMComponent,
|
56 |
+
vector_store_component: VectorStoreComponent,
|
57 |
+
embedding_component: EmbeddingComponent,
|
58 |
+
node_store_component: NodeStoreComponent,
|
59 |
+
) -> None:
|
60 |
+
self.llm_service = llm_component
|
61 |
+
self.vector_store_component = vector_store_component
|
62 |
+
self.storage_context = StorageContext.from_defaults(
|
63 |
+
vector_store=vector_store_component.vector_store,
|
64 |
+
docstore=node_store_component.doc_store,
|
65 |
+
index_store=node_store_component.index_store,
|
66 |
+
)
|
67 |
+
self.service_context = ServiceContext.from_defaults(
|
68 |
+
llm=llm_component.llm, embed_model=embedding_component.embedding_model
|
69 |
+
)
|
70 |
+
self.index = VectorStoreIndex.from_vector_store(
|
71 |
+
vector_store_component.vector_store,
|
72 |
+
storage_context=self.storage_context,
|
73 |
+
service_context=self.service_context,
|
74 |
+
show_progress=True,
|
75 |
+
)
|
76 |
+
|
77 |
+
def _chat_engine(self, system_prompt: str | None = None) -> BaseChatEngine:
|
78 |
+
vector_index_retriever = self.vector_store_component.get_retriever(
|
79 |
+
index=self.index
|
80 |
+
)
|
81 |
+
|
82 |
+
node_postprocessors = [
|
83 |
+
MetadataReplacementPostProcessor(target_metadata_key="window")
|
84 |
+
]
|
85 |
+
if settings.IS_RERANK_ENABLED:
|
86 |
+
rerank = SentenceTransformerRerank(
|
87 |
+
top_n=settings.RERANK_TOP_N, model=settings.RERANK_MODEL_NAME
|
88 |
+
)
|
89 |
+
node_postprocessors.append(rerank)
|
90 |
+
|
91 |
+
return ContextChatEngine.from_defaults(
|
92 |
+
system_prompt=system_prompt,
|
93 |
+
retriever=vector_index_retriever,
|
94 |
+
service_context=self.service_context,
|
95 |
+
node_postprocessors=node_postprocessors,
|
96 |
+
)
|
97 |
+
|
98 |
+
def chat(self, messages: list[ChatMessage]):
|
99 |
+
chat_engine_input = ChatEngineInput.from_messages(messages)
|
100 |
+
last_message = (
|
101 |
+
chat_engine_input.last_message.content
|
102 |
+
if chat_engine_input.last_message
|
103 |
+
else None
|
104 |
+
)
|
105 |
+
system_prompt = (
|
106 |
+
chat_engine_input.system_message.content
|
107 |
+
if chat_engine_input.system_message
|
108 |
+
else None
|
109 |
+
)
|
110 |
+
chat_history = (
|
111 |
+
chat_engine_input.chat_history if chat_engine_input.chat_history else None
|
112 |
+
)
|
113 |
+
|
114 |
+
chat_engine = self._chat_engine(system_prompt=system_prompt)
|
115 |
+
wrapped_response = chat_engine.chat(
|
116 |
+
message=last_message if last_message is not None else "",
|
117 |
+
chat_history=chat_history,
|
118 |
+
)
|
119 |
+
|
120 |
+
sources = [Chunk.from_node(node) for node in wrapped_response.source_nodes]
|
121 |
+
completion = Completion(response=wrapped_response.response, sources=sources)
|
122 |
+
return completion
|
app/server/chat/utils.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import uuid
|
3 |
+
from typing import Literal
|
4 |
+
|
5 |
+
from llama_index.llms import ChatResponse
|
6 |
+
from pydantic import BaseModel, Field
|
7 |
+
|
8 |
+
from app.server.chat.schemas import Chunk
|
9 |
+
|
10 |
+
|
11 |
+
class OpenAIMessage(BaseModel):
|
12 |
+
"""Inference result, with the source of the message.
|
13 |
+
|
14 |
+
Role could be the assistant or system
|
15 |
+
(providing a default response, not AI generated).
|
16 |
+
"""
|
17 |
+
|
18 |
+
role: Literal["assistant", "system", "user"] = Field(default="user")
|
19 |
+
content: str | None
|
20 |
+
|
21 |
+
|
22 |
+
class OpenAIChoice(BaseModel):
|
23 |
+
"""Response from AI."""
|
24 |
+
|
25 |
+
finish_reason: str | None = Field(examples=["stop"])
|
26 |
+
message: OpenAIMessage | None = None
|
27 |
+
sources: list[Chunk] | None = None
|
28 |
+
index: int = 0
|
29 |
+
|
30 |
+
|
31 |
+
class OpenAICompletion(BaseModel):
|
32 |
+
"""Clone of OpenAI Completion model.
|
33 |
+
|
34 |
+
For more information see: https://platform.openai.com/docs/api-reference/chat/object
|
35 |
+
"""
|
36 |
+
|
37 |
+
id: str
|
38 |
+
object: Literal["completion", "completion.chunk"] = Field(default="completion")
|
39 |
+
created: int = Field(..., examples=[1623340000])
|
40 |
+
model: Literal["llm-agriculture"]
|
41 |
+
choices: list[OpenAIChoice]
|
42 |
+
|
43 |
+
@classmethod
|
44 |
+
def from_text(
|
45 |
+
cls,
|
46 |
+
text: str | None,
|
47 |
+
finish_reason: str | None = None,
|
48 |
+
sources: list[Chunk] | None = None,
|
49 |
+
) -> "OpenAICompletion":
|
50 |
+
return OpenAICompletion(
|
51 |
+
id=str(uuid.uuid4()),
|
52 |
+
object="completion",
|
53 |
+
created=int(time.time()),
|
54 |
+
model="llm-agriculture",
|
55 |
+
choices=[
|
56 |
+
OpenAIChoice(
|
57 |
+
message=OpenAIMessage(role="assistant", content=text),
|
58 |
+
finish_reason=finish_reason,
|
59 |
+
sources=sources,
|
60 |
+
)
|
61 |
+
],
|
62 |
+
)
|
63 |
+
|
64 |
+
|
65 |
+
def to_openai_response(
|
66 |
+
response: str | ChatResponse, sources: list[Chunk] | None = None
|
67 |
+
) -> OpenAICompletion:
|
68 |
+
return OpenAICompletion.from_text(response, finish_reason="stop", sources=sources)
|
app/server/embedding/__init__.py
ADDED
File without changes
|
app/server/embedding/router.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter
|
2 |
+
|
3 |
+
from app.components.embedding.component import EmbeddingComponent
|
4 |
+
from app.server.embedding.schemas import EmbeddingsBody, EmbeddingsResponse
|
5 |
+
from app.server.embedding.service import EmbeddingsService
|
6 |
+
|
7 |
+
embedding_router = APIRouter()
|
8 |
+
|
9 |
+
|
10 |
+
@embedding_router.post("/embedding", tags=["Embeddings"])
|
11 |
+
def generate_embeddings(body: EmbeddingsBody) -> EmbeddingsResponse:
|
12 |
+
embedding_component = EmbeddingComponent()
|
13 |
+
service = EmbeddingsService(embedding_component)
|
14 |
+
input_texts = body.input if isinstance(body.input, list) else [body.input]
|
15 |
+
embeddings = service.embed_texts(input_texts)
|
16 |
+
return EmbeddingsResponse(
|
17 |
+
object="list", model=service.embedding_model.model_name, data=embeddings
|
18 |
+
)
|
app/server/embedding/schemas.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal
|
2 |
+
|
3 |
+
from pydantic import BaseModel, Field
|
4 |
+
|
5 |
+
|
6 |
+
class Embedding(BaseModel):
|
7 |
+
index: int
|
8 |
+
object: Literal["embedding"]
|
9 |
+
embedding: list[float] = Field(examples=[[0.1, -0.2]])
|
10 |
+
|
11 |
+
|
12 |
+
class EmbeddingsBody(BaseModel):
|
13 |
+
input: str | list[str]
|
14 |
+
|
15 |
+
|
16 |
+
class EmbeddingsResponse(BaseModel):
|
17 |
+
object: Literal["list"]
|
18 |
+
model: str
|
19 |
+
data: list[Embedding]
|
app/server/embedding/service.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from app.components.embedding.component import EmbeddingComponent
|
2 |
+
from app.server.embedding.schemas import Embedding
|
3 |
+
|
4 |
+
|
5 |
+
class EmbeddingsService:
|
6 |
+
def __init__(self, embedding_component: EmbeddingComponent) -> None:
|
7 |
+
self.embedding_model = embedding_component.embedding_model
|
8 |
+
|
9 |
+
def embed_texts(self, texts: list[str]) -> list[Embedding]:
|
10 |
+
texts_embeddings = self.embedding_model.get_text_embedding_batch(texts)
|
11 |
+
return [
|
12 |
+
Embedding(
|
13 |
+
index=texts_embeddings.index(embedding),
|
14 |
+
object="embedding",
|
15 |
+
embedding=embedding,
|
16 |
+
)
|
17 |
+
for embedding in texts_embeddings
|
18 |
+
]
|
app/server/ingest/__init__.py
ADDED
File without changes
|
app/server/ingest/schemas.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Literal
|
2 |
+
|
3 |
+
from llama_index import Document
|
4 |
+
from pydantic import BaseModel, Field
|
5 |
+
|
6 |
+
|
7 |
+
class IngestedDoc(BaseModel):
|
8 |
+
object: Literal["ingest.document"]
|
9 |
+
doc_id: str = Field(examples=["c202d5e6-7b69-4869-81cc-dd574ee8ee11"])
|
10 |
+
doc_metadata: dict[str, Any] | None = Field(
|
11 |
+
examples=[
|
12 |
+
{
|
13 |
+
"page_label": "2",
|
14 |
+
"file_name": "agriculture.pdf",
|
15 |
+
}
|
16 |
+
]
|
17 |
+
)
|
18 |
+
|
19 |
+
@staticmethod
|
20 |
+
def curate_metadata(metadata: dict[str, Any]) -> dict[str, Any]:
|
21 |
+
"""Remove unwanted metadata keys."""
|
22 |
+
for key in ["doc_id", "window", "original_text"]:
|
23 |
+
metadata.pop(key, None)
|
24 |
+
return metadata
|
25 |
+
|
26 |
+
@staticmethod
|
27 |
+
def from_document(document: Document) -> "IngestedDoc":
|
28 |
+
return IngestedDoc(
|
29 |
+
object="ingest.document",
|
30 |
+
doc_id=document.doc_id,
|
31 |
+
doc_metadata=IngestedDoc.curate_metadata(document.metadata),
|
32 |
+
)
|
app/server/ingest/service.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import tempfile
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import AnyStr, BinaryIO
|
5 |
+
|
6 |
+
from llama_index import ServiceContext, StorageContext
|
7 |
+
from llama_index.node_parser import SentenceWindowNodeParser
|
8 |
+
|
9 |
+
from app.components.embedding.component import EmbeddingComponent
|
10 |
+
from app.components.ingest.component import get_ingestion_component
|
11 |
+
from app.components.llm.component import LLMComponent
|
12 |
+
from app.components.node_store.component import NodeStoreComponent
|
13 |
+
from app.components.vector_store.component import VectorStoreComponent
|
14 |
+
from app.server.ingest.schemas import IngestedDoc
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
class IngestService:
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
llm_component: LLMComponent,
|
23 |
+
vector_store_component: VectorStoreComponent,
|
24 |
+
embedding_component: EmbeddingComponent,
|
25 |
+
node_store_component: NodeStoreComponent,
|
26 |
+
) -> None:
|
27 |
+
self.llm_service = llm_component
|
28 |
+
self.storage_context = StorageContext.from_defaults(
|
29 |
+
vector_store=vector_store_component.vector_store,
|
30 |
+
docstore=node_store_component.doc_store,
|
31 |
+
index_store=node_store_component.index_store,
|
32 |
+
)
|
33 |
+
node_parser = SentenceWindowNodeParser.from_defaults()
|
34 |
+
self.ingest_service_context = ServiceContext.from_defaults(
|
35 |
+
llm=self.llm_service.llm,
|
36 |
+
embed_model=embedding_component.embedding_model,
|
37 |
+
node_parser=node_parser,
|
38 |
+
# Embeddings done early in the pipeline of node transformations, right
|
39 |
+
# after the node parsing
|
40 |
+
transformations=[node_parser, embedding_component.embedding_model],
|
41 |
+
)
|
42 |
+
|
43 |
+
self.ingest_component = get_ingestion_component(
|
44 |
+
self.storage_context, self.ingest_service_context
|
45 |
+
)
|
46 |
+
|
47 |
+
def _ingest_data(self, file_name: str, file_data: AnyStr) -> list[IngestedDoc]:
|
48 |
+
logger.debug(f"Got file data of size={len(file_data)} to ingest")
|
49 |
+
# llama-index mainly supports reading from files, so
|
50 |
+
# we have to create a tmp file to read for it to work
|
51 |
+
# delete=False to avoid a Windows 11 permission error.
|
52 |
+
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
53 |
+
try:
|
54 |
+
path_to_tmp = Path(tmp.name)
|
55 |
+
if isinstance(file_data, bytes):
|
56 |
+
path_to_tmp.write_bytes(file_data)
|
57 |
+
else:
|
58 |
+
path_to_tmp.write_text(str(file_data))
|
59 |
+
return self.ingest_file(file_name, path_to_tmp)
|
60 |
+
finally:
|
61 |
+
tmp.close()
|
62 |
+
path_to_tmp.unlink()
|
63 |
+
|
64 |
+
def ingest_file(self, file_name: str, file_data: Path) -> list[IngestedDoc]:
|
65 |
+
logger.info(f"Ingesting file_name={file_name}")
|
66 |
+
documents = self.ingest_component.ingest(file_name, file_data)
|
67 |
+
logger.info(f"Finished ingestion file_name={file_name}")
|
68 |
+
return [IngestedDoc.from_document(document) for document in documents]
|
69 |
+
|
70 |
+
def ingest_text(self, file_name: str, text: str) -> list[IngestedDoc]:
|
71 |
+
logger.debug(f"Ingesting text data with file_name={file_name}")
|
72 |
+
return self._ingest_data(file_name, text)
|
73 |
+
|
74 |
+
def ingest_bin_data(
|
75 |
+
self, file_name: str, raw_file_data: BinaryIO
|
76 |
+
) -> list[IngestedDoc]:
|
77 |
+
logger.debug(f"Ingesting binary data with file_name={file_name}")
|
78 |
+
file_data = raw_file_data.read()
|
79 |
+
return self._ingest_data(file_name, file_data)
|
80 |
+
|
81 |
+
def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[IngestedDoc]:
|
82 |
+
logger.info(f"Ingesting file_names={[f[0] for f in files]}")
|
83 |
+
documents = self.ingest_component.bulk_ingest(files)
|
84 |
+
logger.info(f"Finished ingestion file_name={[f[0] for f in files]}")
|
85 |
+
return [IngestedDoc.from_document(document) for document in documents]
|
86 |
+
|
87 |
+
def list_ingested(self) -> list[IngestedDoc]:
|
88 |
+
ingested_docs = []
|
89 |
+
try:
|
90 |
+
docstore = self.storage_context.docstore
|
91 |
+
ingested_docs_ids: set[str] = set()
|
92 |
+
|
93 |
+
for node in docstore.docs.values():
|
94 |
+
if node.ref_doc_id is not None:
|
95 |
+
ingested_docs_ids.add(node.ref_doc_id)
|
96 |
+
|
97 |
+
for doc_id in ingested_docs_ids:
|
98 |
+
ref_doc_info = docstore.get_ref_doc_info(ref_doc_id=doc_id)
|
99 |
+
doc_metadata = None
|
100 |
+
if ref_doc_info is not None and ref_doc_info.metadata is not None:
|
101 |
+
doc_metadata = IngestedDoc.curate_metadata(ref_doc_info.metadata)
|
102 |
+
ingested_docs.append(
|
103 |
+
IngestedDoc(
|
104 |
+
object="ingest.document",
|
105 |
+
doc_id=doc_id,
|
106 |
+
doc_metadata=doc_metadata,
|
107 |
+
)
|
108 |
+
)
|
109 |
+
except ValueError:
|
110 |
+
logger.warning("Got an exception when getting list of docs", exc_info=True)
|
111 |
+
pass
|
112 |
+
logger.debug(f"Found count={len(ingested_docs)} ingested documents")
|
113 |
+
return ingested_docs
|
114 |
+
|
115 |
+
def delete(self, doc_id: str) -> None:
|
116 |
+
"""Delete an ingested document.
|
117 |
+
|
118 |
+
:raises ValueError: if the document does not exist
|
119 |
+
"""
|
120 |
+
logger.info(
|
121 |
+
"Deleting the ingested document=%s in the doc and index store", doc_id
|
122 |
+
)
|
123 |
+
self.ingest_component.delete(doc_id)
|
app/ui/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
"""Gradio based UI."""
|
app/ui/dodge_ava.jpg
ADDED
app/ui/schemas.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel
|
2 |
+
|
3 |
+
from app.server.chat.schemas import Chunk
|
4 |
+
|
5 |
+
|
6 |
+
class Source(BaseModel):
|
7 |
+
file: str
|
8 |
+
page: str
|
9 |
+
text: str
|
10 |
+
|
11 |
+
class Config:
|
12 |
+
frozen = True
|
13 |
+
|
14 |
+
@staticmethod
|
15 |
+
def curate_sources(sources: list[Chunk]) -> set["Source"]:
|
16 |
+
curated_sources = set()
|
17 |
+
|
18 |
+
for chunk in sources:
|
19 |
+
doc_metadata = chunk.document.doc_metadata
|
20 |
+
|
21 |
+
file_name = doc_metadata.get("file_name", "-") if doc_metadata else "-"
|
22 |
+
page_label = doc_metadata.get("page_label", "-") if doc_metadata else "-"
|
23 |
+
|
24 |
+
source = Source(file=file_name, page=page_label, text=chunk.text)
|
25 |
+
curated_sources.add(source)
|
26 |
+
|
27 |
+
return curated_sources
|
app/ui/ui.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This file should be imported only and only if you want to run the UI locally."""
|
2 |
+
import itertools
|
3 |
+
import logging
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Any
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
from fastapi import FastAPI
|
9 |
+
from gradio.themes.utils.colors import slate
|
10 |
+
from llama_index.llms import ChatMessage, MessageRole
|
11 |
+
|
12 |
+
from app._config import settings
|
13 |
+
from app.components.embedding.component import EmbeddingComponent
|
14 |
+
from app.components.llm.component import LLMComponent
|
15 |
+
from app.components.node_store.component import NodeStoreComponent
|
16 |
+
from app.components.vector_store.component import VectorStoreComponent
|
17 |
+
from app.enums import PROJECT_ROOT_PATH
|
18 |
+
from app.server.chat.service import ChatService
|
19 |
+
from app.server.ingest.service import IngestService
|
20 |
+
from app.ui.schemas import Source
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
THIS_DIRECTORY_RELATIVE = Path(__file__).parent.relative_to(PROJECT_ROOT_PATH)
|
25 |
+
AVATAR_BOT = THIS_DIRECTORY_RELATIVE / "dodge_ava.jpg"
|
26 |
+
|
27 |
+
UI_TAB_TITLE = "Agriculture Chatbot"
|
28 |
+
|
29 |
+
SOURCES_SEPARATOR = "\n\n Sources: \n"
|
30 |
+
|
31 |
+
|
32 |
+
class PrivateGptUi:
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
ingest_service: IngestService,
|
36 |
+
chat_service: ChatService,
|
37 |
+
) -> None:
|
38 |
+
self._ingest_service = ingest_service
|
39 |
+
self._chat_service = chat_service
|
40 |
+
|
41 |
+
# Cache the UI blocks
|
42 |
+
self._ui_block = None
|
43 |
+
|
44 |
+
# Initialize system prompt
|
45 |
+
self._system_prompt = self._get_default_system_prompt()
|
46 |
+
|
47 |
+
def _chat(self, message: str, history: list[list[str]], *_: Any) -> Any:
|
48 |
+
def build_history() -> list[ChatMessage]:
|
49 |
+
history_messages: list[ChatMessage] = list(
|
50 |
+
itertools.chain(
|
51 |
+
*[
|
52 |
+
[
|
53 |
+
ChatMessage(content=interaction[0], role=MessageRole.USER),
|
54 |
+
ChatMessage(
|
55 |
+
# Remove from history content the Sources information
|
56 |
+
content=interaction[1].split(SOURCES_SEPARATOR)[0],
|
57 |
+
role=MessageRole.ASSISTANT,
|
58 |
+
),
|
59 |
+
]
|
60 |
+
for interaction in history
|
61 |
+
]
|
62 |
+
)
|
63 |
+
)
|
64 |
+
|
65 |
+
# max 20 messages to try to avoid context overflow
|
66 |
+
return history_messages[:20]
|
67 |
+
|
68 |
+
new_message = ChatMessage(content=message, role=MessageRole.USER)
|
69 |
+
all_messages = [*build_history(), new_message]
|
70 |
+
# If a system prompt is set, add it as a system message
|
71 |
+
if self._system_prompt:
|
72 |
+
all_messages.insert(
|
73 |
+
0,
|
74 |
+
ChatMessage(
|
75 |
+
content=self._system_prompt,
|
76 |
+
role=MessageRole.SYSTEM,
|
77 |
+
),
|
78 |
+
)
|
79 |
+
|
80 |
+
completion = self._chat_service.chat(messages=all_messages)
|
81 |
+
full_response = completion.response
|
82 |
+
|
83 |
+
if completion.sources:
|
84 |
+
full_response += SOURCES_SEPARATOR
|
85 |
+
curated_sources = Source.curate_sources(completion.sources)
|
86 |
+
sources_text = "\n\n\n".join(
|
87 |
+
f"{index}. {source.file} (page {source.page})"
|
88 |
+
for index, source in enumerate(curated_sources, start=1)
|
89 |
+
)
|
90 |
+
full_response += sources_text
|
91 |
+
|
92 |
+
return full_response
|
93 |
+
|
94 |
+
# On initialization this function set the system prompt
|
95 |
+
# to the default prompt based on settings.
|
96 |
+
@staticmethod
|
97 |
+
def _get_default_system_prompt() -> str:
|
98 |
+
return settings.DEFAULT_QUERY_SYSTEM_PROMPT
|
99 |
+
|
100 |
+
def _set_system_prompt(self, system_prompt_input: str) -> None:
|
101 |
+
logger.info(f"Setting system prompt to: {system_prompt_input}")
|
102 |
+
self._system_prompt = system_prompt_input
|
103 |
+
|
104 |
+
def _list_ingested_files(self) -> list[list[str]]:
|
105 |
+
files = set()
|
106 |
+
for ingested_document in self._ingest_service.list_ingested():
|
107 |
+
if ingested_document.doc_metadata is None:
|
108 |
+
# Skipping documents without metadata
|
109 |
+
continue
|
110 |
+
file_name = ingested_document.doc_metadata.get(
|
111 |
+
"file_name", "[FILE NAME MISSING]"
|
112 |
+
)
|
113 |
+
files.add(file_name)
|
114 |
+
return [[row] for row in files]
|
115 |
+
|
116 |
+
def _upload_file(self, files: list[str]) -> None:
|
117 |
+
logger.debug("Loading count=%s files", len(files))
|
118 |
+
paths = [Path(file) for file in files]
|
119 |
+
self._ingest_service.bulk_ingest([(str(path.name), path) for path in paths])
|
120 |
+
|
121 |
+
def _build_ui_blocks(self) -> gr.Blocks:
|
122 |
+
logger.debug("Creating the UI blocks")
|
123 |
+
with gr.Blocks(
|
124 |
+
title=UI_TAB_TITLE,
|
125 |
+
theme=gr.themes.Soft(primary_hue=slate),
|
126 |
+
css=".logo { "
|
127 |
+
"display:flex;"
|
128 |
+
"height: 80px;"
|
129 |
+
"border-radius: 8px;"
|
130 |
+
"align-content: center;"
|
131 |
+
"justify-content: center;"
|
132 |
+
"align-items: center;"
|
133 |
+
"}"
|
134 |
+
".logo img { height: 25% }"
|
135 |
+
".contain { display: flex !important; flex-direction: column !important; }"
|
136 |
+
"#component-0, #component-3, #component-10, #component-8 { height: 100% !important; }"
|
137 |
+
"#chatbot { flex-grow: 1 !important; overflow: auto !important;}"
|
138 |
+
"#col { height: calc(100vh - 112px - 16px) !important; }",
|
139 |
+
) as blocks:
|
140 |
+
with gr.Row():
|
141 |
+
gr.HTML(f"<div class='logo'/><h1>{UI_TAB_TITLE}</h1></div")
|
142 |
+
|
143 |
+
with gr.Row(equal_height=False):
|
144 |
+
with gr.Column(scale=3):
|
145 |
+
upload_button = gr.components.UploadButton(
|
146 |
+
"Upload File(s)",
|
147 |
+
type="filepath",
|
148 |
+
file_count="multiple",
|
149 |
+
size="sm",
|
150 |
+
)
|
151 |
+
ingested_dataset = gr.List(
|
152 |
+
self._list_ingested_files,
|
153 |
+
headers=["File name"],
|
154 |
+
label="Ingested Files",
|
155 |
+
interactive=False,
|
156 |
+
render=False, # Rendered under the button
|
157 |
+
)
|
158 |
+
upload_button.upload(
|
159 |
+
self._upload_file,
|
160 |
+
inputs=upload_button,
|
161 |
+
outputs=ingested_dataset,
|
162 |
+
)
|
163 |
+
ingested_dataset.change(
|
164 |
+
self._list_ingested_files,
|
165 |
+
outputs=ingested_dataset,
|
166 |
+
)
|
167 |
+
ingested_dataset.render()
|
168 |
+
system_prompt_input = gr.Textbox(
|
169 |
+
placeholder=self._system_prompt,
|
170 |
+
label="System Prompt",
|
171 |
+
lines=2,
|
172 |
+
interactive=True,
|
173 |
+
render=False,
|
174 |
+
)
|
175 |
+
|
176 |
+
# On blur, set system prompt to use in queries
|
177 |
+
system_prompt_input.blur(
|
178 |
+
self._set_system_prompt,
|
179 |
+
inputs=system_prompt_input,
|
180 |
+
)
|
181 |
+
|
182 |
+
with gr.Column(scale=7, elem_id="col"):
|
183 |
+
_ = gr.ChatInterface(
|
184 |
+
self._chat,
|
185 |
+
chatbot=gr.Chatbot(
|
186 |
+
label=f"LLM: {settings.LLM_MODE}",
|
187 |
+
show_copy_button=True,
|
188 |
+
elem_id="chatbot",
|
189 |
+
render=False,
|
190 |
+
avatar_images=(
|
191 |
+
None,
|
192 |
+
AVATAR_BOT,
|
193 |
+
),
|
194 |
+
),
|
195 |
+
additional_inputs=[upload_button, system_prompt_input],
|
196 |
+
)
|
197 |
+
return blocks
|
198 |
+
|
199 |
+
def get_ui_blocks(self) -> gr.Blocks:
|
200 |
+
if self._ui_block is None:
|
201 |
+
self._ui_block = self._build_ui_blocks()
|
202 |
+
return self._ui_block
|
203 |
+
|
204 |
+
def mount_in_app(self, app: FastAPI, path: str) -> None:
|
205 |
+
blocks = self.get_ui_blocks()
|
206 |
+
blocks.queue()
|
207 |
+
logger.info("Mounting the gradio UI, at path=%s", path)
|
208 |
+
gr.mount_gradio_app(app, blocks, path=path)
|
209 |
+
|
210 |
+
|
211 |
+
if __name__ == "__main__":
|
212 |
+
llm_component = LLMComponent()
|
213 |
+
vector_store_component = VectorStoreComponent()
|
214 |
+
embedding_component = EmbeddingComponent()
|
215 |
+
node_store_component = NodeStoreComponent()
|
216 |
+
|
217 |
+
ingest_service = IngestService(
|
218 |
+
llm_component, vector_store_component, embedding_component, node_store_component
|
219 |
+
)
|
220 |
+
chat_service = ChatService(
|
221 |
+
llm_component, vector_store_component, embedding_component, node_store_component
|
222 |
+
)
|
223 |
+
|
224 |
+
ui = PrivateGptUi(ingest_service, chat_service)
|
225 |
+
|
226 |
+
_blocks = ui.get_ui_blocks()
|
227 |
+
_blocks.queue()
|
228 |
+
_blocks.launch(debug=False, show_api=False)
|
docker-compose.yml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version: '3.4'
|
2 |
+
services:
|
3 |
+
weaviate:
|
4 |
+
command:
|
5 |
+
- --host
|
6 |
+
- 0.0.0.0
|
7 |
+
- --port
|
8 |
+
- '8080'
|
9 |
+
- --scheme
|
10 |
+
- http
|
11 |
+
image: semitechnologies/weaviate:1.23.0
|
12 |
+
ports:
|
13 |
+
- 8080:8080
|
14 |
+
- 50051:50051
|
15 |
+
volumes:
|
16 |
+
- weaviate_data:/var/lib/weaviate
|
17 |
+
restart: on-failure:0
|
18 |
+
environment:
|
19 |
+
QUERY_DEFAULTS_LIMIT: 25
|
20 |
+
AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true'
|
21 |
+
PERSISTENCE_DATA_PATH: '/var/lib/weaviate'
|
22 |
+
DEFAULT_VECTORIZER_MODULE: 'none'
|
23 |
+
CLUSTER_HOSTNAME: 'node1'
|
24 |
+
|
25 |
+
volumes:
|
26 |
+
weaviate_data:
|
prestart.sh
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env bash
|
2 |
+
|
3 |
+
# Download Embedding model and LLM model
|
4 |
+
python setup.py
|
5 |
+
|
pyproject.toml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.poetry]
|
2 |
+
name = "capstone"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = ""
|
5 |
+
authors = ["PhucVu <vuhongphuc24601@gmail.com>"]
|
6 |
+
readme = "README.md"
|
7 |
+
|
8 |
+
[tool.poetry.dependencies]
|
9 |
+
python = "^3.10"
|
10 |
+
llama-index = "^0.9.22"
|
11 |
+
weaviate-client = "^3.26.0"
|
12 |
+
pydantic-settings = "^2.1.0"
|
13 |
+
fastapi = "^0.108.0"
|
14 |
+
uvicorn = "^0.25.0"
|
15 |
+
pydantic = "^2.5.3"
|
16 |
+
gradio = "^4.12.0"
|
17 |
+
|
18 |
+
# reranker
|
19 |
+
torch = {version="^2.3.0", optional=true}
|
20 |
+
sentence-transformers = {version="^2.7.0", optional=true}
|
21 |
+
|
22 |
+
[tool.poetry.group.local]
|
23 |
+
optional = true
|
24 |
+
[tool.poetry.group.local.dependencies]
|
25 |
+
transformers = "^4.36.2"
|
26 |
+
torch = "^2.1.2"
|
27 |
+
llama-cpp-python = "^0.2.29"
|
28 |
+
|
29 |
+
[build-system]
|
30 |
+
requires = ["poetry-core"]
|
31 |
+
build-backend = "poetry.core.masonry.api"
|
setup.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
|
4 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
5 |
+
|
6 |
+
from app._config import settings
|
7 |
+
from app.paths import models_cache_path, models_path
|
8 |
+
|
9 |
+
os.makedirs(models_path, exist_ok=True)
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
# Download embedding model
|
14 |
+
embedding_path = models_path / "embedding"
|
15 |
+
logger.info(f"Downloading embedding: {settings.LOCAL_HF_EMBEDDING_MODEL_NAME}")
|
16 |
+
snapshot_download(
|
17 |
+
repo_id=settings.LOCAL_HF_EMBEDDING_MODEL_NAME,
|
18 |
+
cache_dir=models_cache_path,
|
19 |
+
local_dir=embedding_path,
|
20 |
+
)
|
21 |
+
logger.info("Embedding model downloaded")
|
22 |
+
|
23 |
+
# Download LLM model
|
24 |
+
logger.info(f"Downloading LLM: {settings.LOCAL_HF_LLM_MODEL_FILE}")
|
25 |
+
hf_hub_download(
|
26 |
+
repo_id=settings.LOCAL_HF_LLM_REPO_ID,
|
27 |
+
filename=settings.LOCAL_HF_LLM_MODEL_FILE,
|
28 |
+
cache_dir=models_cache_path,
|
29 |
+
local_dir=models_path,
|
30 |
+
)
|
31 |
+
logger.info("LLM model downloaded")
|