File size: 3,952 Bytes
173ea2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import random

from BeamDiffusionModel.models.CoSeD.cross_attention import get_softmax
from BeamDiffusionModel.models.diffusionModel.configs.config_loader import CONFIG
from BeamDiffusionModel.tree.tree import BeamSearchTree
from BeamDiffusionModel.utils.utils import gen_img


def set_softmax(nodes, softmax, n_latents, n_max_latents):
    for node, softmax_value in zip(nodes, softmax):
        node.set_softmax(softmax_value, n_latents, n_max_latents)

def beam_inference(steps, latents_idx, n_seeds= 1, seeds=[], steps_back=2, beam_width=4, window_size=2, use_rand=True):
    while len(seeds) < n_seeds:
        seeds.append(random.randint(0, 10**6))
    captions = steps
    tree = BeamSearchTree(steps_back,beam_width,latents_idx,len(captions))
    nodes_to_explore = []
    for i, caption in enumerate(captions):
        if i == 0:
            for seed in seeds:
                latents, img = gen_img(caption, seed=seed)
                new_node = tree.add_node(tree.root, caption, i + 1, "Rand Seed", "Rand Seed",
                                         img, latents, None)
                nodes_to_explore.append(new_node)
        else:
            next_nodes = []
            for child, parent_node in enumerate(nodes_to_explore):
                parent_childs = []
                current_step_embeddings, current_image_embeddings = [], []
                if use_rand:
                    seed = random.randint(0, 10 ** 6)
                    latents, img  = gen_img(caption, seed=seed)
                    new_node = tree.add_node(parent_node, caption, i + 1, "Rand Seed", "Rand Seed",
                                             img, latents, None)
                    parent_childs.append(new_node)
                    current_step_embedding, current_image_embedding = new_node.get_features()
                    current_step_embeddings.append(current_step_embedding)
                    current_image_embeddings.append(current_image_embedding)

                ancestors = parent_node.get_ancestors(steps_back-1)

                for ancestor_idx, ancestor in enumerate(ancestors):
                    for latent in latents_idx:
                        ancestor_latent = ancestor.get_latent(latent)
                        latents, img = gen_img(caption, latent=ancestor_latent)
                        new_node = tree.add_node(parent_node, caption, i + 1, ancestor.step, latent,img, latents, None)
                        parent_childs.append(new_node)
                        current_step_embedding, current_image_embedding = new_node.get_features()
                        current_step_embeddings.append(current_step_embedding)
                        current_image_embeddings.append(current_image_embedding)

                if current_step_embeddings != []:
                    previous_steps_embeddings, previous_images_embeddings = tree.get_previous_steps_features(parent_childs[-1])
                    softmax = get_softmax(previous_steps_embeddings, previous_images_embeddings,
                                          current_step_embeddings,
                                          current_image_embeddings)
                    set_softmax(parent_childs, softmax, len(latents_idx), CONFIG["stable_diffusion"]["diffusion_settings"]["steps"])
                next_nodes += parent_childs

            if i >= window_size:
                print("-----------------------------------Cleaning some nodes-----------------------------------")
                best_paths = tree.get_n_best_paths(beam_width, i + 1)
                new_next_nodes = []
                for node in next_nodes:
                    for node_path in best_paths:
                        if node in node_path:
                            new_next_nodes.append(node)
                next_nodes = new_next_nodes
            nodes_to_explore = next_nodes
    return tree.best_path_imgs()