MMM-Demo / models /t2m_trans_uplow.py
samadi10's picture
Added necessary files
eeaa83d
raw
history blame
28.1 kB
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.distributions import Categorical
import models.pos_encoding as pos_encoding
from exit.utils import cosine_schedule, uniform, top_k, gumbel_sample, top_p
from tqdm import tqdm
from einops import rearrange, repeat
from exit.utils import get_model, generate_src_mask
class PatchUpSampling(nn.Module):
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.up_sampling = nn.Linear(dim, 4 * dim, bias=False)
self.norm = norm_layer(dim)
def forward(self, x):
"""
x: B, F, C
"""
x = self.norm(x)
x = self.up_sampling(x)
x0 = x[:, :, 0::4]
x1 = x[:, :, 1::4]
x2 = x[:, :, 2::4]
x3 = x[:, :, 3::4]
x = torch.cat([x0, x1, x2, x3], 1)
return x
class Decoder_Transformer(nn.Module):
def __init__(self,
code_dim=1024,
embed_dim=512,
output_dim=263,
block_size=16,
num_layers=2,
n_head=8,
drop_out_rate=0.1,
fc_rate=4):
super().__init__()
self.joint_embed = nn.Linear(code_dim, embed_dim)
self.drop = nn.Dropout(drop_out_rate)
# transformer block
self.blocks = nn.Sequential(*[Block(embed_dim, block_size, n_head, drop_out_rate, fc_rate) for _ in range(num_layers)])
self.up_sample = PatchUpSampling(embed_dim)
self.pos_embed = pos_encoding.PositionEmbedding(block_size, embed_dim, 0.0, False)
self.head = nn.Sequential(nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, output_dim))
self.block_size = block_size
self.n_head = n_head
self.apply(self._init_weights)
def get_block_size(self):
return self.block_size
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, token_embeddings):
# token_embeddings = self.tok_emb(idx)
# B, T = src_mask.shape
# src_mask = src_mask.view(B, 1, 1, T).repeat(1, self.n_head, T, 1)
token_embeddings = token_embeddings.permute(0, 2, 1)
token_embeddings = self.joint_embed(token_embeddings)
x = self.pos_embed(token_embeddings)
# block_step_len = int(len(self.blocks)/3)
# mask_temp = get_attn_mask(_range=3, _max=x.shape[1]).to(src_mask.device)
# eye = torch.eye(x.shape[1]).unsqueeze(0).unsqueeze(0).to(src_mask.device).bool()
# src_mask = src_mask*mask_temp + eye
for block in self.blocks:
x = block(x)
x = self.up_sample(x)
# mask_2 = mask_1.repeat(1, 1, 2, 2)
# for block in self.blocks[block_step_len:2*block_step_len]:
# x = block(x, mask_2)
# x = self.up_sample(x)
# mask_3 = mask_2.repeat(1, 1, 2, 2)
# for block in self.blocks[2*block_step_len:]:
# x = block(x, mask_3)
x = self.head(x).permute(0, 2, 1)
return x
# https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py#L342C9-L343C33
class PatchMerging(nn.Module):
def __init__(self, input_feats, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * input_feats, dim, bias=False)
self.norm = norm_layer(4 * input_feats)
def forward(self, x):
"""
x: B, F, C
"""
x0 = x[:, 0::4, :] # B F/2 C
x1 = x[:, 1::4, :]
x2 = x[:, 2::4, :] # B F/2 C
x3 = x[:, 3::4, :]
x = torch.cat([x0, x1, x2, x3], -1) # B F/2 2*C
x = self.norm(x)
x = self.reduction(x)
return x
class Encoder_Transformer(nn.Module):
def __init__(self,
input_feats=1024,
embed_dim=512,
output_dim=263,
block_size=16,
num_layers=2,
n_head=8,
drop_out_rate=0.1,
fc_rate=4):
super().__init__()
self.joint_embed = nn.Linear(input_feats, embed_dim)
self.drop = nn.Dropout(drop_out_rate)
# transformer block
self.blocks = nn.Sequential(*[Block(embed_dim, block_size, n_head, drop_out_rate, fc_rate) for _ in range(num_layers)])
# self.patch_merging1 = PatchMerging(input_feats, embed_dim)
# self.patch_merging2 = PatchMerging(embed_dim)
self.weighted_mean_norm = nn.LayerNorm(embed_dim)
self.weighted_mean = torch.nn.Conv1d(in_channels=block_size, out_channels=1, kernel_size=1)
self.pos_embed = pos_encoding.PositionEmbedding(block_size, embed_dim, 0.0, False)
self.head = nn.Sequential(nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, output_dim))
self.block_size = block_size
self.n_head = n_head
self.apply(self._init_weights)
def get_block_size(self):
return self.block_size
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, joints):
# B, T = src_mask.shape
# src_mask = src_mask.view(B, 1, 1, T).repeat(1, self.n_head, T, 1)
joints = joints.permute(0,2,1)
# token_embeddings = self.joint_embed(joints)
block_step_len = int(len(self.blocks)/3)
x = self.joint_embed(joints)
token_len = int(x.shape[1]/self.block_size)
_original_shape = list(x.shape)
x = x.view(x.shape[0]*token_len, self.block_size, -1)
# mask_temp = get_attn_mask(_range=3, _max=x.shape[1]).to(src_mask.device)
# eye = torch.eye(x.shape[1]).unsqueeze(0).unsqueeze(0).to(src_mask.device).bool()
# src_mask = src_mask*mask_temp + eye
x = self.pos_embed(x)
for block in self.blocks:
x = block(x)
x = self.weighted_mean_norm(x)
x = self.weighted_mean(x)
_original_shape[1] = int(_original_shape[1] / self.block_size)
x = x.view(*_original_shape)
# for block in self.blocks[block_step_len:2*block_step_len]:
# x = block(x)
# x = self.patch_merging2(x)
# for block in self.blocks[2*block_step_len:]:
# x = block(x)
x = self.head(x).permute(0, 2, 1)
return x
class Text2Motion_Transformer(nn.Module):
def __init__(self,
vqvae,
num_vq=1024,
embed_dim=512,
clip_dim=512,
block_size=16,
num_layers=2,
num_local_layer=0,
n_head=8,
drop_out_rate=0.1,
fc_rate=4):
super().__init__()
self.n_head = n_head
self.trans_base = CrossCondTransBase(vqvae, num_vq, embed_dim, clip_dim, block_size, num_layers, num_local_layer, n_head, drop_out_rate, fc_rate)
self.trans_head = CrossCondTransHead(num_vq, embed_dim, block_size, num_layers, n_head, drop_out_rate, fc_rate)
self.block_size = block_size
self.num_vq = num_vq
def get_block_size(self):
return self.block_size
def forward(self, *args, type='forward', **kwargs):
'''type=[forward, sample]'''
if type=='forward':
return self.forward_function(*args, **kwargs)
elif type=='sample':
return self.sample(*args, **kwargs)
elif type=='inpaint':
return self.inpaint(*args, **kwargs)
else:
raise ValueError(f'Unknown "{type}" type')
def get_attn_mask(self, src_mask, att_txt=None, txt_mark=None):
if att_txt is None:
att_txt = torch.tensor([[True]]*src_mask.shape[0]).to(src_mask.device)
src_mask = torch.cat([att_txt, src_mask], dim=1)
B, T = src_mask.shape
src_mask = src_mask.view(B, 1, 1, T).repeat(1, self.n_head, T, 1)
if txt_mark is not None:
att_txt_txt = torch.tensor([[True]]*txt_mark.shape[0]).to(txt_mark.device)
txt_mark = torch.cat([att_txt_txt, txt_mark], dim=1)
src_mask[:, :, :, 0] = txt_mark.view(B, 1, T).repeat(1, self.n_head, 1)
return src_mask
def forward_function(self, idx_upper, idx_lower, clip_feature, src_mask=None, att_txt=None, txt_mark=None, word_emb=None):
# MLD:
# if att_txt is None:
# att_txt = torch.tensor([[True]]*src_mask.shape[0]).to(src_mask.device)
# src_mask = torch.cat([att_txt, src_mask], dim=1)
# logits = self.skip_trans(idxs, clip_feature, src_mask)
# T2M-BD
if src_mask is not None:
src_mask = self.get_attn_mask(src_mask, att_txt, txt_mark)
feat = self.trans_base(idx_upper, idx_lower, clip_feature, src_mask, word_emb)
logits = self.trans_head(feat, src_mask)
return logits
def sample(self, clip_feature, idx_lower, word_emb, m_length=None, if_test=False, rand_pos=False, CFG=-1):
max_steps = 20
max_length = 49
batch_size = clip_feature.shape[0]
mask_id = self.num_vq + 2
pad_id = self.num_vq + 1
end_id = self.num_vq
shape = (batch_size, self.block_size - 1)
topk_filter_thres = .9
starting_temperature = 1.0
scores = torch.ones(shape, dtype = torch.float32, device = clip_feature.device)
m_tokens_len = torch.ceil((m_length)/4)
src_token_mask = generate_src_mask(self.block_size-1, m_tokens_len+1)
src_token_mask_noend = generate_src_mask(self.block_size-1, m_tokens_len)
ids = torch.full(shape, mask_id, dtype = torch.long, device = clip_feature.device)
# [TODO] confirm that these 2 lines are not neccessary (repeated below and maybe don't need them at all)
ids[~src_token_mask] = pad_id # [INFO] replace with pad id
ids.scatter_(-1, m_tokens_len[..., None].long(), end_id) # [INFO] replace with end id
### PlayGround ####
# score high = mask
# m_tokens_len = torch.ceil((m_length)/4)
# src_token_mask = generate_src_mask(self.block_size-1, m_tokens_len+1)
# # mock
# timestep = torch.tensor(.5)
# rand_mask_prob = cosine_schedule(timestep)
# scores = torch.arange(self.block_size - 1).repeat(batch_size, 1).cuda()
# scores[1] = torch.flip(torch.arange(self.block_size - 1), dims=(0,))
# # iteration
# num_token_masked = (rand_mask_prob * m_tokens_len).int().clip(min=1)
# scores[~src_token_mask] = -1e5
# masked_indices = scores.argsort(dim=-1, descending=True) # This is flipped the order. The highest score is the first in order.
# masked_indices = masked_indices < num_token_masked.unsqueeze(-1) # So it can filter out by "< num_token_masked". We want to filter the high score as a mask
# ids[masked_indices] = mask_id
#########################
temp = []
sample_max_steps = torch.round(max_steps/max_length*m_tokens_len) + 1e-8
for step in range(max_steps):
timestep = torch.clip(step/(sample_max_steps), max=1)
rand_mask_prob = cosine_schedule(timestep) # timestep #
num_token_masked = (rand_mask_prob * m_tokens_len).long().clip(min=1)
# [INFO] rm no motion frames
scores[~src_token_mask_noend] = 0
scores = scores/scores.sum(-1)[:, None] # normalize only unmasked token
# if rand_pos:
# sorted_score_indices = scores.multinomial(scores.shape[-1], replacement=False) # stocastic
# else:
sorted, sorted_score_indices = scores.sort(descending=True) # deterministic
ids[~src_token_mask] = pad_id # [INFO] replace with pad id
ids.scatter_(-1, m_tokens_len[..., None].long(), end_id) # [INFO] replace with end id
## [INFO] Replace "mask_id" to "ids" that have highest "num_token_masked" "scores"
select_masked_indices = generate_src_mask(sorted_score_indices.shape[1], num_token_masked)
# [INFO] repeat last_id to make it scatter_ the existing last ids.
last_index = sorted_score_indices.gather(-1, num_token_masked.unsqueeze(-1)-1)
sorted_score_indices = sorted_score_indices * select_masked_indices + (last_index*~select_masked_indices)
ids.scatter_(-1, sorted_score_indices, mask_id)
# if torch.isclose(timestep, torch.tensor(0.7647), atol=.01):
# print('masked_indices:', ids[0], src_token_mask[0])
if CFG!=-1:
# print('ids:', ids.shape, clip_feature.shape, src_token_mask.shape)
_ids = ids.repeat(2,1)
_clip_feature = clip_feature.repeat(2,1)
_src_token_mask = src_token_mask.repeat(2,1)
att_txt = torch.cat( (torch.ones((batch_size,1), dtype=torch.bool),
torch.zeros((batch_size,1), dtype=torch.bool) )).to(_ids.device)
logits = self.forward(_ids, idx_lower, _clip_feature, _src_token_mask, att_txt)[:,1:]
logits_textcond = logits[:batch_size]
logits_uncond = logits[batch_size:]
# logits = (1-CFG)*logits_textcond + CFG*logits_uncond
logits = (1+CFG)*logits_textcond - CFG*logits_uncond
else:
logits = self.forward(ids, idx_lower, clip_feature, src_token_mask, word_emb=word_emb)[:,1:]
filtered_logits = logits #top_p(logits, .5) # #top_k(logits, topk_filter_thres)
if rand_pos:
temperature = 1 #starting_temperature * (steps_until_x0 / timesteps) # temperature is annealed
else:
temperature = 0 #starting_temperature * (steps_until_x0 / timesteps) # temperature is annealed
# [INFO] if temperature==0: is equal to argmax (filtered_logits.argmax(dim = -1))
# pred_ids = filtered_logits.argmax(dim = -1)
pred_ids = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)
is_mask = ids == mask_id
temp.append(is_mask[:1])
# mid = is_mask[0][:m_tokens_len[0].int()]
# mid = mid.nonzero(as_tuple=True)[0]
# print(is_mask[0].sum(), m_tokens_len[0])
ids = torch.where(
is_mask,
pred_ids,
ids
)
# if timestep == 1.:
# print(probs_without_temperature.shape)
probs_without_temperature = logits.softmax(dim = -1)
scores = 1 - probs_without_temperature.gather(-1, pred_ids[..., None])
scores = rearrange(scores, '... 1 -> ...')
scores = scores.masked_fill(~is_mask, 0)
if if_test:
return ids, temp
return ids
def inpaint(self, first_tokens, last_tokens, clip_feature=None, inpaint_len=2, rand_pos=False):
# support only one sample
assert first_tokens.shape[0] == 1
assert last_tokens.shape[0] == 1
max_steps = 20
max_length = 49
batch_size = first_tokens.shape[0]
mask_id = self.num_vq + 2
pad_id = self.num_vq + 1
end_id = self.num_vq
shape = (batch_size, self.block_size - 1)
scores = torch.ones(shape, dtype = torch.float32, device = first_tokens.device)
# force add first / last tokens
first_partition_pos_idx = first_tokens.shape[1]
second_partition_pos_idx = first_partition_pos_idx + inpaint_len
end_pos_idx = second_partition_pos_idx + last_tokens.shape[1]
m_tokens_len = torch.ones(batch_size, device = first_tokens.device)*end_pos_idx
src_token_mask = generate_src_mask(self.block_size-1, m_tokens_len+1)
src_token_mask_noend = generate_src_mask(self.block_size-1, m_tokens_len)
ids = torch.full(shape, mask_id, dtype = torch.long, device = first_tokens.device)
ids[:, :first_partition_pos_idx] = first_tokens
ids[:, second_partition_pos_idx:end_pos_idx] = last_tokens
src_token_mask_noend[:, :first_partition_pos_idx] = False
src_token_mask_noend[:, second_partition_pos_idx:end_pos_idx] = False
# [TODO] confirm that these 2 lines are not neccessary (repeated below and maybe don't need them at all)
ids[~src_token_mask] = pad_id # [INFO] replace with pad id
ids.scatter_(-1, m_tokens_len[..., None].long(), end_id) # [INFO] replace with end id
temp = []
sample_max_steps = torch.round(max_steps/max_length*m_tokens_len) + 1e-8
if clip_feature is None:
clip_feature = torch.zeros(1, 512).to(first_tokens.device)
att_txt = torch.zeros((batch_size,1), dtype=torch.bool, device = first_tokens.device)
else:
att_txt = torch.ones((batch_size,1), dtype=torch.bool, device = first_tokens.device)
for step in range(max_steps):
timestep = torch.clip(step/(sample_max_steps), max=1)
rand_mask_prob = cosine_schedule(timestep) # timestep #
num_token_masked = (rand_mask_prob * m_tokens_len).long().clip(min=1)
# [INFO] rm no motion frames
scores[~src_token_mask_noend] = 0
# [INFO] rm begin and end frames
scores[:, :first_partition_pos_idx] = 0
scores[:, second_partition_pos_idx:end_pos_idx] = 0
scores = scores/scores.sum(-1)[:, None] # normalize only unmasked token
sorted, sorted_score_indices = scores.sort(descending=True) # deterministic
ids[~src_token_mask] = pad_id # [INFO] replace with pad id
ids.scatter_(-1, m_tokens_len[..., None].long(), end_id) # [INFO] replace with end id
## [INFO] Replace "mask_id" to "ids" that have highest "num_token_masked" "scores"
select_masked_indices = generate_src_mask(sorted_score_indices.shape[1], num_token_masked)
# [INFO] repeat last_id to make it scatter_ the existing last ids.
last_index = sorted_score_indices.gather(-1, num_token_masked.unsqueeze(-1)-1)
sorted_score_indices = sorted_score_indices * select_masked_indices + (last_index*~select_masked_indices)
ids.scatter_(-1, sorted_score_indices, mask_id)
# [TODO] force replace begin/end tokens b/c the num mask will be more than actual inpainting frames
ids[:, :first_partition_pos_idx] = first_tokens
ids[:, second_partition_pos_idx:end_pos_idx] = last_tokens
logits = self.forward(ids, clip_feature, src_token_mask, att_txt)[:,1:]
filtered_logits = logits #top_k(logits, topk_filter_thres)
if rand_pos:
temperature = 1 #starting_temperature * (steps_until_x0 / timesteps) # temperature is annealed
else:
temperature = 0 #starting_temperature * (steps_until_x0 / timesteps) # temperature is annealed
# [INFO] if temperature==0: is equal to argmax (filtered_logits.argmax(dim = -1))
# pred_ids = filtered_logits.argmax(dim = -1)
pred_ids = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)
is_mask = ids == mask_id
temp.append(is_mask[:1])
ids = torch.where(
is_mask,
pred_ids,
ids
)
probs_without_temperature = logits.softmax(dim = -1)
scores = 1 - probs_without_temperature.gather(-1, pred_ids[..., None])
scores = rearrange(scores, '... 1 -> ...')
scores = scores.masked_fill(~is_mask, 0)
return ids
class Attention(nn.Module):
def __init__(self, embed_dim=512, block_size=16, n_head=8, drop_out_rate=0.1):
super().__init__()
assert embed_dim % 8 == 0
# key, query, value projections for all heads
self.key = nn.Linear(embed_dim, embed_dim)
self.query = nn.Linear(embed_dim, embed_dim)
self.value = nn.Linear(embed_dim, embed_dim)
self.attn_drop = nn.Dropout(drop_out_rate)
self.resid_drop = nn.Dropout(drop_out_rate)
self.proj = nn.Linear(embed_dim, embed_dim)
self.n_head = n_head
def forward(self, x, src_mask):
B, T, C = x.size()
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
if src_mask is not None:
att[~src_mask] = float('-inf')
att = F.softmax(att, dim=-1)
att = self.attn_drop(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.resid_drop(self.proj(y))
return y
class Block(nn.Module):
def __init__(self, embed_dim=512, block_size=16, n_head=8, drop_out_rate=0.1, fc_rate=4):
super().__init__()
self.ln1 = nn.LayerNorm(embed_dim)
self.ln2 = nn.LayerNorm(embed_dim)
self.attn = Attention(embed_dim, block_size, n_head, drop_out_rate)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, fc_rate * embed_dim),
nn.GELU(),
nn.Linear(fc_rate * embed_dim, embed_dim),
nn.Dropout(drop_out_rate),
)
def forward(self, x, src_mask=None):
x = x + self.attn(self.ln1(x), src_mask)
x = x + self.mlp(self.ln2(x))
return x
from models.t2m_trans import Block_crossatt
class CrossCondTransBase(nn.Module):
def __init__(self,
vqvae,
num_vq=1024,
embed_dim=512,
clip_dim=512,
block_size=16,
num_layers=2,
num_local_layer = 1,
n_head=8,
drop_out_rate=0.1,
fc_rate=4):
super().__init__()
self.vqvae = vqvae
# self.tok_emb = nn.Embedding(num_vq + 3, embed_dim).requires_grad_(False)
self.learn_tok_emb = nn.Embedding(3, int(self.vqvae.vqvae.code_dim/2))# [INFO] 3 = [end_id, blank_id, mask_id]
self.to_emb = nn.Linear(self.vqvae.vqvae.code_dim, embed_dim)
self.cond_emb = nn.Linear(clip_dim, embed_dim)
self.pos_embedding = nn.Embedding(block_size, embed_dim)
self.drop = nn.Dropout(drop_out_rate)
# transformer block
self.blocks = nn.Sequential(*[Block(embed_dim, block_size, n_head, drop_out_rate, fc_rate) for _ in range(num_layers-num_local_layer)])
self.pos_embed = pos_encoding.PositionEmbedding(block_size, embed_dim, 0.0, False)
self.num_local_layer = num_local_layer
if num_local_layer > 0:
self.word_emb = nn.Linear(clip_dim, embed_dim)
self.cross_att = nn.Sequential(*[Block_crossatt(embed_dim, block_size, n_head, drop_out_rate, fc_rate) for _ in range(num_local_layer)])
self.block_size = block_size
self.apply(self._init_weights)
def get_block_size(self):
return self.block_size
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, idx_upper, idx_lower, clip_feature, src_mask, word_emb):
if len(idx_upper) == 0:
token_embeddings = self.cond_emb(clip_feature).unsqueeze(1)
else:
b, t = idx_upper.size()
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
# forward the Trans model
learn_idx_upper = idx_upper>=self.vqvae.vqvae.num_code
learn_idx_lower = idx_lower>=self.vqvae.vqvae.num_code
code_dim = self.vqvae.vqvae.code_dim
token_embeddings = torch.empty((*idx_upper.shape, code_dim), device=idx_upper.device)
token_embeddings[..., :int(code_dim/2)][~learn_idx_upper] = self.vqvae.vqvae.quantizer_upper.dequantize(idx_upper[~learn_idx_upper]).requires_grad_(False)
token_embeddings[..., :int(code_dim/2)][learn_idx_upper] = self.learn_tok_emb(idx_upper[learn_idx_upper]-self.vqvae.vqvae.num_code)
token_embeddings[..., int(code_dim/2):][~learn_idx_lower] = self.vqvae.vqvae.quantizer_lower.dequantize(idx_lower[~learn_idx_lower]).requires_grad_(False)
token_embeddings[..., int(code_dim/2):][learn_idx_lower] = self.learn_tok_emb(idx_lower[learn_idx_lower]-self.vqvae.vqvae.num_code)
token_embeddings = self.to_emb(token_embeddings)
if self.num_local_layer > 0:
word_emb = self.word_emb(word_emb)
token_embeddings = self.pos_embed(token_embeddings)
for module in self.cross_att:
token_embeddings = module(token_embeddings, word_emb)
token_embeddings = torch.cat([self.cond_emb(clip_feature).unsqueeze(1), token_embeddings], dim=1)
x = self.pos_embed(token_embeddings)
for block in self.blocks:
x = block(x, src_mask)
return x
class CrossCondTransHead(nn.Module):
def __init__(self,
num_vq=1024,
embed_dim=512,
block_size=16,
num_layers=2,
n_head=8,
drop_out_rate=0.1,
fc_rate=4):
super().__init__()
self.blocks = nn.Sequential(*[Block(embed_dim, block_size, n_head, drop_out_rate, fc_rate) for _ in range(num_layers)])
self.ln_f = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_vq, bias=False)
self.block_size = block_size
self.apply(self._init_weights)
def get_block_size(self):
return self.block_size
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, x, src_mask):
for block in self.blocks:
x = block(x, src_mask)
x = self.ln_f(x)
logits = self.head(x)
return logits