khointn commited on
Commit
cca4857
1 Parent(s): 5a67683

Upload folder using huggingface_hub

Browse files
.flake8 ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [flake8]
2
+ ignore = E203, E501, W503
3
+ max-line-length = 88
4
+ max-complexity = 10
5
+ per-file-ignores =
6
+ */__init__.py: F401, F403
.gitignore ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by https://www.toptal.com/developers/gitignore/api/python,pycharm
2
+ # Edit at https://www.toptal.com/developers/gitignore?templates=python,pycharm
3
+
4
+ ### PyCharm ###
5
+ # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
6
+ # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
7
+
8
+ .idea/
9
+
10
+ # AWS User-
11
+ # CMake
12
+ cmake-build-*/
13
+
14
+ # File-based project format
15
+ *.iws
16
+
17
+ # IntelliJ
18
+ out/
19
+
20
+ # mpeltonen/sbt-idea plugin
21
+ .idea_modules/
22
+
23
+ # JIRA plugin
24
+ atlassian-ide-plugin.xml
25
+
26
+ # Crashlytics plugin (for Android Studio and IntelliJ)
27
+ com_crashlytics_export_strings.xml
28
+ crashlytics.properties
29
+ crashlytics-build.properties
30
+ fabric.properties
31
+
32
+ ### Python ###
33
+ # Byte-compiled / optimized / DLL files
34
+ __pycache__/
35
+ *.py[cod]
36
+ *$py.class
37
+
38
+ # C extensions
39
+ *.so
40
+
41
+ # Distribution / packaging
42
+ .Python
43
+ build/
44
+ develop-eggs/
45
+ dist/
46
+ downloads/
47
+ eggs/
48
+ .eggs/
49
+ lib/
50
+ lib64/
51
+ parts/
52
+ sdist/
53
+ var/
54
+ wheels/
55
+ share/python-wheels/
56
+ *.egg-info/
57
+ .installed.cfg
58
+ *.egg
59
+ MANIFEST
60
+
61
+ # PyInstaller
62
+ # Usually these files are written by a python script from a template
63
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
64
+ *.manifest
65
+ *.spec
66
+
67
+ # Installer logs
68
+ pip-log.txt
69
+ pip-delete-this-directory.txt
70
+
71
+ # Unit test / coverage reports
72
+ htmlcov/
73
+ .tox/
74
+ .nox/
75
+ .coverage
76
+ .coverage.*
77
+ .cache
78
+ nosetests.xml
79
+ coverage.xml
80
+ *.cover
81
+ *.py,cover
82
+ .hypothesis/
83
+ .pytest_cache/
84
+ cover/
85
+
86
+ # Translations
87
+ *.mo
88
+ *.pot
89
+
90
+ # Django stuff:
91
+ *.log
92
+ local_settings.py
93
+ db.sqlite3
94
+ db.sqlite3-journal
95
+
96
+ # Flask stuff:
97
+ instance/
98
+ .webassets-cache
99
+
100
+ # Scrapy stuff:
101
+ .scrapy
102
+
103
+ # Sphinx documentation
104
+ docs/_build/
105
+
106
+ # PyBuilder
107
+ .pybuilder/
108
+ target/
109
+
110
+ # Jupyter Notebook
111
+ .ipynb_checkpoints
112
+
113
+ # IPython
114
+ profile_default/
115
+ ipython_config.py
116
+
117
+ # pyenv
118
+ # For a library or package, you might want to ignore these files since the code is
119
+ # intended to run in multiple environments; otherwise, check them in:
120
+ # .python-version
121
+
122
+ # pipenv
123
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
124
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
125
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
126
+ # install all needed dependencies.
127
+ #Pipfile.lock
128
+
129
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
130
+ __pypackages__/
131
+
132
+ # Celery stuff
133
+ celerybeat-schedule
134
+ celerybeat.pid
135
+
136
+ # SageMath parsed files
137
+ *.sage.py
138
+
139
+ # Environments
140
+ .env
141
+ .venv
142
+ env/
143
+ venv/
144
+ ENV/
145
+ env.bak/
146
+ venv.bak/
147
+
148
+ poetry.lock
149
+
150
+ # Spyder project settings
151
+ .spyderproject
152
+ .spyproject
153
+
154
+ # Rope project settings
155
+ .ropeproject
156
+
157
+ # mkdocs documentation
158
+ /site
159
+
160
+ # mypy
161
+ .mypy_cache/
162
+ .dmypy.json
163
+ dmypy.json
164
+
165
+ # Pyre type checker
166
+ .pyre/
167
+
168
+ # pytype static type analyzer
169
+ .pytype/
170
+
171
+ # Cython debug symbols
172
+ cython_debug/
173
+
174
+ # End of https://www.toptal.com/developers/gitignore/api/python,pycharm
175
+
176
+ coverage_report/
177
+
178
+ local_data/
179
+ models/
180
+ .DS_Store
.isort.cfg ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [settings]
2
+ profile = black
.pre-commit-config.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/psf/black
3
+ rev: 23.3.0
4
+ hooks:
5
+ - id: black
6
+ - repo: https://github.com/PyCQA/flake8
7
+ rev: 6.0.0
8
+ hooks:
9
+ - id: flake8
10
+ - repo: https://github.com/PyCQA/isort
11
+ rev: 5.12.0
12
+ hooks:
13
+ - id: isort
14
+
README.md CHANGED
@@ -1,6 +1,62 @@
1
  ---
2
  title: discord-bot
3
- app_file: __main__.py
4
  sdk: gradio
5
  sdk_version: 4.33.0
6
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: discord-bot
3
+ app_file: app/__main__.py
4
  sdk: gradio
5
  sdk_version: 4.33.0
6
  ---
