lalithadevi commited on
Commit
d5f669f
1 Parent(s): 8a32255

Update news_category_prediction.py

Browse files
Files changed (1) hide show
  1. news_category_prediction.py +14 -8
news_category_prediction.py CHANGED
@@ -74,20 +74,26 @@ def predict_news_category(old_news: pd.DataFrame, new_news: pd.DataFrame, interp
74
  label, prob = inference(headlines, interpreter, label_encoder, tokenizer)
75
  final_df['category'] = label
76
  final_df['pred_proba'] = prob
 
 
77
  else:
78
  logger.warning('Prior predictions found in old news')
79
  if not cols_check([*new_news.columns], [*old_news.columns][:-2]):
80
  raise Exception("New and old cols don't match")
81
  old_urls = [*old_news['url']]
82
  new_news = new_news.loc[new_news['url'].isin(old_urls) == False, :]
83
- headlines = [*new_news['title']].copy()
84
- label, prob = inference(headlines, interpreter, label_encoder, tokenizer)
85
- new_news['category'] = label
86
- new_news['pred_proba'] = prob
87
- final_df = pd.concat([old_news, new_news], axis=0, ignore_index=True)
88
- final_df.drop_duplicates(subset='url', keep='first', inplace=True)
89
- final_df.reset_index(drop=True, inplace=True)
90
- final_df.loc[final_df['pred_proba']<CLASSIFIER_THRESHOLD, 'category'] = 'OTHERS'
 
 
 
 
91
 
92
  if len(final_df) == 0:
93
  final_df = None
 
74
  label, prob = inference(headlines, interpreter, label_encoder, tokenizer)
75
  final_df['category'] = label
76
  final_df['pred_proba'] = prob
77
+ final_df.reset_index(drop=True, inplace=True)
78
+ final_df.loc[final_df['pred_proba']<CLASSIFIER_THRESHOLD, 'category'] = 'OTHERS'
79
  else:
80
  logger.warning('Prior predictions found in old news')
81
  if not cols_check([*new_news.columns], [*old_news.columns][:-2]):
82
  raise Exception("New and old cols don't match")
83
  old_urls = [*old_news['url']]
84
  new_news = new_news.loc[new_news['url'].isin(old_urls) == False, :]
85
+ if len(new_news) > 0:
86
+ headlines = [*new_news['title']].copy()
87
+ label, prob = inference(headlines, interpreter, label_encoder, tokenizer)
88
+ new_news['category'] = label
89
+ new_news['pred_proba'] = prob
90
+ final_df = pd.concat([old_news, new_news], axis=0, ignore_index=True)
91
+ final_df.drop_duplicates(subset='url', keep='first', inplace=True)
92
+ final_df.reset_index(drop=True, inplace=True)
93
+ final_df.loc[final_df['pred_proba']<CLASSIFIER_THRESHOLD, 'category'] = 'OTHERS'
94
+ else:
95
+ raise Exception('INFO: Old & New Articles are the same. There is no requirement of updating them in the database. Database is not updated.')
96
+
97
 
98
  if len(final_df) == 0:
99
  final_df = None