XiaoYun Zhang commited on
Commit
6abb254
0 Parent(s):
Files changed (17) hide show
  1. .docker-compose +0 -0
  2. .gitattributes +35 -0
  3. .gitignore +18 -0
  4. .local_storage/user.json +14 -0
  5. LICENSE +21 -0
  6. README.md +13 -0
  7. app.py +302 -0
  8. di.py +51 -0
  9. embedding.py +56 -0
  10. index.py +207 -0
  11. model/document.py +98 -0
  12. model/record.py +8 -0
  13. model/user.py +13 -0
  14. requirements.txt +4 -0
  15. setting.py +18 -0
  16. setup.py +13 -0
  17. storage.py +48 -0
.docker-compose ADDED
File without changes
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # .gitignore file for Python projects
2
+ # Covers most common project files and folders
3
+
4
+ .DS_Store
5
+ *.pyc
6
+ *.pyo
7
+ *.pyd
8
+ __pycache__
9
+ *.so
10
+ *.egg
11
+ *.egg-info
12
+ dist
13
+ build
14
+ docs/_build
15
+ .idea
16
+ venv
17
+ test
18
+ .env
.local_storage/user.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "user_name": "bigmiao",
3
+ "email": "g2260578356@gmail.com",
4
+ "full_name": "g2260578356",
5
+ "disabled": false,
6
+ "documents": [
7
+ {
8
+ "name": "mlnet_notebook_examples_v1.json.json",
9
+ "description": null,
10
+ "status": "done",
11
+ "url": "bigmiao-mlnet_examples.json"
12
+ }
13
+ ]
14
+ }
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Xiaoyun Zhang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Mlnet Samples
3
+ emoji: 😻
4
+ colorFrom: yellow
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 3.47.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fastapi as api
2
+ from typing import Annotated
3
+ from fastapi.security import OAuth2PasswordBearer, OAuth2AuthorizationCodeBearer, OAuth2PasswordRequestForm
4
+ from model.document import Document, PlainTextDocument, JsonDocument
5
+ import sys
6
+ from model.user import User
7
+ from fastapi import FastAPI, File, UploadFile
8
+ from di import initialize_di_for_app
9
+ import gradio as gr
10
+ import os
11
+ import json
12
+ SETTINGS, STORAGE, EMBEDDING, INDEX = initialize_di_for_app()
13
+ user_json_str = STORAGE.load('user.json')
14
+ USER = User.parse_raw(user_json_str)
15
+
16
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/token")
17
+ app = api.FastAPI()
18
+ app.openapi_version = "3.0.0"
19
+ users = [USER]
20
+ async def get_current_user(token: str = api.Depends(oauth2_scheme)):
21
+ '''
22
+ Get current user
23
+ '''
24
+ for user in users:
25
+ if user.user_name == token:
26
+ return user
27
+
28
+ raise api.HTTPException(status_code=401, detail="Invalid authentication credentials")
29
+
30
+
31
+ @app.post("/api/v1/auth/token")
32
+ async def login(form_data: Annotated[OAuth2PasswordRequestForm, api.Depends()]):
33
+ '''
34
+ Login to get a token
35
+ '''
36
+ return {"access_token": form_data.username}
37
+
38
+ @app.post("/api/v1/uploadfile/", include_in_schema=False)
39
+ def create_upload_file(file: UploadFile = api.File(...)) -> Document:
40
+ '''
41
+ Upload a file
42
+ '''
43
+ fileUrl = f'{USER.user_name}-{file.filename}'
44
+ STORAGE.save(fileUrl, file.read())
45
+
46
+ # create plainTextDocument if the file is a text file
47
+ if file.filename.endswith('.txt'):
48
+ return PlainTextDocument(
49
+ name=file.filename,
50
+ status='uploading',
51
+ url=fileUrl,
52
+ embedding=EMBEDDING,
53
+ storage=STORAGE,
54
+ )
55
+ else:
56
+ raise api.HTTPException(status_code=400, detail="File type not supported")
57
+
58
+
59
+ ### /api/v1/.well-known
60
+ #### Get /openapi.json
61
+ # Get the openapi json file
62
+ @app.get("/api/v1/.well-known/openapi.json")
63
+ async def get_openapi():
64
+ '''
65
+ otherwise return 401
66
+ '''
67
+
68
+ # get a list of document names + description
69
+ document_list = [[doc.name, doc.description] for doc in USER.documents]
70
+
71
+ # get openapi json from api
72
+ openapi = app.openapi().copy()
73
+
74
+ openapi['info']['title'] = 'DocumentSearch'
75
+ description = f'''Search documents with a query.
76
+ ## Documents
77
+ {document_list}
78
+ '''
79
+
80
+ openapi['info']['description'] = description
81
+
82
+ # update description in /api/v1/search
83
+ openapi['paths']['/api/v1/search']['get']['description'] += f'''
84
+ Available documents:
85
+ {document_list}
86
+ '''
87
+
88
+ # filter out unnecessary endpoints
89
+ openapi['paths'] = {
90
+ '/api/v1/search': openapi['paths']['/api/v1/search'],
91
+ }
92
+
93
+ # remove components
94
+ openapi['components'] = {}
95
+
96
+ # return the openapi json
97
+ return openapi
98
+
99
+
100
+
101
+ ### /api/v1/document
102
+ #### Get /list
103
+ # Get the list of documents
104
+ @app.get("/api/v1/document/list")
105
+ # async def get_document_list(user: Annotated[User, api.Depends(get_current_user)]) -> list[Document]:
106
+ async def get_document_list() -> list[Document]:
107
+ '''
108
+ Get the list of documents
109
+ '''
110
+ return USER.documents
111
+
112
+ #### Post /upload
113
+ # Upload a document
114
+ @app.post("/api/v1/document/upload")
115
+ # def upload_document(user: Annotated[User, api.Depends(get_current_user)], document: Annotated[Document, api.Depends(create_upload_file)]):
116
+ def upload_document(document: Annotated[Document, api.Depends(create_upload_file)]):
117
+ '''
118
+ Upload a document
119
+ '''
120
+ document.status = 'processing'
121
+ INDEX.load_or_update_document(user, document, progress)
122
+ document.status = 'done'
123
+ USER.documents.append(document)
124
+
125
+ #### Get /delete
126
+ # Delete a document
127
+ @app.get("/api/v1/document/delete")
128
+ # async def delete_document(user: Annotated[User, api.Depends(get_current_user)], document_name: str):
129
+ async def delete_document(document_name: str):
130
+ '''
131
+ Delete a document
132
+ '''
133
+ for doc in USER.documents:
134
+ if doc.name == document_name:
135
+ STORAGE.delete(doc.url)
136
+ INDEX.remove_document(USER, doc)
137
+ USER.documents.remove(doc)
138
+ return
139
+
140
+ raise api.HTTPException(status_code=404, detail="Document not found")
141
+
142
+ # Query the index
143
+ @app.get("/api/v1/search", operation_id=None,)
144
+ def search(
145
+ # user: Annotated[User, api.Depends(get_current_user)],
146
+ query: str,
147
+ document_name: str = None,
148
+ top_k: int = 10,
149
+ threshold: float = 0.5):
150
+ '''
151
+ Search documents with a query. It will return [top_k] results with a score higher than [threshold].
152
+ query: the query string, required
153
+ document_name: the document name, optional. You can provide this parameter to search in a specific document.
154
+ top_k: the number of results to return, optional. Default to 10.
155
+ threshold: the threshold of the results, optional. Default to 0.5.
156
+ '''
157
+ if document_name:
158
+ for doc in USER.documents:
159
+ if doc.name == document_name:
160
+ return INDEX.query_document(USER, doc, query, top_k, threshold)
161
+
162
+ raise api.HTTPException(status_code=404, detail="Document not found")
163
+ else:
164
+ return INDEX.query_index(USER, query, top_k, threshold)
165
+
166
+ def receive_signal(signalNumber, frame):
167
+ print('Received:', signalNumber)
168
+ sys.exit()
169
+
170
+
171
+ @app.on_event("startup")
172
+ async def startup_event():
173
+ import signal
174
+ signal.signal(signal.SIGINT, receive_signal)
175
+ # startup tasks
176
+
177
+ @app.on_event("shutdown")
178
+ def exit_event():
179
+ # save USER
180
+ STORAGE.save('user.json', USER.model_dump_json())
181
+ print('exit')
182
+
183
+ user = USER
184
+
185
+ def gradio_upload_document(file: File):
186
+ file_temp_path = file.name
187
+ # load file
188
+ file_name = os.path.basename(file_temp_path)
189
+ fileUrl = f'{USER.user_name}-{file_name}'
190
+ with open(file_temp_path, 'r', encoding='utf-8') as f:
191
+ STORAGE.save(fileUrl, f.read())
192
+
193
+ # create plainTextDocument if the file is a text file
194
+ doc = None
195
+ if file_name.endswith('.txt'):
196
+ doc = PlainTextDocument(
197
+ name=file_name,
198
+ status='uploading',
199
+ url=fileUrl,
200
+ embedding=EMBEDDING,
201
+ storage=STORAGE,
202
+ )
203
+ elif file_name.endswith('.json'):
204
+ doc = JsonDocument(
205
+ name=file_name,
206
+ status='uploading',
207
+ url=fileUrl,
208
+ embedding=EMBEDDING,
209
+ storage=STORAGE,
210
+ )
211
+ else:
212
+ raise api.HTTPException(status_code=400, detail="File type not supported")
213
+ doc.status = 'processing'
214
+ INDEX.load_or_update_document(user, doc)
215
+ doc.status = 'done'
216
+ USER.documents.append(doc)
217
+
218
+ return f'uploaded {file_name}'
219
+
220
+ def gradio_query(query: str, document_name: str = None, top_k: int = 10, threshold: float = 0.5):
221
+ res_or_exception = search(query, document_name, top_k, threshold)
222
+ if isinstance(res_or_exception, api.HTTPException):
223
+ raise res_or_exception
224
+
225
+ # convert to json string
226
+
227
+ records = [record.model_dump(mode='json') for record in res_or_exception]
228
+
229
+ return json.dumps(records, indent=4)
230
+
231
+ with gr.Blocks() as ui:
232
+ gr.Markdown("#llm-memory")
233
+
234
+ with gr.Column():
235
+ gr.Markdown(
236
+ """
237
+ ## LLM Memory
238
+ """)
239
+ with gr.Row():
240
+ user_name = gr.Label(label="User name", value=USER.user_name)
241
+
242
+ # url to .well-known/openapi.json
243
+ gr.Label(label=".wellknown/openapi.json", value=f"/api/v1/.well-known/openapi.json")
244
+
245
+ # with gr.Tab("avaiable documents"):
246
+ # available_documents = gr.Label(label="avaiable documents", value="avaiable documents")
247
+ # refresh_btn = gr.Button(label="refresh", type="button")
248
+ # refresh_btn.click(lambda: '\r\n'.join([doc.name for doc in USER.documents]), None, available_documents)
249
+ # documents = USER.documents
250
+ # for document in documents:
251
+ # gr.Label(label=document.name, value=document.name)
252
+ # with gr.Tab("upload document"):
253
+ # with gr.Tab("upload .txt document"):
254
+ # file = gr.File(label="upload document", type="file", file_types=[".txt"])
255
+ # output = gr.Label(label="output", value="output")
256
+ # upload_btn = gr.Button("upload document", type="button")
257
+ # upload_btn.click(gradio_upload_document, file, output)
258
+ # with gr.Tab("upload .json document"):
259
+ # gr.Markdown(
260
+ # """
261
+ # The json document should be a list of objects, each object should have a `content` field. If you want to add more fields, you can add them in the `meta_data` field.
262
+ # For example:
263
+ # ```json
264
+ # [
265
+ # {
266
+ # "content": "hello world",
267
+ # "meta_data": {
268
+ # "title": "hello world",
269
+ # "author": "llm-memory"
270
+ # }
271
+ # },
272
+ # {
273
+ # "content": "hello world"
274
+ # "meta_data": {
275
+ # "title": "hello world",
276
+ # "author": "llm-memory"
277
+ # }
278
+ # }
279
+ # ]
280
+ # ```
281
+
282
+ # ## Note
283
+ # - The `meta_data` should be a dict which both keys and values are strings.
284
+ # """)
285
+ # file = gr.File(label="upload document", type="file", file_types=[".json"])
286
+ # output = gr.Label(label="output", value="output")
287
+ # upload_btn = gr.Button("upload document", type="button")
288
+ # upload_btn.click(gradio_upload_document, file, output)
289
+ with gr.Tab("search"):
290
+ query = gr.Textbox(label="search", placeholder="Query")
291
+ document = gr.Dropdown(label="document", choices=[None] + [doc.name for doc in USER.documents], placeholder="document, optional")
292
+ top_k = gr.Number(label="top_k", placeholder="top_k, optional", value=10)
293
+ threshold = gr.Number(label="threshold", placeholder="threshold, optional", value=0.5)
294
+ output = gr.Code(label="output", language="json", value="output")
295
+ query_btn = gr.Button("Query")
296
+ query_btn.click(gradio_query, [query, document, top_k, threshold], output, api_name="search")
297
+
298
+
299
+ gradio_app = gr.routes.App.create_app(ui)
300
+ app.mount("/", gradio_app)
301
+
302
+ ui.launch()
di.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from storage import LocalStorage, Storage
2
+ from setting import Settings
3
+ from embedding import AzureOpenAITextAda002, Embedding, OpenAITextAda002
4
+ from index import Index, QDrantVectorStore
5
+ from model.user import User
6
+ from qdrant_client import QdrantClient
7
+
8
+ def initialize_di_for_test() -> tuple[Settings, Storage,Embedding,Index]:
9
+ SETTINGS = Settings(_env_file='./test/.env.test')
10
+ STORAGE = LocalStorage('./test/test_storage')
11
+ if SETTINGS.embedding_use_azure:
12
+ EMBEDDING = AzureOpenAITextAda002(
13
+ api_base=SETTINGS.embedding_azure_openai_api_base,
14
+ model_name=SETTINGS.embedding_azure_openai_model_name,
15
+ api_key=SETTINGS.embedding_azure_openai_api_key,
16
+ )
17
+ else:
18
+ EMBEDDING = OpenAITextAda002(SETTINGS.openai_api_key)
19
+ INDEX = QDrantVectorStore(
20
+ embedding=EMBEDDING,
21
+ client= QdrantClient(
22
+ url=SETTINGS.qdrant_url,
23
+ api_key=SETTINGS.qdrant_api_key,),
24
+
25
+ collection_name='test_collection',
26
+ )
27
+ INDEX.create_collection_if_not_exists()
28
+
29
+ return SETTINGS, STORAGE, EMBEDDING, INDEX
30
+
31
+ def initialize_di_for_app() -> tuple[Settings, Storage,Embedding,Index]:
32
+ SETTINGS = Settings(_env_file='.env')
33
+ STORAGE = LocalStorage('.local_storage')
34
+ if SETTINGS.embedding_use_azure:
35
+ EMBEDDING = AzureOpenAITextAda002(
36
+ api_base=SETTINGS.embedding_azure_openai_api_base,
37
+ model_name=SETTINGS.embedding_azure_openai_model_name,
38
+ api_key=SETTINGS.embedding_azure_openai_api_key,
39
+ )
40
+ else:
41
+ EMBEDDING = OpenAITextAda002(SETTINGS.openai_api_key)
42
+ INDEX = QDrantVectorStore(
43
+ embedding=EMBEDDING,
44
+ client= QdrantClient(
45
+ url=SETTINGS.qdrant_url,
46
+ api_key=SETTINGS.qdrant_api_key,),
47
+ collection_name='collection',
48
+ )
49
+
50
+
51
+ return SETTINGS, STORAGE, EMBEDDING, INDEX
embedding.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+
3
+ class Embedding:
4
+ type: str|None = None
5
+ vector_size: int|None = None
6
+ def generate_embedding(self, content: str) -> list[float]:
7
+ pass
8
+
9
+ class OpenAITextAda002(Embedding):
10
+ type: str = 'text-ada-002'
11
+ vector_size: int = 1536
12
+ api_key: str = None
13
+
14
+ def __init__(self, api_key: str):
15
+ self.api_key = api_key
16
+
17
+ def generate_embedding(self, content: str) -> list[float]:
18
+ # replace newline with space
19
+ content = content.replace('\n', ' ')
20
+ # limit to 8192 characters
21
+ content = content[:6000]
22
+ return openai.Embedding.create(
23
+ api_key=self.api_key,
24
+ api_type='openai',
25
+ input = content,
26
+ model="text-embedding-ada-002"
27
+ )["data"][0]["embedding"]
28
+
29
+ class AzureOpenAITextAda002(Embedding):
30
+ type: str = 'text-ada-002'
31
+ vector_size: int = 1536
32
+ api_key: str = None
33
+
34
+ def __init__(
35
+ self,
36
+ api_base: str,
37
+ model_name: str,
38
+ api_key: str):
39
+ self.api_key = api_key
40
+ self.model_name = model_name
41
+ self.api_key = api_key
42
+ self.api_base = api_base
43
+
44
+ def generate_embedding(self, content: str) -> list[float]:
45
+ # replace newline with space
46
+ content = content.replace('\n', ' ')
47
+ # limit to 8192 characters
48
+ content = content[:6000]
49
+ return openai.Embedding.create(
50
+ api_key=self.api_key,
51
+ api_type='azure',
52
+ api_base=self.api_base,
53
+ input = content,
54
+ engine=self.model_name,
55
+ api_version="2023-07-01-preview"
56
+ )["data"][0]["embedding"]
index.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from qdrant_client import QdrantClient
2
+ from qdrant_client.http.models import ScoredPoint
3
+
4
+ from embedding import Embedding
5
+ from model.document import Document
6
+ from model.record import Record
7
+ from model.user import User
8
+ from qdrant_client.http import models
9
+ import uuid
10
+ import tqdm
11
+
12
+ class Index:
13
+ type: str
14
+
15
+ def load_or_update_document(self, user: User, document: Document, progress: tqdm.tqdm = None):
16
+ pass
17
+
18
+ def remove_document(self, user: User, document: Document):
19
+ pass
20
+
21
+ def query_index(self, user: User, query: str, top_k: int = 10, threshold: float = 0.5) -> list[Record]:
22
+ pass
23
+
24
+ def query_document(self, user: User, document: Document, query: str, top_k: int = 10, threshold: float = 0.5) -> list[Record]:
25
+ pass
26
+
27
+ def contains(self, user: User, document: Document) -> bool:
28
+ pass
29
+
30
+ class QDrantVectorStore(Index):
31
+ _client: QdrantClient
32
+ _embedding: Embedding
33
+ collection_name: str
34
+ batch_size: int = 10
35
+ type: str = 'qdrant'
36
+
37
+ def __init__(
38
+ self,
39
+ client: QdrantClient,
40
+ embedding: Embedding,
41
+ collection_name: str):
42
+ self._embedding = embedding
43
+ self.collection_name = collection_name
44
+ self._client = client
45
+
46
+ def _response_to_records(self, response: list[ScoredPoint]) -> list[Record]:
47
+ for point in response:
48
+ meta_data = point.payload['meta_data']
49
+ yield Record(
50
+ embedding=point.vector,
51
+ meta_data= meta_data,
52
+ content=point.payload['content'],
53
+ document_id=point.payload['document_id'],
54
+ timestamp=point.payload['timestamp'],
55
+ )
56
+
57
+ def create_collection(self):
58
+ self._client.recreate_collection(
59
+ collection_name=self.collection_name,
60
+ vectors_config=models.VectorParams(
61
+ size=self._embedding.vector_size,
62
+ distance=models.Distance.COSINE),
63
+ )
64
+
65
+ def if_collection_exists(self) -> bool:
66
+ try:
67
+ self._client.get_collection(self.collection_name)
68
+ return True
69
+ except Exception:
70
+ return False
71
+
72
+ def create_collection_if_not_exists(self):
73
+ if not self.if_collection_exists():
74
+ self.create_collection()
75
+
76
+ def load_or_update_document(self, user: User, document: Document, progress: tqdm.tqdm = None):
77
+ self.create_collection_if_not_exists()
78
+
79
+ if self.contains(user, document):
80
+ self.remove_document(user, document)
81
+
82
+ group_id = user.user_name
83
+ # upsert records in batch
84
+ records = document.load_records()
85
+ records = list(records)
86
+
87
+ batch_range = range(0, len(records), self.batch_size)
88
+ if progress is not None:
89
+ batch_range = progress(batch_range)
90
+ for i in batch_range:
91
+ batch = records[i:i+self.batch_size]
92
+ uuids = [str(uuid.uuid4()) for _ in batch]
93
+ payloads = [{
94
+ 'content': record.content,
95
+ 'meta_data': record.meta_data,
96
+ 'document_id': record.document_id,
97
+ 'group_id': group_id,
98
+ 'timestamp': record.timestamp,
99
+ } for record in batch]
100
+ vectors = [record.embedding for record in batch]
101
+ self._client.upsert(
102
+ collection_name=self.collection_name,
103
+ points=models.Batch(
104
+ payloads=payloads,
105
+ ids=uuids,
106
+ vectors=vectors,
107
+ ),
108
+ )
109
+
110
+ def remove_document(self, user: User, document: Document):
111
+ if not self.if_collection_exists():
112
+ return
113
+
114
+ document_id = document.name
115
+ self._client.delete(
116
+ collection_name=self.collection_name,
117
+ points_selector=models.FilterSelector(
118
+ filter=models.Filter(
119
+ must=[
120
+ models.FieldCondition(
121
+ key="document_id",
122
+ match=models.MatchValue(value=document_id)
123
+ ),
124
+ models.FieldCondition(
125
+ key="group_id",
126
+ match=models.MatchValue(
127
+ value=user.user_name,
128
+ ),
129
+ )
130
+ ]
131
+ )
132
+ )
133
+ )
134
+
135
+ def contains(self, user: User, document: Document) -> bool:
136
+ document_id = document.name
137
+ group_id = user.user_name
138
+
139
+ count = self._client.count(
140
+ collection_name=self.collection_name,
141
+ count_filter=models.Filter(
142
+ must=[
143
+ models.FieldCondition(
144
+ key="document_id",
145
+ match=models.MatchValue(value=document_id)
146
+ ),
147
+ models.FieldCondition(
148
+ key="group_id",
149
+ match=models.MatchValue(
150
+ value=group_id,
151
+ ),
152
+ )
153
+ ]
154
+ ),
155
+ exact=True,
156
+ )
157
+
158
+ return count.count > 0
159
+
160
+ def query_index(self, user: User, query: str, top_k: int = 10, threshold: float = 0.5) -> list[Record]:
161
+ if not self.if_collection_exists():
162
+ return []
163
+
164
+ response = self._client.search(
165
+ collection_name=self.collection_name,
166
+ query_vector=self._embedding.generate_embedding(query),
167
+ limit=top_k,
168
+ query_filter= models.Filter(
169
+ must=[
170
+ models.FieldCondition(
171
+ key="group_id",
172
+ match=models.MatchValue(
173
+ value=user.user_name,
174
+ ),
175
+ )
176
+ ]
177
+ ),
178
+ score_threshold=threshold,
179
+ )
180
+
181
+ return self._response_to_records(response)
182
+
183
+ def query_document(self, user: User, document: Document, query: str, top_k: int = 10, threshold: float = 0.5) -> list[Record]:
184
+ if not self.if_collection_exists():
185
+ return []
186
+
187
+ response = self._client.search(
188
+ collection_name=self.collection_name,
189
+ query_vector=self._embedding.generate_embedding(query),
190
+ limit=top_k,
191
+ query_filter= models.Filter(
192
+ must=[
193
+ models.FieldCondition(
194
+ key="document_id",
195
+ match=models.MatchValue(value=document.name)
196
+ ),
197
+ models.FieldCondition(
198
+ key="group_id",
199
+ match=models.MatchValue(value=user.user_name),
200
+ )
201
+ ]
202
+ ),
203
+ score_threshold=threshold,
204
+ )
205
+
206
+ return self._response_to_records(response)
207
+
model/document.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from .record import Record
3
+ from storage import Storage
4
+ from embedding import Embedding
5
+ import time
6
+ import json
7
+
8
+ class Document(BaseModel):
9
+ name: str
10
+ description: str | None = None
11
+ status: str = 'uploading' # uploading, processing, done, failed
12
+ url: str | None = None
13
+
14
+ _embedding: Embedding
15
+ _storage: Storage
16
+
17
+ def load_records(self) -> list[Record]:
18
+ pass
19
+
20
+ class PlainTextDocument(Document):
21
+ def __init__(
22
+ self,
23
+ embedding: Embedding,
24
+ storage: Storage,
25
+ **kwargs):
26
+ super().__init__(**kwargs)
27
+ self._embedding = embedding
28
+ self._storage = storage
29
+
30
+ def _enhance_line(self, line: str) -> str:
31
+ return line
32
+
33
+ def load_records(self) -> list[Record]:
34
+ str = self._storage.load(self.url)
35
+ lines = str.split('\n')
36
+
37
+ for i, line in enumerate(lines):
38
+ # remove empty lines
39
+ if len(line.strip()) == 0:
40
+ continue
41
+ enhance_line = self._enhance_line(line)
42
+ embedding = self._embedding.generate_embedding(enhance_line)
43
+ embedding_type = self._embedding.type
44
+ meta_data = {
45
+ 'embedding_type': embedding_type,
46
+ 'document_id': self.name,
47
+ 'line_number': i,
48
+ 'source': line,
49
+ }
50
+
51
+ yield Record(
52
+ embedding=embedding,
53
+ meta_data=meta_data,
54
+ content=line,
55
+ document_id=self.name,
56
+ timestamp=int(time.time()))
57
+
58
+ class JsonDocument(Document):
59
+ def __init__(
60
+ self,
61
+ embedding: Embedding,
62
+ storage: Storage,
63
+ **kwargs):
64
+ super().__init__(**kwargs)
65
+ self._embedding = embedding
66
+ self._storage = storage
67
+
68
+ def load_records(self) -> list[Record]:
69
+ '''
70
+ json format:
71
+ {
72
+ 'content': str // the content of the record
73
+ 'meta_data': dict // the meta data of the record
74
+ }
75
+ '''
76
+ str = self._storage.load(self.url)
77
+ records = json.loads(str)
78
+ for i, item in enumerate(records):
79
+ # sleep 300ms
80
+ time.sleep(0.3)
81
+ embedding = self._embedding.generate_embedding(item['content'])
82
+ embedding_type = self._embedding.type
83
+ meta_data = {
84
+ 'embedding_type': embedding_type,
85
+ 'document_id': self.name,
86
+ 'line_number': i,
87
+ 'source': item['content'],
88
+ }
89
+ if 'meta_data' in item:
90
+ # merge meta data
91
+ meta_data = {**item['meta_data'], **meta_data}
92
+
93
+ yield Record(
94
+ embedding=embedding,
95
+ meta_data=meta_data,
96
+ content=item['content'],
97
+ document_id=self.name,
98
+ timestamp=int(time.time()))
model/record.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+
3
+ class Record(BaseModel):
4
+ content: str
5
+ embedding: list[float] | None = None
6
+ document_id: str | None = None
7
+ meta_data: dict | None = None
8
+ timestamp: int | None = None
model/user.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from .document import Document
3
+
4
+ class User(BaseModel):
5
+ user_name: str
6
+ email: str
7
+ full_name: str
8
+ disabled: bool = None
9
+
10
+ documents: list[Document] = None
11
+
12
+
13
+
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ fastapi[all]==0.103.1
2
+ openai==0.28.0
3
+ python-dotenv==1.0.0
4
+ qdrant-client==1.5.2
setting.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic_settings import BaseSettings, SettingsConfigDict
2
+
3
+ class Settings(BaseSettings):
4
+ openai_api_key: str | None = None
5
+ azure_openai_api_key: str | None = None
6
+ qdrant_api_key: str | None = None
7
+ qdrant_url: str | None = None
8
+ qdrant_host: str | None = None
9
+ qdrant_port: int | None = None
10
+
11
+
12
+ # embedding setting
13
+ embedding_use_azure: bool = False
14
+ embedding_azure_openai_api_base: str | None = None
15
+ embedding_azure_openai_model_name: str | None = None
16
+ embedding_azure_openai_api_key: str | None = None
17
+
18
+ model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8')
setup.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # setup
2
+ from setuptools import setup
3
+ setup(
4
+ name='llm_memory',
5
+ version='1.0',
6
+ author='LittleLittleCloud',
7
+ python_requires='>=3.7, <4',
8
+ install_requires=[
9
+ 'fastapi[all]==0.103.1',
10
+ 'openai==0.28.0',
11
+ 'python-dotenv==1.0.0',
12
+ 'qdrant-client==1.5.2',
13
+ ])
storage.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ class Storage:
4
+ def save(self, filename, data):
5
+ '''
6
+ Save or update a file
7
+ '''
8
+ pass
9
+
10
+ def delete(self, filename):
11
+ '''
12
+ Delete a file
13
+ '''
14
+ pass
15
+
16
+ def load(self, filename)->str:
17
+ '''
18
+ Load a file
19
+ '''
20
+ pass
21
+
22
+ def list(self)->list[str]:
23
+ '''
24
+ List all files
25
+ '''
26
+ pass
27
+
28
+
29
+ class LocalStorage(Storage):
30
+ def __init__(self, root):
31
+ if not os.path.exists(root):
32
+ os.makedirs(root)
33
+ self.root = root
34
+
35
+ def save(self, filename, data):
36
+ with open(os.path.join(self.root, filename), 'w', encoding='utf-8') as f:
37
+ f.write(data)
38
+
39
+ def delete(self, filename):
40
+ os.remove(os.path.join(self.root, filename))
41
+
42
+ def load(self, filename):
43
+ with open(os.path.join(self.root, filename), 'r', encoding='utf-8') as f:
44
+ return f.read()
45
+
46
+ def list(self):
47
+ return os.listdir(self.root)
48
+