File size: 2,719 Bytes
e857da4 c2c5fc6 e857da4 99ce67b e857da4 5116708 c2c5fc6 4ed8f54 e857da4 4ed8f54 a9a9d3a e857da4 c2c5fc6 2f5f99d c2c5fc6 2f5f99d c2c5fc6 e857da4 2f5f99d c3395ab eec495e e857da4 c2c5fc6 c3395ab 02dba55 e857da4 c2c5fc6 c3395ab eec495e 50cbe24 02dba55 2f5f99d eec495e e857da4 a9a9d3a e857da4 cf90c5c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
from news_extractor import get_news
from db_operations.db_write import DBWrite
from db_operations.db_read import DBRead
from news_category_prediction import predict_news_category
import json
from flask import Flask, Response
from flask_cors import cross_origin, CORS
import logging
import tensorflow as tf
import cloudpickle
from transformers import DistilBertTokenizerFast
import os
from logger import get_logger
app = Flask(__name__)
CORS(app)
logger = get_logger()
logger.warning('Entering application')
def load_model():
logger.warning('Entering load transformer')
interpreter = tf.lite.Interpreter(model_path=os.path.join("models/news_classification_hf_distilbert.tflite"))
with open("models/news_classification_labelencoder.bin", "rb") as model_file_obj:
label_encoder = cloudpickle.load(model_file_obj)
model_checkpoint = "distilbert-base-uncased"
tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint)
logger.warning('Exiting load transformer')
return interpreter, label_encoder, tokenizer
interpreter, label_encoder, tokenizer = load_model()
@app.route("/")
@cross_origin()
def update_news():
logger.warning('Entering update_news()')
status_json = "{'status':'success', 'message':'success'}"
status_code = 200
try:
db_read = DBRead()
db_write = DBWrite()
old_news = db_read.read_news_from_db()
new_news = get_news()
news_df = predict_news_category(old_news, new_news, interpreter, label_encoder, tokenizer)
if news_df is None:
raise Exception('Could not generate category predictions. Aborting the database insertion. No new articles are inserted into the collection.')
old_news_count = 0 if old_news is None else len(old_news)
new_news_count = 0 if news_df is None else len(news_df)
logger.warning(f'Old News count: {old_news_count}\nNew News count: {new_news_count}')
if new_news_count < old_news_count:
raise Exception('New news count < Old news count. Aborting the database insertion. No new articles are inserted into the collection.')
news_json = [*json.loads(news_df.reset_index(drop=True).to_json(orient="index")).values()]
db_write.insert_news_into_db(news_json)
except Exception as e:
status_json = "{'status':'failure', 'message':'" + str(e) + "'}"
status_code = 500
raise
logger.warning('Exiting update_news()')
return Response(status_json, status=status_code, mimetype='application/json')
logger.warning('Exiting application')
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, timeout=1000, workers=1, threads=1)
|