Spaces:
Sleeping
Sleeping
import logging | |
import torch | |
from diffusers import (AutoencoderKL, DDPMScheduler, | |
EulerAncestralDiscreteScheduler, LCMScheduler, | |
Transformer2DModel, UNet2DConditionModel) | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file | |
from models.RewardPixart import RewardPixartPipeline, freeze_params | |
from models.RewardStableDiffusion import RewardStableDiffusion | |
from models.RewardStableDiffusionXL import RewardStableDiffusionXL | |
def get_model( | |
model_name: str, | |
dtype: torch.dtype, | |
device: torch.device, | |
cache_dir: str, | |
memsave: bool = False, | |
): | |
logging.info(f"Loading model: {model_name}") | |
if model_name == "sd-turbo": | |
pipe = RewardStableDiffusion.from_pretrained( | |
"stabilityai/sd-turbo", | |
torch_dtype=dtype, | |
variant="fp16", | |
cache_dir=cache_dir, | |
memsave=memsave, | |
) | |
pipe = pipe.to(device, dtype) | |
elif model_name == "sdxl-turbo": | |
vae = AutoencoderKL.from_pretrained( | |
"madebyollin/sdxl-vae-fp16-fix", | |
torch_dtype=torch.float16, | |
cache_dir=cache_dir, | |
) | |
pipe = RewardStableDiffusionXL.from_pretrained( | |
"stabilityai/sdxl-turbo", | |
vae=vae, | |
torch_dtype=dtype, | |
variant="fp16", | |
use_safetensors=True, | |
cache_dir=cache_dir, | |
memsave=memsave, | |
) | |
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( | |
pipe.scheduler.config, timestep_spacing="trailing" | |
) | |
pipe = pipe.to(device, dtype) | |
elif model_name == "pixart": | |
pipe = RewardPixartPipeline.from_pretrained( | |
"PixArt-alpha/PixArt-XL-2-1024-MS", | |
torch_dtype=dtype, | |
cache_dir=cache_dir, | |
memsave=memsave, | |
) | |
pipe.transformer = Transformer2DModel.from_pretrained( | |
"PixArt-alpha/PixArt-Alpha-DMD-XL-2-512x512", | |
subfolder="transformer", | |
torch_dtype=dtype, | |
cache_dir=cache_dir, | |
) | |
pipe.scheduler = DDPMScheduler.from_pretrained( | |
"PixArt-alpha/PixArt-Alpha-DMD-XL-2-512x512", | |
subfolder="scheduler", | |
cache_dir=cache_dir, | |
) | |
# speed-up T5 | |
pipe.text_encoder.to_bettertransformer() | |
pipe.transformer.eval() | |
freeze_params(pipe.transformer.parameters()) | |
pipe.transformer.enable_gradient_checkpointing() | |
pipe = pipe.to(device) | |
elif model_name == "hyper-sd": | |
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0" | |
repo_name = "ByteDance/Hyper-SD" | |
ckpt_name = "Hyper-SDXL-1step-Unet.safetensors" | |
# Load model. | |
unet = UNet2DConditionModel.from_config( | |
base_model_id, subfolder="unet", cache_dir=cache_dir | |
).to(device, dtype) | |
unet.load_state_dict( | |
load_file( | |
hf_hub_download(repo_name, ckpt_name, cache_dir=cache_dir), | |
device="cuda", | |
) | |
) | |
pipe = RewardStableDiffusionXL.from_pretrained( | |
base_model_id, | |
unet=unet, | |
torch_dtype=dtype, | |
variant="fp16", | |
cache_dir=cache_dir, | |
is_hyper=True, | |
memsave=memsave, | |
) | |
# Use LCM scheduler instead of ddim scheduler to support specific timestep number inputs | |
pipe.scheduler = LCMScheduler.from_config( | |
pipe.scheduler.config, cache_dir=cache_dir | |
) | |
pipe = pipe.to(device, dtype) | |
# upcast vae | |
pipe.vae = pipe.vae.to(dtype=torch.float32) | |
# pipe.enable_sequential_cpu_offload() | |
else: | |
raise ValueError(f"Unknown model name: {model_name}") | |
return pipe | |