|
import os |
|
|
|
import torch |
|
|
|
from diffusers import DiffusionPipeline, FluxPipeline, StableDiffusion3Pipeline |
|
from huggingface_hub import hf_hub_download |
|
from safetensors.torch import load_file |
|
|
|
from .model import DrUM as backbone |
|
from .sampling import coreset_sampling |
|
|
|
def stable_diffusion(large): |
|
""" |
|
openai/clip-vit-large-patch14, CLIPTextModel, skip -1 |
|
""" |
|
def inference(prompt, ref_prompt = None, weight = None, alpha = 0.3, skip = -1, batch_size = 64, **kwargs): |
|
return large(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, **kwargs), None |
|
return inference |
|
|
|
def stable_diffusion_v2(huge): |
|
""" |
|
openai/clip-vit-huge-patch14, CLIPTextModel, skip -1 |
|
""" |
|
def inference(prompt, ref_prompt = None, weight = None, alpha = 0.3, skip = -1, batch_size = 64, **kwargs): |
|
return huge(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, **kwargs), None |
|
return inference |
|
|
|
def stable_diffusion_xl(large, bigG): |
|
""" |
|
openai/clip-vit-large-patch14, CLIPTextModel, skip -2, unnorm |
|
laion/CLIP-ViT-bigG-14-laion2B-39B-b160k, CLIPTextModelWithProjection, skip -2, unnorm, pooling + proj |
|
""" |
|
def inference(prompt, ref_prompt = None, weight = None, alpha = 0.3, skip = -2, batch_size = 64, **kwargs): |
|
hidden_state = large(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, **kwargs) |
|
if skip == -1: |
|
hidden_state2, pool_hidden_state = bigG(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs) |
|
else: |
|
hidden_state2 = bigG(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, **kwargs) |
|
pool_hidden_state = bigG(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = -1, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs)[1] |
|
hidden_state = torch.cat([hidden_state, hidden_state2], dim = -1) |
|
pool_hidden_state = bigG.projection_text_hidden_state(pool_hidden_state) |
|
return hidden_state.type(pool_hidden_state.dtype), pool_hidden_state |
|
return inference |
|
|
|
def stable_diffusion_v3(large, bigG, t5): |
|
""" |
|
openai/clip-vit-large-patch14, CLIPTextModelWithProjection, skip -2, unnorm, pooling + proj |
|
laion/CLIP-ViT-bigG-14-laion2B-39B-b160k, CLIPTextModelWithProjection, skip -2, unnorm, pooling + proj |
|
t5-v1_1-xxl, T5EncoderModel |
|
""" |
|
def inference(prompt, ref_prompt = None, weight = None, alpha = 0.3, skip = -2, batch_size = 64, **kwargs): |
|
if skip == -1: |
|
hidden_state, pool_hidden_state = large(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs) |
|
hidden_state2, pool_hidden_state2 = bigG(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs) |
|
else: |
|
hidden_state = large(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, **kwargs) |
|
hidden_state2 = bigG(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, skip = skip, batch_size = batch_size, normalize = False, **kwargs) |
|
pool_hidden_state = large(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = -1, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs)[1] |
|
pool_hidden_state2 = bigG(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = -1, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs)[1] |
|
hidden_state3 = t5(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, batch_size = batch_size, normalize = False, **kwargs) |
|
hidden_state = torch.cat([hidden_state, hidden_state2], dim = -1) |
|
pool_hidden_state = large.projection_text_hidden_state(pool_hidden_state) |
|
pool_hidden_state2 = bigG.projection_text_hidden_state(pool_hidden_state2) |
|
hidden_state = torch.nn.functional.pad(hidden_state, (0, hidden_state3.shape[-1] - hidden_state.shape[-1])) |
|
hidden_state = torch.cat([hidden_state, hidden_state3], dim = -2) |
|
pool_hidden_state = torch.cat([pool_hidden_state, pool_hidden_state2], dim = -1) |
|
return hidden_state.type(pool_hidden_state.dtype), pool_hidden_state |
|
return inference |
|
|
|
def flux(large, t5): |
|
""" |
|
openai/clip-vit-large-patch14, CLIPTextModel, pooling |
|
t5-v1_1-xxl, T5EncoderModel |
|
""" |
|
def inference(prompt, ref_prompt = None, weight = None, alpha = 0.3, skip = None, batch_size = 64, **kwargs): |
|
hidden_state = t5(prompt, ref_prompt, pooling = False, weight = weight, alpha = alpha, batch_size = batch_size, normalize = False, **kwargs) |
|
pool_hidden_state = large(prompt, ref_prompt, pooling = True, weight = weight, alpha = alpha, skip = -1, batch_size = batch_size, normalize = False, normalize_pool = True, **kwargs)[1] |
|
return hidden_state.type(pool_hidden_state.dtype), pool_hidden_state |
|
return inference |
|
|
|
def peca(pipeline, save_path = "./weight", n_layer = 10): |
|
if os.path.exists(os.path.join(save_path, "L.pth")) or os.path.exists(os.path.join(save_path, "H.pth")): |
|
load_func = torch.load |
|
postfix = "pth" |
|
else: |
|
from safetensors.torch import load_file as load_func |
|
postfix = "safetensors" |
|
|
|
if "flux" in pipeline.config._name_or_path.split("/")[-1].lower(): |
|
model = pipeline.text_encoder |
|
processor = pipeline.tokenizer |
|
model2 = pipeline.text_encoder_2 |
|
processor2 = pipeline.tokenizer_2 |
|
|
|
large = backbone(model, processor, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval() |
|
large.adapter.load_state_dict(load_func(os.path.join(save_path, "L.{0}".format(postfix)))) |
|
t5 = backbone(model2, processor2, n_layer = n_layer, encode_ratio = 4, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval() |
|
t5.adapter.load_state_dict(load_func(os.path.join(save_path, "T5.{0}".format(postfix)))) |
|
empty, pool = large.encode_prompt("", pooling = True, normalize = False, normalize_pool = False) |
|
large.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1)) |
|
empty, pool = t5.encode_prompt("", pooling = True, normalize = False, normalize_pool = False) |
|
t5.adapter.set_base_query(empty) |
|
|
|
feature_encoder = large |
|
encoder = flux(large, t5) |
|
size = 1024 |
|
num_inference_steps = 28 |
|
skip = -2 |
|
elif "stable-diffusion-3.5" in pipeline.config._name_or_path.split("/")[-1].lower(): |
|
model = pipeline.text_encoder |
|
processor = pipeline.tokenizer |
|
model2 = pipeline.text_encoder_2 |
|
processor2 = pipeline.tokenizer_2 |
|
model3 = pipeline.text_encoder_3 |
|
processor3 = pipeline.tokenizer_3 |
|
|
|
large = backbone(model, processor, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval() |
|
large.adapter.load_state_dict(load_func(os.path.join(save_path, "L.{0}".format(postfix)))) |
|
bigG = backbone(model2, processor2, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval() |
|
bigG.adapter.load_state_dict(load_func(os.path.join(save_path, "bigG.{0}".format(postfix)))) |
|
t5 = backbone(model3, processor3, n_layer = n_layer, encode_ratio = 4, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval() |
|
t5.adapter.load_state_dict(load_func(os.path.join(save_path, "T5.{0}".format(postfix)))) |
|
empty, pool = large.encode_prompt("", pooling = True, normalize = False, normalize_pool = False) |
|
large.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1)) |
|
empty, pool = bigG.encode_prompt("", pooling = True, normalize = False, normalize_pool = False) |
|
bigG.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1)) |
|
empty, pool = t5.encode_prompt("", pooling = True, normalize = False, normalize_pool = False) |
|
t5.adapter.set_base_query(empty) |
|
|
|
feature_encoder = large |
|
encoder = stable_diffusion_v3(large, bigG, t5) |
|
size = 1024 |
|
num_inference_steps = 28 |
|
skip = -2 |
|
elif "xl-base" in pipeline.config._name_or_path.split("/")[-1].lower(): |
|
model = pipeline.text_encoder |
|
processor = pipeline.tokenizer |
|
model2 = pipeline.text_encoder_2 |
|
processor2 = pipeline.tokenizer_2 |
|
|
|
large = backbone(model, processor, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval() |
|
large.adapter.load_state_dict(load_func(os.path.join(save_path, "L.{0}".format(postfix)))) |
|
bigG = backbone(model2, processor2, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval() |
|
bigG.adapter.load_state_dict(load_func(os.path.join(save_path, "bigG.{0}".format(postfix)))) |
|
empty, pool = large.encode_prompt("", pooling = True, normalize = False, normalize_pool = False) |
|
large.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1)) |
|
empty, pool = bigG.encode_prompt("", pooling = True, normalize = False, normalize_pool = False) |
|
bigG.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1)) |
|
|
|
feature_encoder = large |
|
encoder = stable_diffusion_xl(large, bigG) |
|
size = 1024 |
|
num_inference_steps = 50 |
|
skip = -2 |
|
elif "stable-diffusion-2" in pipeline.config._name_or_path.split("/")[-1].lower(): |
|
model = pipeline.text_encoder |
|
processor = pipeline.tokenizer |
|
|
|
huge = backbone(model, processor, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval() |
|
huge.adapter.load_state_dict(load_func(os.path.join(save_path, "H.{0}".format(postfix)))) |
|
empty, pool = huge.encode_prompt("", pooling = True, normalize = False, normalize_pool = False) |
|
huge.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1)) |
|
|
|
feature_encoder = huge |
|
encoder = stable_diffusion_v2(huge) |
|
size = 768 |
|
num_inference_steps = 50 |
|
skip = -1 |
|
else: |
|
model = pipeline.text_encoder |
|
processor = pipeline.tokenizer |
|
|
|
large = backbone(model, processor, n_layer = n_layer, pos = False, cls_pos = False, dropout = 0.0).to(pipeline.device).eval() |
|
large.adapter.load_state_dict(load_func(os.path.join(save_path, "L.{0}".format(postfix)))) |
|
empty, pool = large.encode_prompt("", pooling = True, normalize = False, normalize_pool = False) |
|
large.adapter.set_base_query(torch.cat([empty, pool.unsqueeze(1)], dim = 1)) |
|
|
|
feature_encoder = large |
|
encoder = stable_diffusion(large) |
|
size = 512 |
|
num_inference_steps = 50 |
|
skip = -1 |
|
return encoder, feature_encoder.get_text_feature, size, num_inference_steps, skip |
|
|
|
class DrUM(DiffusionPipeline): |
|
def __init__(self, pipeline, repo_id = "Burf/DrUM", weight = None, torch_dtype = torch.bfloat16, device = "cuda"): |
|
""" |
|
DrUM for various T2I diffusion models |
|
""" |
|
self.pipeline = pipeline if not isinstance(pipeline, str) else self.load_pipeline(pipeline, torch_dtype = torch_dtype, device = device) |
|
self.repo_id = repo_id |
|
|
|
self.adapter, self.feature_encoder, self.size, self.num_inference_steps, self.skip = self.load_peca(self.pipeline, repo_id, weight) |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, repo_id = "Burf/DrUM", torch_dtype = torch.bfloat16, device = "cuda", weight=None): |
|
""" |
|
Load DrUM adapter with appropriate pipeline |
|
""" |
|
pipeline = cls.load_pipeline(pretrained_model_name_or_path, torch_dtype, device) |
|
return cls(pipeline = pipeline, repo_id = repo_id, weight = weight, torch_dtype = torch_dtype, device = device) |
|
|
|
@staticmethod |
|
def load_pipeline(model_id, torch_dtype = torch.bfloat16, device = "cuda"): |
|
name = model_id.split("/")[-1].lower() |
|
if "flux" in name: |
|
pipeline = FluxPipeline.from_pretrained(model_id, torch_dtype = torch_dtype) |
|
elif "stable-diffusion-3.5" in name: |
|
pipeline = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype = torch_dtype) |
|
else: |
|
pipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype = torch_dtype) |
|
|
|
pipeline = pipeline.to(device if torch.cuda.is_available() else "cpu") |
|
|
|
return pipeline |
|
|
|
def load_weight(self, pipeline, repo_id = "Burf/DrUM", weight = None): |
|
name = pipeline.config._name_or_path.split("/")[-1].lower() |
|
|
|
weights = [] |
|
if "flux" in name: |
|
weights = ["L.safetensors", "T5.safetensors"] |
|
elif "stable-diffusion-3.5" in name: |
|
weights = ["L.safetensors", "bigG.safetensors", "T5.safetensors"] |
|
elif "xl-base" in name: |
|
weights = ["L.safetensors", "bigG.safetensors"] |
|
elif "stable-diffusion-2" in name: |
|
weights = ["H.safetensors"] |
|
else: |
|
weights = ["L.safetensors"] |
|
|
|
for weight_file in weights: |
|
if isinstance(weight, str) and os.path.exists(os.path.join(weight, weight_file)): |
|
weight_path = weight |
|
break |
|
else: |
|
safetensor_path = hf_hub_download(repo_id = repo_id, filename = "weight/" + weight_file) |
|
weight_path = os.path.dirname(safetensor_path) |
|
return weight_path |
|
|
|
def load_peca(self, pipeline, repo_id = "Burf/DrUM", weight = None): |
|
adapter, feature_encoder, size, num_inference_steps, skip = peca(pipeline, save_path = self.load_weight(pipeline, repo_id, weight)) |
|
return adapter, feature_encoder, size, num_inference_steps, skip |
|
|
|
def __call__(self, prompt, ref = None, weight = None, alpha = 0.3, skip = None, sampling = False, seed = 42, |
|
size = None, num_inference_steps = None, num_images_per_prompt = 1): |
|
""" |
|
Generate images using DrUM adapter |
|
|
|
Args: |
|
prompt: Text prompt for generation |
|
ref: Reference prompts (list of strings) |
|
weight: Weights for reference prompts (list of floats) |
|
alpha: Personalization strength (0-1) |
|
skip: Text condition axis |
|
sampling: Whether to use coreset sampling for reference selection (default: False) |
|
seed: Random seed |
|
size: Image size |
|
num_inference_steps: Inference steps |
|
num_images_per_prompt: Number of images to generate |
|
|
|
Returns: |
|
Personalized images (list of PIL Images) |
|
""" |
|
size = self.size if size is None else size |
|
num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps |
|
skip = self.skip if skip is None else skip |
|
|
|
if sampling and isinstance(ref, (tuple, list)) and 1 < len(ref): |
|
import numpy as np |
|
|
|
with torch.no_grad(): |
|
feature = self.feature_encoder(ref).cpu().float().numpy() |
|
|
|
indices = coreset_sampling(feature, weight = weight, seed = seed) |
|
ref = np.array(ref)[indices].tolist() |
|
|
|
if isinstance(weight, (tuple, list)) and len(weight) == len(ref): |
|
weight = np.array(weight)[indices].tolist() |
|
|
|
generator = torch.Generator(self.pipeline.device).manual_seed(seed) |
|
with torch.no_grad(): |
|
cond, pool_cond = self.adapter(prompt, ref, weight = weight, alpha = alpha, skip = skip) |
|
|
|
pipe_kwargs = { |
|
"num_images_per_prompt": num_images_per_prompt, |
|
"num_inference_steps": num_inference_steps, |
|
"generator": generator, |
|
"height": size, |
|
"width": size |
|
} |
|
|
|
pipe_kwargs["prompt_embeds"] = cond.type(self.pipeline.dtype) |
|
if pool_cond is not None: |
|
pipe_kwargs["pooled_prompt_embeds"] = pool_cond.type(self.pipeline.dtype) |
|
|
|
name = self.pipeline.config._name_or_path.split("/")[-1].lower() |
|
if "flux" in name or "stable-diffusion-3" in name: |
|
pipe_kwargs["max_sequence_length"] = 256 |
|
|
|
images = self.pipeline(**pipe_kwargs).images |
|
return images |