lalithadevi commited on
Commit
74a2149
1 Parent(s): 0a863df

Update news_category_prediction.py

Browse files
Files changed (1) hide show
  1. news_category_prediction.py +13 -3
news_category_prediction.py CHANGED
@@ -1,7 +1,13 @@
1
  import pandas as pd
2
  import numpy as np
3
  import tensorflow as tf
4
- from config import DISTILBERT_TOKENIZER_N_TOKENS, NEWS_CATEGORY_CLASSIFIER_N_CLASSES, CLASSIFIER_THRESHOLD
 
 
 
 
 
 
5
 
6
  def parse_prediction(tflite_pred, label_encoder):
7
  tflite_pred_argmax = np.argmax(tflite_pred, axis=1)
@@ -11,6 +17,7 @@ def parse_prediction(tflite_pred, label_encoder):
11
 
12
 
13
  def inference(text, interpreter, label_encoder, tokenizer):
 
14
  batch_size = len(text)
15
  if text != "":
16
  tokens = tokenizer(text, max_length=DISTILBERT_TOKENIZER_N_TOKENS, padding="max_length", truncation=True, return_tensors="tf")
@@ -28,6 +35,7 @@ def inference(text, interpreter, label_encoder, tokenizer):
28
  interpreter.invoke()
29
  tflite_pred = interpreter.get_tensor(output_details["index"])
30
  tflite_pred = parse_prediction(tflite_pred)
 
31
  return tflite_pred
32
 
33
  def cols_check(new_cols, old_cols):
@@ -35,6 +43,7 @@ def cols_check(new_cols, old_cols):
35
 
36
 
37
  def predict_news_category(old_news: pd.DataFrame, new_news: pd.DataFrame, interpreter, label_encoder, tokenizer):
 
38
  old_news = old_news.copy()
39
  new_news = new_news.copy()
40
  # dbops = DBOperations()
@@ -42,7 +51,7 @@ def predict_news_category(old_news: pd.DataFrame, new_news: pd.DataFrame, interp
42
  old_news.drop(columns='_id', inplace=True)
43
  # new_news = get_news()
44
  if 'category' not in [*old_news.columns]:
45
- print('no prior predictions found')
46
  if not cols_check([*new_news.columns], [*old_news.columns]):
47
  raise Exeption("New and old cols don't match")
48
  final_df = pd.concat([old_news, new_news], axis=0, ignore_index=True)
@@ -52,7 +61,7 @@ def predict_news_category(old_news: pd.DataFrame, new_news: pd.DataFrame, interp
52
  final_df['category'] = label
53
  final_df['pred_proba'] = prob
54
  else:
55
- print('prior predictions found')
56
  if not cols_check([*new_news.columns], [*old_news.columns][:-2]):
57
  raise Exeption("New and old cols don't match")
58
  old_urls = [*old_news['url']]
@@ -65,5 +74,6 @@ def predict_news_category(old_news: pd.DataFrame, new_news: pd.DataFrame, interp
65
  final_df.drop_duplicates(subset='url', keep='first', inplace=True)
66
  final_df.reset_index(drop=True, inplace=True)
67
  final_df.loc[final_df['pred_proba']<CLASSIFIER_THRESHOLD, 'category'] = 'OTHERS'
 
68
  return final_df
69
 
 
1
  import pandas as pd
2
  import numpy as np
3
  import tensorflow as tf
4
+ from config import (DISTILBERT_TOKENIZER_N_TOKENS,
5
+ NEWS_CATEGORY_CLASSIFIER_N_CLASSES,
6
+ CLASSIFIER_THRESHOLD)
7
+
8
+ from logger import get_logger
9
+
10
+ logger = get_logger()
11
 
12
  def parse_prediction(tflite_pred, label_encoder):
13
  tflite_pred_argmax = np.argmax(tflite_pred, axis=1)
 
17
 
18
 
19
  def inference(text, interpreter, label_encoder, tokenizer):
20
+ logger.warning('Entering inference()')
21
  batch_size = len(text)
22
  if text != "":
23
  tokens = tokenizer(text, max_length=DISTILBERT_TOKENIZER_N_TOKENS, padding="max_length", truncation=True, return_tensors="tf")
 
35
  interpreter.invoke()
36
  tflite_pred = interpreter.get_tensor(output_details["index"])
37
  tflite_pred = parse_prediction(tflite_pred)
38
+ logger.warning('Exiting inference()')
39
  return tflite_pred
40
 
41
  def cols_check(new_cols, old_cols):
 
43
 
44
 
45
  def predict_news_category(old_news: pd.DataFrame, new_news: pd.DataFrame, interpreter, label_encoder, tokenizer):
46
+ logger.warning('Entering predicting_news_category()')
47
  old_news = old_news.copy()
48
  new_news = new_news.copy()
49
  # dbops = DBOperations()
 
51
  old_news.drop(columns='_id', inplace=True)
52
  # new_news = get_news()
53
  if 'category' not in [*old_news.columns]:
54
+ logger.warning('No prior predictions found in old news')
55
  if not cols_check([*new_news.columns], [*old_news.columns]):
56
  raise Exeption("New and old cols don't match")
57
  final_df = pd.concat([old_news, new_news], axis=0, ignore_index=True)
 
61
  final_df['category'] = label
62
  final_df['pred_proba'] = prob
63
  else:
64
+ logger.warning('Prior predictions found in old news')
65
  if not cols_check([*new_news.columns], [*old_news.columns][:-2]):
66
  raise Exeption("New and old cols don't match")
67
  old_urls = [*old_news['url']]
 
74
  final_df.drop_duplicates(subset='url', keep='first', inplace=True)
75
  final_df.reset_index(drop=True, inplace=True)
76
  final_df.loc[final_df['pred_proba']<CLASSIFIER_THRESHOLD, 'category'] = 'OTHERS'
77
+ logger.warning('Exiting predicting_news_category()')
78
  return final_df
79