import torch import gc import os import torch.nn as nn import urllib.request import cv2 from tqdm import tqdm HTTP_PREFIXES = [ 'http', 'data:image/jpeg', ] RELEASED_WEIGHTS = { "hayao:v2": ( # Dataset trained on Google Landmark micro as training real photo "v2", "https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.1/GeneratorV2_gldv2_Hayao.pt" ), "hayao:v1": ( "v1", "https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_hayao.pth" ), "hayao": ( "v1", "https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_hayao.pth" ), "shinkai:v1": ( "v1", "https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_shinkai.pth" ), "shinkai": ( "v1", "https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_shinkai.pth" ), } def is_image_file(path): _, ext = os.path.splitext(path) return ext.lower() in (".png", ".jpg", ".jpeg") def read_image(path): """ Read image from given path """ if any(path.startswith(p) for p in HTTP_PREFIXES): urllib.request.urlretrieve(path, "temp.jpg") path = "temp.jpg" return cv2.imread(path)[: ,: ,::-1] def save_checkpoint(model, path, optimizer=None, epoch=None): checkpoint = { 'model_state_dict': model.state_dict(), 'epoch': epoch, } if optimizer is not None: checkpoint['optimizer_state_dict'] = optimizer.state_dict() torch.save(checkpoint, path) def maybe_remove_module(state_dict): # Remove added module ins state_dict in ddp training # https://discuss.pytorch.org/t/why-are-state-dict-keys-getting-prepended-with-the-string-module/104627/3 new_state_dict = {} module_str = 'module.' for k, v in state_dict.items(): if k.startswith(module_str): k = k[len(module_str):] new_state_dict[k] = v return new_state_dict def load_checkpoint(model, path, optimizer=None, strip_optimizer=False, map_location=None) -> int: state_dict = load_state_dict(path, map_location) model_state_dict = maybe_remove_module(state_dict['model_state_dict']) model.load_state_dict( model_state_dict, strict=True ) if 'optimizer_state_dict' in state_dict: if optimizer is not None: optimizer.load_state_dict(state_dict['optimizer_state_dict']) if strip_optimizer: del state_dict["optimizer_state_dict"] torch.save(state_dict, path) print(f"Optimizer stripped and saved to {path}") epoch = state_dict.get('epoch', 0) return epoch def load_state_dict(weight, map_location) -> dict: if weight.lower() in RELEASED_WEIGHTS: weight = _download_weight(weight.lower()) if map_location is None: # auto select map_location = 'cuda' if torch.cuda.is_available() else 'cpu' state_dict = torch.load(weight, map_location=map_location) return state_dict def initialize_weights(net): for m in net.modules(): try: if isinstance(m, nn.Conv2d): # m.weight.data.normal_(0, 0.02) torch.nn.init.xavier_uniform_(m.weight) m.bias.data.zero_() elif isinstance(m, nn.ConvTranspose2d): # m.weight.data.normal_(0, 0.02) torch.nn.init.xavier_uniform_(m.weight) m.bias.data.zero_() elif isinstance(m, nn.Linear): # m.weight.data.normal_(0, 0.02) torch.nn.init.xavier_uniform_(m.weight) m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() except Exception as e: # print(f'SKip layer {m}, {e}') pass def set_lr(optimizer, lr): for param_group in optimizer.param_groups: param_group['lr'] = lr class DownloadProgressBar(tqdm): ''' https://stackoverflow.com/questions/15644964/python-progress-bar-and-downloads ''' def update_to(self, b=1, bsize=1, tsize=None): if tsize is not None: self.total = tsize self.update(b * bsize - self.n) def _download_weight(weight): ''' Download weight and save to local file ''' os.makedirs('.cache', exist_ok=True) url = RELEASED_WEIGHTS[weight][1] filename = os.path.basename(url) save_path = f'.cache/{filename}' if os.path.isfile(save_path): return save_path desc = f'Downloading {url} to {save_path}' with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=desc) as t: urllib.request.urlretrieve(url, save_path, reporthook=t.update_to) return save_path