|
from cnn import CNN,Model |
|
from utils import loader_test |
|
import torch |
|
|
|
|
|
|
|
device = torch.device('cuda') |
|
model = CNN().to(device) |
|
model.load_state_dict(torch.load('net_params.pth')) |
|
|
|
|
|
|
|
def test(): |
|
model.eval() |
|
correct = 0 |
|
total = 0 |
|
|
|
|
|
for i ,(input_ids, attention_mask, token_type_ids, labels) in enumerate(loader_test): |
|
print(i) |
|
input_ids = input_ids.to(device) |
|
attention_mask = attention_mask.to(device) |
|
token_type_ids = token_type_ids.to(device) |
|
labels = labels.to(device) |
|
with torch.no_grad(): |
|
out = model(input_ids=input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids) |
|
out = out.argmax(dim=1) |
|
correct += (out == labels).sum().item() |
|
total += len(labels) |
|
|
|
print('correct: ',correct,'total: ',total) |
|
print('accuracy:',correct/total) |
|
|
|
test() |