cartoonize / utils /common.py
YANGYYYY's picture
Upload 8 files
922e494 verified
raw
history blame contribute delete
No virus
4.87 kB
import torch
import gc
import os
import torch.nn as nn
import urllib.request
import cv2
from tqdm import tqdm
HTTP_PREFIXES = [
'http',
'data:image/jpeg',
]
RELEASED_WEIGHTS = {
"hayao:v2": (
# Dataset trained on Google Landmark micro as training real photo
"v2",
"https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.1/GeneratorV2_gldv2_Hayao.pt"
),
"hayao:v1": (
"v1",
"https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_hayao.pth"
),
"hayao": (
"v1",
"https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_hayao.pth"
),
"shinkai:v1": (
"v1",
"https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_shinkai.pth"
),
"shinkai": (
"v1",
"https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_shinkai.pth"
),
}
def is_image_file(path):
_, ext = os.path.splitext(path)
return ext.lower() in (".png", ".jpg", ".jpeg")
def read_image(path):
"""
Read image from given path
"""
if any(path.startswith(p) for p in HTTP_PREFIXES):
urllib.request.urlretrieve(path, "temp.jpg")
path = "temp.jpg"
return cv2.imread(path)[: ,: ,::-1]
def save_checkpoint(model, path, optimizer=None, epoch=None):
checkpoint = {
'model_state_dict': model.state_dict(),
'epoch': epoch,
}
if optimizer is not None:
checkpoint['optimizer_state_dict'] = optimizer.state_dict()
torch.save(checkpoint, path)
def maybe_remove_module(state_dict):
# Remove added module ins state_dict in ddp training
# https://discuss.pytorch.org/t/why-are-state-dict-keys-getting-prepended-with-the-string-module/104627/3
new_state_dict = {}
module_str = 'module.'
for k, v in state_dict.items():
if k.startswith(module_str):
k = k[len(module_str):]
new_state_dict[k] = v
return new_state_dict
def load_checkpoint(model, path, optimizer=None, strip_optimizer=False, map_location=None) -> int:
state_dict = load_state_dict(path, map_location)
model_state_dict = maybe_remove_module(state_dict['model_state_dict'])
model.load_state_dict(
model_state_dict,
strict=True
)
if 'optimizer_state_dict' in state_dict:
if optimizer is not None:
optimizer.load_state_dict(state_dict['optimizer_state_dict'])
if strip_optimizer:
del state_dict["optimizer_state_dict"]
torch.save(state_dict, path)
print(f"Optimizer stripped and saved to {path}")
epoch = state_dict.get('epoch', 0)
return epoch
def load_state_dict(weight, map_location) -> dict:
if weight.lower() in RELEASED_WEIGHTS:
weight = _download_weight(weight.lower())
if map_location is None:
# auto select
map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
state_dict = torch.load(weight, map_location=map_location)
return state_dict
def initialize_weights(net):
for m in net.modules():
try:
if isinstance(m, nn.Conv2d):
# m.weight.data.normal_(0, 0.02)
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.zero_()
elif isinstance(m, nn.ConvTranspose2d):
# m.weight.data.normal_(0, 0.02)
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
# m.weight.data.normal_(0, 0.02)
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
except Exception as e:
# print(f'SKip layer {m}, {e}')
pass
def set_lr(optimizer, lr):
for param_group in optimizer.param_groups:
param_group['lr'] = lr
class DownloadProgressBar(tqdm):
'''
https://stackoverflow.com/questions/15644964/python-progress-bar-and-downloads
'''
def update_to(self, b=1, bsize=1, tsize=None):
if tsize is not None:
self.total = tsize
self.update(b * bsize - self.n)
def _download_weight(weight):
'''
Download weight and save to local file
'''
os.makedirs('.cache', exist_ok=True)
url = RELEASED_WEIGHTS[weight][1]
filename = os.path.basename(url)
save_path = f'.cache/{filename}'
if os.path.isfile(save_path):
return save_path
desc = f'Downloading {url} to {save_path}'
with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=desc) as t:
urllib.request.urlretrieve(url, save_path, reporthook=t.update_to)
return save_path