| |
| """ |
| Train a separate REINFORCE agent for each protocol set (e.g. PCR, ELISA). |
| |
| Each protocol has its own presets and outcome model. Training one agent per |
| protocol gives you a policy tailored to that protocol's action/observation |
| space. Checkpoints are saved under checkpoints/<workflow_id>.pt. |
| |
| Usage: |
| python scripts/train_per_protocol.py --workflows pcr-amplification elisa-readout |
| python scripts/train_per_protocol.py --workflows pcr-amplification --train-episodes 1000 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import sys |
| from pathlib import Path |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
|
|
| from lab_env.env import LabEnv |
| from lab_env.spec import get_spec_for_workflow |
| from agents.rl_agent import ReinforceAgent |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser( |
| description="Train one RL agent per protocol set (different presets / specs)" |
| ) |
| parser.add_argument( |
| "--workflows", |
| nargs="+", |
| default=["pcr-amplification", "elisa-readout"], |
| help="Workflow IDs to train (each gets its own agent and checkpoint)", |
| ) |
| parser.add_argument("--train-episodes", type=int, default=1500) |
| parser.add_argument("--eval-episodes", type=int, default=50) |
| parser.add_argument("--lr", type=float, default=3e-3) |
| parser.add_argument("--max-trials", type=int, default=4) |
| parser.add_argument("--checkpoint-dir", type=str, default="checkpoints") |
| parser.add_argument("--seed", type=int, default=42) |
| args = parser.parse_args() |
|
|
| Path(args.checkpoint_dir).mkdir(parents=True, exist_ok=True) |
|
|
| for workflow_id in args.workflows: |
| spec = get_spec_for_workflow(workflow_id) |
| env = LabEnv(spec=spec) |
| agent = ReinforceAgent( |
| lr=args.lr, |
| max_trials=args.max_trials, |
| spec=spec, |
| ) |
|
|
| print(f"\n{'='*60}") |
| print(f" Training for protocol: {workflow_id} (presets={spec.num_presets}, obs_dim={spec.obs_dim})") |
| print("=" * 60) |
|
|
| for ep in range(1, args.train_episodes + 1): |
| result = agent.run_episode(env, seed=args.seed + ep, train=True) |
| if ep % 200 == 0 or ep == args.train_episodes: |
| print(f" Episode {ep:5d} | reward: {result['reward']:7.1f} | success: {result['success']}") |
|
|
| checkpoint_path = Path(args.checkpoint_dir) / f"{workflow_id}.pt" |
| agent.save(str(checkpoint_path)) |
| print(f" Saved checkpoint: {checkpoint_path}") |
|
|
| |
| successes = 0 |
| for i in range(args.eval_episodes): |
| r = agent.run_episode(env, seed=999_000 + i, train=False) |
| successes += r["success"] |
| print(f" Eval success rate: {successes / args.eval_episodes:.0%}") |
|
|
| env.close() |
|
|
| print("\nDone. Use each checkpoint with LabEnv(spec=<same_spec>) and ReinforceAgent(spec=spec).load(path).") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|