lalithadevi commited on
Commit
c2c5fc6
1 Parent(s): 960de68

Upload 13 files

Browse files
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  python_version: '3.11.5'
3
- title: newsdb
4
  emoji: 🔥
5
  colorFrom: red
6
  colorTo: red
 
1
  ---
2
  python_version: '3.11.5'
3
+ title: latest_news_backend_with_cat_pred
4
  emoji: 🔥
5
  colorFrom: red
6
  colorTo: red
app.py CHANGED
@@ -1,25 +1,46 @@
1
  from news_extractor import get_news
2
- from db_operations import DBOperations
 
 
3
  import json
4
  from flask import Flask, Response
5
  from flask_cors import cross_origin, CORS
6
  import logging
 
 
 
 
7
 
8
  app = Flask(__name__)
9
  CORS(app)
10
  logging.warning('Initiated')
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  @app.route("/")
14
  @cross_origin()
15
  def update_news():
16
  status_json = "{'status':'success'}"
17
  status_code = 200
18
  try:
19
- db = DBOperations()
20
- news_df = get_news()
 
 
 
21
  news_json = [*json.loads(news_df.reset_index(drop=True).to_json(orient="index")).values()]
22
- db.insert_news_into_db(news_json)
23
  except:
24
  status_json = "{'status':'failure'}"
25
  status_code = 500
 
1
  from news_extractor import get_news
2
+ from db_operations.db_write import DBWrite
3
+ from db_operations.db_read import DBRead
4
+ from news_category_prediction import predict_news_category
5
  import json
6
  from flask import Flask, Response
7
  from flask_cors import cross_origin, CORS
8
  import logging
9
+ import tensorflow as tf
10
+ import cloudpickle
11
+ from transformers import DistilBertTokenizerFast
12
+ import os
13
 
14
  app = Flask(__name__)
15
  CORS(app)
16
  logging.warning('Initiated')
17
 
18
 
19
+ def load_model():
20
+ interpreter = tf.lite.Interpreter(model_path=os.path.join("models/news_classification_hf_distilbert.tflite"))
21
+ with open("models/news_classification_labelencoder.bin", "rb") as model_file_obj:
22
+ label_encoder = cloudpickle.load(model_file_obj)
23
+
24
+ model_checkpoint = "distilbert-base-uncased"
25
+ tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint)
26
+ return interpreter, label_encoder, tokenizer
27
+
28
+ interpreter, label_encoder, tokenizer = load_model()
29
+
30
+
31
  @app.route("/")
32
  @cross_origin()
33
  def update_news():
34
  status_json = "{'status':'success'}"
35
  status_code = 200
36
  try:
37
+ db_read = DBRead()
38
+ db_write = DBWrite()
39
+ old_news = db_read.read_news_from_db()
40
+ new_news = get_news()
41
+ news_df = predict_news_category(old_news, new_news, interpreter, label_encoder, tokenizer)
42
  news_json = [*json.loads(news_df.reset_index(drop=True).to_json(orient="index")).values()]
43
+ db_write.insert_news_into_db(news_json)
44
  except:
45
  status_json = "{'status':'failure'}"
46
  status_code = 500
