LoRa_Streamlit / train.py
ramimu's picture
Update train.py
b1dde27 verified
raw
history blame
3.19 kB
# train.py
import os
import torch
from huggingface_hub import snapshot_download
from peft import LoraConfig, get_peft_model
# 1️⃣ Pick your scheduler class
from diffusers import (
StableDiffusionPipeline,
DPMSolverMultistepScheduler,
UNet2DConditionModel,
AutoencoderKL,
)
from transformers import CLIPTextModel, CLIPTokenizer
# ─── 1) CONFIG ────────────────────────────────────────────────────────────────
DATA_DIR = os.getenv("DATA_DIR", "./data")
MODEL_DIR = os.getenv("MODEL_DIR", "./hidream-model")
OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./lora-trained")
# ─── 2) DOWNLOAD OR VERIFY BASE MODEL ──────────────────────────────────────────
if not os.path.isdir(MODEL_DIR):
MODEL_DIR = snapshot_download(
repo_id="HiDream-ai/HiDream-I1-Dev",
local_dir=MODEL_DIR
)
# ─── 3) LOAD EACH PIPELINE COMPONENT ──────────────────────────────────────────
# 3a) Scheduler
scheduler = DPMSolverMultistepScheduler.from_pretrained(
MODEL_DIR,
subfolder="scheduler"
)
# 3b) VAE
vae = AutoencoderKL.from_pretrained(
MODEL_DIR,
subfolder="vae",
torch_dtype=torch.float16
).to("cuda")
# 3c) Text encoder + tokenizer
text_encoder = CLIPTextModel.from_pretrained(
MODEL_DIR,
subfolder="text_encoder",
torch_dtype=torch.float16
).to("cuda")
tokenizer = CLIPTokenizer.from_pretrained(
MODEL_DIR,
subfolder="tokenizer"
)
# 3d) U‑Net
unet = UNet2DConditionModel.from_pretrained(
MODEL_DIR,
subfolder="unet",
torch_dtype=torch.float16
).to("cuda")
# ─── 4) BUILD THE PIPELINE ────────────────────────────────────────────────────
pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
).to("cuda")
# ─── 5) APPLY LORA ────────────────────────────────────────────────────────────
lora_config = LoraConfig(
r=16,
lora_alpha=16,
bias="none",
task_type="CAUSAL_LM",
)
pipe.unet = get_peft_model(pipe.unet, lora_config)
# ─── 6) TRAINING LOOP (SIMULATED) ─────────────────────────────────────────────
print(f"πŸ“‚ Data at {DATA_DIR}")
for step in range(100):
# … your real data loading + optimizer here …
print(f"Training step {step+1}/100")
# ─── 7) SAVE THE FINE‑TUNED LO‑RA ─────────────────────────────────────────────
os.makedirs(OUTPUT_DIR, exist_ok=True)
pipe.save_pretrained(OUTPUT_DIR)
print("βœ… Done! Saved to", OUTPUT_DIR)