NVILA-Lite-2B-hf-preview / auto_processor.py
Ligeng-Zhu's picture
Upload files with `vila-upload`.
eb202aa verified
import os
import os.path as osp
from collections import defaultdict
from typing import List, Union
from transformers import AutoConfig, AutoImageProcessor, AutoModel, AutoProcessor, AutoTokenizer
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput, VideoInput
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.utils import logging
from .constants import DEFAULT_IMAGE_TOKEN, MEDIA_TOKENS
from .media import Image, Video, extract_media
from .mm_utils import process_image, process_images
from .tokenizer_utils import tokenize_conversation
class VILAProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
},
}
class VILAProcessor(ProcessorMixin):
# attributes = ["image_processor", "tokenizer"]
attributes = []
# valid_kwargs = ["chat_template"]
valid_kwargs = []
# image_processor_class = "VILAImageProcessor"
# tokenizer_class = ("VILATokenizer", "VILATokenizerFast")
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, config=None, **kwargs):
# self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
# self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
self.image_token = MEDIA_TOKENS["image"]
self.video_token = MEDIA_TOKENS["video"]
self.config = config
self.image_processor = image_processor
self.tokenizer = tokenizer
super().__init__(image_processor, tokenizer, chat_template=chat_template)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
if os.path.isdir(pretrained_model_name_or_path):
pretrained_model_name_or_path = pretrained_model_name_or_path
else:
print(f"pretrained_model_name_or_path {pretrained_model_name_or_path} is not a directory, downloading")
from huggingface_hub import HfApi, snapshot_download
pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path)
image_processor = AutoImageProcessor.from_pretrained(
osp.join(pretrained_model_name_or_path, "vision_tower"), trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
osp.join(pretrained_model_name_or_path, "llm"), trust_remote_code=True
)
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
return cls(image_processor=image_processor, tokenizer=tokenizer, config=config)
def __repr__(self):
return (
f"VILAProcessor(image_processor={self.image_processor}, tokenizer={self.tokenizer}, config={self.config})"
)
def __call__(
self,
conversation,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
videos: VideoInput = None,
**kwargs: Unpack[VILAProcessorKwargs],
) -> BatchFeature:
# TODO: should be merged with llava_arch.py/generate_content()
# TODO (extract and preprocess should be done together, as the preprocess of image and video can be different, i.e. when dynamic res is used)
media = extract_media(conversation, self.config)
# Process media
media_config = defaultdict(dict)
for name in media:
if name == "image":
if len(media["image"]) == 1 and self.config.image_aspect_ratio in ["dynamic", "dynamic_s2"]:
self.config.image_processor = self.image_processor
if self.config.image_aspect_ratio == "dynamic":
images = process_image(media["image"][0], self.config, None, enable_dynamic_res=True).half()
conversation[0]["value"] = conversation[0]["value"].replace(
DEFAULT_IMAGE_TOKEN, f"{DEFAULT_IMAGE_TOKEN}\n" * images.shape[0]
)
else:
if type(self.config.s2_scales) is str:
self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
images, block_sizes = process_image(
media["image"][0], self.config, None, enable_dynamic_s2=True
)
images = images.half()
media_config[name]["block_sizes"] = [block_sizes]
else:
images = process_images(media["image"], self.vision_tower.image_processor, self.config).half()
media[name] = [image for image in images]
elif name == "video":
media[name] = [
process_images(images, self.vision_tower.image_processor, self.config).half()
for images in media[name]
]
else:
raise ValueError(f"Unsupported media type: {name}")
input_ids = tokenize_conversation(conversation, self.tokenizer, add_generation_prompt=True).cuda().unsqueeze(0)
# Set up the generation config
# print(input_ids.shape); print(media); input()
return BatchFeature(data={"input_ids": input_ids, **media})
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
Returns:
`List[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(
generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
# inputs = processor(conversation=llavaconv, padding=True, return_tensors="pt")
def apply_chat_template(self, conversation, add_generation_prompt=True, **kwargs):
vila_conv = []
for chat in conversation:
vila_chat = {"from": "", "value": []}
if chat["role"] == "user":
# user allows to input image and text
vila_chat["from"] = "human"
for content in chat["content"]:
if content["type"] == "image":
vila_chat["value"].append(Image(content["path"]))
elif content["type"] == "text":
vila_chat["value"].append(content["text"])
else:
raise ValueError(f"Unsupported content type: {content['type']}")
elif chat["role"] == "assistant":
vila_chat["from"] = "gpt"
for content in chat["content"]:
assert content["type"] == "text", f"Unsupported content type: {content['type']}"
vila_chat["value"].append(content["text"])
vila_conv.append(vila_chat)
return self(vila_conv)
if __name__ == "__main__":
# gpt style: user, assistant
# vila style: human, gpt
gpt_conv = [
{
"role": "user",
"content": [
{"type": "image", "path": "demo_images/demo_img_1.png"},
{"type": "text", "text": "Describe this image."},
],
}
]
llavaconv = [
{
"from": "human",
"value": [
PIL.Image.open("demo_images/demo_img_1.png"),
"Describe this image.",
],
}
]
processor = AutoProcessor.from_pretrained(output_dir, trust_remote_code=True)
inputs = processor.apply_chat_template(conversation=gpt_conv, padding=True, return_tensors="pt")
# model = llava.load("Efficient-Large-Model/qwen25_2B_3x3-sft").cuda()
# print(model)
model_path = "NVILA-Lite-2B-hf-preview"
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, device_map="auto")
# res = model.generate_content(["how are you today?"])
# print(model.config)
# print(model.tokenizer)
# print(res)
# exit(0)
processor = VILAProcessor(
config=model.config,
image_processor=model.vision_tower.image_processor,
tokenizer=model.tokenizer,
)
# TODO: add padding, return_tensors,
inputs = processor(conversation=llavaconv, padding=True, return_tensors="pt")
print(inputs.keys(), inputs.input_ids.shape, [_.shape for _ in inputs.image])
print("vila conv pass")
inputs = processor.apply_chat_template(conversation=gpt_conv, padding=True, return_tensors="pt")
print(inputs.keys(), inputs.input_ids.shape, [_.shape for _ in inputs.image])
print("gpt conv pass")
output_ids = model.generate(
input_ids=inputs.input_ids,
media={
"image": inputs.image,
},
media_config={"image": {}},
generation_config=model.generation_config,
max_new_tokens=100,
)
print(output_ids)