| from typing import Optional |
|
|
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
| import numpy as np |
| from dataclasses import dataclass |
|
|
| from .transformer import ( |
| LayerNormFp32, |
| LayerNorm, |
| QuickGELU, |
| MultimodalTransformer, |
| ) |
| from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower |
|
|
| try: |
| from transformers import ( |
| BeamSearchScorer, |
| LogitsProcessorList, |
| TopPLogitsWarper, |
| TopKLogitsWarper, |
| RepetitionPenaltyLogitsProcessor, |
| MinLengthLogitsProcessor, |
| MaxLengthCriteria, |
| StoppingCriteriaList |
| ) |
|
|
| GENERATION_TYPES = { |
| "top_k": TopKLogitsWarper, |
| "top_p": TopPLogitsWarper, |
| "beam_search": "beam_search" |
| } |
| _has_transformers = True |
| except ImportError as e: |
| GENERATION_TYPES = { |
| "top_k": None, |
| "top_p": None, |
| "beam_search": "beam_search" |
| } |
| _has_transformers = False |
|
|
|
|
| @dataclass |
| class MultimodalCfg(CLIPTextCfg): |
| mlp_ratio: int = 4 |
| dim_head: int = 64 |
| heads: int = 8 |
| n_queries: int = 256 |
| attn_pooler_heads: int = 8 |
|
|
|
|
| def _build_text_decoder_tower( |
| embed_dim, |
| multimodal_cfg, |
| quick_gelu: bool = False, |
| cast_dtype: Optional[torch.dtype] = None, |
| ): |
| multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg |
| act_layer = QuickGELU if quick_gelu else nn.GELU |
| norm_layer = ( |
| LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm |
| ) |
|
|
| decoder = MultimodalTransformer( |
| context_length=multimodal_cfg.context_length, |
| width=multimodal_cfg.width, |
| heads=multimodal_cfg.heads, |
| layers=multimodal_cfg.layers, |
| ls_init_value=multimodal_cfg.ls_init_value, |
| output_dim=embed_dim, |
| act_layer=act_layer, |
| norm_layer=norm_layer, |
| ) |
|
|
| return decoder |
|
|
|
|
| class CoCa(nn.Module): |
| def __init__( |
| self, |
| embed_dim, |
| multimodal_cfg: MultimodalCfg, |
| text_cfg: CLIPTextCfg, |
| vision_cfg: CLIPVisionCfg, |
| quick_gelu: bool = False, |
| cast_dtype: Optional[torch.dtype] = None, |
| pad_id: int = 0, |
| ): |
| super().__init__() |
| multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg |
| text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg |
| vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg |
|
|
| self.text = _build_text_tower( |
| embed_dim=embed_dim, |
| text_cfg=text_cfg, |
| quick_gelu=quick_gelu, |
| cast_dtype=cast_dtype, |
| ) |
|
|
| vocab_size = ( |
| text_cfg.vocab_size |
| if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None |
| else text_cfg.vocab_size |
| ) |
|
|
| self.visual = _build_vision_tower( |
| embed_dim=embed_dim, |
| vision_cfg=vision_cfg, |
| quick_gelu=quick_gelu, |
| cast_dtype=cast_dtype, |
| ) |
|
|
| self.text_decoder = _build_text_decoder_tower( |
| vocab_size, |
| multimodal_cfg=multimodal_cfg, |
| quick_gelu=quick_gelu, |
| cast_dtype=cast_dtype, |
| ) |
|
|
| self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) |
| self.pad_id = pad_id |
|
|
| @torch.jit.ignore |
| def set_grad_checkpointing(self, enable=True): |
| self.visual.set_grad_checkpointing(enable) |
| self.text.set_grad_checkpointing(enable) |
| self.text_decoder.set_grad_checkpointing(enable) |
|
|
| def _encode_image(self, images, normalize=True): |
| image_latent, tokens_embs = self.visual(images) |
| image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent |
| return image_latent, tokens_embs |
|
|
| def _encode_text(self, text, normalize=True, embed_cls=True): |
| text = text[:, :-1] if embed_cls else text |
| text_latent, token_emb = self.text(text) |
| text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent |
| return text_latent, token_emb |
|
|
| def encode_image(self, images, normalize=True): |
| image_latent, _ = self._encode_image(images, normalize=normalize) |
| return image_latent |
|
|
| def encode_text(self, text, normalize=True, embed_cls=True): |
| text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls) |
| return text_latent |
|
|
| def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None): |
| text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls) |
| if image_latent is None or image_embs is None: |
| image_latent, image_embs = self._encode_image(image) |
|
|
| |
| labels = text[:, -token_embs.shape[1]:] |
|
|
| logits = self.text_decoder(image_embs, token_embs) |
| return { |
| "image_features": image_latent, |
| "text_features": text_latent, |
| "logits": logits, |
| "labels": labels, |
| "logit_scale": self.logit_scale.exp() |
| } |
|
|
| def generate( |
| self, |
| image, |
| text=None, |
| seq_len=30, |
| max_seq_len=77, |
| temperature=1., |
| generation_type="beam_search", |
| top_p=0.1, |
| top_k=1, |
| pad_token_id=None, |
| eos_token_id=None, |
| sot_token_id=None, |
| num_beams=6, |
| num_beam_groups=3, |
| min_seq_len=5, |
| stopping_criteria=None, |
| repetition_penalty=1.0, |
| fixed_output_length=False |
| ): |
| |
| |
| assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`." |
| assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len" |
|
|
| with torch.no_grad(): |
| sot_token_id = 49406 if sot_token_id is None else sot_token_id |
| eos_token_id = 49407 if eos_token_id is None else eos_token_id |
| pad_token_id = self.pad_id if pad_token_id is None else pad_token_id |
| logit_processor = LogitsProcessorList( |
| [ |
| MinLengthLogitsProcessor(min_seq_len, eos_token_id), |
| RepetitionPenaltyLogitsProcessor(repetition_penalty), |
| ] |
| ) |
|
|
| if stopping_criteria is None: |
| stopping_criteria = [MaxLengthCriteria(max_length=seq_len)] |
|
|
| stopping_criteria = StoppingCriteriaList( |
| stopping_criteria |
| ) |
|
|
| device = image.device |
|
|
| if generation_type == "beam_search": |
| output = self._generate_beamsearch( |
| image_inputs = image, |
| pad_token_id=pad_token_id, |
| eos_token_id=eos_token_id, |
| sot_token_id=sot_token_id, |
| num_beams=num_beams, |
| num_beam_groups=num_beam_groups, |
| min_seq_len=min_seq_len, |
| stopping_criteria=stopping_criteria, |
| logit_processor=logit_processor, |
| ) |
| if fixed_output_length and output.shape[1] < seq_len: |
| return torch.cat( |
| (output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id), |
| dim=1 |
| ) |
| return output |
|
|
| elif generation_type == "top_p": |
| logit_warper = GENERATION_TYPES[generation_type](top_p) |
| elif generation_type == "top_k": |
| logit_warper = GENERATION_TYPES[generation_type](top_k) |
| else: |
| raise ValueError( |
| f"generation_type has to be one of " |
| f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}." |
| ) |
|
|
| image_latent, image_embs = self._encode_image(image) |
|
|
| if text is None: |
| text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id |
|
|
| was_training = self.training |
| num_dims = len(text.shape) |
|
|
| if num_dims == 1: |
| text = text[None, :] |
|
|
| cur_len = text.shape[1] |
| self.eval() |
| out = text |
|
|
| while True: |
| x = out[:, -max_seq_len:] |
| cur_len = x.shape[1] |
| logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1] |
| mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id) |
| sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id |
|
|
| if mask.all(): |
| if not fixed_output_length: |
| break |
| else: |
| logits = logits[~mask, :] |
| filtered_logits = logit_processor(x[~mask, :], logits) |
| filtered_logits = logit_warper(x[~mask, :], filtered_logits) |
| probs = F.softmax(filtered_logits / temperature, dim=-1) |
|
|
| if (cur_len + 1 == seq_len): |
| sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id |
| else: |
| sample[~mask, :] = torch.multinomial(probs, 1) |
|
|
| out = torch.cat((out, sample), dim=-1) |
|
|
| cur_len += 1 |
|
|
| if stopping_criteria(out, None): |
| break |
|
|
| if num_dims == 1: |
| out = out.squeeze(0) |
|
|
| self.train(was_training) |
| return out |
|
|
| def _generate_beamsearch( |
| self, |
| image_inputs, |
| pad_token_id=None, |
| eos_token_id=None, |
| sot_token_id=None, |
| num_beams=6, |
| num_beam_groups=3, |
| min_seq_len=5, |
| stopping_criteria=None, |
| logit_processor=None, |
| logit_warper=None, |
| ): |
| device = image_inputs.device |
| batch_size = image_inputs.shape[0] |
| image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0) |
| image_latent, image_embs = self._encode_image(image_inputs) |
|
|
| input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long) |
| input_ids = input_ids * sot_token_id |
| beam_scorer = BeamSearchScorer( |
| batch_size=batch_size, |
| num_beams=num_beams, |
| device=device, |
| num_beam_groups=num_beam_groups, |
| ) |
| |
| logits_processor = ( |
| LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)]) |
| if logit_processor is None |
| else logit_processor |
| ) |
|
|
| batch_size = len(beam_scorer._beam_hyps) |
| num_beams = beam_scorer.num_beams |
| num_beam_groups = beam_scorer.num_beam_groups |
| num_sub_beams = num_beams // num_beam_groups |
| batch_beam_size, cur_len = input_ids.shape |
| beam_indices = None |
|
|
| if num_beams * batch_size != batch_beam_size: |
| raise ValueError( |
| f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." |
| ) |
|
|
| beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) |
| |
| |
| beam_scores[:, ::num_sub_beams] = 0 |
| beam_scores = beam_scores.view((batch_size * num_beams,)) |
|
|
| while True: |
|
|
| |
| current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) |
|
|
| |
| reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) |
|
|
| |
| model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs) |
| outputs = self( |
| model_inputs['images'], |
| model_inputs['text'], |
| embed_cls=False, |
| image_latent=image_latent, |
| image_embs=image_embs |
| ) |
|
|
| for beam_group_idx in range(num_beam_groups): |
| group_start_idx = beam_group_idx * num_sub_beams |
| group_end_idx = min(group_start_idx + num_sub_beams, num_beams) |
| group_size = group_end_idx - group_start_idx |
|
|
| |
| batch_group_indices = [] |
|
|
| for batch_idx in range(batch_size): |
| batch_group_indices.extend( |
| [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] |
| ) |
| group_input_ids = input_ids[batch_group_indices] |
|
|
| |
| next_token_logits = outputs['logits'][batch_group_indices, -1, :] |
| vocab_size = next_token_logits.shape[-1] |
|
|
| next_token_scores_processed = logits_processor( |
| group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx |
| ) |
| next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) |
| next_token_scores = next_token_scores.expand_as(next_token_scores_processed) |
|
|
| |
| next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) |
|
|
| next_token_scores, next_tokens = torch.topk( |
| next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True |
| ) |
|
|
| next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") |
| next_tokens = next_tokens % vocab_size |
|
|
| |
| process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None |
| beam_outputs = beam_scorer.process( |
| group_input_ids, |
| next_token_scores, |
| next_tokens, |
| next_indices, |
| pad_token_id=pad_token_id, |
| eos_token_id=eos_token_id, |
| beam_indices=process_beam_indices, |
| ) |
| beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] |
| beam_next_tokens = beam_outputs["next_beam_tokens"] |
| beam_idx = beam_outputs["next_beam_indices"] |
|
|
| input_ids[batch_group_indices] = group_input_ids[beam_idx] |
| group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) |
| current_tokens[batch_group_indices] = group_input_ids[:, -1] |
|
|
| |
| |
| reordering_indices[batch_group_indices] = ( |
| num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size) |
| ) |
|
|
| input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) |
|
|
| |
| cur_len = cur_len + 1 |
| if beam_scorer.is_done or stopping_criteria(input_ids, None): |
| break |
|
|
| final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None |
| sequence_outputs = beam_scorer.finalize( |
| input_ids, |
| beam_scores, |
| next_tokens, |
| next_indices, |
| pad_token_id=pad_token_id, |
| eos_token_id=eos_token_id, |
| max_length=stopping_criteria.max_length, |
| beam_indices=final_beam_indices, |
| ) |
| return sequence_outputs['sequences'] |
|
|
|
|
| def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs): |
| if past: |
| input_ids = input_ids[:, -1].unsqueeze(-1) |
|
|
| attention_mask = kwargs.get("attention_mask", None) |
| position_ids = kwargs.get("position_ids", None) |
|
|
| if attention_mask is not None and position_ids is None: |
| |
| position_ids = attention_mask.long().cumsum(-1) - 1 |
| position_ids.masked_fill_(attention_mask == 0, 1) |
| else: |
| position_ids = None |
| return { |
| "text": input_ids, |
| "images": image_inputs, |
| "past_key_values": past, |
| "position_ids": position_ids, |
| "attention_mask": attention_mask, |
| } |
|
|