File size: 775 Bytes
b0369c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
"""
"""

import matplotlib.pyplot as plt
import numpy as np
import torchvision

import DirectedDiffusion

plt.rcParams["figure.figsize"] = [float(v)*1.5 for v in plt.rcParams["figure.figsize"]]

def plot_activation(filepath, unet, prompt, clip_tokenizer):
    a = DirectedDiffusion.AttnEditorUtils.get_attn(unet)
    splitted_prompt = prompt.split(" ")
    n = len(splitted_prompt)
    start = 0
    arrs = []
    for j in range(1):
        arr = []
        for i in range(start,start+n):
            b = a[..., i+1] / (a[..., i+1].max() + 0.001)
            arr.append(b.T)
        start += n
        arr = np.hstack(arr)
        arrs.append(arr)
    arrs = np.vstack(arrs).T
    plt.imshow(arrs, cmap='jet', vmin=0, vmax=.8)
    plt.title(prompt)
    plt.savefig(filepath)