COCOM_disabled_flash_attn / modeling_cocom.py
Plasmarine's picture
Remove Longformer
196e5f1 verified
raw
history blame
7.5 kB
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, PreTrainedModel, PretrainedConfig, AutoModel
from linformer.attention import LinformerSelfAttention
import torch
import math
from peft import get_peft_model, LoraConfig, TaskType
import os
# Freeze model function (unchanged)
def freeze_model(model):
for param in model.parameters():
param.requires_grad = False
# BERT_Compressor remains the same as you are not modifying it for Linformer
class BERT_Compressor(torch.nn.Module):
def __init__(self, compr_model_name, compr_rate, compr_linear_type, decoder_hidden_size):
super().__init__()
self.model_name = compr_model_name
self.model = AutoModel.from_pretrained(compr_model_name, torch_dtype=torch.float16)
self.tokenizer = AutoTokenizer.from_pretrained(compr_model_name, use_fast=True)
self.compr_rate = compr_rate
self.compressing_mode = compr_linear_type
if self.compressing_mode == 'concat':
self.linear = torch.nn.Linear(self.model.config.hidden_size*self.compr_rate, decoder_hidden_size)
elif self.compressing_mode == 'mean':
self.linear = torch.nn.Linear(self.model.config.hidden_size, decoder_hidden_size)
self.linear = self.linear.float16()
def forward(self, input_ids, attention_mask):
segment_compress_outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
num_embs = math.ceil(input_ids.size(1) / self.compr_rate)
all_hidden_states_emb = list()
if self.compressing_mode == 'concat':
for segment_idx in range(num_embs):
start_idx = segment_idx * self.compr_rate
end_idx = (segment_idx + 1) * self.compr_rate
hidden_state = segment_compress_outputs.hidden_states[-1][:, start_idx:end_idx, :]
hidden_state_concat = torch.flatten(hidden_state, start_dim=1) #batch_size, hidden_state_dim * compression_rate
all_hidden_states_emb.append(hidden_state_concat)
elif self.compressing_mode == "mean":
for segment_idx in range(num_embs):
start_idx = segment_idx * self.compr_rate
end_idx = (segment_idx + 1) * self.compr_rate
hidden_state = segment_compress_outputs.hidden_states[-1][:, start_idx:end_idx, :]
all_hidden_states_emb.append(hidden_state)
all_hidden_states_emb_cat = torch.stack(all_hidden_states_emb, dim=1)
transformed_embeds = self.linear(all_hidden_states_emb_cat)
if self.compressing_mode == "mean":
transformed_embeds = torch.mean(transformed_embeds, dim=2)
return transformed_embeds
# Modify COCOMConfig to support Linformer
class COCOMConfig(PretrainedConfig):
model_type = "COCOM"
def __init__(self,
decoder_model_name="meta-llama/Llama-2-7b-chat-hf",
quantization = 'no',
generation_top_k = 1,
sep = False,
compr_model_name = "bert-base-uncased",
compr_rate = 64,
compr_linear_type = 'concat',
lora = False,
training_form="both",
lora_r=16,
attn_implementation="linformer", # Change default to Linformer
device_map = "cuda",
**kwargs):
super().__init__(**kwargs)
self.decoder_model_name = decoder_model_name
self.quantization = quantization
self.generation_top_k = generation_top_k
self.sep = sep
self.compr_model_name = compr_model_name
self.compr_rate = compr_rate
self.compr_linear_type = compr_linear_type
self.lora = lora
self.training_form = training_form
self.lora_r = lora_r
self.attn_implementation = attn_implementation
self.device_map = device_map
# Modify COCOM model to use Linformer in the attention layer
class COCOM(PreTrainedModel):
config_class = COCOMConfig
def __init__(self, cfg):
super().__init__(cfg)
attn_impl = cfg.attn_implementation
# Load the model (decoder) in standard quantization or Linformer
self.decoder = AutoModelForCausalLM.from_pretrained(
cfg.decoder_model_name,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map=cfg.device_map
)
# Replace decoder's attention mechanism with LinformerSelfAttention if configured
if attn_impl == 'linformer':
self._replace_attention_with_linformer()
# Initialize other parts of the model (compression, LoRA, etc.)
self.compr = BERT_Compressor(cfg.compr_model_name, cfg.compr_rate, cfg.compr_linear_type, self.decoder.config.hidden_size)
if cfg.lora:
self._apply_lora(cfg.lora_r)
self.decoder_tokenizer = AutoTokenizer.from_pretrained(cfg.decoder_model_name, use_fast=True, padding_side='left')
self.decoder_tokenizer.add_special_tokens({'additional_special_tokens': ['<MEM>', '<AE>', '<ENC>', '<SEP>']})
def _replace_attention_with_linformer(self):
# Replace all attention layers with LinformerSelfAttention in the model
for layer in self.decoder.transformer.h:
layer.attn = LinformerSelfAttention(
dim=layer.attn.attn.in_proj_weight.shape[0],
num_heads=layer.attn.num_attention_heads,
dropout=0.1,
n_heads=layer.attn.num_attention_heads,
d_head=layer.attn.attn.in_proj_weight.shape[0] // layer.attn.num_attention_heads
)
def _apply_lora(self, lora_r):
# Apply LoRA as per your configuration
peft_config = LoraConfig(
task_type="CAUSAL_LM",
r=lora_r,
lora_alpha=2 * lora_r,
target_modules='all-linear',
lora_dropout=0.1,
)
self.decoder = get_peft_model(self.decoder, peft_config)
def forward(self, enc_input_ids, enc_attention_mask, dec_input_ids, dec_attention_mask, labels):
inputs_embeds = self.compress_and_replace_emb(enc_input_ids, enc_attention_mask, dec_input_ids)
decoder_outputs = self.decoder(inputs_embeds=inputs_embeds, attention_mask=dec_attention_mask, labels=labels)
return {"loss": decoder_outputs.loss, "logits": decoder_outputs.logits}
def generate(self, model_input, max_new_tokens=128):
device = self.decoder.device
enc_input_ids, enc_attention_mask, dec_input_ids, dec_attention_mask = model_input['enc_input_ids'], model_input['enc_attention_mask'], model_input['dec_input_ids'], model_input['dec_attention_mask']
inputs_embeds = self.compress_and_replace_emb(enc_input_ids.to(device), enc_attention_mask.to(device), dec_input_ids.to(device))
output_ids = self.decoder.generate(
inputs_embeds=inputs_embeds.to(device),
attention_mask=dec_attention_mask.to(device),
do_sample=False,
top_p=None,
max_new_tokens=min(max_new_tokens, 4096)
)
decoded = self.decoder_tokenizer.batch_decode(output_ids, skip_special_tokens=True)
return decoded