Spaces:
Sleeping
Sleeping
# File: model_utils | |
# ----------------- | |
# Contain utilities for models, such as loading and saving models | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import os | |
from transformers import GenerationConfig | |
from dataset import process_idefics_listener_generation_input | |
import pdb | |
def filter_targets(logits, index_to_token): | |
target_logits = logits[:, index_to_token] | |
return target_logits | |
class IdeficsJointInferenceModel(nn.Module): | |
def __init__(self, listener_lambda, speaker_lambda, | |
model=None, listener=None, speaker=None): | |
super().__init__() | |
self.l_lambda = listener_lambda | |
self.s_lambda = speaker_lambda | |
self.has_shared_parameters = model is not None | |
if self.has_shared_parameters: | |
self.model = model | |
else: | |
self.listener = listener | |
self.speaker = speaker | |
def forward(self, inf_mode, arguments): | |
if inf_mode == "joint_comprehension": | |
return self.comprehension_side(arguments) | |
elif inf_mode == "joint_reranking": | |
return self.reranking_side(arguments) | |
elif inf_mode == "comprehension": | |
return self.split_comprehension_forward(arguments) | |
elif inf_mode == "split_reranking": | |
return self.split_reranking_forward(arguments) | |
elif inf_mode == "generation": | |
return self.split_generation_forward(arguments) | |
def get_listener(self): | |
if self.has_shared_parameters: | |
return self.model | |
else: | |
return self.listener | |
def get_speaker(self): | |
if self.has_shared_parameters: | |
return self.model | |
else: | |
return self.speaker | |
def get_image_embeddings(self, pixel_values, pixel_attention_mask, model): | |
''' | |
Get image embeddings to avoid repeated computation for images during joint inference. | |
Adapted from the IDEFICS-2 source code. | |
''' | |
# Get the model | |
model = self.get_listener() if model == "listener" else self.get_speaker() | |
if len(pixel_attention_mask.shape) == 5: | |
pixel_attention_mask = pixel_attention_mask[:, 0].contiguous() | |
# Assume images of form: BxCxcnlxHxW | |
batch_size, num_images, num_channels, height, width = pixel_values.shape | |
pixel_values = pixel_values.to(dtype=model.dtype) # fp16 compatibility | |
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) | |
# Remove padding images - padding images are full 0. | |
nb_values_per_image = pixel_values.shape[1:].numel() | |
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image | |
pixel_values = pixel_values[real_images_inds].contiguous() | |
# Remove padding images from the mask/pP p | |
pixel_attention_mask = pixel_attention_mask.view( | |
batch_size * num_images, *pixel_attention_mask.shape[2:] | |
) | |
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() | |
patch_size = model.model.config.vision_config.patch_size | |
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size) | |
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size) | |
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() | |
# Get sequence from the vision encoder | |
image_hidden_states = model.model.model.vision_model( | |
pixel_values=pixel_values, | |
patch_attention_mask=patch_attention_mask, | |
).last_hidden_state | |
# Modality projection & resampling | |
image_hidden_states = model.model.model.connector( | |
image_hidden_states, attention_mask=patch_attention_mask.view(pixel_values.size(0), -1) | |
) | |
return image_hidden_states | |
def split_comprehension_side(self, input_tokens, attn_mask, images, image_attn_mask, index_to_token): | |
''' | |
Redundant with split_comprehension_forward except for the final computation. | |
Used during deployment in ray_models.py. | |
''' | |
listener = self.get_listener() | |
all_logits = listener( | |
input_ids=input_tokens, | |
attention_mask=attn_mask, | |
pixel_values=images, | |
pixel_attention_mask=image_attn_mask | |
)['logits'] | |
target_logits = filter_targets(all_logits[:, -1], index_to_token) | |
listener_log_probs = F.log_softmax(target_logits, dim=1) | |
return listener_log_probs | |
def split_comprehension_forward(self, arguments): | |
input_tokens, attn_mask, images, image_attn_mask = arguments | |
listener = self.get_listener() | |
all_logits = listener( | |
input_ids=input_tokens, | |
attention_mask=attn_mask, | |
pixel_values=images, | |
pixel_attention_mask=image_attn_mask | |
)['logits'] | |
return all_logits | |
def split_generation_forward(self, arguments): | |
input_tokens, attn_mask, images, image_attn_mask = arguments | |
speaker = self.get_speaker() | |
all_logits = speaker( | |
input_ids=input_tokens, | |
attention_mask=attn_mask, | |
pixel_values=images, | |
pixel_attention_mask=image_attn_mask | |
)['logits'] | |
return all_logits | |
def split_reranking_forward(self, arguments): | |
images, input_tokens, attn_mask, image_attn_mask, target_tokens, target_mask = arguments | |
# Get the image embeddings | |
image_embeddings = self.get_image_embeddings(images, image_attn_mask, "speaker") | |
embed_shape = image_embeddings.shape | |
B, mult = input_tokens.shape[:2] | |
C = images.shape[1] | |
image_embeddings = image_embeddings.view(B, C, *embed_shape[1:]) | |
image_embeddings = image_embeddings.unsqueeze(1).repeat(1, mult, 1, 1, 1).view(-1, *embed_shape[1:]) | |
annotation_mask = torch.zeros(B, mult, device=image_embeddings.device).bool() | |
_, speaker_log_probs = self.reranking_speaker_side(image_embeddings, input_tokens, attn_mask, | |
image_attn_mask, target_tokens, target_mask, | |
annotation_mask) | |
return speaker_log_probs | |
def comprehension_side(self, arguments): | |
images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token, \ | |
s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label = arguments | |
if self.has_shared_parameters: | |
image_embeddings = self.get_image_embeddings(images, l_image_attn_mask, "listener") | |
listener_log_probs = self.comprehension_listener_side( | |
image_embeddings, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token | |
) # TODO | |
speaker_log_probs = self.comprehension_speaker_side( | |
image_embeddings, s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label | |
) | |
else: | |
# Deprecated and not used in experiments | |
listener_embeddings = self.get_image_embeddings(images, l_image_attn_mask, "listener") | |
listener_log_probs = self.comprehension_listener_side( | |
listener_embeddings, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token | |
) | |
speaker_embeddings = self.get_image_embeddings(images, "speaker") | |
speaker_log_probs = self.comprehension_speaker_side( | |
speaker_embeddings, s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label | |
) | |
joint_log_probs = self.comprehension_reranking(listener_log_probs, speaker_log_probs) | |
return listener_log_probs, speaker_log_probs, joint_log_probs | |
def comprehension_listener_side(self, image_encoder_embeddings, input_tokens, attn_mask, image_attn_mask, | |
index_to_token): | |
listener = self.get_listener() | |
all_logits = listener( | |
input_ids=input_tokens, | |
attention_mask=attn_mask, | |
image_hidden_states=image_encoder_embeddings, | |
pixel_attention_mask=image_attn_mask | |
)['logits'] | |
target_logits = filter_targets(all_logits[:, -1], index_to_token) # BxC | |
listener_log_probs = F.log_softmax(target_logits, dim=1) | |
return listener_log_probs | |
def comprehension_speaker_side(self, image_encoder_embeddings, input_tokens, attn_mask, image_attn_mask, | |
target_mask, target_label): | |
# Expand embeddings | |
B, C = input_tokens.shape[:2] | |
embed_shape = image_encoder_embeddings.shape | |
image_encoder_embeddings = image_encoder_embeddings.view(B, C, *embed_shape[1:]) | |
image_encoder_embeddings = image_encoder_embeddings.unsqueeze(1).repeat(1, C, 1, 1, 1).view(-1, *embed_shape[1:]) | |
input_tokens = input_tokens.view(B*C, -1) | |
attn_mask = attn_mask.view(B*C, -1) | |
# Forward pass | |
speaker = self.get_speaker() | |
all_logits = speaker( | |
input_ids=input_tokens, | |
attention_mask=attn_mask, | |
image_hidden_states=image_encoder_embeddings, | |
)['logits'] | |
# Get tokenwise probabilities | |
all_log_probs = F.log_softmax(all_logits, dim=2) | |
target_label = target_label.view(B*C, -1).unsqueeze(2) | |
target_mask = target_mask.view(B*C, -1) | |
token_log_probs = torch.gather(all_log_probs, 2, target_label).squeeze(2) # BCxT | |
# Compute the log probabilities | |
token_log_probs = token_log_probs * target_mask | |
utterance_log_probs = torch.sum(token_log_probs, dim=1).view(B, C) | |
return utterance_log_probs | |
def comprehension_reranking(self, listener_log_probs, speaker_log_probs): | |
rerank_weights = self.l_lambda * listener_log_probs + (1 - self.l_lambda) * speaker_log_probs | |
rerank_denominator = torch.logsumexp(rerank_weights, dim=1).unsqueeze(1) | |
rerank_log_distribution = rerank_weights - rerank_denominator | |
return rerank_log_distribution | |
def reranking_side(self, arguments): | |
images, label, s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_tokens, s_target_mask, \ | |
l_input_tokens, l_attn_mask, l_image_attn_mask, \ | |
index_to_token, annotation_mask = arguments | |
# Repeat image embeddings according to number of distractors | |
if self.has_shared_parameters: | |
image_embeddings = self.get_image_embeddings(images, s_image_attn_mask, "speaker") | |
embed_shape = image_embeddings.shape | |
B, mult = s_input_tokens.shape[:2] | |
C = images.shape[1] | |
image_embeddings = image_embeddings.view(B, C, *embed_shape[1:]) | |
image_embeddings = image_embeddings.unsqueeze(1).repeat(1, mult, 1, 1, 1).view(-1, *embed_shape[1:]) | |
speaker_logits, speaker_log_probs = self.reranking_speaker_side(image_embeddings, s_input_tokens, | |
s_attn_mask, s_image_attn_mask, | |
s_target_tokens, s_target_mask, | |
annotation_mask) | |
listener_log_probs = self.reranking_listener_side(image_embeddings, l_input_tokens, l_attn_mask, | |
l_image_attn_mask, label, index_to_token, | |
annotation_mask) | |
else: | |
# Deprecated and no longer used in main experiments | |
image_embeddings = self.get_image_embeddings(images, s_image_attn_mask, "speaker") | |
embed_shape = image_embeddings.shape | |
B, mult = s_input_tokens.shape[:2] | |
C = images.shape[1] | |
image_embeddings = image_embeddings.view(B, C, *embed_shape[1:]) | |
image_embeddings = image_embeddings.unsqueeze(1).repeat(1, mult, 1, 1, 1).view(-1, *embed_shape[1:]) | |
speaker_logits, speaker_log_probs = self.reranking_speaker_side(image_embeddings, s_input_tokens, | |
s_attn_mask, s_image_attn_mask, | |
s_target_tokens, s_target_mask, | |
annotation_mask) | |
image_embeddings = self.get_image_embeddings(images, l_image_attn_mask, "listener") | |
embed_shape = image_embeddings.shape | |
B, mult = s_input_tokens.shape[:2] | |
C = images.shape[1] | |
image_embeddings = image_embeddings.view(B, C, *embed_shape[1:]) | |
image_embeddings = image_embeddings.unsqueeze(1).repeat(1, mult, 1, 1, 1).view(-1, *embed_shape[1:]) | |
listener_log_probs = self.reranking_listener_side(image_embeddings, l_input_tokens, l_attn_mask, | |
l_image_attn_mask, label, index_to_token, annotation_mask) | |
# Full forward passes | |
utterance_distribution = self.reranking_combination(speaker_log_probs, listener_log_probs) | |
return speaker_logits, speaker_log_probs, listener_log_probs, utterance_distribution | |
def reranking_speaker_side(self, image_embeddings, input_tokens, attn_mask, image_attn_mask, | |
target_tokens, target_mask, annotation_mask): | |
# Flatten inputs and outputs | |
B, mult = input_tokens.shape[:2] | |
input_tokens = input_tokens.view(B*mult, -1) | |
attn_mask = attn_mask.view(B*mult, -1) | |
target_tokens = target_tokens.view(B*mult, -1).unsqueeze(-1) | |
target_mask = target_mask.view(B*mult, -1) | |
# Forward pass: Compute utterance probabilities for all | |
speaker = self.get_speaker() | |
all_logits = speaker( | |
input_ids=input_tokens, | |
attention_mask=attn_mask, | |
image_hidden_states=image_embeddings, | |
)['logits'] | |
# Compute utterance log probabilities | |
all_log_probs = F.log_softmax(all_logits, dim=2) | |
token_log_probs = torch.gather(all_log_probs, 2, target_tokens).squeeze(2) # BCxT | |
token_log_probs = token_log_probs * target_mask | |
utterance_log_probs = torch.sum(token_log_probs, dim=1).view(B, mult) | |
utterance_log_probs[annotation_mask] = float('-inf') # Mask in the event there aren't 9 distractors | |
return all_logits, utterance_log_probs | |
def reranking_listener_side(self, image_embeddings, input_tokens, attn_mask, image_attn_mask, | |
label, index_to_token, annotation_mask): | |
# Flatten inputs and outputs | |
B, mult = input_tokens.shape[:2] | |
input_tokens = input_tokens.view(B*mult, -1) | |
attn_mask = attn_mask.view(B*mult, -1) | |
label = label.unsqueeze(1).repeat(1, mult).view(-1).unsqueeze(1) | |
# Forward pass: Compute listener log-probs | |
listener = self.get_listener() | |
all_logits = listener( | |
input_ids=input_tokens, | |
attention_mask=attn_mask, | |
image_hidden_states=image_embeddings, | |
)['logits'] | |
target_logits = filter_targets(all_logits[:, -1], index_to_token) # BmultxC | |
listener_log_probs = F.log_softmax(target_logits, dim=1) #BmultxC | |
utterance_log_probs = torch.gather(listener_log_probs, 1, label).squeeze(1).view(B, mult) | |
utterance_log_probs[annotation_mask] = float('-inf') # Mask in the event there aren't mult distractors | |
return utterance_log_probs | |
def reranking_combination(self, speaker_utterance_log_probs, listener_utterance_log_probs): | |
weights = self.s_lambda * speaker_utterance_log_probs + (1-self.s_lambda) * listener_utterance_log_probs | |
rerank_denominator = torch.logsumexp(weights, dim=1).unsqueeze(1) | |
rerank_log_distribution = weights - rerank_denominator | |
return rerank_log_distribution | |
def split_generate(self, input_tokens, attn_mask, images, image_attn_mask, processor, | |
max_steps=25, sampling_type="nucleus", temperature=1.0, | |
top_k=40, top_p=0.9, repetition_penalty=1, num_samples=1): | |
# (1) Perform generation | |
speaker = self.get_speaker() | |
generation_config = GenerationConfig( | |
max_new_tokens=max_steps, | |
do_sample=True, | |
temperature=temperature, | |
top_k=top_k, top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
num_return_sequences=num_samples, | |
output_hidden_states=True, | |
return_dict_in_generate=True | |
) | |
outputs = speaker.generate( | |
input_ids=input_tokens, | |
attention_mask=attn_mask, | |
pixel_values=images, | |
pixel_attention_mask=image_attn_mask, | |
generation_config=generation_config, | |
use_cache=True | |
) | |
# (2) Get the speaker captions | |
B = input_tokens.shape[0] | |
observed_steps = len(outputs['hidden_states']) | |
filtered_seqs = [] | |
for seq in outputs['sequences']: | |
filtered_seqs.append(seq[-observed_steps:]) | |
speaker_outputs = processor.batch_decode(filtered_seqs, skip_special_tokens=True) | |
# (3) Get the speaker log probabilities | |
target_outputs = torch.stack(filtered_seqs, dim=0) # BNxT | |
target_mask = target_outputs != 0 | |
final_states = torch.stack([outputs['hidden_states'][i][-1][:, -1] for i in range(observed_steps)], dim=1) # BNxTxD | |
token_logits = speaker.lm_head(final_states) # BNxTxV | |
token_log_probs = F.log_softmax(token_logits, dim=2) | |
token_log_probs = torch.gather(token_log_probs, 2, target_outputs.unsqueeze(2)).squeeze(2) | |
# (4) Choose the output with the top probability | |
if B == 1: | |
utterance_log_probs = torch.sum(token_log_probs * target_mask, dim=1).view(num_samples) # N | |
best_idx = torch.argmax(utterance_log_probs).item() | |
return [speaker_outputs[best_idx]] | |
else: | |
utterance_log_probs = torch.sum(token_log_probs * target_mask, dim=1).view(B, num_samples) # N | |
best_indices = torch.argmax(utterance_log_probs, dim=1) | |
choices = [] | |
for i in range(B): | |
curr_index = num_samples * i + best_indices[i].item() | |
choices.append(speaker_outputs[curr_index]) | |
return choices | |
def generate(self, images, s_input_tokens, s_attn_mask, s_image_attn_mask, label, | |
image_paths, processor, image_dir, index_to_token, | |
max_steps=25, sampling_type="nucleus", temperature=1.0, top_k=40, | |
top_p=0.9, repetition_penalty=1, num_samples=10): | |
# Get the repeated image embeddings; assume parameter sharing | |
image_embeddings = self.get_image_embeddings(images, s_image_attn_mask, "speaker") | |
# Sample utterances from the speaker | |
speaker_utterance_log_probs, speaker_utterances = self.generate_speaker_side(processor, images, s_input_tokens, | |
s_attn_mask, s_image_attn_mask, max_steps, | |
sampling_type, temperature, | |
top_k, top_p, repetition_penalty, | |
num_samples) # BxN, BN list | |
# Get probabilities for the utterances from the listener | |
listener_log_probs = self.generate_listener_side(image_embeddings, speaker_utterances, label, image_paths, processor, | |
image_dir, index_to_token, num_samples) | |
# Reranked selection | |
utterance_weights = self.s_lambda*speaker_utterance_log_probs + (1-self.s_lambda)*listener_log_probs | |
chosen_indices = torch.argmax(utterance_weights, dim=1) | |
choices = [] | |
for i in range(speaker_utterance_log_probs.shape[0]): | |
curr_index = num_samples * i + chosen_indices[i].item() | |
choices.append(speaker_utterances[curr_index]) | |
return choices, speaker_utterances, listener_log_probs, speaker_utterance_log_probs, utterance_weights | |
def generate_speaker_side(self, processor, images, s_input_tokens, s_attn_mask, s_image_attn_mask, max_steps, | |
sampling_type, temperature, top_k, top_p, repetition_penalty, num_samples): | |
# (1) Perform generation | |
speaker = self.get_speaker() | |
generation_config = GenerationConfig( | |
max_new_tokens=max_steps, | |
min_new_tokens=1, | |
do_sample=True, | |
temperature=temperature, | |
top_k=top_k, top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
num_return_sequences=num_samples, | |
output_hidden_states=True, | |
return_dict_in_generate=True | |
) | |
outputs = speaker.generate( | |
input_ids=s_input_tokens, | |
attention_mask=s_attn_mask, | |
pixel_values=images, | |
pixel_attention_mask=s_image_attn_mask, | |
generation_config=generation_config, | |
use_cache=True | |
) | |
# (2) Get the speaker captions | |
B = s_input_tokens.shape[0] | |
observed_steps = len(outputs['hidden_states']) | |
filtered_seqs = [] | |
for seq in outputs['sequences']: | |
filtered_seqs.append(seq[-observed_steps:]) | |
speaker_outputs = processor.batch_decode(filtered_seqs, skip_special_tokens=True) | |
# (3) Get the speaker log probabilities | |
target_outputs = torch.stack(filtered_seqs, dim=0) # BNxT | |
target_mask = target_outputs != 0 | |
final_states = torch.stack([outputs['hidden_states'][i][-1][:, -1] for i in range(observed_steps)], dim=1) # BNxTxD | |
token_logits = speaker.lm_head(final_states) # BNxTxV | |
token_log_probs = F.log_softmax(token_logits, dim=2) | |
token_log_probs = torch.gather(token_log_probs, 2, target_outputs.unsqueeze(2)).squeeze(2) | |
utterance_log_probs = torch.sum(token_log_probs * target_mask, dim=1).view(B, num_samples) # BxN | |
return utterance_log_probs, speaker_outputs | |
def generate_listener_side(self, image_embeddings, speaker_utterances, label, image_paths, processor, | |
image_dir, index_to_token, num_samples): | |
# Construct the inputs | |
B = label.shape[0] | |
embed_shape = image_embeddings.shape | |
image_embeddings = image_embeddings.view(B, -1, *embed_shape[1:]) | |
image_embeddings = image_embeddings.unsqueeze(1).repeat(1, num_samples, 1, 1, 1).view(-1, *embed_shape[1:]) | |
l_batch = process_idefics_listener_generation_input(image_paths, speaker_utterances, processor, | |
image_dir, num_samples, image_embeddings.device) | |
l_input_tokens, l_attn_mask, _, l_image_attn_mask = l_batch | |
label = label.unsqueeze(1).repeat(1, num_samples).view(-1).unsqueeze(1) | |
# Forward pass | |
listener = self.get_listener() | |
all_logits = listener( | |
input_ids=l_input_tokens, | |
attention_mask=l_attn_mask, | |
image_hidden_states=image_embeddings, | |
pixel_attention_mask=l_image_attn_mask | |
)['logits'] | |
target_logits = filter_targets(all_logits[:, -1], index_to_token) | |
listener_log_probs = F.log_softmax(target_logits, dim=1) | |
utterance_log_probs = torch.gather(listener_log_probs, 1, label).squeeze(1).view(B, num_samples) | |
return utterance_log_probs | |