SMPLer-X / common /base.py
onescotch
clean up for zero gpus
010a8bc
raw
history blame
2.85 kB
import os.path as osp
import math
import abc
from torch.utils.data import DataLoader
import torch.optim
import torchvision.transforms as transforms
from timer import Timer
from logger import colorlogger
from torch.nn.parallel.data_parallel import DataParallel
from config import cfg
from SMPLer_X import get_model
# ddp
import torch.distributed as dist
from torch.utils.data import DistributedSampler
import torch.utils.data.distributed
from utils.distribute_utils import (
get_rank, is_main_process, time_synchronized, get_group_idx, get_process_groups
)
class Base(object):
__metaclass__ = abc.ABCMeta
def __init__(self, log_name='logs.txt'):
self.cur_epoch = 0
# timer
self.tot_timer = Timer()
self.gpu_timer = Timer()
self.read_timer = Timer()
# logger
self.logger = colorlogger(cfg.log_dir, log_name=log_name)
@abc.abstractmethod
def _make_batch_generator(self):
return
@abc.abstractmethod
def _make_model(self):
return
class Demoer(Base):
def __init__(self, test_epoch=None):
if test_epoch is not None:
self.test_epoch = int(test_epoch)
super(Demoer, self).__init__(log_name='test_logs.txt')
def _make_batch_generator(self, demo_scene):
# data load and construct batch generator
self.logger.info("Creating dataset...")
from data.UBody.UBody import UBody
testset_loader = UBody(transforms.ToTensor(), "demo", demo_scene) # eval(demoset)(transforms.ToTensor(), "demo")
batch_generator = DataLoader(dataset=testset_loader, batch_size=cfg.num_gpus * cfg.test_batch_size,
shuffle=False, num_workers=cfg.num_thread, pin_memory=True)
self.testset = testset_loader
self.batch_generator = batch_generator
def _make_model(self):
self.logger.info('Load checkpoint from {}'.format(cfg.pretrained_model_path))
# prepare network
self.logger.info("Creating graph...")
model = get_model('test')
model = DataParallel(model).to(cfg.device)
ckpt = torch.load(cfg.pretrained_model_path, map_location=cfg.device)
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in ckpt['network'].items():
if 'module' not in k:
k = 'module.' + k
k = k.replace('module.backbone', 'module.encoder').replace('body_rotation_net', 'body_regressor').replace(
'hand_rotation_net', 'hand_regressor')
new_state_dict[k] = v
model.load_state_dict(new_state_dict, strict=False)
model.eval()
self.model = model
def _evaluate(self, outs, cur_sample_idx):
eval_result = self.testset.evaluate(outs, cur_sample_idx)
return eval_result