Spaces:
Runtime error
Runtime error
""" | |
""" | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torchvision | |
import DirectedDiffusion | |
plt.rcParams["figure.figsize"] = [float(v)*1.5 for v in plt.rcParams["figure.figsize"]] | |
def plot_activation(filepath, unet, prompt, clip_tokenizer): | |
a = DirectedDiffusion.AttnEditorUtils.get_attn(unet) | |
splitted_prompt = prompt.split(" ") | |
n = len(splitted_prompt) | |
start = 0 | |
arrs = [] | |
for j in range(1): | |
arr = [] | |
for i in range(start,start+n): | |
b = a[..., i+1] / (a[..., i+1].max() + 0.001) | |
arr.append(b.T) | |
start += n | |
arr = np.hstack(arr) | |
arrs.append(arr) | |
arrs = np.vstack(arrs).T | |
plt.imshow(arrs, cmap='jet', vmin=0, vmax=.8) | |
plt.title(prompt) | |
plt.savefig(filepath) | |