|
from torch.nn.modules.batchnorm import BatchNorm2d |
|
from torchvision.ops.misc import FrozenBatchNorm2d |
|
|
|
import timm |
|
from timm.utils.model import freeze, unfreeze |
|
|
|
|
|
def test_freeze_unfreeze(): |
|
model = timm.create_model('resnet18') |
|
|
|
|
|
freeze(model) |
|
|
|
assert model.fc.weight.requires_grad == False |
|
|
|
assert model.layer1[0].conv1.weight.requires_grad == False |
|
|
|
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) |
|
|
|
|
|
unfreeze(model) |
|
|
|
assert model.fc.weight.requires_grad == True |
|
|
|
assert model.layer1[0].conv1.weight.requires_grad == True |
|
|
|
assert isinstance(model.layer1[0].bn1, BatchNorm2d) |
|
|
|
|
|
freeze(model, ['layer1', 'layer2.0']) |
|
|
|
assert model.layer1[0].conv1.weight.requires_grad == False |
|
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) |
|
assert model.layer2[0].conv1.weight.requires_grad == False |
|
|
|
assert model.layer3[0].conv1.weight.requires_grad == True |
|
assert isinstance(model.layer3[0].bn1, BatchNorm2d) |
|
assert model.layer2[1].conv1.weight.requires_grad == True |
|
|
|
|
|
unfreeze(model, ['layer1', 'layer2.0']) |
|
|
|
assert model.layer1[0].conv1.weight.requires_grad == True |
|
assert isinstance(model.layer1[0].bn1, BatchNorm2d) |
|
assert model.layer2[0].conv1.weight.requires_grad == True |
|
|
|
|
|
|
|
freeze(model, ['layer1.0.bn1']) |
|
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) |
|
unfreeze(model, ['layer1.0.bn1']) |
|
assert isinstance(model.layer1[0].bn1, BatchNorm2d) |
|
|
|
freeze(model.layer1[0], ['bn1']) |
|
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) |
|
unfreeze(model.layer1[0], ['bn1']) |
|
assert isinstance(model.layer1[0].bn1, BatchNorm2d) |