|  | from dataclasses import dataclass | 
					
						
						|  | from typing import Optional, Tuple, Dict, Any, Union | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  | import numpy as np | 
					
						
						|  | from transformers.utils import ModelOutput | 
					
						
						|  | from transformers.modeling_utils import PreTrainedModel | 
					
						
						|  | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, PretrainedConfig | 
					
						
						|  | from safetensors.torch import load_file | 
					
						
						|  | import torchvision.transforms as transforms | 
					
						
						|  | from .build import load_sd_model, load_Florence2_model | 
					
						
						|  | from .vlv_utils import initiate_time_steps, normalize, process_caption | 
					
						
						|  | from .VLV_stage1 import SDModel, SDConfig | 
					
						
						|  | from .configuration_vlv import VLV_Config | 
					
						
						|  | import os | 
					
						
						|  | import sys | 
					
						
						|  | import argparse | 
					
						
						|  |  | 
					
						
						|  | def handle_module_prefix(state_dict): | 
					
						
						|  | """Handle 'module.' prefix in state dict keys.""" | 
					
						
						|  | if any(k.startswith('module.') for k in state_dict.keys()): | 
					
						
						|  | return {k.replace('module.', ''): v for k, v in state_dict.items()} | 
					
						
						|  | return state_dict | 
					
						
						|  |  | 
					
						
						|  | def create_model_args(args): | 
					
						
						|  | """Create model arguments needed by SDModel.""" | 
					
						
						|  | model_args = argparse.Namespace() | 
					
						
						|  | model_args.use_text_encoder = args.use_text_encoder | 
					
						
						|  | model_args.batch_size = args.batch_size | 
					
						
						|  | model_args.eval_batch_size = args.batch_size | 
					
						
						|  | model_args.distributed_strategy = 'none' | 
					
						
						|  | model_args.fp32 = args.fp32 | 
					
						
						|  | model_args.learnable_token_length = args.learnable_token_length | 
					
						
						|  | model_args.num_inference_steps = args.num_inference_steps | 
					
						
						|  | model_args.image_size = args.image_size | 
					
						
						|  | model_args.guidance_scale = args.guidance_scale | 
					
						
						|  | model_args.unfreeze_florence2_all = False | 
					
						
						|  | model_args.unfreeze_florence2_language_model = False | 
					
						
						|  | model_args.unfreeze_florence2_language_model_decoder = False | 
					
						
						|  | return model_args | 
					
						
						|  |  | 
					
						
						|  | def load_model_checkpoint(model, model_path, device): | 
					
						
						|  | """Load model checkpoint.""" | 
					
						
						|  | try: | 
					
						
						|  | checkpoint = torch.load(model_path, map_location="cpu") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: | 
					
						
						|  | state_dict = checkpoint['model_state_dict'] | 
					
						
						|  | elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint: | 
					
						
						|  | state_dict = checkpoint['state_dict'] | 
					
						
						|  | else: | 
					
						
						|  | state_dict = checkpoint | 
					
						
						|  |  | 
					
						
						|  | state_dict = handle_module_prefix(state_dict) | 
					
						
						|  | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) | 
					
						
						|  |  | 
					
						
						|  | if missing_keys: | 
					
						
						|  | print(f"Missing keys: {missing_keys[:10]}...") | 
					
						
						|  | if unexpected_keys: | 
					
						
						|  | print(f"Unexpected keys: {unexpected_keys[:10]}...") | 
					
						
						|  |  | 
					
						
						|  | print(f"Successfully loaded model from {model_path}") | 
					
						
						|  | except Exception as e: | 
					
						
						|  | print(f"Error loading model: {e}") | 
					
						
						|  | raise e | 
					
						
						|  |  | 
					
						
						|  | return model | 
					
						
						|  |  | 
					
						
						|  | def initialize_diffusion_model(args): | 
					
						
						|  | """Initialize the diffusion model.""" | 
					
						
						|  | config = SDConfig() | 
					
						
						|  | diffusion_model_args = create_model_args(args) | 
					
						
						|  | diffusion_model = SDModel(config, diffusion_model_args) | 
					
						
						|  | _dtype = torch.float32 if diffusion_model_args.fp32 else torch.bfloat16 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if hasattr(diffusion_model, 'vae'): | 
					
						
						|  | del diffusion_model.vae | 
					
						
						|  | if hasattr(diffusion_model, 'unet'): | 
					
						
						|  | del diffusion_model.unet | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | torch.cuda.empty_cache() | 
					
						
						|  |  | 
					
						
						|  | diffusion_model = diffusion_model.to(_dtype) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for param in diffusion_model.language_proj.parameters(): | 
					
						
						|  | param.requires_grad = False | 
					
						
						|  | diffusion_model.query_embed.requires_grad = False | 
					
						
						|  |  | 
					
						
						|  | return diffusion_model | 
					
						
						|  |  | 
					
						
						|  | class MLP(nn.Module): | 
					
						
						|  | def __init__(self, input_dim, output_dim): | 
					
						
						|  | super(MLP, self).__init__() | 
					
						
						|  | self.layers = nn.Sequential( | 
					
						
						|  | nn.Linear(input_dim, output_dim), | 
					
						
						|  | nn.GELU(), | 
					
						
						|  | nn.Linear(output_dim, output_dim), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | return self.layers(x) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  | class CLIPDecoderOutput(ModelOutput): | 
					
						
						|  | """ | 
					
						
						|  | Output class for the CLIP Decoder model. | 
					
						
						|  | """ | 
					
						
						|  | last_hidden_state: Optional[torch.FloatTensor] = None | 
					
						
						|  | generated_ids: Optional[torch.LongTensor] = None | 
					
						
						|  | generated_text: Optional[list] = None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class CLIPDecoder(nn.Module): | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | language_model: str, | 
					
						
						|  | VLV_model: SDModel, | 
					
						
						|  | device: torch.device, | 
					
						
						|  | bf16: str, | 
					
						
						|  | qwen2_config: dict = None, | 
					
						
						|  | args: argparse.Namespace = None | 
					
						
						|  | ): | 
					
						
						|  | """ | 
					
						
						|  | Initialize the CLIP Decoder model. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | language_model: Path to the language model | 
					
						
						|  | VLV_model: The VLV model instance | 
					
						
						|  | device: The device to run the model on | 
					
						
						|  | bf16: Whether to use bfloat16 precision | 
					
						
						|  | qwen2_config: Optional qwen2 configuration dict | 
					
						
						|  | """ | 
					
						
						|  | super(CLIPDecoder, self).__init__() | 
					
						
						|  |  | 
					
						
						|  | self._dtype = torch.bfloat16 if bf16 == "bf16" else torch.float32 | 
					
						
						|  | self.qwen2_tokenizer = AutoTokenizer.from_pretrained(language_model) | 
					
						
						|  |  | 
					
						
						|  | self.qwen2_config = AutoConfig.from_pretrained(language_model) | 
					
						
						|  | self.qwen2_model = AutoModelForCausalLM.from_pretrained( | 
					
						
						|  | language_model, | 
					
						
						|  | torch_dtype=self._dtype, | 
					
						
						|  | device_map=None, | 
					
						
						|  | low_cpu_mem_usage=True | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.VLV_model = VLV_model | 
					
						
						|  | self.device = device | 
					
						
						|  | self.mlp = MLP(input_dim=1024, output_dim=self.qwen2_model.config.hidden_size) | 
					
						
						|  | self.ignore_token_id = -100 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_conditional_context(self, images, batch_size): | 
					
						
						|  | """ | 
					
						
						|  | Get conditional context from images using the diffusion model. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | images: Input images | 
					
						
						|  | batch_size: Batch size | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | Decoder hidden states from the diffusion model | 
					
						
						|  | """ | 
					
						
						|  | prompt = ["<MORE_DETAILED_CAPTION>"] * batch_size | 
					
						
						|  | inputs = self.VLV_model.processor(text=prompt, images=images, return_tensors="pt").to(self.device).to(self._dtype) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.VLV_model = self.VLV_model.to(inputs["input_ids"].device) | 
					
						
						|  | self.qwen2_model = self.qwen2_model.to(inputs["input_ids"].device) | 
					
						
						|  | self.mlp = self.mlp.to(inputs["input_ids"].device) | 
					
						
						|  | self.VLV_model.model.language_model.model = self.VLV_model.model.language_model.model.to(inputs["input_ids"].device) | 
					
						
						|  |  | 
					
						
						|  | if inputs["input_ids"] is not None: | 
					
						
						|  | inputs_embeds = self.VLV_model.model.language_model.get_input_embeddings()(inputs["input_ids"]).to(self.device) | 
					
						
						|  |  | 
					
						
						|  | if inputs["pixel_values"] is not None: | 
					
						
						|  | image_features = self.VLV_model.model._encode_image(inputs["pixel_values"]).to(self.device) | 
					
						
						|  | inputs_embeds, attention_mask = self.VLV_model.model._merge_input_ids_with_image_features( | 
					
						
						|  | image_features, inputs_embeds | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if inputs_embeds is not None: | 
					
						
						|  | attention_mask = attention_mask.to(inputs_embeds.dtype) | 
					
						
						|  |  | 
					
						
						|  | encoder_outputs = self.VLV_model.model.language_model.model.encoder( | 
					
						
						|  | inputs_embeds=inputs_embeds, | 
					
						
						|  | attention_mask=attention_mask, | 
					
						
						|  | output_hidden_states=True, | 
					
						
						|  | return_dict=True | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | decoder_inputs_embeds = self.VLV_model.query_embed.expand(batch_size, -1, -1) | 
					
						
						|  | decoder_attention_mask = torch.ones( | 
					
						
						|  | (batch_size, self.VLV_model.num_queries), | 
					
						
						|  | dtype=self._dtype, | 
					
						
						|  | device=self.device | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | encoder_hidden_states = encoder_outputs.last_hidden_state.to(self._dtype) | 
					
						
						|  | decoder_input_embeds = decoder_inputs_embeds.to(self._dtype) | 
					
						
						|  | attention_mask = attention_mask.to(self._dtype) | 
					
						
						|  |  | 
					
						
						|  | decoder_outputs = self.VLV_model.model.language_model.model.decoder( | 
					
						
						|  | inputs_embeds=decoder_input_embeds, | 
					
						
						|  | attention_mask=decoder_attention_mask, | 
					
						
						|  | encoder_hidden_states=encoder_hidden_states, | 
					
						
						|  | encoder_attention_mask=attention_mask, | 
					
						
						|  | output_hidden_states=True, | 
					
						
						|  | return_dict=True | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | return decoder_outputs.last_hidden_state | 
					
						
						|  |  | 
					
						
						|  | def process_image(self, images, batch_size): | 
					
						
						|  | """ | 
					
						
						|  | Process images to get clip text embeddings. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | images: Input images | 
					
						
						|  | batch_size: Batch size | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | Processed clip text embeddings and attention mask | 
					
						
						|  | """ | 
					
						
						|  | decoder_hidden_states = self.get_conditional_context(images, batch_size) | 
					
						
						|  | context_embeds = self.VLV_model.language_proj(decoder_hidden_states) | 
					
						
						|  | clip_text_embeds = self.VLV_model.text_encoder(inputs_embeds=context_embeds).last_hidden_state | 
					
						
						|  | clip_text_embeds = self.mlp(clip_text_embeds) | 
					
						
						|  | clip_text_embeds_attention_mask = torch.ones( | 
					
						
						|  | (batch_size, self.VLV_model.num_queries), | 
					
						
						|  | dtype=torch.long, | 
					
						
						|  | device=self.device | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | return clip_text_embeds, clip_text_embeds_attention_mask | 
					
						
						|  |  | 
					
						
						|  | def prepare_generation_inputs(self, clip_text_embeds, clip_text_attention_mask=None): | 
					
						
						|  | """ | 
					
						
						|  | Prepare inputs for text generation. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | clip_text_embeds: Processed clip text embeddings | 
					
						
						|  | clip_text_attention_mask: Attention mask for clip text embeddings | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | Dictionary of generation inputs | 
					
						
						|  | """ | 
					
						
						|  | if clip_text_attention_mask is None: | 
					
						
						|  | clip_text_attention_mask = torch.ones( | 
					
						
						|  | (clip_text_embeds.shape[0], clip_text_embeds.shape[1]), | 
					
						
						|  | dtype=torch.long, | 
					
						
						|  | device=clip_text_embeds.device | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | return { | 
					
						
						|  | "inputs_embeds": clip_text_embeds, | 
					
						
						|  | "attention_mask": clip_text_attention_mask | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | def generate(self, images, max_new_tokens=300, num_beams=4, early_stopping=True): | 
					
						
						|  | """ | 
					
						
						|  | Generate text from images. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | images: Input images | 
					
						
						|  | max_new_tokens: Maximum number of tokens to generate | 
					
						
						|  | num_beams: Number of beams for beam search | 
					
						
						|  | early_stopping: Whether to stop early in beam search | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | CLIPDecoderOutput with generated ids and text | 
					
						
						|  | """ | 
					
						
						|  | batch_size = len(images) | 
					
						
						|  | clip_text_embeds, clip_text_attention_mask = self.process_image(images, batch_size) | 
					
						
						|  | generation_inputs = self.prepare_generation_inputs(clip_text_embeds, clip_text_attention_mask) | 
					
						
						|  |  | 
					
						
						|  | generation_inputs["inputs_embeds"] = generation_inputs["inputs_embeds"].to(self._dtype) | 
					
						
						|  | generation_inputs["attention_mask"] = generation_inputs["attention_mask"].to(self._dtype) | 
					
						
						|  |  | 
					
						
						|  | generated_ids = self.qwen2_model.generate( | 
					
						
						|  | inputs_embeds=generation_inputs["inputs_embeds"], | 
					
						
						|  | attention_mask=generation_inputs["attention_mask"], | 
					
						
						|  | max_new_tokens=max_new_tokens, | 
					
						
						|  | num_beams=num_beams, | 
					
						
						|  | early_stopping=early_stopping | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | generated_text = self.qwen2_tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | 
					
						
						|  | processed_generated_text = [process_caption(text) for text in generated_text] | 
					
						
						|  |  | 
					
						
						|  | return CLIPDecoderOutput( | 
					
						
						|  | generated_ids=generated_ids, | 
					
						
						|  | generated_text=processed_generated_text | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, images, captions=None): | 
					
						
						|  | """ | 
					
						
						|  | Forward pass for training. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | images: Input images | 
					
						
						|  | captions: Target captions (optional, for training) | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | CLIPDecoderOutput with loss and logits | 
					
						
						|  | """ | 
					
						
						|  | batch_size = images.shape[0] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | clip_text_embeds, clip_text_attention_mask = self.process_image(images, batch_size) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if captions is None: | 
					
						
						|  | return CLIPDecoderOutput( | 
					
						
						|  | last_hidden_state=clip_text_embeds | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | assert len(captions) == batch_size | 
					
						
						|  |  | 
					
						
						|  | processed_captions = [process_caption(caption) for caption in captions] | 
					
						
						|  | qwen_input_ids = self.qwen2_tokenizer( | 
					
						
						|  | text=processed_captions, | 
					
						
						|  | truncation=True, | 
					
						
						|  | return_tensors="pt", | 
					
						
						|  | padding="max_length", | 
					
						
						|  | max_length=300, | 
					
						
						|  | return_token_type_ids=False, | 
					
						
						|  | ).input_ids | 
					
						
						|  |  | 
					
						
						|  | assert len(captions) == batch_size | 
					
						
						|  | qwen_attention_mask = qwen_input_ids.ne(self.qwen2_tokenizer.pad_token_id).to(torch.long).to(self.device) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | labels = qwen_input_ids | 
					
						
						|  | labels[labels == self.qwen2_tokenizer.pad_token_id] = self.ignore_token_id | 
					
						
						|  | labels = labels.to(self.device) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | labels_for_embeddings = labels.clone() | 
					
						
						|  | labels_for_embeddings[labels_for_embeddings == self.ignore_token_id] = self.qwen2_tokenizer.pad_token_id | 
					
						
						|  | clip_text_embeds_qwen = self.qwen2_model.get_input_embeddings()(labels_for_embeddings) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | inputs_embeds = torch.cat((clip_text_embeds, clip_text_embeds_qwen), dim=1) | 
					
						
						|  | clip_seq_len = clip_text_embeds.shape[1] | 
					
						
						|  | clip_ignore_labels = torch.full((labels.shape[0], clip_seq_len), self.ignore_token_id).to(labels) | 
					
						
						|  | combined_labels = torch.cat((clip_ignore_labels, labels), dim=1) | 
					
						
						|  |  | 
					
						
						|  | attention_mask = torch.cat(( | 
					
						
						|  | clip_text_attention_mask, | 
					
						
						|  | qwen_attention_mask | 
					
						
						|  | ), dim=1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | outputs = self.qwen2_model( | 
					
						
						|  | inputs_embeds=inputs_embeds, | 
					
						
						|  | labels=combined_labels, | 
					
						
						|  | attention_mask=attention_mask, | 
					
						
						|  | use_cache=False | 
					
						
						|  | ) | 
					
						
						|  | return outputs | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class VLV_MODEL(PreTrainedModel): | 
					
						
						|  | config_class = VLV_Config | 
					
						
						|  | model_type = "VLV_decoder" | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, config): | 
					
						
						|  | super().__init__(config) | 
					
						
						|  | """Load the CLIPDecoder model.""" | 
					
						
						|  |  | 
					
						
						|  | device = "cuda" | 
					
						
						|  | de_diffusion_model = initialize_diffusion_model(config) | 
					
						
						|  | clip_decoder_model = CLIPDecoder( | 
					
						
						|  | language_model=config.qwen_model, | 
					
						
						|  | VLV_model=de_diffusion_model, | 
					
						
						|  | device=device, | 
					
						
						|  | bf16=config.mixed_precision, | 
					
						
						|  | qwen2_config=config.qwen2_config | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | clip_decoder_model.eval() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.VLV_model = clip_decoder_model.VLV_model | 
					
						
						|  | self.qwen2_model = clip_decoder_model.qwen2_model | 
					
						
						|  | self.mlp = clip_decoder_model.mlp | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self._clip_decoder_model = clip_decoder_model | 
					
						
						|  | self.max_new_tokens = config.max_length | 
					
						
						|  | self.num_beams = config.num_beams | 
					
						
						|  | self.transform = self.get_transform(config.image_size) | 
					
						
						|  |  | 
					
						
						|  | def get_transform(self, image_size): | 
					
						
						|  | """Transformation pipeline for input images.""" | 
					
						
						|  | return transforms.Compose([ | 
					
						
						|  | transforms.Resize(image_size), | 
					
						
						|  | transforms.CenterCrop((image_size, image_size)), | 
					
						
						|  | transforms.PILToTensor(), | 
					
						
						|  | ]) | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def from_checkpoint(cls, checkpoint_path, config=None, **kwargs): | 
					
						
						|  | """ | 
					
						
						|  | Load model from original training checkpoint. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | checkpoint_path: Path to the original model.pt checkpoint | 
					
						
						|  | config: Optional VLV_Config, will create default if None | 
					
						
						|  | **kwargs: Additional arguments for model initialization | 
					
						
						|  | """ | 
					
						
						|  | if config is None: | 
					
						
						|  |  | 
					
						
						|  | config = VLV_Config( | 
					
						
						|  | image_size=384, | 
					
						
						|  | guidance_scale=7.5, | 
					
						
						|  | learnable_token_length=77, | 
					
						
						|  | max_length=300, | 
					
						
						|  | num_beams=4, | 
					
						
						|  | **kwargs | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model = cls(config) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | device = "cuda" if torch.cuda.is_available() else "cpu" | 
					
						
						|  | load_model_checkpoint(model._clip_decoder_model, checkpoint_path, device) | 
					
						
						|  |  | 
					
						
						|  | return model | 
					
						
						|  |  | 
					
						
						|  | def forward(self, valid_images, max_length): | 
					
						
						|  | valid_images = [self.transform(img) for img in valid_images] | 
					
						
						|  | if hasattr(self._clip_decoder_model, 'module'): | 
					
						
						|  | outputs = self._clip_decoder_model.module.generate( | 
					
						
						|  | valid_images, | 
					
						
						|  | max_new_tokens=max_length, | 
					
						
						|  | num_beams=self.num_beams, | 
					
						
						|  | early_stopping=True | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | outputs = self._clip_decoder_model.generate( | 
					
						
						|  | valid_images, | 
					
						
						|  | max_new_tokens=max_length, | 
					
						
						|  | num_beams=self.num_beams, | 
					
						
						|  | early_stopping=True | 
					
						
						|  | ) | 
					
						
						|  | return outputs |