Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
from transformers import CLIPModel | |
from transformers.models.clip.modeling_clip import _expand_mask | |
from .utils import drop_sequence_mask | |
def position_embedding(input, d_model): | |
input = input.view(-1, 1) | |
dim = torch.arange(d_model // 2, dtype=torch.float32, device=input.device).view(1, -1) | |
sin = torch.sin(input / 10000 ** (2 * dim / d_model)) | |
cos = torch.cos(input / 10000 ** (2 * dim / d_model)) | |
out = torch.zeros((input.shape[0], d_model), device=input.device) | |
out[:, ::2] = sin | |
out[:, 1::2] = cos | |
return out | |
def sinusoid_encoding_table(max_len, d_model, padding_idx=None): | |
pos = torch.arange(max_len, dtype=torch.float32) | |
out = position_embedding(pos, d_model) | |
if padding_idx is not None: | |
out[padding_idx] = 0 | |
return out | |
class KnwlModel(nn.Module): | |
def __init__(self, d_knwl, d_out, pt=0.1): | |
super().__init__() | |
self.pt = pt | |
self.fc_knwl = nn.Linear(d_knwl, d_out, bias=False) | |
self.fc_query = nn.Linear(d_knwl, d_out) | |
self.pos = nn.Embedding(9, d_out) | |
self.score1 = nn.Parameter(torch.randn(1, 1, d_out)) | |
self.score2 = nn.Parameter(torch.randn(1, 1, d_out)) | |
self.obj = nn.Parameter(torch.randn(1, 1, d_out)) | |
self.attr = nn.Parameter(torch.randn(1, 1, d_out)) | |
self.act = nn.Parameter(torch.randn(1, 1, d_out)) | |
self.query = nn.Parameter(torch.randn(1, 1, d_out)) | |
def device(self): | |
return self.score1.device | |
def prepare_input(self, knowledge): | |
e = self.fc_knwl(knowledge["embed"]) | |
p = self.pos(knowledge["pos"]) | |
s = knowledge["score"].unsqueeze(-1) * self.score1 + self.score2 | |
e_knwl = e + p + s | |
m_knwl = drop_sequence_mask( | |
*e_knwl.shape[:2], self.device, self.pt, self.training | |
) | |
e = self.fc_query(knowledge["query"]) | |
p = torch.arange(knowledge["query"].shape[1], device=self.device) | |
p = self.pos(p[None, :]) | |
e_query = e + p | |
m_query = torch.ones( | |
e_query.shape[:2], dtype=torch.long, device=self.device | |
) | |
return e_knwl, m_knwl, e_query, m_query | |
def forward(self, knowledge): | |
e_obj, m_obj, e_query, m_query = self.prepare_input(knowledge["obj"]) | |
e_attr, m_attr, _, _ = self.prepare_input(knowledge["attr"]) | |
e_act, m_act, _, _ = self.prepare_input(knowledge["act"]) | |
e_obj = e_obj + self.obj | |
e_attr = e_attr + self.attr | |
e_act = e_act + self.act | |
e_query = e_query + self.query | |
embeds = torch.cat([e_query, e_obj, e_attr, e_act], dim=1) | |
masks = torch.cat([m_query, m_obj, m_attr, m_act], dim=1) | |
return embeds, masks | |
class KnwlEncoder(nn.Module): | |
def __init__(self, d_out, num_layers=None, grad_ckpt=True): | |
super().__init__() | |
self.model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.float16).vision_model | |
self.model.encoder.gradient_checkpointing = grad_ckpt | |
if num_layers is not None: | |
self.model.encoder.layers = nn.ModuleList([ | |
self.model.encoder.layers[i] for i in range(-num_layers, 0) | |
]) | |
self.fc = nn.Linear(self.model.config.hidden_size, d_out, bias=False) | |
self.d = self.model.config.hidden_size | |
def forward(self, inputs_embeds, attention_mask): | |
embed = self.model.pre_layrnorm(inputs_embeds) | |
mask = _expand_mask(attention_mask, embed.dtype) | |
embed = self.model.encoder( | |
inputs_embeds=embed, | |
attention_mask=mask, | |
return_dict=True, | |
)[0] | |
embed = self.fc(self.model.post_layernorm(embed)) | |
return embed | |