Instructions to use szxllm/MSD with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use szxllm/MSD with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("szxllm/MSD", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| import argparse | |
| import os | |
| import sys | |
| from pathlib import Path | |
| import torch | |
| from diffusers import ( | |
| StableDiffusionPipeline, | |
| UNet2DConditionModel, | |
| AutoencoderKL, | |
| DDPMScheduler, | |
| ) | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from PIL import Image | |
| # --- Crucial: Import Mamba utilities --- | |
| # Ensure msd_utils.py is in the same directory or Python path | |
| try: | |
| from msd_utils import MambaSequentialBlock, replace_unet_self_attention_with_mamba | |
| print("Successfully imported Mamba utilities from msd_utils.py") | |
| except ImportError as e: | |
| print(f"ERROR: Could not import from msd_utils.py. Make sure it's in the same directory.") | |
| print(f"Import Error: {e}") | |
| sys.exit(1) | |
| except Exception as e: | |
| print(f"ERROR: An unexpected error occurred while importing msd_utils.py: {e}") | |
| sys.exit(1) | |
| # --- End Mamba Import --- | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Generate images using a fine-tuned Stable Diffusion Mamba UNet checkpoint.") | |
| parser.add_argument( | |
| "--base_model", type=str, default="runwayml/stable-diffusion-v1-5", | |
| help="Path or Hub ID of the base Stable Diffusion model used for training (e.g., 'runwayml/stable-diffusion-v1-5')." | |
| ) | |
| parser.add_argument( | |
| "--checkpoint_dir", type=str, required=True, | |
| help="Path to the specific checkpoint directory (e.g., 'sd-mamba-mscoco-urltext-5k-run1/checkpoint-5000')." | |
| ) | |
| parser.add_argument( | |
| "--unet_subfolder", type=str, default="unet_mamba", | |
| help="Name of the subfolder within the checkpoint directory containing the saved UNet weights." | |
| ) | |
| parser.add_argument( | |
| "--prompt", type=str, default="A photo of an astronaut riding a horse on the moon", | |
| help="Text prompt for image generation." | |
| ) | |
| parser.add_argument( | |
| "--output_path", type=str, default="generated_image_mamba.png", | |
| help="Path to save the generated image." | |
| ) | |
| parser.add_argument( | |
| "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", | |
| help="Device to use for generation ('cuda' or 'cpu')." | |
| ) | |
| parser.add_argument( | |
| "--seed", type=int, default=None, | |
| help="Optional random seed for reproducibility." | |
| ) | |
| parser.add_argument( | |
| "--num_inference_steps", type=int, default=30, | |
| help="Number of denoising steps." | |
| ) | |
| parser.add_argument( | |
| "--guidance_scale", type=float, default=7.5, | |
| help="Scale for classifier-free guidance." | |
| ) | |
| # --- Mamba Parameters (MUST match training) --- | |
| parser.add_argument( | |
| "--mamba_d_state", type=int, default=16, required=True, # Require to ensure user provides it | |
| help="Mamba ssm state dimension used during training." | |
| ) | |
| parser.add_argument( | |
| "--mamba_d_conv", type=int, default=4, required=True, # Require to ensure user provides it | |
| help="Mamba ssm convolution dimension used during training." | |
| ) | |
| parser.add_argument( | |
| "--mamba_expand", type=int, default=2, required=True, # Require to ensure user provides it | |
| help="Mamba ssm expansion factor used during training." | |
| ) | |
| # --- End Mamba Parameters --- | |
| parser.add_argument( | |
| "--pipeline_dtype", type=str, default="float32", choices=["float32", "float16"], | |
| help="Run pipeline inference in float32 or float16. float32 is generally more stable." | |
| ) | |
| args = parser.parse_args() | |
| return args | |
| def main(): | |
| args = parse_args() | |
| print(f"--- Configuration ---") | |
| print(f"Base Model: {args.base_model}") | |
| print(f"Checkpoint Dir: {args.checkpoint_dir}") | |
| print(f"UNet Subfolder: {args.unet_subfolder}") | |
| print(f"Prompt: '{args.prompt}'") | |
| print(f"Output Path: {args.output_path}") | |
| print(f"Device: {args.device}") | |
| print(f"Seed: {args.seed}") | |
| print(f"Inference Steps: {args.num_inference_steps}") | |
| print(f"Guidance Scale: {args.guidance_scale}") | |
| print(f"Pipeline dtype: {args.pipeline_dtype}") | |
| print(f"Mamba Params: d_state={args.mamba_d_state}, d_conv={args.mamba_d_conv}, expand={args.mamba_expand}") | |
| print(f"--------------------") | |
| # Set device | |
| device = torch.device(args.device) | |
| pipeline_torch_dtype = torch.float32 if args.pipeline_dtype == "float32" else torch.float16 | |
| # Set seed if provided | |
| generator = None | |
| if args.seed is not None: | |
| generator = torch.Generator(device=device).manual_seed(args.seed) | |
| print(f"Using random seed: {args.seed}") | |
| # Prepare Mamba kwargs dictionary | |
| mamba_kwargs = { | |
| 'd_state': args.mamba_d_state, | |
| 'd_conv': args.mamba_d_conv, | |
| 'expand': args.mamba_expand, | |
| } | |
| print("Prepared Mamba kwargs for UNet replacement.") | |
| # --- 1. Load Base Components (Tokenizer, Scheduler, VAE, Text Encoder) --- | |
| print(f"Loading base components from {args.base_model}...") | |
| try: | |
| tokenizer = CLIPTokenizer.from_pretrained(args.base_model, subfolder="tokenizer") | |
| scheduler = DDPMScheduler.from_pretrained(args.base_model, subfolder="scheduler") | |
| # Load VAE and Text Encoder in float32 for stability, move to device | |
| vae = AutoencoderKL.from_pretrained(args.base_model, subfolder="vae", torch_dtype=torch.float32).to(device) | |
| text_encoder = CLIPTextModel.from_pretrained(args.base_model, subfolder="text_encoder", torch_dtype=torch.float32).to(device) | |
| print("Base components loaded.") | |
| except Exception as e: | |
| print(f"ERROR: Failed to load base components from {args.base_model}. Check path/name.") | |
| print(f"Error details: {e}") | |
| sys.exit(1) | |
| # --- 2. Create Base UNet Structure --- | |
| print(f"Creating UNet structure from {args.base_model} config...") | |
| try: | |
| unet_config = UNet2DConditionModel.load_config(args.base_model, subfolder="unet") | |
| unet = UNet2DConditionModel.from_config(unet_config, torch_dtype=pipeline_torch_dtype) # Use target dtype here | |
| print("Base UNet structure created.") | |
| except Exception as e: | |
| print(f"ERROR: Failed to create UNet structure from config {args.base_model}.") | |
| print(f"Error details: {e}") | |
| sys.exit(1) | |
| # --- 3. Modify UNet Structure with Mamba --- | |
| print(f"Replacing UNet Self-Attention with Mamba blocks (using provided parameters)...") | |
| try: | |
| unet = replace_unet_self_attention_with_mamba(unet, mamba_kwargs) | |
| print("UNet structure modified with Mamba blocks.") | |
| except Exception as e: | |
| print(f"ERROR: Failed during UNet modification with Mamba blocks.") | |
| print(f"Error details: {e}") | |
| sys.exit(1) | |
| # --- 4. Load Fine-tuned UNet Weights --- | |
| unet_weights_dir = Path(args.checkpoint_dir) / args.unet_subfolder | |
| print(f"Attempting to load fine-tuned UNet weights from: {unet_weights_dir}") | |
| if not unet_weights_dir.is_dir(): | |
| print(f"ERROR: UNet weights directory not found: {unet_weights_dir}") | |
| print(f"Please ensure '--checkpoint_dir' points to the correct checkpoint folder (e.g., checkpoint-5000)") | |
| print(f"and '--unet_subfolder' is correct (likely 'unet_mamba').") | |
| sys.exit(1) | |
| try: | |
| # Load the state dict into the already modified unet structure | |
| print(f"Loading state dict from {unet_weights_dir}...") | |
| # Check for safetensors first, then bin | |
| state_dict_path_safe = unet_weights_dir / "diffusion_pytorch_model.safetensors" | |
| state_dict_path_bin = unet_weights_dir / "diffusion_pytorch_model.bin" | |
| if state_dict_path_safe.exists(): | |
| from safetensors.torch import load_file | |
| unet_state_dict = load_file(state_dict_path_safe, device="cpu") | |
| print(f"Loaded state dict from {state_dict_path_safe}") | |
| elif state_dict_path_bin.exists(): | |
| unet_state_dict = torch.load(state_dict_path_bin, map_location="cpu") | |
| print(f"Loaded state dict from {state_dict_path_bin}") | |
| else: | |
| raise FileNotFoundError(f"Neither safetensors nor bin file found in {unet_weights_dir}") | |
| # Load into the existing UNet object (which has the Mamba structure) | |
| load_result = unet.load_state_dict(unet_state_dict, strict=True) # Use strict=True to catch mismatches | |
| print(f"UNet state dict loaded successfully. Load result: {load_result}") | |
| del unet_state_dict # Free memory | |
| print("Fine-tuned UNet weights loaded.") | |
| except Exception as e: | |
| print(f"ERROR: Failed to load UNet weights from {unet_weights_dir}.") | |
| print(f"Make sure the directory exists and contains the model weights ('diffusion_pytorch_model.safetensors' or '.bin').") | |
| print(f"Also ensure Mamba parameters match those used during training.") | |
| print(f"Error details: {e}") | |
| sys.exit(1) | |
| # Move UNet to device and set to eval mode | |
| unet = unet.to(device) | |
| unet.eval() | |
| print("UNet moved to device and set to eval mode.") | |
| # --- 5. Create Stable Diffusion Pipeline --- | |
| print("Creating Stable Diffusion Pipeline with modified UNet...") | |
| try: | |
| pipeline = StableDiffusionPipeline( | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| unet=unet, # Use the modified and loaded UNet | |
| scheduler=scheduler, | |
| safety_checker=None, # Disabled during training, keep disabled | |
| feature_extractor=None, | |
| requires_safety_checker=False, | |
| ) | |
| # No need to move pipeline again if components are already on device | |
| # pipeline = pipeline.to(device) # Components already moved | |
| print("Pipeline created successfully.") | |
| except Exception as e: | |
| print(f"ERROR: Failed to create Stable Diffusion Pipeline.") | |
| print(f"Error details: {e}") | |
| sys.exit(1) | |
| # --- 6. Generate Image --- | |
| print(f"Generating image for prompt: '{args.prompt}'...") | |
| try: | |
| with torch.no_grad(): # Inference context | |
| # Run inference in the specified precision | |
| with torch.autocast(device_type=args.device.split(":")[0], dtype=pipeline_torch_dtype, enabled=(pipeline_torch_dtype != torch.float32)): | |
| result = pipeline( | |
| prompt=args.prompt, | |
| num_inference_steps=args.num_inference_steps, | |
| guidance_scale=args.guidance_scale, | |
| generator=generator, | |
| # Add negative prompt if needed: negative_prompt="..." | |
| ) | |
| image = result.images[0] | |
| print("Image generation complete.") | |
| except Exception as e: | |
| print(f"ERROR: Image generation failed.") | |
| print(f"Error details: {e}") | |
| sys.exit(1) | |
| # --- 7. Save Image --- | |
| try: | |
| output_dir = Path(args.output_path).parent | |
| output_dir.mkdir(parents=True, exist_ok=True) # Ensure output directory exists | |
| image.save(args.output_path) | |
| print(f"Image saved successfully to: {args.output_path}") | |
| except Exception as e: | |
| print(f"ERROR: Failed to save image to {args.output_path}.") | |
| print(f"Error details: {e}") | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| main() |