File size: 12,613 Bytes
77f904f
69a88fc
e024b69
ba5f8d5
 
2335fe9
80a186f
643501b
 
2a092a5
3c09b95
 
84d8e35
f8f4b5f
26ff28f
d79611d
75dc5f6
a3039e9
 
b71cd64
26ff28f
f8f4b5f
01fb40f
3c09b95
 
c89a171
3c09b95
3c7207a
 
22da72a
af31e8a
3c7207a
 
 
 
 
 
 
192d8ff
3c7207a
 
 
b71cd64
d6a49a1
b71cd64
9ee5788
d6a49a1
 
 
 
 
 
 
84d8e35
 
 
 
 
d6a49a1
84d8e35
 
d6a49a1
84d8e35
 
 
 
 
 
 
 
 
 
 
d6a49a1
84d8e35
 
d6a49a1
84d8e35
 
 
d6a49a1
84d8e35
 
 
 
 
 
de80a7e
d6a49a1
2335fe9
d6a49a1
 
 
9412e99
 
3c09b95
 
 
 
 
 
 
b71cd64
d6a49a1
3c7207a
 
26ff28f
 
 
e5faf6c
 
 
 
92e286d
3c09b95
 
 
 
22da72a
 
 
6bd8770
26ff28f
 
52f1084
d6a49a1
3890fa9
901b98e
90f56df
 
 
92e286d
 
 
3c09b95
4d90a4c
3c09b95
92e286d
7e83a81
d6a49a1
 
2335fe9
b21cabd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90f56df
92e286d
 
db06d58
d339e73
3c09b95
 
 
 
4604f63
2335fe9
4604f63
 
 
 
 
 
 
 
 
 
 
 
247496e
 
 
2335fe9
 
 
 
 
 
 
 
b21cabd
2335fe9
dd435da
2335fe9
 
d8b3aa1
85a06e4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
import joblib
import streamlit as st
import json
import requests
from bs4 import BeautifulSoup
from datetime import date
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import numpy as np
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch

# load all the models and vectorizer (global vocabulary)
Seq_model = load_model("LSTM.h5") # Sequential
SVM_model = joblib.load("SVM_Linear_Kernel.joblib") # SVM
logistic_model = joblib.load("Logistic_Model.joblib") # Logistic
svm_model = joblib.load('svm_model.joblib')

vectorizer = joblib.load("vectorizer.joblib") # global vocabulary (used for Logistic, SVC)
tokenizer = joblib.load("tokenizer.joblib") # used for LSTM

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer1 = DistilBertTokenizer.from_pretrained("tokenizer_bert")
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=5)
model.load_state_dict(torch.load("fine_tuned_bert_model1.pth", map_location=device))

# Decode label function
# {'business': 0, 'entertainment': 1, 'health': 2, 'politics': 3, 'sport': 4}
def decodedLabel(input_number):
    print('receive label encoded', input_number)
    categories = {
      0: 'Business',
      1: 'Entertainment',
      2: 'Health',
      3: 'Politics',
      4: 'Sport'
    }
    result = categories.get(input_number) # Ex: Health
    print('decoded result', result)
    return result

# Web Crawler function
def crawURL(url):
    # Fetch the URL content
    response = requests.get(url)
    # Parse the sitemap HTML
    soup = BeautifulSoup(response.content, 'html.parser')

    # Find all anchor tags that are children of span tags with class 'sitemap-link'
    urls = [span.a['href'] for span in soup.find_all('span', class_='sitemap-link') if span.a]

    # Crawl pages and extract data
    try:
        print(f"Crawling page: {url}")
        # Fetch page content
        page_response = requests.get(url)
        page_content = page_response.content

        # Parse page content with BeautifulSoup
        soup = BeautifulSoup(page_content, 'html.parser')

        # Extract data you need from the page
        author = soup.find("meta", {"name": "author"}).attrs['content'].strip()
        date_published = soup.find("meta", {"property": "article:published_time"}).attrs['content'].strip()
        article_section = soup.find("meta", {"name": "meta-section"}).attrs['content']
        url = soup.find("meta", {"property": "og:url"}).attrs['content']
        headline = soup.find("h1", {"data-editable": "headlineText"}).text.strip()
        description = soup.find("meta", {"name": "description"}).attrs['content'].strip()
        keywords = soup.find("meta", {"name": "keywords"}).attrs['content'].strip()
        text = soup.find(itemprop="articleBody")
        # Find all <p> tags with class "paragraph inline-placeholder"
        paragraphs = text.find_all('p', class_="paragraph inline-placeholder")

        # Initialize an empty list to store the text content of each paragraph
        paragraph_texts = []

        # Iterate over each <p> tag and extract its text content
        for paragraph in paragraphs:
            paragraph_texts.append(paragraph.text.strip())

        # Join the text content of all paragraphs into a single string
        full_text = ''.join(paragraph_texts)
        return full_text
        
    except Exception as e:
        print(f"Failed to crawl page: {url}, Error: {str(e)}")
        return None

