from typing import List, Optional from collections import namedtuple from diffusers import StableDiffusionPipeline import json import numpy as np import os import glob import torch from torch import Tensor import torch.nn as nn import torch.nn.functional as F import pickle as pkl from PIL import Image, UnidentifiedImageError from requests.exceptions import ConnectionError from transformers import AutoTokenizer, AutoModel, CLIPVisionModel, OPTForCausalLM from gill import utils from gill import layers class GILLArgs: freeze_lm: bool = True freeze_vm: bool = True opt_version: str = 'facebook/opt-6.7b' visual_encoder: str = 'openai/clip-vit-large-patch14' n_visual_tokens: int = 1 task: str = 'captioning' ret_emb_dim: Optional[int] = 256 gen_emb_dim: Optional[int] = 256 text_emb_layers: List[int] = [-1] gen_token_idx: List[int] = [0] retrieval_token_idx: List[int] = [0] text_fc_mode: str = 'gill_mapper' ret_text_fc_mode: str = 'linear' num_tokens: int = 8 num_clip_tokens: int = 77 class GILLModel(nn.Module): def __init__(self, tokenizer, args: GILLArgs = GILLArgs()): super().__init__() self.tokenizer = tokenizer self.feature_extractor = utils.get_feature_extractor_for_model(args.visual_encoder, train=False) self.image_token = self.tokenizer.cls_token_id assert args.text_emb_layers != set(args.text_emb_layers), 'text_emb_layers not unique' self.args = args self.num_tokens = args.num_tokens self.num_clip_tokens = args.num_clip_tokens opt_version = args.opt_version visual_encoder = args.visual_encoder n_visual_tokens = args.n_visual_tokens print(f"Using {opt_version} for the language model.") print(f"Using {visual_encoder} for the visual model with {n_visual_tokens} visual tokens.") if 'facebook/opt' in opt_version: self.lm = OPTForCausalLM.from_pretrained(opt_version) else: raise NotImplementedError self.opt_version = opt_version if self.args.freeze_lm: self.lm.eval() print("Freezing the LM.") for param in self.lm.parameters(): param.requires_grad = False else: self.lm.train() self.retrieval_token_idx = args.retrieval_token_idx self.gen_token_idx = args.gen_token_idx self.lm.resize_token_embeddings(len(tokenizer)) self.input_embeddings = self.lm.get_input_embeddings() print("Restoring pretrained weights for the visual model.") if 'clip' in visual_encoder: self.visual_model = CLIPVisionModel.from_pretrained(visual_encoder) else: self.visual_model = AutoModel.from_pretrained(visual_encoder) if 'clip' in visual_encoder: hidden_size = self.visual_model.config.hidden_size else: raise NotImplementedError if self.args.freeze_vm: print("Freezing the VM.") self.visual_model.eval() for param in self.visual_model.parameters(): param.requires_grad = False else: self.visual_model.train() self.visual_model_name = visual_encoder embedding_dim = self.input_embeddings.embedding_dim * self.args.n_visual_tokens self.ret_text_hidden_fcs = nn.ModuleList([]) self.gen_text_hidden_fcs = nn.ModuleList([]) for layer_idx in self.args.text_emb_layers: if (layer_idx == -1 or layer_idx == self.lm.config.num_hidden_layers) and ('bert' not in opt_version): if 'opt' in opt_version: # OPT models in_dim = self.lm.config.word_embed_proj_dim else: raise NotImplementedError self.ret_text_hidden_fcs.append( layers.TextFcLayer(in_dim, self.args.ret_emb_dim, num_input_tokens=self.args.num_tokens, num_output_tokens=1, mode=self.args.ret_text_fc_mode)) self.gen_text_hidden_fcs.append( layers.TextFcLayer(in_dim, self.args.gen_emb_dim, num_input_tokens=self.args.num_tokens, num_output_tokens=self.args.num_clip_tokens, mode=self.args.text_fc_mode)) elif layer_idx < self.lm.config.num_hidden_layers: self.ret_text_hidden_fcs.append(layers.TextFcLayer(self.lm.config.hidden_size, self.args.ret_emb_dim, num_input_tokens=self.args.num_tokens, num_output_tokens=1, mode=self.args.ret_text_fc_mode)) self.gen_text_hidden_fcs.append(layers.TextFcLayer(self.lm.config.hidden_size, self.args.gen_emb_dim, num_input_tokens=self.args.num_tokens, num_output_tokens=self.args.num_clip_tokens, mode=self.args.text_fc_mode)) else: raise ValueError(f'Embedding of layer {layer_idx} was requested but model only has {self.lm.config.num_hidden_layers} layers.') self.visual_embeddings = nn.Linear(hidden_size, embedding_dim) # Retrieval image FC layer. self.visual_fc = nn.Linear(hidden_size, self.args.ret_emb_dim) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) def get_visual_embs(self, pixel_values: torch.FloatTensor, mode: str = 'captioning'): if mode not in ['captioning', 'retrieval', 'generation']: raise ValueError(f"mode should be one of ['captioning', 'retrieval', 'generation'], got {mode} instead.") # Extract visual embeddings from the vision encoder. if 'clip' in self.visual_model_name: outputs = self.visual_model(pixel_values) encoder_outputs = outputs.pooler_output else: raise NotImplementedError # Use the correct fc based on function argument. if mode == 'captioning': visual_embs = self.visual_embeddings(encoder_outputs) # (2, D * n_visual_tokens) visual_embs = torch.reshape(visual_embs, (visual_embs.shape[0], self.args.n_visual_tokens, -1)) elif mode == 'retrieval': visual_embs = self.visual_fc(encoder_outputs) # (2, D * n_visual_tokens) visual_embs = torch.reshape(visual_embs, (visual_embs.shape[0], 1, -1)) elif mode == 'generation': visual_embs = torch.zeros((pixel_values.shape[0], 1, 768), device=pixel_values.device) else: raise NotImplementedError return visual_embs def train(self, mode=True): super(GILLModel, self).train(mode=mode) # Overwrite train() to ensure frozen models remain frozen. if self.args.freeze_lm: self.lm.eval() if self.args.freeze_vm: self.visual_model.eval() def forward( self, pixel_values: torch.FloatTensor, labels: Optional[torch.LongTensor] = None, caption_len: Optional[torch.LongTensor] = None, mode: str = 'captioning', concat_captions: bool = False, input_prefix: Optional[str] = None, ): visual_embs = self.get_visual_embs(pixel_values, mode) batch_size, vis_seq_len, _ = visual_embs.shape # vis_seq_len = n_visual_tokens if labels is not None: assert labels.shape[0] == batch_size, (visual_embs.shape, labels.shape) visual_embs_norm = ((visual_embs ** 2).sum(dim=-1) ** 0.5).mean() input_embs = self.input_embeddings(labels) # (N, T, D) input_embs_norm = ((input_embs ** 2).sum(dim=-1) ** 0.5).mean() last_embedding_idx = caption_len - 1 # -1 to retrieve the token before the eos token if input_prefix is not None: prompt_ids = self.tokenizer(input_prefix, add_special_tokens=False, return_tensors="pt").input_ids prompt_ids = prompt_ids.to(visual_embs.device) prompt_embs = self.input_embeddings(prompt_ids) prompt_embs = prompt_embs.repeat(batch_size, 1, 1) assert prompt_embs.shape[0] == batch_size, prompt_embs.shape assert prompt_embs.shape[2] == input_embs.shape[2], prompt_embs.shape assert len(prompt_embs.shape) == 3, prompt_embs.shape if mode == 'captioning': # Concat to text embeddings. condition_seq_len = 0 if input_prefix is None: # Just add visual embeddings. input_embs = torch.cat([visual_embs, input_embs], axis=1) last_embedding_idx += vis_seq_len condition_seq_len += vis_seq_len full_labels = torch.zeros(visual_embs.shape[:2], dtype=torch.int64).to(visual_embs.device) - 100 else: print(f'Adding prefix "{input_prefix}" to captioning.') # Add visual and prompt embeddings. prefix_embs = torch.cat([visual_embs, prompt_embs], axis=1) input_embs = torch.cat([prefix_embs, input_embs], axis=1) last_embedding_idx += prefix_embs.shape[1] condition_seq_len += prefix_embs.shape[1] full_labels = torch.zeros(prefix_embs.shape[:2], dtype=torch.int64).to(visual_embs.device) - 100 # Mask out embedding tokens in the labels. full_labels = torch.cat([full_labels, labels], axis=1) pad_idx = [] for label in full_labels: for k, token in enumerate(label): # Mask out retrieval/gen tokens if they exist. if token in [self.tokenizer.pad_token_id] + self.retrieval_token_idx + self.gen_token_idx: label[k:] = -100 pad_idx.append(k) break if k == len(label) - 1: # No padding found. pad_idx.append(k + 1) assert len(pad_idx) == batch_size, (len(pad_idx), batch_size) bs, seq_len, embs_dim = input_embs.shape if concat_captions: print('Concatenating examples for captioning!') assert len(input_embs.shape) == 3, input_embs assert len(full_labels.shape) == 2, full_labels assert batch_size % 2 == 0 all_concat_input_embs = [] all_concat_labels = [] # Rearrange embeddings and labels (and their padding) to concatenate captions. for i in range(batch_size // 2): first_idx = i * 2 second_idx = first_idx + 1 first_emb = input_embs[first_idx, :pad_idx[first_idx], :] first_labels = full_labels[first_idx, :pad_idx[first_idx]] first_padding = input_embs[first_idx, pad_idx[first_idx]:, :] first_labels_padding = full_labels[first_idx, pad_idx[first_idx]:] second_emb = input_embs[second_idx, :pad_idx[second_idx], :] second_labels = full_labels[second_idx, :pad_idx[second_idx]] second_padding = input_embs[second_idx, pad_idx[second_idx]:, :] second_labels_padding = full_labels[second_idx, pad_idx[second_idx]:] bos_idx = visual_embs.shape[1] assert torch.all(first_labels_padding == -100), first_labels_padding assert torch.all(second_labels_padding == -100), second_labels_padding assert torch.all(second_labels[bos_idx] == self.tokenizer.bos_token_id), (second_labels, bos_idx, self.tokenizer.bos_token_id) # Remove BOS token of the second caption. second_labels = torch.cat([second_labels[:bos_idx], second_labels[bos_idx + 1:]], axis=0) second_emb = torch.cat([second_emb[:bos_idx, :], second_emb[bos_idx + 1:, :]], axis=0) concat_input_embs = torch.cat([first_emb, second_emb, first_padding, second_padding], axis=0) # (T*2, 768) concat_labels = torch.cat([first_labels, second_labels, first_labels_padding, second_labels_padding], axis=0) # (T*2, 768) all_concat_input_embs.append(concat_input_embs) all_concat_labels.append(concat_labels) # Pad to max length. input_embs = torch.stack(all_concat_input_embs, axis=0) # (N/2, T*2, 768) full_labels = torch.stack(all_concat_labels, axis=0) # (N/2, T*2, 768) print("Concatenated full_labels:", full_labels[0, ...]) assert input_embs.shape == (bs // 2, seq_len * 2 - 1, embs_dim), input_embs.shape assert full_labels.shape == (bs // 2, seq_len * 2 - 1), full_labels.shape output = self.lm(inputs_embeds=input_embs, labels=full_labels, output_hidden_states=True) elif mode in ['retrieval', 'generation']: full_labels = torch.clone(labels) if input_prefix is not None: print(f'Adding prefix "{input_prefix}" to retrieval.') # Add prompt embeddings. prefix_embs = prompt_embs input_embs = torch.cat([prefix_embs, input_embs], axis=1) last_embedding_idx += prefix_embs.shape[1] full_labels = torch.cat([ torch.zeros(prefix_embs.shape[:2], dtype=torch.int64).to(labels.device) - 100, full_labels ], axis=1) pad_idx = [] for label in full_labels: for k, token in enumerate(label): if (token == self.tokenizer.pad_token_id): label[k:] = -100 pad_idx.append(k) break if k == len(label) - 1: # No padding found. pad_idx.append(k + 1) assert len(pad_idx) == batch_size, (len(pad_idx), batch_size) bs, seq_len, embs_dim = input_embs.shape # Concatenate examples for captioning, if specified. if concat_captions: print(f'Concatenating examples for {mode}!') assert len(input_embs.shape) == 3, input_embs assert len(full_labels.shape) == 2, full_labels assert batch_size % 2 == 0 all_concat_input_embs = [] all_concat_labels = [] all_last_embedding_idx = [] # Rearrange embeddings and labels (and their padding) to concatenate captions. for i in range(batch_size // 2): first_idx = i * 2 second_idx = first_idx + 1 first_emb = input_embs[first_idx, :pad_idx[first_idx], :] first_labels = full_labels[first_idx, :pad_idx[first_idx]] first_padding = input_embs[first_idx, pad_idx[first_idx]:, :] first_labels_padding = full_labels[first_idx, pad_idx[first_idx]:] second_emb = input_embs[second_idx, :pad_idx[second_idx], :] second_labels = full_labels[second_idx, :pad_idx[second_idx]] second_padding = input_embs[second_idx, pad_idx[second_idx]:, :] second_labels_padding = full_labels[second_idx, pad_idx[second_idx]:] bos_idx = 0 assert torch.all(first_labels_padding == -100), first_labels_padding assert torch.all(second_labels_padding == -100), second_labels_padding assert torch.all(second_labels[bos_idx] == self.tokenizer.bos_token_id), (second_labels, bos_idx, self.tokenizer.bos_token_id) # Remove BOS token of second caption. second_labels = second_labels[bos_idx + 1:] second_emb = second_emb[bos_idx + 1:, :] last_embedding_idx[second_idx] = last_embedding_idx[second_idx] - 1 concat_input_embs = torch.cat([first_emb, second_emb, first_padding, second_padding], axis=0) # (T*2, 768) concat_labels = torch.cat([first_labels, second_labels, first_labels_padding, second_labels_padding], axis=0) # (T*2, 768) all_concat_input_embs.append(concat_input_embs) all_concat_labels.append(concat_labels) all_last_embedding_idx.append((last_embedding_idx[first_idx], first_emb.shape[0] + last_embedding_idx[second_idx])) if mode == 'retrieval': assert concat_labels[all_last_embedding_idx[-1][0]] in self.retrieval_token_idx, (concat_labels, all_last_embedding_idx[-1][0]) assert concat_labels[all_last_embedding_idx[-1][1]] in self.retrieval_token_idx, (concat_labels, all_last_embedding_idx[-1][1]) elif mode == 'generation': # Check that the last n tokens are GEN tokens. for gen_i in range(len(self.gen_token_idx)): assert concat_labels[all_last_embedding_idx[-1][0]-gen_i] == self.gen_token_idx[-gen_i-1], (concat_labels, all_last_embedding_idx[-1][0]-gen_i, self.gen_token_idx[-gen_i-1]) assert concat_labels[all_last_embedding_idx[-1][1]-gen_i] == self.gen_token_idx[-gen_i-1], (concat_labels, all_last_embedding_idx[-1][1]-gen_i, self.gen_token_idx[-gen_i-1]) # Pad to max length. input_embs = torch.stack(all_concat_input_embs, axis=0) # (N/2, T*2, 768) full_labels = torch.stack(all_concat_labels, axis=0) # (N/2, T*2, 768) assert input_embs.shape == (bs // 2, seq_len * 2 - 1, embs_dim), input_embs.shape assert full_labels.shape == (bs // 2, seq_len * 2 - 1), full_labels.shape # Update labels to pad non-first tokens. for label in full_labels: for k, token in enumerate(label): if (token == self.tokenizer.pad_token_id) or (token in (self.retrieval_token_idx[1:] + self.gen_token_idx[1:])): label[k:] = -100 break output = self.lm(inputs_embeds=input_embs, labels=full_labels, output_hidden_states=True) else: raise NotImplementedError last_embedding = None last_output_logit = None hidden_states = [] llm_hidden_states = [] if mode in ['retrieval', 'generation']: num_tokens = self.num_tokens if mode == 'retrieval': text_hidden_fcs = self.ret_text_hidden_fcs else: text_hidden_fcs = self.gen_text_hidden_fcs # Concatenate captions for retrieval / generation, if specified. if not concat_captions: for idx, fc_layer in zip(self.args.text_emb_layers, text_hidden_fcs): input_hidden_state = torch.stack([output.hidden_states[idx][i, last_embedding_idx[i]-num_tokens+1:last_embedding_idx[i]+1, :] for i in range(batch_size)], axis=0) input_embedding = torch.stack([input_embs[i, last_embedding_idx[i]-num_tokens+1:last_embedding_idx[i]+1, :] for i in range(batch_size)], axis=0) llm_hidden_states.append(input_hidden_state) hidden_states.append(fc_layer(input_hidden_state, input_embedding)) # (N, seq_len, 2048) else: for idx, fc_layer in zip(self.args.text_emb_layers, text_hidden_fcs): all_last_embedding = [] all_input_embedding = [] all_last_output_logit = [] for i in range(batch_size // 2): first_last_embedding_idx, second_last_embedding_idx = all_last_embedding_idx[i] first_last_embedding = output.hidden_states[idx][i, first_last_embedding_idx-num_tokens+1:first_last_embedding_idx+1, :] # (N, D) second_last_embedding = output.hidden_states[idx][i, second_last_embedding_idx-num_tokens+1:second_last_embedding_idx+1, :] # (N, D) all_last_embedding.append(first_last_embedding) all_last_embedding.append(second_last_embedding) first_input_embs = input_embs[i, first_last_embedding_idx-num_tokens+1:first_last_embedding_idx+1, :] # (N, D) second_input_embs = input_embs[i, second_last_embedding_idx-num_tokens+1:second_last_embedding_idx+1, :] # (N, D) all_input_embedding.append(first_input_embs) all_input_embedding.append(second_input_embs) first_last_output_logit = output.logits[i, first_last_embedding_idx - 1, :] # (N, D) second_last_output_logit = output.logits[i, second_last_embedding_idx - 1, :] # (N, D) all_last_output_logit.append(first_last_output_logit) all_last_output_logit.append(second_last_output_logit) last_embedding = torch.stack(all_last_embedding, axis=0) input_embedding = torch.stack(all_input_embedding, axis=0) last_output_logit = torch.stack(all_last_output_logit, axis=0) llm_hidden_states.append(last_embedding) hidden_states.append(fc_layer(last_embedding, input_embedding)) # (N, seq_len, 2048) if not concat_captions: # Add hidden states together. last_embedding = torch.stack(hidden_states, dim=-1).sum(dim=-1) #torch.stack([last_hidden_state[i, :, :] for i in range(batch_size)], axis=0) # (N, T, D) last_output_logit = torch.stack([output.logits[i, last_embedding_idx[i] - 1, :] for i in range(batch_size)], axis=0) # (N, D) else: # Add hidden states together. last_embedding = torch.stack(hidden_states, dim=-1).sum(dim=-1) # Compute retrieval loss. if mode == 'retrieval': assert visual_embs.shape[1] == 1, visual_embs.shape assert last_embedding.shape[1] == 1, last_embedding.shape visual_embs = visual_embs[:, 0, :] visual_embs = visual_embs / visual_embs.norm(dim=1, keepdim=True) last_embedding = last_embedding[:, 0, :] last_embedding = last_embedding / last_embedding.norm(dim=1, keepdim=True) # cosine similarity as logits logit_scale = self.logit_scale.exp() visual_embs = logit_scale * visual_embs elif mode == 'captioning': pass else: raise NotImplementedError return output, full_labels, last_embedding, last_output_logit, visual_embs, visual_embs_norm, input_embs_norm, llm_hidden_states def generate(self, embeddings = torch.FloatTensor, max_len: int = 32, temperature: float = 0.0, top_p: float = 1.0, min_word_tokens: int = 0, ret_scale_factor: float = 1.0, gen_scale_factor: float = 1.0, filter_value: float = -float('Inf')): """Runs greedy decoding and returns generated captions. Args: min_word_tokens: Minimum number of words to generate before allowing a [IMG] output. filter_value: Value to assign to tokens that should never be generated. Outputs: out: (N, T) int32 sequence of output tokens. output_embeddings: (N, T, 256) sequence of text output embeddings. """ self.lm.eval() with torch.no_grad(): # no tracking history # init output with image tokens out = None output_embeddings = [] output_logits = [] for i in range(max_len): output = self.lm(inputs_embeds=embeddings, use_cache=False, output_hidden_states=True) for idx in self.args.text_emb_layers: output_embeddings.append(output.hidden_states[idx]) logits = output.logits[:, -1, :] # (N, vocab_size) if top_p == 1.0: logits = logits.cpu() output_logits.append(logits) # Prevent the model from generating the [IMG1..n] tokens. logits[:, self.retrieval_token_idx[1:]] = filter_value logits[:, self.gen_token_idx[1:]] = filter_value if (self.retrieval_token_idx or self.gen_token_idx) and self.retrieval_token_idx[0] != -1 and self.gen_token_idx[0] != -1: if i < min_word_tokens: # Eliminate probability of generating [IMG] if this is earlier than min_word_tokens. logits[:, self.retrieval_token_idx] = filter_value logits[:, self.gen_token_idx] = filter_value else: # Multiply by scaling factor. if ret_scale_factor > 1: logits[:, self.retrieval_token_idx[0]] = logits[:, self.retrieval_token_idx[0]].abs() * ret_scale_factor if gen_scale_factor > 1: logits[:, self.gen_token_idx[0]] = logits[:, self.gen_token_idx[0]].abs() * gen_scale_factor if temperature == 0.0: if top_p != 1.0: raise ValueError('top_p cannot be set if temperature is 0 (greedy decoding).') next_token = torch.argmax(logits, keepdim=True, dim=-1) # (N, 1) else: logits = logits / temperature # Apply top-p filtering. if top_p < 1.0: assert top_p > 0, f'top_p should be above 0, got {top_p} instead.' sorted_logits, sorted_indices = torch.sort(logits, descending=True) # (N, D) and (N, D) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # (N, D) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probs > top_p # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 for j in range(sorted_indices.shape[0]): indices_to_remove = sorted_indices[j, sorted_indices_to_remove[j, :]] logits[j, indices_to_remove] = filter_value token_weights = logits.exp() # (N, vocab_size) next_token = torch.multinomial(token_weights, 1) # (N, 1) # Force generation of the remaining [IMG] tokens if [IMG0] is generated. if next_token.shape[0] == 1 and next_token.item() == self.retrieval_token_idx[0]: assert self.retrieval_token_idx == self.gen_token_idx, (self.retrieval_token_idx, self.gen_token_idx) next_token = torch.tensor(self.retrieval_token_idx)[None, :].long().to(embeddings.device) # (1, num_tokens) else: next_token = next_token.long().to(embeddings.device) if out is not None: out = torch.cat([out, next_token], dim=-1) else: out = next_token next_embedding = self.input_embeddings(next_token) embeddings = torch.cat([embeddings, next_embedding], dim=1) return out, output_embeddings, output_logits class GILL(nn.Module): def __init__(self, tokenizer, model_args: Optional[GILLArgs] = None, path_array: Optional[List[str]] = None, emb_matrix: Optional[torch.tensor] = None, load_sd: bool = False, num_gen_images: int = 1, decision_model_path: Optional[str] = None): super().__init__() self.model = GILLModel(tokenizer, model_args) self.path_array = path_array self.emb_matrix = emb_matrix self.load_sd = load_sd self.num_gen_images = num_gen_images self.idx2dec = {0: 'gen', 1: 'ret', 2: 'same'} self.decision_model = None # Load the Stable Diffusion model. if load_sd: model_id = "runwayml/stable-diffusion-v1-5" if torch.cuda.is_available(): self.sd_pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") else: self.sd_pipe = StableDiffusionPipeline.from_pretrained(model_id) if decision_model_path is not None: print('Loading decision model...') self.decision_model = nn.Sequential(*[ nn.Dropout(0.5), nn.Linear(4097, 2), ]) if torch.cuda.is_available(): mlp_checkpoint = torch.load(decision_model_path) else: mlp_checkpoint = torch.load(decision_model_path, map_location=torch.device('cpu')) self.decision_model.load_state_dict(mlp_checkpoint['state_dict'], strict=True) self.decision_model.eval() def __call__(self, images: Tensor, tgt_tokens: Optional[Tensor] = None, caption_len: Optional[Tensor] = None, generate: bool = False, num_words: int = 32, temperature: float = 1.0, top_p: float = 1.0, ret_scale_factor: float = 1.0, gen_scale_factor: float = 1.0, min_word_tokens: int = 0, mode: str = 'captioning', concat_captions: bool = False, input_prefix: Optional[str] = None) -> Tensor: if generate: return self.model.generate(images, num_words, temperature=temperature, top_p=top_p, min_word_tokens=min_word_tokens, ret_scale_factor=ret_scale_factor, gen_scale_factor=gen_scale_factor) else: output = self.model( pixel_values = images, labels = tgt_tokens, caption_len = caption_len, mode = mode, concat_captions = concat_captions, input_prefix = input_prefix) return output def generate_for_images_and_texts( self, prompts: List, num_words: int = 0, min_word_tokens: int = 0, ret_scale_factor: float = 1.0, gen_scale_factor: float = 1.0, top_p: float = 1.0, temperature: float = 0.0, max_num_rets: int = 1, generator=None, always_add_bos : bool = False, guidance_scale: float = 7.5, num_inference_steps: int = 50): """ Encode prompts into embeddings, and generates text and image outputs accordingly. Args: prompts: List of interleaved PIL.Image.Image and strings representing input to the model. num_words: Maximum number of words to generate for. If num_words = 0, the model will run its forward pass and return the outputs. min_word_tokens: Minimum number of actual words before generating an image. ret_scale_factor: Proportion to scale [IMG] token logits by. A higher value may increase the probability of the model generating [IMG] outputs. top_p: If set to < 1, the smallest set of tokens with highest probabilities that add up to top_p or higher are kept for generation. temperature: Used to modulate logit distribution. max_num_rets: Maximum number of images to return in one generation pass. Returns: return_outputs: List consisting of either str or List[PIL.Image.Image] objects, representing image-text interleaved model outputs. """ input_embs = [] input_ids = [] add_bos = True with torch.no_grad(): for p in prompts: if type(p) == Image.Image: # Encode as image. pixel_values = utils.get_pixel_values_for_model(self.model.feature_extractor, p) pixel_values = pixel_values.to(device=self.model.logit_scale.device, dtype=self.model.logit_scale.dtype) pixel_values = pixel_values[None, ...] visual_embs = self.model.get_visual_embs(pixel_values, mode='captioning') # (1, n_visual_tokens, D) input_embs.append(visual_embs) elif type(p) == str: text_ids = self.model.tokenizer(p, add_special_tokens=add_bos, return_tensors="pt").input_ids.to(self.model.logit_scale.device) # Only add once unless the flag is set. if not always_add_bos: add_bos = False text_embs = self.model.input_embeddings(text_ids) # (1, T, D) input_embs.append(text_embs) input_ids.append(text_ids) else: raise ValueError(f'Input prompts should be either PIL.Image.Image or str types, got {type(p)} instead.') input_embs = torch.cat(input_embs, dim=1) input_ids = torch.cat(input_ids, dim=1) if num_words == 0: raise NotImplementedError('Generation not implemented for num_words=0.') elif num_words > 0: generated_ids, generated_embeddings, _ = self.model.generate(input_embs, num_words, min_word_tokens=min_word_tokens, temperature=temperature, top_p=top_p, ret_scale_factor=ret_scale_factor, gen_scale_factor=gen_scale_factor) embeddings = generated_embeddings[-1][:, input_embs.shape[1]:] # Truncate to newline. newline_token_id = self.model.tokenizer('\n', add_special_tokens=False).input_ids[0] trunc_idx = 0 for j in range(generated_ids.shape[1]): if generated_ids[0, j] == newline_token_id: trunc_idx = j break if trunc_idx > 0: generated_ids = generated_ids[:, :trunc_idx] embeddings = embeddings[:, :trunc_idx] else: raise ValueError # Save outputs as an interleaved list. return_outputs = [] # Find up to max_num_rets [IMG] tokens, and their corresponding scores. all_ret_idx = [i for i, x in enumerate(generated_ids[0, :] == self.model.retrieval_token_idx[0]) if x][:max_num_rets] seen_image_idx = [] # Avoid showing the same image multiple times. last_ret_idx = 0 if len(all_ret_idx) == 0: # No [IMG] tokens. caption = self.model.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] return_outputs.append(utils.truncate_caption(caption)) else: for ret_idx in all_ret_idx: assert generated_ids[0, ret_idx:ret_idx+self.model.num_tokens].cpu().detach().numpy().tolist() == self.model.retrieval_token_idx, (generated_ids[0, ret_idx:ret_idx+self.model.num_tokens], self.model.retrieval_token_idx) raw_emb = embeddings[:, ret_idx:ret_idx+self.model.num_tokens, :] # (1, 8, 4096) assert len(self.model.args.text_emb_layers) == 1 image_outputs = { 'gen': [], 'ret': [], 'decision': None, } if self.emb_matrix is not None: # Produce retrieval embedding. ret_emb = self.model.ret_text_hidden_fcs[0](raw_emb, None)[:, 0, :] # (1, 256) ret_emb = ret_emb / ret_emb.norm(dim=-1, keepdim=True) ret_emb = ret_emb.type(self.emb_matrix.dtype) # (1, 256) scores = self.emb_matrix @ ret_emb.T # Downweight seen images. for seen_idx in seen_image_idx: scores[seen_idx, :] -= 1000 # Get the top 3 images for each image. _, top_image_idx = scores.squeeze().topk(3) for img_idx in top_image_idx: # Find the first image that does not error out. try: seen_image_idx.append(img_idx) img = utils.get_image_from_url(self.path_array[img_idx]) image_outputs['ret'].append((img, 'ret', scores[img_idx].item())) if len(image_outputs) == max_num_rets: break except (UnidentifiedImageError, ConnectionError): pass # Make decision with MLP. if self.decision_model is not None: decision_emb = raw_emb[:, 0, :] # (1, 4096) assert decision_emb.shape[1] == 4096, decision_emb.shape max_ret_score = scores.max().reshape((1, 1)).clone().detach().to(device=decision_emb.device, dtype=decision_emb.dtype) decision_logits = self.decision_model(torch.cat([decision_emb, max_ret_score], dim=-1)) probs = decision_logits.softmax(dim=-1).cpu().float().numpy().tolist() image_outputs['decision'] = [self.idx2dec[decision_logits.argmax().item()]] + probs else: # If no embedding matrix is provided, generate instead. image_outputs['decision'] = ['gen', [0, 1]] # Produce generation embedding. gen_prefix = ' '.join([f'[IMG{i}]' for i in range(self.model.args.num_tokens)]) gen_prefx_ids = self.model.tokenizer(gen_prefix, add_special_tokens=False, return_tensors="pt").input_ids.to(self.model.logit_scale.device) gen_prefix_embs = self.model.input_embeddings(gen_prefx_ids) # (1, T, D) gen_emb = self.model.gen_text_hidden_fcs[0](raw_emb, gen_prefix_embs) # (1, 77, 768) if gen_emb.shape[1] != 77: print(f"Padding {gen_emb.shape} with zeros") bs = gen_emb.shape[0] clip_emb = 768 gen_emb = gen_emb.reshape(bs, -1, clip_emb) # (bs, T, 768) seq_len = gen_emb.shape[1] gen_emb = torch.cat([gen_emb, torch.zeros((bs, 77 - seq_len, clip_emb), device=gen_emb.device, dtype=gen_emb.dtype)], dim=1) print('Padded to', gen_emb.shape) gen_emb = gen_emb.repeat(self.num_gen_images, 1, 1) # (self.num_gen_images, 77, 768) # Only generate if we are showing a generated image. if self.load_sd and image_outputs['decision'][0] == 'gen': # If num_gen_images > 8, split into multiple batches (for GPU memory reasons). gen_max_bs = 8 gen_images = [] for i in range(0, self.num_gen_images, gen_max_bs): gen_images.extend( self.sd_pipe(prompt_embeds=gen_emb[i:i+gen_max_bs], generator=generator, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps).images) all_gen_pixels = [] for img in gen_images: pixel_values = utils.get_pixel_values_for_model(self.model.feature_extractor, img.resize((224, 224)).convert('RGB')) pixel_values = pixel_values.to(device=self.model.logit_scale.device, dtype=self.model.logit_scale.dtype) all_gen_pixels.append(pixel_values) if self.emb_matrix is not None: all_gen_pixels = torch.stack(all_gen_pixels, dim=0) gen_visual_embs = self.model.get_visual_embs(all_gen_pixels, mode='retrieval') # (1, D) gen_visual_embs = gen_visual_embs / gen_visual_embs.norm(dim=-1, keepdim=True) gen_visual_embs = gen_visual_embs.type(self.emb_matrix.dtype) gen_rank_scores = (gen_visual_embs @ ret_emb.T).squeeze() sorted_score_idx = torch.argsort(-gen_rank_scores) # Rank images by retrieval score. if self.num_gen_images > 1: image_outputs['gen'] = [(gen_images[idx], gen_rank_scores[idx].item()) for idx in sorted_score_idx] else: image_outputs['gen'] = [(gen_images[0], gen_rank_scores.item())] else: image_outputs['gen'] = [(gen_images[0], 0)] else: image_outputs['gen'] = [gen_emb] caption = self.model.tokenizer.batch_decode(generated_ids[:, last_ret_idx:ret_idx], skip_special_tokens=True)[0] last_ret_idx = ret_idx + 1 return_outputs.append(utils.truncate_caption(caption) + f' {gen_prefix}') return_outputs.append(image_outputs) return return_outputs def get_log_likelihood_scores( self, prompts: List): """ Output the log likelihood of the given interleaved prompts. Args: prompts: List of interleaved PIL.Image.Image and strings representing input to the model. Returns: Log likelihood score of the prompt sequence. """ input_embs = [] input_ids = [] add_bos = True for p in prompts: if type(p) == Image.Image: # Encode as image. pixel_values = utils.get_pixel_values_for_model(self.model.feature_extractor, p) pixel_values = pixel_values.to(device=self.model.logit_scale.device, dtype=self.model.logit_scale.dtype) pixel_values = pixel_values[None, ...] visual_embs = self.model.get_visual_embs(pixel_values, mode='captioning') # (1, n_visual_tokens, D) input_embs.append(visual_embs) id_ = torch.zeros(visual_embs.shape[:2], dtype=torch.int64).to(visual_embs.device) - 100 input_ids.append(id_) elif type(p) == str: text_ids = self.model.tokenizer(p, add_special_tokens=True, return_tensors="pt").input_ids.to(self.model.logit_scale.device) if not add_bos: # Remove tag. text_ids = text_ids[:, 1:] else: # Only add once. add_bos = False text_embs = self.model.input_embeddings(text_ids) # (1, T, D) input_embs.append(text_embs) input_ids.append(text_ids) else: raise ValueError(f'Input prompts should be either PIL.Image.Image or str types, got {type(p)} instead.') input_embs = torch.cat(input_embs, dim=1) input_ids = torch.cat(input_ids, dim=1) outputs = self.model.lm(inputs_embeds=input_embs, labels=input_ids, use_cache=False, output_hidden_states=True) return -outputs.loss.item() def load_gill(embeddings_dir: str, model_args_path: str, model_ckpt_path: str, decision_model_path: str) -> GILL: embs_paths = [s for s in glob.glob(os.path.join(embeddings_dir, 'cc3m*.npy'))] if not os.path.exists(model_args_path): raise ValueError(f'model_args.json does not exist at {model_args_path}.') if not os.path.exists(model_ckpt_path): raise ValueError(f'pretrained_ckpt.pth.tar does not exist at {model_ckpt_path}.') if len(embs_paths) == 0: print(f'cc3m*.npy files do not exist in {embeddings_dir}. Running the model without retrieval.') path_array, emb_matrix = None, None else: # Load embeddings. # Construct embedding matrix for nearest neighbor lookup. path_array = [] emb_matrix = [] # These were precomputed for all CC3M images with `model.get_visual_embs(image, mode='retrieval')`. for p in embs_paths: with open(p, 'rb') as wf: train_embs_data = pkl.load(wf) path_array.extend(train_embs_data['paths']) emb_matrix.extend(train_embs_data['embeddings']) emb_matrix = np.stack(emb_matrix, axis=0) # Number of paths should be equal to number of embeddings. assert len(path_array) == emb_matrix.shape[0], (len(path_array), emb_matrix.shape) with open(model_args_path, 'r') as f: model_kwargs = json.load(f) # Initialize tokenizer. tokenizer = AutoTokenizer.from_pretrained(model_kwargs['opt_version'], use_fast=False) if tokenizer.pad_token is None: tokenizer.pad_token_id = tokenizer.eos_token_id # Add an image token for loss masking (and visualization) purposes. tokenizer.add_special_tokens({"cls_token": "<|image|>"}) # add special image token to tokenizer # Add [IMG] tokens to the vocabulary. model_kwargs['retrieval_token_idx'] = [] for i in range(model_kwargs['num_tokens']): print(f'Adding [IMG{i}] token to vocabulary.') print(f'Before adding new token, tokenizer("[IMG{i}]") =', tokenizer(f'[IMG{i}]', add_special_tokens=False)) num_added_tokens = tokenizer.add_tokens(f'[IMG{i}]') print(f'After adding {num_added_tokens} new tokens, tokenizer("[IMG{i}]") =', tokenizer(f'[IMG{i}]', add_special_tokens=False)) ret_token_idx = tokenizer(f'[IMG{i}]', add_special_tokens=False).input_ids assert len(ret_token_idx) == 1, ret_token_idx model_kwargs['retrieval_token_idx'].append(ret_token_idx[0]) # Use the same RET tokens for generation. model_kwargs['gen_token_idx'] = model_kwargs['retrieval_token_idx'] debug = False if debug: model_kwargs['opt_version'] = 'facebook/opt-125m' model_kwargs['visual_encoder'] = 'openai/clip-vit-base-patch32' decision_model_path = None args = namedtuple('args', model_kwargs)(**model_kwargs) # Initialize model for inference. model = GILL(tokenizer, args, path_array=path_array, emb_matrix=emb_matrix, load_sd=not debug, num_gen_images=1, decision_model_path=decision_model_path) model = model.eval() if torch.cuda.is_available(): model = model.bfloat16().cuda() if not debug: # Load pretrained linear mappings and [IMG] embeddings. checkpoint = torch.load(model_ckpt_path) state_dict = {} # This is needed if we train with DDP. for k, v in checkpoint['state_dict'].items(): state_dict[k.replace('module.', '')] = v img_token_embeddings = state_dict['model.input_embeddings.weight'].cpu().detach() del state_dict['model.input_embeddings.weight'] model.load_state_dict(state_dict, strict=False) # Copy over the embeddings of the [IMG] tokens (while loading the others from the pretrained LLM). with torch.no_grad(): if 'share_ret_gen' in model_kwargs: assert model_kwargs['share_ret_gen'], 'Model loading only supports share_ret_gen=True for now.' model.model.input_embeddings.weight[-model_kwargs['num_tokens']:, :].copy_(img_token_embeddings) if len(embs_paths) > 0: logit_scale = model.model.logit_scale.exp() emb_matrix = torch.tensor(emb_matrix, dtype=logit_scale.dtype).to(logit_scale.device) emb_matrix = emb_matrix / emb_matrix.norm(dim=1, keepdim=True) emb_matrix = logit_scale * emb_matrix model.emb_matrix = emb_matrix return model