Spaces:
Running
Running
File size: 1,673 Bytes
80a1334 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import torch
from diffusers import FluxPipeline
# Optimized for Apple Silicon (MPS) - 16 CPU cores, MPS available
# Memory-efficient configuration for Apple Silicon
# Load pipeline with bfloat16 for better MPS performance
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16,
use_safetensors=True
)
# Move to MPS device for GPU acceleration on Apple Silicon
pipe.to("mps")
# Apple Silicon optimizations
pipe.enable_attention_slicing() # Reduce memory usage
pipe.enable_vae_slicing() # VAE memory optimization
# Optional: Enable model CPU offload if memory is tight
# pipe.enable_model_cpu_offload()
# For Apple Silicon, compile the UNet for speed (if supported)
try:
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
except Exception:
print("Torch compile not supported, proceeding without compilation")
prompt = "A cat holding a sign that says hello world"
# Generate image with optimized settings for Apple Silicon
with torch.inference_mode():
out = pipe(
prompt=prompt,
guidance_scale=0.0, # FLUX.1-schnell works best with guidance_scale=0
height=768,
width=1360,
num_inference_steps=4, # FLUX.1-schnell is optimized for 4 steps
max_sequence_length=256, # Reduced for memory efficiency
generator=torch.Generator(device="mps").manual_seed(42) # Reproducible results
).images[0]
# Save the generated image
out.save("image.png")
print("Image generated and saved as 'image.png'")
print("Optimizations applied: MPS device, bfloat16 precision, attention slicing, VAE slicing") |