lalithadevi's picture
Update app.py
f7374b3 verified
raw history blame
No virus
3.37 kB
from news_extractor import get_news
from db_operations.db_write import DBWrite
from db_operations.db_read import DBRead
from news_category_similar_news_prediction import predict_news_category_similar_news
import json
from flask import Flask, Response
from flask_cors import cross_origin, CORS
import logging
import tensorflow as tf
import cloudpickle
from transformers import DistilBertTokenizerFast
import os
from logger import get_logger
import gc
from find_similar_news import TextVectorizer, get_milvus_collection, load_sentence_transformer
app = Flask(__name__)
CORS(app)
logger = get_logger()
logger.warning('Entering application')
os.environ["TOKENIZERS_PARALLELISM"] = "true"
def load_model():
logger.warning('Entering load transformer')
interpreter = tf.lite.Interpreter(model_path=os.path.join("models/news_classification_hf_distilbert.tflite"))
with open("models/news_classification_labelencoder.bin", "rb") as model_file_obj:
label_encoder = cloudpickle.load(model_file_obj)
model_checkpoint = "distilbert-base-uncased"
tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint)
logger.warning('Exiting load transformer')
return interpreter, label_encoder, tokenizer
interpreter, label_encoder, tokenizer = load_model()
vectorizer = TextVectorizer()
collection = get_milvus_collection()
sent_model, ce_model = load_sentence_transformer()
@app.route("/")
@cross_origin()
def update_news():
logger.warning('Entering update_news()')
status_json = "{'status':'success', 'message':'success'}"
status_code = 200
try:
db_read = DBRead()
db_write = DBWrite()
old_news = db_read.read_news_from_db()
new_news = get_news()
news_df, is_db_updation_required = predict_news_category_similar_news(old_news, new_news, interpreter, label_encoder,
tokenizer, collection, vectorizer, sent_model, ce_model)
if news_df is None:
raise Exception('Could not generate category predictions. Aborting the database insertion. No new articles are inserted into the collection.')
# old_news_count = 0 if old_news is None else len(old_news)
# new_news_count = 0 if news_df is None else len(news_df)
# logger.warning(f'Old News count: {old_news_count}\nNew News count: {new_news_count}')
# if new_news_count < old_news_count:
# raise Exception('New news count < Old news count. Aborting the database insertion. No new articles are inserted into the collection.')
if is_db_updation_required:
news_json = [*json.loads(news_df.reset_index(drop=True).to_json(orient="index")).values()]
db_write.insert_news_into_db(news_json)
else:
logger.warning('DB is not updated as it is not required.')
except Exception as e:
status_json = "{'status':'failure', 'message':'" + str(e) + "'}"
status_code = 500
logger.warning(f'ERROR IN update_news(): {e}')
logger.warning('Exiting update_news()')
gc.collect()
return Response(status_json, status=status_code, mimetype='application/json')
logger.warning('Exiting application')
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, timeout=10000, workers=1, threads=1)