Spaces:
Runtime error
Runtime error
import os | |
import time | |
os.environ["OPENAI_API_KEY"] = "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" | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain.document_loaders import PyPDFLoader | |
from langchain.vectorstores import Chroma | |
import openai | |
from pypinyin import lazy_pinyin | |
embedding = OpenAIEmbeddings() | |
def list_files(directory): | |
select = [] | |
for root, dirs, files in os.walk(directory): | |
for file in files: | |
select.append(os.path.join(root, file)) | |
return select | |
def get_path(target_string): | |
folder_path = "./vector_data" | |
all_vectors = os.listdir(folder_path) | |
matching_files = [file for file in all_vectors if file.startswith(target_string)] | |
for file in matching_files: | |
file_path = os.path.join(folder_path, file) | |
return file_path | |
return "" | |
if __name__ == "__main__": | |
domain_cls_prompt = """ | |
帮我根据用户的问题划分到以下几个类别,输出最匹配的一个类别:[宗教与文化, 农业, 建筑业与制造业, 医疗卫生保健, 国家治理, 法律法规, 财政税收, 教育, 金融, 贸易, 宏观经济, 社会发展, 科学技术, 能源环保, 国际关系, 国防安全] | |
""" | |
response = openai.ChatCompletion.create( | |
model="gpt-4", | |
messages=[ | |
{"role": "system", "content": domain_cls_prompt}, | |
{"role": "user", "content": "新能源汽车有什么影响?"} | |
], | |
temperature=1.0, | |
stream=False | |
) | |
domain = response['choices'][0]['message']['content'] | |
print("匹配领域:", domain) | |
persist_vector_path = get_path("".join(lazy_pinyin(domain))) | |
print("vector_path: ", persist_vector_path) | |
start_time = time.time() | |
db = Chroma(persist_directory=persist_vector_path, embedding_function=embedding) | |
docs = db.similarity_search_with_score(query="新能源汽车有什么影响", k=5) | |
contents = [doc[0] for doc in docs] | |
relevance = " ".join(doc.page_content for doc in contents) | |
source = [doc.metadata for doc in contents] | |
end_time = time.time() | |
elapsed_time = end_time - start_time | |
target_files = set() | |
for item in source: | |
target_files.add(item['source']) | |
print("搜索结果:", target_files) | |
print("time cost:", elapsed_time) | |
sys_prompt = """ | |
你是一个研报助手,根据这段文字: | |
{} | |
来回复用户的问题生成总结,你需要严格按照这种格式回复:以上文章总结了*,主要观点是*。 | |
你只能回复中文。 | |
""" | |
sys_prompt = sys_prompt.format(relevance) | |
response = openai.ChatCompletion.create( | |
model="gpt-4", | |
messages=[ | |
{"role": "system", "content": sys_prompt}, | |
{"role": "user", "content": "新能源汽车有什么影响?"} | |
], | |
temperature=1.0, | |
stream=True | |
) | |
partial_message = "" | |
for chunk in response: | |
if len(chunk["choices"][0]["delta"]) != 0: | |
partial_message = partial_message + chunk["choices"][0]["delta"]["content"] | |
print(partial_message) | |
QA_pages = [] | |
for item in target_files: | |
loader = PyPDFLoader(item) | |
QA_pages.extend(loader.load_and_split()) | |
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) | |
documents = text_splitter.split_documents(QA_pages) | |
db = Chroma.from_documents(documents, OpenAIEmbeddings()) | |
docs = db.similarity_search_with_score(query="宏观经济有什么影响", k=3) | |
contents = [doc[0] for doc in docs] | |
relevance = " ".join(doc.page_content for doc in contents) | |
sys_prompt = """ | |
你是一个研报助手,根据这段文字: | |
{} | |
来回复用户的问题,如果这段文字无法回答用户的问题,你可以根据你的知识面来进行专业的回答,你只能回复中文。 | |
""" | |
sys_prompt = sys_prompt.format(relevance) | |
response = openai.ChatCompletion.create( | |
model="gpt-4", | |
messages=[ | |
{"role": "system", "content": sys_prompt}, | |
{"role": "user", "content": "新能源汽车有什么影响?"} | |
], | |
temperature=1.0, | |
stream=True | |
) | |
partial_message = "" | |
for chunk in response: | |
if len(chunk["choices"][0]["delta"]) != 0: | |
partial_message = partial_message + chunk["choices"][0]["delta"]["content"] | |
print(partial_message) | |
domain_cls_prompt = """ | |
你是一个研报助手,请根据用户的要求回复问题 | |
""" | |
response = openai.ChatCompletion.create( | |
model="gpt-4", | |
messages=[ | |
{"role": "system", "content": domain_cls_prompt}, | |
{"role": "user", "content": "给我生成一个研报的提纲?"} | |
], | |
temperature=1.0, | |
stream=False | |
) | |
domain = response['choices'][0]['message']['content'] | |
print("研报的提纲", domain) | |