ReNO / models /utils.py
fffiloni's picture
Upload 24 files
ca25718 verified
raw
history blame
3.83 kB
import logging
import torch
from diffusers import (AutoencoderKL, DDPMScheduler,
EulerAncestralDiscreteScheduler, LCMScheduler,
Transformer2DModel, UNet2DConditionModel)
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from models.RewardPixart import RewardPixartPipeline, freeze_params
from models.RewardStableDiffusion import RewardStableDiffusion
from models.RewardStableDiffusionXL import RewardStableDiffusionXL
def get_model(
model_name: str,
dtype: torch.dtype,
device: torch.device,
cache_dir: str,
memsave: bool = False,
):
logging.info(f"Loading model: {model_name}")
if model_name == "sd-turbo":
pipe = RewardStableDiffusion.from_pretrained(
"stabilityai/sd-turbo",
torch_dtype=dtype,
variant="fp16",
cache_dir=cache_dir,
memsave=memsave,
)
pipe = pipe.to(device, dtype)
elif model_name == "sdxl-turbo":
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix",
torch_dtype=torch.float16,
cache_dir=cache_dir,
)
pipe = RewardStableDiffusionXL.from_pretrained(
"stabilityai/sdxl-turbo",
vae=vae,
torch_dtype=dtype,
variant="fp16",
use_safetensors=True,
cache_dir=cache_dir,
memsave=memsave,
)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe.scheduler.config, timestep_spacing="trailing"
)
pipe = pipe.to(device, dtype)
elif model_name == "pixart":
pipe = RewardPixartPipeline.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
torch_dtype=dtype,
cache_dir=cache_dir,
memsave=memsave,
)
pipe.transformer = Transformer2DModel.from_pretrained(
"PixArt-alpha/PixArt-Alpha-DMD-XL-2-512x512",
subfolder="transformer",
torch_dtype=dtype,
cache_dir=cache_dir,
)
pipe.scheduler = DDPMScheduler.from_pretrained(
"PixArt-alpha/PixArt-Alpha-DMD-XL-2-512x512",
subfolder="scheduler",
cache_dir=cache_dir,
)
# speed-up T5
pipe.text_encoder.to_bettertransformer()
pipe.transformer.eval()
freeze_params(pipe.transformer.parameters())
pipe.transformer.enable_gradient_checkpointing()
pipe = pipe.to(device)
elif model_name == "hyper-sd":
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
repo_name = "ByteDance/Hyper-SD"
ckpt_name = "Hyper-SDXL-1step-Unet.safetensors"
# Load model.
unet = UNet2DConditionModel.from_config(
base_model_id, subfolder="unet", cache_dir=cache_dir
).to(device, dtype)
unet.load_state_dict(
load_file(
hf_hub_download(repo_name, ckpt_name, cache_dir=cache_dir),
device="cuda",
)
)
pipe = RewardStableDiffusionXL.from_pretrained(
base_model_id,
unet=unet,
torch_dtype=dtype,
variant="fp16",
cache_dir=cache_dir,
is_hyper=True,
memsave=memsave,
)
# Use LCM scheduler instead of ddim scheduler to support specific timestep number inputs
pipe.scheduler = LCMScheduler.from_config(
pipe.scheduler.config, cache_dir=cache_dir
)
pipe = pipe.to(device, dtype)
# upcast vae
pipe.vae = pipe.vae.to(dtype=torch.float32)
# pipe.enable_sequential_cpu_offload()
else:
raise ValueError(f"Unknown model name: {model_name}")
return pipe