File size: 4,542 Bytes
			
			| 8a25c0b 03493f1 8a25c0b 03493f1 8a25c0b 03493f1 8a25c0b 03493f1 5c4f06c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional
from torch import Tensor
from transformers import PretrainedConfig, PreTrainedModel
# ---------------- CONFIG ---------------- #
class BlaserConfig(PretrainedConfig):
    model_type = "blaser"
    def __init__(
        self,
        embedding_dim=1024,
        output_dim=1,
        hidden_dims=None,
        dropout=0.1,
        activation="TANH",
        input_form="QE",
        norm_emb=True,
        output_act=False,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.embedding_dim = embedding_dim
        self.output_dim = output_dim
        self.hidden_dims = hidden_dims if hidden_dims is not None else [3072, 1536]
        self.dropout = dropout
        self.activation = activation
        self.input_form = input_form
        self.norm_emb = norm_emb
        self.output_act = output_act
# ---------------- CORE MODEL ---------------- #
ACTIVATIONS = {"TANH": nn.Tanh, "RELU": nn.ReLU}
class BlaserCore(nn.Module):
    def __init__(
        self,
        embedding_dim: int,
        output_dim: int,
        hidden_dims: List[int],
        dropout: float,
        activation: str,
        input_form: str,
        norm_emb: bool,
        output_act: bool,
    ):
        super().__init__()
        self.input_form = input_form
        self.norm_emb = norm_emb
        if input_form == "COMET":
            embedding_dim *= 6
        elif input_form == "QE":
            embedding_dim *= 4
        else:
            raise ValueError(f"Unrecognized input_form: {input_form}")
        if activation not in ACTIVATIONS:
            raise ValueError(f"Unrecognized activation: {activation}")
        modules: List[nn.Module] = []
        if hidden_dims:
            if dropout > 0:
                modules.append(nn.Dropout(p=dropout))
            nprev = embedding_dim
            for h in hidden_dims:
                modules.append(nn.Linear(nprev, h))
                modules.append(ACTIVATIONS[activation]())
                if dropout > 0:
                    modules.append(nn.Dropout(p=dropout))
                nprev = h
            modules.append(nn.Linear(nprev, output_dim))
            if output_act:
                modules.append(nn.Tanh())
        else:
            modules.append(nn.Linear(embedding_dim, output_dim))
        self.mlp = nn.Sequential(*modules)
    def _norm(self, emb: Optional[Tensor]) -> Optional[Tensor]:
        return F.normalize(emb) if (emb is not None and self.norm_emb) else emb
    def _featurize(self, src: Tensor, mt: Tensor, ref: Optional[Tensor] = None) -> Tensor:
        if self.input_form == "COMET":
            if ref is None:
                raise ValueError("COMET input_form requires reference embedding")
            return torch.cat(
                [ref, mt, src * mt, ref * mt, torch.abs(mt - src), torch.abs(mt - ref)],
                dim=-1,
            )
        elif self.input_form == "QE":
            return torch.cat([src, mt, src * mt, torch.abs(mt - src)], dim=-1)
# ---------------- HF MODEL WRAPPER ---------------- #
class BlaserModel(PreTrainedModel):
    config_class = BlaserConfig
    def __init__(self, config: BlaserConfig):
        super().__init__(config)
        # Directly assign the Sequential MLP to self.mlp
        core = BlaserCore(
            embedding_dim=config.embedding_dim,
            output_dim=config.output_dim,
            hidden_dims=config.hidden_dims,
            dropout=config.dropout,
            activation=config.activation,
            input_form=config.input_form,
            norm_emb=config.norm_emb,
            output_act=config.output_act,
        )
        self.mlp = core.mlp
        self.input_form = core.input_form
        self.norm_emb = core.norm_emb
    def forward(self, src, mt, ref=None):
        # Use the same featurization as in BlaserCore
        src = F.normalize(src) if self.norm_emb else src
        mt = F.normalize(mt) if self.norm_emb else mt
        ref = F.normalize(ref) if (ref is not None and self.norm_emb) else ref
        if self.input_form == "COMET":
            if ref is None:
                raise ValueError("COMET input_form requires reference embedding")
            proc = torch.cat(
                [ref, mt, src * mt, ref * mt, torch.abs(mt - src), torch.abs(mt - ref)],
                dim=-1,
            )
        else:  # QE
            proc = torch.cat([src, mt, src * mt, torch.abs(mt - src)], dim=-1)
        return self.mlp(proc) |