fake-news-detector-LSTM / inference.py
kimic's picture
Added cm and updated graph titles for clarity
1bb2bdd
import torch
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from model import LSTMModel
def load_model(model_path, vocab_size):
model = LSTMModel(vocab_size)
model.load_state_dict(torch.load(model_path))
model.eval()
return model
def predict(model, titles, texts, device):
titles, texts = titles.to(device), texts.to(device)
model.to(device)
with torch.no_grad():
outputs = model(titles, texts).squeeze()
return outputs
def evaluate_model(model, data_loader, device, labels):
model.to(device)
model.eval()
predictions = []
labels = torch.tensor(labels).to(device)
for titles, texts in data_loader:
titles, texts = titles.to(device), texts.to(device)
outputs = predict(model, titles, texts, device)
predictions.extend(outputs.cpu().numpy())
labels = labels.cpu().numpy() # Convert labels to NumPy array for consistency
predicted_labels = [1 if p > 0.5 else 0 for p in predictions]
# Calculate metrics
accuracy = accuracy_score(labels, predicted_labels)
f1 = f1_score(labels, predicted_labels)
auc_roc = roc_auc_score(labels, predictions)
return accuracy, f1, auc_roc, labels, predicted_labels