JohnSmith9982 commited on
Commit
44846b2
1 Parent(s): 8ad9e26

Upload 7 files

Browse files
Files changed (2) hide show
  1. modules/chat_func.py +2 -2
  2. modules/llama_func.py +42 -38
modules/chat_func.py CHANGED
@@ -155,7 +155,7 @@ def stream_predict(
155
  yield get_return_value()
156
  error_json_str = ""
157
 
158
- for chunk in response.iter_lines():
159
  if counter == 0:
160
  counter += 1
161
  continue
@@ -272,7 +272,7 @@ def predict(
272
  if reply_language == "跟随问题语言(不稳定)":
273
  reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
274
  if files:
275
- msg = "构建索引中……(这可能需要比较久的时间)"
276
  logging.info(msg)
277
  yield chatbot+[(inputs, "")], history, msg, all_token_counts
278
  index = construct_index(openai_api_key, file_src=files)
 
155
  yield get_return_value()
156
  error_json_str = ""
157
 
158
+ for chunk in tqdm(response.iter_lines()):
159
  if counter == 0:
160
  counter += 1
161
  continue
 
272
  if reply_language == "跟随问题语言(不稳定)":
273
  reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
274
  if files:
275
+ msg = "加载索引中……(这可能需要几分钟)"
276
  logging.info(msg)
277
  yield chatbot+[(inputs, "")], history, msg, all_token_counts
278
  index = construct_index(openai_api_key, file_src=files)
modules/llama_func.py CHANGED
@@ -13,54 +13,57 @@ from llama_index import (
13
  from langchain.llms import OpenAI
14
  import colorama
15
 
16
-
17
  from modules.presets import *
18
  from modules.utils import *
19
 
 
 
 
 
 
 
20
 
21
  def get_documents(file_src):
22
  documents = []
23
- index_name = ""
24
  logging.debug("Loading documents...")
25
  logging.debug(f"file_src: {file_src}")
26
  for file in file_src:
27
- logging.debug(f"file: {file.name}")
28
- index_name += file.name
29
  if os.path.splitext(file.name)[1] == ".pdf":
30
  logging.debug("Loading PDF...")
31
  CJKPDFReader = download_loader("CJKPDFReader")
32
  loader = CJKPDFReader()
33
- documents += loader.load_data(file=file.name)
34
  elif os.path.splitext(file.name)[1] == ".docx":
35
  logging.debug("Loading DOCX...")
36
  DocxReader = download_loader("DocxReader")
37
  loader = DocxReader()
38
- documents += loader.load_data(file=file.name)
39
  elif os.path.splitext(file.name)[1] == ".epub":
40
  logging.debug("Loading EPUB...")
41
  EpubReader = download_loader("EpubReader")
42
  loader = EpubReader()
43
- documents += loader.load_data(file=file.name)
44
  else:
45
  logging.debug("Loading text file...")
46
  with open(file.name, "r", encoding="utf-8") as f:
47
- text = add_space(f.read())
48
- documents += [Document(text)]
49
- index_name = sha1sum(index_name)
50
- return documents, index_name
51
 
52
 
53
  def construct_index(
54
- api_key,
55
- file_src,
56
- max_input_size=4096,
57
- num_outputs=1,
58
- max_chunk_overlap=20,
59
- chunk_size_limit=600,
60
- embedding_limit=None,
61
- separator=" ",
62
- num_children=10,
63
- max_keywords_per_chunk=10,
64
  ):
65
  os.environ["OPENAI_API_KEY"] = api_key
66
  chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
@@ -78,12 +81,13 @@ def construct_index(
78
  chunk_size_limit,
79
  separator=separator,
80
  )
81
- documents, index_name = get_documents(file_src)
82
  if os.path.exists(f"./index/{index_name}.json"):
83
  logging.info("找到了缓存的索引文件,加载中……")
84
  return GPTSimpleVectorIndex.load_from_disk(f"./index/{index_name}.json")
85
  else:
86
  try:
 
87
  logging.debug("构建索引中……")
88
  index = GPTSimpleVectorIndex(
89
  documents, llm_predictor=llm_predictor, prompt_helper=prompt_helper
@@ -97,12 +101,12 @@ def construct_index(
97
 
98
 
99
  def chat_ai(
100
- api_key,
101
- index,
102
- question,
103
- context,
104
- chatbot,
105
- reply_language,
106
  ):
107
  os.environ["OPENAI_API_KEY"] = api_key
108
 
@@ -133,15 +137,15 @@ def chat_ai(
133
 
134
 
135
  def ask_ai(
136
- api_key,
137
- index,
138
- question,
139
- prompt_tmpl,
140
- refine_tmpl,
141
- sim_k=1,
142
- temprature=0,
143
- prefix_messages=[],
144
- reply_language="中文",
145
  ):
146
  os.environ["OPENAI_API_KEY"] = api_key
147
 
@@ -174,7 +178,7 @@ def ask_ai(
174
  for index, node in enumerate(response.source_nodes):
175
  brief = node.source_text[:25].replace("\n", "")
176
  nodes.append(
177
- f"<details><summary>[{index+1}]\t{brief}...</summary><p>{node.source_text}</p></details>"
178
  )
179
  new_response = ret_text + "\n----------\n" + "\n\n".join(nodes)
180
  logging.info(
 
13
  from langchain.llms import OpenAI
14
  import colorama
15
 
 
16
  from modules.presets import *
17
  from modules.utils import *
18
 
19
+ def get_index_name(file_src):
20
+ index_name = ""
21
+ for file in file_src:
22
+ index_name += os.path.basename(file.name)
23
+ index_name = sha1sum(index_name)
24
+ return index_name
25
 
26
  def get_documents(file_src):
27
  documents = []
 
28
  logging.debug("Loading documents...")
29
  logging.debug(f"file_src: {file_src}")
30
  for file in file_src:
31
+ logging.info(f"loading file: {file.name}")
 
32
  if os.path.splitext(file.name)[1] == ".pdf":
33
  logging.debug("Loading PDF...")
34
  CJKPDFReader = download_loader("CJKPDFReader")
35
  loader = CJKPDFReader()
36
+ text_raw = loader.load_data(file=file.name)[0].text
37
  elif os.path.splitext(file.name)[1] == ".docx":
38
  logging.debug("Loading DOCX...")
39
  DocxReader = download_loader("DocxReader")
40
  loader = DocxReader()
41
+ text_raw = loader.load_data(file=file.name)[0].text
42
  elif os.path.splitext(file.name)[1] == ".epub":
43
  logging.debug("Loading EPUB...")
44
  EpubReader = download_loader("EpubReader")
45
  loader = EpubReader()
46
+ text_raw = loader.load_data(file=file.name)[0].text
47
  else:
48
  logging.debug("Loading text file...")
49
  with open(file.name, "r", encoding="utf-8") as f:
50
+ text_raw = f.read()
51
+ text = add_space(text_raw)
52
+ documents += [Document(text)]
53
+ return documents
54
 
55
 
56
  def construct_index(
57
+ api_key,
58
+ file_src,
59
+ max_input_size=4096,
60
+ num_outputs=1,
61
+ max_chunk_overlap=20,
62
+ chunk_size_limit=600,
63
+ embedding_limit=None,
64
+ separator=" ",
65
+ num_children=10,
66
+ max_keywords_per_chunk=10,
67
  ):
68
  os.environ["OPENAI_API_KEY"] = api_key
69
  chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
 
81
  chunk_size_limit,
82
  separator=separator,
83
  )
84
+ index_name = get_index_name(file_src)
85
  if os.path.exists(f"./index/{index_name}.json"):
86
  logging.info("找到了缓存的索引文件,加载中……")
87
  return GPTSimpleVectorIndex.load_from_disk(f"./index/{index_name}.json")
88
  else:
89
  try:
90
+ documents = get_documents(file_src)
91
  logging.debug("构建索引中……")
92
  index = GPTSimpleVectorIndex(
93
  documents, llm_predictor=llm_predictor, prompt_helper=prompt_helper
 
101
 
102
 
103
  def chat_ai(
104
+ api_key,
105
+ index,
106
+ question,
107
+ context,
108
+ chatbot,
109
+ reply_language,
110
  ):
111
  os.environ["OPENAI_API_KEY"] = api_key
112
 
 
137
 
138
 
139
  def ask_ai(
140
+ api_key,
141
+ index,
142
+ question,
143
+ prompt_tmpl,
144
+ refine_tmpl,
145
+ sim_k=1,
146
+ temprature=0,
147
+ prefix_messages=[],
148
+ reply_language="中文",
149
  ):
150
  os.environ["OPENAI_API_KEY"] = api_key
151
 
 
178
  for index, node in enumerate(response.source_nodes):
179
  brief = node.source_text[:25].replace("\n", "")
180
  nodes.append(
181
+ f"<details><summary>[{index + 1}]\t{brief}...</summary><p>{node.source_text}</p></details>"
182
  )
183
  new_response = ret_text + "\n----------\n" + "\n\n".join(nodes)
184
  logging.info(