Spaces:
Runtime error
Runtime error
from PIL import Image, ExifTags | |
import numpy as np | |
import torch | |
from torch import Tensor | |
from einops import rearrange | |
import uuid | |
import os | |
from src.flux.modules.layers import ( | |
SingleStreamBlockProcessor, | |
DoubleStreamBlockProcessor, | |
SingleStreamBlockLoraProcessor, | |
DoubleStreamBlockLoraProcessor, | |
IPDoubleStreamBlockProcessor, | |
ImageProjModel, | |
) | |
from src.flux.sampling import denoise, denoise_controlnet, get_noise, get_schedule, prepare, unpack | |
from src.flux.util import ( | |
load_ae, | |
load_clip, | |
load_flow_model, | |
load_t5, | |
load_controlnet, | |
load_flow_model_quintized, | |
Annotator, | |
get_lora_rank, | |
load_checkpoint | |
) | |
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor | |
import spaces | |
class XFluxPipeline: | |
def __init__(self, model_type, device, offload: bool = False): | |
self.device = torch.device(device) | |
self.offload = offload | |
self.model_type = model_type | |
self.clip = load_clip(self.device) | |
self.t5 = load_t5(self.device, max_length=512) | |
self.ae = load_ae(model_type, device="cpu" if offload else self.device) | |
if "fp8" in model_type: | |
self.model = load_flow_model_quintized(model_type, device="cpu" if offload else self.device) | |
else: | |
self.model = load_flow_model(model_type, device="cpu" if offload else self.device) | |
self.image_encoder_path = "openai/clip-vit-large-patch14" | |
self.hf_lora_collection = "XLabs-AI/flux-lora-collection" | |
self.lora_types_to_names = { | |
"realism": "lora.safetensors", | |
} | |
self.controlnet_loaded = False | |
self.ip_loaded = False | |
def set_ip(self, local_path: str = None, repo_id = None, name: str = None): | |
self.model.to(self.device) | |
# unpack checkpoint | |
checkpoint = load_checkpoint(local_path, repo_id, name) | |
prefix = "double_blocks." | |
blocks = {} | |
proj = {} | |
for key, value in checkpoint.items(): | |
if key.startswith(prefix): | |
blocks[key[len(prefix):].replace('.processor.', '.')] = value | |
if key.startswith("ip_adapter_proj_model"): | |
proj[key[len("ip_adapter_proj_model."):]] = value | |
# load image encoder | |
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( | |
self.device, dtype=torch.float16 | |
) | |
self.clip_image_processor = CLIPImageProcessor() | |
# setup image embedding projection model | |
self.improj = ImageProjModel(4096, 768, 4) | |
self.improj.load_state_dict(proj) | |
self.improj = self.improj.to(self.device, dtype=torch.bfloat16) | |
ip_attn_procs = {} | |
for name, _ in self.model.attn_processors.items(): | |
ip_state_dict = {} | |
for k in checkpoint.keys(): | |
if name in k: | |
ip_state_dict[k.replace(f'{name}.', '')] = checkpoint[k] | |
if ip_state_dict: | |
ip_attn_procs[name] = IPDoubleStreamBlockProcessor(4096, 3072) | |
ip_attn_procs[name].load_state_dict(ip_state_dict) | |
ip_attn_procs[name].to(self.device, dtype=torch.bfloat16) | |
else: | |
ip_attn_procs[name] = self.model.attn_processors[name] | |
self.model.set_attn_processor(ip_attn_procs) | |
self.ip_loaded = True | |
def set_lora(self, local_path: str = None, repo_id: str = None, | |
name: str = None, lora_weight: int = 0.7): | |
checkpoint = load_checkpoint(local_path, repo_id, name) | |
self.update_model_with_lora(checkpoint, lora_weight) | |
def set_lora_from_collection(self, lora_type: str = "realism", lora_weight: int = 0.7): | |
checkpoint = load_checkpoint( | |
None, self.hf_lora_collection, self.lora_types_to_names[lora_type] | |
) | |
self.update_model_with_lora(checkpoint, lora_weight) | |
def update_model_with_lora(self, checkpoint, lora_weight): | |
rank = get_lora_rank(checkpoint) | |
lora_attn_procs = {} | |
for name, _ in self.model.attn_processors.items(): | |
lora_state_dict = {} | |
for k in checkpoint.keys(): | |
if name in k: | |
lora_state_dict[k[len(name) + 1:]] = checkpoint[k] * lora_weight | |
if len(lora_state_dict): | |
if name.startswith("single_blocks"): | |
lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=3072, rank=rank) | |
else: | |
lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=3072, rank=rank) | |
lora_attn_procs[name].load_state_dict(lora_state_dict) | |
lora_attn_procs[name].to(self.device) | |
else: | |
if name.startswith("single_blocks"): | |
lora_attn_procs[name] = SingleStreamBlockProcessor() | |
else: | |
lora_attn_procs[name] = DoubleStreamBlockProcessor() | |
self.model.set_attn_processor(lora_attn_procs) | |
def set_controlnet(self, control_type: str, local_path: str = None, repo_id: str = None, name: str = None): | |
self.model.to(self.device) | |
self.controlnet = load_controlnet(self.model_type, self.device).to(torch.bfloat16) | |
checkpoint = load_checkpoint(local_path, repo_id, name) | |
self.controlnet.load_state_dict(checkpoint, strict=False) | |
self.annotator = Annotator(control_type, self.device) | |
self.controlnet_loaded = True | |
self.control_type = control_type | |
def get_image_proj( | |
self, | |
image_prompt: Tensor, | |
): | |
# encode image-prompt embeds | |
image_prompt = self.clip_image_processor( | |
images=image_prompt, | |
return_tensors="pt" | |
).pixel_values | |
image_prompt = image_prompt.to(self.image_encoder.device) | |
image_prompt_embeds = self.image_encoder( | |
image_prompt | |
).image_embeds.to( | |
device=self.device, dtype=torch.bfloat16, | |
) | |
# encode image | |
image_proj = self.improj(image_prompt_embeds) | |
return image_proj | |
def __call__(self, | |
prompt: str, | |
image_prompt: Image = None, | |
controlnet_image: Image = None, | |
width: int = 512, | |
height: int = 512, | |
guidance: float = 4, | |
num_steps: int = 50, | |
seed: int = 123456789, | |
true_gs: float = 3, | |
control_weight: float = 0.9, | |
ip_scale: float = 1.0, | |
neg_ip_scale: float = 1.0, | |
neg_prompt: str = '', | |
neg_image_prompt: Image = None, | |
timestep_to_start_cfg: int = 0, | |
): | |
width = 16 * (width // 16) | |
height = 16 * (height // 16) | |
image_proj = None | |
neg_image_proj = None | |
if not (image_prompt is None and neg_image_prompt is None) : | |
assert self.ip_loaded, 'You must setup IP-Adapter to add image prompt as input' | |
if image_prompt is None: | |
image_prompt = np.zeros((width, height, 3), dtype=np.uint8) | |
if neg_image_prompt is None: | |
neg_image_prompt = np.zeros((width, height, 3), dtype=np.uint8) | |
image_proj = self.get_image_proj(image_prompt) | |
neg_image_proj = self.get_image_proj(neg_image_prompt) | |
if self.controlnet_loaded: | |
controlnet_image = self.annotator(controlnet_image, width, height) | |
controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1) | |
controlnet_image = controlnet_image.permute( | |
2, 0, 1).unsqueeze(0).to(torch.bfloat16).to(self.device) | |
return self.forward( | |
prompt, | |
width, | |
height, | |
guidance, | |
num_steps, | |
seed, | |
controlnet_image, | |
timestep_to_start_cfg=timestep_to_start_cfg, | |
true_gs=true_gs, | |
control_weight=control_weight, | |
neg_prompt=neg_prompt, | |
image_proj=image_proj, | |
neg_image_proj=neg_image_proj, | |
ip_scale=ip_scale, | |
neg_ip_scale=neg_ip_scale, | |
) | |
def gradio_generate(self, prompt, image_prompt, controlnet_image, width, height, guidance, | |
num_steps, seed, true_gs, ip_scale, neg_ip_scale, neg_prompt, | |
neg_image_prompt, timestep_to_start_cfg, control_type, control_weight, | |
lora_weight, local_path, lora_local_path, ip_local_path): | |
if controlnet_image is not None: | |
controlnet_image = Image.fromarray(controlnet_image) | |
if ((self.controlnet_loaded and control_type != self.control_type) | |
or not self.controlnet_loaded): | |
if local_path is not None: | |
self.set_controlnet(control_type, local_path=local_path) | |
else: | |
self.set_controlnet(control_type, local_path=None, | |
repo_id=f"xlabs-ai/flux-controlnet-{control_type}-v3", | |
name=f"flux-{control_type}-controlnet-v3.safetensors") | |
if lora_local_path is not None: | |
self.set_lora(local_path=lora_local_path, lora_weight=lora_weight) | |
if image_prompt is not None: | |
image_prompt = Image.fromarray(image_prompt) | |
if neg_image_prompt is not None: | |
neg_image_prompt = Image.fromarray(neg_image_prompt) | |
if not self.ip_loaded: | |
if ip_local_path is not None: | |
self.set_ip(local_path=ip_local_path) | |
else: | |
self.set_ip(repo_id="xlabs-ai/flux-ip-adapter", | |
name="flux-ip-adapter.safetensors") | |
seed = int(seed) | |
if seed == -1: | |
seed = torch.Generator(device="cpu").seed() | |
img = self(prompt, image_prompt, controlnet_image, width, height, guidance, | |
num_steps, seed, true_gs, control_weight, ip_scale, neg_ip_scale, neg_prompt, | |
neg_image_prompt, timestep_to_start_cfg) | |
filename = f"output/gradio/{uuid.uuid4()}.jpg" | |
os.makedirs(os.path.dirname(filename), exist_ok=True) | |
exif_data = Image.Exif() | |
exif_data[ExifTags.Base.Make] = "XLabs AI" | |
exif_data[ExifTags.Base.Model] = self.model_type | |
img.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0) | |
return img, filename | |
def forward( | |
self, | |
prompt, | |
width, | |
height, | |
guidance, | |
num_steps, | |
seed, | |
controlnet_image = None, | |
timestep_to_start_cfg = 0, | |
true_gs = 3.5, | |
control_weight = 0.9, | |
neg_prompt="", | |
image_proj=None, | |
neg_image_proj=None, | |
ip_scale=1.0, | |
neg_ip_scale=1.0, | |
): | |
x = get_noise( | |
1, height, width, device=self.device, | |
dtype=torch.bfloat16, seed=seed | |
) | |
timesteps = get_schedule( | |
num_steps, | |
(width // 8) * (height // 8) // (16 * 16), | |
shift=True, | |
) | |
torch.manual_seed(seed) | |
with torch.no_grad(): | |
if self.offload: | |
self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device) | |
inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=prompt) | |
neg_inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt) | |
if self.offload: | |
self.offload_model_to_cpu(self.t5, self.clip) | |
self.model = self.model.to(self.device) | |
if self.controlnet_loaded: | |
x = denoise_controlnet( | |
self.model, | |
**inp_cond, | |
controlnet=self.controlnet, | |
timesteps=timesteps, | |
guidance=guidance, | |
controlnet_cond=controlnet_image, | |
timestep_to_start_cfg=timestep_to_start_cfg, | |
neg_txt=neg_inp_cond['txt'], | |
neg_txt_ids=neg_inp_cond['txt_ids'], | |
neg_vec=neg_inp_cond['vec'], | |
true_gs=true_gs, | |
controlnet_gs=control_weight, | |
image_proj=image_proj, | |
neg_image_proj=neg_image_proj, | |
ip_scale=ip_scale, | |
neg_ip_scale=neg_ip_scale, | |
) | |
else: | |
x = denoise( | |
self.model, | |
**inp_cond, | |
timesteps=timesteps, | |
guidance=guidance, | |
timestep_to_start_cfg=timestep_to_start_cfg, | |
neg_txt=neg_inp_cond['txt'], | |
neg_txt_ids=neg_inp_cond['txt_ids'], | |
neg_vec=neg_inp_cond['vec'], | |
true_gs=true_gs, | |
image_proj=image_proj, | |
neg_image_proj=neg_image_proj, | |
ip_scale=ip_scale, | |
neg_ip_scale=neg_ip_scale, | |
) | |
if self.offload: | |
self.offload_model_to_cpu(self.model) | |
self.ae.decoder.to(x.device) | |
x = unpack(x.float(), height, width) | |
x = self.ae.decode(x) | |
self.offload_model_to_cpu(self.ae.decoder) | |
x1 = x.clamp(-1, 1) | |
x1 = rearrange(x1[-1], "c h w -> h w c") | |
output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy()) | |
return output_img | |
def offload_model_to_cpu(self, *models): | |
if not self.offload: return | |
for model in models: | |
model.cpu() | |
torch.cuda.empty_cache() | |
class XFluxSampler(XFluxPipeline): | |
def __init__(self, clip, t5, ae, model, device): | |
self.clip = clip | |
self.t5 = t5 | |
self.ae = ae | |
self.model = model | |
self.model.eval() | |
self.device = device | |
self.controlnet_loaded = False | |
self.ip_loaded = False | |
self.offload = False | |