YchKhan commited on
Commit
3f8cc48
1 Parent(s): c6cfdf4

Update classification.py

Browse files
Files changed (1) hide show
  1. classification.py +1 -1
classification.py CHANGED
@@ -178,7 +178,7 @@ def match_categories(df, category_df, treshold=0.45):
178
  cos_scores = util.cos_sim(ebd_content, torch.stack(list(category_df['Embeddings']), dim=0))[0]
179
  high_score_indices = [i for i, score in enumerate(cos_scores) if score > treshold]
180
  for j in high_score_indices:
181
- df.loc[index, category_df.loc[j, 'topic']] = 'float(cos_scores[j])'
182
  return df
183
 
184
  def save_data(df, filename):
 
178
  cos_scores = util.cos_sim(ebd_content, torch.stack(list(category_df['Embeddings']), dim=0))[0]
179
  high_score_indices = [i for i, score in enumerate(cos_scores) if score > treshold]
180
  for j in high_score_indices:
181
+ df.loc[index, category_df.loc[j, 'topic']] = float(cos_scores[j])
182
  return df
183
 
184
  def save_data(df, filename):