weibo_senti_cls / test.py
ZaynSu99's picture
Update weibo_senti_cls
995278d
raw
history blame contribute delete
947 Bytes
from cnn import CNN,Model
from utils import loader_test
import torch
device = torch.device('cuda')
model = CNN().to(device)
model.load_state_dict(torch.load('net_params.pth'))
#model = Model().to(device)
#model.load_state_dict(torch.load('cls_params.pth'))
def test():
model.eval()
correct = 0
total = 0
for i ,(input_ids, attention_mask, token_type_ids, labels) in enumerate(loader_test):
print(i)
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
token_type_ids = token_type_ids.to(device)
labels = labels.to(device)
with torch.no_grad():
out = model(input_ids=input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
out = out.argmax(dim=1)
correct += (out == labels).sum().item()
total += len(labels)
print('correct: ',correct,'total: ',total)
print('accuracy:',correct/total)
test()