|
import joblib |
|
import streamlit as st |
|
import json |
|
import requests |
|
from bs4 import BeautifulSoup |
|
from datetime import date |
|
from tensorflow.keras.models import load_model |
|
from tensorflow.keras.preprocessing.text import Tokenizer |
|
from tensorflow.keras.preprocessing.sequence import pad_sequences |
|
import numpy as np |
|
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification |
|
import torch |
|
|
|
|
|
Seq_model = load_model("LSTM.h5") |
|
SVM_model = joblib.load("SVM_Linear_Kernel.joblib") |
|
logistic_model = joblib.load("Logistic_Model.joblib") |
|
svm_model = joblib.load('svm_model.joblib') |
|
|
|
vectorizer = joblib.load("vectorizer.joblib") |
|
tokenizer = joblib.load("tokenizer.joblib") |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
tokenizer1 = DistilBertTokenizer.from_pretrained("tokenizer_bert") |
|
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=5) |
|
model.load_state_dict(torch.load("fine_tuned_bert_model1.pth", map_location=device)) |
|
|
|
|
|
|
|
def decodedLabel(input_number): |
|
print('receive label encoded', input_number) |
|
categories = { |
|
0: 'Business', |
|
1: 'Entertainment', |
|
2: 'Health', |
|
3: 'Politics', |
|
4: 'Sport' |
|
} |
|
result = categories.get(input_number) |
|
print('decoded result', result) |
|
return result |
|
|
|
|
|
def crawURL(url): |
|
|
|
response = requests.get(url) |
|
|
|
soup = BeautifulSoup(response.content, 'html.parser') |
|
|
|
|
|
urls = [span.a['href'] for span in soup.find_all('span', class_='sitemap-link') if span.a] |
|
|
|
|
|
try: |
|
print(f"Crawling page: {url}") |
|
|
|
page_response = requests.get(url) |
|
page_content = page_response.content |
|
|
|
|
|
soup = BeautifulSoup(page_content, 'html.parser') |
|
|
|
|
|
author = soup.find("meta", {"name": "author"}).attrs['content'].strip() |
|
date_published = soup.find("meta", {"property": "article:published_time"}).attrs['content'].strip() |
|
article_section = soup.find("meta", {"name": "meta-section"}).attrs['content'] |
|
url = soup.find("meta", {"property": "og:url"}).attrs['content'] |
|
headline = soup.find("h1", {"data-editable": "headlineText"}).text.strip() |
|
description = soup.find("meta", {"name": "description"}).attrs['content'].strip() |
|
keywords = soup.find("meta", {"name": "keywords"}).attrs['content'].strip() |
|
text = soup.find(itemprop="articleBody") |
|
|
|
paragraphs = text.find_all('p', class_="paragraph inline-placeholder") |
|
|
|
|
|
paragraph_texts = [] |
|
|
|
|
|
for paragraph in paragraphs: |
|
paragraph_texts.append(paragraph.text.strip()) |
|
|
|
|
|
full_text = ''.join(paragraph_texts) |
|
return full_text |
|
|
|
except Exception as e: |
|
print(f"Failed to crawl page: {url}, Error: {str(e)}") |
|
return None |
|
|
|
|
|
def process_api(text): |
|
|
|
processed_text = vectorizer.transform([text]) |
|
sequence = tokenizer.texts_to_sequences([text]) |
|
padded_sequence = pad_sequences(sequence, maxlen=1000, padding='post') |
|
|
|
new_encoding = tokenizer1([text], truncation=True, padding=True, return_tensors="pt") |
|
input_ids = new_encoding['input_ids'] |
|
attention_mask = new_encoding['attention_mask'] |
|
with torch.no_grad(): |
|
output = model(input_ids, attention_mask=attention_mask) |
|
logits = output.logits |
|
|
|
|
|
Logistic_Predicted = logistic_model.predict(processed_text).tolist() |
|
SVM_Predicted = SVM_model.predict(processed_text).tolist() |
|
Seq_Predicted = Seq_model.predict(padded_sequence) |
|
predicted_label_index = np.argmax(Seq_Predicted) |
|
|
|
|
|
Logistic_Predicted_proba = logistic_model.predict_proba(processed_text) |
|
svm_new_probs = SVM_model.decision_function(processed_text) |
|
svm_probs = svm_model.predict_proba(svm_new_probs) |
|
predicted_label_index = np.argmax(Seq_Predicted) |
|
|
|
bert_probabilities = torch.softmax(logits, dim=1) |
|
max_probability = torch.max(bert_probabilities).item() |
|
predicted_label_bert = torch.argmax(logits, dim=1).item() |
|
|
|
logistic_debug = decodedLabel(int(Logistic_Predicted[0])) |
|
svc_debug = decodedLabel(int(SVM_Predicted[0])) |
|
|
|
|
|
|
|
|
|
return { |
|
'predicted_label_logistic': decodedLabel(int(Logistic_Predicted[0])), |
|
'probability_logistic': f"{int(float(np.max(Logistic_Predicted_proba))*10000//100)}%", |
|
|
|
'predicted_label_svm': decodedLabel(int(SVM_Predicted[0])), |
|
'probability_svm': f"{int(float(np.max(svm_probs))*10000//100)}%", |
|
|
|
'predicted_label_lstm': decodedLabel(int(predicted_label_index)), |
|
'probability_lstm': f"{int(float(np.max(Seq_Predicted))*10000//100)}%", |
|
|
|
'predicted_label_bert': decodedLabel(int(predicted_label_bert)), |
|
'probability_bert': f"{int(float(max_probability)*10000//100)}%", |
|
|
|
'Article_Content': text |
|
} |
|
|
|
|
|
def categorize(url): |
|
try: |
|
article_content = crawURL(url) |
|
result = process_api(article_content) |
|
return result |
|
except Exception as error: |
|
if hasattr(error, 'message'): |
|
return {"error_message": error.message} |
|
else: |
|
return {"error_message": error} |
|
|
|
|
|
|
|
st.title('Instant Category Classification') |
|
st.write("Unsure what category a CNN article belongs to? Our clever tool can help! Paste the URL below and press Enter. We'll sort it into one of our 5 categories in a flash! ⚡️") |
|
|
|
|
|
categories = { |
|
"Business": [ |
|
"Analyze market trends and investment opportunities.", |
|
"Gain insights into company performance and industry news.", |
|
"Stay informed about economic developments and regulations." |
|
], |
|
"Health": [ |
|
"Discover healthy recipes and exercise tips.", |
|
"Learn about the latest medical research and advancements.", |
|
"Find resources for managing chronic conditions and improving well-being." |
|
], |
|
"Sport": [ |
|
"Follow your favorite sports teams and athletes.", |
|
"Explore news and analysis from various sports categories.", |
|
"Stay updated on upcoming games and competitions." |
|
], |
|
"Politics": [ |
|
"Get informed about current political events and policies.", |
|
"Understand different perspectives on political issues.", |
|
"Engage in discussions and debates about politics." |
|
], |
|
"Entertainment": [ |
|
"Find recommendations for movies, TV shows, and music.", |
|
"Explore reviews and insights from entertainment critics.", |
|
"Stay updated on celebrity news and cultural trends." |
|
] |
|
} |
|
|
|
|
|
models = { |
|
"Logistic Regression": "A widely used statistical method for classification problems. It excels at identifying linear relationships between features and the target variable.", |
|
"SVC (Support Vector Classifier)": "A powerful machine learning model that seeks to find a hyperplane that best separates data points of different classes. It's effective for high-dimensional data and can handle some non-linear relationships.", |
|
"LSTM (Long Short-Term Memory)": "A type of recurrent neural network (RNN) particularly well-suited for sequential data like text or time series. LSTMs can effectively capture long-term dependencies within the data.", |
|
"BERT (Bidirectional Encoder Representations from Transformers)": "A powerful pre-trained model based on the Transformer architecture. It excels at understanding the nuances of language and can be fine-tuned for various NLP tasks like text classification." |
|
} |
|
|
|
|
|
|
|
URL_Example = [ |
|
'https://edition.cnn.com/2012/01/31/health/frank-njenga-mental-health/index.html', |
|
'https://edition.cnn.com/2024/04/30/entertainment/barbra-streisand-melissa-mccarthy-ozempic/index.html', |
|
'https://edition.cnn.com/2024/04/30/sport/lebron-james-lakers-future-nba-spt-intl/index.html', |
|
'https://edition.cnn.com/2024/04/30/business/us-home-prices-rose-in-february/index.html' |
|
] |
|
|
|
|
|
with st.expander("Category List"): |
|
|
|
st.subheader("Available Categories:") |
|
for category in categories.keys(): |
|
st.write(f"- {category}") |
|
|
|
st.write("---") |
|
for category, content in categories.items(): |
|
st.subheader(category) |
|
for item in content: |
|
st.write(f"- {item}") |
|
|
|
|
|
|
|
with st.expander("Available Models"): |
|
st.subheader("List of Models:") |
|
for model_name in models.keys(): |
|
st.write(f"- {model_name}") |
|
st.write("---") |
|
for model_name, description in models.items(): |
|
st.subheader(model_name) |
|
st.write(description) |
|
|
|
with st.expander("URLs Example"): |
|
for url in URL_Example: |
|
st.write(f"- {url}") |
|
|
|
|
|
with st.expander("Tips", expanded=True): |
|
st.write( |
|
''' |
|
This project works best with CNN articles right now. |
|
Our web crawler is like a special tool for CNN's website. |
|
It can't quite understand other websites because they're built differently |
|
''' |
|
) |
|
|
|
st.divider() |
|
|
|
st.title('Dive in! See what category your CNN story belongs to 😉.') |
|
|
|
url = st.text_input("Find your favorite CNN story! Paste the URL and press ENTER 🔍.", placeholder='Ex: https://edition.cnn.com/2012/01/31/health/frank-njenga-mental-health/index.html') |
|
|
|
if url: |
|
st.divider() |
|
result = categorize(url) |
|
article_content = result.get('Article_Content') |
|
st.title('Article Content Fetched') |
|
st.text_area("", value=article_content, height=400) |
|
st.divider() |
|
st.title('Predicted Results') |
|
st.json({ |
|
"Logistic": { |
|
"predicted_label": result.get("predicted_label_logistic"), |
|
"probability": result.get("probability_logistic") |
|
}, |
|
"SVC": { |
|
"predicted_label": result.get("predicted_label_svm"), |
|
"probability": result.get("probability_svm") |
|
}, |
|
"LSTM": { |
|
"predicted_label": result.get("predicted_label_lstm"), |
|
"probability": result.get("probability_lstm") |
|
}, |
|
"BERT": { |
|
"predicted_label": result.get("predicted_label_bert"), |
|
"probability": result.get("probability_bert") |
|
} |
|
}) |
|
|
|
st.divider() |
|
|
|
|
|
categories = ["Sport", "Health", "Entertainment", "Politics", "Business"] |
|
counts = [5638, 4547, 2658, 2461, 1362] |
|
|
|
|
|
st.title("Training Data Category Distribution") |
|
|
|
|
|
st.write("Here's a breakdown of the number of articles in each category:") |
|
for category, count in zip(categories, counts): |
|
st.write(f"- {category}: {count}") |
|
|
|
|
|
st.bar_chart(data=dict(zip(categories, counts))) |
|
|
|
st.divider() |
|
|
|
|
|
|
|
current_year = date.today().year |
|
|
|
copyright_text = f"Copyright © {current_year}" |
|
st.title(copyright_text) |
|
author_names = ["Trần Thanh Phước (Mentor)", "Lương Ngọc Phương (Member)", "Trịnh Cẩm Minh (Member)"] |
|
st.write("Meet the minds behind the work!") |
|
for author in author_names: |
|
if (author == "Trịnh Cẩm Minh (Member)"): st.markdown("- [Trịnh Cẩm Minh (Member)](https://minhct.netlify.app/)") |
|
else: st.markdown(f"- {author}\n") |