Commit
•
d5f669f
1
Parent(s):
8a32255
Update news_category_prediction.py
Browse files- news_category_prediction.py +14 -8
news_category_prediction.py
CHANGED
@@ -74,20 +74,26 @@ def predict_news_category(old_news: pd.DataFrame, new_news: pd.DataFrame, interp
|
|
74 |
label, prob = inference(headlines, interpreter, label_encoder, tokenizer)
|
75 |
final_df['category'] = label
|
76 |
final_df['pred_proba'] = prob
|
|
|
|
|
77 |
else:
|
78 |
logger.warning('Prior predictions found in old news')
|
79 |
if not cols_check([*new_news.columns], [*old_news.columns][:-2]):
|
80 |
raise Exception("New and old cols don't match")
|
81 |
old_urls = [*old_news['url']]
|
82 |
new_news = new_news.loc[new_news['url'].isin(old_urls) == False, :]
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
|
|
91 |
|
92 |
if len(final_df) == 0:
|
93 |
final_df = None
|
|
|
74 |
label, prob = inference(headlines, interpreter, label_encoder, tokenizer)
|
75 |
final_df['category'] = label
|
76 |
final_df['pred_proba'] = prob
|
77 |
+
final_df.reset_index(drop=True, inplace=True)
|
78 |
+
final_df.loc[final_df['pred_proba']<CLASSIFIER_THRESHOLD, 'category'] = 'OTHERS'
|
79 |
else:
|
80 |
logger.warning('Prior predictions found in old news')
|
81 |
if not cols_check([*new_news.columns], [*old_news.columns][:-2]):
|
82 |
raise Exception("New and old cols don't match")
|
83 |
old_urls = [*old_news['url']]
|
84 |
new_news = new_news.loc[new_news['url'].isin(old_urls) == False, :]
|
85 |
+
if len(new_news) > 0:
|
86 |
+
headlines = [*new_news['title']].copy()
|
87 |
+
label, prob = inference(headlines, interpreter, label_encoder, tokenizer)
|
88 |
+
new_news['category'] = label
|
89 |
+
new_news['pred_proba'] = prob
|
90 |
+
final_df = pd.concat([old_news, new_news], axis=0, ignore_index=True)
|
91 |
+
final_df.drop_duplicates(subset='url', keep='first', inplace=True)
|
92 |
+
final_df.reset_index(drop=True, inplace=True)
|
93 |
+
final_df.loc[final_df['pred_proba']<CLASSIFIER_THRESHOLD, 'category'] = 'OTHERS'
|
94 |
+
else:
|
95 |
+
raise Exception('INFO: Old & New Articles are the same. There is no requirement of updating them in the database. Database is not updated.')
|
96 |
+
|
97 |
|
98 |
if len(final_df) == 0:
|
99 |
final_df = None
|