# This file has a copy of the dam library. Please refer to the dam library for documents and licenses. Licenses in this file are carried over from files in the dam library. import base64 import dataclasses import logging import math import os import os.path as osp import re import string import tempfile import warnings from abc import ABC from collections import OrderedDict from dataclasses import dataclass from enum import Enum, auto from io import BytesIO from shutil import copyfile from threading import Thread from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import numpy as np import sentencepiece as spm import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint from accelerate.hooks import add_hook_to_module from huggingface_hub import repo_exists, snapshot_download from huggingface_hub.utils import HFValidationError from PIL import Image from torch import nn from torch.nn.init import _calculate_fan_in_and_fan_out from transformers import ( AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer, PretrainedConfig, PreTrainedModel, StoppingCriteria, TextIteratorStreamer, ) from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from transformers.convert_slow_tokenizer import import_protobuf from transformers.feature_extraction_utils import BatchFeature from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from transformers.image_transforms import ( convert_to_rgb, get_channel_dimension_axis, get_resize_output_image_size, normalize, pad, rescale, resize, to_channel_dimension_format, ) from transformers.image_utils import ( IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, ChannelDimension, ImageInput, PILImageResampling, get_image_size, infer_channel_dimension_format, is_scaled_image, make_list_of_images, to_numpy_array, valid_images, ) from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, CausalLMOutputWithPast from transformers.modeling_utils import ContextManagers, PreTrainedModel, no_init_weights from transformers.processing_utils import ProcessorMixin from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils_base import ( AddedToken, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy, ) from transformers.utils import ( ModelOutput, TensorType, add_start_docstrings, add_start_docstrings_to_model_forward, is_tf_available, is_torch_available, is_torchvision_available, is_vision_available, logging, replace_return_docstrings, requires_backends, ) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # SPDX-License-Identifier: Apache-2.0 # This file is modified from https://github.com/haotian-liu/LLaVA/ CONTROLLER_HEART_BEAT_EXPIRATION = 30 WORKER_HEART_BEAT_INTERVAL = 15 LOGDIR = "." # Model Constants IGNORE_INDEX = -100 IMAGE_TOKEN_INDEX = -200 MASK_TOKEN_INDEX = -300 DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" DEFAULT_IM_START_TOKEN = "" DEFAULT_IM_END_TOKEN = "" IMAGE_PLACEHOLDER = "" class LlavaConfig(PretrainedConfig): model_type = "llava" def __init__( self, llm_cfg=None, vision_tower_cfg=None, mm_projector_cfg=None, mask_encoder_cfg=None, context_provider_cfg=None, architectures=None, resume_path=None, hidden_size=None, mm_hidden_size=None, image_aspect_ratio=None, num_video_frames=None, mm_vision_select_layer=None, mm_vision_select_feature=None, mm_use_im_start_end=False, mm_use_im_patch_token=True, mm_projector_lr=None, vision_resolution=None, interpolate_mode=None, s2=None, s2_scales=None, s2_max_split_size=None, **kwargs, ): super().__init__() self.architectures = architectures self.llm_cfg = llm_cfg self.vision_tower_cfg = vision_tower_cfg self.mm_projector_cfg = mm_projector_cfg self.mask_encoder_cfg = mask_encoder_cfg self.context_provider_cfg = context_provider_cfg self.resume_path = resume_path self.hidden_size = hidden_size self.mm_hidden_size = mm_hidden_size self.image_aspect_ratio = image_aspect_ratio self.num_video_frames = num_video_frames self.mm_vision_select_layer = mm_vision_select_layer self.mm_vision_select_feature = mm_vision_select_feature self.mm_use_im_start_end = mm_use_im_start_end self.mm_use_im_start_end = mm_use_im_start_end self.mm_use_im_patch_token = mm_use_im_patch_token self.mm_projector_lr = mm_projector_lr self.vision_resolution = vision_resolution self.interpolate_mode = interpolate_mode self.s2 = s2 self.s2_scales = s2_scales self.s2_max_split_size = s2_max_split_size # Copyright 2024 NVIDIA CORPORATION & AFFILIATES # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # SPDX-License-Identifier: Apache-2.0 # This file is modified from https://github.com/haotian-liu/LLaVA/ class SeparatorStyle(Enum): """Different separator style.""" SINGLE = auto() TWO = auto() MPT = auto() PLAIN = auto() LLAMA_2 = auto() MISTRAL = auto() LLAMA_3 = auto() @dataclasses.dataclass class Conversation: """A class that keeps all conversation history.""" system: str roles: List[str] messages: List[List[str]] offset: int sep_style: SeparatorStyle = SeparatorStyle.SINGLE sep: str = "###" sep2: str = None version: str = "Unknown" skip_next: bool = False def get_prompt(self): messages = self.messages if len(messages) > 0 and type(messages[0][1]) is tuple: messages = self.messages.copy() init_role, init_msg = messages[0].copy() init_msg = init_msg[0].replace("", "").strip() if "mmtag" in self.version: messages[0] = (init_role, init_msg) messages.insert(0, (self.roles[0], "")) messages.insert(1, (self.roles[1], "Received.")) else: messages[0] = (init_role, "\n" + init_msg) if self.sep_style == SeparatorStyle.SINGLE: ret = self.system + self.sep for role, message in messages: if message: if type(message) is tuple: message, _, _ = message ret += role + ": " + message + self.sep else: ret += role + ":" elif self.sep_style == SeparatorStyle.TWO: seps = [self.sep, self.sep2] ret = self.system + seps[0] for i, (role, message) in enumerate(messages): if message: if type(message) is tuple: message, _, _ = message ret += role + ": " + message + seps[i % 2] else: ret += role + ":" elif self.sep_style == SeparatorStyle.LLAMA_3: ret = self.system + self.sep for role, message in messages: if message: if type(message) is tuple: message = message[0] ret += role + message + self.sep else: ret += role elif self.sep_style == SeparatorStyle.MPT: ret = self.system + self.sep for role, message in messages: if message: if type(message) is tuple: message, _, _ = message ret += role + message + self.sep else: ret += role elif ( self.sep_style == SeparatorStyle.LLAMA_2 or self.sep_style == SeparatorStyle.MISTRAL ): if self.sep_style == SeparatorStyle.LLAMA_2: def wrap_sys(msg): return f"<>\n{msg}\n<>\n\n" else: def wrap_sys(msg): return f"{msg}" + ("\n" if msg else "") def wrap_inst(msg): return f"[INST] {msg} [/INST]" ret = "" if self.sep_style == SeparatorStyle.MISTRAL: ret += "" for i, (role, message) in enumerate(messages): if i == 0: assert message, "first message should not be none" assert role == self.roles[0], "first message should come from user" if message: if type(message) is tuple: message, _, _ = message if i == 0: message = wrap_sys(self.system) + message if i % 2 == 0: message = wrap_inst(message) ret += self.sep + message else: if self.sep_style == SeparatorStyle.LLAMA_2: ret += " " + message + " " + self.sep2 else: ret += message + self.sep2 else: ret += "" ret = ret.lstrip(self.sep) elif self.sep_style == SeparatorStyle.PLAIN: seps = [self.sep, self.sep2] ret = self.system for i, (role, message) in enumerate(messages): if message: if type(message) is tuple: message, _, _ = message ret += message + seps[i % 2] else: ret += "" else: raise ValueError(f"Invalid style: {self.sep_style}") return ret def append_message(self, role, message): self.messages.append([role, message]) def get_images(self, return_pil=False): images = [] for i, (role, msg) in enumerate(self.messages[self.offset :]): if i % 2 == 0: if type(msg) is tuple: import base64 from io import BytesIO from PIL import Image msg, image, image_process_mode = msg if image_process_mode == "Pad": def expand2square(pil_img, background_color=(122, 116, 104)): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new( pil_img.mode, (width, width), background_color ) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new( pil_img.mode, (height, height), background_color ) result.paste(pil_img, ((height - width) // 2, 0)) return result image = expand2square(image) elif image_process_mode in ["Default", "Crop"]: pass elif image_process_mode == "Resize": image = image.resize((336, 336)) else: raise ValueError( f"Invalid image_process_mode: {image_process_mode}" ) max_hw, min_hw = max(image.size), min(image.size) aspect_ratio = max_hw / min_hw max_len, min_len = 800, 400 shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) longest_edge = int(shortest_edge * aspect_ratio) W, H = image.size if longest_edge != max(image.size): if H > W: H, W = longest_edge, shortest_edge else: H, W = shortest_edge, longest_edge image = image.resize((W, H)) if return_pil: images.append(image) else: buffered = BytesIO() image.save(buffered, format="PNG") img_b64_str = base64.b64encode(buffered.getvalue()).decode() images.append(img_b64_str) return images def to_gradio_chatbot(self): ret = [] for i, (role, msg) in enumerate(self.messages[self.offset :]): if i % 2 == 0: if type(msg) is tuple: import base64 from io import BytesIO msg, image, image_process_mode = msg max_hw, min_hw = max(image.size), min(image.size) aspect_ratio = max_hw / min_hw max_len, min_len = 800, 400 shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) longest_edge = int(shortest_edge * aspect_ratio) W, H = image.size if H > W: H, W = longest_edge, shortest_edge else: H, W = shortest_edge, longest_edge image = image.resize((W, H)) buffered = BytesIO() image.save(buffered, format="JPEG") img_b64_str = base64.b64encode(buffered.getvalue()).decode() img_str = f'user upload image' msg = img_str + msg.replace("", "").strip() ret.append([msg, None]) else: ret.append([msg, None]) else: ret[-1][-1] = msg return ret def copy(self): return Conversation( system=self.system, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2, version=self.version, ) def dict(self): if len(self.get_images()) > 0: return { "system": self.system, "roles": self.roles, "messages": [ [x, y[0] if type(y) is tuple else y] for x, y in self.messages ], "offset": self.offset, "sep": self.sep, "sep2": self.sep2, } return { "system": self.system, "roles": self.roles, "messages": self.messages, "offset": self.offset, "sep": self.sep, "sep2": self.sep2, } conv_vicuna_v0 = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), messages=( ( "Human", "What are the key differences between renewable and non-renewable energy sources?", ), ( "Assistant", "Renewable energy sources are those that can be replenished naturally in a relatively " "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " "Non-renewable energy sources, on the other hand, are finite and will eventually be " "depleted, such as coal, oil, and natural gas. Here are some key differences between " "renewable and non-renewable energy sources:\n" "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " "energy sources are finite and will eventually run out.\n" "2. Environmental impact: Renewable energy sources have a much lower environmental impact " "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " "and other negative effects.\n" "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " "have lower operational costs than non-renewable sources.\n" "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " "locations than non-renewable sources.\n" "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " "situations and needs, while non-renewable sources are more rigid and inflexible.\n" "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " "non-renewable sources are not, and their depletion can lead to economic and social instability.\n", ), ), offset=2, sep_style=SeparatorStyle.SINGLE, sep="###", ) conv_vicuna_v1 = Conversation( system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.", roles=("USER", "ASSISTANT"), version="v1", messages=(), offset=0, sep_style=SeparatorStyle.TWO, sep=" ", sep2="", ) # kentang-mit@: This conversation template is designed for SFT on VFLAN. conv_vicuna_v1_nosys = Conversation( system="", roles=("USER", "ASSISTANT"), version="v1_nosys", messages=(), offset=0, sep_style=SeparatorStyle.TWO, sep=" ", sep2="", ) conv_llama_2 = Conversation( system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""", roles=("USER", "ASSISTANT"), version="llama_v2", messages=(), offset=0, sep_style=SeparatorStyle.LLAMA_2, sep="", sep2="", ) conv_mistral = Conversation( system="", roles=("USER", "ASSISTANT"), version="mistral", messages=(), offset=0, sep_style=SeparatorStyle.MISTRAL, sep="", sep2="", ) conv_llava_llama_2 = Conversation( system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.", roles=("USER", "ASSISTANT"), version="llama_v2", messages=(), offset=0, sep_style=SeparatorStyle.LLAMA_2, sep="", sep2="", ) conv_mpt = Conversation( system="""<|im_start|>system A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""", roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), version="mpt", messages=(), offset=0, sep_style=SeparatorStyle.MPT, sep="<|im_end|>", ) conv_llava_plain = Conversation( system="", roles=("", ""), messages=(), offset=0, sep_style=SeparatorStyle.PLAIN, sep="\n", ) conv_llava_v0 = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), messages=(), offset=0, sep_style=SeparatorStyle.SINGLE, sep="###", ) conv_llava_v0_mmtag = Conversation( system="A chat between a curious user and an artificial intelligence assistant. " "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." "The visual content will be provided with the following format: visual content.", roles=("Human", "Assistant"), messages=(), offset=0, sep_style=SeparatorStyle.SINGLE, sep="###", version="v0_mmtag", ) conv_llava_v1 = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("USER", "ASSISTANT"), version="v1", messages=(), offset=0, sep_style=SeparatorStyle.TWO, sep=" ", sep2="", ) conv_llava_v1_mmtag = Conversation( system="A chat between a curious user and an artificial intelligence assistant. " "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." "The visual content will be provided with the following format: visual content.", roles=("USER", "ASSISTANT"), messages=(), offset=0, sep_style=SeparatorStyle.TWO, sep=" ", sep2="", version="v1_mmtag", ) hermes_2 = Conversation( system="<|im_start|>system\nAnswer the questions.", roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), sep_style=SeparatorStyle.MPT, sep="<|im_end|>", messages=(), offset=0, version="hermes-2", ) # Template added by Yukang. Note (kentang-mit@): sep is <|eot_id|> for official template. llama_3_chat = Conversation( system="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.", roles=( "<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>system<|end_header_id|>\n\n", ), version="llama_v3", messages=(), offset=0, sep_style=SeparatorStyle.LLAMA_3, sep="<|end_of_text|>", ) default_conversation = conv_vicuna_v1 conv_templates = { "default": conv_vicuna_v0, "hermes-2": hermes_2, "llama_3": llama_3_chat, "v0": conv_vicuna_v0, "v1": conv_vicuna_v1, "vicuna_v1": conv_vicuna_v1, "vicuna_v1_nosys": conv_vicuna_v1_nosys, "llama_2": conv_llama_2, "mistral": conv_mistral, "plain": conv_llava_plain, "v0_plain": conv_llava_plain, "llava_v0": conv_llava_v0, "v0_mmtag": conv_llava_v0_mmtag, "llava_v1": conv_llava_v1, "v1_mmtag": conv_llava_v1_mmtag, "llava_llama_2": conv_llava_llama_2, "mpt": conv_mpt, } # if __name__ == "__main__": # print(default_conversation.get_prompt()) # Copyright 2024 NVIDIA CORPORATION & AFFILIATES # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # SPDX-License-Identifier: Apache-2.0 def get_frame_from_vcap(vidcap, num_frames=10, fps=None, frame_count=None): import cv2 if fps is None or frame_count is None: # if one of fps or frame_count is None, still recompute fps = vidcap.get(cv2.CAP_PROP_FPS) frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) if fps == 0 or frame_count == 0: print("Video file not found. return empty images.") return [ Image.new("RGB", (720, 720)), ] * num_frames frame_count / fps frame_interval = frame_count // num_frames if frame_interval == 0 and frame_count <= 1: print("frame_interval is equal to 0. return empty image.") return [ Image.new("RGB", (720, 720)), ] * num_frames # print("duration:", duration, "frames:", frame_count, "intervals:", frame_interval) images = [] count = 0 success = True frame_indices = np.linspace(0, frame_count - 2, num_frames, dtype=int) while success: # print("frame_count:", frame_count, "count:", count, "num_frames:", num_frames, "frame_interval:", frame_interval) if frame_count >= num_frames: success, frame = vidcap.read() if count in frame_indices: img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) im_pil = Image.fromarray(img) images.append(im_pil) if len(images) >= num_frames: return images count += 1 else: # Left padding frames if the video is not long enough success, frame = vidcap.read() if success: img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) im_pil = Image.fromarray(img) images.append(im_pil) count += 1 elif count >= 1: width, height = images[-1].size images = [Image.new("RGB", (width, height))] * ( num_frames - len(images) ) + images print("padding frames:", (num_frames - len(images))) return images else: break raise ValueError("Did not find enough frames in the video. return empty image.") def opencv_extract_frames(vpath_or_bytesio, frames=6, fps=None, frame_count=None): """ Extract frames from a video using OpenCV. Args: vpath_or_bytesio (str or BytesIO): Path to the video file or BytesIO object containing the video. frames (int): Number of frames to extract from the video. Returns: list: List of PIL Images extracted from the video. Raises: NotImplementedError: If the type of `vpath_or_bytesio` is not supported. """ import cv2 if isinstance(vpath_or_bytesio, str): vidcap = cv2.VideoCapture(vpath_or_bytesio) return get_frame_from_vcap(vidcap, frames, fps=fps, frame_count=frame_count) elif isinstance(vpath_or_bytesio, (BytesIO,)): # assuming mp4 with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video: temp_video.write(vpath_or_bytesio.read()) temp_video_name = temp_video.name vidcap = cv2.VideoCapture(temp_video_name) return get_frame_from_vcap(vidcap, frames, fps=fps, frame_count=frame_count) else: raise NotImplementedError(type(vpath_or_bytesio)) def load_image_from_base64(image): return Image.open(BytesIO(base64.b64decode(image))) def expand2square(pil_img, background_color): """ Expand the given PIL image to a square shape by adding padding. Parameters: - pil_img: The PIL image to be expanded. - background_color: The color of the padding to be added. Returns: - The expanded PIL image. If the image is already square, it is returned as is. If the image is wider than it is tall, padding is added to the top and bottom. If the image is taller than it is wide, padding is added to the left and right. """ width, height = pil_img.size if pil_img.mode == "L": background_color = background_color[0] if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result def process_image(image_file, data_args, image_folder, pil_preprocess_fn=None): processor = data_args.image_processor if isinstance(image_file, str): if image_folder is not None: image = Image.open(os.path.join(image_folder, image_file)).convert("RGB") else: image = Image.open(image_file).convert("RGB") else: # image is stored in bytearray image = image_file.convert("RGB") info = None if pil_preprocess_fn is not None: image = pil_preprocess_fn(image) if isinstance(image, tuple): image, info = image if data_args.image_aspect_ratio == "resize": if hasattr(data_args.image_processor, "crop_size"): # CLIP vision tower crop_size = data_args.image_processor.crop_size else: # SIGLIP vision tower assert hasattr(data_args.image_processor, "size") crop_size = data_args.image_processor.size image = image.resize((crop_size["height"], crop_size["width"])) if data_args.image_aspect_ratio == "pad": def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean)) image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0] else: # Using default behavior of the vision encoder # For CLIP, default is central crop # For Radio, default is central crop # For Siglip, default is resize # For InternVIT, default is resize image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0] if info is not None: return image, info return image def process_images(images, image_processor, model_cfg): model_cfg.image_processor = image_processor new_images = [process_image(image, model_cfg, None) for image in images] if all(x.shape == new_images[0].shape for x in new_images): new_images = torch.stack(new_images, dim=0) return new_images # Note that newer VILA codebase adds an lstrip option that defaults to False, and the functionality is the same by default def tokenizer_image_token( prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None ): prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("")] def insert_separator(X, sep): return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] input_ids = [] offset = 0 if ( len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id ): offset = 1 input_ids.append(prompt_chunks[0][0]) for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): input_ids.extend(x[offset:]) if return_tensors is not None: if return_tensors == "pt": return torch.tensor(input_ids, dtype=torch.long) raise ValueError(f"Unsupported tensor type: {return_tensors}") return input_ids def is_gemma_tokenizer(tokenizer): return "gemma" in tokenizer.__class__.__name__.lower() def get_model_name_from_path(model_path): if not model_path: return "describe_anything_model" model_path = model_path.strip("/") model_paths = model_path.split("/") if model_paths[-1].startswith("checkpoint-"): return model_paths[-2] + "_" + model_paths[-1] else: return model_paths[-1] class KeywordsStoppingCriteria(StoppingCriteria): def __init__(self, keywords, tokenizer, input_ids): self.keywords = keywords self.keyword_ids = [] self.max_keyword_len = 0 for keyword in keywords: cur_keyword_ids = tokenizer(keyword).input_ids if ( len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id ): cur_keyword_ids = cur_keyword_ids[1:] if len(cur_keyword_ids) > self.max_keyword_len: self.max_keyword_len = len(cur_keyword_ids) self.keyword_ids.append(torch.tensor(cur_keyword_ids)) self.tokenizer = tokenizer self.start_len = input_ids.shape[1] def call_for_batch( self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs ) -> bool: offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) self.keyword_ids = [ keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids ] for keyword_id in self.keyword_ids: if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all(): return True outputs = self.tokenizer.batch_decode( output_ids[:, -offset:], skip_special_tokens=True )[0] for keyword in self.keywords: if keyword in outputs: return True return False def __call__( self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs ) -> bool: outputs = [] for i in range(output_ids.shape[0]): outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) return all(outputs) # Copyright 2024 NVIDIA CORPORATION & AFFILIATES # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # SPDX-License-Identifier: Apache-2.0 # This file is modified from https://github.com/haotian-liu/LLaVA/ def get_model_config(config): # `mask_encoder_cfg` and `context_provider_cfg` are optional default_keys = [ "llm_cfg", "vision_tower_cfg", "mm_projector_cfg", "mask_encoder_cfg", "context_provider_cfg", ] if hasattr(config, "_name_or_path") and len(config._name_or_path) >= 2: root_path = config._name_or_path else: root_path = config.resume_path # download from huggingface if root_path is not None and not osp.exists(root_path): try: valid_hf_repo = repo_exists(root_path) except HFValidationError: valid_hf_repo = False if valid_hf_repo: root_path = snapshot_download(root_path) return_list = [] for key in default_keys: cfg = getattr(config, key, None) if isinstance(cfg, dict): try: return_list.append(os.path.join(root_path, key[:-4])) except: raise ValueError(f"Cannot find resume path in config for {key}!") elif isinstance(cfg, PretrainedConfig): return_list.append(os.path.join(root_path, key[:-4])) elif isinstance(cfg, str): return_list.append(cfg) elif cfg is None: # We still return even if the cfg is None or does not exist return_list.append(cfg) return return_list def is_mm_model(model_path): """ Check if the model at the given path is a visual language model. Args: model_path (str): The path to the model. Returns: bool: True if the model is an MM model, False otherwise. """ config = AutoConfig.from_pretrained(model_path) architectures = config.architectures for architecture in architectures: if "llava" in architecture.lower(): return True return False def auto_upgrade(config): cfg = AutoConfig.from_pretrained(config) if "llava" in config and "llava" not in cfg.model_type: assert cfg.model_type == "llama" print( "You are using newer LLaVA code base, while the checkpoint of v0 is from older code base." ) print( "You must upgrade the checkpoint to the new code base (this can be done automatically)." ) confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") if confirm.lower() in ["y", "yes"]: print("Upgrading checkpoint...") assert len(cfg.architectures) == 1 setattr(cfg.__class__, "model_type", "llava") cfg.architectures[0] = "LlavaLlamaForCausalLM" cfg.save_pretrained(config) print("Checkpoint upgraded.") else: print("Checkpoint upgrade aborted.") exit(1) # Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # TODO decide whether should we use metaclass class LlavaMetaModel(ABC): def init_vlm(self, config: PreTrainedModel = None, *args, **kwargs): # TODO(ligeng): figure out how from_config and from_pretrained works in HF implementation. if ( hasattr(self, "llm") or hasattr(self, "vision_tower") or hasattr(self, "mm_projector") ): # already initialized, skipped return model_dtype = getattr(config, "model_dtype", "torch.float16") if not hasattr(config, "model_dtype"): warnings.warn( "model_dtype not found in config, defaulting to torch.float16." ) config.model_dtype = model_dtype # print("init_vlm(): config", config); input("DEBUG init_vlm") cfgs = get_model_config(config) # Only the first three are required. Others are optional. ( llm_cfg, vision_tower_cfg, mm_projector_cfg, mask_encoder_cfg, context_provider_cfg, ) = cfgs if llm_cfg is None or vision_tower_cfg is None or mm_projector_cfg is None: raise ValueError( "`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config." ) # print("init_vlm():", cfgs); input("DEBUG init_vlm") # print(llm_cfg, vision_tower_cfg, mm_projector_cfg); input("DEBUG init_vlm") self.llm, self.tokenizer = build_llm_and_tokenizer( llm_cfg, config, *args, **kwargs ) self.vision_tower = build_vision_tower(vision_tower_cfg, config) self.mm_projector = build_mm_projector(mm_projector_cfg, config) self.context_provider = ( build_context_provider(context_provider_cfg, config) if context_provider_cfg is not None else None ) self.post_config() self.is_loaded = True assert ( self.llm is not None or self.vision_tower is not None or self.mm_projector is not None ), "At least one of the components must be instantiated." @classmethod def load_from_config(cls, model_path_or_config, *args, **kwargs): pass # FIXME we will use this function to load model in the future @classmethod def load_pretrained(cls, model_path_or_config, *args, **kwargs): config = kwargs.pop("config", None) if config is None: if isinstance(model_path_or_config, str): config = AutoConfig.from_pretrained(model_path_or_config, trust_remote_code=True) elif isinstance(model_path_or_config, LlavaConfig): config = model_path_or_config else: raise NotImplementedError( f"wrong type, {type(model_path_or_config)} \ {isinstance(model_path_or_config, LlavaConfig)}" ) model_dtype = getattr(config, "model_dtype", "torch.float16") if not hasattr(config, "model_dtype"): warnings.warn( "model_dtype not found in config, defaulting to torch.float16." ) config.model_dtype = model_dtype cfgs = get_model_config(config) # Only the first three are required. Others are optional. ( llm_cfg, vision_tower_cfg, mm_projector_cfg, mask_encoder_cfg, context_provider_cfg, ) = cfgs if llm_cfg is None or vision_tower_cfg is None or mm_projector_cfg is None: raise ValueError( "`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config." ) # print(llm_cfg, vision_tower_cfg, mm_projector_cfg); input("DEBUG load_pretrained") with ContextManagers( [ no_init_weights(_enable=True), ] ): vlm = cls(config, *args, **kwargs) # print(llm_cfg, vision_tower_cfg, mm_projector_cfg); input("DEBUG load_pretrained finish") if ( hasattr(vlm, "llm") or hasattr(vlm, "vision_tower") or hasattr(vlm, "mm_projector") ): if vlm.is_loaded: return vlm vlm.llm, vlm.tokenizer = build_llm_and_tokenizer( llm_cfg, config, *args, **kwargs ) vlm.vision_tower = build_vision_tower(vision_tower_cfg, config) vlm.mm_projector = build_mm_projector(mm_projector_cfg, config) if mask_encoder_cfg is not None: raise NotImplementedError("Mask encoder is not supported.") vlm.context_provider = ( build_context_provider(context_provider_cfg, config) if context_provider_cfg is not None else None ) self.post_config() self.is_loaded = True # FIXME(ligeng, yunhao): llm should never be none here. assert ( vlm.llm is not None or vlm.vision_tower is not None or vlm.mm_projector is not None ), "At least one of the components must be instantiated." return vlm # FIXME we will use this function to save the model in the future def save_pretrained(self, output_dir, state_dict=None): if state_dict is None: # other wise fetch from deepspeed # state_dict = accelerator.get_state_dict(is_deepspeed_enabled) state_dict = self.state_dict() if getattr(self, "tokenizer", None): self.tokenizer.save_pretrained(osp.join(output_dir, "llm")) if self.get_llm(): print(f"saving llm to {osp.join(output_dir, 'llm')}") self.llm.config._name_or_path = osp.join(output_dir, "llm") llm_state_dict = OrderedDict( {k.split("llm.")[-1]: v for k, v in state_dict.items() if "llm" in k} ) self.llm.save_pretrained( os.path.join(output_dir, "llm"), state_dict=llm_state_dict ) self.config.llm_cfg = self.llm.config if ( self.get_vision_tower() and "radio" not in self.get_vision_tower().__class__.__name__.lower() ): print(f"saving vision_tower to {osp.join(output_dir, 'vision_tower')}") self.vision_tower.config._name_or_path = osp.join( output_dir, "vision_tower" ) vision_tower_state_dict = OrderedDict( { k.split("vision_tower.vision_tower.")[-1]: v for k, v in state_dict.items() if "vision_tower" in k } ) self.vision_tower.vision_tower.save_pretrained( os.path.join(output_dir, "vision_tower"), state_dict=vision_tower_state_dict, ) self.vision_tower.image_processor.save_pretrained( os.path.join(output_dir, "vision_tower") ) self.config.vision_tower_cfg = self.vision_tower.config if hasattr(self.config.vision_tower_cfg, "auto_map"): delattr(self.config.vision_tower_cfg, "auto_map") if self.get_mm_projector(): print(f"saving mm_projector to {osp.join(output_dir, 'mm_projector')}") self.mm_projector.config._name_or_path = osp.join( output_dir, "mm_projector" ) mm_projector_state_dict = OrderedDict( { k.split("mm_projector.")[-1]: v for k, v in state_dict.items() if "mm_projector" in k } ) self.mm_projector.save_pretrained( os.path.join(output_dir, "mm_projector"), state_dict=mm_projector_state_dict, ) self.config.mm_projector_cfg = self.mm_projector.config if self.get_context_provider(): print( f"saving context_provider to {osp.join(output_dir, 'context_provider')}" ) self.context_provider.config._name_or_path = osp.join( output_dir, "context_provider" ) context_provider_state_dict = OrderedDict( { k.split("context_provider.")[-1]: v for k, v in state_dict.items() if "context_provider" in k } ) self.context_provider.save_pretrained( os.path.join(output_dir, "context_provider"), state_dict=context_provider_state_dict, ) self.config.context_provider_cfg = self.context_provider.config # update and save top-level config self.config._name_or_path = output_dir self.config.architectures = [self.__class__.__name__] self.config.save_pretrained(output_dir) def get_llm(self): llm = getattr(self, "llm", None) if type(llm) is list: llm = llm[0] return llm def get_lm_head(self): lm_head = getattr(self.get_llm(), "lm_head", None) return lm_head def get_vision_tower(self): vision_tower = getattr(self, "vision_tower", None) if type(vision_tower) is list: vision_tower = vision_tower[0] return vision_tower def get_mm_projector(self): mm_projector = getattr(self, "mm_projector", None) if type(mm_projector) is list: mm_projector = mm_projector[0] return mm_projector def get_context_provider(self): context_provider = getattr(self, "context_provider", None) return context_provider def post_config(self): self.training = self.get_llm().training # configuration if getattr(self.config, "llm_cfg", None) is None: self.config.llm_cfg = self.llm.config if getattr(self.config, "vision_tower_cfg", None) is None: self.config.vision_tower_cfg = self.vision_tower.config if getattr(self.config, "mm_projector_cfg", None) is None: self.config.mm_projector_cfg = self.mm_projector.config if ( getattr(self.config, "context_provider_cfg", None) is None and self.context_provider is not None ): self.config.context_provider_cfg = self.context_provider.config def freezed_module_patch(self): """ Huggingface will call model.train() at each training_step. To ensure the expected behaviors for modules like dropout, batchnorm, etc., we need to call model.eval() for the freezed modules. """ if self.training: if self.get_llm() and not getattr( self.config, "tune_language_model", False ): logging.warning( "Caution: Your LLM is currently in training mode, ensuring accurate gradient computation. Please be vigilant, particularly regarding BatchNorm and Dropout operations." ) if self.get_vision_tower() and not getattr( self.config, "tune_vision_tower", False ): self.get_vision_tower().eval() if self.get_mm_projector() and not getattr( self.config, "tune_mm_projector", False ): self.get_mm_projector().eval() if self.get_context_provider() and not getattr( self.config, "tune_context_provider", False ): self.get_context_provider().eval() def encode_images(self, images): image_features = self.get_vision_tower()(images) image_features = self.get_mm_projector()(image_features) return image_features def encode_images_with_context(self, images): context_provider = self.get_context_provider() # If the channels completely match, they are cimage (image with context). cimage_mask = torch.any( (images[:, :4, ...] != images[:, 4:, ...]).flatten(start_dim=1), dim=1 ) if context_provider.treat_image_as_cimage: # If the context provider treats the image as cimage, then all images are cimage. cimage_mask[:] = True if context_provider.context_image_as_queries: # Swap the crop image and full image since the model uses the full image as queries by default images = torch.cat((images[:, 4:, ...], images[:, :4, ...]), dim=1) # Process the first 4 channels for all images: for image it's the image, for cimage it's the full image vision_tower = self.get_vision_tower() # Encode context images (full images) image_features = vision_tower(images[:, :4, ...]).to(self.device) # Each cimage has 8 channels (full and crop concatenated) cimage_concatenated = images[cimage_mask] cimage_full_features = image_features[cimage_mask] if context_provider.context_provider_type == "cross_attn_end_to_all": cimage_features = self.context_provider( cimage_full_features=cimage_full_features, cimage_concatenated=cimage_concatenated, vision_tower=vision_tower, ).to(self.device) elif context_provider.context_provider_type == "concat": # Full features of cimages are computed but not used. cimage_features = self.context_provider( cimage_concatenated=cimage_concatenated, vision_tower=vision_tower ).to(self.device) else: raise NotImplementedError( f"Context provider type {context_provider.context_provider_type} not implemented." ) # Put cimage_features into image_features image_features[cimage_mask] = cimage_features # Project to the llm space image_features = self.get_mm_projector()(image_features) return image_features # @yunhao: is there a better way to handle function call and attributes for llm? # support beam search def _temporary_reorder_cache(self, past_key_values, sorted_idx): return self.get_llm()._temporary_reorder_cache(past_key_values, sorted_idx) def get_input_embeddings(self): return self.get_llm().get_input_embeddings() def get_output_embeddings(self): return self.get_llm().get_output_embeddings() def resize_token_embeddings(self, embed_size): self.get_llm().resize_token_embeddings(embed_size) class LlavaMetaForCausalLM(ABC): """This class is originally implemented by the LLaVA team and modified by Haotian Tang and Jason Lu based on Ji Lin's implementation to support multiple images and input packing.""" # TODO move the forward function here if there is no need to override it def prepare_inputs_labels_for_multimodal( self, input_ids, position_ids, attention_mask, past_key_values, labels, images ): vision_tower = self.get_vision_tower() if vision_tower is None or images is None or input_ids.shape[1] == 1: if ( past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1 ): target_shape = past_key_values[-1][-1].shape[-2] + 1 attention_mask = torch.cat( ( attention_mask, torch.ones( ( attention_mask.shape[0], target_shape - attention_mask.shape[1], ), dtype=attention_mask.dtype, device=attention_mask.device, ), ), dim=1, ) position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 return ( input_ids, position_ids, attention_mask, past_key_values, None, labels, ) # handle different image dtypes for packing if type(images) is list: images = torch.cat(images, dim=0) elif images.ndim == 5: # batch_size x seq_len x image_channels images = images.flatten(0, 1) if getattr(self, "context_provider", None): image_features = self.encode_images_with_context(images) else: # Since we slice it with index below, turning it into a list splits things by the first index which does not result in data copy or degrade performance. # Example dimension: [1, 196, 2560] assert ( images.shape[1] <= 4 ), "images have more than 4 channels, but context provider is not included" image_features = self.encode_images(images).to(self.device) # Note (kentang-mit@): image start / end is not implemented here to support pretraining. if getattr(self.config, "turn_mm_projector", False) and getattr( self.config, "mm_use_im_start_end", False ): raise NotImplementedError # Let's just add dummy tensors if they do not exist, # it is a headache to deal with None all the time. # But it is not ideal, and if you have a better idea, # please open an issue / submit a PR, thanks. _labels = labels _position_ids = position_ids _attention_mask = attention_mask if attention_mask is None: attention_mask = torch.ones_like(input_ids, dtype=torch.bool) else: attention_mask = attention_mask.bool() if position_ids is None: position_ids = torch.arange( 0, input_ids.shape[1], dtype=torch.long, device=input_ids.device ) if labels is None: labels = torch.full_like(input_ids, IGNORE_INDEX) # remove the padding using attention_mask input_ids_copy = input_ids.clone() # kentang-mit@: Otherwise tokenizer out of bounds. Embeddings of image tokens will not be used. input_ids_copy[input_ids_copy == IMAGE_TOKEN_INDEX] = 0 input_embeds = self.llm.model.embed_tokens(input_ids_copy) input_ids = [ cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) ] input_embeds_1 = [ cur_input_embeds[cur_attention_mask] for cur_input_embeds, cur_attention_mask in zip( input_embeds, attention_mask ) ] labels = [ cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask) ] new_input_embeds = [] new_labels = [] cur_image_idx = 0 # print("BEFORE BATCH LOOP:", len(input_ids), input_ids[0].shape, input_ids[0].device, [(x == IMAGE_TOKEN_INDEX).sum() for x in input_ids]) # kentang-mit@: If some part of the model is executed in the loop, the the loop length needs to be a constant. for batch_idx, cur_input_ids in enumerate(input_ids): cur_input_ids = input_ids[batch_idx] num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() if num_images == 0: cur_image_features = image_features[0] # cur_input_embeds_1 = self.get_llm().embed_tokens(cur_input_ids) cur_input_embeds_1 = input_embeds_1[batch_idx] cur_input_embeds = torch.cat( [cur_input_embeds_1, cur_image_features[0:0]], dim=0 ) new_input_embeds.append(cur_input_embeds) new_labels.append(labels[batch_idx]) # kenang-mit@: we do not have placeholdr image for text-only data now. # cur_image_idx += 1 continue cur_input_embeds = input_embeds_1[batch_idx] image_token_indices = ( [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] ) cur_input_ids_noim = [] cur_labels = labels[batch_idx] cur_labels_noim = [] cur_input_embeds_no_im = [] for i in range(len(image_token_indices) - 1): cur_input_ids_noim.append( cur_input_ids[ image_token_indices[i] + 1 : image_token_indices[i + 1] ] ) cur_labels_noim.append( cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]] ) cur_input_embeds_no_im.append( cur_input_embeds[ image_token_indices[i] + 1 : image_token_indices[i + 1] ] ) [x.shape[0] for x in cur_labels_noim] # cur_input_embeds = self.get_llm().embed_tokens(torch.cat(cur_input_ids_noim)) # cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) cur_new_input_embeds = [] cur_new_labels = [] for i in range(num_images + 1): cur_new_input_embeds.append(cur_input_embeds_no_im[i]) cur_new_labels.append(cur_labels_noim[i]) if i < num_images: cur_image_features = image_features[cur_image_idx] cur_image_idx += 1 cur_new_input_embeds.append(cur_image_features) cur_new_labels.append( torch.full( (cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype, ) ) cur_new_input_embeds = torch.cat(cur_new_input_embeds) cur_new_labels = torch.cat(cur_new_labels) new_input_embeds.append(cur_new_input_embeds) new_labels.append(cur_new_labels) # Truncate sequences to max length as image embeddings can make the sequence longer tokenizer_model_max_length = getattr( self.llm.config, "tokenizer_model_max_length", None ) if tokenizer_model_max_length is not None: if any(len(x) > tokenizer_model_max_length for x in new_input_embeds): warnings.warn("Inputs truncated!") new_input_embeds = [ x[:tokenizer_model_max_length] for x in new_input_embeds ] new_labels = [x[:tokenizer_model_max_length] for x in new_labels] # Combine them max_len = max(x.shape[0] for x in new_input_embeds) batch_size = len(new_input_embeds) new_input_embeds_padded = [] new_labels_padded = torch.full( (batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device, ) attention_mask = torch.zeros( (batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device, ) position_ids = torch.zeros( (batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device ) for i, (cur_new_embed, cur_new_labels) in enumerate( zip(new_input_embeds, new_labels) ): cur_len = cur_new_embed.shape[0] if getattr(self.llm.config, "tokenizer_padding_side", "right") == "left": new_input_embeds_padded.append( torch.cat( ( torch.zeros( (max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device, ), cur_new_embed, ), dim=0, ) ) if cur_len > 0: new_labels_padded[i, -cur_len:] = cur_new_labels attention_mask[i, -cur_len:] = True position_ids[i, -cur_len:] = torch.arange( 0, cur_len, dtype=position_ids.dtype, device=position_ids.device ) else: new_input_embeds_padded.append( torch.cat( ( cur_new_embed, torch.zeros( (max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device, ), ), dim=0, ) ) if cur_len > 0: new_labels_padded[i, :cur_len] = cur_new_labels attention_mask[i, :cur_len] = True position_ids[i, :cur_len] = torch.arange( 0, cur_len, dtype=position_ids.dtype, device=position_ids.device ) new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) if _labels is None: new_labels = None else: new_labels = new_labels_padded if _attention_mask is None: attention_mask = None else: attention_mask = attention_mask.to(dtype=_attention_mask.dtype) if _position_ids is None: position_ids = None return ( None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels, ) def repack_multimodal_data( self, input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels, ): # kentang-mit@: reorder and repack (reduce computation overhead) # requires transformers replacement. new_inputs_embeds = [] new_position_ids = [] new_labels = [] seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) sorted_seqlens_in_batch, sorted_idx = torch.sort( seqlens_in_batch, descending=True ) # print(sorted_seqlens_in_batch) max_seqlen = inputs_embeds.shape[1] cur_inputs_embeds = [] cur_position_ids = [] cur_labels = [] cur_batch_len = 0 # print(sorted_seqlens_in_batch.device, len(sorted_seqlens_in_batch), max_seqlen) for i in range(len(sorted_seqlens_in_batch)): cur_seqlen = sorted_seqlens_in_batch[i].item() if cur_seqlen + cur_batch_len <= max_seqlen: cur_batch_len += cur_seqlen # each item: num_tokens x num_channels # remove padding on-the-fly cur_inputs_embeds.append( inputs_embeds[sorted_idx[i]][attention_mask[sorted_idx[i]]] ) # each item: num_tokens cur_position_ids.append( torch.arange( cur_inputs_embeds[-1].shape[0], device=cur_inputs_embeds[-1].device, ) ) # each item: num_tokens # remove padding on-the-fly cur_labels.append(labels[sorted_idx[i]][attention_mask[sorted_idx[i]]]) else: new_inputs_embeds.append(torch.cat(cur_inputs_embeds, 0)) new_position_ids.append(torch.cat(cur_position_ids, 0)) new_labels.append(torch.cat(cur_labels, 0)) # The current batch is too long. We will start a new batch. cur_batch_len = cur_seqlen cur_inputs_embeds = [ inputs_embeds[sorted_idx[i]][attention_mask[sorted_idx[i]]] ] cur_position_ids = [ torch.arange( cur_inputs_embeds[-1].shape[0], device=cur_inputs_embeds[-1].device, ) ] cur_labels = [labels[sorted_idx[i]][attention_mask[sorted_idx[i]]]] if len(cur_inputs_embeds): new_inputs_embeds.append(torch.cat(cur_inputs_embeds, 0)) new_position_ids.append(torch.cat(cur_position_ids, 0)) new_labels.append(torch.cat(cur_labels, 0)) # print(new_position_ids[0].device, [x.shape for x in new_inputs_embeds], [x.shape for x in new_labels], [x.shape for x in new_position_ids]) # assert 0 new_inputs_embeds = torch.nn.utils.rnn.pad_sequence( new_inputs_embeds, batch_first=True, padding_value=self.llm.pad_token_id ) new_position_ids = torch.nn.utils.rnn.pad_sequence( new_position_ids, batch_first=True, padding_value=-1 ) new_labels = torch.nn.utils.rnn.pad_sequence( new_labels, batch_first=True, padding_value=IGNORE_INDEX ) # yunhao: it's currently a workaround to avoid errors for seq_len < 100 new_attention_mask = new_position_ids.ne(-1) # sanity check assert new_attention_mask.sum() == attention_mask.sum() # print(new_inputs_embeds.shape, (new_attention_mask.sum(1))) # print(sorted_seqlens_in_batch.device, sorted_seqlens_in_batch, new_attention_mask.sum(1)) # return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels return ( None, new_position_ids, new_attention_mask, past_key_values, new_inputs_embeds, new_labels, sorted_seqlens_in_batch, ) def initialize_vision_tokenizer(self, model_args, tokenizer): if model_args.mm_use_im_patch_token: tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) self.resize_token_embeddings(len(tokenizer)) if model_args.mm_use_im_start_end: num_new_tokens = tokenizer.add_tokens( [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True ) self.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = self.get_input_embeddings().weight.data output_embeddings = self.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True ) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True ) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg # TODO yunhao: handle cases for if model_args.pretrain_mm_mlp_adapter: mm_projector_weights = torch.load( model_args.pretrain_mm_mlp_adapter, map_location="cpu" ) embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"] assert num_new_tokens == 2 if input_embeddings.shape == embed_tokens_weight.shape: input_embeddings[-num_new_tokens:] = embed_tokens_weight[ -num_new_tokens: ] elif embed_tokens_weight.shape[0] == num_new_tokens: input_embeddings[-num_new_tokens:] = embed_tokens_weight else: raise ValueError( f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}." ) elif model_args.mm_use_im_patch_token: if model_args.mm_projector: for p in self.get_input_embeddings().parameters(): p.requires_grad = False for p in self.get_output_embeddings().parameters(): p.requires_grad = False # This file is modified from https://github.com/haotian-liu/LLaVA/ import torch # noqa def build_mm_projector( model_type_or_path: str, config: PretrainedConfig ) -> PreTrainedModel: if model_type_or_path is None: return None # load from pretrained model if config.resume_path: assert os.path.exists( model_type_or_path ), f"Resume mm projector path {model_type_or_path} does not exist!" return MultimodalProjector.from_pretrained( model_type_or_path, config, torch_dtype=eval(config.model_dtype) ) # build from scratch else: mm_projector_cfg = MultimodalProjectorConfig(model_type_or_path) mm_projector = MultimodalProjector(mm_projector_cfg, config).to( eval(config.model_dtype) ) return mm_projector class IdentityMap(nn.Module): def __init__(self): super().__init__() def forward(self, x, *args, **kwargs): return x @property def config(self): return {"mm_projector_type": "identity"} class SimpleResBlock(nn.Module): def __init__(self, channels): super().__init__() self.pre_norm = nn.LayerNorm(channels) self.proj = nn.Sequential( nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels) ) def forward(self, x): x = self.pre_norm(x) return x + self.proj(x) class DownSampleBlock(nn.Module): def forward(self, x): vit_embeds = x h = w = int(vit_embeds.shape[1] ** 0.5) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) vit_embeds = self.flat_square(vit_embeds) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) return vit_embeds def flat_square(self, x): n, w, h, c = x.size() if w % 2 == 1: x = torch.concat( [x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1 ).contiguous() n, w, h, c = x.size() if h % 2 == 1: x = torch.concat( [x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2 ).contiguous() n, w, h, c = x.size() x = x.view(n, w, int(h / 2), int(c * 2)) x = x.permute(0, 2, 1, 3).contiguous() x = x.view(n, int(h / 2), int(w / 2), int(c * 4)) return x class MultimodalProjectorConfig(PretrainedConfig): model_type = "v2l_projector" def __init__(self, mm_projector_type: str = None, **kwargs): super().__init__() self.mm_projector_type = mm_projector_type class MultimodalProjector(PreTrainedModel): config_class = MultimodalProjectorConfig def __init__( self, mm_projector_cfg: MultimodalProjectorConfig, config: PretrainedConfig ): super().__init__(mm_projector_cfg) mm_projector_type = mm_projector_cfg.mm_projector_type if mm_projector_type == "identity": self.layers = IdentityMap() elif mm_projector_type == "linear": self.layers = nn.Linear(config.mm_hidden_size, config.hidden_size) elif mm_projector_type == "mlp_downsample": self.layers = nn.Sequential( DownSampleBlock(), nn.LayerNorm(config.mm_hidden_size * 4), nn.Linear(config.mm_hidden_size * 4, config.hidden_size), nn.GELU(), nn.Linear(config.hidden_size, config.hidden_size), ) else: mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", mm_projector_type) if mlp_gelu_match: mlp_depth = int(mlp_gelu_match.group(1)) modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] for _ in range(1, mlp_depth): modules.append(nn.GELU()) modules.append(nn.Linear(config.hidden_size, config.hidden_size)) self.layers = nn.Sequential(*modules) else: raise ValueError(f"Unknown projector type: {mm_projector_type}") def forward(self, x, *args, **kwargs): return self.layers(x) AutoConfig.register("v2l_projector", MultimodalProjectorConfig) # This file is modified from https://github.com/haotian-liu/LLaVA/ AutoModel.register(MultimodalProjectorConfig, MultimodalProjector) import torch # noqa def build_vision_tower( model_name_or_path: str, config: PretrainedConfig ) -> PreTrainedModel: # skip vision tower instantiation if model_name_or_path is None: return None vision_tower_arch = None if config.resume_path and "radio" not in model_name_or_path: assert os.path.exists( model_name_or_path ), f"Resume vision tower path {model_name_or_path} does not exist!" vision_tower_cfg = AutoConfig.from_pretrained( model_name_or_path, trust_remote_code=True ) vision_tower_arch = vision_tower_cfg.architectures[0].lower() vision_tower_name = ( vision_tower_arch if vision_tower_arch is not None else model_name_or_path ) if "siglip" in vision_tower_name: vision_tower = SiglipVisionTower(model_name_or_path, config) else: raise ValueError(f"Unknown vision tower: {model_name_or_path}") config.mm_hidden_size = vision_tower.config.hidden_size return vision_tower def build_context_provider( model_type_or_path: str, config: PretrainedConfig ) -> PreTrainedModel: if model_type_or_path is None: return None # load from pretrained model if config.resume_path: assert os.path.exists( model_type_or_path ), f"Resume context provider path {model_type_or_path} does not exist!" return ContextProvider.from_pretrained( model_type_or_path, config, torch_dtype=eval(config.model_dtype) ) # build from scratch else: mm_projector_cfg = ContextProviderConfig(model_type_or_path) mm_projector = ContextProvider(mm_projector_cfg, config).to( eval(config.model_dtype) ) return mm_projector # Copyright 2024 NVIDIA CORPORATION & AFFILIATES # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # SPDX-License-Identifier: Apache-2.0 # import deepspeed class ContextProviderConfig(PretrainedConfig): model_type = "context_provider" def __init__( self, context_provider_type: str = None, hidden_size=768, intermediate_size=3072, num_hidden_layers=12, num_attention_heads=12, num_channels=3, num_mask_channels=0, image_size=224, patch_size=16, hidden_act="gelu_pytorch_tanh", layer_norm_eps=1e-6, attention_dropout=0.0, zero_init_output=True, residual_dropout=0.0, context_image_as_queries=False, context_provider_layer_indices=None, masked_cross_attn=False, crop_position_single_embedding=False, trainable_crop_position_embedding=True, crop_embedding_mode="add", treat_image_as_cimage=False, **kwargs, ): super().__init__(**kwargs) self.context_provider_type = context_provider_type self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_attention_heads = num_attention_heads self.num_channels = num_channels self.num_mask_channels = num_mask_channels self.patch_size = patch_size self.image_size = image_size self.attention_dropout = attention_dropout self.layer_norm_eps = layer_norm_eps self.hidden_act = hidden_act self.zero_init_output = zero_init_output self.residual_dropout = residual_dropout self.context_image_as_queries = context_image_as_queries # cross_attn_end_to_all # the `num_hidden_layers` should be the same as the one in the vision tower self.num_hidden_layers = num_hidden_layers self.context_provider_layer_indices = context_provider_layer_indices self.masked_cross_attn = masked_cross_attn # If enabled, crop_position_embedding (delta to full pos) will be updated during training. self.trainable_crop_position_embedding = trainable_crop_position_embedding # If enabled, crop_position_embedding (delta to full pos) will be a single embedding for all positions. self.crop_position_single_embedding = crop_position_single_embedding # add: delta. replace: do not add the original positional embedding self.crop_embedding_mode = crop_embedding_mode # If True, the input image will be treated as a cimage (with mask as full 1s) self.treat_image_as_cimage = treat_image_as_cimage # Context Provider class ContextProviderCrossAttention(nn.Module): """Multi-headed cross-attention from 'Attention Is All You Need' paper""" # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ def __init__(self, config): super().__init__() self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads})." ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" batch_size, q_len, _ = hidden_states.size() batch_size, kv_len, _ = encoder_hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(encoder_hidden_states) value_states = self.v_proj(encoder_hidden_states) query_states = query_states.view( batch_size, q_len, self.num_heads, self.head_dim ).transpose(1, 2) key_states = key_states.view( batch_size, kv_len, self.num_heads, self.head_dim ).transpose(1, 2) value_states = value_states.view( batch_size, kv_len, self.num_heads, self.head_dim ).transpose(1, 2) k_v_seq_len = key_states.shape[-2] attn_weights = ( torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale ) if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): raise ValueError( f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): raise ValueError( f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask # Visualizations (-inf are shown as white) # import matplotlib.pyplot as plt # plt.imshow(attention_mask[0, 0, 0].view(27, 27).detach().cpu().numpy()) # plt.title("Attention mask") # plt.colorbar() # plt.show() # upcast attention to fp32 attn_weights = nn.functional.softmax( attn_weights, dim=-1, dtype=torch.float32 ).to(query_states.dtype) # Visualizations: show the attention weights of the first head, with the first query # import matplotlib.pyplot as plt # plt.imshow(attn_weights[0, 0, 0].view(27, 27).detach().cpu().numpy()) # plt.title("Attention weights") # plt.colorbar() # plt.show() attn_weights = nn.functional.dropout( attn_weights, p=self.dropout, training=self.training ) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, attn_weights class ContextProviderMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.activation_fn = ACT2FN[config.hidden_act] self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states def get_token_mask_bias(mask, patch_size): # Note: mask should be (0, 1) with torch.no_grad(): # Add a channel dimension and perform conv # mask_tokens_after_conv: (B, 1, H, W), example dimension: [1, 1, 27, 27] mask_tokens_after_conv = F.conv2d( input=mask[:, None], weight=torch.ones( (1, 1, patch_size, patch_size), device=mask.device, dtype=mask.dtype ), bias=None, stride=(patch_size, patch_size), padding="valid", ) token_mask_bias = torch.zeros_like(mask_tokens_after_conv) token_mask_bias.masked_fill_(mask_tokens_after_conv < 1e-5, float("-inf")) token_mask_bias = token_mask_bias.flatten(1) # Flattened dimension: (1, 729) return token_mask_bias def attn_mask_from_cimage_concatenated(cimage_concatenated, patch_size): # Use the mask from input image (4th channel) mask_normalized = cimage_concatenated[:, 3] mask_unnormalized = (mask_normalized + 1) / 2 # (1, 729) token_mask_bias = get_token_mask_bias(mask_unnormalized, patch_size=patch_size) # attn_mask: (B, 1, Q, KV) # print("Token positions:", token_mask.nonzero()) # Obtain token mask in the bias format: in mask 0, out of mask -inf q_kv = token_mask_bias.shape[-1] attn_mask_bias = token_mask_bias[:, None, None, :].repeat(1, 1, q_kv, 1) # Visualizations # print(f"token_mask_bias shape: {token_mask_bias.shape}, attn_mask_bias shape: {attn_mask_bias.shape}") # import matplotlib.pyplot as plt # plt.imshow(attn_mask_bias[0, 0, 0].view(27, 27).detach().cpu().numpy()) # plt.title("Attention mask (outside)") # plt.show() return attn_mask_bias # From SiglipEncoderLayer. We would like to modify this to cross-attention. class CrossAttnEncoderLayer(nn.Module): def __init__(self, config: ContextProviderConfig): super().__init__() self.embed_dim = config.hidden_size self.cross_attn = ContextProviderCrossAttention(config) self.residual_dropout = nn.Dropout(config.residual_dropout) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = ContextProviderMLP(config) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) if config.zero_init_output: # TODO: alternatively, we could parameterize with an MLP # These factors are initialized with 0 (so only residual passes through) if config.context_provider_type != "cross_attn_at_the_end": self.register_parameter("attn_factor", nn.Parameter(torch.zeros((1,)))) self.register_parameter("mlp_factor", nn.Parameter(torch.zeros((1,)))) else: # Use scalar tensor for compatibility self.register_parameter( "attn_factor", nn.Parameter(torch.zeros((1,)).view(())) ) self.register_parameter( "mlp_factor", nn.Parameter(torch.zeros((1,)).view(())) ) else: self.attn_factor = 1.0 self.mlp_factor = 1.0 # Ignore copy def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: Optional[bool] = False, ) -> Tuple[torch.FloatTensor]: """ Args: hidden_states (`torch.FloatTensor`): Input to the layer of shape `(batch, seq_len, embed_dim)`. attention_mask (`torch.FloatTensor`): Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states, attn_weights = self.cross_attn( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, ) # Dropping the residual: let the model leverage more on the context hidden_states = ( self.residual_dropout(residual) + self.attn_factor * hidden_states ) residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + self.mlp_factor * hidden_states outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs class CrossAttnContextProviderEndToAll(nn.Module): def __init__(self, config: ContextProviderConfig): super().__init__() self.layers = nn.ModuleList( [ CrossAttnEncoderLayer(config) for i in enumerate(range(config.num_hidden_layers)) if config.context_provider_layer_indices is None or i in config.context_provider_layer_indices ] ) self.patch_size = config.patch_size self.masked_cross_attn = config.masked_cross_attn def forward(self, context_image_features, cimage_concatenated, vision_tower): # Use the mask from input image (4th channel) if self.masked_cross_attn: attn_mask = attn_mask_from_cimage_concatenated( cimage_concatenated, patch_size=self.patch_size ) else: attn_mask = None detail_raw_image = cimage_concatenated[:, 4:, ...] # NOTE: when using context image as queries, the context image was swapped with the detail image before passing into the context provider outputs = vision_tower( detail_raw_image, context_provider_layers=self.layers, contexts=context_image_features, cross_attention_mask=attn_mask, ) return outputs class ContextProvider(PreTrainedModel): config_class = ContextProviderConfig def __init__( self, context_provider_cfg: ContextProviderConfig, config: PretrainedConfig ): super().__init__(context_provider_cfg) self.context_image_as_queries = context_provider_cfg.context_image_as_queries self.context_provider_type = context_provider_type = ( context_provider_cfg.context_provider_type ) self.treat_image_as_cimage = context_provider_cfg.treat_image_as_cimage if self.context_image_as_queries: assert ( not context_provider_cfg.masked_cross_attn ), "Masked cross-attention not implemented when using context image as queries." assert ( "concat" not in context_provider_type ), "Concat not implemented when using context image as queries." if context_provider_type == "cross_attn_end_to_all": # Information flow: end of context features -> all detail features self.context_provider_module = CrossAttnContextProviderEndToAll( context_provider_cfg ) else: raise ValueError(f"Unknown context provider type: {context_provider_type}") def forward( self, cimage_full_features=None, cimage_crop_features=None, cimage_concatenated=None, vision_tower=None, ): if self.context_provider_type == "cross_attn_end_to_all": assert ( cimage_full_features.shape[0] == cimage_concatenated.shape[0] ), f"shape mismatches: {cimage_full_features.shape[0]} != {cimage_concatenated.shape[0]}" return self.context_provider_module( context_image_features=cimage_full_features, cimage_concatenated=cimage_concatenated, vision_tower=vision_tower, ) else: raise ValueError(f"Unknown context provider type: {context_provider_type}") AutoConfig.register("context_provider", ContextProviderConfig) AutoModel.register(ContextProviderConfig, ContextProvider) # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Image processor class for RADIO.""" if is_torch_available(): import torch if is_torchvision_available(): pass if is_tf_available(): pass logger = logging.get_logger(__name__) def rank_print(s): rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 print(f"[Rank {rank}] {s}") class ImageProcessor(BaseImageProcessor): r""" Constructs an image processor. Args: do_resize (`bool`, *optional*, defaults to `True`): Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the `do_resize` parameter in the `preprocess` method. size (`dict`, *optional*, defaults to `{"longest_edge": 1024}`): Size of the output image after resizing. If "longest_edge" is specified, resizes the longest edge of the image to match `size["longest_edge"]` while maintaining the aspect ratio. If "width" and "height" are specified, resizes the image to that size, possibly changing the aspect ratio. Can be overridden by the `size` parameter in the `preprocess` method. resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the `preprocess` method. do_rescale (`bool`, *optional*, defaults to `True`): Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` parameter in the `preprocess` method. rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be overridden by the `rescale_factor` parameter in the `preprocess` method. do_normalize (`bool`, *optional*, defaults to `True`): Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` method. Can be overridden by the `do_normalize` parameter in the `preprocess` method. image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`): Mean to use if normalizing the image. This is a float or list of floats the length of the number of channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be overridden by the `image_mean` parameter in the `preprocess` method. image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`): Standard deviation to use if normalizing the image. This is a float or list of floats the length of the number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. Can be overridden by the `image_std` parameter in the `preprocess` method. do_pad (`bool`, *optional*, defaults to `True`): Whether to pad the image to the specified `pad_size`. Can be overridden by the `do_pad` parameter in the `preprocess` method. pad_size (`dict`, *optional*, defaults to `{"height": 1024, "width": 1024}`): Size of the output image after padding. Can be overridden by the `pad_size` parameter in the `preprocess` method. pad_value (`float` or `Iterable[float]`, *optional*, defaults to `0.`): Value of padded pixels. pad_multiple (`int`, *optional*, defaults to `None`): Pad to a multiple of specified number. do_convert_rgb (`bool`, *optional*, defaults to `True`): Whether to convert the image to RGB. """ model_input_names = ["pixel_values"] def __init__( self, do_resize: bool = True, size: Dict[str, int] = None, resample: PILImageResampling = PILImageResampling.BILINEAR, do_rescale: bool = True, rescale_factor: Union[int, float] = 1 / 255, do_normalize: bool = True, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, do_pad: bool = True, pad_size: int = None, pad_multiple: int = None, pad_value: Optional[Union[float, List[float]]] = 0.0, do_convert_rgb: bool = True, **kwargs, ) -> None: super().__init__(**kwargs) size = size if size is not None else {"longest_edge": 1024} size = ( get_size_dict(max_size=size, default_to_square=False) if not isinstance(size, dict) else size ) if pad_size is not None and pad_multiple is not None: raise ValueError( "pad_size and pad_multiple should not be set at the same time." ) pad_size = ( pad_size if pad_size is not None else {"height": 1024, "width": 1024} if pad_multiple is not None else None ) if do_pad: pad_size = get_size_dict(pad_size, default_to_square=True) self.do_resize = do_resize self.size = size self.resample = resample self.do_rescale = do_rescale self.rescale_factor = rescale_factor self.do_normalize = do_normalize self.image_mean = ( image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN ) self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD self.do_pad = do_pad self.pad_multiple = pad_multiple self.pad_size = pad_size self.pad_value = tuple(pad_value) if isinstance(pad_value, list) else pad_value self.do_convert_rgb = do_convert_rgb self._valid_processor_keys = [ "images", "segmentation_maps", "do_resize", "size", "resample", "do_rescale", "rescale_factor", "do_normalize", "image_mean", "image_std", "do_pad", "pad_size", "do_convert_rgb", "return_tensors", "data_format", "input_data_format", ] def pad_image( self, image: np.ndarray, pad_size: Dict[str, int], data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, **kwargs, ) -> np.ndarray: """ Pad an image to `(pad_size["height"], pad_size["width"])` to the right and bottom. Args: image (`np.ndarray`): Image to pad. pad_size (`Dict[str, int]`): Size of the output image after padding. data_format (`str` or `ChannelDimension`, *optional*): The data format of the image. Can be either "channels_first" or "channels_last". If `None`, the `data_format` of the `image` will be used. input_data_format (`str` or `ChannelDimension`, *optional*): The channel dimension format of the input image. If not provided, it will be inferred. """ output_height, output_width = pad_size["height"], pad_size["width"] input_height, input_width = get_image_size(image, channel_dim=input_data_format) pad_width = output_width - input_width pad_height = output_height - input_height padded_image = pad( image, ((0, pad_height), (0, pad_width)), data_format=data_format, input_data_format=input_data_format, constant_values=self.pad_value, **kwargs, ) return padded_image def _get_preprocess_shape(self, old_shape: Tuple[int, int], longest_edge: int): """ Compute the output size given input size and target long side length. """ oldh, oldw = old_shape scale = longest_edge * 1.0 / max(oldh, oldw) newh, neww = oldh * scale, oldw * scale newh = int(newh + 0.5) neww = int(neww + 0.5) return (newh, neww) def resize( self, image: np.ndarray, size: Dict[str, int], resample: PILImageResampling = PILImageResampling.BICUBIC, data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, **kwargs, ) -> np.ndarray: """ Resize an image to `(size["height"], size["width"])`. Args: image (`np.ndarray`): Image to resize. size (`Dict[str, int]`): Dictionary in the format `{"longest_edge": int}` or `{"width": int, "height": int}` specifying the size of the output image. If "longest_edge" is specified, resizes the longest edge of the image to match `size["longest_edge"]` while maintaining the aspect ratio. If "width" and "height" are specified, resizes the image to that size, possibly changing the aspect ratio. resample: `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`. data_format (`ChannelDimension` or `str`, *optional*): The channel dimension format for the output image. If unset, the channel dimension format of the input image is used. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. input_data_format (`ChannelDimension` or `str`, *optional*): The channel dimension format for the input image. If unset, the channel dimension format is inferred from the input image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. Returns: `np.ndarray`: The resized image. """ size = get_size_dict(size) if "longest_edge" not in size: if "width" not in size or "height" not in size: raise ValueError( f"The `size` dictionary must contain the key `longest_edge`, or `width` and `height`. Got {size.keys()}" ) input_size = get_image_size(image, channel_dim=input_data_format) if "longest_edge" in size: output_height, output_width = self._get_preprocess_shape( input_size, size["longest_edge"] ) else: output_height, output_width = size["height"], size["width"] return resize( image, size=(output_height, output_width), resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs, ) def _preprocess( self, image: ImageInput, do_resize: bool, do_rescale: bool, do_normalize: bool, size: Optional[Dict[str, int]] = None, resample: PILImageResampling = None, rescale_factor: Optional[float] = None, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, do_pad: Optional[bool] = None, pad_size: Optional[Dict[str, int]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, ): if do_resize: image = self.resize( image=image, size=size, resample=resample, input_data_format=input_data_format, ) reshaped_input_size = get_image_size(image, channel_dim=input_data_format) if do_rescale: image = self.rescale( image=image, scale=rescale_factor, input_data_format=input_data_format ) if do_normalize: image = self.normalize( image=image, mean=image_mean, std=image_std, input_data_format=input_data_format, ) if do_pad: if self.pad_multiple: h, w = get_image_size(image, channel_dim=input_data_format) pad_size = { "height": math.ceil(h / self.pad_multiple) * self.pad_multiple, "width": math.ceil(w / self.pad_multiple) * self.pad_multiple, } image = self.pad_image( image=image, pad_size=pad_size, input_data_format=input_data_format ) return image, reshaped_input_size def _preprocess_image( self, image: ImageInput, do_resize: Optional[bool] = None, size: Dict[str, int] = None, resample: PILImageResampling = None, do_rescale: bool = None, rescale_factor: Optional[float] = None, do_normalize: Optional[bool] = None, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, do_pad: Optional[bool] = None, pad_size: Optional[Dict[str, int]] = None, do_convert_rgb: Optional[bool] = None, data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, ) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]]: # image = to_numpy_array(image) # import time # if int(time.time()*1000) % 10 == 0: # # create an PIL image of size 1x1 # image = PIL.Image.new('RGB', (1, 1)) if isinstance(image, Image.Image): # PIL always uses Channels Last. input_data_format = ChannelDimension.LAST # PIL RGBA images are converted to RGB # mode_before = image.mode if do_convert_rgb: image = convert_to_rgb(image) # All transformations expect numpy arrays. image = to_numpy_array(image) # if isinstance(image_, np.ndarray): # rank_print(f"preprocess image type={type(image_)} shape={image_.shape} array shape={image.shape}") # elif isinstance(image_, Image.Image): # rank_print(f"preprocessimage type={type(image_)} size={image_.size} mode={image_.mode} array shape={image.shape}") # else: # rank_print(f"preprocess unknown image type={type(image_)} array shape={image.shape}") if len(image.shape) == 2: h, w = image.shape ret = np.empty((h, w, 3), dtype=np.uint8) ret[:, :, 0] = image ret[:, :, 1] = image ret[:, :, 2] = image image = ret rank_print(f"preprocess new image shape={image.shape}") elif len(image.shape) == 3 and image.shape[-1] == 1: ret = np.empty((h, w, 3), dtype=np.uint8) ret[:, :, 0] = image[:, :, 0] ret[:, :, 1] = image[:, :, 0] ret[:, :, 2] = image[:, :, 0] image = ret rank_print(f"preprocess new image shape={image.shape}") if is_scaled_image(image) and do_rescale: logger.warning_once( "It looks like you are trying to rescale already rescaled images. If the input" " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." ) if input_data_format is None: input_data_format = infer_channel_dimension_format(image) original_size = get_image_size(image, channel_dim=input_data_format) image, reshaped_input_size = self._preprocess( image=image, do_resize=do_resize, size=size, resample=resample, do_rescale=do_rescale, rescale_factor=rescale_factor, do_normalize=do_normalize, image_mean=image_mean, image_std=image_std, do_pad=do_pad, pad_size=pad_size, input_data_format=input_data_format, ) if data_format is not None: image = to_channel_dimension_format( image, data_format, input_channel_dim=input_data_format ) # rank_print(f"preprocess original_size={original_size} reshaped_input_size={reshaped_input_size} image shape={image.shape} type={type(image)}") # if image is a single channel convert to rgb if do_convert_rgb and image.shape[0] == 1: c, h, w = image.shape ret = np.empty((3, h, w), dtype=np.uint8) ret[0, :, :] = image[0, :, :] ret[1, :, :] = image[0, :, :] ret[2, :, :] = image[0, :, :] image = ret rank_print(f"preprocess final: {image.shape}") return image, original_size, reshaped_input_size def preprocess( self, images: ImageInput, do_resize: Optional[bool] = None, size: Optional[Dict[str, int]] = None, resample: Optional["PILImageResampling"] = None, do_rescale: Optional[bool] = None, rescale_factor: Optional[Union[int, float]] = None, do_normalize: Optional[bool] = None, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, do_pad: Optional[bool] = None, pad_size: Optional[Dict[str, int]] = None, do_convert_rgb: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, data_format: ChannelDimension = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, **kwargs, ): """ Preprocess an image or batch of images. Args: images (`ImageInput`): Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`. do_resize (`bool`, *optional*, defaults to `self.do_resize`): Whether to resize the image. size (`Dict[str, int]`, *optional*, defaults to `self.size`): Controls the size of the image after `resize`. The longest edge of the image is resized to `size["longest_edge"]` whilst preserving the aspect ratio. resample (`PILImageResampling`, *optional*, defaults to `self.resample`): `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`. do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): Whether to rescale the image pixel values by rescaling factor. rescale_factor (`int` or `float`, *optional*, defaults to `self.rescale_factor`): Rescale factor to apply to the image pixel values. do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): Whether to normalize the image. image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): Image mean to normalize the image by if `do_normalize` is set to `True`. image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): Image standard deviation to normalize the image by if `do_normalize` is set to `True`. do_pad (`bool`, *optional*, defaults to `self.do_pad`): Whether to pad the image. pad_size (`Dict[str, int]`, *optional*, defaults to `self.pad_size`): Controls the size of the padding applied to the image. The image is padded to `pad_size["height"]` and `pad_size["width"]` if `do_pad` is set to `True`. do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): Whether to convert the image to RGB. return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - Unset: Use the channel dimension format of the input image. input_data_format (`ChannelDimension` or `str`, *optional*): The channel dimension format for the input image. If unset, the channel dimension format is inferred from the input image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. """ do_resize = do_resize if do_resize is not None else self.do_resize size = size if size is not None else self.size size = ( get_size_dict(max_size=size, default_to_square=False) if not isinstance(size, dict) else size ) resample = resample if resample is not None else self.resample do_rescale = do_rescale if do_rescale is not None else self.do_rescale rescale_factor = ( rescale_factor if rescale_factor is not None else self.rescale_factor ) do_normalize = do_normalize if do_normalize is not None else self.do_normalize image_mean = image_mean if image_mean is not None else self.image_mean image_std = image_std if image_std is not None else self.image_std do_pad = do_pad if do_pad is not None else self.do_pad pad_size = pad_size if pad_size is not None else self.pad_size if do_pad: pad_size = get_size_dict(pad_size, default_to_square=True) do_convert_rgb = ( do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb ) images = make_list_of_images(images) if not valid_images(images): raise ValueError( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) images, original_sizes, reshaped_input_sizes = zip( *( self._preprocess_image( image=img, do_resize=do_resize, size=size, resample=resample, do_rescale=do_rescale, rescale_factor=rescale_factor, do_normalize=do_normalize, image_mean=image_mean, image_std=image_std, do_pad=do_pad, pad_size=pad_size, do_convert_rgb=do_convert_rgb, data_format=data_format, input_data_format=input_data_format, ) for img in images ) ) data = { "pixel_values": images, "original_sizes": original_sizes, "reshaped_input_sizes": reshaped_input_sizes, } return BatchFeature(data=data, tensor_type=return_tensors) # This file is modified from https://github.com/haotian-liu/LLaVA/ class VisionTower(nn.Module): def __init__(self, vision_tower, args, delay_load=False): super().__init__() self.is_loaded = False self.vision_tower_name = vision_tower self.select_layer = getattr(args, "mm_vision_select_layer", -2) self.select_feature = getattr(args, "mm_vision_select_feature", "patch") self.cfg_only = None def feature_select(self, image_forward_outs): image_features = image_forward_outs.hidden_states[self.select_layer] if self.select_feature == "patch": image_features = image_features[:, 1:] elif self.select_feature == "cls_patch": image_features = image_features else: raise ValueError(f"Unexpected select feature: {self.select_feature}") return image_features def _maybe_resize_pos_embeds( self, model: PreTrainedModel, image_processor: BaseImageProcessor, resolution: int = -1, interpolate_mode: str = "linear", ): if resolution in [model.config.image_size, -1]: return print( f"Resizing vision model's position embeddings to support higher vision resolution: from {model.config.image_size} to {resolution} ..." ) embeddings = model.vision_model.embeddings patch_size = embeddings.patch_size num_new_tokens = int((resolution // patch_size) ** 2) old_embeddings = embeddings.position_embedding match interpolate_mode: case "linear": # Step 1: Calculate the corresponding patch ID (pid) in the current resolution (M patches) based on the target resolution (N patches). Formula: pid = pid / N * M # Step 2: Obtain new embeddings by interpolating between the embeddings of the two nearest calculated patch IDs. Formula: new_embeds = (pid - floor(pid)) * embeds[ceil(pid)] + (ceil(pid) - pid) * embeds[floor(pid)] # At inference time, we assume deepspeed zero3 is not enabled. # import deepspeed # with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None): # old_num_tokens, old_embedding_dim = old_embeddings.weight.size() old_num_tokens, old_embedding_dim = old_embeddings.weight.size() new_embeddings = nn.Embedding( num_new_tokens, old_embedding_dim, dtype=old_embeddings.weight.dtype, device=old_embeddings.weight.device, ) mapped_indices = ( torch.arange(num_new_tokens).to(old_embeddings.weight.device) / (num_new_tokens - 1) * (old_num_tokens - 1) ) floor_indices = torch.clamp( mapped_indices.floor().long(), min=0, max=old_num_tokens - 1 ) ceil_indices = torch.clamp( mapped_indices.ceil().long(), min=0, max=old_num_tokens - 1 ) # At inference time, we assume deepspeed zero3 is not enabled. # params = [old_embeddings.weight, new_embeddings.weight] # with deepspeed.zero.GatheredParameters(params, modifier_rank=0): # interpolated_embeds = (mapped_indices - floor_indices)[:, None] * old_embeddings.weight.data[ # ceil_indices, : # ] + (ceil_indices - mapped_indices)[:, None] * old_embeddings.weight.data[floor_indices, :] interpolated_embeds = (mapped_indices - floor_indices)[ :, None ] * old_embeddings.weight.data[ceil_indices, :] + ( ceil_indices - mapped_indices )[ :, None ] * old_embeddings.weight.data[ floor_indices, : ] new_embeddings.weight.data = interpolated_embeds case _: raise NotImplementedError if hasattr(old_embeddings, "_hf_hook"): hook = old_embeddings._hf_hook add_hook_to_module(new_embeddings, hook) new_embeddings.requires_grad_(old_embeddings.weight.requires_grad) # update vision encoder's configurations model.config.image_size = resolution if hasattr(image_processor, "crop_size"): # CLIP vision tower image_processor.crop_size = resolution else: # SIGLIP vision tower assert hasattr(image_processor, "size") image_processor.size = {"height": resolution, "width": resolution} # TODO define a '_reinitialize' method for VisionTower embeddings.position_embedding = new_embeddings embeddings.image_size = resolution embeddings.num_patches = embeddings.num_positions = num_new_tokens embeddings.position_ids = ( torch.arange(embeddings.num_positions) .expand((1, -1)) .to(old_embeddings.weight.device) ) def forward(self, images, **kwargs): if type(images) is list: image_features = [] for image in images: image_forward_out = self.vision_tower( image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True, **kwargs, ) image_feature = self.feature_select(image_forward_out).to(image.dtype) image_features.append(image_feature) else: image_forward_outs = self.vision_tower( images.to(device=self.device, dtype=self.dtype), output_hidden_states=True, **kwargs, ) image_features = self.feature_select(image_forward_outs).to(images.dtype) return image_features @property def dummy_feature(self): return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) @property def dtype(self): return self.vision_tower.dtype @property def device(self): return self.vision_tower.device @property def config(self): if self.is_loaded: return self.vision_tower.config else: return self.cfg_only @property def hidden_size(self): return self.config.hidden_size @property def num_patches(self): return (self.config.image_size // self.config.patch_size) ** 2 # coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Siglip model configuration""" logger = logging.get_logger(__name__) SIGLIP_PRETRAINED_CONFIG_ARCHIVE_MAP = { "google/siglip-base-patch16-224": "https://huggingface.co/google/siglip-base-patch16-224/resolve/main/config.json", } class SiglipTextConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`SiglipTextModel`]. It is used to instantiate a Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: vocab_size (`int`, *optional*, defaults to 32000): Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`SiglipModel`]. hidden_size (`int`, *optional*, defaults to 768): Dimensionality of the encoder layers and the pooler layer. intermediate_size (`int`, *optional*, defaults to 3072): Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. num_hidden_layers (`int`, *optional*, defaults to 12): Number of hidden layers in the Transformer encoder. num_attention_heads (`int`, *optional*, defaults to 12): Number of attention heads for each attention layer in the Transformer encoder. max_position_embeddings (`int`, *optional*, defaults to 64): The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048). hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. layer_norm_eps (`float`, *optional*, defaults to 1e-06): The epsilon used by the layer normalization layers. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. pad_token_id (`int`, *optional*, defaults to 1): The id of the padding token in the vocabulary. bos_token_id (`int`, *optional*, defaults to 49406): The id of the beginning-of-sequence token in the vocabulary. eos_token_id (`int`, *optional*, defaults to 49407): The id of the end-of-sequence token in the vocabulary. Example: ```python >>> from transformers import SiglipTextConfig, SiglipTextModel >>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration >>> configuration = SiglipTextConfig() >>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration >>> model = SiglipTextModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "siglip_text_model" def __init__( self, vocab_size=32000, hidden_size=768, intermediate_size=3072, num_hidden_layers=12, num_attention_heads=12, max_position_embeddings=64, hidden_act="gelu_pytorch_tanh", layer_norm_eps=1e-6, attention_dropout=0.0, # This differs from `CLIPTokenizer`'s default and from openai/siglip # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538 pad_token_id=1, bos_token_id=49406, eos_token_id=49407, **kwargs, ): super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs, ) self.vocab_size = vocab_size self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.max_position_embeddings = max_position_embeddings self.layer_norm_eps = layer_norm_eps self.hidden_act = hidden_act self.attention_dropout = attention_dropout @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs ) -> "PretrainedConfig": # cls._set_token_in_kwargs(kwargs) config_dict, kwargs = cls.get_config_dict( pretrained_model_name_or_path, **kwargs ) # get the text config dict if we are loading from SiglipConfig if config_dict.get("model_type") == "siglip": config_dict = config_dict["text_config"] if ( "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type ): logger.warning( f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." ) return cls.from_dict(config_dict, **kwargs) class SiglipVisionConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: hidden_size (`int`, *optional*, defaults to 768): Dimensionality of the encoder layers and the pooler layer. intermediate_size (`int`, *optional*, defaults to 3072): Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. num_hidden_layers (`int`, *optional*, defaults to 12): Number of hidden layers in the Transformer encoder. num_attention_heads (`int`, *optional*, defaults to 12): Number of attention heads for each attention layer in the Transformer encoder. num_channels (`int`, *optional*, defaults to 3): Number of channels in the input images. image_size (`int`, *optional*, defaults to 224): The size (resolution) of each image. patch_size (`int`, *optional*, defaults to 16): The size (resolution) of each patch. hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported. layer_norm_eps (`float`, *optional*, defaults to 1e-06): The epsilon used by the layer normalization layers. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. num_mask_channels (`int`, *optional*, defaults to 0): Number of mask channels in the input images. Example: ```python >>> from transformers import SiglipVisionConfig, SiglipVisionModel >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration >>> configuration = SiglipVisionConfig() >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration >>> model = SiglipVisionModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "siglip_vision_model" def __init__( self, hidden_size=768, intermediate_size=3072, num_hidden_layers=12, num_attention_heads=12, num_channels=3, image_size=224, patch_size=16, hidden_act="gelu_pytorch_tanh", layer_norm_eps=1e-6, attention_dropout=0.0, num_mask_channels=0, **kwargs, ): super().__init__(**kwargs) self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.num_channels = num_channels self.patch_size = patch_size self.image_size = image_size self.attention_dropout = attention_dropout self.layer_norm_eps = layer_norm_eps self.hidden_act = hidden_act self.num_mask_channels = num_mask_channels @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs ) -> "PretrainedConfig": # cls._set_token_in_kwargs(kwargs) config_dict, kwargs = cls.get_config_dict( pretrained_model_name_or_path, **kwargs ) # get the vision config dict if we are loading from SiglipConfig if config_dict.get("model_type") == "siglip": config_dict = config_dict["vision_config"] if ( "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type ): logger.warning( f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." ) return cls.from_dict(config_dict, **kwargs) class SiglipConfig(PretrainedConfig): r""" [`SiglipConfig`] is the configuration class to store the configuration of a [`SiglipModel`]. It is used to instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: text_config (`dict`, *optional*): Dictionary of configuration options used to initialize [`SiglipTextConfig`]. vision_config (`dict`, *optional*): Dictionary of configuration options used to initialize [`SiglipVisionConfig`]. kwargs (*optional*): Dictionary of keyword arguments. Example: ```python >>> from transformers import SiglipConfig, SiglipModel >>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 style configuration >>> configuration = SiglipConfig() >>> # Initializing a SiglipModel (with random weights) from the google/siglip-base-patch16-224 style configuration >>> model = SiglipModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config >>> # We can also initialize a SiglipConfig from a SiglipTextConfig and a SiglipVisionConfig >>> from transformers import SiglipTextConfig, SiglipVisionConfig >>> # Initializing a SiglipText and SiglipVision configuration >>> config_text = SiglipTextConfig() >>> config_vision = SiglipVisionConfig() >>> config = SiglipConfig.from_text_vision_configs(config_text, config_vision) ```""" model_type = "siglip" def __init__(self, text_config=None, vision_config=None, **kwargs): super().__init__(**kwargs) if text_config is None: text_config = {} logger.info( "`text_config` is `None`. Initializing the `SiglipTextConfig` with default values." ) if vision_config is None: vision_config = {} logger.info( "`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values." ) self.text_config = SiglipTextConfig(**text_config) self.vision_config = SiglipVisionConfig(**vision_config) self.initializer_factor = 1.0 @classmethod def from_text_vision_configs( cls, text_config: SiglipTextConfig, vision_config: SiglipVisionConfig, **kwargs ): r""" Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision model configuration. Returns: [`SiglipConfig`]: An instance of a configuration object """ return cls( text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs, ) # coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Image processor class for SigLIP.""" logger = logging.get_logger(__name__) def is_scaled_image(image: np.ndarray) -> bool: """ Checks to see whether the pixel values have already been rescaled to [0, 1]. """ if image.dtype == np.uint8: return False # It's possible the image has pixel values in [0, 255] but is of floating type return np.min(image) >= 0 and np.max(image) <= 1 if is_vision_available(): import PIL class SiglipImageProcessor(BaseImageProcessor): r""" Constructs a SigLIP image processor. Args: do_resize (`bool`, *optional*, defaults to `True`): Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by `do_resize` in the `preprocess` method. size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`): Size of the image after resizing. Can be overridden by `size` in the `preprocess` method. resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. do_rescale (`bool`, *optional*, defaults to `True`): Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in the `preprocess` method. rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` method. do_normalize (`bool`, *optional*, defaults to `True`): Whether to normalize the image by the specified mean and standard deviation. Can be overridden by `do_normalize` in the `preprocess` method. image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): Mean to use if normalizing the image. This is a float or list of floats the length of the number of channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): Standard deviation to use if normalizing the image. This is a float or list of floats the length of the number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. Can be overridden by the `image_std` parameter in the `preprocess` method. """ model_input_names = ["pixel_values"] def __init__( self, do_resize: bool = True, size: Dict[str, int] = None, resample: PILImageResampling = PILImageResampling.BICUBIC, do_rescale: bool = True, rescale_factor: Union[int, float] = 1 / 255, do_normalize: bool = True, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, do_convert_rgb: bool = True, **kwargs, ) -> None: super().__init__(**kwargs) size = size if size is not None else {"shortest_edge": 384} size = get_size_dict(size, default_to_square=False) image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD self.do_resize = do_resize self.size = size self.resample = resample self.do_rescale = do_rescale self.rescale_factor = rescale_factor self.do_normalize = do_normalize self.image_mean = image_mean self.image_std = image_std self.do_convert_rgb = do_convert_rgb def resize( self, image: np.ndarray, size: Dict[str, int], resample: PILImageResampling = PILImageResampling.BICUBIC, data_format: Optional[Union[str, ChannelDimension]] = None, **kwargs, ) -> np.ndarray: """ Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge resized to keep the input aspect ratio. Args: image (`np.ndarray`): Image to resize. size (`Dict[str, int]`): Size of the output image. resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): Resampling filter to use when resiizing the image. data_format (`str` or `ChannelDimension`, *optional*): The channel dimension format of the image. If not provided, it will be the same as the input image. """ # size = get_size_dict(size, default_to_square=False) default_to_square = True if "shortest_edge" in size: size = size["shortest_edge"] default_to_square = False elif "height" in size and "width" in size: size = (size["height"], size["width"]) else: raise ValueError( "Size must contain either 'shortest_edge' or 'height' and 'width'." ) output_size = get_resize_output_image_size( image, size=size, default_to_square=default_to_square ) return resize( image, size=output_size, resample=resample, data_format=data_format, **kwargs, ) def preprocess( self, images: ImageInput, do_resize: bool = None, size: Dict[str, int] = None, resample: PILImageResampling = None, do_rescale: bool = None, rescale_factor: float = None, do_normalize: bool = None, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, do_convert_rgb: bool = None, return_tensors: Optional[Union[str, TensorType]] = None, data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, **kwargs, ) -> PIL.Image.Image: """ Preprocess an image or batch of images. Args: images (`ImageInput`): Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`. do_resize (`bool`, *optional*, defaults to `self.do_resize`): Whether to resize the image. size (`Dict[str, int]`, *optional*, defaults to `self.size`): Size of the image after resizing. resample (`int`, *optional*, defaults to `self.resample`): Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only has an effect if `do_resize` is set to `True`. do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): Whether to rescale the image. rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): Rescale factor to rescale the image by if `do_rescale` is set to `True`. do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): Whether to normalize the image. image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to `True`. return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - Unset: Use the channel dimension format of the input image. input_data_format (`ChannelDimension` or `str`, *optional*): The channel dimension format for the input image. If unset, the channel dimension format is inferred from the input image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. """ do_resize = do_resize if do_resize is not None else self.do_resize size = size if size is not None else self.size size = get_size_dict(size, param_name="size", default_to_square=False) resample = resample if resample is not None else self.resample # do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop # crop_size = crop_size if crop_size is not None else self.crop_size # crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True) do_rescale = do_rescale if do_rescale is not None else self.do_rescale rescale_factor = ( rescale_factor if rescale_factor is not None else self.rescale_factor ) do_normalize = do_normalize if do_normalize is not None else self.do_normalize image_mean = image_mean if image_mean is not None else self.image_mean image_std = image_std if image_std is not None else self.image_std do_convert_rgb = ( do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb ) images = make_list_of_images(images) if not valid_images(images): raise ValueError( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) if do_resize and size is None: raise ValueError("Size must be specified if do_resize is True.") if do_rescale and rescale_factor is None: raise ValueError("Rescale factor must be specified if do_rescale is True.") if do_normalize and (image_mean is None or image_std is None): raise ValueError( "Image mean and std must be specified if do_normalize is True." ) # PIL RGBA images are converted to RGB if do_convert_rgb: images = [convert_to_rgb(image) for image in images] # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] if is_scaled_image(images[0]) and do_rescale: logger.warning_once( "It looks like you are trying to rescale already rescaled images. If the input" " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." ) # if input_data_format is None: # # We assume that all images have the same channel dimension format. # input_data_format = infer_channel_dimension_format(images[0]) if do_resize: images = [ self.resize(image=image, size=size, resample=resample) for image in images ] if do_rescale: images = [rescale(image=image, scale=rescale_factor) for image in images] if do_normalize: output_images = [] for image in images: if get_channel_dimension_axis(image) == 0: image = image.transpose((1, 2, 0)) if image.shape[-1] == 1: image = np.dstack((image, image, image)) output_images.append(image) images = output_images # for image in images: # # print("image shape", image.shape) # channel_axis = get_channel_dimension_axis(image) # num_channels = image.shape[channel_axis] # if num_channels != len(image_mean): # print("image_mean", image_mean) # print("channel_axis", channel_axis) # print("num_channels", num_channels) # print("image.shape", image.shape) # raise ValueError( # f"Number of channels in the image ({num_channels}) does not match the length of image mean " # f"({len(image_mean)})." # ) images = [ normalize(image=image, mean=image_mean, std=image_std) for image in images ] images = [to_channel_dimension_format(image, data_format) for image in images] data = {"pixel_values": images} return BatchFeature(data=data, tensor_type=return_tensors) # coding=utf-8 # Copyright 2024 Google AI and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch Siglip model.""" # from ...modeling_attn_mask_utils import _prepare_4d_attention_mask logger = logging.get_logger(__name__) # _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224" # SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [ # "google/siglip-base-patch16-224", # # See all SigLIP models at https://huggingface.co/models?filter=siglip # ] def _trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 if (mean < a - 2 * std) or (mean > b + 2 * std): warnings.warn( "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect.", stacklevel=2, ) # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.0)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) def trunc_normal_tf_( tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0, ) -> torch.Tensor: """Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for generating the random values works best when :math:`a \\leq \text{mean} \\leq b`. NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 and the result is subsquently scaled and shifted by the mean and std args. Args: tensor: an n-dimensional `torch.Tensor` mean: the mean of the normal distribution std: the standard deviation of the normal distribution a: the minimum cutoff value b: the maximum cutoff value """ with torch.no_grad(): _trunc_normal_(tensor, 0, 1.0, a, b) tensor.mul_(std).add_(mean) def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) if mode == "fan_in": denom = fan_in elif mode == "fan_out": denom = fan_out elif mode == "fan_avg": denom = (fan_in + fan_out) / 2 variance = scale / denom if distribution == "truncated_normal": # constant is stddev of standard normal truncated to (-2, 2) trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) elif distribution == "normal": with torch.no_grad(): tensor.normal_(std=math.sqrt(variance)) elif distribution == "uniform": bound = math.sqrt(3 * variance) with torch.no_grad(): tensor.uniform_(-bound, bound) else: raise ValueError(f"invalid distribution {distribution}") def lecun_normal_(tensor): variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") def default_flax_embed_init(tensor): variance_scaling_(tensor, mode="fan_in", distribution="normal") @dataclass # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip class SiglipVisionModelOutput(ModelOutput): """ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. Args: image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): The image embeddings obtained by applying the projection layer to the pooler_output. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ image_embeds: Optional[torch.FloatTensor] = None last_hidden_state: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None @dataclass # Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip class SiglipTextModelOutput(ModelOutput): """ Base class for text model's outputs that also contains a pooling of the last hidden states. Args: text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): The text embeddings obtained by applying the projection layer to the pooler_output. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ text_embeds: Optional[torch.FloatTensor] = None last_hidden_state: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None @dataclass # Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip class SiglipOutput(ModelOutput): """ Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): Contrastive loss for image-text similarity. logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text similarity scores. logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image similarity scores. text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`]. image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`]. text_model_output(`BaseModelOutputWithPooling`): The output of the [`SiglipTextModel`]. vision_model_output(`BaseModelOutputWithPooling`): The output of the [`SiglipVisionModel`]. """ loss: Optional[torch.FloatTensor] = None logits_per_image: torch.FloatTensor = None logits_per_text: torch.FloatTensor = None text_embeds: torch.FloatTensor = None image_embeds: torch.FloatTensor = None text_model_output: BaseModelOutputWithPooling = None vision_model_output: BaseModelOutputWithPooling = None def to_tuple(self) -> Tuple[Any]: return tuple( ( self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() ) for k in self.keys() ) class SiglipVisionEmbeddings(nn.Module): def __init__(self, config: SiglipVisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size self.patch_embedding = nn.Conv2d( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, padding="valid", ) print(f"Number of mask channels: {config.num_mask_channels}") if config.num_mask_channels: # Mask should have the same output shape to be added. # Currently we have bias in this embedding (so that mask vs no mask are different). self.mask_patch_embedding = nn.Conv2d( in_channels=config.num_mask_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, padding="valid", ) self.mask_patch_embedding.use_zero_init = True else: self.mask_patch_embedding = None self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.register_buffer( "position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False, ) def forward( self, pixel_values: torch.FloatTensor, additional_position_embedding: Optional[torch.Tensor] = None, additional_embedding_mode: Optional[str] = None, ) -> torch.Tensor: if self.mask_patch_embedding is None: patch_embeds = self.patch_embedding( pixel_values ) # shape = [*, width, grid, grid] else: # Comment this out if you want to encode both images without mask channel and with mask channel. # However, if different samples in the batch have different number of channels, this is not applicable. # assert pixel_values.size(1) == 4, f"Input does not have a mask channel, shape: {pixel_values.shape}" patch_embeds = self.patch_embedding( pixel_values[:, :3, ...] ) # shape = [*, width, grid, grid] if pixel_values.size(1) == 4: patch_embeds = patch_embeds + self.mask_patch_embedding( pixel_values[:, 3:4, ...] ) embeddings = patch_embeds.flatten(2).transpose(1, 2) if additional_position_embedding is not None: if additional_embedding_mode == "add": embeddings = embeddings + self.position_embedding(self.position_ids) embeddings = embeddings + additional_position_embedding elif additional_embedding_mode == "replace": # The original positional embedding is not used (multiplied by zero to ensure all parameters are used to be safe) embeddings = ( embeddings + self.position_embedding(self.position_ids) * 0.0 ) embeddings = embeddings + additional_position_embedding else: raise ValueError( f"additional_embedding_mode should be either 'add' or 'replace', got {additional_embedding_mode}" ) else: # Without additional position embedding embeddings = embeddings + self.position_embedding(self.position_ids) # print("No additional position embedding") return embeddings # Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip class SiglipTextEmbeddings(nn.Module): def __init__(self, config: SiglipTextConfig): super().__init__() embed_dim = config.hidden_size self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) self.position_embedding = nn.Embedding( config.max_position_embeddings, embed_dim ) # position_ids (1, len position emb) is contiguous in memory and exported when serialized self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False, ) def forward( self, input_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: seq_length = ( input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] ) if position_ids is None: position_ids = self.position_ids[:, :seq_length] if inputs_embeds is None: inputs_embeds = self.token_embedding(input_ids) position_embeddings = self.position_embedding(position_ids) embeddings = inputs_embeds + position_embeddings return embeddings class SiglipAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ def __init__(self, config): super().__init__() self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads})." ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" batch_size, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view( batch_size, q_len, self.num_heads, self.head_dim ).transpose(1, 2) key_states = key_states.view( batch_size, q_len, self.num_heads, self.head_dim ).transpose(1, 2) value_states = value_states.view( batch_size, q_len, self.num_heads, self.head_dim ).transpose(1, 2) k_v_seq_len = key_states.shape[-2] attn_weights = ( torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale ) if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): raise ValueError( f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): raise ValueError( f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax( attn_weights, dim=-1, dtype=torch.float32 ).to(query_states.dtype) attn_weights = nn.functional.dropout( attn_weights, p=self.dropout, training=self.training ) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, attn_weights # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip class SiglipMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.activation_fn = ACT2FN[config.hidden_act] self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip class SiglipEncoderLayer(nn.Module): def __init__(self, config: SiglipConfig): super().__init__() self.embed_dim = config.hidden_size self.self_attn = SiglipAttention(config) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP(config) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) # Ignore copy def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: Optional[bool] = False, ) -> Tuple[torch.FloatTensor]: """ Args: hidden_states (`torch.FloatTensor`): Input to the layer of shape `(batch, seq_len, embed_dim)`. attention_mask (`torch.FloatTensor`): Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs class SiglipPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = SiglipConfig base_model_prefix = "siglip" supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" if isinstance(module, SiglipVisionEmbeddings): width = ( self.config.vision_config.hidden_size if isinstance(self.config, SiglipConfig) else self.config.hidden_size ) nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) elif isinstance(module, nn.Embedding): default_flax_embed_init(module.weight) elif isinstance(module, SiglipAttention): nn.init.xavier_uniform_(module.q_proj.weight) nn.init.xavier_uniform_(module.k_proj.weight) nn.init.xavier_uniform_(module.v_proj.weight) nn.init.xavier_uniform_(module.out_proj.weight) nn.init.zeros_(module.q_proj.bias) nn.init.zeros_(module.k_proj.bias) nn.init.zeros_(module.v_proj.bias) nn.init.zeros_(module.out_proj.bias) elif isinstance(module, SiglipMLP): nn.init.xavier_uniform_(module.fc1.weight) nn.init.xavier_uniform_(module.fc2.weight) nn.init.normal_(module.fc1.bias, std=1e-6) nn.init.normal_(module.fc2.bias, std=1e-6) elif isinstance(module, SiglipMultiheadAttentionPoolingHead): nn.init.xavier_uniform_(module.probe.data) nn.init.xavier_uniform_(module.attention.in_proj_weight.data) nn.init.zeros_(module.attention.in_proj_bias.data) elif isinstance(module, SiglipModel): logit_scale_init = torch.log(torch.tensor(1.0)) module.logit_scale.data.fill_(logit_scale_init) module.logit_bias.data.zero_() elif isinstance(module, nn.Conv2d) and getattr(module, "use_zero_init", False): param_list = [module.weight] if module.bias is not None: param_list += [module.bias] # This is used in mask patch embedding # # with deepspeed.zero.GatheredParameters(param_list, modifier_rank=0): # for param in param_list: # nn.init.zeros_(param) for param in param_list: nn.init.zeros_(param) elif isinstance(module, (nn.Linear, nn.Conv2d)): lecun_normal_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) SIGLIP_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`SiglipConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ SIGLIP_TEXT_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`. [What are position IDs?](../glossary#position-ids) output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ SIGLIP_VISION_INPUTS_DOCSTRING = r""" Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ SIGLIP_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`. [What are position IDs?](../glossary#position-ids) pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. return_loss (`bool`, *optional*): Whether or not to return the contrastive loss. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ # Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip class SiglipEncoder(nn.Module): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a [`SiglipEncoderLayer`]. Args: config: SiglipConfig """ def __init__(self, config: SiglipConfig): super().__init__() self.config = config self.layers = nn.ModuleList( [SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)] ) self.gradient_checkpointing = False # Ignore copy def forward( self, inputs_embeds, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, context_provider_layers: Optional[nn.ModuleList] = None, contexts: Optional[List[torch.Tensor]] = None, cross_attention_mask: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutput]: r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. context_provider_layers (nn.ModuleList): ModuleList of context provider layers. contexts: List of torch.Tensor for context (for KV in cross-attention). cross_attention_mask (`torch.Tensor` of shape `(batch_size, q_sequence_length, kv_sequence_length)`, *optional*): mask for cross-attention. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None hidden_states = inputs_embeds for layer_index, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, output_attentions, ) else: layer_outputs = encoder_layer( hidden_states, attention_mask, output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) if context_provider_layers: # Right now contexts is passed as the encoder_hidden_states (the output hidden_states of the context ViT). context_provider_layer = context_provider_layers[layer_index] if context_provider_layer is not None: if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( context_provider_layer.__call__, hidden_states, contexts, cross_attention_mask, output_attentions, ) else: layer_outputs = context_provider_layer( hidden_states, contexts, cross_attention_mask, output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if not return_dict: return tuple( v for v in [hidden_states, encoder_states, all_attentions] if v is not None ) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions, ) class SiglipTextTransformer(nn.Module): def __init__(self, config: SiglipTextConfig): super().__init__() self.config = config embed_dim = config.hidden_size self.embeddings = SiglipTextEmbeddings(config) self.encoder = SiglipEncoder(config) self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.head = nn.Linear(embed_dim, embed_dim) @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings( output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig ) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: """ output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if input_ids is None: raise ValueError("You have to specify input_ids") input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model. # expand attention_mask # if attention_mask is not None: # # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] # attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) encoder_outputs = self.encoder( inputs_embeds=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs, ) last_hidden_state = encoder_outputs[0] last_hidden_state = self.final_layer_norm(last_hidden_state) # Assuming "sticky" EOS tokenization, last token is always EOS. pooled_output = last_hidden_state[:, -1, :] pooled_output = self.head(pooled_output) if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) @add_start_docstrings( """The text model from SigLIP without any head or projection on top.""", SIGLIP_START_DOCSTRING, ) class SiglipTextModel(SiglipPreTrainedModel): config_class = SiglipTextConfig _no_split_modules = ["SiglipTextEmbeddings", "SiglipEncoderLayer"] def __init__(self, config: SiglipTextConfig): super().__init__(config) self.text_model = SiglipTextTransformer(config) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> nn.Module: return self.text_model.embeddings.token_embedding def set_input_embeddings(self, value): self.text_model.embeddings.token_embedding = value @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings( output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig ) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: Examples: ```python >>> from transformers import AutoTokenizer, SiglipTextModel >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224") >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") >>> # important: make sure to set padding="max_length" as that's how the model was trained >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") >>> outputs = model(**inputs) >>> last_hidden_state = outputs.last_hidden_state >>> pooled_output = outputs.pooler_output # pooled (EOS token) states ```""" return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) return self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) class SiglipVisionTransformer(nn.Module): def __init__(self, config: SiglipVisionConfig): super().__init__() self.config = config embed_dim = config.hidden_size self.embeddings = SiglipVisionEmbeddings(config) self.encoder = SiglipEncoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.head = SiglipMultiheadAttentionPoolingHead(config) @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) @replace_return_docstrings( output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig ) def forward( self, pixel_values, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: """ output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) hidden_states = self.embeddings(pixel_values) encoder_outputs = self.encoder( inputs_embeds=hidden_states, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs, ) last_hidden_state = encoder_outputs[0] last_hidden_state = self.post_layernorm(last_hidden_state) pooled_output = self.head(last_hidden_state) if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) class SiglipMultiheadAttentionPoolingHead(nn.Module): """Multihead Attention Pooling.""" def __init__(self, config: SiglipVisionConfig): super().__init__() self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) self.attention = torch.nn.MultiheadAttention( config.hidden_size, config.num_attention_heads, batch_first=True ) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = SiglipMLP(config) def forward(self, hidden_state): batch_size = hidden_state.shape[0] probe = self.probe.repeat(batch_size, 1, 1) hidden_state = self.attention(probe, hidden_state, hidden_state)[0] residual = hidden_state hidden_state = self.layernorm(hidden_state) hidden_state = residual + self.mlp(hidden_state) return hidden_state[:, 0] @add_start_docstrings( """The vision model from SigLIP without any head or projection on top.""", SIGLIP_START_DOCSTRING, ) class SiglipVisionModel(SiglipPreTrainedModel): config_class = SiglipVisionConfig main_input_name = "pixel_values" def __init__(self, config: SiglipVisionConfig): super().__init__(config) self.vision_model = SiglipVisionTransformer(config) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> nn.Module: return self.vision_model.embeddings.patch_embedding @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) @replace_return_docstrings( output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig ) def forward( self, pixel_values, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: Examples: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, SiglipVisionModel >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224") >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(images=image, return_tensors="pt") >>> outputs = model(**inputs) >>> last_hidden_state = outputs.last_hidden_state >>> pooled_output = outputs.pooler_output # pooled features ```""" return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) return self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs, ) @add_start_docstrings(SIGLIP_START_DOCSTRING) class SiglipModel(SiglipPreTrainedModel): config_class = SiglipConfig def __init__(self, config: SiglipConfig): super().__init__(config) if not isinstance(config.text_config, SiglipTextConfig): raise ValueError( "config.text_config is expected to be of type SiglipTextConfig but is of type" f" {type(config.text_config)}." ) if not isinstance(config.vision_config, SiglipVisionConfig): raise ValueError( "config.vision_config is expected to be of type SiglipVisionConfig but is of type" f" {type(config.vision_config)}." ) text_config = config.text_config vision_config = config.vision_config self.text_model = SiglipTextTransformer(text_config) self.vision_model = SiglipVisionTransformer(vision_config) self.logit_scale = nn.Parameter(torch.randn(1)) self.logit_bias = nn.Parameter(torch.randn(1)) # Initialize weights and apply final processing self.post_init() @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) def get_text_features( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> torch.FloatTensor: r""" Returns: text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`]. Examples: ```python >>> from transformers import AutoTokenizer, AutoModel >>> import torch >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") >>> # important: make sure to set padding="max_length" as that's how the model was trained >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") >>> with torch.no_grad(): ... text_features = model.get_text_features(**inputs) ```""" # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) text_outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) pooled_output = text_outputs[1] return pooled_output @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) def get_image_features( self, pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> torch.FloatTensor: r""" Returns: image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`]. Examples: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, AutoModel >>> import torch >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(images=image, return_tensors="pt") >>> with torch.no_grad(): ... image_features = model.get_image_features(**inputs) ```""" # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components. output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) vision_outputs = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) pooled_output = vision_outputs[1] return pooled_output @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig) def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, return_loss: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, SiglipOutput]: r""" Returns: Examples: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, AutoModel >>> import torch >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"] >>> inputs = processor(text=texts, images=image, return_tensors="pt") >>> with torch.no_grad(): ... outputs = model(**inputs) >>> logits_per_image = outputs.logits_per_image >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") 31.9% that image 0 is 'a photo of 2 cats' ```""" # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) vision_outputs = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) text_outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) image_embeds = vision_outputs[1] text_embeds = text_outputs[1] # normalized features image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) # cosine similarity as logits logits_per_text = ( torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale.exp() + self.logit_bias ) logits_per_image = logits_per_text.t() loss = None if return_loss: raise NotImplementedError("SigLIP loss to be implemented") if not return_dict: output = ( logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs, ) return ((loss,) + output) if loss is not None else output return SiglipOutput( loss=loss, logits_per_image=logits_per_image, logits_per_text=logits_per_text, text_embeds=text_embeds, image_embeds=image_embeds, text_model_output=text_outputs, vision_model_output=vision_outputs, ) # coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Image/Text processor class for SigLIP. """ class SiglipProcessor(ProcessorMixin): r""" Constructs a Siglip processor which wraps a Siglip image processor and a Siglip tokenizer into a single processor. [`SiglipProcessor`] offers all the functionalities of [`SiglipImageProcessor`] and [`SiglipTokenizer`]. See the [`~SiglipProcessor.__call__`] and [`~SiglipProcessor.decode`] for more information. Args: image_processor ([`SiglipImageProcessor`]): The image processor is a required input. tokenizer ([`SiglipTokenizer`]): The tokenizer is a required input. """ attributes = ["image_processor", "tokenizer"] image_processor_class = "SiglipImageProcessor" tokenizer_class = "SiglipTokenizer" def __init__(self, image_processor, tokenizer): super().__init__(image_processor, tokenizer) def __call__( self, text: Union[ TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput] ] = None, images: ImageInput = None, padding: Union[bool, str, PaddingStrategy] = "max_length", truncation: Union[bool, str, TruncationStrategy] = None, max_length=None, return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` and `kwargs` arguments to SiglipTokenizer's [`~SiglipTokenizer.__call__`] if `text` is not `None` to encode the text. To prepare the image(s), this method forwards the `images` argument to SiglipImageProcessor's [`~SiglipImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring of the above two methods for more information. Args: text (`str`, `List[str]`, `List[List[str]]`): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a number of channels, H and W are image height and width. padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `max_length`): Select a strategy to pad the returned sequences (according to the model's padding side and padding index) among: - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence if provided). - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum acceptable input length for the model if that argument is not provided. - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different lengths). max_length (`int`, *optional*): Maximum length of the returned list and optionally padding length (see above). truncation (`bool`, *optional*): Activates truncation to cut input sequences longer than `max_length` to `max_length`. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ if text is None and images is None: raise ValueError( "You have to specify either text or images. Both cannot be none." ) if text is not None: encoding = self.tokenizer( text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length, ) if images is not None: image_features = self.image_processor(images, return_tensors=return_tensors) if text is not None and images is not None: encoding["pixel_values"] = image_features.pixel_values return encoding elif text is not None: return encoding else: return BatchFeature(data=dict(**image_features), tensor_type=return_tensors) def decode(self, *args, **kwargs): """ This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to the docstring of this method for more information. """ return self.tokenizer.decode(*args, **kwargs) def batch_decode(self, *args, **kwargs): """ This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information. """ return self.tokenizer.batch_decode(*args, **kwargs) @property # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->Siglip, T5->Siglip def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names image_processor_input_names = self.image_processor.model_input_names return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) # coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Tokenization class for SigLIP model.""" if TYPE_CHECKING: from transformers.tokenization_utils_base import TextInput logger = logging.get_logger(__name__) VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"} PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { "google/siglip-base-patch16-224": "https://huggingface.co/google/siglip-base-patch16-224/resolve/main/spiece.model", } } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { "google/siglip-base-patch16-224": 256, } SPIECE_UNDERLINE = "▁" class SiglipTokenizer(PreTrainedTokenizer): """ Construct a Siglip tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece). This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to this superclass for more information regarding those methods. Args: vocab_file (`str`): [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that contains the vocabulary necessary to instantiate a tokenizer. eos_token (`str`, *optional*, defaults to `""`): The end of sequence token. unk_token (`str`, *optional*, defaults to `""`): The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead. pad_token (`str`, *optional*, defaults to `""`): The token used for padding, for example when batching sequences of different lengths. additional_special_tokens (`List[str]`, *optional*): Additional special tokens used by the tokenizer. sp_model_kwargs (`dict`, *optional*): Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, to set: - `enable_sampling`: Enable subword regularization. - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. - `nbest_size = {0,1}`: No sampling is performed. - `nbest_size > 1`: samples from the nbest_size results. - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) using forward-filtering-and-backward-sampling algorithm. - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for BPE-dropout. model_max_length (`int`, *optional*, defaults to 64): The maximum length (in number of tokens) for model inputs. do_lower_case (`bool`, *optional*, defaults to `True`): Whether or not to lowercase the input when tokenizing. """ vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES model_input_names = ["input_ids", "attention_mask"] def __init__( self, vocab_file, eos_token="", unk_token="", pad_token="", additional_special_tokens=None, sp_model_kwargs: Optional[Dict[str, Any]] = None, model_max_length=64, do_lower_case=True, **kwargs, ) -> None: requires_backends(self, "protobuf") pad_token = ( AddedToken( pad_token, rstrip=True, lstrip=True, normalized=False, special=True ) if isinstance(pad_token, str) else pad_token ) unk_token = ( AddedToken( unk_token, rstrip=True, lstrip=True, normalized=False, special=True ) if isinstance(unk_token, str) else unk_token ) eos_token = ( AddedToken( eos_token, rstrip=True, lstrip=True, normalized=False, special=True ) if isinstance(eos_token, str) else eos_token ) self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs self.do_lower_case = do_lower_case self.vocab_file = vocab_file self.sp_model = self.get_spm_processor() self.vocab_file = vocab_file super().__init__( eos_token=eos_token, unk_token=unk_token, pad_token=pad_token, additional_special_tokens=additional_special_tokens, sp_model_kwargs=self.sp_model_kwargs, model_max_length=model_max_length, do_lower_case=do_lower_case, **kwargs, ) def get_spm_processor(self): tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs) with open(self.vocab_file, "rb") as f: sp_model = f.read() model_pb2 = import_protobuf() model = model_pb2.ModelProto.FromString(sp_model) normalizer_spec = model_pb2.NormalizerSpec() normalizer_spec.add_dummy_prefix = False model.normalizer_spec.MergeFrom(normalizer_spec) sp_model = model.SerializeToString() tokenizer.LoadFromSerializedProto(sp_model) return tokenizer @property # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.vocab_size def vocab_size(self): return self.sp_model.get_piece_size() # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_vocab def get_vocab(self): vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} vocab.update(self.added_tokens_encoder) return vocab # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_special_tokens_mask def get_special_tokens_mask( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False, ) -> List[int]: """ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer `prepare_for_model` method. Args: token_ids_0 (`List[int]`): List of IDs. token_ids_1 (`List[int]`, *optional*): Optional second list of IDs for sequence pairs. already_has_special_tokens (`bool`, *optional*, defaults to `False`): Whether or not the token list is already formatted with special tokens for the model. Returns: `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. """ if already_has_special_tokens: return super().get_special_tokens_mask( token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True, ) # normal case: some special tokens if token_ids_1 is None: return ([0] * len(token_ids_0)) + [1] return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._add_eos_if_not_present def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]: """Do not add eos again if user already added it.""" if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id: warnings.warn( f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated" " eos tokens being added." ) return token_ids else: return token_ids + [self.eos_token_id] # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.create_token_type_ids_from_sequences def create_token_type_ids_from_sequences( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: """ Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make use of token type ids, therefore a list of zeros is returned. Args: token_ids_0 (`List[int]`): List of IDs. token_ids_1 (`List[int]`, *optional*): Optional second list of IDs for sequence pairs. Returns: `List[int]`: List of zeros. """ eos = [self.eos_token_id] if token_ids_1 is None: return len(token_ids_0 + eos) * [0] return len(token_ids_0 + eos + token_ids_1 + eos) * [0] # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.build_inputs_with_special_tokens def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: """ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and adding special tokens. A sequence has the following format: - single sequence: `X ` - pair of sequences: `A B ` Args: token_ids_0 (`List[int]`): List of IDs to which the special tokens will be added. token_ids_1 (`List[int]`, *optional*): Optional second list of IDs for sequence pairs. Returns: `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. """ token_ids_0 = self._add_eos_if_not_present(token_ids_0) if token_ids_1 is None: return token_ids_0 else: token_ids_1 = self._add_eos_if_not_present(token_ids_1) return token_ids_0 + token_ids_1 # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__getstate__ def __getstate__(self): state = self.__dict__.copy() state["sp_model"] = None return state # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__setstate__ def __setstate__(self, d): self.__dict__ = d # for backward compatibility if not hasattr(self, "sp_model_kwargs"): self.sp_model_kwargs = {} self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) self.sp_model.Load(self.vocab_file) def remove_punctuation(self, text: str) -> str: return text.translate(str.maketrans("", "", string.punctuation)) # source: https://github.com/google-research/big_vision/blob/3b8e5ab6ad4f96e32b32826f9e1b8fd277914f9c/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94 def canonicalize_text(self, text, *, keep_punctuation_exact_string=None): """Returns canonicalized `text` (puncuation removed). Args: text (`str`): String to be canonicalized. keep_punctuation_exact_string (`str`, *optional*): If provided, then this exact string is kept. For example providing '{}' will keep any occurrences of '{}' (but will still remove '{' and '}' that appear separately). """ if keep_punctuation_exact_string: text = keep_punctuation_exact_string.join( self.remove_punctuation(part) for part in text.split(keep_punctuation_exact_string) ) else: text = self.remove_punctuation(text) text = re.sub(r"\s+", " ", text) text = text.strip() return text def tokenize( self, text: "TextInput", add_special_tokens=False, **kwargs ) -> List[str]: """ Converts a string to a list of tokens. """ tokens = super().tokenize( SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " "), **kwargs ) if ( len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens ): tokens = tokens[1:] return tokens @property # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.unk_token_length def unk_token_length(self): return len(self.sp_model.encode(str(self.unk_token))) def _tokenize(self, text, **kwargs): """ Returns a tokenized string. We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`. `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`. """ text = self.canonicalize_text(text, keep_punctuation_exact_string=None) tokens = self.sp_model.encode(text, out_type=str) # 1. Encode string + prefix ex: " Hey" tokens = self.sp_model.encode(self.unk_token + text, out_type=str) # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] return ( tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens ) # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_token_to_id def _convert_token_to_id(self, token): """Converts a token (str) in an id using the vocab.""" return self.sp_model.piece_to_id(token) # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_id_to_token def _convert_id_to_token(self, index): """Converts an index (integer) in a token (str) using the vocab.""" token = self.sp_model.IdToPiece(index) return token # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.convert_tokens_to_string def convert_tokens_to_string(self, tokens): """Converts a sequence of tokens (string) in a single string.""" current_sub_tokens = [] # since we manually add the prefix space, we have to remove it tokens[0] = tokens[0].lstrip(SPIECE_UNDERLINE) out_string = "" prev_is_special = False for token in tokens: # make sure that special tokens are not decoded using sentencepiece model if token in self.all_special_tokens: if not prev_is_special: out_string += " " out_string += self.sp_model.decode(current_sub_tokens) + token prev_is_special = True current_sub_tokens = [] else: current_sub_tokens.append(token) prev_is_special = False out_string += self.sp_model.decode(current_sub_tokens) return out_string.strip() # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.save_vocabulary def save_vocabulary( self, save_directory: str, filename_prefix: Optional[str] = None ) -> Tuple[str]: if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") return out_vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"], ) if os.path.abspath(self.vocab_file) != os.path.abspath( out_vocab_file ) and os.path.isfile(self.vocab_file): copyfile(self.vocab_file, out_vocab_file) elif not os.path.isfile(self.vocab_file): with open(out_vocab_file, "wb") as fi: content_spiece_model = self.sp_model.serialized_model_proto() fi.write(content_spiece_model) return (out_vocab_file,) import torch # noqa class SiglipVisionTower(VisionTower): def __init__( self, model_name_or_path: str, config: PretrainedConfig, state_dict=None ): super().__init__(model_name_or_path, config) self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path) self.vision_tower = SiglipVisionModel.from_pretrained( # TODO(ligeng): why pass config here leading to errors? model_name_or_path, torch_dtype=eval(config.model_dtype), state_dict=state_dict, ) self.is_loaded = True AutoConfig.register("siglip_vision_model", SiglipVisionConfig, exist_ok=True) AutoModel.register(SiglipVisionConfig, SiglipVisionModel, exist_ok=True) # Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # This file is modified from https://github.com/haotian-liu/LLaVA/ class LlavaLlamaConfig(LlavaConfig): model_type = "llava_llama" # FIXME we will follow the convention to add a new class for CausalLM in the future class LlavaLlamaModel(LlavaMetaModel, LlavaMetaForCausalLM, PreTrainedModel): config_class = LlavaLlamaConfig main_input_name = "input_embeds" supports_gradient_checkpointing = True tokenizer_image_token = staticmethod(tokenizer_image_token) def __init__(self, config: LlavaLlamaConfig = None, *args, **kwargs) -> None: super().__init__(config) self.dam_model = None self.pretrained_model_name_or_path = None self.init_vlm(config=config, *args, **kwargs) @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, cache_dir: Optional[Union[str, os.PathLike]] = None, ignore_mismatched_sizes: bool = False, force_download: bool = False, local_files_only: bool = False, token: Optional[Union[str, bool]] = None, revision: str = "main", use_safetensors: bool = None, torch_dtype: Optional[Union[str, torch.dtype]] = torch.float16, init_dam: bool = False, # conv_mode and prompt_mode are only used by `init_dam` in `from_pretrained` if `init_dam` is set to True conv_mode: str = "v1", prompt_mode: str = "full+focal_crop", **kwargs, ): if torch_dtype: config.model_dtype = str(torch_dtype) if hasattr(cls, "load_pretrained"): obj = cls.load_pretrained( pretrained_model_name_or_path, *model_args, config=config, cache_dir=cache_dir, ignore_mismatched_sizes=ignore_mismatched_sizes, force_download=force_download, local_files_only=local_files_only, token=token, revision=revision, use_safetensors=use_safetensors, **kwargs, ) else: obj = super(LlavaLlamaModel).from_pretrained( pretrained_model_name_or_path, *model_args, config=config, cache_dir=cache_dir, ignore_mismatched_sizes=ignore_mismatched_sizes, force_download=force_download, local_files_only=local_files_only, token=token, revision=revision, use_safetensors=use_safetensors, **kwargs, ) obj.pretrained_model_name_or_path = pretrained_model_name_or_path # `init_dam` is used to initialize a `DescribeAnythingModel` object in a `LlavaLlamaModel` in DAM. If you initialize `DescribeAnythingModel` on your own outside, then you don't have to use this option. # This is very useful if you use `from_pretrained` with remote code execution and don't want to put implementation for `DescribeAnythingModel` class in your codebase. if init_dam: obj.init_dam(conv_mode, prompt_mode) return obj def forward( self, input_ids: torch.LongTensor = None, images: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: self.freezed_module_patch() if inputs_embeds is None: ( input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels, ) = self.prepare_inputs_labels_for_multimodal( input_ids, position_ids, attention_mask, past_key_values, labels, images ) # Note (kentang-mit@): we have a unit test for this function. if self.training: ( _, new_position_ids, new_attention_mask, _, new_inputs_embeds, new_labels, sorted_seqlens_in_batch, ) = self.repack_multimodal_data( input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels, ) new_input_ids = None past_key_values = None else: new_attention_mask = attention_mask new_position_ids = position_ids new_inputs_embeds = inputs_embeds new_labels = labels sorted_seqlens_in_batch = attention_mask.sum(-1).int() new_input_ids = input_ids outputs = self.llm.forward( input_ids=new_input_ids, attention_mask=new_attention_mask, position_ids=new_position_ids, past_key_values=past_key_values, inputs_embeds=new_inputs_embeds, labels=new_labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, seqlens_in_batch=sorted_seqlens_in_batch, ) return outputs @torch.no_grad() def generate( self, input_ids: Optional[torch.FloatTensor] = None, images: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, **generation_kwargs, ): if images is not None: ( _, _, attention_mask, _, inputs_embeds, _, ) = self.prepare_inputs_labels_for_multimodal( input_ids, None, attention_mask, None, None, images ) else: inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = inputs_embeds.to(self.dtype) outputs = self.llm.generate( inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_kwargs, ) return outputs def init_dam(self, conv_mode, prompt_mode): model_name = get_model_name_from_path(self.pretrained_model_name_or_path) self.dam_model = DescribeAnythingModel( model_path=dict( model=self, tokenizer=self.tokenizer, model_name=model_name ), conv_mode=conv_mode, prompt_mode=prompt_mode, ) return self.dam_model @property def dam(self): if self.dam_model is None: self.init_dam() return self.dam_model AutoConfig.register("llava_llama", LlavaLlamaConfig) AutoModel.register(LlavaLlamaConfig, LlavaLlamaModel) import torch # noqa def has_tokenizer(path): if ( osp.exists(osp.join(path, "special_tokens_map.json")) and osp.exists(osp.join(path, "tokenizer_config.json")) and ( osp.exists(osp.join(path, "tokenizer.model")) or osp.exists(osp.join(path, "tokenizer.json")) ) ): # print("[has_tokenizer]", path, True) return True from huggingface_hub import HfApi, file_exists from huggingface_hub.utils import HFValidationError api = HfApi() try: valid_hf_repo = api.repo_exists(path) except HFValidationError: valid_hf_repo = False if ( valid_hf_repo and file_exists(path, "special_tokens_map.json") and file_exists(path, "tokenizer_config.json") and ( file_exists(path, "tokenizer.model") or file_exists(path, "tokenizer.json") ) ): # print("[has_tokenizer]", path, True) return True # print("[has_tokenizer]", path, False) return False def context_length_extension(config): orig_ctx_len = getattr(config, "max_position_embeddings", None) model_max_length = getattr(config, "model_max_length", None) if orig_ctx_len and model_max_length > orig_ctx_len: print(f"Scaling RoPE from {orig_ctx_len} to {model_max_length}") scaling_factor = float(math.ceil(model_max_length / orig_ctx_len)) config.rope_scaling = {"type": "linear", "factor": scaling_factor} return config def build_llm_and_tokenizer( model_name_or_path: str, config: PretrainedConfig, # config_cls: PretrainedConfig = None, # llm_cls: PreTrainedModel = None, attn_implementation=None, model_max_length=None, *args, **kwargs, ) -> PreTrainedModel: # if config_cls is None: # config_cls = AutoConfig # if llm_cls is None: # llm_cls = AutoModelForCausalLM # config_cls = AutoConfig # llm_cls = AutoModelForCausalLM # extra configuration for llm # print("build_llm_and_tokenizer():", model_name_or_path); input("DEBUG") llm_cfg = AutoConfig.from_pretrained(model_name_or_path) llm_cfg._attn_implementation = attn_implementation llm_cfg.model_max_length = model_max_length if model_max_length is not None: context_length_extension(llm_cfg) llm = AutoModelForCausalLM.from_pretrained( model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs, ) llm_path = model_name_or_path if not has_tokenizer(llm_path): warnings.warn( "tokenizer found in VLM root folder. Move to ./{VILA}/llm in the future." ) llm_path = osp.join(llm_path, "llm") # TODO(ligeng): use LLM class to judge to better compability. if "mpt" in model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( llm_path, model_max_length=llm_cfg.model_max_length, padding_side="right", ) elif "yi" in model_name_or_path.lower(): tokenizer = AutoTokenizer.from_pretrained( llm_path, model_max_length=llm_cfg.model_max_length, padding_side="right", use_fast=False, ) else: tokenizer = AutoTokenizer.from_pretrained( llm_path, model_max_length=llm_cfg.model_max_length, padding_side="right", use_fast=False, legacy=False, ) # TODO(ligeng): is this necessary for llava? config.hidden_size = llm.config.hidden_size return llm, tokenizer # This file is modified from https://github.com/haotian-liu/LLaVA/ and https://github.com/NVlabs/VILA/ # Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # TODO: we may move LlavaConfig to configuration_llava.py # from model.configuration_llava import LlavaConfig def disable_torch_init(): """ Disable the redundant torch default initialization to accelerate model creation. """ import torch setattr(torch.nn.Linear, "reset_parameters", lambda self: None) setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) def load_pretrained_model( model_path, model_name, model_base=None, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", **kwargs, ): kwargs = {"device_map": device_map, **kwargs} if device != "cuda": kwargs["device_map"] = {"": device} if load_8bit: kwargs["load_in_8bit"] = True elif load_4bit: kwargs["load_in_4bit"] = True kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) else: kwargs["torch_dtype"] = torch.float16 config = AutoConfig.from_pretrained(model_path) config.resume_path = model_path prepare_config_for_eval(config, kwargs) model = LlavaLlamaModel(config=config, low_cpu_mem_usage=True, **kwargs) tokenizer = model.tokenizer model.eval() # mm_use_im_start_end = getattr( # model.config, "mm_use_im_start_end", False) # mm_use_im_patch_token = getattr( # model.config, "mm_use_im_patch_token", True) # if mm_use_im_patch_token: # tokenizer.add_tokens( # [DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) # if mm_use_im_start_end: # tokenizer.add_tokens( # [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True # ) model.resize_token_embeddings(len(tokenizer)) vision_tower = model.get_vision_tower() vision_tower.to(device=device, dtype=torch.float16) mm_projector = model.get_mm_projector() mm_projector.to(device=device, dtype=torch.float16) context_provider = model.get_context_provider() if context_provider is not None: context_provider.to(device=device, dtype=torch.float16) image_processor = vision_tower.image_processor if hasattr(model.llm.config, "max_sequence_length"): context_len = model.config.max_sequence_length else: context_len = 2048 return tokenizer, model, image_processor, context_len def parse_model_name_or_path(config: PretrainedConfig, model_name="llm", suffix="_cfg"): target_model = f"{model_name}{suffix}" target_cfg = getattr(config, target_model, None) if isinstance(target_cfg, str): return target_cfg elif isinstance(target_cfg, dict): return target_cfg["architectures"][0] else: raise ValueError(f"Invalid {target_model} configuration!") def prepare_config_for_eval(config: PretrainedConfig, kwargs: dict): try: # compatible with deprecated config convention if getattr(config, "vision_tower_cfg", None) is None: config.vision_tower_cfg = config.mm_vision_tower except AttributeError: raise ValueError( f"Invalid configuration! Cannot find vision_tower in config:\n{config}" ) config.model_dtype = kwargs.pop("torch_dtype").__str__() # siglip does not support device_map = "auto" vision_tower_name = parse_model_name_or_path(config, "vision_tower") if "siglip" in vision_tower_name.lower(): kwargs["device_map"] = "cuda" class DescribeAnythingModel(nn.Module): def __init__(self, model_path, conv_mode, prompt_mode, **kwargs): super().__init__() self.model_path = model_path self.conv_mode = conv_mode self.prompt_mode = prompt_mode if isinstance(model_path, str): self.tokenizer, self.model, _, _ = load_pretrained_model( model_path, None, None, **kwargs ) self.model_name = get_model_name_from_path(model_path) else: # model_path is actually a dict with model, tokenizer, and (optionally) model_name self.model = model_path["model"] self.tokenizer = model_path["tokenizer"] self.model_name = model_path.get("model_name", None) image_processor = self.model.vision_tower.image_processor self.model.config.image_processor = image_processor def get_prompt(self, qs): if DEFAULT_IMAGE_TOKEN not in qs: raise ValueError("no tag found in input.") conv = conv_templates[self.conv_mode].copy() conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() return prompt, conv @staticmethod def mask_to_box(mask_np): mask_coords = np.argwhere(mask_np) y0, x0 = mask_coords.min(axis=0) y1, x1 = mask_coords.max(axis=0) + 1 h = y1 - y0 w = x1 - x0 return x0, y0, w, h @classmethod def crop_image(cls, pil_img, mask_np, crop_mode, min_box_w=48, min_box_h=48): if crop_mode == "full": # no crop info = dict(mask_np=mask_np) return pil_img, info if crop_mode == "crop": # crop image and mask x0, y0, w, h = cls.mask_to_box(mask_np) img_np = np.asarray(pil_img) assert ( img_np.shape[:2] == mask_np.shape ), f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}" cropped_mask_np = mask_np[y0 : y0 + h, x0 : x0 + w] cropped_img_np = img_np[y0 : y0 + h, x0 : x0 + w] cropped_pil_img = Image.fromarray(cropped_img_np) elif crop_mode == "context_crop": # crop image and mask x0, y0, w, h = cls.mask_to_box(mask_np) img_np = np.asarray(pil_img) assert ( img_np.shape[:2] == mask_np.shape ), f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}" img_h, img_w = img_np.shape[:2] cropped_mask_np = mask_np[ max(y0 - h, 0) : min(y0 + 2 * h, img_h), max(x0 - w, 0) : min(x0 + 2 * w, img_w), ] cropped_img_np = img_np[ max(y0 - h, 0) : min(y0 + 2 * h, img_h), max(x0 - w, 0) : min(x0 + 2 * w, img_w), ] cropped_pil_img = Image.fromarray(cropped_img_np) elif crop_mode == "focal_crop": # crop image and mask x0, y0, w, h = cls.mask_to_box(mask_np) img_np = np.asarray(pil_img) assert ( img_np.shape[:2] == mask_np.shape ), f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}" img_h, img_w = img_np.shape[:2] xc, yc = x0 + w / 2, y0 + h / 2 # focal_crop: need to have at least min_box_w and min_box_h pixels, otherwise resizing to (384, 384) leads to artifacts that may be OOD w, h = max(w, min_box_w), max(h, min_box_h) x0, y0 = int(xc - w / 2), int(yc - h / 2) cropped_mask_np = mask_np[ max(y0 - h, 0) : min(y0 + 2 * h, img_h), max(x0 - w, 0) : min(x0 + 2 * w, img_w), ] cropped_img_np = img_np[ max(y0 - h, 0) : min(y0 + 2 * h, img_h), max(x0 - w, 0) : min(x0 + 2 * w, img_w), ] cropped_pil_img = Image.fromarray(cropped_img_np) elif crop_mode == "crop_mask": # crop image and mask x0, y0, w, h = cls.mask_to_box(mask_np) img_np = np.asarray(pil_img) assert ( img_np.shape[:2] == mask_np.shape ), f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}" cropped_mask_np = mask_np[y0 : y0 + h, x0 : x0 + w] cropped_img_np = img_np[y0 : y0 + h, x0 : x0 + w] # Mask the image cropped_img_np = cropped_img_np * cropped_mask_np[..., None] cropped_pil_img = Image.fromarray(cropped_img_np) else: raise ValueError(f"Unsupported crop_mode: {crop_mode}") info = dict(mask_np=cropped_mask_np) return cropped_pil_img, info def get_description( self, image_pil, mask_pil, query, streaming=False, temperature=0.2, top_p=0.5, num_beams=1, max_new_tokens=512, **kwargs, ): # kwargs is passed to generation_kwargs: https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig prompt, conv = self.get_prompt(query) if not isinstance(image_pil, (list, tuple)): assert not isinstance( mask_pil, (list, tuple) ), "image_pil and mask_pil must be both list or tuple or not list or tuple." image_pils = [image_pil] mask_pils = [mask_pil] else: image_pils = image_pil mask_pils = mask_pil description = self.get_description_from_prompt( image_pils, mask_pils, prompt, conv, streaming=streaming, temperature=temperature, top_p=top_p, num_beams=num_beams, max_new_tokens=max_new_tokens, **kwargs, ) return description def get_image_tensor(self, image_pil, mask_pil, crop_mode, crop_mode2): # the pil has True/False (if the value is non-zero, then we treat it as True) mask_np = (np.asarray(mask_pil) > 0).astype(np.uint8) images_tensor, image_info = process_image( image_pil, self.model.config, None, pil_preprocess_fn=lambda pil_img: self.crop_image( image_pil, mask_np=mask_np, crop_mode=crop_mode ), ) images_tensor = images_tensor[None].to(self.model.device, dtype=torch.float16) mask_np = image_info["mask_np"] mask_pil = Image.fromarray(mask_np * 255) masks_tensor = process_image(mask_pil, self.model.config, None) masks_tensor = masks_tensor[None].to(self.model.device, dtype=torch.float16) images_tensor = torch.cat((images_tensor, masks_tensor[:, :1, ...]), dim=1) if crop_mode2 is not None: images_tensor2, image_info2 = process_image( image_pil, self.model.config, None, pil_preprocess_fn=lambda pil_img: self.crop_image( pil_img, mask_np=mask_np, crop_mode=crop_mode2 ), ) images_tensor2 = images_tensor2[None].to( self.model.device, dtype=torch.float16 ) mask_np2 = image_info2["mask_np"] mask_pil2 = Image.fromarray(mask_np2 * 255) masks_tensor2 = process_image(mask_pil2, self.model.config, None) masks_tensor2 = masks_tensor2[None].to( self.model.device, dtype=torch.float16 ) images_tensor2 = torch.cat( (images_tensor2, masks_tensor2[:, :1, ...]), dim=1 ) else: images_tensor2 = None return ( torch.cat((images_tensor, images_tensor2), dim=1) if images_tensor2 is not None else images_tensor ) def get_description_from_prompt( self, image_pils, mask_pils, prompt, conv, streaming=False, temperature=0.2, top_p=0.5, num_beams=1, max_new_tokens=512, **kwargs, ): if streaming: return self.get_description_from_prompt_iterator( image_pils, mask_pils, prompt, conv, streaming=True, temperature=temperature, top_p=top_p, num_beams=num_beams, max_new_tokens=max_new_tokens, **kwargs, ) else: # If streaming is False, there will be only one output output = self.get_description_from_prompt_iterator( image_pils, mask_pils, prompt, conv, streaming=False, temperature=temperature, top_p=top_p, num_beams=num_beams, max_new_tokens=max_new_tokens, **kwargs, ) return next(output) def get_description_from_prompt_iterator( self, image_pils, mask_pils, prompt, conv, streaming=False, temperature=0.2, top_p=0.5, num_beams=1, max_new_tokens=512, **kwargs, ): crop_mode, crop_mode2 = self.prompt_mode.split("+") assert ( crop_mode == "full" ), "Current prompt only supports first crop as full (non-cropped). If you need other specifications, please update the prompt." assert len(image_pils) == len( mask_pils ), f"image_pils and mask_pils must have the same length. Got {len(image_pils)} and {len(mask_pils)}." image_tensors = [ self.get_image_tensor( image_pil, mask_pil, crop_mode=crop_mode, crop_mode2=crop_mode2 ) for image_pil, mask_pil in zip(image_pils, mask_pils) ] input_ids = ( tokenizer_image_token( prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" ) .unsqueeze(0) .cuda() ) stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria( keywords, self.tokenizer, input_ids ) streamer = ( TextIteratorStreamer( self.tokenizer, skip_prompt=True, skip_special_tokens=True ) if streaming else None ) generation_kwargs = dict( input_ids=input_ids, images=image_tensors, do_sample=True if temperature > 0 else False, use_cache=True, stopping_criteria=[stopping_criteria], streamer=streamer, temperature=temperature, top_p=top_p, num_beams=num_beams, max_new_tokens=max_new_tokens, **kwargs, ) if streaming: thread = Thread(target=self.model.generate, kwargs=generation_kwargs) thread.start() generated_text = "" for new_text in streamer: generated_text += new_text if stop_str in generated_text: generated_text = generated_text[: generated_text.find(stop_str)] break yield new_text thread.join() else: with torch.inference_mode(): output_ids = self.model.generate(**generation_kwargs) outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[ 0 ] outputs = outputs.strip() if outputs.endswith(stop_str): outputs = outputs[: -len(stop_str)] outputs = outputs.strip() yield outputs