AmirShabani commited on
Commit
a3e15e0
1 Parent(s): 225ee02

Fuzzy matching

Browse files
Files changed (2) hide show
  1. core.py +5 -1
  2. requirements.txt +1 -0
core.py CHANGED
@@ -28,6 +28,9 @@ from torch_geometric.nn.conv.gcn_conv import gcn_norm
28
  from torch_geometric.nn.conv import MessagePassing
29
  from torch_geometric.typing import Adj
30
  from sklearn.neighbors import BallTree
 
 
 
31
  class LightGCN(MessagePassing):
32
  def __init__(self, num_users, num_items, embedding_dim=64, diffusion_steps=3, add_self_loops=False):
33
  super().__init__()
@@ -186,7 +189,8 @@ def drop_non_numerical_columns(df):
186
  def output_list(input_dict, movies_df = movie_embeds, tree = btree, user_embeddings = user_embeds, movies = final_movies):
187
  movie_ratings = {}
188
  for movie_title in input_dict:
189
- index = movies.index[movies['title'] == movie_title].tolist()[0]
 
190
  movie_ratings[index] = input_dict[movie_title]
191
  user_embed = create_user_embedding(movie_ratings, movie_embeds)
192
  # Call the find_closest_user function with the pre-built BallTree
 
28
  from torch_geometric.nn.conv import MessagePassing
29
  from torch_geometric.typing import Adj
30
  from sklearn.neighbors import BallTree
31
+ from thefuzz import fuzz
32
+ from thefuzz import process
33
+
34
  class LightGCN(MessagePassing):
35
  def __init__(self, num_users, num_items, embedding_dim=64, diffusion_steps=3, add_self_loops=False):
36
  super().__init__()
 
189
  def output_list(input_dict, movies_df = movie_embeds, tree = btree, user_embeddings = user_embeds, movies = final_movies):
190
  movie_ratings = {}
191
  for movie_title in input_dict:
192
+ matching_title = process.extractOne(movie_title, final_movies['title'].values, scorer=fuzz.token_sort_ratio)[0]
193
+ index = movies.index[movies['title'] == matching_title].tolist()[0]
194
  movie_ratings[index] = input_dict[movie_title]
195
  user_embed = create_user_embedding(movie_ratings, movie_embeds)
196
  # Call the find_closest_user function with the pre-built BallTree
requirements.txt CHANGED
@@ -2,6 +2,7 @@ requests==2.29.0
2
  pillow
3
  numpy==1.23.5
4
  pandas==1.5.3
 
5
  scikit-learn==1.2.2
6
  torch==2.0.0
7
  torchvision==0.15.1
 
2
  pillow
3
  numpy==1.23.5
4
  pandas==1.5.3
5
+ thefuzz[speedup]
6
  scikit-learn==1.2.2
7
  torch==2.0.0
8
  torchvision==0.15.1