lalithadevi
commited on
Commit
•
f442802
1
Parent(s):
1bbda15
Update news_category_prediction.py
Browse files- news_category_prediction.py +46 -41
news_category_prediction.py
CHANGED
@@ -43,46 +43,51 @@ def cols_check(new_cols, old_cols):
|
|
43 |
|
44 |
|
45 |
def predict_news_category(old_news: pd.DataFrame, new_news: pd.DataFrame, interpreter, label_encoder, tokenizer):
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
old_news
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
if not
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
87 |
return final_df
|
88 |
|
|
|
43 |
|
44 |
|
45 |
def predict_news_category(old_news: pd.DataFrame, new_news: pd.DataFrame, interpreter, label_encoder, tokenizer):
|
46 |
+
try:
|
47 |
+
logger.warning('Entering predict_news_category()')
|
48 |
+
logger.warning(f'old news: {old_news}\nnew_news: {new_news}')
|
49 |
+
if not isinstance(new_news, pd.DataFrame):
|
50 |
+
raise Exception('No New News Found')
|
51 |
+
else:
|
52 |
+
new_news = new_news.copy()
|
53 |
+
logger.warning(f'new news columns: {[*new_news.columns]}')
|
54 |
+
logger.warning(f'{len(new_news)} new news items found')
|
55 |
+
|
56 |
+
if isinstance(old_news, pd.DataFrame):
|
57 |
+
old_news = old_news.copy()
|
58 |
+
logger.warning(f'old news columns: {[*old_news.columns]}')
|
59 |
+
logger.warning(f'{len(old_news)} old news items found')
|
60 |
+
else:
|
61 |
+
logger.warning('No old news is found')
|
62 |
+
old_news = new_news.copy()
|
63 |
+
|
64 |
+
if 'category' not in [*old_news.columns]:
|
65 |
+
logger.warning('No prior predictions found in old news')
|
66 |
+
if not cols_check([*new_news.columns], [*old_news.columns]):
|
67 |
+
raise Exeption("New and old cols don't match")
|
68 |
+
final_df = pd.concat([old_news, new_news], axis=0, ignore_index=True)
|
69 |
+
final_df.drop_duplicates(subset='url', keep='first', inplace=True)
|
70 |
+
headlines = [*final_df['title']].copy()
|
71 |
+
label, prob = inference(headlines, interpreter, label_encoder, tokenizer)
|
72 |
+
final_df['category'] = label
|
73 |
+
final_df['pred_proba'] = prob
|
74 |
+
else:
|
75 |
+
logger.warning('Prior predictions found in old news')
|
76 |
+
if not cols_check([*new_news.columns], [*old_news.columns][:-2]):
|
77 |
+
raise Exeption("New and old cols don't match")
|
78 |
+
old_urls = [*old_news['url']]
|
79 |
+
new_news = new_news.loc[new_news['url'].isin(old_urls) == False, :]
|
80 |
+
headlines = [*new_news['title']].copy()
|
81 |
+
label, prob = inference(headlines, interpreter, label_encoder, tokenizer)
|
82 |
+
new_news['category'] = label
|
83 |
+
new_news['pred_proba'] = prob
|
84 |
+
final_df = pd.concat([old_news, new_news], axis=0, ignore_index=True)
|
85 |
+
final_df.drop_duplicates(subset='url', keep='first', inplace=True)
|
86 |
+
final_df.reset_index(drop=True, inplace=True)
|
87 |
+
final_df.loc[final_df['pred_proba']<CLASSIFIER_THRESHOLD, 'category'] = 'OTHERS'
|
88 |
+
logger.warning('Exiting predict_news_category()')
|
89 |
+
except Exception as e:
|
90 |
+
logger.warning(f'Unexcpected error in predict_news_category()\n{e}')
|
91 |
+
return None
|
92 |
return final_df
|
93 |
|