|
|
import math
|
|
|
import torch
|
|
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
|
|
from transformers import PreTrainedModel
|
|
|
|
|
|
from .PreTrainedRMTConfig import PreTrainedRMTConfig
|
|
|
|
|
|
class MemoryCell(torch.nn.Module):
|
|
|
"""Holds memory tensors.
|
|
|
Replicates memory tensor for each batch size.
|
|
|
Adds memory tokens to the input tensor and returns that tensor.
|
|
|
Processes the model output and returns a new memory state.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
torch : _type_
|
|
|
_description_
|
|
|
"""
|
|
|
|
|
|
def __init__(self, base_model, num_mem_tokens):
|
|
|
super().__init__()
|
|
|
self.model = base_model
|
|
|
self.create_memory(num_mem_tokens)
|
|
|
self.config = base_model.config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_memory(self, num_mem_tokens):
|
|
|
"""Randomly initializes an embedding matrix (tensor) for memory tokens and registers it for gradient computation.
|
|
|
Sets read and write positions for memory tokens.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
num_mem_tokens : _type_
|
|
|
Number of memory tokens.
|
|
|
"""
|
|
|
self.read_memory_position = range(num_mem_tokens)
|
|
|
self.write_memory_position = range(-num_mem_tokens, 0)
|
|
|
|
|
|
self.num_mem_tokens = num_mem_tokens
|
|
|
embeddings = self.model.get_input_embeddings()
|
|
|
memory_dim = getattr(self.model.config, "n_embd", self.model.config.hidden_size)
|
|
|
memory_weights = (
|
|
|
torch.randn((num_mem_tokens, memory_dim))
|
|
|
)
|
|
|
|
|
|
self.register_parameter(
|
|
|
"memory", torch.nn.Parameter(memory_weights, requires_grad=True)
|
|
|
)
|
|
|
|
|
|
def set_memory(self, input_shape):
|
|
|
"""Replicates memory tensor for each batch size
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
input_shape : _type_
|
|
|
_description_
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
_type_
|
|
|
Replicated memory tensor. (batch_size, num_mem_tokens, memory_dim)
|
|
|
"""
|
|
|
memory = self.memory.repeat(
|
|
|
input_shape[0], 1, 1
|
|
|
)
|
|
|
return memory
|
|
|
|
|
|
def forward(self, input_ids, memory_state=None, **kwargs):
|
|
|
"""Performs inference.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
input_ids : torch.Tensor
|
|
|
Input tensor.
|
|
|
memory_state : torch.Tensor, optional
|
|
|
Memory tensor, by default None (num_mem_tokens, memory_dim)
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
tuple(tuple, torch.Tensor)
|
|
|
out : tuple
|
|
|
Model output.
|
|
|
new_memory_state : torch.Tensor
|
|
|
New memory state.
|
|
|
"""
|
|
|
if memory_state is None:
|
|
|
|
|
|
memory_state = self.set_memory(input_ids.shape)
|
|
|
|
|
|
|
|
|
seg_kwargs = self.process_input(input_ids, memory_state, **kwargs)
|
|
|
out = self.model(**seg_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
out, new_memory_state = self.process_output(out, **kwargs)
|
|
|
|
|
|
return out, new_memory_state
|
|
|
|
|
|
def process_input(self, input_ids, memory_state, **kwargs):
|
|
|
"""Adds memory tokens to the input tensor and returns that tensor
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
input_ids : _type_
|
|
|
Input tensor.
|
|
|
memory_state : _type_
|
|
|
Memory tensor.
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
_type_
|
|
|
Input tensor with added memory tokens. (batch_size, seq_len, hidden_size)
|
|
|
"""
|
|
|
seg_kwargs = dict(**kwargs)
|
|
|
|
|
|
inputs_embeds = kwargs.get("inputs_embeds")
|
|
|
if inputs_embeds is None:
|
|
|
inputs_embeds = self.model.get_input_embeddings()(input_ids)
|
|
|
if inputs_embeds.shape[0] != memory_state.shape[0]:
|
|
|
memory_state = self.set_memory(inputs_embeds.shape)
|
|
|
|
|
|
|
|
|
inputs_embeds = torch.cat(
|
|
|
[memory_state, inputs_embeds, memory_state], dim=1
|
|
|
).to(input_ids.device)
|
|
|
"""
|
|
|
# token_type_idsの生成
|
|
|
token_type_ids = torch.zeros_like(inputs_embeds[:, :, 0], dtype=torch.long)
|
|
|
token_type_ids[:, self.num_mem_tokens:-self.num_mem_tokens] = 1
|
|
|
|
|
|
# token_type_embeddingsの追加と入力の更新
|
|
|
token_type_embeds = self.token_type_embeddings(token_type_ids)
|
|
|
inputs_embeds = inputs_embeds + token_type_embeds
|
|
|
"""
|
|
|
|
|
|
seg_kwargs["input_ids"] = None
|
|
|
seg_kwargs["inputs_embeds"] = inputs_embeds
|
|
|
if kwargs.get("attention_mask") is not None:
|
|
|
seg_kwargs["attention_mask"] = self.pad_attention_mask(
|
|
|
kwargs["attention_mask"], inputs_embeds.shape
|
|
|
)
|
|
|
seg_kwargs["output_hidden_states"] = True
|
|
|
|
|
|
|
|
|
pos_mem1 = torch.arange(self.num_mem_tokens, device=input_ids.device)
|
|
|
pos_mem2 = torch.arange(self.num_mem_tokens, self.num_mem_tokens * 2, device=input_ids.device)
|
|
|
pos_seg = torch.arange(self.num_mem_tokens * 2, self.num_mem_tokens * 2 + input_ids.shape[1], device=input_ids.device)
|
|
|
pos = torch.cat([pos_mem1, pos_seg, pos_mem2], dim=0)
|
|
|
pos = pos.unsqueeze(0).expand(input_ids.shape[0], -1)
|
|
|
seg_kwargs["position_ids"] = pos
|
|
|
|
|
|
return seg_kwargs
|
|
|
|
|
|
def pad_attention_mask(self, attention_mask, shape):
|
|
|
if self.num_mem_tokens in {0, None}:
|
|
|
return attention_mask
|
|
|
else:
|
|
|
attention_mask = torch.cat(
|
|
|
[
|
|
|
torch.ones(
|
|
|
shape[0], self.num_mem_tokens, device=attention_mask.device
|
|
|
),
|
|
|
attention_mask,
|
|
|
torch.ones(
|
|
|
shape[0], self.num_mem_tokens, device=attention_mask.device
|
|
|
),
|
|
|
],
|
|
|
dim=1,
|
|
|
)
|
|
|
return attention_mask
|
|
|
|
|
|
def compute_logpi(mean, stddev, action):
|
|
|
a1 =-0.5 * torch.log(2*torch.fill(stddev.shape, math.pi))
|
|
|
a2 = -torch.log(stddev)
|
|
|
a3 = -0.5 * (((action - mean) / stddev) ** 2)
|
|
|
return a1 + a2 + a3
|
|
|
|
|
|
def process_output(self, model_outputs, **kwargs):
|
|
|
if self.num_mem_tokens not in {0, None}:
|
|
|
out = CausalLMOutputWithCrossAttentions()
|
|
|
memory_state = model_outputs.hidden_states[-1][:, -self.num_mem_tokens :]
|
|
|
out["logits"] = model_outputs.logits[
|
|
|
:, self.num_mem_tokens : -self.num_mem_tokens
|
|
|
]
|
|
|
|
|
|
if kwargs.get("output_hidden_states"):
|
|
|
out["hidden_states"] = [
|
|
|
lh[:, self.num_mem_tokens : -self.num_mem_tokens]
|
|
|
for lh in model_outputs.hidden_states
|
|
|
]
|
|
|
if kwargs.get("output_attentions"):
|
|
|
out["attentions"] = model_outputs["attentions"]
|
|
|
else:
|
|
|
memory_state = None
|
|
|
out = model_outputs
|
|
|
|
|
|
return out, memory_state
|
|
|
|
|
|
def generate(self, input_ids, memory_state, attention_mask, **generate_kwargs):
|
|
|
if memory_state is None:
|
|
|
memory_state = self.set_memory(input_ids.shape)
|
|
|
|
|
|
seg_kwargs = self.process_input(input_ids, memory_state, attention_mask=attention_mask)
|
|
|
out = self.model.generate(inputs_embeds=seg_kwargs['inputs_embeds'], attention_mask=seg_kwargs['attention_mask'], **generate_kwargs)
|
|
|
return out |