GPT-K / model /knwl_model.py
cwkuo
implement gpt-k demo
7962ed0
import torch
from torch import nn
from transformers import CLIPModel
from transformers.models.clip.modeling_clip import _expand_mask
from .utils import drop_sequence_mask
def position_embedding(input, d_model):
input = input.view(-1, 1)
dim = torch.arange(d_model // 2, dtype=torch.float32, device=input.device).view(1, -1)
sin = torch.sin(input / 10000 ** (2 * dim / d_model))
cos = torch.cos(input / 10000 ** (2 * dim / d_model))
out = torch.zeros((input.shape[0], d_model), device=input.device)
out[:, ::2] = sin
out[:, 1::2] = cos
return out
def sinusoid_encoding_table(max_len, d_model, padding_idx=None):
pos = torch.arange(max_len, dtype=torch.float32)
out = position_embedding(pos, d_model)
if padding_idx is not None:
out[padding_idx] = 0
return out
class KnwlModel(nn.Module):
def __init__(self, d_knwl, d_out, pt=0.1):
super().__init__()
self.pt = pt
self.fc_knwl = nn.Linear(d_knwl, d_out, bias=False)
self.fc_query = nn.Linear(d_knwl, d_out)
self.pos = nn.Embedding(9, d_out)
self.score1 = nn.Parameter(torch.randn(1, 1, d_out))
self.score2 = nn.Parameter(torch.randn(1, 1, d_out))
self.obj = nn.Parameter(torch.randn(1, 1, d_out))
self.attr = nn.Parameter(torch.randn(1, 1, d_out))
self.act = nn.Parameter(torch.randn(1, 1, d_out))
self.query = nn.Parameter(torch.randn(1, 1, d_out))
@property
def device(self):
return self.score1.device
def prepare_input(self, knowledge):
e = self.fc_knwl(knowledge["embed"])
p = self.pos(knowledge["pos"])
s = knowledge["score"].unsqueeze(-1) * self.score1 + self.score2
e_knwl = e + p + s
m_knwl = drop_sequence_mask(
*e_knwl.shape[:2], self.device, self.pt, self.training
)
e = self.fc_query(knowledge["query"])
p = torch.arange(knowledge["query"].shape[1], device=self.device)
p = self.pos(p[None, :])
e_query = e + p
m_query = torch.ones(
e_query.shape[:2], dtype=torch.long, device=self.device
)
return e_knwl, m_knwl, e_query, m_query
def forward(self, knowledge):
e_obj, m_obj, e_query, m_query = self.prepare_input(knowledge["obj"])
e_attr, m_attr, _, _ = self.prepare_input(knowledge["attr"])
e_act, m_act, _, _ = self.prepare_input(knowledge["act"])
e_obj = e_obj + self.obj
e_attr = e_attr + self.attr
e_act = e_act + self.act
e_query = e_query + self.query
embeds = torch.cat([e_query, e_obj, e_attr, e_act], dim=1)
masks = torch.cat([m_query, m_obj, m_attr, m_act], dim=1)
return embeds, masks
class KnwlEncoder(nn.Module):
def __init__(self, d_out, num_layers=None, grad_ckpt=True):
super().__init__()
self.model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.float16).vision_model
self.model.encoder.gradient_checkpointing = grad_ckpt
if num_layers is not None:
self.model.encoder.layers = nn.ModuleList([
self.model.encoder.layers[i] for i in range(-num_layers, 0)
])
self.fc = nn.Linear(self.model.config.hidden_size, d_out, bias=False)
self.d = self.model.config.hidden_size
def forward(self, inputs_embeds, attention_mask):
embed = self.model.pre_layrnorm(inputs_embeds)
mask = _expand_mask(attention_mask, embed.dtype)
embed = self.model.encoder(
inputs_embeds=embed,
attention_mask=mask,
return_dict=True,
)[0]
embed = self.fc(self.model.post_layernorm(embed))
return embed