Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
from pathlib import Path | |
import copy | |
from abc import ABC, abstractmethod | |
import torch | |
import torch.nn as nn | |
from .component import create_net | |
from .logger import BaseLogger | |
from lib import ParamSet | |
from typing import List, Dict, Tuple, Union | |
# Alias of typing | |
# eg. {'labels': {'label_A: torch.Tensor([0, 1, ...]), ...}} | |
LabelDict = Dict[str, Dict[str, Union[torch.IntTensor, torch.FloatTensor]]] | |
logger = BaseLogger.get_logger(__name__) | |
class BaseModel(ABC): | |
""" | |
Class to construct model. This class is the base class to construct model. | |
""" | |
def __init__(self, params: ParamSet) -> None: | |
""" | |
Class to define Model | |
Args: | |
param (ParamSet): parameters | |
""" | |
self.params = params | |
self.device = self.params.device | |
self.network = create_net( | |
mlp=self.params.mlp, | |
net=self.params.net, | |
num_outputs_for_label=self.params.num_outputs_for_label, | |
mlp_num_inputs=self.params.mlp_num_inputs, | |
in_channel=self.params.in_channel, | |
vit_image_size=self.params.vit_image_size, | |
pretrained=self.params.pretrained | |
) | |
self.network.to(self.device) | |
# variables to keep temporary best_weight and best_epoch | |
self.acting_best_weight = None | |
self.acting_best_epoch = None | |
def train(self) -> None: | |
""" | |
Make network training mode. | |
""" | |
self.network.train() | |
def eval(self) -> None: | |
""" | |
Make network evaluation mode. | |
""" | |
self.network.eval() | |
def set_data( | |
self, | |
data: Dict | |
) -> Tuple[ | |
Dict[str, torch.FloatTensor], | |
Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]] | |
]: | |
raise NotImplementedError | |
def store_weight(self, at_epoch: int = None) -> None: | |
""" | |
Store weight and epoch number when it is saved. | |
Args: | |
at_epoch (int): epoch number when save weight | |
""" | |
self.acting_best_epoch = at_epoch | |
_network = copy.deepcopy(self.network) | |
if hasattr(_network, 'module'): | |
# When DataParallel used, move weight to CPU. | |
self.acting_best_weight = copy.deepcopy(_network.module.to(torch.device('cpu')).state_dict()) | |
else: | |
self.acting_best_weight = copy.deepcopy(_network.state_dict()) | |
def save_weight(self, save_datetime_dir: str, as_best: bool = None) -> None: | |
""" | |
Save weight. | |
Args: | |
save_datetime_dir (str): save_datetime_dir | |
as_best (bool): True if weight is saved as best, otherwise False. Defaults to None. | |
""" | |
save_dir = Path(save_datetime_dir, 'weights') | |
save_dir.mkdir(parents=True, exist_ok=True) | |
save_name = 'weight_epoch-' + str(self.acting_best_epoch).zfill(3) + '.pt' | |
save_path = Path(save_dir, save_name) | |
if as_best: | |
save_name_as_best = 'weight_epoch-' + str(self.acting_best_epoch).zfill(3) + '_best' + '.pt' | |
save_path_as_best = Path(save_dir, save_name_as_best) | |
if save_path.exists(): | |
# Check if best weight already saved. If exists, rename with '_best' | |
save_path.rename(save_path_as_best) | |
else: | |
torch.save(self.acting_best_weight, save_path_as_best) | |
else: | |
save_name = 'weight_epoch-' + str(self.acting_best_epoch).zfill(3) + '.pt' | |
torch.save(self.acting_best_weight, save_path) | |
def load_weight(self, weight_path: Path) -> None: | |
""" | |
Load wight from weight_path. | |
Args: | |
weight_path (Path): path to weight | |
""" | |
logger.info(f"Load weight: {weight_path}.\n") | |
weight = torch.load(weight_path) | |
self.network.load_state_dict(weight) | |
class ModelMixin: | |
def to_gpu(self, gpu_ids: List[int]) -> None: | |
""" | |
Make model compute on the GPU. | |
Args: | |
gpu_ids (List[int]): GPU ids | |
""" | |
if gpu_ids != []: | |
assert torch.cuda.is_available(), 'No available GPU on this machine.' | |
self.network = nn.DataParallel(self.network, device_ids=gpu_ids) | |
def init_network(self) -> None: | |
""" | |
Initialize network. | |
This method is used at test to reset the current weight by redefining network. | |
""" | |
self.network = create_net( | |
mlp=self.params.mlp, | |
net=self.params.net, | |
num_outputs_for_label=self.params.num_outputs_for_label, | |
mlp_num_inputs=self.params.mlp_num_inputs, | |
in_channel=self.params.in_channel, | |
vit_image_size=self.params.vit_image_size, | |
pretrained=self.params.pretrained | |
) | |
self.network.to(self.device) | |
class ModelWidget(BaseModel, ModelMixin): | |
""" | |
Class for a widget to inherit multiple classes simultaneously | |
""" | |
pass | |
class MLPModel(ModelWidget): | |
""" | |
Class for MLP model | |
""" | |
def __init__(self, params: ParamSet) -> None: | |
""" | |
Args: | |
params: (ParamSet): parameters | |
""" | |
super().__init__(params) | |
def set_data( | |
self, | |
data: Dict | |
) -> Tuple[ | |
Dict[str, torch.FloatTensor], | |
Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]] | |
]: | |
""" | |
Unpack data for forwarding of MLP and calculating loss | |
by passing them to device. | |
When deepsurv, period and network are also returned. | |
Args: | |
data (Dict): dictionary of data | |
Returns: | |
Tuple[ | |
Dict[str, torch.FloatTensor], | |
Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]] | |
]: input of model and data for calculating loss. | |
eg. | |
([inputs], [labels]), or ([inputs], [labels, periods, network]) when deepsurv | |
""" | |
in_data = {'inputs': data['inputs'].to(self.device)} | |
labels = {'labels': {label_name: label.to(self.device) for label_name, label in data['labels'].items()}} | |
if not any(data['periods']): | |
return in_data, labels | |
# When deepsurv | |
labels = { | |
**labels, | |
**{'periods': data['periods'].to(self.device), 'network': self.network.to(self.device)} | |
} | |
return in_data, labels | |
def __call__(self, in_data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | |
""" | |
Forward | |
Args: | |
in_data (Dict[str, torch.Tensor]): data to be input into model | |
Returns: | |
Dict[str, torch.Tensor]: output | |
""" | |
inputs = in_data['inputs'] | |
output = self.network(inputs) | |
return output | |
class CVModel(ModelWidget): | |
""" | |
Class for CNN or ViT model | |
""" | |
def __init__(self, params: ParamSet) -> None: | |
""" | |
Args: | |
params: (ParamSet): parameters | |
""" | |
super().__init__(params) | |
def set_data( | |
self, | |
data: Dict | |
) -> Tuple[ | |
Dict[str, torch.FloatTensor], | |
Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]] | |
]: | |
""" | |
Unpack data for forwarding of CNN or ViT and calculating loss by passing them to device. | |
When deepsurv, period and network are also returned. | |
Args: | |
data (Dict): dictionary of data | |
Returns: | |
Tuple[ | |
Dict[str, torch.FloatTensor], | |
Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]] | |
]: input of model and data for calculating loss. | |
eg. | |
([image], [labels]), or ([image], [labels, periods, network]) when deepsurv | |
""" | |
in_data = {'image': data['image'].to(self.device)} | |
labels = {'labels': {label_name: label.to(self.device) for label_name, label in data['labels'].items()}} | |
if not any(data['periods']): | |
return in_data, labels | |
# When deepsurv | |
labels = { | |
**labels, | |
**{'periods': data['periods'].to(self.device), 'network': self.network.to(self.device)} | |
} | |
return in_data, labels | |
def __call__(self, in_data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | |
""" | |
Forward | |
Args: | |
in_data (Dict[str, torch.Tensor]): data to be input into model | |
Returns: | |
Dict[str, torch.Tensor]: output | |
""" | |
image = in_data['image'] | |
output = self.network(image) | |
return output | |
class FusionModel(ModelWidget): | |
""" | |
Class for MLP+CNN or MLP+ViT model. | |
""" | |
def __init__(self, params: ParamSet) -> None: | |
""" | |
Args: | |
params: (ParamSet): parameters | |
""" | |
super().__init__(params) | |
def set_data( | |
self, | |
data: Dict | |
) -> Tuple[ | |
Dict[str, torch.FloatTensor], | |
Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]] | |
]: | |
""" | |
Unpack data for forwarding of MLP+CNN or MLP+ViT and calculating loss | |
by passing them to device. | |
When deepsurv, period and network are also returned. | |
Args: | |
data (Dict): dictionary of data | |
Returns: | |
Tuple[ | |
Dict[str, torch.FloatTensor], | |
Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]] | |
]: input of model and data for calculating loss. | |
eg. | |
([inputs, image], [labels]), or ([inputs, image], [labels, periods, network]) when deepsurv | |
""" | |
in_data = { | |
'inputs': data['inputs'].to(self.device), | |
'image': data['image'].to(self.device) | |
} | |
labels = {'labels': {label_name: label.to(self.device) for label_name, label in data['labels'].items()}} | |
if not any(data['periods']): | |
return in_data, labels | |
# When deepsurv | |
labels = { | |
**labels, | |
**{'periods': data['periods'].to(self.device), 'network': self.network.to(self.device)} | |
} | |
return in_data, labels | |
def __call__(self, in_data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | |
""" | |
Forward | |
Args: | |
in_data (Dict[str, torch.Tensor]): data to be input into model | |
Returns: | |
Dict[str, torch.Tensor]: output | |
""" | |
inputs = in_data['inputs'] | |
image = in_data['image'] | |
output = self.network(inputs, image) | |
return output | |
def create_model(params: ParamSet) -> nn.Module: | |
""" | |
Construct model. | |
Args: | |
params (ParamSet): parameters | |
Returns: | |
nn.Module: model | |
""" | |
_isMLPModel = (params.mlp is not None) and (params.net is None) | |
_isCVModel = (params.mlp is None) and (params.net is not None) | |
_isFusionModel = (params.mlp is not None) and (params.net is not None) | |
if _isMLPModel: | |
return MLPModel(params) | |
elif _isCVModel: | |
return CVModel(params) | |
elif _isFusionModel: | |
return FusionModel(params) | |
else: | |
raise ValueError(f"Invalid model type: mlp={params.mlp}, net={params.net}.") | |