File size: 1,731 Bytes
e857da4
c2c5fc6
 
 
e857da4
99ce67b
e857da4
5116708
c2c5fc6
 
 
 
e857da4
 
 
e85a070
e857da4
 
c2c5fc6
 
 
 
 
 
 
 
 
 
 
 
e857da4
 
 
eec495e
 
e857da4
c2c5fc6
 
 
 
 
e857da4
c2c5fc6
e857da4
eec495e
 
 
e857da4
 
 
b9ffefc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from news_extractor import get_news
from db_operations.db_write import DBWrite
from db_operations.db_read import DBRead
from news_category_prediction import predict_news_category
import json
from flask import Flask, Response
from flask_cors import cross_origin, CORS
import logging
import tensorflow as tf 
import cloudpickle
from transformers import DistilBertTokenizerFast
import os

app = Flask(__name__)
CORS(app)
logging.warning('Initiated')


def load_model():
    interpreter = tf.lite.Interpreter(model_path=os.path.join("models/news_classification_hf_distilbert.tflite"))
    with open("models/news_classification_labelencoder.bin", "rb") as model_file_obj:
        label_encoder = cloudpickle.load(model_file_obj)
        
    model_checkpoint = "distilbert-base-uncased"
    tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint)
    return interpreter, label_encoder, tokenizer

interpreter, label_encoder, tokenizer = load_model()


@app.route("/")
@cross_origin()
def update_news():
    status_json = "{'status':'success'}"
    status_code = 200
    try:
        db_read = DBRead()
        db_write = DBWrite()
        old_news = db_read.read_news_from_db()
        new_news = get_news()
        news_df = predict_news_category(old_news, new_news, interpreter, label_encoder, tokenizer)
        news_json = [*json.loads(news_df.reset_index(drop=True).to_json(orient="index")).values()]
        db_write.insert_news_into_db(news_json)
    except:
        status_json = "{'status':'failure'}"
        status_code = 500
    return Response(status_json, status=status_code, mimetype='application/json')


if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860, timeout=120, workers=1, threads=1)