Spaces:
Paused
Paused
from sentence_transformers import SentenceTransformer, util | |
import pandas as pd | |
from tqdm import tqdm | |
import os | |
import pickle | |
# Load pre-trained sentence transformer model | |
model = SentenceTransformer('all-MiniLM-L6-v2') | |
def generate_st_embedding(sentence): | |
return model.encode(sentence, convert_to_tensor=True) | |
def cosine_similarity_st(embedding1, embedding2): | |
return util.pytorch_cos_sim(embedding1, embedding2).item() | |
# Load the dictionary | |
csv_file_path = './dictionary/dictionary.csv' | |
df_dictionary = pd.read_csv(csv_file_path) | |
dictionary = df_dictionary['description'].tolist() | |
# Load the input words | |
input_file_path = 'raw/test.csv' | |
df_input = pd.read_csv(input_file_path) | |
input_words = df_input['description'].tolist() | |
print("Everything loaded...") | |
# Check if the embeddings pickle file exists | |
pickle_file_path = './sbert_dictionary_embeddings.pkl' | |
if os.path.exists(pickle_file_path): | |
with open(pickle_file_path, 'rb') as f: | |
dictionary_embeddings = pickle.load(f) | |
else: | |
# Generate embeddings for dictionary words | |
dictionary_embeddings = {} | |
for desc in tqdm(dictionary, desc="Generating embeddings for dictionary words"): | |
dictionary_embeddings[desc] = generate_st_embedding(desc) | |
# Save the embeddings to a pickle file | |
with open(pickle_file_path, 'wb') as f: | |
pickle.dump(dictionary_embeddings, f) | |
# Find the most similar word in the dictionary for each input word | |
results = [] | |
for input_word in tqdm(input_words, desc="Processing input words"): | |
input_embedding = generate_st_embedding(input_word) | |
similarities = [(desc, cosine_similarity_st(input_embedding, dict_embedding)) | |
for desc, dict_embedding in dictionary_embeddings.items()] | |
most_similar_word, highest_score = max(similarities, key=lambda x: x[1]) | |
results.append((input_word, most_similar_word, highest_score)) | |
# Print the results | |
for input_word, most_similar_word, score in results: | |
print(f"Input word: {input_word}") | |
print(f"Most similar word: {most_similar_word}") | |
print(f"Similarity score: {score}\n") | |