lalithadevi commited on
Commit
bb7fe56
1 Parent(s): fa2ba95

Update news_category_prediction.py

Browse files
Files changed (1) hide show
  1. news_category_prediction.py +5 -7
news_category_prediction.py CHANGED
@@ -12,18 +12,16 @@ def parse_prediction(tflite_pred, label_encoder):
12
 
13
  def inference(text, interpreter, label_encoder, tokenizer):
14
  batch_size = len(text)
15
- MAX_LEN = 80
16
- N_CLASSES = 8
17
  if text != "":
18
- tokens = tokenizer(text, max_length=MAX_LEN, padding="max_length", truncation=True, return_tensors="tf")
19
  # tflite model inference
20
  interpreter.allocate_tensors()
21
  input_details = interpreter.get_input_details()
22
  output_details = interpreter.get_output_details()[0]
23
  attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids']
24
- interpreter.resize_tensor_input(input_details[0]['index'],[batch_size, MAX_LEN])
25
- interpreter.resize_tensor_input(input_details[1]['index'],[batch_size, MAX_LEN])
26
- interpreter.resize_tensor_input(output_details['index'],[batch_size, N_CLASSES])
27
  interpreter.allocate_tensors()
28
  interpreter.set_tensor(input_details[0]["index"], attention_mask)
29
  interpreter.set_tensor(input_details[1]["index"], input_ids)
@@ -66,6 +64,6 @@ def predict_news_category(old_news: pd.DataFrame, new_news: pd.DataFrame, interp
66
  final_df = pd.concat([old_news, new_news], axis=0, ignore_index=True)
67
  final_df.drop_duplicates(subset='url', keep='first', inplace=True)
68
  final_df.reset_index(drop=True, inplace=True)
69
- final_df.loc[final_df['pred_proba']<0.65, 'category'] = 'OTHERS'
70
  return final_df
71
 
 
12
 
13
  def inference(text, interpreter, label_encoder, tokenizer):
14
  batch_size = len(text)
 
 
15
  if text != "":
16
+ tokens = tokenizer(text, max_length=DISTILBERT_TOKENIZER_N_TOKENS, padding="max_length", truncation=True, return_tensors="tf")
17
  # tflite model inference
18
  interpreter.allocate_tensors()
19
  input_details = interpreter.get_input_details()
20
  output_details = interpreter.get_output_details()[0]
21
  attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids']
22
+ interpreter.resize_tensor_input(input_details[0]['index'],[batch_size, DISTILBERT_TOKENIZER_N_TOKENS])
23
+ interpreter.resize_tensor_input(input_details[1]['index'],[batch_size, DISTILBERT_TOKENIZER_N_TOKENS])
24
+ interpreter.resize_tensor_input(output_details['index'],[batch_size, NEWS_CATEGORY_CLASSIFIER_N_CLASSES])
25
  interpreter.allocate_tensors()
26
  interpreter.set_tensor(input_details[0]["index"], attention_mask)
27
  interpreter.set_tensor(input_details[1]["index"], input_ids)
 
64
  final_df = pd.concat([old_news, new_news], axis=0, ignore_index=True)
65
  final_df.drop_duplicates(subset='url', keep='first', inplace=True)
66
  final_df.reset_index(drop=True, inplace=True)
67
+ final_df.loc[final_df['pred_proba']<CLASSIFIER_THRESHOLD, 'category'] = 'OTHERS'
68
  return final_df
69