cawacci commited on
Commit
ff614e3
1 Parent(s): 6b39381

ver 0.7 for test

Browse files
Files changed (2) hide show
  1. app.py +869 -0
  2. requirements.txt +18 -0
app.py ADDED
@@ -0,0 +1,869 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------
2
+ # Libraries
3
+ # --------------------------------------
4
+ import os
5
+ import time
6
+ import gc # メモリ解放
7
+ import re # 正規表現で文章をクリーンアップ
8
+
9
+ # HuggingFace
10
+ import torch
11
+ from transformers import AutoTokenizer, AutoModelForCausalLM
12
+
13
+ # OpenAI
14
+ import openai
15
+ from langchain.embeddings.openai import OpenAIEmbeddings
16
+ from langchain.chat_models import ChatOpenAI
17
+
18
+ # LangChain
19
+ from langchain.llms import HuggingFacePipeline
20
+ from transformers import pipeline
21
+
22
+ from langchain.embeddings import HuggingFaceEmbeddings
23
+ from langchain.chains import VectorDBQA
24
+ from langchain.vectorstores import Chroma
25
+
26
+ from langchain import PromptTemplate, ConversationChain
27
+ from langchain.chains.question_answering import load_qa_chain # QA Chat
28
+ from langchain.document_loaders import SeleniumURLLoader # URL取得
29
+ from langchain.docstore.document import Document # テキストをドキュメント化
30
+ # from langchain.memory import ConversationBufferWindowMemory # チャット履歴
31
+ from langchain.memory import ConversationSummaryBufferMemory # チャット履歴
32
+
33
+ from typing import Any
34
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
35
+
36
+ # Gradio
37
+ import gradio as gr
38
+
39
+ # PyPdf
40
+ from pypdf import PdfReader
41
+
42
+ # test
43
+ import langchain # (debug=Trueにするため)
44
+
45
+ # --------------------------------------
46
+ # ユーザ別セッションの変数値を記録するクラス
47
+ #  (参考)https://blog.shikoan.com/gradio-state/
48
+ # --------------------------------------
49
+ class SessionState:
50
+ def __init__(self):
51
+ # Hugging Face
52
+ self.tokenizer = None
53
+ self.pipe = None
54
+ self.model = None
55
+
56
+ # LangChain
57
+ self.llm = None
58
+ self.embeddings = None
59
+ self.current_model = ""
60
+ self.current_embedding = ""
61
+ self.db = None # Vector DB
62
+ self.memory = None # Langchain Chat Memory
63
+ self.qa_chain = None # load_qa_chain
64
+ self.conversation_chain = None # ConversationChain
65
+ self.embedded_urls = []
66
+
67
+ # Apps
68
+ self.dialogue = [] # Recent Chat History for display
69
+
70
+ # --------------------------------------
71
+ # Empty Cache
72
+ # --------------------------------------
73
+ def cache_clear(self):
74
+ if torch.cuda.is_available():
75
+ torch.cuda.empty_cache() # GPU Memory Clear
76
+
77
+ gc.collect() # CPU Memory Clear
78
+
79
+ # --------------------------------------
80
+ # Clear Models (llm: llm model, embd: embeddings, db: vectordb)
81
+ # --------------------------------------
82
+ def clear_memory(self, llm=False, embd=False, db=False):
83
+ # DB
84
+ if db and self.db:
85
+ self.db.delete_collection()
86
+ self.db = None
87
+ self.embedded_urls = []
88
+
89
+ # Embeddings model
90
+ if llm or embd:
91
+ self.embeddings = None
92
+ self.current_embedding = ""
93
+ self.qa_chain = None
94
+
95
+ # LLM model
96
+ if llm:
97
+ self.llm = None
98
+ self.pipe = None
99
+ self.model = None
100
+ self.current_model = ""
101
+ self.tokenizer = None
102
+ self.memory = None
103
+ self.chat_history = [] # ←必要性を要検証
104
+
105
+ self.cache_clear()
106
+
107
+ # --------------------------------------
108
+ # Load Chat History as a list
109
+ # --------------------------------------
110
+ def load_chat_history(self) -> list:
111
+ chat_history = []
112
+ try:
113
+ chat_memory = self.memory.load_memory_variables({})['chat_history']
114
+ except KeyError:
115
+ return chat_history
116
+
117
+ # チャット履歴をペアごとに読み取る
118
+ for i in range(0, len(chat_memory), 2):
119
+ user_message = chat_memory[i].content
120
+ ai_message = ""
121
+ if i + 1 < len(chat_memory):
122
+ ai_message = chat_memory[i + 1].content
123
+ chat_history.append([user_message, ai_message])
124
+ return chat_history
125
+
126
+ # --------------------------------------
127
+ # 自作TextSplitter(テキストをLLMのトークン数内に分割)
128
+ # (参考)https://www.sato-susumu.com/entry/2023/04/30/131338
129
+ #  → 「!」、「?」、「)」、「.」、「!」、「?」、「,」などを追加
130
+ # --------------------------------------
131
+ class JPTextSplitter(RecursiveCharacterTextSplitter):
132
+ def __init__(self, **kwargs: Any):
133
+ separators = ["\n\n", "\n", "。", "!", "?", ")","、", ".", "!", "?", ",", " ", ""]
134
+ super().__init__(separators=separators, **kwargs)
135
+
136
+ # チャンクの分割
137
+ chunk_size = 512
138
+ chunk_overlap = 35
139
+
140
+ text_splitter = JPTextSplitter(
141
+ chunk_size = chunk_size, # チャンクの最大文字数
142
+ chunk_overlap = chunk_overlap, # オーバーラップの最大文字数
143
+ )
144
+
145
+ # --------------------------------------
146
+ # DeepL でメモリを翻訳しトークン数を削減(OpenAIモデル利用時)
147
+ # --------------------------------------
148
+ DEEPL_API_ENDPOINT = "https://api-free.deepl.com/v2/translate"
149
+ DEEPL_API_KEY = "YOUR_DEEPL_API_KEY"
150
+
151
+ def deepl_memory(ss: SessionState) -> (SessionState):
152
+ if ss.current_model == "gpt-3.5-turbo":
153
+ # メモリから会話履歴を取得
154
+ user_message = ss.memory.chat_memory.messages[-1][0].content
155
+ ai_message = ss.memory.chat_memory.messages[-1][1].content
156
+ text = [user_message, ai_message]
157
+
158
+ # DeepL設定
159
+ params = {
160
+ "auth_key": DEEPL_API_KEY,
161
+ "text": text,
162
+ "target_lang": "EN",
163
+ "source_lang": "JA"
164
+ }
165
+ request = requests.post(DEEPL_API_ENDPOINT, data=params)
166
+ request.raise_for_status() # 応答のステータスコードがエラーの場合は例外を発生させます。
167
+ response = request.json()
168
+
169
+ # JSONから翻訳文を取得
170
+ user_message = response["translations"][0]["text"]
171
+ ai_message = response["translations"][1]["text"]
172
+
173
+ # memoryの最後の会話を削除し、翻訳文を追加
174
+ ss.memory.chat_memory.messages = ss.memory.chat_memory.messages[:-1]
175
+ ss.memory.chat_memory.add_user_message(user_message)
176
+ ss.memory.chat_memory.add_ai_message(ai_message)
177
+
178
+ return ss
179
+
180
+ # --------------------------------------
181
+ # LangChain カスタムプロンプト各種
182
+ # llama tokenizer
183
+ # https://belladoreai.github.io/llama-tokenizer-js/example-demo/build/
184
+
185
+ # OpenAI tokenizer
186
+ # https://platform.openai.com/tokenizer
187
+ # --------------------------------------
188
+
189
+ # --------------------------------------
190
+ # Conversation Chain Template
191
+ # --------------------------------------
192
+
193
+ # Tokens: OpenAI 104/ Llama 105 <- In Japanese: Tokens: OpenAI 191/ Llama 162
194
+ sys_chat_message = """
195
+ The following is a conversation between an AI concierge and a customer.
196
+ The AI understands what the customer wants to know from the conversation history and the latest question,
197
+ and gives many specific details in Japanese. If the AI does not know the answer to a question, it does not
198
+ make up an answer and says "誠に申し訳ございませんが、その点についてはわかりかねます".
199
+ """.replace("\n", "")
200
+
201
+ chat_common_format = """
202
+ ===
203
+ Question: {query}
204
+ ===
205
+ Conversation History:
206
+ {chat_history}
207
+ ===
208
+ 日本語の回答:"""
209
+
210
+ chat_template_std = f"{sys_chat_message}{chat_common_format}"
211
+ chat_template_llama2 = f"<s>[INST] <<SYS>>{sys_chat_message}<</SYS>>{chat_common_format}[/INST]"
212
+
213
+ # --------------------------------------
214
+ # QA Chain Template
215
+ # --------------------------------------
216
+ # Tokens: OpenAI 113/ Llama 111 <- In Japanese: Tokens: OpenAI 256/ Llama 225
217
+ sys_qa_message = """
218
+ You are an AI concierge who carefully answers questions from customers based on references.
219
+ You understand what the customer wants to know from the "Conversation History" and "Question",
220
+ and give a specific answer in Japanese using sentences extracted from the following references.
221
+ If you do not know the answer, do not make up an answer and reply,
222
+ "誠に申し訳ございませんが、その点についてはわかりかねます".
223
+ """.replace("\n", "")
224
+
225
+ qa_common_format = """
226
+ ===
227
+ Question:
228
+ {query}
229
+ ===
230
+ References:
231
+ {context}
232
+ ===
233
+ Conversation History:
234
+ {chat_history}
235
+ ===
236
+ 日本語の回答:"""
237
+
238
+ qa_template_std = f"{sys_qa_message}{qa_common_format}"
239
+ qa_template_llama2 = f"<s>[INST] <<SYS>>{sys_qa_message}<</SYS>>{qa_common_format}[/INST]"
240
+
241
+ # --------------------------------------
242
+ # ConversationSummaryBufferMemoryの要約プロンプト
243
+ # ソース → https://github.com/langchain-ai/langchain/blob/894c272a562471aadc1eb48e4a2992923533dea0/langchain/memory/prompt.py#L26-L49
244
+ # --------------------------------------
245
+ # Tokens: OpenAI 212/ Llama 214 <- In Japanese: Tokens: OpenAI 397/ Llama 297
246
+ conversation_summary_template = """
247
+ Using the example as a guide, compose a summary in English that gives an overview of the conversation by summarizing the "current summary" and the "new conversation".
248
+ ===
249
+ Example
250
+ [Current Summary] Customer asks AI what it thinks about Artificial Intelligence, AI says Artificial Intelligence is a good tool.
251
+
252
+ [New Conversation]
253
+ Human: なぜ人工知能が良いツールだと思いますか?
254
+ AI: 人工知能は「人間の可能性を最大限に引き出すことを助ける」からです。
255
+
256
+ [New Summary] Customer asks what you think about Artificial Intelligence, and AI responds that it is a good force that helps humans reach their full potential.
257
+ ===
258
+ [Current Summary] {summary}
259
+
260
+ [New Conversation]
261
+ {new_lines}
262
+
263
+ [New Summary]
264
+ """.strip()
265
+
266
+ # モデル読み込み
267
+ def load_models(
268
+ ss: SessionState,
269
+ model_id: str,
270
+ embedding_id: str,
271
+ openai_api_key: str,
272
+ load_in_8bit: bool,
273
+ verbose: bool,
274
+ temperature: float,
275
+ min_length: int,
276
+ max_new_tokens: int,
277
+ top_k: int,
278
+ top_p: float,
279
+ repetition_penalty: float,
280
+ num_return_sequences: int,
281
+ ) -> (SessionState, str):
282
+
283
+ # --------------------------------------
284
+ # OpenAI API KEYの確認
285
+ # --------------------------------------
286
+ if (model_id == "gpt-3.5-turbo" or embedding_id == "text-embedding-ada-002"):
287
+ # 前処理
288
+ if not os.environ["OPENAI_API_KEY"]:
289
+ status_message = "❌ OpenAI API KEY を設定してください"
290
+ return ss, status_message
291
+
292
+ # --------------------------------------
293
+ # LLMの設定
294
+ # --------------------------------------
295
+ # OpenAI Model
296
+ if model_id == "gpt-3.5-turbo":
297
+ ss.clear_memory(llm=True, db=True)
298
+ ss.llm = ChatOpenAI(
299
+ model_name = model_id,
300
+ temperature = temperature,
301
+ verbose = verbose,
302
+ max_tokens = max_new_tokens,
303
+ )
304
+
305
+ # Hugging Face GPT Model
306
+ else:
307
+ ss.clear_memory(llm=True, db=True)
308
+
309
+ if model_id == "rinna/bilingual-gpt-neox-4b-instruction-sft":
310
+ ss.tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
311
+ else:
312
+ ss.tokenizer = AutoTokenizer.from_pretrained(model_id)
313
+
314
+ ss.model = AutoModelForCausalLM.from_pretrained(
315
+ model_id,
316
+ load_in_8bit = load_in_8bit,
317
+ torch_dtype = torch.float16,
318
+ device_map = "auto",
319
+ )
320
+
321
+ ss.pipe = pipeline(
322
+ "text-generation",
323
+ model = ss.model,
324
+ tokenizer = ss.tokenizer,
325
+ min_length = min_length,
326
+ max_new_tokens = max_new_tokens,
327
+ do_sample = True,
328
+ top_k = top_k,
329
+ top_p = top_p,
330
+ repetition_penalty = repetition_penalty,
331
+ num_return_sequences = num_return_sequences,
332
+ temperature = temperature,
333
+ )
334
+ ss.llm = HuggingFacePipeline(pipeline=ss.pipe)
335
+
336
+ # --------------------------------------
337
+ # 埋め込みモデルの設定
338
+ # --------------------------------------
339
+ if ss.current_embedding == embedding_id:
340
+ return
341
+
342
+ # Reset embeddings and vectordb
343
+ ss.clear_memory(embd=True, db=True)
344
+
345
+ if embedding_id == "None":
346
+ pass
347
+
348
+ # OpenAI
349
+ elif embedding_id == "text-embedding-ada-002":
350
+ ss.embeddings = OpenAIEmbeddings()
351
+
352
+ # Hugging Face
353
+ else:
354
+ ss.embeddings = HuggingFaceEmbeddings(model_name=embedding_id)
355
+
356
+ # --------------------------------------
357
+ # 現在のモデル名を SessionStateオブジェクトに保存
358
+ #---------------------------------------
359
+ ss.current_model = model_id
360
+ ss.current_embedding = embedding_id
361
+
362
+ # Status Message
363
+ status_message = "✅ LLM: " + ss.current_model + ", embeddings: " + ss.current_embedding
364
+
365
+ return ss, status_message
366
+
367
+ def conversation_prep(ss: SessionState) -> SessionState:
368
+ if ss.conversation_chain is None:
369
+
370
+ human_prefix = "Human: "
371
+ ai_prefix = "AI: "
372
+ chat_template = chat_template_std
373
+
374
+ if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft":
375
+ # Rinnaモデル向けの設定(改行コード修正、メモリ用prefix (公式ページ参照)
376
+ chat_template = chat_template.replace("\n", "<NL>")
377
+ human_prefix = "ユーザー: "
378
+ ai_prefix = "システム: "
379
+
380
+ elif ss.current_model.startswith("elyza/ELYZA-japanese-Llama-2-7b"):
381
+ chat_template = chat_template_llama2
382
+
383
+ chat_prompt = PromptTemplate(input_variables=['query', 'chat_history'], template=chat_template)
384
+
385
+ if ss.memory is None:
386
+ conversation_summary_prompt = PromptTemplate(input_variables=['summary', 'new_lines'], template=conversation_summary_template)
387
+ ss.memory = ConversationSummaryBufferMemory(
388
+ llm = ss.llm,
389
+ memory_key = "chat_history",
390
+ input_key = "query",
391
+ output_key = "output_text",
392
+ return_messages = True,
393
+ human_prefix = human_prefix,
394
+ ai_prefix = ai_prefix,
395
+ max_token_limit = 512,
396
+ prompt = conversation_summary_prompt,
397
+ )
398
+
399
+ ss.conversation_chain = ConversationChain(
400
+ llm=ss.llm,
401
+ prompt = chat_prompt,
402
+ memory = ss.memory
403
+ )
404
+
405
+ return ss
406
+
407
+ def initialize_db(ss: SessionState) -> SessionState:
408
+
409
+ # client = chromadb.PersistentClient(path="./db")
410
+ ss.db = Chroma(
411
+ collection_name = "user_reference",
412
+ embedding_function = ss.embeddings,
413
+ # client = client
414
+ )
415
+
416
+ return ss
417
+
418
+ def embedding_process(ss: SessionState, ref_documents: Document) -> SessionState:
419
+
420
+ # --------------------------------------
421
+ # 文章構成と不要な文字列の削除
422
+ # --------------------------------------
423
+ for i in range(len(ref_documents)):
424
+ content = ref_documents[i].page_content.strip()
425
+
426
+ # --------------------------------------
427
+ # PDFの場合は読み取りエラー対策で文書修正を強めに実施
428
+ # --------------------------------------
429
+ if ".pdf" in ref_documents[i].metadata['source']:
430
+ pdf_replacement_sets = [
431
+ ('\n ', '**PLACEHOLDER+SPACE**'),
432
+ ('\n\u3000', '**PLACEHOLDER+SPACE**'),
433
+ ('.\n', '。**PLACEHOLDER**'),
434
+ (',\n', '。**PLACEHOLDER**'),
435
+ ('?\n', '。**PLACEHOLDER**'),
436
+ ('!\n', '。**PLACEHOLDER**'),
437
+ ('!\n', '。**PLACEHOLDER**'),
438
+ ('。\n', '。**PLACEHOLDER**'),
439
+ ('!\n', '!**PLACEHOLDER**'),
440
+ (')\n', '!**PLACEHOLDER**'),
441
+ (']\n', '!**PLACEHOLDER**'),
442
+ ('?\n', '?**PLACEHOLDER**'),
443
+ (')\n', '?**PLACEHOLDER**'),
444
+ ('】\n', '?**PLACEHOLDER**'),
445
+ ]
446
+ for original, replacement in pdf_replacement_sets:
447
+ content = content.replace(original, replacement)
448
+ content = content.replace(" ", "")
449
+ # --------------------------------------
450
+
451
+ # 不要文字列・空白の削除
452
+ remove_texts = ["\n", "\r", " "]
453
+ for remove_text in remove_texts:
454
+ content = content.replace(remove_text, "")
455
+
456
+ # タブや連続空白をシングルスペースに変換
457
+ replace_texts = ["\t", "\u3000"]
458
+ for replace_text in replace_texts:
459
+ content = content.replace(replace_text, " ")
460
+
461
+ # PDFの正当な改行をもとに戻す。
462
+ if ".pdf" in ref_documents[i].metadata['source']:
463
+ content = content.replace('**PLACEHOLDER**', '\n').replace('**PLACEHOLDER+SPACE**', '\n ')
464
+
465
+ ref_documents[i].page_content = content
466
+
467
+ # --------------------------------------
468
+ # チャンクに分割
469
+ texts = text_splitter.split_documents(ref_documents)
470
+
471
+ # --------------------------------------
472
+ # multi-e5 モデルの学習環境に合わせて文言を追加
473
+ # https://hironsan.hatenablog.com/entry/2023/07/05/073150
474
+ # --------------------------------------
475
+ if ss.current_embedding == "intfloat/multilingual-e5-large":
476
+ for i in range(len(texts)):
477
+ texts[i].page_content = "passage:" + texts[i].page_content
478
+
479
+ # vectordb の初期化
480
+ if ss.db is None:
481
+ ss = initialize_db(ss)
482
+
483
+ # db に埋め込み
484
+ # ss.db = Chroma.from_documents(texts, ss.embeddings)
485
+ ss.db.add_documents(documents=texts, embedding=ss.embeddings)
486
+
487
+ # --------------------------------------
488
+ # QAチェーンの設定
489
+ # --------------------------------------
490
+ if ss.qa_chain is None:
491
+
492
+ # QAメモリ
493
+ human_prefix = "Human: "
494
+ ai_prefix = "AI: "
495
+ qa_template = qa_template_std
496
+
497
+ if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft":
498
+ # Rinnaモデル向けの設定(改行コード修正、メモリ用prefix (公式ページ参照)
499
+ qa_template = qa_template.replace("\n", "<NL>")
500
+ human_prefix = "ユーザー: "
501
+ ai_prefix = "システム: "
502
+
503
+ elif ss.current_model.startswith("elyza/ELYZA-japanese-Llama-2-7b"):
504
+ qa_template = qa_template_llama2
505
+
506
+ qa_prompt = PromptTemplate(input_variables=['context', 'query', 'chat_history'], template=qa_template)
507
+
508
+ if ss.memory is None:
509
+ conversation_summary_prompt = PromptTemplate(input_variables=['summary', 'new_lines'], template=conversation_summary_template)
510
+ ss.memory = ConversationSummaryBufferMemory(
511
+ llm = ss.llm,
512
+ memory_key = "chat_history",
513
+ input_key = "query",
514
+ output_key = "output_text",
515
+ return_messages = True,
516
+ human_prefix = human_prefix,
517
+ ai_prefix = ai_prefix,
518
+ max_token_limit = 512,
519
+ prompt = conversation_summary_prompt,
520
+ )
521
+
522
+ ss.qa_chain = load_qa_chain(ss.llm, chain_type="stuff", memory=ss.memory, prompt=qa_prompt)
523
+
524
+ return ss
525
+
526
+ def embed_ref(ss: SessionState, urls: str, fileobj: list, header_lim: int, footer_lim: int) -> (SessionState, str):
527
+
528
+ url_flag = "-"
529
+ pdf_flag = "-"
530
+
531
+ # --------------------------------------
532
+ # URLの読み込みとvectordb登録
533
+ # --------------------------------------
534
+
535
+ # URLリストの前処理(リスト化、重複削除、非URL排除)
536
+ urls = list({url for url in urls.split("\n") if url and "://" in url})
537
+
538
+ if urls:
539
+ # 登録済みURL(ss.embedded_urls)との重複を排除。登録済みリストに登録
540
+ urls = [url for url in urls if url not in ss.embedded_urls]
541
+ ss.embedded_urls.extend(urls)
542
+
543
+ # ウェブページの読み込み
544
+ loader = SeleniumURLLoader(urls=urls)
545
+ ref_documents = loader.load()
546
+
547
+ # 埋め込み処理の実行
548
+ ss = embedding_process(ss, ref_documents)
549
+
550
+ url_flag = "✅ 登録済"
551
+
552
+ # --------------------------------------
553
+ # PDFのヘッダーとフッターを除去してvectordb登録
554
+ #  https://pypdf.readthedocs.io/en/stable/user/extract-text.html
555
+ # --------------------------------------
556
+
557
+ if fileobj is None:
558
+ pass
559
+
560
+ else:
561
+ # ファイル名リストを取得
562
+ pdf_paths = []
563
+ for path in fileobj:
564
+ pdf_paths.append(path.name)
565
+
566
+ # リストの初期化
567
+ ref_documents = []
568
+
569
+ # 各PDFファイルを読み込み
570
+ for pdf_path in pdf_paths:
571
+ pdf = PdfReader(pdf_path)
572
+ body = []
573
+
574
+ def visitor_body(text, cm, tm, font_dict, font_size):
575
+ y = tm[5]
576
+ if y > footer_lim and y < header_lim: # y座標がヘッダーとフッターの間にあるかどうかを確認
577
+ parts.append(text)
578
+
579
+ for page in pdf.pages:
580
+ parts = []
581
+ page.extract_text(visitor_text=visitor_body)
582
+ body.append("".join(parts))
583
+
584
+ body = "\n".join(body)
585
+
586
+ # パスからファイル名のみを取得
587
+ filename = os.path.basename(pdf_path)
588
+ # 取得テキスト → LangChain ドキュメント変換
589
+ ref_documents.append(Document(page_content=body, metadata={"source": filename}))
590
+
591
+ # 埋め込み処理の実行
592
+ ss = embedding_process(ss, ref_documents)
593
+
594
+ pdf_flag = "✅ 登録済"
595
+
596
+
597
+ langchain.debug=True
598
+
599
+ status_message = "URL: " + url_flag + " / PDF: " + pdf_flag
600
+ return ss, status_message
601
+
602
+ def clear_db(ss: SessionState) -> (SessionState, str):
603
+ try:
604
+ ss.db.delete_collection()
605
+ status_message = "✅ 参照データを削除しました。"
606
+
607
+ except NameError:
608
+ status_message = "❌ 参照データが登録されていません。"
609
+
610
+ return ss, status_message
611
+
612
+ # ----------------------------------------------------------------------------
613
+ # query入力 ▶ [def user] ▶ [ def bot ] ▶ [def show_response] ▶ チャットボット画面
614
+ # ⬇ ⬇ ⬆
615
+ # チャットボット画面 [qa_predict / conversation_predict]
616
+ # ----------------------------------------------------------------------------
617
+
618
+ def user(ss: SessionState, query) -> (SessionState, list):
619
+ # 会話履歴が一定数を超えた場合は、最初の履歴を削除する
620
+ if len(ss.dialogue) > 10:
621
+ ss.dialogue.pop(0)
622
+
623
+ ss.dialogue = ss.dialogue + [(query, None)] # 会話履歴(None はボットの回答欄=空欄)
624
+ chat_history = ss.dialogue
625
+
626
+ # チャット画面=chat_history
627
+ return ss, chat_history
628
+
629
+ def bot(ss: SessionState, query, qa_flag) -> (SessionState, str):
630
+ if qa_flag is True:
631
+ ss = qa_predict(ss, query) # LLMで回答を生成
632
+
633
+ else:
634
+ ss = conversation_prep(ss)
635
+ ss = chat_predict(ss, query)
636
+
637
+ return ss, "" # ssとquery欄(空欄)
638
+
639
+ def chat_predict(ss: SessionState, query) -> SessionState:
640
+ response = ss.conversation_chain.predict(input=query)
641
+ ss.dialogue[-1] = (ss.dialogue[-1][0], response)
642
+ return ss
643
+
644
+ def qa_predict(ss: SessionState, query) -> SessionState:
645
+
646
+ # Rinnaモデル向けの設定(クエリの改行コード修正)
647
+ if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft":
648
+ query = query.strip().replace("\n", "<NL>")
649
+ else:
650
+ query = query.strip()
651
+
652
+ # multilingual-e5向けのクエリ文言prefix
653
+ if ss.current_embedding == "intfloat/multilingual-e5-large":
654
+ db_query_str = "query: " + query
655
+ else:
656
+ db_query_str = query
657
+
658
+ # DBから関連文書と出典を抽出
659
+ docs = ss.db.similarity_search(db_query_str, k=2)
660
+ sources= "\n\n[Sources]\n" + '\n - '.join(list(set(doc.metadata['source'] for doc in docs if 'source' in doc.metadata)))
661
+
662
+ # Rinnaモデル向けの設定(抽出文書の改行コード修正)
663
+ if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft":
664
+ for i in range(len(docs)):
665
+ docs[i].page_content = docs[i].page_content.strip().replace("\n", "<NL>")
666
+
667
+ # 回答の生成(最大3回の試行)
668
+ for _ in range(3):
669
+ result = ss.qa_chain({"input_documents": docs, "query": query})
670
+ result["output_text"] = result["output_text"].replace("<NL>", "\n").strip("...").strip("回答:").strip()
671
+
672
+ # result["output_text"]が空欄でない場合、メモリーを更新して返す
673
+ if result["output_text"] != "":
674
+ response = result["output_text"] + sources
675
+ ss.memory.chat_memory.messages = ss.memory.chat_memory.messages[:-1] # 最後の会話を削除
676
+ ss.memory.chat_memory.add_user_message(query)
677
+ ss.memory.chat_memory.add_ai_message(response)
678
+ ss.dialogue[-1] = (ss.dialogue[-1][0], response)
679
+ return ss
680
+ else:
681
+ # 空欄の場合は直近の履歴を削除してやり直し
682
+ ss.memory.chat_memory.messages = ss.memory.chat_memory.messages[:-1]
683
+
684
+ # 3回の試行後も空欄の場合
685
+ response = "3回試行しましたが、情報製生成できませんでした。"
686
+ if sources != "":
687
+ response += "参考文献の抽出には成功していますので、言語モデルを変えてお試しください。"
688
+
689
+ # ユーザーメッセージと AI メッセージの追加
690
+ ss.memory.chat_memory.add_user_message(query.replace("<NL>", "\n"))
691
+ ss.memory.chat_memory.add_ai_message(response)
692
+ ss.dialogue[-1] = (ss.dialogue[-1][0], response) # 会話履歴
693
+ return ss
694
+
695
+ # 回答を1文字ずつチャット画面に表示する
696
+ def show_response(ss: SessionState) -> str:
697
+ # chat_history = ss.load_chat_history() # メモリから会話履歴をリスト型で取得
698
+ # response = chat_history[-1][1] # メモリから最新の会話[-1]を取得し、チャットボットの回答[1]を退避
699
+ # chat_history[-1][1] = "" # 逐次表示のため、チャットボットの回答[1]を空にする
700
+
701
+ chat_history = [list(item) for item in ss.dialogue] # タプルをリストに変換して、メモリから会話履歴を取得
702
+ response = chat_history[-1][1] # メモリから最新の会話[-1]を取得し、チャットボットの回答[1]を退避
703
+ chat_history[-1][1] = "" # 逐次表示のため、チャットボットの回答[1]を空にする
704
+
705
+ for character in response:
706
+ chat_history[-1][1] += character
707
+ time.sleep(0.05)
708
+ yield chat_history
709
+
710
+ with gr.Blocks() as demo:
711
+
712
+ # ユーザ別セッションメモリのインスタンス化(リロードでリセット)
713
+ ss = gr.State(SessionState())
714
+
715
+ # --------------------------------------
716
+ # API KEY をセット/クリアする関数
717
+ # --------------------------------------
718
+ def openai_api_setfn(openai_api_key) -> str:
719
+ if not openai_api_key or not openai_api_key.startswith("sk-") or len(openai_api_key) < 50:
720
+ os.environ["OPENAI_API_KEY"] = ""
721
+ status_message = "❌ 有効なAPIキーを入力してください"
722
+ return status_message
723
+ else:
724
+ os.environ["OPENAI_API_KEY"] = openai_api_key
725
+ status_message = "✅ APIキーを設定しました"
726
+ return status_message
727
+
728
+ def openai_api_clsfn(ss) -> (str, str):
729
+ openai_api_key = ""
730
+ os.environ["OPENAI_API_KEY"] = ""
731
+ status_message = "✅ APIキーの削除が完了しました"
732
+ return status_message, ""
733
+
734
+ # --------------------------------------
735
+ # 回答の継続ボタン
736
+ # --------------------------------------
737
+ def continue_pred():
738
+ query = "回答を続けてください"
739
+ return query
740
+
741
+ with gr.Tabs():
742
+ # --------------------------------------
743
+ # Setting Tab
744
+ # --------------------------------------
745
+ with gr.TabItem("1. LLM設定"):
746
+ with gr.Row():
747
+ model_id = gr.Dropdown(
748
+ choices=[
749
+ 'elyza/ELYZA-japanese-Llama-2-7b-fast-instruct',
750
+ 'rinna/bilingual-gpt-neox-4b-instruction-sft',
751
+ 'gpt-3.5-turbo',
752
+ ],
753
+ value="elyza/ELYZA-japanese-Llama-2-7b-fast-instruct",
754
+ label='LLM model',
755
+ interactive=True,
756
+ )
757
+ with gr.Row():
758
+ embedding_id = gr.Dropdown(
759
+ choices=[
760
+ 'intfloat/multilingual-e5-large',
761
+ 'sonoisa/sentence-bert-base-ja-mean-tokens-v2',
762
+ 'oshizo/sbert-jsnli-luke-japanese-base-lite',
763
+ 'text-embedding-ada-002',
764
+ "None"
765
+ ],
766
+ value="sonoisa/sentence-bert-base-ja-mean-tokens-v2",
767
+ label = 'Embedding model',
768
+ interactive=True,
769
+ )
770
+ with gr.Row():
771
+ with gr.Column(scale=19):
772
+ openai_api_key = gr.Textbox(label="OpenAI API Key (Optional)", interactive=True, type="password", value="", placeholder="Your OpenAI API Key for OpenAI models.", max_lines=1)
773
+ with gr.Column(scale=1):
774
+ openai_api_set = gr.Button(value="Set API KEY", size="sm")
775
+ openai_api_cls = gr.Button(value="Delete API KEY", size="sm")
776
+
777
+ # 詳細設定(折りたたみ)
778
+ with gr.Accordion(label="Advanced Setting", open=False):
779
+ with gr.Row():
780
+ with gr.Column():
781
+ load_in_8bit = gr.Checkbox(label="8bit Quantize (HF)", value=True, interactive=True)
782
+ verbose = gr.Checkbox(label="Verbose (OpenAI, HF)", value=True, interactive=False)
783
+ with gr.Column():
784
+ temperature = gr.Slider(label='Temperature (OpenAI, HF)', minimum=0.0, maximum=1.0, step=0.1, value=0.2, interactive=True)
785
+ with gr.Column():
786
+ min_length = gr.Slider(label="min_length (HF)", minimum=1, maximum=100, step=1, value=10, interactive=True)
787
+ with gr.Column():
788
+ max_new_tokens = gr.Slider(label="max_tokens(OpenAI), max_new_tokens(HF)", minimum=1, maximum=1024, step=1, value=256, interactive=True)
789
+ with gr.Column():
790
+ top_k = gr.Slider(label='top_k (HF)', minimum=1, maximum=100, step=1, value=40, interactive=True)
791
+ with gr.Column():
792
+ top_p = gr.Slider(label='top_p (HF)', minimum=0.01, maximum=0.99, step=0.01, value=0.92, interactive=True)
793
+ with gr.Column():
794
+ repetition_penalty = gr.Slider(label='repetition_penalty (HF)', minimum=0.5, maximum=2, step=0.1, value=1.2, interactive=True)
795
+ with gr.Column():
796
+ num_return_sequences = gr.Slider(label='num_return_sequences (HF)', minimum=1, maximum=20, step=1, value=3, interactive=True)
797
+
798
+ with gr.Row():
799
+ with gr.Column(scale=2):
800
+ config_btn = gr.Button(value="Configure")
801
+ with gr.Column(scale=13):
802
+ status_cfg = gr.Textbox(show_label=False, interactive=False, value="モデルを設定してください", container=False, max_lines=1)
803
+
804
+ # ボタン等のアクション設定
805
+ openai_api_set.click(openai_api_setfn, inputs=[openai_api_key], outputs=[status_cfg], show_progress="full")
806
+ openai_api_cls.click(openai_api_clsfn, inputs=[openai_api_key], outputs=[status_cfg, openai_api_key], show_progress="full")
807
+ openai_api_key.submit(openai_api_setfn, inputs=[openai_api_key], outputs=[status_cfg], show_progress="full")
808
+ config_btn.click(
809
+ fn = load_models,
810
+ inputs = [ss, model_id, embedding_id, openai_api_key, load_in_8bit, verbose, temperature,
811
+ min_length, max_new_tokens, top_k, top_p, repetition_penalty, num_return_sequences],
812
+ outputs = [ss, status_cfg],
813
+ queue = True,
814
+ show_progress = "full"
815
+ )
816
+
817
+ # --------------------------------------
818
+ # Reference Tab
819
+ # --------------------------------------
820
+ with gr.TabItem("2. References"):
821
+ urls = gr.TextArea(
822
+ max_lines = 60,
823
+ show_label=False,
824
+ info = "List any reference URLs for Q&A retrieval.",
825
+ placeholder = "https://blog.kikagaku.co.jp/deep-learning-transformer\nhttps://note.com/elyza/n/na405acaca130",
826
+ interactive=True,
827
+ )
828
+
829
+ with gr.Row():
830
+ pdf_paths = gr.File(label="PDFs", height=150, min_width=60, scale=7, file_types=[".pdf"], file_count="multiple", interactive=True)
831
+ header_lim = gr.Number(label="Header (pt)", step=1, value=792, precision=0, min_width=70, scale=1, interactive=True)
832
+ footer_lim = gr.Number(label="Footer (pt)", step=1, value=0, precision=0, min_width=70, scale=1, interactive=True)
833
+ pdf_ref = gr.Textbox(show_label=False, value="A4 Size:\n(下)0-792pt(上)\n *28.35pt/cm", container=False, scale=1, interactive=False)
834
+
835
+ with gr.Row():
836
+ ref_set_btn = gr.Button(value="コンテンツ登録", scale=1)
837
+ ref_clear_btn = gr.Button(value="登録データ削除", scale=1)
838
+ status_ref = gr.Textbox(show_label=False, interactive=False, value="参照データ未登録", container=False, max_lines=1, scale=18)
839
+
840
+ ref_set_btn.click(fn=embed_ref, inputs=[ss, urls, pdf_paths, header_lim, footer_lim], outputs=[ss, status_ref], queue=True, show_progress="full")
841
+ ref_clear_btn.click(fn=clear_db, inputs=[ss], outputs=[ss, status_ref], show_progress="full")
842
+
843
+ # --------------------------------------
844
+ # Chatbot Tab
845
+ # --------------------------------------
846
+ with gr.TabItem("3. Q&A Chat"):
847
+ chat_history = gr.Chatbot([], elem_id="chatbot").style(height=600, color_map=('green', 'gray'))
848
+ with gr.Row():
849
+ with gr.Column(scale=95):
850
+ query = gr.Textbox(
851
+ show_label=False,
852
+ placeholder="Send a message with [Shift]+[Enter] key.",
853
+ lines=4,
854
+ container=False,
855
+ autofocus=True,
856
+ interactive=True,
857
+ )
858
+ with gr.Column(scale=5):
859
+ qa_flag = gr.Checkbox(label="QA mode", value=True, min_width=60, interactive=True)
860
+ query_send_btn = gr.Button(value="▶")
861
+
862
+ # gr.Examples(["機械学習について説明してください"], inputs=[query])
863
+ query.submit(user, [ss, query], [ss, chat_history]).then(bot, [ss, query, qa_flag], [ss, query]).then(show_response, [ss], [chat_history])
864
+ query_send_btn.click(user, [ss, query], [ss, chat_history]).then(bot, [ss, query, qa_flag], [ss, query]).then(show_response, [ss], [chat_history])
865
+
866
+ if __name__ == "__main__":
867
+ demo.queue(concurrency_count=5)
868
+ demo.launch(debug=True, inbrowser=True)
869
+
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ bitsandbytes
3
+ transformers
4
+ sentence_transformers
5
+ sentencepiece
6
+ accelerate
7
+ bitsandbytes
8
+ langchain
9
+ xformers
10
+ chromadb
11
+ gradio
12
+ openai
13
+ tiktoken
14
+ fugashi
15
+ ipadic
16
+ unstructured
17
+ selenium
18
+ pypdf