ajyl's picture
Upload CubeLM
0853b2a verified
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from .CubeConfig import CubeConfig
from transformers import (
GPT2Model,
GPT2LMHeadModel,
GenerationMixin,
GPT2PreTrainedModel,
PreTrainedModel,
)
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.utils import ModelOutput
IGNORE_INDEX = -100
@dataclass
class CubeLMOutput(CausalLMOutputWithCrossAttentions):
# loss: Optional[torch.FloatTensor] = None
lm_loss: Optional[torch.FloatTensor] = None
cube_loss: Optional[torch.FloatTensor] = None
# logits: Optional[torch.FloatTensor] = None
cube_logits: Optional[torch.FloatTensor] = None
class CubeLM(GPT2LMHeadModel):
def __init__(self, config, task="sft", num_heads=24, num_classes=6):
super().__init__(config)
assert task in ["sft", "pretrain", "joint"]
self.task = task
self.alpha = None
if hasattr(config, "alpha"):
self.alpha = config.alpha
self.vocab_size = config.vocab_size
self.cube_heads = None
if task in ["pretrain", "joint"]:
self.cube_heads = nn.Linear(
config.n_embd, num_heads * num_classes, bias=False
)
self.num_heads = num_heads
self.num_classes = num_classes
self.config = config
def forward(
self, input_ids, attention_mask=None, labels=None, cube_states=None, **kwargs
):
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
)
# [batch, seq, d_model]
hidden_states = outputs.last_hidden_state
lm_logits = None
lm_loss = None
# [batch, seq, n_heads * n_classes]
if self.task in ["sft", "joint"]:
lm_logits = self.lm_head(hidden_states)
if labels is not None:
shift_logits = lm_logits[:, :-1, :].contiguous()
shift_labels = input_ids[:, 1:].contiguous()
loss_fn = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)
lm_loss = loss_fn(
shift_logits.view(-1, self.vocab_size),
shift_labels.view(-1),
)
cube_logits = None
cube_loss = None
if self.cube_heads:
cube_logits = self.cube_heads(hidden_states)
if cube_states is not None:
cube_logits = cube_logits.view(
hidden_states.size(0),
hidden_states.size(1),
self.num_heads,
self.num_classes,
)
# [batch * seq, n_heads, n_classes]
_logits = cube_logits.view(-1, self.num_heads, self.num_classes)
# [batch * seq, n_heads]
_labels = cube_states.view(-1, self.num_heads)
loss_fn = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)
losses = []
for head_idx in range(self.num_heads):
losses.append(
loss_fn(_logits[:, head_idx, :], _labels[:, head_idx])
)
cube_loss = sum(losses) / self.num_heads
total_loss = None
if lm_loss is not None and cube_loss is not None:
assert self.alpha is not None
total_loss = lm_loss + self.alpha * cube_loss
elif lm_loss is not None:
total_loss = lm_loss
elif cube_loss is not None:
total_loss = cube_loss
return CubeLMOutput(
loss=total_loss,
lm_loss=lm_loss,
cube_loss=cube_loss,
logits=lm_logits,
cube_logits=cube_logits,
)