File size: 3,944 Bytes
bb5cd12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import torch.nn.functional as F
from torchtyping import TensorType
from typing import Union, List


def top_p_filter(logits: TensorType[..., "vocab"], threshold: float = 0.9):
    """
    Nucleus sampling
    """
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    sorted_indices_to_remove = cum_probs > (1 - threshold)
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0

    sorted_logits[sorted_indices_to_remove] = float("-inf")
    return sorted_logits.scatter(1, sorted_indices, sorted_logits)


def top_k_filter(logits, k):
    """
    Top K sampling
    """
    assert k > 0
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float("-inf"))
    probs.scatter_(1, ind, val)
    return probs


def remove_tokens_after_eos(tensor, eos_token, image_token):
    # any tokens after and end of sequence token is produced are also set to the eos token, and removed
    eos_index = (tensor == eos_token).nonzero()
    if eos_index.any():
        tensor[eos_index[0] :] = eos_token

    tensor = tensor.tolist()
    return [i for i in tensor if (not i == image_token) and (not i == eos_token)]


@torch.no_grad()
def generate(
    model: "Magma",
    embeddings: TensorType["b", "s", "d"],
    max_steps: int = 100,
    temperature: float = 0.7,
    top_k: int = 0,
    top_p: float = 0.9,
    eos_token: int = None,
    decode: bool = True,
) -> Union[List[str], TensorType["b", "s"]]:
    """
    Generates captions for a batch of embeddings.

    :param model: The model to use for generation.
    :param embeddings: The embeddings to generate captions for.
    :param max_steps: The maximum number of steps to generate captions for.
    :param temperature: The temperature to use for sampling.
    :param top_k: value for top k sampling. If 0, no sampling will be used.
    :param top_p: value for top p sampling. If 0, no sampling will be used.
    :param eos_token: The token to use for end of sequence.
    :param decode: Whether to decode the output into text, or return the raw tokens.
    """

    # init values
    eos_token = eos_token or model.eos_token
    was_training = model.training
    model.eval()
    b, s, _ = embeddings.shape
    past_key_values = None

    # init output with image tokens
    out = torch.zeros((b, s), dtype=torch.long).to(model.device) + model.image_token

    # do sampling
    for i in range(max_steps):
        if i == 0:
            # initial input
            outputs = model.lm(
                inputs_embeds=embeddings,
                use_cache=True,
                past_key_values=past_key_values,
            )
        else:
            # now caching past k/v so we can use only the last token
            outputs = model.lm(
                input_ids=out[:, -1:], use_cache=True, past_key_values=past_key_values
            )

        logits = outputs.logits[:, -1, :].float()
        past_key_values = outputs.past_key_values

        # filter / temperature sample
        if temperature == 0.0:
            next_token = torch.argmax(logits, dim=-1)
        else:
            if top_k > 0:
                logits = top_k_filter(logits, k=top_k)
            if top_p > 0:
                logits = top_p_filter(logits, threshold=top_p)

            probs = F.softmax(logits / temperature, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)

        out = torch.cat((out, next_token), dim=-1)

        if eos_token is not None and (next_token == eos_token).all():
            break

    if decode:
        captions = []
        for b in out:
            b = remove_tokens_after_eos(b, eos_token, model.image_token)
            caption = model.tokenizer.decode(b)
            captions.append(caption)
        out = captions

    model.train(was_training)
    return out