|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from dataclasses import dataclass |
|
from transformers import PreTrainedModel |
|
from transformers.utils import ModelOutput |
|
|
|
from .configuration_erbb1_mlp import Erbb1MlpConfig |
|
|
|
|
|
|
|
class FeedForward(nn.Module): |
|
def __init__(self, d_in, d_out): |
|
super().__init__() |
|
self.fc1 = nn.Linear(d_in, d_out * 2) |
|
self.fc2 = nn.Linear(d_out, d_out) |
|
|
|
def forward(self, x): |
|
x = self.fc1(x) |
|
x1, x2 = x.chunk(2, dim=-1) |
|
return self.fc2(F.silu(x1) * x2) |
|
|
|
|
|
class FeedForwardLayer(nn.Module): |
|
def __init__(self, d_in, d_out, dropout=0.1, layer_norm_eps=1e-12): |
|
super().__init__() |
|
self.ff = FeedForward(d_in, d_out) |
|
self.skip = nn.Linear(d_in, d_out) if d_in != d_out else nn.Identity() |
|
self.dropout = nn.Dropout(dropout) |
|
self.norm = nn.LayerNorm(d_out, eps=layer_norm_eps) if layer_norm_eps else nn.Identity() |
|
|
|
def forward(self, x): |
|
y = self.ff(self.dropout(x)) + self.skip(x) |
|
return self.norm(y) |
|
|
|
|
|
|
|
@dataclass |
|
class Erbb1MlpOutput(ModelOutput): |
|
loss: torch.FloatTensor | None = None |
|
prediction: torch.FloatTensor = None |
|
prediction_norm: torch.FloatTensor = None |
|
|
|
|
|
class Erbb1MlpModel(PreTrainedModel): |
|
config_class = Erbb1MlpConfig |
|
|
|
def __init__(self, config: Erbb1MlpConfig): |
|
super().__init__(config) |
|
|
|
layers = [FeedForwardLayer(config.d_in, config.d_hidden, 0.0, config.layer_norm_eps)] |
|
layers += [ |
|
FeedForwardLayer(config.d_hidden, config.d_hidden, config.dropout, config.layer_norm_eps) |
|
for _ in range(config.n_layers - 1) |
|
] |
|
self.body = nn.Sequential(*layers) |
|
self.out_proj = nn.Linear(config.d_hidden, 1) |
|
|
|
|
|
mean = torch.tensor(config.dataset_mean or 0.0, dtype=torch.float32) |
|
std = torch.tensor(config.dataset_std or 1.0, dtype=torch.float32) |
|
self.register_buffer("target_mean", mean, persistent=True) |
|
self.register_buffer("target_std", std, persistent=True) |
|
|
|
self.post_init() |
|
|
|
def forward(self, embedding, labels=None, return_dict=True): |
|
x = self.body(embedding) |
|
pred_norm = self.out_proj(x).squeeze(-1) |
|
pred = pred_norm * self.target_std + self.target_mean |
|
|
|
loss = None |
|
if labels is not None: |
|
loss = F.mse_loss(pred_norm, labels) |
|
|
|
if not return_dict: |
|
return (loss, pred, pred_norm) |
|
|
|
return Erbb1MlpOutput( |
|
loss=loss, |
|
prediction=pred, |
|
prediction_norm=pred_norm, |
|
) |
|
|