|
|
|
|
|
|
|
|
|
import unittest |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from yolox.utils import adjust_status, freeze_module |
|
from yolox.exp import get_exp |
|
|
|
|
|
class TestModelUtils(unittest.TestCase): |
|
|
|
def setUp(self): |
|
self.model: nn.Module = get_exp(exp_name="yolox-s").get_model() |
|
|
|
def test_model_state_adjust_status(self): |
|
data = torch.ones(1, 10, 10, 10) |
|
|
|
model = nn.BatchNorm2d(10) |
|
prev_state = model.state_dict() |
|
|
|
modes = [False, True] |
|
results = [True, False] |
|
|
|
|
|
for mode, result in zip(modes, results): |
|
with adjust_status(model, training=mode): |
|
model(data) |
|
model_state = model.state_dict() |
|
self.assertTrue(len(model_state) == len(prev_state)) |
|
self.assertEqual( |
|
result, |
|
all([torch.allclose(v, model_state[k]) for k, v in prev_state.items()]) |
|
) |
|
|
|
|
|
prev_state = model.state_dict() |
|
with adjust_status(model, training=False): |
|
with adjust_status(model, training=False): |
|
model(data) |
|
model_state = model.state_dict() |
|
self.assertTrue(len(model_state) == len(prev_state)) |
|
self.assertTrue( |
|
all([torch.allclose(v, model_state[k]) for k, v in prev_state.items()]) |
|
) |
|
|
|
def test_model_effect_adjust_status(self): |
|
|
|
self.model.train() |
|
with adjust_status(self.model, training=False): |
|
for module in self.model.modules(): |
|
self.assertFalse(module.training) |
|
|
|
for module in self.model.modules(): |
|
self.assertTrue(module.training) |
|
|
|
|
|
self.model.backbone.eval() |
|
with adjust_status(self.model, training=False): |
|
for module in self.model.modules(): |
|
self.assertFalse(module.training) |
|
|
|
for name, module in self.model.named_modules(): |
|
if "backbone" in name: |
|
self.assertFalse(module.training) |
|
else: |
|
self.assertTrue(module.training) |
|
|
|
def test_freeze_module(self): |
|
model = nn.Sequential( |
|
nn.Conv2d(3, 10, 1), |
|
nn.BatchNorm2d(10), |
|
nn.ReLU(), |
|
) |
|
data = torch.rand(1, 3, 10, 10) |
|
model.train() |
|
assert isinstance(model[1], nn.BatchNorm2d) |
|
before_states = model[1].state_dict() |
|
freeze_module(model[1]) |
|
model(data) |
|
after_states = model[1].state_dict() |
|
self.assertTrue( |
|
all([torch.allclose(v, after_states[k]) for k, v in before_states.items()]) |
|
) |
|
|
|
|
|
self.model.train() |
|
for module in self.model.modules(): |
|
self.assertTrue(module.training) |
|
|
|
freeze_module(self.model, "backbone") |
|
for module in self.model.backbone.modules(): |
|
self.assertFalse(module.training) |
|
for p in self.model.backbone.parameters(): |
|
self.assertFalse(p.requires_grad) |
|
|
|
for module in self.model.head.modules(): |
|
self.assertTrue(module.training) |
|
for p in self.model.head.parameters(): |
|
self.assertTrue(p.requires_grad) |
|
|
|
|
|
if __name__ == "__main__": |
|
unittest.main() |
|
|