|
from fastai.core import * |
|
from fastai.torch_core import * |
|
from fastai.vision import * |
|
from fastai.vision.gan import AdaptiveLoss, accuracy_thresh_expand |
|
|
|
_conv_args = dict(leaky=0.2, norm_type=NormType.Spectral) |
|
|
|
|
|
def _conv(ni: int, nf: int, ks: int = 3, stride: int = 1, **kwargs): |
|
return conv_layer(ni, nf, ks=ks, stride=stride, **_conv_args, **kwargs) |
|
|
|
|
|
def custom_gan_critic( |
|
n_channels: int = 3, nf: int = 256, n_blocks: int = 3, p: int = 0.15 |
|
): |
|
"Critic to train a `GAN`." |
|
layers = [_conv(n_channels, nf, ks=4, stride=2), nn.Dropout2d(p / 2)] |
|
for i in range(n_blocks): |
|
layers += [ |
|
_conv(nf, nf, ks=3, stride=1), |
|
nn.Dropout2d(p), |
|
_conv(nf, nf * 2, ks=4, stride=2, self_attention=(i == 0)), |
|
] |
|
nf *= 2 |
|
layers += [ |
|
_conv(nf, nf, ks=3, stride=1), |
|
_conv(nf, 1, ks=4, bias=False, padding=0, use_activ=False), |
|
Flatten(), |
|
] |
|
return nn.Sequential(*layers) |
|
|
|
|
|
def colorize_crit_learner( |
|
data: ImageDataBunch, |
|
loss_critic=AdaptiveLoss(nn.BCEWithLogitsLoss()), |
|
nf: int = 256, |
|
) -> Learner: |
|
return Learner( |
|
data, |
|
custom_gan_critic(nf=nf), |
|
metrics=accuracy_thresh_expand, |
|
loss_func=loss_critic, |
|
wd=1e-3, |
|
) |
|
|