| import pickle |
| import torch |
| import timm |
| from dataclasses import dataclass |
| from typing import Optional, Tuple, Dict, Any |
|
|
| |
| try: |
| from timm.models.vision_transformer import VisionTransformer |
| except ImportError: |
| VisionTransformer = None |
|
|
| try: |
| from transformers import AutoModelForImageClassification |
| except Exception: |
| AutoModelForImageClassification = None |
|
|
| |
| try: |
| from safetensors.torch import load_file as load_safetensors |
| except ImportError: |
| load_safetensors = None |
|
|
| DEVICE_DEFAULT = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| @dataclass |
| class ViTConfig: |
| """Configuração de arquitetura ViT extraída dinamicamente do modelo.""" |
| embed_dim: int = 768 |
| num_heads: int = 12 |
| num_layers: int = 12 |
| patch_size: int = 16 |
| img_size: int = 224 |
| num_classes: int = 1000 |
| mlp_ratio: float = 4.0 |
| qkv_bias: bool = True |
| |
| @property |
| def grid_size(self) -> int: |
| """Tamanho do grid de patches (ex: 224/16 = 14).""" |
| return self.img_size // self.patch_size |
| |
| @property |
| def num_patches(self) -> int: |
| """Número total de patches (ex: 14*14 = 196).""" |
| return self.grid_size ** 2 |
| |
| @property |
| def timm_model_name(self) -> str: |
| """Retorna o nome do modelo timm correspondente (para fins informativos).""" |
| |
| size_map = { |
| (192, 3): 'tiny', |
| (384, 6): 'small', |
| (768, 12): 'base', |
| (1024, 16): 'large', |
| (1280, 16): 'huge', |
| } |
| size = size_map.get((self.embed_dim, self.num_heads), 'custom') |
| return f"vit_{size}_patch{self.patch_size}_{self.img_size}" |
|
|
|
|
| def create_vit_from_config(config: ViTConfig, device: Optional[torch.device] = None) -> torch.nn.Module: |
| """Cria um modelo ViT diretamente a partir da configuração inferida. |
| |
| Isso permite criar modelos com arquiteturas arbitrárias, não limitadas |
| aos nomes predefinidos do timm (vit_base_patch16_224, etc.). |
| """ |
| device = device or DEVICE_DEFAULT |
| |
| if VisionTransformer is None: |
| raise RuntimeError("VisionTransformer não disponível. Verifique a instalação do timm.") |
| |
| model = VisionTransformer( |
| img_size=config.img_size, |
| patch_size=config.patch_size, |
| in_chans=3, |
| num_classes=config.num_classes, |
| embed_dim=config.embed_dim, |
| depth=config.num_layers, |
| num_heads=config.num_heads, |
| mlp_ratio=config.mlp_ratio, |
| qkv_bias=config.qkv_bias, |
| class_token=True, |
| global_pool='token', |
| ) |
| |
| return model.to(device) |
|
|
|
|
| def _strip_state_dict_prefix(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
| """Remove prefixos comuns de frameworks (Lightning, DDP, etc.) das keys do state_dict. |
| |
| Prefixos tratados: |
| - 'model.' (PyTorch Lightning) |
| - 'module.' (DataParallel/DistributedDataParallel) |
| - 'encoder.' (alguns frameworks de self-supervised learning) |
| - 'backbone.' (alguns frameworks de detecção) |
| |
| Returns: |
| state_dict com keys sem prefixo |
| """ |
| prefixes = ['model.', 'module.', 'encoder.', 'backbone.'] |
| |
| |
| has_prefix = False |
| detected_prefix = None |
| for key in state_dict.keys(): |
| for prefix in prefixes: |
| if key.startswith(prefix): |
| has_prefix = True |
| detected_prefix = prefix |
| break |
| if has_prefix: |
| break |
| |
| if not has_prefix: |
| return state_dict |
| |
| print(f"[ViTViz] Detectado prefixo '{detected_prefix}' nas keys do state_dict (Lightning/DDP). Removendo...") |
| |
| new_sd: Dict[str, torch.Tensor] = {} |
| for key, value in state_dict.items(): |
| new_key = key |
| for prefix in prefixes: |
| if key.startswith(prefix): |
| new_key = key[len(prefix):] |
| break |
| new_sd[new_key] = value |
| |
| return new_sd |
|
|
|
|
| def validate_vit_structure(model: torch.nn.Module) -> Tuple[bool, str]: |
| """Valida se o modelo tem a estrutura esperada de um ViT timm-compatível. |
| |
| Returns: |
| (is_valid, error_message) - se inválido, error_message descreve o problema |
| """ |
| if not hasattr(model, 'blocks'): |
| return False, "Modelo não tem atributo 'blocks'. Não é um ViT compatível." |
| |
| if len(model.blocks) == 0: |
| return False, "Modelo tem 'blocks' vazio." |
| |
| block = model.blocks[0] |
| if not hasattr(block, 'attn'): |
| return False, "Bloco não tem atributo 'attn'. Estrutura incompatível." |
| |
| attn = block.attn |
| if not hasattr(attn, 'qkv'): |
| return False, "Módulo de atenção não tem 'qkv'. Estrutura incompatível." |
| |
| if not hasattr(attn, 'num_heads'): |
| return False, "Módulo de atenção não tem 'num_heads'. Estrutura incompatível." |
| |
| return True, "" |
|
|
|
|
| def infer_config_from_model(model: torch.nn.Module) -> ViTConfig: |
| """Infere configuração ViT a partir de um modelo timm carregado.""" |
| config = ViTConfig() |
| |
| |
| if hasattr(model, 'patch_embed'): |
| pe = model.patch_embed |
| if hasattr(pe, 'img_size'): |
| img_size = pe.img_size |
| config.img_size = img_size[0] if isinstance(img_size, (tuple, list)) else img_size |
| if hasattr(pe, 'patch_size'): |
| patch_size = pe.patch_size |
| config.patch_size = patch_size[0] if isinstance(patch_size, (tuple, list)) else patch_size |
| |
| |
| if hasattr(model, 'blocks') and len(model.blocks) > 0: |
| config.num_layers = len(model.blocks) |
| block = model.blocks[0] |
| if hasattr(block, 'attn'): |
| attn = block.attn |
| if hasattr(attn, 'num_heads'): |
| config.num_heads = attn.num_heads |
| if hasattr(attn, 'qkv') and hasattr(attn.qkv, 'in_features'): |
| config.embed_dim = attn.qkv.in_features |
| |
| |
| if hasattr(model, 'head') and hasattr(model.head, 'out_features'): |
| config.num_classes = model.head.out_features |
| elif hasattr(model, 'head') and hasattr(model.head, 'weight'): |
| config.num_classes = model.head.weight.shape[0] |
| |
| return config |
|
|
|
|
| def infer_config_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> ViTConfig: |
| """Infere configuração ViT a partir de um state_dict.""" |
| config = ViTConfig() |
| |
| |
| layer_indices = set() |
| for key in state_dict.keys(): |
| if key.startswith('blocks.') and '.attn.' in key: |
| |
| idx = int(key.split('.')[1]) |
| layer_indices.add(idx) |
| if layer_indices: |
| config.num_layers = max(layer_indices) + 1 |
| |
| |
| qkv_key = 'blocks.0.attn.qkv.weight' |
| if qkv_key in state_dict: |
| qkv_weight = state_dict[qkv_key] |
| |
| config.embed_dim = qkv_weight.shape[1] |
| |
| |
| |
| |
| |
| |
| proj_key = 'blocks.0.attn.proj.weight' |
| if proj_key in state_dict and qkv_key in state_dict: |
| embed_dim = state_dict[proj_key].shape[0] |
| qkv_out = state_dict[qkv_key].shape[0] |
| |
| |
| if qkv_out == 3 * embed_dim: |
| |
| for head_dim in [64, 32, 96, 48, 128]: |
| if embed_dim % head_dim == 0: |
| config.num_heads = embed_dim // head_dim |
| break |
| else: |
| |
| |
| for nh in [12, 16, 8, 6, 24, 4, 3]: |
| if embed_dim % nh == 0: |
| config.num_heads = nh |
| break |
| |
| |
| qkv_bias_key = 'blocks.0.attn.qkv.bias' |
| config.qkv_bias = qkv_bias_key in state_dict |
| |
| |
| mlp_fc1_key = 'blocks.0.mlp.fc1.weight' |
| if mlp_fc1_key in state_dict and config.embed_dim > 0: |
| mlp_hidden = state_dict[mlp_fc1_key].shape[0] |
| config.mlp_ratio = mlp_hidden / config.embed_dim |
| |
| |
| head_key = 'head.weight' |
| if head_key in state_dict: |
| config.num_classes = state_dict[head_key].shape[0] |
| |
| |
| patch_proj_key = 'patch_embed.proj.weight' |
| if patch_proj_key in state_dict: |
| |
| patch_weight = state_dict[patch_proj_key] |
| config.patch_size = patch_weight.shape[2] |
| |
| |
| pos_embed_key = 'pos_embed' |
| if pos_embed_key in state_dict: |
| |
| num_tokens = state_dict[pos_embed_key].shape[1] |
| num_patches = num_tokens - 1 |
| grid_size = int(num_patches ** 0.5) |
| config.img_size = grid_size * config.patch_size |
| |
| return config |
|
|
|
|
| def _hf_id2label_to_class_names(id2label: Any) -> Optional[Dict[int, str]]: |
| if not isinstance(id2label, dict): |
| return None |
| out: Dict[int, str] = {} |
| for k, v in id2label.items(): |
| try: |
| out[int(k)] = str(v) |
| except Exception: |
| continue |
| return out or None |
|
|
|
|
| def _convert_hf_vit_to_timm_state_dict(hf_sd: Dict[str, torch.Tensor], num_layers: int) -> Dict[str, torch.Tensor]: |
| """Converte state_dict de ViT (Hugging Face Transformers) para chaves do timm ViT. |
| |
| Alvo: timm "vit_base_patch16_224". |
| """ |
| out: Dict[str, torch.Tensor] = {} |
|
|
| def get(key: str) -> torch.Tensor: |
| if key not in hf_sd: |
| raise KeyError(f"Missing key in HF state_dict: {key}") |
| return hf_sd[key] |
|
|
| |
| out["cls_token"] = get("vit.embeddings.cls_token") |
| out["pos_embed"] = get("vit.embeddings.position_embeddings") |
| out["patch_embed.proj.weight"] = get("vit.embeddings.patch_embeddings.projection.weight") |
| out["patch_embed.proj.bias"] = get("vit.embeddings.patch_embeddings.projection.bias") |
|
|
| |
| for i in range(num_layers): |
| prefix = f"vit.encoder.layer.{i}" |
| out[f"blocks.{i}.norm1.weight"] = get(f"{prefix}.layernorm_before.weight") |
| out[f"blocks.{i}.norm1.bias"] = get(f"{prefix}.layernorm_before.bias") |
| out[f"blocks.{i}.norm2.weight"] = get(f"{prefix}.layernorm_after.weight") |
| out[f"blocks.{i}.norm2.bias"] = get(f"{prefix}.layernorm_after.bias") |
|
|
| qw = get(f"{prefix}.attention.attention.query.weight") |
| kw = get(f"{prefix}.attention.attention.key.weight") |
| vw = get(f"{prefix}.attention.attention.value.weight") |
| qb = get(f"{prefix}.attention.attention.query.bias") |
| kb = get(f"{prefix}.attention.attention.key.bias") |
| vb = get(f"{prefix}.attention.attention.value.bias") |
| out[f"blocks.{i}.attn.qkv.weight"] = torch.cat([qw, kw, vw], dim=0) |
| out[f"blocks.{i}.attn.qkv.bias"] = torch.cat([qb, kb, vb], dim=0) |
|
|
| out[f"blocks.{i}.attn.proj.weight"] = get(f"{prefix}.attention.output.dense.weight") |
| out[f"blocks.{i}.attn.proj.bias"] = get(f"{prefix}.attention.output.dense.bias") |
|
|
| out[f"blocks.{i}.mlp.fc1.weight"] = get(f"{prefix}.intermediate.dense.weight") |
| out[f"blocks.{i}.mlp.fc1.bias"] = get(f"{prefix}.intermediate.dense.bias") |
| out[f"blocks.{i}.mlp.fc2.weight"] = get(f"{prefix}.output.dense.weight") |
| out[f"blocks.{i}.mlp.fc2.bias"] = get(f"{prefix}.output.dense.bias") |
|
|
| out["norm.weight"] = get("vit.layernorm.weight") |
| out["norm.bias"] = get("vit.layernorm.bias") |
|
|
| |
| if "classifier.weight" in hf_sd and "classifier.bias" in hf_sd: |
| out["head.weight"] = get("classifier.weight") |
| out["head.bias"] = get("classifier.bias") |
|
|
| return out |
|
|
|
|
| def _convert_hf_timm_wrapper_to_timm_state_dict(hf_sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
| """Converte state_dict de TimmWrapper (Transformers) para formato timm ViT. |
| |
| Exemplo de origem: chaves com prefixo ``timm_model.``. |
| """ |
| out: Dict[str, torch.Tensor] = {} |
|
|
| for key, value in hf_sd.items(): |
| if key.startswith("timm_model."): |
| out[key[len("timm_model."):]] = value |
| elif key.startswith("classifier."): |
| |
| out[f"head.{key[len('classifier.'):]}"] = value |
|
|
| if not out: |
| raise ValueError("State_dict de TimmWrapper sem chaves reconhecidas (timm_model.* / classifier.*).") |
|
|
| return out |
|
|
|
|
| def load_vit_from_huggingface(model_id: str, device: Optional[torch.device] = None) -> Tuple[torch.nn.Module, Optional[Dict[int, str]], ViTConfig]: |
| """Carrega ViT do Hugging Face Hub e retorna um modelo timm equivalente. |
| |
| Returns: |
| (model, class_names, config) |
| """ |
| if AutoModelForImageClassification is None: |
| raise RuntimeError("transformers não está instalado; instale 'transformers' para carregar do Hugging Face.") |
|
|
| device = device or DEVICE_DEFAULT |
| hf_model = AutoModelForImageClassification.from_pretrained(model_id) |
| hf_model.eval() |
| cfg = getattr(hf_model, "config", None) |
| class_names = _hf_id2label_to_class_names(getattr(cfg, "id2label", None)) if cfg is not None else None |
|
|
| hf_sd = hf_model.state_dict() |
| if any(key.startswith("timm_model.") for key in hf_sd.keys()): |
| timm_sd = _convert_hf_timm_wrapper_to_timm_state_dict(hf_sd) |
| else: |
| num_layers = int(getattr(cfg, "num_hidden_layers", 12)) if cfg is not None else 12 |
| timm_sd = _convert_hf_vit_to_timm_state_dict(hf_sd, num_layers=num_layers) |
|
|
| vit_config = infer_config_from_state_dict(timm_sd) |
| if cfg is not None and hasattr(cfg, "num_labels"): |
| try: |
| vit_config.num_classes = int(getattr(cfg, "num_labels")) |
| except Exception: |
| pass |
|
|
| print(f"[ViTViz] Carregando do HuggingFace: {vit_config.timm_model_name} " |
| f"(embed_dim={vit_config.embed_dim}, heads={vit_config.num_heads}, " |
| f"layers={vit_config.num_layers}, patch={vit_config.patch_size}, img={vit_config.img_size})") |
|
|
| timm_model = create_vit_from_config(vit_config, device=device) |
| timm_model.load_state_dict(timm_sd, strict=False) |
| timm_model.eval() |
| |
| return timm_model, class_names, vit_config |
|
|
|
|
| class CustomUnpickler(pickle.Unpickler): |
| """Unpickler que ignora classes customizadas ausentes criando dummies dinamicamente.""" |
|
|
| def find_class(self, module, name): |
| try: |
| return super().find_class(module, name) |
| except Exception: |
| |
| return type(name, (), {}) |
|
|
|
|
| def load_checkpoint(model_path: str, device: Optional[torch.device] = None) -> Any: |
| """Carrega um checkpoint/modelo do caminho informado. |
| |
| Suporta formatos: |
| - .pth / .pt: PyTorch checkpoint (torch.load) |
| - .safetensors: Formato moderno do HuggingFace (mais seguro e rápido) |
| |
| Retorna o objeto carregado (modelo completo, state_dict ou dict de checkpoint). |
| """ |
| device = device or DEVICE_DEFAULT |
| |
| |
| if model_path.endswith('.safetensors'): |
| if load_safetensors is None: |
| raise ImportError( |
| "safetensors não está instalado. Instale com: pip install safetensors" |
| ) |
| |
| state_dict = load_safetensors(model_path, device=str(device)) |
| return state_dict |
| |
| |
| try: |
| return torch.load(model_path, map_location=device, weights_only=False) |
| except (AttributeError, ModuleNotFoundError, RuntimeError): |
| |
| with open(model_path, 'rb') as f: |
| return CustomUnpickler(f).load() |
|
|
|
|
| def infer_num_classes(state_dict: Dict[str, torch.Tensor]) -> int: |
| """Infere o número de classes a partir do state_dict (camada de head). |
| |
| Caso não encontre, retorna 1000 (padrão ImageNet). |
| """ |
| for key, tensor in state_dict.items(): |
| if 'head' in key and 'weight' in key and hasattr(tensor, 'shape'): |
| return tensor.shape[0] |
| return 1000 |
|
|
|
|
| def extract_class_names(checkpoint: Any) -> Optional[Dict[int, str]]: |
| """Tenta extrair nomes de classes de um checkpoint (se presente).""" |
| if not isinstance(checkpoint, dict): |
| return None |
|
|
| possible_keys = [ |
| 'class_names', 'classes', 'class_to_idx', 'idx_to_class', |
| 'label_names', 'labels', 'class_labels' |
| ] |
|
|
| for key in possible_keys: |
| if key in checkpoint: |
| labels = checkpoint[key] |
| if isinstance(labels, list): |
| return {i: name for i, name in enumerate(labels)} |
| if isinstance(labels, dict): |
| |
| if all(isinstance(k, int) for k in labels.keys()): |
| return labels |
| |
| if all(isinstance(v, int) for v in labels.values()): |
| return {v: k for k, v in labels.items()} |
| return labels |
| return None |
|
|
|
|
| def load_class_names_from_file(labels_file: Optional[str]) -> Optional[Dict[int, str]]: |
| """Carrega nomes de classes de um arquivo .txt (um por linha) ou .json (lista ou dict).""" |
| if not labels_file: |
| return None |
| import json |
| try: |
| if labels_file.endswith('.json'): |
| with open(labels_file, 'r', encoding='utf-8') as f: |
| data = json.load(f) |
| if isinstance(data, list): |
| return {i: name for i, name in enumerate(data)} |
| if isinstance(data, dict): |
| out: Dict[int, str] = {} |
| for k, v in data.items(): |
| try: |
| out[int(k)] = v |
| except Exception: |
| |
| pass |
| if out: |
| return out |
| |
| if all(isinstance(v, int) for v in data.values()): |
| return {v: k for k, v in data.items()} |
| return None |
| else: |
| with open(labels_file, 'r', encoding='utf-8') as f: |
| lines = [line.strip() for line in f if line.strip()] |
| return {i: name for i, name in enumerate(lines)} |
| except Exception: |
| return None |
|
|
|
|
| def build_model_from_checkpoint(checkpoint: Any, device: Optional[torch.device] = None) -> Tuple[torch.nn.Module, ViTConfig]: |
| """Constroi um modelo a partir de um checkpoint que pode ser um dict, state_dict ou o próprio modelo. |
| |
| Suporta arquiteturas ViT arbitrárias, não limitadas aos nomes predefinidos do timm. |
| |
| Returns: |
| (model, config) - modelo carregado e configuração inferida |
| """ |
| device = device or DEVICE_DEFAULT |
| config: Optional[ViTConfig] = None |
| |
| |
| if isinstance(checkpoint, dict) and 'pytorch-lightning_version' in checkpoint: |
| print(f"[ViTViz] Detectado checkpoint PyTorch Lightning (v{checkpoint.get('pytorch-lightning_version', '?')})") |
| |
| if isinstance(checkpoint, dict): |
| if 'model' in checkpoint: |
| |
| model = checkpoint['model'] |
| config = infer_config_from_model(model) |
| |
| is_valid, error_msg = validate_vit_structure(model) |
| if not is_valid: |
| raise ValueError(f"Modelo inválido: {error_msg}") |
| elif 'state_dict' in checkpoint: |
| state_dict = checkpoint['state_dict'] |
| |
| state_dict = _strip_state_dict_prefix(state_dict) |
| config = infer_config_from_state_dict(state_dict) |
| print(f"[ViTViz] Arquitetura inferida: {config.timm_model_name} " |
| f"(embed_dim={config.embed_dim}, heads={config.num_heads}, " |
| f"layers={config.num_layers}, patch={config.patch_size}, img={config.img_size})") |
| |
| model = create_vit_from_config(config, device=device) |
| |
| model.load_state_dict(state_dict, strict=False) |
| elif 'model_state_dict' in checkpoint: |
| |
| state_dict = checkpoint['model_state_dict'] |
| |
| state_dict = _strip_state_dict_prefix(state_dict) |
| config = infer_config_from_state_dict(state_dict) |
| print(f"[ViTViz] Arquitetura inferida: {config.timm_model_name} " |
| f"(embed_dim={config.embed_dim}, heads={config.num_heads}, " |
| f"layers={config.num_layers}, patch={config.patch_size}, img={config.img_size})") |
| |
| model = create_vit_from_config(config, device=device) |
| |
| model.load_state_dict(state_dict, strict=False) |
| else: |
| |
| |
| checkpoint = _strip_state_dict_prefix(checkpoint) |
| config = infer_config_from_state_dict(checkpoint) |
| print(f"[ViTViz] Arquitetura inferida: {config.timm_model_name} " |
| f"(embed_dim={config.embed_dim}, heads={config.num_heads}, " |
| f"layers={config.num_layers}, patch={config.patch_size}, img={config.img_size})") |
| |
| model = create_vit_from_config(config, device=device) |
| |
| model.load_state_dict(checkpoint, strict=False) |
| else: |
| |
| model = checkpoint |
| |
| is_valid, error_msg = validate_vit_structure(model) |
| if not is_valid: |
| raise ValueError(f"Modelo inválido: {error_msg}") |
| config = infer_config_from_model(model) |
|
|
| model = model.to(device) |
| model.eval() |
| |
| |
| if config is None: |
| config = infer_config_from_model(model) |
| |
| return model, config |
|
|
|
|
| def load_model_and_labels( |
| model_path: str, |
| labels_file: Optional[str] = None, |
| device: Optional[torch.device] = None, |
| ) -> Tuple[torch.nn.Module, Optional[Dict[int, str]], Optional[str], ViTConfig]: |
| """ |
| ** Função Principal ** |
| Carrega modelo e, se disponível, nomes de classes. |
| |
| Retorna: (model, class_names, origem_labels, config) onde origem_labels ∈ {"file", "checkpoint", "hf", None} |
| None se não houver nomes de classes disponíveis. |
| config contém a configuração da arquitetura ViT (embed_dim, num_heads, grid_size, etc.) |
| """ |
| device = device or DEVICE_DEFAULT |
|
|
| |
| if isinstance(model_path, str) and model_path.startswith("hf-model://"): |
| model_id = model_path[len("hf-model://"):].strip("/") |
| model, class_names, config = load_vit_from_huggingface(model_id, device=device) |
| return model, class_names, 'hf', config |
|
|
| checkpoint = load_checkpoint(model_path, device=device) |
| class_names_ckpt = extract_class_names(checkpoint) |
| |
| |
| |
| |
| |
| |
| |
|
|
| class_names = class_names_ckpt |
| source = 'checkpoint' if class_names_ckpt else None |
|
|
| model, config = build_model_from_checkpoint(checkpoint, device=device) |
| return model, class_names, source, config |
|
|