|
import os |
|
import os.path as osp |
|
from collections import defaultdict |
|
from typing import List, Union |
|
|
|
from transformers import AutoConfig, AutoImageProcessor, AutoModel, AutoProcessor, AutoTokenizer |
|
from transformers.feature_extraction_utils import BatchFeature |
|
from transformers.image_utils import ImageInput, VideoInput |
|
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack |
|
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput |
|
from transformers.utils import logging |
|
|
|
from .constants import DEFAULT_IMAGE_TOKEN, MEDIA_TOKENS |
|
from .media import Image, Video, extract_media |
|
from .mm_utils import process_image, process_images |
|
from .tokenizer_utils import tokenize_conversation |
|
|
|
|
|
class VILAProcessorKwargs(ProcessingKwargs, total=False): |
|
_defaults = { |
|
"text_kwargs": { |
|
"padding": False, |
|
}, |
|
} |
|
|
|
|
|
class VILAProcessor(ProcessorMixin): |
|
|
|
attributes = [] |
|
|
|
valid_kwargs = [] |
|
|
|
|
|
|
|
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, config=None, **kwargs): |
|
|
|
|
|
self.image_token = MEDIA_TOKENS["image"] |
|
self.video_token = MEDIA_TOKENS["video"] |
|
self.config = config |
|
self.image_processor = image_processor |
|
self.tokenizer = tokenizer |
|
super().__init__(image_processor, tokenizer, chat_template=chat_template) |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
|
if os.path.isdir(pretrained_model_name_or_path): |
|
pretrained_model_name_or_path = pretrained_model_name_or_path |
|
else: |
|
print(f"pretrained_model_name_or_path {pretrained_model_name_or_path} is not a directory, downloading") |
|
from huggingface_hub import HfApi, snapshot_download |
|
|
|
pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path) |
|
|
|
image_processor = AutoImageProcessor.from_pretrained( |
|
osp.join(pretrained_model_name_or_path, "vision_tower"), trust_remote_code=True |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
osp.join(pretrained_model_name_or_path, "llm"), trust_remote_code=True |
|
) |
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) |
|
|
|
return cls(image_processor=image_processor, tokenizer=tokenizer, config=config) |
|
|
|
def __repr__(self): |
|
return ( |
|
f"VILAProcessor(image_processor={self.image_processor}, tokenizer={self.tokenizer}, config={self.config})" |
|
) |
|
|
|
def __call__( |
|
self, |
|
conversation, |
|
images: ImageInput = None, |
|
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, |
|
videos: VideoInput = None, |
|
**kwargs: Unpack[VILAProcessorKwargs], |
|
) -> BatchFeature: |
|
|
|
|
|
media = extract_media(conversation, self.config) |
|
|
|
media_config = defaultdict(dict) |
|
for name in media: |
|
if name == "image": |
|
if len(media["image"]) == 1 and self.config.image_aspect_ratio in ["dynamic", "dynamic_s2"]: |
|
self.config.image_processor = self.image_processor |
|
if self.config.image_aspect_ratio == "dynamic": |
|
images = process_image(media["image"][0], self.config, None, enable_dynamic_res=True).half() |
|
conversation[0]["value"] = conversation[0]["value"].replace( |
|
DEFAULT_IMAGE_TOKEN, f"{DEFAULT_IMAGE_TOKEN}\n" * images.shape[0] |
|
) |
|
else: |
|
if type(self.config.s2_scales) is str: |
|
self.config.s2_scales = list(map(int, self.config.s2_scales.split(","))) |
|
images, block_sizes = process_image( |
|
media["image"][0], self.config, None, enable_dynamic_s2=True |
|
) |
|
images = images.half() |
|
media_config[name]["block_sizes"] = [block_sizes] |
|
else: |
|
images = process_images(media["image"], self.vision_tower.image_processor, self.config).half() |
|
media[name] = [image for image in images] |
|
elif name == "video": |
|
media[name] = [ |
|
process_images(images, self.vision_tower.image_processor, self.config).half() |
|
for images in media[name] |
|
] |
|
else: |
|
raise ValueError(f"Unsupported media type: {name}") |
|
|
|
input_ids = tokenize_conversation(conversation, self.tokenizer, add_generation_prompt=True).cuda().unsqueeze(0) |
|
|
|
|
|
return BatchFeature(data={"input_ids": input_ids, **media}) |
|
|
|
def batch_decode(self, *args, **kwargs): |
|
""" |
|
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please |
|
refer to the docstring of this method for more information. |
|
""" |
|
return self.tokenizer.batch_decode(*args, **kwargs) |
|
|
|
def decode(self, *args, **kwargs): |
|
""" |
|
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to |
|
the docstring of this method for more information. |
|
""" |
|
return self.tokenizer.decode(*args, **kwargs) |
|
|
|
def post_process_image_text_to_text(self, generated_outputs): |
|
""" |
|
Post-process the output of the model to decode the text. |
|
|
|
Args: |
|
generated_outputs (`torch.Tensor` or `np.ndarray`): |
|
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` |
|
or `(sequence_length,)`. |
|
|
|
Returns: |
|
`List[str]`: The decoded text. |
|
""" |
|
return self.tokenizer.batch_decode( |
|
generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False |
|
) |
|
|
|
@property |
|
def model_input_names(self): |
|
tokenizer_input_names = self.tokenizer.model_input_names |
|
image_processor_input_names = self.image_processor.model_input_names |
|
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) |
|
|
|
|
|
def apply_chat_template(self, conversation, add_generation_prompt=True, **kwargs): |
|
vila_conv = [] |
|
|
|
for chat in conversation: |
|
vila_chat = {"from": "", "value": []} |
|
if chat["role"] == "user": |
|
|
|
vila_chat["from"] = "human" |
|
for content in chat["content"]: |
|
if content["type"] == "image": |
|
vila_chat["value"].append(Image(content["path"])) |
|
elif content["type"] == "text": |
|
vila_chat["value"].append(content["text"]) |
|
else: |
|
raise ValueError(f"Unsupported content type: {content['type']}") |
|
elif chat["role"] == "assistant": |
|
vila_chat["from"] = "gpt" |
|
for content in chat["content"]: |
|
assert content["type"] == "text", f"Unsupported content type: {content['type']}" |
|
vila_chat["value"].append(content["text"]) |
|
vila_conv.append(vila_chat) |
|
|
|
return self(vila_conv) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
gpt_conv = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "image", "path": "demo_images/demo_img_1.png"}, |
|
{"type": "text", "text": "Describe this image."}, |
|
], |
|
} |
|
] |
|
|
|
llavaconv = [ |
|
{ |
|
"from": "human", |
|
"value": [ |
|
PIL.Image.open("demo_images/demo_img_1.png"), |
|
"Describe this image.", |
|
], |
|
} |
|
] |
|
|
|
processor = AutoProcessor.from_pretrained(output_dir, trust_remote_code=True) |
|
inputs = processor.apply_chat_template(conversation=gpt_conv, padding=True, return_tensors="pt") |
|
|
|
|
|
model_path = "NVILA-Lite-2B-hf-preview" |
|
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, device_map="auto") |
|
|
|
|
|
|
|
|
|
|
|
|
|
processor = VILAProcessor( |
|
config=model.config, |
|
image_processor=model.vision_tower.image_processor, |
|
tokenizer=model.tokenizer, |
|
) |
|
|
|
|
|
inputs = processor(conversation=llavaconv, padding=True, return_tensors="pt") |
|
print(inputs.keys(), inputs.input_ids.shape, [_.shape for _ in inputs.image]) |
|
print("vila conv pass") |
|
|
|
inputs = processor.apply_chat_template(conversation=gpt_conv, padding=True, return_tensors="pt") |
|
print(inputs.keys(), inputs.input_ids.shape, [_.shape for _ in inputs.image]) |
|
print("gpt conv pass") |
|
|
|
output_ids = model.generate( |
|
input_ids=inputs.input_ids, |
|
media={ |
|
"image": inputs.image, |
|
}, |
|
media_config={"image": {}}, |
|
generation_config=model.generation_config, |
|
max_new_tokens=100, |
|
) |
|
print(output_ids) |
|
|