1st_langchain / injection.py
jfeng1115's picture
init commit
72a5f6e
raw
history blame
1.75 kB
import pandas as pd
from tqdm.auto import tqdm
from uuid import uuid4
from langchain.text_splitter import RecursiveCharacterTextSplitter
def inject(index, embedder, data_file):
data = pd.read_csv(data_file)
print(data.head())
BATCH_LIMIT = 100
text_splitter = RecursiveCharacterTextSplitter(
chunk_size = 1000, ### YOUR CODE HERE, # the character length of the chunk
chunk_overlap = 100, ### YOUR CODE HERE, # the character length of the overlap between chunks
length_function = len ### YOUR CODE HERE, # the length function - in this case, character length (aka the python len() fn.)
)
texts = []
metadatas = []
for i in tqdm(range(len(data))):
record = data.iloc[i]
metadata = {
'review-url': str(record["Review_Url"]),
'review-date' : str(record["Review_Date"]),
'author' : str(record["Author"]),
'rating' : str(record["Rating"]),
'review-title' : str(record["Review_Title"]),
}
record_texts = text_splitter.split_text(record["Review"])
record_metadatas = [{
"chunk": j, "text": text, **metadata
} for j, text in enumerate(record_texts)]
texts.extend(record_texts)
metadatas.extend(record_metadatas)
if len(texts) >= BATCH_LIMIT:
ids = [str(uuid4()) for _ in range(len(texts))]
embeds = embedder.embed_documents(texts)
index.upsert(vectors=zip(ids, embeds, metadatas))
texts = []
metadatas = []
if len(texts) > 0:
ids = [str(uuid4()) for _ in range(len(texts))]
embeds = embedder.embed_documents(texts)
index.upsert(vectors=zip(ids, embeds, metadatas))