| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from ppdet.core.workspace import create |
| | from ppdet.utils.logger import setup_logger |
| | logger = setup_logger('ppdet.engine') |
| |
|
| | from . import Trainer |
| | __all__ = ['TrainerCot'] |
| |
|
| | class TrainerCot(Trainer): |
| | """ |
| | Trainer for label-cotuning |
| | calculate the relationship between base_classes and novel_classes |
| | """ |
| | def __init__(self, cfg, mode='train'): |
| | super(TrainerCot, self).__init__(cfg, mode) |
| | self.cotuning_init() |
| |
|
| | def cotuning_init(self): |
| | num_classes_novel = self.cfg['num_classes'] |
| |
|
| | self.load_weights(self.cfg.pretrain_weights) |
| |
|
| | self.model.eval() |
| | relationship = self.model.relationship_learning(self.loader, num_classes_novel) |
| | |
| | self.model.init_cot_head(relationship) |
| | self.optimizer = create('OptimizerBuilder')(self.lr, self.model) |
| |
|
| |
|
| |
|