File size: 3,340 Bytes
e857da4 c2c5fc6 4f1a241 e857da4 99ce67b e857da4 5116708 c2c5fc6 4ed8f54 92aa25c e857da4 4ed8f54 a9a9d3a 8a32255 e857da4 c2c5fc6 2f5f99d c2c5fc6 2f5f99d c2c5fc6 92aa25c c2c5fc6 e857da4 2f5f99d c3395ab eec495e e857da4 c2c5fc6 4f1a241 c3395ab 02dba55 761214c 1da191b c3395ab eec495e 60edc87 02dba55 2f5f99d eec495e e857da4 a9a9d3a e857da4 4812905 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
from news_extractor import get_news
from db_operations.db_write import DBWrite
from db_operations.db_read import DBRead
from news_category_similar_news_prediction import predict_news_category_similar_news
import json
from flask import Flask, Response
from flask_cors import cross_origin, CORS
import logging
import tensorflow as tf
import cloudpickle
from transformers import DistilBertTokenizerFast
import os
from logger import get_logger
from find_similar_news import TextVectorizer, get_milvus_collection, load_sentence_transformer
app = Flask(__name__)
CORS(app)
logger = get_logger()
logger.warning('Entering application')
os.environ["TOKENIZERS_PARALLELISM"] = "true"
def load_model():
logger.warning('Entering load transformer')
interpreter = tf.lite.Interpreter(model_path=os.path.join("models/news_classification_hf_distilbert.tflite"))
with open("models/news_classification_labelencoder.bin", "rb") as model_file_obj:
label_encoder = cloudpickle.load(model_file_obj)
model_checkpoint = "distilbert-base-uncased"
tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint)
logger.warning('Exiting load transformer')
return interpreter, label_encoder, tokenizer
interpreter, label_encoder, tokenizer = load_model()
vectorizer = TextVectorizer()
collection = get_milvus_collection()
sent_model, ce_model = load_sentence_transformer()
@app.route("/")
@cross_origin()
def update_news():
logger.warning('Entering update_news()')
status_json = "{'status':'success', 'message':'success'}"
status_code = 200
try:
db_read = DBRead()
db_write = DBWrite()
old_news = db_read.read_news_from_db()
new_news = get_news()
news_df, is_db_updation_required = predict_news_category_similar_news(old_news, new_news, interpreter, label_encoder,
tokenizer, collection, vectorizer, sent_model, ce_model)
if news_df is None:
raise Exception('Could not generate category predictions. Aborting the database insertion. No new articles are inserted into the collection.')
# old_news_count = 0 if old_news is None else len(old_news)
# new_news_count = 0 if news_df is None else len(news_df)
# logger.warning(f'Old News count: {old_news_count}\nNew News count: {new_news_count}')
# if new_news_count < old_news_count:
# raise Exception('New news count < Old news count. Aborting the database insertion. No new articles are inserted into the collection.')
if is_db_updation_required:
news_json = [*json.loads(news_df.reset_index(drop=True).to_json(orient="index")).values()]
db_write.insert_news_into_db(news_json)
else:
logger.warning('DB is not updated as it is not required.')
except Exception as e:
status_json = "{'status':'failure', 'message':'" + str(e) + "'}"
status_code = 500
logger.warning(f'ERROR IN update_news(): {e}')
logger.warning('Exiting update_news()')
return Response(status_json, status=status_code, mimetype='application/json')
logger.warning('Exiting application')
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, timeout=10000, workers=1, threads=1)
|