|
import torch |
|
from torch.nn import functional as F |
|
|
|
from dassl.data import DataManager |
|
from dassl.engine import TRAINER_REGISTRY, TrainerXU |
|
from dassl.metrics import compute_accuracy |
|
from dassl.data.transforms import build_transform |
|
|
|
|
|
@TRAINER_REGISTRY.register() |
|
class FixMatch(TrainerXU): |
|
"""FixMatch: Simplifying Semi-Supervised Learning with |
|
Consistency and Confidence. |
|
|
|
https://arxiv.org/abs/2001.07685. |
|
""" |
|
|
|
def __init__(self, cfg): |
|
super().__init__(cfg) |
|
self.weight_u = cfg.TRAINER.FIXMATCH.WEIGHT_U |
|
self.conf_thre = cfg.TRAINER.FIXMATCH.CONF_THRE |
|
|
|
def check_cfg(self, cfg): |
|
assert len(cfg.TRAINER.FIXMATCH.STRONG_TRANSFORMS) > 0 |
|
|
|
def build_data_loader(self): |
|
cfg = self.cfg |
|
tfm_train = build_transform(cfg, is_train=True) |
|
custom_tfm_train = [tfm_train] |
|
choices = cfg.TRAINER.FIXMATCH.STRONG_TRANSFORMS |
|
tfm_train_strong = build_transform(cfg, is_train=True, choices=choices) |
|
custom_tfm_train += [tfm_train_strong] |
|
self.dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train) |
|
self.train_loader_x = self.dm.train_loader_x |
|
self.train_loader_u = self.dm.train_loader_u |
|
self.val_loader = self.dm.val_loader |
|
self.test_loader = self.dm.test_loader |
|
self.num_classes = self.dm.num_classes |
|
|
|
def assess_y_pred_quality(self, y_pred, y_true, mask): |
|
n_masked_correct = (y_pred.eq(y_true).float() * mask).sum() |
|
acc_thre = n_masked_correct / (mask.sum() + 1e-5) |
|
acc_raw = y_pred.eq(y_true).sum() / y_pred.numel() |
|
keep_rate = mask.sum() / mask.numel() |
|
output = { |
|
"acc_thre": acc_thre, |
|
"acc_raw": acc_raw, |
|
"keep_rate": keep_rate |
|
} |
|
return output |
|
|
|
def forward_backward(self, batch_x, batch_u): |
|
parsed_data = self.parse_batch_train(batch_x, batch_u) |
|
input_x, input_x2, label_x, input_u, input_u2, label_u = parsed_data |
|
input_u = torch.cat([input_x, input_u], 0) |
|
input_u2 = torch.cat([input_x2, input_u2], 0) |
|
n_x = input_x.size(0) |
|
|
|
|
|
with torch.no_grad(): |
|
output_u = F.softmax(self.model(input_u), 1) |
|
max_prob, label_u_pred = output_u.max(1) |
|
mask_u = (max_prob >= self.conf_thre).float() |
|
|
|
|
|
y_u_pred_stats = self.assess_y_pred_quality( |
|
label_u_pred[n_x:], label_u, mask_u[n_x:] |
|
) |
|
|
|
|
|
output_x = self.model(input_x) |
|
loss_x = F.cross_entropy(output_x, label_x) |
|
|
|
|
|
output_u = self.model(input_u2) |
|
loss_u = F.cross_entropy(output_u, label_u_pred, reduction="none") |
|
loss_u = (loss_u * mask_u).mean() |
|
|
|
loss = loss_x + loss_u * self.weight_u |
|
self.model_backward_and_update(loss) |
|
|
|
loss_summary = { |
|
"loss_x": loss_x.item(), |
|
"acc_x": compute_accuracy(output_x, label_x)[0].item(), |
|
"loss_u": loss_u.item(), |
|
"y_u_pred_acc_raw": y_u_pred_stats["acc_raw"], |
|
"y_u_pred_acc_thre": y_u_pred_stats["acc_thre"], |
|
"y_u_pred_keep": y_u_pred_stats["keep_rate"], |
|
} |
|
|
|
if (self.batch_idx + 1) == self.num_batches: |
|
self.update_lr() |
|
|
|
return loss_summary |
|
|
|
def parse_batch_train(self, batch_x, batch_u): |
|
input_x = batch_x["img"] |
|
input_x2 = batch_x["img2"] |
|
label_x = batch_x["label"] |
|
input_u = batch_u["img"] |
|
input_u2 = batch_u["img2"] |
|
|
|
label_u = batch_u["label"] |
|
|
|
input_x = input_x.to(self.device) |
|
input_x2 = input_x2.to(self.device) |
|
label_x = label_x.to(self.device) |
|
input_u = input_u.to(self.device) |
|
input_u2 = input_u2.to(self.device) |
|
label_u = label_u.to(self.device) |
|
|
|
return input_x, input_x2, label_x, input_u, input_u2, label_u |
|
|