db_operations/__init__.py ADDED
File without changes
db_operations/db_read.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pymongo
2
+ import os
3
+ import pandas as pd
4
+
5
+
6
+ class DBRead:
7
+ """
8
+ Reads news from MongoDB
9
+ """
10
+ def __init__(self):
11
+ self.url = os.getenv('DB_URL')
12
+ self.database = "rss_news_db"
13
+ self.collection = "rss_news"
14
+ self.__client = None
15
+ self.__error = 0
16
+
17
+ def __connect(self):
18
+ try:
19
+ self.__client = pymongo.MongoClient(self.url)
20
+ _ = self.__client.list_database_names()
21
+ except Exception as conn_exception:
22
+ self.__error = 1
23
+ self.__client = None
24
+ raise
25
+
26
+ def __read(self):
27
+ try:
28
+ db = self.__client[self.database]
29
+ coll = db[self.collection]
30
+ docs = []
31
+ for doc in coll.find():
32
+ docs.append(doc)
33
+ rss_df = pd.DataFrame(docs)
34
+ except Exception as insert_err:
35
+ self.__error = 1
36
+ rss_df = pd.DataFrame({'_id': '', 'title': '', 'url': '',
37
+ 'description': '', 'parsed_date': '',
38
+ 'src': ''}, index=[0])
39
+ return rss_df
40
+
41
+ def __close_connection(self):
42
+ if self.__client is not None:
43
+ self.__client.close()
44
+ self.__client = None
45
+
46
+ def read_news_from_db(self):
47
+ rss_df = pd.DataFrame({'_id': '', 'title': '', 'url': '',
48
+ 'description': '', 'parsed_date': '',
49
+ 'src': ''}, index=[0])
50
+ if self.url is not None:
51
+ if self.__error == 0:
52
+ self.__connect()
53
+ if self.__error == 0:
54
+ rss_df = self.__read()
55
+ if self.__error == 0:
56
+ print("Read Successful")
57
+ if self.__client is not None:
58
+ self.__close_connection()
59
+ return rss_df
db_operations/db_write.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pymongo
2
+ import os
3
+
4
+
5
+ class DBWrite:
6
+ """
7
+ Inserts processed news into MongoDB
8
+ """
9
+ def __init__(self):
10
+ self.url = os.getenv('DB_URL')
11
+ self.database = "rss_news_db"
12
+ self.collection = "rss_news"
13
+ self.__client = None
14
+ self.__error = 0
15
+
16
+ def __connect(self):
17
+ try:
18
+ self.__client = pymongo.MongoClient(self.url)
19
+ _ = self.__client.list_database_names()
20
+ except Exception as conn_exception:
21
+ self.__error = 1
22
+ self.__close_connection()
23
+ self.__client = None
24
+ raise
25
+
26
+ def __insert(self, documents):
27
+ try:
28
+
29
+ db = self.__client[self.database]
30
+ coll = db[self.collection]
31
+ coll.drop()
32
+ coll.insert_many(documents=documents)
33
+ except Exception as insert_err:
34
+ self.__error = 1
35
+ self.__close_connection()
36
+ raise
37
+
38
+ def __close_connection(self):
39
+ if self.__client is not None:
40
+ self.__client.close()
41
+ self.__client = None
42
+
43
+ def insert_news_into_db(self, documents: list):
44
+ if self.url is not None:
45
+ if self.__error == 0:
46
+ self.__connect()
47
+ if self.__error == 0:
48
+ self.__insert(documents=documents)
49
+ if self.__error == 0:
50
+ print("Insertion Successful")
51
+ if self.__client is not None:
52
+ self.__close_connection()
models/news_classification_hf_distilbert.tflite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:283b191892f95889a79e18f4362f207617e56b5c9f93160b61be7db1c480938e
3
+ size 66788520
models/news_classification_labelencoder.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65ddceef60d9f1dc95d70a1940c5b382bb58d47ebb7145bf32e887f62e054535
3
+ size 327
news_category_prediction.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import tensorflow as tf
4
+
5
+
6
+ def parse_prediction(tflite_pred, label_encoder):
7
+ tflite_pred_argmax = np.argmax(tflite_pred, axis=1)
8
+ tflite_pred_label = label_encoder.inverse_transform(tflite_pred_argmax)
9
+ tflite_pred_prob = np.max(tflite_pred, axis=1)
10
+ return tflite_pred_label, tflite_pred_prob
11
+
12
+
13
+ def inference(text, interpreter, label_encoder, tokenizer):
14
+ batch_size = len(text)
15
+ MAX_LEN = 80
16
+ N_CLASSES = 8
17
+ if text != "":
18
+ tokens = tokenizer(text, max_length=MAX_LEN, padding="max_length", truncation=True, return_tensors="tf")
19
+ # tflite model inference
20
+ interpreter.allocate_tensors()
21
+ input_details = interpreter.get_input_details()
22
+ output_details = interpreter.get_output_details()[0]
23
+ attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids']
24
+ interpreter.resize_tensor_input(input_details[0]['index'],[batch_size, MAX_LEN])
25
+ interpreter.resize_tensor_input(input_details[1]['index'],[batch_size, MAX_LEN])
26
+ interpreter.resize_tensor_input(output_details['index'],[batch_size, N_CLASSES])
27
+ interpreter.allocate_tensors()
28
+ interpreter.set_tensor(input_details[0]["index"], attention_mask)
29
+ interpreter.set_tensor(input_details[1]["index"], input_ids)
30
+ interpreter.invoke()
31
+ tflite_pred = interpreter.get_tensor(output_details["index"])
32
+ tflite_pred = parse_prediction(tflite_pred)
33
+ return tflite_pred
34
+
35
+ def cols_check(new_cols, old_cols):
36
+ return all([new_col==old_col for new_col, old_col in zip(new_cols, old_cols)])
37
+
38
+
39
+ def predict_news_category(old_news: pd.DataFrame, new_news: pd.DataFrame, interpreter, label_encoder, tokenizer):
40
+ old_news = old_news.copy()
41
+ new_news = new_news.copy()
42
+ # dbops = DBOperations()
43
+ # old_news = dbops.read_news_from_db()
44
+ old_news.drop(columns='_id', inplace=True)
45
+ # new_news = get_news()
46
+ if 'category' not in [*old_news.columns]:
47
+ print('no prior predictions found')
48
+ if not cols_check([*new_news.columns], [*old_news.columns]):
49
+ raise Exeption("New and old cols don't match")
50
+ final_df = pd.concat([old_news, new_news], axis=0, ignore_index=True)
51
+ final_df.drop_duplicates(subset='url', keep='first', inplace=True)
52
+ headlines = [*final_df['title']].copy()
53
+ label, prob = inference(headlines, interpreter, label_encoder, tokenizer)
54
+ final_df['category'] = label
55
+ final_df['pred_proba'] = prob
56
+ else:
57
+ print('prior predictions found')
58
+ if not cols_check([*new_news.columns], [*old_news.columns][:-2]):
59
+ raise Exeption("New and old cols don't match")
60
+ old_urls = [*old_news['url']]
61
+ new_news = new_news.loc[new_news['url'].isin(old_urls) == False, :]
62
+ headlines = [*new_news['title']].copy()
63
+ label, prob = inference(headlines, interpreter, label_encoder, tokenizer)
64
+ new_news['category'] = label
65
+ new_news['pred_proba'] = prob
66
+ final_df = pd.concat([old_news, new_news], axis=0, ignore_index=True)
67
+ final_df.drop_duplicates(subset='url', keep='first', inplace=True)
68
+ final_df.reset_index(drop=True, inplace=True)
69
+ final_df.loc[final_df['pred_proba']<0.65, 'category'] = 'OTHERS'
70
+ return final_df
71
+
news_extractor/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from news_extractor.news_extractor import *
news_extractor/news_extractor.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from bs4 import BeautifulSoup
4
+ import requests as r
5
+ import regex as re
6
+ from dateutil import parser
7
+ import logging
8
+ import multiprocessing
9
+
10
+
11
+ def date_time_parser(dt):
12
+ """
13
+ Computes the minutes elapsed since published time.
14
+ :param dt: date
15
+ :return: int, minutes elapsed.
16
+ """
17
+ return int(np.round((dt.now(dt.tz) - dt).total_seconds() / 60, 0))
18
+
19
+ def text_clean(desc):
20
+ """
21
+ Cleans the text by removing special chars.
22
+ :param desc: string containing description
23
+ :return: str, cleaned description.
24
+ """
25
+ desc = desc.replace("&lt;", "<")
26
+ desc = desc.replace("&gt;", ">")
27
+ desc = re.sub("<.*?>", "", desc)
28
+ desc = desc.replace("#39;", "'")
29
+ desc = desc.replace('&quot;', '"')
30
+ desc = desc.replace('&nbsp;', ' ')
31
+ desc = desc.replace('#32;', ' ')
32
+ return desc
33
+
34
+
35
+ def rss_parser(i):
36
+ """
37
+ Returns a data frame of parsed news item.
38
+ :param i: single news item in RSS feed.
39
+ :return: Data frame of parsed news item.
40
+ """
41
+ b1 = BeautifulSoup(str(i), "xml")
42
+ title = "" if b1.find("title") is None else b1.find("title").get_text()
43
+ title = text_clean(title)
44
+ url = "" if b1.find("link") is None else b1.find("link").get_text()
45
+ desc = "" if b1.find("description") is None else b1.find("description").get_text()
46
+ desc = text_clean(desc)
47
+ desc = f'{desc[:300]}...' if len(desc) >= 300 else desc
48
+ date = "Sat, 12 Aug 2000 13:39:15 +05:30" if ((b1.find("pubDate") == "") or (b1.find("pubDate") is None)) else b1.find("pubDate").get_text()
49
+ if url.find("businesstoday.in") >= 0:
50
+ date = date.replace("GMT", "+0530")
51
+
52
+ date1 = parser.parse(date)
53
+ return pd.DataFrame({"title": title,
54
+ "url": url,
55
+ "description": desc,
56
+ "parsed_date": date1}, index=[0])
57
+
58
+
59
+ def src_parse(rss):
60
+ """
61
+ Returns the root domain name (eg. livemint.com is extracted from www.livemint.com
62
+ :param rss: RSS URL
63
+ :return: str, string containing the source name
64
+ """
65
+ if rss.find('ndtvprofit') >= 0:
66
+ rss = 'ndtv profit'
67
+ if rss.find('ndtv') >= 0:
68
+ rss = 'ndtv.com'
69
+ if rss.find('telanganatoday') >= 0:
70
+ rss = 'telanganatoday.com'
71
+
72
+ rss = rss.replace("https://www.", "")
73
+ rss = rss.split("/")
74
+ return rss[0]
75
+
76
+
77
+ def news_agg(rss):
78
+ """
79
+ Returns feeds from each 'rss' URL.
80
+ :param rss: RSS URL.
81
+ :return: Data frame of processed articles.
82
+ """
83
+ try:
84
+ rss_df = pd.DataFrame()
85
+ # user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/109.0.0.0 Safari/537.36"
86
+ headers = {
87
+ 'authority': 'www.google.com',
88
+ 'accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7',
89
+ 'accept-language': 'en-US,en;q=0.9',
90
+ 'cache-control': 'max-age=0',
91
+ 'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/115.0.0.0 Safari/537.36'
92
+ }
93
+
94
+ timeout = 5
95
+
96
+ resp = r.get(rss, timeout=timeout, headers=headers)
97
+ logging.warning(f'{rss}: {resp.status_code}')
98
+ b = BeautifulSoup(resp.content, "xml")
99
+ items = b.find_all("item")
100
+ for i in items:
101
+ # rss_df = rss_df.append(rss_parser(i)).copy()
102
+ rss_df = pd.concat([rss_df, rss_parser(i)], axis=0)
103
+ rss_df.reset_index(drop=True, inplace=True)
104
+ rss_df["description"] = rss_df["description"].replace([" NULL", ''], np.nan)
105
+
106
+
107
+ #### UNCOMMENT IN CASE OF OOM ERROR IN RENDER
108
+ # rss_df.dropna(inplace=True)
109
+
110
+ ####
111
+
112
+
113
+ rss_df["src"] = src_parse(rss)
114
+ rss_df["elapsed_time"] = rss_df["parsed_date"].apply(date_time_parser)
115
+ rss_df["parsed_date"] = rss_df["parsed_date"].astype("str")
116
+ # rss_df["elapsed_time_str"] = rss_df["elapsed_time"].apply(elapsed_time_str)
117
+ except Exception as e:
118
+ print(e)
119
+ pass
120
+ return rss_df
121
+
122
+
123
+ # List of RSS feeds
124
+ rss = ['https://www.economictimes.indiatimes.com/rssfeedstopstories.cms',
125
+ 'https://www.thehindu.com/news/feeder/default.rss',
126
+ # 'https://telanganatoday.com/feed',
127
+ 'https://www.businesstoday.in/rssfeeds/?id=225346',
128
+ 'https://feeds.feedburner.com/ndtvnews-latest',
129
+ 'https://www.hindustantimes.com/feeds/rss/world-news/rssfeed.xml',
130
+ 'https://www.indiatoday.in/rss/1206578',
131
+
132
+ 'https://www.moneycontrol.com/rss/latestnews.xml',
133
+ 'https://www.livemint.com/rss/news',
134
+
135
+ 'https://www.zeebiz.com/latest.xml/feed',
136
+ 'https://www.timesofindia.indiatimes.com/rssfeedmostrecent.cms']
137
+
138
+
139
+ def get_news_rss(url):
140
+ # final_df = pd.DataFrame()
141
+ # for i in rss:
142
+ # # final_df = final_df.append(news_agg(i))
143
+ # final_df = pd.concat([final_df, news_agg(i)], axis=0)
144
+ final_df = news_agg(url)
145
+ final_df.reset_index(drop=True, inplace=True)
146
+
147
+
148
+
149
+ final_df.sort_values(by="elapsed_time", inplace=True)
150
+ # final_df['src_time'] = final_df['src'] + ("&nbsp;" * 5) + final_df["elapsed_time_str"]
151
+ # final_df.drop(columns=['date', 'parsed_date', 'src', 'elapsed_time', 'elapsed_time_str'], inplace=True)
152
+ final_df.drop(columns=['elapsed_time'], inplace=True)
153
+
154
+
155
+ #### UNCOMMENT 1ST STATEMENT AND REMOVE 2ND STATEMENT IN CASE OF OOM ERROR IN RENDER
156
+ # final_df.drop_duplicates(subset='description', inplace=True)
157
+ final_df.drop_duplicates(subset='url', inplace=True)
158
+
159
+ ####
160
+
161
+ final_df = final_df.loc[(final_df["title"] != ""), :].copy()
162
+
163
+ final_df.loc[(final_df['description'].isna()) | (final_df['description']=='')| (final_df['description']==' '), 'description'] = final_df.loc[(final_df['description'].isna()) | (final_df['description']=='')| (final_df['description']==' '), 'title']
164
+
165
+ return final_df
166
+
167
+ def get_news_multi_process(urls):
168
+ '''
169
+ Get the data shape by parallely calculating lenght of each chunk and
170
+ aggregating them to get lenght of complete training dataset
171
+ '''
172
+ pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())
173
+
174
+ results = []
175
+ for url in urls:
176
+ f = pool.apply_async(get_news_rss, [url]) # asynchronously applying function to chunk. Each worker parallely begins to work on the job
177
+ results.append(f) # appending result to results
178
+
179
+ final_df = pd.DataFrame()
180
+ for f in results:
181
+ # print(f.get())
182
+ final_df = pd.concat([final_df, f.get(timeout=120)], axis=0) # getting output of each parallel job
183
+
184
+ final_df.reset_index(drop=True, inplace=True)
185
+ logging.warning(final_df['src'].unique())
186
+ pool.close()
187
+ pool.join()
188
+ return final_df
189
+
190
+
191
+ def get_news():
192
+ return get_news_multi_process(rss)
requirements.txt CHANGED
@@ -10,3 +10,7 @@ flask_cors==3.0.10
10
  gunicorn==20.1.0
11
  pymongo==4.3.3
12
  Werkzeug==2.2.2
 
 
 
 
 
10
  gunicorn==20.1.0
11
  pymongo==4.3.3
12
  Werkzeug==2.2.2
13
+ tensorflow
14
+ scikit-learn==1.2.2
15
+ cloudpickle
16
+ transformers