SreyanG-NVIDIA's picture
Upload 225 files
174ae06 verified
# Copyright (c) 2025 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
# LICENSE is in incl_licenses directory.
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
import math
import os
import os.path as osp
import warnings
from dataclasses import asdict
from typing import Tuple
import torch
from huggingface_hub import file_exists, repo_exists
from huggingface_hub.utils import HFValidationError
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForVision2Seq,
AutoTokenizer,
PretrainedConfig,
PreTrainedModel,
PreTrainedTokenizer,
)
from llava.constants import MEDIA_TOKENS
from llava.model.utils import packing
from llava.utils.logging import logger
from llava.utils.tokenizer import infer_stop_tokens
def has_tokenizer(repo_id_or_path: str) -> bool:
# Check if the tokenizer is in a local directory
if osp.exists(osp.join(repo_id_or_path, "tokenizer_config.json")):
return True
# Check if the tokenizer is in a Hugging Face Hub repo
try:
return repo_exists(repo_id_or_path) and file_exists(repo_id_or_path, "tokenizer_config.json")
except HFValidationError:
return False
def context_length_extension(config):
orig_ctx_len = getattr(config, "max_position_embeddings", None)
model_max_length = getattr(config, "model_max_length", None)
if orig_ctx_len and model_max_length > orig_ctx_len:
print(f"Scaling RoPE from {orig_ctx_len} to {model_max_length}")
scaling_factor = float(math.ceil(model_max_length / orig_ctx_len))
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
return config
def build_llm_and_tokenizer(
model_name_or_path: str,
config: PretrainedConfig,
attn_implementation=None,
model_max_length=None,
*args,
**kwargs,
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
# print(model_name_or_path)
llm_cfg = AutoConfig.from_pretrained(model_name_or_path)
llm_cfg._attn_implementation = attn_implementation
llm_cfg.model_max_length = model_max_length
if model_max_length is not None:
context_length_extension(llm_cfg)
# Quantization related
quantization_restore_from_checkpoint = False
if kwargs.get("quantize_model_class") is not None:
assert kwargs.get("model_args") is not None
quantize_model_class = kwargs.pop("quantize_model_class", None)
model_args = kwargs.pop("model_args", None)
if quantize_model_class == "QLlamaForCausalLM": # TODO: Also change the name of this class
from .qllama import QLlamaConfig
llm_cfg.architectures = "QLlamaForCausalLM"
_attn_implementation = llm_cfg._attn_implementation
llm_cfg = QLlamaConfig(**llm_cfg.to_dict())
llm_cfg._attn_implementation = _attn_implementation
elif quantize_model_class == "QMemLlamaForCausalLM": # TODO: Also change the name of this class
from .qmemllama import QMemLlamaConfig
llm_cfg.architectures = "QMemLlamaForCausalLM"
llm_cfg = QMemLlamaConfig(**llm_cfg.to_dict())
elif quantize_model_class == "FP8LinearQwen2ForCausalLM":
from .configuration_quantize import QuantizationConfig
from .fp8linearqwen2 import FP8LinearQwen2Config
llm_cfg.architectures = "FP8LinearQwen2ForCausalLM"
coat_fp8_args = QuantizationConfig(**asdict(model_args))
# Remove the quantization args from llm_cfg and make it a independent config
model_args_dict = asdict(model_args)
for key in asdict(coat_fp8_args).keys():
model_args_dict.pop(key, None)
llm_cfg.coat_fp8_args = asdict(coat_fp8_args)
_attn_implementation = llm_cfg._attn_implementation
llm_cfg = FP8LinearQwen2Config(**llm_cfg.to_dict())
llm_cfg._attn_implementation = _attn_implementation
elif quantize_model_class == "FP8ActivationQwen2ForCausalLM":
from ..coat.activation.models._fp8_quantization_config import QuantizationConfig
from .fp8activationqwen2 import FP8ActivationQwen2Config
quantization_restore_from_checkpoint = True
llm_cfg.architectures = "FP8ActivationQwen2ForCausalLM"
coat_fp8_args = QuantizationConfig(**asdict(model_args))
# Remove the quantization args from llm_cfg and make it a independent config
model_args_dict = asdict(model_args)
for key in asdict(coat_fp8_args).keys():
model_args_dict.pop(key, None)
llm_cfg.coat_fp8_args = asdict(coat_fp8_args)
_attn_implementation = llm_cfg._attn_implementation
llm_cfg = FP8ActivationQwen2Config(**llm_cfg.to_dict())
llm_cfg._attn_implementation = _attn_implementation
elif quantize_model_class == "FP8ActivationResidualQwen2ForCausalLM":
from ..coat.activation.models._fp8_quantization_config import QuantizationConfig
from .fp8activationresidualqwen2 import FP8ActivationResidualQwen2Config
quantization_restore_from_checkpoint = True
llm_cfg.architectures = "FP8ActivationResidualQwen2ForCausalLM"
coat_fp8_args = QuantizationConfig(**asdict(model_args))
# Remove the quantization args from llm_cfg and make it a independent config
model_args_dict = asdict(model_args)
for key in asdict(coat_fp8_args).keys():
model_args_dict.pop(key, None)
llm_cfg.coat_fp8_args = asdict(coat_fp8_args)
_attn_implementation = llm_cfg._attn_implementation
llm_cfg = FP8ActivationResidualQwen2Config(**llm_cfg.to_dict())
llm_cfg._attn_implementation = _attn_implementation
else:
raise ValueError(f"{quantize_model_class} is not supported quantize_model_class.")
kwargs.pop("quantize_model_class", None)
if quantize_model_class in [
"FP8LinearQwen2ForCausalLM",
"FP8ActivationQwen2ForCausalLM",
"FP8ActivationResidualQwen2ForCausalLM",
]: # Remove the quantization args from llm_cfg and make it a independent config
llm_cfg.update(model_args_dict)
else:
llm_cfg.update(asdict(model_args))
# print(model_args)
if quantization_restore_from_checkpoint:
fp8_model_name_or_path = kwargs.pop("fp8_llm_cfg", None)
llm = AutoModelForCausalLM.from_pretrained(
model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
)
else:
llm = AutoModelForCausalLM.from_pretrained(
model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
)
packing.patch(llm)
# Locate the tokenizer.
llm_path = model_name_or_path
if not has_tokenizer(llm_path):
llm_path = osp.join(llm_path, "llm")
if not has_tokenizer(llm_path):
raise ValueError(f"Cannot find tokenizer in {llm_path}.")
tokenizer = AutoTokenizer.from_pretrained(llm_path, padding_side="right", use_fast=True, legacy=False)
if model_max_length is not None:
tokenizer.model_max_length = model_max_length
# Load chat template if specified.
if getattr(config, "chat_template", None) is not None:
logger.info(f"Using chat template: {config.chat_template}")
fpath = os.path.join(os.path.dirname(__file__), "chat_templates", f"{config.chat_template}.jinja")
with open(fpath) as fd:
chat_template = fd.read()
tokenizer.chat_template = chat_template.replace(" ", "").replace("\n", "")
# Set stop tokens for the tokenizer
tokenizer.stop_tokens = infer_stop_tokens(tokenizer)
tokenizer.stop_token_ids = tokenizer.convert_tokens_to_ids(tokenizer.stop_tokens)
# Add media tokens to the tokenizer
tokenizer.media_tokens = MEDIA_TOKENS
tokenizer.media_token_ids = {}
for name, token in MEDIA_TOKENS.items():
tokenizer.add_tokens([token], special_tokens=True)
tokenizer.media_token_ids[name] = tokenizer.convert_tokens_to_ids(token)
# TODO(ligeng): is this necessary for llava?
config.hidden_size = llm.config.hidden_size
return llm, tokenizer