File size: 3,672 Bytes
e857da4
c2c5fc6
 
4f1a241
e857da4
99ce67b
e857da4
5116708
c2c5fc6
 
 
 
4ed8f54
f7374b3
92aa25c
e857da4
 
 
4ed8f54
 
a9a9d3a
8a32255
e857da4
c2c5fc6
2f5f99d
c2c5fc6
 
 
 
 
 
2f5f99d
c2c5fc6
 
 
92aa25c
 
 
c2c5fc6
 
e857da4
 
 
2f5f99d
c3395ab
eec495e
e857da4
c2c5fc6
b8405af
 
c2c5fc6
 
b8405af
4f1a241
c3395ab
02dba55
 
761214c
 
 
 
 
1da191b
 
 
b8405af
 
1da191b
b8405af
 
1da191b
 
c3395ab
 
eec495e
60edc87
02dba55
2f5f99d
f7374b3
eec495e
e857da4
a9a9d3a
e857da4
 
4812905
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from news_extractor import get_news
from db_operations.db_write import DBWrite
from db_operations.db_read import DBRead
from news_category_similar_news_prediction import predict_news_category_similar_news
import json
from flask import Flask, Response
from flask_cors import cross_origin, CORS
import logging
import tensorflow as tf 
import cloudpickle
from transformers import DistilBertTokenizerFast
import os
from logger import get_logger
import gc
from find_similar_news import TextVectorizer, get_milvus_collection, load_sentence_transformer

app = Flask(__name__)
CORS(app)

logger = get_logger()
logger.warning('Entering application')
os.environ["TOKENIZERS_PARALLELISM"] = "true"

def load_model():
    logger.warning('Entering load transformer')
    interpreter = tf.lite.Interpreter(model_path=os.path.join("models/news_classification_hf_distilbert.tflite"))
    with open("models/news_classification_labelencoder.bin", "rb") as model_file_obj:
        label_encoder = cloudpickle.load(model_file_obj)
        
    model_checkpoint = "distilbert-base-uncased"
    tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint)
    logger.warning('Exiting load transformer')
    return interpreter, label_encoder, tokenizer

interpreter, label_encoder, tokenizer = load_model()
vectorizer = TextVectorizer()
collection = get_milvus_collection()
sent_model, ce_model = load_sentence_transformer()


@app.route("/")
@cross_origin()
def update_news():
    logger.warning('Entering update_news()')
    status_json = "{'status':'success', 'message':'success'}"
    status_code = 200
    try:
        db_read = DBRead()
        db_write = DBWrite(db_type="production")
        prediction_db_write = DBWrite(db_type="prediction")
        old_news = db_read.read_news_from_db()
        new_news = get_news()
        news_df, prediction_df, is_db_updation_required = predict_news_category_similar_news(old_news, new_news, interpreter, label_encoder, 
                                                                              tokenizer, collection, vectorizer, sent_model, ce_model)
        if news_df is None:
            raise Exception('Could not generate category predictions. Aborting the database insertion. No new articles are inserted into the collection.')
            
        # old_news_count = 0 if old_news is None else len(old_news)
        # new_news_count = 0 if news_df is None else len(news_df)
        # logger.warning(f'Old News count: {old_news_count}\nNew News count: {new_news_count}')
        # if new_news_count < old_news_count:
            # raise Exception('New news count < Old news count. Aborting the database insertion. No new articles are inserted into the collection.')

        if is_db_updation_required:
            news_json = [*json.loads(news_df.reset_index(drop=True).to_json(orient="index")).values()]
            prediction_json = [*json.loads(prediction_df.reset_index(drop=True).to_json(orient="index")).values()]
            
            db_write.insert_news_into_db(news_json)
            prediction_db_write.insert_news_into_db(prediction_json)
            
        else:
            logger.warning('DB is not updated as it is not required.')
    except Exception as e:
        status_json = "{'status':'failure', 'message':'" + str(e) + "'}"
        status_code = 500
        logger.warning(f'ERROR IN update_news(): {e}')
    
    logger.warning('Exiting update_news()')
    gc.collect()
    return Response(status_json, status=status_code, mimetype='application/json')

logger.warning('Exiting application')

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860, timeout=10000, workers=1, threads=1)