7
+ # Capstone Project
8
+
9
+ ## Requirements
10
+
11
+ * [Docker](https://www.docker.com/).
12
+ * [Docker Compose](https://docs.docker.com/compose/install/).
13
+ * [Poetry](https://python-poetry.org/) for Python package and environment management.
14
+
15
+
16
+ ## Installation
17
+
18
+ ### Set up virtual environment
19
+
20
+ ```shell
21
+ python -m venv venv
22
+ . venv/bin/activate
23
+ ```
24
+
25
+ ### Install dependencies
26
+
27
+ - Install poetry: https://python-poetry.org/
28
+ - Install dependencies
29
+
30
+ ```shell
31
+ poetry install
32
+ ```
33
+
34
+ #### Local Requirements
35
+ ```shell
36
+ poetry install --with local
37
+ ```
38
+
39
+ Download embedding and(or) LLM models
40
+ ```shell
41
+ bash prestart.sh
42
+ ```
43
+
44
+ ### Install `pre-commit` hooks
45
+
46
+ ```shell
47
+ pre-commit install
48
+ ```
49
+
50
+ ## Running
51
+ ```shell
52
+ docker compose up -d
53
+ poetry run python -m app
54
+ ```
55
+
56
+ ## URLs
57
+ ### Development URLs
58
+ #### Gradio UI
59
+ http://localhost:8000
60
+
61
+ #### API Documentation
62
+ http://localhost:8000/docs
app/README.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ ---
2
+ title: discord-bot
3
+ app_file: __main__.py
4
+ sdk: gradio
5
+ sdk_version: 4.33.0
6
+ ---
app/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ ROOT_LOG_LEVEL = "INFO"
4
+
5
+ PRETTY_LOG_FORMAT = (
6
+ "%(asctime)s.%(msecs)03d [%(levelname)-8s] %(name)+25s - %(message)s"
7
+ )
8
+ logging.basicConfig(level=ROOT_LOG_LEVEL, format=PRETTY_LOG_FORMAT, datefmt="%H:%M:%S")
9
+ logging.captureWarnings(True)
app/__main__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import uvicorn
2
+
3
+ from app._config import settings
4
+ from app.main import app
5
+
6
+ if __name__ == "__main__":
7
+ uvicorn.run(app, host="0.0.0.0", port=settings.PORT)
app/_config.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Literal, Optional
3
+
4
+ from pydantic import Field
5
+ from pydantic_settings import BaseSettings
6
+
7
+
8
+ class Settings(BaseSettings):
9
+ ENVIRONMENT: str
10
+ PORT: int = 8000
11
+ VECTOR_DATABASE: Literal["weaviate"] = "weaviate"
12
+
13
+ OPENAI_API_KEY: Optional[str] = None
14
+ OPENAI_MODEL: str = "gpt-3.5-turbo"
15
+
16
+ WEAVIATE_CLIENT_URL: str = "http://localhost:8080"
17
+
18
+ LLM_MODE: Literal["openai", "mock", "local"] = "mock"
19
+ EMBEDDING_MODE: Literal["openai", "mock", "local"] = "mock"
20
+
21
+ LOCAL_DATA_FOLDER: str = "local_data/test"
22
+
23
+ DEFAULT_QUERY_SYSTEM_PROMPT: str = "You can only answer questions about the provided context. If you know the answer but it is not based in the provided context, don't provide the answer, just state the answer is not in the context provided."
24
+
25
+ LOCAL_HF_EMBEDDING_MODEL_NAME: str = "BAAI/bge-small-en-v1.5"
26
+
27
+ LOCAL_HF_LLM_REPO_ID: str = "TheBloke/Llama-2-7B-Chat-GGUF"
28
+ LOCAL_HF_LLM_MODEL_FILE: str = "llama-2-7b-chat.Q4_K_M.gguf"
29
+
30
+ # LLM config
31
+ LLM_TEMPERATURE: float = Field(
32
+ default=0.1, description="The temperature to use for sampling."
33
+ )
34
+ LLM_MAX_NEW_TOKENS: int = Field(
35
+ default=256,
36
+ description="The maximum number of tokens to generate.",
37
+ )
38
+ LLM_CONTEXT_WINDOW: int = Field(
39
+ default=3900,
40
+ description="The maximum number of context tokens for the model.",
41
+ )
42
+
43
+ # UI
44
+ IS_UI_ENABLED: bool = True
45
+ UI_PATH: str = "/"
46
+
47
+ # Rerank
48
+ IS_RERANK_ENABLED: bool = True
49
+ RERANK_TOP_N: int = 3
50
+ RERANK_MODEL_NAME: str = "cross-encoder/ms-marco-MiniLM-L-2-v2"
51
+
52
+ class Config:
53
+ case_sensitive = True
54
+ env_file_encoding = "utf-8"
55
+
56
+
57
+ environment = os.environ.get("ENVIRONMENT", "local")
58
+ settings = Settings(
59
+ ENVIRONMENT=environment,
60
+ # ".env.{environment}" takes priority over ".env"
61
+ _env_file=[".env", f".env.{environment}"],
62
+ )
app/components/__init__.py ADDED
File without changes
app/components/embedding/__init__.py ADDED
File without changes
app/components/embedding/component.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from llama_index import MockEmbedding
4
+ from llama_index.embeddings.base import BaseEmbedding
5
+
6
+ from app._config import settings
7
+ from app.enums import EmbeddingMode
8
+ from app.paths import models_cache_path
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ MOCK_EMBEDDING_DIM = 1536
13
+
14
+
15
+ class EmbeddingComponent:
16
+ embedding_model: BaseEmbedding
17
+
18
+ def __init__(self) -> None:
19
+ embedding_mode = settings.EMBEDDING_MODE
20
+ logger.info("Initializing the embedding model in mode=%s", embedding_mode)
21
+ match embedding_mode:
22
+ case EmbeddingMode.OPENAI:
23
+ from llama_index import OpenAIEmbedding
24
+
25
+ self.embedding_model = OpenAIEmbedding(api_key=settings.OPENAI_API_KEY)
26
+
27
+ case EmbeddingMode.MOCK:
28
+ # Not a random number, is the dimensionality used by
29
+ # the default embedding model
30
+ self.embedding_model = MockEmbedding(MOCK_EMBEDDING_DIM)
31
+
32
+ case EmbeddingMode.LOCAL:
33
+ from llama_index.embeddings import HuggingFaceEmbedding
34
+
35
+ self.embedding_model = HuggingFaceEmbedding(
36
+ model_name=settings.LOCAL_HF_EMBEDDING_MODEL_NAME,
37
+ cache_folder=str(models_cache_path),
38
+ )
app/components/ingest/__init__.py ADDED
File without changes
app/components/ingest/component.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import logging
3
+ import threading
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ from llama_index import (
8
+ Document,
9
+ ServiceContext,
10
+ StorageContext,
11
+ VectorStoreIndex,
12
+ load_index_from_storage,
13
+ )
14
+ from llama_index.data_structs import IndexDict
15
+ from llama_index.indices.base import BaseIndex
16
+
17
+ from app.components.ingest.helpers import IngestionHelper
18
+ from app.paths import local_data_path
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class BaseIngestComponent(abc.ABC):
24
+ def __init__(
25
+ self,
26
+ storage_context: StorageContext,
27
+ service_context: ServiceContext,
28
+ *args: Any,
29
+ **kwargs: Any,
30
+ ) -> None:
31
+ logger.debug(f"Initializing base ingest component type={type(self).__name__}")
32
+ self.storage_context = storage_context
33
+ self.service_context = service_context
34
+
35
+ @abc.abstractmethod
36
+ def ingest(self, file_name: str, file_data: Path) -> list[Document]:
37
+ pass
38
+
39
+ @abc.abstractmethod
40
+ def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]:
41
+ pass
42
+
43
+ @abc.abstractmethod
44
+ def delete(self, doc_id: str) -> None:
45
+ pass
46
+
47
+
48
+ class BaseIngestComponentWithIndex(BaseIngestComponent, abc.ABC):
49
+ def __init__(
50
+ self,
51
+ storage_context: StorageContext,
52
+ service_context: ServiceContext,
53
+ *args: Any,
54
+ **kwargs: Any,
55
+ ) -> None:
56
+ super().__init__(storage_context, service_context, *args, **kwargs)
57
+
58
+ self.show_progress = True
59
+ self._index_thread_lock = (
60
+ threading.Lock()
61
+ ) # Thread lock! Not Multiprocessing lock
62
+ self._index = self._initialize_index()
63
+
64
+ def _initialize_index(self) -> BaseIndex[IndexDict]:
65
+ """Initialize the index from the storage context."""
66
+ try:
67
+ # Load the index with store_nodes_override=True to be able to delete them
68
+ index = load_index_from_storage(
69
+ storage_context=self.storage_context,
70
+ service_context=self.service_context,
71
+ store_nodes_override=True, # Force store nodes in index and document stores
72
+ show_progress=self.show_progress,
73
+ )
74
+ except ValueError:
75
+ # There are no index in the storage context, creating a new one
76
+ logger.info("Creating a new vector store index")
77
+ index = VectorStoreIndex.from_documents(
78
+ [],
79
+ storage_context=self.storage_context,
80
+ service_context=self.service_context,
81
+ store_nodes_override=True, # Force store nodes in index and document stores
82
+ show_progress=self.show_progress,
83
+ )
84
+ index.storage_context.persist(persist_dir=local_data_path)
85
+ return index
86
+
87
+ def _save_index(self) -> None:
88
+ self._index.storage_context.persist(persist_dir=local_data_path)
89
+
90
+ def delete(self, doc_id: str) -> None:
91
+ with self._index_thread_lock:
92
+ # Delete the document from the index
93
+ self._index.delete_ref_doc(doc_id, delete_from_docstore=True)
94
+
95
+ # Save the index
96
+ self._save_index()
97
+
98
+
99
+ class SimpleIngestComponent(BaseIngestComponentWithIndex):
100
+ def __init__(
101
+ self,
102
+ storage_context: StorageContext,
103
+ service_context: ServiceContext,
104
+ *args: Any,
105
+ **kwargs: Any,
106
+ ) -> None:
107
+ super().__init__(storage_context, service_context, *args, **kwargs)
108
+
109
+ def ingest(self, file_name: str, file_data: Path) -> list[Document]:
110
+ logger.info("Ingesting file_name=%s", file_name)
111
+ documents = IngestionHelper.transform_file_into_documents(file_name, file_data)
112
+ logger.info(
113
+ "Transformed file=%s into count=%s documents", file_name, len(documents)
114
+ )
115
+ logger.debug("Saving the documents in the index and doc store")
116
+ return self._save_docs(documents)
117
+
118
+ def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]:
119
+ saved_documents = []
120
+ for file_name, file_data in files:
121
+ documents = IngestionHelper.transform_file_into_documents(
122
+ file_name, file_data
123
+ )
124
+ saved_documents.extend(self._save_docs(documents))
125
+ return saved_documents
126
+
127
+ def _save_docs(self, documents: list[Document]) -> list[Document]:
128
+ logger.debug("Transforming count=%s documents into nodes", len(documents))
129
+ with self._index_thread_lock:
130
+ for document in documents:
131
+ self._index.insert(document, show_progress=True)
132
+ logger.debug("Persisting the index and nodes")
133
+ # persist the index and nodes
134
+ self._save_index()
135
+ logger.debug("Persisted the index and nodes")
136
+ return documents
137
+
138
+
139
+ def get_ingestion_component(
140
+ storage_context: StorageContext,
141
+ service_context: ServiceContext,
142
+ ) -> BaseIngestComponent:
143
+ return SimpleIngestComponent(storage_context, service_context)
app/components/ingest/helpers.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+
4
+ from llama_index import Document
5
+ from llama_index.readers import JSONReader, StringIterableReader
6
+ from llama_index.readers.file.base import DEFAULT_FILE_READER_CLS
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ # Patching the default file reader to support other file types
11
+ FILE_READER_CLS = DEFAULT_FILE_READER_CLS.copy()
12
+ FILE_READER_CLS.update(
13
+ {
14
+ ".json": JSONReader,
15
+ }
16
+ )
17
+
18
+
19
+ class IngestionHelper:
20
+ """Helper class to transform a file into a list of documents.
21
+
22
+ This class should be used to transform a file into a list of documents.
23
+ These methods are thread-safe (and multiprocessing-safe).
24
+ """
25
+
26
+ @staticmethod
27
+ def transform_file_into_documents(
28
+ file_name: str, file_data: Path
29
+ ) -> list[Document]:
30
+ documents = IngestionHelper._load_file_to_documents(file_name, file_data)
31
+ for document in documents:
32
+ document.metadata["file_name"] = file_name
33
+ IngestionHelper._exclude_metadata(documents)
34
+ return documents
35
+
36
+ @staticmethod
37
+ def _load_file_to_documents(file_name: str, file_data: Path) -> list[Document]:
38
+ logger.debug("Transforming file_name=%s into documents", file_name)
39
+ extension = Path(file_name).suffix
40
+ reader_cls = FILE_READER_CLS.get(extension)
41
+ if reader_cls is None:
42
+ logger.debug(
43
+ "No reader found for extension=%s, using default string reader",
44
+ extension,
45
+ )
46
+ # Read as a plain text
47
+ string_reader = StringIterableReader()
48
+ return string_reader.load_data([file_data.read_text()])
49
+
50
+ logger.debug("Specific reader found for extension=%s", extension)
51
+ return reader_cls().load_data(file_data)
52
+
53
+ @staticmethod
54
+ def _exclude_metadata(documents: list[Document]) -> None:
55
+ logger.debug("Excluding metadata from count=%s documents", len(documents))
56
+ for document in documents:
57
+ document.metadata["doc_id"] = document.doc_id
58
+ # We don't want the Embeddings search to receive this metadata
59
+ document.excluded_embed_metadata_keys = ["doc_id"]
60
+ # We don't want the LLM to receive these metadata in the context
61
+ document.excluded_llm_metadata_keys = ["file_name", "doc_id", "page_label"]
app/components/llm/__init__.py ADDED
File without changes
app/components/llm/component.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from llama_index.llms import LLM, MockLLM
4
+
5
+ from app._config import settings
6
+ from app.enums import LLMMode
7
+ from app.paths import models_path
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class LLMComponent:
13
+ llm: LLM
14
+
15
+ def __init__(self) -> None:
16
+ llm_mode = settings.LLM_MODE
17
+ logger.info(f"Initializing the LLM in mode={llm_mode}")
18
+ match settings.LLM_MODE:
19
+ case LLMMode.OPENAI:
20
+ from llama_index.llms import OpenAI
21
+
22
+ self.llm = OpenAI(
23
+ api_key=settings.OPENAI_API_KEY,
24
+ model=settings.OPENAI_MODEL,
25
+ )
26
+ case LLMMode.MOCK:
27
+ self.llm = MockLLM()
28
+
29
+ case LLMMode.LOCAL:
30
+ from llama_index.llms import LlamaCPP
31
+ from llama_index.llms.llama_utils import (
32
+ completion_to_prompt,
33
+ messages_to_prompt,
34
+ )
35
+
36
+ self.llm = LlamaCPP(
37
+ model_path=str(models_path / settings.LOCAL_HF_LLM_MODEL_FILE),
38
+ temperature=settings.LLM_TEMPERATURE,
39
+ max_new_tokens=settings.LLM_MAX_NEW_TOKENS,
40
+ context_window=settings.LLM_CONTEXT_WINDOW,
41
+ generate_kwargs={},
42
+ # set to at least 1 to use GPU
43
+ # set to -1 for all gpu
44
+ # set to 0 for cpu
45
+ model_kwargs={"n_gpu_layers": 0},
46
+ # transform inputs into Llama2 format
47
+ messages_to_prompt=messages_to_prompt,
48
+ completion_to_prompt=completion_to_prompt,
49
+ verbose=True,
50
+ )
app/components/node_store/__init__.py ADDED
File without changes
app/components/node_store/component.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from llama_index.storage.docstore import BaseDocumentStore, SimpleDocumentStore
4
+ from llama_index.storage.index_store import SimpleIndexStore
5
+ from llama_index.storage.index_store.types import BaseIndexStore
6
+
7
+ from app.paths import local_data_path
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class NodeStoreComponent:
13
+ index_store: BaseIndexStore
14
+ doc_store: BaseDocumentStore
15
+
16
+ def __init__(self) -> None:
17
+ try:
18
+ self.index_store = SimpleIndexStore.from_persist_dir(
19
+ persist_dir=str(local_data_path)
20
+ )
21
+ except FileNotFoundError:
22
+ logger.debug("Local index store not found, creating a new one")
23
+ self.index_store = SimpleIndexStore()
24
+
25
+ try:
26
+ self.doc_store = SimpleDocumentStore.from_persist_dir(
27
+ persist_dir=str(local_data_path)
28
+ )
29
+ except FileNotFoundError:
30
+ logger.debug("Local document store not found, creating a new one")
31
+ self.doc_store = SimpleDocumentStore()
app/components/vector_store/__init__.py ADDED
File without changes
app/components/vector_store/component.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import typing
3
+
4
+ from llama_index import VectorStoreIndex
5
+ from llama_index.indices.vector_store import VectorIndexRetriever
6
+ from llama_index.vector_stores.types import VectorStore
7
+
8
+ from app._config import settings
9
+ from app.enums import WEAVIATE_INDEX_NAME, VectorDatabase
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class VectorStoreComponent:
15
+ vector_store: VectorStore
16
+
17
+ def __init__(self) -> None:
18
+ match settings.VECTOR_DATABASE:
19
+ case VectorDatabase.WEAVIATE:
20
+ import weaviate
21
+ from llama_index.vector_stores import WeaviateVectorStore
22
+
23
+ client = weaviate.Client(settings.WEAVIATE_CLIENT_URL)
24
+ self.vector_store = typing.cast(
25
+ VectorStore,
26
+ WeaviateVectorStore(
27
+ weaviate_client=client, index_name=WEAVIATE_INDEX_NAME
28
+ ),
29
+ )
30
+ case _:
31
+ # Should be unreachable
32
+ # The settings validator should have caught this
33
+ raise ValueError(
34
+ f"Vectorstore database {settings.VECTOR_DATABASE} not supported"
35
+ )
36
+
37
+ @staticmethod
38
+ def get_retriever(
39
+ index: VectorStoreIndex,
40
+ doc_ids: list[str] | None = None,
41
+ similarity_top_k: int = 2,
42
+ ) -> VectorIndexRetriever:
43
+ return VectorIndexRetriever(
44
+ index=index,
45
+ similarity_top_k=similarity_top_k,
46
+ doc_ids=doc_ids,
47
+ )
48
+
49
+ def close(self) -> None:
50
+ if hasattr(self.vector_store.client, "close"):
51
+ self.vector_store.client.close()
app/enums.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum, auto, unique
2
+ from pathlib import Path
3
+
4
+ PROJECT_ROOT_PATH: Path = Path(__file__).parents[1]
5
+
6
+
7
+ @unique
8
+ class BaseEnum(str, Enum):
9
+ @staticmethod
10
+ def _generate_next_value_(name: str, *_):
11
+ """
12
+ Automatically generate values for enum.
13
+ Enum values are lower-cased enum member names.
14
+ """
15
+ return name.lower()
16
+
17
+ @classmethod
18
+ def get_values(cls) -> list[str]:
19
+ # noinspection PyUnresolvedReferences
20
+ return [m.value for m in cls]
21
+
22
+
23
+ class LLMMode(BaseEnum):
24
+ MOCK = auto()
25
+ OPENAI = auto()
26
+ LOCAL = auto()
27
+
28
+
29
+ class EmbeddingMode(BaseEnum):
30
+ MOCK = auto()
31
+ OPENAI = auto()
32
+ LOCAL = auto()
33
+
34
+
35
+ class VectorDatabase(BaseEnum):
36
+ WEAVIATE = auto()
37
+
38
+
39
+ WEAVIATE_INDEX_NAME = "LlamaIndex"
app/main.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from fastapi import FastAPI
4
+
5
+ from app._config import settings
6
+ from app.components.embedding.component import EmbeddingComponent
7
+ from app.components.llm.component import LLMComponent
8
+ from app.components.node_store.component import NodeStoreComponent
9
+ from app.components.vector_store.component import VectorStoreComponent
10
+ from app.server.chat.router import chat_router
11
+ from app.server.chat.service import ChatService
12
+ from app.server.embedding.router import embedding_router
13
+ from app.server.ingest.service import IngestService
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ app = FastAPI()
18
+ app.include_router(chat_router)
19
+ app.include_router(embedding_router)
20
+
21
+ if settings.IS_UI_ENABLED:
22
+ logger.debug("Importing the UI module")
23
+ from app.ui.ui import PrivateGptUi
24
+
25
+ llm_component = LLMComponent()
26
+ vector_store_component = VectorStoreComponent()
27
+ embedding_component = EmbeddingComponent()
28
+ node_store_component = NodeStoreComponent()
29
+
30
+ ingest_service = IngestService(
31
+ llm_component, vector_store_component, embedding_component, node_store_component
32
+ )
33
+ chat_service = ChatService(
34
+ llm_component, vector_store_component, embedding_component, node_store_component
35
+ )
36
+
37
+ ui = PrivateGptUi(ingest_service, chat_service)
38
+ ui.mount_in_app(app, settings.UI_PATH)
app/paths.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from app._config import settings
4
+ from app.enums import PROJECT_ROOT_PATH
5
+
6
+
7
+ def _absolute_or_from_project_root(path: str) -> Path:
8
+ if path.startswith("/"):
9
+ return Path(path)
10
+ return PROJECT_ROOT_PATH / path
11
+
12
+
13
+ local_data_path: Path = _absolute_or_from_project_root(settings.LOCAL_DATA_FOLDER)
14
+ models_path: Path = PROJECT_ROOT_PATH / "models"
15
+ models_cache_path: Path = models_path / "cache"
app/server/__init__.py ADDED
File without changes
app/server/chat/__init__.py ADDED
File without changes
app/server/chat/router.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter
2
+ from llama_index.llms import ChatMessage, MessageRole
3
+ from pydantic import BaseModel
4
+
5
+ from app.components.embedding.component import EmbeddingComponent
6
+ from app.components.llm.component import LLMComponent
7
+ from app.components.node_store.component import NodeStoreComponent
8
+ from app.components.vector_store.component import VectorStoreComponent
9
+ from app.server.chat.service import ChatService
10
+ from app.server.chat.utils import OpenAICompletion, OpenAIMessage, to_openai_response
11
+
12
+ chat_router = APIRouter()
13
+
14
+
15
+ class ChatBody(BaseModel):
16
+ messages: list[OpenAIMessage]
17
+ include_sources: bool = True
18
+
19
+ model_config = {
20
+ "json_schema_extra": {
21
+ "examples": [
22
+ {
23
+ "messages": [
24
+ {
25
+ "role": "system",
26
+ "content": "You are a rapper. Always answer with a rap.",
27
+ },
28
+ {
29
+ "role": "user",
30
+ "content": "How do you fry an egg?",
31
+ },
32
+ ],
33
+ "include_sources": True,
34
+ }
35
+ ]
36
+ }
37
+ }
38
+
39
+
40
+ @chat_router.post(
41
+ "/chat",
42
+ response_model=None,
43
+ responses={200: {"model": OpenAICompletion}},
44
+ tags=["Contextual Completions"],
45
+ )
46
+ def chat_completion(body: ChatBody) -> OpenAICompletion:
47
+ """Given a list of messages comprising a conversation, return a response.
48
+
49
+ Optionally include an initial `role: system` message to influence the way
50
+ the LLM answers.
51
+
52
+ When using `'include_sources': true`, the API will return the source Chunks used
53
+ to create the response, which come from the context provided.
54
+ """
55
+ llm_component = LLMComponent()
56
+ vector_store_component = VectorStoreComponent()
57
+ embedding_component = EmbeddingComponent()
58
+ node_store_component = NodeStoreComponent()
59
+
60
+ chat_service = ChatService(
61
+ llm_component, vector_store_component, embedding_component, node_store_component
62
+ )
63
+ all_messages = [
64
+ ChatMessage(content=m.content, role=MessageRole(m.role)) for m in body.messages
65
+ ]
66
+
67
+ completion = chat_service.chat(messages=all_messages)
68
+ return to_openai_response(
69
+ completion.response, completion.sources if body.include_sources else None
70
+ )
app/server/chat/schemas.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+
3
+ from llama_index.schema import NodeWithScore
4
+ from pydantic import BaseModel, Field
5
+
6
+ from app.server.ingest.schemas import IngestedDoc
7
+
8
+
9
+ class Chunk(BaseModel):
10
+ object: Literal["context.chunk"]
11
+ score: float = Field(examples=[0.023])
12
+ document: IngestedDoc
13
+ text: str = Field(examples=["Outbound sales increased 20%, driven by new leads."])
14
+ previous_texts: list[str] | None = Field(
15
+ default=None,
16
+ examples=[["SALES REPORT 2023", "Inbound didn't show major changes."]],
17
+ )
18
+ next_texts: list[str] | None = Field(
19
+ default=None,
20
+ examples=[
21
+ [
22
+ "New leads came from Google Ads campaign.",
23
+ "The campaign was run by the Marketing Department",
24
+ ]
25
+ ],
26
+ )
27
+
28
+ @classmethod
29
+ def from_node(cls: type["Chunk"], node: NodeWithScore) -> "Chunk":
30
+ doc_id = node.node.ref_doc_id if node.node.ref_doc_id is not None else "-"
31
+ return cls(
32
+ object="context.chunk",
33
+ score=node.score or 0.0,
34
+ document=IngestedDoc(
35
+ object="ingest.document",
36
+ doc_id=doc_id,
37
+ doc_metadata=node.metadata,
38
+ ),
39
+ text=node.get_content(),
40
+ )
41
+
42
+
43
+ class Completion(BaseModel):
44
+ response: str
45
+ sources: list[Chunk] | None = None
app/server/chat/service.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ from llama_index import ServiceContext, StorageContext, VectorStoreIndex
4
+ from llama_index.chat_engine import ContextChatEngine
5
+ from llama_index.chat_engine.types import BaseChatEngine
6
+ from llama_index.core.postprocessor import SentenceTransformerRerank
7
+ from llama_index.indices.postprocessor import MetadataReplacementPostProcessor
8
+ from llama_index.llms import ChatMessage, MessageRole
9
+
10
+ from app._config import settings
11
+ from app.components.embedding.component import EmbeddingComponent
12
+ from app.components.llm.component import LLMComponent
13
+ from app.components.node_store.component import NodeStoreComponent
14
+ from app.components.vector_store.component import VectorStoreComponent
15
+ from app.server.chat.schemas import Chunk, Completion
16
+
17
+
18
+ @dataclass
19
+ class ChatEngineInput:
20
+ system_message: ChatMessage | None = None
21
+ last_message: ChatMessage | None = None
22
+ chat_history: list[ChatMessage] | None = None
23
+
24
+ @classmethod
25
+ def from_messages(cls, messages: list[ChatMessage]) -> "ChatEngineInput":
26
+ # Detect if there is a system message, extract the last message and chat history
27
+ system_message = (
28
+ messages[0]
29
+ if len(messages) > 0 and messages[0].role == MessageRole.SYSTEM
30
+ else None
31
+ )
32
+ last_message = (
33
+ messages[-1]
34
+ if len(messages) > 0 and messages[-1].role == MessageRole.USER
35
+ else None
36
+ )
37
+ # Remove from messages list the system message and last message,
38
+ # if they exist. The rest is the chat history.
39
+ if system_message:
40
+ messages.pop(0)
41
+ if last_message:
42
+ messages.pop(-1)
43
+ chat_history = messages if len(messages) > 0 else None
44
+
45
+ return cls(
46
+ system_message=system_message,
47
+ last_message=last_message,
48
+ chat_history=chat_history,
49
+ )
50
+
51
+
52
+ class ChatService:
53
+ def __init__(
54
+ self,
55
+ llm_component: LLMComponent,
56
+ vector_store_component: VectorStoreComponent,
57
+ embedding_component: EmbeddingComponent,
58
+ node_store_component: NodeStoreComponent,
59
+ ) -> None:
60
+ self.llm_service = llm_component
61
+ self.vector_store_component = vector_store_component
62
+ self.storage_context = StorageContext.from_defaults(
63
+ vector_store=vector_store_component.vector_store,
64
+ docstore=node_store_component.doc_store,
65
+ index_store=node_store_component.index_store,
66
+ )
67
+ self.service_context = ServiceContext.from_defaults(
68
+ llm=llm_component.llm, embed_model=embedding_component.embedding_model
69
+ )
70
+ self.index = VectorStoreIndex.from_vector_store(
71
+ vector_store_component.vector_store,
72
+ storage_context=self.storage_context,
73
+ service_context=self.service_context,
74
+ show_progress=True,
75
+ )
76
+
77
+ def _chat_engine(self, system_prompt: str | None = None) -> BaseChatEngine:
78
+ vector_index_retriever = self.vector_store_component.get_retriever(
79
+ index=self.index
80
+ )
81
+
82
+ node_postprocessors = [
83
+ MetadataReplacementPostProcessor(target_metadata_key="window")
84
+ ]
85
+ if settings.IS_RERANK_ENABLED:
86
+ rerank = SentenceTransformerRerank(
87
+ top_n=settings.RERANK_TOP_N, model=settings.RERANK_MODEL_NAME
88
+ )
89
+ node_postprocessors.append(rerank)
90
+
91
+ return ContextChatEngine.from_defaults(
92
+ system_prompt=system_prompt,
93
+ retriever=vector_index_retriever,
94
+ service_context=self.service_context,
95
+ node_postprocessors=node_postprocessors,
96
+ )
97
+
98
+ def chat(self, messages: list[ChatMessage]):
99
+ chat_engine_input = ChatEngineInput.from_messages(messages)
100
+ last_message = (
101
+ chat_engine_input.last_message.content
102
+ if chat_engine_input.last_message
103
+ else None
104
+ )
105
+ system_prompt = (
106
+ chat_engine_input.system_message.content
107
+ if chat_engine_input.system_message
108
+ else None
109
+ )
110
+ chat_history = (
111
+ chat_engine_input.chat_history if chat_engine_input.chat_history else None
112
+ )
113
+
114
+ chat_engine = self._chat_engine(system_prompt=system_prompt)
115
+ wrapped_response = chat_engine.chat(
116
+ message=last_message if last_message is not None else "",
117
+ chat_history=chat_history,
118
+ )
119
+
120
+ sources = [Chunk.from_node(node) for node in wrapped_response.source_nodes]
121
+ completion = Completion(response=wrapped_response.response, sources=sources)
122
+ return completion
app/server/chat/utils.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import uuid
3
+ from typing import Literal
4
+
5
+ from llama_index.llms import ChatResponse
6
+ from pydantic import BaseModel, Field
7
+
8
+ from app.server.chat.schemas import Chunk
9
+
10
+
11
+ class OpenAIMessage(BaseModel):
12
+ """Inference result, with the source of the message.
13
+
14
+ Role could be the assistant or system
15
+ (providing a default response, not AI generated).
16
+ """
17
+
18
+ role: Literal["assistant", "system", "user"] = Field(default="user")
19
+ content: str | None
20
+
21
+
22
+ class OpenAIChoice(BaseModel):
23
+ """Response from AI."""
24
+
25
+ finish_reason: str | None = Field(examples=["stop"])
26
+ message: OpenAIMessage | None = None
27
+ sources: list[Chunk] | None = None
28
+ index: int = 0
29
+
30
+
31
+ class OpenAICompletion(BaseModel):
32
+ """Clone of OpenAI Completion model.
33
+
34
+ For more information see: https://platform.openai.com/docs/api-reference/chat/object
35
+ """
36
+
37
+ id: str
38
+ object: Literal["completion", "completion.chunk"] = Field(default="completion")
39
+ created: int = Field(..., examples=[1623340000])
40
+ model: Literal["llm-agriculture"]
41
+ choices: list[OpenAIChoice]
42
+
43
+ @classmethod
44
+ def from_text(
45
+ cls,
46
+ text: str | None,
47
+ finish_reason: str | None = None,
48
+ sources: list[Chunk] | None = None,
49
+ ) -> "OpenAICompletion":
50
+ return OpenAICompletion(
51
+ id=str(uuid.uuid4()),
52
+ object="completion",
53
+ created=int(time.time()),
54
+ model="llm-agriculture",
55
+ choices=[
56
+ OpenAIChoice(
57
+ message=OpenAIMessage(role="assistant", content=text),
58
+ finish_reason=finish_reason,
59
+ sources=sources,
60
+ )
61
+ ],
62
+ )
63
+
64
+
65
+ def to_openai_response(
66
+ response: str | ChatResponse, sources: list[Chunk] | None = None
67
+ ) -> OpenAICompletion:
68
+ return OpenAICompletion.from_text(response, finish_reason="stop", sources=sources)
app/server/embedding/__init__.py ADDED
File without changes
app/server/embedding/router.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter
2
+
3
+ from app.components.embedding.component import EmbeddingComponent
4
+ from app.server.embedding.schemas import EmbeddingsBody, EmbeddingsResponse
5
+ from app.server.embedding.service import EmbeddingsService
6
+
7
+ embedding_router = APIRouter()
8
+
9
+
10
+ @embedding_router.post("/embedding", tags=["Embeddings"])
11
+ def generate_embeddings(body: EmbeddingsBody) -> EmbeddingsResponse:
12
+ embedding_component = EmbeddingComponent()
13
+ service = EmbeddingsService(embedding_component)
14
+ input_texts = body.input if isinstance(body.input, list) else [body.input]
15
+ embeddings = service.embed_texts(input_texts)
16
+ return EmbeddingsResponse(
17
+ object="list", model=service.embedding_model.model_name, data=embeddings
18
+ )
app/server/embedding/schemas.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+
6
+ class Embedding(BaseModel):
7
+ index: int
8
+ object: Literal["embedding"]
9
+ embedding: list[float] = Field(examples=[[0.1, -0.2]])
10
+
11
+
12
+ class EmbeddingsBody(BaseModel):
13
+ input: str | list[str]
14
+
15
+
16
+ class EmbeddingsResponse(BaseModel):
17
+ object: Literal["list"]
18
+ model: str
19
+ data: list[Embedding]
app/server/embedding/service.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.components.embedding.component import EmbeddingComponent
2
+ from app.server.embedding.schemas import Embedding
3
+
4
+
5
+ class EmbeddingsService:
6
+ def __init__(self, embedding_component: EmbeddingComponent) -> None:
7
+ self.embedding_model = embedding_component.embedding_model
8
+
9
+ def embed_texts(self, texts: list[str]) -> list[Embedding]:
10
+ texts_embeddings = self.embedding_model.get_text_embedding_batch(texts)
11
+ return [
12
+ Embedding(
13
+ index=texts_embeddings.index(embedding),
14
+ object="embedding",
15
+ embedding=embedding,
16
+ )
17
+ for embedding in texts_embeddings
18
+ ]
app/server/ingest/__init__.py ADDED
File without changes
app/server/ingest/schemas.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Literal
2
+
3
+ from llama_index import Document
4
+ from pydantic import BaseModel, Field
5
+
6
+
7
+ class IngestedDoc(BaseModel):
8
+ object: Literal["ingest.document"]
9
+ doc_id: str = Field(examples=["c202d5e6-7b69-4869-81cc-dd574ee8ee11"])
10
+ doc_metadata: dict[str, Any] | None = Field(
11
+ examples=[
12
+ {
13
+ "page_label": "2",
14
+ "file_name": "agriculture.pdf",
15
+ }
16
+ ]
17
+ )
18
+
19
+ @staticmethod
20
+ def curate_metadata(metadata: dict[str, Any]) -> dict[str, Any]:
21
+ """Remove unwanted metadata keys."""
22
+ for key in ["doc_id", "window", "original_text"]:
23
+ metadata.pop(key, None)
24
+ return metadata
25
+
26
+ @staticmethod
27
+ def from_document(document: Document) -> "IngestedDoc":
28
+ return IngestedDoc(
29
+ object="ingest.document",
30
+ doc_id=document.doc_id,
31
+ doc_metadata=IngestedDoc.curate_metadata(document.metadata),
32
+ )
app/server/ingest/service.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import tempfile
3
+ from pathlib import Path
4
+ from typing import AnyStr, BinaryIO
5
+
6
+ from llama_index import ServiceContext, StorageContext
7
+ from llama_index.node_parser import SentenceWindowNodeParser
8
+
9
+ from app.components.embedding.component import EmbeddingComponent
10
+ from app.components.ingest.component import get_ingestion_component
11
+ from app.components.llm.component import LLMComponent
12
+ from app.components.node_store.component import NodeStoreComponent
13
+ from app.components.vector_store.component import VectorStoreComponent
14
+ from app.server.ingest.schemas import IngestedDoc
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class IngestService:
20
+ def __init__(
21
+ self,
22
+ llm_component: LLMComponent,
23
+ vector_store_component: VectorStoreComponent,
24
+ embedding_component: EmbeddingComponent,
25
+ node_store_component: NodeStoreComponent,
26
+ ) -> None:
27
+ self.llm_service = llm_component
28
+ self.storage_context = StorageContext.from_defaults(
29
+ vector_store=vector_store_component.vector_store,
30
+ docstore=node_store_component.doc_store,
31
+ index_store=node_store_component.index_store,
32
+ )
33
+ node_parser = SentenceWindowNodeParser.from_defaults()
34
+ self.ingest_service_context = ServiceContext.from_defaults(
35
+ llm=self.llm_service.llm,
36
+ embed_model=embedding_component.embedding_model,
37
+ node_parser=node_parser,
38
+ # Embeddings done early in the pipeline of node transformations, right
39
+ # after the node parsing
40
+ transformations=[node_parser, embedding_component.embedding_model],
41
+ )
42
+
43
+ self.ingest_component = get_ingestion_component(
44
+ self.storage_context, self.ingest_service_context
45
+ )
46
+
47
+ def _ingest_data(self, file_name: str, file_data: AnyStr) -> list[IngestedDoc]:
48
+ logger.debug(f"Got file data of size={len(file_data)} to ingest")
49
+ # llama-index mainly supports reading from files, so
50
+ # we have to create a tmp file to read for it to work
51
+ # delete=False to avoid a Windows 11 permission error.
52
+ with tempfile.NamedTemporaryFile(delete=False) as tmp:
53
+ try:
54
+ path_to_tmp = Path(tmp.name)
55
+ if isinstance(file_data, bytes):
56
+ path_to_tmp.write_bytes(file_data)
57
+ else:
58
+ path_to_tmp.write_text(str(file_data))
59
+ return self.ingest_file(file_name, path_to_tmp)
60
+ finally:
61
+ tmp.close()
62
+ path_to_tmp.unlink()
63
+
64
+ def ingest_file(self, file_name: str, file_data: Path) -> list[IngestedDoc]:
65
+ logger.info(f"Ingesting file_name={file_name}")
66
+ documents = self.ingest_component.ingest(file_name, file_data)
67
+ logger.info(f"Finished ingestion file_name={file_name}")
68
+ return [IngestedDoc.from_document(document) for document in documents]
69
+
70
+ def ingest_text(self, file_name: str, text: str) -> list[IngestedDoc]:
71
+ logger.debug(f"Ingesting text data with file_name={file_name}")
72
+ return self._ingest_data(file_name, text)
73
+
74
+ def ingest_bin_data(
75
+ self, file_name: str, raw_file_data: BinaryIO
76
+ ) -> list[IngestedDoc]:
77
+ logger.debug(f"Ingesting binary data with file_name={file_name}")
78
+ file_data = raw_file_data.read()
79
+ return self._ingest_data(file_name, file_data)
80
+
81
+ def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[IngestedDoc]:
82
+ logger.info(f"Ingesting file_names={[f[0] for f in files]}")
83
+ documents = self.ingest_component.bulk_ingest(files)
84
+ logger.info(f"Finished ingestion file_name={[f[0] for f in files]}")
85
+ return [IngestedDoc.from_document(document) for document in documents]
86
+
87
+ def list_ingested(self) -> list[IngestedDoc]:
88
+ ingested_docs = []
89
+ try:
90
+ docstore = self.storage_context.docstore
91
+ ingested_docs_ids: set[str] = set()
92
+
93
+ for node in docstore.docs.values():
94
+ if node.ref_doc_id is not None:
95
+ ingested_docs_ids.add(node.ref_doc_id)
96
+
97
+ for doc_id in ingested_docs_ids:
98
+ ref_doc_info = docstore.get_ref_doc_info(ref_doc_id=doc_id)
99
+ doc_metadata = None
100
+ if ref_doc_info is not None and ref_doc_info.metadata is not None:
101
+ doc_metadata = IngestedDoc.curate_metadata(ref_doc_info.metadata)
102
+ ingested_docs.append(
103
+ IngestedDoc(
104
+ object="ingest.document",
105
+ doc_id=doc_id,
106
+ doc_metadata=doc_metadata,
107
+ )
108
+ )
109
+ except ValueError:
110
+ logger.warning("Got an exception when getting list of docs", exc_info=True)
111
+ pass
112
+ logger.debug(f"Found count={len(ingested_docs)} ingested documents")
113
+ return ingested_docs
114
+
115
+ def delete(self, doc_id: str) -> None:
116
+ """Delete an ingested document.
117
+
118
+ :raises ValueError: if the document does not exist
119
+ """
120
+ logger.info(
121
+ "Deleting the ingested document=%s in the doc and index store", doc_id
122
+ )
123
+ self.ingest_component.delete(doc_id)
app/ui/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Gradio based UI."""
app/ui/dodge_ava.jpg ADDED
app/ui/schemas.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+
3
+ from app.server.chat.schemas import Chunk
4
+
5
+
6
+ class Source(BaseModel):
7
+ file: str
8
+ page: str
9
+ text: str
10
+
11
+ class Config:
12
+ frozen = True
13
+
14
+ @staticmethod
15
+ def curate_sources(sources: list[Chunk]) -> set["Source"]:
16
+ curated_sources = set()
17
+
18
+ for chunk in sources:
19
+ doc_metadata = chunk.document.doc_metadata
20
+
21
+ file_name = doc_metadata.get("file_name", "-") if doc_metadata else "-"
22
+ page_label = doc_metadata.get("page_label", "-") if doc_metadata else "-"
23
+
24
+ source = Source(file=file_name, page=page_label, text=chunk.text)
25
+ curated_sources.add(source)
26
+
27
+ return curated_sources
app/ui/ui.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file should be imported only and only if you want to run the UI locally."""
2
+ import itertools
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import gradio as gr
8
+ from fastapi import FastAPI
9
+ from gradio.themes.utils.colors import slate
10
+ from llama_index.llms import ChatMessage, MessageRole
11
+
12
+ from app._config import settings
13
+ from app.components.embedding.component import EmbeddingComponent
14
+ from app.components.llm.component import LLMComponent
15
+ from app.components.node_store.component import NodeStoreComponent
16
+ from app.components.vector_store.component import VectorStoreComponent
17
+ from app.enums import PROJECT_ROOT_PATH
18
+ from app.server.chat.service import ChatService
19
+ from app.server.ingest.service import IngestService
20
+ from app.ui.schemas import Source
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ THIS_DIRECTORY_RELATIVE = Path(__file__).parent.relative_to(PROJECT_ROOT_PATH)
25
+ AVATAR_BOT = THIS_DIRECTORY_RELATIVE / "dodge_ava.jpg"
26
+
27
+ UI_TAB_TITLE = "Agriculture Chatbot"
28
+
29
+ SOURCES_SEPARATOR = "\n\n Sources: \n"
30
+
31
+
32
+ class PrivateGptUi:
33
+ def __init__(
34
+ self,
35
+ ingest_service: IngestService,
36
+ chat_service: ChatService,
37
+ ) -> None:
38
+ self._ingest_service = ingest_service
39
+ self._chat_service = chat_service
40
+
41
+ # Cache the UI blocks
42
+ self._ui_block = None
43
+
44
+ # Initialize system prompt
45
+ self._system_prompt = self._get_default_system_prompt()
46
+
47
+ def _chat(self, message: str, history: list[list[str]], *_: Any) -> Any:
48
+ def build_history() -> list[ChatMessage]:
49
+ history_messages: list[ChatMessage] = list(
50
+ itertools.chain(
51
+ *[
52
+ [
53
+ ChatMessage(content=interaction[0], role=MessageRole.USER),
54
+ ChatMessage(
55
+ # Remove from history content the Sources information
56
+ content=interaction[1].split(SOURCES_SEPARATOR)[0],
57
+ role=MessageRole.ASSISTANT,
58
+ ),
59
+ ]
60
+ for interaction in history
61
+ ]
62
+ )
63
+ )
64
+
65
+ # max 20 messages to try to avoid context overflow
66
+ return history_messages[:20]
67
+
68
+ new_message = ChatMessage(content=message, role=MessageRole.USER)
69
+ all_messages = [*build_history(), new_message]
70
+ # If a system prompt is set, add it as a system message
71
+ if self._system_prompt:
72
+ all_messages.insert(
73
+ 0,
74
+ ChatMessage(
75
+ content=self._system_prompt,
76
+ role=MessageRole.SYSTEM,
77
+ ),
78
+ )
79
+
80
+ completion = self._chat_service.chat(messages=all_messages)
81
+ full_response = completion.response
82
+
83
+ if completion.sources:
84
+ full_response += SOURCES_SEPARATOR
85
+ curated_sources = Source.curate_sources(completion.sources)
86
+ sources_text = "\n\n\n".join(
87
+ f"{index}. {source.file} (page {source.page})"
88
+ for index, source in enumerate(curated_sources, start=1)
89
+ )
90
+ full_response += sources_text
91
+
92
+ return full_response
93
+
94
+ # On initialization this function set the system prompt
95
+ # to the default prompt based on settings.
96
+ @staticmethod
97
+ def _get_default_system_prompt() -> str:
98
+ return settings.DEFAULT_QUERY_SYSTEM_PROMPT
99
+
100
+ def _set_system_prompt(self, system_prompt_input: str) -> None:
101
+ logger.info(f"Setting system prompt to: {system_prompt_input}")
102
+ self._system_prompt = system_prompt_input
103
+
104
+ def _list_ingested_files(self) -> list[list[str]]:
105
+ files = set()
106
+ for ingested_document in self._ingest_service.list_ingested():
107
+ if ingested_document.doc_metadata is None:
108
+ # Skipping documents without metadata
109
+ continue
110
+ file_name = ingested_document.doc_metadata.get(
111
+ "file_name", "[FILE NAME MISSING]"
112
+ )
113
+ files.add(file_name)
114
+ return [[row] for row in files]
115
+
116
+ def _upload_file(self, files: list[str]) -> None:
117
+ logger.debug("Loading count=%s files", len(files))
118
+ paths = [Path(file) for file in files]
119
+ self._ingest_service.bulk_ingest([(str(path.name), path) for path in paths])
120
+
121
+ def _build_ui_blocks(self) -> gr.Blocks:
122
+ logger.debug("Creating the UI blocks")
123
+ with gr.Blocks(
124
+ title=UI_TAB_TITLE,
125
+ theme=gr.themes.Soft(primary_hue=slate),
126
+ css=".logo { "
127
+ "display:flex;"
128
+ "height: 80px;"
129
+ "border-radius: 8px;"
130
+ "align-content: center;"
131
+ "justify-content: center;"
132
+ "align-items: center;"
133
+ "}"
134
+ ".logo img { height: 25% }"
135
+ ".contain { display: flex !important; flex-direction: column !important; }"
136
+ "#component-0, #component-3, #component-10, #component-8 { height: 100% !important; }"
137
+ "#chatbot { flex-grow: 1 !important; overflow: auto !important;}"
138
+ "#col { height: calc(100vh - 112px - 16px) !important; }",
139
+ ) as blocks:
140
+ with gr.Row():
141
+ gr.HTML(f"<div class='logo'/><h1>{UI_TAB_TITLE}</h1></div")
142
+
143
+ with gr.Row(equal_height=False):
144
+ with gr.Column(scale=3):
145
+ upload_button = gr.components.UploadButton(
146
+ "Upload File(s)",
147
+ type="filepath",
148
+ file_count="multiple",
149
+ size="sm",
150
+ )
151
+ ingested_dataset = gr.List(
152
+ self._list_ingested_files,
153
+ headers=["File name"],
154
+ label="Ingested Files",
155
+ interactive=False,
156
+ render=False, # Rendered under the button
157
+ )
158
+ upload_button.upload(
159
+ self._upload_file,
160
+ inputs=upload_button,
161
+ outputs=ingested_dataset,
162
+ )
163
+ ingested_dataset.change(
164
+ self._list_ingested_files,
165
+ outputs=ingested_dataset,
166
+ )
167
+ ingested_dataset.render()
168
+ system_prompt_input = gr.Textbox(
169
+ placeholder=self._system_prompt,
170
+ label="System Prompt",
171
+ lines=2,
172
+ interactive=True,
173
+ render=False,
174
+ )
175
+
176
+ # On blur, set system prompt to use in queries
177
+ system_prompt_input.blur(
178
+ self._set_system_prompt,
179
+ inputs=system_prompt_input,
180
+ )
181
+
182
+ with gr.Column(scale=7, elem_id="col"):
183
+ _ = gr.ChatInterface(
184
+ self._chat,
185
+ chatbot=gr.Chatbot(
186
+ label=f"LLM: {settings.LLM_MODE}",
187
+ show_copy_button=True,
188
+ elem_id="chatbot",
189
+ render=False,
190
+ avatar_images=(
191
+ None,
192
+ AVATAR_BOT,
193
+ ),
194
+ ),
195
+ additional_inputs=[upload_button, system_prompt_input],
196
+ )
197
+ return blocks
198
+
199
+ def get_ui_blocks(self) -> gr.Blocks:
200
+ if self._ui_block is None:
201
+ self._ui_block = self._build_ui_blocks()
202
+ return self._ui_block
203
+
204
+ def mount_in_app(self, app: FastAPI, path: str) -> None:
205
+ blocks = self.get_ui_blocks()
206
+ blocks.queue()
207
+ logger.info("Mounting the gradio UI, at path=%s", path)
208
+ gr.mount_gradio_app(app, blocks, path=path)
209
+
210
+
211
+ if __name__ == "__main__":
212
+ llm_component = LLMComponent()
213
+ vector_store_component = VectorStoreComponent()
214
+ embedding_component = EmbeddingComponent()
215
+ node_store_component = NodeStoreComponent()
216
+
217
+ ingest_service = IngestService(
218
+ llm_component, vector_store_component, embedding_component, node_store_component
219
+ )
220
+ chat_service = ChatService(
221
+ llm_component, vector_store_component, embedding_component, node_store_component
222
+ )
223
+
224
+ ui = PrivateGptUi(ingest_service, chat_service)
225
+
226
+ _blocks = ui.get_ui_blocks()
227
+ _blocks.queue()
228
+ _blocks.launch(debug=False, show_api=False)
docker-compose.yml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: '3.4'
2
+ services:
3
+ weaviate:
4
+ command:
5
+ - --host
6
+ - 0.0.0.0
7
+ - --port
8
+ - '8080'
9
+ - --scheme
10
+ - http
11
+ image: semitechnologies/weaviate:1.23.0
12
+ ports:
13
+ - 8080:8080
14
+ - 50051:50051
15
+ volumes:
16
+ - weaviate_data:/var/lib/weaviate
17
+ restart: on-failure:0
18
+ environment:
19
+ QUERY_DEFAULTS_LIMIT: 25
20
+ AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true'
21
+ PERSISTENCE_DATA_PATH: '/var/lib/weaviate'
22
+ DEFAULT_VECTORIZER_MODULE: 'none'
23
+ CLUSTER_HOSTNAME: 'node1'
24
+
25
+ volumes:
26
+ weaviate_data:
prestart.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #! /usr/bin/env bash
2
+
3
+ # Download Embedding model and LLM model
4
+ python setup.py
5
+
pyproject.toml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "capstone"
3
+ version = "0.1.0"
4
+ description = ""
5
+ authors = ["PhucVu <vuhongphuc24601@gmail.com>"]
6
+ readme = "README.md"
7
+
8
+ [tool.poetry.dependencies]
9
+ python = "^3.10"
10
+ llama-index = "^0.9.22"
11
+ weaviate-client = "^3.26.0"
12
+ pydantic-settings = "^2.1.0"
13
+ fastapi = "^0.108.0"
14
+ uvicorn = "^0.25.0"
15
+ pydantic = "^2.5.3"
16
+ gradio = "^4.12.0"
17
+
18
+ # reranker
19
+ torch = {version="^2.3.0", optional=true}
20
+ sentence-transformers = {version="^2.7.0", optional=true}
21
+
22
+ [tool.poetry.group.local]
23
+ optional = true
24
+ [tool.poetry.group.local.dependencies]
25
+ transformers = "^4.36.2"
26
+ torch = "^2.1.2"
27
+ llama-cpp-python = "^0.2.29"
28
+
29
+ [build-system]
30
+ requires = ["poetry-core"]
31
+ build-backend = "poetry.core.masonry.api"
setup.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ from huggingface_hub import hf_hub_download, snapshot_download
5
+
6
+ from app._config import settings
7
+ from app.paths import models_cache_path, models_path
8
+
9
+ os.makedirs(models_path, exist_ok=True)
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ # Download embedding model
14
+ embedding_path = models_path / "embedding"
15
+ logger.info(f"Downloading embedding: {settings.LOCAL_HF_EMBEDDING_MODEL_NAME}")
16
+ snapshot_download(
17
+ repo_id=settings.LOCAL_HF_EMBEDDING_MODEL_NAME,
18
+ cache_dir=models_cache_path,
19
+ local_dir=embedding_path,
20
+ )
21
+ logger.info("Embedding model downloaded")
22
+
23
+ # Download LLM model
24
+ logger.info(f"Downloading LLM: {settings.LOCAL_HF_LLM_MODEL_FILE}")
25
+ hf_hub_download(
26
+ repo_id=settings.LOCAL_HF_LLM_REPO_ID,
27
+ filename=settings.LOCAL_HF_LLM_MODEL_FILE,
28
+ cache_dir=models_cache_path,
29
+ local_dir=models_path,
30
+ )
31
+ logger.info("LLM model downloaded")