terapyon commited on
Commit
143e47c
1 Parent(s): 6125df0

Embeddings multilingual-e5-largeとLLM rinnaを使えるようにした

Browse files
Files changed (4) hide show
  1. app.py +103 -33
  2. config.py +12 -6
  3. requirements.txt +2 -0
  4. store.py +40 -29
app.py CHANGED
@@ -1,45 +1,105 @@
1
- from time import time
2
  import gradio as gr
3
  from langchain.chains import RetrievalQA
4
- # from langchain.embeddings import OpenAIEmbeddings
5
  from langchain.embeddings import HuggingFaceEmbeddings
6
- from langchain.embeddings import GPT4AllEmbeddings
7
- from langchain.llms import OpenAI
 
 
 
 
8
  from langchain.chat_models import ChatOpenAI
9
  from langchain.vectorstores import Qdrant
10
  from openai.error import InvalidRequestError
11
  from qdrant_client import QdrantClient
12
- from config import DB_CONFIG
13
 
14
 
15
- PERSIST_DIR_NAME = "nvdajp-book"
16
- # MODEL_NAME = "text-davinci-003"
17
- # MODEL_NAME = "gpt-3.5-turbo"
18
- # MODEL_NAME = "gpt-4"
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
 
21
- def get_retrieval_qa(model_name: str | None, temperature: int, option: str | None) -> RetrievalQA:
22
- # embeddings = OpenAIEmbeddings()
23
- model_name = "sentence-transformers/all-mpnet-base-v2"
24
- model_kwargs = {'device': 'cpu'}
25
- encode_kwargs = {'normalize_embeddings': False}
26
- embeddings = HuggingFaceEmbeddings(
27
- model_name=model_name,
28
- model_kwargs=model_kwargs,
29
- encode_kwargs=encode_kwargs,
30
  )
31
- # embeddings = GPT4AllEmbeddings()
32
- db_url, db_api_key, db_collection_name = DB_CONFIG
33
- client = QdrantClient(url=db_url, api_key=db_api_key)
34
- db = Qdrant(client=client, collection_name=db_collection_name, embeddings=embeddings)
 
 
 
 
 
 
 
 
 
 
 
35
  if model_name is None:
36
- model = "gpt-3.5-turbo"
 
 
37
  elif model_name == "GPT-3.5":
38
  model = "gpt-3.5-turbo"
39
  elif model_name == "GPT-4":
40
  model = "gpt-4"
41
  else:
42
- model = "gpt-3.5-turbo"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  if option is None or option == "All":
44
  retriever = db.as_retriever()
45
  else:
@@ -48,14 +108,17 @@ def get_retrieval_qa(model_name: str | None, temperature: int, option: str | Non
48
  "filter": {"category": option},
49
  }
50
  )
 
 
 
 
 
51
  result = RetrievalQA.from_chain_type(
52
- llm=ChatOpenAI(
53
- model=model,
54
- temperature=temperature
55
- ),
56
  chain_type="stuff",
57
  retriever=retriever,
58
  return_source_documents=True,
 
59
  )
60
  return result
61
 
@@ -73,8 +136,10 @@ def get_related_url(metadata):
73
  yield f'<p>URL: <a href="{url}">{url}</a> (category: {category})</p>'
74
 
75
 
76
- def main(query: str, model_name: str, option: str, temperature: int):
77
- qa = get_retrieval_qa(model_name, temperature, option)
 
 
78
  try:
79
  result = qa(query)
80
  except InvalidRequestError as e:
@@ -90,9 +155,14 @@ nvdajp_book_qa = gr.Interface(
90
  fn=main,
91
  inputs=[
92
  gr.Textbox(label="query"),
93
- gr.Radio(["GPT-3.5", "GPT-4"], label="Model", info="選択なしで「3.5」を使用"),
94
- gr.Radio(["All", "ja-book", "ja-nvda-user-guide", "en-nvda-user-guide"], label="絞り込み", info="ドキュメント制限する?"),
95
- gr.Slider(0, 2)
 
 
 
 
 
96
  ],
97
  outputs=[gr.Textbox(label="answer"), gr.outputs.HTML()],
98
  )
 
1
+ # from time import time
2
  import gradio as gr
3
  from langchain.chains import RetrievalQA
4
+ from langchain.embeddings import OpenAIEmbeddings
5
  from langchain.embeddings import HuggingFaceEmbeddings
6
+ from langchain.prompts import PromptTemplate
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
9
+ from langchain.llms import HuggingFacePipeline
10
+
11
+ # from langchain.llms import OpenAI
12
  from langchain.chat_models import ChatOpenAI
13
  from langchain.vectorstores import Qdrant
14
  from openai.error import InvalidRequestError
