|
""" |
|
Attention conversion helpers |
|
""" |
|
from functools import partial |
|
from tqdm import tqdm |
|
import torch.nn as nn |
|
|
|
|
|
def convert_attention(model: nn.Module, |
|
attention_config: dict, |
|
train_attention: bool = False, |
|
remove_base_attn: bool = True,): |
|
""" |
|
Call to convert all attention layers |
|
""" |
|
softmax_attns = [] |
|
if 'softmax_attentions' in attention_config: |
|
softmax_attns = attention_config['softmax_attentions'] |
|
if attention_config.attention_type != 'softmax': |
|
layers = traverse_layers(model) |
|
for layer_idx, layer in enumerate(tqdm(layers, desc='Converting attentions...')): |
|
if layer_idx not in softmax_attns: |
|
layer.self_attn = convert_llama_attention( |
|
layer, attention_config, layers, train_attention, remove_base_attn, |
|
) |
|
layer.self_attn.converted = True |
|
else: |
|
for p in layer.parameters(): |
|
p.requires_grad = False |
|
else: |
|
print(f'-> attention_config.attention_type is {attention_config.attention_type}; not converting attentions') |
|
return model |
|
|
|
|
|
def toggle_attention(llama_model: nn.Module, train: bool = False): |
|
""" |
|
Make attentions trainable if train is True |
|
-> Set train_attention = False when finetuning |
|
""" |
|
for layer in traverse_layers(llama_model): |
|
layer.self_attn.train_attention = train |
|
return llama_model |
|
|
|
|
|
def remove_base_attention(llama_model: nn.Module): |
|
""" |
|
Remove teacher attention after distillation (if we keep it) |
|
""" |
|
for layer in traverse_layers(llama_model): |
|
if getattr(layer.self_attn, 'base_attn', False): |
|
del layer.self_attn.base_attn |
|
return llama_model |
|
|
|
|
|
def traverse_layers(model: nn.Module, verbose: bool = False): |
|
""" |
|
Return list of model layers |
|
""" |
|
try: |
|
layers = model.model.layers |
|
if verbose: |
|
print('-> Loading from model.model.layers') |
|
except AttributeError as e: |
|
if verbose: |
|
print(e) |
|
try: |
|
layers = model.layers |
|
if verbose: |
|
print('-> Loading from model.layers') |
|
except AttributeError as e1: |
|
if verbose: |
|
print(e1) |
|
layers = model.base_model.model.model.layers |
|
if verbose: |
|
print('-> Loading from model.base_model.model.model.layers') |
|
return layers |
|
|
|
|
|
def convert_llama_attention(layer: nn.Module, |
|
attention_config: dict, |
|
layers: list[nn.Module], |
|
train_attention: bool = False, |
|
remove_base_attn: bool = True): |
|
""" |
|
Converts a single layer's attention layer as specified by attention_config |
|
""" |
|
return get_attention(**attention_config)( |
|
base_attn=layer.self_attn, |
|
layer_idx=layer.self_attn.layer_idx, |
|
max_layer_idx=len(layers) - 1, |
|
train_attention=train_attention, |
|
remove_base_attn=remove_base_attn, |
|
) |
|
|
|
|
|
def get_attention(attention_type: str, **kwargs: any): |
|
""" |
|
Get the linear attention class; either purely linear or linear with sliding window |
|
-> 'linear' == 'lolcats_llama' |
|
-> 'linear and sliding_window' == 'lolcats_llama_window_*' |
|
""" |
|
kwargs['attention_type'] = attention_type |
|
|
|
if attention_type == 'lolcats_llama': |
|
from .linear_attention import LolcatsLinearAttention |
|
return partial(LolcatsLinearAttention, **kwargs) |
|
|
|
elif attention_type == 'lolcats_llama_window_tk': |
|
from .linear_attention import LolcatsTKWindowAttention |
|
return partial(LolcatsTKWindowAttention, **kwargs) |
|
|
|
elif attention_type == 'lolcats_llama_window_sw': |
|
from .linear_attention import LolcatsSlidingWindowAttention |
|
return partial(LolcatsSlidingWindowAttention, **kwargs) |
|
|
|
elif attention_type == 'lolcats_llama_window_sw_linear': |
|
from .linear_attention.linear_window_attention_sw_linear import LolcatsLinearSlidingWindowAttention |
|
return partial(LolcatsLinearSlidingWindowAttention, **kwargs) |
|
|
|
|
|
elif attention_type == 'lolcats_long_llama_window_tk': |
|
from .linear_attention import LolcatsTKWindowLongAttention |
|
return partial(LolcatsTKWindowLongAttention, **kwargs) |
|
|
|
elif attention_type == 'lolcats_long_llama_window_sw': |
|
from .linear_attention import LolcatsSlidingWindowLongAttention |
|
return partial(LolcatsSlidingWindowLongAttention, **kwargs) |
|
|
|
|
|
elif attention_type == 'lolcats_llama_window_tk_gen': |
|
from .linear_attention import LolcatsWindowAttentionTKGen |
|
return partial(LolcatsWindowAttentionTKGen, **kwargs) |
|
|
|
else: |
|
print(f'-> attention_type {attention_type} not handled... returning None') |
|
return None |
|
|
|
|
|
def get_attention_cache(attention_type: str, past_key_values: any = None): |
|
""" |
|
Determine how we store past keys and values when generating |
|
""" |
|
if attention_type is None: |
|
return past_key_values |
|
|
|
|
|
elif 'lolcats_llama_window_tk_gen' in attention_type: |
|
from .linear_attention import LinearAttentionTKWindowGenerationCache |
|
return LinearAttentionTKWindowGenerationCache() |
|
|
|
elif 'llama_window_tk' in attention_type: |
|
from .linear_attention import LinearAttentionTKWindowCache |
|
return LinearAttentionTKWindowCache() |
|
|
|
elif 'llama_window_sw' in attention_type: |
|
from .linear_attention import LinearAttentionSlidingWindowCache |
|
return LinearAttentionSlidingWindowCache() |
|
|
|
elif 'llama_window_sw_linear' in attention_type: |
|
from .linear_attention import LinearAttentionSlidingWindowCache |
|
return LinearAttentionSlidingWindowCache() |
|
|
|
|
|
elif attention_type == 'lolcats_llama_window_tk_gen': |
|
from .linear_attention.linear_window_attention_tk_gen import LinearAttentionTKWindowGenerationCache |
|
return LinearAttentionTKWindowGenerationCache() |
|
|
|
elif 'softmax' in attention_type: |
|
return past_key_values |
|
|
|
else: |
|
from .linear_attention import LinearAttentionState |
|
return LinearAttentionState() |
|
|