File size: 4,871 Bytes
6e6e48e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# ablangrbd_model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel
from transformers import PreTrainedModel, PretrainedConfig, AutoModel, AutoConfig
from safetensors.torch import load_file

import typing as T


class Mixer(nn.Module):
    def __init__(self, in_d: int=1536):
        super(Mixer, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_d, in_d), # First layer
            nn.ReLU(),             # First activation function
            nn.Linear(in_d, in_d), # Second layer
            nn.ReLU(),             # Second activation function
            nn.Linear(in_d, in_d), # Third layer
            nn.ReLU(),             # Third activation function
            nn.Linear(in_d, in_d), # Fourth layer
            nn.ReLU(),             # Fourth activation function
            nn.Linear(in_d, in_d),  # Fifth layer
            nn.ReLU(),             # Fifth activation function
            nn.Linear(in_d, in_d)      # Output layer
            # No activation here, apply softmax or sigmoid externally if needed, depending on your loss function
        )

    def forward(self, x):
        return self.layers(x)
    

def get_sequence_embeddings(mask, model_output):
    mask = mask.float()
    d = {k: v for k, v in torch.nonzero(mask).cpu().numpy()} # dict of sep tokens k = ab index, v = index of final position where mask = 1 
    # make sep token invisible
    for i in d:
        mask[i, d[i]] = 0
    mask[:, 0] = 0.0 # make cls token invisible
    mask = mask.unsqueeze(-1).expand(model_output.last_hidden_state.size())
    sum_embeddings = torch.sum(model_output.last_hidden_state * mask, 1)
    sum_mask = torch.clamp(mask.sum(1), min=1e-9)
    return sum_embeddings / sum_mask  # sum_mask means length of unmasked positions


class AbLangPairedConfig(PretrainedConfig):
    model_type = "ablang_paired"

    def __init__(
        self,
        checkpoint_filename: str,
        heavy_model_id='qilowoq/AbLang_heavy',
        heavy_revision='ecac793b0493f76590ce26d48f7aac4912de8717',
        light_model_id='qilowoq/AbLang_light',
        light_revision='ce0637166f5e6e271e906d29a8415d9fdc30e377',
        mixer_hidden_dim: int = 1536,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.checkpoint_filename = checkpoint_filename
        self.heavy_model_id = heavy_model_id
        self.heavy_revision = heavy_revision
        self.light_model_id = light_model_id
        self.light_revision = light_revision
        self.mixer_hidden_dim = mixer_hidden_dim


class AbLangPaired(PreTrainedModel):

    def __init__(self, personal_config: AbLangPairedConfig, device: T.Union[str, torch.device] = "cpu"):
        # During training I used the AbLang_heavy config as AbLangPaired's config
        # This may be why it is very hard to integrate this into the Hugging Face AutoModel system
        self.config = AutoConfig.from_pretrained(personal_config.heavy_model_id, revision=personal_config.heavy_revision)
        super().__init__(self.config)
        # super().__init__()

        self.roberta_heavy = AutoModel.from_pretrained(
            personal_config.heavy_model_id,
            revision=personal_config.heavy_revision,  # Specific commit hash
            trust_remote_code=True
        )

        self.roberta_light = AutoModel.from_pretrained(
            personal_config.light_model_id,
            revision=personal_config.light_revision,  # Specific commit hash
            trust_remote_code=True
        )
        
        self.mixer = Mixer(in_d=1536)

        # Load either torch or transformers saved file
        if personal_config.checkpoint_filename.endswith('.safetensors'):
            state_dict = load_file(personal_config.checkpoint_filename)
        else:
            state_dict = torch.load(personal_config.checkpoint_filename, map_location=device)
        
        load_result = self.load_state_dict(state_dict, strict=False)
        self.to(device)
        self.eval()

    def forward(self, h_input_ids, h_attention_mask, l_input_ids, l_attention_mask, **kwargs):
        # Run chains through separate streams
        outputs_h = self.roberta_heavy(input_ids=h_input_ids.to(torch.int64), attention_mask=h_attention_mask)
        outputs_l = self.roberta_light(input_ids=l_input_ids.to(torch.int64), attention_mask=l_attention_mask)

        # Mean pool
        pooled_output_h = get_sequence_embeddings(h_attention_mask, outputs_h)
        pooled_output_l = get_sequence_embeddings(l_attention_mask, outputs_l)

        # Concatenate and then do 6 fully connected layers to pick up on cross-chain features
        pooled_output = torch.cat([pooled_output_h, pooled_output_l], dim=1)
        pooled_output = self.mixer(pooled_output)
        embedding = F.normalize(pooled_output, p=2, dim=1)
        return embedding