LINC-BIT's picture
Upload 1912 files
b84549f verified
raw
history blame contribute delete
668 Bytes
import torch
from torch import nn
def to(x, device):
if isinstance(x, dict):
for k, v in x.items():
if isinstance(v, torch.Tensor):
x[k] = v.to(device)
else:
x = x.to(device)
return x
def get_cur_acc(testset, hyps, model, shuffle, iter_index):
from data import split_dataset, build_dataloader
cur_test_batch_dataset = split_dataset(testset, hyps['val_batch_size'], iter_index)[0]
cur_test_batch_dataloader = build_dataloader(cur_test_batch_dataset, hyps['train_batch_size'], hyps['num_workers'], False, shuffle)
cur_acc = model.get_accuracy(cur_test_batch_dataloader)
return cur_acc