|
import torch |
|
from torch import nn |
|
|
|
|
|
def to(x, device): |
|
if isinstance(x, dict): |
|
for k, v in x.items(): |
|
if isinstance(v, torch.Tensor): |
|
x[k] = v.to(device) |
|
else: |
|
x = x.to(device) |
|
return x |
|
|
|
|
|
def get_cur_acc(testset, hyps, model, shuffle, iter_index): |
|
from data import split_dataset, build_dataloader |
|
cur_test_batch_dataset = split_dataset(testset, hyps['val_batch_size'], iter_index)[0] |
|
cur_test_batch_dataloader = build_dataloader(cur_test_batch_dataset, hyps['train_batch_size'], hyps['num_workers'], False, shuffle) |
|
cur_acc = model.get_accuracy(cur_test_batch_dataloader) |
|
return cur_acc |
|
|