# Predict for text category by Models
def process_api(text):
    # Vectorize the text data
    processed_text = vectorizer.transform([text])
    sequence = tokenizer.texts_to_sequences([text])
    padded_sequence = pad_sequences(sequence, maxlen=1000, padding='post')

    new_encoding = tokenizer1([text], truncation=True, padding=True, return_tensors="pt")
    input_ids = new_encoding['input_ids']
    attention_mask = new_encoding['attention_mask']
    with torch.no_grad():
        output = model(input_ids, attention_mask=attention_mask)
        logits = output.logits
    
    # Get the predicted result from models
    Logistic_Predicted = logistic_model.predict(processed_text).tolist() # Logistic Model
    SVM_Predicted = SVM_model.predict(processed_text).tolist() # SVC Model
    Seq_Predicted = Seq_model.predict(padded_sequence)
    predicted_label_index = np.argmax(Seq_Predicted)
    
    # ----------- Proba -----------
    Logistic_Predicted_proba = logistic_model.predict_proba(processed_text)
    svm_new_probs = SVM_model.decision_function(processed_text)
    svm_probs = svm_model.predict_proba(svm_new_probs)
    predicted_label_index = np.argmax(Seq_Predicted)

    bert_probabilities = torch.softmax(logits, dim=1)
    max_probability = torch.max(bert_probabilities).item()
    predicted_label_bert = torch.argmax(logits, dim=1).item()
    # ----------- Debug Logs -----------
    logistic_debug = decodedLabel(int(Logistic_Predicted[0]))
    svc_debug = decodedLabel(int(SVM_Predicted[0]))
    # predicted_label_index = np.argmax(Seq_Predicted)
    #print('Logistic', int(Logistic_Predicted[0]), logistic_debug)
    #print('SVM', int(SVM_Predicted[0]), svc_debug)
    
    return {
            'predicted_label_logistic': decodedLabel(int(Logistic_Predicted[0])),
            'probability_logistic': f"{int(float(np.max(Logistic_Predicted_proba))*10000//100)}%",
        
            'predicted_label_svm': decodedLabel(int(SVM_Predicted[0])),
            'probability_svm': f"{int(float(np.max(svm_probs))*10000//100)}%",
        
            'predicted_label_lstm': decodedLabel(int(predicted_label_index)),
            'probability_lstm': f"{int(float(np.max(Seq_Predicted))*10000//100)}%",

            'predicted_label_bert': decodedLabel(int(predicted_label_bert)),
            'probability_bert': f"{int(float(max_probability)*10000//100)}%",
        
            'Article_Content': text
        }

# Init web crawling, process article content by Model and return result as JSON
def categorize(url):
    try:
        article_content = crawURL(url)
        result = process_api(article_content)
        return result
    except Exception as error:
        if hasattr(error, 'message'):
            return {"error_message": error.message}
        else:
            return {"error_message": error}

        
# Main App
st.title('Instant Category Classification')
st.write("Unsure what category a CNN article belongs to? Our clever tool can help! Paste the URL below and press Enter. We'll sort it into one of our 5 categories in a flash! ⚡️")

# Define category information (modify content and bullet points as needed)
categories = {
    "Business": [
        "Analyze market trends and investment opportunities.",
        "Gain insights into company performance and industry news.",
        "Stay informed about economic developments and regulations."
    ],
    "Health": [
        "Discover healthy recipes and exercise tips.",
        "Learn about the latest medical research and advancements.",
        "Find resources for managing chronic conditions and improving well-being."
    ],
    "Sport": [
        "Follow your favorite sports teams and athletes.",
        "Explore news and analysis from various sports categories.",
        "Stay updated on upcoming games and competitions."
    ],
    "Politics": [
        "Get informed about current political events and policies.",
        "Understand different perspectives on political issues.",
        "Engage in discussions and debates about politics."
    ],
    "Entertainment": [
        "Find recommendations for movies, TV shows, and music.",
        "Explore reviews and insights from entertainment critics.",
        "Stay updated on celebrity news and cultural trends."
    ]
}

