Spaces:
Sleeping
Sleeping
Commit ·
e5fe6f5
1
Parent(s): 1544ce8
feat: enhance SFT training process with new tokenization method, implement custom trainer class for loss computation, and update README with GRPO launcher details for Unsloth LoRA integration
Browse files- README.md +11 -4
- scripts/modal_train_grpo.py +57 -2
- scripts/modal_train_sft.py +57 -2
- tests/test_modal_scenario_cache_static.py +8 -1
README.md
CHANGED
|
@@ -335,13 +335,20 @@ reward metadata passes. The default SFT config trains the full dataset
|
|
| 335 |
(`--max-steps -1`) with bf16/tf32, LoRA rank 32, and Modal GPU fallback
|
| 336 |
`H200 -> H100 -> A100-80GB -> L40S`. TRL does not support packing or
|
| 337 |
assistant-only loss for the Gemma 4 vision-language loader, so both remain
|
| 338 |
-
disabled for this model.
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
|
|
|
|
|
|
|
|
|
| 342 |
|
| 343 |
Continue GRPO from the SFT LoRA:
|
| 344 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
```bash
|
| 346 |
uv run --extra modal modal run --detach scripts/modal_train_grpo.py \
|
| 347 |
--initial-adapter-repo-id Humanlearning/CyberSecurity_OWASP-unsloth-gemma-4-e2b-it-sft-lora \
|
|
|
|
| 335 |
(`--max-steps -1`) with bf16/tf32, LoRA rank 32, and Modal GPU fallback
|
| 336 |
`H200 -> H100 -> A100-80GB -> L40S`. TRL does not support packing or
|
| 337 |
assistant-only loss for the Gemma 4 vision-language loader, so both remain
|
| 338 |
+
disabled for this model. The script pre-tokenizes the small JSONL dataset
|
| 339 |
+
serially before constructing `SFTTrainer`, which avoids TRL multiprocessing
|
| 340 |
+
around the Gemma/Unsloth config object. It also uses the base Transformers loss
|
| 341 |
+
path to avoid a TRL entropy-metric incompatibility with Gemma 4 lazy logits. A
|
| 342 |
+
warm run for the 300-400 episode dataset should usually finish in about 20-60
|
| 343 |
+
minutes; first image or model-cache builds can push that closer to 45-90
|
| 344 |
+
minutes.
|
| 345 |
|
| 346 |
Continue GRPO from the SFT LoRA:
|
| 347 |
|
| 348 |
+
The GRPO launcher downloads the Hub adapter, attaches a matching trainable
|
| 349 |
+
Unsloth LoRA to Gemma 4, and then loads the adapter safetensors. This keeps the
|
| 350 |
+
SFT handoff compatible with Gemma 4's Unsloth linear wrappers.
|
| 351 |
+
|
| 352 |
```bash
|
| 353 |
uv run --extra modal modal run --detach scripts/modal_train_grpo.py \
|
| 354 |
--initial-adapter-repo-id Humanlearning/CyberSecurity_OWASP-unsloth-gemma-4-e2b-it-sft-lora \
|
scripts/modal_train_grpo.py
CHANGED
|
@@ -1081,11 +1081,12 @@ def train_cybersecurity_owasp_grpo(
|
|
| 1081 |
trace_log_every = max(0, int(trace_log_every))
|
| 1082 |
|
| 1083 |
import torch
|
|
|
|
| 1084 |
from unsloth import FastVisionModel
|
| 1085 |
import transformers.utils.hub as transformers_hub
|
| 1086 |
from datasets import Dataset
|
| 1087 |
from huggingface_hub import snapshot_download, whoami
|
| 1088 |
-
from peft import
|
| 1089 |
from transformers import TrainerCallback
|
| 1090 |
from trl import GRPOConfig, GRPOTrainer, clone_chat_template
|
| 1091 |
try:
|
|
@@ -1869,7 +1870,61 @@ def train_cybersecurity_owasp_grpo(
|
|
| 1869 |
cache_volume.commit()
|
| 1870 |
if adapter_source:
|
| 1871 |
print(f"Loading initial SFT adapter for trainable GRPO continuation: {adapter_source}")
|
| 1872 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1873 |
if hasattr(model, "print_trainable_parameters"):
|
| 1874 |
model.print_trainable_parameters()
|
| 1875 |
else:
|
|
|
|
| 1081 |
trace_log_every = max(0, int(trace_log_every))
|
| 1082 |
|
| 1083 |
import torch
|
| 1084 |
+
from safetensors.torch import load_file as load_safetensors_file
|
| 1085 |
from unsloth import FastVisionModel
|
| 1086 |
import transformers.utils.hub as transformers_hub
|
| 1087 |
from datasets import Dataset
|
| 1088 |
from huggingface_hub import snapshot_download, whoami
|
| 1089 |
+
from peft import set_peft_model_state_dict
|
| 1090 |
from transformers import TrainerCallback
|
| 1091 |
from trl import GRPOConfig, GRPOTrainer, clone_chat_template
|
| 1092 |
try:
|
|
|
|
| 1870 |
cache_volume.commit()
|
| 1871 |
if adapter_source:
|
| 1872 |
print(f"Loading initial SFT adapter for trainable GRPO continuation: {adapter_source}")
|
| 1873 |
+
adapter_source_path = pathlib.Path(adapter_source)
|
| 1874 |
+
adapter_config_path = adapter_source_path / "adapter_config.json"
|
| 1875 |
+
if not adapter_config_path.exists():
|
| 1876 |
+
raise RuntimeError(f"Initial SFT adapter config not found: {adapter_config_path}")
|
| 1877 |
+
adapter_config = json.loads(adapter_config_path.read_text(encoding="utf-8"))
|
| 1878 |
+
adapter_rank = int(adapter_config.get("r") or lora_rank)
|
| 1879 |
+
adapter_alpha = int(adapter_config.get("lora_alpha") or adapter_rank * 2)
|
| 1880 |
+
adapter_target_modules = adapter_config.get("target_modules") or [
|
| 1881 |
+
"q_proj",
|
| 1882 |
+
"k_proj",
|
| 1883 |
+
"v_proj",
|
| 1884 |
+
"o_proj",
|
| 1885 |
+
"gate_proj",
|
| 1886 |
+
"up_proj",
|
| 1887 |
+
"down_proj",
|
| 1888 |
+
]
|
| 1889 |
+
adapter_target_modules = list(adapter_target_modules)
|
| 1890 |
+
print(
|
| 1891 |
+
"Attaching Unsloth LoRA before loading SFT weights: "
|
| 1892 |
+
f"rank={adapter_rank}, alpha={adapter_alpha}, targets={adapter_target_modules}"
|
| 1893 |
+
)
|
| 1894 |
+
model = model_api.get_peft_model(
|
| 1895 |
+
model,
|
| 1896 |
+
r=adapter_rank,
|
| 1897 |
+
target_modules=adapter_target_modules,
|
| 1898 |
+
lora_alpha=adapter_alpha,
|
| 1899 |
+
use_gradient_checkpointing="unsloth",
|
| 1900 |
+
random_state=3407,
|
| 1901 |
+
)
|
| 1902 |
+
adapter_weights_path = adapter_source_path / "adapter_model.safetensors"
|
| 1903 |
+
if not adapter_weights_path.exists():
|
| 1904 |
+
raise RuntimeError(f"Initial SFT adapter weights not found: {adapter_weights_path}")
|
| 1905 |
+
adapter_state = load_safetensors_file(str(adapter_weights_path), device="cpu")
|
| 1906 |
+
adapter_load_result = set_peft_model_state_dict(
|
| 1907 |
+
model,
|
| 1908 |
+
adapter_state,
|
| 1909 |
+
adapter_name="default",
|
| 1910 |
+
)
|
| 1911 |
+
unexpected_adapter_keys = sorted(
|
| 1912 |
+
key
|
| 1913 |
+
for key in getattr(adapter_load_result, "unexpected_keys", [])
|
| 1914 |
+
if "lora_" in key or "modules_to_save" in key
|
| 1915 |
+
)
|
| 1916 |
+
if unexpected_adapter_keys:
|
| 1917 |
+
raise RuntimeError(
|
| 1918 |
+
"Initial SFT adapter keys do not match the trainable Unsloth LoRA. "
|
| 1919 |
+
f"Unexpected adapter keys: {unexpected_adapter_keys[:10]}"
|
| 1920 |
+
)
|
| 1921 |
+
missing_lora_keys = sorted(
|
| 1922 |
+
key
|
| 1923 |
+
for key in getattr(adapter_load_result, "missing_keys", [])
|
| 1924 |
+
if "lora_" in key or "modules_to_save" in key
|
| 1925 |
+
)
|
| 1926 |
+
if missing_lora_keys:
|
| 1927 |
+
print(f"Missing LoRA keys while loading SFT adapter: {missing_lora_keys[:10]}")
|
| 1928 |
if hasattr(model, "print_trainable_parameters"):
|
| 1929 |
model.print_trainable_parameters()
|
| 1930 |
else:
|
scripts/modal_train_sft.py
CHANGED
|
@@ -373,8 +373,9 @@ def train_cybersecurity_owasp_sft(
|
|
| 373 |
) -> dict[str, Any]:
|
| 374 |
import inspect
|
| 375 |
|
| 376 |
-
from datasets import load_dataset
|
| 377 |
from huggingface_hub import snapshot_download
|
|
|
|
| 378 |
from trl import SFTConfig, SFTTrainer
|
| 379 |
try:
|
| 380 |
from trl.chat_template_utils import add_response_schema
|
|
@@ -454,6 +455,47 @@ def train_cybersecurity_owasp_sft(
|
|
| 454 |
except Exception as exc:
|
| 455 |
print(f"Tokenizer response schema add skipped: {exc!r}")
|
| 456 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
model = model_api.get_peft_model(
|
| 458 |
model,
|
| 459 |
r=lora_rank,
|
|
@@ -522,7 +564,20 @@ def train_cybersecurity_owasp_sft(
|
|
| 522 |
)
|
| 523 |
if skipped_trainer:
|
| 524 |
print(f"Skipping unsupported SFTTrainer keys: {skipped_trainer}")
|
| 525 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 526 |
**{
|
| 527 |
key: value
|
| 528 |
for key, value in trainer_values.items()
|
|
|
|
| 373 |
) -> dict[str, Any]:
|
| 374 |
import inspect
|
| 375 |
|
| 376 |
+
from datasets import Dataset, load_dataset
|
| 377 |
from huggingface_hub import snapshot_download
|
| 378 |
+
from transformers import Trainer
|
| 379 |
from trl import SFTConfig, SFTTrainer
|
| 380 |
try:
|
| 381 |
from trl.chat_template_utils import add_response_schema
|
|
|
|
| 455 |
except Exception as exc:
|
| 456 |
print(f"Tokenizer response schema add skipped: {exc!r}")
|
| 457 |
|
| 458 |
+
def _tokenize_sft_split(split_name: str, split_dataset) -> Dataset:
|
| 459 |
+
tokenized_rows: list[dict[str, list[int]]] = []
|
| 460 |
+
total_rows = len(split_dataset)
|
| 461 |
+
for row_index, example in enumerate(split_dataset, start=1):
|
| 462 |
+
messages = example["messages"]
|
| 463 |
+
if isinstance(messages, str):
|
| 464 |
+
messages = json.loads(messages)
|
| 465 |
+
rendered = tokenizer.apply_chat_template(
|
| 466 |
+
messages,
|
| 467 |
+
tokenize=False,
|
| 468 |
+
add_generation_prompt=False,
|
| 469 |
+
)
|
| 470 |
+
try:
|
| 471 |
+
encoded = tokenizer(
|
| 472 |
+
rendered,
|
| 473 |
+
add_special_tokens=False,
|
| 474 |
+
truncation=True,
|
| 475 |
+
max_length=max_seq_length,
|
| 476 |
+
)
|
| 477 |
+
except TypeError:
|
| 478 |
+
encoded = tokenizer(
|
| 479 |
+
text=rendered,
|
| 480 |
+
add_special_tokens=False,
|
| 481 |
+
truncation=True,
|
| 482 |
+
max_length=max_seq_length,
|
| 483 |
+
)
|
| 484 |
+
input_ids = encoded["input_ids"]
|
| 485 |
+
if input_ids and isinstance(input_ids[0], list):
|
| 486 |
+
input_ids = input_ids[0]
|
| 487 |
+
input_ids = [int(token_id) for token_id in input_ids[:max_seq_length]]
|
| 488 |
+
if not input_ids:
|
| 489 |
+
raise RuntimeError(f"{split_name} row {row_index} produced no tokens.")
|
| 490 |
+
tokenized_rows.append({"input_ids": input_ids, "labels": list(input_ids)})
|
| 491 |
+
if row_index % 500 == 0 or row_index == total_rows:
|
| 492 |
+
print(f"Tokenized {split_name} rows: {row_index}/{total_rows}")
|
| 493 |
+
return Dataset.from_list(tokenized_rows)
|
| 494 |
+
|
| 495 |
+
dataset["train"] = _tokenize_sft_split("train", dataset["train"])
|
| 496 |
+
if has_validation:
|
| 497 |
+
dataset["validation"] = _tokenize_sft_split("validation", dataset["validation"])
|
| 498 |
+
|
| 499 |
model = model_api.get_peft_model(
|
| 500 |
model,
|
| 501 |
r=lora_rank,
|
|
|
|
| 564 |
)
|
| 565 |
if skipped_trainer:
|
| 566 |
print(f"Skipping unsupported SFTTrainer keys: {skipped_trainer}")
|
| 567 |
+
class CyberSecurityOWASPSFTTrainer(SFTTrainer):
|
| 568 |
+
def compute_loss(
|
| 569 |
+
self,
|
| 570 |
+
model,
|
| 571 |
+
inputs,
|
| 572 |
+
return_outputs: bool = False,
|
| 573 |
+
num_items_in_batch=None,
|
| 574 |
+
):
|
| 575 |
+
compute_loss_kwargs = {"return_outputs": return_outputs}
|
| 576 |
+
if "num_items_in_batch" in inspect.signature(Trainer.compute_loss).parameters:
|
| 577 |
+
compute_loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
| 578 |
+
return Trainer.compute_loss(self, model, inputs, **compute_loss_kwargs)
|
| 579 |
+
|
| 580 |
+
trainer = CyberSecurityOWASPSFTTrainer(
|
| 581 |
**{
|
| 582 |
key: value
|
| 583 |
for key, value in trainer_values.items()
|
tests/test_modal_scenario_cache_static.py
CHANGED
|
@@ -59,6 +59,10 @@ def test_modal_sft_defaults_match_300_episode_fast_handoff_plan():
|
|
| 59 |
assert '"packing": False' in source
|
| 60 |
assert '"packing_strategy": "bfd"' not in source
|
| 61 |
assert '"dataset_num_proc": None' in source
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
assert '"bf16": True' in source
|
| 63 |
assert '"tf32": True' in source
|
| 64 |
assert '"hub_strategy": "every_save"' in source
|
|
@@ -74,4 +78,7 @@ def test_modal_grpo_loads_sft_adapter_from_hub_as_trainable_lora():
|
|
| 74 |
assert "initial_adapter_repo_id" in source
|
| 75 |
assert "Downloading initial SFT adapter" in source
|
| 76 |
assert "snapshot_download(" in source
|
| 77 |
-
assert "
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
assert '"packing": False' in source
|
| 60 |
assert '"packing_strategy": "bfd"' not in source
|
| 61 |
assert '"dataset_num_proc": None' in source
|
| 62 |
+
assert "Dataset.from_list(tokenized_rows)" in source
|
| 63 |
+
assert "tokenizer.apply_chat_template" in source
|
| 64 |
+
assert "class CyberSecurityOWASPSFTTrainer(SFTTrainer)" in source
|
| 65 |
+
assert "Trainer.compute_loss(self, model, inputs" in source
|
| 66 |
assert '"bf16": True' in source
|
| 67 |
assert '"tf32": True' in source
|
| 68 |
assert '"hub_strategy": "every_save"' in source
|
|
|
|
| 78 |
assert "initial_adapter_repo_id" in source
|
| 79 |
assert "Downloading initial SFT adapter" in source
|
| 80 |
assert "snapshot_download(" in source
|
| 81 |
+
assert "Attaching Unsloth LoRA before loading SFT weights" in source
|
| 82 |
+
assert "load_safetensors_file(str(adapter_weights_path), device=\"cpu\")" in source
|
| 83 |
+
assert "set_peft_model_state_dict(" in source
|
| 84 |
+
assert "unexpected_adapter_keys" in source
|