File size: 3,952 Bytes
173ea2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import random
from BeamDiffusionModel.models.CoSeD.cross_attention import get_softmax
from BeamDiffusionModel.models.diffusionModel.configs.config_loader import CONFIG
from BeamDiffusionModel.tree.tree import BeamSearchTree
from BeamDiffusionModel.utils.utils import gen_img
def set_softmax(nodes, softmax, n_latents, n_max_latents):
for node, softmax_value in zip(nodes, softmax):
node.set_softmax(softmax_value, n_latents, n_max_latents)
def beam_inference(steps, latents_idx, n_seeds= 1, seeds=[], steps_back=2, beam_width=4, window_size=2, use_rand=True):
while len(seeds) < n_seeds:
seeds.append(random.randint(0, 10**6))
captions = steps
tree = BeamSearchTree(steps_back,beam_width,latents_idx,len(captions))
nodes_to_explore = []
for i, caption in enumerate(captions):
if i == 0:
for seed in seeds:
latents, img = gen_img(caption, seed=seed)
new_node = tree.add_node(tree.root, caption, i + 1, "Rand Seed", "Rand Seed",
img, latents, None)
nodes_to_explore.append(new_node)
else:
next_nodes = []
for child, parent_node in enumerate(nodes_to_explore):
parent_childs = []
current_step_embeddings, current_image_embeddings = [], []
if use_rand:
seed = random.randint(0, 10 ** 6)
latents, img = gen_img(caption, seed=seed)
new_node = tree.add_node(parent_node, caption, i + 1, "Rand Seed", "Rand Seed",
img, latents, None)
parent_childs.append(new_node)
current_step_embedding, current_image_embedding = new_node.get_features()
current_step_embeddings.append(current_step_embedding)
current_image_embeddings.append(current_image_embedding)
ancestors = parent_node.get_ancestors(steps_back-1)
for ancestor_idx, ancestor in enumerate(ancestors):
for latent in latents_idx:
ancestor_latent = ancestor.get_latent(latent)
latents, img = gen_img(caption, latent=ancestor_latent)
new_node = tree.add_node(parent_node, caption, i + 1, ancestor.step, latent,img, latents, None)
parent_childs.append(new_node)
current_step_embedding, current_image_embedding = new_node.get_features()
current_step_embeddings.append(current_step_embedding)
current_image_embeddings.append(current_image_embedding)
if current_step_embeddings != []:
previous_steps_embeddings, previous_images_embeddings = tree.get_previous_steps_features(parent_childs[-1])
softmax = get_softmax(previous_steps_embeddings, previous_images_embeddings,
current_step_embeddings,
current_image_embeddings)
set_softmax(parent_childs, softmax, len(latents_idx), CONFIG["stable_diffusion"]["diffusion_settings"]["steps"])
next_nodes += parent_childs
if i >= window_size:
print("-----------------------------------Cleaning some nodes-----------------------------------")
best_paths = tree.get_n_best_paths(beam_width, i + 1)
new_next_nodes = []
for node in next_nodes:
for node_path in best_paths:
if node in node_path:
new_next_nodes.append(node)
next_nodes = new_next_nodes
nodes_to_explore = next_nodes
return tree.best_path_imgs() |