Image Classification
Transformers
Safetensors
cetaceanet
biology
biodiversity
custom_code
MalloryWittwerEPFL's picture
Fix broken upload
ea195e7
from typing import Dict, List, Optional, Tuple
import numpy as np
import timm
import torch
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from .config import Config, load_config
# from .dataset import WhaleDataset, load_df
from .metric_learning import ArcFaceLossAdaptiveMargin, ArcMarginProductSubcenter, GeM
from .utils import WarmupCosineLambda, map_dict, topk_average_precision
class SphereClassifier(LightningModule):
def __init__(self, cfg: dict, id_class_nums=None, species_class_nums=None):
super().__init__()
# import pdb; pdb.set_trace()
if not isinstance(cfg, Config):
cfg = Config(cfg)
self.save_hyperparameters(cfg, ignore=["id_class_nums", "species_class_nums"])
self.test_results_fp = None
# import json
# cfg_json = json.dumps(cfg)
# with open("config_extracted.json", "w") as file:
# file.write(cfg_json)
# NN architecture
self.backbone = timm.create_model(
cfg.model_name,
in_chans=3,
pretrained=cfg.pretrained,
num_classes=0,
features_only=True,
out_indices=cfg.out_indices,
)
feature_dims = self.backbone.feature_info.channels()
print(f"feature dims: {feature_dims}")
self.global_pools = torch.nn.ModuleList(
[GeM(p=cfg.global_pool.p, requires_grad=cfg.global_pool.train) for _ in cfg.out_indices]
)
self.mid_features = np.sum(feature_dims)
if cfg.normalization == "batchnorm":
self.neck = torch.nn.BatchNorm1d(self.mid_features)
elif cfg.normalization == "layernorm":
self.neck = torch.nn.LayerNorm(self.mid_features)
self.head_id = ArcMarginProductSubcenter(self.mid_features, cfg.num_classes, cfg.n_center_id)
self.head_species = ArcMarginProductSubcenter(self.mid_features, cfg.num_species_classes, cfg.n_center_species)
if id_class_nums is not None and species_class_nums is not None:
margins_id = np.power(id_class_nums, cfg.margin_power_id) * cfg.margin_coef_id + cfg.margin_cons_id
margins_species = (
np.power(species_class_nums, cfg.margin_power_species) * cfg.margin_coef_species
+ cfg.margin_cons_species
)
print("margins_id", margins_id)
print("margins_species", margins_species)
self.margin_fn_id = ArcFaceLossAdaptiveMargin(margins_id, cfg.num_classes, cfg.s_id)
self.margin_fn_species = ArcFaceLossAdaptiveMargin(margins_species, cfg.num_species_classes, cfg.s_species)
self.loss_fn_id = torch.nn.CrossEntropyLoss()
self.loss_fn_species = torch.nn.CrossEntropyLoss()
def get_feat(self, x: torch.Tensor) -> torch.Tensor:
ms = self.backbone(x)
h = torch.cat([global_pool(m) for m, global_pool in zip(ms, self.global_pools)], dim=1)
return self.neck(h)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
feat = self.get_feat(x)
return self.head_id(feat), self.head_species(feat)
def training_step(self, batch, batch_idx):
x, ids, species = batch["image"], batch["label"], batch["label_species"]
logits_ids, logits_species = self(x)
margin_logits_ids = self.margin_fn_id(logits_ids, ids)
loss_ids = self.loss_fn_id(margin_logits_ids, ids)
loss_species = self.loss_fn_species(self.margin_fn_species(logits_species, species), species)
self.log_dict({"train/loss_ids": loss_ids.detach()}, on_step=False, on_epoch=True)
self.log_dict({"train/loss_species": loss_species.detach()}, on_step=False, on_epoch=True)
with torch.no_grad():
self.log_dict(map_dict(logits_ids, ids, "train"), on_step=False, on_epoch=True)
self.log_dict(
{"train/acc_species": topk_average_precision(logits_species, species, 1).mean().detach()},
on_step=False,
on_epoch=True,
)
return loss_ids * self.hparams.loss_id_ratio + loss_species * (1 - self.hparams.loss_id_ratio)
def validation_step(self, batch, batch_idx):
x, ids, species = batch["image"], batch["label"], batch["label_species"]
out1, out_species1 = self(x)
out2, out_species2 = self(x.flip(3))
output, output_species = (out1 + out2) / 2, (out_species1 + out_species2) / 2
self.log_dict(map_dict(output, ids, "val"), on_step=False, on_epoch=True)
self.log_dict(
{"val/acc_species": topk_average_precision(output_species, species, 1).mean().detach()},
on_step=False,
on_epoch=True,
)
def configure_optimizers(self):
backbone_params = list(self.backbone.parameters()) + list(self.global_pools.parameters())
head_params = (
list(self.neck.parameters()) + list(self.head_id.parameters()) + list(self.head_species.parameters())
)
params = [
{"params": backbone_params, "lr": self.hparams.lr_backbone},
{"params": head_params, "lr": self.hparams.lr_head},
]
if self.hparams.optimizer == "Adam":
optimizer = torch.optim.Adam(params)
elif self.hparams.optimizer == "AdamW":
optimizer = torch.optim.AdamW(params)
elif self.hparams.optimizer == "RAdam":
optimizer = torch.optim.RAdam(params)
warmup_steps = self.hparams.max_epochs * self.hparams.warmup_steps_ratio
cycle_steps = self.hparams.max_epochs - warmup_steps
lr_lambda = WarmupCosineLambda(warmup_steps, cycle_steps, self.hparams.lr_decay_scale)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
return [optimizer], [scheduler]
def test_step(self, batch, batch_idx):
x = batch["image"]
feat1 = self.get_feat(x)
out1, out_species1 = self.head_id(feat1), self.head_species(feat1)
feat2 = self.get_feat(x.flip(3))
out2, out_species2 = self.head_id(feat2), self.head_species(feat2)
pred_logit, pred_idx = ((out1 + out2) / 2).cpu().sort(descending=True)
return {
"original_index": batch["original_index"],
"label": batch["label"],
"label_species": batch["label_species"],
"pred_logit": pred_logit[:, :1000],
"pred_idx": pred_idx[:, :1000],
"pred_species": ((out_species1 + out_species2) / 2).cpu(),
"embed_features1": feat1.cpu(),
"embed_features2": feat2.cpu(),
}
def test_epoch_end(self, outputs: List[Dict[str, torch.Tensor]]):
outputs = self.all_gather(outputs)
if self.trainer.global_rank == 0:
epoch_results: Dict[str, np.ndarray] = {}
for key in outputs[0].keys():
if torch.cuda.device_count() > 1:
result = torch.cat([x[key] for x in outputs], dim=1).flatten(end_dim=1)
else:
result = torch.cat([x[key] for x in outputs], dim=0)
epoch_results[key] = result.detach().cpu().numpy()
np.savez_compressed(self.test_results_fp, **epoch_results)