terapyon commited on
Commit
8034497
1 Parent(s): c959977

dev/modify-load-llm (#9)

Browse files

- modify load timing of model (ff920d1f88f3f44ef23ace448c056fa1a8d226e3)

Files changed (1) hide show
  1. app.py +26 -20
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # from time import time
2
  import gradio as gr
3
  from langchain.chains import RetrievalQA
4
  from langchain.embeddings import OpenAIEmbeddings
@@ -16,17 +16,29 @@ from qdrant_client import QdrantClient
16
  from config import DB_CONFIG, DB_E5_CONFIG
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def _get_config_and_embeddings(collection_name: str | None) -> tuple:
20
  if collection_name is None or collection_name == "E5":
21
  db_config = DB_E5_CONFIG
22
- model_name = "intfloat/multilingual-e5-large"
23
- model_kwargs = {"device": "cpu"}
24
- encode_kwargs = {"normalize_embeddings": False}
25
- embeddings = HuggingFaceEmbeddings(
26
- model_name=model_name,
27
- model_kwargs=model_kwargs,
28
- encode_kwargs=encode_kwargs,
29
- )
30
  elif collection_name == "OpenAI":
31
  db_config = DB_CONFIG
32
  embeddings = OpenAIEmbeddings()
@@ -36,18 +48,10 @@ def _get_config_and_embeddings(collection_name: str | None) -> tuple:
36
 
37
 
38
  def _get_rinna_llm(temperature: float):
39
- model = "rinna/bilingual-gpt-neox-4b-instruction-ppo"
40
- tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
41
- model = AutoModelForCausalLM.from_pretrained(
42
- model,
43
- load_in_8bit=True,
44
- torch_dtype=torch.float16,
45
- device_map="auto",
46
- )
47
  pipe = pipeline(
48
  "text-generation",
49
- model=model,
50
- tokenizer=tokenizer,
51
  max_new_tokens=1024,
52
  temperature=temperature,
53
  )
@@ -139,6 +143,7 @@ def get_related_url(metadata):
139
  def main(
140
  query: str, collection_name: str, model_name: str, option: str, temperature: float
141
  ):
 
142
  qa = get_retrieval_qa(collection_name, model_name, temperature, option)
143
  try:
144
  result = qa(query)
@@ -146,7 +151,8 @@ def main(
146
  return "回答が見つかりませんでした。別な質問をしてみてください", str(e)
147
  else:
148
  metadata = [s.metadata for s in result["source_documents"]]
149
- html = "<div>" + "\n".join(get_related_url(metadata)) + "</div>"
 
150
 
151
  return result["result"], html
152
 
 
1
+ from time import time
2
  import gradio as gr
3
  from langchain.chains import RetrievalQA
4
  from langchain.embeddings import OpenAIEmbeddings
 
16
  from config import DB_CONFIG, DB_E5_CONFIG
17
 
18
 
19
+ E5_MODEL_NAME = "intfloat/multilingual-e5-large"
20
+ E5_MODEL_KWARGS = {"device": "cuda:0" if torch.cuda.is_available() else "cpu"}
21
+ E5_ENCODE_KWARGS = {"normalize_embeddings": False}
22
+ E5_EMBEDDINGS = HuggingFaceEmbeddings(
23
+ model_name=E5_MODEL_NAME,
24
+ model_kwargs=E5_MODEL_KWARGS,
25
+ encode_kwargs=E5_ENCODE_KWARGS,
26
+ )
27
+
28
+ RINNA_MODEL_NAME = "rinna/bilingual-gpt-neox-4b-instruction-ppo"
29
+ RINNA_TOKENIZER = AutoTokenizer.from_pretrained(RINNA_MODEL_NAME, use_fast=False)
30
+ RINNA_MODEL = AutoModelForCausalLM.from_pretrained(
31
+ RINNA_MODEL_NAME,
32
+ load_in_8bit=True,
33
+ torch_dtype=torch.float16,
34
+ device_map="auto",
35
+ )
36
+
37
+
38
  def _get_config_and_embeddings(collection_name: str | None) -> tuple:
39
  if collection_name is None or collection_name == "E5":
40
  db_config = DB_E5_CONFIG
41
+ embeddings = E5_EMBEDDINGS
 
 
 
 
 
 
 
42
  elif collection_name == "OpenAI":
43
  db_config = DB_CONFIG
44
  embeddings = OpenAIEmbeddings()
 
48
 
49
 
50
  def _get_rinna_llm(temperature: float):
 
 
 
 
 
 
 
 
51
  pipe = pipeline(
52
  "text-generation",
53
+ model=RINNA_MODEL,
54
+ tokenizer=RINNA_TOKENIZER,
55
  max_new_tokens=1024,
56
  temperature=temperature,
57
  )
 
143
  def main(
144
  query: str, collection_name: str, model_name: str, option: str, temperature: float
145
  ):
146
+ now = time()
147
  qa = get_retrieval_qa(collection_name, model_name, temperature, option)
148
  try:
149
  result = qa(query)
 
151
  return "回答が見つかりませんでした。別な質問をしてみてください", str(e)
152
  else:
153
  metadata = [s.metadata for s in result["source_documents"]]
154
+ sec_html = f"<p>実行時間: {(time() - now):.2f}秒</p>"
155
+ html = "<div>" + sec_html + "\n".join(get_related_url(metadata)) + "</div>"
156
 
157
  return result["result"], html
158