File size: 3,328 Bytes
c2c5fc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import pandas as pd
import numpy as np
import tensorflow as tf 


def parse_prediction(tflite_pred, label_encoder):
    tflite_pred_argmax = np.argmax(tflite_pred, axis=1)
    tflite_pred_label = label_encoder.inverse_transform(tflite_pred_argmax)
    tflite_pred_prob = np.max(tflite_pred, axis=1)
    return tflite_pred_label, tflite_pred_prob
    

def inference(text, interpreter, label_encoder, tokenizer):
    batch_size = len(text)
    MAX_LEN = 80
    N_CLASSES = 8
    if text != "":
        tokens = tokenizer(text, max_length=MAX_LEN, padding="max_length", truncation=True, return_tensors="tf")
        # tflite model inference  
        interpreter.allocate_tensors()
        input_details = interpreter.get_input_details()
        output_details = interpreter.get_output_details()[0]
        attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids']
        interpreter.resize_tensor_input(input_details[0]['index'],[batch_size, MAX_LEN])
        interpreter.resize_tensor_input(input_details[1]['index'],[batch_size, MAX_LEN])
        interpreter.resize_tensor_input(output_details['index'],[batch_size, N_CLASSES])
        interpreter.allocate_tensors()
        interpreter.set_tensor(input_details[0]["index"], attention_mask)
        interpreter.set_tensor(input_details[1]["index"], input_ids)
        interpreter.invoke()
        tflite_pred = interpreter.get_tensor(output_details["index"])
        tflite_pred = parse_prediction(tflite_pred)
    return tflite_pred

def cols_check(new_cols, old_cols):
    return all([new_col==old_col for new_col, old_col in zip(new_cols, old_cols)])

    
def predict_news_category(old_news: pd.DataFrame, new_news: pd.DataFrame, interpreter, label_encoder, tokenizer):
    old_news = old_news.copy()
    new_news = new_news.copy()
    # dbops = DBOperations()
    # old_news = dbops.read_news_from_db()
    old_news.drop(columns='_id', inplace=True)
    # new_news = get_news()  
    if 'category' not in [*old_news.columns]:
        print('no prior predictions found')
        if not cols_check([*new_news.columns], [*old_news.columns]):
            raise Exeption("New and old cols don't match")
        final_df = pd.concat([old_news, new_news], axis=0, ignore_index=True)
        final_df.drop_duplicates(subset='url', keep='first', inplace=True)
        headlines = [*final_df['title']].copy()
        label, prob = inference(headlines, interpreter, label_encoder, tokenizer)
        final_df['category'] = label
        final_df['pred_proba'] = prob
    else:
        print('prior predictions found')
        if not cols_check([*new_news.columns], [*old_news.columns][:-2]):
            raise Exeption("New and old cols don't match")
        old_urls = [*old_news['url']]
        new_news = new_news.loc[new_news['url'].isin(old_urls) == False, :]
        headlines = [*new_news['title']].copy()
        label, prob = inference(headlines, interpreter, label_encoder, tokenizer)
        new_news['category'] = label
        new_news['pred_proba'] = prob
        final_df = pd.concat([old_news, new_news], axis=0, ignore_index=True)
        final_df.drop_duplicates(subset='url', keep='first', inplace=True)
    final_df.reset_index(drop=True, inplace=True)
    final_df.loc[final_df['pred_proba']<0.65, 'category'] = 'OTHERS'
    return final_df