Spaces:
Runtime error
Runtime error
File size: 13,128 Bytes
77f904f 69a88fc e024b69 ba5f8d5 2335fe9 80a186f 643501b 2a092a5 3c09b95 84d8e35 f8f4b5f 26ff28f d79611d 75dc5f6 a3039e9 b71cd64 26ff28f f8f4b5f 01fb40f 3c09b95 c89a171 3c09b95 3c7207a 22da72a af31e8a 3c7207a 192d8ff 3c7207a b71cd64 d6a49a1 b71cd64 9ee5788 d6a49a1 84d8e35 d6a49a1 84d8e35 d6a49a1 84d8e35 d6a49a1 84d8e35 d6a49a1 84d8e35 d6a49a1 84d8e35 de80a7e d6a49a1 2335fe9 d6a49a1 9412e99 3c09b95 b71cd64 d6a49a1 3c7207a 26ff28f e5faf6c 92e286d 3c09b95 22da72a 6bd8770 26ff28f 52f1084 d6a49a1 3890fa9 901b98e 90f56df 92e286d 3c09b95 4d90a4c 3c09b95 92e286d 7e83a81 d6a49a1 2335fe9 b21cabd 0b9ee05 b21cabd 0b9ee05 b21cabd 90f56df 92e286d db06d58 d339e73 3c09b95 4604f63 2335fe9 4604f63 247496e 2335fe9 b21cabd 2335fe9 dd435da 2335fe9 d8b3aa1 85a06e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 |
import joblib
import streamlit as st
import json
import requests
from bs4 import BeautifulSoup
from datetime import date
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import numpy as np
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch
# load all the models and vectorizer (global vocabulary)
Seq_model = load_model("LSTM.h5") # Sequential
SVM_model = joblib.load("SVM_Linear_Kernel.joblib") # SVM
logistic_model = joblib.load("Logistic_Model.joblib") # Logistic
svm_model = joblib.load('svm_model.joblib')
vectorizer = joblib.load("vectorizer.joblib") # global vocabulary (used for Logistic, SVC)
tokenizer = joblib.load("tokenizer.joblib") # used for LSTM
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer1 = DistilBertTokenizer.from_pretrained("tokenizer_bert")
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=5)
model.load_state_dict(torch.load("fine_tuned_bert_model1.pth", map_location=device))
# Decode label function
# {'business': 0, 'entertainment': 1, 'health': 2, 'politics': 3, 'sport': 4}
def decodedLabel(input_number):
print('receive label encoded', input_number)
categories = {
0: 'Business',
1: 'Entertainment',
2: 'Health',
3: 'Politics',
4: 'Sport'
}
result = categories.get(input_number) # Ex: Health
print('decoded result', result)
return result
# Web Crawler function
def crawURL(url):
# Fetch the URL content
response = requests.get(url)
# Parse the sitemap HTML
soup = BeautifulSoup(response.content, 'html.parser')
# Find all anchor tags that are children of span tags with class 'sitemap-link'
urls = [span.a['href'] for span in soup.find_all('span', class_='sitemap-link') if span.a]
# Crawl pages and extract data
try:
print(f"Crawling page: {url}")
# Fetch page content
page_response = requests.get(url)
page_content = page_response.content
# Parse page content with BeautifulSoup
soup = BeautifulSoup(page_content, 'html.parser')
# Extract data you need from the page
author = soup.find("meta", {"name": "author"}).attrs['content'].strip()
date_published = soup.find("meta", {"property": "article:published_time"}).attrs['content'].strip()
article_section = soup.find("meta", {"name": "meta-section"}).attrs['content']
url = soup.find("meta", {"property": "og:url"}).attrs['content']
headline = soup.find("h1", {"data-editable": "headlineText"}).text.strip()
description = soup.find("meta", {"name": "description"}).attrs['content'].strip()
keywords = soup.find("meta", {"name": "keywords"}).attrs['content'].strip()
text = soup.find(itemprop="articleBody")
# Find all <p> tags with class "paragraph inline-placeholder"
paragraphs = text.find_all('p', class_="paragraph inline-placeholder")
# Initialize an empty list to store the text content of each paragraph
paragraph_texts = []
# Iterate over each <p> tag and extract its text content
for paragraph in paragraphs:
paragraph_texts.append(paragraph.text.strip())
# Join the text content of all paragraphs into a single string
full_text = ''.join(paragraph_texts)
return full_text
except Exception as e:
print(f"Failed to crawl page: {url}, Error: {str(e)}")
return None
# Predict for text category by Models
def process_api(text):
# Vectorize the text data
processed_text = vectorizer.transform([text])
sequence = tokenizer.texts_to_sequences([text])
padded_sequence = pad_sequences(sequence, maxlen=1000, padding='post')
new_encoding = tokenizer1([text], truncation=True, padding=True, return_tensors="pt")
input_ids = new_encoding['input_ids']
attention_mask = new_encoding['attention_mask']
with torch.no_grad():
output = model(input_ids, attention_mask=attention_mask)
logits = output.logits
# Get the predicted result from models
Logistic_Predicted = logistic_model.predict(processed_text).tolist() # Logistic Model
SVM_Predicted = SVM_model.predict(processed_text).tolist() # SVC Model
Seq_Predicted = Seq_model.predict(padded_sequence)
predicted_label_index = np.argmax(Seq_Predicted)
# ----------- Proba -----------
Logistic_Predicted_proba = logistic_model.predict_proba(processed_text)
svm_new_probs = SVM_model.decision_function(processed_text)
svm_probs = svm_model.predict_proba(svm_new_probs)
predicted_label_index = np.argmax(Seq_Predicted)
bert_probabilities = torch.softmax(logits, dim=1)
max_probability = torch.max(bert_probabilities).item()
predicted_label_bert = torch.argmax(logits, dim=1).item()
# ----------- Debug Logs -----------
logistic_debug = decodedLabel(int(Logistic_Predicted[0]))
svc_debug = decodedLabel(int(SVM_Predicted[0]))
# predicted_label_index = np.argmax(Seq_Predicted)
#print('Logistic', int(Logistic_Predicted[0]), logistic_debug)
#print('SVM', int(SVM_Predicted[0]), svc_debug)
return {
'predicted_label_logistic': decodedLabel(int(Logistic_Predicted[0])),
'probability_logistic': f"{int(float(np.max(Logistic_Predicted_proba))*10000//100)}%",
'predicted_label_svm': decodedLabel(int(SVM_Predicted[0])),
'probability_svm': f"{int(float(np.max(svm_probs))*10000//100)}%",
'predicted_label_lstm': decodedLabel(int(predicted_label_index)),
'probability_lstm': f"{int(float(np.max(Seq_Predicted))*10000//100)}%",
'predicted_label_bert': decodedLabel(int(predicted_label_bert)),
'probability_bert': f"{int(float(max_probability)*10000//100)}%",
'Article_Content': text
}
# Init web crawling, process article content by Model and return result as JSON
def categorize(url):
try:
article_content = crawURL(url)
result = process_api(article_content)
return result
except Exception as error:
if hasattr(error, 'message'):
return {"error_message": error.message}
else:
return {"error_message": error}
# Main App
st.title('Instant Category Classification')
st.write("Unsure what category a CNN article belongs to? Our clever tool can help! Paste the URL below and press Enter. We'll sort it into one of our 5 categories in a flash! ⚡️")
# Define category information (modify content and bullet points as needed)
categories = {
"Business": [
"Analyze market trends and investment opportunities.",
"Gain insights into company performance and industry news.",
"Stay informed about economic developments and regulations."
],
"Health": [
"Discover healthy recipes and exercise tips.",
"Learn about the latest medical research and advancements.",
"Find resources for managing chronic conditions and improving well-being."
],
"Sport": [
"Follow your favorite sports teams and athletes.",
"Explore news and analysis from various sports categories.",
"Stay updated on upcoming games and competitions."
],
"Politics": [
"Get informed about current political events and policies.",
"Understand different perspectives on political issues.",
"Engage in discussions and debates about politics."
],
"Entertainment": [
"Find recommendations for movies, TV shows, and music.",
"Explore reviews and insights from entertainment critics.",
"Stay updated on celebrity news and cultural trends."
]
}
# Define model information (modify descriptions as needed)
models = {
"Logistic Regression": "A widely used statistical method for classification problems. It excels at identifying linear relationships between features and the target variable.",
"SVC (Support Vector Classifier)": "A powerful machine learning model that seeks to find a hyperplane that best separates data points of different classes. It's effective for high-dimensional data and can handle some non-linear relationships.",
"LSTM (Long Short-Term Memory)": "A type of recurrent neural network (RNN) particularly well-suited for sequential data like text or time series. LSTMs can effectively capture long-term dependencies within the data.",
"BERT (Bidirectional Encoder Representations from Transformers)": "A powerful pre-trained model based on the Transformer architecture. It excels at understanding the nuances of language and can be fine-tuned for various NLP tasks like text classification."
}
# CNN URL Example List
URL_Example = [
'https://edition.cnn.com/2012/01/31/health/frank-njenga-mental-health/index.html',
'https://edition.cnn.com/2024/04/30/entertainment/barbra-streisand-melissa-mccarthy-ozempic/index.html',
'https://edition.cnn.com/2024/04/30/sport/lebron-james-lakers-future-nba-spt-intl/index.html',
'https://edition.cnn.com/2024/04/30/business/us-home-prices-rose-in-february/index.html'
]
# Create expanders containing list of categories can be classified
with st.expander("Category List"):
# Title for each category
st.subheader("Available Categories:")
for category in categories.keys():
st.write(f"- {category}")
# Content for each category (separated by a horizontal line)
st.write("---")
for category, content in categories.items():
st.subheader(category)
for item in content:
st.write(f"- {item}")
# Create expanders containing list of models used in this project
with st.expander("Available Models"):
st.subheader("List of Models:")
for model_name in models.keys():
st.write(f"- {model_name}")
st.write("---")
for model_name, description in models.items():
st.subheader(model_name)
st.write(description)
with st.expander("URLs Example"):
for url in URL_Example:
st.write(f"- {url}")
# Explain to user why this project is only worked for CNN domain
with st.expander("Tips", expanded=True):
st.write(
'''
This project works best with CNN articles right now.
Our web crawler is like a special tool for CNN's website.
It can't quite understand other websites because they're built differently
'''
)
st.divider() # 👈 Draws a horizontal rule
st.title('Dive in! See what category your CNN story belongs to 😉.')
# Paste URL Input
url = st.text_input("Find your favorite CNN story! Paste the URL and press ENTER 🔍.", placeholder='Ex: https://edition.cnn.com/2012/01/31/health/frank-njenga-mental-health/index.html')
if url:
st.divider() # 👈 Draws a horizontal rule
result = categorize(url)
article_content = result.get('Article_Content')
st.title('Article Content Fetched')
st.text_area("", value=article_content, height=400) # render the article content as textarea element
st.divider() # 👈 Draws a horizontal rule
st.title('Predicted Results')
st.json({
"Logistic": {
"predicted_label": result.get("predicted_label_logistic"),
"probability": result.get("probability_logistic")
},
"SVC": {
"predicted_label": result.get("predicted_label_svm"),
"probability": result.get("probability_svm")
},
"LSTM": {
"predicted_label": result.get("predicted_label_lstm"),
"probability": result.get("probability_lstm")
},
"BERT": {
"predicted_label": result.get("predicted_label_bert"),
"probability": result.get("probability_bert")
}
})
st.divider() # 👈 Draws a horizontal rule
# Category labels and corresponding counts
categories = ["Sport", "Health", "Entertainment", "Politics", "Business"]
counts = [5638, 4547, 2658, 2461, 1362]
# Optional: Add a chart title
st.title("Training Data Category Distribution")
# Optional: Display additional information
st.write("Here's a breakdown of the number of articles in each category:")
for category, count in zip(categories, counts):
st.write(f"- {category}: {count}")
# Create the bar chart
st.bar_chart(data=dict(zip(categories, counts)))
st.divider() # 👈 Draws a horizontal rule
# ------------ Copyright Section ------------
# Get the current year
current_year = date.today().year
# Format the copyright statement with dynamic year
copyright_text = f"Copyright © {current_year}"
st.title(copyright_text)
author_names = ["Trần Thanh Phước (Mentor)", "Lương Ngọc Phương (Member)", "Trịnh Cẩm Minh (Member)"]
st.write("Meet the minds behind the work!")
for author in author_names:
if (author == "Trịnh Cẩm Minh (Member)"): st.markdown("- [Trịnh Cẩm Minh (Member)](https://minhct.netlify.app/)")
else: st.markdown(f"- {author}\n") # Use f-string for bullet and newline |