|
import pandas as pd |
|
import numpy as np |
|
import tensorflow as tf |
|
from config import (DISTILBERT_TOKENIZER_N_TOKENS, |
|
NEWS_CATEGORY_CLASSIFIER_N_CLASSES, |
|
CLASSIFIER_THRESHOLD) |
|
|
|
from logger import get_logger |
|
|
|
logger = get_logger() |
|
|
|
def parse_prediction(tflite_pred, label_encoder): |
|
tflite_pred_argmax = np.argmax(tflite_pred, axis=1) |
|
tflite_pred_label = label_encoder.inverse_transform(tflite_pred_argmax) |
|
tflite_pred_prob = np.max(tflite_pred, axis=1) |
|
return tflite_pred_label, tflite_pred_prob |
|
|
|
|
|
def inference(text, interpreter, label_encoder, tokenizer): |
|
logger.warning('Entering inference()') |
|
batch_size = len(text) |
|
logger.warning(f'Samples to predict: {batch_size}') |
|
if text != "": |
|
tokens = tokenizer(text, max_length=DISTILBERT_TOKENIZER_N_TOKENS, padding="max_length", truncation=True, return_tensors="tf") |
|
|
|
interpreter.allocate_tensors() |
|
input_details = interpreter.get_input_details() |
|
output_details = interpreter.get_output_details()[0] |
|
attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids'] |
|
interpreter.resize_tensor_input(input_details[0]['index'],[batch_size, DISTILBERT_TOKENIZER_N_TOKENS]) |
|
interpreter.resize_tensor_input(input_details[1]['index'],[batch_size, DISTILBERT_TOKENIZER_N_TOKENS]) |
|
interpreter.resize_tensor_input(output_details['index'],[batch_size, NEWS_CATEGORY_CLASSIFIER_N_CLASSES]) |
|
interpreter.allocate_tensors() |
|
interpreter.set_tensor(input_details[0]["index"], attention_mask) |
|
interpreter.set_tensor(input_details[1]["index"], input_ids) |
|
interpreter.invoke() |
|
tflite_pred = interpreter.get_tensor(output_details["index"]) |
|
tflite_pred = parse_prediction(tflite_pred, label_encoder) |
|
logger.warning('Exiting inference()') |
|
return tflite_pred |
|
|
|
def cols_check(new_cols, old_cols): |
|
return all([new_col==old_col for new_col, old_col in zip(new_cols, old_cols)]) |
|
|
|
|
|
def predict_news_category(old_news: pd.DataFrame, new_news: pd.DataFrame, interpreter, label_encoder, tokenizer): |
|
try: |
|
logger.warning('Entering predict_news_category()') |
|
logger.warning(f'old news: {old_news}\nnew_news: {new_news}') |
|
if not isinstance(new_news, pd.DataFrame): |
|
raise Exception('No New News Found') |
|
else: |
|
new_news = new_news.copy() |
|
logger.warning(f'new news columns: {[*new_news.columns]}') |
|
logger.warning(f'{len(new_news)} new news items found') |
|
|
|
if isinstance(old_news, pd.DataFrame): |
|
old_news = old_news.copy() |
|
logger.warning(f'old news columns: {[*old_news.columns]}') |
|
logger.warning(f'{len(old_news)} old news items found') |
|
old_news.drop(columns='_id', inplace=True) |
|
logger.warning('Dropped _id column from old news data frame.') |
|
else: |
|
logger.warning('No old news is found') |
|
old_news = new_news.copy() |
|
|
|
if 'category' not in [*old_news.columns]: |
|
logger.warning('No prior predictions found in old news') |
|
if not cols_check([*new_news.columns], [*old_news.columns]): |
|
raise Exception("New and old cols don't match") |
|
final_df = pd.concat([old_news, new_news], axis=0, ignore_index=True) |
|
final_df.drop_duplicates(subset='url', keep='first', inplace=True) |
|
headlines = [*final_df['title']].copy() |
|
label, prob = inference(headlines, interpreter, label_encoder, tokenizer) |
|
final_df['category'] = label |
|
final_df['pred_proba'] = prob |
|
else: |
|
logger.warning('Prior predictions found in old news') |
|
if not cols_check([*new_news.columns], [*old_news.columns][:-2]): |
|
raise Exception("New and old cols don't match") |
|
old_urls = [*old_news['url']] |
|
new_news = new_news.loc[new_news['url'].isin(old_urls) == False, :] |
|
headlines = [*new_news['title']].copy() |
|
label, prob = inference(headlines, interpreter, label_encoder, tokenizer) |
|
new_news['category'] = label |
|
new_news['pred_proba'] = prob |
|
final_df = pd.concat([old_news, new_news], axis=0, ignore_index=True) |
|
final_df.drop_duplicates(subset='url', keep='first', inplace=True) |
|
final_df.reset_index(drop=True, inplace=True) |
|
final_df.loc[final_df['pred_proba']<CLASSIFIER_THRESHOLD, 'category'] = 'OTHERS' |
|
|
|
if len(final_df) == 0: |
|
final_df = None |
|
|
|
logger.warning('Exiting predict_news_category()') |
|
except Exception as e: |
|
logger.warning(f'Unexcpected error in predict_news_category()\n{e}') |
|
return None |
|
return final_df |
|
|