YchKhan commited on
Commit
57bf1c3
1 Parent(s): 9b3fe22

Update classification.py

Browse files
Files changed (1) hide show
  1. classification.py +12 -6
classification.py CHANGED
@@ -17,9 +17,15 @@ def initialize_models():
17
  def generate_embeddings(df, model, Column):
18
  embeddings_list = []
19
  for index, row in df.iterrows():
20
- if type(row["Title"]) == str and type(row[Column]) == str:
21
  print(index)
22
- content = row["Title"] + "\n" + row[Column]
 
 
 
 
 
 
23
  embeddings = model.encode(content, convert_to_tensor=True)
24
  embeddings_list.append(embeddings)
25
  else:
@@ -39,13 +45,13 @@ def process_categories(categories, model):
39
 
40
 
41
 
42
- def match_categories(df, category_df):
43
 
44
  categories_list, experts_list, topic_list, scores_list = [], [], [], []
45
  for ebd_content in df['Embeddings']:
46
  if isinstance(ebd_content, torch.Tensor):
47
  cos_scores = util.cos_sim(ebd_content, torch.stack(list(category_df['Embeddings']), dim=0))[0]
48
- high_score_indices = [i for i, score in enumerate(cos_scores) if score > 0.45]
49
 
50
  # Append the corresponding categories, experts, and topics for each high-scoring index
51
  categories_list.append([category_df.loc[index, 'description'] for index in high_score_indices])
@@ -86,7 +92,7 @@ def save_data(df, filename):
86
  df.to_excel(new_filename, index=False)
87
  return new_filename
88
 
89
- def classification(column, file_path, categories):
90
  # Load data
91
  df = load_data(file_path)
92
 
@@ -100,7 +106,7 @@ def classification(column, file_path, categories):
100
  category_df = process_categories(categories, model_ST)
101
 
102
  # Match categories
103
- df = match_categories(df, category_df)
104
 
105
  # Save data
106
  return save_data(df,file_path), df
 
17
  def generate_embeddings(df, model, Column):
18
  embeddings_list = []
19
  for index, row in df.iterrows():
20
+ if type(row[Column]) == str:
21
  print(index)
22
+ if 'Title' in df.columns:
23
+ if type(row["Title"]) == str:
24
+ content = row["Title"] + "\n" + row[Column]
25
+ else:
26
+ content = row[Column]
27
+ else:
28
+ content = row[Column]
29
  embeddings = model.encode(content, convert_to_tensor=True)
30
  embeddings_list.append(embeddings)
31
  else:
 
45
 
46
 
47
 
48
+ def match_categories(df, category_df, treshold=0.45):
49
 
50
  categories_list, experts_list, topic_list, scores_list = [], [], [], []
51
  for ebd_content in df['Embeddings']:
52
  if isinstance(ebd_content, torch.Tensor):
53
  cos_scores = util.cos_sim(ebd_content, torch.stack(list(category_df['Embeddings']), dim=0))[0]
54
+ high_score_indices = [i for i, score in enumerate(cos_scores) if score > treshold]
55
 
56
  # Append the corresponding categories, experts, and topics for each high-scoring index
57
  categories_list.append([category_df.loc[index, 'description'] for index in high_score_indices])
 
92
  df.to_excel(new_filename, index=False)
93
  return new_filename
94
 
95
+ def classification(column, file_path, categories, treshold):
96
  # Load data
97
  df = load_data(file_path)
98
 
 
106
  category_df = process_categories(categories, model_ST)
107
 
108
  # Match categories
109
+ df = match_categories(df, category_df, treshold=treshold)
110
 
111
  # Save data
112
  return save_data(df,file_path), df