npphuong commited on
Commit
1aa8526
1 Parent(s): d8e0eb9

Upload 12 files

Browse files
LSTM.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d6fba55ed617b95fc7c1f5322f3878bed3fe7d4962b8a263aaf2f1a16c16970
3
+ size 341626144
Logistic_Model.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:190bc17002c82fd3a418d7cd9835ec4a590bc45eb1964d52e625a836bda1a6a9
3
+ size 400895
SVM_Linear_Kernel.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d56dda49b361a586459a3fb2d79f1fbffa65ccd78c8bf208ab37abdf0515ccc7
3
+ size 400764
app.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import joblib
2
+ import streamlit as st
3
+ import json
4
+ import requests
5
+ from bs4 import BeautifulSoup
6
+ from datetime import date
7
+ from tensorflow.keras.models import load_model
8
+ from tensorflow.keras.preprocessing.text import Tokenizer
9
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
10
+ import numpy as np
11
+ from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
12
+ import torch
13
+
14
+ # load all the models and vectorizer (global vocabulary)
15
+ Seq_model = load_model("LSTM.h5") # Sequential
16
+ SVM_model = joblib.load("SVM_Linear_Kernel.joblib") # SVM
17
+ logistic_model = joblib.load("Logistic_Model.joblib") # Logistic
18
+ svm_model = joblib.load('svm_model.joblib')
19
+
20
+ vectorizer = joblib.load("vectorizer.joblib") # global vocabulary (used for Logistic, SVC)
21
+ tokenizer = joblib.load("tokenizer.joblib") # used for LSTM
22
+
23
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24
+ tokenizer1 = DistilBertTokenizer.from_pretrained("tokenizer_bert")
25
+ model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=5)
26
+ model.load_state_dict(torch.load("fine_tuned_bert_model1.pth", map_location=device))
27
+
28
+ # Decode label function
29
+ # {'business': 0, 'entertainment': 1, 'health': 2, 'politics': 3, 'sport': 4}
30
+ def decodedLabel(input_number):
31
+ print('receive label encoded', input_number)
32
+ categories = {
33
+ 0: 'Business',
34
+ 1: 'Entertainment',
35
+ 2: 'Health',
36
+ 3: 'Politics',
37
+ 4: 'Sport'
38
+ }
39
+ result = categories.get(input_number) # Ex: Health
40
+ print('decoded result', result)
41
+ return result
42
+
43
+ # Web Crawler function
44
+ def crawURL(url):
45
+ # Fetch the URL content
46
+ response = requests.get(url)
47
+ # Parse the sitemap HTML
48
+ soup = BeautifulSoup(response.content, 'html.parser')
49
+
50
+ # Find all anchor tags that are children of span tags with class 'sitemap-link'
51
+ urls = [span.a['href'] for span in soup.find_all('span', class_='sitemap-link') if span.a]
52
+
53
+ # Crawl pages and extract data
54
+ try:
55
+ print(f"Crawling page: {url}")
56
+ # Fetch page content
57
+ page_response = requests.get(url)
58
+ page_content = page_response.content
59
+
60
+ # Parse page content with BeautifulSoup
61
+ soup = BeautifulSoup(page_content, 'html.parser')
62
+
63
+ # Extract data you need from the page
64
+ author = soup.find("meta", {"name": "author"}).attrs['content'].strip()
65
+ date_published = soup.find("meta", {"property": "article:published_time"}).attrs['content'].strip()
66
+ article_section = soup.find("meta", {"name": "meta-section"}).attrs['content']
67
+ url = soup.find("meta", {"property": "og:url"}).attrs['content']
68
+ headline = soup.find("h1", {"data-editable": "headlineText"}).text.strip()
69
+ description = soup.find("meta", {"name": "description"}).attrs['content'].strip()
70
+ keywords = soup.find("meta", {"name": "keywords"}).attrs['content'].strip()
71
+ text = soup.find(itemprop="articleBody")
72
+ # Find all <p> tags with class "paragraph inline-placeholder"
73
+ paragraphs = text.find_all('p', class_="paragraph inline-placeholder")
74
+
75
+ # Initialize an empty list to store the text content of each paragraph
76
+ paragraph_texts = []
77
+
78
+ # Iterate over each <p> tag and extract its text content
79
+ for paragraph in paragraphs:
80
+ paragraph_texts.append(paragraph.text.strip())
81
+
82
+ # Join the text content of all paragraphs into a single string
83
+ full_text = ''.join(paragraph_texts)
84
+ return full_text
85
+
86
+ except Exception as e:
87
+ print(f"Failed to crawl page: {url}, Error: {str(e)}")
88
+ return None
89
+
90
+ # Predict for text category by Models
91
+ def process_api(text):
92
+ # Vectorize the text data
93
+ processed_text = vectorizer.transform([text])
94
+ sequence = tokenizer.texts_to_sequences([text])
95
+ padded_sequence = pad_sequences(sequence, maxlen=1000, padding='post')
96
+
97
+ new_encoding = tokenizer1([text], truncation=True, padding=True, return_tensors="pt")
98
+ input_ids = new_encoding['input_ids']
99
+ attention_mask = new_encoding['attention_mask']
100
+ with torch.no_grad():
101
+ output = model(input_ids, attention_mask=attention_mask)
102
+ logits = output.logits
103
+
104
+ # Get the predicted result from models
105
+ Logistic_Predicted = logistic_model.predict(processed_text).tolist() # Logistic Model
106
+ SVM_Predicted = SVM_model.predict(processed_text).tolist() # SVC Model
107
+ Seq_Predicted = Seq_model.predict(padded_sequence)
108
+ predicted_label_index = np.argmax(Seq_Predicted)
109
+
110
+ # ----------- Proba -----------
111
+ Logistic_Predicted_proba = logistic_model.predict_proba(processed_text)
112
+ svm_new_probs = SVM_model.decision_function(processed_text)
113
+ svm_probs = svm_model.predict_proba(svm_new_probs)
114
+ predicted_label_index = np.argmax(Seq_Predicted)
115
+
116
+ bert_probabilities = torch.softmax(logits, dim=1)
117
+ max_probability = torch.max(bert_probabilities).item()
118
+ predicted_label_bert = torch.argmax(logits, dim=1).item()
119
+ # ----------- Debug Logs -----------
120
+ logistic_debug = decodedLabel(int(Logistic_Predicted[0]))
121
+ svc_debug = decodedLabel(int(SVM_Predicted[0]))
122
+ # predicted_label_index = np.argmax(Seq_Predicted)
123
+ #print('Logistic', int(Logistic_Predicted[0]), logistic_debug)
124
+ #print('SVM', int(SVM_Predicted[0]), svc_debug)
125
+
126
+ return {
127
+ 'predicted_label_logistic': decodedLabel(int(Logistic_Predicted[0])),
128
+ 'probability_logistic': f"{int(float(np.max(Logistic_Predicted_proba))*10000//100)}%",
129
+
130
+ 'predicted_label_svm': decodedLabel(int(SVM_Predicted[0])),
131
+ 'probability_svm': f"{int(float(np.max(svm_probs))*10000//100)}%",
132
+
133
+ 'predicted_label_lstm': decodedLabel(int(predicted_label_index)),
134
+ 'probability_lstm': f"{int(float(np.max(Seq_Predicted))*10000//100)}%",
135
+
136
+ 'predicted_label_bert': decodedLabel(int(predicted_label_bert)),
137
+ 'probability_bert': f"{int(float(max_probability)*10000//100)}%",
138
+
139
+ 'Article_Content': text
140
+ }
141
+
142
+ # Init web crawling, process article content by Model and return result as JSON
143
+ def categorize(url):
144
+ try:
145
+ article_content = crawURL(url)
146
+ result = process_api(article_content)
147
+ return result
148
+ except Exception as error:
149
+ if hasattr(error, 'message'):
150
+ return {"error_message": error.message}
151
+ else:
152
+ return {"error_message": error}
153
+
154
+
155
+ # Main App
156
+ st.title('Instant Category Classification')
157
+ st.write("Unsure what category a CNN article belongs to? Our clever tool can help! Paste the URL below and press Enter. We'll sort it into one of our 5 categories in a flash! ⚡️")
158
+
159
+ # Define category information (modify content and bullet points as needed)
160
+ categories = {
161
+ "Business": [
162
+ "Analyze market trends and investment opportunities.",
163
+ "Gain insights into company performance and industry news.",
164
+ "Stay informed about economic developments and regulations."
165
+ ],
166
+ "Health": [
167
+ "Discover healthy recipes and exercise tips.",
168
+ "Learn about the latest medical research and advancements.",
169
+ "Find resources for managing chronic conditions and improving well-being."
170
+ ],
171
+ "Sport": [
172
+ "Follow your favorite sports teams and athletes.",
173
+ "Explore news and analysis from various sports categories.",
174
+ "Stay updated on upcoming games and competitions."
175
+ ],
176
+ "Politics": [
177
+ "Get informed about current political events and policies.",
178
+ "Understand different perspectives on political issues.",
179
+ "Engage in discussions and debates about politics."
180
+ ],
181
+ "Entertainment": [
182
+ "Find recommendations for movies, TV shows, and music.",
183
+ "Explore reviews and insights from entertainment critics.",
184
+ "Stay updated on celebrity news and cultural trends."
185
+ ]
186
+ }
187
+
188
+ # Define model information (modify descriptions as needed)
189
+ models = {
190
+ "Logistic Regression": "A widely used statistical method for classification problems. It excels at identifying linear relationships between features and the target variable.",
191
+ "SVC (Support Vector Classifier)": "A powerful machine learning model that seeks to find a hyperplane that best separates data points of different classes. It's effective for high-dimensional data and can handle some non-linear relationships.",
192
+ "LSTM (Long Short-Term Memory)": "A type of recurrent neural network (RNN) particularly well-suited for sequential data like text or time series. LSTMs can effectively capture long-term dependencies within the data.",
193
+ "BERT (Bidirectional Encoder Representations from Transformers)": "A powerful pre-trained model based on the Transformer architecture. It excels at understanding the nuances of language and can be fine-tuned for various NLP tasks like text classification."
194
+ }
195
+
196
+
197
+ # CNN URL Example List
198
+ URL_Example = [
199
+ 'https://edition.cnn.com/2012/01/31/health/frank-njenga-mental-health/index.html',
200
+ 'https://edition.cnn.com/2024/04/30/entertainment/barbra-streisand-melissa-mccarthy-ozempic/index.html',
201
+ 'https://edition.cnn.com/2024/04/30/sport/lebron-james-lakers-future-nba-spt-intl/index.html',
202
+ 'https://edition.cnn.com/2024/04/30/business/us-home-prices-rose-in-february/index.html'
203
+ ]
204
+
205
+ # Create expanders containing list of categories can be classified
206
+ with st.expander("Category List"):
207
+ # Title for each category
208
+ st.subheader("Available Categories:")
209
+ for category in categories.keys():
210
+ st.write(f"- {category}")
211
+ # Content for each category (separated by a horizontal line)
212
+ st.write("---")
213
+ for category, content in categories.items():
214
+ st.subheader(category)
215
+ for item in content:
216
+ st.write(f"- {item}")
217
+
218
+
219
+ # Create expanders containing list of models used in this project
220
+ with st.expander("Available Models"):
221
+ st.subheader("List of Models:")
222
+ for model_name in models.keys():
223
+ st.write(f"- {model_name}")
224
+ st.write("---")
225
+ for model_name, description in models.items():
226
+ st.subheader(model_name)
227
+ st.write(description)
228
+
229
+ with st.expander("URLs Example"):
230
+ for url in URL_Example:
231
+ st.write(f"- {url}")
232
+
233
+ # Explain to user why this project is only worked for CNN domain
234
+ with st.expander("Tips", expanded=True):
235
+ st.write(
236
+ '''
237
+ This project works best with CNN articles right now.
238
+ Our web crawler is like a special tool for CNN's website.
239
+ It can't quite understand other websites because they're built differently
240
+ '''
241
+ )
242
+
243
+ st.divider() # 👈 Draws a horizontal rule
244
+
245
+ st.title('Dive in! See what category your CNN story belongs to 😉.')
246
+ # Paste URL Input
247
+ url = st.text_input("Find your favorite CNN story! Paste the URL and press ENTER 🔍.", placeholder='Ex: https://edition.cnn.com/2012/01/31/health/frank-njenga-mental-health/index.html')
248
+
249
+ if url:
250
+ st.divider() # 👈 Draws a horizontal rule
251
+ result = categorize(url)
252
+ article_content = result.get('Article_Content')
253
+ st.title('Article Content Fetched')
254
+ st.text_area("", value=article_content, height=400) # render the article content as textarea element
255
+ st.divider() # 👈 Draws a horizontal rule
256
+ st.title('Predicted Results')
257
+ st.json({
258
+ "Logistic": {
259
+ "predicted_label": result.get("predicted_label_logistic"),
260
+ "probability": result.get("probability_logistic")
261
+ },
262
+ "SVC": {
263
+ "predicted_label": result.get("predicted_label_svm"),
264
+ "probability": result.get("probability_svm")
265
+ },
266
+ "LSTM": {
267
+ "predicted_label": result.get("predicted_label_lstm"),
268
+ "probability": result.get("probability_lstm")
269
+ },
270
+ "BERT": {
271
+ "predicted_label": result.get("predicted_label_bert"),
272
+ "probability": result.get("probability_bert")
273
+ }
274
+ })
275
+
276
+ st.divider() # 👈 Draws a horizontal rule
277
+
278
+ # Category labels and corresponding counts
279
+ categories = ["Sport", "Health", "Entertainment", "Politics", "Business"]
280
+ counts = [5638, 4547, 2658, 2461, 1362]
281
+
282
+ # Optional: Add a chart title
283
+ st.title("Training Data Category Distribution")
284
+
285
+ # Optional: Display additional information
286
+ st.write("Here's a breakdown of the number of articles in each category:")
287
+ for category, count in zip(categories, counts):
288
+ st.write(f"- {category}: {count}")
289
+
290
+ # Create the bar chart
291
+ st.bar_chart(data=dict(zip(categories, counts)))
292
+
293
+ st.divider() # 👈 Draws a horizontal rule
294
+
295
+ # ------------ Copyright Section ------------
296
+ # Get the current year
297
+ current_year = date.today().year
298
+ # Format the copyright statement with dynamic year
299
+ copyright_text = f"Copyright © {current_year}"
300
+ st.title(copyright_text)
301
+ author_names = ["Trần Thanh Phước (Mentor)", "Lương Ngọc Phương (Member)", "Trịnh Cẩm Minh (Member)"]
302
+ st.write("Meet the minds behind the work!")
303
+ for author in author_names:
304
+ if (author == "Trịnh Cẩm Minh (Member)"): st.markdown("- [Trịnh Cẩm Minh (Member)](https://minhct.netlify.app/)")
305
+ else: st.markdown(f"- {author}\n") # Use f-string for bullet and newline
fine_tuned_bert_model1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9666f4dc527e68a8ad8f528b0b946d86fc05f5cdf151cc35cd92b71932e22095
3
+ size 267872438
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ tensorflow==2.15.0
2
+ joblib
3
+ scikit-learn
4
+ transformers==4.40.1
5
+ streamlit
6
+ numpy
7
+ requests
8
+ beautifulsoup4
9
+ torch
svm_model.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5de24bdba9f805bde0ac379c14ab68f45da8419a0349f87c8ae1596766173f15
3
+ size 1135
tokenizer.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:89848b480e0cc8f4555bc6b9d79b9fd2f369a4b4d2dd2247561c16b7dfabf7f9
3
+ size 20743621
tokenizer_bert/special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
tokenizer_bert/tokenizer_config.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "do_basic_tokenize": true,
47
+ "do_lower_case": true,
48
+ "mask_token": "[MASK]",
49
+ "model_max_length": 512,
50
+ "never_split": null,
51
+ "pad_token": "[PAD]",
52
+ "sep_token": "[SEP]",
53
+ "strip_accents": null,
54
+ "tokenize_chinese_chars": true,
55
+ "tokenizer_class": "DistilBertTokenizer",
56
+ "unk_token": "[UNK]"
57
+ }
tokenizer_bert/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
vectorizer.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fe068fd0f45e7dcef9baebd5d6813ab5d730cbd7018030728c6b00e66e03acc
3
+ size 4122050