eggarsway's picture
add DD
b0369c2
"""
"""
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import DirectedDiffusion
plt.rcParams["figure.figsize"] = [float(v)*1.5 for v in plt.rcParams["figure.figsize"]]
def plot_activation(filepath, unet, prompt, clip_tokenizer):
a = DirectedDiffusion.AttnEditorUtils.get_attn(unet)
splitted_prompt = prompt.split(" ")
n = len(splitted_prompt)
start = 0
arrs = []
for j in range(1):
arr = []
for i in range(start,start+n):
b = a[..., i+1] / (a[..., i+1].max() + 0.001)
arr.append(b.T)
start += n
arr = np.hstack(arr)
arrs.append(arr)
arrs = np.vstack(arrs).T
plt.imshow(arrs, cmap='jet', vmin=0, vmax=.8)
plt.title(prompt)
plt.savefig(filepath)