| """ |
| self_train.py |
| ============= |
| Autonomous continual-learning pipeline for MyoSeg. |
| Place at the ROOT of your Hugging Face Space repo (same level as Dockerfile). |
| |
| IMPORTANT: This file is completely self-contained. |
| It does NOT import from train_myotube_nuclei_unet.py. |
| The train script is a separate PyCharm tool. |
| |
| Trigger conditions (any one fires a retrain): |
| 1. User submitted corrected label pairs via the app β corrections/ folder |
| 2. N unlabelled images accumulated in queue β retrain_queue/ |
| 3. K consecutive low-confidence images β retrain_queue/ (reason=low_confidence) |
| 4. Nightly scheduled run β APScheduler cron 02:00 UTC |
| 5. User-optimized parameters submitted β corrections/ (reason=user_optimized_params) |
| These submissions include the image, postprocessed masks from the user's |
| tuned parameter set, and a full snapshot of the sidebar settings. The |
| self-training pipeline uses these as additional supervised training pairs |
| and can aggregate parameter statistics to learn optimal defaults. |
| |
| Privacy note: |
| Images processed in Private Mode (toggle in the Streamlit sidebar) are |
| NEVER queued for retraining. Only images explicitly submitted by the |
| user via "Submit for training" or "Submit corrections" are used. |
| |
| After each retrain: |
| β’ Fine-tunes from current HF Hub weights |
| β’ Validates on held-out 20% split |
| β’ Only pushes to Hub if new Dice > previous best |
| β’ Archives queue β runs/<run_id>/processed_queue/ |
| β’ Appends entry to manifest.json |
| β’ Aggregates user-submitted parameter snapshots β optimal_params.json |
| |
| Usage: |
| python self_train.py # check triggers once |
| python self_train.py --manual # force retrain now |
| python self_train.py --scheduler # blocking APScheduler loop (for Docker) |
| |
| Environment variables / HF Secrets: |
| HF_TOKEN write-access Hugging Face token |
| HF_REPO_ID model repo, e.g. "skarugu/myotube-unet" |
| HF_FILENAME model filename, e.g. "model_final.pt" |
| DATA_ROOT path to base training data/ folder |
| BATCH_TRIGGER_N images before batch trigger (default 20) |
| CONF_DROP_K consecutive low-conf before trigger (default 5) |
| FT_EPOCHS fine-tuning epochs per run (default 10) |
| FT_LR fine-tuning learning rate (default 5e-4) |
| SCHEDULE_HOUR nightly retrain UTC hour (default 2) |
| """ |
|
|
| import argparse |
| import json |
| import logging |
| import os |
| import random |
| import shutil |
| import tempfile |
| from datetime import datetime |
| from pathlib import Path |
| from typing import Optional |
|
|
| import numpy as np |
| import scipy.ndimage as ndi |
| import torch |
| import torch.nn as nn |
| from PIL import Image |
| from huggingface_hub import HfApi, hf_hub_download |
| from skimage import measure |
| from skimage.feature import peak_local_max |
| from skimage.morphology import disk, opening, remove_small_objects |
| from skimage.segmentation import watershed |
| from torch.utils.data import DataLoader, Dataset, random_split |
|
|
| try: |
| from apscheduler.schedulers.blocking import BlockingScheduler |
| HAS_SCHEDULER = True |
| except ImportError: |
| HAS_SCHEDULER = False |
|
|
| |
| |
| |
|
|
| ROOT = Path(__file__).parent |
|
|
| HF_REPO_ID = os.environ.get("HF_REPO_ID", "skarugu/myotube-unet") |
| HF_FILENAME = os.environ.get("HF_FILENAME", "model_final.pt") |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) |
| DATA_ROOT = os.environ.get("DATA_ROOT", str(ROOT / "data")) |
|
|
| BATCH_TRIGGER_N = int(os.environ.get("BATCH_TRIGGER_N", 20)) |
| CONF_DROP_K = int(os.environ.get("CONF_DROP_K", 5)) |
| CONF_FLAG_THR = float(os.environ.get("CONF_FLAG_THR", 0.60)) |
| SCHEDULE_HOUR = int(os.environ.get("SCHEDULE_HOUR", 2)) |
| FT_EPOCHS = int(os.environ.get("FT_EPOCHS", 10)) |
| FT_LR = float(os.environ.get("FT_LR", 5e-4)) |
| FT_BATCH_SIZE = int(os.environ.get("FT_BATCH_SIZE", 4)) |
| IMAGE_SIZE = int(os.environ.get("IMAGE_SIZE", 512)) |
|
|
| QUEUE_DIR = ROOT / "retrain_queue" |
| CORRECTIONS_DIR = ROOT / "corrections" |
| RUNS_DIR = ROOT / "runs" |
| STATE_PATH = ROOT / "self_train_state.json" |
| MANIFEST_PATH = ROOT / "manifest.json" |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s [%(levelname)s] %(message)s", |
| datefmt="%Y-%m-%d %H:%M:%S", |
| ) |
| log = logging.getLogger("self_train") |
|
|
|
|
| |
| |
| |
|
|
| def _load_state() -> dict: |
| if STATE_PATH.exists(): |
| return json.loads(STATE_PATH.read_text()) |
| return {"best_dice": 0.0, "last_retrain_ts": None, "current_hf_sha": None} |
|
|
| def _save_state(s: dict): STATE_PATH.write_text(json.dumps(s, indent=2)) |
|
|
| def _load_manifest() -> list: |
| return json.loads(MANIFEST_PATH.read_text()) if MANIFEST_PATH.exists() else [] |
|
|
| def _save_manifest(m: list): MANIFEST_PATH.write_text(json.dumps(m, indent=2, default=str)) |
|
|
|
|
| def _aggregate_user_params(corrections_dir: Path, run_dir: Path): |
| """ |
| Scan corrections for user_optimized_params submissions and aggregate |
| their parameter settings. Writes optimal_params.json to run_dir with |
| median values β useful for tuning defaults. |
| """ |
| all_params = [] |
| if not corrections_dir.exists(): |
| return |
| for meta_p in corrections_dir.glob("*/meta.json"): |
| try: |
| meta = json.loads(meta_p.read_text()) |
| if meta.get("reason") == "user_optimized_params" and "parameters" in meta: |
| all_params.append(meta["parameters"]) |
| except Exception: |
| continue |
|
|
| if not all_params: |
| return |
|
|
| |
| aggregated = {} |
| for key in all_params[0]: |
| vals = [p[key] for p in all_params if key in p and isinstance(p[key], (int, float))] |
| if vals: |
| vals.sort() |
| mid = len(vals) // 2 |
| aggregated[key] = vals[mid] if len(vals) % 2 else (vals[mid-1] + vals[mid]) / 2 |
|
|
| result = { |
| "n_submissions": len(all_params), |
| "aggregated_params": aggregated, |
| "all_submissions": all_params, |
| } |
| out = run_dir / "optimal_params.json" |
| out.write_text(json.dumps(result, indent=2)) |
| log.info("Aggregated %d user param submissions β %s", len(all_params), out) |
|
|
|
|
| |
| |
| |
|
|
| def should_retrain(force=False): |
| if force: |
| return True, "manual" |
|
|
| corrections = list(CORRECTIONS_DIR.glob("*/meta.json")) if CORRECTIONS_DIR.exists() else [] |
| if corrections: |
| return True, f"user_correction ({len(corrections)} pairs)" |
|
|
| q_jsons = list(QUEUE_DIR.glob("*.json")) if QUEUE_DIR.exists() else [] |
| if len(q_jsons) >= BATCH_TRIGGER_N: |
| return True, f"batch_trigger ({len(q_jsons)} images)" |
|
|
| low_conf = sum( |
| 1 for jf in q_jsons |
| if json.loads(jf.read_text()).get("reason") == "low_confidence" |
| ) if q_jsons else 0 |
| if low_conf >= CONF_DROP_K: |
| return True, f"confidence_drop ({low_conf} low-conf images)" |
|
|
| return False, "none" |
|
|
|
|
| |
| |
| |
|
|
| class DoubleConv(nn.Module): |
| def __init__(self, in_ch, out_ch): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(True), |
| nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(True), |
| ) |
| def forward(self, x): return self.net(x) |
|
|
|
|
| class UNet(nn.Module): |
| def __init__(self, in_ch=2, out_ch=2, base=32): |
| super().__init__() |
| self.d1=DoubleConv(in_ch,base); self.p1=nn.MaxPool2d(2) |
| self.d2=DoubleConv(base,base*2); self.p2=nn.MaxPool2d(2) |
| self.d3=DoubleConv(base*2,base*4); self.p3=nn.MaxPool2d(2) |
| self.d4=DoubleConv(base*4,base*8); self.p4=nn.MaxPool2d(2) |
| self.bn=DoubleConv(base*8,base*16) |
| self.u4=nn.ConvTranspose2d(base*16,base*8,2,2); self.du4=DoubleConv(base*16,base*8) |
| self.u3=nn.ConvTranspose2d(base*8,base*4,2,2); self.du3=DoubleConv(base*8,base*4) |
| self.u2=nn.ConvTranspose2d(base*4,base*2,2,2); self.du2=DoubleConv(base*4,base*2) |
| self.u1=nn.ConvTranspose2d(base*2,base,2,2); self.du1=DoubleConv(base*2,base) |
| self.out=nn.Conv2d(base,out_ch,1) |
|
|
| def forward(self, x): |
| d1=self.d1(x); p1=self.p1(d1) |
| d2=self.d2(p1); p2=self.p2(d2) |
| d3=self.d3(p2); p3=self.p3(d3) |
| d4=self.d4(p3); p4=self.p4(d4) |
| b=self.bn(p4) |
| x=self.u4(b); x=torch.cat([x,d4],1); x=self.du4(x) |
| x=self.u3(x); x=torch.cat([x,d3],1); x=self.du3(x) |
| x=self.u2(x); x=torch.cat([x,d2],1); x=self.du2(x) |
| x=self.u1(x); x=torch.cat([x,d1],1); x=self.du1(x) |
| return self.out(x) |
|
|
|
|
| |
| |
| |
|
|
| class _FTDataset(Dataset): |
| IMG_EXTS = {".jpg", ".jpeg", ".png", ".tif", ".tiff"} |
|
|
| def __init__(self, root, size=512, augment=True): |
| root = Path(root) |
| img_dir = root / "images" |
| nuc_dir = root / "masks" / "Nuclei_m" |
| myo_dir = root / "masks" / "Myotubes_m" |
|
|
| imgs = sorted([p for p in img_dir.glob("*") if p.suffix.lower() in self.IMG_EXTS]) |
| self.samples = [] |
| for p in imgs: |
| nuc = self._mp(nuc_dir, p.stem) |
| myo = self._mp(myo_dir, p.stem) |
| if nuc and myo: |
| self.samples.append((p, nuc, myo)) |
|
|
| if not self.samples: |
| raise FileNotFoundError(f"No labelled samples found under {root}") |
|
|
| self.size = size |
| self.augment = augment |
|
|
| @staticmethod |
| def _mp(d, stem): |
| for ext in (".tif", ".tiff", ".png"): |
| p = d / f"{stem}{ext}" |
| if p.exists(): return p |
| return None |
|
|
| def __len__(self): return len(self.samples) |
|
|
| def __getitem__(self, idx): |
| ip, nuc_path, mp = self.samples[idx] |
| rgb = np.array(Image.open(ip).convert("RGB"), dtype=np.uint8) |
| H = W = self.size |
|
|
| def _ch(arr): return np.array(Image.fromarray(arr, "L").resize((W, H), Image.BILINEAR), dtype=np.float32) / 255.0 |
| def _mk(p): return (np.array(Image.open(p).convert("L").resize((W, H), Image.NEAREST)) > 0).astype(np.uint8) |
|
|
| red = _ch(rgb[..., 0]) |
| blue = _ch(rgb[..., 2]) |
| yn = _mk(nuc_path) |
| ym = _mk(mp) |
|
|
| if self.augment: |
| f = np.stack([red, blue, np.zeros_like(red)], -1).astype(np.float32) |
| for ax in [1, 0]: |
| if random.random() < 0.5: |
| f = np.flip(f, ax); yn = np.flip(yn, ax); ym = np.flip(ym, ax) |
| k = random.randint(0, 3) |
| if k: f = np.rot90(f, k); yn = np.rot90(yn, k); ym = np.rot90(ym, k) |
| red, blue = f[..., 0], f[..., 1] |
|
|
| x = np.stack([red, blue], 0).astype(np.float32) |
| y = np.stack([yn, ym], 0).astype(np.float32) |
| return torch.from_numpy(x.copy()), torch.from_numpy(y.copy()), ip.stem |
|
|
|
|
| |
| |
| |
|
|
| class _BCEDice(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.bce = nn.BCEWithLogitsLoss() |
| def forward(self, logits, target): |
| bce = self.bce(logits, target) |
| p = torch.sigmoid(logits) |
| inter = (p * target).sum(dim=(2,3)) |
| union = p.sum(dim=(2,3)) + target.sum(dim=(2,3)) |
| dice = 1 - (2*inter+1e-6)/(union+1e-6) |
| return 0.5*bce + 0.5*dice.mean() |
|
|
| @torch.no_grad() |
| def _dice(probs, target, thr=0.5): |
| pred = (probs > thr).float() |
| inter = (pred * target).sum(dim=(2,3)) |
| union = pred.sum(dim=(2,3)) + target.sum(dim=(2,3)) |
| return ((2*inter+1e-6)/(union+1e-6)).mean(dim=0) |
|
|
|
|
| |
| |
| |
|
|
| def _prepare_data(base: str) -> str: |
| tmp = Path(tempfile.mkdtemp()) / "ft" |
| orig = Path(base) |
| if (orig / "images").exists(): |
| shutil.copytree(str(orig), str(tmp), dirs_exist_ok=True) |
| else: |
| for sub in ("images", "masks/Nuclei_m", "masks/Myotubes_m"): |
| (tmp / sub).mkdir(parents=True, exist_ok=True) |
| log.warning("DATA_ROOT %s has no images/ β training on corrections only.", orig) |
|
|
| injected = 0 |
| if CORRECTIONS_DIR.exists(): |
| for meta_p in CORRECTIONS_DIR.glob("*/meta.json"): |
| folder = meta_p.parent |
| img, nuc, myo = folder/"image.png", folder/"nuclei_mask.png", folder/"myotube_mask.png" |
| if not (img.exists() and nuc.exists() and myo.exists()): |
| continue |
| stem = folder.name |
| shutil.copy(img, tmp/"images"/f"{stem}.png") |
| shutil.copy(nuc, tmp/"masks"/"Nuclei_m"/f"{stem}.png") |
| shutil.copy(myo, tmp/"masks"/"Myotubes_m"/f"{stem}.png") |
| injected += 1 |
|
|
| log.info("Fine-tune data ready: %d correction(s) injected β %s", injected, tmp) |
| return str(tmp) |
|
|
|
|
| |
| |
| |
|
|
| def _load_from_hub(): |
| path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_FILENAME, |
| token=HF_TOKEN, force_download=True) |
| ckpt = torch.load(path, map_location="cpu") |
| state = ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt |
| model = UNet(in_ch=2, out_ch=2, base=32) |
| model.load_state_dict(state) |
| log.info("Loaded model from Hub (repo=%s, file=%s)", HF_REPO_ID, HF_FILENAME) |
| return model |
|
|
|
|
| def _push_to_hub(model_path: Path, metrics: dict, run_id: str) -> bool: |
| if not HF_TOKEN: |
| log.warning("HF_TOKEN not set β skipping Hub push.") |
| return False |
| api = HfApi(token=HF_TOKEN) |
| api.upload_file( |
| path_or_fileobj=str(model_path), |
| path_in_repo=HF_FILENAME, |
| repo_id=HF_REPO_ID, |
| repo_type="model", |
| commit_message=(f"Auto-retrain {run_id} | " |
| f"dice_nuc={metrics['dice_nuc']:.3f} " |
| f"dice_myo={metrics['dice_myo']:.3f}"), |
| ) |
| api.upload_file( |
| path_or_fileobj=json.dumps({**metrics, "run_id": run_id, |
| "timestamp": datetime.now().isoformat()}, |
| indent=2).encode(), |
| path_in_repo="auto_retrain_metrics.json", |
| repo_id=HF_REPO_ID, |
| repo_type="model", |
| commit_message=f"Metrics for auto-retrain {run_id}", |
| ) |
| log.info("β
Pushed new weights to %s/%s", HF_REPO_ID, HF_FILENAME) |
| return True |
|
|
|
|
| |
| |
| |
|
|
| def run_retrain(reason: str = "scheduled"): |
| random.seed(42); np.random.seed(42); torch.manual_seed(42) |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| run_id = datetime.now().strftime("%Y%m%d_%H%M%S") |
| run_dir = RUNS_DIR / run_id |
| run_dir.mkdir(parents=True, exist_ok=True) |
|
|
| log.info("ββ Retrain run %s | reason=%s | device=%s ββ", run_id, reason, device) |
|
|
| ft_data = _prepare_data(DATA_ROOT) |
| try: |
| ds = _FTDataset(ft_data, size=IMAGE_SIZE, augment=True) |
| except FileNotFoundError as e: |
| log.error("No data: %s β aborting.", e) |
| return None |
|
|
| n_val = max(1, int(len(ds) * 0.2)) |
| n_train = len(ds) - n_val |
| if n_train < 1: |
| log.warning("Only %d samples β need β₯2. Aborting.", len(ds)) |
| return None |
|
|
| train_ds, val_ds = random_split( |
| ds, [n_train, n_val], generator=torch.Generator().manual_seed(42) |
| ) |
| val_ds.dataset.augment = False |
| train_dl = DataLoader(train_ds, batch_size=FT_BATCH_SIZE, shuffle=True, num_workers=0) |
| val_dl = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=0) |
|
|
| model = _load_from_hub().to(device) |
| loss_fn = _BCEDice() |
| opt = torch.optim.Adam(model.parameters(), lr=FT_LR) |
| sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=FT_EPOCHS, eta_min=1e-5) |
|
|
| state = _load_state() |
| prev_best = state.get("best_dice", 0.0) |
| best_run_dice = -1.0 |
| best_path = run_dir / "model_best.pt" |
|
|
| for ep in range(1, FT_EPOCHS + 1): |
| model.train() |
| for x, y, _ in train_dl: |
| x, y = x.to(device), y.to(device) |
| opt.zero_grad(); loss_fn(model(x), y).backward(); opt.step() |
| sched.step() |
|
|
| model.eval() |
| dices = [] |
| with torch.no_grad(): |
| for x, y, _ in val_dl: |
| probs = torch.sigmoid(model(x.to(device))).cpu() |
| dices.append(_dice(probs, y).numpy()) |
| d = np.array(dices) |
| d_nuc, d_myo = float(d[:,0].mean()), float(d[:,1].mean()) |
| score = (d_nuc + d_myo) / 2.0 |
| log.info(" Ep %02d | dice_nuc=%.3f | dice_myo=%.3f | mean=%.3f", ep, d_nuc, d_myo, score) |
|
|
| if score > best_run_dice: |
| best_run_dice = score |
| torch.save({"model": model.state_dict()}, best_path) |
|
|
| metrics = { |
| "dice_nuc": round(d_nuc, 4), |
| "dice_myo": round(d_myo, 4), |
| "mean_dice": round(best_run_dice, 4), |
| "reason": reason, |
| "n_train": n_train, |
| "n_val": n_val, |
| } |
|
|
| pushed = False |
| log.info("Best this run: %.4f | Previous best: %.4f", best_run_dice, prev_best) |
| if best_run_dice > prev_best: |
| pushed = _push_to_hub(best_path, metrics, run_id) |
| state["best_dice"] = best_run_dice |
| state["current_hf_sha"] = str(best_path) |
| else: |
| log.info("New model did not beat previous best β NOT pushing.") |
|
|
| |
| archive = run_dir / "processed_queue" |
| archive.mkdir(parents=True, exist_ok=True) |
|
|
| |
| _aggregate_user_params(CORRECTIONS_DIR, run_dir) |
|
|
| for p in list(QUEUE_DIR.glob("*")) if QUEUE_DIR.exists() else []: |
| shutil.move(str(p), str(archive / p.name)) |
| for folder in list(CORRECTIONS_DIR.glob("*")) if CORRECTIONS_DIR.exists() else []: |
| if folder.is_dir(): |
| shutil.move(str(folder), str(archive / folder.name)) |
|
|
| state["last_retrain_ts"] = datetime.now().isoformat() |
| _save_state(state) |
|
|
| manifest = _load_manifest() |
| manifest.append({"run_id": run_id, "timestamp": state["last_retrain_ts"], |
| "reason": reason, "metrics": metrics, "pushed": pushed}) |
| _save_manifest(manifest) |
|
|
| log.info("ββ Run %s complete | pushed=%s ββ", run_id, pushed) |
| return metrics |
|
|
|
|
| |
| |
| |
|
|
| def check_and_retrain(force=False): |
| ok, reason = should_retrain(force=force) |
| if ok: |
| log.info("Trigger met: %s β retrainingβ¦", reason) |
| run_retrain(reason=reason) |
| else: |
| log.info("No trigger met β skipping.") |
|
|
|
|
| |
| |
| |
|
|
| def start_scheduler(): |
| if not HAS_SCHEDULER: |
| log.error("APScheduler not installed. pip install apscheduler") |
| return |
| s = BlockingScheduler(timezone="UTC") |
| s.add_job(lambda: check_and_retrain(force=True), |
| "cron", hour=SCHEDULE_HOUR, minute=0, id="nightly") |
| s.add_job(check_and_retrain, "interval", minutes=30, id="poll") |
| log.info("Scheduler running. Nightly at %02d:00 UTC. Polling every 30 min.", SCHEDULE_HOUR) |
| try: |
| s.start() |
| except (KeyboardInterrupt, SystemExit): |
| log.info("Scheduler stopped.") |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--manual", action="store_true", help="Force retrain now") |
| ap.add_argument("--scheduler", action="store_true", help="Start blocking scheduler") |
| ap.add_argument("--data_root", default=None, help="Override DATA_ROOT env var") |
| a = ap.parse_args() |
| if a.data_root: |
| DATA_ROOT = a.data_root |
| if a.scheduler: |
| start_scheduler() |
| else: |
| check_and_retrain(force=a.manual) |
|
|