MoodSpace / my_ipadapter_model.py
huzey's picture
update download
5cd47b3
import numpy as np
import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL
from PIL import Image
torch.backends.cuda.enable_cudnn_sdp(False) # a fix for torch 2.5.0
from ip_adapter import IPAdapterPlus
from ip_adapter import IPAdapter
# %%
def image_grid(imgs, rows, cols):
assert len(imgs) == rows*cols
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols*w, rows*h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i%cols*w, i//cols*h))
return grid
@torch.inference_mode()
def extract_clip_embedding_pil(pil_image, ip_model):
clip_image = ip_model.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(ip_model.device, dtype=torch.float16)
clip_image_embeds = ip_model.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
clip_image_embeds = clip_image_embeds.float()
return clip_image_embeds
def extract_clip_embedding_pil_batch(pil_images, ip_model):
feats = []
for image in pil_images:
feats.append(extract_clip_embedding_pil(image, ip_model))
feats = torch.cat(feats, dim=0)
return feats
@torch.inference_mode()
def extract_clip_embedding_tensor(tensor_image, ip_model):
tensor_image = tensor_image.to(ip_model.device, dtype=torch.float16)
tensor_image = torch.nn.functional.interpolate(tensor_image, size=(224, 224), mode="bilinear", align_corners=False)
clip_image_embeds = ip_model.image_encoder(tensor_image, output_hidden_states=True).hidden_states[-2]
clip_image_embeds = clip_image_embeds.float()
return clip_image_embeds
@torch.inference_mode()
def _myheck_ipadapter_get_image_embeds(self, pil_image=None, clip_image_embeds=None):
if pil_image is not None:
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(self.device, dtype=torch.float16)
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
uncond_clip_image_embeds = self.image_encoder(
torch.zeros(1, 3, 224, 224).to(self.device, dtype=torch.float16),
output_hidden_states=True
).hidden_states[-2]
uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
return image_prompt_embeds, uncond_image_prompt_embeds
@torch.inference_mode()
def load_sdxl():
base_model_path = "SG161222/Realistic_Vision_V4.0_noVAE"
vae_model_path = "stabilityai/sd-vae-ft-mse"
noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)
# load SD pipeline
pipe = StableDiffusionPipeline.from_pretrained(
base_model_path,
torch_dtype=torch.float16,
scheduler=noise_scheduler,
vae=vae,
feature_extractor=None,
safety_checker=None,
)
return pipe
@torch.inference_mode()
def load_ipadapter(device="cuda"):
base_model_path = "SG161222/Realistic_Vision_V4.0_noVAE"
vae_model_path = "stabilityai/sd-vae-ft-mse"
image_encoder_path = "./downloads/models/image_encoder"
ip_ckpt = "./downloads/models/ip-adapter-plus_sd15.bin"
noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)
# load SD pipeline
pipe = StableDiffusionPipeline.from_pretrained(
base_model_path,
torch_dtype=torch.float16,
scheduler=noise_scheduler,
vae=vae,
feature_extractor=None,
safety_checker=None
)
# load ip-adapter
ip_model = IPAdapterPlus(pipe, image_encoder_path, ip_ckpt, device, num_tokens=16)
setattr(ip_model.__class__, "get_image_embeds", _myheck_ipadapter_get_image_embeds)
return ip_model
@torch.inference_mode()
def generate(ip_model, clip_embeds, num_samples=4, num_inference_steps=50, seed=42):
if clip_embeds.ndim == 2:
clip_embeds = clip_embeds.unsqueeze(0)
assert clip_embeds.ndim == 3
assert clip_embeds.shape[0] == 1
clip_embeds = clip_embeds.half().to(ip_model.device)
images = ip_model.generate(clip_image_embeds=clip_embeds, pil_image=None,
num_samples=num_samples, num_inference_steps=num_inference_steps, seed=seed)
return images