| | |
| | |
| |
|
| | from collections.abc import Iterable |
| | from typing import Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from vllm.config import VllmConfig |
| | from vllm.logger import init_logger |
| | from vllm.model_executor.layers.layernorm import RMSNorm |
| | from vllm.model_executor.layers.logits_processor import LogitsProcessor |
| | from vllm.model_executor.layers.vocab_parallel_embedding import ( |
| | DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) |
| | from vllm.model_executor.model_loader.weight_utils import default_weight_loader |
| | from vllm.model_executor.models import ModelRegistry |
| | from vllm.model_executor.sampling_metadata import SamplingMetadata |
| | from vllm.sequence import IntermediateTensors |
| |
|
| | from .utils import maybe_prefix |
| |
|
| | logger = init_logger(__name__) |
| |
|
| |
|
| | class DummyInputLayerNorm(nn.Module): |
| |
|
| | def __init__(self, weight=None, bias=None): |
| | super().__init__() |
| | self.weight = nn.Parameter(weight) if weight is not None else None |
| | self.bias = nn.Parameter(bias) if bias is not None else None |
| |
|
| | def forward(self, x): |
| | return x |
| |
|
| |
|
| | class DummyOutputNorm(nn.Module): |
| |
|
| | def forward(self, x, residual): |
| | if residual is None: |
| | return x |
| | else: |
| | return x + residual, None |
| |
|
| |
|
| | class EAGLE(nn.Module): |
| | """This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077 |
| | Reference implementation: https://github.com/SafeAILab/EAGLE |
| | |
| | Differences from reference implementation: |
| | 1. In reference, LlamaDecoderLayer implementation doesn't have |
| | input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427). |
| | Following this approach, our implementation also disables |
| | the input_layernorm for the first decoder layer. |
| | 2. We allow any decoder layer to be used in EAGLE whereas in reference |
| | decoder layer is fixed to be LlamaDecoderLayer. |
| | 3. We have an optional token_map which reduces draft vocab to most |
| | frequently used tokens to give some additional speed-up by reducing |
| | sampling overhead. This is disabled unless the checkpoint file has |
| | explicit token_map tensor and config has an optional attribute |
| | truncated_vocab_size < vocab_size. To use this technique, one has to find |
| | the top-k most frequent tokens in target dataset and add that as a tensor |
| | in the draft checkpoint (using key token_map). Also, the draft config |
| | needs to have truncated_vocab_size (=k) as an attribute. |
| | 4. We allow an enhanced EAGLE architecture similar to the DeepSeek MTP |
| | module with regards to the use of additional RMS norms. The original |
| | EAGLE architecture 1) skips the pre-attention norm in its first |
| | transformer block, and 2) skips the final output norm, both of which we |
| | found to be suboptimal. We also add the support for separate norms |
| | applying to both the token embedding and hidden states before projection |
| | as in DeepSeek MTP, which we found to improve performance as well. |
| | """ |
| |
|
| | def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
| | super().__init__() |
| | config = vllm_config.model_config.hf_config |
| | self.config = config |
| |
|
| | architectures = getattr(self.config.model, "architectures", []) |
| | model_cls, _ = ModelRegistry.resolve_model_cls(architectures) |
| |
|
| | self.model = model_cls(vllm_config=vllm_config, |
| | prefix=maybe_prefix(prefix, "model")) |
| |
|
| | self.fc = nn.Linear(config.model.hidden_size * 2, |
| | config.model.hidden_size, |
| | bias=getattr(self.config, "eagle_fc_bias", False)) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | if not hasattr(self.config.model, |
| | "skip_prenorm") or self.config.model.skip_prenorm: |
| | self.model.model.layers[0].input_layernorm = DummyInputLayerNorm( |
| | weight=self.model.model.layers[0].input_layernorm.weight) |
| |
|
| | if not hasattr( |
| | self.config.model, |
| | "skip_output_norm") or self.config.model.skip_output_norm: |
| | self.model.model.norm = DummyOutputNorm() |
| |
|
| | self.add_para_norm = False |
| | if hasattr(self.config.model, |
| | "add_para_norm") and self.config.model.add_para_norm: |
| | self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| | self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| | self.add_para_norm = True |
| |
|
| | self.orig_vocab_size = config.vocab_size |
| | self.truncated_vocab_size = config.truncated_vocab_size |
| | self.unpadded_vocab_size = self.truncated_vocab_size |
| |
|
| | self.lm_head = ParallelLMHead( |
| | self.unpadded_vocab_size, |
| | config.hidden_size, |
| | org_num_embeddings=self.truncated_vocab_size, |
| | padding_size=DEFAULT_VOCAB_PADDING_SIZE, |
| | ) |
| |
|
| | logit_scale = getattr(config, "logit_scale", 1.0) |
| | self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, |
| | self.truncated_vocab_size, |
| | logit_scale) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | self.token_map = None |
| |
|
| | def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: |
| | return self.model.model.get_input_embeddings(input_ids) |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | positions: torch.Tensor, |
| | previous_hidden_states: torch.Tensor, |
| | intermediate_tensors: Optional[IntermediateTensors] = None, |
| | inputs_embeds: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| |
|
| | if inputs_embeds is None: |
| | inputs_embeds = self.get_input_embeddings(input_ids) |
| |
|
| | |
| | |
| | batch_size = inputs_embeds.size(0) |
| | if previous_hidden_states.size(0) == 0 or \ |
| | previous_hidden_states.size(0) != batch_size: |
| | hidden_dim = self.config.model.hidden_size |
| | device = inputs_embeds.device |
| | |
| | previous_hidden_states = \ |
| | torch.zeros(batch_size, hidden_dim, device=device) |
| |
|
| | if self.add_para_norm: |
| | inputs_embeds = torch.cat([ |
| | self.enorm(inputs_embeds), |
| | self.hnorm(previous_hidden_states) |
| | ], |
| | dim=-1) |
| | else: |
| | inputs_embeds = torch.cat([inputs_embeds, previous_hidden_states], |
| | dim=-1) |
| |
|
| | inputs_embeds = self.fc(inputs_embeds) |
| |
|
| | inputs_embeds[positions == 0] = 0 |
| |
|
| | hidden_states = self.model.model( |
| | input_ids=None, |
| | inputs_embeds=inputs_embeds, |
| | positions=positions, |
| | intermediate_tensors=intermediate_tensors, |
| | ) |
| | return hidden_states |
| |
|
| | def compute_logits(self, hidden_states: torch.Tensor, |
| | sampling_metadata: SamplingMetadata) -> torch.Tensor: |
| | logits = self.logits_processor(self.lm_head, hidden_states, |
| | sampling_metadata) |
| |
|
| | if self.token_map is not None: |
| | _logits = logits |
| | logits = -torch.inf * torch.ones( |
| | size=(*_logits.shape[:-1], self.orig_vocab_size), |
| | device=_logits.device, |
| | dtype=_logits.dtype) |
| |
|
| | logits[..., self.token_map] = _logits |
| |
|
| | return logits |
| |
|
| | def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): |
| | |
| | |
| | |
| | |
| | |
| | |
| | model_weights = {} |
| | for name, loaded_weight in weights: |
| | if name == "token_map": |
| | if self.config.truncated_vocab_size < self.config.vocab_size: |
| | self.token_map = nn.Parameter(loaded_weight, |
| | requires_grad=False) |
| | elif name.startswith("fc.weight"): |
| | weight_loader = getattr(self.fc.weight, "weight_loader", |
| | default_weight_loader) |
| | weight_loader(self.fc.weight, loaded_weight) |
| | elif name.startswith("fc.bias"): |
| | if self.fc.bias is not None: |
| | weight_loader = getattr(self.fc.bias, "weight_loader", |
| | default_weight_loader) |
| | weight_loader(self.fc.bias, loaded_weight) |
| | else: |
| | logger.warning_once("Found bias in the loaded weights but " |
| | "the model config doesn't have bias.") |
| | elif name.startswith("enorm.weight"): |
| | weight_loader = getattr(self.enorm.weight, "weight_loader", |
| | default_weight_loader) |
| | weight_loader(self.enorm.weight, loaded_weight) |
| | elif name.startswith("hnorm.weight"): |
| | weight_loader = getattr(self.hnorm.weight, "weight_loader", |
| | default_weight_loader) |
| | weight_loader(self.hnorm.weight, loaded_weight) |
| | elif name.startswith("model.lm_head.") or name.startswith( |
| | "model.model."): |
| | model_weights[name.split("model.", 1)[-1]] = loaded_weight |
| | elif name.startswith("lm_head.") or name.startswith("model."): |
| | model_weights[name] = loaded_weight |
| | else: |
| | model_weights[f"model.{name}"] = loaded_weight |
| |
|
| | if "lm_head.weight" in model_weights: |
| | lm_head_weight = model_weights.pop("lm_head.weight") |
| |
|
| | if self.token_map is not None and\ |
| | lm_head_weight.shape[0] > self.token_map.shape[0]: |
| |
|
| | lm_head_weight = lm_head_weight[self.token_map] |
| |
|
| | else: |
| | |
| | lm_head_weight = torch.zeros( |
| | self.lm_head.org_vocab_size, |
| | self.lm_head.embedding_dim, |
| | dtype=self.config.torch_dtype, |
| | ) |
| |
|
| | weight_loader = getattr(self.lm_head.weight, "weight_loader", |
| | default_weight_loader) |
| | weight_loader(self.lm_head.weight, lm_head_weight) |
| |
|
| | self.model.load_weights(model_weights.items()) |
| |
|