gill / gill /models.py
jykoh's picture
Fixes to image resolution
3d6dac6
from typing import List, Optional
from collections import namedtuple
from diffusers import StableDiffusionPipeline
import json
import numpy as np
import os
import glob
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import pickle as pkl
from PIL import Image, UnidentifiedImageError
from requests.exceptions import ConnectionError
from transformers import AutoTokenizer, AutoModel, CLIPVisionModel, OPTForCausalLM
from gill import utils
from gill import layers
class GILLArgs:
freeze_lm: bool = True
freeze_vm: bool = True
opt_version: str = 'facebook/opt-6.7b'
visual_encoder: str = 'openai/clip-vit-large-patch14'
n_visual_tokens: int = 1
task: str = 'captioning'
ret_emb_dim: Optional[int] = 256
gen_emb_dim: Optional[int] = 256
text_emb_layers: List[int] = [-1]
gen_token_idx: List[int] = [0]
retrieval_token_idx: List[int] = [0]
text_fc_mode: str = 'gill_mapper'
ret_text_fc_mode: str = 'linear'
num_tokens: int = 8
num_clip_tokens: int = 77
class GILLModel(nn.Module):
def __init__(self, tokenizer, args: GILLArgs = GILLArgs()):
super().__init__()
self.tokenizer = tokenizer
self.feature_extractor = utils.get_feature_extractor_for_model(args.visual_encoder, train=False)
self.image_token = self.tokenizer.cls_token_id
assert args.text_emb_layers != set(args.text_emb_layers), 'text_emb_layers not unique'
self.args = args
self.num_tokens = args.num_tokens
self.num_clip_tokens = args.num_clip_tokens
opt_version = args.opt_version
visual_encoder = args.visual_encoder
n_visual_tokens = args.n_visual_tokens
print(f"Using {opt_version} for the language model.")
print(f"Using {visual_encoder} for the visual model with {n_visual_tokens} visual tokens.")
if 'facebook/opt' in opt_version:
self.lm = OPTForCausalLM.from_pretrained(opt_version)
else:
raise NotImplementedError
self.opt_version = opt_version
if self.args.freeze_lm:
self.lm.eval()
print("Freezing the LM.")
for param in self.lm.parameters():
param.requires_grad = False
else:
self.lm.train()
self.retrieval_token_idx = args.retrieval_token_idx
self.gen_token_idx = args.gen_token_idx
self.lm.resize_token_embeddings(len(tokenizer))
self.input_embeddings = self.lm.get_input_embeddings()
print("Restoring pretrained weights for the visual model.")
if 'clip' in visual_encoder:
self.visual_model = CLIPVisionModel.from_pretrained(visual_encoder)
else:
self.visual_model = AutoModel.from_pretrained(visual_encoder)
if 'clip' in visual_encoder:
hidden_size = self.visual_model.config.hidden_size
else:
raise NotImplementedError
if self.args.freeze_vm:
print("Freezing the VM.")
self.visual_model.eval()
for param in self.visual_model.parameters():
param.requires_grad = False
else:
self.visual_model.train()
self.visual_model_name = visual_encoder
embedding_dim = self.input_embeddings.embedding_dim * self.args.n_visual_tokens
self.ret_text_hidden_fcs = nn.ModuleList([])
self.gen_text_hidden_fcs = nn.ModuleList([])
for layer_idx in self.args.text_emb_layers:
if (layer_idx == -1 or layer_idx == self.lm.config.num_hidden_layers) and ('bert' not in opt_version):
if 'opt' in opt_version: # OPT models
in_dim = self.lm.config.word_embed_proj_dim
else:
raise NotImplementedError
self.ret_text_hidden_fcs.append(
layers.TextFcLayer(in_dim, self.args.ret_emb_dim, num_input_tokens=self.args.num_tokens,
num_output_tokens=1, mode=self.args.ret_text_fc_mode))
self.gen_text_hidden_fcs.append(
layers.TextFcLayer(in_dim, self.args.gen_emb_dim, num_input_tokens=self.args.num_tokens,
num_output_tokens=self.args.num_clip_tokens, mode=self.args.text_fc_mode))
elif layer_idx < self.lm.config.num_hidden_layers:
self.ret_text_hidden_fcs.append(layers.TextFcLayer(self.lm.config.hidden_size, self.args.ret_emb_dim, num_input_tokens=self.args.num_tokens, num_output_tokens=1, mode=self.args.ret_text_fc_mode))
self.gen_text_hidden_fcs.append(layers.TextFcLayer(self.lm.config.hidden_size, self.args.gen_emb_dim, num_input_tokens=self.args.num_tokens, num_output_tokens=self.args.num_clip_tokens, mode=self.args.text_fc_mode))
else:
raise ValueError(f'Embedding of layer {layer_idx} was requested but model only has {self.lm.config.num_hidden_layers} layers.')
self.visual_embeddings = nn.Linear(hidden_size, embedding_dim)
# Retrieval image FC layer.
self.visual_fc = nn.Linear(hidden_size, self.args.ret_emb_dim)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def get_visual_embs(self, pixel_values: torch.FloatTensor, mode: str = 'captioning'):
if mode not in ['captioning', 'retrieval', 'generation']:
raise ValueError(f"mode should be one of ['captioning', 'retrieval', 'generation'], got {mode} instead.")
# Extract visual embeddings from the vision encoder.
if 'clip' in self.visual_model_name:
outputs = self.visual_model(pixel_values)
encoder_outputs = outputs.pooler_output
else:
raise NotImplementedError
# Use the correct fc based on function argument.
if mode == 'captioning':
visual_embs = self.visual_embeddings(encoder_outputs) # (2, D * n_visual_tokens)
visual_embs = torch.reshape(visual_embs, (visual_embs.shape[0], self.args.n_visual_tokens, -1))
elif mode == 'retrieval':
visual_embs = self.visual_fc(encoder_outputs) # (2, D * n_visual_tokens)
visual_embs = torch.reshape(visual_embs, (visual_embs.shape[0], 1, -1))
elif mode == 'generation':
visual_embs = torch.zeros((pixel_values.shape[0], 1, 768), device=pixel_values.device)
else:
raise NotImplementedError
return visual_embs
def train(self, mode=True):
super(GILLModel, self).train(mode=mode)
# Overwrite train() to ensure frozen models remain frozen.
if self.args.freeze_lm:
self.lm.eval()
if self.args.freeze_vm:
self.visual_model.eval()
def forward(
self,
pixel_values: torch.FloatTensor,
labels: Optional[torch.LongTensor] = None,
caption_len: Optional[torch.LongTensor] = None,
mode: str = 'captioning',
concat_captions: bool = False,
input_prefix: Optional[str] = None,
):
visual_embs = self.get_visual_embs(pixel_values, mode)
batch_size, vis_seq_len, _ = visual_embs.shape # vis_seq_len = n_visual_tokens
if labels is not None:
assert labels.shape[0] == batch_size, (visual_embs.shape, labels.shape)
visual_embs_norm = ((visual_embs ** 2).sum(dim=-1) ** 0.5).mean()
input_embs = self.input_embeddings(labels) # (N, T, D)
input_embs_norm = ((input_embs ** 2).sum(dim=-1) ** 0.5).mean()
last_embedding_idx = caption_len - 1 # -1 to retrieve the token before the eos token
if input_prefix is not None:
prompt_ids = self.tokenizer(input_prefix, add_special_tokens=False, return_tensors="pt").input_ids
prompt_ids = prompt_ids.to(visual_embs.device)
prompt_embs = self.input_embeddings(prompt_ids)
prompt_embs = prompt_embs.repeat(batch_size, 1, 1)
assert prompt_embs.shape[0] == batch_size, prompt_embs.shape
assert prompt_embs.shape[2] == input_embs.shape[2], prompt_embs.shape
assert len(prompt_embs.shape) == 3, prompt_embs.shape
if mode == 'captioning':
# Concat to text embeddings.
condition_seq_len = 0
if input_prefix is None:
# Just add visual embeddings.
input_embs = torch.cat([visual_embs, input_embs], axis=1)
last_embedding_idx += vis_seq_len
condition_seq_len += vis_seq_len
full_labels = torch.zeros(visual_embs.shape[:2], dtype=torch.int64).to(visual_embs.device) - 100
else:
print(f'Adding prefix "{input_prefix}" to captioning.')
# Add visual and prompt embeddings.
prefix_embs = torch.cat([visual_embs, prompt_embs], axis=1)
input_embs = torch.cat([prefix_embs, input_embs], axis=1)
last_embedding_idx += prefix_embs.shape[1]
condition_seq_len += prefix_embs.shape[1]
full_labels = torch.zeros(prefix_embs.shape[:2], dtype=torch.int64).to(visual_embs.device) - 100
# Mask out embedding tokens in the labels.
full_labels = torch.cat([full_labels, labels], axis=1)
pad_idx = []
for label in full_labels:
for k, token in enumerate(label):
# Mask out retrieval/gen tokens if they exist.
if token in [self.tokenizer.pad_token_id] + self.retrieval_token_idx + self.gen_token_idx:
label[k:] = -100
pad_idx.append(k)
break
if k == len(label) - 1: # No padding found.
pad_idx.append(k + 1)
assert len(pad_idx) == batch_size, (len(pad_idx), batch_size)
bs, seq_len, embs_dim = input_embs.shape
if concat_captions:
print('Concatenating examples for captioning!')
assert len(input_embs.shape) == 3, input_embs
assert len(full_labels.shape) == 2, full_labels
assert batch_size % 2 == 0
all_concat_input_embs = []
all_concat_labels = []
# Rearrange embeddings and labels (and their padding) to concatenate captions.
for i in range(batch_size // 2):
first_idx = i * 2
second_idx = first_idx + 1
first_emb = input_embs[first_idx, :pad_idx[first_idx], :]
first_labels = full_labels[first_idx, :pad_idx[first_idx]]
first_padding = input_embs[first_idx, pad_idx[first_idx]:, :]
first_labels_padding = full_labels[first_idx, pad_idx[first_idx]:]
second_emb = input_embs[second_idx, :pad_idx[second_idx], :]
second_labels = full_labels[second_idx, :pad_idx[second_idx]]
second_padding = input_embs[second_idx, pad_idx[second_idx]:, :]
second_labels_padding = full_labels[second_idx, pad_idx[second_idx]:]
bos_idx = visual_embs.shape[1]
assert torch.all(first_labels_padding == -100), first_labels_padding
assert torch.all(second_labels_padding == -100), second_labels_padding
assert torch.all(second_labels[bos_idx] == self.tokenizer.bos_token_id), (second_labels, bos_idx, self.tokenizer.bos_token_id)
# Remove BOS token of the second caption.
second_labels = torch.cat([second_labels[:bos_idx], second_labels[bos_idx + 1:]], axis=0)
second_emb = torch.cat([second_emb[:bos_idx, :], second_emb[bos_idx + 1:, :]], axis=0)
concat_input_embs = torch.cat([first_emb, second_emb, first_padding, second_padding], axis=0) # (T*2, 768)
concat_labels = torch.cat([first_labels, second_labels, first_labels_padding, second_labels_padding], axis=0) # (T*2, 768)
all_concat_input_embs.append(concat_input_embs)
all_concat_labels.append(concat_labels)
# Pad to max length.
input_embs = torch.stack(all_concat_input_embs, axis=0) # (N/2, T*2, 768)
full_labels = torch.stack(all_concat_labels, axis=0) # (N/2, T*2, 768)
print("Concatenated full_labels:", full_labels[0, ...])
assert input_embs.shape == (bs // 2, seq_len * 2 - 1, embs_dim), input_embs.shape
assert full_labels.shape == (bs // 2, seq_len * 2 - 1), full_labels.shape
output = self.lm(inputs_embeds=input_embs,
labels=full_labels,
output_hidden_states=True)
elif mode in ['retrieval', 'generation']:
full_labels = torch.clone(labels)
if input_prefix is not None:
print(f'Adding prefix "{input_prefix}" to retrieval.')
# Add prompt embeddings.
prefix_embs = prompt_embs
input_embs = torch.cat([prefix_embs, input_embs], axis=1)
last_embedding_idx += prefix_embs.shape[1]
full_labels = torch.cat([
torch.zeros(prefix_embs.shape[:2], dtype=torch.int64).to(labels.device) - 100,
full_labels
], axis=1)
pad_idx = []
for label in full_labels:
for k, token in enumerate(label):
if (token == self.tokenizer.pad_token_id):
label[k:] = -100
pad_idx.append(k)
break
if k == len(label) - 1: # No padding found.
pad_idx.append(k + 1)
assert len(pad_idx) == batch_size, (len(pad_idx), batch_size)
bs, seq_len, embs_dim = input_embs.shape
# Concatenate examples for captioning, if specified.
if concat_captions:
print(f'Concatenating examples for {mode}!')
assert len(input_embs.shape) == 3, input_embs
assert len(full_labels.shape) == 2, full_labels
assert batch_size % 2 == 0
all_concat_input_embs = []
all_concat_labels = []
all_last_embedding_idx = []
# Rearrange embeddings and labels (and their padding) to concatenate captions.
for i in range(batch_size // 2):
first_idx = i * 2
second_idx = first_idx + 1
first_emb = input_embs[first_idx, :pad_idx[first_idx], :]
first_labels = full_labels[first_idx, :pad_idx[first_idx]]
first_padding = input_embs[first_idx, pad_idx[first_idx]:, :]
first_labels_padding = full_labels[first_idx, pad_idx[first_idx]:]
second_emb = input_embs[second_idx, :pad_idx[second_idx], :]
second_labels = full_labels[second_idx, :pad_idx[second_idx]]
second_padding = input_embs[second_idx, pad_idx[second_idx]:, :]
second_labels_padding = full_labels[second_idx, pad_idx[second_idx]:]
bos_idx = 0
assert torch.all(first_labels_padding == -100), first_labels_padding
assert torch.all(second_labels_padding == -100), second_labels_padding
assert torch.all(second_labels[bos_idx] == self.tokenizer.bos_token_id), (second_labels, bos_idx, self.tokenizer.bos_token_id)
# Remove BOS token of second caption.
second_labels = second_labels[bos_idx + 1:]
second_emb = second_emb[bos_idx + 1:, :]
last_embedding_idx[second_idx] = last_embedding_idx[second_idx] - 1
concat_input_embs = torch.cat([first_emb, second_emb, first_padding, second_padding], axis=0) # (T*2, 768)
concat_labels = torch.cat([first_labels, second_labels, first_labels_padding, second_labels_padding], axis=0) # (T*2, 768)
all_concat_input_embs.append(concat_input_embs)
all_concat_labels.append(concat_labels)
all_last_embedding_idx.append((last_embedding_idx[first_idx], first_emb.shape[0] + last_embedding_idx[second_idx]))
if mode == 'retrieval':
assert concat_labels[all_last_embedding_idx[-1][0]] in self.retrieval_token_idx, (concat_labels, all_last_embedding_idx[-1][0])
assert concat_labels[all_last_embedding_idx[-1][1]] in self.retrieval_token_idx, (concat_labels, all_last_embedding_idx[-1][1])
elif mode == 'generation':
# Check that the last n tokens are GEN tokens.
for gen_i in range(len(self.gen_token_idx)):
assert concat_labels[all_last_embedding_idx[-1][0]-gen_i] == self.gen_token_idx[-gen_i-1], (concat_labels, all_last_embedding_idx[-1][0]-gen_i, self.gen_token_idx[-gen_i-1])
assert concat_labels[all_last_embedding_idx[-1][1]-gen_i] == self.gen_token_idx[-gen_i-1], (concat_labels, all_last_embedding_idx[-1][1]-gen_i, self.gen_token_idx[-gen_i-1])
# Pad to max length.
input_embs = torch.stack(all_concat_input_embs, axis=0) # (N/2, T*2, 768)
full_labels = torch.stack(all_concat_labels, axis=0) # (N/2, T*2, 768)
assert input_embs.shape == (bs // 2, seq_len * 2 - 1, embs_dim), input_embs.shape
assert full_labels.shape == (bs // 2, seq_len * 2 - 1), full_labels.shape
# Update labels to pad non-first tokens.
for label in full_labels:
for k, token in enumerate(label):
if (token == self.tokenizer.pad_token_id) or (token in (self.retrieval_token_idx[1:] + self.gen_token_idx[1:])):
label[k:] = -100
break
output = self.lm(inputs_embeds=input_embs,
labels=full_labels,
output_hidden_states=True)
else:
raise NotImplementedError
last_embedding = None
last_output_logit = None
hidden_states = []
llm_hidden_states = []
if mode in ['retrieval', 'generation']:
num_tokens = self.num_tokens
if mode == 'retrieval':
text_hidden_fcs = self.ret_text_hidden_fcs
else:
text_hidden_fcs = self.gen_text_hidden_fcs
# Concatenate captions for retrieval / generation, if specified.
if not concat_captions:
for idx, fc_layer in zip(self.args.text_emb_layers, text_hidden_fcs):
input_hidden_state = torch.stack([output.hidden_states[idx][i, last_embedding_idx[i]-num_tokens+1:last_embedding_idx[i]+1, :] for i in range(batch_size)], axis=0)
input_embedding = torch.stack([input_embs[i, last_embedding_idx[i]-num_tokens+1:last_embedding_idx[i]+1, :] for i in range(batch_size)], axis=0)
llm_hidden_states.append(input_hidden_state)
hidden_states.append(fc_layer(input_hidden_state, input_embedding)) # (N, seq_len, 2048)
else:
for idx, fc_layer in zip(self.args.text_emb_layers, text_hidden_fcs):
all_last_embedding = []
all_input_embedding = []
all_last_output_logit = []
for i in range(batch_size // 2):
first_last_embedding_idx, second_last_embedding_idx = all_last_embedding_idx[i]
first_last_embedding = output.hidden_states[idx][i, first_last_embedding_idx-num_tokens+1:first_last_embedding_idx+1, :] # (N, D)
second_last_embedding = output.hidden_states[idx][i, second_last_embedding_idx-num_tokens+1:second_last_embedding_idx+1, :] # (N, D)
all_last_embedding.append(first_last_embedding)
all_last_embedding.append(second_last_embedding)
first_input_embs = input_embs[i, first_last_embedding_idx-num_tokens+1:first_last_embedding_idx+1, :] # (N, D)
second_input_embs = input_embs[i, second_last_embedding_idx-num_tokens+1:second_last_embedding_idx+1, :] # (N, D)
all_input_embedding.append(first_input_embs)
all_input_embedding.append(second_input_embs)
first_last_output_logit = output.logits[i, first_last_embedding_idx - 1, :] # (N, D)
second_last_output_logit = output.logits[i, second_last_embedding_idx - 1, :] # (N, D)
all_last_output_logit.append(first_last_output_logit)
all_last_output_logit.append(second_last_output_logit)
last_embedding = torch.stack(all_last_embedding, axis=0)
input_embedding = torch.stack(all_input_embedding, axis=0)
last_output_logit = torch.stack(all_last_output_logit, axis=0)
llm_hidden_states.append(last_embedding)
hidden_states.append(fc_layer(last_embedding, input_embedding)) # (N, seq_len, 2048)
if not concat_captions:
# Add hidden states together.
last_embedding = torch.stack(hidden_states, dim=-1).sum(dim=-1) #torch.stack([last_hidden_state[i, :, :] for i in range(batch_size)], axis=0) # (N, T, D)
last_output_logit = torch.stack([output.logits[i, last_embedding_idx[i] - 1, :] for i in range(batch_size)], axis=0) # (N, D)
else:
# Add hidden states together.
last_embedding = torch.stack(hidden_states, dim=-1).sum(dim=-1)
# Compute retrieval loss.
if mode == 'retrieval':
assert visual_embs.shape[1] == 1, visual_embs.shape
assert last_embedding.shape[1] == 1, last_embedding.shape
visual_embs = visual_embs[:, 0, :]
visual_embs = visual_embs / visual_embs.norm(dim=1, keepdim=True)
last_embedding = last_embedding[:, 0, :]
last_embedding = last_embedding / last_embedding.norm(dim=1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
visual_embs = logit_scale * visual_embs
elif mode == 'captioning':
pass
else:
raise NotImplementedError
return output, full_labels, last_embedding, last_output_logit, visual_embs, visual_embs_norm, input_embs_norm, llm_hidden_states
def generate(self, embeddings = torch.FloatTensor, max_len: int = 32,
temperature: float = 0.0, top_p: float = 1.0, min_word_tokens: int = 0,
ret_scale_factor: float = 1.0, gen_scale_factor: float = 1.0,
filter_value: float = -float('Inf')):
"""Runs greedy decoding and returns generated captions.
Args:
min_word_tokens: Minimum number of words to generate before allowing a [IMG] output.
filter_value: Value to assign to tokens that should never be generated.
Outputs:
out: (N, T) int32 sequence of output tokens.
output_embeddings: (N, T, 256) sequence of text output embeddings.
"""
self.lm.eval()
with torch.no_grad(): # no tracking history
# init output with image tokens
out = None
output_embeddings = []
output_logits = []
for i in range(max_len):
output = self.lm(inputs_embeds=embeddings, use_cache=False, output_hidden_states=True)
for idx in self.args.text_emb_layers:
output_embeddings.append(output.hidden_states[idx])
logits = output.logits[:, -1, :] # (N, vocab_size)
if top_p == 1.0:
logits = logits.cpu()
output_logits.append(logits)
# Prevent the model from generating the [IMG1..n] tokens.
logits[:, self.retrieval_token_idx[1:]] = filter_value
logits[:, self.gen_token_idx[1:]] = filter_value
if (self.retrieval_token_idx or self.gen_token_idx) and self.retrieval_token_idx[0] != -1 and self.gen_token_idx[0] != -1:
if i < min_word_tokens:
# Eliminate probability of generating [IMG] if this is earlier than min_word_tokens.
logits[:, self.retrieval_token_idx] = filter_value
logits[:, self.gen_token_idx] = filter_value
else:
# Multiply by scaling factor.
if ret_scale_factor > 1:
logits[:, self.retrieval_token_idx[0]] = logits[:, self.retrieval_token_idx[0]].abs() * ret_scale_factor
if gen_scale_factor > 1:
logits[:, self.gen_token_idx[0]] = logits[:, self.gen_token_idx[0]].abs() * gen_scale_factor
if temperature == 0.0:
if top_p != 1.0:
raise ValueError('top_p cannot be set if temperature is 0 (greedy decoding).')
next_token = torch.argmax(logits, keepdim=True, dim=-1) # (N, 1)
else:
logits = logits / temperature
# Apply top-p filtering.
if top_p < 1.0:
assert top_p > 0, f'top_p should be above 0, got {top_p} instead.'
sorted_logits, sorted_indices = torch.sort(logits, descending=True) # (N, D) and (N, D)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # (N, D)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
for j in range(sorted_indices.shape[0]):
indices_to_remove = sorted_indices[j, sorted_indices_to_remove[j, :]]
logits[j, indices_to_remove] = filter_value
token_weights = logits.exp() # (N, vocab_size)
next_token = torch.multinomial(token_weights, 1) # (N, 1)
# Force generation of the remaining [IMG] tokens if [IMG0] is generated.
if next_token.shape[0] == 1 and next_token.item() == self.retrieval_token_idx[0]:
assert self.retrieval_token_idx == self.gen_token_idx, (self.retrieval_token_idx, self.gen_token_idx)
next_token = torch.tensor(self.retrieval_token_idx)[None, :].long().to(embeddings.device) # (1, num_tokens)
else:
next_token = next_token.long().to(embeddings.device)
if out is not None:
out = torch.cat([out, next_token], dim=-1)
else:
out = next_token
next_embedding = self.input_embeddings(next_token)
embeddings = torch.cat([embeddings, next_embedding], dim=1)
return out, output_embeddings, output_logits
class GILL(nn.Module):
def __init__(self, tokenizer, model_args: Optional[GILLArgs] = None,
path_array: Optional[List[str]] = None, emb_matrix: Optional[torch.tensor] = None,
load_sd: bool = False, num_gen_images: int = 1, decision_model_path: Optional[str] = None):
super().__init__()
self.model = GILLModel(tokenizer, model_args)
self.path_array = path_array
self.emb_matrix = emb_matrix
self.load_sd = load_sd
self.num_gen_images = num_gen_images
self.idx2dec = {0: 'gen', 1: 'ret', 2: 'same'}
self.decision_model = None
# Load the Stable Diffusion model.
if load_sd:
model_id = "runwayml/stable-diffusion-v1-5"
if torch.cuda.is_available():
self.sd_pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
else:
self.sd_pipe = StableDiffusionPipeline.from_pretrained(model_id)
if decision_model_path is not None:
print('Loading decision model...')
self.decision_model = nn.Sequential(*[
nn.Dropout(0.5),
nn.Linear(4097, 2),
])
if torch.cuda.is_available():
mlp_checkpoint = torch.load(decision_model_path)
else:
mlp_checkpoint = torch.load(decision_model_path, map_location=torch.device('cpu'))
self.decision_model.load_state_dict(mlp_checkpoint['state_dict'], strict=True)
self.decision_model.eval()
def __call__(self, images: Tensor, tgt_tokens: Optional[Tensor] = None, caption_len: Optional[Tensor] = None,
generate: bool = False, num_words: int = 32, temperature: float = 1.0, top_p: float = 1.0,
ret_scale_factor: float = 1.0, gen_scale_factor: float = 1.0,
min_word_tokens: int = 0, mode: str = 'captioning', concat_captions: bool = False,
input_prefix: Optional[str] = None) -> Tensor:
if generate:
return self.model.generate(images, num_words, temperature=temperature, top_p=top_p,
min_word_tokens=min_word_tokens, ret_scale_factor=ret_scale_factor,
gen_scale_factor=gen_scale_factor)
else:
output = self.model(
pixel_values = images,
labels = tgt_tokens,
caption_len = caption_len,
mode = mode,
concat_captions = concat_captions,
input_prefix = input_prefix)
return output
def generate_for_images_and_texts(
self, prompts: List, num_words: int = 0, min_word_tokens: int = 0, ret_scale_factor: float = 1.0, gen_scale_factor: float = 1.0,
top_p: float = 1.0, temperature: float = 0.0, max_num_rets: int = 1, generator=None,
always_add_bos : bool = False, guidance_scale: float = 7.5, num_inference_steps: int = 50):
"""
Encode prompts into embeddings, and generates text and image outputs accordingly.
Args:
prompts: List of interleaved PIL.Image.Image and strings representing input to the model.
num_words: Maximum number of words to generate for. If num_words = 0, the model will run its forward pass and return the outputs.
min_word_tokens: Minimum number of actual words before generating an image.
ret_scale_factor: Proportion to scale [IMG] token logits by. A higher value may increase the probability of the model generating [IMG] outputs.
top_p: If set to < 1, the smallest set of tokens with highest probabilities that add up to top_p or higher are kept for generation.
temperature: Used to modulate logit distribution.
max_num_rets: Maximum number of images to return in one generation pass.
Returns:
return_outputs: List consisting of either str or List[PIL.Image.Image] objects, representing image-text interleaved model outputs.
"""
input_embs = []
input_ids = []
add_bos = True
with torch.no_grad():
for p in prompts:
if type(p) == Image.Image:
# Encode as image.
pixel_values = utils.get_pixel_values_for_model(self.model.feature_extractor, p)
pixel_values = pixel_values.to(device=self.model.logit_scale.device, dtype=self.model.logit_scale.dtype)
pixel_values = pixel_values[None, ...]
visual_embs = self.model.get_visual_embs(pixel_values, mode='captioning') # (1, n_visual_tokens, D)
input_embs.append(visual_embs)
elif type(p) == str:
text_ids = self.model.tokenizer(p, add_special_tokens=add_bos, return_tensors="pt").input_ids.to(self.model.logit_scale.device)
# Only add <bos> once unless the flag is set.
if not always_add_bos:
add_bos = False
text_embs = self.model.input_embeddings(text_ids) # (1, T, D)
input_embs.append(text_embs)
input_ids.append(text_ids)
else:
raise ValueError(f'Input prompts should be either PIL.Image.Image or str types, got {type(p)} instead.')
input_embs = torch.cat(input_embs, dim=1)
input_ids = torch.cat(input_ids, dim=1)
if num_words == 0:
raise NotImplementedError('Generation not implemented for num_words=0.')
elif num_words > 0:
generated_ids, generated_embeddings, _ = self.model.generate(input_embs, num_words, min_word_tokens=min_word_tokens,
temperature=temperature, top_p=top_p, ret_scale_factor=ret_scale_factor, gen_scale_factor=gen_scale_factor)
embeddings = generated_embeddings[-1][:, input_embs.shape[1]:]
# Truncate to newline.
newline_token_id = self.model.tokenizer('\n', add_special_tokens=False).input_ids[0]
trunc_idx = 0
for j in range(generated_ids.shape[1]):
if generated_ids[0, j] == newline_token_id:
trunc_idx = j
break
if trunc_idx > 0:
generated_ids = generated_ids[:, :trunc_idx]
embeddings = embeddings[:, :trunc_idx]
else:
raise ValueError
# Save outputs as an interleaved list.
return_outputs = []
# Find up to max_num_rets [IMG] tokens, and their corresponding scores.
all_ret_idx = [i for i, x in enumerate(generated_ids[0, :] == self.model.retrieval_token_idx[0]) if x][:max_num_rets]
seen_image_idx = [] # Avoid showing the same image multiple times.
last_ret_idx = 0
if len(all_ret_idx) == 0:
# No [IMG] tokens.
caption = self.model.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return_outputs.append(utils.truncate_caption(caption))
else:
for ret_idx in all_ret_idx:
assert generated_ids[0, ret_idx:ret_idx+self.model.num_tokens].cpu().detach().numpy().tolist() == self.model.retrieval_token_idx, (generated_ids[0, ret_idx:ret_idx+self.model.num_tokens], self.model.retrieval_token_idx)
raw_emb = embeddings[:, ret_idx:ret_idx+self.model.num_tokens, :] # (1, 8, 4096)
assert len(self.model.args.text_emb_layers) == 1
image_outputs = {
'gen': [],
'ret': [],
'decision': None,
}
if self.emb_matrix is not None:
# Produce retrieval embedding.
ret_emb = self.model.ret_text_hidden_fcs[0](raw_emb, None)[:, 0, :] # (1, 256)
ret_emb = ret_emb / ret_emb.norm(dim=-1, keepdim=True)
ret_emb = ret_emb.type(self.emb_matrix.dtype) # (1, 256)
scores = self.emb_matrix @ ret_emb.T
# Downweight seen images.
for seen_idx in seen_image_idx:
scores[seen_idx, :] -= 1000
# Get the top 3 images for each image.
_, top_image_idx = scores.squeeze().topk(3)
for img_idx in top_image_idx:
# Find the first image that does not error out.
try:
seen_image_idx.append(img_idx)
img = utils.get_image_from_url(self.path_array[img_idx])
image_outputs['ret'].append((img, 'ret', scores[img_idx].item()))
if len(image_outputs) == max_num_rets:
break
except (UnidentifiedImageError, ConnectionError):
pass
# Make decision with MLP.
if self.decision_model is not None:
decision_emb = raw_emb[:, 0, :] # (1, 4096)
assert decision_emb.shape[1] == 4096, decision_emb.shape
max_ret_score = scores.max().reshape((1, 1)).clone().detach().to(device=decision_emb.device, dtype=decision_emb.dtype)
decision_logits = self.decision_model(torch.cat([decision_emb, max_ret_score], dim=-1))
probs = decision_logits.softmax(dim=-1).cpu().float().numpy().tolist()
image_outputs['decision'] = [self.idx2dec[decision_logits.argmax().item()]] + probs
else:
# If no embedding matrix is provided, generate instead.
image_outputs['decision'] = ['gen', [0, 1]]
# Produce generation embedding.
gen_prefix = ' '.join([f'[IMG{i}]' for i in range(self.model.args.num_tokens)])
gen_prefx_ids = self.model.tokenizer(gen_prefix, add_special_tokens=False, return_tensors="pt").input_ids.to(self.model.logit_scale.device)
gen_prefix_embs = self.model.input_embeddings(gen_prefx_ids) # (1, T, D)
gen_emb = self.model.gen_text_hidden_fcs[0](raw_emb, gen_prefix_embs) # (1, 77, 768)
if gen_emb.shape[1] != 77:
print(f"Padding {gen_emb.shape} with zeros")
bs = gen_emb.shape[0]
clip_emb = 768
gen_emb = gen_emb.reshape(bs, -1, clip_emb) # (bs, T, 768)
seq_len = gen_emb.shape[1]
gen_emb = torch.cat([gen_emb, torch.zeros((bs, 77 - seq_len, clip_emb), device=gen_emb.device, dtype=gen_emb.dtype)], dim=1)
print('Padded to', gen_emb.shape)
gen_emb = gen_emb.repeat(self.num_gen_images, 1, 1) # (self.num_gen_images, 77, 768)
# Only generate if we are showing a generated image.
if self.load_sd and image_outputs['decision'][0] == 'gen':
# If num_gen_images > 8, split into multiple batches (for GPU memory reasons).
gen_max_bs = 8
gen_images = []
for i in range(0, self.num_gen_images, gen_max_bs):
gen_images.extend(
self.sd_pipe(prompt_embeds=gen_emb[i:i+gen_max_bs], generator=generator,
guidance_scale=guidance_scale, num_inference_steps=num_inference_steps).images)
all_gen_pixels = []
for img in gen_images:
pixel_values = utils.get_pixel_values_for_model(self.model.feature_extractor, img.resize((224, 224)).convert('RGB'))
pixel_values = pixel_values.to(device=self.model.logit_scale.device, dtype=self.model.logit_scale.dtype)
all_gen_pixels.append(pixel_values)
if self.emb_matrix is not None:
all_gen_pixels = torch.stack(all_gen_pixels, dim=0)
gen_visual_embs = self.model.get_visual_embs(all_gen_pixels, mode='retrieval') # (1, D)
gen_visual_embs = gen_visual_embs / gen_visual_embs.norm(dim=-1, keepdim=True)
gen_visual_embs = gen_visual_embs.type(self.emb_matrix.dtype)
gen_rank_scores = (gen_visual_embs @ ret_emb.T).squeeze()
sorted_score_idx = torch.argsort(-gen_rank_scores)
# Rank images by retrieval score.
if self.num_gen_images > 1:
image_outputs['gen'] = [(gen_images[idx], gen_rank_scores[idx].item()) for idx in sorted_score_idx]
else:
image_outputs['gen'] = [(gen_images[0], gen_rank_scores.item())]
else:
image_outputs['gen'] = [(gen_images[0], 0)]
else:
image_outputs['gen'] = [gen_emb]
caption = self.model.tokenizer.batch_decode(generated_ids[:, last_ret_idx:ret_idx], skip_special_tokens=True)[0]
last_ret_idx = ret_idx + 1
return_outputs.append(utils.truncate_caption(caption) + f' {gen_prefix}')
return_outputs.append(image_outputs)
return return_outputs
def get_log_likelihood_scores(
self, prompts: List):
"""
Output the log likelihood of the given interleaved prompts.
Args:
prompts: List of interleaved PIL.Image.Image and strings representing input to the model.
Returns:
Log likelihood score of the prompt sequence.
"""
input_embs = []
input_ids = []
add_bos = True
for p in prompts:
if type(p) == Image.Image:
# Encode as image.
pixel_values = utils.get_pixel_values_for_model(self.model.feature_extractor, p)
pixel_values = pixel_values.to(device=self.model.logit_scale.device, dtype=self.model.logit_scale.dtype)
pixel_values = pixel_values[None, ...]
visual_embs = self.model.get_visual_embs(pixel_values, mode='captioning') # (1, n_visual_tokens, D)
input_embs.append(visual_embs)
id_ = torch.zeros(visual_embs.shape[:2], dtype=torch.int64).to(visual_embs.device) - 100
input_ids.append(id_)
elif type(p) == str:
text_ids = self.model.tokenizer(p, add_special_tokens=True, return_tensors="pt").input_ids.to(self.model.logit_scale.device)
if not add_bos:
# Remove <bos> tag.
text_ids = text_ids[:, 1:]
else:
# Only add <bos> once.
add_bos = False
text_embs = self.model.input_embeddings(text_ids) # (1, T, D)
input_embs.append(text_embs)
input_ids.append(text_ids)
else:
raise ValueError(f'Input prompts should be either PIL.Image.Image or str types, got {type(p)} instead.')
input_embs = torch.cat(input_embs, dim=1)
input_ids = torch.cat(input_ids, dim=1)
outputs = self.model.lm(inputs_embeds=input_embs, labels=input_ids, use_cache=False, output_hidden_states=True)
return -outputs.loss.item()
def load_gill(embeddings_dir: str, model_args_path: str, model_ckpt_path: str, decision_model_path: str) -> GILL:
embs_paths = [s for s in glob.glob(os.path.join(embeddings_dir, 'cc3m*.npy'))]
if not os.path.exists(model_args_path):
raise ValueError(f'model_args.json does not exist at {model_args_path}.')
if not os.path.exists(model_ckpt_path):
raise ValueError(f'pretrained_ckpt.pth.tar does not exist at {model_ckpt_path}.')
if len(embs_paths) == 0:
print(f'cc3m*.npy files do not exist in {embeddings_dir}. Running the model without retrieval.')
path_array, emb_matrix = None, None
else:
# Load embeddings.
# Construct embedding matrix for nearest neighbor lookup.
path_array = []
emb_matrix = []
# These were precomputed for all CC3M images with `model.get_visual_embs(image, mode='retrieval')`.
for p in embs_paths:
with open(p, 'rb') as wf:
train_embs_data = pkl.load(wf)
path_array.extend(train_embs_data['paths'])
emb_matrix.extend(train_embs_data['embeddings'])
emb_matrix = np.stack(emb_matrix, axis=0)
# Number of paths should be equal to number of embeddings.
assert len(path_array) == emb_matrix.shape[0], (len(path_array), emb_matrix.shape)
with open(model_args_path, 'r') as f:
model_kwargs = json.load(f)
# Initialize tokenizer.
tokenizer = AutoTokenizer.from_pretrained(model_kwargs['opt_version'], use_fast=False)
if tokenizer.pad_token is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
# Add an image token for loss masking (and visualization) purposes.
tokenizer.add_special_tokens({"cls_token": "<|image|>"}) # add special image token to tokenizer
# Add [IMG] tokens to the vocabulary.
model_kwargs['retrieval_token_idx'] = []
for i in range(model_kwargs['num_tokens']):
print(f'Adding [IMG{i}] token to vocabulary.')
print(f'Before adding new token, tokenizer("[IMG{i}]") =', tokenizer(f'[IMG{i}]', add_special_tokens=False))
num_added_tokens = tokenizer.add_tokens(f'[IMG{i}]')
print(f'After adding {num_added_tokens} new tokens, tokenizer("[IMG{i}]") =', tokenizer(f'[IMG{i}]', add_special_tokens=False))
ret_token_idx = tokenizer(f'[IMG{i}]', add_special_tokens=False).input_ids
assert len(ret_token_idx) == 1, ret_token_idx
model_kwargs['retrieval_token_idx'].append(ret_token_idx[0])
# Use the same RET tokens for generation.
model_kwargs['gen_token_idx'] = model_kwargs['retrieval_token_idx']
debug = False
if debug:
model_kwargs['opt_version'] = 'facebook/opt-125m'
model_kwargs['visual_encoder'] = 'openai/clip-vit-base-patch32'
decision_model_path = None
args = namedtuple('args', model_kwargs)(**model_kwargs)
# Initialize model for inference.
model = GILL(tokenizer, args, path_array=path_array, emb_matrix=emb_matrix,
load_sd=not debug, num_gen_images=1, decision_model_path=decision_model_path)
model = model.eval()
if torch.cuda.is_available():
model = model.bfloat16().cuda()
if not debug:
# Load pretrained linear mappings and [IMG] embeddings.
checkpoint = torch.load(model_ckpt_path)
state_dict = {}
# This is needed if we train with DDP.
for k, v in checkpoint['state_dict'].items():
state_dict[k.replace('module.', '')] = v
img_token_embeddings = state_dict['model.input_embeddings.weight'].cpu().detach()
del state_dict['model.input_embeddings.weight']
model.load_state_dict(state_dict, strict=False)
# Copy over the embeddings of the [IMG] tokens (while loading the others from the pretrained LLM).
with torch.no_grad():
if 'share_ret_gen' in model_kwargs:
assert model_kwargs['share_ret_gen'], 'Model loading only supports share_ret_gen=True for now.'
model.model.input_embeddings.weight[-model_kwargs['num_tokens']:, :].copy_(img_token_embeddings)
if len(embs_paths) > 0:
logit_scale = model.model.logit_scale.exp()
emb_matrix = torch.tensor(emb_matrix, dtype=logit_scale.dtype).to(logit_scale.device)
emb_matrix = emb_matrix / emb_matrix.norm(dim=1, keepdim=True)
emb_matrix = logit_scale * emb_matrix
model.emb_matrix = emb_matrix
return model