TimurZav commited on
Commit
3fca975
1 Parent(s): bef6da0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +412 -0
app.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ import itertools
3
+ import gradio as gr
4
+ from __init__ import *
5
+ from llama_cpp import Llama
6
+ from chromadb.config import Settings
7
+ from typing import List, Optional, Union
8
+ from langchain.vectorstores import Chroma
9
+ from langchain.docstore.document import Document
10
+ from huggingface_hub.file_download import http_get
11
+ from langchain.embeddings import HuggingFaceEmbeddings
12
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
13
+
14
+
15
+ class LocalChatGPT:
16
+ def __init__(self):
17
+ self.llama_model: Optional[Llama] = None
18
+ self.embeddings: HuggingFaceEmbeddings = self.initialize_app()
19
+
20
+ def initialize_app(self) -> HuggingFaceEmbeddings:
21
+ """
22
+ Загружаем все модели из списка.
23
+ :return:
24
+ """
25
+ os.makedirs(MODELS_DIR, exist_ok=True)
26
+ model_url, model_name = list(DICT_REPO_AND_MODELS.items())[0]
27
+ final_model_path = os.path.join(MODELS_DIR, model_name)
28
+ os.makedirs("/".join(final_model_path.split("/")[:-1]), exist_ok=True)
29
+
30
+ if not os.path.exists(final_model_path):
31
+ with open(final_model_path, "wb") as f:
32
+ http_get(model_url, f)
33
+
34
+ self.llama_model = Llama(
35
+ model_path=final_model_path,
36
+ n_ctx=2000,
37
+ n_parts=1,
38
+ )
39
+
40
+ return HuggingFaceEmbeddings(model_name=EMBEDDER_NAME, cache_folder=MODELS_DIR)
41
+
42
+ def load_model(self, model_name):
43
+ """
44
+
45
+ :param model_name:
46
+ :return:
47
+ """
48
+ final_model_path = os.path.join(MODELS_DIR, model_name)
49
+ os.makedirs("/".join(final_model_path.split("/")[:-1]), exist_ok=True)
50
+
51
+ if not os.path.exists(final_model_path):
52
+ with open(final_model_path, "wb") as f:
53
+ if model_url := [i for i in DICT_REPO_AND_MODELS if DICT_REPO_AND_MODELS[i] == model_name]:
54
+ http_get(model_url[0], f)
55
+
56
+ self.llama_model = Llama(
57
+ model_path=final_model_path,
58
+ n_ctx=2000,
59
+ n_parts=1,
60
+ )
61
+ return model_name
62
+
63
+ @staticmethod
64
+ def load_single_document(file_path: str) -> Document:
65
+ """
66
+ Загружаем один документ.
67
+ :param file_path:
68
+ :return:
69
+ """
70
+ ext: str = "." + file_path.rsplit(".", 1)[-1]
71
+ assert ext in LOADER_MAPPING
72
+ loader_class, loader_args = LOADER_MAPPING[ext]
73
+ loader = loader_class(file_path, **loader_args)
74
+ return loader.load()[0]
75
+
76
+ @staticmethod
77
+ def get_message_tokens(model: Llama, role: str, content: str) -> list:
78
+ """
79
+
80
+ :param model:
81
+ :param role:
82
+ :param content:
83
+ :return:
84
+ """
85
+ message_tokens: list = model.tokenize(content.encode("utf-8"))
86
+ message_tokens.insert(1, ROLE_TOKENS[role])
87
+ message_tokens.insert(2, LINEBREAK_TOKEN)
88
+ message_tokens.append(model.token_eos())
89
+ return message_tokens
90
+
91
+ def get_system_tokens(self, model: Llama) -> list:
92
+ """
93
+
94
+ :param model:
95
+ :return:
96
+ """
97
+ system_message: dict = {"role": "system", "content": SYSTEM_PROMPT}
98
+ return self.get_message_tokens(model, **system_message)
99
+
100
+ @staticmethod
101
+ def upload_files(files: List[tempfile.TemporaryFile]) -> List[str]:
102
+ """
103
+
104
+ :param files:
105
+ :return:
106
+ """
107
+ return [f.name for f in files]
108
+
109
+ @staticmethod
110
+ def process_text(text: str) -> Optional[str]:
111
+ """
112
+
113
+ :param text:
114
+ :return:
115
+ """
116
+ lines: list = text.split("\n")
117
+ lines = [line for line in lines if len(line.strip()) > 2]
118
+ text = "\n".join(lines).strip()
119
+ return None if len(text) < 10 else text
120
+
121
+ @staticmethod
122
+ def update_text_db(
123
+ db: Optional[Chroma],
124
+ fixed_documents: List[Document],
125
+ ids: List[str]
126
+ ) -> Union[Optional[Chroma], str]:
127
+ if db:
128
+ data: dict = db.get()
129
+ files_db = {dict_data['source'].split('/')[-1] for dict_data in data["metadatas"]}
130
+ files_load = {dict_data.metadata["source"].split('/')[-1] for dict_data in fixed_documents}
131
+ if files_load == files_db:
132
+ # db.delete([item for item in data['ids'] if item not in ids])
133
+ # db.update_documents(ids, fixed_documents)
134
+
135
+ db.delete(data['ids'])
136
+ db.add_texts(
137
+ texts=[doc.page_content for doc in fixed_documents],
138
+ metadatas=[doc.metadata for doc in fixed_documents],
139
+ ids=ids
140
+ )
141
+ file_warning = f"Загружено {len(fixed_documents)} фрагментов! Можно задавать вопросы."
142
+ return db, file_warning
143
+
144
+ def build_index(
145
+ self,
146
+ file_paths: List[str],
147
+ db: Optional[Chroma],
148
+ chunk_size: int,
149
+ chunk_overlap: int
150
+ ):
151
+ """
152
+
153
+ :param file_paths:
154
+ :param db:
155
+ :param chunk_size:
156
+ :param chunk_overlap:
157
+ :return:
158
+ """
159
+ documents: List[Document] = [self.load_single_document(path) for path in file_paths]
160
+ text_splitter: RecursiveCharacterTextSplitter = RecursiveCharacterTextSplitter(
161
+ chunk_size=chunk_size, chunk_overlap=chunk_overlap
162
+ )
163
+ documents = text_splitter.split_documents(documents)
164
+ fixed_documents: List[Document] = []
165
+ for doc in documents:
166
+ doc.page_content = self.process_text(doc.page_content)
167
+ if not doc.page_content:
168
+ continue
169
+ fixed_documents.append(doc)
170
+
171
+ ids: List[str] = [
172
+ f"{path.split('/')[-1].replace('.txt', '')}{i}"
173
+ for path, i in itertools.product(file_paths, range(1, len(fixed_documents) + 1))
174
+ ]
175
+
176
+ self.update_text_db(db, fixed_documents, ids)
177
+
178
+ db = Chroma.from_documents(
179
+ documents=fixed_documents,
180
+ embedding=self.embeddings,
181
+ ids=ids,
182
+ client_settings=Settings(
183
+ anonymized_telemetry=False,
184
+ persist_directory="db"
185
+ )
186
+ )
187
+ file_warning = f"Загружено {len(fixed_documents)} фрагментов! Можно задавать вопросы."
188
+ return db, file_warning
189
+
190
+ @staticmethod
191
+ def user(message, history):
192
+ new_history = history + [[message, None]]
193
+ return "", new_history
194
+
195
+ @staticmethod
196
+ def regenerate_response(history):
197
+ """
198
+
199
+ :param history:
200
+ :return:
201
+ """
202
+ return "", history
203
+
204
+ @staticmethod
205
+ def retrieve(history, db: Optional[Chroma], retrieved_docs):
206
+ """
207
+
208
+ :param history:
209
+ :param db:
210
+ :param retrieved_docs:
211
+ :return:
212
+ """
213
+ if db:
214
+ last_user_message = history[-1][0]
215
+ try:
216
+ docs = db.similarity_search(last_user_message, k=4)
217
+ # retriever = db.as_retriever(search_kwargs={"k": k_documents})
218
+ # docs = retriever.get_relevant_documents(last_user_message)
219
+ except RuntimeError:
220
+ docs = db.similarity_search(last_user_message, k=1)
221
+ # retriever = db.as_retriever(search_kwargs={"k": 1})
222
+ # docs = retriever.get_relevant_documents(last_user_message)
223
+ source_docs = set()
224
+ for doc in docs:
225
+ for content in doc.metadata.values():
226
+ source_docs.add(content.split("/")[-1])
227
+ retrieved_docs = "\n\n".join([doc.page_content for doc in docs])
228
+ retrieved_docs = f"Документ - {''.join(list(source_docs))}.\n\n{retrieved_docs}"
229
+ return retrieved_docs
230
+
231
+ def bot(self, history, retrieved_docs):
232
+ """
233
+
234
+ :param history:
235
+ :param retrieved_docs:
236
+ :return:
237
+ """
238
+ if not history:
239
+ return
240
+ tokens = self.get_system_tokens(self.llama_model)[:]
241
+ tokens.append(LINEBREAK_TOKEN)
242
+
243
+ for user_message, bot_message in history[:-1]:
244
+ message_tokens = self.get_message_tokens(model=self.llama_model, role="user", content=user_message)
245
+ tokens.extend(message_tokens)
246
+
247
+ last_user_message = history[-1][0]
248
+ if retrieved_docs:
249
+ last_user_message = f"Контекст: {retrieved_docs}\n\nИспользуя контекст, ответь на вопрос: " \
250
+ f"{last_user_message}"
251
+ message_tokens = self.get_message_tokens(model=self.llama_model, role="user", content=last_user_message)
252
+ tokens.extend(message_tokens)
253
+
254
+ role_tokens = [self.llama_model.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN]
255
+ tokens.extend(role_tokens)
256
+ generator = self.llama_model.generate(
257
+ tokens,
258
+ top_k=30,
259
+ top_p=0.9,
260
+ temp=0.1
261
+ )
262
+
263
+ partial_text = ""
264
+ for i, token in enumerate(generator):
265
+ if token == self.llama_model.token_eos() or (MAX_NEW_TOKENS is not None and i >= MAX_NEW_TOKENS):
266
+ break
267
+ partial_text += self.llama_model.detokenize([token]).decode("utf-8", "ignore")
268
+ history[-1][1] = partial_text
269
+ yield history
270
+
271
+ def run(self):
272
+ """
273
+
274
+ :return:
275
+ """
276
+ with gr.Blocks(theme=gr.themes.Soft(), css=BLOCK_CSS) as demo:
277
+ db: Optional[Chroma] = gr.State(None)
278
+ favicon = f'<img src="{FAVICON_PATH}" width="48px" style="display: inline">'
279
+ gr.Markdown(
280
+ f"""<h1><center>{favicon} Я, Макар - текстовый ассистент на основе GPT</center></h1>"""
281
+ )
282
+
283
+ with gr.Row(elem_id="model_selector_row"):
284
+ models: list = list(DICT_REPO_AND_MODELS.values())
285
+ model_selector = gr.Dropdown(
286
+ choices=models,
287
+ value=models[0] if models else "",
288
+ interactive=True,
289
+ show_label=False,
290
+ container=False,
291
+ )
292
+
293
+ with gr.Row():
294
+ with gr.Column(scale=5):
295
+ chatbot = gr.Chatbot(label="Диалог", height=400)
296
+ with gr.Column(min_width=200, scale=4):
297
+ retrieved_docs = gr.Textbox(
298
+ label="Извлеченные фрагменты",
299
+ placeholder="Появятся после задавания вопросов",
300
+ interactive=False
301
+ )
302
+
303
+ with gr.Row():
304
+ with gr.Column(scale=20):
305
+ msg = gr.Textbox(
306
+ label="Отправить сообщение",
307
+ show_label=False,
308
+ placeholder="Отправить сообщение",
309
+ container=False
310
+ )
311
+ with gr.Column(scale=3, min_width=100):
312
+ submit = gr.Button("📤 Отправить", variant="primary")
313
+
314
+ with gr.Row():
315
+ # gr.Button(value="👍 Понравилось")
316
+ # gr.Button(value="👎 Не понравилось")
317
+ stop = gr.Button(value="⛔ Остановить")
318
+ regenerate = gr.Button(value="🔄 Повторить")
319
+ clear = gr.Button(value="🗑️ Очистить")
320
+
321
+ # # Upload files
322
+ # file_output.upload(
323
+ # fn=self.upload_files,
324
+ # inputs=[file_output],
325
+ # outputs=[file_paths],
326
+ # queue=True,
327
+ # ).success(
328
+ # fn=self.build_index,
329
+ # inputs=[file_paths, db, chunk_size, chunk_overlap],
330
+ # outputs=[db, file_warning],
331
+ # queue=True
332
+ # )
333
+
334
+ model_selector.change(
335
+ fn=self.load_model,
336
+ inputs=[model_selector],
337
+ outputs=[model_selector]
338
+ )
339
+
340
+ # Pressing Enter
341
+ submit_event = msg.submit(
342
+ fn=self.user,
343
+ inputs=[msg, chatbot],
344
+ outputs=[msg, chatbot],
345
+ queue=False,
346
+ ).success(
347
+ fn=self.retrieve,
348
+ inputs=[chatbot, db, retrieved_docs],
349
+ outputs=[retrieved_docs],
350
+ queue=True,
351
+ ).success(
352
+ fn=self.bot,
353
+ inputs=[chatbot, retrieved_docs],
354
+ outputs=chatbot,
355
+ queue=True,
356
+ )
357
+
358
+ # Pressing the button
359
+ submit_click_event = submit.click(
360
+ fn=self.user,
361
+ inputs=[msg, chatbot],
362
+ outputs=[msg, chatbot],
363
+ queue=False,
364
+ ).success(
365
+ fn=self.retrieve,
366
+ inputs=[chatbot, db, retrieved_docs],
367
+ outputs=[retrieved_docs],
368
+ queue=True,
369
+ ).success(
370
+ fn=self.bot,
371
+ inputs=[chatbot, retrieved_docs],
372
+ outputs=chatbot,
373
+ queue=True,
374
+ )
375
+
376
+ # Stop generation
377
+ stop.click(
378
+ fn=None,
379
+ inputs=None,
380
+ outputs=None,
381
+ cancels=[submit_event, submit_click_event],
382
+ queue=False,
383
+ )
384
+
385
+ # Regenerate
386
+ regenerate.click(
387
+ fn=self.regenerate_response,
388
+ inputs=[chatbot],
389
+ outputs=[msg, chatbot],
390
+ queue=False,
391
+ ).success(
392
+ fn=self.retrieve,
393
+ inputs=[chatbot, db, retrieved_docs],
394
+ outputs=[retrieved_docs],
395
+ queue=True,
396
+ ).success(
397
+ fn=self.bot,
398
+ inputs=[chatbot, retrieved_docs],
399
+ outputs=chatbot,
400
+ queue=True,
401
+ )
402
+
403
+ # Clear history
404
+ clear.click(lambda: None, None, chatbot, queue=False)
405
+
406
+ demo.queue(max_size=128, default_concurrency_limit=10, api_open=False)
407
+ demo.launch(server_name="0.0.0.0", max_threads=200)
408
+
409
+
410
+ if __name__ == "__main__":
411
+ local_chat_gpt = LocalChatGPT()
412
+ local_chat_gpt.run()