| | import gradio as gr |
| | import os |
| | import torch |
| |
|
| | training_status = {"running": False, "log": ""} |
| |
|
| | def run_training( |
| | base_model: str, |
| | dataset_id: str, |
| | epochs: int, |
| | batch_size: int, |
| | learning_rate: float, |
| | lora_r: int, |
| | output_repo: str, |
| | progress=gr.Progress() |
| | ): |
| | global training_status |
| | training_status["running"] = True |
| | training_status["log"] = "" |
| | |
| | def log(msg): |
| | training_status["log"] += msg + "\n" |
| | print(msg) |
| | |
| | try: |
| | log("=" * 50) |
| | log("Agent Zero Music Workflow Trainer") |
| | log("Intuition Labs • terminals.tech") |
| | log("=" * 50) |
| | |
| | progress(0.05, desc="Installing dependencies...") |
| | log("\n[1/6] Installing dependencies...") |
| | os.system("pip install -q transformers trl peft datasets accelerate bitsandbytes") |
| | |
| | progress(0.1, desc="Loading libraries...") |
| | log("[2/6] Loading libraries...") |
| | |
| | from datasets import load_dataset |
| | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments |
| | from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training |
| | from trl import SFTTrainer, SFTConfig |
| | |
| | progress(0.15, desc="Loading tokenizer...") |
| | log(f"[3/6] Loading tokenizer: {base_model}") |
| | tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| | |
| | progress(0.2, desc="Loading model with 4-bit quantization...") |
| | log(f"[4/6] Loading model with 4-bit quantization...") |
| | bnb_config = BitsAndBytesConfig( |
| | load_in_4bit=True, |
| | bnb_4bit_compute_dtype=torch.bfloat16, |
| | bnb_4bit_quant_type="nf4", |
| | bnb_4bit_use_double_quant=True, |
| | ) |
| | |
| | model = AutoModelForCausalLM.from_pretrained( |
| | base_model, |
| | quantization_config=bnb_config, |
| | device_map="auto", |
| | trust_remote_code=True, |
| | torch_dtype=torch.bfloat16, |
| | ) |
| | model = prepare_model_for_kbit_training(model) |
| | |
| | log(f"[4/6] Applying LoRA (r={lora_r})...") |
| | lora_config = LoraConfig( |
| | r=lora_r, |
| | lora_alpha=lora_r * 2, |
| | lora_dropout=0.05, |
| | target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], |
| | bias="none", |
| | task_type="CAUSAL_LM", |
| | ) |
| | model = get_peft_model(model, lora_config) |
| | |
| | progress(0.3, desc="Loading dataset...") |
| | log(f"[5/6] Loading dataset: {dataset_id}") |
| | dataset = load_dataset(dataset_id, split="train") |
| | |
| | def format_example(example): |
| | if "instruction" in example and "response" in example: |
| | return {"text": f"<|im_start|>user\n{example['instruction']}<|im_end|>\n<|im_start|>assistant\n{example['response']}<|im_end|>"} |
| | elif "text" in example: |
| | return {"text": example["text"]} |
| | else: |
| | return {"text": " ".join(str(v) for v in example.values() if isinstance(v, str))} |
| | |
| | dataset = dataset.map(format_example) |
| | log(f"Dataset size: {len(dataset)} examples") |
| | |
| | progress(0.4, desc="Setting up trainer...") |
| | log(f"[6/6] Starting training: {epochs} epochs, batch={batch_size}, lr={learning_rate}") |
| | |
| | |
| | sft_config = SFTConfig( |
| | output_dir="./outputs", |
| | num_train_epochs=epochs, |
| | per_device_train_batch_size=batch_size, |
| | gradient_accumulation_steps=4, |
| | learning_rate=learning_rate, |
| | lr_scheduler_type="cosine", |
| | warmup_ratio=0.1, |
| | logging_steps=10, |
| | save_steps=100, |
| | bf16=True, |
| | gradient_checkpointing=True, |
| | push_to_hub=True, |
| | hub_model_id=output_repo, |
| | hub_token=os.environ.get("HF_TOKEN"), |
| | max_length=4096, |
| | dataset_text_field="text", |
| | ) |
| | |
| | trainer = SFTTrainer( |
| | model=model, |
| | args=sft_config, |
| | train_dataset=dataset, |
| | processing_class=tokenizer, |
| | ) |
| | |
| | log("\n" + "=" * 50) |
| | log("TRAINING STARTED") |
| | log("=" * 50) |
| | |
| | trainer.train() |
| | |
| | progress(0.95, desc="Pushing to Hub...") |
| | log("\nPushing model to Hub...") |
| | trainer.push_to_hub() |
| | |
| | progress(1.0, desc="Complete!") |
| | log("\n" + "=" * 50) |
| | log("TRAINING COMPLETE!") |
| | log(f"Model saved to: https://huggingface.co/{output_repo}") |
| | log("=" * 50) |
| | |
| | training_status["running"] = False |
| | return training_status["log"] |
| | |
| | except Exception as e: |
| | log(f"\nERROR: {str(e)}") |
| | import traceback |
| | log(traceback.format_exc()) |
| | training_status["running"] = False |
| | return training_status["log"] |
| |
|
| | with gr.Blocks(title="Agent Zero Trainer") as demo: |
| | gr.Markdown(""" |
| | # Agent Zero Music Workflow Trainer |
| | **Intuition Labs** • terminals.tech |
| | |
| | Fine-tune models for coherent multi-context orchestration. |
| | Running on L40S GPU (48GB VRAM) - $1.80/hr |
| | """) |
| | |
| | with gr.Row(): |
| | with gr.Column(): |
| | base_model = gr.Textbox(value="Qwen/Qwen2.5-7B-Instruct", label="Base Model") |
| | dataset_id = gr.Textbox(value="wheattoast11/agent-zero-training-data", label="Dataset ID") |
| | epochs = gr.Slider(1, 10, value=3, step=1, label="Epochs") |
| | batch_size = gr.Slider(1, 8, value=2, step=1, label="Batch Size") |
| | learning_rate = gr.Number(value=2e-5, label="Learning Rate") |
| | lora_r = gr.Slider(8, 64, value=16, step=8, label="LoRA Rank") |
| | output_repo = gr.Textbox(value="wheattoast11/agent-zero-music-workflow", label="Output Repo") |
| | submit_btn = gr.Button("Start Training", variant="primary") |
| | |
| | with gr.Column(): |
| | output = gr.Textbox(label="Training Log", lines=25, max_lines=50) |
| | |
| | submit_btn.click( |
| | fn=run_training, |
| | inputs=[base_model, dataset_id, epochs, batch_size, learning_rate, lora_r, output_repo], |
| | outputs=output, |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |
| |
|