MedicalAILabo's picture
Upload app.py and lib.
1f53a4c
raw
history blame
No virus
9.02 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from pathlib import Path
import torch
import pandas as pd
from ..logger import BaseLogger
from typing import List, Dict, Union
logger = BaseLogger.get_logger(__name__)
class LabelLoss:
"""
Class to store loss for every bash and epoch loss of each label.
"""
def __init__(self) -> None:
# Accumulate batch_loss(=loss * batch_size)
self.train_batch_loss = 0.0
self.val_batch_loss = 0.0
# epoch_loss = batch_loss / dataset_size
self.train_epoch_loss = [] # List[float]
self.val_epoch_loss = [] # List[float]
self.best_val_loss = None # float
self.best_epoch = None # int
self.is_val_loss_updated = None # bool
def get_loss(self, phase: str, target: str) -> Union[float, List[float]]:
"""
Return loss depending on phase and target
Args:
phase (str): 'train' or 'val'
target (str): 'batch' or 'epoch'
Returns:
Union[float, List[float]]: batch_loss or epoch_loss
"""
_target = phase + '_' + target + '_loss'
return getattr(self, _target)
def store_batch_loss(self, phase: str, new_batch_loss: torch.FloatTensor, batch_size: int) -> None:
"""
Add new batch loss to previous one for phase by multiplying by batch_size.
Args:
phase (str): 'train' or 'val'
new_batch_loss (torch.FloatTensor): batch loss calculated by criterion
batch_size (int): batch size
"""
_new = new_batch_loss.item() * batch_size # torch.FloatTensor -> float
_prev = self.get_loss(phase, 'batch')
_added = _prev + _new
_target = phase + '_' + 'batch_loss'
setattr(self, _target, _added)
def append_epoch_loss(self, phase: str, new_epoch_loss: float) -> None:
"""
Append epoch loss depending on phase and target
Args:
phase (str): 'train' or 'val'
new_epoch_loss (float): batch loss or epoch loss
"""
_target = phase + '_' + 'epoch_loss'
getattr(self, _target).append(new_epoch_loss)
def get_latest_epoch_loss(self, phase: str) -> float:
"""
Return the latest loss of phase.
Args:
phase (str): train or val
Returns:
float: the latest loss
"""
return self.get_loss(phase, 'epoch')[-1]
def update_best_val_loss(self, at_epoch: int = None) -> None:
"""
Update val_epoch_loss is the best.
Args:
at_epoch (int): epoch when checked
"""
_latest_val_loss = self.get_latest_epoch_loss('val')
if at_epoch == 1:
self.best_val_loss = _latest_val_loss
self.best_epoch = at_epoch
self.is_val_loss_updated = True
else:
# When at_epoch > 1
if _latest_val_loss < self.best_val_loss:
self.best_val_loss = _latest_val_loss
self.best_epoch = at_epoch
self.is_val_loss_updated = True
else:
self.is_val_loss_updated = False
class LossStore:
"""
Class for calculating loss and store it.
"""
def __init__(self, label_list: List[str], num_epochs: int, dataset_info: Dict[str, int]) -> None:
"""
Args:
label_list (List[str]): list of internal labels
num_epochs (int) : number of epochs
dataset_info (Dict[str, int]): dataset sizes of 'train' and 'val'
"""
self.label_list = label_list
self.num_epochs = num_epochs
self.dataset_info = dataset_info
# Added a special label 'total' to store total of losses of all labels.
self.label_losses = {label_name: LabelLoss() for label_name in self.label_list + ['total']}
def store(self, phase: str, losses: Dict[str, torch.FloatTensor], batch_size: int = None) -> None:
"""
Store label-wise batch losses of phase to previous one.
Args:
phase (str): 'train' or 'val'
losses (Dict[str, torch.FloatTensor]): loss for each label calculated by criterion
batch_size (int): batch size
# Note:
self.loss_stores['total'] is already total of losses of all label, which is calculated in criterion.py,
therefore, it is OK just to multiply by batch_size. This is done in add_batch_loss().
"""
for label_name in self.label_list + ['total']:
_new_batch_loss = losses[label_name]
self.label_losses[label_name].store_batch_loss(phase, _new_batch_loss, batch_size)
def cal_epoch_loss(self, at_epoch: int = None) -> None:
"""
Calculate epoch loss for each phase all at once.
Args:
at_epoch (int): epoch number
"""
# For each label
for label_name in self.label_list:
for phase in ['train', 'val']:
_batch_loss = self.label_losses[label_name].get_loss(phase, 'batch')
_dataset_size = self.dataset_info[phase]
_new_epoch_loss = _batch_loss / _dataset_size
self.label_losses[label_name].append_epoch_loss(phase, _new_epoch_loss)
# For total, average by dataset_size and the number of labels.
for phase in ['train', 'val']:
_batch_loss = self.label_losses['total'].get_loss(phase, 'batch')
_dataset_size = self.dataset_info[phase]
_new_epoch_loss = _batch_loss / (_dataset_size * len(self.label_list))
self.label_losses['total'].append_epoch_loss(phase, _new_epoch_loss)
# Update val_best_loss and best_epoch.
for label_name in self.label_list + ['total']:
self.label_losses[label_name].update_best_val_loss(at_epoch=at_epoch)
# Initialize batch_loss after calculating epoch loss.
for label_name in self.label_list + ['total']:
self.label_losses[label_name].train_batch_loss = 0.0
self.label_losses[label_name].val_batch_loss = 0.0
def is_val_loss_updated(self) -> bool:
"""
Check if val_loss of 'total' is updated.
Returns:
bool: Updated or not
"""
return self.label_losses['total'].is_val_loss_updated
def get_best_epoch(self) -> int:
"""
Returns best epoch.
Returns:
int: best epoch
"""
return self.label_losses['total'].best_epoch
def print_epoch_loss(self, at_epoch: int = None) -> None:
"""
Print train_loss and val_loss for the ith epoch.
Args:
at_epoch (int): epoch number
"""
train_epoch_loss = self.label_losses['total'].get_latest_epoch_loss('train')
val_epoch_loss = self.label_losses['total'].get_latest_epoch_loss('val')
_epoch_comm = f"epoch [{at_epoch:>3}/{self.num_epochs:<3}]"
_train_comm = f"train_loss: {train_epoch_loss :>8.4f}"
_val_comm = f"val_loss: {val_epoch_loss:>8.4f}"
_updated_comment = ''
if (at_epoch > 1) and (self.is_val_loss_updated()):
_updated_comment = ' Updated best val_loss!'
comment = _epoch_comm + ', ' + _train_comm + ', ' + _val_comm + _updated_comment
logger.info(comment)
def save_learning_curve(self, save_datetime_dir: str) -> None:
"""
Save learning curve.
Args:
save_datetime_dir (str): save_datetime_dir
"""
save_dir = Path(save_datetime_dir, 'learning_curve')
save_dir.mkdir(parents=True, exist_ok=True)
for label_name in self.label_list + ['total']:
_label_loss = self.label_losses[label_name]
_train_epoch_loss = _label_loss.get_loss('train', 'epoch')
_val_epoch_loss = _label_loss.get_loss('val', 'epoch')
df_label_epoch_loss = pd.DataFrame({
'train_loss': _train_epoch_loss,
'val_loss': _val_epoch_loss
})
_best_epoch = str(_label_loss.best_epoch).zfill(3)
_best_val_loss = f"{_label_loss.best_val_loss:.4f}"
save_name = 'learning_curve_' + label_name + '_val-best-epoch-' + _best_epoch + '_val-best-loss-' + _best_val_loss + '.csv'
save_path = Path(save_dir, save_name)
df_label_epoch_loss.to_csv(save_path, index=False)
def set_loss_store(label_list: List[str], num_epochs: int, dataset_info: Dict[str, int]) -> LossStore:
"""
Return class LossStore.
Args:
label_list (List[str]): label list
num_epochs (int) : number of epochs
dataset_info (Dict[str, int]): dataset sizes of 'train' and 'val'
Returns:
LossStore: LossStore
"""
return LossStore(label_list, num_epochs, dataset_info)