|
import os |
|
import re |
|
from importlib import import_module |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn.parallel import DataParallel, DistributedDataParallel |
|
|
|
import torch.distributed as dist |
|
from torch.nn.utils import parameters_to_vector, vector_to_parameters |
|
|
|
from .discriminator import Discriminator |
|
|
|
from utils import interact |
|
|
|
class Model(nn.Module): |
|
def __init__(self, args): |
|
super(Model, self).__init__() |
|
|
|
self.args = args |
|
self.device = args.device |
|
self.n_GPUs = args.n_GPUs |
|
self.save_dir = os.path.join(args.save_dir, 'models') |
|
os.makedirs(self.save_dir, exist_ok=True) |
|
|
|
module = import_module('model.' + args.model) |
|
|
|
self.model = nn.ModuleDict() |
|
self.model.G = module.build_model(args) |
|
if self.args.loss.lower().find('adv') >= 0: |
|
self.model.D = Discriminator(self.args) |
|
else: |
|
self.model.D = None |
|
|
|
self.to(args.device, dtype=args.dtype, non_blocking=True) |
|
self.load(args.load_epoch, path=args.pretrained) |
|
|
|
def parallelize(self): |
|
if self.args.device_type == 'cuda': |
|
if self.args.distributed: |
|
Parallel = DistributedDataParallel |
|
parallel_args = { |
|
"device_ids": [self.args.rank], |
|
"output_device": self.args.rank, |
|
} |
|
else: |
|
Parallel = DataParallel |
|
parallel_args = { |
|
'device_ids': list(range(self.n_GPUs)), |
|
'output_device': self.args.rank |
|
} |
|
|
|
for model_key in self.model: |
|
if self.model[model_key] is not None: |
|
self.model[model_key] = Parallel(self.model[model_key], **parallel_args) |
|
|
|
def forward(self, input): |
|
return self.model.G(input) |
|
|
|
def _save_path(self, epoch): |
|
model_path = os.path.join(self.save_dir, 'model-{:d}.pt'.format(epoch)) |
|
return model_path |
|
|
|
def state_dict(self): |
|
state_dict = {} |
|
for model_key in self.model: |
|
if self.model[model_key] is not None: |
|
parallelized = isinstance(self.model[model_key], (DataParallel, DistributedDataParallel)) |
|
if parallelized: |
|
state_dict[model_key] = self.model[model_key].module.state_dict() |
|
else: |
|
state_dict[model_key] = self.model[model_key].state_dict() |
|
|
|
return state_dict |
|
|
|
def load_state_dict(self, state_dict, strict=True): |
|
for model_key in self.model: |
|
parallelized = isinstance(self.model[model_key], (DataParallel, DistributedDataParallel)) |
|
if model_key in state_dict: |
|
if parallelized: |
|
self.model[model_key].module.load_state_dict(state_dict[model_key], strict) |
|
else: |
|
self.model[model_key].load_state_dict(state_dict[model_key], strict) |
|
|
|
def save(self, epoch): |
|
torch.save(self.state_dict(), self._save_path(epoch)) |
|
|
|
def load(self, epoch=None, path=None): |
|
if path: |
|
model_name = path |
|
elif isinstance(epoch, int): |
|
if epoch < 0: |
|
epoch = self.get_last_epoch() |
|
if epoch == 0: |
|
|
|
|
|
|
|
|
|
return |
|
|
|
model_name = self._save_path(epoch) |
|
else: |
|
raise Exception('no epoch number or model path specified!') |
|
|
|
print('Loading model from {}'.format(model_name)) |
|
state_dict = torch.load(model_name, map_location=self.args.device) |
|
self.load_state_dict(state_dict) |
|
|
|
return |
|
|
|
def synchronize(self): |
|
if self.args.distributed: |
|
|
|
vector = parameters_to_vector(self.parameters()) |
|
|
|
dist.broadcast(vector, 0) |
|
if self.args.rank != 0: |
|
vector_to_parameters(vector, self.parameters()) |
|
|
|
del vector |
|
|
|
return |
|
|
|
def get_last_epoch(self): |
|
model_list = sorted(os.listdir(self.save_dir)) |
|
if len(model_list) == 0: |
|
epoch = 0 |
|
else: |
|
epoch = int(re.findall('\\d+', model_list[-1])[0]) |
|
|
|
return epoch |
|
|
|
def print(self): |
|
print(self.model) |
|
|
|
return |
|
|