File size: 6,018 Bytes
c2c5fc6 74a2149 1462b37 74a2149 c2c5fc6 74a2149 c2c5fc6 9e645b6 c2c5fc6 bb7fe56 c2c5fc6 bb7fe56 c2c5fc6 77aa593 74a2149 c2c5fc6 1462b37 f442802 0b44a23 f442802 f97e412 f442802 0cb5606 f442802 1462b37 4b6249a f442802 4b6249a d5f669f f442802 4b6249a 0cb5606 f442802 d5f669f 4b6249a d5f669f 4b6249a d5f669f 0b44a23 d5f669f 2536052 f442802 0b44a23 c2c5fc6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import pandas as pd
import numpy as np
import tensorflow as tf
from config import (DISTILBERT_TOKENIZER_N_TOKENS,
NEWS_CATEGORY_CLASSIFIER_N_CLASSES,
CLASSIFIER_THRESHOLD)
from logger import get_logger
from find_similar_news import find_similar_news
logger = get_logger()
def parse_prediction(tflite_pred, label_encoder):
tflite_pred_argmax = np.argmax(tflite_pred, axis=1)
tflite_pred_label = label_encoder.inverse_transform(tflite_pred_argmax)
tflite_pred_prob = np.max(tflite_pred, axis=1)
return tflite_pred_label, tflite_pred_prob
def inference(text, interpreter, label_encoder, tokenizer):
logger.warning('Entering inference()')
batch_size = len(text)
logger.warning(f'Samples to predict: {batch_size}')
if text != "":
tokens = tokenizer(text, max_length=DISTILBERT_TOKENIZER_N_TOKENS, padding="max_length", truncation=True, return_tensors="tf")
# tflite model inference
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()[0]
attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids']
interpreter.resize_tensor_input(input_details[0]['index'],[batch_size, DISTILBERT_TOKENIZER_N_TOKENS])
interpreter.resize_tensor_input(input_details[1]['index'],[batch_size, DISTILBERT_TOKENIZER_N_TOKENS])
interpreter.resize_tensor_input(output_details['index'],[batch_size, NEWS_CATEGORY_CLASSIFIER_N_CLASSES])
interpreter.allocate_tensors()
interpreter.set_tensor(input_details[0]["index"], attention_mask)
interpreter.set_tensor(input_details[1]["index"], input_ids)
interpreter.invoke()
tflite_pred = interpreter.get_tensor(output_details["index"])
tflite_pred = parse_prediction(tflite_pred, label_encoder)
logger.warning('Exiting inference()')
return tflite_pred
def cols_check(new_cols, old_cols):
return all([new_col==old_col for new_col, old_col in zip(new_cols, old_cols)])
def predict_news_category_similar_news(old_news: pd.DataFrame, new_news: pd.DataFrame, interpreter, label_encoder, tokenizer,
collection, vectorizer, sent_model, ce_model):
try:
db_updation_required = 1
logger.warning('Entering predict_news_category()')
logger.warning(f'old news: {old_news}\nnew_news: {new_news}')
if not isinstance(new_news, pd.DataFrame):
raise Exception('No New News Found')
else:
new_news = new_news.copy()
logger.warning(f'new news columns: {[*new_news.columns]}')
logger.warning(f'{len(new_news)} new news items found')
if isinstance(old_news, pd.DataFrame):
old_news = old_news.copy()
logger.warning(f'old news columns: {[*old_news.columns]}')
logger.warning(f'{len(old_news)} old news items found')
old_news.drop(columns='_id', inplace=True)
logger.warning('Dropped _id column from old news data frame.')
else:
logger.warning('No old news is found')
old_news = new_news.copy()
if 'category' not in [*old_news.columns]:
logger.warning('No prior predictions found in old news')
if not cols_check([*new_news.columns], [*old_news.columns]):
raise Exception("New and old cols don't match")
final_df = pd.concat([old_news, new_news], axis=0, ignore_index=True)
final_df.drop_duplicates(subset='url', keep='first', inplace=True)
headlines = [*final_df['title']].copy()
label, prob = inference(headlines, interpreter, label_encoder, tokenizer)
sent_embs = vectorizer.vectorize(headlines)
sim_news = [find_similar_news(search_vec, collection, vectorizer, sent_model, ce_model) for search_vec in sent_embs]
final_df['category'] = label
final_df['pred_proba'] = prob
final_df['similar_news'] = sim_news
final_df.reset_index(drop=True, inplace=True)
final_df.loc[final_df['pred_proba']<CLASSIFIER_THRESHOLD, 'category'] = 'OTHERS'
else:
logger.warning('Prior predictions found in old news')
if not cols_check([*new_news.columns], [*old_news.columns][:-3]):
raise Exception("New and old cols don't match")
old_urls = [*old_news['url']]
new_news = new_news.loc[new_news['url'].isin(old_urls) == False, :]
if len(new_news) > 0:
headlines = [*new_news['title']].copy()
label, prob = inference(headlines, interpreter, label_encoder, tokenizer)
sent_embs = vectorizer.vectorize(headlines)
sim_news = [find_similar_news(search_vec, collection, vectorizer, sent_model, ce_model) for search_vec in sent_embs]
new_news['category'] = label
new_news['pred_proba'] = prob
final_df['similar_news'] = sim_news
final_df = pd.concat([old_news, new_news], axis=0, ignore_index=True)
final_df.drop_duplicates(subset='url', keep='first', inplace=True)
final_df.reset_index(drop=True, inplace=True)
final_df.loc[final_df['pred_proba']<CLASSIFIER_THRESHOLD, 'category'] = 'OTHERS'
else:
logger.warning('INFO: Old & New Articles are the same. There is no requirement of updating them in the database. Database is not updated.')
db_updation_required = 0
final_df = old_news.copy()
if len(final_df) == 0:
final_df = None
logger.warning('Exiting predict_news_category()')
except Exception as e:
logger.warning(f'Unexcpected error in predict_news_category()\n{e}')
return None
return final_df, db_updation_required
|