mist-v2 / attacks /mist.py
AeroXi's picture
Upload folder using huggingface_hub
ece766c
raw history blame
No virus
41.3 kB
import argparse
import copy
import hashlib
import itertools
import logging
import os
import sys
import gc
from pathlib import Path
from colorama import Fore, Style, init,Back
import random, time
'''some system level settings'''
init(autoreset=True)
sys.path.insert(0, sys.path[0]+"/../")
import lpips
import datasets
import diffusers
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel,DDIMScheduler
from diffusers.utils.import_utils import is_xformers_available
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig
from torch import autograd
from typing import Optional, Tuple
import pynvml
# from utils import print_tensor
from lora_diffusion import (
extract_lora_ups_down,
inject_trainable_lora,
)
from lora_diffusion.xformers_utils import set_use_memory_efficient_attention_xformers
from attacks.utils import LatentAttack
logger = get_logger(__name__)
def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--cuda",
action="store_true",
help="Use gpu for attack",
)
parser.add_argument(
"--pretrained_model_name_or_path",
"-p",
type=str,
default="./stable-diffusion/stable-diffusion-1-5",
required=False,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help=(
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
" float32 precision."
),
)
parser.add_argument(
"--tokenizer_name",
type=str,
default=None,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--instance_data_dir",
type=str,
default="",
required=False,
help="A folder containing the images to add adversarial noise",
)
parser.add_argument(
"--class_data_dir",
type=str,
default="",
required=False,
help="A folder containing the training data of class images.",
)
parser.add_argument(
"--instance_prompt",
type=str,
default="a picture",
required=False,
help="The prompt with identifier specifying the instance",
)
parser.add_argument(
"--class_prompt",
type=str,
default="a picture",
help="The prompt to specify images in the same class as provided instance images.",
)
parser.add_argument(
"--with_prior_preservation",
default=True,
help="Flag to add prior preservation loss.",
)
parser.add_argument(
"--prior_loss_weight",
type=float,
default=0.1,
help="The weight of prior preservation loss.",
)
parser.add_argument(
"--num_class_images",
type=int,
default=50,
help=(
"Minimal class images for prior preservation loss. If there are not enough images already present in"
" class_data_dir, additional images will be sampled with class_prompt."
),
)
parser.add_argument(
"--output_dir",
type=str,
default="",
help="The output directory where the perturbed data is stored",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--center_crop",
default=True,
help=(
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser.add_argument(
"--train_text_encoder",
action="store_false",
help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
)
parser.add_argument(
"--train_batch_size",
type=int,
default=1,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument(
"--sample_batch_size",
type=int,
default=1,
help="Batch size (per device) for sampling images.",
)
parser.add_argument(
"--max_train_steps",
type=int,
default=5,
help="Total number of training steps to perform.",
)
parser.add_argument(
"--max_f_train_steps",
type=int,
default=10,
help="Total number of sub-steps to train surogate model.",
)
parser.add_argument(
"--max_adv_train_steps",
type=int,
default=30,
help="Total number of sub-steps to train adversarial noise.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--checkpointing_iterations",
type=int,
default=5,
help=("Save a checkpoint of the training state every X iterations."),
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default="bf16",
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--low_vram_mode",
action="store_false",
help="Whether or not to use low vram mode.",
)
parser.add_argument(
"--pgd_alpha",
type=float,
default=5e-3,
help="The step size for pgd.",
)
parser.add_argument(
"--pgd_eps",
type=float,
default=float(8.0/255.0),
help="The noise budget for pgd.",
)
parser.add_argument(
"--lpips_bound",
type=float,
default=0.1,
help="The noise budget for pgd.",
)
parser.add_argument(
"--lpips_weight",
type=float,
default=0.5,
help="The noise budget for pgd.",
)
parser.add_argument(
"--fused_weight",
type=float,
default=1e-5,
help="The decay of alpha and eps when applying pre_attack",
)
parser.add_argument(
"--target_image_path",
default="data/MIST.png",
help="target image for attacking",
)
parser.add_argument(
"--lora_rank",
type=int,
default=4,
help="Rank of LoRA approximation.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--learning_rate_text",
type=float,
default=5e-6,
help="Initial learning rate for text encoder (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--mode",
type=str,
choices=['lunet','fused', 'anti-db'],
default='lunet',
help="The mode of attack",
)
parser.add_argument(
"--constraint",
type=str,
choices=['eps','lpips'],
default='eps',
help="The constraint of attack",
)
parser.add_argument(
"--use_8bit_adam",
action="store_true",
help="Whether or not to use 8-bit Adam from bitsandbytes.",
)
parser.add_argument(
"--adam_beta1",
type=float,
default=0.9,
help="The beta1 parameter for the Adam optimizer.",
)
parser.add_argument(
"--adam_beta2",
type=float,
default=0.999,
help="The beta2 parameter for the Adam optimizer.",
)
parser.add_argument(
"--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
)
parser.add_argument(
"--adam_epsilon",
type=float,
default=1e-08,
help="Epsilon value for the Adam optimizer",
)
parser.add_argument(
"--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
)
parser.add_argument(
"--local_rank",
type=int,
default=-1,
help="For distributed training: local_rank",
)
parser.add_argument(
"--resume_unet",
type=str,
default=None,
help=("File path for unet lora to resume training."),
)
parser.add_argument(
"--resume_text_encoder",
type=str,
default=None,
help=("File path for text encoder lora to resume training."),
)
parser.add_argument(
"--resize",
action='store_true',
required=False,
help="Should images be resized to --resolution after attacking?",
)
if input_args is not None:
args = parser.parse_args(input_args)
else:
args = parser.parse_args()
if args.output_dir != "":
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir,exist_ok=True)
print(Back.BLUE+Fore.GREEN+'create output dir: {}'.format(args.output_dir))
return args
class DreamBoothDatasetFromTensor(Dataset):
"""Just like DreamBoothDataset, but take instance_images_tensor instead of path"""
def __init__(
self,
instance_images_tensor,
prompts,
instance_prompt,
tokenizer,
class_data_root=None,
class_prompt=None,
size=512,
center_crop=False,
):
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer
self.instance_images_tensor = instance_images_tensor
self.instance_prompts = prompts
self.num_instance_images = len(self.instance_images_tensor)
self.instance_prompt = instance_prompt
self._length = self.num_instance_images
if class_data_root is not None:
self.class_data_root = Path(class_data_root)
self.class_data_root.mkdir(parents=True, exist_ok=True)
self.class_images_path = list(self.class_data_root.iterdir())
self.num_class_images = len(self.class_images_path)
# self._length = max(self.num_class_images, self.num_instance_images)
self.class_prompt = class_prompt
else:
self.class_data_root = None
self.image_transforms = transforms.Compose(
[
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def __len__(self):
return self._length
def __getitem__(self, index):
example = {}
instance_image = self.instance_images_tensor[index % self.num_instance_images]
instance_prompt = self.instance_prompts[index % self.num_instance_images]
if instance_prompt == None:
instance_prompt = self.instance_prompt
instance_prompt = \
'masterpiece,best quality,extremely detailed CG unity 8k wallpaper,illustration,cinematic lighting,beautiful detailed glow' + instance_prompt
example["instance_images"] = instance_image
example["instance_prompt_ids"] = self.tokenizer(
instance_prompt,
truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids
if self.class_data_root:
class_image = Image.open(self.class_images_path[index % self.num_class_images])
if not class_image.mode == "RGB":
class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image)
example["class_prompt_ids"] = self.tokenizer(
self.class_prompt,
truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids
return example
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path,
subfolder="text_encoder",
revision=revision,
)
model_class = text_encoder_config.architectures[0]
if model_class == "CLIPTextModel":
from transformers import CLIPTextModel
return CLIPTextModel
elif model_class == "RobertaSeriesModelWithTransformation":
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
return RobertaSeriesModelWithTransformation
else:
raise ValueError(f"{model_class} is not supported.")
class PromptDataset(Dataset):
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
def __init__(self, prompt, num_samples):
self.prompt = prompt
self.num_samples = num_samples
def __len__(self):
return self.num_samples
def __getitem__(self, index):
example = {}
example["prompt"] = self.prompt
example["index"] = index
return example
def load_data(data_dir, size=512, center_crop=True) -> torch.Tensor:
image_transforms = transforms.Compose(
[
transforms.Resize((size,size), interpolation=transforms.InterpolationMode.BILINEAR),
# transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
# transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
# load images & prompts
images, prompts = [], []
num_image = 0
for filename in os.listdir(data_dir):
if filename.endswith(".png") or filename.endswith(".jpg"):
file_path = os.path.join(data_dir, filename)
images.append(Image.open(file_path).convert("RGB"))
num_image += 1
prompt_name = filename[:-3] + 'txt'
prompt_path = os.path.join(data_dir, prompt_name)
if os.path.exists(prompt_path):
with open(prompt_path, "r") as file:
text_string = file.read()
prompts.append(text_string)
print("==load image {} from {}, prompt: {}==".format(num_image-1, file_path, text_string))
else:
prompts.append(None)
print("==load image {} from {}, prompt: None, args.instance_prompt used==".format(num_image-1, file_path))
# load sizes
sizes = [img.size for img in images]
# preprocess images
images = [image_transforms(img) for img in images]
images = torch.stack(images)
print("==tensor shape: {}==".format(images.shape))
return images, prompts, sizes
def train_one_epoch(
args,
accelerator,
models,
tokenizer,
noise_scheduler,
vae,
data_tensor: torch.Tensor,
prompts,
weight_dtype=torch.bfloat16,
):
# prepare training data
train_dataset = DreamBoothDatasetFromTensor(
data_tensor,
prompts,
args.instance_prompt,
tokenizer,
args.class_data_dir,
args.class_prompt,
args.resolution,
args.center_crop,
)
device = accelerator.device
# prepare models & inject lora layers
unet, text_encoder = copy.deepcopy(models[0]), copy.deepcopy(models[1])
vae.to(device, dtype=weight_dtype)
vae.requires_grad_(False)
text_encoder.to(device, dtype=weight_dtype)
unet.to(device, dtype=weight_dtype)
if args.low_vram_mode:
set_use_memory_efficient_attention_xformers(unet,True)
# this is only done at the first epoch
unet_lora_params, _ = inject_trainable_lora(
unet, r=args.lora_rank, loras=args.resume_unet
)
if args.train_text_encoder:
text_encoder_lora_params, _ = inject_trainable_lora(
text_encoder,
target_replace_module=["CLIPAttention"],
r=args.lora_rank,
)
# for _up, _down in extract_lora_ups_down(
# text_encoder, target_replace_module=["CLIPAttention"]
# ):
# print("Before training: text encoder First Layer lora up", _up.weight.data)
# print(
# "Before training: text encoder First Layer lora down", _down.weight.data
# )
# break
# build the optimizer
optimizer_class = torch.optim.AdamW
text_lr = (
args.learning_rate
if args.learning_rate_text is None
else args.learning_rate_text
)
params_to_optimize = (
[
{
"params": itertools.chain(*unet_lora_params),
"lr": args.learning_rate},
{
"params": itertools.chain(*text_encoder_lora_params),
"lr": text_lr,
},
]
if args.train_text_encoder
else itertools.chain(*unet_lora_params)
)
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
# begin training
for step in range(args.max_f_train_steps):
unet.train()
text_encoder.train()
random.seed(time.time())
instance_idx = random.randint(0, len(train_dataset)-1)
step_data = train_dataset[instance_idx]
pixel_values = torch.stack([step_data["instance_images"], step_data["class_images"]])
#print("pixel_values shape: {}".format(pixel_values.shape))
input_ids = torch.cat([step_data["instance_prompt_ids"], step_data["class_prompt_ids"]], dim=0).to(device)
for k in range(pixel_values.shape[0]):
#calculate loss of instance and class seperately
pixel_value = pixel_values[k, :].unsqueeze(0).to(device, dtype=weight_dtype)
latents = vae.encode(pixel_value).latent_dist.sample().detach().clone()
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# encode text
input_id = input_ids[k, :].unsqueeze(0)
encode_hidden_states = text_encoder(input_id)[0]
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
model_pred= unet(noisy_latents, timesteps, encode_hidden_states).sample
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
if k == 1:
# calculate loss of class(prior)
loss *= args.prior_loss_weight
loss.backward()
if k == 1:
print(f"==loss - image index {instance_idx}, loss: {loss.detach().item() / args.prior_loss_weight}, prior")
else:
print(f"==loss - image index {instance_idx}, loss: {loss.detach().item()}, instance")
params_to_clip = (
itertools.chain(unet.parameters(), text_encoder.parameters())
if args.train_text_encoder
else unet.parameters()
)
torch.nn.utils.clip_grad_norm_(params_to_clip, 1.0, error_if_nonfinite=True)
optimizer.step()
optimizer.zero_grad()
return [unet, text_encoder]
def pgd_attack(
args,
accelerator,
models,
tokenizer,
noise_scheduler:DDIMScheduler,
vae:AutoencoderKL,
data_tensor: torch.Tensor,
original_images: torch.Tensor,
target_tensor: torch.Tensor,
weight_dtype = torch.bfloat16,
):
"""Return new perturbed data"""
num_steps = args.max_adv_train_steps
unet, text_encoder = models
device = accelerator.device
if args.constraint == 'lpips':
lpips_vgg = lpips.LPIPS(net='vgg')
vae.to(device, dtype=weight_dtype)
text_encoder.to(device, dtype=weight_dtype)
unet.to(device, dtype=weight_dtype)
if args.low_vram_mode:
unet.set_use_memory_efficient_attention_xformers(True)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.requires_grad_(False)
data_tensor = data_tensor.detach().clone()
num_image = len(data_tensor)
image_list = []
tbar = tqdm(range(num_image))
tbar.set_description("PGD attack")
for id in range(num_image):
tbar.update(1)
perturbed_image = data_tensor[id, :].unsqueeze(0)
perturbed_image.requires_grad = True
original_image = original_images[id, :].unsqueeze(0)
input_ids = tokenizer(
args.instance_prompt,
truncation=True,
padding="max_length",
max_length=tokenizer.model_max_length,
return_tensors="pt",
).input_ids
input_ids = input_ids.to(device)
for step in range(num_steps):
perturbed_image.requires_grad = False
with torch.no_grad():
latents = vae.encode(perturbed_image.to(device, dtype=weight_dtype)).latent_dist.mean
#offload vae
latents = latents.detach().clone()
latents.requires_grad = True
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(input_ids)[0]
# Predict the noise residual
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
unet.zero_grad()
text_encoder.zero_grad()
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# target-shift loss
if target_tensor is not None:
if args.mode != 'anti-db':
loss = - F.mse_loss(model_pred, target_tensor)
# fused mode
if args.mode == 'fused':
latent_attack = LatentAttack()
loss = loss - 1e2 * latent_attack(latents, target_tensor=target_tensor)
loss = loss / args.gradient_accumulation_steps
grads = autograd.grad(loss, latents)[0].detach().clone()
# now loss is backproped to latents
#print('grads: {}'.format(grads))
#do forward on vae again
perturbed_image.requires_grad = True
gc_latents = vae.encode(perturbed_image.to(device, dtype=weight_dtype)).latent_dist.mean
gc_latents.backward(gradient=grads)
if step % args.gradient_accumulation_steps == args.gradient_accumulation_steps - 1:
if args.constraint == 'eps':
alpha = args.pgd_alpha
adv_images = perturbed_image + alpha * perturbed_image.grad.sign()
# hard constraint
eps = args.pgd_eps
eta = torch.clamp(adv_images - original_image, min=-eps, max=+eps)
perturbed_image = torch.clamp(original_image + eta, min=-1, max=+1).detach_()
perturbed_image.requires_grad = True
elif args.constraint == 'lpips':
# compute reg loss
lpips_distance = lpips_vgg(perturbed_image, original_image)
reg_loss = args.lpips_weight * torch.max(lpips_distance - args.lpips_bound, 0)[0].squeeze()
reg_loss.backward()
alpha = args.pgd_alpha
adv_images = perturbed_image + alpha * perturbed_image.grad.sign()
eta = adv_images - original_image
perturbed_image = torch.clamp(original_image + eta, min=-1, max=+1).detach_()
perturbed_image.requires_grad = True
else:
raise NotImplementedError
#print(f"PGD loss - step {step}, loss: {loss.detach().item()}")
image_list.append(perturbed_image.detach().clone().squeeze(0))
outputs = torch.stack(image_list)
return outputs
def main(args):
if args.cuda:
try:
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
mem_free = mem_info.free / float(1073741824)
if mem_free < 5.5:
raise NotImplementedError("Your GPU memory is not enough for running Mist on GPU. Please try CPU mode.")
except:
raise NotImplementedError("No GPU found in GPU mode. Please try CPU mode.")
logging_dir = Path(args.output_dir, args.logging_dir)
if not args.cuda:
accelerator = Accelerator(
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_dir=logging_dir,
cpu=True
)
else:
accelerator = Accelerator(
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_dir=logging_dir
)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
if args.seed is not None:
set_seed(args.seed)
weight_dtype = torch.float32
if args.cuda:
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
print("==precision: {}==".format(weight_dtype))
# Generate class images if prior preservation is enabled.
if args.with_prior_preservation:
class_images_dir = Path(args.class_data_dir)
if not class_images_dir.exists():
class_images_dir.mkdir(parents=True)
cur_class_images = len(list(class_images_dir.iterdir()))
if cur_class_images < args.num_class_images:
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
if args.mixed_precision == "fp32":
torch_dtype = torch.float32
elif args.mixed_precision == "fp16":
torch_dtype = torch.float16
elif args.mixed_precision == "bf16":
torch_dtype = torch.bfloat16
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
torch_dtype=torch_dtype,
safety_checker=None,
revision=args.revision,
)
pipeline.set_progress_bar_config(disable=True)
num_new_images = args.num_class_images - cur_class_images
logger.info(f"Number of class images to sample: {num_new_images}.")
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
sample_dataloader = accelerator.prepare(sample_dataloader)
pipeline.to(accelerator.device)
for example in tqdm(
sample_dataloader,
desc="Generating class images",
disable=not accelerator.is_local_main_process,
):
images = pipeline(example["prompt"]).images
for i, image in enumerate(images):
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename)
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
# import correct text encoder class
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
# Load scheduler and models
text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
# add by lora
unet.requires_grad_(False)
# end: added by lora
tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
use_fast=False,
)
noise_scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
if not args.cuda:
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
).cuda()
else:
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
)
vae.to(accelerator.device, dtype=weight_dtype)
vae.requires_grad_(False)
vae.encoder.training = True
vae.encoder.gradient_checkpointing = True
#print info about train_text_encoder
if not args.train_text_encoder:
text_encoder.requires_grad_(False)
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
perturbed_data, prompts, data_sizes = load_data(
args.instance_data_dir,
size=args.resolution,
center_crop=args.center_crop,
)
original_data = perturbed_data.clone()
original_data.requires_grad_(False)
target_latent_tensor = None
if args.target_image_path is not None and args.target_image_path != "":
# print(Style.BRIGHT+Back.BLUE+Fore.GREEN+'load target image from {}'.format(args.target_image_path))
target_image_path = Path(args.target_image_path)
assert target_image_path.is_file(), f"Target image path {target_image_path} does not exist"
target_image = Image.open(target_image_path).convert("RGB").resize((args.resolution, args.resolution))
target_image = np.array(target_image)[None].transpose(0, 3, 1, 2)
if args.cuda:
target_image_tensor = torch.from_numpy(target_image).to("cuda", dtype=weight_dtype) / 127.5 - 1.0
else:
target_image_tensor = torch.from_numpy(target_image).to(dtype=weight_dtype) / 127.5 - 1.0
target_latent_tensor = (
vae.encode(target_image_tensor).latent_dist.sample().to(dtype=weight_dtype) * vae.config.scaling_factor
)
target_image_tensor = target_image_tensor.to('cpu')
del target_image_tensor
#target_latent_tensor = target_latent_tensor.repeat(len(perturbed_data), 1, 1, 1).cuda()
f = [unet, text_encoder]
for i in range(args.max_train_steps):
f_sur = copy.deepcopy(f)
perturbed_data = pgd_attack(
args,
accelerator,
f_sur,
tokenizer,
noise_scheduler,
vae,
perturbed_data,
original_data,
target_latent_tensor,
weight_dtype,
)
del f_sur
if args.cuda:
gc.collect()
f = train_one_epoch(
args,
accelerator,
f,
tokenizer,
noise_scheduler,
vae,
perturbed_data,
prompts,
weight_dtype,
)
for model in f:
if model != None:
model.to('cpu')
if args.cuda:
gc.collect()
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
print("=======Epoch {} ends! Memory cost: {}======".format(i, mem_info.used / float(1073741824)))
else:
print("=======Epoch {} ends!======".format(i))
if (i + 1) % args.max_train_steps == 0:
save_folder = f"{args.output_dir}"
os.makedirs(save_folder, exist_ok=True)
noised_imgs = perturbed_data.detach().cpu()
origin_imgs = original_data.detach().cpu()
img_names = []
for filename in os.listdir(args.instance_data_dir):
if filename.endswith(".png") or filename.endswith(".jpg"):
img_names.append(str(filename))
for img_pixel, ori_img_pixel, img_name, img_size in zip(noised_imgs, origin_imgs, img_names, data_sizes):
save_path = os.path.join(save_folder, f"{i+1}_noise_{img_name}")
if not args.resize:
Image.fromarray(
(img_pixel * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0).numpy()
).save(save_path)
else:
ori_img_path = os.path.join(args.instance_data_dir, img_name)
ori_img = np.array(Image.open(ori_img_path).convert("RGB"))
ori_img_duzzy = np.array(Image.fromarray(
(ori_img_pixel * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0).numpy()
).resize(img_size), dtype=np.int32)
perturbed_img_duzzy = np.array(Image.fromarray(
(img_pixel * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0).numpy()
).resize(img_size), dtype=np.int32)
perturbation = perturbed_img_duzzy - ori_img_duzzy
assert perturbation.shape == ori_img.shape
perturbed_img = (ori_img + perturbation).clip(0, 255).astype(np.uint8)
# print("perturbation: {}, ori: {}, res: {}".format(
# perturbed_img_duzzy[:2, :2, :], ori_img_duzzy[:2, :2, :], perturbed_img_duzzy[:2, :2, :]))
Image.fromarray(perturbed_img).save(save_path)
print(f"==Saved misted image to {save_path}, size: {img_size}==")
# print(f"Saved noise at step {i+1} to {save_folder}")
del noised_imgs
def update_args_with_config(args, config):
'''
Update the default augments in args with config assigned by users
args list:
eps:
max train epoch:
data path:
class path:
output path:
device:
gpu normal,
gpu low vram,
cpu,
mode:
lunet, full
'''
args = parse_args()
eps, device, mode, resize, data_path, output_path, model_path, class_path, prompt, \
class_prompt, max_train_steps, max_f_train_steps, max_adv_train_steps, lora_lr, pgd_lr, \
rank, prior_loss_weight, fused_weight, constraint_mode, lpips_bound, lpips_weight = config
args.pgd_eps = float(eps)/255.0
if device == 'cpu':
args.cuda, args.low_vram_mode = False, False
else:
args.cuda, args.low_vram_mode = True, True
# if precision == 'bfloat16':
# args.mixed_precision = 'bf16'
# else:
# args.mixed_precision = 'fp16'
if mode == 'Mode 1':
args.mode = 'lunet'
elif mode == 'Mode 2':
args.mode = 'fused'
elif mode == 'Mode 3':
args.mode = 'anti-db'
if resize:
args.resize = True
assert os.path.exists(data_path) and os.path.exists(output_path)
args.instance_data_dir = data_path
args.output_dir = output_path
args.pretrained_model_name_or_path = model_path
args.class_data_dir = class_path
args.instance_prompt = prompt
args.class_prompt = class_prompt
args.max_train_steps = max_train_steps
args.max_f_train_steps = max_f_train_steps
args.max_adv_train_steps = max_adv_train_steps
args.learning_rate = lora_lr
args.pgd_alpha = pgd_lr
args.rank = rank
args.prior_loss_weight = prior_loss_weight
args.fused_weight = fused_weight
if constraint_mode == 'LPIPS':
args.constraint = 'lpips'
else:
args.constraint = 'eps'
args.lpips_bound = lpips_bound
args.lpips_weight = lpips_weight
return args
if __name__ == "__main__":
args = parse_args()
main(args)