FactChecker / app.py
Saty01's picture
FactChecker files(final_deploy)
844d2b9 verified
from flask import Flask, request, jsonify, send_from_directory
import pickle
import torch
import re
import nltk
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from transformers import DistilBertTokenizer, DistilBertModel
import torch.nn as nn
import os
import numpy
# Download NLTK stuff
nltk.data.path.append('/usr/local/share/nltk_data')
nltk.download('punkt_tab')
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('punkt')
app = Flask(__name__, static_folder='build', static_url_path='')
# Define DistilBERT model class
class DistilBERTClassifier(nn.Module):
def __init__(self, dropout_rate=0.2):
super(DistilBERTClassifier, self).__init__()
self.distilbert = DistilBertModel.from_pretrained('distilbert-base-uncased')
self.dropout = nn.Dropout(dropout_rate)
self.classifier = nn.Linear(768, 2)
def forward(self, input_ids, attention_mask):
outputs = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.last_hidden_state[:, 0]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
return logits
# Clean text function
def clean_text(text):
text = text.lower()
text = re.sub(r'http\S+|www\S+|https\S+', '', text)
text = re.sub(r'<.*?>', '', text)
text = re.sub(r'[^\w\s]', '', text)
text = re.sub(r'\d+', '', text)
tokens = nltk.word_tokenize(text)
stop_words = set(stopwords.words('english'))
lemmatizer = WordNetLemmatizer()
tokens = [lemmatizer.lemmatize(word) for word in tokens if word not in stop_words]
cleaned_text = ' '.join(tokens)
return cleaned_text
# Load models
def load_models():
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# TF-IDF vectorizer
with open('models/tfidf_vectorizer.pkl', 'rb') as f:
tfidf_vectorizer = pickle.load(f)
# Logistic Regression
with open('models/lr_model.pkl', 'rb') as f:
lr_model = pickle.load(f)
# random Forest
with open('models/rf_model.pkl', 'rb') as f:
rf_model = pickle.load(f)
# load DistilBERT
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
distilbert_model = DistilBERTClassifier()
distilbert_model.load_state_dict(torch.load('models/distilbert_model.pt', map_location=device))
distilbert_model.to(device)
distilbert_model.eval()
return tfidf_vectorizer, lr_model, rf_model, distilbert_model, tokenizer, device
# Load models at startup
tfidf_vectorizer, lr_model, rf_model, distilbert_model, tokenizer, device = load_models()
@app.route('/')
def serve():
return send_from_directory(app.static_folder, 'index.html')
@app.route('/api/analyze', methods=['POST'])
def analyze():
data = request.get_json()
if not data or 'text' not in data or 'model' not in data:
return jsonify({'error': 'Missing required fields'}), 400
news_text = data['text']
model_option = data['model']
if not news_text:
return jsonify({'error': 'Text cannot be empty'}), 400
# Clean text
cleaned_text = clean_text(news_text)
results = {}
# Using Logistic Regression
if model_option in ["lr", "all"]:
text_tfidf = tfidf_vectorizer.transform([cleaned_text])
lr_pred = lr_model.predict(text_tfidf)[0]
lr_prob = lr_model.predict_proba(text_tfidf)[0]
results["Logistic Regression"] = {
"prediction": "Real" if lr_pred == 1 else "Fake",
"fake_prob": float(lr_prob[0]),
"real_prob": float(lr_prob[1])
}
# Using Random Forest
if model_option in ["rf", "all"]:
text_tfidf = tfidf_vectorizer.transform([cleaned_text])
rf_pred = rf_model.predict(text_tfidf)[0]
rf_prob = rf_model.predict_proba(text_tfidf)[0]
results["Random Forest"] = {
"prediction": "Real" if rf_pred == 1 else "Fake",
"fake_prob": float(rf_prob[0]),
"real_prob": float(rf_prob[1])
}
# Using DistilBERT
if model_option in ["distilbert", "all"]:
encoding = tokenizer(
cleaned_text,
truncation=True,
padding='max_length',
max_length=128,
return_tensors='pt'
)
with torch.no_grad():
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
outputs = distilbert_model(input_ids=input_ids, attention_mask=attention_mask)
print("Raw model output:", outputs.cpu().numpy())
probs = torch.softmax(outputs, dim=1).cpu().numpy()[0]
print("After softmax:", probs)
print(f"Text: {cleaned_text[:50]}...")
print(f"Probabilities: Real={probs[0]:.4f}, Fake={probs[1]:.4f}")
distilbert_pred = 1 if probs[1] > probs[0] else 0
results["DistilBERT"] = {
"prediction": "Real" if distilbert_pred == 1 else "Fake",
"fake_prob": float(probs[0]),
"real_prob": float(probs[1])
}
# Calculate overall results for "all models" option
if model_option == "all":
real_votes = sum(1 for model, result in results.items() if result["prediction"] == "Real")
fake_votes = len(results) - real_votes
overall_verdict = "Real" if real_votes >= fake_votes else "Fake"
results["Overall"] = {
"prediction": overall_verdict,
"real_votes": real_votes,
"fake_votes": fake_votes,
"total_models": len(results)
}
return jsonify({'results': results})
if __name__ == '__main__':
port = int(os.environ.get('PORT', 7860))
app.run(host='0.0.0.0', port=port)