File size: 2,193 Bytes
e857da4 c2c5fc6 e857da4 99ce67b e857da4 5116708 c2c5fc6 4ed8f54 e857da4 4ed8f54 a9a9d3a e857da4 c2c5fc6 2f5f99d c2c5fc6 2f5f99d c2c5fc6 e857da4 2f5f99d c3395ab eec495e e857da4 c2c5fc6 c3395ab e857da4 c2c5fc6 c3395ab eec495e 50cbe24 2f5f99d eec495e e857da4 a9a9d3a e857da4 cf90c5c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
from news_extractor import get_news
from db_operations.db_write import DBWrite
from db_operations.db_read import DBRead
from news_category_prediction import predict_news_category
import json
from flask import Flask, Response
from flask_cors import cross_origin, CORS
import logging
import tensorflow as tf
import cloudpickle
from transformers import DistilBertTokenizerFast
import os
from logger import get_logger
app = Flask(__name__)
CORS(app)
logger = get_logger()
logger.warning('Entering application')
def load_model():
logger.warning('Entering load transformer')
interpreter = tf.lite.Interpreter(model_path=os.path.join("models/news_classification_hf_distilbert.tflite"))
with open("models/news_classification_labelencoder.bin", "rb") as model_file_obj:
label_encoder = cloudpickle.load(model_file_obj)
model_checkpoint = "distilbert-base-uncased"
tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint)
logger.warning('Exiting load transformer')
return interpreter, label_encoder, tokenizer
interpreter, label_encoder, tokenizer = load_model()
@app.route("/")
@cross_origin()
def update_news():
logger.warning('Entering update_news()')
status_json = "{'status':'success', 'message':'success'}"
status_code = 200
try:
db_read = DBRead()
db_write = DBWrite()
old_news = db_read.read_news_from_db()
new_news = get_news()
news_df = predict_news_category(old_news, new_news, interpreter, label_encoder, tokenizer)
if news_df is None:
raise Exception('Could not generate category predictions')
news_json = [*json.loads(news_df.reset_index(drop=True).to_json(orient="index")).values()]
db_write.insert_news_into_db(news_json)
except Exception as e:
status_json = "{'status':'failure', 'message':'" + str(e) + "'}"
status_code = 500
raise
logger.warning('Exiting update_news()')
return Response(status_json, status=status_code, mimetype='application/json')
logger.warning('Exiting application')
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, timeout=1000, workers=1, threads=1)
|