therealestcoder commited on
Commit
53ec820
·
verified ·
1 Parent(s): 0d3fe75

Upload src\train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src//train.py +106 -0
src//train.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Обучение модели детекции дефектов окраски кузова.
2
+
3
+ Запуск: python -m src.train
4
+ """
5
+ from __future__ import annotations
6
+ import time
7
+ import json
8
+ from pathlib import Path
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.optim import AdamW
14
+ from torch.optim.lr_scheduler import CosineAnnealingLR
15
+ from sklearn.metrics import roc_auc_score, f1_score, confusion_matrix
16
+ from tqdm import tqdm
17
+
18
+ from . import config as C
19
+ from .dataset import make_loaders
20
+ from .model import build_model
21
+
22
+
23
+ def set_seed(seed: int) -> None:
24
+ import random
25
+ random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
26
+ torch.cuda.manual_seed_all(seed)
27
+
28
+
29
+ def evaluate(model: nn.Module, loader, device) -> dict:
30
+ model.eval()
31
+ all_p, all_y = [], []
32
+ with torch.no_grad():
33
+ for x, y in loader:
34
+ x = x.to(device, non_blocking=True)
35
+ logits = model(x)
36
+ prob = torch.softmax(logits, dim=1)[:, 1]
37
+ all_p.append(prob.cpu().numpy())
38
+ all_y.append(y.numpy())
39
+ p = np.concatenate(all_p); y = np.concatenate(all_y)
40
+ pred = (p >= C.DEFECT_THRESHOLD).astype(int)
41
+ metrics = {
42
+ "auc": float(roc_auc_score(y, p)) if len(np.unique(y)) > 1 else float("nan"),
43
+ "f1": float(f1_score(y, pred, zero_division=0)),
44
+ "acc": float((pred == y).mean()),
45
+ "cm": confusion_matrix(y, pred, labels=[0, 1]).tolist(),
46
+ }
47
+ return metrics
48
+
49
+
50
+ def main() -> None:
51
+ set_seed(C.SEED)
52
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
53
+ print(f"Устройство: {device}")
54
+
55
+ train_loader, val_loader = make_loaders()
56
+ model = build_model(pretrained=True).to(device)
57
+
58
+ optim = AdamW(model.parameters(), lr=C.LR, weight_decay=C.WEIGHT_DECAY)
59
+ sched = CosineAnnealingLR(optim, T_max=C.EPOCHS)
60
+ criterion = nn.CrossEntropyLoss(label_smoothing=C.LABEL_SMOOTH)
61
+
62
+ C.CHECKPOINTS.mkdir(parents=True, exist_ok=True)
63
+ C.RUNS.mkdir(parents=True, exist_ok=True)
64
+ history = []
65
+ best_score = -1.0
66
+ best_path = C.CHECKPOINTS / "best.pt"
67
+
68
+ for epoch in range(1, C.EPOCHS + 1):
69
+ model.train()
70
+ running = 0.0
71
+ n = 0
72
+ t0 = time.time()
73
+ pbar = tqdm(train_loader, desc=f"Эпоха {epoch}/{C.EPOCHS}")
74
+ for x, y in pbar:
75
+ x = x.to(device, non_blocking=True); y = y.to(device, non_blocking=True)
76
+ optim.zero_grad(set_to_none=True)
77
+ logits = model(x)
78
+ loss = criterion(logits, y)
79
+ loss.backward()
80
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
81
+ optim.step()
82
+ running += float(loss.item()) * x.size(0); n += x.size(0)
83
+ pbar.set_postfix(loss=f"{running / n:.4f}")
84
+ sched.step()
85
+
86
+ metrics = evaluate(model, val_loader, device)
87
+ score = metrics["auc"] if not np.isnan(metrics["auc"]) else metrics["f1"]
88
+ elapsed = time.time() - t0
89
+ print(f" val: AUC={metrics['auc']:.3f} F1={metrics['f1']:.3f} "
90
+ f"acc={metrics['acc']:.3f} cm={metrics['cm']} ({elapsed:.1f}s)")
91
+ history.append({"epoch": epoch, "train_loss": running / max(n, 1), **metrics})
92
+
93
+ if score > best_score:
94
+ best_score = score
95
+ torch.save({"model": model.state_dict(),
96
+ "backbone": C.BACKBONE,
97
+ "img_size": C.IMG_SIZE,
98
+ "metrics": metrics}, best_path)
99
+ print(f" ✓ сохранён лучший чекпоинт {best_path.name} (score={best_score:.3f})")
100
+
101
+ (C.RUNS / "history.json").write_text(json.dumps(history, indent=2, ensure_ascii=False))
102
+ print(f"\nГотово. Лучший score: {best_score:.3f}\nЧекпоинт: {best_path}")
103
+
104
+
105
+ if __name__ == "__main__":
106
+ main()