AllenYkl commited on
Commit
f2c9c98
1 Parent(s): 2652cfb

Update bin_public/app/llama_func.py

Browse files
Files changed (1) hide show
  1. bin_public/app/llama_func.py +70 -45
bin_public/app/llama_func.py CHANGED
@@ -1,4 +1,7 @@
1
- from llama_index import GPTSimpleVectorIndex
 
 
 
2
  from llama_index import download_loader
3
  from llama_index import (
4
  Document,
@@ -8,10 +11,33 @@ from llama_index import (
8
  RefinePrompt,
9
  )
10
  from langchain.llms import OpenAI
 
11
  import colorama
 
 
 
12
 
 
13
  from bin_public.utils.utils import *
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def get_documents(file_src):
17
  documents = []
@@ -50,16 +76,14 @@ def get_documents(file_src):
50
 
51
 
52
  def construct_index(
53
- api_key,
54
- file_src,
55
- max_input_size=4096,
56
- num_outputs=1,
57
- max_chunk_overlap=20,
58
- chunk_size_limit=600,
59
- embedding_limit=None,
60
- separator=" ",
61
- num_children=10,
62
- max_keywords_per_chunk=10,
63
  ):
64
  os.environ["OPENAI_API_KEY"] = api_key
65
  chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
@@ -67,40 +91,40 @@ def construct_index(
67
  separator = " " if separator == "" else separator
68
 
69
  llm_predictor = LLMPredictor(
70
- llm=OpenAI(model_name="gpt-3.5-turbo-0301", openai_api_key=api_key)
71
- )
72
- prompt_helper = PromptHelper(
73
- max_input_size,
74
- num_outputs,
75
- max_chunk_overlap,
76
- embedding_limit,
77
- chunk_size_limit,
78
- separator=separator,
79
  )
80
- documents, index_name = get_documents(file_src)
 
81
  if os.path.exists(f"./index/{index_name}.json"):
82
  logging.info("找到了缓存的索引文件,加载中……")
83
  return GPTSimpleVectorIndex.load_from_disk(f"./index/{index_name}.json")
84
  else:
85
  try:
86
- logging.debug("构建索引中……")
87
- index = GPTSimpleVectorIndex(
88
- documents, llm_predictor=llm_predictor, prompt_helper=prompt_helper
 
 
89
  )
90
- # os.makedirs("./index", exist_ok=True)
91
- # index.save_to_disk(f"./index/{index_name}.json")
 
 
92
  return index
 
93
  except Exception as e:
 
94
  print(e)
95
  return None
96
 
97
 
98
  def chat_ai(
99
- api_key,
100
- index,
101
- question,
102
- context,
103
- chatbot,
 
104
  ):
105
  os.environ["OPENAI_API_KEY"] = api_key
106
 
@@ -113,8 +137,9 @@ def chat_ai(
113
  replace_today(PROMPT_TEMPLATE),
114
  REFINE_TEMPLATE,
115
  SIM_K,
116
- INDEX_QUERY_TEMPERATURE,
117
  context,
 
118
  )
119
  if response is None:
120
  status_text = "查询失败,请换个问法试试"
@@ -130,21 +155,22 @@ def chat_ai(
130
 
131
 
132
  def ask_ai(
133
- api_key,
134
- index,
135
- question,
136
- prompt_tmpl,
137
- refine_tmpl,
138
- sim_k=1,
139
- temprature=0,
140
- prefix_messages=[],
 
141
  ):
142
  os.environ["OPENAI_API_KEY"] = api_key
143
 
144
  logging.debug("Index file found")
145
  logging.debug("Querying index...")
146
  llm_predictor = LLMPredictor(
147
- llm=OpenAI(
148
  temperature=temprature,
149
  model_name="gpt-3.5-turbo-0301",
150
  prefix_messages=prefix_messages,
@@ -152,11 +178,10 @@ def ask_ai(
152
  )
153
 
154
  response = None # Initialize response variable to avoid UnboundLocalError
155
- qa_prompt = QuestionAnswerPrompt(prompt_tmpl)
156
- rf_prompt = RefinePrompt(refine_tmpl)
157
  response = index.query(
158
  question,
159
- llm_predictor=llm_predictor,
160
  similarity_top_k=sim_k,
161
  text_qa_template=qa_prompt,
162
  refine_template=rf_prompt,
@@ -170,7 +195,7 @@ def ask_ai(
170
  for index, node in enumerate(response.source_nodes):
171
  brief = node.source_text[:25].replace("\n", "")
172
  nodes.append(
173
- f"<details><summary>[{index+1}]\t{brief}...</summary><p>{node.source_text}</p></details>"
174
  )
175
  new_response = ret_text + "\n----------\n" + "\n\n".join(nodes)
176
  logging.info(
 
1
+ import os
2
+ import logging
3
+
4
+ from llama_index import GPTSimpleVectorIndex, ServiceContext
5
  from llama_index import download_loader
6
  from llama_index import (
7
  Document,
 
11
  RefinePrompt,
12
  )
13
  from langchain.llms import OpenAI
14
+ from langchain.chat_models import ChatOpenAI
15
  import colorama
16
+ import PyPDF2
17
+ from tqdm import tqdm
18
+ import hashlib
19
 
20
+ from bin_public.config.presets import *
21
  from bin_public.utils.utils import *
22
 
23
+ def get_index_name(file_src):
24
+ file_paths = [x.name for x in file_src]
25
+ file_paths.sort(key=lambda x: os.path.basename(x))
26
+
27
+ md5_hash = hashlib.md5()
28
+ for file_path in file_paths:
29
+ with open(file_path, "rb") as f:
30
+ while chunk := f.read(8192):
31
+ md5_hash.update(chunk)
32
+
33
+ return md5_hash.hexdigest()
34
+
35
+ def block_split(text):
36
+ blocks = []
37
+ while len(text) > 0:
38
+ blocks.append(Document(text[:1000]))
39
+ text = text[1000:]
40
+ return blocks
41
 
42
  def get_documents(file_src):
43
  documents = []
 
76
 
77
 
78
  def construct_index(
79
+ api_key,
80
+ file_src,
81
+ max_input_size=4096,
82
+ num_outputs=5,
83
+ max_chunk_overlap=20,
84
+ chunk_size_limit=600,
85
+ embedding_limit=None,
86
+ separator=" "
 
 
87
  ):
88
  os.environ["OPENAI_API_KEY"] = api_key
89
  chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
 
91
  separator = " " if separator == "" else separator
92
 
93
  llm_predictor = LLMPredictor(
94
+ llm=ChatOpenAI(model_name="gpt-3.5-turbo-0301", openai_api_key=api_key)
 
 
 
 
 
 
 
 
95
  )
96
+ prompt_helper = PromptHelper(max_input_size = max_input_size, num_output = num_outputs, max_chunk_overlap = max_chunk_overlap, embedding_limit=embedding_limit, chunk_size_limit=600, separator=separator)
97
+ index_name = get_index_name(file_src)
98
  if os.path.exists(f"./index/{index_name}.json"):
99
  logging.info("找到了缓存的索引文件,加载中……")
100
  return GPTSimpleVectorIndex.load_from_disk(f"./index/{index_name}.json")
101
  else:
102
  try:
103
+ documents = get_documents(file_src)
104
+ logging.info("构建索引中……")
105
+ service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper, chunk_size_limit=chunk_size_limit)
106
+ index = GPTSimpleVectorIndex.from_documents(
107
+ documents, service_context=service_context
108
  )
109
+ logging.debug("索引构建完成!")
110
+ os.makedirs("./index", exist_ok=True)
111
+ index.save_to_disk(f"./index/{index_name}.json")
112
+ logging.debug("索引已保存至本地!")
113
  return index
114
+
115
  except Exception as e:
116
+ logging.error("索引构建失败!", e)
117
  print(e)
118
  return None
119
 
120
 
121
  def chat_ai(
122
+ api_key,
123
+ index,
124
+ question,
125
+ context,
126
+ chatbot,
127
+ reply_language,
128
  ):
129
  os.environ["OPENAI_API_KEY"] = api_key
130
 
 
137
  replace_today(PROMPT_TEMPLATE),
138
  REFINE_TEMPLATE,
139
  SIM_K,
140
+ 1.0,
141
  context,
142
+ reply_language,
143
  )
144
  if response is None:
145
  status_text = "查询失败,请换个问法试试"
 
155
 
156
 
157
  def ask_ai(
158
+ api_key,
159
+ index,
160
+ question,
161
+ prompt_tmpl,
162
+ refine_tmpl,
163
+ sim_k=5,
164
+ temprature=0,
165
+ prefix_messages=[],
166
+ reply_language="中文",
167
  ):
168
  os.environ["OPENAI_API_KEY"] = api_key
169
 
170
  logging.debug("Index file found")
171
  logging.debug("Querying index...")
172
  llm_predictor = LLMPredictor(
173
+ llm=ChatOpenAI(
174
  temperature=temprature,
175
  model_name="gpt-3.5-turbo-0301",
176
  prefix_messages=prefix_messages,
 
178
  )
179
 
180
  response = None # Initialize response variable to avoid UnboundLocalError
181
+ qa_prompt = QuestionAnswerPrompt(prompt_tmpl.replace("{reply_language}", reply_language))
182
+ rf_prompt = RefinePrompt(refine_tmpl.replace("{reply_language}", reply_language))
183
  response = index.query(
184
  question,
 
185
  similarity_top_k=sim_k,
186
  text_qa_template=qa_prompt,
187
  refine_template=rf_prompt,
 
195
  for index, node in enumerate(response.source_nodes):
196
  brief = node.source_text[:25].replace("\n", "")
197
  nodes.append(
198
+ f"<details><summary>[{index + 1}]\t{brief}...</summary><p>{node.source_text}</p></details>"
199
  )
200
  new_response = ret_text + "\n----------\n" + "\n\n".join(nodes)
201
  logging.info(