storydalle / dalle /utils /sampling.py
adymaharana
Added files
3d5e231
raw
history blame
No virus
13.8 kB
# ------------------------------------------------------------------------------------
# Minimal DALL-E
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------
import torch
from typing import Optional
from tqdm import tqdm
from torch.nn import functional as F
torch.set_printoptions(precision=2, threshold=10)
def cutoff_topk_logits(logits: torch.FloatTensor, k: int) -> torch.FloatTensor:
if k is None:
return logits
else:
v, ix = torch.topk(logits, k)
out = logits.clone()
out[out < v[:, [-1]]] = -float('Inf')
return out
def cutoff_topp_probs(probs: torch.FloatTensor, p: float) -> torch.FloatTensor:
if p is None:
return probs
else:
sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
cum_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_idx_remove_cond = cum_probs >= p
sorted_idx_remove_cond[..., 1:] = sorted_idx_remove_cond[..., :-1].clone()
sorted_idx_remove_cond[..., 0] = 0
indices_to_remove = sorted_idx_remove_cond.scatter(-1, sorted_indices, sorted_idx_remove_cond)
probs = probs.masked_fill(indices_to_remove, 0.0)
norm_probs = probs / torch.sum(probs, dim=-1, keepdim=True)
return norm_probs
def get_positional_encoding(inputs: torch.LongTensor, mode: str = '1d') -> torch.LongTensor:
device = inputs.device
if mode == '1d':
B, N = inputs.shape
xs_pos = torch.arange(N, device=device).repeat((B, 1))
elif mode == '2d':
B, H, W = inputs.shape
xs_pos_h = torch.arange(H, device=device).repeat(B, W, 1).transpose(1, 2)
xs_pos_w = torch.arange(W, device=device).repeat(B, H, 1)
xs_pos = (xs_pos_h, xs_pos_w)
else:
raise ValueError('%s positional encoding invalid' % mode)
return xs_pos
@torch.no_grad()
def sampling(model: torch.nn.Module,
tokens: torch.LongTensor,
top_k: Optional[float] = None,
top_p: Optional[float] = None,
softmax_temperature: float = 1.0,
is_tqdm: bool = True,
use_fp16: bool = True,
max_seq_len: int = 256,
prompt: Optional[torch.tensor] = None,
pos_prompt: Optional[torch.Tensor] = None) -> torch.LongTensor:
code = None
past = None
pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
pos_enc_tokens = get_positional_encoding(tokens, mode='1d')
for cnt, h in enumerate(pbar):
if code is None:
code_ = None
pos_enc_code_ = None
else:
code_ = code.clone().detach()
pos_enc_code_ = get_positional_encoding(code_, mode='1d')
code_ = code_[:, cnt-1].unsqueeze(-1)
pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1)
logits, present = model.sampling(images=code_,
texts=tokens,
pos_images=pos_enc_code_,
pos_texts=pos_enc_tokens,
use_fp16=use_fp16,
past=past,
prompt=prompt,
pos_prompt=pos_prompt)
logits = logits.to(dtype=torch.float32)
logits = logits / softmax_temperature
# print(len(present), present[0].shape)
present = torch.stack(present).clone().detach()
if past is None:
past = [present]
else:
past.append(present)
logits = cutoff_topk_logits(logits, top_k)
probs = F.softmax(logits, dim=-1)
probs = cutoff_topp_probs(probs, top_p)
# print(probs[0])
idx = torch.multinomial(probs, num_samples=1).clone().detach()
# print(idx)
code = idx if code is None else torch.cat([code, idx], axis=1)
del past
return code
@torch.no_grad()
def sampling_prefix(model: torch.nn.Module,
tokens: torch.LongTensor,
past: torch.FloatTensor,
top_k: Optional[float] = None,
top_p: Optional[float] = None,
softmax_temperature: float = 1.0,
is_tqdm: bool = True,
use_fp16: bool = True,
max_seq_len: int = 256,
labels = None) -> torch.LongTensor:
code = None
pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
pos_enc_tokens = get_positional_encoding(tokens, mode='1d')
# print("Entering sampling_prefix; ", past.shape)
if past is not None:
past = [past]
for cnt, h in enumerate(pbar):
if code is None:
code_ = None
pos_enc_code_ = None
else:
code_ = code.clone().detach()
pos_enc_code_ = get_positional_encoding(code_, mode='1d')
code_ = code_[:, cnt-1].unsqueeze(-1)
pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1)
# print("Looop enter")
# print(cnt, past[0].shape)
# print("-------------------")
logits, present = model.sampling(images=code_,
texts=tokens,
pos_images=pos_enc_code_,
pos_texts=pos_enc_tokens,
use_fp16=use_fp16,
past=past)
logits = logits.to(dtype=torch.float32)
logits = logits / softmax_temperature
present = torch.stack(present).clone().detach()
# print('Present', present.shape)
if past is None:
past = [present]
else:
# print("Loop end")
# print(present.shape)
# print("-----------------")
# n_layers, temp, _, seq_len, n_dim = present.shape
# _, _, bs, n_heads, pre_seq_len, n_dim = past[0].shape
# assert temp == 2
# past.append(present.view(n_layers, temp, bs, n_heads, seq_len, n_dim))
past.append(present)
logits = cutoff_topk_logits(logits, top_k)
probs = F.softmax(logits, dim=-1)
probs = cutoff_topp_probs(probs, top_p)
print(torch.topk(probs, 5, dim=-1))
if labels is not None:
print(labels[cnt])
idx = torch.multinomial(probs, num_samples=1).clone().detach()
# print(idx)
code = idx if code is None else torch.cat([code, idx], axis=1)
del past
return code
@torch.no_grad()
def sampling_prefix_new(model: torch.nn.Module,
tokens: torch.LongTensor,
past: torch.FloatTensor,
top_k: Optional[float] = None,
top_p: Optional[float] = None,
softmax_temperature: float = 1.0,
is_tqdm: bool = True,
use_fp16: bool = True,
max_seq_len: int = 256) -> torch.LongTensor:
code = None
pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
pos_enc_tokens = get_positional_encoding(tokens, mode='1d')
# print("Entering sampling_prefix; ", past.shape)
if past is not None:
past = [past]
for cnt, h in enumerate(pbar):
if code is None:
code_ = None
pos_enc_code_ = None
else:
code_ = code.clone().detach()
pos_enc_code_ = get_positional_encoding(code_, mode='1d')
# code_ = code_[:, cnt-1].unsqueeze(-1)
# pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1)
# print("Looop enter")
# print(cnt, past[0].shape)
# print("-------------------")
if cnt == 0:
logits, present = model.sampling(images=code_,
texts=tokens,
pos_images=pos_enc_code_,
pos_texts=pos_enc_tokens,
use_fp16=use_fp16,
past=past)
logits = logits.to(dtype=torch.float32)
logits = logits / softmax_temperature
present = torch.stack(present).clone().detach()
# print('Present', present.shape)
if past is None:
past = [present]
else:
pass
logits = cutoff_topk_logits(logits, top_k)
probs = F.softmax(logits, dim=-1)
probs = cutoff_topp_probs(probs, top_p)
# print(torch.topk(probs[0], 5))
idx = torch.multinomial(probs, num_samples=1).clone().detach()
# print(idx)
code = idx if code is None else torch.cat([code, idx], axis=1)
else:
pass
del past
return code
@torch.no_grad()
def sampling_conditional(model: torch.nn.Module,
cross_attention_idxs,
cross_attention_layers,
tokens: torch.LongTensor,
src_codes: torch.FloatTensor,
top_k: Optional[float] = None,
top_p: Optional[float] = None,
softmax_temperature: float = 1.0,
is_tqdm: bool = True,
use_fp16: bool = True,
max_seq_len: int = 256,
prompt: Optional[torch.tensor] = None,
pos_prompt: Optional[torch.Tensor] = None) -> torch.LongTensor:
code = None
past = None
pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
pos_enc_tokens = get_positional_encoding(tokens, mode='1d')
src_pos_tokens = get_positional_encoding(src_codes, mode='1d')
src_tokens = model.tok_emb_img(src_codes)
src_tokens = src_tokens + model.pos_emb_img(src_pos_tokens)
for cnt, h in enumerate(pbar):
if code is None:
code_ = None
pos_enc_code_ = None
else:
code_ = code.clone().detach()
pos_enc_code_ = get_positional_encoding(code_, mode='1d')
code_ = code_[:, cnt-1].unsqueeze(-1)
pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1)
logits, present = model.sampling_with_context(images=code_,
cross_attention_idxs=cross_attention_idxs,
cross_attention_layers=cross_attention_layers,
texts=tokens,
pos_images=pos_enc_code_,
pos_texts=pos_enc_tokens,
source_image=src_tokens,
use_fp16=use_fp16,
past=past,
prompt=prompt,
pos_prompt=pos_prompt)
logits = logits.to(dtype=torch.float32)
logits = logits / softmax_temperature
present = torch.stack(present).clone().detach()
if past is None:
past = [present]
else:
past.append(present)
logits = cutoff_topk_logits(logits, top_k)
probs = F.softmax(logits, dim=-1)
probs = cutoff_topp_probs(probs, top_p)
idx = torch.multinomial(probs, num_samples=1).clone().detach()
code = idx if code is None else torch.cat([code, idx], axis=1)
del past
return code
@torch.no_grad()
def sampling_igpt(model: torch.nn.Module,
sos: torch.FloatTensor,
top_k: Optional[float] = None,
top_p: Optional[float] = None,
softmax_temperature: float = 1.0,
is_tqdm: bool = True,
use_fp16: bool = True,
max_seq_len: int = 256) -> torch.LongTensor:
code = None
past = None
pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
for cnt, h in enumerate(pbar):
if code is None:
code_ = None
pos_enc_code_ = None
else:
code_ = code.clone().detach()
pos_enc_code_ = get_positional_encoding(code_, mode='1d')
code_ = code_[:, cnt-1].unsqueeze(-1)
pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1)
logits, present = model.sampling(sos=sos,
codes=code_,
pos_codes=pos_enc_code_,
use_fp16=use_fp16,
past=past)
logits = logits.to(dtype=torch.float32)
logits = logits / softmax_temperature
present = torch.stack(present).clone().detach()
if past is None:
past = [present]
else:
past.append(present)
logits = cutoff_topk_logits(logits, top_k)
probs = F.softmax(logits, dim=-1)
probs = cutoff_topp_probs(probs, top_p)
idx = torch.multinomial(probs, num_samples=1).clone().detach()
code = idx if code is None else torch.cat([code, idx], axis=1)
del past
return code