TruthCheck / src /app.py
adnaan05's picture
Update src/app.py (#52)
911ad03 verified
import streamlit as st
import torch
import pandas as pd
import numpy as np
from pathlib import Path
import sys
import plotly.express as px
import plotly.graph_objects as go
from transformers import BertTokenizer
import nltk
# Download required NLTK data
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download('punkt')
try:
nltk.data.find('corpora/stopwords')
except LookupError:
nltk.download('stopwords')
try:
nltk.data.find('tokenizers/punkt_tab')
except LookupError:
nltk.download('punkt_tab')
try:
nltk.data.find('corpora/wordnet')
except LookupError:
nltk.download('wordnet')
# Add project root to Python path
project_root = Path(__file__).parent.parent
sys.path.append(str(project_root))
from src.models.hybrid_model import HybridFakeNewsDetector
from src.config.config import *
from src.data.preprocessor import TextPreprocessor
# Page config is set in main app.py
@st.cache_resource
def load_model_and_tokenizer():
"""Load the model and tokenizer (cached)."""
# Initialize model
model = HybridFakeNewsDetector(
bert_model_name=BERT_MODEL_NAME,
lstm_hidden_size=LSTM_HIDDEN_SIZE,
lstm_num_layers=LSTM_NUM_LAYERS,
dropout_rate=DROPOUT_RATE
)
# Load trained weights
state_dict = torch.load(SAVED_MODELS_DIR / "final_model.pt", map_location=torch.device('cpu'))
# Filter out unexpected keys
model_state_dict = model.state_dict()
filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict}
# Load the filtered state dict
model.load_state_dict(filtered_state_dict, strict=False)
model.eval()
# Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)
return model, tokenizer
@st.cache_resource
def get_preprocessor():
"""Get the text preprocessor (cached)."""
return TextPreprocessor()
def predict_news(text):
"""Predict if the given news is fake or real."""
# Get model, tokenizer, and preprocessor from cache
model, tokenizer = load_model_and_tokenizer()
preprocessor = get_preprocessor()
# Preprocess text
processed_text = preprocessor.preprocess_text(text)
# Tokenize
encoding = tokenizer.encode_plus(
processed_text,
add_special_tokens=True,
max_length=MAX_SEQUENCE_LENGTH,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
# Get prediction
with torch.no_grad():
outputs = model(
encoding['input_ids'],
encoding['attention_mask']
)
probabilities = torch.softmax(outputs['logits'], dim=1)
prediction = torch.argmax(outputs['logits'], dim=1)
attention_weights = outputs['attention_weights']
# Convert attention weights to numpy and get the first sequence
attention_weights_np = attention_weights[0].cpu().numpy()
return {
'prediction': prediction.item(),
'label': 'FAKE' if prediction.item() == 1 else 'REAL',
'confidence': torch.max(probabilities, dim=1)[0].item(),
'probabilities': {
'REAL': probabilities[0][0].item(),
'FAKE': probabilities[0][1].item()
},
'attention_weights': attention_weights_np
}
def plot_confidence(probabilities):
"""Plot prediction confidence."""
fig = go.Figure(data=[
go.Bar(
x=list(probabilities.keys()),
y=list(probabilities.values()),
text=[f'{p:.2%}' for p in probabilities.values()],
textposition='auto',
)
])
fig.update_layout(
title='Prediction Confidence',
xaxis_title='Class',
yaxis_title='Probability',
yaxis_range=[0, 1]
)
return fig
def plot_attention(text, attention_weights):
"""Plot attention weights."""
tokens = text.split()
attention_weights = attention_weights[:len(tokens)] # Truncate to match tokens
# Ensure attention weights are in the correct format
if isinstance(attention_weights, (list, np.ndarray)):
attention_weights = np.array(attention_weights).flatten()
# Format weights for display
formatted_weights = [f'{float(w):.2f}' for w in attention_weights]
fig = go.Figure(data=[
go.Bar(
x=tokens,
y=attention_weights,
text=formatted_weights,
textposition='auto',
)
])
fig.update_layout(
title='Attention Weights',
xaxis_title='Tokens',
yaxis_title='Attention Weight',
xaxis_tickangle=45
)
return fig
def main():
# Main Container
st.markdown('<div class="main-container">', unsafe_allow_html=True)
# Custom CSS with Poppins font
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Poppins:wght@200;300;400;500;600;700&display=swap');
* {
font-family: 'Poppins', sans-serif !important;
box-sizing: border-box;
}
.stApp {
background: #ffffff;
min-height: 100vh;
color: #1f2a44;
}
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
.stDeployButton {display: none;}
header {visibility: hidden;}
.stApp > header {visibility: hidden;}
/* Main Container */
.main-container {
max-width: 1200px;
margin: 0 auto;
padding: 1rem 2rem;
}
/* Header Section */
.header-section {
text-align: center;
margin-bottom: 2.5rem;
padding: 1.5rem 0;
}
.header-title {
font-size: 2.25rem;
font-weight: 700;
color: #1f2a44;
margin: 0;
}
/* Section Styling */
.section {
margin-bottom: 2.5rem;
max-width: 1200px;
margin-left: auto;
margin-right: auto;
padding: 0 1rem;
}
.section-title {
font-size: 1.5rem;
font-weight: 600;
color: #1f2a44;
margin-bottom: 1rem;
display: flex;
align-items: center;
gap: 0.5rem;
}
.section-text {
font-size: 0.95rem;
color: #6b7280;
line-height: 1.6;
max-width: 800px;
margin: 0 auto;
}
/* Sidebar */
.stSidebar {
background: #f4f7fa;
padding: 1rem;
}
/* Input Section */
.input-container {
max-width: 800px;
margin: 0 auto;
}
.stTextArea > div > div > textarea {
border-radius: 8px !important;
border: 1px solid #d1d5db !important;
padding: 1rem !important;
font-size: 1rem !important;
background: #ffffff !important;
min-height: 200px !important;
transition: all 0.2s ease !important;
}
.stTextArea > div > div > textarea:focus {
border-color: #6366f1 !important;
box-shadow: 0 0 0 2px rgba(99, 102, 241, 0.1) !important;
outline: none !important;
}
.stTextArea > div > div > textarea::placeholder {
color: #9ca3af !important;
}
/* Button Styling */
.stButton > button {
background: #6366f1 !important;
color: white !important;
border-radius: 8px !important;
padding: 0.75rem 2rem !important;
font-size: 1rem !important;
font-weight: 600 !important;
transition: all 0.2s ease !important;
border: none !important;
width: 100% !important;
max-width: 300px;
}
.stButton > button:hover {
background: #4f46e5 !important;
transform: translateY(-1px) !important;
}
/* Results Section */
.results-container {
margin-top: 1rem;
padding: 1rem;
border-radius: 8px;
max-width: 1200px;
margin-left: auto;
margin-right: auto;
}
.result-card {
padding: 1rem;
border-radius: 8px;
border-left: 4px solid transparent;
margin-bottom: 1rem;
}
.fake-news {
background: #fef2f2;
border-left-color: #ef4444;
}
.real-news {
background: #ecfdf5;
border-left-color: #10b981;
}
.prediction-badge {
font-weight: 600;
font-size: 1rem;
margin-bottom: 0.5rem;
display: flex;
align-items: center;
gap: 0.5rem;
}
.confidence-score {
font-weight: 600;
margin-left: auto;
font-size: 1rem;
}
/* Chart Containers */
.chart-container {
padding: 1rem;
border-radius: 8px;
margin: 1rem 0;
max-width: 1200px;
margin-left: auto;
margin-right: auto;
}
/* Footer */
.footer {
border-top: 1px solid #e5e7eb;
padding: 1.5rem 0;
text-align: center;
max-width: 1200px;
margin: 2rem auto 0;
}
/* Responsive Design */
@media (max-width: 1024px) {
.main-container {
padding: 1rem;
}
.section {
padding: 0 0.5rem;
}
}
@media (max-width: 768px) {
.header-title {
font-size: 1.75rem;
}
.section-title {
font-size: 1.25rem;
}
.section-text {
font-size: 0.9rem;
}
}
@media (max-width: 480px) {
.header-title {
font-size: 1.5rem;
}
.section-title {
font-size: 1.1rem;
}
.section-text {
font-size: 0.85rem;
}
}
</style>
""", unsafe_allow_html=True)
# Header Section
st.markdown('<h1 class="header-title">πŸ“° TruthCheck - Advanced Fake News Detection System</h1>', unsafe_allow_html=True)
st.markdown('</div>', unsafe_allow_html=True)
# Main Content
st.markdown('<div class="section">', unsafe_allow_html=True)
st.markdown('<p class="section-text">This application uses a hybrid deep learning model (BERT + BiLSTM + Attention) to detect fake news articles. Enter a news article below to analyze it.</p>', unsafe_allow_html=True)
st.markdown('</div>', unsafe_allow_html=True)
# News Analysis Section
st.markdown('<div class="section">', unsafe_allow_html=True)
st.markdown('<h2 class="section-title">πŸ“‹ News Analysis</h2>', unsafe_allow_html=True)
# Input Section
st.markdown('<div class="input-container">', unsafe_allow_html=True)
news_text = st.text_area(
"Enter the news article to analyze:",
height=200,
placeholder="Paste your news article here..."
)
st.markdown('</div>', unsafe_allow_html=True)
if st.button("Analyze"):
if news_text:
with st.spinner("Analyzing the news article..."):
# Get prediction
result = predict_news(news_text)
# Display result
col1, col2 = st.columns(2)
with col1:
st.markdown('<div class="results-container">', unsafe_allow_html=True)
st.markdown('<h3 class="section-title">πŸ” Prediction</h3>', unsafe_allow_html=True)
if result['label'] == 'FAKE':
st.markdown(f'<div class="result-card fake-news"><div class="prediction-badge">🚨 Fake News Detected <span class="confidence-score">{result["confidence"]:.2%}</span></div></div>', unsafe_allow_html=True)
else:
st.markdown(f'<div class="result-card real-news"><div class="prediction-badge">βœ… Authentic News <span class="confidence-score">{result["confidence"]:.2%}</span></div></div>', unsafe_allow_html=True)
st.markdown('</div>', unsafe_allow_html=True)
with col2:
st.markdown('<div class="results-container">', unsafe_allow_html=True)
st.markdown('<h3 class="section-title">πŸ“Š Confidence Scores</h3>', unsafe_allow_html=True)
st.plotly_chart(plot_confidence(result['probabilities']), use_container_width=True)
st.markdown('</div>', unsafe_allow_html=True)
# Show attention visualization
st.markdown('<div class="section">', unsafe_allow_html=True)
st.markdown('<h3 class="section-title">πŸ‘οΈ Attention Analysis</h3>', unsafe_allow_html=True)
st.markdown('<p class="section-text">The attention weights show which parts of the text the model focused on while making its prediction. Higher weights indicate more important tokens.</p>', unsafe_allow_html=True)
st.plotly_chart(plot_attention(news_text, result['attention_weights']), use_container_width=True)
st.markdown('</div>', unsafe_allow_html=True)
# Show model explanation
st.markdown('<div class="section">', unsafe_allow_html=True)
st.markdown('<h3 class="section-title">πŸ“ Model Explanation</h3>', unsafe_allow_html=True)
if result['label'] == 'FAKE':
st.markdown('<p class="section-text">The model identified this as fake news based on:<ul><li>Linguistic patterns typical of fake news</li><li>Inconsistencies in the content</li><li>Attention weights on suspicious phrases</li></ul></p>', unsafe_allow_html=True)
else:
st.markdown('<p class="section-text">The model identified this as real news based on:<ul><li>Credible language patterns</li><li>Consistent information</li><li>Attention weights on factual statements</li></ul></p>', unsafe_allow_html=True)
st.markdown('</div>', unsafe_allow_html=True)
else:
st.warning("Please enter a news article to analyze.")
# Footer
st.markdown(
'<div class="footer"><p style="text-align: center; font-weight: 600; font-size: 16px;">πŸ’» Developed with ❀️ using Streamlit | Β© 2025</p></div>',
unsafe_allow_html=True
)
st.markdown('</div>', unsafe_allow_html=True) # Close main-container
if __name__ == "__main__":
main()