| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import math |
| import os |
| import time |
| import json |
| from dataclasses import dataclass |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from tqdm import tqdm |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| MODELS = [ |
| ("google-bert/bert-base-uncased", "bert", 512), |
| ("answerdotai/ModernBERT-base", "modern", 8192), |
| ("FacebookAI/roberta-base", "roberta", 512), |
| ("albert/albert-base-v2", "albert", 512), |
| ("distilbert/distilbert-base-uncased", "distil", 512), |
| ] |
|
|
| @dataclass |
| class Config: |
| |
| n_samples: int = 500000 |
| n_val: int = 5000 |
| min_caption_len: int = 50 |
| extract_batch: int = 1024 |
| cache_dir: str = "/home/claude/consensus_500k" |
|
|
| |
| d_model: int = 384 |
| n_heads: int = 6 |
| n_layers: int = 6 |
| d_ff: int = 1536 |
| max_len: int = 8192 |
| tokenize_len: int = 512 |
| output_dim: int = 768 |
| dropout: float = 0.1 |
|
|
| |
| epochs: int = 30 |
| batch_size: int = 128 |
| lr: float = 3e-4 |
| weight_decay: float = 0.01 |
| warmup_steps: int = 1000 |
| grad_clip: float = 1.0 |
| seed: int = 42 |
|
|
| |
| nce_weight: float = 1.0 |
| mse_weight: float = 1.0 |
| cv_weight: float = 0.1 |
| cv_target: float = 0.084 |
|
|
| CFG = Config() |
|
|
| print("=" * 65) |
| print("DISTILLED CONSENSUS BERT β 200K Scale") |
| print("=" * 65) |
| print(f" Device: {DEVICE}") |
| print(f" Samples: {CFG.n_samples:,}") |
|
|
|
|
| |
| |
| |
|
|
| def load_captions(n, min_len=50): |
| from datasets import load_dataset |
| print(f"\n Loading captions (n={n:,})...") |
| ds = load_dataset("CaptionEmporium/conceptual-captions-cc12m-llavanext", |
| split="train", streaming=True) |
| captions = [] |
| for row in ds: |
| cap = row.get("caption_llava", "") |
| if isinstance(cap, str) and len(cap) > min_len: |
| captions.append(cap) |
| if len(captions) >= n: |
| break |
| print(f" Got {len(captions):,} captions") |
| return captions |
|
|
|
|
| @torch.no_grad() |
| def extract_one(model_name, short_name, captions, max_len, batch_size): |
| from transformers import AutoModel, AutoTokenizer |
| print(f"\n Extracting: {short_name} ({model_name})...") |
| model = AutoModel.from_pretrained(model_name).to(DEVICE).eval() |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| dim = model.config.hidden_size |
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f" dim={dim}, {n_params:,} params") |
|
|
| all_emb = [] |
| for i in tqdm(range(0, len(captions), batch_size), desc=f" {short_name}"): |
| batch = captions[i:i+batch_size] |
| inputs = tokenizer(batch, max_length=max_len, padding=True, |
| truncation=True, return_tensors="pt").to(DEVICE) |
| out = model(**inputs) |
| mask = inputs.attention_mask.unsqueeze(-1).float() |
| pooled = (out.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1) |
| all_emb.append(pooled.cpu()) |
|
|
| emb = torch.cat(all_emb) |
| print(f" Shape: {emb.shape}") |
| del model |
| torch.cuda.empty_cache() |
| return emb |
|
|
|
|
| def extract_all(): |
| os.makedirs(CFG.cache_dir, exist_ok=True) |
| caps_path = os.path.join(CFG.cache_dir, "captions.json") |
|
|
| all_cached = all( |
| os.path.exists(os.path.join(CFG.cache_dir, f"{s}.pt")) |
| for _, s, _ in MODELS) |
|
|
| if all_cached and os.path.exists(caps_path): |
| print("\n Loading cached embeddings...") |
| embeds = {} |
| for _, short, _ in MODELS: |
| embeds[short] = torch.load( |
| os.path.join(CFG.cache_dir, f"{short}.pt"), weights_only=True) |
| print(f" {short}: {embeds[short].shape}") |
| with open(caps_path) as f: |
| captions = json.load(f) |
| return embeds, captions |
|
|
| captions = load_captions(CFG.n_samples, CFG.min_caption_len) |
|
|
| embeds = {} |
| for model_name, short, model_max_len in MODELS: |
| emb = extract_one(model_name, short, captions, |
| model_max_len, CFG.extract_batch) |
| if emb.shape[1] != 768: |
| if emb.shape[1] < 768: |
| emb = F.pad(emb, (0, 768 - emb.shape[1])) |
| else: |
| emb = emb[:, :768] |
| embeds[short] = emb |
| torch.save(emb, os.path.join(CFG.cache_dir, f"{short}.pt")) |
|
|
| with open(caps_path, "w") as f: |
| json.dump(captions, f) |
|
|
| return embeds, captions |
|
|
|
|
| |
| |
| |
|
|
| def symmetric_inv_sqrt(cov, eps=1e-6): |
| evals, evecs = torch.linalg.eigh(cov) |
| evals = torch.clamp(evals, min=eps) |
| return evecs @ torch.diag(evals.rsqrt()) @ evecs.T |
|
|
|
|
| def procrustes_align(source, target, n_align=10000): |
| N = min(n_align, source.shape[0], target.shape[0]) |
| S = source[:N].float() |
| T = target[:N].float() |
| s_mean = S.mean(0, keepdim=True) |
| t_mean = T.mean(0, keepdim=True) |
| Sc = S - s_mean |
| Tc = T - t_mean |
| N_s = Sc.shape[0] |
|
|
| s_cov = (Sc.T @ Sc) / max(N_s - 1, 1) |
| t_cov = (Tc.T @ Tc) / max(N_s - 1, 1) |
| s_whiten = symmetric_inv_sqrt(s_cov) |
| t_whiten = symmetric_inv_sqrt(t_cov) |
|
|
| Sc_w = F.normalize(Sc @ s_whiten, dim=-1) |
| Tc_w = F.normalize(Tc @ t_whiten, dim=-1) |
|
|
| cos_before = F.cosine_similarity(Sc, Tc, dim=-1).mean().item() |
|
|
| U, _, Vt = torch.linalg.svd(Tc_w.T @ Sc_w, full_matrices=False) |
| R = U @ Vt |
|
|
| cos_after = F.cosine_similarity(Sc_w @ R.T, Tc_w, dim=-1).mean().item() |
|
|
| return { |
| "rotation": R, "source_mean": s_mean.squeeze(0), |
| "source_whitener": s_whiten, |
| "target_unwhitener": torch.linalg.pinv(t_whiten), |
| "cos_before": cos_before, "cos_after": cos_after, |
| } |
|
|
|
|
| def apply_align(emb, a): |
| x = emb.float() - a["source_mean"] |
| x = x @ a["source_whitener"] |
| x = x @ a["rotation"].T |
| x = x @ a["target_unwhitener"] |
| return x |
|
|
|
|
| def generate_consensus(embeds): |
| """Align all to bert space, take normalized centroid as target.""" |
| print(f"\n{'='*65}") |
| print("WHITENED PROCRUSTES ALIGNMENT + CONSENSUS") |
| print(f"{'='*65}") |
|
|
| ref_name = "bert" |
| names = [s for _, s, _ in MODELS] |
| aligned = {} |
|
|
| for name in names: |
| info = procrustes_align(embeds[name], embeds[ref_name]) |
| aligned[name] = apply_align(embeds[name], info) |
| label = " (ref)" if name == ref_name else "" |
| print(f" {name:10s}: cos {info['cos_before']:.4f} β {info['cos_after']:.4f}{label}") |
|
|
| |
| |
| |
| centroid = sum(aligned[n] for n in names) / len(names) |
| consensus = F.normalize(centroid, dim=-1) |
|
|
| |
| N_check = min(5000, consensus.shape[0]) |
| for name in names: |
| cos = F.cosine_similarity( |
| consensus[:N_check], aligned[name][:N_check], dim=-1).mean().item() |
| print(f" cos(consensus, {name:10s}): {cos:.4f}") |
|
|
| return consensus |
|
|
|
|
| |
| |
| |
|
|
| class CaptionEncoder(nn.Module): |
| def __init__(self, vocab_size=30522, max_len=128, d_model=384, |
| n_heads=6, n_layers=6, d_ff=1536, output_dim=768, |
| dropout=0.1, pad_token_id=0): |
| super().__init__() |
| self.pad_token_id = pad_token_id |
| self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id) |
| self.pos_emb = nn.Embedding(max_len, d_model) |
| self.emb_norm = nn.LayerNorm(d_model) |
| self.emb_drop = nn.Dropout(dropout) |
|
|
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=d_model, nhead=n_heads, dim_feedforward=d_ff, |
| dropout=dropout, activation="gelu", batch_first=True, |
| norm_first=True) |
| self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) |
|
|
| self.output_proj = nn.Sequential( |
| nn.Linear(d_model, d_model), |
| nn.GELU(), |
| nn.LayerNorm(d_model), |
| nn.Linear(d_model, output_dim), |
| ) |
|
|
| def forward(self, input_ids, attention_mask=None): |
| B, L = input_ids.shape |
| positions = torch.arange(L, device=input_ids.device).unsqueeze(0) |
| x = self.token_emb(input_ids) + self.pos_emb(positions) |
| x = self.emb_drop(self.emb_norm(x)) |
|
|
| if attention_mask is not None: |
| kpm = ~attention_mask.bool() |
| else: |
| kpm = (input_ids == self.pad_token_id) |
|
|
| x = self.encoder(x, src_key_padding_mask=kpm) |
|
|
| if attention_mask is not None: |
| mask = attention_mask.unsqueeze(-1).float() |
| else: |
| mask = (~kpm).unsqueeze(-1).float() |
| pooled = (x * mask).sum(1) / mask.sum(1).clamp(min=1) |
|
|
| return F.normalize(self.output_proj(pooled), dim=-1) |
|
|
|
|
| |
| |
| |
|
|
| def cayley_menger_vol2(pts): |
| pts = pts.float() |
| diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) |
| d2 = (diff * diff).sum(-1) |
| B, V, _ = d2.shape |
| cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32) |
| cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 |
| s = (-1.0)**V; f = math.factorial(V-1) |
| return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm) |
|
|
| def cv_loss(emb, target=0.084, n_samples=16): |
| B = emb.shape[0] |
| if B < 5: return torch.tensor(0.0, device=emb.device) |
| vols = [] |
| for _ in range(n_samples): |
| idx = torch.randperm(B, device=emb.device)[:5] |
| v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) |
| vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12)) |
| stacked = torch.stack(vols) |
| cv = stacked.std() / (stacked.mean() + 1e-8) |
| return (cv - target).abs() |
|
|
| def cv_metric(emb, n=200): |
| B = emb.shape[0] |
| if B < 5: return 0.0 |
| vols = [] |
| for _ in range(n): |
| idx = torch.randperm(B, device=emb.device)[:5] |
| v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) |
| v = torch.sqrt(F.relu(v2[0]) + 1e-12).item() |
| if v > 0: vols.append(v) |
| if len(vols) < 10: return 0.0 |
| a = np.array(vols) |
| return float(a.std() / (a.mean() + 1e-8)) |
|
|
| def infonce(a, b, temperature=0.07): |
| a = F.normalize(a, dim=-1) |
| b = F.normalize(b, dim=-1) |
| logits = (a @ b.T) / temperature |
| labels = torch.arange(logits.shape[0], device=logits.device) |
| loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2 |
| with torch.no_grad(): |
| acc = (logits.argmax(-1) == labels).float().mean().item() |
| return loss, acc |
|
|
|
|
| |
| |
| |
|
|
| def train(): |
| torch.manual_seed(CFG.seed) |
| torch.cuda.manual_seed_all(CFG.seed) |
| np.random.seed(CFG.seed) |
|
|
| |
| embeds, captions = extract_all() |
| consensus = generate_consensus(embeds) |
|
|
| |
| del embeds |
| torch.cuda.empty_cache() |
| import gc; gc.collect() |
|
|
| |
| from transformers import AutoTokenizer |
| tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") |
| print(f"\n Tokenizer: bert-base-uncased (vocab={tokenizer.vocab_size})") |
|
|
| print(" Pre-tokenizing...") |
| |
| all_ids, all_masks = [], [] |
| chunk = 50000 |
| for i in tqdm(range(0, len(captions), chunk), desc=" Tokenizing"): |
| j = min(i + chunk, len(captions)) |
| tokens = tokenizer(captions[i:j], max_length=CFG.tokenize_len, |
| padding="max_length", truncation=True, |
| return_tensors="pt") |
| all_ids.append(tokens["input_ids"]) |
| all_masks.append(tokens["attention_mask"]) |
|
|
| input_ids = torch.cat(all_ids) |
| attention_mask = torch.cat(all_masks) |
|
|
| real_lens = attention_mask.sum(1).float() |
| print(f" Token lengths: mean={real_lens.mean():.0f} " |
| f"median={real_lens.median():.0f} " |
| f">{CFG.tokenize_len}: {(real_lens >= CFG.tokenize_len).float().mean():.1%}") |
| print(f" Padded to: {CFG.tokenize_len} (model supports up to {CFG.max_len})") |
|
|
| |
| n_train = len(captions) - CFG.n_val |
| print(f" Train: {n_train:,}, Val: {CFG.n_val:,}") |
|
|
| |
| train_ids = input_ids[:n_train].to(DEVICE) |
| train_mask = attention_mask[:n_train].to(DEVICE) |
| train_targets = consensus[:n_train].to(DEVICE) |
| val_ids = input_ids[n_train:].to(DEVICE) |
| val_mask = attention_mask[n_train:].to(DEVICE) |
| val_targets = consensus[n_train:].to(DEVICE) |
|
|
| |
| print(f"\n{'='*65}") |
| print("STUDENT MODEL") |
| print(f"{'='*65}") |
|
|
| student = CaptionEncoder( |
| vocab_size=tokenizer.vocab_size, |
| max_len=CFG.max_len, |
| d_model=CFG.d_model, |
| n_heads=CFG.n_heads, |
| n_layers=CFG.n_layers, |
| d_ff=CFG.d_ff, |
| output_dim=CFG.output_dim, |
| dropout=CFG.dropout, |
| pad_token_id=tokenizer.pad_token_id, |
| ).to(DEVICE) |
|
|
| n_params = sum(p.numel() for p in student.parameters()) |
| print(f" Architecture: {CFG.n_layers}L, {CFG.d_model}d, {CFG.n_heads}h, {CFG.d_ff} FFN") |
| print(f" Output: {CFG.output_dim}-dim (consensus space)") |
| print(f" Parameters: {n_params:,}") |
| size_mb = sum(p.numel() * p.element_size() for p in student.parameters()) / 1e6 |
| print(f" Size: {size_mb:.1f} MB") |
|
|
| |
| for prev_dir in ["/home/claude/consensus_200k/student", |
| "/home/claude/distilled_consensus"]: |
| prev_ckpt = os.path.join(prev_dir, "best_model.pt") |
| if os.path.exists(prev_ckpt): |
| print(f"\n Warm-starting from: {prev_ckpt}") |
| prev_state = torch.load(prev_ckpt, weights_only=True, map_location=DEVICE) |
| current_state = student.state_dict() |
|
|
| loaded, extended, skipped = 0, 0, 0 |
| for name, param in prev_state.items(): |
| if name not in current_state: |
| skipped += 1 |
| continue |
| if param.shape == current_state[name].shape: |
| current_state[name] = param |
| loaded += 1 |
| elif "pos_emb" in name and param.shape[0] < current_state[name].shape[0]: |
| |
| old_len = param.shape[0] |
| current_state[name][:old_len] = param |
| nn.init.normal_(current_state[name][old_len:], std=0.02) |
| extended += 1 |
| print(f" Extended {name}: {param.shape[0]}β{current_state[name].shape[0]}") |
| else: |
| skipped += 1 |
|
|
| student.load_state_dict(current_state) |
| print(f" Loaded: {loaded}, Extended: {extended}, Skipped: {skipped}") |
| break |
| else: |
| print("\n Training from scratch (no previous checkpoint found)") |
|
|
| |
| optimizer = torch.optim.AdamW(student.parameters(), lr=CFG.lr, |
| weight_decay=CFG.weight_decay) |
| n_batches = n_train // CFG.batch_size |
| total_steps = n_batches * CFG.epochs |
| scheduler = torch.optim.lr_scheduler.SequentialLR( |
| optimizer, |
| [torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, |
| total_iters=CFG.warmup_steps), |
| torch.optim.lr_scheduler.CosineAnnealingLR( |
| optimizer, T_max=max(total_steps - CFG.warmup_steps, 1), |
| eta_min=1e-6)], |
| milestones=[CFG.warmup_steps]) |
|
|
| os.makedirs(CFG.cache_dir, exist_ok=True) |
| save_dir = os.path.join(CFG.cache_dir, "student") |
| os.makedirs(save_dir, exist_ok=True) |
|
|
| |
| print(f"\n{'='*65}") |
| print(f"TRAINING ({CFG.epochs} epochs, {n_batches} batches/epoch)") |
| print(f"{'='*65}") |
|
|
| all_metrics = {"config": {k: str(v) for k, v in vars(CFG).items()}, "epochs": []} |
| best_val_cos = 0.0 |
|
|
| for epoch in range(CFG.epochs): |
| student.train() |
| perm = torch.randperm(n_train, device=DEVICE) |
| losses = {"total": 0, "nce": 0, "mse": 0} |
| metrics = {"acc": 0, "cos": 0} |
| n = 0 |
| t0 = time.time() |
|
|
| for i in range(0, n_train, CFG.batch_size): |
| idx = perm[i:i+CFG.batch_size] |
| if len(idx) < 8: continue |
|
|
| emb = student(train_ids[idx], train_mask[idx]) |
| tgt = train_targets[idx] |
|
|
| l_nce, acc = infonce(emb, tgt) |
| l_mse = F.mse_loss(emb, tgt) |
| l_cv = cv_loss(emb, target=CFG.cv_target) |
|
|
| loss = CFG.nce_weight * l_nce + CFG.mse_weight * l_mse + CFG.cv_weight * l_cv |
|
|
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(student.parameters(), CFG.grad_clip) |
| optimizer.step() |
| optimizer.zero_grad(set_to_none=True) |
| scheduler.step() |
|
|
| with torch.no_grad(): |
| cos = F.cosine_similarity(emb, tgt, dim=-1).mean().item() |
|
|
| losses["total"] += loss.item() |
| losses["nce"] += l_nce.item() |
| losses["mse"] += l_mse.item() |
| metrics["acc"] += acc |
| metrics["cos"] += cos |
| n += 1 |
|
|
| elapsed = time.time() - t0 |
| d = max(n, 1) |
|
|
| |
| student.eval() |
| with torch.no_grad(): |
| val_embs = [] |
| for vi in range(0, CFG.n_val, 512): |
| vj = min(vi + 512, CFG.n_val) |
| ve = student(val_ids[vi:vj], val_mask[vi:vj]) |
| val_embs.append(ve) |
| val_emb = torch.cat(val_embs) |
| _, val_acc = infonce(val_emb[:2000], val_targets[:2000]) |
| val_cos = F.cosine_similarity(val_emb, val_targets, dim=-1).mean().item() |
| val_cv = cv_metric(val_emb[:2000]) |
|
|
| summary = { |
| "epoch": epoch + 1, "elapsed": elapsed, |
| "loss": losses["total"] / d, |
| "train_acc": metrics["acc"] / d, |
| "train_cos": metrics["cos"] / d, |
| "val_acc": val_acc, "val_cos": val_cos, "val_cv": val_cv, |
| } |
| all_metrics["epochs"].append(summary) |
|
|
| print(f" E{epoch+1:2d}: {elapsed:.0f}s " |
| f"loss={summary['loss']:.4f} " |
| f"t_acc={summary['train_acc']:.3f} t_cos={summary['train_cos']:.3f} " |
| f"v_acc={summary['val_acc']:.3f} v_cos={summary['val_cos']:.3f} " |
| f"v_cv={summary['val_cv']:.3f}") |
|
|
| if val_cos > best_val_cos: |
| best_val_cos = val_cos |
| torch.save(student.state_dict(), os.path.join(save_dir, "best_model.pt")) |
|
|
| if (epoch + 1) % 10 == 0: |
| torch.save(student.state_dict(), |
| os.path.join(save_dir, f"model_e{epoch+1:02d}.pt")) |
|
|
| |
| torch.save(student.state_dict(), os.path.join(save_dir, "final_model.pt")) |
| tokenizer.save_pretrained(os.path.join(save_dir, "tokenizer")) |
| with open(os.path.join(save_dir, "metrics.json"), "w") as f: |
| json.dump(all_metrics, f, indent=2, default=str) |
|
|
| |
| |
| |
|
|
| print(f"\n{'='*65}") |
| print("FINAL EVALUATION") |
| print(f"{'='*65}") |
|
|
| student.load_state_dict( |
| torch.load(os.path.join(save_dir, "best_model.pt"), |
| weights_only=True, map_location=DEVICE)) |
| student.eval() |
|
|
| with torch.no_grad(): |
| val_embs = [] |
| for vi in range(0, CFG.n_val, 512): |
| vj = min(vi + 512, CFG.n_val) |
| ve = student(val_ids[vi:vj], val_mask[vi:vj]) |
| val_embs.append(ve) |
| val_emb = torch.cat(val_embs) |
|
|
| |
| sub = min(2000, CFG.n_val) |
| sim = val_emb[:sub] @ val_targets[:sub].T |
| labels = torch.arange(sub, device=DEVICE) |
| r1 = (sim.argmax(1) == labels).float().mean().item() |
| r5 = (sim.topk(5, dim=1).indices == labels.unsqueeze(1)).any(1).float().mean().item() |
| r10 = (sim.topk(10, dim=1).indices == labels.unsqueeze(1)).any(1).float().mean().item() |
|
|
| cos_match = F.cosine_similarity(val_emb, val_targets, dim=-1).mean().item() |
| final_cv = cv_metric(val_emb[:2000]) |
|
|
| print(f" Retrieval (student β consensus):") |
| print(f" R@1: {r1:.4f}") |
| print(f" R@5: {r5:.4f}") |
| print(f" R@10: {r10:.4f}") |
| print(f" Cosine: {cos_match:.4f}") |
| print(f" CV: {final_cv:.4f} (target: {CFG.cv_target})") |
| print(f" Model: {n_params:,} params, {size_mb:.1f} MB") |
|
|
| |
| print(f"\n Standalone similarity test:") |
| test = [ |
| "A cat sitting on a windowsill watching birds", |
| "A golden retriever playing fetch on the beach", |
| "A still life painting with flowers and fruit", |
| "An aerial photograph of a city skyline at night", |
| "A child riding a bicycle through autumn leaves", |
| ] |
| with torch.no_grad(): |
| tok = tokenizer(test, max_length=CFG.tokenize_len, padding="max_length", |
| truncation=True, return_tensors="pt").to(DEVICE) |
| embs = student(tok["input_ids"], tok["attention_mask"]) |
| sim = embs @ embs.T |
| for i in range(len(test)): |
| for j in range(i+1, len(test)): |
| print(f" [{i}]β[{j}]: {sim[i,j]:.3f} " |
| f"({test[i][:35]}β{test[j][:35]})") |
|
|
| print(f"\n Saved to: {save_dir}/") |
| print(f"\n{'='*65}") |
| print("DONE") |
| print(f"{'='*65}") |
|
|
|
|
| if __name__ == "__main__": |
| train() |