Nitishkumar-ai commited on
Commit
d051a6a
·
1 Parent(s): b3eb082

Fix: Update Unsloth installation and improve path handling in training script

Browse files
Files changed (2) hide show
  1. Dockerfile.train +1 -5
  2. scripts/train_grpo.py +15 -4
Dockerfile.train CHANGED
@@ -27,13 +27,9 @@ RUN pip install --no-cache-dir \
27
  xformers \
28
  --index-url https://download.pytorch.org/whl/cu121
29
 
30
- # Install Unsloth and other training dependencies
31
  RUN pip install --no-cache-dir \
32
  "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git" \
33
- trl \
34
- peft \
35
- accelerate \
36
- bitsandbytes \
37
  datasets \
38
  wandb \
39
  matplotlib \
 
27
  xformers \
28
  --index-url https://download.pytorch.org/whl/cu121
29
 
30
+ # Install Unsloth and let it resolve its own compatible TRL/PEFT stack.
31
  RUN pip install --no-cache-dir \
32
  "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git" \
 
 
 
 
33
  datasets \
34
  wandb \
35
  matplotlib \
scripts/train_grpo.py CHANGED
@@ -10,9 +10,11 @@ from datasets import Dataset, load_dataset
10
  from trl import GRPOConfig, GRPOTrainer
11
  from unsloth import FastLanguageModel, PatchFastRL
12
 
13
- sys.path.insert(0, str(Path(__file__).resolve().parent))
14
- sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
15
- from agent_prompt import SYSTEM_PROMPT, get_agent_prompt
 
 
16
  from commitguard_env.parse_action import parse_action
17
  from commitguard_env.reward import compute_reward
18
 
@@ -23,7 +25,6 @@ MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.2-3B-Instruct")
23
  OUTPUT_DIR = os.getenv("OUTPUT_DIR", "outputs/commitguard-llama-3b-grpo")
24
  WANDB_PROJECT = os.getenv("WANDB_PROJECT", "commitguard")
25
 
26
- REPO_ROOT = Path(__file__).resolve().parent.parent
27
  CWE_KEYWORDS_PATH = REPO_ROOT / "data" / "cwe_keywords.json"
28
  CWE_KEYWORDS: dict[str, list[str]] = {}
29
  if CWE_KEYWORDS_PATH.exists():
@@ -100,6 +101,16 @@ def main():
100
  ap.add_argument("--hub-model-id", type=str, default="inmodel-labs/commitguard-llama-3b")
101
  args = ap.parse_args()
102
 
 
 
 
 
 
 
 
 
 
 
103
  if not args.no_wandb:
104
  wandb.init(project=WANDB_PROJECT, name=f"grpo-{MODEL_NAME.split('/')[-1]}-run1")
105
 
 
10
  from trl import GRPOConfig, GRPOTrainer
11
  from unsloth import FastLanguageModel, PatchFastRL
12
 
13
+ REPO_ROOT = Path(__file__).resolve().parent.parent
14
+ if str(REPO_ROOT) not in sys.path:
15
+ sys.path.insert(0, str(REPO_ROOT))
16
+
17
+ from agent_prompt import SYSTEM_PROMPT
18
  from commitguard_env.parse_action import parse_action
19
  from commitguard_env.reward import compute_reward
20
 
 
25
  OUTPUT_DIR = os.getenv("OUTPUT_DIR", "outputs/commitguard-llama-3b-grpo")
26
  WANDB_PROJECT = os.getenv("WANDB_PROJECT", "commitguard")
27
 
 
28
  CWE_KEYWORDS_PATH = REPO_ROOT / "data" / "cwe_keywords.json"
29
  CWE_KEYWORDS: dict[str, list[str]] = {}
30
  if CWE_KEYWORDS_PATH.exists():
 
101
  ap.add_argument("--hub-model-id", type=str, default="inmodel-labs/commitguard-llama-3b")
102
  args = ap.parse_args()
103
 
104
+ if args.num_generations < 2:
105
+ raise ValueError("--num-generations must be at least 2 for GRPO")
106
+ effective_batch = args.batch_size * args.grad_accum
107
+ if effective_batch % args.num_generations != 0:
108
+ raise ValueError(
109
+ "For single-process GRPO training, --batch-size * --grad-accum "
110
+ f"must be divisible by --num-generations; got {args.batch_size} * "
111
+ f"{args.grad_accum} = {effective_batch}, num_generations={args.num_generations}."
112
+ )
113
+
114
  if not args.no_wandb:
115
  wandb.init(project=WANDB_PROJECT, name=f"grpo-{MODEL_NAME.split('/')[-1]}-run1")
116