File size: 2,204 Bytes
c9b1bf6
 
 
b1dde27
 
 
 
2ec882e
b1dde27
 
35bd3cf
0a3593d
2ec882e
aff7e63
 
 
 
 
 
 
 
 
 
 
 
 
 
0a3593d
2ec882e
 
 
0a3593d
2ec882e
 
0a3593d
2ec882e
0a3593d
 
2ec882e
b1dde27
2ec882e
 
35bd3cf
2ec882e
b1dde27
2ec882e
 
35bd3cf
2ec882e
b1dde27
 
2ec882e
b1dde27
2ec882e
 
0a3593d
2ec882e
 
b1dde27
 
 
 
 
2ec882e
c9b1bf6
 
2ec882e
 
 
c9b1bf6
 
2ec882e
 
0a3593d
b1dde27
2ec882e
0a3593d
2ec882e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import torch
from huggingface_hub import snapshot_download
from diffusers import (
    StableDiffusionPipeline,
    DPMSolverMultistepScheduler,
    AutoencoderKL,
    UNet2DConditionModel
)
from transformers import CLIPTextModel, CLIPTokenizer
from peft import LoraConfig, get_peft_model

MODEL_ID = "black-forest-labs/FLUX.1-dev"

# download
model_path = snapshot_download(
    MODEL_ID,
    local_dir="./fluxdev-model",
    use_auth_token=True
)

# later loading
pipe = StableDiffusionPipeline.from_pretrained(
    model_path,
    torch_dtype=torch.float16,
    use_auth_token=True
).to("cuda")

# 1) grab the model locally
print("📥 Downloading Flux‑Dev model…")
model_path = snapshot_download(MODEL_ID, local_dir="./fluxdev-model")

# 2) load each piece with its correct subfolder
print("🔄 Loading scheduler…")
scheduler = DPMSolverMultistepScheduler.from_pretrained(
    model_path, subfolder="scheduler"
)

print("🔄 Loading VAE…")
vae = AutoencoderKL.from_pretrained(
    model_path, subfolder="vae", torch_dtype=torch.float16
)

print("🔄 Loading text encoder + tokenizer…")
text_encoder = CLIPTextModel.from_pretrained(
    model_path, subfolder="text_encoder", torch_dtype=torch.float16
)
tokenizer = CLIPTokenizer.from_pretrained(
    model_path, subfolder="tokenizer"
)

print("🔄 Loading U‑Net…")
unet = UNet2DConditionModel.from_pretrained(
    model_path, subfolder="unet", torch_dtype=torch.float16
)

# 3) assemble the pipeline
print("🛠  Assembling pipeline…")
pipe = StableDiffusionPipeline(
    vae=vae,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    unet=unet,
    scheduler=scheduler
).to("cuda")

# 4) apply LoRA
print("🧠 Applying LoRA…")
lora_config = LoraConfig(r=16, lora_alpha=16, bias="none", task_type="CAUSAL_LM")
pipe.unet = get_peft_model(pipe.unet, lora_config)

# 5) your training loop (or dummy loop for illustration)
print("🚀 Starting fine‑tuning…")
for step in range(100):
    print(f"Training step {step+1}/100")
    # …insert your actual data‑loader and loss/backprop here…

os.makedirs(output_dir, exist_ok=True)
pipe.save_pretrained(output_dir)
print("✅ Done. LoRA weights in", output_dir)