import torch def expand_inputs_for_generation( input_ids, expand_size=1, is_encoder_decoder=False, attention_mask=None, encoder_outputs=None, **model_kwargs, ): expanded_return_idx = ( torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device) ) input_ids = input_ids.index_select(0, expanded_return_idx) if "token_type_ids" in model_kwargs: token_type_ids = model_kwargs["token_type_ids"] model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx) if attention_mask is not None: model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx) model_kwargs["image_attention_mask"] = model_kwargs["image_attention_mask"].index_select( 0, expanded_return_idx ) model_kwargs["pixel_values"] = model_kwargs["pixel_values"].index_select(0, expanded_return_idx) if is_encoder_decoder: if encoder_outputs is None: raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device) ) model_kwargs["encoder_outputs"] = encoder_outputs return input_ids, model_kwargs def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False): # must have this key set to at least None model_kwargs["past_key_values"] = model_kwargs.get("past_key_values", None) # update past if "past_key_values" in outputs: model_kwargs["past"] = outputs.past_key_values elif "mems" in outputs: model_kwargs["past"] = outputs.mems elif "past_buckets_states" in outputs: model_kwargs["past"] = outputs.past_buckets_states else: model_kwargs["past"] = None # update token_type_ids with last value if "token_type_ids" in model_kwargs: token_type_ids = model_kwargs["token_type_ids"] model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) # update attention masks if not is_encoder_decoder: if "attention_mask" in model_kwargs: attention_mask = model_kwargs["attention_mask"] model_kwargs["attention_mask"] = torch.cat( [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 ) if "image_attention_mask" in model_kwargs: image_attention_mask = model_kwargs["image_attention_mask"] last_mask = image_attention_mask[:, -1, :].unsqueeze(1) model_kwargs["image_attention_mask"] = last_mask return model_kwargs def prepare_inputs_for_generation(input_ids, past=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) # only last token for inputs_ids if past is defined in kwargs if past: input_ids = input_ids[:, -1].unsqueeze(-1) if token_type_ids is not None: token_type_ids = token_type_ids[:, -1].unsqueeze(-1) attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past: position_ids = position_ids[:, -1].unsqueeze(-1) pixel_values = kwargs.get("pixel_values", None) image_attention_mask = kwargs.get("image_attention_mask", None) if pixel_values is None or image_attention_mask is None: raise ValueError("pixel values and image attention mask cannot be None") return { "input_ids": input_ids, "past_key_values": past, "use_cache": kwargs.get("use_cache"), "position_ids": position_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids, "pixel_values": pixel_values, "image_attention_mask": image_attention_mask, }