File size: 3,950 Bytes
ece766c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import torch.nn.functional as F

prompt_dataset = [
    "Portrait of an astronaut in space, detailed starry background, reflective helmet,",
    "Painting of a floating island with giant clock gears, populated with mythical creatures,",
    "Landscape of a Japanese garden in autumn, with a bridge over a koi pond,",
    "Painting representing the sound of jazz music, using vibrant colors and erratic shapes,",
    "Painting of a modern smartphone with classic art pieces appearing on the screen,",
    "Battle scene with futuristic robots and a golden palace in the background,",
    "Scene of a bustling city market with different perspectives of people and stalls,",
    "Scene of a ship sailing in a stormy sea, with dramatic lighting and powerful waves,",
    "Portraint of a female botanist surrounded by exotic plants in a greenhouse,",
    "Painting of an ancient castle at night, with a full moon, gargoyles, and shadows,",
]

style_dataset = [
    "Art Nouveau",
    "Romantic",
    "Cubist",
    "Baroque",
    "Pop Art",
    "Abstract",
    "Impressionist",
    "Surrealist",
    "Renaissance",
    "Pointillism",
]



class attack_mixin:
    def __call__(
        self,
        latents: torch.Tensor,
        timesteps: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        unet: torch.nn.Module,
        target_tensor: torch.Tensor,
        noise_scheduler
    ):
        raise NotImplementedError
    
class AdvDM(attack_mixin):
    """
    This attack aims to maximize the training loss of diffusion model
    """
    def __call__(
        self,
        latents: torch.Tensor,
        noise: torch.Tensor,
        timesteps: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        unet: torch.nn.Module,
        text_encoder: torch.nn.Module,
        input_ids,
        target_tensor: torch.Tensor,
        noise_scheduler
    ):
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

        # Get the text embedding for conditioning
        encoder_hidden_states = text_encoder(input_ids)[0]

        # Predict the noise residual
        model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

        # Get the target for loss depending on the prediction type
        if noise_scheduler.config.prediction_type == "epsilon":
            target = noise
        elif noise_scheduler.config.prediction_type == "v_prediction":
            target = noise_scheduler.get_velocity(latents, noise, timesteps)
        else:
            raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

        unet.zero_grad()
        text_encoder.zero_grad()
        loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

        # target-shift loss
        if target_tensor is not None:
            xtm1_pred = torch.cat(
                [
                    noise_scheduler.step(
                        model_pred[idx : idx + 1],
                        timesteps[idx : idx + 1],
                        noisy_latents[idx : idx + 1],
                    ).prev_sample
                    for idx in range(len(model_pred))
                ]
            )
            xtm1_target = noise_scheduler.add_noise(target_tensor, noise, timesteps - 1)
            loss = loss - F.mse_loss(xtm1_pred, xtm1_target)

        return loss
    
class LatentAttack(attack_mixin):
    """
    This attack aims to minimize the l2 distance between latent and target_tensor
    """
    def __call__(
        self,
        latents: torch.Tensor,
        timesteps: torch.Tensor=None,
        encoder_hidden_states: torch.Tensor=None,
        unet: torch.nn.Module=None,
        target_tensor: torch.Tensor=None,
        noise_scheduler=None
    ):
        if target_tensor == None:
            raise ValueError("Need a target tensor for pre-attack")
        loss = - F.mse_loss(latents, target_tensor, reduction="mean")
        return loss