Vnese_crawl / predictor.py
SonFox2920's picture
Update predictor.py
88c7b44 verified
import os
hf_token = os.getenv('HF_TOKEN')
import warnings
warnings.filterwarnings('ignore')
import logging
logging.disable(logging.WARNING)
from Mbert import SentencePairDataset, MBERTClassifier
import torch
import numpy as np
import random
from transformers import AutoModel, AutoTokenizer
from torch.utils.data import DataLoader
import pandas as pd
# Thiết lập seed cố định
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Gọi hàm set_seed với seed cố định, ví dụ: 42
set_seed(42)
device = torch.device("cpu")
modelname = "bert-base-multilingual-cased"
tokenizer = AutoTokenizer.from_pretrained(modelname, use_auth_token=hf_token,
trust_remote_code=True)
mbert = AutoModel.from_pretrained(modelname, use_auth_token=hf_token,
trust_remote_code=True).to(device)
model = MBERTClassifier(mbert, num_classes=3).to(device)
model.load_state_dict(torch.load('Model/classifier.pt', map_location=device))
def predict(context, claim):
data = pd.DataFrame([{'context': context, 'claim': claim}])
X1_pub_test = data['claim']
X2_pub_test = data['context']
X_pub_test = [(X1_pub_test, X2_pub_test) for (X1_pub_test, X2_pub_test) in zip(X1_pub_test, X2_pub_test)]
y_pub_test = [1]
test_dataset = SentencePairDataset(X_pub_test, y_pub_test, tokenizer, 256)
test_loader_pub = DataLoader(test_dataset, batch_size=1)
model.eval()
predictions = []
probabilities = []
for batch in test_loader_pub:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
with torch.no_grad():
outputs = model(input_ids, attention_mask)
probs = torch.nn.functional.softmax(outputs, dim=1)
predicted = torch.argmax(outputs, dim=1)
predictions.extend(predicted.cpu().numpy().tolist())
probabilities.extend(probs.cpu().numpy().tolist())
data['verdict'] = predictions
data['verdict'] = data['verdict'].replace(0, "SUPPORTED")
data['verdict'] = data['verdict'].replace(1, "REFUTED")
data['verdict'] = data['verdict'].replace(2, "NEI")
result = {
'verdict': data['verdict'][0],
'probabilities': {
'SUPPORTED': probabilities[0][0],
'REFUTED': probabilities[0][1],
'NEI': probabilities[0][2]
}
}
return result