Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| Qwen2.5-7B + glaive-function-calling-v2 QLoRA学習スクリプト | |
| マルチGPU対応版 (4xL40S等) | |
| 実行方法: | |
| accelerate launch --num_processes 4 train_multi_gpu.py | |
| """ | |
| import os | |
| import sys | |
| import time | |
| from datetime import datetime | |
| import torch | |
| from datasets import load_dataset | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| BitsAndBytesConfig, | |
| TrainingArguments, | |
| ) | |
| from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training | |
| from trl import SFTTrainer | |
| from transformers.trainer_callback import TrainerCallback | |
| # ============================================================ | |
| # 設定 | |
| # ============================================================ | |
| BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct" | |
| OUTPUT_MODEL_ID = "hajimemat/qwen2.5-7b-glaive-fc-lora" | |
| DATASET_NAME = "glaiveai/glaive-function-calling-v2" | |
| CHECKPOINT_DIR = "./checkpoints" | |
| FINAL_OUTPUT_DIR = "./output/final" | |
| # ============================================================ | |
| # QLoRA量子化設定 | |
| # ============================================================ | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| # ============================================================ | |
| # LoRA設定 | |
| # ============================================================ | |
| lora_config = LoraConfig( | |
| r=64, | |
| lora_alpha=16, | |
| lora_dropout=0.05, | |
| target_modules=[ | |
| "q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj" | |
| ], | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| ) | |
| # ============================================================ | |
| # カスタムコールバック | |
| # ============================================================ | |
| class VerboseLoggingCallback(TrainerCallback): | |
| def __init__(self): | |
| self.start_time = None | |
| def on_train_begin(self, args, state, control, **kwargs): | |
| self.start_time = time.time() | |
| if state.is_world_process_zero: | |
| print("\n" + "=" * 70) | |
| print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Training started") | |
| print(f" Total steps: {state.max_steps}") | |
| print(f" Num GPUs: {args.world_size}") | |
| print(f" Per device batch: {args.per_device_train_batch_size}") | |
| print(f" Gradient accum: {args.gradient_accumulation_steps}") | |
| print(f" Effective batch: {args.per_device_train_batch_size * args.gradient_accumulation_steps * args.world_size}") | |
| print("=" * 70 + "\n") | |
| def on_log(self, args, state, control, logs=None, **kwargs): | |
| if logs is None or not state.is_world_process_zero: | |
| return | |
| current_time = time.time() | |
| elapsed = current_time - self.start_time | |
| elapsed_str = time.strftime("%H:%M:%S", time.gmtime(elapsed)) | |
| progress = state.global_step / state.max_steps * 100 if state.max_steps > 0 else 0 | |
| if state.global_step > 0: | |
| time_per_step = elapsed / state.global_step | |
| remaining_steps = state.max_steps - state.global_step | |
| eta_seconds = time_per_step * remaining_steps | |
| eta_str = time.strftime("%H:%M:%S", time.gmtime(eta_seconds)) | |
| else: | |
| eta_str = "calculating..." | |
| loss = logs.get("loss", "N/A") | |
| lr = logs.get("learning_rate", "N/A") | |
| print(f"[{datetime.now().strftime('%H:%M:%S')}] " | |
| f"Step {state.global_step}/{state.max_steps} ({progress:.1f}%) | " | |
| f"Loss: {loss:.4f if isinstance(loss, float) else loss} | " | |
| f"LR: {lr:.2e if isinstance(lr, float) else lr} | " | |
| f"Elapsed: {elapsed_str} | ETA: {eta_str}") | |
| def on_save(self, args, state, control, **kwargs): | |
| if state.is_world_process_zero: | |
| print(f"\n[{datetime.now().strftime('%H:%M:%S')}] " | |
| f"💾 Checkpoint saved at step {state.global_step}\n") | |
| def on_train_end(self, args, state, control, **kwargs): | |
| if state.is_world_process_zero: | |
| total_time = time.time() - self.start_time | |
| total_str = time.strftime("%H:%M:%S", time.gmtime(total_time)) | |
| print("\n" + "=" * 70) | |
| print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Training completed!") | |
| print(f" Total time: {total_str}") | |
| print("=" * 70 + "\n") | |
| # ============================================================ | |
| # データセット変換 | |
| # ============================================================ | |
| def convert_glaive_to_chatml(example: dict) -> dict: | |
| parts = [] | |
| if example.get("system"): | |
| parts.append(f"<|im_start|>system\n{example['system']}<|im_end|>") | |
| chat = example.get("chat", "") | |
| if chat: | |
| current_role = None | |
| current_content = [] | |
| for line in chat.split("\n"): | |
| line = line.strip() | |
| if line.startswith("USER:"): | |
| if current_role and current_content: | |
| content = "\n".join(current_content).strip() | |
| if content: | |
| parts.append(f"<|im_start|>{current_role}\n{content}<|im_end|>") | |
| current_role = "user" | |
| current_content = [line[5:].strip()] | |
| elif line.startswith("ASSISTANT:"): | |
| if current_role and current_content: | |
| content = "\n".join(current_content).strip() | |
| if content: | |
| parts.append(f"<|im_start|>{current_role}\n{content}<|im_end|>") | |
| current_role = "assistant" | |
| current_content = [line[10:].strip()] | |
| elif current_role: | |
| current_content.append(line) | |
| if current_role and current_content: | |
| content = "\n".join(current_content).strip() | |
| if content: | |
| parts.append(f"<|im_start|>{current_role}\n{content}<|im_end|>") | |
| return {"text": "\n".join(parts)} | |
| def load_and_prepare_dataset(): | |
| print(f"\nLoading dataset: {DATASET_NAME}") | |
| dataset = load_dataset(DATASET_NAME, split="train") | |
| print(f"Original size: {len(dataset)} examples") | |
| dataset = dataset.map( | |
| convert_glaive_to_chatml, | |
| remove_columns=dataset.column_names, | |
| num_proc=4, | |
| desc="Converting" | |
| ) | |
| dataset = dataset.filter(lambda x: len(x["text"]) > 50) | |
| print(f"After filtering: {len(dataset)} examples") | |
| dataset = dataset.shuffle(seed=42) | |
| split = dataset.train_test_split(test_size=0.02, seed=42) | |
| print(f"Train: {len(split['train'])}, Test: {len(split['test'])}") | |
| return split | |
| # ============================================================ | |
| # 学習パラメータ(マルチGPU最適化) | |
| # ============================================================ | |
| num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1 | |
| training_args = TrainingArguments( | |
| output_dir=CHECKPOINT_DIR, | |
| num_train_epochs=2, | |
| # マルチGPU: L40Sは48GB VRAMなのでバッチサイズを上げる | |
| per_device_train_batch_size=8, # 1GPUあたり8 (L40S 48GB) | |
| per_device_eval_batch_size=8, | |
| gradient_accumulation_steps=2, # 有効バッチ: 8*2*4=64 | |
| learning_rate=1e-4, | |
| weight_decay=0.01, | |
| warmup_ratio=0.03, | |
| lr_scheduler_type="cosine", | |
| optim="paged_adamw_8bit", | |
| fp16=False, | |
| bf16=True, | |
| max_grad_norm=0.3, | |
| logging_steps=10, | |
| save_steps=500, | |
| save_total_limit=3, | |
| eval_strategy="steps", | |
| eval_steps=500, | |
| report_to="none", | |
| group_by_length=True, | |
| gradient_checkpointing=True, | |
| # マルチGPU設定 | |
| ddp_find_unused_parameters=False, | |
| dataloader_num_workers=4, | |
| save_safetensors=True, | |
| ) | |
| # ============================================================ | |
| # メイン | |
| # ============================================================ | |
| def main(): | |
| local_rank = int(os.environ.get("LOCAL_RANK", 0)) | |
| is_main = local_rank == 0 | |
| if is_main: | |
| print("\n" + "=" * 70) | |
| print(" Qwen2.5-7B + glaive-function-calling-v2 QLoRA Training") | |
| print(" Multi-GPU Version") | |
| print("=" * 70) | |
| print(f"Start: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") | |
| print(f"GPUs available: {torch.cuda.device_count()}") | |
| for i in range(torch.cuda.device_count()): | |
| print(f" GPU {i}: {torch.cuda.get_device_name(i)}") | |
| print("=" * 70 + "\n") | |
| # データセット | |
| dataset = load_and_prepare_dataset() | |
| # トークナイザー | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True) | |
| tokenizer.padding_side = "right" | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # モデル | |
| if is_main: | |
| print(f"\nLoading model: {BASE_MODEL}") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| quantization_config=bnb_config, | |
| device_map={"": local_rank}, # 各GPUに配置 | |
| attn_implementation="sdpa", | |
| trust_remote_code=True, | |
| ) | |
| model = prepare_model_for_kbit_training(model) | |
| model = get_peft_model(model, lora_config) | |
| if is_main: | |
| model.print_trainable_parameters() | |
| # Trainer | |
| trainer = SFTTrainer( | |
| model=model, | |
| train_dataset=dataset["train"], | |
| eval_dataset=dataset["test"], | |
| args=training_args, | |
| peft_config=lora_config, | |
| processing_class=tokenizer, | |
| max_seq_length=2048, | |
| packing=True, | |
| dataset_text_field="text", | |
| callbacks=[VerboseLoggingCallback()], | |
| ) | |
| # チェックポイント再開 | |
| resume_from = None | |
| if os.path.exists(CHECKPOINT_DIR): | |
| checkpoints = [d for d in os.listdir(CHECKPOINT_DIR) if d.startswith("checkpoint-")] | |
| if checkpoints: | |
| latest = max(checkpoints, key=lambda x: int(x.split("-")[1])) | |
| resume_from = os.path.join(CHECKPOINT_DIR, latest) | |
| if is_main: | |
| print(f"\n📂 Resuming from: {resume_from}") | |
| # 学習 | |
| trainer.train(resume_from_checkpoint=resume_from) | |
| # 保存(メインプロセスのみ) | |
| if is_main: | |
| print(f"\nSaving to {FINAL_OUTPUT_DIR}...") | |
| trainer.save_model(FINAL_OUTPUT_DIR) | |
| tokenizer.save_pretrained(FINAL_OUTPUT_DIR) | |
| print(f"\nUploading to: {OUTPUT_MODEL_ID}") | |
| try: | |
| trainer.model.push_to_hub(OUTPUT_MODEL_ID, private=True) | |
| tokenizer.push_to_hub(OUTPUT_MODEL_ID, private=True) | |
| print(f"✅ Uploaded: https://huggingface.co/{OUTPUT_MODEL_ID}") | |
| except Exception as e: | |
| print(f"⚠️ Upload failed: {e}") | |
| print("\n🎉 Training complete!") | |
| if __name__ == "__main__": | |
| main() | |