Spaces:
Runtime error
Runtime error
import os | |
import logging | |
from llama_index import GPTSimpleVectorIndex | |
from llama_index import download_loader | |
from llama_index import ( | |
Document, | |
LLMPredictor, | |
PromptHelper, | |
QuestionAnswerPrompt, | |
RefinePrompt, | |
) | |
from langchain.llms import OpenAI | |
import colorama | |
from presets import * | |
from utils import * | |
def get_documents(file_src): | |
documents = [] | |
index_name = "" | |
logging.debug("Loading documents...") | |
logging.debug(f"file_src: {file_src}") | |
for file in file_src: | |
logging.debug(f"file: {file.name}") | |
index_name += file.name | |
if os.path.splitext(file.name)[1] == ".pdf": | |
logging.debug("Loading PDF...") | |
CJKPDFReader = download_loader("CJKPDFReader") | |
loader = CJKPDFReader() | |
documents += loader.load_data(file=file.name) | |
elif os.path.splitext(file.name)[1] == ".docx": | |
logging.debug("Loading DOCX...") | |
DocxReader = download_loader("DocxReader") | |
loader = DocxReader() | |
documents += loader.load_data(file=file.name) | |
elif os.path.splitext(file.name)[1] == ".epub": | |
logging.debug("Loading EPUB...") | |
EpubReader = download_loader("EpubReader") | |
loader = EpubReader() | |
documents += loader.load_data(file=file.name) | |
else: | |
logging.debug("Loading text file...") | |
with open(file.name, "r", encoding="utf-8") as f: | |
text = add_space(f.read()) | |
documents += [Document(text)] | |
index_name = sha1sum(index_name) | |
return documents, index_name | |
def construct_index( | |
api_key, | |
file_src, | |
max_input_size=4096, | |
num_outputs=1, | |
max_chunk_overlap=20, | |
chunk_size_limit=600, | |
embedding_limit=None, | |
separator=" ", | |
num_children=10, | |
max_keywords_per_chunk=10, | |
): | |
os.environ["OPENAI_API_KEY"] = api_key | |
chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit | |
embedding_limit = None if embedding_limit == 0 else embedding_limit | |
separator = " " if separator == "" else separator | |
llm_predictor = LLMPredictor( | |
llm=OpenAI(model_name="gpt-3.5-turbo-0301", openai_api_key=api_key) | |
) | |
prompt_helper = PromptHelper( | |
max_input_size, | |
num_outputs, | |
max_chunk_overlap, | |
embedding_limit, | |
chunk_size_limit, | |
separator=separator, | |
) | |
documents, index_name = get_documents(file_src) | |
if os.path.exists(f"./index/{index_name}.json"): | |
logging.info("找到了缓存的索引文件,加载中……") | |
return GPTSimpleVectorIndex.load_from_disk(f"./index/{index_name}.json") | |
else: | |
try: | |
logging.debug("构建索引中……") | |
index = GPTSimpleVectorIndex( | |
documents, llm_predictor=llm_predictor, prompt_helper=prompt_helper | |
) | |
os.makedirs("./index", exist_ok=True) | |
index.save_to_disk(f"./index/{index_name}.json") | |
return index | |
except Exception as e: | |
print(e) | |
return None | |
def chat_ai( | |
api_key, | |
index, | |
question, | |
context, | |
chatbot, | |
): | |
os.environ["OPENAI_API_KEY"] = api_key | |
logging.info(f"Question: {question}") | |
response, chatbot_display, status_text = ask_ai( | |
api_key, | |
index, | |
question, | |
replace_today(PROMPT_TEMPLATE), | |
REFINE_TEMPLATE, | |
SIM_K, | |
INDEX_QUERY_TEMPRATURE, | |
context, | |
) | |
if response is None: | |
status_text = "查询失败,请换个问法试试" | |
return context, chatbot | |
response = response | |
context.append({"role": "user", "content": question}) | |
context.append({"role": "assistant", "content": response}) | |
chatbot.append((question, chatbot_display)) | |
os.environ["OPENAI_API_KEY"] = "" | |
return context, chatbot, status_text | |
def ask_ai( | |
api_key, | |
index, | |
question, | |
prompt_tmpl, | |
refine_tmpl, | |
sim_k=1, | |
temprature=0, | |
prefix_messages=[], | |
): | |
os.environ["OPENAI_API_KEY"] = api_key | |
logging.debug("Index file found") | |
logging.debug("Querying index...") | |
llm_predictor = LLMPredictor( | |
llm=OpenAI( | |
temperature=temprature, | |
model_name="gpt-3.5-turbo-0301", | |
prefix_messages=prefix_messages, | |
) | |
) | |
response = None # Initialize response variable to avoid UnboundLocalError | |
qa_prompt = QuestionAnswerPrompt(prompt_tmpl) | |
rf_prompt = RefinePrompt(refine_tmpl) | |
response = index.query( | |
question, | |
llm_predictor=llm_predictor, | |
similarity_top_k=sim_k, | |
text_qa_template=qa_prompt, | |
refine_template=rf_prompt, | |
response_mode="compact", | |
) | |
if response is not None: | |
logging.info(f"Response: {response}") | |
ret_text = response.response | |
nodes = [] | |
for index, node in enumerate(response.source_nodes): | |
brief = node.source_text[:25].replace("\n", "") | |
nodes.append( | |
f"<details><summary>[{index+1}]\t{brief}...</summary><p>{node.source_text}</p></details>" | |
) | |
new_response = ret_text + "\n----------\n" + "\n\n".join(nodes) | |
logging.info( | |
f"Response: {colorama.Fore.BLUE}{ret_text}{colorama.Style.RESET_ALL}" | |
) | |
os.environ["OPENAI_API_KEY"] = "" | |
return ret_text, new_response, f"查询消耗了{llm_predictor.last_token_usage} tokens" | |
else: | |
logging.warning("No response found, returning None") | |
os.environ["OPENAI_API_KEY"] = "" | |
return None | |
def add_space(text): | |
punctuations = {",": ", ", "。": "。 ", "?": "? ", "!": "! ", ":": ": ", ";": "; "} | |
for cn_punc, en_punc in punctuations.items(): | |
text = text.replace(cn_punc, en_punc) | |
return text | |