Upload ProtoMorph-DINO scaffold and random head checkpoint
Browse files- .env.example +8 -0
- .gitattributes +5 -32
- Dockerfile +22 -0
- HF_UPLOAD_GUIDE.md +67 -0
- LICENSE-WEIGHTS.md +15 -0
- README.md +132 -172
- README_RUNPOD.md +181 -0
- checkpoints/config.json +17 -0
- checkpoints/labels.txt +10 -0
- checkpoints/protomorph_head.safetensors +3 -0
- config.json +17 -0
- infer.py +47 -0
- labels.txt +10 -0
- notebooks/ProtoMorph_DINOv3_Inference.ipynb +127 -0
- pyproject.toml +8 -0
- requirements-core.txt +10 -0
- runpod/setup_runpod.sh +30 -0
- runpod/start_jupyter.sh +19 -0
- runpod/upload_to_hf.sh +22 -0
- scripts/create_random_head.py +62 -0
- scripts/smoke_test_head_only.py +35 -0
- scripts/upload_to_hf.py +149 -0
- src/protomorph/__init__.py +15 -0
- src/protomorph/config.py +51 -0
- src/protomorph/hf_utils.py +67 -0
- src/protomorph/inference.py +97 -0
- src/protomorph/model.py +420 -0
.env.example
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# RunPod environment variables
|
| 2 |
+
# Do not commit real secrets.
|
| 3 |
+
hf_key=hf_your_huggingface_write_token_here
|
| 4 |
+
hf_repo=shiowo/DINO-Protomorph
|
| 5 |
+
|
| 6 |
+
# Standard names are also supported:
|
| 7 |
+
# HF_TOKEN=hf_your_huggingface_write_token_here
|
| 8 |
+
# HF_REPO_ID=shiowo/DINO-Protomorph
|
.gitattributes
CHANGED
|
@@ -1,35 +1,8 @@
|
|
| 1 |
-
*.
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.
|
| 25 |
-
*.
|
| 26 |
-
|
| 27 |
-
*.
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 2 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 4 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Dockerfile
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Optional Dockerfile. On RunPod, it is usually easier to start from a PyTorch
|
| 2 |
+
# 2.4.0 / CUDA 12.4 template and run runpod/setup_runpod.sh.
|
| 3 |
+
FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04
|
| 4 |
+
|
| 5 |
+
ENV DEBIAN_FRONTEND=noninteractive \
|
| 6 |
+
HF_HOME=/workspace/hf_cache \
|
| 7 |
+
TRANSFORMERS_CACHE=/workspace/hf_cache \
|
| 8 |
+
PYTHONUNBUFFERED=1
|
| 9 |
+
|
| 10 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 11 |
+
software-properties-common git curl wget ca-certificates build-essential \
|
| 12 |
+
&& add-apt-repository ppa:deadsnakes/ppa -y \
|
| 13 |
+
&& apt-get update && apt-get install -y --no-install-recommends \
|
| 14 |
+
python3.11 python3.11-venv python3.11-dev \
|
| 15 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 16 |
+
|
| 17 |
+
WORKDIR /workspace/protomorph_dinov3_runpod
|
| 18 |
+
COPY . /workspace/protomorph_dinov3_runpod
|
| 19 |
+
RUN bash runpod/setup_runpod.sh
|
| 20 |
+
|
| 21 |
+
EXPOSE 8888
|
| 22 |
+
CMD ["bash", "runpod/start_jupyter.sh"]
|
HF_UPLOAD_GUIDE.md
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hugging Face Upload Guide
|
| 2 |
+
|
| 3 |
+
This project is configured for the Hugging Face model repo:
|
| 4 |
+
|
| 5 |
+
```text
|
| 6 |
+
shiowo/DINO-Protomorph
|
| 7 |
+
```
|
| 8 |
+
|
| 9 |
+
## 1. Set RunPod environment variables
|
| 10 |
+
|
| 11 |
+
In RunPod, add:
|
| 12 |
+
|
| 13 |
+
```text
|
| 14 |
+
hf_key=hf_your_huggingface_write_token_here
|
| 15 |
+
hf_repo=shiowo/DINO-Protomorph
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
`hf_key` must be a Hugging Face token with write access to the target repo.
|
| 19 |
+
|
| 20 |
+
The script also supports standard names:
|
| 21 |
+
|
| 22 |
+
```text
|
| 23 |
+
HF_TOKEN=hf_your_huggingface_write_token_here
|
| 24 |
+
HF_REPO_ID=shiowo/DINO-Protomorph
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
Never commit real tokens.
|
| 28 |
+
|
| 29 |
+
## 2. Install dependencies
|
| 30 |
+
|
| 31 |
+
```bash
|
| 32 |
+
cd /workspace/protomorph_dinov3_runpod
|
| 33 |
+
bash runpod/setup_runpod.sh
|
| 34 |
+
source .venv/bin/activate
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
## 3. Dry run
|
| 38 |
+
|
| 39 |
+
```bash
|
| 40 |
+
python scripts/upload_to_hf.py --dry-run
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
This checks the required files and prints the file list without uploading.
|
| 44 |
+
|
| 45 |
+
## 4. Upload
|
| 46 |
+
|
| 47 |
+
```bash
|
| 48 |
+
python scripts/upload_to_hf.py
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
Or:
|
| 52 |
+
|
| 53 |
+
```bash
|
| 54 |
+
bash runpod/upload_to_hf.sh
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
## 5. Important notes
|
| 58 |
+
|
| 59 |
+
This upload includes the custom ProtoMorph head checkpoint:
|
| 60 |
+
|
| 61 |
+
```text
|
| 62 |
+
checkpoints/protomorph_head.safetensors
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
It does not include DINOv3 backbone weights. DINOv3 is loaded separately during inference.
|
| 66 |
+
|
| 67 |
+
The model card marks all results as **Pending** because training and benchmarking have not been completed yet.
|
LICENSE-WEIGHTS.md
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ProtoMorph Head Weights License
|
| 2 |
+
|
| 3 |
+
The file below is licensed under the Creative Commons Attribution-ShareAlike 4.0 International License:
|
| 4 |
+
|
| 5 |
+
```text
|
| 6 |
+
checkpoints/protomorph_head.safetensors
|
| 7 |
+
```
|
| 8 |
+
|
| 9 |
+
You may use, share, and adapt these weights, including for commercial purposes, provided that you give appropriate credit and distribute adapted versions under CC BY-SA 4.0 or a compatible license.
|
| 10 |
+
|
| 11 |
+
This license applies only to the ProtoMorph head weights released by this project.
|
| 12 |
+
|
| 13 |
+
It does not apply to DINOv3, PyTorch, Hugging Face Transformers, third-party datasets, third-party model weights, or upstream dependencies.
|
| 14 |
+
|
| 15 |
+
DINOv3 weights are not redistributed in this repository. Users are responsible for obtaining DINOv3 separately from its official source and complying with its license.
|
README.md
CHANGED
|
@@ -9,15 +9,10 @@ tags:
|
|
| 9 |
- dinov3
|
| 10 |
- pytorch
|
| 11 |
- safetensors
|
| 12 |
-
- architecture
|
| 13 |
-
- research
|
| 14 |
-
- untrained
|
| 15 |
- prototype-learning
|
| 16 |
- hard-example-mining
|
| 17 |
- feedback-routing
|
| 18 |
- experimental
|
| 19 |
-
datasets:
|
| 20 |
-
- pending
|
| 21 |
metrics:
|
| 22 |
- accuracy
|
| 23 |
- f1
|
|
@@ -25,27 +20,21 @@ metrics:
|
|
| 25 |
- recall
|
| 26 |
---
|
| 27 |
|
| 28 |
-
#
|
| 29 |
|
| 30 |
**Feedback-Gated Prototype Morphing for Hard-Case Image Classification**
|
| 31 |
|
| 32 |
ProtoMorph-DINO is an experimental image classification head designed to run on top of a frozen DINOv3 vision backbone.
|
| 33 |
|
| 34 |
-
|
| 35 |
|
| 36 |
-
|
| 37 |
-
-
|
| 38 |
-
|
| 39 |
-
- confidence-based hard-case routing
|
| 40 |
-
- top-2 probability feedback
|
| 41 |
-
- Delta-RBF hard expert refinement
|
| 42 |
-
- logit fusion for difficult samples
|
| 43 |
-
|
| 44 |
-
This repository currently contains the early project/model-card setup for ProtoMorph-DINO. Training and evaluation results are still pending.
|
| 45 |
|
| 46 |
-
This repository
|
| 47 |
|
| 48 |
-
This project is
|
| 49 |
|
| 50 |
---
|
| 51 |
|
|
@@ -81,13 +70,48 @@ Hard-case gate
|
|
| 81 |
|
| 82 |
## Model Summary
|
| 83 |
|
| 84 |
-
ProtoMorph-DINO
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
-
|
| 89 |
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
---
|
| 93 |
|
|
@@ -100,11 +124,11 @@ This model is intended for:
|
|
| 100 |
- prototype learning experiments
|
| 101 |
- frozen-backbone classifier research
|
| 102 |
- fine-grained classification experiments
|
| 103 |
-
- educational
|
| 104 |
|
| 105 |
This model is **not** intended for safety-critical use.
|
| 106 |
|
| 107 |
-
Do not use this model for medical, legal, financial, biometric, security-critical, or production decisions without
|
| 108 |
|
| 109 |
---
|
| 110 |
|
|
@@ -115,23 +139,29 @@ Recommended repository layout:
|
|
| 115 |
```text
|
| 116 |
.
|
| 117 |
├── README.md
|
|
|
|
| 118 |
├── config.json
|
| 119 |
├── labels.txt
|
| 120 |
-
├──
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
```
|
| 125 |
|
| 126 |
-
The main weight file is
|
| 127 |
|
| 128 |
```text
|
| 129 |
-
protomorph_head.safetensors
|
| 130 |
```
|
| 131 |
|
| 132 |
This file contains only the custom ProtoMorph classification head.
|
| 133 |
|
| 134 |
-
DINOv3 backbone weights are not included.
|
| 135 |
|
| 136 |
---
|
| 137 |
|
|
@@ -145,7 +175,7 @@ facebook/dinov3-vits16-pretrain-lvd1689m
|
|
| 145 |
|
| 146 |
The backbone is used as a frozen visual feature extractor.
|
| 147 |
|
| 148 |
-
For RTX 3090-class GPUs,
|
| 149 |
|
| 150 |
---
|
| 151 |
|
|
@@ -162,159 +192,105 @@ CUDA 12.4 PyTorch wheel
|
|
| 162 |
Install PyTorch:
|
| 163 |
|
| 164 |
```bash
|
| 165 |
-
pip install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124
|
| 166 |
```
|
| 167 |
|
| 168 |
Install dependencies:
|
| 169 |
|
| 170 |
```bash
|
| 171 |
-
pip install
|
| 172 |
```
|
| 173 |
|
| 174 |
---
|
| 175 |
|
| 176 |
-
##
|
| 177 |
-
|
| 178 |
-
```python
|
| 179 |
-
import torch
|
| 180 |
-
from PIL import Image
|
| 181 |
-
from transformers import AutoImageProcessor, AutoModel
|
| 182 |
-
from safetensors.torch import load_file
|
| 183 |
-
|
| 184 |
-
# Replace with your local or Hugging Face repo path.
|
| 185 |
-
REPO_ID = "shiowo/DINO-Protomorph"
|
| 186 |
-
|
| 187 |
-
# DINOv3 is loaded separately.
|
| 188 |
-
BACKBONE_NAME = "facebook/dinov3-vits16-pretrain-lvd1689m"
|
| 189 |
-
|
| 190 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 191 |
-
|
| 192 |
-
processor = AutoImageProcessor.from_pretrained(BACKBONE_NAME)
|
| 193 |
-
backbone = AutoModel.from_pretrained(
|
| 194 |
-
BACKBONE_NAME,
|
| 195 |
-
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
| 196 |
-
).to(device)
|
| 197 |
-
|
| 198 |
-
backbone.eval()
|
| 199 |
-
for p in backbone.parameters():
|
| 200 |
-
p.requires_grad = False
|
| 201 |
-
|
| 202 |
-
# Load your ProtoMorph model class from your local code.
|
| 203 |
-
# from model import ProtoMorphDINOClassifier
|
| 204 |
-
#
|
| 205 |
-
# model = ProtoMorphDINOClassifier(...)
|
| 206 |
-
# state = load_file("protomorph_head.safetensors")
|
| 207 |
-
# model.load_state_dict(state, strict=True)
|
| 208 |
-
# model.to(device)
|
| 209 |
-
# model.eval()
|
| 210 |
-
|
| 211 |
-
image = Image.open("example.jpg").convert("RGB")
|
| 212 |
-
inputs = processor(images=image, return_tensors="pt").to(device)
|
| 213 |
-
|
| 214 |
-
with torch.no_grad():
|
| 215 |
-
outputs = backbone(**inputs)
|
| 216 |
-
tokens = outputs.last_hidden_state
|
| 217 |
-
|
| 218 |
-
# DINOv3 ViT outputs include special tokens before patch tokens.
|
| 219 |
-
# Your implementation should remove CLS/register tokens according to its config.
|
| 220 |
-
#
|
| 221 |
-
# logits = model(tokens)
|
| 222 |
-
# probs = torch.softmax(logits, dim=-1)
|
| 223 |
-
# print(probs)
|
| 224 |
-
```
|
| 225 |
|
| 226 |
-
|
| 227 |
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
```json
|
| 233 |
-
{
|
| 234 |
-
"model_name": "ProtoMorph-DINO",
|
| 235 |
-
"backbone_name": "facebook/dinov3-vits16-pretrain-lvd1689m",
|
| 236 |
-
"num_classes": "pending",
|
| 237 |
-
"patch_dim": 384,
|
| 238 |
-
"hidden_dim": 512,
|
| 239 |
-
"num_prototypes": 64,
|
| 240 |
-
"memory_heads": 8,
|
| 241 |
-
"hard_gate_confidence_threshold": 0.65,
|
| 242 |
-
"hard_gate_margin_threshold": 0.15,
|
| 243 |
-
"hard_expert_weight": 0.5,
|
| 244 |
-
"dtype": "float16"
|
| 245 |
-
}
|
| 246 |
```
|
| 247 |
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
## Training Status
|
| 251 |
-
|
| 252 |
-
**Status: Pending**
|
| 253 |
|
| 254 |
-
|
|
|
|
|
|
|
|
|
|
| 255 |
|
| 256 |
-
|
| 257 |
|
| 258 |
---
|
| 259 |
|
| 260 |
-
##
|
| 261 |
|
| 262 |
-
|
| 263 |
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
-
|
| 269 |
-
-
|
| 270 |
-
-
|
| 271 |
-
|
| 272 |
-
- augmentation strategy
|
| 273 |
-
- label mapping
|
| 274 |
|
| 275 |
-
|
| 276 |
|
| 277 |
-
```
|
| 278 |
-
|
| 279 |
```
|
| 280 |
|
|
|
|
|
|
|
| 281 |
---
|
| 282 |
|
| 283 |
-
##
|
| 284 |
|
| 285 |
-
|
| 286 |
|
| 287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
|
| 289 |
-
|
| 290 |
-
|---|---:|
|
| 291 |
-
| Accuracy | Pending |
|
| 292 |
-
| F1 | Pending |
|
| 293 |
-
| Precision | Pending |
|
| 294 |
-
| Recall | Pending |
|
| 295 |
|
| 296 |
-
|
|
|
|
|
|
|
| 297 |
|
| 298 |
-
|
| 299 |
-
|---|---|
|
| 300 |
-
| DINOv3 + Linear Probe | Minimal frozen-backbone baseline |
|
| 301 |
-
| DINOv3 + MLP Head | Strong simple head baseline |
|
| 302 |
-
| CLIP + Linear Probe | Popular vision-language baseline |
|
| 303 |
-
| ConvNeXt | Strong CNN-style baseline |
|
| 304 |
-
| ViT | Standard transformer baseline |
|
| 305 |
|
| 306 |
-
|
|
|
|
|
|
|
| 307 |
|
| 308 |
-
|
| 309 |
|
| 310 |
-
|
| 311 |
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
|
| 319 |
---
|
| 320 |
|
|
@@ -323,12 +299,12 @@ Planned research questions:
|
|
| 323 |
Known limitations:
|
| 324 |
|
| 325 |
- The architecture is experimental.
|
| 326 |
-
-
|
| 327 |
- The hard-case gate requires threshold tuning.
|
| 328 |
- The Delta-RBF hard expert may overfit small datasets.
|
| 329 |
- Inference may be slower for hard samples.
|
| 330 |
- The model should be compared against simple baselines before claiming improvement.
|
| 331 |
-
- This
|
| 332 |
- The custom head may not generalize outside the dataset it was trained on.
|
| 333 |
|
| 334 |
---
|
|
@@ -365,8 +341,8 @@ If you use this model or build on it, please credit:
|
|
| 365 |
|
| 366 |
```text
|
| 367 |
ProtoMorph-DINO: Feedback-Gated Prototype Morphing for Hard-Case Image Classification
|
| 368 |
-
Author:
|
| 369 |
-
Repository: https://huggingface.co/
|
| 370 |
```
|
| 371 |
|
| 372 |
BibTeX:
|
|
@@ -374,9 +350,9 @@ BibTeX:
|
|
| 374 |
```bibtex
|
| 375 |
@software{protomorph_dino_2026,
|
| 376 |
title = {ProtoMorph-DINO: Feedback-Gated Prototype Morphing for Hard-Case Image Classification},
|
| 377 |
-
author = {
|
| 378 |
year = {2026},
|
| 379 |
-
url = {https://huggingface.co/
|
| 380 |
}
|
| 381 |
```
|
| 382 |
|
|
@@ -387,19 +363,3 @@ BibTeX:
|
|
| 387 |
This is a research prototype.
|
| 388 |
|
| 389 |
The model is provided for experimentation and educational use. It should not be used in production or high-stakes environments without independent validation, dataset auditing, robustness testing, and bias evaluation.
|
| 390 |
-
|
| 391 |
-
---
|
| 392 |
-
|
| 393 |
-
## Project Links
|
| 394 |
-
|
| 395 |
-
GitHub repository: coming soon
|
| 396 |
-
|
| 397 |
-
```text
|
| 398 |
-
https://github.com/shiowo/DINO-Protomorph
|
| 399 |
-
```
|
| 400 |
-
|
| 401 |
-
Hugging Face model page:
|
| 402 |
-
|
| 403 |
-
```text
|
| 404 |
-
https://huggingface.co/shiowo/DINO-Protomorph
|
| 405 |
-
```
|
|
|
|
| 9 |
- dinov3
|
| 10 |
- pytorch
|
| 11 |
- safetensors
|
|
|
|
|
|
|
|
|
|
| 12 |
- prototype-learning
|
| 13 |
- hard-example-mining
|
| 14 |
- feedback-routing
|
| 15 |
- experimental
|
|
|
|
|
|
|
| 16 |
metrics:
|
| 17 |
- accuracy
|
| 18 |
- f1
|
|
|
|
| 20 |
- recall
|
| 21 |
---
|
| 22 |
|
| 23 |
+
# ProtoMorph-DINO
|
| 24 |
|
| 25 |
**Feedback-Gated Prototype Morphing for Hard-Case Image Classification**
|
| 26 |
|
| 27 |
ProtoMorph-DINO is an experimental image classification head designed to run on top of a frozen DINOv3 vision backbone.
|
| 28 |
|
| 29 |
+
This model card is for the Hugging Face repository:
|
| 30 |
|
| 31 |
+
```text
|
| 32 |
+
shiowo/DINO-Protomorph
|
| 33 |
+
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
+
This repository currently contains an initial research scaffold and custom ProtoMorph head checkpoint. Evaluation results are **pending** because the repository is being created before full training and benchmarking.
|
| 36 |
|
| 37 |
+
This project is independent and is not affiliated with Meta AI, Hugging Face, or the official DINOv3 project.
|
| 38 |
|
| 39 |
---
|
| 40 |
|
|
|
|
| 70 |
|
| 71 |
## Model Summary
|
| 72 |
|
| 73 |
+
ProtoMorph-DINO explores whether a frozen foundation vision backbone can be improved with a custom hard-case refinement head.
|
| 74 |
+
|
| 75 |
+
For easy images, the model returns the main classifier output directly. For difficult or ambiguous images, the model activates a feedback branch. The feedback branch uses the top-2 predicted probabilities to modulate the DINO patch map, sends the modified representation through a Delta-RBF hard expert, and fuses the refined logits with the main logits.
|
| 76 |
+
|
| 77 |
+
The main research question is whether feedback-guided hard-case refinement can improve classification performance over simpler frozen-backbone heads such as a linear probe or MLP classifier.
|
| 78 |
+
|
| 79 |
+
---
|
| 80 |
+
|
| 81 |
+
## Current Status
|
| 82 |
+
|
| 83 |
+
**Status: research scaffold / pre-training setup**
|
| 84 |
+
|
| 85 |
+
The current checkpoint may be randomly initialized or only intended for smoke testing unless a later release says otherwise.
|
| 86 |
+
|
| 87 |
+
Predictions are **not meaningful** until the ProtoMorph head is trained on a real dataset.
|
| 88 |
|
| 89 |
+
---
|
| 90 |
+
|
| 91 |
+
## Results
|
| 92 |
+
|
| 93 |
+
**Evaluation results: Pending**
|
| 94 |
+
|
| 95 |
+
No benchmark results are reported yet because the repository is being prepared before training and evaluation.
|
| 96 |
+
|
| 97 |
+
| Metric | Value |
|
| 98 |
+
|---|---:|
|
| 99 |
+
| Accuracy | Pending |
|
| 100 |
+
| F1 | Pending |
|
| 101 |
+
| Precision | Pending |
|
| 102 |
+
| Recall | Pending |
|
| 103 |
+
| Confusion-pair improvement | Pending |
|
| 104 |
+
| Hard-case routing benefit | Pending |
|
| 105 |
|
| 106 |
+
Recommended future baselines:
|
| 107 |
|
| 108 |
+
| Baseline | Purpose |
|
| 109 |
+
|---|---|
|
| 110 |
+
| DINOv3 + Linear Probe | Minimal frozen-backbone baseline |
|
| 111 |
+
| DINOv3 + MLP Head | Strong simple head baseline |
|
| 112 |
+
| CLIP + Linear Probe | Popular vision-language comparison |
|
| 113 |
+
| ConvNeXt | Strong CNN-style baseline |
|
| 114 |
+
| ViT | Standard transformer baseline |
|
| 115 |
|
| 116 |
---
|
| 117 |
|
|
|
|
| 124 |
- prototype learning experiments
|
| 125 |
- frozen-backbone classifier research
|
| 126 |
- fine-grained classification experiments
|
| 127 |
+
- educational computer vision experiments
|
| 128 |
|
| 129 |
This model is **not** intended for safety-critical use.
|
| 130 |
|
| 131 |
+
Do not use this model for medical, legal, financial, biometric, security-critical, or production decisions without independent validation.
|
| 132 |
|
| 133 |
---
|
| 134 |
|
|
|
|
| 139 |
```text
|
| 140 |
.
|
| 141 |
├── README.md
|
| 142 |
+
├── LICENSE-WEIGHTS.md
|
| 143 |
├── config.json
|
| 144 |
├── labels.txt
|
| 145 |
+
├── checkpoints/
|
| 146 |
+
│ ├── config.json
|
| 147 |
+
│ ├── labels.txt
|
| 148 |
+
│ └── protomorph_head.safetensors
|
| 149 |
+
├── infer.py
|
| 150 |
+
├── scripts/
|
| 151 |
+
│ └── upload_to_hf.py
|
| 152 |
+
└── src/
|
| 153 |
+
└── protomorph/
|
| 154 |
```
|
| 155 |
|
| 156 |
+
The main weight file is:
|
| 157 |
|
| 158 |
```text
|
| 159 |
+
checkpoints/protomorph_head.safetensors
|
| 160 |
```
|
| 161 |
|
| 162 |
This file contains only the custom ProtoMorph classification head.
|
| 163 |
|
| 164 |
+
DINOv3 backbone weights are **not** included in this repository.
|
| 165 |
|
| 166 |
---
|
| 167 |
|
|
|
|
| 175 |
|
| 176 |
The backbone is used as a frozen visual feature extractor.
|
| 177 |
|
| 178 |
+
For RTX 3090-class GPUs, ViT-S/16 is a practical starting point because it keeps VRAM usage manageable while still producing useful patch embeddings.
|
| 179 |
|
| 180 |
---
|
| 181 |
|
|
|
|
| 192 |
Install PyTorch:
|
| 193 |
|
| 194 |
```bash
|
| 195 |
+
pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124
|
| 196 |
```
|
| 197 |
|
| 198 |
Install dependencies:
|
| 199 |
|
| 200 |
```bash
|
| 201 |
+
pip install -r requirements-core.txt
|
| 202 |
```
|
| 203 |
|
| 204 |
---
|
| 205 |
|
| 206 |
+
## RunPod Environment Variables
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
+
This project supports the RunPod environment variable names shown below:
|
| 209 |
|
| 210 |
+
```text
|
| 211 |
+
hf_key=hf_your_huggingface_write_token_here
|
| 212 |
+
hf_repo=shiowo/DINO-Protomorph
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
```
|
| 214 |
|
| 215 |
+
Standard Hugging Face names are also supported:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
+
```text
|
| 218 |
+
HF_TOKEN=hf_your_huggingface_write_token_here
|
| 219 |
+
HF_REPO_ID=shiowo/DINO-Protomorph
|
| 220 |
+
```
|
| 221 |
|
| 222 |
+
Never commit your real Hugging Face token to the repository.
|
| 223 |
|
| 224 |
---
|
| 225 |
|
| 226 |
+
## Inference
|
| 227 |
|
| 228 |
+
Run inference from the command line:
|
| 229 |
|
| 230 |
+
```bash
|
| 231 |
+
python infer.py \
|
| 232 |
+
--image examples/sample_image.jpg \
|
| 233 |
+
--config checkpoints/config.json \
|
| 234 |
+
--checkpoint checkpoints/protomorph_head.safetensors \
|
| 235 |
+
--labels checkpoints/labels.txt \
|
| 236 |
+
--topk 5
|
| 237 |
+
```
|
|
|
|
|
|
|
| 238 |
|
| 239 |
+
For smoke testing only:
|
| 240 |
|
| 241 |
+
```bash
|
| 242 |
+
python infer.py --image examples/sample_image.jpg --allow-random-head
|
| 243 |
```
|
| 244 |
|
| 245 |
+
If the head is untrained, the output is only useful for checking that the pipeline runs.
|
| 246 |
+
|
| 247 |
---
|
| 248 |
|
| 249 |
+
## Upload to Hugging Face from RunPod
|
| 250 |
|
| 251 |
+
After setting `hf_key` and `hf_repo` in RunPod, run:
|
| 252 |
|
| 253 |
+
```bash
|
| 254 |
+
cd /workspace/protomorph_dinov3_runpod
|
| 255 |
+
source .venv/bin/activate
|
| 256 |
+
python scripts/upload_to_hf.py
|
| 257 |
+
```
|
| 258 |
|
| 259 |
+
Or use the helper script:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
|
| 261 |
+
```bash
|
| 262 |
+
bash runpod/upload_to_hf.sh
|
| 263 |
+
```
|
| 264 |
|
| 265 |
+
Dry run before upload:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
|
| 267 |
+
```bash
|
| 268 |
+
python scripts/upload_to_hf.py --dry-run
|
| 269 |
+
```
|
| 270 |
|
| 271 |
+
---
|
| 272 |
|
| 273 |
+
## Config Example
|
| 274 |
|
| 275 |
+
```json
|
| 276 |
+
{
|
| 277 |
+
"dino_model_name": "facebook/dinov3-vits16-pretrain-lvd1689m",
|
| 278 |
+
"num_classes": 10,
|
| 279 |
+
"embed_dim": 384,
|
| 280 |
+
"patch_size": 16,
|
| 281 |
+
"proto_count": 64,
|
| 282 |
+
"memory_tokens": 16,
|
| 283 |
+
"rbf_count": 128,
|
| 284 |
+
"num_heads": 8,
|
| 285 |
+
"dropout": 0.0,
|
| 286 |
+
"hard_pmax_threshold": 0.65,
|
| 287 |
+
"hard_margin_threshold": 0.15,
|
| 288 |
+
"hard_entropy_threshold": 1.35,
|
| 289 |
+
"image_size": 512,
|
| 290 |
+
"use_bf16_autocast": true,
|
| 291 |
+
"normalize_patch_tokens": true
|
| 292 |
+
}
|
| 293 |
+
```
|
| 294 |
|
| 295 |
---
|
| 296 |
|
|
|
|
| 299 |
Known limitations:
|
| 300 |
|
| 301 |
- The architecture is experimental.
|
| 302 |
+
- Evaluation results are pending.
|
| 303 |
- The hard-case gate requires threshold tuning.
|
| 304 |
- The Delta-RBF hard expert may overfit small datasets.
|
| 305 |
- Inference may be slower for hard samples.
|
| 306 |
- The model should be compared against simple baselines before claiming improvement.
|
| 307 |
+
- This repository does not include DINOv3 weights.
|
| 308 |
- The custom head may not generalize outside the dataset it was trained on.
|
| 309 |
|
| 310 |
---
|
|
|
|
| 341 |
|
| 342 |
```text
|
| 343 |
ProtoMorph-DINO: Feedback-Gated Prototype Morphing for Hard-Case Image Classification
|
| 344 |
+
Author: shiowo
|
| 345 |
+
Repository: https://huggingface.co/shiowo/DINO-Protomorph
|
| 346 |
```
|
| 347 |
|
| 348 |
BibTeX:
|
|
|
|
| 350 |
```bibtex
|
| 351 |
@software{protomorph_dino_2026,
|
| 352 |
title = {ProtoMorph-DINO: Feedback-Gated Prototype Morphing for Hard-Case Image Classification},
|
| 353 |
+
author = {shiowo},
|
| 354 |
year = {2026},
|
| 355 |
+
url = {https://huggingface.co/shiowo/DINO-Protomorph}
|
| 356 |
}
|
| 357 |
```
|
| 358 |
|
|
|
|
| 363 |
This is a research prototype.
|
| 364 |
|
| 365 |
The model is provided for experimentation and educational use. It should not be used in production or high-stakes environments without independent validation, dataset auditing, robustness testing, and bias evaluation.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README_RUNPOD.md
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ProtoMorph-DINOv3 RunPod Inference + Hugging Face Upload Template
|
| 2 |
+
|
| 3 |
+
This is a runnable experimental inference scaffold for:
|
| 4 |
+
|
| 5 |
+
```text
|
| 6 |
+
Image
|
| 7 |
+
↓
|
| 8 |
+
Frozen DINOv3
|
| 9 |
+
↓
|
| 10 |
+
Patch map z0
|
| 11 |
+
↓
|
| 12 |
+
ProtoMorph block 1
|
| 13 |
+
↓
|
| 14 |
+
Layer Memory Attention
|
| 15 |
+
↓
|
| 16 |
+
ProtoMorph block 2
|
| 17 |
+
↓
|
| 18 |
+
Layer Memory Attention
|
| 19 |
+
↓
|
| 20 |
+
Main logits
|
| 21 |
+
↓
|
| 22 |
+
Hard-case gate
|
| 23 |
+
├── easy: return main logits
|
| 24 |
+
└── hard:
|
| 25 |
+
feedback from top-2 probabilities
|
| 26 |
+
modulate DINO patch map
|
| 27 |
+
run Delta-RBF hard expert
|
| 28 |
+
fuse logits
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
## Practical GPU choice
|
| 32 |
+
|
| 33 |
+
Default backbone: `facebook/dinov3-vits16-pretrain-lvd1689m`.
|
| 34 |
+
|
| 35 |
+
Reason: RTX 3090 has 24 GB VRAM. ViT-S/16 gives 384-dim patch tokens, leaves room for Jupyter, batch inference, the custom hard expert, and future training experiments. You can switch to ViT-B/16 by recreating the head with `--dino-model-name facebook/dinov3-vitb16-pretrain-lvd1689m --embed-dim 768`, but start with ViT-S until the plumbing is stable.
|
| 36 |
+
|
| 37 |
+
## Important compatibility note
|
| 38 |
+
|
| 39 |
+
You said PyTorch 2.4.0 and CUDA 13. PyTorch 2.4.0 official wheels are for CUDA 11.8, 12.1, and 12.4. On RunPod, use the CUDA 12.4 wheel even when the host driver/toolkit is newer:
|
| 40 |
+
|
| 41 |
+
```bash
|
| 42 |
+
pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
## RunPod setup
|
| 46 |
+
|
| 47 |
+
Recommended RunPod template:
|
| 48 |
+
|
| 49 |
+
- RTX 3090
|
| 50 |
+
- Python 3.11
|
| 51 |
+
- PyTorch 2.4.0 if available, otherwise a clean CUDA 12.4 Ubuntu image
|
| 52 |
+
- Persistent volume mounted at `/workspace`
|
| 53 |
+
|
| 54 |
+
```bash
|
| 55 |
+
cd /workspace
|
| 56 |
+
git clone <your-repo-or-upload-this-folder> protomorph_dinov3_runpod
|
| 57 |
+
cd /workspace/protomorph_dinov3_runpod
|
| 58 |
+
bash runpod/setup_runpod.sh
|
| 59 |
+
bash runpod/start_jupyter.sh
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
Open Jupyter on port `8888`. Default token: `protomorph`, or set `JUPYTER_TOKEN`.
|
| 63 |
+
|
| 64 |
+
## Hugging Face access and upload env
|
| 65 |
+
|
| 66 |
+
This package is configured for your Hugging Face model repo:
|
| 67 |
+
|
| 68 |
+
```text
|
| 69 |
+
https://huggingface.co/shiowo/DINO-Protomorph
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
In RunPod, you can use the environment variable names you added:
|
| 73 |
+
|
| 74 |
+
```text
|
| 75 |
+
hf_key=hf_your_huggingface_write_token_here
|
| 76 |
+
hf_repo=shiowo/DINO-Protomorph
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
The code also supports standard Hugging Face names:
|
| 80 |
+
|
| 81 |
+
```text
|
| 82 |
+
HF_TOKEN=hf_your_huggingface_write_token_here
|
| 83 |
+
HF_REPO_ID=shiowo/DINO-Protomorph
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
`hf_key` is never printed by the scripts. Do not commit real tokens.
|
| 87 |
+
|
| 88 |
+
DINOv3 checkpoints may require accepting the model/license on Hugging Face before the frozen backbone can be downloaded. The inference code automatically passes `hf_key`/`HF_TOKEN` to `transformers`.
|
| 89 |
+
|
| 90 |
+
## Create the initial safetensors head
|
| 91 |
+
|
| 92 |
+
The setup script creates a random custom head only if the checkpoint bundle does not already exist:
|
| 93 |
+
|
| 94 |
+
```bash
|
| 95 |
+
python scripts/create_random_head.py --num-classes 10 --out-dir checkpoints
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
This writes:
|
| 99 |
+
|
| 100 |
+
- `checkpoints/config.json`
|
| 101 |
+
- `checkpoints/protomorph_head.safetensors`
|
| 102 |
+
- `checkpoints/labels.txt`
|
| 103 |
+
|
| 104 |
+
The random head is only for smoke tests. Train the head before trusting predictions.
|
| 105 |
+
|
| 106 |
+
To intentionally overwrite it:
|
| 107 |
+
|
| 108 |
+
```bash
|
| 109 |
+
python scripts/create_random_head.py --num-classes 10 --out-dir checkpoints --force
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
## CLI inference
|
| 113 |
+
|
| 114 |
+
```bash
|
| 115 |
+
source .venv/bin/activate
|
| 116 |
+
python infer.py \
|
| 117 |
+
--image /workspace/my_image.jpg \
|
| 118 |
+
--config checkpoints/config.json \
|
| 119 |
+
--checkpoint checkpoints/protomorph_head.safetensors \
|
| 120 |
+
--labels checkpoints/labels.txt \
|
| 121 |
+
--topk 5
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
For plumbing tests without a trained checkpoint:
|
| 125 |
+
|
| 126 |
+
```bash
|
| 127 |
+
python infer.py --image /workspace/my_image.jpg --allow-random-head
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
## Switch to DINOv3 ViT-B/16
|
| 131 |
+
|
| 132 |
+
```bash
|
| 133 |
+
python scripts/create_random_head.py \
|
| 134 |
+
--dino-model-name facebook/dinov3-vitb16-pretrain-lvd1689m \
|
| 135 |
+
--embed-dim 768 \
|
| 136 |
+
--num-classes 10 \
|
| 137 |
+
--out-dir checkpoints_vitb
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
Then use `--config checkpoints_vitb/config.json --checkpoint checkpoints_vitb/protomorph_head.safetensors`.
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
## Upload to Hugging Face
|
| 144 |
+
|
| 145 |
+
After setting `hf_key` and `hf_repo` in RunPod, upload the model card, config, labels, custom head checkpoint, and related inference files with:
|
| 146 |
+
|
| 147 |
+
```bash
|
| 148 |
+
source .venv/bin/activate
|
| 149 |
+
python scripts/upload_to_hf.py --dry-run
|
| 150 |
+
python scripts/upload_to_hf.py
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
Or use the helper script:
|
| 154 |
+
|
| 155 |
+
```bash
|
| 156 |
+
bash runpod/upload_to_hf.sh
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
The target repo defaults to:
|
| 160 |
+
|
| 161 |
+
```text
|
| 162 |
+
shiowo/DINO-Protomorph
|
| 163 |
+
```
|
| 164 |
+
|
| 165 |
+
The upload includes `checkpoints/protomorph_head.safetensors`, but it does not include DINOv3 backbone weights. DINOv3 is loaded separately from Hugging Face during inference.
|
| 166 |
+
|
| 167 |
+
## What is actually saved in safetensors?
|
| 168 |
+
|
| 169 |
+
The `.safetensors` file stores the custom ProtoMorph head only. DINOv3 remains frozen and is loaded from Hugging Face cache. This keeps the experiment checkpoint small and avoids duplicating the foundation model weights.
|
| 170 |
+
|
| 171 |
+
## Files
|
| 172 |
+
|
| 173 |
+
- `src/protomorph/model.py`: architecture implementation
|
| 174 |
+
- `src/protomorph/config.py`: config dataclass
|
| 175 |
+
- `src/protomorph/inference.py`: image loading and prediction helpers
|
| 176 |
+
- `infer.py`: CLI inference
|
| 177 |
+
- `scripts/create_random_head.py`: initialize config + safetensors
|
| 178 |
+
- `scripts/smoke_test_head_only.py`: tests custom head without downloading DINOv3
|
| 179 |
+
- `scripts/upload_to_hf.py`: uploads model card/checkpoint/source files to Hugging Face
|
| 180 |
+
- `runpod/upload_to_hf.sh`: RunPod helper for `hf_key` and `hf_repo` env variables
|
| 181 |
+
- `notebooks/ProtoMorph_DINOv3_Inference.ipynb`: Jupyter inference notebook
|
checkpoints/config.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"dino_model_name": "facebook/dinov3-vits16-pretrain-lvd1689m",
|
| 3 |
+
"num_classes": 10,
|
| 4 |
+
"embed_dim": 384,
|
| 5 |
+
"patch_size": 16,
|
| 6 |
+
"proto_count": 64,
|
| 7 |
+
"memory_tokens": 16,
|
| 8 |
+
"rbf_count": 128,
|
| 9 |
+
"num_heads": 8,
|
| 10 |
+
"dropout": 0.0,
|
| 11 |
+
"hard_pmax_threshold": 0.65,
|
| 12 |
+
"hard_margin_threshold": 0.15,
|
| 13 |
+
"hard_entropy_threshold": 1.35,
|
| 14 |
+
"image_size": 512,
|
| 15 |
+
"use_bf16_autocast": true,
|
| 16 |
+
"normalize_patch_tokens": true
|
| 17 |
+
}
|
checkpoints/labels.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class_0
|
| 2 |
+
class_1
|
| 3 |
+
class_2
|
| 4 |
+
class_3
|
| 5 |
+
class_4
|
| 6 |
+
class_5
|
| 7 |
+
class_6
|
| 8 |
+
class_7
|
| 9 |
+
class_8
|
| 10 |
+
class_9
|
checkpoints/protomorph_head.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ab309078e2db41027fe0148415bb3e2e8e3e6059e0dd89633fd3476d0b72bebe
|
| 3 |
+
size 32451516
|
config.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"dino_model_name": "facebook/dinov3-vits16-pretrain-lvd1689m",
|
| 3 |
+
"num_classes": 10,
|
| 4 |
+
"embed_dim": 384,
|
| 5 |
+
"patch_size": 16,
|
| 6 |
+
"proto_count": 64,
|
| 7 |
+
"memory_tokens": 16,
|
| 8 |
+
"rbf_count": 128,
|
| 9 |
+
"num_heads": 8,
|
| 10 |
+
"dropout": 0.0,
|
| 11 |
+
"hard_pmax_threshold": 0.65,
|
| 12 |
+
"hard_margin_threshold": 0.15,
|
| 13 |
+
"hard_entropy_threshold": 1.35,
|
| 14 |
+
"image_size": 512,
|
| 15 |
+
"use_bf16_autocast": true,
|
| 16 |
+
"normalize_patch_tokens": true
|
| 17 |
+
}
|
infer.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from src.protomorph.inference import build_model, load_labels, predict_paths
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def parse_args() -> argparse.Namespace:
|
| 13 |
+
ap = argparse.ArgumentParser(description="ProtoMorph-DINOv3 inference CLI")
|
| 14 |
+
ap.add_argument("--config", default="checkpoints/config.json")
|
| 15 |
+
ap.add_argument("--checkpoint", default="checkpoints/protomorph_head.safetensors")
|
| 16 |
+
ap.add_argument("--labels", default=None, help="txt/json labels. Defaults to class_0..class_N")
|
| 17 |
+
ap.add_argument("--image", action="append", required=True, help="Image path. Repeat for batch inference.")
|
| 18 |
+
ap.add_argument("--topk", type=int, default=5)
|
| 19 |
+
ap.add_argument("--device", default="cuda")
|
| 20 |
+
ap.add_argument("--force-hard", action="store_true", help="Always run/fuse hard expert branch.")
|
| 21 |
+
ap.add_argument("--local-files-only", action="store_true")
|
| 22 |
+
ap.add_argument("--allow-random-head", action="store_true", help="Smoke test only; logits are random.")
|
| 23 |
+
ap.add_argument("--output", default=None, help="Optional JSON output path")
|
| 24 |
+
return ap.parse_args()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def main() -> None:
|
| 28 |
+
args = parse_args()
|
| 29 |
+
device = args.device if torch.cuda.is_available() or args.device == "cpu" else "cpu"
|
| 30 |
+
model = build_model(
|
| 31 |
+
args.config,
|
| 32 |
+
args.checkpoint,
|
| 33 |
+
device=device,
|
| 34 |
+
local_files_only=args.local_files_only,
|
| 35 |
+
allow_random_head=args.allow_random_head,
|
| 36 |
+
)
|
| 37 |
+
labels = load_labels(args.labels, model.cfg.num_classes)
|
| 38 |
+
results = predict_paths(model, args.image, labels, topk=args.topk, device=device, force_hard=args.force_hard)
|
| 39 |
+
text = json.dumps(results, indent=2)
|
| 40 |
+
print(text)
|
| 41 |
+
if args.output:
|
| 42 |
+
Path(args.output).parent.mkdir(parents=True, exist_ok=True)
|
| 43 |
+
Path(args.output).write_text(text + "\n")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
if __name__ == "__main__":
|
| 47 |
+
main()
|
labels.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class_0
|
| 2 |
+
class_1
|
| 3 |
+
class_2
|
| 4 |
+
class_3
|
| 5 |
+
class_4
|
| 6 |
+
class_5
|
| 7 |
+
class_6
|
| 8 |
+
class_7
|
| 9 |
+
class_8
|
| 10 |
+
class_9
|
notebooks/ProtoMorph_DINOv3_Inference.ipynb
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# ProtoMorph-DINOv3 Inference Notebook\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"This notebook loads a frozen DINOv3 backbone, the custom ProtoMorph head from `safetensors`, and runs the hard-case gated inference path.\n"
|
| 10 |
+
]
|
| 11 |
+
},
|
| 12 |
+
{
|
| 13 |
+
"cell_type": "code",
|
| 14 |
+
"execution_count": null,
|
| 15 |
+
"metadata": {},
|
| 16 |
+
"outputs": [],
|
| 17 |
+
"source": [
|
| 18 |
+
"import os, sys, json, torch\n",
|
| 19 |
+
"from pathlib import Path\n",
|
| 20 |
+
"ROOT = Path('/workspace/protomorph_dinov3_runpod') if Path('/workspace/protomorph_dinov3_runpod').exists() else Path.cwd().parent\n",
|
| 21 |
+
"os.chdir(ROOT)\n",
|
| 22 |
+
"sys.path.insert(0, str(ROOT))\n",
|
| 23 |
+
"print('cwd:', ROOT)\n",
|
| 24 |
+
"print('torch:', torch.__version__)\n",
|
| 25 |
+
"print('cuda available:', torch.cuda.is_available())\n",
|
| 26 |
+
"if torch.cuda.is_available():\n",
|
| 27 |
+
" print('gpu:', torch.cuda.get_device_name(0))\n",
|
| 28 |
+
" print('torch cuda runtime:', torch.version.cuda)\n"
|
| 29 |
+
]
|
| 30 |
+
},
|
| 31 |
+
{
|
| 32 |
+
"cell_type": "markdown",
|
| 33 |
+
"metadata": {},
|
| 34 |
+
"source": [
|
| 35 |
+
"## Create a random head checkpoint for smoke testing\n",
|
| 36 |
+
"Run this once. Random logits are not meaningful; this just proves the pipeline works.\n"
|
| 37 |
+
]
|
| 38 |
+
},
|
| 39 |
+
{
|
| 40 |
+
"cell_type": "code",
|
| 41 |
+
"execution_count": null,
|
| 42 |
+
"metadata": {},
|
| 43 |
+
"outputs": [],
|
| 44 |
+
"source": [
|
| 45 |
+
"!python scripts/create_random_head.py --num-classes 10 --out-dir checkpoints\n"
|
| 46 |
+
]
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"cell_type": "code",
|
| 50 |
+
"execution_count": null,
|
| 51 |
+
"metadata": {},
|
| 52 |
+
"outputs": [],
|
| 53 |
+
"source": [
|
| 54 |
+
"from src.protomorph.inference import build_model, load_labels, predict_paths\n",
|
| 55 |
+
"CONFIG = 'checkpoints/config.json'\n",
|
| 56 |
+
"CKPT = 'checkpoints/protomorph_head.safetensors'\n",
|
| 57 |
+
"LABELS = 'checkpoints/labels.txt'\n",
|
| 58 |
+
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
| 59 |
+
"labels = load_labels(LABELS, num_classes=json.load(open(CONFIG))['num_classes'])\n",
|
| 60 |
+
"model = build_model(CONFIG, CKPT, device=device, allow_random_head=False)\n",
|
| 61 |
+
"print('loaded model on', device)\n"
|
| 62 |
+
]
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"cell_type": "markdown",
|
| 66 |
+
"metadata": {},
|
| 67 |
+
"source": [
|
| 68 |
+
"## Run inference\n",
|
| 69 |
+
"Set `IMAGE_PATH` to an image on your RunPod volume.\n"
|
| 70 |
+
]
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"cell_type": "code",
|
| 74 |
+
"execution_count": null,
|
| 75 |
+
"metadata": {},
|
| 76 |
+
"outputs": [],
|
| 77 |
+
"source": [
|
| 78 |
+
"from PIL import Image\n",
|
| 79 |
+
"import matplotlib.pyplot as plt\n",
|
| 80 |
+
"IMAGE_PATH = '/workspace/my_image.jpg' # change this\n",
|
| 81 |
+
"img = Image.open(IMAGE_PATH).convert('RGB')\n",
|
| 82 |
+
"plt.imshow(img); plt.axis('off');\n"
|
| 83 |
+
]
|
| 84 |
+
},
|
| 85 |
+
{
|
| 86 |
+
"cell_type": "code",
|
| 87 |
+
"execution_count": null,
|
| 88 |
+
"metadata": {},
|
| 89 |
+
"outputs": [],
|
| 90 |
+
"source": [
|
| 91 |
+
"results = predict_paths(model, [IMAGE_PATH], labels, topk=5, device=device)\n",
|
| 92 |
+
"print(json.dumps(results[0], indent=2))\n"
|
| 93 |
+
]
|
| 94 |
+
},
|
| 95 |
+
{
|
| 96 |
+
"cell_type": "markdown",
|
| 97 |
+
"metadata": {},
|
| 98 |
+
"source": [
|
| 99 |
+
"## Force hard branch for debugging\n",
|
| 100 |
+
"This runs the feedback + Delta-RBF expert even if the gate says the image is easy.\n"
|
| 101 |
+
]
|
| 102 |
+
},
|
| 103 |
+
{
|
| 104 |
+
"cell_type": "code",
|
| 105 |
+
"execution_count": null,
|
| 106 |
+
"metadata": {},
|
| 107 |
+
"outputs": [],
|
| 108 |
+
"source": [
|
| 109 |
+
"debug_results = predict_paths(model, [IMAGE_PATH], labels, topk=5, device=device, force_hard=True)\n",
|
| 110 |
+
"print(json.dumps(debug_results[0], indent=2))\n"
|
| 111 |
+
]
|
| 112 |
+
}
|
| 113 |
+
],
|
| 114 |
+
"metadata": {
|
| 115 |
+
"kernelspec": {
|
| 116 |
+
"display_name": "Python 3",
|
| 117 |
+
"language": "python",
|
| 118 |
+
"name": "python3"
|
| 119 |
+
},
|
| 120 |
+
"language_info": {
|
| 121 |
+
"name": "python",
|
| 122 |
+
"version": "3.11"
|
| 123 |
+
}
|
| 124 |
+
},
|
| 125 |
+
"nbformat": 4,
|
| 126 |
+
"nbformat_minor": 5
|
| 127 |
+
}
|
pyproject.toml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "protomorph-dinov3-runpod"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Experimental ProtoMorph + frozen DINOv3 hard-case inference scaffold"
|
| 5 |
+
requires-python = ">=3.11"
|
| 6 |
+
|
| 7 |
+
[tool.setuptools]
|
| 8 |
+
packages = ["src.protomorph"]
|
requirements-core.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers>=4.56.0,<5
|
| 2 |
+
safetensors>=0.4.5
|
| 3 |
+
accelerate>=0.33.0
|
| 4 |
+
pillow>=10.0.0
|
| 5 |
+
numpy>=1.26.0
|
| 6 |
+
tqdm>=4.66.0
|
| 7 |
+
jupyterlab>=4.2.0
|
| 8 |
+
ipywidgets>=8.1.0
|
| 9 |
+
matplotlib>=3.8.0
|
| 10 |
+
huggingface_hub>=0.24.0
|
runpod/setup_runpod.sh
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
PROJECT_DIR="${PROJECT_DIR:-/workspace/protomorph_dinov3_runpod}"
|
| 5 |
+
cd "$PROJECT_DIR"
|
| 6 |
+
|
| 7 |
+
python3.11 -m venv .venv
|
| 8 |
+
source .venv/bin/activate
|
| 9 |
+
python -m pip install --upgrade pip wheel setuptools
|
| 10 |
+
|
| 11 |
+
# PyTorch 2.4.0 does not have official CUDA 13 wheels. Use cu124 on RunPod/RTX 3090.
|
| 12 |
+
pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124
|
| 13 |
+
pip install -r requirements-core.txt
|
| 14 |
+
|
| 15 |
+
# Map RunPod variable names to standard Hugging Face names for download/upload tools.
|
| 16 |
+
if [[ -n "${hf_key:-}" && -z "${HF_TOKEN:-}" ]]; then
|
| 17 |
+
export HF_TOKEN="$hf_key"
|
| 18 |
+
fi
|
| 19 |
+
if [[ -n "${hf_repo:-}" && -z "${HF_REPO_ID:-}" ]]; then
|
| 20 |
+
export HF_REPO_ID="$hf_repo"
|
| 21 |
+
fi
|
| 22 |
+
|
| 23 |
+
mkdir -p "${HF_HOME:-/workspace/hf_cache}"
|
| 24 |
+
python scripts/create_random_head.py --num-classes 10 --out-dir checkpoints
|
| 25 |
+
python scripts/smoke_test_head_only.py
|
| 26 |
+
|
| 27 |
+
echo
|
| 28 |
+
echo "Setup complete."
|
| 29 |
+
echo "To start Jupyter: bash runpod/start_jupyter.sh"
|
| 30 |
+
echo "To upload to Hugging Face: bash runpod/upload_to_hf.sh"
|
runpod/start_jupyter.sh
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
PROJECT_DIR="${PROJECT_DIR:-/workspace/protomorph_dinov3_runpod}"
|
| 5 |
+
cd "$PROJECT_DIR"
|
| 6 |
+
source .venv/bin/activate
|
| 7 |
+
|
| 8 |
+
if [[ -n "${hf_key:-}" && -z "${HF_TOKEN:-}" ]]; then
|
| 9 |
+
export HF_TOKEN="$hf_key"
|
| 10 |
+
fi
|
| 11 |
+
if [[ -n "${hf_repo:-}" && -z "${HF_REPO_ID:-}" ]]; then
|
| 12 |
+
export HF_REPO_ID="$hf_repo"
|
| 13 |
+
fi
|
| 14 |
+
|
| 15 |
+
export HF_HOME="${HF_HOME:-/workspace/hf_cache}"
|
| 16 |
+
export JUPYTER_TOKEN="${JUPYTER_TOKEN:-protomorph}"
|
| 17 |
+
|
| 18 |
+
echo "Jupyter token: $JUPYTER_TOKEN"
|
| 19 |
+
jupyter lab --ip=0.0.0.0 --port=8888 --no-browser --allow-root --NotebookApp.token="$JUPYTER_TOKEN"
|
runpod/upload_to_hf.sh
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
PROJECT_DIR="${PROJECT_DIR:-/workspace/protomorph_dinov3_runpod}"
|
| 5 |
+
cd "$PROJECT_DIR"
|
| 6 |
+
|
| 7 |
+
if [[ -d .venv ]]; then
|
| 8 |
+
source .venv/bin/activate
|
| 9 |
+
fi
|
| 10 |
+
|
| 11 |
+
# Map the user's RunPod env names to Hugging Face standard names for tools
|
| 12 |
+
# that only look for HF_TOKEN/HF_REPO_ID.
|
| 13 |
+
if [[ -n "${hf_key:-}" && -z "${HF_TOKEN:-}" ]]; then
|
| 14 |
+
export HF_TOKEN="$hf_key"
|
| 15 |
+
fi
|
| 16 |
+
if [[ -n "${hf_repo:-}" && -z "${HF_REPO_ID:-}" ]]; then
|
| 17 |
+
export HF_REPO_ID="$hf_repo"
|
| 18 |
+
fi
|
| 19 |
+
|
| 20 |
+
python scripts/upload_to_hf.py \
|
| 21 |
+
--repo-id "${HF_REPO_ID:-${hf_repo:-shiowo/DINO-Protomorph}}" \
|
| 22 |
+
--commit-message "${HF_COMMIT_MESSAGE:-Upload ProtoMorph-DINO scaffold and checkpoint}"
|
scripts/create_random_head.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import sys
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def parse_args() -> argparse.Namespace:
|
| 11 |
+
ap = argparse.ArgumentParser(description="Create initial ProtoMorph custom-head safetensors checkpoint")
|
| 12 |
+
ap.add_argument("--out-dir", default="checkpoints")
|
| 13 |
+
ap.add_argument("--dino-model-name", default="facebook/dinov3-vits16-pretrain-lvd1689m")
|
| 14 |
+
ap.add_argument("--num-classes", type=int, default=10)
|
| 15 |
+
ap.add_argument("--embed-dim", type=int, default=None)
|
| 16 |
+
ap.add_argument("--image-size", type=int, default=512)
|
| 17 |
+
ap.add_argument("--proto-count", type=int, default=64)
|
| 18 |
+
ap.add_argument("--memory-tokens", type=int, default=16)
|
| 19 |
+
ap.add_argument("--rbf-count", type=int, default=128)
|
| 20 |
+
ap.add_argument("--num-heads", type=int, default=8)
|
| 21 |
+
ap.add_argument("--force", action="store_true", help="Overwrite existing config/checkpoint/labels")
|
| 22 |
+
return ap.parse_args()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def main() -> None:
|
| 26 |
+
args = parse_args()
|
| 27 |
+
out_dir = Path(args.out_dir)
|
| 28 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 29 |
+
cfg_path = out_dir / "config.json"
|
| 30 |
+
ckpt_path = out_dir / "protomorph_head.safetensors"
|
| 31 |
+
labels_path = out_dir / "labels.txt"
|
| 32 |
+
|
| 33 |
+
if not args.force and cfg_path.exists() and ckpt_path.exists() and labels_path.exists():
|
| 34 |
+
print(f"Existing checkpoint bundle found in {out_dir}; not overwriting. Pass --force to recreate it.")
|
| 35 |
+
return
|
| 36 |
+
|
| 37 |
+
from safetensors.torch import save_file
|
| 38 |
+
from src.protomorph.config import ProtoMorphConfig
|
| 39 |
+
from src.protomorph.model import ProtoMorphHead, infer_embed_dim_from_model_name
|
| 40 |
+
|
| 41 |
+
embed_dim = args.embed_dim or infer_embed_dim_from_model_name(args.dino_model_name)
|
| 42 |
+
cfg = ProtoMorphConfig(
|
| 43 |
+
dino_model_name=args.dino_model_name,
|
| 44 |
+
num_classes=args.num_classes,
|
| 45 |
+
embed_dim=embed_dim,
|
| 46 |
+
image_size=args.image_size,
|
| 47 |
+
proto_count=args.proto_count,
|
| 48 |
+
memory_tokens=args.memory_tokens,
|
| 49 |
+
rbf_count=args.rbf_count,
|
| 50 |
+
num_heads=args.num_heads,
|
| 51 |
+
)
|
| 52 |
+
head = ProtoMorphHead(cfg)
|
| 53 |
+
cfg.to_json(cfg_path)
|
| 54 |
+
save_file(head.state_dict(), str(ckpt_path))
|
| 55 |
+
labels_path.write_text("\n".join([f"class_{i}" for i in range(args.num_classes)]) + "\n")
|
| 56 |
+
print(f"Wrote {cfg_path}")
|
| 57 |
+
print(f"Wrote {ckpt_path}")
|
| 58 |
+
print("Important: this is a random head for plumbing/smoke tests. Train it before real predictions.")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
if __name__ == "__main__":
|
| 62 |
+
main()
|
scripts/smoke_test_head_only.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from src.protomorph.config import ProtoMorphConfig
|
| 11 |
+
from src.protomorph.model import ProtoMorphHead
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def main() -> None:
|
| 15 |
+
torch.set_num_threads(1)
|
| 16 |
+
cfg = ProtoMorphConfig(
|
| 17 |
+
num_classes=7,
|
| 18 |
+
embed_dim=32,
|
| 19 |
+
proto_count=8,
|
| 20 |
+
memory_tokens=4,
|
| 21 |
+
rbf_count=16,
|
| 22 |
+
num_heads=4,
|
| 23 |
+
)
|
| 24 |
+
head = ProtoMorphHead(cfg).eval()
|
| 25 |
+
cls = torch.randn(2, cfg.embed_dim)
|
| 26 |
+
patches = torch.randn(2, 8 * 8, cfg.embed_dim)
|
| 27 |
+
with torch.no_grad():
|
| 28 |
+
out = head(cls, patches)
|
| 29 |
+
assert out["logits"].shape == (2, cfg.num_classes)
|
| 30 |
+
assert out["hard_mask"].shape == (2,)
|
| 31 |
+
print("OK head-only smoke test", out["logits"].shape, out["hard_mask"].tolist())
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
if __name__ == "__main__":
|
| 35 |
+
main()
|
scripts/upload_to_hf.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Iterable, List
|
| 9 |
+
|
| 10 |
+
# Allow running from the repo root without installing the package.
|
| 11 |
+
ROOT = Path(__file__).resolve().parents[1]
|
| 12 |
+
if str(ROOT) not in sys.path:
|
| 13 |
+
sys.path.insert(0, str(ROOT))
|
| 14 |
+
|
| 15 |
+
from src.protomorph.hf_utils import get_hf_repo_id, get_hf_token, normalize_repo_id
|
| 16 |
+
|
| 17 |
+
DEFAULT_REPO = "shiowo/DINO-Protomorph"
|
| 18 |
+
REQUIRED_FILES = [
|
| 19 |
+
"README.md",
|
| 20 |
+
"checkpoints/config.json",
|
| 21 |
+
"checkpoints/labels.txt",
|
| 22 |
+
"checkpoints/protomorph_head.safetensors",
|
| 23 |
+
"src/protomorph/model.py",
|
| 24 |
+
"src/protomorph/config.py",
|
| 25 |
+
"infer.py",
|
| 26 |
+
]
|
| 27 |
+
IGNORE_PATTERNS = [
|
| 28 |
+
".git/*",
|
| 29 |
+
".venv/*",
|
| 30 |
+
"venv/*",
|
| 31 |
+
"env/*",
|
| 32 |
+
"__pycache__/*",
|
| 33 |
+
"**/__pycache__/*",
|
| 34 |
+
"*.pyc",
|
| 35 |
+
".ipynb_checkpoints/*",
|
| 36 |
+
"**/.ipynb_checkpoints/*",
|
| 37 |
+
".cache/*",
|
| 38 |
+
"hf_cache/*",
|
| 39 |
+
"outputs/*",
|
| 40 |
+
"wandb/*",
|
| 41 |
+
"data/*",
|
| 42 |
+
"datasets/*",
|
| 43 |
+
"*.zip",
|
| 44 |
+
"*.tar",
|
| 45 |
+
"*.tar.gz",
|
| 46 |
+
"*.7z",
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def human_size(n: int) -> str:
|
| 51 |
+
units = ["B", "KB", "MB", "GB", "TB"]
|
| 52 |
+
size = float(n)
|
| 53 |
+
for unit in units:
|
| 54 |
+
if size < 1024 or unit == units[-1]:
|
| 55 |
+
return f"{size:.1f} {unit}" if unit != "B" else f"{int(size)} B"
|
| 56 |
+
size /= 1024
|
| 57 |
+
return f"{n} B"
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def iter_upload_files(source: Path, ignore_dirs: Iterable[str]) -> List[Path]:
|
| 61 |
+
ignore_dir_names = set(ignore_dirs)
|
| 62 |
+
files: List[Path] = []
|
| 63 |
+
for path in source.rglob("*"):
|
| 64 |
+
if path.is_dir():
|
| 65 |
+
continue
|
| 66 |
+
rel = path.relative_to(source)
|
| 67 |
+
parts = set(rel.parts)
|
| 68 |
+
if parts & ignore_dir_names:
|
| 69 |
+
continue
|
| 70 |
+
if path.suffix in {".pyc", ".zip", ".7z"}:
|
| 71 |
+
continue
|
| 72 |
+
files.append(rel)
|
| 73 |
+
return sorted(files)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def check_required(source: Path) -> None:
|
| 77 |
+
missing = [rel for rel in REQUIRED_FILES if not (source / rel).exists()]
|
| 78 |
+
if missing:
|
| 79 |
+
joined = "\n - ".join(missing)
|
| 80 |
+
raise FileNotFoundError(f"Missing required files for HF upload:\n - {joined}")
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def parse_args() -> argparse.Namespace:
|
| 84 |
+
p = argparse.ArgumentParser(description="Upload ProtoMorph-DINO files to a Hugging Face model repo.")
|
| 85 |
+
p.add_argument("--source", default=".", help="Folder to upload. Default: current project root.")
|
| 86 |
+
p.add_argument("--repo-id", default=None, help="HF repo id or URL. Default: env hf_repo/HF_REPO_ID, then shiowo/DINO-Protomorph.")
|
| 87 |
+
p.add_argument("--token", default=None, help="HF token. Default: env hf_key/HF_TOKEN/etc. Do not paste this into logs.")
|
| 88 |
+
p.add_argument("--revision", default="main", help="Target branch/revision. Default: main.")
|
| 89 |
+
p.add_argument("--private", action="store_true", help="Create repo as private if it does not exist yet.")
|
| 90 |
+
p.add_argument("--no-create", action="store_true", help="Do not create the repo if missing.")
|
| 91 |
+
p.add_argument("--dry-run", action="store_true", help="Print what would be uploaded, then exit.")
|
| 92 |
+
p.add_argument("--commit-message", default="Upload ProtoMorph-DINO scaffold and checkpoint", help="HF commit message.")
|
| 93 |
+
return p.parse_args()
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def main() -> None:
|
| 97 |
+
args = parse_args()
|
| 98 |
+
source = Path(args.source).resolve()
|
| 99 |
+
if not source.exists() or not source.is_dir():
|
| 100 |
+
raise NotADirectoryError(f"Source folder does not exist: {source}")
|
| 101 |
+
|
| 102 |
+
repo_id = normalize_repo_id(args.repo_id) if args.repo_id else get_hf_repo_id(DEFAULT_REPO)
|
| 103 |
+
token = args.token or get_hf_token()
|
| 104 |
+
|
| 105 |
+
check_required(source)
|
| 106 |
+
|
| 107 |
+
files = iter_upload_files(source, ignore_dirs={".git", ".venv", "venv", "env", "__pycache__", ".ipynb_checkpoints", ".cache", "hf_cache", "outputs", "wandb", "data", "datasets"})
|
| 108 |
+
total_bytes = sum((source / f).stat().st_size for f in files)
|
| 109 |
+
|
| 110 |
+
print(f"HF repo: {repo_id}")
|
| 111 |
+
print(f"Source: {source}")
|
| 112 |
+
print(f"Files: {len(files)} files, {human_size(total_bytes)}")
|
| 113 |
+
print("Token: " + ("found" if token else "missing"))
|
| 114 |
+
|
| 115 |
+
if args.dry_run:
|
| 116 |
+
print("\nDry run file list:")
|
| 117 |
+
for rel in files:
|
| 118 |
+
print(f" {rel}")
|
| 119 |
+
print("\nNo upload performed.")
|
| 120 |
+
return
|
| 121 |
+
|
| 122 |
+
if not token:
|
| 123 |
+
raise RuntimeError(
|
| 124 |
+
"No Hugging Face token found. In RunPod environment variables, set `hf_key=hf_xxx`, "
|
| 125 |
+
"or set standard `HF_TOKEN=hf_xxx`."
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
try:
|
| 129 |
+
from huggingface_hub import HfApi
|
| 130 |
+
except ImportError as e:
|
| 131 |
+
raise ImportError("Install huggingface_hub first: pip install huggingface_hub") from e
|
| 132 |
+
|
| 133 |
+
api = HfApi(token=token)
|
| 134 |
+
if not args.no_create:
|
| 135 |
+
api.create_repo(repo_id=repo_id, repo_type="model", private=args.private, exist_ok=True)
|
| 136 |
+
|
| 137 |
+
api.upload_folder(
|
| 138 |
+
folder_path=str(source),
|
| 139 |
+
repo_id=repo_id,
|
| 140 |
+
repo_type="model",
|
| 141 |
+
revision=args.revision,
|
| 142 |
+
commit_message=args.commit_message,
|
| 143 |
+
ignore_patterns=IGNORE_PATTERNS,
|
| 144 |
+
)
|
| 145 |
+
print(f"\nUpload complete: https://huggingface.co/{repo_id}")
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
if __name__ == "__main__":
|
| 149 |
+
main()
|
src/protomorph/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .config import ProtoMorphConfig
|
| 2 |
+
from .model import ProtoMorphDINOv3, ProtoMorphHead, infer_embed_dim_from_model_name
|
| 3 |
+
from .inference import build_model, predict_paths, load_labels
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"ProtoMorphConfig",
|
| 7 |
+
"ProtoMorphDINOv3",
|
| 8 |
+
"ProtoMorphHead",
|
| 9 |
+
"infer_embed_dim_from_model_name",
|
| 10 |
+
"build_model",
|
| 11 |
+
"predict_paths",
|
| 12 |
+
"load_labels",
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
from .hf_utils import get_hf_token, get_hf_repo_id, normalize_repo_id
|
src/protomorph/config.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import asdict, dataclass
|
| 4 |
+
import json
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any, Dict
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class ProtoMorphConfig:
|
| 11 |
+
"""Configuration for the custom ProtoMorph head.
|
| 12 |
+
|
| 13 |
+
The frozen DINOv3 backbone is loaded separately from Hugging Face. The
|
| 14 |
+
safetensors checkpoint stores only the trainable experimental head, which is
|
| 15 |
+
what you will train/tune for your dataset.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
dino_model_name: str = "facebook/dinov3-vits16-pretrain-lvd1689m"
|
| 19 |
+
num_classes: int = 10
|
| 20 |
+
embed_dim: int = 384
|
| 21 |
+
patch_size: int = 16
|
| 22 |
+
|
| 23 |
+
# ProtoMorph blocks
|
| 24 |
+
proto_count: int = 64
|
| 25 |
+
memory_tokens: int = 16
|
| 26 |
+
rbf_count: int = 128
|
| 27 |
+
num_heads: int = 8
|
| 28 |
+
dropout: float = 0.0
|
| 29 |
+
|
| 30 |
+
# Hard-case gate thresholds
|
| 31 |
+
hard_pmax_threshold: float = 0.65
|
| 32 |
+
hard_margin_threshold: float = 0.15
|
| 33 |
+
hard_entropy_threshold: float = 1.35
|
| 34 |
+
|
| 35 |
+
# Inference / performance knobs
|
| 36 |
+
image_size: int = 512
|
| 37 |
+
use_bf16_autocast: bool = True
|
| 38 |
+
normalize_patch_tokens: bool = True
|
| 39 |
+
|
| 40 |
+
@classmethod
|
| 41 |
+
def from_json(cls, path: str | Path) -> "ProtoMorphConfig":
|
| 42 |
+
data = json.loads(Path(path).read_text())
|
| 43 |
+
return cls(**data)
|
| 44 |
+
|
| 45 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 46 |
+
return asdict(self)
|
| 47 |
+
|
| 48 |
+
def to_json(self, path: str | Path) -> None:
|
| 49 |
+
p = Path(path)
|
| 50 |
+
p.parent.mkdir(parents=True, exist_ok=True)
|
| 51 |
+
p.write_text(json.dumps(self.to_dict(), indent=2) + "\n")
|
src/protomorph/hf_utils.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from typing import Iterable, Optional
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
HF_TOKEN_ENV_NAMES = (
|
| 8 |
+
"hf_key", # RunPod env name used by this project
|
| 9 |
+
"HF_TOKEN", # Hugging Face standard
|
| 10 |
+
"HUGGINGFACE_HUB_TOKEN",
|
| 11 |
+
"HUGGING_FACE_HUB_TOKEN",
|
| 12 |
+
"HUGGINGFACE_TOKEN",
|
| 13 |
+
"HF_API_TOKEN",
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
HF_REPO_ENV_NAMES = (
|
| 17 |
+
"hf_repo", # RunPod env name used by this project
|
| 18 |
+
"HF_REPO",
|
| 19 |
+
"HF_REPO_ID",
|
| 20 |
+
"HUGGINGFACE_REPO",
|
| 21 |
+
"HUGGINGFACE_REPO_ID",
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def first_env(names: Iterable[str]) -> Optional[str]:
|
| 26 |
+
"""Return the first non-empty environment variable value from names."""
|
| 27 |
+
for name in names:
|
| 28 |
+
value = os.environ.get(name)
|
| 29 |
+
if value and value.strip():
|
| 30 |
+
return value.strip()
|
| 31 |
+
return None
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_hf_token() -> Optional[str]:
|
| 35 |
+
"""Read a Hugging Face token from common env names.
|
| 36 |
+
|
| 37 |
+
RunPod users can set `hf_key=hf_...`. This helper maps that to the token
|
| 38 |
+
argument used by `transformers` and `huggingface_hub` without printing it.
|
| 39 |
+
"""
|
| 40 |
+
return first_env(HF_TOKEN_ENV_NAMES)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def normalize_repo_id(repo_id_or_url: str) -> str:
|
| 44 |
+
"""Accept `shiowo/DINO-Protomorph` or full HF URLs and return a repo_id."""
|
| 45 |
+
value = repo_id_or_url.strip()
|
| 46 |
+
prefixes = (
|
| 47 |
+
"https://huggingface.co/",
|
| 48 |
+
"http://huggingface.co/",
|
| 49 |
+
"huggingface.co/",
|
| 50 |
+
)
|
| 51 |
+
for prefix in prefixes:
|
| 52 |
+
if value.startswith(prefix):
|
| 53 |
+
value = value[len(prefix):]
|
| 54 |
+
break
|
| 55 |
+
value = value.strip("/")
|
| 56 |
+
if value.startswith("models/"):
|
| 57 |
+
value = value[len("models/"):]
|
| 58 |
+
if "/tree/" in value:
|
| 59 |
+
value = value.split("/tree/", 1)[0]
|
| 60 |
+
if "/blob/" in value:
|
| 61 |
+
value = value.split("/blob/", 1)[0]
|
| 62 |
+
return value
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def get_hf_repo_id(default: Optional[str] = None) -> Optional[str]:
|
| 66 |
+
value = first_env(HF_REPO_ENV_NAMES) or default
|
| 67 |
+
return normalize_repo_id(value) if value else None
|
src/protomorph/inference.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Dict, List, Optional, Sequence
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
from .config import ProtoMorphConfig
|
| 11 |
+
from .model import ProtoMorphDINOv3
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def load_image(path: str | Path) -> Image.Image:
|
| 15 |
+
return Image.open(path).convert("RGB")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def load_labels(path: Optional[str | Path], num_classes: int) -> List[str]:
|
| 19 |
+
if path is None:
|
| 20 |
+
return [f"class_{i}" for i in range(num_classes)]
|
| 21 |
+
p = Path(path)
|
| 22 |
+
if p.suffix.lower() == ".json":
|
| 23 |
+
data = json.loads(p.read_text())
|
| 24 |
+
if isinstance(data, dict):
|
| 25 |
+
return [data.get(str(i), data.get(i, f"class_{i}")) for i in range(num_classes)]
|
| 26 |
+
return list(data)
|
| 27 |
+
labels = [line.strip() for line in p.read_text().splitlines() if line.strip()]
|
| 28 |
+
if len(labels) < num_classes:
|
| 29 |
+
labels += [f"class_{i}" for i in range(len(labels), num_classes)]
|
| 30 |
+
return labels[:num_classes]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def build_model(
|
| 34 |
+
config_path: str | Path,
|
| 35 |
+
checkpoint_path: Optional[str | Path],
|
| 36 |
+
device: str = "cuda",
|
| 37 |
+
local_files_only: bool = False,
|
| 38 |
+
allow_random_head: bool = False,
|
| 39 |
+
) -> ProtoMorphDINOv3:
|
| 40 |
+
cfg = ProtoMorphConfig.from_json(config_path)
|
| 41 |
+
device_obj = torch.device(device if torch.cuda.is_available() or device == "cpu" else "cpu")
|
| 42 |
+
model = ProtoMorphDINOv3(cfg, local_files_only=local_files_only).to(device_obj).eval()
|
| 43 |
+
if checkpoint_path is not None and Path(checkpoint_path).exists():
|
| 44 |
+
model.load_custom_head(checkpoint_path)
|
| 45 |
+
elif not allow_random_head:
|
| 46 |
+
raise FileNotFoundError(
|
| 47 |
+
f"Missing custom-head checkpoint: {checkpoint_path}. "
|
| 48 |
+
"Pass --allow-random-head only for smoke tests; random logits are not meaningful."
|
| 49 |
+
)
|
| 50 |
+
return model
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@torch.no_grad()
|
| 54 |
+
def predict_paths(
|
| 55 |
+
model: ProtoMorphDINOv3,
|
| 56 |
+
image_paths: Sequence[str | Path],
|
| 57 |
+
labels: List[str],
|
| 58 |
+
topk: int = 5,
|
| 59 |
+
device: str = "cuda",
|
| 60 |
+
force_hard: bool = False,
|
| 61 |
+
) -> List[Dict]:
|
| 62 |
+
images = [load_image(p) for p in image_paths]
|
| 63 |
+
out = model(images, device=device, force_hard=force_hard)
|
| 64 |
+
probs = out["logits"].softmax(dim=-1).float().cpu()
|
| 65 |
+
main_probs = out["main_logits"].softmax(dim=-1).float().cpu()
|
| 66 |
+
hard_mask = out["hard_mask"].cpu().tolist()
|
| 67 |
+
gate_pmax = out["gate_pmax"].float().cpu().tolist()
|
| 68 |
+
gate_margin = out["gate_margin"].float().cpu().tolist()
|
| 69 |
+
gate_entropy = out["gate_entropy"].float().cpu().tolist()
|
| 70 |
+
|
| 71 |
+
results: List[Dict] = []
|
| 72 |
+
for i, path in enumerate(image_paths):
|
| 73 |
+
k = min(topk, probs.shape[-1])
|
| 74 |
+
values, indices = probs[i].topk(k)
|
| 75 |
+
main_values, main_indices = main_probs[i].topk(k)
|
| 76 |
+
results.append(
|
| 77 |
+
{
|
| 78 |
+
"image": str(path),
|
| 79 |
+
"hard_case": bool(hard_mask[i]),
|
| 80 |
+
"gate": {
|
| 81 |
+
"pmax": float(gate_pmax[i]),
|
| 82 |
+
"margin": float(gate_margin[i]),
|
| 83 |
+
"entropy": float(gate_entropy[i]),
|
| 84 |
+
},
|
| 85 |
+
"topk": [
|
| 86 |
+
{"rank": r + 1, "class_id": int(idx), "label": labels[int(idx)], "prob": float(val)}
|
| 87 |
+
for r, (idx, val) in enumerate(zip(indices.tolist(), values.tolist()))
|
| 88 |
+
],
|
| 89 |
+
"main_topk": [
|
| 90 |
+
{"rank": r + 1, "class_id": int(idx), "label": labels[int(idx)], "prob": float(val)}
|
| 91 |
+
for r, (idx, val) in enumerate(zip(main_indices.tolist(), main_values.tolist()))
|
| 92 |
+
],
|
| 93 |
+
"patch_hw": out["patch_hw"],
|
| 94 |
+
"pixel_hw": out["pixel_hw"],
|
| 95 |
+
}
|
| 96 |
+
)
|
| 97 |
+
return results
|
src/protomorph/model.py
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import Tensor, nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from PIL import Image
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from safetensors.torch import load_file as safe_load_file
|
| 14 |
+
from safetensors.torch import save_file as safe_save_file
|
| 15 |
+
except Exception: # pragma: no cover - handled at runtime with better error.
|
| 16 |
+
safe_load_file = None
|
| 17 |
+
safe_save_file = None
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from transformers import AutoImageProcessor, AutoModel
|
| 21 |
+
except Exception: # pragma: no cover - handled at runtime with better error.
|
| 22 |
+
AutoImageProcessor = None
|
| 23 |
+
AutoModel = None
|
| 24 |
+
|
| 25 |
+
from .config import ProtoMorphConfig
|
| 26 |
+
from .hf_utils import get_hf_token
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class DinoFeatures:
|
| 31 |
+
cls: Tensor
|
| 32 |
+
registers: Optional[Tensor]
|
| 33 |
+
patches: Tensor
|
| 34 |
+
patch_hw: Tuple[int, int]
|
| 35 |
+
pixel_hw: Tuple[int, int]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class FrozenDINOv3(nn.Module):
|
| 39 |
+
"""Hugging Face DINOv3 wrapper that returns CLS/register/patch tokens.
|
| 40 |
+
|
| 41 |
+
DINOv3 is kept frozen. Use torch.autocast during forward for memory savings
|
| 42 |
+
on RTX 3090; the custom head remains regular PyTorch modules.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(self, model_name: str, image_size: int = 512, local_files_only: bool = False):
|
| 46 |
+
super().__init__()
|
| 47 |
+
if AutoImageProcessor is None or AutoModel is None:
|
| 48 |
+
raise ImportError(
|
| 49 |
+
"transformers is required. Install transformers>=4.56.0 before loading DINOv3."
|
| 50 |
+
)
|
| 51 |
+
self.model_name = model_name
|
| 52 |
+
self.image_size = image_size
|
| 53 |
+
hf_token = get_hf_token()
|
| 54 |
+
hf_kwargs = {"local_files_only": local_files_only}
|
| 55 |
+
if hf_token:
|
| 56 |
+
# Supports RunPod env variable `hf_key` as well as standard HF_TOKEN.
|
| 57 |
+
hf_kwargs["token"] = hf_token
|
| 58 |
+
self.processor = AutoImageProcessor.from_pretrained(model_name, **hf_kwargs)
|
| 59 |
+
self.model = AutoModel.from_pretrained(model_name, **hf_kwargs)
|
| 60 |
+
self.model.eval().requires_grad_(False)
|
| 61 |
+
|
| 62 |
+
config = self.model.config
|
| 63 |
+
self.patch_size = int(getattr(config, "patch_size", 16))
|
| 64 |
+
self.hidden_size = int(getattr(config, "hidden_size", 0))
|
| 65 |
+
self.num_register_tokens = int(getattr(config, "num_register_tokens", 0))
|
| 66 |
+
|
| 67 |
+
def _prepare_images(self, images: Image.Image | Sequence[Image.Image]) -> Dict[str, Tensor]:
|
| 68 |
+
if isinstance(images, Image.Image):
|
| 69 |
+
images = [images]
|
| 70 |
+
# HF processors support overriding target size at call time for ViT-like image processors.
|
| 71 |
+
# We request a square size that is divisible by patch_size for clean patch grids.
|
| 72 |
+
size = {"height": self.image_size, "width": self.image_size}
|
| 73 |
+
return self.processor(images=list(images), return_tensors="pt", size=size)
|
| 74 |
+
|
| 75 |
+
@torch.no_grad()
|
| 76 |
+
def forward(self, images: Image.Image | Sequence[Image.Image], device: torch.device | str) -> DinoFeatures:
|
| 77 |
+
inputs = self._prepare_images(images)
|
| 78 |
+
pixel_values = inputs["pixel_values"].to(device, non_blocking=True)
|
| 79 |
+
outputs = self.model(pixel_values=pixel_values)
|
| 80 |
+
|
| 81 |
+
tokens = outputs.last_hidden_state
|
| 82 |
+
cls = tokens[:, 0]
|
| 83 |
+
reg_start = 1
|
| 84 |
+
reg_end = 1 + self.num_register_tokens
|
| 85 |
+
registers = tokens[:, reg_start:reg_end] if self.num_register_tokens > 0 else None
|
| 86 |
+
patches = tokens[:, reg_end:]
|
| 87 |
+
|
| 88 |
+
h, w = pixel_values.shape[-2:]
|
| 89 |
+
ph, pw = h // self.patch_size, w // self.patch_size
|
| 90 |
+
expected = ph * pw
|
| 91 |
+
if patches.shape[1] != expected:
|
| 92 |
+
# Fallback for processors/checkpoints that return a non-square crop or resize.
|
| 93 |
+
# This keeps inference running and makes the mismatch visible to the caller.
|
| 94 |
+
side = int(patches.shape[1] ** 0.5)
|
| 95 |
+
if side * side == patches.shape[1]:
|
| 96 |
+
ph, pw = side, side
|
| 97 |
+
else:
|
| 98 |
+
ph, pw = patches.shape[1], 1
|
| 99 |
+
return DinoFeatures(cls=cls, registers=registers, patches=patches, patch_hw=(ph, pw), pixel_hw=(h, w))
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class FeedForward(nn.Module):
|
| 103 |
+
def __init__(self, dim: int, expansion: int = 4, dropout: float = 0.0):
|
| 104 |
+
super().__init__()
|
| 105 |
+
hidden = dim * expansion
|
| 106 |
+
self.net = nn.Sequential(
|
| 107 |
+
nn.Linear(dim, hidden),
|
| 108 |
+
nn.GELU(),
|
| 109 |
+
nn.Dropout(dropout),
|
| 110 |
+
nn.Linear(hidden, dim),
|
| 111 |
+
nn.Dropout(dropout),
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 115 |
+
return self.net(x)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class ProtoMorphBlock(nn.Module):
|
| 119 |
+
"""Prototype-morphing residual block over DINO patch tokens.
|
| 120 |
+
|
| 121 |
+
It computes soft assignment of each patch token to learnable prototypes, then
|
| 122 |
+
mixes original token, nearest prototype context, difference, and product.
|
| 123 |
+
This creates a lightweight nonstandard CNN replacement over patch embeddings.
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
def __init__(self, dim: int, proto_count: int, dropout: float = 0.0):
|
| 127 |
+
super().__init__()
|
| 128 |
+
self.norm = nn.LayerNorm(dim)
|
| 129 |
+
self.prototypes = nn.Parameter(torch.randn(proto_count, dim) * 0.02)
|
| 130 |
+
self.log_temperature = nn.Parameter(torch.tensor(0.0))
|
| 131 |
+
self.mix = nn.Sequential(
|
| 132 |
+
nn.Linear(dim * 4, dim * 2),
|
| 133 |
+
nn.GELU(),
|
| 134 |
+
nn.Dropout(dropout),
|
| 135 |
+
nn.Linear(dim * 2, dim),
|
| 136 |
+
)
|
| 137 |
+
self.gamma = nn.Parameter(torch.tensor(0.1))
|
| 138 |
+
self.out_norm = nn.LayerNorm(dim)
|
| 139 |
+
|
| 140 |
+
def forward(self, z: Tensor) -> Tuple[Tensor, Tensor]:
|
| 141 |
+
zn = self.norm(z)
|
| 142 |
+
p = F.normalize(self.prototypes, dim=-1)
|
| 143 |
+
q = F.normalize(zn, dim=-1)
|
| 144 |
+
# cosine distance in [0, 2]
|
| 145 |
+
dist = 1.0 - torch.matmul(q, p.t())
|
| 146 |
+
temp = F.softplus(self.log_temperature) + 1e-4
|
| 147 |
+
assign = F.softmax(-dist / temp, dim=-1)
|
| 148 |
+
context = torch.matmul(assign, self.prototypes)
|
| 149 |
+
mixed = self.mix(torch.cat([zn, context, zn - context, zn * context], dim=-1))
|
| 150 |
+
z = z + self.gamma.tanh() * mixed
|
| 151 |
+
return self.out_norm(z), assign
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class LayerMemoryAttention(nn.Module):
|
| 155 |
+
"""A small learned memory bank attended by every patch token."""
|
| 156 |
+
|
| 157 |
+
def __init__(self, dim: int, memory_tokens: int, num_heads: int, dropout: float = 0.0):
|
| 158 |
+
super().__init__()
|
| 159 |
+
self.memory = nn.Parameter(torch.randn(memory_tokens, dim) * 0.02)
|
| 160 |
+
self.norm_q = nn.LayerNorm(dim)
|
| 161 |
+
self.norm_out = nn.LayerNorm(dim)
|
| 162 |
+
self.attn = nn.MultiheadAttention(dim, num_heads=num_heads, dropout=dropout, batch_first=True)
|
| 163 |
+
self.ffn = FeedForward(dim, expansion=4, dropout=dropout)
|
| 164 |
+
self.gamma_attn = nn.Parameter(torch.tensor(0.1))
|
| 165 |
+
self.gamma_ffn = nn.Parameter(torch.tensor(0.1))
|
| 166 |
+
|
| 167 |
+
def forward(self, z: Tensor) -> Tuple[Tensor, Tensor]:
|
| 168 |
+
b = z.shape[0]
|
| 169 |
+
mem = self.memory.unsqueeze(0).expand(b, -1, -1)
|
| 170 |
+
q = self.norm_q(z)
|
| 171 |
+
attn_out, attn_weights = self.attn(q, mem, mem, need_weights=True)
|
| 172 |
+
z = z + self.gamma_attn.tanh() * attn_out
|
| 173 |
+
z = z + self.gamma_ffn.tanh() * self.ffn(self.norm_out(z))
|
| 174 |
+
return z, attn_weights
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class MainClassifier(nn.Module):
|
| 178 |
+
def __init__(self, dim: int, num_classes: int, dropout: float = 0.0):
|
| 179 |
+
super().__init__()
|
| 180 |
+
self.norm = nn.LayerNorm(dim * 3)
|
| 181 |
+
self.head = nn.Sequential(
|
| 182 |
+
nn.Linear(dim * 3, dim),
|
| 183 |
+
nn.GELU(),
|
| 184 |
+
nn.Dropout(dropout),
|
| 185 |
+
nn.Linear(dim, num_classes),
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
def forward(self, cls: Tensor, z: Tensor) -> Tensor:
|
| 189 |
+
mean_pool = z.mean(dim=1)
|
| 190 |
+
max_pool = z.max(dim=1).values
|
| 191 |
+
feat = torch.cat([cls, mean_pool, max_pool], dim=-1)
|
| 192 |
+
return self.head(self.norm(feat))
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class Top2FeedbackModulator(nn.Module):
|
| 196 |
+
"""Turns top-2 class probabilities into scale/shift over patch tokens."""
|
| 197 |
+
|
| 198 |
+
def __init__(self, dim: int, num_classes: int):
|
| 199 |
+
super().__init__()
|
| 200 |
+
self.class_embed = nn.Embedding(num_classes, dim)
|
| 201 |
+
self.stats_mlp = nn.Sequential(
|
| 202 |
+
nn.Linear(4, dim),
|
| 203 |
+
nn.GELU(),
|
| 204 |
+
nn.Linear(dim, dim),
|
| 205 |
+
)
|
| 206 |
+
self.to_scale_shift = nn.Sequential(
|
| 207 |
+
nn.LayerNorm(dim * 2),
|
| 208 |
+
nn.Linear(dim * 2, dim * 2),
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
def forward(self, z0: Tensor, logits: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]:
|
| 212 |
+
probs = logits.softmax(dim=-1)
|
| 213 |
+
top_probs, top_idx = probs.topk(k=min(2, probs.shape[-1]), dim=-1)
|
| 214 |
+
if top_probs.shape[-1] == 1:
|
| 215 |
+
top_probs = torch.cat([top_probs, torch.zeros_like(top_probs)], dim=-1)
|
| 216 |
+
top_idx = torch.cat([top_idx, top_idx], dim=-1)
|
| 217 |
+
|
| 218 |
+
p1 = top_probs[:, 0]
|
| 219 |
+
p2 = top_probs[:, 1]
|
| 220 |
+
margin = p1 - p2
|
| 221 |
+
entropy = -(probs * (probs.clamp_min(1e-8)).log()).sum(dim=-1)
|
| 222 |
+
class_vecs = self.class_embed(top_idx) # [B, 2, C]
|
| 223 |
+
weighted_class_vec = (class_vecs * top_probs.unsqueeze(-1)).sum(dim=1)
|
| 224 |
+
stats = torch.stack([p1, p2, margin, entropy], dim=-1)
|
| 225 |
+
stat_vec = self.stats_mlp(stats)
|
| 226 |
+
scale_shift = self.to_scale_shift(torch.cat([weighted_class_vec, stat_vec], dim=-1))
|
| 227 |
+
scale, shift = scale_shift.chunk(2, dim=-1)
|
| 228 |
+
z_mod = z0 * (1.0 + 0.25 * torch.tanh(scale).unsqueeze(1)) + 0.25 * torch.tanh(shift).unsqueeze(1)
|
| 229 |
+
return z_mod, {
|
| 230 |
+
"p1": p1,
|
| 231 |
+
"p2": p2,
|
| 232 |
+
"margin": margin,
|
| 233 |
+
"entropy": entropy,
|
| 234 |
+
"top_idx": top_idx,
|
| 235 |
+
"top_probs": top_probs,
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class DeltaRBFHardExpert(nn.Module):
|
| 240 |
+
"""RBF expert for hard examples, driven by feedback-modulated patch deltas."""
|
| 241 |
+
|
| 242 |
+
def __init__(self, dim: int, rbf_count: int, num_classes: int, dropout: float = 0.0):
|
| 243 |
+
super().__init__()
|
| 244 |
+
self.delta_norm = nn.LayerNorm(dim)
|
| 245 |
+
self.rbf_centers = nn.Parameter(torch.randn(rbf_count, dim) * 0.02)
|
| 246 |
+
self.log_sigma = nn.Parameter(torch.zeros(rbf_count))
|
| 247 |
+
self.rbf_to_logits = nn.Linear(rbf_count, num_classes)
|
| 248 |
+
self.delta_mlp = nn.Sequential(
|
| 249 |
+
nn.Linear(dim * 2, dim),
|
| 250 |
+
nn.GELU(),
|
| 251 |
+
nn.Dropout(dropout),
|
| 252 |
+
nn.Linear(dim, num_classes),
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
def forward(self, z_base: Tensor, z_mod: Tensor) -> Tuple[Tensor, Tensor]:
|
| 256 |
+
delta = self.delta_norm(z_mod - z_base)
|
| 257 |
+
delta_mean = delta.mean(dim=1)
|
| 258 |
+
delta_max = delta.max(dim=1).values
|
| 259 |
+
|
| 260 |
+
q = F.normalize(delta, dim=-1)
|
| 261 |
+
c = F.normalize(self.rbf_centers, dim=-1)
|
| 262 |
+
dist = 1.0 - torch.matmul(q, c.t()) # [B, N, R]
|
| 263 |
+
sigma = F.softplus(self.log_sigma).view(1, 1, -1) + 1e-4
|
| 264 |
+
rbf = torch.exp(-dist / sigma).mean(dim=1) # [B, R]
|
| 265 |
+
expert_logits = self.rbf_to_logits(rbf) + self.delta_mlp(torch.cat([delta_mean, delta_max], dim=-1))
|
| 266 |
+
return expert_logits, rbf
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class LogitFusion(nn.Module):
|
| 270 |
+
def __init__(self, num_classes: int):
|
| 271 |
+
super().__init__()
|
| 272 |
+
self.alpha = nn.Parameter(torch.tensor(0.35))
|
| 273 |
+
self.calibrate = nn.Sequential(
|
| 274 |
+
nn.LayerNorm(num_classes * 2),
|
| 275 |
+
nn.Linear(num_classes * 2, num_classes),
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
def forward(self, main_logits: Tensor, expert_logits: Tensor) -> Tensor:
|
| 279 |
+
residual = self.calibrate(torch.cat([main_logits, expert_logits], dim=-1))
|
| 280 |
+
return main_logits + self.alpha.sigmoid() * expert_logits + 0.1 * residual
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
class HardCaseGate(nn.Module):
|
| 284 |
+
"""Deterministic inference gate from probability confidence signals."""
|
| 285 |
+
|
| 286 |
+
def __init__(self, pmax_threshold: float, margin_threshold: float, entropy_threshold: float):
|
| 287 |
+
super().__init__()
|
| 288 |
+
self.pmax_threshold = pmax_threshold
|
| 289 |
+
self.margin_threshold = margin_threshold
|
| 290 |
+
self.entropy_threshold = entropy_threshold
|
| 291 |
+
|
| 292 |
+
def forward(self, logits: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]:
|
| 293 |
+
probs = logits.softmax(dim=-1)
|
| 294 |
+
top_probs = probs.topk(k=min(2, probs.shape[-1]), dim=-1).values
|
| 295 |
+
if top_probs.shape[-1] == 1:
|
| 296 |
+
p1 = top_probs[:, 0]
|
| 297 |
+
p2 = torch.zeros_like(p1)
|
| 298 |
+
else:
|
| 299 |
+
p1, p2 = top_probs[:, 0], top_probs[:, 1]
|
| 300 |
+
margin = p1 - p2
|
| 301 |
+
entropy = -(probs * probs.clamp_min(1e-8).log()).sum(dim=-1)
|
| 302 |
+
hard = (p1 < self.pmax_threshold) | (margin < self.margin_threshold) | (entropy > self.entropy_threshold)
|
| 303 |
+
return hard, {"pmax": p1, "margin": margin, "entropy": entropy}
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
class ProtoMorphHead(nn.Module):
|
| 307 |
+
def __init__(self, cfg: ProtoMorphConfig):
|
| 308 |
+
super().__init__()
|
| 309 |
+
self.cfg = cfg
|
| 310 |
+
d = cfg.embed_dim
|
| 311 |
+
self.input_norm = nn.LayerNorm(d)
|
| 312 |
+
self.block1 = ProtoMorphBlock(d, cfg.proto_count, cfg.dropout)
|
| 313 |
+
self.mem1 = LayerMemoryAttention(d, cfg.memory_tokens, cfg.num_heads, cfg.dropout)
|
| 314 |
+
self.block2 = ProtoMorphBlock(d, cfg.proto_count, cfg.dropout)
|
| 315 |
+
self.mem2 = LayerMemoryAttention(d, cfg.memory_tokens, cfg.num_heads, cfg.dropout)
|
| 316 |
+
self.main = MainClassifier(d, cfg.num_classes, cfg.dropout)
|
| 317 |
+
self.gate = HardCaseGate(cfg.hard_pmax_threshold, cfg.hard_margin_threshold, cfg.hard_entropy_threshold)
|
| 318 |
+
self.feedback = Top2FeedbackModulator(d, cfg.num_classes)
|
| 319 |
+
self.hard_expert = DeltaRBFHardExpert(d, cfg.rbf_count, cfg.num_classes, cfg.dropout)
|
| 320 |
+
self.fusion = LogitFusion(cfg.num_classes)
|
| 321 |
+
|
| 322 |
+
def forward(self, cls: Tensor, patches: Tensor, force_hard: bool = False) -> Dict[str, Tensor]:
|
| 323 |
+
z0 = self.input_norm(patches)
|
| 324 |
+
z, assign1 = self.block1(z0)
|
| 325 |
+
z, mem_attn1 = self.mem1(z)
|
| 326 |
+
z, assign2 = self.block2(z)
|
| 327 |
+
z, mem_attn2 = self.mem2(z)
|
| 328 |
+
|
| 329 |
+
main_logits = self.main(cls, z)
|
| 330 |
+
hard_mask, gate_stats = self.gate(main_logits)
|
| 331 |
+
if force_hard:
|
| 332 |
+
hard_mask = torch.ones_like(hard_mask, dtype=torch.bool)
|
| 333 |
+
|
| 334 |
+
z_mod, fb_stats = self.feedback(z0, main_logits)
|
| 335 |
+
expert_logits, rbf = self.hard_expert(z0, z_mod)
|
| 336 |
+
fused_logits = self.fusion(main_logits, expert_logits)
|
| 337 |
+
final_logits = torch.where(hard_mask[:, None], fused_logits, main_logits)
|
| 338 |
+
|
| 339 |
+
out = {
|
| 340 |
+
"logits": final_logits,
|
| 341 |
+
"main_logits": main_logits,
|
| 342 |
+
"expert_logits": expert_logits,
|
| 343 |
+
"hard_mask": hard_mask,
|
| 344 |
+
"rbf": rbf,
|
| 345 |
+
"assign1_mean": assign1.mean(dim=1),
|
| 346 |
+
"assign2_mean": assign2.mean(dim=1),
|
| 347 |
+
"mem_attn1_mean": mem_attn1.mean(dim=1),
|
| 348 |
+
"mem_attn2_mean": mem_attn2.mean(dim=1),
|
| 349 |
+
}
|
| 350 |
+
out.update({f"gate_{k}": v for k, v in gate_stats.items()})
|
| 351 |
+
out.update({f"fb_{k}": v for k, v in fb_stats.items() if isinstance(v, Tensor)})
|
| 352 |
+
return out
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
class ProtoMorphDINOv3(nn.Module):
|
| 356 |
+
"""Full inference graph: frozen DINOv3 + custom ProtoMorph head."""
|
| 357 |
+
|
| 358 |
+
def __init__(self, cfg: ProtoMorphConfig, local_files_only: bool = False):
|
| 359 |
+
super().__init__()
|
| 360 |
+
self.cfg = cfg
|
| 361 |
+
self.backbone = FrozenDINOv3(cfg.dino_model_name, image_size=cfg.image_size, local_files_only=local_files_only)
|
| 362 |
+
actual_dim = self.backbone.hidden_size
|
| 363 |
+
if actual_dim and actual_dim != cfg.embed_dim:
|
| 364 |
+
raise ValueError(
|
| 365 |
+
f"Config embed_dim={cfg.embed_dim} but DINO hidden_size={actual_dim}. "
|
| 366 |
+
f"Use the matching config or run scripts/create_random_head.py with --embed-dim {actual_dim}."
|
| 367 |
+
)
|
| 368 |
+
self.head = ProtoMorphHead(cfg)
|
| 369 |
+
|
| 370 |
+
@torch.no_grad()
|
| 371 |
+
def forward(
|
| 372 |
+
self,
|
| 373 |
+
images: Image.Image | Sequence[Image.Image],
|
| 374 |
+
device: torch.device | str,
|
| 375 |
+
force_hard: bool = False,
|
| 376 |
+
use_bf16_autocast: Optional[bool] = None,
|
| 377 |
+
) -> Dict[str, Tensor | Tuple[int, int]]:
|
| 378 |
+
use_amp = self.cfg.use_bf16_autocast if use_bf16_autocast is None else use_bf16_autocast
|
| 379 |
+
device_obj = torch.device(device)
|
| 380 |
+
amp_enabled = bool(use_amp and device_obj.type == "cuda")
|
| 381 |
+
amp_dtype = torch.bfloat16
|
| 382 |
+
|
| 383 |
+
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp_enabled):
|
| 384 |
+
feats = self.backbone(images, device=device_obj)
|
| 385 |
+
cls = feats.cls
|
| 386 |
+
patches = feats.patches
|
| 387 |
+
if self.cfg.normalize_patch_tokens:
|
| 388 |
+
cls = F.layer_norm(cls, cls.shape[-1:])
|
| 389 |
+
patches = F.layer_norm(patches, patches.shape[-1:])
|
| 390 |
+
head_out = self.head(cls, patches, force_hard=force_hard)
|
| 391 |
+
head_out["patch_hw"] = feats.patch_hw
|
| 392 |
+
head_out["pixel_hw"] = feats.pixel_hw
|
| 393 |
+
return head_out
|
| 394 |
+
|
| 395 |
+
def save_custom_head(self, checkpoint_path: str | Path) -> None:
|
| 396 |
+
if safe_save_file is None:
|
| 397 |
+
raise ImportError("safetensors is required: pip install safetensors")
|
| 398 |
+
p = Path(checkpoint_path)
|
| 399 |
+
p.parent.mkdir(parents=True, exist_ok=True)
|
| 400 |
+
safe_save_file(self.head.state_dict(), str(p))
|
| 401 |
+
|
| 402 |
+
def load_custom_head(self, checkpoint_path: str | Path, strict: bool = True) -> None:
|
| 403 |
+
if safe_load_file is None:
|
| 404 |
+
raise ImportError("safetensors is required: pip install safetensors")
|
| 405 |
+
sd = safe_load_file(str(checkpoint_path), device="cpu")
|
| 406 |
+
self.head.load_state_dict(sd, strict=strict)
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def infer_embed_dim_from_model_name(model_name: str) -> int:
|
| 410 |
+
"""Useful defaults for DINOv3 ViT checkpoints."""
|
| 411 |
+
name = model_name.lower()
|
| 412 |
+
if "vits" in name:
|
| 413 |
+
return 384
|
| 414 |
+
if "vitb" in name:
|
| 415 |
+
return 768
|
| 416 |
+
if "vitl" in name:
|
| 417 |
+
return 1024
|
| 418 |
+
if "vith" in name:
|
| 419 |
+
return 1280
|
| 420 |
+
return 384
|