Spaces:
Build error
Build error
import os.path as osp | |
import math | |
import abc | |
from torch.utils.data import DataLoader | |
import torch.optim | |
import torchvision.transforms as transforms | |
from timer import Timer | |
from logger import colorlogger | |
from torch.nn.parallel.data_parallel import DataParallel | |
from config import cfg | |
from SMPLer_X import get_model | |
# ddp | |
import torch.distributed as dist | |
from torch.utils.data import DistributedSampler | |
import torch.utils.data.distributed | |
from utils.distribute_utils import ( | |
get_rank, is_main_process, time_synchronized, get_group_idx, get_process_groups | |
) | |
class Base(object): | |
__metaclass__ = abc.ABCMeta | |
def __init__(self, log_name='logs.txt'): | |
self.cur_epoch = 0 | |
# timer | |
self.tot_timer = Timer() | |
self.gpu_timer = Timer() | |
self.read_timer = Timer() | |
# logger | |
self.logger = colorlogger(cfg.log_dir, log_name=log_name) | |
def _make_batch_generator(self): | |
return | |
def _make_model(self): | |
return | |
class Demoer(Base): | |
def __init__(self, test_epoch=None): | |
if test_epoch is not None: | |
self.test_epoch = int(test_epoch) | |
super(Demoer, self).__init__(log_name='test_logs.txt') | |
def _make_batch_generator(self, demo_scene): | |
# data load and construct batch generator | |
self.logger.info("Creating dataset...") | |
from data.UBody.UBody import UBody | |
testset_loader = UBody(transforms.ToTensor(), "demo", demo_scene) # eval(demoset)(transforms.ToTensor(), "demo") | |
batch_generator = DataLoader(dataset=testset_loader, batch_size=cfg.num_gpus * cfg.test_batch_size, | |
shuffle=False, num_workers=cfg.num_thread, pin_memory=True) | |
self.testset = testset_loader | |
self.batch_generator = batch_generator | |
def _make_model(self): | |
self.logger.info('Load checkpoint from {}'.format(cfg.pretrained_model_path)) | |
# prepare network | |
self.logger.info("Creating graph...") | |
model = get_model('test') | |
model = DataParallel(model).to(cfg.device) | |
ckpt = torch.load(cfg.pretrained_model_path, map_location=cfg.device) | |
from collections import OrderedDict | |
new_state_dict = OrderedDict() | |
for k, v in ckpt['network'].items(): | |
if 'module' not in k: | |
k = 'module.' + k | |
k = k.replace('module.backbone', 'module.encoder').replace('body_rotation_net', 'body_regressor').replace( | |
'hand_rotation_net', 'hand_regressor') | |
new_state_dict[k] = v | |
model.load_state_dict(new_state_dict, strict=False) | |
model.eval() | |
self.model = model | |
def _evaluate(self, outs, cur_sample_idx): | |
eval_result = self.testset.evaluate(outs, cur_sample_idx) | |
return eval_result | |