File size: 4,542 Bytes
8a25c0b 03493f1 8a25c0b 03493f1 8a25c0b 03493f1 8a25c0b 03493f1 5c4f06c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional
from torch import Tensor
from transformers import PretrainedConfig, PreTrainedModel
# ---------------- CONFIG ---------------- #
class BlaserConfig(PretrainedConfig):
model_type = "blaser"
def __init__(
self,
embedding_dim=1024,
output_dim=1,
hidden_dims=None,
dropout=0.1,
activation="TANH",
input_form="QE",
norm_emb=True,
output_act=False,
**kwargs,
):
super().__init__(**kwargs)
self.embedding_dim = embedding_dim
self.output_dim = output_dim
self.hidden_dims = hidden_dims if hidden_dims is not None else [3072, 1536]
self.dropout = dropout
self.activation = activation
self.input_form = input_form
self.norm_emb = norm_emb
self.output_act = output_act
# ---------------- CORE MODEL ---------------- #
ACTIVATIONS = {"TANH": nn.Tanh, "RELU": nn.ReLU}
class BlaserCore(nn.Module):
def __init__(
self,
embedding_dim: int,
output_dim: int,
hidden_dims: List[int],
dropout: float,
activation: str,
input_form: str,
norm_emb: bool,
output_act: bool,
):
super().__init__()
self.input_form = input_form
self.norm_emb = norm_emb
if input_form == "COMET":
embedding_dim *= 6
elif input_form == "QE":
embedding_dim *= 4
else:
raise ValueError(f"Unrecognized input_form: {input_form}")
if activation not in ACTIVATIONS:
raise ValueError(f"Unrecognized activation: {activation}")
modules: List[nn.Module] = []
if hidden_dims:
if dropout > 0:
modules.append(nn.Dropout(p=dropout))
nprev = embedding_dim
for h in hidden_dims:
modules.append(nn.Linear(nprev, h))
modules.append(ACTIVATIONS[activation]())
if dropout > 0:
modules.append(nn.Dropout(p=dropout))
nprev = h
modules.append(nn.Linear(nprev, output_dim))
if output_act:
modules.append(nn.Tanh())
else:
modules.append(nn.Linear(embedding_dim, output_dim))
self.mlp = nn.Sequential(*modules)
def _norm(self, emb: Optional[Tensor]) -> Optional[Tensor]:
return F.normalize(emb) if (emb is not None and self.norm_emb) else emb
def _featurize(self, src: Tensor, mt: Tensor, ref: Optional[Tensor] = None) -> Tensor:
if self.input_form == "COMET":
if ref is None:
raise ValueError("COMET input_form requires reference embedding")
return torch.cat(
[ref, mt, src * mt, ref * mt, torch.abs(mt - src), torch.abs(mt - ref)],
dim=-1,
)
elif self.input_form == "QE":
return torch.cat([src, mt, src * mt, torch.abs(mt - src)], dim=-1)
# ---------------- HF MODEL WRAPPER ---------------- #
class BlaserModel(PreTrainedModel):
config_class = BlaserConfig
def __init__(self, config: BlaserConfig):
super().__init__(config)
# Directly assign the Sequential MLP to self.mlp
core = BlaserCore(
embedding_dim=config.embedding_dim,
output_dim=config.output_dim,
hidden_dims=config.hidden_dims,
dropout=config.dropout,
activation=config.activation,
input_form=config.input_form,
norm_emb=config.norm_emb,
output_act=config.output_act,
)
self.mlp = core.mlp
self.input_form = core.input_form
self.norm_emb = core.norm_emb
def forward(self, src, mt, ref=None):
# Use the same featurization as in BlaserCore
src = F.normalize(src) if self.norm_emb else src
mt = F.normalize(mt) if self.norm_emb else mt
ref = F.normalize(ref) if (ref is not None and self.norm_emb) else ref
if self.input_form == "COMET":
if ref is None:
raise ValueError("COMET input_form requires reference embedding")
proc = torch.cat(
[ref, mt, src * mt, ref * mt, torch.abs(mt - src), torch.abs(mt - ref)],
dim=-1,
)
else: # QE
proc = torch.cat([src, mt, src * mt, torch.abs(mt - src)], dim=-1)
return self.mlp(proc) |