mist-v2 / attacks /utils.py
AeroXi's picture
Upload folder using huggingface_hub
ece766c
raw history blame
No virus
3.95 kB
import torch
import torch.nn.functional as F
prompt_dataset = [
"Portrait of an astronaut in space, detailed starry background, reflective helmet,",
"Painting of a floating island with giant clock gears, populated with mythical creatures,",
"Landscape of a Japanese garden in autumn, with a bridge over a koi pond,",
"Painting representing the sound of jazz music, using vibrant colors and erratic shapes,",
"Painting of a modern smartphone with classic art pieces appearing on the screen,",
"Battle scene with futuristic robots and a golden palace in the background,",
"Scene of a bustling city market with different perspectives of people and stalls,",
"Scene of a ship sailing in a stormy sea, with dramatic lighting and powerful waves,",
"Portraint of a female botanist surrounded by exotic plants in a greenhouse,",
"Painting of an ancient castle at night, with a full moon, gargoyles, and shadows,",
]
style_dataset = [
"Art Nouveau",
"Romantic",
"Cubist",
"Baroque",
"Pop Art",
"Abstract",
"Impressionist",
"Surrealist",
"Renaissance",
"Pointillism",
]
class attack_mixin:
def __call__(
self,
latents: torch.Tensor,
timesteps: torch.Tensor,
encoder_hidden_states: torch.Tensor,
unet: torch.nn.Module,
target_tensor: torch.Tensor,
noise_scheduler
):
raise NotImplementedError
class AdvDM(attack_mixin):
"""
This attack aims to maximize the training loss of diffusion model
"""
def __call__(
self,
latents: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
encoder_hidden_states: torch.Tensor,
unet: torch.nn.Module,
text_encoder: torch.nn.Module,
input_ids,
target_tensor: torch.Tensor,
noise_scheduler
):
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(input_ids)[0]
# Predict the noise residual
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
unet.zero_grad()
text_encoder.zero_grad()
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# target-shift loss
if target_tensor is not None:
xtm1_pred = torch.cat(
[
noise_scheduler.step(
model_pred[idx : idx + 1],
timesteps[idx : idx + 1],
noisy_latents[idx : idx + 1],
).prev_sample
for idx in range(len(model_pred))
]
)
xtm1_target = noise_scheduler.add_noise(target_tensor, noise, timesteps - 1)
loss = loss - F.mse_loss(xtm1_pred, xtm1_target)
return loss
class LatentAttack(attack_mixin):
"""
This attack aims to minimize the l2 distance between latent and target_tensor
"""
def __call__(
self,
latents: torch.Tensor,
timesteps: torch.Tensor=None,
encoder_hidden_states: torch.Tensor=None,
unet: torch.nn.Module=None,
target_tensor: torch.Tensor=None,
noise_scheduler=None
):
if target_tensor == None:
raise ValueError("Need a target tensor for pre-attack")
loss = - F.mse_loss(latents, target_tensor, reduction="mean")
return loss