Spaces:
Runtime error
Runtime error
File size: 4,977 Bytes
129cd69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import os
import time
os.environ["OPENAI_API_KEY"] = "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"
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.document_loaders import PyPDFLoader
from langchain.vectorstores import Chroma
import openai
from pypinyin import lazy_pinyin
embedding = OpenAIEmbeddings()
def list_files(directory):
select = []
for root, dirs, files in os.walk(directory):
for file in files:
select.append(os.path.join(root, file))
return select
def get_path(target_string):
folder_path = "./vector_data"
all_vectors = os.listdir(folder_path)
matching_files = [file for file in all_vectors if file.startswith(target_string)]
for file in matching_files:
file_path = os.path.join(folder_path, file)
return file_path
return ""
if __name__ == "__main__":
domain_cls_prompt = """
帮我根据用户的问题划分到以下几个类别,输出最匹配的一个类别:[宗教与文化, 农业, 建筑业与制造业, 医疗卫生保健, 国家治理, 法律法规, 财政税收, 教育, 金融, 贸易, 宏观经济, 社会发展, 科学技术, 能源环保, 国际关系, 国防安全]
"""
response = openai.ChatCompletion.create(
model="gpt-4",
messages=[
{"role": "system", "content": domain_cls_prompt},
{"role": "user", "content": "新能源汽车有什么影响?"}
],
temperature=1.0,
stream=False
)
domain = response['choices'][0]['message']['content']
print("匹配领域:", domain)
persist_vector_path = get_path("".join(lazy_pinyin(domain)))
print("vector_path: ", persist_vector_path)
start_time = time.time()
db = Chroma(persist_directory=persist_vector_path, embedding_function=embedding)
docs = db.similarity_search_with_score(query="新能源汽车有什么影响", k=5)
contents = [doc[0] for doc in docs]
relevance = " ".join(doc.page_content for doc in contents)
source = [doc.metadata for doc in contents]
end_time = time.time()
elapsed_time = end_time - start_time
target_files = set()
for item in source:
target_files.add(item['source'])
print("搜索结果:", target_files)
print("time cost:", elapsed_time)
sys_prompt = """
你是一个研报助手,根据这段文字:
{}
来回复用户的问题生成总结,你需要严格按照这种格式回复:以上文章总结了*,主要观点是*。
你只能回复中文。
"""
sys_prompt = sys_prompt.format(relevance)
response = openai.ChatCompletion.create(
model="gpt-4",
messages=[
{"role": "system", "content": sys_prompt},
{"role": "user", "content": "新能源汽车有什么影响?"}
],
temperature=1.0,
stream=True
)
partial_message = ""
for chunk in response:
if len(chunk["choices"][0]["delta"]) != 0:
partial_message = partial_message + chunk["choices"][0]["delta"]["content"]
print(partial_message)
QA_pages = []
for item in target_files:
loader = PyPDFLoader(item)
QA_pages.extend(loader.load_and_split())
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
documents = text_splitter.split_documents(QA_pages)
db = Chroma.from_documents(documents, OpenAIEmbeddings())
docs = db.similarity_search_with_score(query="宏观经济有什么影响", k=3)
contents = [doc[0] for doc in docs]
relevance = " ".join(doc.page_content for doc in contents)
sys_prompt = """
你是一个研报助手,根据这段文字:
{}
来回复用户的问题,如果这段文字无法回答用户的问题,你可以根据你的知识面来进行专业的回答,你只能回复中文。
"""
sys_prompt = sys_prompt.format(relevance)
response = openai.ChatCompletion.create(
model="gpt-4",
messages=[
{"role": "system", "content": sys_prompt},
{"role": "user", "content": "新能源汽车有什么影响?"}
],
temperature=1.0,
stream=True
)
partial_message = ""
for chunk in response:
if len(chunk["choices"][0]["delta"]) != 0:
partial_message = partial_message + chunk["choices"][0]["delta"]["content"]
print(partial_message)
domain_cls_prompt = """
你是一个研报助手,请根据用户的要求回复问题
"""
response = openai.ChatCompletion.create(
model="gpt-4",
messages=[
{"role": "system", "content": domain_cls_prompt},
{"role": "user", "content": "给我生成一个研报的提纲?"}
],
temperature=1.0,
stream=False
)
domain = response['choices'][0]['message']['content']
print("研报的提纲", domain)
|