Spaces:
Running
Running
from typing import Optional | |
import pandas as pd | |
import os | |
import re | |
from sentence_transformers import SentenceTransformer | |
import sys | |
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
sys.path.append(os.path.dirname(SCRIPT_DIR)) | |
from src.data_directories import * | |
def create_chunks(city, country, text): | |
""" | |
Helper function that creates chunks given paragraph(s) of text based on implicit sections in the text. | |
Args: | |
- city: str | |
- country: str | |
- text: str; document that needs to be chunked | |
""" | |
for i, line in enumerate(text): | |
if line[0] == "\n": | |
del text[i] | |
index = 0 | |
chunks = [] | |
pattern = re.compile("==") | |
ignore = re.compile("===") | |
section = 'Introduction' | |
for i, line in enumerate(text): | |
if pattern.search(line) and not ignore.search(line): | |
chunk = ''.join(text[index:i]) | |
chunks.append({ | |
'city': city, | |
'country': country, | |
'section': section, | |
'text': chunk, | |
# 'vector': f'city: {city}, country: {country}, section: {section}, text: {chunk}' | |
}) | |
index = i + 1 | |
section = re.sub(pattern, '', line).strip() | |
df = pd.DataFrame(chunks) | |
return df | |
def read_docs(): | |
""" | |
Helper function that reads all of the Wikivoyage documents containing information about the city. | |
""" | |
df = pd.DataFrame() | |
cities = pd.read_csv(cities_csv) | |
for file_name in os.listdir(wikivoyage_docs_dir + "cleaned/"): | |
city = file_name.split(".")[0] | |
# print(city) | |
country = cities[cities['city'] == city]['country'].item() | |
with open(wikivoyage_docs_dir + "cleaned/" + file_name) as file: | |
text = file.readlines() | |
chunk_df = create_chunks(city, country, text) | |
df = pd.concat([df, chunk_df]) | |
return df | |
def read_listings(): | |
""" | |
Helper function that reads the Wikivoyage listings csv containing tabular information about 144 cities. | |
""" | |
df = pd.read_csv(wikivoyage_listings_dir + "wikivoyage-listings-cleaned.csv") | |
cities = pd.read_csv(cities_csv) | |
def find_country(city): | |
return cities[cities['city'] == city]['country'].values[0] | |
df['country'] = df['city'].apply(find_country) | |
return df | |
def preprocess_df(df): | |
""" | |
Helper function that preprocesses the dataframe containing chunks of text and removes hyperlinks and strips the \n from the text. | |
Args: | |
- df: dataframe | |
""" | |
section_counts = df['section'].value_counts() | |
sections_to_keep = section_counts[section_counts > 150].index | |
filtered_df = df[df['section'].isin(sections_to_keep)] | |
def preprocess_text(s): | |
s = re.sub(r'http\S+', '', s) | |
s = re.sub(r'=+', '', s) | |
s = s.strip() | |
return s | |
filtered_df['text'] = filtered_df['text'].apply(preprocess_text) | |
return filtered_df | |
def compute_wv_docs_embeddings(df): | |
""" | |
Helper function that computes embeddings for the text. The all-MiniLM-L6-v2 embedding model is used. | |
Args: | |
- df: dataframe | |
""" | |
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
vector_dimension = model.get_sentence_embedding_dimension() | |
print("Computing embeddings") | |
embeddings = [] | |
for i, row in df.iterrows(): | |
emb = model.encode(row['combined'], show_progress_bar=True).tolist() | |
embeddings.append(emb) | |
print("Finished computing embeddings for wikivoyage documents.") | |
df['vector'] = embeddings | |
# df.to_csv(wv_embeddings + "wikivoyage-listings-embeddings.csv") | |
# print("Finished saving file.") | |
return df | |
def embed_query(query): | |
""" | |
Helper function that returns the embedded query. | |
Args: | |
- query: str | |
""" | |
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
# vector_dimension = model.get_sentence_embedding_dimension() | |
embedding = model.encode(query).tolist() | |
return embedding | |
def set_uri(run_local: Optional[bool] = False): | |
if run_local: | |
uri = database_dir | |
current_dir = os.path.split(os.getcwd())[1] | |
if "src" or "tests" in current_dir: # hacky way to get the correct path | |
uri = uri.replace("../../", "../") | |
else: | |
uri = os.environ["BUCKET_NAME"] | |
return uri | |