Ashmi Banerjee
refactored the vectordb
89cd5d5
raw
history blame
4.44 kB
from typing import Optional
import pandas as pd
import os
import re
from sentence_transformers import SentenceTransformer
import sys
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(SCRIPT_DIR))
from src.data_directories import *
def create_chunks(city, country, text):
"""
Helper function that creates chunks given paragraph(s) of text based on implicit sections in the text.
Args:
- city: str
- country: str
- text: str; document that needs to be chunked
"""
for i, line in enumerate(text):
if line[0] == "\n":
del text[i]
index = 0
chunks = []
pattern = re.compile("==")
ignore = re.compile("===")
section = 'Introduction'
for i, line in enumerate(text):
if pattern.search(line) and not ignore.search(line):
chunk = ''.join(text[index:i])
chunks.append({
'city': city,
'country': country,
'section': section,
'text': chunk,
# 'vector': f'city: {city}, country: {country}, section: {section}, text: {chunk}'
})
index = i + 1
section = re.sub(pattern, '', line).strip()
df = pd.DataFrame(chunks)
return df
def read_docs():
"""
Helper function that reads all of the Wikivoyage documents containing information about the city.
"""
df = pd.DataFrame()
cities = pd.read_csv(cities_csv)
for file_name in os.listdir(wikivoyage_docs_dir + "cleaned/"):
city = file_name.split(".")[0]
# print(city)
country = cities[cities['city'] == city]['country'].item()
with open(wikivoyage_docs_dir + "cleaned/" + file_name) as file:
text = file.readlines()
chunk_df = create_chunks(city, country, text)
df = pd.concat([df, chunk_df])
return df
def read_listings():
"""
Helper function that reads the Wikivoyage listings csv containing tabular information about 144 cities.
"""
df = pd.read_csv(wikivoyage_listings_dir + "wikivoyage-listings-cleaned.csv")
cities = pd.read_csv(cities_csv)
def find_country(city):
return cities[cities['city'] == city]['country'].values[0]
df['country'] = df['city'].apply(find_country)
return df
def preprocess_df(df):
"""
Helper function that preprocesses the dataframe containing chunks of text and removes hyperlinks and strips the \n from the text.
Args:
- df: dataframe
"""
section_counts = df['section'].value_counts()
sections_to_keep = section_counts[section_counts > 150].index
filtered_df = df[df['section'].isin(sections_to_keep)]
def preprocess_text(s):
s = re.sub(r'http\S+', '', s)
s = re.sub(r'=+', '', s)
s = s.strip()
return s
filtered_df['text'] = filtered_df['text'].apply(preprocess_text)
return filtered_df
def compute_wv_docs_embeddings(df):
"""
Helper function that computes embeddings for the text. The all-MiniLM-L6-v2 embedding model is used.
Args:
- df: dataframe
"""
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
vector_dimension = model.get_sentence_embedding_dimension()
print("Computing embeddings")
embeddings = []
for i, row in df.iterrows():
emb = model.encode(row['combined'], show_progress_bar=True).tolist()
embeddings.append(emb)
print("Finished computing embeddings for wikivoyage documents.")
df['vector'] = embeddings
# df.to_csv(wv_embeddings + "wikivoyage-listings-embeddings.csv")
# print("Finished saving file.")
return df
def embed_query(query):
"""
Helper function that returns the embedded query.
Args:
- query: str
"""
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
# vector_dimension = model.get_sentence_embedding_dimension()
embedding = model.encode(query).tolist()
return embedding
def set_uri(run_local: Optional[bool] = False):
if run_local:
uri = database_dir
current_dir = os.path.split(os.getcwd())[1]
if "src" or "tests" in current_dir: # hacky way to get the correct path
uri = uri.replace("../../", "../")
else:
uri = os.environ["BUCKET_NAME"]
return uri