CompactAI commited on
Commit
b3fe8dc
·
verified ·
1 Parent(s): f63dfdb

Delete train.py

Browse files
Files changed (1) hide show
  1. train.py +0 -305
train.py DELETED
@@ -1,305 +0,0 @@
1
- """
2
- AIFinder Training Script
3
- Loads data, trains a two-headed GPU classifier, reports metrics, and saves the model.
4
-
5
- Usage: python3 train.py
6
- """
7
-
8
- import os
9
- import sys
10
- import time
11
- import joblib
12
- import numpy as np
13
- import torch
14
- import torch.nn as nn
15
- from torch.utils.data import TensorDataset, DataLoader
16
- from sklearn.model_selection import train_test_split
17
- from sklearn.metrics import classification_report
18
- from sklearn.preprocessing import LabelEncoder
19
- from sklearn.utils.class_weight import compute_class_weight
20
-
21
- from config import (
22
- MODEL_DIR,
23
- TEST_SIZE,
24
- RANDOM_STATE,
25
- HIDDEN_DIM,
26
- EMBED_DIM,
27
- DROPOUT,
28
- BATCH_SIZE,
29
- EPOCHS,
30
- LEARNING_RATE,
31
- WEIGHT_DECAY,
32
- EARLY_STOP_PATIENCE,
33
- )
34
- from data_loader import load_all_data
35
- from features import FeaturePipeline
36
- from model import AIFinderNet
37
-
38
-
39
- def _log(msg, t0=None):
40
- """Print a timestamped log message, optionally with elapsed time."""
41
- ts = time.strftime("%H:%M:%S")
42
- if t0 is not None:
43
- elapsed = time.time() - t0
44
- print(f" [{ts}] {msg} ({elapsed:.1f}s)")
45
- else:
46
- print(f" [{ts}] {msg}")
47
-
48
-
49
- def main():
50
- t_start = time.time()
51
-
52
- print("=" * 60)
53
- print("AIFinder Training - Provider Classification")
54
- print("=" * 60)
55
-
56
- # ── GPU check ──────────────────────────────────────────────
57
- if torch.cuda.is_available():
58
- device = torch.device("cuda")
59
- gpu_name = torch.cuda.get_device_name(0)
60
- gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1024**3
61
- _log(f"GPU: {gpu_name} ({gpu_mem:.1f} GB)")
62
- else:
63
- device = torch.device("cpu")
64
- _log("No GPU available, using CPU")
65
-
66
- # ── Load data ──────────────────────────────────────────────
67
- _log("Starting data load...")
68
- t0 = time.time()
69
- texts, providers, models, _is_ai = load_all_data()
70
- _log("Data load complete", t0)
71
-
72
- if len(texts) < 100:
73
- print("ERROR: Not enough data loaded. Check dataset access.")
74
- sys.exit(1)
75
-
76
- # ── Encode labels ──────────────────────────────────────────
77
- _log("Encoding labels...")
78
- t0 = time.time()
79
- provider_enc = LabelEncoder()
80
- provider_labels = provider_enc.fit_transform(providers)
81
- num_providers = len(provider_enc.classes_)
82
- _log(f"Labels encoded — {num_providers} providers", t0)
83
-
84
- # ── Train/test split ───────────────────────────────────────
85
- _log("Splitting train/test...")
86
- t0 = time.time()
87
- indices = np.arange(len(texts))
88
- train_idx, test_idx = train_test_split(
89
- indices,
90
- test_size=TEST_SIZE,
91
- random_state=RANDOM_STATE,
92
- stratify=provider_labels,
93
- )
94
- train_texts = [texts[i] for i in train_idx]
95
- test_texts = [texts[i] for i in test_idx]
96
- _log(f"Split: {len(train_texts)} train / {len(test_texts)} test", t0)
97
-
98
- # ── Build features ─────────────────────────────────────────
99
- _log("Building feature pipeline (fit on train)...")
100
- t0 = time.time()
101
- pipeline = FeaturePipeline()
102
- X_train = pipeline.fit_transform(train_texts)
103
- _log(f"Train features: {X_train.shape}", t0)
104
-
105
- _log("Transforming test set...")
106
- t0 = time.time()
107
- X_test = pipeline.transform(test_texts)
108
- _log(f"Test features: {X_test.shape}", t0)
109
-
110
- input_dim = X_train.shape[1]
111
-
112
- # ── Move to device ─────────────────────────────────────────
113
- _log(f"Moving data to {device}...")
114
- t0 = time.time()
115
- X_train_t = torch.tensor(X_train.toarray(), dtype=torch.float32).to(device)
116
- X_test_t = torch.tensor(X_test.toarray(), dtype=torch.float32).to(device)
117
- y_prov_train = torch.tensor(provider_labels[train_idx], dtype=torch.long).to(device)
118
- y_prov_test = torch.tensor(provider_labels[test_idx], dtype=torch.long).to(device)
119
- if device.type == "cuda":
120
- mem_used = torch.cuda.memory_allocated() / 1024**3
121
- _log(f"GPU memory used: {mem_used:.2f} GB", t0)
122
- else:
123
- _log(f"Data on {device}", t0)
124
-
125
- # ── DataLoaders ────────────────────────────────────────────
126
- batch_size = min(BATCH_SIZE, 512) if device.type == "cpu" else BATCH_SIZE
127
- train_ds = TensorDataset(X_train_t, y_prov_train)
128
- train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
129
- val_ds = TensorDataset(X_test_t, y_prov_test)
130
- val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
131
-
132
- # ── Model ────���─────────────────────────────────────────────
133
- _log("Building model...")
134
- net = AIFinderNet(
135
- input_dim=input_dim,
136
- num_providers=num_providers,
137
- hidden_dim=HIDDEN_DIM,
138
- embed_dim=EMBED_DIM,
139
- dropout=DROPOUT,
140
- ).to(device)
141
- n_params = sum(p.numel() for p in net.parameters())
142
- _log(f"Model: {n_params:,} parameters")
143
-
144
- # ── Class-weighted loss ────────────────────────────────────
145
- prov_weights = compute_class_weight(
146
- "balanced", classes=np.arange(num_providers), y=provider_labels[train_idx]
147
- )
148
- prov_criterion = nn.CrossEntropyLoss(
149
- weight=torch.tensor(prov_weights, dtype=torch.float32).to(device)
150
- )
151
-
152
- # ── Optimizer + scheduler ──────────────────────────────────
153
- optimizer = torch.optim.AdamW(
154
- net.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
155
- )
156
- scheduler = torch.optim.lr_scheduler.OneCycleLR(
157
- optimizer,
158
- max_lr=LEARNING_RATE,
159
- epochs=EPOCHS,
160
- steps_per_epoch=len(train_loader),
161
- )
162
- use_amp = device.type == "cuda"
163
- scaler = torch.amp.GradScaler() if use_amp else None
164
-
165
- # ── Training loop ──────────────────────────────────────────
166
- _log(
167
- f"Training for {EPOCHS} epochs, batch_size={batch_size}, "
168
- f"early_stop_patience={EARLY_STOP_PATIENCE}..."
169
- )
170
- t0 = time.time()
171
-
172
- best_val_loss = float("inf")
173
- best_state = None
174
- patience_counter = 0
175
-
176
- for epoch in range(EPOCHS):
177
- # ── Train phase ───────────────────────────────────────
178
- net.train()
179
- epoch_loss = 0.0
180
- n_batches = 0
181
-
182
- for batch_X, batch_prov in train_loader:
183
- optimizer.zero_grad(set_to_none=True)
184
- if use_amp:
185
- with torch.amp.autocast(device_type="cuda"):
186
- prov_logits = net(batch_X)
187
- loss = prov_criterion(prov_logits, batch_prov)
188
- scaler.scale(loss).backward()
189
- scaler.unscale_(optimizer)
190
- torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
191
- scaler.step(optimizer)
192
- scaler.update()
193
- else:
194
- prov_logits = net(batch_X)
195
- loss = prov_criterion(prov_logits, batch_prov)
196
- loss.backward()
197
- torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
198
- optimizer.step()
199
- scheduler.step()
200
- epoch_loss += loss.item()
201
- n_batches += 1
202
-
203
- avg_train_loss = epoch_loss / n_batches
204
-
205
- # ── Validation phase ──────────────────────────────────
206
- net.eval()
207
- val_loss = 0.0
208
- val_batches = 0
209
- with torch.no_grad():
210
- for batch_X, batch_prov in val_loader:
211
- prov_logits = net(batch_X)
212
- loss = prov_criterion(prov_logits, batch_prov)
213
- val_loss += loss.item()
214
- val_batches += 1
215
- avg_val_loss = val_loss / val_batches
216
-
217
- # ── Early stopping check ──────────────────────────────
218
- if avg_val_loss < best_val_loss:
219
- best_val_loss = avg_val_loss
220
- best_state = {k: v.clone() for k, v in net.state_dict().items()}
221
- patience_counter = 0
222
- else:
223
- patience_counter += 1
224
-
225
- # ── Logging ───────────────────────────────────────────
226
- if (epoch + 1) % 5 == 0 or epoch == 0:
227
- lr = scheduler.get_last_lr()[0]
228
- marker = " *" if patience_counter == 0 else ""
229
- _log(
230
- f"Epoch {epoch + 1:>3d}/{EPOCHS} "
231
- f"train={avg_train_loss:.4f} "
232
- f"val={avg_val_loss:.4f} "
233
- f"lr={lr:.2e}{marker}"
234
- )
235
-
236
- if patience_counter >= EARLY_STOP_PATIENCE:
237
- _log(
238
- f"Early stopping at epoch {epoch + 1} "
239
- f"(best val_loss={best_val_loss:.4f})"
240
- )
241
- break
242
-
243
- # Restore best weights
244
- if best_state is not None:
245
- net.load_state_dict(best_state)
246
- _log(f"Restored best weights (val_loss={best_val_loss:.4f})")
247
-
248
- _log("Training complete", t0)
249
-
250
- # ── Evaluate ───────────────────────────────────────────────
251
- _log("Evaluating...")
252
- net.eval()
253
- with torch.no_grad():
254
- prov_logits = net(X_test_t)
255
-
256
- prov_preds = prov_logits.argmax(dim=1).cpu().numpy()
257
- prov_true = y_prov_test.cpu().numpy()
258
-
259
- print("\n === Provider Classification ===")
260
- print(
261
- classification_report(
262
- prov_true,
263
- prov_preds,
264
- target_names=provider_enc.classes_,
265
- zero_division=0,
266
- )
267
- )
268
-
269
- # ── Save ───────────────────────────────────────────────────
270
- _log(f"Saving to {MODEL_DIR}/ ...")
271
- t0 = time.time()
272
- os.makedirs(MODEL_DIR, exist_ok=True)
273
-
274
- checkpoint = {
275
- "input_dim": input_dim,
276
- "num_providers": num_providers,
277
- "hidden_dim": HIDDEN_DIM,
278
- "embed_dim": EMBED_DIM,
279
- "dropout": DROPOUT,
280
- "state_dict": net.state_dict(),
281
- }
282
- torch.save(checkpoint, os.path.join(MODEL_DIR, "classifier.pt"))
283
- _log(" Saved classifier.pt")
284
-
285
- joblib.dump(pipeline, os.path.join(MODEL_DIR, "feature_pipeline.joblib"))
286
- _log(" Saved feature_pipeline.joblib")
287
- joblib.dump(provider_enc, os.path.join(MODEL_DIR, "provider_enc.joblib"))
288
- _log(" Saved provider_enc.joblib")
289
-
290
- _log("All artifacts saved", t0)
291
-
292
- elapsed = time.time() - t_start
293
- if device.type == "cuda":
294
- mem_peak = torch.cuda.max_memory_allocated() / 1024**3
295
- print(f"\n{'=' * 60}")
296
- print(f"Training complete in {elapsed:.1f}s (peak GPU mem: {mem_peak:.2f} GB)")
297
- print(f"{'=' * 60}")
298
- else:
299
- print(f"\n{'=' * 60}")
300
- print(f"Training complete in {elapsed:.1f}s")
301
- print(f"{'=' * 60}")
302
-
303
-
304
- if __name__ == "__main__":
305
- main()