fake-news-detector-LSTM / inference.py
kimic's picture
Initial commit
c5cd586
raw
history blame
No virus
1.17 kB
import torch
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from model import LSTMModel
def load_model(model_path, vocab_size):
model = LSTMModel(vocab_size)
model.load_state_dict(torch.load(model_path))
model.eval()
return model
def predict(model, titles, texts, device):
titles, texts = titles.to(device), texts.to(device)
model.to(device)
with torch.no_grad():
outputs = model(titles, texts).squeeze()
return outputs
def evaluate_model(model, data_loader, device, labels):
model.to(device)
model.eval()
predictions = []
labels = torch.tensor(labels).to(device)
for titles, texts in data_loader:
titles, texts = titles.to(device), texts.to(device)
outputs = predict(model, titles, texts, device)
predictions.extend(outputs.cpu().numpy())
labels = labels.cpu()
# Calculate metrics
predicted_labels = [1 if p > 0.5 else 0 for p in predictions]
accuracy = accuracy_score(labels, predicted_labels)
f1 = f1_score(labels, predicted_labels)
auc_roc = roc_auc_score(labels, predictions)
return accuracy, f1, auc_roc