import torch from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers import AutoConfig from typing import Dict, List, Tuple, Union, Optional class FasterChatGLM(PreTrainedModel): def __init__(self, model_dir, kernel, *inputs, **kwargs): config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) config.n_head = config.num_attention_heads config.n_embd = config.hidden_size config.n_layer = config.num_layers super().__init__(config, *inputs, **kwargs) self.kernel = kernel self.fake_reg = torch.nn.Linear(2, 2) self.position_encoding_2d = True def forward(self, input_ids, position_ids, attention_mask, past_key_values, *args, **kwargs): inputs_values = [input_ids, position_ids, attention_mask] if past_key_values is not None: inputs_values = inputs_values + past_key_values computed = self.kernel.infer(inputs_values) logits = computed[0] if len(computed) == 1: present_key_values = None else: present_key_values = computed[1:] return CausalLMOutputWithPast(logits=logits, past_key_values=present_key_values) def get_masks_and_position_ids(self, seq, mask_position, context_length, device, gmask=False): attention_mask = torch.ones((1, context_length, context_length), device=device) attention_mask.tril_() attention_mask[..., :context_length - 1] = 1 attention_mask.unsqueeze_(1) attention_mask = (attention_mask < 0.5).bool() if self.position_encoding_2d: seq_length = seq.index(150004) position_ids = torch.arange(context_length, dtype=torch.long, device=device) if not gmask: position_ids[seq_length:] = mask_position block_position_ids = torch.cat(( torch.zeros(seq_length, dtype=torch.long, device=device), torch.arange(context_length - seq_length, dtype=torch.long, device=device) + 1 )) position_ids = torch.stack((position_ids, block_position_ids), dim=0) else: position_ids = torch.arange(context_length, dtype=torch.long, device=device) if not gmask: position_ids[context_length - 1:] = mask_position position_ids = position_ids.unsqueeze(0) return attention_mask, position_ids def prepare_one_sample(self, input_id, mask_token, past, past_key_values, use_gmask): seq = input_id.tolist() mask_position = seq.index(mask_token) if mask_token not in seq: raise ValueError("You have to add either [MASK] or [gMASK] in your input") # only last token for input_ids if past is not None if past is not None or past_key_values is not None: context_length = seq.index(150004) last_token = input_id[-1].unsqueeze(-1).unsqueeze(0) # 2 dim proc_input_id = last_token if self.position_encoding_2d: position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long, device=input_id.device) else: position_ids = torch.tensor([[mask_position]], dtype=torch.long, device=input_id.device) attention_mask = torch.zeros(1, 1, 1, 1, device=input_id.device) else: proc_input_id = input_id.unsqueeze(0) attention_mask, position_ids = self.get_masks_and_position_ids( seq=seq, mask_position=mask_position, context_length=len(seq), device=input_id.device, gmask=use_gmask ) return (proc_input_id.to(torch.int32), position_ids.to(torch.int32), attention_mask.to(torch.bool)) def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, past: Optional[torch.Tensor] = None, past_key_values: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, use_cache: bool = None, **kwargs ) -> dict: MASK, gMASK = 150000, 150001 mask_token = MASK if MASK in input_ids else gMASK use_gmask = False if MASK in input_ids else gMASK batch_input_ids, batch_position_ids, batch_attention_mask = [], [], [] for input_id in input_ids: proc_input_id, position_id, attention_mask = self.prepare_one_sample( input_id, mask_token, past, past_key_values, use_gmask) batch_input_ids.append(proc_input_id) batch_position_ids.append(position_id) batch_attention_mask.append(attention_mask) batch_input_ids = torch.vstack(batch_input_ids) batch_position_ids = torch.vstack(batch_position_ids) batch_attention_mask = torch.vstack(batch_attention_mask) if past is None: past = past_key_values if past is not None or past_key_values is not None: self.kernel.set_context_mode(False) else: self.kernel.set_context_mode(self.config.use_cache) return { "input_ids": batch_input_ids, "past_key_values": past_key_values, "position_ids": batch_position_ids, "attention_mask": batch_attention_mask }