| |
| import tempfile |
| from unittest import TestCase |
| from unittest.mock import Mock |
|
|
| import torch |
| import torch.nn as nn |
| from mmengine.evaluator import Evaluator |
| from mmengine.model import BaseModel |
| from mmengine.optim import OptimWrapper |
| from mmengine.runner import Runner |
| from torch.utils.data import Dataset |
|
|
| from mmdet.registry import DATASETS |
| from mmdet.utils import register_all_modules |
|
|
| register_all_modules() |
|
|
|
|
| class ToyModel(nn.Module): |
|
|
| def __init__(self): |
| super().__init__() |
| self.linear = nn.Linear(2, 1) |
|
|
| def forward(self, inputs, data_samples, mode='tensor'): |
| labels = torch.stack(data_samples) |
| inputs = torch.stack(inputs) |
| outputs = self.linear(inputs) |
| if mode == 'tensor': |
| return outputs |
| elif mode == 'loss': |
| loss = (labels - outputs).sum() |
| outputs = dict(loss=loss) |
| return outputs |
| else: |
| return outputs |
|
|
|
|
| class ToyModel1(BaseModel, ToyModel): |
|
|
| def __init__(self): |
| super().__init__() |
|
|
| def forward(self, *args, **kwargs): |
| return super(BaseModel, self).forward(*args, **kwargs) |
|
|
|
|
| class ToyModel2(BaseModel): |
|
|
| def __init__(self): |
| super().__init__() |
| self.teacher = ToyModel1() |
| self.student = ToyModel1() |
| self.semi_test_cfg = dict(predict_on='teacher') |
|
|
| def forward(self, *args, **kwargs): |
| return self.student(*args, **kwargs) |
|
|
|
|
| @DATASETS.register_module(force=True) |
| class DummyDataset(Dataset): |
| METAINFO = dict() |
| data = torch.randn(12, 2) |
| label = torch.ones(12) |
|
|
| @property |
| def metainfo(self): |
| return self.METAINFO |
|
|
| def __len__(self): |
| return self.data.size(0) |
|
|
| def __getitem__(self, index): |
| return dict(inputs=self.data[index], data_samples=self.label[index]) |
|
|
|
|
| class TestTeacherStudentValLoop(TestCase): |
|
|
| def setUp(self): |
| self.temp_dir = tempfile.TemporaryDirectory() |
|
|
| def tearDown(self): |
| self.temp_dir.cleanup() |
|
|
| def test_teacher_student_val_loop(self): |
| device = 'cuda:0' if torch.cuda.is_available() else 'cpu' |
| model = ToyModel2().to(device) |
| evaluator = Mock() |
| evaluator.evaluate = Mock(return_value=dict(acc=0.5)) |
| evaluator.__class__ = Evaluator |
| runner = Runner( |
| model=model, |
| train_dataloader=dict( |
| dataset=dict(type='DummyDataset'), |
| sampler=dict(type='DefaultSampler', shuffle=True), |
| batch_size=3, |
| num_workers=0), |
| val_dataloader=dict( |
| dataset=dict(type='DummyDataset'), |
| sampler=dict(type='DefaultSampler', shuffle=False), |
| batch_size=3, |
| num_workers=0), |
| val_evaluator=evaluator, |
| work_dir=self.temp_dir.name, |
| default_scope='mmdet', |
| optim_wrapper=OptimWrapper( |
| torch.optim.Adam(ToyModel().parameters())), |
| train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1), |
| val_cfg=dict(type='TeacherStudentValLoop'), |
| default_hooks=dict(logger=dict(type='LoggerHook', interval=1)), |
| experiment_name='test1') |
| runner.train() |
|
|