Spaces:
Runtime error
Runtime error
import os | |
import glob | |
import time | |
import numpy as np | |
from PIL import Image | |
from pathlib import Path | |
from tqdm.notebook import tqdm | |
import matplotlib.pyplot as plt | |
from skimage.color import rgb2lab, lab2rgb | |
import torch | |
from torch import nn, optim | |
from torchvision import transforms | |
from torchvision.utils import make_grid | |
from torch.utils.data import Dataset, DataLoader | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def init_weights(net, init="norm", gain=0.02): | |
def init_func(m): | |
classname = m.__class__.__name__ | |
if hasattr(m, "weight") and "Conv" in classname: | |
if init == "norm": | |
nn.init.normal_(m.weight.data, mean=0.0, std=gain) | |
elif init == "xavier": | |
nn.init.xavier_normal_(m.weight.data, gain=gain) | |
elif init == "kaiming": | |
nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") | |
if hasattr(m, "bias") and m.bias is not None: | |
nn.init.constant_(m.bias.data, 0.0) | |
elif "BatchNorm2d" in classname: | |
nn.init.normal_(m.weight.data, 1.0, gain) | |
nn.init.constant_(m.bias.data, 0.0) | |
net.apply(init_func) | |
print(f"model initialized with {init} initialization") | |
return net | |