| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| Resume FLUX.2-klein-4B LoRA training from step 500 checkpoint. |
| Output: Limbicnation/pixel-art-lora |
| """ |
|
|
| import os |
| import sys |
| import torch |
| import torch.nn.functional as F |
| from pathlib import Path |
| from tqdm import tqdm |
| from PIL import Image |
| import numpy as np |
|
|
| |
| token = os.environ.get("HF_TOKEN") |
| if not token or token == "$HF_TOKEN": |
| print("ERROR: HF_TOKEN not set") |
| sys.exit(1) |
|
|
| os.environ["HF_TOKEN"] = token |
|
|
| |
| from huggingface_hub import login, hf_hub_download, snapshot_download, create_repo, upload_file |
| from diffusers import FluxPipeline |
| from peft import LoraConfig, get_peft_model, set_peft_model_state_dict |
| from safetensors.torch import load_file, save_file |
| from accelerate import Accelerator |
|
|
| CHECKPOINT_REPO = "Limbicnation/sprite-lora-checkpoint-step500" |
| DATASET_REPO = "Limbicnation/sprite-lora-training-data" |
| OUTPUT_REPO = "Limbicnation/pixel-art-lora" |
| BASE_MODEL = "black-forest-labs/FLUX.2-klein-4B" |
|
|
| def main(): |
| print("="*70) |
| print("π FLUX.2-klein-4B LoRA Training - Final") |
| print("="*70) |
| print(f"Base model: {BASE_MODEL}") |
| print(f"Output: {OUTPUT_REPO}") |
| print(f"Resume: Step 500 -> 1000") |
| |
| |
| print("\nπ Authenticating...") |
| login(token=token, add_to_git_credential=False) |
| print("β
Authenticated") |
| |
| |
| print("\nπ₯ Downloading checkpoint...") |
| os.makedirs("checkpoint", exist_ok=True) |
| hf_hub_download( |
| repo_id=CHECKPOINT_REPO, |
| filename="pytorch_lora_weights.safetensors", |
| repo_type="model", |
| local_dir="checkpoint", |
| token=token |
| ) |
| print("β
Checkpoint downloaded") |
| |
| |
| print("\nπ₯ Downloading dataset...") |
| snapshot_download( |
| repo_id=DATASET_REPO, |
| repo_type="dataset", |
| local_dir="data", |
| token=token |
| ) |
| image_files = list(Path("data").rglob("*.png")) |
| print(f"β
Dataset: {len(image_files)} images") |
| |
| |
| accelerator = Accelerator(gradient_accumulation_steps=4, mixed_precision="bf16") |
| device = accelerator.device |
| print(f"\nβοΈ Device: {device}") |
| |
| |
| print(f"\nπ₯ Loading {BASE_MODEL}...") |
| pipe = FluxPipeline.from_pretrained( |
| BASE_MODEL, |
| torch_dtype=torch.bfloat16, |
| token=token |
| ) |
| pipe.enable_model_cpu_offload() |
| print("β
Model loaded") |
| |
| |
| print("\nπ§ Applying LoRA (rank=64, alpha=128)...") |
| target_modules = [] |
| for i in range(19): |
| target_modules.extend([ |
| f"transformer_blocks.{i}.attn.to_q", |
| f"transformer_blocks.{i}.attn.to_k", |
| f"transformer_blocks.{i}.attn.to_v", |
| ]) |
| |
| lora_config = LoraConfig(r=64, lora_alpha=128, target_modules=target_modules, use_rslora=True) |
| pipe.transformer = get_peft_model(pipe.transformer, lora_config) |
| |
| |
| print("\nπ Loading checkpoint...") |
| state_dict = load_file("checkpoint/pytorch_lora_weights.safetensors") |
| set_peft_model_state_dict(pipe.transformer, state_dict) |
| print("β
Checkpoint loaded, resuming from step 500") |
| |
| global_step = 500 |
| |
| |
| print(f"\nπ€ Creating output repo...") |
| create_repo(OUTPUT_REPO, exist_ok=True, repo_type="model", token=token) |
| |
| |
| trainable = [p for p in pipe.transformer.parameters() if p.requires_grad] |
| import bitsandbytes as bnb |
| optimizer = bnb.optim.AdamW8bit(trainable, lr=1e-4) |
| |
| |
| class Dataset(torch.utils.data.Dataset): |
| def __init__(self, root, res=512): |
| self.imgs = sorted(list(Path(root).rglob("*.png"))) |
| self.res = res |
| def __len__(self): return len(self.imgs) |
| def __getitem__(self, idx): |
| img = Image.open(self.imgs[idx]).convert("RGB").resize((self.res, self.res)) |
| img = torch.from_numpy(np.array(img)).permute(2,0,1).float()/255.0 * 2 - 1 |
| txt = self.imgs[idx].with_suffix(".txt") |
| cap = txt.read_text().strip() if txt.exists() else "" |
| return {"images": img, "captions": cap} |
| |
| dataset = Dataset("data/images") |
| dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True) |
| print(f"β
Dataset ready: {len(dataset)} images") |
| |
| |
| pipe.transformer, optimizer, dataloader = accelerator.prepare( |
| pipe.transformer, optimizer, dataloader |
| ) |
| |
| |
| print("\n" + "="*70) |
| print("ποΈ Training: Step 500 -> 1000") |
| print("="*70) |
| |
| pipe.transformer.train() |
| pbar = tqdm(total=1000, initial=global_step, desc="Training") |
| |
| while global_step < 1000: |
| for batch in dataloader: |
| with accelerator.accumulate(pipe.transformer): |
| imgs = batch["images"].to(device) |
| caps = [f"pixel art sprite, {c}" for c in batch["captions"]] |
| |
| with torch.no_grad(): |
| latents = pipe.vae.encode(imgs).latent_dist.sample() |
| noise = torch.randn_like(latents) |
| t = torch.rand(latents.shape[0], device=device) * 1000 |
| sigmas = t.view(-1,1,1,1) / 1000 |
| noisy = (1-sigmas)*latents + sigmas*noise |
| target = noise - latents |
| |
| with torch.no_grad(): |
| prompt_embeds = pipe.encode_prompt(caps)[0] |
| |
| output = pipe.transformer( |
| hidden_states=noisy, |
| timestep=t, |
| encoder_hidden_states=prompt_embeds, |
| return_dict=False |
| )[0] |
| |
| loss = F.mse_loss(output.float(), target.float()) |
| accelerator.backward(loss) |
| |
| if accelerator.sync_gradients: |
| accelerator.clip_grad_norm_(pipe.transformer.parameters(), 1.0) |
| |
| optimizer.step() |
| optimizer.zero_grad() |
| |
| if accelerator.sync_gradients: |
| global_step += 1 |
| pbar.update(1) |
| pbar.set_postfix({"loss": f"{loss.item():.4f}"}) |
| |
| if global_step % 500 == 0: |
| print(f"\nπΎ Saving checkpoint at step {global_step}...") |
| os.makedirs(f"output/step_{global_step}", exist_ok=True) |
| save_file( |
| get_peft_model_state_dict(accelerator.unwrap_model(pipe.transformer)), |
| f"output/step_{global_step}/pytorch_lora_weights.safetensors" |
| ) |
| upload_file( |
| path_or_fileobj=f"output/step_{global_step}/pytorch_lora_weights.safetensors", |
| path_in_repo=f"step_{global_step}/pytorch_lora_weights.safetensors", |
| repo_id=OUTPUT_REPO, |
| repo_type="model", |
| token=token |
| ) |
| print("β
Checkpoint saved") |
| |
| if global_step >= 1000: |
| break |
| |
| pbar.close() |
| |
| |
| print("\nπΎ Saving final model...") |
| os.makedirs("output/final", exist_ok=True) |
| save_file( |
| get_peft_model_state_dict(accelerator.unwrap_model(pipe.transformer)), |
| "output/final/pytorch_lora_weights.safetensors" |
| ) |
| upload_file( |
| path_or_fileobj="output/final/pytorch_lora_weights.safetensors", |
| path_in_repo="pytorch_lora_weights.safetensors", |
| repo_id=OUTPUT_REPO, |
| repo_type="model", |
| token=token |
| ) |
| |
| print("\n" + "="*70) |
| print("β
Training Complete!") |
| print("="*70) |
| print(f"\nπ€ Model: https://huggingface.co/{OUTPUT_REPO}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|