cogen / joint_inference.py
momergul
Update
14eba99
# File: model_utils
# -----------------
# Contain utilities for models, such as loading and saving models
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from transformers import GenerationConfig
from dataset import process_idefics_listener_generation_input
import pdb
def filter_targets(logits, index_to_token):
target_logits = logits[:, index_to_token]
return target_logits
class IdeficsJointInferenceModel(nn.Module):
def __init__(self, listener_lambda, speaker_lambda,
model=None, listener=None, speaker=None):
super().__init__()
self.l_lambda = listener_lambda
self.s_lambda = speaker_lambda
self.has_shared_parameters = model is not None
if self.has_shared_parameters:
self.model = model
else:
self.listener = listener
self.speaker = speaker
def forward(self, inf_mode, arguments):
if inf_mode == "joint_comprehension":
return self.comprehension_side(arguments)
elif inf_mode == "joint_reranking":
return self.reranking_side(arguments)
elif inf_mode == "comprehension":
return self.split_comprehension_forward(arguments)
elif inf_mode == "split_reranking":
return self.split_reranking_forward(arguments)
elif inf_mode == "generation":
return self.split_generation_forward(arguments)
def get_listener(self):
if self.has_shared_parameters:
return self.model
else:
return self.listener
def get_speaker(self):
if self.has_shared_parameters:
return self.model
else:
return self.speaker
def get_image_embeddings(self, pixel_values, pixel_attention_mask, model):
'''
Get image embeddings to avoid repeated computation for images during joint inference.
Adapted from the IDEFICS-2 source code.
'''
# Get the model
model = self.get_listener() if model == "listener" else self.get_speaker()
if len(pixel_attention_mask.shape) == 5:
pixel_attention_mask = pixel_attention_mask[:, 0].contiguous()
# Assume images of form: BxCxcnlxHxW
batch_size, num_images, num_channels, height, width = pixel_values.shape
pixel_values = pixel_values.to(dtype=model.dtype) # fp16 compatibility
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
# Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel()
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
pixel_values = pixel_values[real_images_inds].contiguous()
# Remove padding images from the mask/pP p
pixel_attention_mask = pixel_attention_mask.view(
batch_size * num_images, *pixel_attention_mask.shape[2:]
)
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
patch_size = model.model.config.vision_config.patch_size
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
# Get sequence from the vision encoder
image_hidden_states = model.model.model.vision_model(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
).last_hidden_state
# Modality projection & resampling
image_hidden_states = model.model.model.connector(
image_hidden_states, attention_mask=patch_attention_mask.view(pixel_values.size(0), -1)
)
return image_hidden_states
def split_comprehension_side(self, input_tokens, attn_mask, images, image_attn_mask, index_to_token):
'''
Redundant with split_comprehension_forward except for the final computation.
Used during deployment in ray_models.py.
'''
listener = self.get_listener()
all_logits = listener(
input_ids=input_tokens,
attention_mask=attn_mask,
pixel_values=images,
pixel_attention_mask=image_attn_mask
)['logits']
target_logits = filter_targets(all_logits[:, -1], index_to_token)
listener_log_probs = F.log_softmax(target_logits, dim=1)
return listener_log_probs
def split_comprehension_forward(self, arguments):
input_tokens, attn_mask, images, image_attn_mask = arguments
listener = self.get_listener()
all_logits = listener(
input_ids=input_tokens,
attention_mask=attn_mask,
pixel_values=images,
pixel_attention_mask=image_attn_mask
)['logits']
return all_logits
def split_generation_forward(self, arguments):
input_tokens, attn_mask, images, image_attn_mask = arguments
speaker = self.get_speaker()
all_logits = speaker(
input_ids=input_tokens,
attention_mask=attn_mask,
pixel_values=images,
pixel_attention_mask=image_attn_mask
)['logits']
return all_logits
def split_reranking_forward(self, arguments):
images, input_tokens, attn_mask, image_attn_mask, target_tokens, target_mask = arguments
# Get the image embeddings
image_embeddings = self.get_image_embeddings(images, image_attn_mask, "speaker")
embed_shape = image_embeddings.shape
B, mult = input_tokens.shape[:2]
C = images.shape[1]
image_embeddings = image_embeddings.view(B, C, *embed_shape[1:])
image_embeddings = image_embeddings.unsqueeze(1).repeat(1, mult, 1, 1, 1).view(-1, *embed_shape[1:])
annotation_mask = torch.zeros(B, mult, device=image_embeddings.device).bool()
_, speaker_log_probs = self.reranking_speaker_side(image_embeddings, input_tokens, attn_mask,
image_attn_mask, target_tokens, target_mask,
annotation_mask)
return speaker_log_probs
def comprehension_side(self, arguments):
images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token, \
s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label = arguments
if self.has_shared_parameters:
image_embeddings = self.get_image_embeddings(images, l_image_attn_mask, "listener")
listener_log_probs = self.comprehension_listener_side(
image_embeddings, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token
) # TODO
speaker_log_probs = self.comprehension_speaker_side(
image_embeddings, s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label
)
else:
# Deprecated and not used in experiments
listener_embeddings = self.get_image_embeddings(images, l_image_attn_mask, "listener")
listener_log_probs = self.comprehension_listener_side(
listener_embeddings, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token
)
speaker_embeddings = self.get_image_embeddings(images, "speaker")
speaker_log_probs = self.comprehension_speaker_side(
speaker_embeddings, s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label
)
joint_log_probs = self.comprehension_reranking(listener_log_probs, speaker_log_probs)
return listener_log_probs, speaker_log_probs, joint_log_probs
def comprehension_listener_side(self, image_encoder_embeddings, input_tokens, attn_mask, image_attn_mask,
index_to_token):
listener = self.get_listener()
all_logits = listener(
input_ids=input_tokens,
attention_mask=attn_mask,
image_hidden_states=image_encoder_embeddings,
pixel_attention_mask=image_attn_mask
)['logits']
target_logits = filter_targets(all_logits[:, -1], index_to_token) # BxC
listener_log_probs = F.log_softmax(target_logits, dim=1)
return listener_log_probs
def comprehension_speaker_side(self, image_encoder_embeddings, input_tokens, attn_mask, image_attn_mask,
target_mask, target_label):
# Expand embeddings
B, C = input_tokens.shape[:2]
embed_shape = image_encoder_embeddings.shape
image_encoder_embeddings = image_encoder_embeddings.view(B, C, *embed_shape[1:])
image_encoder_embeddings = image_encoder_embeddings.unsqueeze(1).repeat(1, C, 1, 1, 1).view(-1, *embed_shape[1:])
input_tokens = input_tokens.view(B*C, -1)
attn_mask = attn_mask.view(B*C, -1)
# Forward pass
speaker = self.get_speaker()
all_logits = speaker(
input_ids=input_tokens,
attention_mask=attn_mask,
image_hidden_states=image_encoder_embeddings,
)['logits']
# Get tokenwise probabilities
all_log_probs = F.log_softmax(all_logits, dim=2)
target_label = target_label.view(B*C, -1).unsqueeze(2)
target_mask = target_mask.view(B*C, -1)
token_log_probs = torch.gather(all_log_probs, 2, target_label).squeeze(2) # BCxT
# Compute the log probabilities
token_log_probs = token_log_probs * target_mask
utterance_log_probs = torch.sum(token_log_probs, dim=1).view(B, C)
return utterance_log_probs
def comprehension_reranking(self, listener_log_probs, speaker_log_probs):
rerank_weights = self.l_lambda * listener_log_probs + (1 - self.l_lambda) * speaker_log_probs
rerank_denominator = torch.logsumexp(rerank_weights, dim=1).unsqueeze(1)
rerank_log_distribution = rerank_weights - rerank_denominator
return rerank_log_distribution
def reranking_side(self, arguments):
images, label, s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_tokens, s_target_mask, \
l_input_tokens, l_attn_mask, l_image_attn_mask, \
index_to_token, annotation_mask = arguments
# Repeat image embeddings according to number of distractors
if self.has_shared_parameters:
image_embeddings = self.get_image_embeddings(images, s_image_attn_mask, "speaker")
embed_shape = image_embeddings.shape
B, mult = s_input_tokens.shape[:2]
C = images.shape[1]
image_embeddings = image_embeddings.view(B, C, *embed_shape[1:])
image_embeddings = image_embeddings.unsqueeze(1).repeat(1, mult, 1, 1, 1).view(-1, *embed_shape[1:])
speaker_logits, speaker_log_probs = self.reranking_speaker_side(image_embeddings, s_input_tokens,
s_attn_mask, s_image_attn_mask,
s_target_tokens, s_target_mask,
annotation_mask)
listener_log_probs = self.reranking_listener_side(image_embeddings, l_input_tokens, l_attn_mask,
l_image_attn_mask, label, index_to_token,
annotation_mask)
else:
# Deprecated and no longer used in main experiments
image_embeddings = self.get_image_embeddings(images, s_image_attn_mask, "speaker")
embed_shape = image_embeddings.shape
B, mult = s_input_tokens.shape[:2]
C = images.shape[1]
image_embeddings = image_embeddings.view(B, C, *embed_shape[1:])
image_embeddings = image_embeddings.unsqueeze(1).repeat(1, mult, 1, 1, 1).view(-1, *embed_shape[1:])
speaker_logits, speaker_log_probs = self.reranking_speaker_side(image_embeddings, s_input_tokens,
s_attn_mask, s_image_attn_mask,
s_target_tokens, s_target_mask,
annotation_mask)
image_embeddings = self.get_image_embeddings(images, l_image_attn_mask, "listener")
embed_shape = image_embeddings.shape
B, mult = s_input_tokens.shape[:2]
C = images.shape[1]
image_embeddings = image_embeddings.view(B, C, *embed_shape[1:])
image_embeddings = image_embeddings.unsqueeze(1).repeat(1, mult, 1, 1, 1).view(-1, *embed_shape[1:])
listener_log_probs = self.reranking_listener_side(image_embeddings, l_input_tokens, l_attn_mask,
l_image_attn_mask, label, index_to_token, annotation_mask)
# Full forward passes
utterance_distribution = self.reranking_combination(speaker_log_probs, listener_log_probs)
return speaker_logits, speaker_log_probs, listener_log_probs, utterance_distribution
def reranking_speaker_side(self, image_embeddings, input_tokens, attn_mask, image_attn_mask,
target_tokens, target_mask, annotation_mask):
# Flatten inputs and outputs
B, mult = input_tokens.shape[:2]
input_tokens = input_tokens.view(B*mult, -1)
attn_mask = attn_mask.view(B*mult, -1)
target_tokens = target_tokens.view(B*mult, -1).unsqueeze(-1)
target_mask = target_mask.view(B*mult, -1)
# Forward pass: Compute utterance probabilities for all
speaker = self.get_speaker()
all_logits = speaker(
input_ids=input_tokens,
attention_mask=attn_mask,
image_hidden_states=image_embeddings,
)['logits']
# Compute utterance log probabilities
all_log_probs = F.log_softmax(all_logits, dim=2)
token_log_probs = torch.gather(all_log_probs, 2, target_tokens).squeeze(2) # BCxT
token_log_probs = token_log_probs * target_mask
utterance_log_probs = torch.sum(token_log_probs, dim=1).view(B, mult)
utterance_log_probs[annotation_mask] = float('-inf') # Mask in the event there aren't 9 distractors
return all_logits, utterance_log_probs
def reranking_listener_side(self, image_embeddings, input_tokens, attn_mask, image_attn_mask,
label, index_to_token, annotation_mask):
# Flatten inputs and outputs
B, mult = input_tokens.shape[:2]
input_tokens = input_tokens.view(B*mult, -1)
attn_mask = attn_mask.view(B*mult, -1)
label = label.unsqueeze(1).repeat(1, mult).view(-1).unsqueeze(1)
# Forward pass: Compute listener log-probs
listener = self.get_listener()
all_logits = listener(
input_ids=input_tokens,
attention_mask=attn_mask,
image_hidden_states=image_embeddings,
)['logits']
target_logits = filter_targets(all_logits[:, -1], index_to_token) # BmultxC
listener_log_probs = F.log_softmax(target_logits, dim=1) #BmultxC
utterance_log_probs = torch.gather(listener_log_probs, 1, label).squeeze(1).view(B, mult)
utterance_log_probs[annotation_mask] = float('-inf') # Mask in the event there aren't mult distractors
return utterance_log_probs
def reranking_combination(self, speaker_utterance_log_probs, listener_utterance_log_probs):
weights = self.s_lambda * speaker_utterance_log_probs + (1-self.s_lambda) * listener_utterance_log_probs
rerank_denominator = torch.logsumexp(weights, dim=1).unsqueeze(1)
rerank_log_distribution = weights - rerank_denominator
return rerank_log_distribution
def split_generate(self, input_tokens, attn_mask, images, image_attn_mask, processor,
max_steps=25, sampling_type="nucleus", temperature=1.0,
top_k=40, top_p=0.9, repetition_penalty=1, num_samples=1):
# (1) Perform generation
speaker = self.get_speaker()
generation_config = GenerationConfig(
max_new_tokens=max_steps,
do_sample=True,
temperature=temperature,
top_k=top_k, top_p=top_p,
repetition_penalty=repetition_penalty,
num_return_sequences=num_samples,
output_hidden_states=True,
return_dict_in_generate=True
)
outputs = speaker.generate(
input_ids=input_tokens,
attention_mask=attn_mask,
pixel_values=images,
pixel_attention_mask=image_attn_mask,
generation_config=generation_config,
use_cache=True
)
# (2) Get the speaker captions
B = input_tokens.shape[0]
observed_steps = len(outputs['hidden_states'])
filtered_seqs = []
for seq in outputs['sequences']:
filtered_seqs.append(seq[-observed_steps:])
speaker_outputs = processor.batch_decode(filtered_seqs, skip_special_tokens=True)
# (3) Get the speaker log probabilities
target_outputs = torch.stack(filtered_seqs, dim=0) # BNxT
target_mask = target_outputs != 0
final_states = torch.stack([outputs['hidden_states'][i][-1][:, -1] for i in range(observed_steps)], dim=1) # BNxTxD
token_logits = speaker.lm_head(final_states) # BNxTxV
token_log_probs = F.log_softmax(token_logits, dim=2)
token_log_probs = torch.gather(token_log_probs, 2, target_outputs.unsqueeze(2)).squeeze(2)
# (4) Choose the output with the top probability
if B == 1:
utterance_log_probs = torch.sum(token_log_probs * target_mask, dim=1).view(num_samples) # N
best_idx = torch.argmax(utterance_log_probs).item()
return [speaker_outputs[best_idx]]
else:
utterance_log_probs = torch.sum(token_log_probs * target_mask, dim=1).view(B, num_samples) # N
best_indices = torch.argmax(utterance_log_probs, dim=1)
choices = []
for i in range(B):
curr_index = num_samples * i + best_indices[i].item()
choices.append(speaker_outputs[curr_index])
return choices
def generate(self, images, s_input_tokens, s_attn_mask, s_image_attn_mask, label,
image_paths, processor, image_dir, index_to_token,
max_steps=25, sampling_type="nucleus", temperature=1.0, top_k=40,
top_p=0.9, repetition_penalty=1, num_samples=10):
# Get the repeated image embeddings; assume parameter sharing
image_embeddings = self.get_image_embeddings(images, s_image_attn_mask, "speaker")
# Sample utterances from the speaker
speaker_utterance_log_probs, speaker_utterances = self.generate_speaker_side(processor, images, s_input_tokens,
s_attn_mask, s_image_attn_mask, max_steps,
sampling_type, temperature,
top_k, top_p, repetition_penalty,
num_samples) # BxN, BN list
# Get probabilities for the utterances from the listener
listener_log_probs = self.generate_listener_side(image_embeddings, speaker_utterances, label, image_paths, processor,
image_dir, index_to_token, num_samples)
# Reranked selection
utterance_weights = self.s_lambda*speaker_utterance_log_probs + (1-self.s_lambda)*listener_log_probs
chosen_indices = torch.argmax(utterance_weights, dim=1)
choices = []
for i in range(speaker_utterance_log_probs.shape[0]):
curr_index = num_samples * i + chosen_indices[i].item()
choices.append(speaker_utterances[curr_index])
return choices, speaker_utterances, listener_log_probs, speaker_utterance_log_probs, utterance_weights
def generate_speaker_side(self, processor, images, s_input_tokens, s_attn_mask, s_image_attn_mask, max_steps,
sampling_type, temperature, top_k, top_p, repetition_penalty, num_samples):
# (1) Perform generation
speaker = self.get_speaker()
generation_config = GenerationConfig(
max_new_tokens=max_steps,
min_new_tokens=1,
do_sample=True,
temperature=temperature,
top_k=top_k, top_p=top_p,
repetition_penalty=repetition_penalty,
num_return_sequences=num_samples,
output_hidden_states=True,
return_dict_in_generate=True
)
outputs = speaker.generate(
input_ids=s_input_tokens,
attention_mask=s_attn_mask,
pixel_values=images,
pixel_attention_mask=s_image_attn_mask,
generation_config=generation_config,
use_cache=True
)
# (2) Get the speaker captions
B = s_input_tokens.shape[0]
observed_steps = len(outputs['hidden_states'])
filtered_seqs = []
for seq in outputs['sequences']:
filtered_seqs.append(seq[-observed_steps:])
speaker_outputs = processor.batch_decode(filtered_seqs, skip_special_tokens=True)
# (3) Get the speaker log probabilities
target_outputs = torch.stack(filtered_seqs, dim=0) # BNxT
target_mask = target_outputs != 0
final_states = torch.stack([outputs['hidden_states'][i][-1][:, -1] for i in range(observed_steps)], dim=1) # BNxTxD
token_logits = speaker.lm_head(final_states) # BNxTxV
token_log_probs = F.log_softmax(token_logits, dim=2)
token_log_probs = torch.gather(token_log_probs, 2, target_outputs.unsqueeze(2)).squeeze(2)
utterance_log_probs = torch.sum(token_log_probs * target_mask, dim=1).view(B, num_samples) # BxN
return utterance_log_probs, speaker_outputs
def generate_listener_side(self, image_embeddings, speaker_utterances, label, image_paths, processor,
image_dir, index_to_token, num_samples):
# Construct the inputs
B = label.shape[0]
embed_shape = image_embeddings.shape
image_embeddings = image_embeddings.view(B, -1, *embed_shape[1:])
image_embeddings = image_embeddings.unsqueeze(1).repeat(1, num_samples, 1, 1, 1).view(-1, *embed_shape[1:])
l_batch = process_idefics_listener_generation_input(image_paths, speaker_utterances, processor,
image_dir, num_samples, image_embeddings.device)
l_input_tokens, l_attn_mask, _, l_image_attn_mask = l_batch
label = label.unsqueeze(1).repeat(1, num_samples).view(-1).unsqueeze(1)
# Forward pass
listener = self.get_listener()
all_logits = listener(
input_ids=l_input_tokens,
attention_mask=l_attn_mask,
image_hidden_states=image_embeddings,
pixel_attention_mask=l_image_attn_mask
)['logits']
target_logits = filter_targets(all_logits[:, -1], index_to_token)
listener_log_probs = F.log_softmax(target_logits, dim=1)
utterance_log_probs = torch.gather(listener_log_probs, 1, label).squeeze(1).view(B, num_samples)
return utterance_log_probs