|
import random
|
|
|
|
from BeamDiffusionModel.models.CoSeD.cross_attention import get_softmax
|
|
from BeamDiffusionModel.models.diffusionModel.configs.config_loader import CONFIG
|
|
from BeamDiffusionModel.tree.tree import BeamSearchTree
|
|
from BeamDiffusionModel.utils.utils import gen_img
|
|
|
|
|
|
def set_softmax(nodes, softmax, n_latents, n_max_latents):
|
|
for node, softmax_value in zip(nodes, softmax):
|
|
node.set_softmax(softmax_value, n_latents, n_max_latents)
|
|
|
|
def beam_inference(steps, latents_idx, n_seeds= 1, seeds=[], steps_back=2, beam_width=4, window_size=2, use_rand=True):
|
|
while len(seeds) < n_seeds:
|
|
seeds.append(random.randint(0, 10**6))
|
|
captions = steps
|
|
tree = BeamSearchTree(steps_back,beam_width,latents_idx,len(captions))
|
|
nodes_to_explore = []
|
|
for i, caption in enumerate(captions):
|
|
if i == 0:
|
|
for seed in seeds:
|
|
latents, img = gen_img(caption, seed=seed)
|
|
new_node = tree.add_node(tree.root, caption, i + 1, "Rand Seed", "Rand Seed",
|
|
img, latents, None)
|
|
nodes_to_explore.append(new_node)
|
|
else:
|
|
next_nodes = []
|
|
for child, parent_node in enumerate(nodes_to_explore):
|
|
parent_childs = []
|
|
current_step_embeddings, current_image_embeddings = [], []
|
|
if use_rand:
|
|
seed = random.randint(0, 10 ** 6)
|
|
latents, img = gen_img(caption, seed=seed)
|
|
new_node = tree.add_node(parent_node, caption, i + 1, "Rand Seed", "Rand Seed",
|
|
img, latents, None)
|
|
parent_childs.append(new_node)
|
|
current_step_embedding, current_image_embedding = new_node.get_features()
|
|
current_step_embeddings.append(current_step_embedding)
|
|
current_image_embeddings.append(current_image_embedding)
|
|
|
|
ancestors = parent_node.get_ancestors(steps_back-1)
|
|
|
|
for ancestor_idx, ancestor in enumerate(ancestors):
|
|
for latent in latents_idx:
|
|
ancestor_latent = ancestor.get_latent(latent)
|
|
latents, img = gen_img(caption, latent=ancestor_latent)
|
|
new_node = tree.add_node(parent_node, caption, i + 1, ancestor.step, latent,img, latents, None)
|
|
parent_childs.append(new_node)
|
|
current_step_embedding, current_image_embedding = new_node.get_features()
|
|
current_step_embeddings.append(current_step_embedding)
|
|
current_image_embeddings.append(current_image_embedding)
|
|
|
|
if current_step_embeddings != []:
|
|
previous_steps_embeddings, previous_images_embeddings = tree.get_previous_steps_features(parent_childs[-1])
|
|
softmax = get_softmax(previous_steps_embeddings, previous_images_embeddings,
|
|
current_step_embeddings,
|
|
current_image_embeddings)
|
|
set_softmax(parent_childs, softmax, len(latents_idx), CONFIG["stable_diffusion"]["diffusion_settings"]["steps"])
|
|
next_nodes += parent_childs
|
|
|
|
if i >= window_size:
|
|
print("-----------------------------------Cleaning some nodes-----------------------------------")
|
|
best_paths = tree.get_n_best_paths(beam_width, i + 1)
|
|
new_next_nodes = []
|
|
for node in next_nodes:
|
|
for node_path in best_paths:
|
|
if node in node_path:
|
|
new_next_nodes.append(node)
|
|
next_nodes = new_next_nodes
|
|
nodes_to_explore = next_nodes
|
|
return tree.best_path_imgs() |