Spaces:
Runtime error
Runtime error
File size: 4,425 Bytes
8c1c4dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
# Copyright 2022 Kakao Brain and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...modeling_utils import ModelMixin
class UnCLIPTextProjModel(ModelMixin, ConfigMixin):
"""
Utility class for CLIP embeddings. Used to combine the image and text embeddings into a format usable by the
decoder.
For more details, see the original paper: https://arxiv.org/abs/2204.06125 section 2.1
"""
@register_to_config
def __init__(
self,
*,
clip_extra_context_tokens: int = 4,
clip_embeddings_dim: int = 768,
time_embed_dim: int,
cross_attention_dim,
):
super().__init__()
self.learned_classifier_free_guidance_embeddings = self.create_parameter(
(clip_embeddings_dim,), dtype=paddle.get_default_dtype(), default_initializer=nn.initializer.Constant(0.0)
)
# parameters for additional clip time embeddings
self.embedding_proj = nn.Linear(clip_embeddings_dim, time_embed_dim)
self.clip_image_embeddings_project_to_time_embeddings = nn.Linear(clip_embeddings_dim, time_embed_dim)
# parameters for encoder hidden states
self.clip_extra_context_tokens = clip_extra_context_tokens
self.clip_extra_context_tokens_proj = nn.Linear(
clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim
)
self.encoder_hidden_states_proj = nn.Linear(clip_embeddings_dim, cross_attention_dim)
self.text_encoder_hidden_states_norm = nn.LayerNorm(cross_attention_dim)
def forward(self, *, image_embeddings, text_embeddings, text_encoder_hidden_states, do_classifier_free_guidance):
if do_classifier_free_guidance:
# Add the classifier free guidance embeddings to the image embeddings
image_embeddings_batch_size = image_embeddings.shape[0]
classifier_free_guidance_embeddings = self.learned_classifier_free_guidance_embeddings.unsqueeze(0)
classifier_free_guidance_embeddings = classifier_free_guidance_embeddings.expand(
[image_embeddings_batch_size, -1]
)
image_embeddings = paddle.concat([classifier_free_guidance_embeddings, image_embeddings], axis=0)
# The image embeddings batch size and the text embeddings batch size are equal
assert image_embeddings.shape[0] == text_embeddings.shape[0]
batch_size = text_embeddings.shape[0]
# "Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and
# adding CLIP embeddings to the existing timestep embedding, ...
time_projected_text_embeddings = self.embedding_proj(text_embeddings)
time_projected_image_embeddings = self.clip_image_embeddings_project_to_time_embeddings(image_embeddings)
additive_clip_time_embeddings = time_projected_image_embeddings + time_projected_text_embeddings
# ... and by projecting CLIP embeddings into four
# extra tokens of context that are concatenated to the sequence of outputs from the GLIDE text encoder"
clip_extra_context_tokens = self.clip_extra_context_tokens_proj(image_embeddings)
clip_extra_context_tokens = clip_extra_context_tokens.reshape([batch_size, -1, self.clip_extra_context_tokens])
text_encoder_hidden_states = self.encoder_hidden_states_proj(text_encoder_hidden_states)
text_encoder_hidden_states = self.text_encoder_hidden_states_norm(text_encoder_hidden_states)
text_encoder_hidden_states = text_encoder_hidden_states.transpose([0, 2, 1])
text_encoder_hidden_states = paddle.concat([clip_extra_context_tokens, text_encoder_hidden_states], axis=2)
return text_encoder_hidden_states, additive_clip_time_embeddings
|