|
import math |
|
import torch |
|
import torch.nn as nn |
|
from torch.optim import SGD |
|
from torch.optim.lr_scheduler import LambdaLR |
|
|
|
from ding.policy import Policy |
|
from ding.model import model_wrap |
|
from ding.torch_utils import to_device |
|
from ding.utils import EasyTimer |
|
|
|
|
|
class ImageClassificationPolicy(Policy): |
|
config = dict( |
|
type='image_classification', |
|
on_policy=False, |
|
) |
|
|
|
def _init_learn(self): |
|
self._optimizer = SGD( |
|
self._model.parameters(), |
|
lr=self._cfg.learn.learning_rate, |
|
weight_decay=self._cfg.learn.weight_decay, |
|
momentum=0.9 |
|
) |
|
self._timer = EasyTimer(cuda=True) |
|
|
|
def lr_scheduler_fn(epoch): |
|
if epoch <= self._cfg.learn.warmup_epoch: |
|
return self._cfg.learn.warmup_lr / self._cfg.learn.learning_rate |
|
else: |
|
ratio = epoch // self._cfg.learn.decay_epoch |
|
return math.pow(self._cfg.learn.decay_rate, ratio) |
|
|
|
self._lr_scheduler = LambdaLR(self._optimizer, lr_scheduler_fn) |
|
self._lr_scheduler.step() |
|
self._learn_model = model_wrap(self._model, 'base') |
|
self._learn_model.reset() |
|
|
|
self._ce_loss = nn.CrossEntropyLoss() |
|
|
|
def _forward_learn(self, data): |
|
if self._cuda: |
|
data = to_device(data, self._device) |
|
self._learn_model.train() |
|
|
|
with self._timer: |
|
img, target = data |
|
logit = self._learn_model.forward(img) |
|
loss = self._ce_loss(logit, target) |
|
forward_time = self._timer.value |
|
|
|
with self._timer: |
|
self._optimizer.zero_grad() |
|
loss.backward() |
|
backward_time = self._timer.value |
|
|
|
with self._timer: |
|
if self._cfg.multi_gpu: |
|
self.sync_gradients(self._learn_model) |
|
sync_time = self._timer.value |
|
self._optimizer.step() |
|
|
|
cur_lr = [param_group['lr'] for param_group in self._optimizer.param_groups] |
|
cur_lr = sum(cur_lr) / len(cur_lr) |
|
return { |
|
'cur_lr': cur_lr, |
|
'total_loss': loss.item(), |
|
'forward_time': forward_time, |
|
'backward_time': backward_time, |
|
'sync_time': sync_time, |
|
} |
|
|
|
def _monitor_vars_learn(self): |
|
return ['cur_lr', 'total_loss', 'forward_time', 'backward_time', 'sync_time'] |
|
|
|
def _init_eval(self): |
|
self._eval_model = model_wrap(self._model, 'base') |
|
|
|
def _forward_eval(self, data): |
|
if self._cuda: |
|
data = to_device(data, self._device) |
|
self._eval_model.eval() |
|
with torch.no_grad(): |
|
output = self._eval_model.forward(data) |
|
if self._cuda: |
|
output = to_device(output, 'cpu') |
|
return output |
|
|
|
def _init_collect(self): |
|
pass |
|
|
|
def _forward_collect(self, data): |
|
pass |
|
|
|
def _process_transition(self): |
|
pass |
|
|
|
def _get_train_sample(self): |
|
pass |
|
|