Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from .model import Model | |
from .utils import sample_token, get_last_attn | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch.nn.functional as F | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
class AttentionModel(Model): | |
def __init__(self, config): | |
super().__init__(config) | |
self.name = config["model_info"]["name"] | |
self.max_output_tokens = int(config["params"]["max_output_tokens"]) | |
model_id = config["model_info"]["model_id"] | |
self.tokenizer = AutoTokenizer.from_pretrained(model_id) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.bfloat16, | |
device_map=device, | |
attn_implementation="eager" | |
).eval() | |
if config["params"]["important_heads"] == "all": | |
attn_size = self.get_map_dim() | |
self.important_heads = [[i, j] for i in range( | |
attn_size[0]) for j in range(attn_size[1])] | |
else: | |
self.important_heads = config["params"]["important_heads"] | |
self.top_k = 50 | |
self.top_p = None | |
def get_map_dim(self): | |
_, _, attention_maps, _, _, _ = self.inference("print hi", "") | |
attention_map = attention_maps[0] | |
return len(attention_map), attention_map[0].shape[1] | |
# def query(self, msg, return_type="normal", max_output_tokens=None): | |
# text_split = msg.split('\nText: ') | |
# instruction, data = text_split[0], text_split[1] | |
# response, output_tokens, attention_maps, tokens, input_range, generated_probs = self.inference( | |
# instruction, data, max_output_tokens=max_output_tokens) | |
# if return_type == "attention": | |
# return response, output_tokens, attention_maps, tokens, input_range, generated_probs | |
# else: | |
# return response | |
def inference(self, instruction, data, max_output_tokens=None): | |
messages = [ | |
{"role": "system", "content": instruction}, | |
{"role": "user", "content": "\nText: " + data} | |
] | |
# Use tokenization with minimal overhead | |
text = self.tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
instruction_len = len(self.tokenizer.encode(instruction)) | |
data_len = len(self.tokenizer.encode(data)) | |
model_inputs = self.tokenizer( | |
[text], return_tensors="pt").to(self.model.device) | |
input_tokens = self.tokenizer.convert_ids_to_tokens( | |
model_inputs['input_ids'][0]) | |
if "qwen-attn" in self.name: | |
data_range = ((3, 3+instruction_len), (-5-data_len, -5)) | |
elif "phi3-attn" in self.name: | |
data_range = ((1, 1+instruction_len), (-2-data_len, -2)) | |
elif "llama2-13b" in self.name or "llama3-8b" in self.name: | |
data_range = ((5, 5+instruction_len), (-5-data_len, -5)) | |
else: | |
raise NotImplementedError | |
generated_tokens = [] | |
generated_probs = [] | |
input_ids = model_inputs.input_ids | |
attention_mask = model_inputs.attention_mask | |
attention_maps = [] | |
if max_output_tokens != None: | |
n_tokens = max_output_tokens | |
else: | |
n_tokens = self.max_output_tokens | |
with torch.no_grad(): | |
for i in range(n_tokens): | |
output = self.model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
output_attentions=True | |
) | |
logits = output.logits[:, -1, :] | |
probs = F.softmax(logits, dim=-1) | |
# next_token_id = logits.argmax(dim=-1).squeeze() | |
next_token_id = sample_token( | |
logits[0], top_k=self.top_k, top_p=self.top_p, temperature=1.0)[0] | |
generated_probs.append(probs[0, next_token_id.item()].item()) | |
generated_tokens.append(next_token_id.item()) | |
if next_token_id.item() == self.tokenizer.eos_token_id: | |
break | |
input_ids = torch.cat( | |
(input_ids, next_token_id.unsqueeze(0).unsqueeze(0)), dim=-1) | |
attention_mask = torch.cat( | |
(attention_mask, torch.tensor([[1]], device=input_ids.device)), dim=-1) | |
attention_map = [attention.detach().cpu().half() | |
for attention in output['attentions']] | |
attention_map = [torch.nan_to_num( | |
attention, nan=0.0) for attention in attention_map] | |
attention_map = get_last_attn(attention_map) | |
attention_maps.append(attention_map) | |
output_tokens = [self.tokenizer.decode( | |
token, skip_special_tokens=True) for token in generated_tokens] | |
generated_text = self.tokenizer.decode( | |
generated_tokens, skip_special_tokens=True) | |
return generated_text, output_tokens, attention_maps, input_tokens, data_range, generated_probs | |