Spaces:
Sleeping
Sleeping
import sys | |
from langchain_chroma import Chroma | |
from langchain_core.documents import Document | |
# sys.path.append('C://Users//Admin//Desktop//PDPO//NLL_LLM//util') | |
sys.path.append('/home/user/app/util') | |
from Embeddings import TextEmb3LargeEmbedding | |
from pathlib import Path | |
import time | |
class EmbeddingFunction(): | |
def __init__(self, embeddingmodel): | |
self.embeddingmodel = embeddingmodel | |
def embed_query(self, query): | |
return list(self.embeddingmodel.get_embedding(query)) | |
def embed_documents(self, documents): | |
return [self.embeddingmodel.get_embedding(document) for document in documents] | |
def get_or_create_vector_base(collection_name: str, embedding, documents=None) -> Chroma: | |
""" | |
判断vector store是否已经构建好,如果没有构建好,则先初始化vector store。不使用embed_documents | |
方法批量初始化vector store而是for循环逐个加入,同时使用sleep,以此避免调用openai的接口达到最大 | |
上限而导致初始化失败。 | |
""" | |
persist_directory = "/home/user/app/store/" +collection_name | |
persist_path = Path(persist_directory) | |
if not persist_path.exists and not documents: | |
raise ValueError("vector store does not exist and documents is empty") | |
elif persist_path.exists(): | |
print("vector store already exists") | |
vector_store = Chroma( | |
collection_name=collection_name, | |
embedding_function=embedding, | |
persist_directory=persist_directory | |
) | |
else: | |
print("start creating vector store") | |
vector_store = Chroma( | |
collection_name=collection_name, | |
embedding_function=embedding, | |
persist_directory=persist_directory | |
) | |
for document in documents: | |
vector_store.add_documents(documents=[document]) | |
time.sleep(1) | |
return vector_store | |
if __name__=="__main__": | |
import pandas as pd | |
requirements_data = pd.read_csv("/root/PTR-LLM/tasks/pcf/reference/NLL_DATA_NEW_Test.csv") | |
requirements_dict_v2 = {} | |
for index, row in requirements_data.iterrows(): | |
requirement = row['Requirement'].split("- ")[1] | |
requirement = requirement + ": " + row['Details'] | |
requirement = requirement.replace('\n', ' ').replace('\r', ' ').replace('\t', ' ') | |
if requirement not in requirements_dict_v2: | |
requirements_dict_v2[requirement] = { | |
'PO': set(), | |
'safeguard': set() | |
} | |
requirements_dict_v2[requirement]['PO'].add(row['PCF-Privacy Objective'].lower().rstrip() if isinstance(row['PCF-Privacy Objective'], str) else None) | |
requirements_dict_v2[requirement]['safeguard'].add(row['Safeguard'].lower().rstrip()) | |
index = 0 | |
documents = [] | |
for key, value in requirements_dict_v2.items(): | |
page_content = key | |
metadata = { | |
"index": index, | |
"version":2, | |
"PO": str([po for po in value['PO'] if po]), | |
"safeguard":str([safeguard for safeguard in value['safeguard']]) | |
} | |
index += 1 | |
document=Document( | |
page_content=page_content, | |
metadata=metadata | |
) | |
documents.append(document) | |
embeddingmodel = TextEmb3LargeEmbedding(max_qpm=58) | |
embedding = EmbeddingFunction(embeddingmodel) | |
requirement_v2_vector_store = get_or_create_vector_base('requirement_v2_database', embedding, documents) |