File size: 1,673 Bytes
80a1334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
from diffusers import FluxPipeline

# Optimized for Apple Silicon (MPS) - 16 CPU cores, MPS available
# Memory-efficient configuration for Apple Silicon

# Load pipeline with bfloat16 for better MPS performance
pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell", 
    torch_dtype=torch.bfloat16,
    use_safetensors=True
)

# Move to MPS device for GPU acceleration on Apple Silicon
pipe.to("mps")

# Apple Silicon optimizations
pipe.enable_attention_slicing()  # Reduce memory usage
pipe.enable_vae_slicing()        # VAE memory optimization

# Optional: Enable model CPU offload if memory is tight
# pipe.enable_model_cpu_offload()

# For Apple Silicon, compile the UNet for speed (if supported)
try:
    pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
except Exception:
    print("Torch compile not supported, proceeding without compilation")

prompt = "A cat holding a sign that says hello world"

# Generate image with optimized settings for Apple Silicon
with torch.inference_mode():
    out = pipe(
        prompt=prompt,
        guidance_scale=0.0,           # FLUX.1-schnell works best with guidance_scale=0
        height=768,
        width=1360,
        num_inference_steps=4,        # FLUX.1-schnell is optimized for 4 steps
        max_sequence_length=256,      # Reduced for memory efficiency
        generator=torch.Generator(device="mps").manual_seed(42)  # Reproducible results
    ).images[0]

# Save the generated image
out.save("image.png")

print("Image generated and saved as 'image.png'")
print("Optimizations applied: MPS device, bfloat16 precision, attention slicing, VAE slicing")