File size: 1,731 Bytes
e857da4 c2c5fc6 e857da4 99ce67b e857da4 5116708 c2c5fc6 e857da4 e85a070 e857da4 c2c5fc6 e857da4 eec495e e857da4 c2c5fc6 e857da4 c2c5fc6 e857da4 eec495e e857da4 b9ffefc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
from news_extractor import get_news
from db_operations.db_write import DBWrite
from db_operations.db_read import DBRead
from news_category_prediction import predict_news_category
import json
from flask import Flask, Response
from flask_cors import cross_origin, CORS
import logging
import tensorflow as tf
import cloudpickle
from transformers import DistilBertTokenizerFast
import os
app = Flask(__name__)
CORS(app)
logging.warning('Initiated')
def load_model():
interpreter = tf.lite.Interpreter(model_path=os.path.join("models/news_classification_hf_distilbert.tflite"))
with open("models/news_classification_labelencoder.bin", "rb") as model_file_obj:
label_encoder = cloudpickle.load(model_file_obj)
model_checkpoint = "distilbert-base-uncased"
tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint)
return interpreter, label_encoder, tokenizer
interpreter, label_encoder, tokenizer = load_model()
@app.route("/")
@cross_origin()
def update_news():
status_json = "{'status':'success'}"
status_code = 200
try:
db_read = DBRead()
db_write = DBWrite()
old_news = db_read.read_news_from_db()
new_news = get_news()
news_df = predict_news_category(old_news, new_news, interpreter, label_encoder, tokenizer)
news_json = [*json.loads(news_df.reset_index(drop=True).to_json(orient="index")).values()]
db_write.insert_news_into_db(news_json)
except:
status_json = "{'status':'failure'}"
status_code = 500
return Response(status_json, status=status_code, mimetype='application/json')
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, timeout=120, workers=1, threads=1)
|