| |
| |
|
|
| |
|
|
| from typing import Optional, List, Any, Dict |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers import AutoModel, AutoConfig |
| from transformers.modeling_outputs import ImageClassifierOutput |
|
|
| |
| from torchvision import models as tv_models |
|
|
| try: |
| from .ds_cfg import BackboneMLPHeadConfig, BACKBONE_META |
| except ImportError: |
| from ds_cfg import BackboneMLPHeadConfig, BACKBONE_META |
| |
|
|
| class MLPHead(nn.Module): |
| """ |
| 간단한 2-layer MLP head. |
| |
| Parameters |
| ---------- |
| in_dim : int |
| backbone feature dim |
| num_labels : int |
| class count |
| bottleneck : int |
| hidden dim |
| p : float |
| dropout prob |
| """ |
| def __init__(self, in_dim: int, num_labels: int, bottleneck: int = 256, p: float = 0.2): |
| super().__init__() |
| self.fc1 = nn.Linear(in_dim, bottleneck) |
| self.act = nn.GELU() |
| self.drop = nn.Dropout(p) |
| self.fc2 = nn.Linear(bottleneck, num_labels) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.fc2(self.drop(self.act(self.fc1(x)))) |
|
|
| |
| |
| |
| def _resolve_backbone_meta(config: BackboneMLPHeadConfig, fallback_table: Dict[str, Dict[str, Any]] | None = None) -> Dict[str, Any]: |
| """ |
| Resolve runtime backbone meta. |
| |
| Priority: |
| 1) config.backbone_meta (preferred; required for Hub runtime determinism) |
| 2) fallback_table[config.backbone_name_or_path] (backward compatibility for local/dev) |
| |
| Returns a dict with at least: type, feat_rule, feat_dim (and optional has_bn/unfreeze). |
| """ |
| meta = getattr(config, "backbone_meta", None) |
| if isinstance(meta, dict) and len(meta) > 0: |
| return meta |
|
|
| bb = getattr(config, "backbone_name_or_path", None) |
| if fallback_table is not None and bb in fallback_table: |
| return fallback_table[bb] |
|
|
| raise ValueError( |
| "config.backbone_meta is missing/empty and no fallback meta is available. " |
| "Populate config.backbone_meta when saving to the Hub (single source of truth)." |
| ) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| class BackboneWithMLPHeadForImageClassification(PreTrainedModel): |
| |
| |
| config_class = BackboneMLPHeadConfig |
|
|
| def __init__(self, config: BackboneMLPHeadConfig): |
| |
| |
| super().__init__(config) |
|
|
| |
| |
| |
| |
| |
| if config.backbone_name_or_path is None: |
| raise ValueError( |
| "config.backbone_name_or_path is None. " |
| "Provide a valid backbone id (whitelist key in BACKBONE_META)." |
| ) |
|
|
| |
| |
| |
| |
| |
| if int(getattr(config, "num_labels", 0)) <= 0: |
| raise ValueError( |
| f"config.num_labels must be > 0, got {getattr(config, 'num_labels', None)}. " |
| "Set num_labels (or id2label/label2id) when creating the config." |
| ) |
|
|
| |
| |
| |
| |
| self._meta = _resolve_backbone_meta(config, fallback_table=BACKBONE_META) |
|
|
| |
| |
| self.backbone = self._build_backbone_skeleton(config.backbone_name_or_path) |
|
|
| |
| |
| self.classifier = MLPHead( |
| in_dim=int(self._meta["feat_dim"]), |
| num_labels=int(config.num_labels), |
| bottleneck=int(config.mlp_head_bottleneck), |
| p=float(config.mlp_head_dropout), |
| ) |
|
|
| |
| |
| self.post_init() |
|
|
| def init_weights(self): |
| """ |
| Initialize only the head to avoid touching the backbone skeleton. |
| backbone skeleton을 건드리지 않기 위해 head만 초기화. |
| |
| HF's default init may traverse the entire module tree, which is undesirable here. |
| HF 기본 init은 전체 모듈 트리를 순회할 수 있어 여기서 그대로 사용하기 부적절. |
| |
| 초기 설계에서 __init__ 내부에서 backbone의 가중치 로드를 수행함(편리를 위해). |
| 이 경우, HF의 post_init()으로 인해 해당 로드가 취소되는 경우가 존재(timm, torchvision 등의 백본). |
| 때문에 이를 오버라이드 하여 classifier만 초기화 하도록 변경함. |
| """ |
| if getattr(self, "classifier", None) is not None: |
| self.classifier.apply(self._init_weights) |
| self.tie_weights() |
|
|
| |
| |
| |
| |
| def _build_backbone_skeleton(self, backbone_id: str) -> nn.Module: |
| |
| |
| meta = self._meta if backbone_id == self.config.backbone_name_or_path else BACKBONE_META.get(backbone_id) |
| if meta is None: |
| raise KeyError(f"Unknown backbone_id={backbone_id}. Provide backbone_meta in config or extend BACKBONE_META.") |
|
|
| t = meta["type"] |
|
|
| if t == "timm_densenet": |
| return self._build_timm_densenet_skeleton(backbone_id) |
|
|
| if t == "torchvision_densenet": |
| return self._build_torchvision_densenet_skeleton(backbone_id) |
|
|
| |
| |
| bb_cfg = AutoConfig.from_pretrained(backbone_id) |
| return AutoModel.from_config(bb_cfg) |
|
|
| @staticmethod |
| def _build_timm_densenet_skeleton(hf_repo_id: str) -> nn.Module: |
| |
| |
| try: |
| import timm |
| except Exception as e: |
| raise ImportError( |
| "DenseNet(timm) backbone requires `timm`. Install: pip install timm" |
| ) from e |
|
|
| |
| |
| return timm.create_model( |
| f"hf_hub:{hf_repo_id}", |
| pretrained=False, |
| num_classes=0, |
| ) |
|
|
| @staticmethod |
| def _build_torchvision_densenet_skeleton(model_id: str) -> nn.Module: |
| |
| |
| if model_id != "torchvision/densenet121": |
| raise ValueError(f"Unsupported torchvision DenseNet id (224 whitelist only): {model_id}") |
|
|
| |
| |
| m = tv_models.densenet121(weights=None) |
| return m |
|
|
| |
| |
| |
| |
| @torch.no_grad() |
| def load_backbone_pretrained_( |
| self, |
| *, |
| low_cpu_mem_usage: bool = False, |
| device_map=None, |
| ): |
| """ |
| Fresh-start only: inject pretrained backbone weights into the skeleton. |
| fresh-start 전용: skeleton backbone에 pretrained 가중치를 주입. |
| |
| Do NOT call this after from_pretrained() because it would overwrite checkpoint weights. |
| from_pretrained() 이후 호출하면 체크포인트 가중치를 덮어쓰므로 주의할 것. |
| """ |
| bb = self.config.backbone_name_or_path |
| meta = self._meta |
| t = meta["type"] |
|
|
| if t == "timm_densenet": |
| self._load_timm_pretrained_into_skeleton_(bb) |
| return |
|
|
| if t == "torchvision_densenet": |
| self._load_torchvision_pretrained_into_skeleton_(bb) |
| return |
|
|
| |
| |
| ref = AutoModel.from_pretrained( |
| bb, |
| low_cpu_mem_usage=low_cpu_mem_usage, |
| device_map=device_map, |
| ) |
|
|
| |
| |
| self.backbone.load_state_dict(ref.state_dict(), strict=False) |
| del ref |
|
|
| @torch.no_grad() |
| def _load_timm_pretrained_into_skeleton_(self, hf_repo_id: str): |
| |
| |
| import timm |
|
|
| |
| |
| ref = timm.create_model( |
| f"hf_hub:{hf_repo_id}", |
| pretrained=True, |
| num_classes=0, |
| ).eval() |
|
|
| self.backbone.load_state_dict(ref.state_dict(), strict=True) |
| del ref |
|
|
| @torch.no_grad() |
| def _load_torchvision_pretrained_into_skeleton_(self, model_id: str): |
| |
| |
| if model_id != "torchvision/densenet121": |
| raise ValueError(f"Unsupported torchvision DenseNet id (224 whitelist only): {model_id}") |
|
|
| |
| |
| ref = tv_models.densenet121(weights=tv_models.DenseNet121_Weights.DEFAULT).eval() |
|
|
| self.backbone.load_state_dict(ref.state_dict(), strict=True) |
| del ref |
|
|
| |
| |
| |
| |
| @staticmethod |
| def _pool_or_gap(outputs) -> torch.Tensor: |
| |
| |
| if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None: |
| x = outputs.pooler_output |
| if x.dim() == 2: |
| return x |
| if x.dim() == 4 and x.size(-1) == 1 and x.size(-2) == 1: |
| return x.flatten(1) |
| raise RuntimeError(f"Unexpected pooler_output shape: {tuple(x.shape)}") |
|
|
| |
| |
| x = outputs.last_hidden_state |
| if x.dim() == 4: |
| return x.mean(dim=(2, 3)) |
|
|
| raise RuntimeError( |
| "Expected pooler_output or (B,C,H,W) last_hidden_state for CNN backbones. " |
| f"Got last_hidden_state shape={tuple(x.shape)}" |
| ) |
|
|
| def _extract_features(self, outputs, pixel_values: Optional[torch.Tensor] = None) -> torch.Tensor: |
| |
| |
| rule = self._meta["feat_rule"] |
|
|
| if rule == "cls": |
| |
| |
| return outputs.last_hidden_state[:, 0, :] |
|
|
| if rule == "pool_or_mean": |
| |
| |
| if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None: |
| return outputs.pooler_output |
| return outputs.last_hidden_state.mean(dim=1) |
|
|
| if rule == "pool_or_gap": |
| |
| |
| return self._pool_or_gap(outputs) |
|
|
| if rule == "timm_gap": |
| |
| |
| if not isinstance(outputs, torch.Tensor): |
| raise TypeError(f"timm_gap expects Tensor features, got {type(outputs)}") |
| if outputs.dim() != 4: |
| raise RuntimeError(f"Expected (B,C,H,W), got {tuple(outputs.shape)}") |
| return outputs.mean(dim=(2, 3)) |
|
|
| if rule == "torchvision_densenet_gap": |
| |
| |
| if not isinstance(outputs, torch.Tensor): |
| raise TypeError(f"torchvision_densenet_gap expects Tensor, got {type(outputs)}") |
| if outputs.dim() != 4: |
| raise RuntimeError(f"Expected (B,C,H,W), got {tuple(outputs.shape)}") |
| return outputs.mean(dim=(2, 3)) |
|
|
| raise RuntimeError(f"unknown feat_rule={rule}") |
|
|
| def forward( |
| self, |
| pixel_values=None, |
| labels=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=True, |
| **kwargs, |
| ): |
| |
| |
| t = self._meta["type"] |
|
|
| if t == "timm_densenet": |
| |
| |
| if pixel_values is None: |
| raise ValueError("timm DenseNet backbone requires pixel_values.") |
| if pixel_values.dim() != 4: |
| raise ValueError(f"pixel_values must be (B,C,H,W), got {tuple(pixel_values.shape)}") |
|
|
| features_map = self.backbone.forward_features(pixel_values) |
| feats = self._extract_features(features_map, pixel_values=pixel_values) |
| hidden_states = None |
| attentions = None |
|
|
| elif t == "torchvision_densenet": |
| |
| |
| if pixel_values is None: |
| raise ValueError("torchvision DenseNet backbone requires pixel_values.") |
| if pixel_values.dim() != 4: |
| raise ValueError(f"pixel_values must be (B,C,H,W), got {tuple(pixel_values.shape)}") |
|
|
| features_map = self.backbone.features(pixel_values) |
| features_map = F.relu(features_map, inplace=False) |
| feats = self._extract_features(features_map, pixel_values=pixel_values) |
| hidden_states = None |
| attentions = None |
|
|
| else: |
| |
| |
| outputs = self.backbone( |
| pixel_values=pixel_values, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=True, |
| **kwargs, |
| ) |
| feats = self._extract_features(outputs, pixel_values=pixel_values) |
| hidden_states = getattr(outputs, "hidden_states", None) |
| attentions = getattr(outputs, "attentions", None) |
|
|
| |
| |
| logits = self.classifier(feats) |
|
|
| loss = None |
| if labels is not None: |
| |
| |
| loss = F.cross_entropy(logits, labels) |
|
|
| if not return_dict: |
| out = (logits,) |
| return ((loss,) + out) if loss is not None else out |
|
|
| return ImageClassifierOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=hidden_states, |
| attentions=attentions, |
| ) |
|
|
|
|
| |
| |
| |
| |
| def _set_requires_grad(module: nn.Module, flag: bool): |
| |
| |
| for p in module.parameters(): |
| p.requires_grad = flag |
|
|
|
|
| def set_bn_eval(module: nn.Module): |
| |
| |
| for m in module.modules(): |
| if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)): |
| m.eval() |
|
|
|
|
| def freeze_backbone(model: BackboneWithMLPHeadForImageClassification, freeze_bn: bool = True): |
| |
| |
| _set_requires_grad(model.backbone, False) |
| _set_requires_grad(model.classifier, True) |
|
|
| meta = getattr(model, "_meta", None) or getattr(model.config, "backbone_meta", None) |
| if freeze_bn and meta.get("has_bn", False): |
| set_bn_eval(model.backbone) |
|
|
|
|
| def finetune_train_mode(model: BackboneWithMLPHeadForImageClassification, keep_bn_eval: bool = True): |
| |
| |
| model.train() |
| meta = getattr(model, "_meta", None) or getattr(model.config, "backbone_meta", None) |
| if keep_bn_eval and meta.get("has_bn", False): |
| set_bn_eval(model.backbone) |
|
|
|
|
| def trainable_summary(model: nn.Module): |
| |
| |
| total = sum(p.numel() for p in model.parameters()) |
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| ratio = trainable / total if total > 0 else 0.0 |
| print(f"trainable: {trainable:,} / total: {total:,} ({ratio*100:.2f}%)") |
| return {"trainable": trainable, "total": total, "ratio": ratio} |
|
|
|
|
| def unfreeze_last_stage( |
| model: BackboneWithMLPHeadForImageClassification, |
| last_n: int = 2, |
| keep_bn_eval: bool = True, |
| ): |
| |
| |
| freeze_backbone(model, freeze_bn=keep_bn_eval) |
|
|
| n = int(last_n) |
| if n <= 0: |
| return |
|
|
| meta = getattr(model, "_meta", None) or getattr(model.config, "backbone_meta", None) |
| if meta.get("unfreeze") != "last_n": |
| raise RuntimeError(f"Unexpected unfreeze rule: {meta.get('unfreeze')} (expected 'last_n')") |
|
|
| bb_type = meta["type"] |
|
|
| if bb_type == "vit": |
| |
| |
| blocks = list(model.backbone.encoder.layer) |
| for blk in blocks[-n:]: |
| _set_requires_grad(blk, True) |
| return |
|
|
| if bb_type == "swin": |
| |
| |
| stages = list(model.backbone.encoder.layers) |
| blocks: List[nn.Module] = [] |
| for st in stages: |
| blocks.extend(list(st.blocks)) |
| for blk in blocks[-n:]: |
| _set_requires_grad(blk, True) |
| return |
|
|
| if bb_type == "resnet": |
| |
| |
| bb = model.backbone |
| for name in ("layer1", "layer2", "layer3", "layer4"): |
| if not hasattr(bb, name): |
| raise RuntimeError(f"Unexpected ResNet structure: missing {name}") |
|
|
| blocks: List[nn.Module] = [] |
| blocks.extend(list(bb.layer1.children())) |
| blocks.extend(list(bb.layer2.children())) |
| blocks.extend(list(bb.layer3.children())) |
| blocks.extend(list(bb.layer4.children())) |
|
|
| for blk in blocks[-n:]: |
| _set_requires_grad(blk, True) |
|
|
| if keep_bn_eval: |
| set_bn_eval(bb) |
| return |
|
|
| if bb_type == "efficientnet": |
| |
| |
| bb = model.backbone |
| if not hasattr(bb, "features"): |
| raise RuntimeError("Unexpected EfficientNet structure: missing features") |
|
|
| blocks: List[nn.Module] = [] |
| for st in bb.features.children(): |
| blocks.extend(list(st.children())) |
|
|
| for blk in blocks[-n:]: |
| _set_requires_grad(blk, True) |
|
|
| if keep_bn_eval: |
| set_bn_eval(bb) |
| return |
|
|
| if bb_type in ("timm_densenet", "torchvision_densenet"): |
| |
| |
| bb = model.backbone |
| if not hasattr(bb, "features"): |
| raise RuntimeError("Unexpected DenseNet: missing features") |
| f = bb.features |
|
|
| req = [ |
| "conv0", "norm0", "relu0", "pool0", |
| "denseblock1", "transition1", |
| "denseblock2", "transition2", |
| "denseblock3", "transition3", |
| "denseblock4", "norm5", |
| ] |
| for name in req: |
| if not hasattr(f, name): |
| raise RuntimeError(f"Unexpected DenseNet features: missing {name}") |
|
|
| def _denselayers(db: nn.Module) -> List[nn.Module]: |
| |
| |
| return list(db.children()) |
|
|
| blocks: List[nn.Module] = [] |
| blocks.extend([f.conv0, f.norm0, f.relu0, f.pool0]) |
| blocks.extend(_denselayers(f.denseblock1)); blocks.append(f.transition1) |
| blocks.extend(_denselayers(f.denseblock2)); blocks.append(f.transition2) |
| blocks.extend(_denselayers(f.denseblock3)); blocks.append(f.transition3) |
| blocks.extend(_denselayers(f.denseblock4)); blocks.append(f.norm5) |
|
|
| for blk in blocks[-n:]: |
| _set_requires_grad(blk, True) |
|
|
| if keep_bn_eval: |
| set_bn_eval(bb) |
| return |
|
|
| raise RuntimeError(f"Unsupported backbone type: {bb_type}") |
|
|
|
|
| |
| |
| |
| |
| |
| |
| BackboneWithMLPHeadForImageClassification.register_for_auto_class("AutoModelForImageClassification") |
|
|