import gc from collections import OrderedDict from typing import Any, Dict, Callable import os from copy import deepcopy from math import ceil import json import safetensors import torch import torch.nn as nn import torch.nn.functional as F from diffusers import ( DiffusionPipeline, StableDiffusionPipeline, StableDiffusionXLPipeline, DDPMScheduler, UNet2DConditionModel, ) import tqdm import yaml def remove_all_forward_hooks(model: torch.nn.Module) -> None: for _name, child in model._modules.items(): # pylint: disable=protected-access if child is not None: if hasattr(child, "_forward_hooks"): child._forward_hooks: Dict[int, Callable] = OrderedDict() remove_all_forward_hooks(child) # Inspired from transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock class SparseMoeBlock(nn.Module): def __init__(self, config, experts): super().__init__() self.hidden_dim = config["hidden_size"] self.num_experts = config["num_local_experts"] self.top_k = config["num_experts_per_tok"] self.out_dim = config.get("out_dim", self.hidden_dim) # gating self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) self.experts = nn.ModuleList([deepcopy(exp) for exp in experts]) def forward(self, hidden_states: torch.Tensor, scale=None) -> torch.Tensor: # pylint: disable=unused-argument batch_size, sequence_length, f_map_sz = hidden_states.shape hidden_states = hidden_states.view(-1, f_map_sz) # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) _, selected_experts = torch.topk( router_logits.sum(dim=0, keepdim=True), self.top_k, dim=1 ) routing_weights = F.softmax( router_logits[:, selected_experts[0]], dim=1, dtype=torch.float ) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) final_hidden_states = torch.zeros( (batch_size * sequence_length, self.out_dim), dtype=hidden_states.dtype, device=hidden_states.device, ) # Loop over all available experts in the model and perform the computation on each expert for i, expert_idx in enumerate(selected_experts[0].tolist()): expert_layer = self.experts[expert_idx] current_hidden_states = routing_weights[:, i].view( batch_size * sequence_length, -1 ) * expert_layer(hidden_states) # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. final_hidden_states = final_hidden_states + current_hidden_states final_hidden_states = final_hidden_states.reshape( batch_size, sequence_length, self.out_dim ) return final_hidden_states def getActivation(activation, name): def hook(model, inp, output): # pylint: disable=unused-argument activation[name] = inp return hook class SegMoEPipeline: def __init__(self, config_or_path, **kwargs) -> Any: """ Instantiates the SegMoEPipeline. SegMoEPipeline implements the Segmind Mixture of Diffusion Experts, efficiently combining Stable Diffusion and Stable Diffusion Xl models. Usage: from segmoe import SegMoEPipeline pipeline = SegMoEPipeline(config_or_path, **kwargs) config_or_path: Path to Config or Directory containing SegMoE checkpoint or HF Card of SegMoE Checkpoint. Other Keyword Arguments: torch_dtype: Data Type to load the pipeline in. (Default: torch.float16) variant: Variant of the Model. (Default: fp16) device: Device to load the model on. (Default: cuda) Other args supported by diffusers.DiffusionPipeline are also supported. For more details visit https://github.com/segmind/segmoe. """ self.torch_dtype = kwargs.pop("torch_dtype", torch.float16) self.use_safetensors = kwargs.pop("use_safetensors", True) self.variant = kwargs.pop("variant", "fp16") self.device = kwargs.pop("device", "cuda") if os.path.isfile(config_or_path): self.load_from_scratch(config_or_path, **kwargs) else: if not os.path.isdir(config_or_path): cached_folder = DiffusionPipeline.download(config_or_path) else: cached_folder = config_or_path unet = self.create_empty(cached_folder) unet.load_state_dict( safetensors.torch.load_file( f"{cached_folder}/unet/diffusion_pytorch_model.safetensors" ) ) self.pipe = DiffusionPipeline.from_pretrained( cached_folder, unet=unet, torch_dtype=self.torch_dtype, use_safetensors=self.use_safetensors, ) self.pipe.to(self.device) self.pipe.unet.to( device=self.device, dtype=self.torch_dtype, memory_format=torch.channels_last, ) def to(self, *args, **kwargs): # TODO added no-op to avoid error self.pipe.to(*args, **kwargs) def load_from_scratch(self, config: str, **kwargs) -> None: # Load Config with open(config, "r", encoding='utf8') as f: config = yaml.load(f, Loader=yaml.SafeLoader) self.config = config if self.config.get("num_experts", None): self.num_experts = self.config["num_experts"] else: if self.config.get("experts", None): self.num_experts = len(self.config["experts"]) else: if self.config.get("loras", None): self.num_experts = len(self.config["loras"]) else: self.num_experts = 1 num_experts_per_tok = self.config.get("num_experts_per_tok", 1) self.config["num_experts_per_tok"] = num_experts_per_tok moe_layers = self.config.get("moe_layers", "attn") self.config["moe_layers"] = moe_layers # Load Base Model if self.config["base_model"].startswith( "https://civitai.com/api/download/models/" ): os.makedirs("base", exist_ok=True) if not os.path.isfile("base/model.safetensors"): os.system( "wget -O " + "base/model.safetensors" + self.config["base_model"] + " --content-disposition" ) self.config["base_model"] = "base/model.safetensors" self.pipe = DiffusionPipeline.from_single_file( self.config["base_model"], torch_dtype=self.torch_dtype ) else: try: self.pipe = DiffusionPipeline.from_pretrained( self.config["base_model"], torch_dtype=self.torch_dtype, use_safetensors=self.use_safetensors, variant=self.variant, **kwargs, ) except Exception: self.pipe = DiffusionPipeline.from_pretrained( self.config["base_model"], torch_dtype=self.torch_dtype, **kwargs ) if self.pipe.__class__ == StableDiffusionPipeline: self.up_idx_start = 1 self.up_idx_end = len(self.pipe.unet.up_blocks) self.down_idx_start = 0 self.down_idx_end = len(self.pipe.unet.down_blocks) - 1 elif self.pipe.__class__ == StableDiffusionXLPipeline: self.up_idx_start = 0 self.up_idx_end = len(self.pipe.unet.up_blocks) - 1 self.down_idx_start = 1 self.down_idx_end = len(self.pipe.unet.down_blocks) self.config["up_idx_start"] = self.up_idx_start self.config["up_idx_end"] = self.up_idx_end self.config["down_idx_start"] = self.down_idx_start self.config["down_idx_end"] = self.down_idx_end # TODO: Add Support for Scheduler Selection self.pipe.scheduler = DDPMScheduler.from_config(self.pipe.scheduler.config) # Load Experts experts = [] positive = [] negative = [] if self.config.get("experts", None): for i, exp in enumerate(self.config["experts"]): positive.append(exp["positive_prompt"]) negative.append(exp["negative_prompt"]) if exp["source_model"].startswith( "https://civitai.com/api/download/models/" ): try: if not os.path.isfile(f"expert_{i}/model.safetensors"): os.makedirs(f"expert_{i}", exist_ok=True) if not os.path.isfile(f"expert_{i}/model.safetensors"): os.system( f"wget {exp['source_model']} -O " + f"expert_{i}/model.safetensors" + " --content-disposition" ) exp["source_model"] = f"expert_{i}/model.safetensors" expert = DiffusionPipeline.from_single_file( exp["source_model"], ).to(self.device, self.torch_dtype) except Exception as e: print(f"Expert {i} {exp['source_model']} failed to load") print("Error:", e) else: try: expert = DiffusionPipeline.from_pretrained( exp["source_model"], torch_dtype=self.torch_dtype, use_safetensors=self.use_safetensors, variant=self.variant, **kwargs, ) # TODO: Add Support for Scheduler Selection expert.scheduler = DDPMScheduler.from_config( expert.scheduler.config ) except Exception: expert = DiffusionPipeline.from_pretrained( exp["source_model"], torch_dtype=self.torch_dtype, **kwargs ) expert.scheduler = DDPMScheduler.from_config( expert.scheduler.config ) if exp.get("loras", None): for j, lora in enumerate(exp["loras"]): if lora.get("positive_prompt", None): positive[-1] += " " + lora["positive_prompt"] if lora.get("negative_prompt", None): negative[-1] += " " + lora["negative_prompt"] if lora["source_model"].startswith( "https://civitai.com/api/download/models/" ): try: os.makedirs(f"expert_{i}/lora_{i}", exist_ok=True) if not os.path.isfile( f"expert_{i}/lora_{i}/pytorch_lora_weights.safetensors" ): os.system( f"wget {lora['source_model']} -O " + f"expert_{i}/lora_{j}/pytorch_lora_weights.safetensors" + " --content-disposition" ) lora["source_model"] = f"expert_{j}/lora_{j}" expert.load_lora_weights(lora["source_model"]) if len(exp["loras"]) == 1: expert.fuse_lora() except Exception as e: print( f"Expert{i} LoRA {j} {lora['source_model']} failed to load" ) print("Error:", e) else: expert.load_lora_weights(lora["source_model"]) if len(exp["loras"]) == 1: expert.fuse_lora() experts.append(expert) else: experts = [deepcopy(self.pipe) for _ in range(self.num_experts)] if self.config.get("experts", None): if self.config.get("loras", None): for i, lora in enumerate(self.config["loras"]): if lora["source_model"].startswith( "https://civitai.com/api/download/models/" ): try: os.makedirs(f"lora_{i}", exist_ok=True) if not os.path.isfile( f"lora_{i}/pytorch_lora_weights.safetensors" ): os.system( f"wget {lora['source_model']} -O " + f"lora_{i}/pytorch_lora_weights.safetensors" + " --content-disposition" ) lora["source_model"] = f"lora_{i}" self.pipe.load_lora_weights(lora["source_model"]) if len(self.config["loras"]) == 1: self.pipe.fuse_lora() except Exception as e: print(f"LoRA {i} {lora['source_model']} failed to load") print("Error:", e) else: self.pipe.load_lora_weights(lora["source_model"]) if len(self.config["loras"]) == 1: self.pipe.fuse_lora() else: if self.config.get("loras", None): j = [] n_loras = len(self.config["loras"]) i = 0 positive = [""] * len(experts) negative = [""] * len(experts) while n_loras: n = ceil(n_loras / len(experts)) j += [i] * n n_loras -= n i += 1 for i, lora in enumerate(self.config["loras"]): positive[j[i]] += lora["positive_prompt"] + " " negative[j[i]] += lora["negative_prompt"] + " " if lora["source_model"].startswith( "https://civitai.com/api/download/models/" ): try: os.makedirs(f"lora_{i}", exist_ok=True) if not os.path.isfile( f"lora_{i}/pytorch_lora_weights.safetensors" ): os.system( f"wget {lora['source_model']} -O " + f"lora_{i}/pytorch_lora_weights.safetensors" + " --content-disposition" ) lora["source_model"] = f"lora_{i}" experts[j[i]].load_lora_weights(lora["source_model"]) experts[j[i]].fuse_lora() except Exception: print(f"LoRA {i} {lora['source_model']} failed to load") else: experts[j[i]].load_lora_weights(lora["source_model"]) experts[j[i]].fuse_lora() # Replace FF and Attention Layers with Sparse MoE Layers for i in range(self.down_idx_start, self.down_idx_end): for j in range(len(self.pipe.unet.down_blocks[i].attentions)): for k in range( len(self.pipe.unet.down_blocks[i].attentions[j].transformer_blocks) ): if not moe_layers == "attn": config = { "hidden_size": next( self.pipe.unet.down_blocks[i] .attentions[j] .transformer_blocks[k] .ff.parameters() ).size()[-1], "num_experts_per_tok": num_experts_per_tok, "num_local_experts": len(experts), } # FF Layers layers = [] for l in range(len(experts)): layers.append( deepcopy( experts[l] .unet.down_blocks[i] .attentions[j] .transformer_blocks[k] .ff ) ) self.pipe.unet.down_blocks[i].attentions[j].transformer_blocks[ k ].ff = SparseMoeBlock(config, layers) if not moe_layers == "ff": ## Attns config = { "hidden_size": self.pipe.unet.down_blocks[i] .attentions[j] .transformer_blocks[k] .attn1.to_q.weight.size()[-1], "num_experts_per_tok": num_experts_per_tok, "num_local_experts": self.num_experts, } layers = [] for l in range(len(experts)): layers.append( deepcopy( experts[l] .unet.down_blocks[i] .attentions[j] .transformer_blocks[k] .attn1.to_q ) ) self.pipe.unet.down_blocks[i].attentions[j].transformer_blocks[ k ].attn1.to_q = SparseMoeBlock(config, layers) layers = [] for l in range(len(experts)): layers.append( deepcopy( experts[l] .unet.down_blocks[i] .attentions[j] .transformer_blocks[k] .attn1.to_k ) ) self.pipe.unet.down_blocks[i].attentions[j].transformer_blocks[ k ].attn1.to_k = SparseMoeBlock(config, layers) layers = [] for l in range(len(experts)): layers.append( deepcopy( experts[l] .unet.down_blocks[i] .attentions[j] .transformer_blocks[k] .attn1.to_v ) ) self.pipe.unet.down_blocks[i].attentions[j].transformer_blocks[ k ].attn1.to_v = SparseMoeBlock(config, layers) config = { "hidden_size": self.pipe.unet.down_blocks[i] .attentions[j] .transformer_blocks[k] .attn2.to_q.weight.size()[-1], "num_experts_per_tok": num_experts_per_tok, "num_local_experts": len(experts), } layers = [] for l in range(len(experts)): layers.append( deepcopy( experts[l] .unet.down_blocks[i] .attentions[j] .transformer_blocks[k] .attn2.to_q ) ) self.pipe.unet.down_blocks[i].attentions[j].transformer_blocks[ k ].attn2.to_q = SparseMoeBlock(config, layers) config = { "hidden_size": self.pipe.unet.down_blocks[i] .attentions[j] .transformer_blocks[k] .attn2.to_k.weight.size()[-1], "num_experts_per_tok": num_experts_per_tok, "num_local_experts": len(experts), "out_dim": self.pipe.unet.down_blocks[i] .attentions[j] .transformer_blocks[k] .attn2.to_k.weight.size()[0], } layers = [] for l in range(len(experts)): layers.append( deepcopy( experts[l] .unet.down_blocks[i] .attentions[j] .transformer_blocks[k] .attn2.to_k ) ) self.pipe.unet.down_blocks[i].attentions[j].transformer_blocks[ k ].attn2.to_k = SparseMoeBlock(config, layers) config = { "hidden_size": self.pipe.unet.down_blocks[i] .attentions[j] .transformer_blocks[k] .attn2.to_v.weight.size()[-1], "out_dim": self.pipe.unet.down_blocks[i] .attentions[j] .transformer_blocks[k] .attn2.to_v.weight.size()[0], "num_experts_per_tok": num_experts_per_tok, "num_local_experts": len(experts), } layers = [] for l in range(len(experts)): layers.append( deepcopy( experts[l] .unet.down_blocks[i] .attentions[j] .transformer_blocks[k] .attn2.to_v ) ) self.pipe.unet.down_blocks[i].attentions[j].transformer_blocks[ k ].attn2.to_v = SparseMoeBlock(config, layers) for i in range(self.up_idx_start, self.up_idx_end): for j in range(len(self.pipe.unet.up_blocks[i].attentions)): for k in range( len(self.pipe.unet.up_blocks[i].attentions[j].transformer_blocks) ): if not moe_layers == "attn": config = { "hidden_size": next( self.pipe.unet.up_blocks[i] .attentions[j] .transformer_blocks[k] .ff.parameters() ).size()[-1], "num_experts_per_tok": num_experts_per_tok, "num_local_experts": len(experts), } # FF Layers layers = [] for l in range(len(experts)): layers.append( deepcopy( experts[l] .unet.up_blocks[i] .attentions[j] .transformer_blocks[k] .ff ) ) self.pipe.unet.up_blocks[i].attentions[j].transformer_blocks[ k ].ff = SparseMoeBlock(config, layers) if not moe_layers == "ff": # Attns config = { "hidden_size": self.pipe.unet.up_blocks[i] .attentions[j] .transformer_blocks[k] .attn1.to_q.weight.size()[-1], "num_experts_per_tok": num_experts_per_tok, "num_local_experts": len(experts), } layers = [] for l in range(len(experts)): layers.append( deepcopy( experts[l] .unet.up_blocks[i] .attentions[j] .transformer_blocks[k] .attn1.to_q ) ) self.pipe.unet.up_blocks[i].attentions[j].transformer_blocks[ k ].attn1.to_q = SparseMoeBlock(config, layers) config = { "hidden_size": self.pipe.unet.up_blocks[i] .attentions[j] .transformer_blocks[k] .attn1.to_k.weight.size()[-1], "num_experts_per_tok": num_experts_per_tok, "num_local_experts": len(experts), } layers = [] for l in range(len(experts)): layers.append( deepcopy( experts[l] .unet.up_blocks[i] .attentions[j] .transformer_blocks[k] .attn1.to_k ) ) self.pipe.unet.up_blocks[i].attentions[j].transformer_blocks[ k ].attn1.to_k = SparseMoeBlock(config, layers) config = { "hidden_size": self.pipe.unet.up_blocks[i] .attentions[j] .transformer_blocks[k] .attn1.to_v.weight.size()[-1], "num_experts_per_tok": num_experts_per_tok, "num_local_experts": len(experts), } layers = [] for l in range(len(experts)): layers.append( deepcopy( experts[l] .unet.up_blocks[i] .attentions[j] .transformer_blocks[k] .attn1.to_v ) ) self.pipe.unet.up_blocks[i].attentions[j].transformer_blocks[ k ].attn1.to_v = SparseMoeBlock(config, layers) config = { "hidden_size": self.pipe.unet.up_blocks[i] .attentions[j] .transformer_blocks[k] .attn2.to_q.weight.size()[-1], "num_experts_per_tok": num_experts_per_tok, "num_local_experts": len(experts), } layers = [] for l in range(len(experts)): layers.append( deepcopy( experts[l] .unet.up_blocks[i] .attentions[j] .transformer_blocks[k] .attn2.to_q ) ) self.pipe.unet.up_blocks[i].attentions[j].transformer_blocks[ k ].attn2.to_q = SparseMoeBlock(config, layers) config = { "hidden_size": self.pipe.unet.up_blocks[i] .attentions[j] .transformer_blocks[k] .attn2.to_k.weight.size()[-1], "out_dim": self.pipe.unet.up_blocks[i] .attentions[j] .transformer_blocks[k] .attn2.to_k.weight.size()[0], "num_experts_per_tok": num_experts_per_tok, "num_local_experts": len(experts), } layers = [] for l in range(len(experts)): layers.append( deepcopy( experts[l] .unet.up_blocks[i] .attentions[j] .transformer_blocks[k] .attn2.to_k ) ) self.pipe.unet.up_blocks[i].attentions[j].transformer_blocks[ k ].attn2.to_k = SparseMoeBlock(config, layers) config = { "hidden_size": self.pipe.unet.up_blocks[i] .attentions[j] .transformer_blocks[k] .attn2.to_v.weight.size()[-1], "out_dim": self.pipe.unet.up_blocks[i] .attentions[j] .transformer_blocks[k] .attn2.to_v.weight.size()[0], "num_experts_per_tok": num_experts_per_tok, "num_local_experts": len(experts), } layers = [] for l in range(len(experts)): layers.append( deepcopy( experts[l] .unet.up_blocks[i] .attentions[j] .transformer_blocks[k] .attn2.to_v ) ) self.pipe.unet.up_blocks[i].attentions[j].transformer_blocks[ k ].attn2.to_v = SparseMoeBlock(config, layers) # Routing Weight Initialization if self.config.get("init", "hidden") == "hidden": gate_params = self.get_gate_params(experts, positive, negative) for i in range(self.down_idx_start, self.down_idx_end): for j in range(len(self.pipe.unet.down_blocks[i].attentions)): for k in range( len( self.pipe.unet.down_blocks[i] .attentions[j] .transformer_blocks ) ): # FF Layers if not moe_layers == "attn": self.pipe.unet.down_blocks[i].attentions[ j ].transformer_blocks[k].ff.gate.weight = nn.Parameter( gate_params[f"d{i}a{j}t{k}"] ) # Attns if not moe_layers == "ff": self.pipe.unet.down_blocks[i].attentions[ j ].transformer_blocks[ k ].attn1.to_q.gate.weight = nn.Parameter( gate_params[f"sattnqd{i}a{j}t{k}"] ) self.pipe.unet.down_blocks[i].attentions[ j ].transformer_blocks[ k ].attn1.to_k.gate.weight = nn.Parameter( gate_params[f"sattnkd{i}a{j}t{k}"] ) self.pipe.unet.down_blocks[i].attentions[ j ].transformer_blocks[ k ].attn1.to_v.gate.weight = nn.Parameter( gate_params[f"sattnvd{i}a{j}t{k}"] ) self.pipe.unet.down_blocks[i].attentions[ j ].transformer_blocks[ k ].attn2.to_q.gate.weight = nn.Parameter( gate_params[f"cattnqd{i}a{j}t{k}"] ) self.pipe.unet.down_blocks[i].attentions[ j ].transformer_blocks[ k ].attn2.to_k.gate.weight = nn.Parameter( gate_params[f"cattnkd{i}a{j}t{k}"] ) self.pipe.unet.down_blocks[i].attentions[ j ].transformer_blocks[ k ].attn2.to_v.gate.weight = nn.Parameter( gate_params[f"cattnvd{i}a{j}t{k}"] ) for i in range(self.up_idx_start, self.up_idx_end): for j in range(len(self.pipe.unet.up_blocks[i].attentions)): for k in range( len( self.pipe.unet.up_blocks[i].attentions[j].transformer_blocks ) ): # FF Layers if not moe_layers == "attn": self.pipe.unet.up_blocks[i].attentions[ j ].transformer_blocks[k].ff.gate.weight = nn.Parameter( gate_params[f"u{i}a{j}t{k}"] ) if not moe_layers == "ff": self.pipe.unet.up_blocks[i].attentions[ j ].transformer_blocks[ k ].attn1.to_q.gate.weight = nn.Parameter( gate_params[f"sattnqu{i}a{j}t{k}"] ) self.pipe.unet.up_blocks[i].attentions[ j ].transformer_blocks[ k ].attn1.to_k.gate.weight = nn.Parameter( gate_params[f"sattnku{i}a{j}t{k}"] ) self.pipe.unet.up_blocks[i].attentions[ j ].transformer_blocks[ k ].attn1.to_v.gate.weight = nn.Parameter( gate_params[f"sattnvu{i}a{j}t{k}"] ) self.pipe.unet.up_blocks[i].attentions[ j ].transformer_blocks[ k ].attn2.to_q.gate.weight = nn.Parameter( gate_params[f"cattnqu{i}a{j}t{k}"] ) self.pipe.unet.up_blocks[i].attentions[ j ].transformer_blocks[ k ].attn2.to_k.gate.weight = nn.Parameter( gate_params[f"cattnku{i}a{j}t{k}"] ) self.pipe.unet.up_blocks[i].attentions[ j ].transformer_blocks[ k ].attn2.to_v.gate.weight = nn.Parameter( gate_params[f"cattnvu{i}a{j}t{k}"] ) self.config["num_experts"] = len(experts) remove_all_forward_hooks(self.pipe.unet) try: del experts del expert except Exception: pass # Move Model to Device self.pipe.to(self.device) self.pipe.unet.to( device=self.device, dtype=self.torch_dtype, memory_format=torch.channels_last, ) gc.collect() torch.cuda.empty_cache() def __call__(self, *args: Any, **kwds: Any) -> Any: """ Inference the SegMoEPipeline. Calls diffusers.DiffusionPipeline forward with the keyword arguments. See https://github.com/segmind/segmoe#usage for detailed usage. """ return self.pipe(*args, **kwds) def create_empty(self, path): with open(f"{path}/unet/config.json", encoding='utf8') as f: config = json.load(f) self.config = config["segmoe_config"] unet = UNet2DConditionModel.from_config(config) num_experts_per_tok = self.config["num_experts_per_tok"] num_experts = self.config["num_experts"] moe_layers = self.config["moe_layers"] self.up_idx_start = self.config["up_idx_start"] self.up_idx_end = self.config["up_idx_end"] self.down_idx_start = self.config["down_idx_start"] self.down_idx_end = self.config["down_idx_end"] for i in range(self.down_idx_start, self.down_idx_end): for j in range(len(unet.down_blocks[i].attentions)): for k in range( len(unet.down_blocks[i].attentions[j].transformer_blocks) ): if not moe_layers == "attn": config = { "hidden_size": next( unet.down_blocks[i] .attentions[j] .transformer_blocks[k] .ff.parameters() ).size()[-1], "num_experts_per_tok": num_experts_per_tok, "num_local_experts": num_experts, } # FF Layers layers = [ unet.down_blocks[i].attentions[j].transformer_blocks[k].ff ] * num_experts unet.down_blocks[i].attentions[j].transformer_blocks[ k ].ff = SparseMoeBlock(config, layers) if not moe_layers == "ff": ## Attns config = { "hidden_size": unet.down_blocks[i] .attentions[j] .transformer_blocks[k] .attn1.to_q.weight.size()[-1], "num_experts_per_tok": num_experts_per_tok, "num_local_experts": num_experts, } layers = [ unet.down_blocks[i] .attentions[j] .transformer_blocks[k] .attn1.to_q ] * num_experts unet.down_blocks[i].attentions[j].transformer_blocks[ k ].attn1.to_q = SparseMoeBlock(config, layers) layers = [ unet.down_blocks[i] .attentions[j] .transformer_blocks[k] .attn1.to_k ] * num_experts unet.down_blocks[i].attentions[j].transformer_blocks[ k ].attn1.to_k = SparseMoeBlock(config, layers) layers = [ unet.down_blocks[i] .attentions[j] .transformer_blocks[k] .attn1.to_v ] * num_experts unet.down_blocks[i].attentions[j].transformer_blocks[ k ].attn1.to_v = SparseMoeBlock(config, layers) config = { "hidden_size": unet.down_blocks[i] .attentions[j] .transformer_blocks[k] .attn2.to_q.weight.size()[-1], "num_experts_per_tok": num_experts_per_tok, "num_local_experts": num_experts, } layers = [ unet.down_blocks[i] .attentions[j] .transformer_blocks[k] .attn2.to_q ] * num_experts unet.down_blocks[i].attentions[j].transformer_blocks[ k ].attn2.to_q = SparseMoeBlock(config, layers) config = { "hidden_size": unet.down_blocks[i] .attentions[j] .transformer_blocks[k] .attn2.to_k.weight.size()[-1], "num_experts_per_tok": num_experts_per_tok, "num_local_experts": num_experts, "out_dim": unet.down_blocks[i] .attentions[j] .transformer_blocks[k] .attn2.to_k.weight.size()[0], } layers = [ unet.down_blocks[i] .attentions[j] .transformer_blocks[k] .attn2.to_k ] * num_experts unet.down_blocks[i].attentions[j].transformer_blocks[ k ].attn2.to_k = SparseMoeBlock(config, layers) config = { "hidden_size": unet.down_blocks[i] .attentions[j] .transformer_blocks[k] .attn2.to_v.weight.size()[-1], "out_dim": unet.down_blocks[i] .attentions[j] .transformer_blocks[k] .attn2.to_v.weight.size()[0], "num_experts_per_tok": num_experts_per_tok, "num_local_experts": num_experts, } layers = [ unet.down_blocks[i] .attentions[j] .transformer_blocks[k] .attn2.to_v ] * num_experts unet.down_blocks[i].attentions[j].transformer_blocks[ k ].attn2.to_v = SparseMoeBlock(config, layers) for i in range(self.up_idx_start, self.up_idx_end): for j in range(len(unet.up_blocks[i].attentions)): for k in range(len(unet.up_blocks[i].attentions[j].transformer_blocks)): if not moe_layers == "attn": config = { "hidden_size": next( unet.up_blocks[i] .attentions[j] .transformer_blocks[k] .ff.parameters() ).size()[-1], "num_experts_per_tok": num_experts_per_tok, "num_local_experts": num_experts, } # FF Layers layers = [ unet.up_blocks[i].attentions[j].transformer_blocks[k].ff ] * num_experts unet.up_blocks[i].attentions[j].transformer_blocks[ k ].ff = SparseMoeBlock(config, layers) if not moe_layers == "ff": # Attns config = { "hidden_size": unet.up_blocks[i] .attentions[j] .transformer_blocks[k] .attn1.to_q.weight.size()[-1], "num_experts_per_tok": num_experts_per_tok, "num_local_experts": num_experts, } layers = [ unet.up_blocks[i] .attentions[j] .transformer_blocks[k] .attn1.to_q ] * num_experts unet.up_blocks[i].attentions[j].transformer_blocks[ k ].attn1.to_q = SparseMoeBlock(config, layers) config = { "hidden_size": unet.up_blocks[i] .attentions[j] .transformer_blocks[k] .attn1.to_k.weight.size()[-1], "num_experts_per_tok": num_experts_per_tok, "num_local_experts": num_experts, } layers = [ unet.up_blocks[i] .attentions[j] .transformer_blocks[k] .attn1.to_k ] * num_experts unet.up_blocks[i].attentions[j].transformer_blocks[ k ].attn1.to_k = SparseMoeBlock(config, layers) config = { "hidden_size": unet.up_blocks[i] .attentions[j] .transformer_blocks[k] .attn1.to_v.weight.size()[-1], "num_experts_per_tok": num_experts_per_tok, "num_local_experts": num_experts, } layers = [ unet.up_blocks[i] .attentions[j] .transformer_blocks[k] .attn1.to_v ] * num_experts unet.up_blocks[i].attentions[j].transformer_blocks[ k ].attn1.to_v = SparseMoeBlock(config, layers) config = { "hidden_size": unet.up_blocks[i] .attentions[j] .transformer_blocks[k] .attn2.to_q.weight.size()[-1], "num_experts_per_tok": num_experts_per_tok, "num_local_experts": num_experts, } layers = [ unet.up_blocks[i] .attentions[j] .transformer_blocks[k] .attn2.to_q ] * num_experts unet.up_blocks[i].attentions[j].transformer_blocks[ k ].attn2.to_q = SparseMoeBlock(config, layers) config = { "hidden_size": unet.up_blocks[i] .attentions[j] .transformer_blocks[k] .attn2.to_k.weight.size()[-1], "out_dim": unet.up_blocks[i] .attentions[j] .transformer_blocks[k] .attn2.to_k.weight.size()[0], "num_experts_per_tok": num_experts_per_tok, "num_local_experts": num_experts, } layers = [ unet.up_blocks[i] .attentions[j] .transformer_blocks[k] .attn2.to_k ] * num_experts unet.up_blocks[i].attentions[j].transformer_blocks[ k ].attn2.to_k = SparseMoeBlock(config, layers) config = { "hidden_size": unet.up_blocks[i] .attentions[j] .transformer_blocks[k] .attn2.to_v.weight.size()[-1], "out_dim": unet.up_blocks[i] .attentions[j] .transformer_blocks[k] .attn2.to_v.weight.size()[0], "num_experts_per_tok": num_experts_per_tok, "num_local_experts": num_experts, } layers = [ unet.up_blocks[i] .attentions[j] .transformer_blocks[k] .attn2.to_v ] * num_experts unet.up_blocks[i].attentions[j].transformer_blocks[ k ].attn2.to_v = SparseMoeBlock(config, layers) return unet def save_pretrained(self, path): """ Save SegMoEPipeline to Disk. Usage: pipeline.save_pretrained(path) Parameters: path: Path to Directory to save the model in. """ for param in self.pipe.unet.parameters(): param.data = param.data.contiguous() self.pipe.unet.config["segmoe_config"] = self.config self.pipe.save_pretrained(path) safetensors.torch.save_file( self.pipe.unet.state_dict(), f"{path}/unet/diffusion_pytorch_model.safetensors", ) def cast_hook(self, pipe, dicts): for i in range(self.down_idx_start, self.down_idx_end): for j in range(len(pipe.unet.down_blocks[i].attentions)): for k in range( len(pipe.unet.down_blocks[i].attentions[j].transformer_blocks) ): pipe.unet.down_blocks[i].attentions[j].transformer_blocks[ k ].ff.register_forward_hook(getActivation(dicts, f"d{i}a{j}t{k}")) ## Down Self Attns pipe.unet.down_blocks[i].attentions[j].transformer_blocks[ k ].attn1.to_q.register_forward_hook( getActivation(dicts, f"sattnqd{i}a{j}t{k}") ) pipe.unet.down_blocks[i].attentions[j].transformer_blocks[ k ].attn1.to_k.register_forward_hook( getActivation(dicts, f"sattnkd{i}a{j}t{k}") ) pipe.unet.down_blocks[i].attentions[j].transformer_blocks[ k ].attn1.to_v.register_forward_hook( getActivation(dicts, f"sattnvd{i}a{j}t{k}") ) ## Down Cross Attns pipe.unet.down_blocks[i].attentions[j].transformer_blocks[ k ].attn2.to_q.register_forward_hook( getActivation(dicts, f"cattnqd{i}a{j}t{k}") ) pipe.unet.down_blocks[i].attentions[j].transformer_blocks[ k ].attn2.to_k.register_forward_hook( getActivation(dicts, f"cattnkd{i}a{j}t{k}") ) pipe.unet.down_blocks[i].attentions[j].transformer_blocks[ k ].attn2.to_v.register_forward_hook( getActivation(dicts, f"cattnvd{i}a{j}t{k}") ) for i in range(self.up_idx_start, self.up_idx_end): for j in range(len(pipe.unet.up_blocks[i].attentions)): for k in range( len(pipe.unet.up_blocks[i].attentions[j].transformer_blocks) ): pipe.unet.up_blocks[i].attentions[j].transformer_blocks[ k ].ff.register_forward_hook(getActivation(dicts, f"u{i}a{j}t{k}")) ## Up Self Attns pipe.unet.up_blocks[i].attentions[j].transformer_blocks[ k ].attn1.to_q.register_forward_hook( getActivation(dicts, f"sattnqu{i}a{j}t{k}") ) pipe.unet.up_blocks[i].attentions[j].transformer_blocks[ k ].attn1.to_k.register_forward_hook( getActivation(dicts, f"sattnku{i}a{j}t{k}") ) pipe.unet.up_blocks[i].attentions[j].transformer_blocks[ k ].attn1.to_v.register_forward_hook( getActivation(dicts, f"sattnvu{i}a{j}t{k}") ) ## Up Cross Attns pipe.unet.up_blocks[i].attentions[j].transformer_blocks[ k ].attn2.to_q.register_forward_hook( getActivation(dicts, f"cattnqu{i}a{j}t{k}") ) pipe.unet.up_blocks[i].attentions[j].transformer_blocks[ k ].attn2.to_k.register_forward_hook( getActivation(dicts, f"cattnku{i}a{j}t{k}") ) pipe.unet.up_blocks[i].attentions[j].transformer_blocks[ k ].attn2.to_v.register_forward_hook( getActivation(dicts, f"cattnvu{i}a{j}t{k}") ) @torch.no_grad def get_hidden_states(self, model, positive, negative, average: bool = True): intermediate = {} self.cast_hook(model, intermediate) with torch.no_grad(): _ = model(positive, negative_prompt=negative, num_inference_steps=25) hidden = {} for key in intermediate: hidden_states = intermediate[key][0][-1] if average: # use average over sequence hidden_states = hidden_states.sum(dim=0) / hidden_states.shape[0] else: # take last value hidden_states = hidden_states[:-1] hidden[key] = hidden_states.to(self.device) del intermediate gc.collect() torch.cuda.empty_cache() return hidden @torch.no_grad def get_gate_params( self, experts, positive, negative, ): gate_vects = {} for i, expert in enumerate(tqdm.tqdm(experts, desc="Expert Prompts")): expert.to(self.device) expert.unet.to( device=self.device, dtype=self.torch_dtype, memory_format=torch.channels_last, ) hidden_states = self.get_hidden_states(expert, positive[i], negative[i]) del expert gc.collect() torch.cuda.empty_cache() for h in hidden_states: if i == 0: gate_vects[h] = [] hidden_states[h] /= ( hidden_states[h].norm(p=2, dim=-1, keepdim=True).clamp(min=1e-8) ) gate_vects[h].append(hidden_states[h]) for h in hidden_states: gate_vects[h] = torch.stack( gate_vects[h], dim=0 ) # (num_expert, num_layer, hidden_size) gate_vects[h].permute(1, 0) return gate_vects