Spaces:
Build error
Build error
import torch | |
def expand_inputs_for_generation( | |
input_ids, | |
expand_size=1, | |
is_encoder_decoder=False, | |
attention_mask=None, | |
encoder_outputs=None, | |
**model_kwargs, | |
): | |
expanded_return_idx = ( | |
torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device) | |
) | |
input_ids = input_ids.index_select(0, expanded_return_idx) | |
if "token_type_ids" in model_kwargs: | |
token_type_ids = model_kwargs["token_type_ids"] | |
model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx) | |
if attention_mask is not None: | |
model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx) | |
model_kwargs["image_attention_mask"] = model_kwargs["image_attention_mask"].index_select( | |
0, expanded_return_idx | |
) | |
model_kwargs["pixel_values"] = model_kwargs["pixel_values"].index_select(0, expanded_return_idx) | |
if is_encoder_decoder: | |
if encoder_outputs is None: | |
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") | |
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( | |
0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device) | |
) | |
model_kwargs["encoder_outputs"] = encoder_outputs | |
return input_ids, model_kwargs | |
def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False): | |
# must have this key set to at least None | |
model_kwargs["past_key_values"] = model_kwargs.get("past_key_values", None) | |
# update past | |
if "past_key_values" in outputs: | |
model_kwargs["past"] = outputs.past_key_values | |
elif "mems" in outputs: | |
model_kwargs["past"] = outputs.mems | |
elif "past_buckets_states" in outputs: | |
model_kwargs["past"] = outputs.past_buckets_states | |
else: | |
model_kwargs["past"] = None | |
# update token_type_ids with last value | |
if "token_type_ids" in model_kwargs: | |
token_type_ids = model_kwargs["token_type_ids"] | |
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) | |
# update attention masks | |
if not is_encoder_decoder: | |
if "attention_mask" in model_kwargs: | |
attention_mask = model_kwargs["attention_mask"] | |
model_kwargs["attention_mask"] = torch.cat( | |
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 | |
) | |
if "image_attention_mask" in model_kwargs: | |
image_attention_mask = model_kwargs["image_attention_mask"] | |
last_mask = image_attention_mask[:, -1, :].unsqueeze(1) | |
model_kwargs["image_attention_mask"] = last_mask | |
return model_kwargs | |
def prepare_inputs_for_generation(input_ids, past=None, **kwargs): | |
token_type_ids = kwargs.get("token_type_ids", None) | |
# only last token for inputs_ids if past is defined in kwargs | |
if past: | |
input_ids = input_ids[:, -1].unsqueeze(-1) | |
if token_type_ids is not None: | |
token_type_ids = token_type_ids[:, -1].unsqueeze(-1) | |
attention_mask = kwargs.get("attention_mask", None) | |
position_ids = kwargs.get("position_ids", None) | |
if attention_mask is not None and position_ids is None: | |
# create position_ids on the fly for batch generation | |
position_ids = attention_mask.long().cumsum(-1) - 1 | |
position_ids.masked_fill_(attention_mask == 0, 1) | |
if past: | |
position_ids = position_ids[:, -1].unsqueeze(-1) | |
pixel_values = kwargs.get("pixel_values", None) | |
image_attention_mask = kwargs.get("image_attention_mask", None) | |
if pixel_values is None or image_attention_mask is None: | |
raise ValueError("pixel values and image attention mask cannot be None") | |
return { | |
"input_ids": input_ids, | |
"past_key_values": past, | |
"use_cache": kwargs.get("use_cache"), | |
"position_ids": position_ids, | |
"attention_mask": attention_mask, | |
"token_type_ids": token_type_ids, | |
"pixel_values": pixel_values, | |
"image_attention_mask": image_attention_mask, | |
} | |