15
  from qdrant_client import QdrantClient
16
+ from config import DB_CONFIG, DB_E5_CONFIG
17
 
18
 
19
+ def _get_config_and_embeddings(collection_name: str | None) -> tuple:
20
+ if collection_name is None or collection_name == "E5":
21
+ db_config = DB_E5_CONFIG
22
+ model_name = "intfloat/multilingual-e5-large"
23
+ model_kwargs = {"device": "cpu"}
24
+ encode_kwargs = {"normalize_embeddings": False}
25
+ embeddings = HuggingFaceEmbeddings(
26
+ model_name=model_name,
27
+ model_kwargs=model_kwargs,
28
+ encode_kwargs=encode_kwargs,
29
+ )
30
+ elif collection_name == "OpenAI":
31
+ db_config = DB_CONFIG
32
+ embeddings = OpenAIEmbeddings()
33
+ else:
34
+ raise ValueError("Unknow collection name")
35
+ return db_config, embeddings
36
 
37
 
38
+ def _get_rinna_llm(temperature: float):
39
+ model = "rinna/bilingual-gpt-neox-4b-instruction-ppo"
40
+ tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
41
+ model = AutoModelForCausalLM.from_pretrained(
42
+ model,
43
+ load_in_8bit=True,
44
+ torch_dtype=torch.float16,
45
+ device_map="auto",
 
46
  )
47
+ pipe = pipeline(
48
+ "text-generation",
49
+ model=model,
50
+ tokenizer=tokenizer,
51
+ max_new_tokens=1024,
52
+ temperature=temperature,
53
+ )
54
+ llm = HuggingFacePipeline(pipeline=pipe)
55
+ return llm
56
+
57
+
58
+ def _get_llm_model(
59
+ model_name: str | None,
60
+ temperature: float,
61
+ ):
62
  if model_name is None:
63
+ model = "rinna"
64
+ elif model_name == "rinna":
65
+ model = "rinna"
66
  elif model_name == "GPT-3.5":
67
  model = "gpt-3.5-turbo"
68
  elif model_name == "GPT-4":
69
  model = "gpt-4"
70
  else:
71
+ raise ValueError("Unknow model name")
72
+ if model.startswith("gpt"):
73
+ llm = ChatOpenAI(model=model, temperature=temperature)
74
+ elif model == "rinna":
75
+ llm = _get_rinna_llm(temperature)
76
+ return llm
77
+
78
+
79
+ # prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
80
+
81
+ # {context}
82
+
83
+ # Question: {question}
84
+ # Answer in Japanese:"""
85
+ # PROMPT = PromptTemplate(
86
+ # template=prompt_template, input_variables=["context", "question"]
87
+ # )
88
+
89
+
90
+ def get_retrieval_qa(
91
+ collection_name: str | None,
92
+ model_name: str | None,
93
+ temperature: float,
94
+ option: str | None,
95
+ ) -> RetrievalQA:
96
+ db_config, embeddings = _get_config_and_embeddings(collection_name)
97
+ db_url, db_api_key, db_collection_name = db_config
98
+ client = QdrantClient(url=db_url, api_key=db_api_key)
99
+ db = Qdrant(
100
+ client=client, collection_name=db_collection_name, embeddings=embeddings
101
+ )
102
+
103
  if option is None or option == "All":
104
  retriever = db.as_retriever()
105
  else:
 
108
  "filter": {"category": option},
109
  }
110
  )
111
+
112
+ llm = _get_llm_model(model_name, temperature)
113
+
114
+ # chain_type_kwargs = {"prompt": PROMPT}
115
+
116
  result = RetrievalQA.from_chain_type(
117
+ llm=llm,
 
 
 
118
  chain_type="stuff",
119
  retriever=retriever,
120
  return_source_documents=True,
121
+ # chain_type_kwargs=chain_type_kwargs,
122
  )
123
  return result
124
 
 
136
  yield f'<p>URL: <a href="{url}">{url}</a> (category: {category})</p>'
137
 
138
 
139
+ def main(
140
+ query: str, collection_name: str, model_name: str, option: str, temperature: float
141
+ ):
142
+ qa = get_retrieval_qa(collection_name, model_name, temperature, option)
143
  try:
144
  result = qa(query)
145
  except InvalidRequestError as e:
 
155
  fn=main,
156
  inputs=[
157
  gr.Textbox(label="query"),
158
+ gr.Radio(["E5", "OpenAI"], label="Embedding", info="選択なしで「E5」を使用"),
159
+ gr.Radio(["rinna", "GPT-3.5", "GPT-4"], label="Model", info="選択なしで「rinna」を使用"),
160
+ gr.Radio(
161
+ ["All", "ja-book", "ja-nvda-user-guide", "en-nvda-user-guide"],
162
+ label="絞り込み",
163
+ info="ドキュメント制限する?",
164
+ ),
165
+ gr.Slider(0, 2),
166
  ],
