Spaces:
Running on A10G
Running on A10G
Commit ·
d051a6a
1
Parent(s): b3eb082
Fix: Update Unsloth installation and improve path handling in training script
Browse files- Dockerfile.train +1 -5
- scripts/train_grpo.py +15 -4
Dockerfile.train
CHANGED
|
@@ -27,13 +27,9 @@ RUN pip install --no-cache-dir \
|
|
| 27 |
xformers \
|
| 28 |
--index-url https://download.pytorch.org/whl/cu121
|
| 29 |
|
| 30 |
-
# Install Unsloth and
|
| 31 |
RUN pip install --no-cache-dir \
|
| 32 |
"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git" \
|
| 33 |
-
trl \
|
| 34 |
-
peft \
|
| 35 |
-
accelerate \
|
| 36 |
-
bitsandbytes \
|
| 37 |
datasets \
|
| 38 |
wandb \
|
| 39 |
matplotlib \
|
|
|
|
| 27 |
xformers \
|
| 28 |
--index-url https://download.pytorch.org/whl/cu121
|
| 29 |
|
| 30 |
+
# Install Unsloth and let it resolve its own compatible TRL/PEFT stack.
|
| 31 |
RUN pip install --no-cache-dir \
|
| 32 |
"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git" \
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
datasets \
|
| 34 |
wandb \
|
| 35 |
matplotlib \
|
scripts/train_grpo.py
CHANGED
|
@@ -10,9 +10,11 @@ from datasets import Dataset, load_dataset
|
|
| 10 |
from trl import GRPOConfig, GRPOTrainer
|
| 11 |
from unsloth import FastLanguageModel, PatchFastRL
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
| 16 |
from commitguard_env.parse_action import parse_action
|
| 17 |
from commitguard_env.reward import compute_reward
|
| 18 |
|
|
@@ -23,7 +25,6 @@ MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.2-3B-Instruct")
|
|
| 23 |
OUTPUT_DIR = os.getenv("OUTPUT_DIR", "outputs/commitguard-llama-3b-grpo")
|
| 24 |
WANDB_PROJECT = os.getenv("WANDB_PROJECT", "commitguard")
|
| 25 |
|
| 26 |
-
REPO_ROOT = Path(__file__).resolve().parent.parent
|
| 27 |
CWE_KEYWORDS_PATH = REPO_ROOT / "data" / "cwe_keywords.json"
|
| 28 |
CWE_KEYWORDS: dict[str, list[str]] = {}
|
| 29 |
if CWE_KEYWORDS_PATH.exists():
|
|
@@ -100,6 +101,16 @@ def main():
|
|
| 100 |
ap.add_argument("--hub-model-id", type=str, default="inmodel-labs/commitguard-llama-3b")
|
| 101 |
args = ap.parse_args()
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
if not args.no_wandb:
|
| 104 |
wandb.init(project=WANDB_PROJECT, name=f"grpo-{MODEL_NAME.split('/')[-1]}-run1")
|
| 105 |
|
|
|
|
| 10 |
from trl import GRPOConfig, GRPOTrainer
|
| 11 |
from unsloth import FastLanguageModel, PatchFastRL
|
| 12 |
|
| 13 |
+
REPO_ROOT = Path(__file__).resolve().parent.parent
|
| 14 |
+
if str(REPO_ROOT) not in sys.path:
|
| 15 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 16 |
+
|
| 17 |
+
from agent_prompt import SYSTEM_PROMPT
|
| 18 |
from commitguard_env.parse_action import parse_action
|
| 19 |
from commitguard_env.reward import compute_reward
|
| 20 |
|
|
|
|
| 25 |
OUTPUT_DIR = os.getenv("OUTPUT_DIR", "outputs/commitguard-llama-3b-grpo")
|
| 26 |
WANDB_PROJECT = os.getenv("WANDB_PROJECT", "commitguard")
|
| 27 |
|
|
|
|
| 28 |
CWE_KEYWORDS_PATH = REPO_ROOT / "data" / "cwe_keywords.json"
|
| 29 |
CWE_KEYWORDS: dict[str, list[str]] = {}
|
| 30 |
if CWE_KEYWORDS_PATH.exists():
|
|
|
|
| 101 |
ap.add_argument("--hub-model-id", type=str, default="inmodel-labs/commitguard-llama-3b")
|
| 102 |
args = ap.parse_args()
|
| 103 |
|
| 104 |
+
if args.num_generations < 2:
|
| 105 |
+
raise ValueError("--num-generations must be at least 2 for GRPO")
|
| 106 |
+
effective_batch = args.batch_size * args.grad_accum
|
| 107 |
+
if effective_batch % args.num_generations != 0:
|
| 108 |
+
raise ValueError(
|
| 109 |
+
"For single-process GRPO training, --batch-size * --grad-accum "
|
| 110 |
+
f"must be divisible by --num-generations; got {args.batch_size} * "
|
| 111 |
+
f"{args.grad_accum} = {effective_batch}, num_generations={args.num_generations}."
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
if not args.no_wandb:
|
| 115 |
wandb.init(project=WANDB_PROJECT, name=f"grpo-{MODEL_NAME.split('/')[-1]}-run1")
|
| 116 |
|