KV-Edit / flux /sampling.py
xilluill's picture
inb version init
95d4bb7
raw
history blame
9.7 kB
import math
from typing import Callable
import torch
from einops import rearrange, repeat
from torch import Tensor
from .model import Flux,Flux_kv
from .modules.conditioner import HFEmbedder
from tqdm import tqdm
from tqdm.contrib import tzip
def get_noise(
num_samples: int,
height: int,
width: int,
device: torch.device,
dtype: torch.dtype,
seed: int,
):
return torch.randn(
num_samples,
16,
# allow for packing
2 * math.ceil(height / 16),
2 * math.ceil(width / 16),
device=device,
dtype=dtype,
generator=torch.Generator(device=device).manual_seed(seed),
)
def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
bs, c, h, w = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
if isinstance(prompt, str):
prompt = [prompt]
txt = t5(prompt)
if txt.shape[0] == 1 and bs > 1:
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
txt_ids = torch.zeros(bs, txt.shape[1], 3)
vec = clip(prompt)
if vec.shape[0] == 1 and bs > 1:
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
return {
"img": img,
"img_ids": img_ids.to(img.device),
"txt": txt.to(img.device),
"txt_ids": txt_ids.to(img.device),
"vec": vec.to(img.device),
}
def prepare_flowedit(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, source_prompt: str | list[str],target_prompt) -> dict[str, Tensor]:
bs, c, h, w = img.shape
if bs == 1 and not isinstance(source_prompt, str):
bs = len(source_prompt)
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
# if isinstance(prompt, str):
# prompt = [prompt]
# txt = t5(prompt)
# if txt.shape[0] == 1 and bs > 1:
# txt = repeat(txt, "1 ... -> bs ...", bs=bs)
# txt_ids = torch.zeros(bs, txt.shape[1], 3)
# vec = clip(prompt)
# if vec.shape[0] == 1 and bs > 1:
# vec = repeat(vec, "1 ... -> bs ...", bs=bs)
if isinstance(source_prompt, str):
source_prompt = [source_prompt]
source_txt = t5(source_prompt)
if source_txt.shape[0] == 1 and bs > 1:
source_txt = repeat(source_txt, "1 ... -> bs ...", bs=bs)
source_txt_ids = torch.zeros(bs, source_txt.shape[1], 3)
source_vec = clip(target_prompt)
if source_vec.shape[0] == 1 and bs > 1:
source_vec = repeat(source_vec, "1 ... -> bs ...", bs=bs)
if isinstance(target_prompt, str):
target_prompt = [target_prompt]
target_txt = t5(target_prompt)
if target_txt.shape[0] == 1 and bs > 1:
target_txt = repeat(target_txt, "1 ... -> bs ...", bs=bs)
target_txt_ids = torch.zeros(bs, target_txt.shape[1], 3)
target_vec = clip(target_prompt)
if target_vec.shape[0] == 1 and bs > 1:
target_vec = repeat(target_vec, "1 ... -> bs ...", bs=bs)
return {
"img": img,
"img_ids": img_ids.to(img.device),
"source_txt": source_txt.to(img.device),
"source_txt_ids": source_txt_ids.to(img.device),
"source_vec": source_vec.to(img.device),
"target_txt": target_txt.to(img.device),
"target_txt_ids": target_txt_ids.to(img.device),
"target_vec": target_vec.to(img.device)
}
def time_shift(mu: float, sigma: float, t: Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# estimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
def denoise(
model: Flux,
# model input
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
# sampling parameters
timesteps: list[float],
guidance: float = 4.0,
):
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
)
img = img + (t_prev - t_curr) * pred
return img
def unpack(x: Tensor, height: int, width: int) -> Tensor:
return rearrange(
x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(height / 16),
w=math.ceil(width / 16),
ph=2,
pw=2,
)
def denoise_kv(
model: Flux_kv,
# model input
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
# sampling parameters
timesteps: list[float],
inverse,
info,
guidance: float = 4.0
):
if inverse:
timesteps = timesteps[::-1]
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for i, (t_curr, t_prev) in enumerate(tzip(timesteps[:-1], timesteps[1:])):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
info['t'] = t_prev if inverse else t_curr
if inverse:
img_name = str(info['t']) + '_' + 'img'
info['feature'][img_name] = img.cpu()
else:
img_name = str(info['t']) + '_' + 'img'
source_img = info['feature'][img_name].to(img.device)
img = source_img[:, info['mask_indices'],...] * (1 - info['mask'][:, info['mask_indices'],...]) + img * info['mask'][:, info['mask_indices'],...]
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
info=info
)
img = img + (t_prev - t_curr) * pred
return img, info
def denoise_kv_inf(
model: Flux_kv,
# model input
img: Tensor,
img_ids: Tensor,
source_txt: Tensor,
source_txt_ids: Tensor,
source_vec: Tensor,
target_txt: Tensor,
target_txt_ids: Tensor,
target_vec: Tensor,
# sampling parameters
timesteps: list[float],
target_guidance: float = 4.0,
source_guidance: float = 4.0,
info: dict = {},
):
target_guidance_vec = torch.full((img.shape[0],), target_guidance, device=img.device, dtype=img.dtype)
source_guidance_vec = torch.full((img.shape[0],), source_guidance, device=img.device, dtype=img.dtype)
mask_indices = info['mask_indices']
init_img = img.clone() # torch.Size([1, 4080, 64])
z_fe = img[:, mask_indices,...]
noise_list = []
for i in range(len(timesteps)):
noise = torch.randn(init_img.size(), dtype=init_img.dtype,
layout=init_img.layout, device=init_img.device,
generator=torch.Generator(device=init_img.device).manual_seed(0)) # 每次重新取噪声 根据t进行加噪
noise_list.append(noise)
for i, (t_curr, t_prev) in enumerate(tzip(timesteps[:-1], timesteps[1:])): # 从高到低
info['t'] = 'inf'
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
z_src = (1 - t_curr) * init_img + t_curr * noise_list[i]
z_tar = z_src[:, mask_indices,...] - init_img[:, mask_indices,...] + z_fe
info['inverse'] = True
info['feature'] = {} # 清空kv特征
v_src = model(
img=z_src,
img_ids=img_ids,
txt=source_txt,
txt_ids=source_txt_ids,
y=source_vec,
timesteps=t_vec,
guidance=source_guidance_vec,
info=info
)
info['inverse'] = False
v_tar = model(
img=z_tar,
img_ids=img_ids,
txt=target_txt,
txt_ids=target_txt_ids,
y=target_vec,
timesteps=t_vec,
guidance=target_guidance_vec,
info=info
)
v_fe = v_tar - v_src[:, mask_indices,...]
z_fe = z_fe + (t_prev - t_curr) * v_fe * info['mask'][:, mask_indices,...]
return z_fe, info