167
  outputs=[gr.Textbox(label="answer"), gr.outputs.HTML()],
168
  )
config.py CHANGED
@@ -1,21 +1,27 @@
1
  import os
2
 
3
 
4
- SAAS = False
5
 
6
 
7
- def get_db_config():
8
- url = os.environ["QDRANT_URL"]
 
 
 
9
  api_key = os.environ["QDRANT_API_KEY"]
10
- collection_name = "nvdajp-book"
11
  return url, api_key, collection_name
12
 
13
 
14
- def get_local_db_congin():
15
  url = "localhost"
16
  # api_key = os.environ["QDRANT_API_KEY"]
17
- collection_name = "nvdajp-book"
18
  return url, None, collection_name
19
 
20
 
21
  DB_CONFIG = get_db_config() if SAAS else get_local_db_congin()
 
 
 
 
1
  import os
2
 
3
 
4
+ SAAS = True
5
 
6
 
7
+ def get_db_config(cname="nvdajp-book"):
8
+ if cname == "nvdajp-book":
9
+ url = os.environ["QDRANT_URL"]
10
+ elif cname == "nvdajp-book-e5":
11
+ url = os.environ["QDRANT_E5_URL"]
12
  api_key = os.environ["QDRANT_API_KEY"]
13
+ collection_name = cname
14
  return url, api_key, collection_name
15
 
16
 
17
+ def get_local_db_congin(cname="nvdajp-book"):
18
  url = "localhost"
19
  # api_key = os.environ["QDRANT_API_KEY"]
20
+ collection_name = cname
21
  return url, None, collection_name
22
 
23
 
24
  DB_CONFIG = get_db_config() if SAAS else get_local_db_congin()
25
+ DB_E5_CONFIG = (
26
+ get_db_config("nvdajp-book-e5") if SAAS else get_local_db_congin("nvdajp-book-e5")
27
+ )
requirements.txt CHANGED
@@ -4,3 +4,5 @@ tiktoken
4
  gradio
5
  qdrant-client
6
  beautifulsoup4
 
 
 
4
  gradio
5
  qdrant-client
6
  beautifulsoup4
7
+ accelerate
8
+ bitsandbytes
store.py CHANGED
@@ -1,12 +1,12 @@
1
  from langchain.document_loaders import ReadTheDocsLoader
2
  from langchain.text_splitter import RecursiveCharacterTextSplitter
3
- # from langchain.embeddings import OpenAIEmbeddings
4
  from langchain.embeddings import HuggingFaceEmbeddings
5
- from langchain.embeddings import GPT4AllEmbeddings
6
  from langchain.vectorstores import Qdrant
 
7
  # from qdrant_client import QdrantClient
8
  from nvda_ug_loader import NVDAUserGuideLoader
9
- from config import DB_CONFIG
10
 
11
 
12
  CHUNK_SIZE = 500
@@ -25,46 +25,55 @@ def get_documents(path: str):
25
  for doc in docs:
26
  org_metadata = doc.metadata
27
  source = _remove_prefix_path(org_metadata["source"])
28
- add_meta = {"category": category, "source": source, "url": f"{base_url}{source}"}
 
 
 
 
29
  doc.metadata = org_metadata | add_meta
30
  yield doc
31
 
32
 
33
  def get_text_chunk(docs):
34
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=0)
 
 
35
  texts = text_splitter.split_documents(docs)
36
  return texts
37
 
38
 
39
- def store(texts):
40
- # embeddings = OpenAIEmbeddings()
41
- model_name = "sentence-transformers/all-mpnet-base-v2"
42
- model_kwargs = {'device': 'cuda'}
43
- encode_kwargs = {'normalize_embeddings': False}
44
- embeddings = HuggingFaceEmbeddings(
45
- model_name=model_name,
46
- model_kwargs=model_kwargs,
47
- encode_kwargs=encode_kwargs,
48
- )
49
- # embeddings = GPT4AllEmbeddings()
50
- db_url, db_api_key, db_collection_name = DB_CONFIG
51
- # client = QdrantClient(url=db_url, api_key=db_api_key, prefer_grpc=True)
 
 
 
52
  _ = Qdrant.from_documents(
53
  texts,
54
  embeddings,
55
  url=db_url,
56
  api_key=db_api_key,
57
- collection_name=db_collection_name
58
  )
59
 
60
 
61
- def rtd_main(path: str):
62
  docs = get_documents(path)
63
  texts = get_text_chunk(docs)
64
- store(texts)
65
 
66
 
67
- def nul_main(url: str):
68
  if "www.nvda.jp" in url:
