edit_anything / semantic_coding.py
memef4rmer's picture
Duplicate from jyseo/3DFuse
efe5745
from diffusers import UnCLIPPipeline, DiffusionPipeline
import torch
import os
from lora_diffusion.cli_lora_pti import *
from lora_diffusion.lora import *
from PIL import Image
import numpy as np
import json
from lora_dataset import PivotalTuningDatasetCapation as PVD
UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"}
UNET_EXTENDED_TARGET_REPLACE = {"ResnetBlock2D", "CrossAttention", "Attention", "GEGLU"}
TEXT_ENCODER_DEFAULT_TARGET_REPLACE = {"CLIPAttention"}
TEXT_ENCODER_EXTENDED_TARGET_REPLACE = {"CLIPAttention"}
DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE
def save_all(
unet,
text_encoder,
save_path,
placeholder_token_ids=None,
placeholder_tokens=None,
save_lora=True,
save_ti=True,
target_replace_module_text=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
target_replace_module_unet=DEFAULT_TARGET_REPLACE,
safe_form=True,
):
if not safe_form:
# save ti
if save_ti:
ti_path = ti_lora_path(save_path)
learned_embeds_dict = {}
for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
print(
f"Current Learned Embeddings for {tok}:, id {tok_id} ",
learned_embeds[:4],
)
learned_embeds_dict[tok] = learned_embeds.detach().cpu()
torch.save(learned_embeds_dict, ti_path)
print("Ti saved to ", ti_path)
# save text encoder
if save_lora:
save_lora_weight(
unet, save_path, target_replace_module=target_replace_module_unet
)
print("Unet saved to ", save_path)
save_lora_weight(
text_encoder,
_text_lora_path(save_path),
target_replace_module=target_replace_module_text,
)
print("Text Encoder saved to ", _text_lora_path(save_path))
else:
assert save_path.endswith(
".safetensors"
), f"Save path : {save_path} should end with .safetensors"
loras = {}
embeds = {}
if save_lora:
loras["unet"] = (unet, target_replace_module_unet)
loras["text_encoder"] = (text_encoder, target_replace_module_text)
if save_ti:
for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
print(
f"Current Learned Embeddings for {tok}:, id {tok_id} ",
learned_embeds[:4],
)
embeds[tok] = learned_embeds.detach().cpu()
return save_safeloras_with_embeds(loras, embeds, save_path)
def save_safeloras_with_embeds(
modelmap = {},
embeds = {},
outpath="./lora.safetensors",
):
"""
Saves the Lora from multiple modules in a single safetensor file.
modelmap is a dictionary of {
"module name": (module, target_replace_module)
}
"""
weights = {}
metadata = {}
for name, (model, target_replace_module) in modelmap.items():
metadata[name] = json.dumps(list(target_replace_module))
for i, (_up, _down) in enumerate(
extract_lora_as_tensor(model, target_replace_module)
):
rank = _down.shape[0]
metadata[f"{name}:{i}:rank"] = str(rank)
weights[f"{name}:{i}:up"] = _up
weights[f"{name}:{i}:down"] = _down
for token, tensor in embeds.items():
metadata[token] = EMBED_FLAG
weights[token] = tensor
sorted_dict = {key: value for key, value in sorted(weights.items())}
state={}
state['weights']=sorted_dict
state['metadata'] = metadata
# print(sorted_dict.keys())
# # print('meta', metadata)
# print(f"Saving weights to {outpath}")
# safe_save(weights, outpath, metadata)
return state
def perform_tuning(
unet,
vae,
text_encoder,
dataloader,
num_steps,
scheduler,
optimizer,
save_steps: int,
placeholder_token_ids,
placeholder_tokens,
save_path,
lr_scheduler_lora,
lora_unet_target_modules,
lora_clip_target_modules,
mask_temperature,
out_name: str,
tokenizer,
test_image_path: str,
cached_latents: bool,
log_wandb: bool = False,
wandb_log_prompt_cnt: int = 10,
class_token: str = "person",
train_inpainting: bool = False,
):
progress_bar = tqdm(range(num_steps))
progress_bar.set_description("Steps")
global_step = 0
weight_dtype = torch.float16
unet.train()
text_encoder.train()
if log_wandb:
preped_clip = prepare_clip_model_sets()
loss_sum = 0.0
for epoch in range(math.ceil(num_steps / len(dataloader))):
for batch in dataloader:
lr_scheduler_lora.step()
optimizer.zero_grad()
loss = loss_step(
batch,
unet,
vae,
text_encoder,
scheduler,
train_inpainting=train_inpainting,
t_mutliplier=0.8,
mixed_precision=True,
mask_temperature=mask_temperature,
cached_latents=cached_latents,
)
loss_sum += loss.detach().item()
loss.backward()
torch.nn.utils.clip_grad_norm_(
itertools.chain(unet.parameters(), text_encoder.parameters()), 1.0
)
optimizer.step()
progress_bar.update(1)
logs = {
"loss": loss.detach().item(),
"lr": lr_scheduler_lora.get_last_lr()[0],
}
progress_bar.set_postfix(**logs)
global_step += 1
if global_step % save_steps == 0:
save_all(
unet,
text_encoder,
placeholder_token_ids=placeholder_token_ids,
placeholder_tokens=placeholder_tokens,
save_path=os.path.join(
save_path, f"step_{global_step}.safetensors"
),
target_replace_module_text=lora_clip_target_modules,
target_replace_module_unet=lora_unet_target_modules,
)
moved = (
torch.tensor(list(itertools.chain(*inspect_lora(unet).values())))
.mean()
.item()
)
print("LORA Unet Moved", moved)
moved = (
torch.tensor(
list(itertools.chain(*inspect_lora(text_encoder).values()))
)
.mean()
.item()
)
print("LORA CLIP Moved", moved)
if log_wandb:
with torch.no_grad():
pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
)
# open all images in test_image_path
images = []
for file in os.listdir(test_image_path):
if file.endswith(".png") or file.endswith(".jpg"):
images.append(
Image.open(os.path.join(test_image_path, file))
)
wandb.log({"loss": loss_sum / save_steps})
loss_sum = 0.0
wandb.log(
evaluate_pipe(
pipe,
target_images=images,
class_token=class_token,
learnt_token="".join(placeholder_tokens),
n_test=wandb_log_prompt_cnt,
n_step=50,
clip_model_sets=preped_clip,
)
)
if global_step >= num_steps:
break
return save_all(
unet,
text_encoder,
placeholder_token_ids=placeholder_token_ids,
placeholder_tokens=placeholder_tokens,
save_path=os.path.join(save_path, f"{out_name}.safetensors"),
target_replace_module_text=lora_clip_target_modules,
target_replace_module_unet=lora_unet_target_modules,
)
def train(
images,
caption,
pretrained_model_name_or_path: str,
train_text_encoder: bool = True,
pretrained_vae_name_or_path: str = None,
revision: Optional[str] = None,
perform_inversion: bool = True,
use_template: Literal[None, "object", "style"] = None,
train_inpainting: bool = False,
placeholder_tokens: str = "",
placeholder_token_at_data: Optional[str] = None,
initializer_tokens: Optional[str] = None,
seed: int = 42,
resolution: int = 512,
color_jitter: bool = True,
train_batch_size: int = 1,
sample_batch_size: int = 1,
max_train_steps_tuning: int = 1000,
max_train_steps_ti: int = 1000,
save_steps: int = 100,
gradient_accumulation_steps: int = 4,
gradient_checkpointing: bool = False,
lora_rank: int = 4,
lora_unet_target_modules={"CrossAttention", "Attention", "GEGLU"},
lora_clip_target_modules={"CLIPAttention"},
lora_dropout_p: float = 0.0,
lora_scale: float = 1.0,
use_extended_lora: bool = False,
clip_ti_decay: bool = True,
learning_rate_unet: float = 1e-4,
learning_rate_text: float = 1e-5,
learning_rate_ti: float = 5e-4,
continue_inversion: bool = False,
continue_inversion_lr: Optional[float] = None,
use_face_segmentation_condition: bool = False,
cached_latents: bool = True,
use_mask_captioned_data: bool = False,
mask_temperature: float = 1.0,
scale_lr: bool = False,
lr_scheduler: str = "linear",
lr_warmup_steps: int = 0,
lr_scheduler_lora: str = "linear",
lr_warmup_steps_lora: int = 0,
weight_decay_ti: float = 0.00,
weight_decay_lora: float = 0.001,
use_8bit_adam: bool = False,
device="cuda:0",
extra_args: Optional[dict] = None,
log_wandb: bool = False,
wandb_log_prompt_cnt: int = 10,
wandb_project_name: str = "new_pti_project",
wandb_entity: str = "new_pti_entity",
proxy_token: str = "person",
enable_xformers_memory_efficient_attention: bool = False,
out_name: str = "final_lora",
):
torch.manual_seed(seed)
# print(placeholder_tokens, initializer_tokens)
if len(placeholder_tokens) == 0:
placeholder_tokens = []
print("PTI : Placeholder Tokens not given, using null token")
else:
placeholder_tokens = placeholder_tokens.split("|")
assert (
sorted(placeholder_tokens) == placeholder_tokens
), f"Placeholder tokens should be sorted. Use something like {'|'.join(sorted(placeholder_tokens))}'"
if initializer_tokens is None:
print("PTI : Initializer Tokens not given, doing random inits")
initializer_tokens = ["<rand-0.017>"] * len(placeholder_tokens)
else:
initializer_tokens = initializer_tokens.split("|")
assert len(initializer_tokens) == len(
placeholder_tokens
), "Unequal Initializer token for Placeholder tokens."
if proxy_token is not None:
class_token = proxy_token
class_token = "".join(initializer_tokens)
if placeholder_token_at_data is not None:
tok, pat = placeholder_token_at_data.split("|")
token_map = {tok: pat}
else:
token_map = {"DUMMY": "".join(placeholder_tokens)}
print("PTI : Placeholder Tokens", placeholder_tokens)
print("PTI : Initializer Tokens", initializer_tokens)
# get the models
text_encoder, vae, unet, tokenizer, placeholder_token_ids = get_models(
pretrained_model_name_or_path,
pretrained_vae_name_or_path,
revision,
placeholder_tokens,
initializer_tokens,
device=device,
)
noise_scheduler = DDPMScheduler.from_config(
pretrained_model_name_or_path, subfolder="scheduler"
)
if gradient_checkpointing:
unet.enable_gradient_checkpointing()
if enable_xformers_memory_efficient_attention:
from diffusers.utils.import_utils import is_xformers_available
if is_xformers_available():
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError(
"xformers is not available. Make sure it is installed correctly"
)
if scale_lr:
unet_lr = learning_rate_unet * gradient_accumulation_steps * train_batch_size
text_encoder_lr = (
learning_rate_text * gradient_accumulation_steps * train_batch_size
)
ti_lr = learning_rate_ti * gradient_accumulation_steps * train_batch_size
else:
unet_lr = learning_rate_unet
text_encoder_lr = learning_rate_text
ti_lr = learning_rate_ti
train_dataset = PVD(
images=images,
caption=caption,
token_map=token_map,
use_template=use_template,
tokenizer=tokenizer,
size=resolution,
color_jitter=color_jitter,
use_face_segmentation_condition=use_face_segmentation_condition,
use_mask_captioned_data=use_mask_captioned_data,
train_inpainting=train_inpainting,
)
train_dataset.blur_amount = 200
if train_inpainting:
assert not cached_latents, "Cached latents not supported for inpainting"
train_dataloader = inpainting_dataloader(
train_dataset, train_batch_size, tokenizer, vae, text_encoder
)
else:
print(cached_latents)
train_dataloader = text2img_dataloader(
train_dataset,
train_batch_size,
tokenizer,
vae,
text_encoder,
cached_latents=cached_latents,
)
index_no_updates = torch.arange(len(tokenizer)) != -1
for tok_id in placeholder_token_ids:
index_no_updates[tok_id] = False
unet.requires_grad_(False)
vae.requires_grad_(False)
params_to_freeze = itertools.chain(
text_encoder.text_model.encoder.parameters(),
text_encoder.text_model.final_layer_norm.parameters(),
text_encoder.text_model.embeddings.position_embedding.parameters(),
)
for param in params_to_freeze:
param.requires_grad = False
if cached_latents:
vae = None
# STEP 1 : Perform Inversion
if perform_inversion:
ti_optimizer = optim.AdamW(
text_encoder.get_input_embeddings().parameters(),
lr=ti_lr,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=weight_decay_ti,
)
lr_scheduler = get_scheduler(
lr_scheduler,
optimizer=ti_optimizer,
num_warmup_steps=lr_warmup_steps,
num_training_steps=max_train_steps_ti,
)
train_inversion(
unet,
vae,
text_encoder,
train_dataloader,
max_train_steps_ti,
cached_latents=cached_latents,
accum_iter=gradient_accumulation_steps,
scheduler=noise_scheduler,
index_no_updates=index_no_updates,
optimizer=ti_optimizer,
lr_scheduler=lr_scheduler,
save_steps=save_steps,
placeholder_tokens=placeholder_tokens,
placeholder_token_ids=placeholder_token_ids,
save_path="./tmps",
test_image_path="./tmps",
log_wandb=log_wandb,
wandb_log_prompt_cnt=wandb_log_prompt_cnt,
class_token=class_token,
train_inpainting=train_inpainting,
mixed_precision=False,
tokenizer=tokenizer,
clip_ti_decay=clip_ti_decay,
)
del ti_optimizer
# Next perform Tuning with LoRA:
if not use_extended_lora:
unet_lora_params, _ = inject_trainable_lora(
unet,
r=lora_rank,
target_replace_module=lora_unet_target_modules,
dropout_p=lora_dropout_p,
scale=lora_scale,
)
else:
print("PTI : USING EXTENDED UNET!!!")
lora_unet_target_modules = (
lora_unet_target_modules | UNET_EXTENDED_TARGET_REPLACE
)
print("PTI : Will replace modules: ", lora_unet_target_modules)
unet_lora_params, _ = inject_trainable_lora_extended(
unet, r=lora_rank, target_replace_module=lora_unet_target_modules
)
print(f"PTI : has {len(unet_lora_params)} lora")
print("PTI : Before training:")
inspect_lora(unet)
params_to_optimize = [
{"params": itertools.chain(*unet_lora_params), "lr": unet_lr},
]
text_encoder.requires_grad_(False)
if continue_inversion:
params_to_optimize += [
{
"params": text_encoder.get_input_embeddings().parameters(),
"lr": continue_inversion_lr
if continue_inversion_lr is not None
else ti_lr,
}
]
text_encoder.requires_grad_(True)
params_to_freeze = itertools.chain(
text_encoder.text_model.encoder.parameters(),
text_encoder.text_model.final_layer_norm.parameters(),
text_encoder.text_model.embeddings.position_embedding.parameters(),
)
for param in params_to_freeze:
param.requires_grad = False
else:
text_encoder.requires_grad_(False)
if train_text_encoder:
text_encoder_lora_params, _ = inject_trainable_lora(
text_encoder,
target_replace_module=lora_clip_target_modules,
r=lora_rank,
)
params_to_optimize += [
{
"params": itertools.chain(*text_encoder_lora_params),
"lr": text_encoder_lr,
}
]
inspect_lora(text_encoder)
lora_optimizers = optim.AdamW(params_to_optimize, weight_decay=weight_decay_lora)
unet.train()
if train_text_encoder:
text_encoder.train()
train_dataset.blur_amount = 70
lr_scheduler_lora = get_scheduler(
lr_scheduler_lora,
optimizer=lora_optimizers,
num_warmup_steps=lr_warmup_steps_lora,
num_training_steps=max_train_steps_tuning,
)
return perform_tuning(
unet,
vae,
text_encoder,
train_dataloader,
max_train_steps_tuning,
cached_latents=cached_latents,
scheduler=noise_scheduler,
optimizer=lora_optimizers,
save_steps=save_steps,
placeholder_tokens=placeholder_tokens,
placeholder_token_ids=placeholder_token_ids,
save_path="./tmps",
lr_scheduler_lora=lr_scheduler_lora,
lora_unet_target_modules=lora_unet_target_modules,
lora_clip_target_modules=lora_clip_target_modules,
mask_temperature=mask_temperature,
tokenizer=tokenizer,
out_name=out_name,
test_image_path="./tmps",
log_wandb=log_wandb,
wandb_log_prompt_cnt=wandb_log_prompt_cnt,
class_token=class_token,
train_inpainting=train_inpainting,
)
def semantic_karlo(prompt, output_dir, num_initial_image, bg_preprocess=False):
pipe = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha", torch_dtype=torch.float16)
pipe = pipe.to('cuda')
view_prompt=["front view of ","overhead view of ","side view of ", "back view of "]
if bg_preprocess:
# Please refer to the code at https://github.com/Ir1d/image-background-remove-tool.
import cv2
from carvekit.api.high import HiInterface
interface = HiInterface(object_type="object",
batch_size_seg=5,
batch_size_matting=1,
device='cuda' if torch.cuda.is_available() else 'cpu',
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
matting_mask_size=2048,
trimap_prob_threshold=231,
trimap_dilation=30,
trimap_erosion_iters=5,
fp16=False)
for i in range(num_initial_image):
t=", white background" if bg_preprocess else ", white background"
if i==0:
prompt_ = f"{view_prompt[i%4]}{prompt}{t}"
else:
prompt_ = f"{view_prompt[i%4]}{prompt}"
image = pipe(prompt_).images[0]
fn=f"instance{i}.png"
os.makedirs(output_dir,exist_ok=True)
if bg_preprocess:
# motivated by NeuralLift-360 (removing bg), and Zero-1-to-3 (removing bg and object-centering)
# NOTE: This option was added during the code orgranization process.
# The results reported in the paper were obtained with [bg_preprocess: False] setting.
img_without_background = interface([image])
mask = np.array(img_without_background[0]) > 127
image = np.array(image)
image[~mask] = [255., 255., 255.]
# x, y, w, h = cv2.boundingRect(mask.astype(np.uint8))
# image = image[y:y+h, x:x+w, :]
image = Image.fromarray(np.array(image))
image.save(os.path.join(output_dir,fn))
def semantic_sd(prompt, output_dir, num_initial_image, bg_preprocess=False):
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipe = pipe.to('cuda')
view_prompt=["front view of ","overhead view of ","side view of ", "back view of "]
if bg_preprocess:
# Please refer to the code at https://github.com/Ir1d/image-background-remove-tool.
import cv2
from carvekit.api.high import HiInterface
interface = HiInterface(object_type="object",
batch_size_seg=5,
batch_size_matting=1,
device='cuda' if torch.cuda.is_available() else 'cpu',
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
matting_mask_size=2048,
trimap_prob_threshold=231,
trimap_dilation=30,
trimap_erosion_iters=5,
fp16=False)
for i in range(num_initial_image):
t=", white background" if bg_preprocess else ", white background"
if i==0:
prompt_ = f"{view_prompt[i%4]}{prompt}{t}"
else:
prompt_ = f"{view_prompt[i%4]}{prompt}"
image = pipe(prompt_).images[0]
fn=f"instance{i}.png"
os.makedirs(output_dir,exist_ok=True)
if bg_preprocess:
# motivated by NeuralLift-360 (removing bg), and Zero-1-to-3 (removing bg and object-centering)
# NOTE: This option was added during the code orgranization process.
# The results reported in the paper were obtained with [bg_preprocess: False] setting.
img_without_background = interface([image])
mask = np.array(img_without_background[0]) > 127
image = np.array(image)
image[~mask] = [255., 255., 255.]
# x, y, w, h = cv2.boundingRect(mask.astype(np.uint8))
# image = image[y:y+h, x:x+w, :]
image = Image.fromarray(np.array(image))
image.save(os.path.join(output_dir,fn))
def semantic_coding(images, cfgs,sd,initial):
ti_step=cfgs.pop('ti_step')
pt_step=cfgs.pop('pt_step')
# semantic_model=cfgs.pop('semantic_model')
prompt=cfgs['sd']['prompt']
# instance_dir=os.path.join(exp_dir,'initial_image')
# weight_dir=os.path.join(exp_dir,'lora')
if initial=="":
initial=None
state=train(images=images, caption=initial, pretrained_model_name_or_path='runwayml/stable-diffusion-v1-5',\
gradient_checkpointing=True,\
scale_lr=True,lora_rank=1,cached_latents=False,save_steps=max(ti_step,pt_step)+1,\
max_train_steps_ti=ti_step,max_train_steps_tuning=pt_step, use_template="object",\
lr_warmup_steps=0, lr_warmup_steps_lora=100, placeholder_tokens="<0>", initializer_tokens=initial,\
continue_inversion=True, continue_inversion_lr=1e-4,device="cuda:0",
)
if initial is not None:
sd.prompt=prompt.replace(initial,'<0>')
else:
sd.prompt="a <0>"
return state