ByteMorph-Demo / src /flux /xflux_pipeline.py
Boese0601's picture
Update src/flux/xflux_pipeline.py
8a00db9 verified
from PIL import Image, ExifTags
import numpy as np
import torch
from torch import Tensor
from einops import rearrange
import uuid
import os
from src.flux.modules.layers import (
SingleStreamBlockProcessor,
DoubleStreamBlockProcessor,
SingleStreamBlockLoraProcessor,
DoubleStreamBlockLoraProcessor,
IPDoubleStreamBlockProcessor,
ImageProjModel,
)
from src.flux.sampling import denoise, denoise_controlnet, get_noise, get_schedule, prepare, unpack
from src.flux.util import (
load_ae,
load_clip,
load_flow_model,
load_t5,
load_controlnet,
load_flow_model_quintized,
Annotator,
get_lora_rank,
load_checkpoint
)
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
class XFluxPipeline:
def __init__(self, model_type, device, offload: bool = False):
self.device = torch.device(device)
self.offload = offload
self.model_type = model_type
self.clip = load_clip(self.device)
self.t5 = load_t5(self.device, max_length=512)
self.ae = load_ae(model_type, device="cpu" if offload else self.device)
if "fp8" in model_type:
self.model = load_flow_model_quintized(model_type, device="cpu" if offload else self.device)
else:
self.model = load_flow_model(model_type, device="cpu" if offload else self.device)
self.image_encoder_path = "openai/clip-vit-large-patch14"
self.hf_lora_collection = "XLabs-AI/flux-lora-collection"
self.lora_types_to_names = {
"realism": "lora.safetensors",
}
self.controlnet_loaded = False
self.ip_loaded = False
self.spatial_condition = False
self.share_position_embedding = False
self.use_share_weight_referencenet = False
self.single_block_refnet = False
self.double_block_refnet = False
def set_ip(self, local_path: str = None, repo_id = None, name: str = None):
self.model.to(self.device)
# unpack checkpoint
checkpoint = load_checkpoint(local_path, repo_id, name)
prefix = "double_blocks."
blocks = {}
proj = {}
for key, value in checkpoint.items():
if key.startswith(prefix):
blocks[key[len(prefix):].replace('.processor.', '.')] = value
if key.startswith("ip_adapter_proj_model"):
proj[key[len("ip_adapter_proj_model."):]] = value
# load image encoder
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
self.device, dtype=torch.float16
)
self.clip_image_processor = CLIPImageProcessor()
# setup image embedding projection model
self.improj = ImageProjModel(4096, 768, 4)
self.improj.load_state_dict(proj)
self.improj = self.improj.to(self.device, dtype=torch.bfloat16)
ip_attn_procs = {}
for name, _ in self.model.attn_processors.items():
ip_state_dict = {}
for k in checkpoint.keys():
if name in k:
ip_state_dict[k.replace(f'{name}.', '')] = checkpoint[k]
if ip_state_dict:
ip_attn_procs[name] = IPDoubleStreamBlockProcessor(4096, 3072)
ip_attn_procs[name].load_state_dict(ip_state_dict)
ip_attn_procs[name].to(self.device, dtype=torch.bfloat16)
else:
ip_attn_procs[name] = self.model.attn_processors[name]
self.model.set_attn_processor(ip_attn_procs)
self.ip_loaded = True
def set_lora(self, local_path: str = None, repo_id: str = None,
name: str = None, lora_weight: int = 0.7):
checkpoint = load_checkpoint(local_path, repo_id, name)
self.update_model_with_lora(checkpoint, lora_weight)
def set_lora_from_collection(self, lora_type: str = "realism", lora_weight: int = 0.7):
checkpoint = load_checkpoint(
None, self.hf_lora_collection, self.lora_types_to_names[lora_type]
)
self.update_model_with_lora(checkpoint, lora_weight)
def update_model_with_lora(self, checkpoint, lora_weight):
rank = get_lora_rank(checkpoint)
lora_attn_procs = {}
for name, _ in self.model.attn_processors.items():
lora_state_dict = {}
for k in checkpoint.keys():
if name in k:
lora_state_dict[k[len(name) + 1:]] = checkpoint[k] * lora_weight
if len(lora_state_dict):
if name.startswith("single_blocks"):
lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=3072, rank=rank)
else:
lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=3072, rank=rank)
lora_attn_procs[name].load_state_dict(lora_state_dict)
lora_attn_procs[name].to(self.device)
else:
if name.startswith("single_blocks"):
lora_attn_procs[name] = SingleStreamBlockProcessor()
else:
lora_attn_procs[name] = DoubleStreamBlockProcessor()
self.model.set_attn_processor(lora_attn_procs)
def set_controlnet(self, control_type: str, local_path: str = None, repo_id: str = None, name: str = None):
self.model.to(self.device)
self.controlnet = load_controlnet(self.model_type, self.device).to(torch.bfloat16)
checkpoint = load_checkpoint(local_path, repo_id, name)
self.controlnet.load_state_dict(checkpoint, strict=False)
self.annotator = Annotator(control_type, self.device)
self.controlnet_loaded = True
self.control_type = control_type
def get_image_proj(
self,
image_prompt: Tensor,
):
# encode image-prompt embeds
image_prompt = self.clip_image_processor(
images=image_prompt,
return_tensors="pt"
).pixel_values
image_prompt = image_prompt.to(self.image_encoder.device)
image_prompt_embeds = self.image_encoder(
image_prompt
).image_embeds.to(
device=self.device, dtype=torch.bfloat16,
)
# encode image
image_proj = self.improj(image_prompt_embeds)
return image_proj
def __call__(self,
prompt: str,
image_prompt: Image = None,
source_image: Tensor = None,
controlnet_image: Image = None,
width: int = 512,
height: int = 512,
guidance: float = 4,
num_steps: int = 50,
seed: int = 123456789,
true_gs: float = 3.5, # 3
control_weight: float = 0.9,
ip_scale: float = 1.0,
neg_ip_scale: float = 1.0,
neg_prompt: str = '',
neg_image_prompt: Image = None,
timestep_to_start_cfg: int = 1, # 0
):
width = 16 * (width // 16)
height = 16 * (height // 16)
image_proj = None
neg_image_proj = None
if not (image_prompt is None and neg_image_prompt is None) :
assert self.ip_loaded, 'You must setup IP-Adapter to add image prompt as input'
if image_prompt is None:
image_prompt = np.zeros((width, height, 3), dtype=np.uint8)
if neg_image_prompt is None:
neg_image_prompt = np.zeros((width, height, 3), dtype=np.uint8)
image_proj = self.get_image_proj(image_prompt)
neg_image_proj = self.get_image_proj(neg_image_prompt)
if self.controlnet_loaded:
controlnet_image = self.annotator(controlnet_image, width, height)
controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1)
controlnet_image = controlnet_image.permute(
2, 0, 1).unsqueeze(0).to(torch.bfloat16).to(self.device)
return self.forward(
prompt,
width,
height,
guidance,
num_steps,
seed,
controlnet_image,
timestep_to_start_cfg=timestep_to_start_cfg,
true_gs=true_gs,
control_weight=control_weight,
neg_prompt=neg_prompt,
image_proj=image_proj,
neg_image_proj=neg_image_proj,
ip_scale=ip_scale,
neg_ip_scale=neg_ip_scale,
spatial_condition=self.spatial_condition,
source_image=source_image,
share_position_embedding=self.share_position_embedding
)
@torch.inference_mode()
def gradio_generate(self, prompt, image_prompt, controlnet_image, width, height, guidance,
num_steps, seed, true_gs, ip_scale, neg_ip_scale, neg_prompt,
neg_image_prompt, timestep_to_start_cfg, control_type, control_weight,
lora_weight, local_path, lora_local_path, ip_local_path):
if controlnet_image is not None:
controlnet_image = Image.fromarray(controlnet_image)
if ((self.controlnet_loaded and control_type != self.control_type)
or not self.controlnet_loaded):
if local_path is not None:
self.set_controlnet(control_type, local_path=local_path)
else:
self.set_controlnet(control_type, local_path=None,
repo_id=f"xlabs-ai/flux-controlnet-{control_type}-v3",
name=f"flux-{control_type}-controlnet-v3.safetensors")
if lora_local_path is not None:
self.set_lora(local_path=lora_local_path, lora_weight=lora_weight)
if image_prompt is not None:
image_prompt = Image.fromarray(image_prompt)
if neg_image_prompt is not None:
neg_image_prompt = Image.fromarray(neg_image_prompt)
if not self.ip_loaded:
if ip_local_path is not None:
self.set_ip(local_path=ip_local_path)
else:
self.set_ip(repo_id="xlabs-ai/flux-ip-adapter",
name="flux-ip-adapter.safetensors")
seed = int(seed)
if seed == -1:
seed = torch.Generator(device="cpu").seed()
img = self(prompt, image_prompt, controlnet_image, width, height, guidance,
num_steps, seed, true_gs, control_weight, ip_scale, neg_ip_scale, neg_prompt,
neg_image_prompt, timestep_to_start_cfg)
filename = f"output/gradio/{uuid.uuid4()}.jpg"
os.makedirs(os.path.dirname(filename), exist_ok=True)
exif_data = Image.Exif()
exif_data[ExifTags.Base.Make] = "XLabs AI"
exif_data[ExifTags.Base.Model] = self.model_type
img.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0)
return img, filename
def forward(
self,
prompt,
width,
height,
guidance,
num_steps,
seed,
controlnet_image = None,
timestep_to_start_cfg = 0,
true_gs = 3.5,
control_weight = 0.9,
neg_prompt="",
image_proj=None,
neg_image_proj=None,
ip_scale=1.0,
neg_ip_scale=1.0,
spatial_condition=True,
source_image=None,
share_position_embedding=False
):
x = get_noise(
1, height, width, device=self.device,
dtype=torch.bfloat16, seed=seed
)
timesteps = get_schedule(
num_steps,
(width // 8) * (height // 8) // (16 * 16),
shift=True,
)
torch.manual_seed(seed)
with torch.no_grad():
if self.offload:
self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
# print("x noise shape:", x.shape)
inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=prompt, use_spatial_condition=spatial_condition, share_position_embedding=share_position_embedding, use_share_weight_referencenet=self.use_share_weight_referencenet)
# print("input img noise shape:", inp_cond['img'].shape)
neg_inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt, use_spatial_condition=spatial_condition, share_position_embedding=share_position_embedding, use_share_weight_referencenet=self.use_share_weight_referencenet)
if spatial_condition or self.use_share_weight_referencenet:
# TODO here:
source_image = self.ae.encode(source_image.to(self.device).to(torch.float32))
# print("ae source image shape:", source_image.shape)
source_image = rearrange(source_image, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2).to(inp_cond['img'].dtype)
# print("rearrange ae source image shape:", source_image.shape)
if self.offload:
self.offload_model_to_cpu(self.t5, self.clip)
self.model = self.model.to(self.device)
if self.controlnet_loaded:
x = denoise_controlnet(
self.model,
img=inp_cond['img'],
img_ids=inp_cond['img_ids'],
txt=inp_cond['txt'],
txt_ids=inp_cond['txt_ids'],
vec=inp_cond['vec'],
controlnet=self.controlnet,
timesteps=timesteps,
guidance=guidance,
controlnet_cond=controlnet_image,
timestep_to_start_cfg=timestep_to_start_cfg,
neg_txt=neg_inp_cond['txt'],
neg_txt_ids=neg_inp_cond['txt_ids'],
neg_vec=neg_inp_cond['vec'],
true_gs=true_gs,
controlnet_gs=control_weight,
image_proj=image_proj,
neg_image_proj=neg_image_proj,
ip_scale=ip_scale,
neg_ip_scale=neg_ip_scale,
)
else:
x = denoise(
self.model,
img=inp_cond['img'],
img_ids=inp_cond['img_ids'],
txt=inp_cond['txt'],
txt_ids=inp_cond['txt_ids'],
vec=inp_cond['vec'],
timesteps=timesteps,
guidance=guidance,
timestep_to_start_cfg=timestep_to_start_cfg,
neg_txt=neg_inp_cond['txt'],
neg_txt_ids=neg_inp_cond['txt_ids'],
neg_vec=neg_inp_cond['vec'],
true_gs=true_gs,
image_proj=image_proj,
neg_image_proj=neg_image_proj,
ip_scale=ip_scale,
neg_ip_scale=neg_ip_scale,
source_image=source_image, # spatial_condition source image
use_share_weight_referencenet=self.use_share_weight_referencenet,
single_img_ids=inp_cond['single_img_ids'] if self.use_share_weight_referencenet else None,
neg_single_img_ids=neg_inp_cond['single_img_ids'] if self.use_share_weight_referencenet else None,
single_block_refnet=self.single_block_refnet,
double_block_refnet=self.double_block_refnet,
)
if self.offload:
self.offload_model_to_cpu(self.model)
self.ae.decoder.to(x.device)
x = unpack(x.float(), height, width)
x = self.ae.decode(x)
self.offload_model_to_cpu(self.ae.decoder)
x1 = x.clamp(-1, 1)
x1 = rearrange(x1[-1], "c h w -> h w c")
output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy())
return output_img
def offload_model_to_cpu(self, *models):
if not self.offload: return
for model in models:
model.cpu()
torch.cuda.empty_cache()
class XFluxSampler(XFluxPipeline):
def __init__(self, device, controlnet_loaded=False,ip_loaded=False, spatial_condition=False, offload=False, clip_image_processor=None, image_encoder=None, improj=None, share_position_embedding=False, use_share_weight_referencenet=False, single_block_refnet=False, double_block_refnet=False):
super().__init__(model_type="flux-dev", device=device, offload=False)
self.device = device
self.controlnet_loaded = controlnet_loaded
self.ip_loaded = ip_loaded
self.offload = offload
self.clip_image_processor = clip_image_processor
self.image_encoder = image_encoder
self.improj = improj
self.spatial_condition = spatial_condition
self.share_position_embedding = share_position_embedding
self.use_share_weight_referencenet = use_share_weight_referencenet
self.single_block_refnet = single_block_refnet
self.double_block_refnet = double_block_refnet