chinese_lantern_riddles / data_preparation.py
3v324v23's picture
first POC
1ae7d73
raw
history blame
1.49 kB
# %%
import shutil
import os
from langchain.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
# %%
# loading the riddle data into docs
data_file = "data/riddles_data"
loader = TextLoader(data_file)
docs = loader.load()
# create the text splitter, splitted exactly line-by-line
text_splitter = CharacterTextSplitter(
separator = "\n",
chunk_size = 0,
chunk_overlap = 0,
length_function = len,
is_separator_regex = False,
)
# get the splits
splits = text_splitter.split_documents(docs)
# %%
# loading the vector encoder
model_name = "shibing624/text2vec-base-chinese"
encode_kwargs = {'normalize_embeddings': False}
model_kwargs = {'device': 'cpu'}
huggingface_embeddings= HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs = encode_kwargs
)
# %%
# vectordb with Chroma
persist_directory = 'chroma/'
# %%
# remove the old one when rebuilt the database
if os.path.exists(persist_directory):
shutil.rmtree(persist_directory) # remove old database files if any
# %%
# load the riddles documents to vectordb
vectordb = Chroma.from_documents(
documents=splits,
embedding=huggingface_embeddings,
persist_directory=persist_directory,
collection_metadata={"hnsw:space": "cosine"}
)
# %%
vectordb.persist()
print(vectordb._collection.count())
# %%