brightly-ai / old_experiments /llama3-gpu2.py
beweinreich's picture
first
9189e38
raw
history blame
No virus
2.09 kB
from sentence_transformers import SentenceTransformer, util
import pandas as pd
from tqdm import tqdm
import os
import pickle
# Load pre-trained sentence transformer model
model = SentenceTransformer('all-MiniLM-L6-v2')
def generate_st_embedding(sentence):
return model.encode(sentence, convert_to_tensor=True)
def cosine_similarity_st(embedding1, embedding2):
return util.pytorch_cos_sim(embedding1, embedding2).item()
# Load the dictionary
csv_file_path = './dictionary/dictionary.csv'
df_dictionary = pd.read_csv(csv_file_path)
dictionary = df_dictionary['description'].tolist()
# Load the input words
input_file_path = 'raw/test.csv'
df_input = pd.read_csv(input_file_path)
input_words = df_input['description'].tolist()
print("Everything loaded...")
# Check if the embeddings pickle file exists
pickle_file_path = './sbert_dictionary_embeddings.pkl'
if os.path.exists(pickle_file_path):
with open(pickle_file_path, 'rb') as f:
dictionary_embeddings = pickle.load(f)
else:
# Generate embeddings for dictionary words
dictionary_embeddings = {}
for desc in tqdm(dictionary, desc="Generating embeddings for dictionary words"):
dictionary_embeddings[desc] = generate_st_embedding(desc)
# Save the embeddings to a pickle file
with open(pickle_file_path, 'wb') as f:
pickle.dump(dictionary_embeddings, f)
# Find the most similar word in the dictionary for each input word
results = []
for input_word in tqdm(input_words, desc="Processing input words"):
input_embedding = generate_st_embedding(input_word)
similarities = [(desc, cosine_similarity_st(input_embedding, dict_embedding))
for desc, dict_embedding in dictionary_embeddings.items()]
most_similar_word, highest_score = max(similarities, key=lambda x: x[1])
results.append((input_word, most_similar_word, highest_score))
# Print the results
for input_word, most_similar_word, score in results:
print(f"Input word: {input_word}")
print(f"Most similar word: {most_similar_word}")
print(f"Similarity score: {score}\n")