lalithadevi
commited on
Commit
•
bb7fe56
1
Parent(s):
fa2ba95
Update news_category_prediction.py
Browse files
news_category_prediction.py
CHANGED
@@ -12,18 +12,16 @@ def parse_prediction(tflite_pred, label_encoder):
|
|
12 |
|
13 |
def inference(text, interpreter, label_encoder, tokenizer):
|
14 |
batch_size = len(text)
|
15 |
-
MAX_LEN = 80
|
16 |
-
N_CLASSES = 8
|
17 |
if text != "":
|
18 |
-
tokens = tokenizer(text, max_length=
|
19 |
# tflite model inference
|
20 |
interpreter.allocate_tensors()
|
21 |
input_details = interpreter.get_input_details()
|
22 |
output_details = interpreter.get_output_details()[0]
|
23 |
attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids']
|
24 |
-
interpreter.resize_tensor_input(input_details[0]['index'],[batch_size,
|
25 |
-
interpreter.resize_tensor_input(input_details[1]['index'],[batch_size,
|
26 |
-
interpreter.resize_tensor_input(output_details['index'],[batch_size,
|
27 |
interpreter.allocate_tensors()
|
28 |
interpreter.set_tensor(input_details[0]["index"], attention_mask)
|
29 |
interpreter.set_tensor(input_details[1]["index"], input_ids)
|
@@ -66,6 +64,6 @@ def predict_news_category(old_news: pd.DataFrame, new_news: pd.DataFrame, interp
|
|
66 |
final_df = pd.concat([old_news, new_news], axis=0, ignore_index=True)
|
67 |
final_df.drop_duplicates(subset='url', keep='first', inplace=True)
|
68 |
final_df.reset_index(drop=True, inplace=True)
|
69 |
-
final_df.loc[final_df['pred_proba']<
|
70 |
return final_df
|
71 |
|
|
|
12 |
|
13 |
def inference(text, interpreter, label_encoder, tokenizer):
|
14 |
batch_size = len(text)
|
|
|
|
|
15 |
if text != "":
|
16 |
+
tokens = tokenizer(text, max_length=DISTILBERT_TOKENIZER_N_TOKENS, padding="max_length", truncation=True, return_tensors="tf")
|
17 |
# tflite model inference
|
18 |
interpreter.allocate_tensors()
|
19 |
input_details = interpreter.get_input_details()
|
20 |
output_details = interpreter.get_output_details()[0]
|
21 |
attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids']
|
22 |
+
interpreter.resize_tensor_input(input_details[0]['index'],[batch_size, DISTILBERT_TOKENIZER_N_TOKENS])
|
23 |
+
interpreter.resize_tensor_input(input_details[1]['index'],[batch_size, DISTILBERT_TOKENIZER_N_TOKENS])
|
24 |
+
interpreter.resize_tensor_input(output_details['index'],[batch_size, NEWS_CATEGORY_CLASSIFIER_N_CLASSES])
|
25 |
interpreter.allocate_tensors()
|
26 |
interpreter.set_tensor(input_details[0]["index"], attention_mask)
|
27 |
interpreter.set_tensor(input_details[1]["index"], input_ids)
|
|
|
64 |
final_df = pd.concat([old_news, new_news], axis=0, ignore_index=True)
|
65 |
final_df.drop_duplicates(subset='url', keep='first', inplace=True)
|
66 |
final_df.reset_index(drop=True, inplace=True)
|
67 |
+
final_df.loc[final_df['pred_proba']<CLASSIFIER_THRESHOLD, 'category'] = 'OTHERS'
|
68 |
return final_df
|
69 |
|