""" 🛰️ core-dino | Training Script for Resolution-Agnostic SSL on Satellite Imagery Trains DINO with a YOLO backbone using multi-resolution Core-Five patches. 👨‍💻 Author: Gajesh Ladhar 🔗 LinkedIn: https://www.linkedin.com/in/gajeshladhar/ 🤗 Hugging Face: https://huggingface.co/gajeshladhar """ # 📦 Imports import torch from torch.utils.data import DataLoader from loss import DinoSpatialLoss from backbone import YOLOBackBone from data import DinoDataset from utils import * # ⚙️ Config CFG = { "imgsz": 1696, "batch_size": 4, "epochs": 100, "device": "cuda" if torch.cuda.is_available() else "cpu", "lr": 1e-4, "queue_size": 1000, "ckpt_path": "yolo11x.pt", "save_path" : "dino-yolo.pt", ## core-DINO logic parameters... "teacher_temperature":0.04, "student_temperature":0.1, "teacher_ema" : 0.998, } # 🔄 Sync Student → Teacher Weights @torch.no_grad() def initialize_teacher(student, teacher): for ps, pt in zip(student.parameters(), teacher.parameters()): pt.data.copy_(ps.data) @torch.no_grad() def update_teacher(student, teacher, m=0.996): for ps, pt in zip(student.parameters(), teacher.parameters()): pt.data.mul_(m).add_(ps.data, alpha=1 - m) # 🧠 Model + Loss + Optimizer def setup_model_and_loss(): student = YOLOBackBone(model_path=CFG["ckpt_path"]).to(CFG["device"]) teacher = YOLOBackBone(model_path=CFG["ckpt_path"]).to(CFG["device"]) for p in teacher.parameters(): p.requires_grad = False loss_fn = DinoSpatialLoss(teacher_temp=CFG["teacher_temperature"],student_temp=CFG["student_temperature"]).to(CFG["device"]) optimizer = torch.optim.AdamW(student.parameters(), lr=CFG["lr"], weight_decay=0.05) return student, teacher, loss_fn, optimizer # 🔁 Training Loop def train(): student, teacher, criterion, optimizer = setup_model_and_loss() dataset = DinoDataset(imgsz=CFG["imgsz"], batch_size=CFG["batch_size"], queue_size=CFG["queue_size"]) num_epochs = CFG["epochs"] device = CFG["device"] for epoch in range(num_epochs): running_loss = 0.0 running_entropy = 0.0 total_count = 0 loop = tqdm(dataset.store, desc=f"📅 Epoch {epoch+1}/{num_epochs}") for batch in loop: images_s = torch.nan_to_num(batch['student'].float() / 255.0, nan=0.0).to(device) images_t = torch.nan_to_num(batch['teacher'].float() / 255.0, nan=0.0).to(device) with torch.no_grad(): teacher_out = teacher(images_t).detach() with autocast(device_type='cuda', enabled=False): student_out = student(images_s) loss = criterion(student_out, teacher_out) optimizer.zero_grad() loss.backward() optimizer.step() update_teacher(student, teacher, m=CFG["teacher_ema"]) running_loss += loss.item() total_count += 1 # 📊 Entropy Calc probs = F.softmax(teacher_out / CFG["teacher_temperature"], dim=1) eps = 1e-6 entropy = -(probs * (probs + eps).log()).sum(dim=1).mean() running_entropy += entropy.item() # 🔄 Live Bar Update loop.set_postfix({ "💥 Loss": f"{loss.item():.4f}", "📈 Entropy": f"{entropy.item():.4f}" }) avg_loss = running_loss / total_count avg_entropy = running_entropy / total_count print(f"✅ Epoch {epoch+1:03} | 🧠 Avg Loss: {avg_loss:.4f} | 🔐 Teacher Entropy: {avg_entropy:.4f} | 💾 Saved → {CFG['save_path']}") torch.save(student.state_dict(), CFG["save_path"]) if __name__=="__main__": train()