File size: 6,018 Bytes
c2c5fc6
 
 
74a2149
 
 
 
 
1462b37
74a2149
 
c2c5fc6
 
 
 
 
 
 
 
 
74a2149
c2c5fc6
9e645b6
c2c5fc6
bb7fe56
c2c5fc6
 
 
 
 
bb7fe56
 
 
c2c5fc6
 
 
 
 
77aa593
74a2149
c2c5fc6
 
 
 
 
 
1462b37
 
f442802
0b44a23
f442802
 
 
 
 
 
 
 
 
 
 
 
 
f97e412
 
f442802
 
 
 
 
 
 
0cb5606
f442802
 
 
 
1462b37
4b6249a
f442802
 
4b6249a
d5f669f
 
f442802
 
4b6249a
0cb5606
f442802
 
d5f669f
 
 
4b6249a
 
d5f669f
 
4b6249a
d5f669f
 
 
 
 
0b44a23
 
 
d5f669f
2536052
 
 
 
f442802
 
 
 
0b44a23
c2c5fc6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import pandas as pd
import numpy as np
import tensorflow as tf 
from config import (DISTILBERT_TOKENIZER_N_TOKENS, 
                    NEWS_CATEGORY_CLASSIFIER_N_CLASSES, 
                    CLASSIFIER_THRESHOLD)

from logger import get_logger
from find_similar_news import find_similar_news

logger = get_logger()

def parse_prediction(tflite_pred, label_encoder):
    tflite_pred_argmax = np.argmax(tflite_pred, axis=1)
    tflite_pred_label = label_encoder.inverse_transform(tflite_pred_argmax)
    tflite_pred_prob = np.max(tflite_pred, axis=1)
    return tflite_pred_label, tflite_pred_prob
    

def inference(text, interpreter, label_encoder, tokenizer):
    logger.warning('Entering inference()')
    batch_size = len(text)
    logger.warning(f'Samples to predict: {batch_size}')
    if text != "":
        tokens = tokenizer(text, max_length=DISTILBERT_TOKENIZER_N_TOKENS, padding="max_length", truncation=True, return_tensors="tf")
        # tflite model inference  
        interpreter.allocate_tensors()
        input_details = interpreter.get_input_details()
        output_details = interpreter.get_output_details()[0]
        attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids']
        interpreter.resize_tensor_input(input_details[0]['index'],[batch_size, DISTILBERT_TOKENIZER_N_TOKENS])
        interpreter.resize_tensor_input(input_details[1]['index'],[batch_size, DISTILBERT_TOKENIZER_N_TOKENS])
        interpreter.resize_tensor_input(output_details['index'],[batch_size, NEWS_CATEGORY_CLASSIFIER_N_CLASSES])
        interpreter.allocate_tensors()
        interpreter.set_tensor(input_details[0]["index"], attention_mask)
        interpreter.set_tensor(input_details[1]["index"], input_ids)
        interpreter.invoke()
        tflite_pred = interpreter.get_tensor(output_details["index"])
        tflite_pred = parse_prediction(tflite_pred, label_encoder)
    logger.warning('Exiting inference()')
    return tflite_pred

def cols_check(new_cols, old_cols):
    return all([new_col==old_col for new_col, old_col in zip(new_cols, old_cols)])

    
def predict_news_category_similar_news(old_news: pd.DataFrame, new_news: pd.DataFrame, interpreter, label_encoder, tokenizer,
                         collection, vectorizer, sent_model, ce_model):
    try:
        db_updation_required = 1
        logger.warning('Entering predict_news_category()')
        logger.warning(f'old news: {old_news}\nnew_news: {new_news}')
        if not isinstance(new_news, pd.DataFrame):
            raise Exception('No New News Found')
        else:
            new_news = new_news.copy()
            logger.warning(f'new news columns: {[*new_news.columns]}')
            logger.warning(f'{len(new_news)} new news items found')
        
        if isinstance(old_news, pd.DataFrame):
            old_news = old_news.copy()
            logger.warning(f'old news columns: {[*old_news.columns]}')
            logger.warning(f'{len(old_news)} old news items found')
            old_news.drop(columns='_id', inplace=True)
            logger.warning('Dropped _id column from old news data frame.')
        else:
            logger.warning('No old news is found')
            old_news = new_news.copy()
      
        if 'category' not in [*old_news.columns]:
            logger.warning('No prior predictions found in old news')
            if not cols_check([*new_news.columns], [*old_news.columns]):
                raise Exception("New and old cols don't match")
            final_df = pd.concat([old_news, new_news], axis=0, ignore_index=True)
            final_df.drop_duplicates(subset='url', keep='first', inplace=True)
            headlines = [*final_df['title']].copy()
            label, prob = inference(headlines, interpreter, label_encoder, tokenizer)
            sent_embs = vectorizer.vectorize(headlines)
            sim_news = [find_similar_news(search_vec, collection, vectorizer, sent_model, ce_model) for search_vec in sent_embs]
            final_df['category'] = label
            final_df['pred_proba'] = prob
            final_df['similar_news'] = sim_news
            final_df.reset_index(drop=True, inplace=True)
            final_df.loc[final_df['pred_proba']<CLASSIFIER_THRESHOLD, 'category'] = 'OTHERS'
        else:
            logger.warning('Prior predictions found in old news')
            if not cols_check([*new_news.columns], [*old_news.columns][:-3]):
                raise Exception("New and old cols don't match")
            old_urls = [*old_news['url']]
            new_news = new_news.loc[new_news['url'].isin(old_urls) == False, :]
            if len(new_news) > 0:
                headlines = [*new_news['title']].copy()
                label, prob = inference(headlines, interpreter, label_encoder, tokenizer)
                sent_embs = vectorizer.vectorize(headlines)
                sim_news = [find_similar_news(search_vec, collection, vectorizer, sent_model, ce_model) for search_vec in sent_embs]
                new_news['category'] = label
                new_news['pred_proba'] = prob
                final_df['similar_news'] = sim_news
                final_df = pd.concat([old_news, new_news], axis=0, ignore_index=True)
                final_df.drop_duplicates(subset='url', keep='first', inplace=True)
                final_df.reset_index(drop=True, inplace=True)
                final_df.loc[final_df['pred_proba']<CLASSIFIER_THRESHOLD, 'category'] = 'OTHERS'
            else:
                logger.warning('INFO: Old & New Articles are the same. There is no requirement of updating them in the database. Database is not updated.')
                db_updation_required = 0
                final_df = old_news.copy()
                
        
        if len(final_df) == 0:
            final_df = None
        
        logger.warning('Exiting predict_news_category()')
    except Exception as e:
        logger.warning(f'Unexcpected error in predict_news_category()\n{e}')
        return None
    return final_df, db_updation_required