Cludoy commited on
Commit
bb1c80c
·
verified ·
1 Parent(s): d1625b6

Add auto_trainer.py

Browse files
Files changed (1) hide show
  1. auto_trainer.py +96 -0
auto_trainer.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import shutil
4
+ import subprocess
5
+
6
+ STATE_FILE = "pipeline_state.json"
7
+ RETRAIN_THRESHOLD = 50
8
+ MODEL_PROD_PATH = "prod_tinybert.pt"
9
+ MODEL_NEW_STAGE_PATH = "best_tinybert.pt"
10
+
11
+ def load_state():
12
+ if os.path.exists(STATE_FILE):
13
+ with open(STATE_FILE, "r") as f:
14
+ return json.load(f)
15
+ return {"sessions_since_last_train": 0, "total_sessions": 0}
16
+
17
+ def save_state(state):
18
+ with open(STATE_FILE, "w") as f:
19
+ json.dump(state, f, indent=4)
20
+
21
+ def run_training_pipeline():
22
+ print("\n" + "=" * 50)
23
+ print(">>> Auto-Trainer: Triggering Retraining Pipeline")
24
+ print("=" * 50)
25
+
26
+ print("\n[Step 1] Running data generation (dataset_generator.py)...")
27
+ result = subprocess.run(["python", "dataset_generator.py"], capture_output=True, text=True, encoding="utf-8")
28
+ if result.returncode != 0:
29
+ print("[!] Data pipeline failed:")
30
+ print(result.stderr)
31
+ return False
32
+ print("[+] Data pipeline finished.")
33
+
34
+ print("\n[Step 2] Running training (train.py)...")
35
+ result = subprocess.run(["python", "train.py"], capture_output=True, text=True, encoding="utf-8")
36
+ if result.returncode != 0:
37
+ print("[!] Training failed:")
38
+ print(result.stderr)
39
+ return False
40
+ print("[+] Training finished.")
41
+
42
+ print("\n[Step 3] Validating model quality...")
43
+ if os.path.exists('training_results.json'):
44
+ with open('training_results.json', 'r') as f:
45
+ results = json.load(f)
46
+ metrics = results.get("metrics", {})
47
+ acc = metrics.get("accuracy", 0.0)
48
+ f1 = metrics.get("f1_score", 0.0)
49
+
50
+ print(f"New model validation: Accuracy={acc*100:.2f}%, F1={f1*100:.2f}%")
51
+
52
+ # Validation logic:
53
+ # 1. Must meet minimum quality bar (80% acc, 80% F1)
54
+ # 2. Perfect 100% on test set = pure memorization (reject)
55
+ if acc >= 1.0:
56
+ print(f"[!] Perfect 100% test accuracy. Likely memorization. Rejecting model.")
57
+ return False
58
+ elif acc >= 0.80 and f1 >= 0.80:
59
+ print(f"[+] Metrics meet quality bar. Promoting model to production.")
60
+ if os.path.exists(MODEL_NEW_STAGE_PATH):
61
+ shutil.copy(MODEL_NEW_STAGE_PATH, MODEL_PROD_PATH)
62
+ print(f"[+] Model published to {MODEL_PROD_PATH}")
63
+ return True
64
+ else:
65
+ print(f"[!] Metrics below quality bar. Rejecting model.")
66
+ return False
67
+ else:
68
+ print("[!] Could not find training_results.json.")
69
+ return False
70
+
71
+ def add_session_and_check():
72
+ state = load_state()
73
+ state["sessions_since_last_train"] += 1
74
+ state["total_sessions"] += 1
75
+
76
+ print(f"Logged new session. (Total since train: {state['sessions_since_last_train']})")
77
+
78
+ if state["sessions_since_last_train"] >= RETRAIN_THRESHOLD:
79
+ print("\nThreshold reached! Starting training pipeline...")
80
+ success = run_training_pipeline()
81
+
82
+ if success:
83
+ state["sessions_since_last_train"] = 0
84
+ print("Resetting sessions counter.")
85
+ else:
86
+ print("Retaining count. Will try again on next session.")
87
+
88
+ save_state(state)
89
+ return state
90
+
91
+ if __name__ == "__main__":
92
+ import sys
93
+ if len(sys.argv) > 1 and sys.argv[1] == "--force-train":
94
+ run_training_pipeline()
95
+ else:
96
+ add_session_and_check()