import os from typing import List, Dict, Union from tqdm import tqdm import torch import safetensors from huggingface_hub import hf_hub_download from transformers import AutoTokenizer, CLIPTextModelWithProjection from diffusers import ( StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler, ) from diffusers.loaders import LoraLoaderMixin SDXL_REPO = "stabilityai/stable-diffusion-xl-base-1.0" JSDXL_REPO = "stabilityai/japanese-stable-diffusion-xl" L_REPO = "ByteDance/SDXL-Lightning" def load_state_dict(checkpoint_file: Union[str, os.PathLike], device: str = "cpu"): file_extension = os.path.basename(checkpoint_file).split(".")[-1] if file_extension == "safetensors": return safetensors.torch.load_file(checkpoint_file, device=device) else: return torch.load(checkpoint_file, map_location=device) def load_from_pretrained( repo_id, filename="diffusion_pytorch_model.fp16.safetensors", subfolder="unet", device="cuda", ) -> Dict[str, torch.Tensor]: return load_state_dict( hf_hub_download( repo_id=repo_id, filename=filename, subfolder=subfolder, ), device=device, ) def reshape_weight_task_tensors(task_tensors, weights): """ Reshapes `weights` to match the shape of `task_tensors` by unsqeezing in the remaining dimenions. Args: task_tensors (`torch.Tensor`): The tensors that will be used to reshape `weights`. weights (`torch.Tensor`): The tensor to be reshaped. Returns: `torch.Tensor`: The reshaped tensor. """ new_shape = weights.shape + (1,) * (task_tensors.dim() - weights.dim()) weights = weights.view(new_shape) return weights def linear(task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor: """ Merge the task tensors using `linear`. Args: task_tensors(`List[torch.Tensor]`):The task tensors to merge. weights (`torch.Tensor`):The weights of the task tensors. Returns: `torch.Tensor`: The merged tensor. """ task_tensors = torch.stack(task_tensors, dim=0) # weighted task tensors weights = reshape_weight_task_tensors(task_tensors, weights) weighted_task_tensors = task_tensors * weights mixed_task_tensors = weighted_task_tensors.sum(dim=0) return mixed_task_tensors def merge_models( task_tensors, weights, ): keys = list(task_tensors[0].keys()) weights = torch.tensor(weights, device=task_tensors[0][keys[0]].device) state_dict = {} for key in tqdm(keys, desc="Merging"): w_list = [] for i, sd in enumerate(task_tensors): w = sd.pop(key) w_list.append(w) new_w = linear(task_tensors=w_list, weights=weights) state_dict[key] = new_w return state_dict def split_conv_attn(weights): attn_tensors = {} conv_tensors = {} for key in list(weights.keys()): if any(k in key for k in ["to_k", "to_q", "to_v", "to_out.0"]): attn_tensors[key] = weights.pop(key) else: conv_tensors[key] = weights.pop(key) return {"conv": conv_tensors, "attn": attn_tensors} def load_evosdxl_jp(device="cuda") -> StableDiffusionXLPipeline: sdxl_weights = split_conv_attn(load_from_pretrained(SDXL_REPO, device=device)) dpo_weights = split_conv_attn( load_from_pretrained( "mhdang/dpo-sdxl-text2image-v1", "diffusion_pytorch_model.safetensors", device=device, ) ) jn_weights = split_conv_attn( load_from_pretrained("RunDiffusion/Juggernaut-XL-v9", device=device) ) jsdxl_weights = split_conv_attn(load_from_pretrained(JSDXL_REPO, device=device)) tensors = [sdxl_weights, dpo_weights, jn_weights, jsdxl_weights] new_conv = merge_models( [sd["conv"] for sd in tensors], [ 0.15928833971605916, 0.1032449268871776, 0.6503217149752791, 0.08714501842148402, ], ) new_attn = merge_models( [sd["attn"] for sd in tensors], [ 0.1877279276437178, 0.20014114603909822, 0.3922685507065275, 0.2198623756106564, ], ) del sdxl_weights, dpo_weights, jn_weights, jsdxl_weights torch.cuda.empty_cache() unet_config = UNet2DConditionModel.load_config(SDXL_REPO, subfolder="unet") unet = UNet2DConditionModel.from_config(unet_config).to(device=device) unet.load_state_dict({**new_conv, **new_attn}) state_dict, network_alphas = LoraLoaderMixin.lora_state_dict( L_REPO, weight_name="sdxl_lightning_4step_lora.safetensors" ) LoraLoaderMixin.load_lora_into_unet(state_dict, network_alphas, unet) unet.fuse_lora(lora_scale=3.224682864579401) new_weights = split_conv_attn(unet.state_dict()) l_weights = split_conv_attn( load_from_pretrained( L_REPO, "sdxl_lightning_4step_unet.safetensors", subfolder=None, device=device, ) ) jnl_weights = split_conv_attn( load_from_pretrained( "RunDiffusion/Juggernaut-XL-Lightning", "diffusion_pytorch_model.bin", device=device, ) ) tensors = [l_weights, jnl_weights, new_weights] new_conv = merge_models( [sd["conv"] for sd in tensors], [0.47222002022088533, 0.48419531030361584, 0.04358466947549889], ) new_attn = merge_models( [sd["attn"] for sd in tensors], [0.023119324530758375, 0.04924981616469831, 0.9276308593045434], ) new_weights = {**new_conv, **new_attn} unet = UNet2DConditionModel.from_config(unet_config).to(device=device) unet.load_state_dict({**new_conv, **new_attn}) text_encoder = CLIPTextModelWithProjection.from_pretrained( JSDXL_REPO, subfolder="text_encoder", torch_dtype=torch.float16, variant="fp16" ) tokenizer = AutoTokenizer.from_pretrained( JSDXL_REPO, subfolder="tokenizer", use_fast=False ) pipe = StableDiffusionXLPipeline.from_pretrained( SDXL_REPO, unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=torch.float16, variant="fp16", ) # Ensure sampler uses "trailing" timesteps. pipe.scheduler = EulerDiscreteScheduler.from_config( pipe.scheduler.config, timestep_spacing="trailing" ) pipe = pipe.to(device, dtype=torch.float16) return pipe if __name__ == "__main__": pipe: StableDiffusionXLPipeline = load_evosdxl_jp() images = pipe("犬", num_inference_steps=4, guidance_scale=0).images images[0].save("out.png")