69
  category = "ja-nvda-user-guide"
70
  else:
@@ -72,25 +81,27 @@ def nul_main(url: str):
72
  loader = NVDAUserGuideLoader(url, category)
73
  docs = loader.load()
74
  texts = get_text_chunk(docs)
75
- store(texts)
76
 
77
 
78
  if __name__ == "__main__":
79
  """
80
- $ python store.py rtd "data/rtdocs/nvdajp-book.readthedocs.io/ja/latest"
81
- $ python store.py nul "https://www.nvaccess.org/files/nvda/documentation/userGuide.html"
82
- $ python store.py nul "https://www.nvda.jp/nvda2023.1jp/ja/userGuide.html"
83
  """
84
  import sys
 
85
  args = sys.argv
86
- if len(args) != 3:
87
  print("No args, you need two args for type, html_path")
88
  else:
89
  type_ = args[1]
90
  path = args[2]
 
91
  if type_ == "rtd":
92
- rtd_main(path)
93
  elif type_ == "nul":
94
- nul_main(path)
95
  else:
96
  print("No type for store")
 
1
  from langchain.document_loaders import ReadTheDocsLoader
2
  from langchain.text_splitter import RecursiveCharacterTextSplitter
3
+ from langchain.embeddings import OpenAIEmbeddings
4
  from langchain.embeddings import HuggingFaceEmbeddings
 
5
  from langchain.vectorstores import Qdrant
6
+
7
  # from qdrant_client import QdrantClient
8
  from nvda_ug_loader import NVDAUserGuideLoader
9
+ from config import DB_CONFIG, DB_E5_CONFIG
10
 
11
 
12
  CHUNK_SIZE = 500
 
25
  for doc in docs:
26
  org_metadata = doc.metadata
27
  source = _remove_prefix_path(org_metadata["source"])
28
+ add_meta = {
29
+ "category": category,
30
+ "source": source,
31
+ "url": f"{base_url}{source}",
32
+ }
33
  doc.metadata = org_metadata | add_meta
34
  yield doc
35
 
36
 
37
  def get_text_chunk(docs):
38
+ text_splitter = RecursiveCharacterTextSplitter(
39
+ chunk_size=CHUNK_SIZE, chunk_overlap=0
40
+ )
41
  texts = text_splitter.split_documents(docs)
42
  return texts
43
 
44
 
45
+ def store(texts, mname):
46
+ if mname == "openai":
47
+ embeddings = OpenAIEmbeddings()
48
+ db_url, db_api_key, db_collection_name = DB_CONFIG
49
+ elif mname == "e5":
50
+ model_name = "intfloat/multilingual-e5-large"
51
+ model_kwargs = {"device": "cuda"}
52
+ encode_kwargs = {"normalize_embeddings": False}
53
+ embeddings = HuggingFaceEmbeddings(
54
+ model_name=model_name,
55
+ model_kwargs=model_kwargs,
56
+ encode_kwargs=encode_kwargs,
57
+ )
58
+ db_url, db_api_key, db_collection_name = DB_E5_CONFIG
59
+ else:
60
+ raise ValueError("Invalid mname")
61
  _ = Qdrant.from_documents(
62
  texts,
63
  embeddings,
64
  url=db_url,
65
  api_key=db_api_key,
66
+ collection_name=db_collection_name,
67
  )
68
 
69
 
70
+ def rtd_main(path: str, mname: str):
71
  docs = get_documents(path)
72
  texts = get_text_chunk(docs)
73
+ store(texts, mname)
74
 
75
 
76
+ def nul_main(url: str, mname: str):
77
  if "www.nvda.jp" in url:
78
  category = "ja-nvda-user-guide"
79
  else:
 
81
  loader = NVDAUserGuideLoader(url, category)
82
  docs = loader.load()
83
  texts = get_text_chunk(docs)
84
+ store(texts, mname)
85
 
86
 
87
  if __name__ == "__main__":
88
  """
89
+ $ python store.py rtd "data/rtdocs/nvdajp-book.readthedocs.io/ja/latest" openai
90
+ $ python store.py nul "https://www.nvaccess.org/files/nvda/documentation/userGuide.html" e5
91
+ $ python store.py nul "https://www.nvda.jp/nvda2023.1jp/ja/userGuide.html" e5
92
  """
93
  import sys
94
+
95
  args = sys.argv
96
+ if len(args) != 4:
97
  print("No args, you need two args for type, html_path")
98
  else:
99
  type_ = args[1]
100
  path = args[2]
101
+ mname = args[3]
102
  if type_ == "rtd":
103
+ rtd_main(path, mname)
104
  elif type_ == "nul":
105
+ nul_main(path, mname)
106
  else:
107
  print("No type for store")