lalithadevi commited on
Commit
1462b37
1 Parent(s): f963240

Rename news_category_prediction.py to news_category_similar_news_prediction.py

Browse files
news_category_prediction.py → news_category_similar_news_prediction.py RENAMED
@@ -6,6 +6,7 @@ from config import (DISTILBERT_TOKENIZER_N_TOKENS,
6
  CLASSIFIER_THRESHOLD)
7
 
8
  from logger import get_logger
 
9
 
10
  logger = get_logger()
11
 
@@ -43,7 +44,8 @@ def cols_check(new_cols, old_cols):
43
  return all([new_col==old_col for new_col, old_col in zip(new_cols, old_cols)])
44
 
45
 
46
- def predict_news_category(old_news: pd.DataFrame, new_news: pd.DataFrame, interpreter, label_encoder, tokenizer):
 
47
  try:
48
  db_updation_required = 1
49
  logger.warning('Entering predict_news_category()')
@@ -73,6 +75,8 @@ def predict_news_category(old_news: pd.DataFrame, new_news: pd.DataFrame, interp
73
  final_df.drop_duplicates(subset='url', keep='first', inplace=True)
74
  headlines = [*final_df['title']].copy()
75
  label, prob = inference(headlines, interpreter, label_encoder, tokenizer)
 
 
76
  final_df['category'] = label
77
  final_df['pred_proba'] = prob
78
  final_df.reset_index(drop=True, inplace=True)
 
6
  CLASSIFIER_THRESHOLD)
7
 
8
  from logger import get_logger
9
+ from find_similar_news import find_similar_news
10
 
11
  logger = get_logger()
12
 
 
44
  return all([new_col==old_col for new_col, old_col in zip(new_cols, old_cols)])
45
 
46
 
47
+ def predict_news_category_similar_news(old_news: pd.DataFrame, new_news: pd.DataFrame, interpreter, label_encoder, tokenizer,
48
+ collection, vectorizer, sent_model, ce_model):
49
  try:
50
  db_updation_required = 1
51
  logger.warning('Entering predict_news_category()')
 
75
  final_df.drop_duplicates(subset='url', keep='first', inplace=True)
76
  headlines = [*final_df['title']].copy()
77
  label, prob = inference(headlines, interpreter, label_encoder, tokenizer)
78
+ sent_embs = vectorizer.vectorize(headlines)
79
+ sim_news = [find_similar_news(text, collection, vectorizer, sent_model, ce_model) for text in sent_embs]
80
  final_df['category'] = label
81
  final_df['pred_proba'] = prob
82
  final_df.reset_index(drop=True, inplace=True)