MagpieTTS_Internal_Demo / scripts /vlm /gemma3vl_generate.py
subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example:
python scripts/vlm/gemma3vl_generate.py --local_model_path="path/to/converted_nemo_checkpoint"
"""
import argparse
from pathlib import Path
import torch
import torch.distributed as dist
from megatron.core import parallel_state
from megatron.core.pipeline_parallel.schedules import get_forward_backward_func
import nemo.lightning as nl
from nemo.collections.common.tokenizers import AutoTokenizer
from nemo.collections.vlm import Gemma3VLModel
from nemo.collections.vlm.inference.base import _setup_trainer_and_restore_model
from nemo.lightning import io
from nemo.lightning.ckpt_utils import ckpt_to_context_subdir
from nemo.utils.get_rank import get_last_rank
class SingleBatchIterator:
def __init__(self, pixel_values, input_ids, position_ids):
self.batch = dict(
pixel_values=pixel_values,
input_ids=input_ids,
position_ids=position_ids,
)
self._yielded = False
def __iter__(self):
return self
def __next__(self):
if self._yielded:
raise StopIteration
self._yielded = True
return self.batch
def gemma3_forward_step(data_iterator, model, **kwargs) -> torch.Tensor:
batch = next(data_iterator)
forward_args = {
"input_ids": batch["input_ids"],
"position_ids": batch["position_ids"],
"pixel_values": batch.get("pixel_values", None),
"loss_mask": batch.get("loss_mask", None),
"labels": batch.get("labels", None),
}
def loss_func(x, **kwargs):
return x
return model(**forward_args), loss_func
def main(args) -> None:
# pylint: disable=C0115,C0116,C0301
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=args.tp,
pipeline_model_parallel_size=args.pp,
sequence_parallel=args.tp > 1,
ckpt_include_optimizer=False,
ckpt_load_strictness="log_all",
pipeline_dtype=torch.bfloat16,
)
trainer = nl.Trainer(
devices=min(args.tp * args.pp, 8),
num_nodes=max(args.tp * args.pp // 8, 1),
accelerator="gpu",
strategy=strategy,
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
enable_checkpointing=False,
)
if args.local_model_path:
path = Path(args.local_model_path)
model: io.TrainerContext = io.load_context(path=ckpt_to_context_subdir(path), subpath="model")
_setup_trainer_and_restore_model(path=path, trainer=trainer, model=model)
else:
fabric = trainer.to_fabric()
model = fabric.import_model("hf://google/gemma-3-4b-it", Gemma3VLModel)
model = model.module.cuda()
model.eval()
from transformers import AutoProcessor
model_id = 'google/gemma-3-4b-it'
processor = AutoProcessor.from_pretrained(model_id)
gemma_tokenizer = AutoTokenizer(model_id)
hf_tokenizer = gemma_tokenizer.tokenizer
messages = [
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
{
"role": "user",
"content": [
{
"type": "image",
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/p-blog/candy.JPG",
},
{"type": "text", "text": "What animal is on the candy?"},
],
},
]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
input_ids = inputs["input_ids"].cuda()
# add additional dim to (B, N, C, H, W)
pixel_values = inputs["pixel_values"].cuda().unsqueeze(0).to(dtype=torch.bfloat16)
position_ids = (
torch.arange(input_ids.size(1), dtype=torch.long, device=input_ids.device).unsqueeze(0).expand_as(input_ids)
)
generated_ids = input_ids.clone()
stop_tokens = [1, 126]
# Greedy generation loop
for step in range(20):
with torch.no_grad():
if torch.distributed.get_rank() == 0:
print(step)
fwd_bwd_function = get_forward_backward_func()
iterator = SingleBatchIterator(pixel_values, input_ids, position_ids)
output = fwd_bwd_function(
forward_step_func=gemma3_forward_step,
data_iterator=iterator,
model=model,
num_microbatches=1,
forward_only=True,
seq_length=input_ids.size(1),
micro_batch_size=1,
collect_non_loss_data=True,
)
if isinstance(output, list) and len(output) > 0:
output = output[0]
if parallel_state.is_pipeline_last_stage():
world_size = parallel_state.get_tensor_model_parallel_world_size()
gathered_tensors = [torch.zeros_like(output) for _ in range(world_size)]
# All-gather operation
dist.all_gather(gathered_tensors, output, group=parallel_state.get_tensor_model_parallel_group())
# Concatenate along last dimension (dim=2)
output = torch.cat(gathered_tensors, dim=2)
next_token_ids = torch.argmax(output[:, -1], dim=-1, keepdim=True)
else:
next_token_ids = torch.ones((1, 1), device=generated_ids.device, dtype=generated_ids.dtype)
torch.distributed.broadcast(next_token_ids, get_last_rank())
generated_ids = torch.cat([generated_ids, next_token_ids], dim=-1)
input_ids = generated_ids
position_ids = (
torch.arange(input_ids.size(1), dtype=torch.long, device=input_ids.device)
.unsqueeze(0)
.expand_as(input_ids)
)
# If the generated token is the end of sequence token, stop generating
if next_token_ids.item() in stop_tokens:
break
generated_texts = hf_tokenizer.decode(list(generated_ids[0]))
if torch.distributed.get_rank() == 0:
print("======== GENERATED TEXT OUTPUT ========")
print(f"{generated_texts}")
print("=======================================")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Gemma3 Multimodal Inference")
parser.add_argument(
"--local_model_path",
type=str,
default=None,
help="Local path to the model if not loading from Hugging Face.",
)
parser.add_argument('--tp', default=1)
parser.add_argument('--pp', default=1)
args = parser.parse_args()
main(args)