Spaces:
Runtime error
Runtime error
from fastai.core import * | |
from fastai.torch_core import * | |
from fastai.vision import * | |
from fastai.vision.gan import AdaptiveLoss, accuracy_thresh_expand | |
_conv_args = dict(leaky=0.2, norm_type=NormType.Spectral) | |
def _conv(ni: int, nf: int, ks: int = 3, stride: int = 1, **kwargs): | |
return conv_layer(ni, nf, ks=ks, stride=stride, **_conv_args, **kwargs) | |
def custom_gan_critic( | |
n_channels: int = 3, nf: int = 256, n_blocks: int = 3, p: int = 0.15 | |
): | |
"Critic to train a `GAN`." | |
layers = [_conv(n_channels, nf, ks=4, stride=2), nn.Dropout2d(p / 2)] | |
for i in range(n_blocks): | |
layers += [ | |
_conv(nf, nf, ks=3, stride=1), | |
nn.Dropout2d(p), | |
_conv(nf, nf * 2, ks=4, stride=2, self_attention=(i == 0)), | |
] | |
nf *= 2 | |
layers += [ | |
_conv(nf, nf, ks=3, stride=1), | |
_conv(nf, 1, ks=4, bias=False, padding=0, use_activ=False), | |
Flatten(), | |
] | |
return nn.Sequential(*layers) | |
def colorize_crit_learner( | |
data: ImageDataBunch, | |
loss_critic=AdaptiveLoss(nn.BCEWithLogitsLoss()), | |
nf: int = 256, | |
) -> Learner: | |
return Learner( | |
data, | |
custom_gan_critic(nf=nf), | |
metrics=accuracy_thresh_expand, | |
loss_func=loss_critic, | |
wd=1e-3, | |
) | |