rtferraz commited on
Commit
b1bb14c
Β·
verified Β·
1 Parent(s): 631e559

v4 notebook: fix dtype Half/BFloat16 mismatch (explicit bf16), fix tied embeddings path, fix max_length warning

Browse files
Files changed (1) hide show
  1. notebooks/v4_instruct_grpo.ipynb +1 -1
notebooks/v4_instruct_grpo.ipynb CHANGED
@@ -51,7 +51,7 @@
51
  "execution_count": null,
52
  "metadata": {},
53
  "outputs": [],
54
- "source": "from unsloth import FastLanguageModel\n\nprint(\"Loading model...\")\nmodel, tokenizer = FastLanguageModel.from_pretrained(\n model_name=MODEL_ID,\n max_seq_length=MAX_SEQ_LENGTH,\n load_in_4bit=True,\n dtype=None, # auto-detect\n)\n\n# ═══════════════════════════════════════════════════════════════════════════════\n# LoRA ADAPTER β€” ADR-002 Β§9: r=16, Ξ±=32\n# ═══════════════════════════════════════════════════════════════════════════════\nmodel = FastLanguageModel.get_peft_model(\n model,\n r=LORA_R,\n lora_alpha=LORA_ALPHA,\n target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n \"gate_proj\", \"up_proj\", \"down_proj\"],\n lora_dropout=0,\n bias=\"none\",\n use_gradient_checkpointing=\"unsloth\",\n random_state=42,\n)\n\n# ═══════════════════════════════════════════════════════════════════════════════\n# CRITICAL OVERRIDES β€” generation_config ships with values that destroy GRPO\n# Source: Polygl0t/Tucano2-qwen-0.5B-Instruct/generation_config.json\n# temperature: 0.1 β†’ override to 1.0\n# repetition_penalty: 1.2 β†’ override to 1.0\n# use_cache: false β†’ override to true\n# ═══════════════════════════════════════════════════════════════════════════════\n\nmodel.config.use_cache = True\nmodel.generation_config.use_cache = True\nmodel.generation_config.temperature = TEMPERATURE\nmodel.generation_config.repetition_penalty = 1.0 # CRITICAL: 1.2 suppresses diversity\nmodel.generation_config.do_sample = True\nmodel.generation_config.top_k = 0 # disable top-k β€” let temperature control diversity\nmodel.generation_config.top_p = 1.0 # disable top-p\n\n# Pad token\nif tokenizer.pad_token is None:\n tokenizer.pad_token = tokenizer.eos_token\n\nprint(f\"βœ“ Model loaded on {model.device}\")\nprint(f\" use_cache: {model.config.use_cache}\")\nprint(f\" temperature: {model.generation_config.temperature}\")\nprint(f\" repetition_penalty: {model.generation_config.repetition_penalty}\")\nprint(f\" top_k: {model.generation_config.top_k}\")\nprint(f\" Params: {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M\")\n\n# ═══════════════════════════════════════════════════════════════════════════════\n# TIED EMBEDDINGS CHECK β€” ADR-002 Decision 4\n# Source: config.json has \"tie_word_embeddings\": true\n# After LoRA patching, verify lm_head and embed_tokens still share weights.\n# ═══════════════════════════════════════════════════════════════════════════════\n\ntry:\n lm_ptr = model.lm_head.weight.data_ptr()\n embed_ptr = model.model.embed_tokens.weight.data_ptr()\n tied = lm_ptr == embed_ptr\n print(f\" Tied embeddings intact: {tied}\")\n if not tied:\n print(\" ⚠️ WARNING: Tied embeddings broken after LoRA patching. May affect output head gradients.\")\nexcept AttributeError as e:\n print(f\" ⚠️ Could not check tied embeddings: {e}\")"
55
  },
56
  {
57
  "cell_type": "markdown",
 
51
  "execution_count": null,
52
  "metadata": {},
53
  "outputs": [],
54
+ "source": "from unsloth import FastLanguageModel\n\nprint(\"Loading model...\")\nmodel, tokenizer = FastLanguageModel.from_pretrained(\n model_name=MODEL_ID,\n max_seq_length=MAX_SEQ_LENGTH,\n load_in_4bit=True,\n dtype=torch.bfloat16, # explicit bf16 β€” dtype=None can cause Half/BFloat16 mismatch in Unsloth LoRA kernels\n)\n\n# ═══════════════════════════════════════════════════════════════════════════════\n# LoRA ADAPTER β€” ADR-002 Β§9: r=16, Ξ±=32\n# ═══════════════════════════════════════════════════════════════════════════════\nmodel = FastLanguageModel.get_peft_model(\n model,\n r=LORA_R,\n lora_alpha=LORA_ALPHA,\n target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n \"gate_proj\", \"up_proj\", \"down_proj\"],\n lora_dropout=0,\n bias=\"none\",\n use_gradient_checkpointing=\"unsloth\",\n random_state=42,\n)\n\n# ═══════════════════════════════════════════════════════════════════════════════\n# CRITICAL OVERRIDES β€” generation_config ships with values that destroy GRPO\n# Source: Polygl0t/Tucano2-qwen-0.5B-Instruct/generation_config.json\n# temperature: 0.1 β†’ override to 1.0\n# repetition_penalty: 1.2 β†’ override to 1.0\n# use_cache: false β†’ override to true\n# ═══════════════════════════════════════════════════════════════════════════════\n\nmodel.config.use_cache = True\nmodel.generation_config.use_cache = True\nmodel.generation_config.temperature = TEMPERATURE\nmodel.generation_config.repetition_penalty = 1.0 # CRITICAL: 1.2 suppresses diversity\nmodel.generation_config.do_sample = True\nmodel.generation_config.top_k = 0 # disable top-k β€” let temperature control diversity\nmodel.generation_config.top_p = 1.0 # disable top-p\nmodel.generation_config.max_length = None # remove conflict with max_new_tokens\n\n# Pad token\nif tokenizer.pad_token is None:\n tokenizer.pad_token = tokenizer.eos_token\n\nprint(f\"βœ“ Model loaded on {model.device}\")\nprint(f\" use_cache: {model.config.use_cache}\")\nprint(f\" temperature: {model.generation_config.temperature}\")\nprint(f\" repetition_penalty: {model.generation_config.repetition_penalty}\")\nprint(f\" top_k: {model.generation_config.top_k}\")\nprint(f\" Params: {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M\")\n\n# ═══════════════════════════════════════════════════════════════════════════════\n# TIED EMBEDDINGS CHECK β€” ADR-002 Decision 4\n# Source: config.json has \"tie_word_embeddings\": true\n# After LoRA patching, verify lm_head and embed_tokens still share weights.\n# ═══════════════════════════════════════════════════════════════════════════════\n\ntry:\n lm_ptr = model.base_model.model.lm_head.weight.data_ptr()\n embed_ptr = model.base_model.model.model.embed_tokens.weight.data_ptr()\n tied = lm_ptr == embed_ptr\n print(f\" Tied embeddings intact: {tied}\")\n if not tied:\n print(\" ⚠️ WARNING: Tied embeddings broken after LoRA patching. May affect output head gradients.\")\nexcept AttributeError as e:\n print(f\" ⚠️ Could not check tied embeddings: {e}\")"
55
  },
56
  {
57
  "cell_type": "markdown",