storydalle / dalle /models /stage2 /transformer.py
adymaharana
Added files
3d5e231
raw
history blame
No virus
20.7 kB
# ------------------------------------------------------------------------------------
# Minimal DALL-E
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------
# Modified from minGPT (https://github.com/karpathy/minGPT)
# Copyright (c) 2020 Andrej Karpathy. All Rights Reserved.
# ------------------------------------------------------------------------------------
import torch
import torch.nn as nn
from typing import Optional, Tuple, List
from torch.cuda.amp import autocast
from omegaconf import OmegaConf
from .layers import Block
class Transformer1d(nn.Module):
def __init__(self,
vocab_size_txt: int,
vocab_size_img: int,
hparams: OmegaConf) -> None:
super().__init__()
assert hparams.n_layers == hparams.n_dense_layers
# input embedding for image and text
self.tok_emb_img = nn.Embedding(vocab_size_img, hparams.embed_dim)
self.tok_emb_txt = nn.Embedding(vocab_size_txt, hparams.embed_dim)
self.pos_emb_img = nn.Embedding(hparams.ctx_len_img, hparams.embed_dim)
self.pos_emb_txt = nn.Embedding(hparams.ctx_len_txt, hparams.embed_dim)
self.drop = nn.Dropout(hparams.embd_pdrop)
# transformer blocks
self.blocks = [Block(ctx_len=hparams.ctx_len_img + hparams.ctx_len_txt,
embed_dim=hparams.embed_dim,
n_heads=hparams.n_heads,
mlp_bias=hparams.mlp_bias,
attn_bias=hparams.attn_bias,
resid_pdrop=hparams.resid_pdrop,
attn_pdrop=hparams.attn_pdrop,
gelu_use_approx=hparams.gelu_use_approx) for i in range(1, hparams.n_layers+1)]
self.blocks = nn.Sequential(*self.blocks)
# heads for image and text
self.ln_f = nn.LayerNorm(hparams.embed_dim)
self.head_img = nn.Linear(hparams.embed_dim, vocab_size_img, bias=False)
self.head_txt = nn.Linear(hparams.embed_dim, vocab_size_txt, bias=False)
self.ctx_len_img = hparams.ctx_len_img
self.ctx_len_txt = hparams.ctx_len_txt
self.n_layers = hparams.n_layers
self.apply(self._init_weights)
def _init_weights(self, module: nn.Module) -> None:
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def resize_token_embeddings(self, new_num_tokens):
old_num_tokens, old_embedding_dim = self.tok_emb_txt.weight.size()
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
new_embeddings.to(self.tok_emb_txt.weight.device, dtype=self.tok_emb_txt.weight.dtype)
self._init_weights(new_embeddings)
# numbers of tokens to copy
n = min(old_num_tokens, new_num_tokens)
new_embeddings.weight.data[:n, :] = self.tok_emb_txt.weight.data[:n, :]
self.tok_emb_txt = new_embeddings
self.resize_lm_head(new_num_tokens)
# TODO: also change config to reflect new vocab size
return new_embeddings
def resize_lm_head(
self, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False) -> nn.Linear:
old_num_tokens, old_lm_head_dim = (
self.head_txt.weight.size() if not transposed else self.head_txt.weight.t().size()
)
# Build new lm head
new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim)
has_new_lm_head_bias = self.head_txt.bias is not None
new_lm_head = nn.Linear(*new_lm_head_shape, bias=has_new_lm_head_bias)
new_lm_head = new_lm_head.to(self.head_txt.weight.device, dtype=self.head_txt.weight.dtype)
# initialize new lm head (in particular added tokens)
self._init_weights(new_lm_head)
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
# Copy old lm head weights to new lm head
if not transposed:
new_lm_head.weight.data[:num_tokens_to_copy, :] = self.head_txt.weight.data[:num_tokens_to_copy, :]
else:
new_lm_head.weight.data[:, :num_tokens_to_copy] = self.head_txt.weight.data[:, :num_tokens_to_copy]
# Copy bias weights to new lm head
if has_new_lm_head_bias:
new_lm_head.bias.data[:num_tokens_to_copy] = self.head_txt.bias.data[:num_tokens_to_copy]
self.head_txt = new_lm_head
return new_lm_head
def forward(self,
images: torch.LongTensor,
texts: torch.LongTensor,
pos_images: torch.LongTensor,
pos_texts: torch.LongTensor,
past: Optional[List[torch.Tensor]] = None,
prompt: Optional[List[torch.Tensor]] = None,
pos_prompt: Optional[List[torch.Tensor]] = None) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
B, T = images.shape
_, N = texts.shape
assert T <= self.ctx_len_img, "Already reached the maximum context length (image)."
assert N == self.ctx_len_txt, "Already reached the maximum context length (text)."
texts = self.tok_emb_txt(texts)
images = self.tok_emb_img(images)
texts = texts + self.pos_emb_txt(pos_texts)
images = images + self.pos_emb_img(pos_images)
if prompt is not None:
prompt = prompt + self.pos_emb_txt(pos_prompt)
texts = torch.cat([prompt, texts], dim=1).contiguous()
P = prompt.shape[1]
x = torch.cat([texts, images], dim=1).contiguous()
x = self.drop(x)
# x = self.blocks(x)
for i, block in enumerate(self.blocks):
x, _ = block.sample(x, layer_past=None if past is None else past[i])
x = self.ln_f(x)
if prompt is not None:
texts = x[:, P:N+P-1].contiguous()
images = x[:, N+P-1:-1].contiguous()
else:
texts = x[:, :N-1].contiguous()
images = x[:, N-1:-1].contiguous()
logits_txt = self.head_txt(texts)
logits_img = self.head_img(images)
return logits_img, logits_txt
def forward_with_context(self,
images: torch.LongTensor,
texts: torch.LongTensor,
pos_images: torch.LongTensor,
pos_texts: torch.LongTensor,
src_images: torch.LongTensor,
src_pos_images: torch.LongTensor,
cross_attention_idxs: List,
cross_attention_layers,
past: Optional[List[torch.Tensor]] = None,
prompt: Optional[List[torch.Tensor]] = None,
pos_prompt: Optional[List[torch.Tensor]] = None) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
B, T = images.shape
_, N = texts.shape
assert T <= self.ctx_len_img, "Already reached the maximum context length (image)."
assert N == self.ctx_len_txt, "Already reached the maximum context length (text)."
texts = self.tok_emb_txt(texts)
images = self.tok_emb_img(images)
src_images = self.tok_emb_img(src_images)
texts = texts + self.pos_emb_txt(pos_texts)
images = images + self.pos_emb_img(pos_images)
src_images = src_images + self.pos_emb_img(src_pos_images)
if prompt is not None:
prompt = prompt + self.pos_emb_txt(pos_prompt)
texts = torch.cat([prompt, texts], dim=1).contiguous()
P = prompt.shape[1]
else:
P = 0
x = torch.cat([texts, images], axis=1).contiguous()
x = self.drop(x)
# prepare mask
mask = torch.zeros_like(x[0])
mask[self.ctx_len_txt+P-1:, :].fill_(1.0)
mask = mask.unsqueeze(0)
# print(images.shape, texts.shape, src_images.shape, mask.shape, x.shape)
# x = self.blocks(x)
for i, block in enumerate(self.blocks):
if i in cross_attention_idxs:
x, _ = block.sample_with_context(x, src_images, mask, cross_attention_layers[int(((i+1)/3)-1)], layer_past=None if past is None else past[i])
else:
x, _ = block.sample(x, layer_past=None if past is None else past[i])
x = self.ln_f(x)
if prompt is not None:
texts = x[:, P:N+P-1].contiguous()
images = x[:, N+P-1:-1].contiguous()
else:
texts = x[:, :N-1].contiguous()
images = x[:, N-1:-1].contiguous()
logits_txt = self.head_txt(texts)
logits_img = self.head_img(images)
return logits_img, logits_txt
@torch.no_grad()
def sampling(self,
images: torch.LongTensor,
texts: torch.LongTensor,
pos_images: torch.LongTensor,
pos_texts: torch.LongTensor,
use_fp16: bool = True,
past: Optional[List[torch.Tensor]] = None,
prompt: Optional[List[torch.Tensor]] = None,
pos_prompt: Optional[List[torch.Tensor]] = None) -> Tuple[torch.FloatTensor, List[torch.FloatTensor]]:
_, N = texts.shape
assert N == self.ctx_len_txt, "Already reached the maximum context length (text)."
with autocast(enabled=use_fp16):
if images is None:
# assert past is None
texts = self.tok_emb_txt(texts)
x = texts + self.pos_emb_txt(pos_texts)
if prompt is not None:
prompt = prompt + self.pos_emb_txt(pos_prompt)
texts = torch.cat([prompt, texts], dim=1).contiguous()
x = self.drop(x)
if past is not None:
past = torch.cat(past, dim=-2)
presents = []
for i, block in enumerate(self.blocks):
x, present = block.sample(x, layer_past=None if past is None else past[i])
presents.append(present)
x = self.ln_f(x)
x = x[:, N-1].contiguous()
logits = self.head_img(x)
else:
if past is None:
texts = self.tok_emb_txt(texts)
images = self.tok_emb_img(images)
texts = texts + self.pos_emb_txt(pos_texts)
images = images + self.pos_emb_img(pos_images)
if prompt is not None:
prompt = prompt + self.pos_emb_txt(pos_prompt)
texts = torch.cat([prompt, texts], dim=1).contiguous()
x = torch.cat([texts, images], axis=1).contiguous()
else:
images = self.tok_emb_img(images)
x = images + self.pos_emb_img(pos_images)
x = self.drop(x)
# if past is not None and len(past) > 1:
if past is not None:
past = torch.cat(past, dim=-2)
# print('Past', past.shape)
presents = []
# print(len(past), past[0].shape)
for i, block in enumerate(self.blocks):
x, present = block.sample(x, layer_past=None if past is None else past[i])
presents.append(present)
x = self.ln_f(x)
x = x[:, -1].contiguous()
logits = self.head_img(x)
return logits, presents
@torch.no_grad()
def sampling_with_context(self,
images: torch.LongTensor,
cross_attention_idxs,
cross_attention_layers,
texts: torch.LongTensor,
pos_images: torch.LongTensor,
pos_texts: torch.LongTensor,
source_image: torch.LongTensor,
use_fp16: bool = True,
past: Optional[List[torch.Tensor]] = None,
prompt: Optional[List[torch.Tensor]] = None,
pos_prompt: Optional[List[torch.Tensor]] = None
) -> Tuple[torch.FloatTensor, List[torch.FloatTensor]]:
_, N = texts.shape
assert N == self.ctx_len_txt, "Already reached the maximum context length (text)."
if prompt is not None:
P = prompt.shape[1]
else:
P = 0
with autocast(enabled=use_fp16):
if images is None:
# assert past is None
texts = self.tok_emb_txt(texts)
texts = texts + self.pos_emb_txt(pos_texts)
if prompt is not None:
prompt = prompt + self.pos_emb_txt(pos_prompt)
texts = torch.cat([prompt, texts], dim=1).contiguous()
x = self.drop(texts)
if past is not None:
past = torch.cat(past, dim=-2)
# prepare mask
mask = torch.zeros_like(x[0])
mask[self.ctx_len_txt+P - 1:, :].fill_(1.0)
mask = mask.unsqueeze(0)
presents = []
for i, block in enumerate(self.blocks):
if i in cross_attention_idxs:
x, present = block.sample_with_context(x, source_image, mask,
cross_attention_layers[int(((i + 1) / 3) - 1)],
layer_past=None if past is None else past[i])
else:
x, present = block.sample(x, layer_past=None if past is None else past[i])
presents.append(present)
x = self.ln_f(x)
x = x[:, N-1].contiguous()
logits = self.head_img(x)
else:
if past is None:
texts = self.tok_emb_txt(texts)
images = self.tok_emb_img(images)
texts = texts + self.pos_emb_txt(pos_texts)
images = images + self.pos_emb_img(pos_images)
if prompt is not None:
prompt = prompt + self.pos_emb_txt(pos_prompt)
texts = torch.cat([prompt, texts], dim=1).contiguous()
x = torch.cat([texts, images], axis=1).contiguous()
else:
images = self.tok_emb_img(images)
x = images + self.pos_emb_img(pos_images)
x = self.drop(x)
# if past is not None and len(past) > 1:
if past is not None:
past = torch.cat(past, dim=-2)
presents = []
# prepare mask
mask = torch.zeros_like(x[0])
mask[self.ctx_len_txt+P - 1:, :].fill_(1.0)
mask = mask.unsqueeze(0)
# print(len(past), past[0].shape)
for i, block in enumerate(self.blocks):
if i in cross_attention_idxs:
x, present = block.sample_with_context(x, source_image, mask,
cross_attention_layers[int(((i + 1) / 3) - 1)],
layer_past=None if past is None else past[i])
else:
x, present = block.sample(x, layer_past=None if past is None else past[i])
presents.append(present)
x = self.ln_f(x)
x = x[:, -1].contiguous()
logits = self.head_img(x)
return logits, presents
def from_ckpt(self, path: str) -> None:
ckpt = torch.load(path, map_location='cpu')['state_dict']
self.load_state_dict(ckpt, strict=True)
print(f'{path} succesfully restored..')
class iGPT(nn.Module):
def __init__(self,
vocab_size_img: int,
use_cls_cond: bool,
hparams: OmegaConf) -> None:
super().__init__()
self.use_cls_cond = use_cls_cond
# sos token embedding
if self.use_cls_cond:
self.sos = nn.Embedding(hparams.n_classes, hparams.embed_dim)
else:
self.sos = nn.Parameter(torch.randn(1, 1, hparams.embed_dim))
# input embedding
self.tok_emb_img = nn.Embedding(vocab_size_img, hparams.embed_dim)
self.pos_emb_img = nn.Embedding(hparams.ctx_len_img, hparams.embed_dim)
self.drop = nn.Dropout(hparams.embd_pdrop)
# transformer blocks
self.blocks = [Block(ctx_len=hparams.ctx_len_img + 1,
embed_dim=hparams.embed_dim,
n_heads=hparams.n_heads,
mlp_bias=hparams.mlp_bias,
attn_bias=hparams.attn_bias,
resid_pdrop=hparams.resid_pdrop,
attn_pdrop=hparams.attn_pdrop,
gelu_use_approx=hparams.gelu_use_approx) for i in range(1, hparams.n_layers+1)]
self.blocks = nn.Sequential(*self.blocks)
# head
self.ln_f = nn.LayerNorm(hparams.embed_dim)
self.head = nn.Linear(hparams.embed_dim, vocab_size_img, bias=False)
self.ctx_len_img = hparams.ctx_len_img
self.n_layers = hparams.n_layers
self.apply(self._init_weights)
def _init_weights(self, module: nn.Module) -> None:
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
@torch.no_grad()
def sampling(self,
sos: torch.FloatTensor,
codes: torch.LongTensor,
pos_codes: torch.LongTensor,
n_samples: int = 16,
use_fp16: bool = True,
past: Optional[torch.Tensor] = None) -> Tuple[torch.FloatTensor, List[torch.FloatTensor]]:
with autocast(enabled=use_fp16):
if codes is None:
assert past is None
xs = self.drop(sos)
presents = []
for i, block in enumerate(self.blocks):
xs, present = block.sample(xs, layer_past=None)
presents.append(present)
xs = self.ln_f(xs)
logits = self.head(xs)[:, -1]
else:
if past is None:
xs = self.tok_emb_img(codes) + self.pos_emb_img(pos_codes)
xs = torch.cat([sos, xs], dim=1)
else:
xs = self.tok_emb_img(codes) + self.pos_emb_img(pos_codes)
xs = self.drop(xs)
past = torch.cat(past, dim=-2) if past is not None else past
presents = []
for i, block in enumerate(self.blocks):
xs, present = block.sample(xs, layer_past=None if past is None else past[i])
presents.append(present)
xs = self.ln_f(xs)
logits = self.head(xs)[:, -1]
return logits, presents
def forward(self,
codes: torch.LongTensor,
labels: Optional[torch.LongTensor] = None) -> torch.FloatTensor:
B, T = codes.shape
xps = torch.arange(T, device=codes.device).repeat((B, 1))
sos = self.sos.repeat((B, 1, 1)) if labels is None else self.sos(labels).unsqueeze(1)
h = self.tok_emb_img(codes) + self.pos_emb_img(xps)
h = torch.cat([sos, h[:, :-1]], dim=1).contiguous()
h = self.drop(h)
h = self.blocks(h)
h = self.ln_f(h)
logits = self.head(h)
return logits
def from_ckpt(self, path: str, strict: bool = True) -> None:
ckpt = torch.load(path, map_location='cpu')['state_dict']
self.load_state_dict(ckpt, strict=strict)
print(f'{path} successfully restored..')