YchKhan commited on
Commit
8f3a756
1 Parent(s): 533a642

Update classification.py

Browse files
Files changed (1) hide show
  1. classification.py +3 -4
classification.py CHANGED
@@ -172,17 +172,16 @@ def process_categories(categories, model):
172
 
173
  def match_categories(df, category_df, treshold=0.45):
174
  for topic in category_df['topic']:
175
- df[topic] = 0
176
  for i, ebd_content in enumerate(df['Embeddings']):
177
  if isinstance(ebd_content, torch.Tensor):
178
  cos_scores = util.cos_sim(ebd_content, torch.stack(list(category_df['Embeddings']), dim=0))[0]
179
  high_score_indices = [i for i, score in enumerate(cos_scores) if score > treshold]
180
- for j in high_score_indices:
181
- df.loc[i, category_df.loc[j, 'topic']] = float(cos_scores[index])
182
  return df
183
 
184
  def save_data(df, filename):
185
-
186
  df = df.drop(columns=['Embeddings'])
187
  new_filename = filename.replace(".", "_classified.")
188
  df.to_excel(new_filename, index=False)
 
172
 
173
  def match_categories(df, category_df, treshold=0.45):
174
  for topic in category_df['topic']:
175
+ df[topic] = 0
176
  for i, ebd_content in enumerate(df['Embeddings']):
177
  if isinstance(ebd_content, torch.Tensor):
178
  cos_scores = util.cos_sim(ebd_content, torch.stack(list(category_df['Embeddings']), dim=0))[0]
179
  high_score_indices = [i for i, score in enumerate(cos_scores) if score > treshold]
180
+ for j in high_score_indices:
181
+ df.loc[i, category_df.loc[j, 'topic']] = float(cos_scores[index])
182
  return df
183
 
184
  def save_data(df, filename):
 
185
  df = df.drop(columns=['Embeddings'])
186
  new_filename = filename.replace(".", "_classified.")
187
  df.to_excel(new_filename, index=False)