DrUM / wrapper.py
Burf's picture
change custom pipeline init issue
e600401
import os
import torch
from diffusers import DiffusionPipeline, FluxPipeline, StableDiffusion3Pipeline
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from .model import DrUM as backbone
from .sampling import coreset_sampling
def stable_diffusion(large):
"""
openai/clip-vit-large-patch14, CLIPTextModel, skip -1
"""
def inference(prompt, ref_prompt = None, weight = None, alpha = 0.3, skip = -1, batch_size = 64, **kwargs):
return large(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, **kwargs), None
return inference
def stable_diffusion_v2(huge):
"""
openai/clip-vit-huge-patch14, CLIPTextModel, skip -1
"""
def inference(prompt, ref_prompt = None, weight = None, alpha = 0.3, skip = -1, batch_size = 64, **kwargs):
return huge(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, **kwargs), None
return inference
def stable_diffusion_xl(large, bigG):
"""
openai/clip-vit-large-patch14, CLIPTextModel, skip -2, unnorm
laion/CLIP-ViT-bigG-14-laion2B-39B-b160k, CLIPTextModelWithProjection, skip -2, unnorm, pooling + proj
"""
def inference(prompt, ref_prompt = None, weight = None, alpha = 0.3, skip = -2, batch_size = 64, **kwargs):
hidden_state = large(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, **kwargs)
if skip == -1:
hidden_state2, pool_hidden_state = bigG(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs)
else:
hidden_state2 = bigG(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, **kwargs)
pool_hidden_state = bigG(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = -1, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs)[1]
hidden_state = torch.cat([hidden_state, hidden_state2], dim = -1)
pool_hidden_state = bigG.projection_text_hidden_state(pool_hidden_state)
return hidden_state.type(pool_hidden_state.dtype), pool_hidden_state
return inference
def stable_diffusion_v3(large, bigG, t5):
"""
openai/clip-vit-large-patch14, CLIPTextModelWithProjection, skip -2, unnorm, pooling + proj
laion/CLIP-ViT-bigG-14-laion2B-39B-b160k, CLIPTextModelWithProjection, skip -2, unnorm, pooling + proj
t5-v1_1-xxl, T5EncoderModel
"""
def inference(prompt, ref_prompt = None, weight = None, alpha = 0.3, skip = -2, batch_size = 64, **kwargs):
if skip == -1:
hidden_state, pool_hidden_state = large(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs)
hidden_state2, pool_hidden_state2 = bigG(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs)
else:
hidden_state = large(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, **kwargs)
hidden_state2 = bigG(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, **kwargs)
pool_hidden_state = large(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = -1, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs)[1]
pool_hidden_state2 = bigG(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = -1, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs)[1]
hidden_state3 = t5(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, batch_size = batch_size, normalize = False, **kwargs)
hidden_state = torch.cat([hidden_state, hidden_state2], dim = -1)
pool_hidden_state = large.projection_text_hidden_state(pool_hidden_state)
pool_hidden_state2 = bigG.projection_text_hidden_state(pool_hidden_state2)
hidden_state = torch.nn.functional.pad(hidden_state, (0, hidden_state3.shape[-1] - hidden_state.shape[-1]))
hidden_state = torch.cat([hidden_state, hidden_state3], dim = -2)
pool_hidden_state = torch.cat([pool_hidden_state, pool_hidden_state2], dim = -1)
return hidden_state.type(pool_hidden_state.dtype), pool_hidden_state
return inference
def flux(large, t5):
"""
openai/clip-vit-large-patch14, CLIPTextModel, pooling
t5-v1_1-xxl, T5EncoderModel
"""
def inference(prompt, ref_prompt = None, weight = None, alpha = 0.3, skip = None, batch_size = 64, **kwargs):
hidden_state = t5(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, batch_size = batch_size, normalize = False, **kwargs)
pool_hidden_state = large(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = -1, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs)[1]
return hidden_state.type(pool_hidden_state.dtype), pool_hidden_state
return inference
def peca(pipeline, save_path = "./weight", n_layer = 10):
if os.path.exists(os.path.join(save_path, "L.pth")) or os.path.exists(os.path.join(save_path, "H.pth")):
load_func = torch.load
postfix = "pth"
else:
from safetensors.torch import load_file as load_func
postfix = "safetensors"
if "flux" in pipeline.config._name_or_path.split("/")[-1].lower():
model = pipeline.text_encoder
processor = pipeline.tokenizer
model2 = pipeline.text_encoder_2
processor2 = pipeline.tokenizer_2
large = backbone(model, processor, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval()
large.adapter.load_state_dict(load_func(os.path.join(save_path, "L.{0}".format(postfix))))
t5 = backbone(model2, processor2, n_layer = n_layer, encode_ratio = 4, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval()
t5.adapter.load_state_dict(load_func(os.path.join(save_path, "T5.{0}".format(postfix))))
empty, pool = large.encode_prompt("", pooling = True, normalize = False, normalize_pool = False)
large.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1))
empty, pool = t5.encode_prompt("", pooling = True, normalize = False, normalize_pool = False)
t5.adapter.set_base_query(empty)
feature_encoder = large
encoder = flux(large, t5)
size = 1024
num_inference_steps = 28
skip = -2
elif "stable-diffusion-3.5" in pipeline.config._name_or_path.split("/")[-1].lower(): #sd v3
model = pipeline.text_encoder
processor = pipeline.tokenizer
model2 = pipeline.text_encoder_2
processor2 = pipeline.tokenizer_2
model3 = pipeline.text_encoder_3
processor3 = pipeline.tokenizer_3
large = backbone(model, processor, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval()
large.adapter.load_state_dict(load_func(os.path.join(save_path, "L.{0}".format(postfix))))
bigG = backbone(model2, processor2, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval()
bigG.adapter.load_state_dict(load_func(os.path.join(save_path, "bigG.{0}".format(postfix))))
t5 = backbone(model3, processor3, n_layer = n_layer, encode_ratio = 4, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval()
t5.adapter.load_state_dict(load_func(os.path.join(save_path, "T5.{0}".format(postfix))))
empty, pool = large.encode_prompt("", pooling = True, normalize = False, normalize_pool = False)
large.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1))
empty, pool = bigG.encode_prompt("", pooling = True, normalize = False, normalize_pool = False)
bigG.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1))
empty, pool = t5.encode_prompt("", pooling = True, normalize = False, normalize_pool = False)
t5.adapter.set_base_query(empty)
feature_encoder = large
encoder = stable_diffusion_v3(large, bigG, t5)
size = 1024
num_inference_steps = 28
skip = -2
elif "xl-base" in pipeline.config._name_or_path.split("/")[-1].lower(): #sd xl
model = pipeline.text_encoder
processor = pipeline.tokenizer
model2 = pipeline.text_encoder_2
processor2 = pipeline.tokenizer_2
large = backbone(model, processor, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval()
large.adapter.load_state_dict(load_func(os.path.join(save_path, "L.{0}".format(postfix))))
bigG = backbone(model2, processor2, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval()
bigG.adapter.load_state_dict(load_func(os.path.join(save_path, "bigG.{0}".format(postfix))))
empty, pool = large.encode_prompt("", pooling = True, normalize = False, normalize_pool = False)
large.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1))
empty, pool = bigG.encode_prompt("", pooling = True, normalize = False, normalize_pool = False)
bigG.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1))
feature_encoder = large
encoder = stable_diffusion_xl(large, bigG)
size = 1024
num_inference_steps = 50
skip = -2
elif "stable-diffusion-2" in pipeline.config._name_or_path.split("/")[-1].lower():
model = pipeline.text_encoder
processor = pipeline.tokenizer
huge = backbone(model, processor, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval()
huge.adapter.load_state_dict(load_func(os.path.join(save_path, "H.{0}".format(postfix))))
empty, pool = huge.encode_prompt("", pooling = True, normalize = False, normalize_pool = False)
huge.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1))
feature_encoder = huge
encoder = stable_diffusion_v2(huge)
size = 768
num_inference_steps = 50
skip = -1
else: #sd
model = pipeline.text_encoder
processor = pipeline.tokenizer
large = backbone(model, processor, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval()
large.adapter.load_state_dict(load_func(os.path.join(save_path, "L.{0}".format(postfix))))
empty, pool = large.encode_prompt("", pooling = True, normalize = False, normalize_pool = False)
large.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1))
feature_encoder = large
encoder = stable_diffusion(large)
size = 512
num_inference_steps = 50
skip = -1
return encoder, feature_encoder.get_text_feature, size, num_inference_steps, skip
class DrUM(DiffusionPipeline):
def __init__(self, pipeline, repo_id = "Burf/DrUM", weight = None, torch_dtype = torch.bfloat16, device = "cuda"):
"""
DrUM for various T2I diffusion models
"""
self.pipeline = pipeline if not isinstance(pipeline, str) else self.load_pipeline(pipeline, torch_dtype = torch_dtype, device = device)
self.repo_id = repo_id
self.adapter, self.feature_encoder, self.size, self.num_inference_steps, self.skip = self.load_peca(self.pipeline, repo_id, weight)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, repo_id = "Burf/DrUM", torch_dtype = torch.bfloat16, device = "cuda", weight=None):
"""
Load DrUM adapter with appropriate pipeline
"""
pipeline = cls.load_pipeline(pretrained_model_name_or_path, torch_dtype, device)
return cls(pipeline = pipeline, repo_id = repo_id, weight = weight, torch_dtype = torch_dtype, device = device)
@staticmethod
def load_pipeline(model_id, torch_dtype = torch.bfloat16, device = "cuda"):
name = model_id.split("/")[-1].lower()
if "flux" in name:
pipeline = FluxPipeline.from_pretrained(model_id, torch_dtype = torch_dtype)
elif "stable-diffusion-3.5" in name:
pipeline = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype = torch_dtype)
else:
pipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype = torch_dtype)
pipeline = pipeline.to(device if torch.cuda.is_available() else "cpu")
#pipeline.safety_checker = lambda images, clip_input: (images, [False] * len(images))
return pipeline
def load_weight(self, pipeline, repo_id = "Burf/DrUM", weight = None):
name = pipeline.config._name_or_path.split("/")[-1].lower()
weights = []
if "flux" in name:
weights = ["L.safetensors", "T5.safetensors"]
elif "stable-diffusion-3.5" in name:
weights = ["L.safetensors", "bigG.safetensors", "T5.safetensors"]
elif "xl-base" in name:
weights = ["L.safetensors", "bigG.safetensors"]
elif "stable-diffusion-2" in name:
weights = ["H.safetensors"]
else: # SD v1.5
weights = ["L.safetensors"]
for weight_file in weights:
if isinstance(weight, str) and os.path.exists(os.path.join(weight, weight_file)):
weight_path = weight
break
else:
safetensor_path = hf_hub_download(repo_id = repo_id, filename = "weight/" + weight_file)
weight_path = os.path.dirname(safetensor_path)
return weight_path
def load_peca(self, pipeline, repo_id = "Burf/DrUM", weight = None):
adapter, feature_encoder, size, num_inference_steps, skip = peca(pipeline, save_path = self.load_weight(pipeline, repo_id, weight))
return adapter, feature_encoder, size, num_inference_steps, skip
def __call__(self, prompt, ref = None, weight = None, alpha = 0.3, skip = None, sampling = False, seed = 42,
size = None, num_inference_steps = None, num_images_per_prompt = 1):
"""
Generate images using DrUM adapter
Args:
prompt: Text prompt for generation
ref: Reference prompts (list of strings)
weight: Weights for reference prompts (list of floats)
alpha: Personalization strength (0-1)
skip: Text condition axis
sampling: Whether to use coreset sampling for reference selection (default: False)
seed: Random seed
size: Image size
num_inference_steps: Inference steps
num_images_per_prompt: Number of images to generate
Returns:
Personalized images (list of PIL Images)
"""
size = self.size if size is None else size
num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps
skip = self.skip if skip is None else skip
if sampling and isinstance(ref, (tuple, list)) and 1 < len(ref):
import numpy as np
with torch.no_grad():
feature = self.feature_encoder(ref).cpu().float().numpy()
indices = coreset_sampling(feature, weight = weight, seed = seed)
ref = np.array(ref)[indices].tolist()
if isinstance(weight, (tuple, list)) and len(weight) == len(ref):
weight = np.array(weight)[indices].tolist()
generator = torch.Generator(self.pipeline.device).manual_seed(seed)
with torch.no_grad():
cond, pool_cond = self.adapter(prompt, ref, weight = weight, alpha = alpha, skip = skip)
pipe_kwargs = {
"num_images_per_prompt": num_images_per_prompt,
"num_inference_steps": num_inference_steps,
"generator": generator,
"height": size,
"width": size
}
pipe_kwargs["prompt_embeds"] = cond.type(self.pipeline.dtype)
if pool_cond is not None:
pipe_kwargs["pooled_prompt_embeds"] = pool_cond.type(self.pipeline.dtype)
name = self.pipeline.config._name_or_path.split("/")[-1].lower()
if "flux" in name or "stable-diffusion-3" in name:
pipe_kwargs["max_sequence_length"] = 256
images = self.pipeline(**pipe_kwargs).images
return images