Commit
•
74a2149
1
Parent(s):
0a863df
Update news_category_prediction.py
Browse files- news_category_prediction.py +13 -3
news_category_prediction.py
CHANGED
@@ -1,7 +1,13 @@
|
|
1 |
import pandas as pd
|
2 |
import numpy as np
|
3 |
import tensorflow as tf
|
4 |
-
from config import DISTILBERT_TOKENIZER_N_TOKENS,
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
def parse_prediction(tflite_pred, label_encoder):
|
7 |
tflite_pred_argmax = np.argmax(tflite_pred, axis=1)
|
@@ -11,6 +17,7 @@ def parse_prediction(tflite_pred, label_encoder):
|
|
11 |
|
12 |
|
13 |
def inference(text, interpreter, label_encoder, tokenizer):
|
|
|
14 |
batch_size = len(text)
|
15 |
if text != "":
|
16 |
tokens = tokenizer(text, max_length=DISTILBERT_TOKENIZER_N_TOKENS, padding="max_length", truncation=True, return_tensors="tf")
|
@@ -28,6 +35,7 @@ def inference(text, interpreter, label_encoder, tokenizer):
|
|
28 |
interpreter.invoke()
|
29 |
tflite_pred = interpreter.get_tensor(output_details["index"])
|
30 |
tflite_pred = parse_prediction(tflite_pred)
|
|
|
31 |
return tflite_pred
|
32 |
|
33 |
def cols_check(new_cols, old_cols):
|
@@ -35,6 +43,7 @@ def cols_check(new_cols, old_cols):
|
|
35 |
|
36 |
|
37 |
def predict_news_category(old_news: pd.DataFrame, new_news: pd.DataFrame, interpreter, label_encoder, tokenizer):
|
|
|
38 |
old_news = old_news.copy()
|
39 |
new_news = new_news.copy()
|
40 |
# dbops = DBOperations()
|
@@ -42,7 +51,7 @@ def predict_news_category(old_news: pd.DataFrame, new_news: pd.DataFrame, interp
|
|
42 |
old_news.drop(columns='_id', inplace=True)
|
43 |
# new_news = get_news()
|
44 |
if 'category' not in [*old_news.columns]:
|
45 |
-
|
46 |
if not cols_check([*new_news.columns], [*old_news.columns]):
|
47 |
raise Exeption("New and old cols don't match")
|
48 |
final_df = pd.concat([old_news, new_news], axis=0, ignore_index=True)
|
@@ -52,7 +61,7 @@ def predict_news_category(old_news: pd.DataFrame, new_news: pd.DataFrame, interp
|
|
52 |
final_df['category'] = label
|
53 |
final_df['pred_proba'] = prob
|
54 |
else:
|
55 |
-
|
56 |
if not cols_check([*new_news.columns], [*old_news.columns][:-2]):
|
57 |
raise Exeption("New and old cols don't match")
|
58 |
old_urls = [*old_news['url']]
|
@@ -65,5 +74,6 @@ def predict_news_category(old_news: pd.DataFrame, new_news: pd.DataFrame, interp
|
|
65 |
final_df.drop_duplicates(subset='url', keep='first', inplace=True)
|
66 |
final_df.reset_index(drop=True, inplace=True)
|
67 |
final_df.loc[final_df['pred_proba']<CLASSIFIER_THRESHOLD, 'category'] = 'OTHERS'
|
|
|
68 |
return final_df
|
69 |
|
|
|
1 |
import pandas as pd
|
2 |
import numpy as np
|
3 |
import tensorflow as tf
|
4 |
+
from config import (DISTILBERT_TOKENIZER_N_TOKENS,
|
5 |
+
NEWS_CATEGORY_CLASSIFIER_N_CLASSES,
|
6 |
+
CLASSIFIER_THRESHOLD)
|
7 |
+
|
8 |
+
from logger import get_logger
|
9 |
+
|
10 |
+
logger = get_logger()
|
11 |
|
12 |
def parse_prediction(tflite_pred, label_encoder):
|
13 |
tflite_pred_argmax = np.argmax(tflite_pred, axis=1)
|
|
|
17 |
|
18 |
|
19 |
def inference(text, interpreter, label_encoder, tokenizer):
|
20 |
+
logger.warning('Entering inference()')
|
21 |
batch_size = len(text)
|
22 |
if text != "":
|
23 |
tokens = tokenizer(text, max_length=DISTILBERT_TOKENIZER_N_TOKENS, padding="max_length", truncation=True, return_tensors="tf")
|
|
|
35 |
interpreter.invoke()
|
36 |
tflite_pred = interpreter.get_tensor(output_details["index"])
|
37 |
tflite_pred = parse_prediction(tflite_pred)
|
38 |
+
logger.warning('Exiting inference()')
|
39 |
return tflite_pred
|
40 |
|
41 |
def cols_check(new_cols, old_cols):
|
|
|
43 |
|
44 |
|
45 |
def predict_news_category(old_news: pd.DataFrame, new_news: pd.DataFrame, interpreter, label_encoder, tokenizer):
|
46 |
+
logger.warning('Entering predicting_news_category()')
|
47 |
old_news = old_news.copy()
|
48 |
new_news = new_news.copy()
|
49 |
# dbops = DBOperations()
|
|
|
51 |
old_news.drop(columns='_id', inplace=True)
|
52 |
# new_news = get_news()
|
53 |
if 'category' not in [*old_news.columns]:
|
54 |
+
logger.warning('No prior predictions found in old news')
|
55 |
if not cols_check([*new_news.columns], [*old_news.columns]):
|
56 |
raise Exeption("New and old cols don't match")
|
57 |
final_df = pd.concat([old_news, new_news], axis=0, ignore_index=True)
|
|
|
61 |
final_df['category'] = label
|
62 |
final_df['pred_proba'] = prob
|
63 |
else:
|
64 |
+
logger.warning('Prior predictions found in old news')
|
65 |
if not cols_check([*new_news.columns], [*old_news.columns][:-2]):
|
66 |
raise Exeption("New and old cols don't match")
|
67 |
old_urls = [*old_news['url']]
|
|
|
74 |
final_df.drop_duplicates(subset='url', keep='first', inplace=True)
|
75 |
final_df.reset_index(drop=True, inplace=True)
|
76 |
final_df.loc[final_df['pred_proba']<CLASSIFIER_THRESHOLD, 'category'] = 'OTHERS'
|
77 |
+
logger.warning('Exiting predicting_news_category()')
|
78 |
return final_df
|
79 |
|