|
from .modeling_exaone import * |
|
from beagle.mixin import * |
|
|
|
|
|
class ExaoneBeagleAttention_(ExaoneSelfAttention, BeagleAttentionMixin): |
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
past_key_value: Optional[Cache] = None, |
|
output_attentions: Optional[bool] = False, |
|
use_cache: Optional[bool] = False, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
**kwargs |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
|
|
bsz, q_len, _ = hidden_states.size() |
|
query_states, key_states, value_states = self.qkv_transform( |
|
hidden_states, past_key_value, use_cache, position_embeddings, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups) |
|
value_states = repeat_kv(value_states, self.num_key_value_groups) |
|
|
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) |
|
|
|
if attention_mask is not None: |
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
|
attn_weights = attn_weights + causal_mask |
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
|
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout_rate, training=self.training) |
|
attn_output = torch.matmul(attn_weights, value_states) |
|
|
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): |
|
raise ValueError( |
|
f"Attention outputs should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" |
|
f" {attn_output.size()}" |
|
) |
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous() |
|
|
|
attn_output = self.out_proj(attn_output) |
|
|
|
if not output_attentions: |
|
attn_weights = None |
|
|
|
return attn_output, attn_weights, past_key_value |
|
|
|
|
|
class ExaoneBeagleAttention(ExaoneAttention): |
|
def __init__(self, config, layer_id=0): |
|
super().__init__(config, layer_id) |
|
self.attention = ExaoneBeagleAttention_(config, self.layer_id) |
|
|
|
|
|
class ExaoneBeagleLayer(ExaoneBlock): |
|
def __init__(self, config, layer_id): |
|
super().__init__(config, layer_id) |
|
|
|
if not config.beagle_use_fc_eagle: |
|
delattr(self, 'attn') |
|
recycle_vram() |
|
self.attn = ExaoneBeagleAttention( |
|
config=config, layer_id=0 |
|
) |
|
|
|
|
|
class ExaoneForSpeculativeCausalLM(ExaoneForCausalLM, BeagleMixin): |
|
_no_split_modules = ["ExaoneBlock", "ExaoneBeagleLayer"] |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
BeagleMixin.__init__(self, config) |
|
self.speculative_decoder = ExaoneBeagleLayer(config, layer_id=0) |
|
|
|
self.post_init() |
|
|
|
def forward(self, *args, **kwargs) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]: |
|
return self.beagle_forward(*args, **kwargs) |
|
|