NLL_Interface / util /vector_base.py
bytedancerneat's picture
Update util/vector_base.py
2af07ab verified
import sys
from langchain_chroma import Chroma
from langchain_core.documents import Document
# sys.path.append('C://Users//Admin//Desktop//PDPO//NLL_LLM//util')
sys.path.append('/home/user/app/util')
from Embeddings import TextEmb3LargeEmbedding
from pathlib import Path
import time
class EmbeddingFunction():
def __init__(self, embeddingmodel):
self.embeddingmodel = embeddingmodel
def embed_query(self, query):
return list(self.embeddingmodel.get_embedding(query))
def embed_documents(self, documents):
return [self.embeddingmodel.get_embedding(document) for document in documents]
def get_or_create_vector_base(collection_name: str, embedding, documents=None) -> Chroma:
"""
判断vector store是否已经构建好,如果没有构建好,则先初始化vector store。不使用embed_documents
方法批量初始化vector store而是for循环逐个加入,同时使用sleep,以此避免调用openai的接口达到最大
上限而导致初始化失败。
"""
persist_directory = "/home/user/app/store/" +collection_name
persist_path = Path(persist_directory)
if not persist_path.exists and not documents:
raise ValueError("vector store does not exist and documents is empty")
elif persist_path.exists():
print("vector store already exists")
vector_store = Chroma(
collection_name=collection_name,
embedding_function=embedding,
persist_directory=persist_directory
)
else:
print("start creating vector store")
vector_store = Chroma(
collection_name=collection_name,
embedding_function=embedding,
persist_directory=persist_directory
)
for document in documents:
vector_store.add_documents(documents=[document])
time.sleep(1)
return vector_store
if __name__=="__main__":
import pandas as pd
requirements_data = pd.read_csv("/root/PTR-LLM/tasks/pcf/reference/NLL_DATA_NEW_Test.csv")
requirements_dict_v2 = {}
for index, row in requirements_data.iterrows():
requirement = row['Requirement'].split("- ")[1]
requirement = requirement + ": " + row['Details']
requirement = requirement.replace('\n', ' ').replace('\r', ' ').replace('\t', ' ')
if requirement not in requirements_dict_v2:
requirements_dict_v2[requirement] = {
'PO': set(),
'safeguard': set()
}
requirements_dict_v2[requirement]['PO'].add(row['PCF-Privacy Objective'].lower().rstrip() if isinstance(row['PCF-Privacy Objective'], str) else None)
requirements_dict_v2[requirement]['safeguard'].add(row['Safeguard'].lower().rstrip())
index = 0
documents = []
for key, value in requirements_dict_v2.items():
page_content = key
metadata = {
"index": index,
"version":2,
"PO": str([po for po in value['PO'] if po]),
"safeguard":str([safeguard for safeguard in value['safeguard']])
}
index += 1
document=Document(
page_content=page_content,
metadata=metadata
)
documents.append(document)
embeddingmodel = TextEmb3LargeEmbedding(max_qpm=58)
embedding = EmbeddingFunction(embeddingmodel)
requirement_v2_vector_store = get_or_create_vector_base('requirement_v2_database', embedding, documents)