webui / zy_local_test.py
zhangyi617's picture
Upload folder using huggingface_hub
129cd69
import os
import time
os.environ["OPENAI_API_KEY"] = "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"
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.document_loaders import PyPDFLoader
from langchain.vectorstores import Chroma
import openai
from pypinyin import lazy_pinyin
embedding = OpenAIEmbeddings()
def list_files(directory):
select = []
for root, dirs, files in os.walk(directory):
for file in files:
select.append(os.path.join(root, file))
return select
def get_path(target_string):
folder_path = "./vector_data"
all_vectors = os.listdir(folder_path)
matching_files = [file for file in all_vectors if file.startswith(target_string)]
for file in matching_files:
file_path = os.path.join(folder_path, file)
return file_path
return ""
if __name__ == "__main__":
domain_cls_prompt = """
帮我根据用户的问题划分到以下几个类别,输出最匹配的一个类别:[宗教与文化, 农业, 建筑业与制造业, 医疗卫生保健, 国家治理, 法律法规, 财政税收, 教育, 金融, 贸易, 宏观经济, 社会发展, 科学技术, 能源环保, 国际关系, 国防安全]
"""
response = openai.ChatCompletion.create(
model="gpt-4",
messages=[
{"role": "system", "content": domain_cls_prompt},
{"role": "user", "content": "新能源汽车有什么影响?"}
],
temperature=1.0,
stream=False
)
domain = response['choices'][0]['message']['content']
print("匹配领域:", domain)
persist_vector_path = get_path("".join(lazy_pinyin(domain)))
print("vector_path: ", persist_vector_path)
start_time = time.time()
db = Chroma(persist_directory=persist_vector_path, embedding_function=embedding)
docs = db.similarity_search_with_score(query="新能源汽车有什么影响", k=5)
contents = [doc[0] for doc in docs]
relevance = " ".join(doc.page_content for doc in contents)
source = [doc.metadata for doc in contents]
end_time = time.time()
elapsed_time = end_time - start_time
target_files = set()
for item in source:
target_files.add(item['source'])
print("搜索结果:", target_files)
print("time cost:", elapsed_time)
sys_prompt = """
你是一个研报助手,根据这段文字:
{}
来回复用户的问题生成总结,你需要严格按照这种格式回复:以上文章总结了*,主要观点是*。
你只能回复中文。
"""
sys_prompt = sys_prompt.format(relevance)
response = openai.ChatCompletion.create(
model="gpt-4",
messages=[
{"role": "system", "content": sys_prompt},
{"role": "user", "content": "新能源汽车有什么影响?"}
],
temperature=1.0,
stream=True
)
partial_message = ""
for chunk in response:
if len(chunk["choices"][0]["delta"]) != 0:
partial_message = partial_message + chunk["choices"][0]["delta"]["content"]
print(partial_message)
QA_pages = []
for item in target_files:
loader = PyPDFLoader(item)
QA_pages.extend(loader.load_and_split())
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
documents = text_splitter.split_documents(QA_pages)
db = Chroma.from_documents(documents, OpenAIEmbeddings())
docs = db.similarity_search_with_score(query="宏观经济有什么影响", k=3)
contents = [doc[0] for doc in docs]
relevance = " ".join(doc.page_content for doc in contents)
sys_prompt = """
你是一个研报助手,根据这段文字:
{}
来回复用户的问题,如果这段文字无法回答用户的问题,你可以根据你的知识面来进行专业的回答,你只能回复中文。
"""
sys_prompt = sys_prompt.format(relevance)
response = openai.ChatCompletion.create(
model="gpt-4",
messages=[
{"role": "system", "content": sys_prompt},
{"role": "user", "content": "新能源汽车有什么影响?"}
],
temperature=1.0,
stream=True
)
partial_message = ""
for chunk in response:
if len(chunk["choices"][0]["delta"]) != 0:
partial_message = partial_message + chunk["choices"][0]["delta"]["content"]
print(partial_message)
domain_cls_prompt = """
你是一个研报助手,请根据用户的要求回复问题
"""
response = openai.ChatCompletion.create(
model="gpt-4",
messages=[
{"role": "system", "content": domain_cls_prompt},
{"role": "user", "content": "给我生成一个研报的提纲?"}
],
temperature=1.0,
stream=False
)
domain = response['choices'][0]['message']['content']
print("研报的提纲", domain)