# Define model information (modify descriptions as needed)
models = {
  "Logistic Regression": "A widely used statistical method for classification problems. It excels at identifying linear relationships between features and the target variable.",
  "SVC (Support Vector Classifier)": "A powerful machine learning model that seeks to find a hyperplane that best separates data points of different classes. It's effective for high-dimensional data and can handle some non-linear relationships.",
  "LSTM (Long Short-Term Memory)": "A type of recurrent neural network (RNN) particularly well-suited for sequential data like text or time series. LSTMs can effectively capture long-term dependencies within the data.",
  "BERT (Bidirectional Encoder Representations from Transformers)": "A powerful pre-trained model based on the Transformer architecture. It excels at understanding the nuances of language and can be fine-tuned for various NLP tasks like text classification."
}

# Create expanders containing list of categories can be classified
with st.expander("Category List"):
  # Title for each category
  st.subheader("Available Categories:")
  for category in categories.keys():
    st.write(f"- {category}")
  # Content for each category (separated by a horizontal line)
  st.write("---")
  for category, content in categories.items():
    st.subheader(category)
    for item in content:
      st.write(f"- {item}")


# Create expanders containing list of models used in this project
with st.expander("Available Models"):
  st.subheader("List of Models:")
  for model_name in models.keys():
    st.write(f"- {model_name}")
  st.write("---")
  for model_name, description in models.items():
    st.subheader(model_name)
    st.write(description)
        
# Explain to user why this project is only worked for CNN domain
with st.expander("Tips", expanded=True):
    st.write(
        '''
            This project works best with CNN articles right now. 
            Our web crawler is like a special tool for CNN's website. 
            It can't quite understand other websites because they're built differently
        '''
    )

st.divider() # 👈 Draws a horizontal rule

st.title('Dive in! See what category your CNN story belongs to 😉.')
# Paste URL Input
url = st.text_input("Find your favorite CNN story! Paste the URL and press ENTER 🔍.", placeholder='Ex: https://edition.cnn.com/2012/01/31/health/frank-njenga-mental-health/index.html')

if url:
    st.divider() # 👈 Draws a horizontal rule
    result = categorize(url)
    article_content = result.get('Article_Content')
    st.title('Article Content Fetched')
    st.text_area("", value=article_content, height=400) # render the article content as textarea element
    st.divider()  # 👈 Draws a horizontal rule
    st.title('Predicted Results')
    st.json({
        "Logistic": {
            "predicted_label": result.get("predicted_label_logistic"),
            "probability": result.get("probability_logistic")
        },
        "SVC": {
            "predicted_label": result.get("predicted_label_svm"),
            "probability": result.get("probability_svm")
        },
        "LSTM": {
            "predicted_label": result.get("predicted_label_lstm"),
            "probability": result.get("probability_lstm")
        },
        "BERT": {
            "predicted_label": result.get("predicted_label_bert"),
            "probability": result.get("probability_bert")
        }
    })
    
st.divider()  # 👈 Draws a horizontal rule

# Category labels and corresponding counts
categories = ["Sport", "Health", "Entertainment", "Politics", "Business"]
counts = [5638, 4547, 2658, 2461, 1362]

# Optional: Add a chart title
st.title("Training Data Category Distribution")

# Optional: Display additional information
st.write("Here's a breakdown of the number of articles in each category:")
for category, count in zip(categories, counts):
  st.write(f"- {category}: {count}")

# Create the bar chart
st.bar_chart(data=dict(zip(categories, counts)))

st.divider()  # 👈 Draws a horizontal rule

# ------------ Copyright Section ------------
# Get the current year
current_year = date.today().year
# Format the copyright statement with dynamic year
copyright_text = f"Copyright © {current_year}"
st.title(copyright_text)
author_names = ["Trần Thanh Phước (Mentor)", "Lương Ngọc Phương (Member)", "Trịnh Cẩm Minh (Member)"]
st.write("Meet the minds behind the work!")
for author in author_names:
    if (author == "Trịnh Cẩm Minh (Member)"): st.markdown("- [Trịnh Cẩm Minh (Member)](https://minhct.netlify.app/)")
    else: st.markdown(f"- {author}\n")  # Use f-string for bullet and newline