# coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Processor class for Llava. """ from typing import List, Optional, Union from transformers.feature_extraction_utils import BatchFeature from transformers.image_utils import ImageInput from transformers.tokenization_utils_base import ( PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy, ) from transformers.utils import TensorType import torch from open_clip.transform import PreprocessCfg, image_transform_v2 from modeling_llava import LlavaForConditionalGeneration from PIL import Image import math class OpenCLIPImageProcessor: def __init__(self, config, crop_size=384, max_tokens=100): cfg = PreprocessCfg(**config) transform = image_transform_v2(cfg=cfg, is_train=False) self.transform = transform self.crop_size = crop_size self.max_tokens = max_tokens def __call__(self, image: Image.Image): output = self.transform_func(image) return { "pixel_values": output, } def transform_func(self, image: Image.Image): outputs = [] outputs.append(self.transform(image)) width, height = image.size crop_size = self.crop_size if width <= crop_size and height <= crop_size: outputs = torch.stack(outputs, dim=0) return outputs total_tokens = math.inf while total_tokens > self.max_tokens: total_tokens = math.floor( (2 * width - crop_size) / crop_size * (2 * height - crop_size) / crop_size ) if total_tokens > self.max_tokens: crop_size += 10 stride = crop_size // 2 x_steps = int(round((2 * width - crop_size) / crop_size)) if x_steps < 1: x_steps = 1 y_steps = int(round((2 * height - crop_size) / crop_size)) if y_steps < 1: y_steps = 1 x_coords = [] y_coords = [] for i in range(x_steps): x_coords.append([i * stride, i * stride + crop_size]) if x_coords[-1][1] != width: x_coords[-1][1] = width for i in range(y_steps): y_coords.append([i * stride, i * stride + crop_size]) if y_coords[-1][1] != height: y_coords[-1][1] = height image_parts = [] for i in range(len(x_coords)): for j in range(len(y_coords)): image_parts.append( image.crop( (x_coords[i][0], y_coords[j][0], x_coords[i][1], y_coords[j][1]) ) ) for image_part in image_parts: outputs.append(self.transform(image_part)) outputs = torch.stack(outputs, dim=0) return outputs @property def model_input_names(self): return ["pixel_values"] class LlavaProcessor: def __init__(self, image_processor: OpenCLIPImageProcessor, tokenizer): self.image_processor = image_processor self.tokenizer = tokenizer def __call__( self, text: Union[ TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput] ] = None, images: ImageInput = None, model: LlavaForConditionalGeneration = None, padding: Union[bool, str, PaddingStrategy] = False, truncation: Union[bool, str, TruncationStrategy] = None, max_length=None, return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, ) -> BatchFeature: if images is not None: pixel_values = self.image_processor(images)[ "pixel_values" ] pixel_values = pixel_values.to(model.device).to(model.dtype) image_outputs = model.vision_model(pixel_values) image_features = model.multi_modal_projector(image_outputs) image_features = image_features.unsqueeze(0) else: image_features = None text_inputs = self.tokenizer( text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length, ) return BatchFeature(data={**text_inputs, "image_features": image_features}) def batch_decode(self, *args, **kwargs): return self.tokenizer.batch_decode(*args, **kwargs) def decode(self, *args, **kwargs): return self.tokenizer.decode(*args, **kwargs) @property def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names image_processor_input_names = self.image_processor.model_input_names return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))