import torch from torch import nn from transformers import CLIPModel from transformers.models.clip.modeling_clip import _expand_mask from .utils import drop_sequence_mask def position_embedding(input, d_model): input = input.view(-1, 1) dim = torch.arange(d_model // 2, dtype=torch.float32, device=input.device).view(1, -1) sin = torch.sin(input / 10000 ** (2 * dim / d_model)) cos = torch.cos(input / 10000 ** (2 * dim / d_model)) out = torch.zeros((input.shape[0], d_model), device=input.device) out[:, ::2] = sin out[:, 1::2] = cos return out def sinusoid_encoding_table(max_len, d_model, padding_idx=None): pos = torch.arange(max_len, dtype=torch.float32) out = position_embedding(pos, d_model) if padding_idx is not None: out[padding_idx] = 0 return out class KnwlModel(nn.Module): def __init__(self, d_knwl, d_out, pt=0.1): super().__init__() self.pt = pt self.fc_knwl = nn.Linear(d_knwl, d_out, bias=False) self.fc_query = nn.Linear(d_knwl, d_out) self.pos = nn.Embedding(9, d_out) self.score1 = nn.Parameter(torch.randn(1, 1, d_out)) self.score2 = nn.Parameter(torch.randn(1, 1, d_out)) self.obj = nn.Parameter(torch.randn(1, 1, d_out)) self.attr = nn.Parameter(torch.randn(1, 1, d_out)) self.act = nn.Parameter(torch.randn(1, 1, d_out)) self.query = nn.Parameter(torch.randn(1, 1, d_out)) @property def device(self): return self.score1.device def prepare_input(self, knowledge): e = self.fc_knwl(knowledge["embed"]) p = self.pos(knowledge["pos"]) s = knowledge["score"].unsqueeze(-1) * self.score1 + self.score2 e_knwl = e + p + s m_knwl = drop_sequence_mask( *e_knwl.shape[:2], self.device, self.pt, self.training ) e = self.fc_query(knowledge["query"]) p = torch.arange(knowledge["query"].shape[1], device=self.device) p = self.pos(p[None, :]) e_query = e + p m_query = torch.ones( e_query.shape[:2], dtype=torch.long, device=self.device ) return e_knwl, m_knwl, e_query, m_query def forward(self, knowledge): e_obj, m_obj, e_query, m_query = self.prepare_input(knowledge["obj"]) e_attr, m_attr, _, _ = self.prepare_input(knowledge["attr"]) e_act, m_act, _, _ = self.prepare_input(knowledge["act"]) e_obj = e_obj + self.obj e_attr = e_attr + self.attr e_act = e_act + self.act e_query = e_query + self.query embeds = torch.cat([e_query, e_obj, e_attr, e_act], dim=1) masks = torch.cat([m_query, m_obj, m_attr, m_act], dim=1) return embeds, masks class KnwlEncoder(nn.Module): def __init__(self, d_out, num_layers=None, grad_ckpt=True): super().__init__() self.model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.float16).vision_model self.model.encoder.gradient_checkpointing = grad_ckpt if num_layers is not None: self.model.encoder.layers = nn.ModuleList([ self.model.encoder.layers[i] for i in range(-num_layers, 0) ]) self.fc = nn.Linear(self.model.config.hidden_size, d_out, bias=False) self.d = self.model.config.hidden_size def forward(self, inputs_embeds, attention_mask): embed = self.model.pre_layrnorm(inputs_embeds) mask = _expand_mask(attention_mask, embed.dtype) embed = self.model.encoder( inputs_embeds=embed, attention_mask=mask, return_dict=True, )[0] embed = self.fc(self.model.post_layernorm(embed)) return embed