|
|
import torch |
|
|
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig |
|
|
from transformers.models.auto.auto_factory import _BaseAutoModelClass |
|
|
from .MemoryCell import MemoryCell |
|
|
from .RecurrentWrapper import RecurrentWrapper |
|
|
from .PreTrainedRMTConfig import PreTrainedRMTConfig |
|
|
|
|
|
|
|
|
|
|
|
class RecurrentMemoryTransformer(PreTrainedModel): |
|
|
""" |
|
|
Recurrent Memory Transformer Model Class |
|
|
A transformer model that processes long context in segments and retains information using memory |
|
|
""" |
|
|
|
|
|
config_class = PreTrainedRMTConfig |
|
|
auto_model_class = "AutoModelForCausalLM" |
|
|
|
|
|
|
|
|
_keys_to_ignore_on_load_missing = [r"position_ids"] |
|
|
|
|
|
|
|
|
AUTO_MAP = { |
|
|
"AutoModelForCausalLM": "RecurrentMemoryTransformer", |
|
|
} |
|
|
|
|
|
def __init__(self, config, base_model=None): |
|
|
""" |
|
|
Initialization |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
config : PreTrainedRMTConfig |
|
|
Model configuration |
|
|
base_model : PreTrainedModel, optional |
|
|
Base transformer model |
|
|
""" |
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
if base_model is None: |
|
|
|
|
|
if not hasattr(config, "base_model_type"): |
|
|
raise ValueError("configにbase_model_typeが指定されていません。RMTの設定にはベースモデルタイプが必要です。") |
|
|
base_model_type = config.base_model_type |
|
|
|
|
|
|
|
|
base_config = AutoConfig.from_pretrained(base_model_type) |
|
|
|
|
|
|
|
|
rmt_specific_params = ['model_type', 'is_memory_all', 'max_n_segments', 'input_seg_len', |
|
|
'output_seg_len', 'align', 'num_mem_tokens', 'base_model_type'] |
|
|
for key, value in config.__dict__.items(): |
|
|
if key not in rmt_specific_params and not key.startswith('_'): |
|
|
setattr(base_config, key, value) |
|
|
|
|
|
|
|
|
base_model = AutoModelForCausalLM.from_config(base_config) |
|
|
|
|
|
|
|
|
memory_cell = MemoryCell(base_model, config.num_mem_tokens) |
|
|
self.recurrent_wrapper = RecurrentWrapper( |
|
|
memory_cell=memory_cell, |
|
|
is_memory_all=config.is_memory_all, |
|
|
max_n_segments=config.max_n_segments, |
|
|
input_seg_len=config.input_seg_len, |
|
|
output_seg_len=config.output_seg_len, |
|
|
align=config.align |
|
|
) |
|
|
|
|
|
def get_base_model(self): |
|
|
""" |
|
|
Get the base model |
|
|
""" |
|
|
return self.recurrent_wrapper.memory_cell.model |
|
|
|
|
|
def forward(self, input_ids=None, attention_mask=None, labels=None, labels_mask=None, |
|
|
inputs_embeds=None, output_attentions=None, output_hidden_states=None): |
|
|
""" |
|
|
Forward pass of the model |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
input_ids : torch.Tensor, optional |
|
|
Input tensor |
|
|
attention_mask : torch.Tensor, optional |
|
|
Attention mask |
|
|
labels : torch.Tensor, optional |
|
|
Label tensor |
|
|
labels_mask : torch.Tensor, optional |
|
|
Label mask |
|
|
inputs_embeds : torch.Tensor, optional |
|
|
Input embeddings |
|
|
output_attentions : bool, optional |
|
|
Whether to output attention weights |
|
|
output_hidden_states : bool, optional |
|
|
Whether to output hidden states |
|
|
""" |
|
|
forward_kwargs = {} |
|
|
if input_ids is not None: |
|
|
forward_kwargs["input_ids"] = input_ids |
|
|
if labels is not None: |
|
|
forward_kwargs["labels"] = labels |
|
|
if attention_mask is not None: |
|
|
forward_kwargs["attention_mask"] = attention_mask |
|
|
if labels_mask is not None: |
|
|
forward_kwargs["labels_mask"] = labels_mask |
|
|
if inputs_embeds is not None: |
|
|
forward_kwargs["inputs_embeds"] = inputs_embeds |
|
|
if output_attentions is not None: |
|
|
forward_kwargs["output_attentions"] = output_attentions |
|
|
if output_hidden_states is not None: |
|
|
forward_kwargs["output_hidden_states"] = output_hidden_states |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out = self.recurrent_wrapper.forward(**forward_kwargs) |
|
|
""" |
|
|
# デバッグ出力を削除(または必要に応じてコメント化) |
|
|
# print(out["loss"]) |
|
|
|
|
|
# 分散環境で損失が二重計算されないよう、ワールドサイズで割る |
|
|
# これは処理済みの場合は不要なので、環境変数などで制御することも可能 |
|
|
if torch.distributed.is_initialized() and "loss" in out and out["loss"] is not None: |
|
|
# 既にDeepSpeedが処理している可能性があるため、確認が必要 |
|
|
# テスト目的で一時的に追加(実際の環境に合わせて調整が必要) |
|
|
# world_size = torch.distributed.get_world_size() |
|
|
# out["loss"] = out["loss"] / world_size |
|
|
pass |
|
|
""" |
|
|
return out |
|
|
|
|
|
def generate(self, **kwargs): |
|
|
""" |
|
|
Text generation |
|
|
""" |
|
|
return self.recurrent_wrapper.generate(**kwargs) |
|
|
|
|
|
def generate_with_tokenizer(self, tokenizer, input_text, **kwargs): |
|
|
""" |
|
|
Text generation using tokenizer |
|
|
""" |
|
|
return self.recurrent_wrapper.generate_with_tokenizer(tokenizer, input_text, **kwargs) |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
""" |
|
|
Get input embeddings |
|
|
""" |
|
|
return self.get_base_model().get_input_embeddings() |
|
|
|
|
|
def set_input_embeddings(self, embeddings): |
|
|
""" |
|
|
Set input embeddings |
|
|
""" |
|
|
self.get_base_model().set_input_embeddings(embeddings) |
|
|
|
|
|
def get_output_embeddings(self): |
|
|
""" |
|
|
Get output embeddings |
|
|
""" |
|
|
return self.get_base_model().get_output_embeddings() |
|
|
|
|
|
def resize_token_embeddings(self, new_num_tokens): |
|
|
""" |
|
|
Resize token embeddings |
|
|
""" |
|
|
self.get_base_model().resize_token_embeddings(new_num_tokens) |
|
|
return self.get_input_embeddings() |
|
|
|
|
|
RecurrentMemoryTransformer.register_for_auto_class("AutoModelForCausalLM") |