File size: 14,108 Bytes
43b7e92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 |
from typing import Optional
import numpy as np
import torch
from torch import nn
from transformers import GPT2Config, GPT2LMHeadModel
from transformers.modeling_utils import ModuleUtilsMixin
from ...configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin
# Modified from ClipCaptionModel in https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py
class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
"""
Text decoder model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is used to
generate text from the UniDiffuser image-text embedding.
Parameters:
prefix_length (`int`):
Max number of prefix tokens that will be supplied to the model.
prefix_inner_dim (`int`):
The hidden size of the incoming prefix embeddings. For UniDiffuser, this would be the hidden dim of the
CLIP text encoder.
prefix_hidden_dim (`int`, *optional*):
Hidden dim of the MLP if we encode the prefix.
vocab_size (`int`, *optional*, defaults to 50257):
Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`].
n_positions (`int`, *optional*, defaults to 1024):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
n_embd (`int`, *optional*, defaults to 768):
Dimensionality of the embeddings and hidden states.
n_layer (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
n_head (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
n_inner (`int`, *optional*, defaults to None):
Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
activation_function (`str`, *optional*, defaults to `"gelu"`):
Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
resid_pdrop (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
embd_pdrop (`float`, *optional*, defaults to 0.1):
The dropout ratio for the embeddings.
attn_pdrop (`float`, *optional*, defaults to 0.1):
The dropout ratio for the attention.
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
The epsilon to use in the layer normalization layers.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
scale_attn_weights (`bool`, *optional*, defaults to `True`):
Scale attention weights by dividing by sqrt(hidden_size)..
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models).
scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`):
Whether to additionally scale attention weights by `1 / layer_idx + 1`.
reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`):
Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention
dot-product/softmax to float() when training with mixed precision.
"""
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"]
@register_to_config
def __init__(
self,
prefix_length: int,
prefix_inner_dim: int,
prefix_hidden_dim: Optional[int] = None,
vocab_size: int = 50257, # Start of GPT2 config args
n_positions: int = 1024,
n_embd: int = 768,
n_layer: int = 12,
n_head: int = 12,
n_inner: Optional[int] = None,
activation_function: str = "gelu_new",
resid_pdrop: float = 0.1,
embd_pdrop: float = 0.1,
attn_pdrop: float = 0.1,
layer_norm_epsilon: float = 1e-5,
initializer_range: float = 0.02,
scale_attn_weights: bool = True,
use_cache: bool = True,
scale_attn_by_inverse_layer_idx: bool = False,
reorder_and_upcast_attn: bool = False,
):
super().__init__()
self.prefix_length = prefix_length
if prefix_inner_dim != n_embd and prefix_hidden_dim is None:
raise ValueError(
f"`prefix_hidden_dim` cannot be `None` when `prefix_inner_dim`: {prefix_hidden_dim} and"
f" `n_embd`: {n_embd} are not equal."
)
self.prefix_inner_dim = prefix_inner_dim
self.prefix_hidden_dim = prefix_hidden_dim
self.encode_prefix = (
nn.Linear(self.prefix_inner_dim, self.prefix_hidden_dim)
if self.prefix_hidden_dim is not None
else nn.Identity()
)
self.decode_prefix = (
nn.Linear(self.prefix_hidden_dim, n_embd) if self.prefix_hidden_dim is not None else nn.Identity()
)
gpt_config = GPT2Config(
vocab_size=vocab_size,
n_positions=n_positions,
n_embd=n_embd,
n_layer=n_layer,
n_head=n_head,
n_inner=n_inner,
activation_function=activation_function,
resid_pdrop=resid_pdrop,
embd_pdrop=embd_pdrop,
attn_pdrop=attn_pdrop,
layer_norm_epsilon=layer_norm_epsilon,
initializer_range=initializer_range,
scale_attn_weights=scale_attn_weights,
use_cache=use_cache,
scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx,
reorder_and_upcast_attn=reorder_and_upcast_attn,
)
self.transformer = GPT2LMHeadModel(gpt_config)
def forward(
self,
input_ids: torch.Tensor,
prefix_embeds: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
):
"""
Args:
input_ids (`torch.Tensor` of shape `(N, max_seq_len)`):
Text tokens to use for inference.
prefix_embeds (`torch.Tensor` of shape `(N, prefix_length, 768)`):
Prefix embedding to preprend to the embedded tokens.
attention_mask (`torch.Tensor` of shape `(N, prefix_length + max_seq_len, 768)`, *optional*):
Attention mask for the prefix embedding.
labels (`torch.Tensor`, *optional*):
Labels to use for language modeling.
"""
embedding_text = self.transformer.transformer.wte(input_ids)
hidden = self.encode_prefix(prefix_embeds)
prefix_embeds = self.decode_prefix(hidden)
embedding_cat = torch.cat((prefix_embeds, embedding_text), dim=1)
if labels is not None:
dummy_token = self.get_dummy_token(input_ids.shape[0], input_ids.device)
labels = torch.cat((dummy_token, input_ids), dim=1)
out = self.transformer(inputs_embeds=embedding_cat, labels=labels, attention_mask=attention_mask)
if self.prefix_hidden_dim is not None:
return out, hidden
else:
return out
def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
def encode(self, prefix):
return self.encode_prefix(prefix)
@torch.no_grad()
def generate_captions(self, features, eos_token_id, device):
"""
Generate captions given text embedding features. Returns list[L].
Args:
features (`torch.Tensor` of shape `(B, L, D)`):
Text embedding features to generate captions from.
eos_token_id (`int`):
The token ID of the EOS token for the text decoder model.
device:
Device to perform text generation on.
Returns:
`List[str]`: A list of strings generated from the decoder model.
"""
features = torch.split(features, 1, dim=0)
generated_tokens = []
generated_seq_lengths = []
for feature in features:
feature = self.decode_prefix(feature.to(device)) # back to the clip feature
# Only support beam search for now
output_tokens, seq_lengths = self.generate_beam(
input_embeds=feature, device=device, eos_token_id=eos_token_id
)
generated_tokens.append(output_tokens[0])
generated_seq_lengths.append(seq_lengths[0])
generated_tokens = torch.stack(generated_tokens)
generated_seq_lengths = torch.stack(generated_seq_lengths)
return generated_tokens, generated_seq_lengths
@torch.no_grad()
def generate_beam(
self,
input_ids=None,
input_embeds=None,
device=None,
beam_size: int = 5,
entry_length: int = 67,
temperature: float = 1.0,
eos_token_id: Optional[int] = None,
):
"""
Generates text using the given tokenizer and text prompt or token embedding via beam search. This
implementation is based on the beam search implementation from the [original UniDiffuser
code](https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py#L89).
Args:
eos_token_id (`int`, *optional*):
The token ID of the EOS token for the text decoder model.
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
Tokenizer indices of input sequence tokens in the vocabulary. One of `input_ids` and `input_embeds`
must be supplied.
input_embeds (`torch.Tensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
An embedded representation to directly pass to the transformer as a prefix for beam search. One of
`input_ids` and `input_embeds` must be supplied.
device:
The device to perform beam search on.
beam_size (`int`, *optional*, defaults to `5`):
The number of best states to store during beam search.
entry_length (`int`, *optional*, defaults to `67`):
The number of iterations to run beam search.
temperature (`float`, *optional*, defaults to 1.0):
The temperature to use when performing the softmax over logits from the decoding model.
Returns:
`Tuple(torch.Tensor, torch.Tensor)`: A tuple of tensors where the first element is a tensor of generated
token sequences sorted by score in descending order, and the second element is the sequence lengths
corresponding to those sequences.
"""
# Generates text until stop_token is reached using beam search with the desired beam size.
stop_token_index = eos_token_id
tokens = None
scores = None
seq_lengths = torch.ones(beam_size, device=device, dtype=torch.int)
is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
if input_embeds is not None:
generated = input_embeds
else:
generated = self.transformer.transformer.wte(input_ids)
for i in range(entry_length):
outputs = self.transformer(inputs_embeds=generated)
logits = outputs.logits
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
logits = logits.softmax(-1).log()
if scores is None:
scores, next_tokens = logits.topk(beam_size, -1)
generated = generated.expand(beam_size, *generated.shape[1:])
next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
if tokens is None:
tokens = next_tokens
else:
tokens = tokens.expand(beam_size, *tokens.shape[1:])
tokens = torch.cat((tokens, next_tokens), dim=1)
else:
logits[is_stopped] = -float(np.inf)
logits[is_stopped, 0] = 0
scores_sum = scores[:, None] + logits
seq_lengths[~is_stopped] += 1
scores_sum_average = scores_sum / seq_lengths[:, None]
scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
next_tokens_source = next_tokens // scores_sum.shape[1]
seq_lengths = seq_lengths[next_tokens_source]
next_tokens = next_tokens % scores_sum.shape[1]
next_tokens = next_tokens.unsqueeze(1)
tokens = tokens[next_tokens_source]
tokens = torch.cat((tokens, next_tokens), dim=1)
generated = generated[next_tokens_source]
scores = scores_sum_average * seq_lengths
is_stopped = is_stopped[next_tokens_source]
next_token_embed = self.transformer.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
generated = torch.cat((generated, next_token_embed), dim=1)
is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
if is_stopped.all():
break
scores = scores / seq_lengths
order = scores.argsort(descending=True)
# tokens tensors are already padded to max_seq_length
output_texts = [tokens[i] for i in order]
output_texts = torch.stack(output_texts, dim=0)
seq_lengths = torch.tensor([seq_lengths[i] for i in order], dtype=seq_lengths.dtype)
return output_texts, seq_lengths
|