|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers import AutoModel |
|
from transformers import PreTrainedModel, PretrainedConfig, AutoModel, AutoConfig |
|
from safetensors.torch import load_file |
|
|
|
import typing as T |
|
|
|
|
|
class Mixer(nn.Module): |
|
def __init__(self, in_d: int=1536): |
|
super(Mixer, self).__init__() |
|
self.layers = nn.Sequential( |
|
nn.Linear(in_d, in_d), |
|
nn.ReLU(), |
|
nn.Linear(in_d, in_d), |
|
nn.ReLU(), |
|
nn.Linear(in_d, in_d), |
|
nn.ReLU(), |
|
nn.Linear(in_d, in_d), |
|
nn.ReLU(), |
|
nn.Linear(in_d, in_d), |
|
nn.ReLU(), |
|
nn.Linear(in_d, in_d) |
|
|
|
) |
|
|
|
def forward(self, x): |
|
return self.layers(x) |
|
|
|
|
|
def get_sequence_embeddings(mask, model_output): |
|
mask = mask.float() |
|
d = {k: v for k, v in torch.nonzero(mask).cpu().numpy()} |
|
|
|
for i in d: |
|
mask[i, d[i]] = 0 |
|
mask[:, 0] = 0.0 |
|
mask = mask.unsqueeze(-1).expand(model_output.last_hidden_state.size()) |
|
sum_embeddings = torch.sum(model_output.last_hidden_state * mask, 1) |
|
sum_mask = torch.clamp(mask.sum(1), min=1e-9) |
|
return sum_embeddings / sum_mask |
|
|
|
|
|
class AbLangPairedConfig(PretrainedConfig): |
|
model_type = "ablang_paired" |
|
|
|
def __init__( |
|
self, |
|
checkpoint_filename: str, |
|
heavy_model_id='qilowoq/AbLang_heavy', |
|
heavy_revision='ecac793b0493f76590ce26d48f7aac4912de8717', |
|
light_model_id='qilowoq/AbLang_light', |
|
light_revision='ce0637166f5e6e271e906d29a8415d9fdc30e377', |
|
mixer_hidden_dim: int = 1536, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.checkpoint_filename = checkpoint_filename |
|
self.heavy_model_id = heavy_model_id |
|
self.heavy_revision = heavy_revision |
|
self.light_model_id = light_model_id |
|
self.light_revision = light_revision |
|
self.mixer_hidden_dim = mixer_hidden_dim |
|
|
|
|
|
class AbLangPaired(PreTrainedModel): |
|
|
|
def __init__(self, personal_config: AbLangPairedConfig, device: T.Union[str, torch.device] = "cpu"): |
|
|
|
|
|
self.config = AutoConfig.from_pretrained(personal_config.heavy_model_id, revision=personal_config.heavy_revision) |
|
super().__init__(self.config) |
|
|
|
|
|
self.roberta_heavy = AutoModel.from_pretrained( |
|
personal_config.heavy_model_id, |
|
revision=personal_config.heavy_revision, |
|
trust_remote_code=True |
|
) |
|
|
|
self.roberta_light = AutoModel.from_pretrained( |
|
personal_config.light_model_id, |
|
revision=personal_config.light_revision, |
|
trust_remote_code=True |
|
) |
|
|
|
self.mixer = Mixer(in_d=1536) |
|
|
|
|
|
if personal_config.checkpoint_filename.endswith('.safetensors'): |
|
state_dict = load_file(personal_config.checkpoint_filename) |
|
else: |
|
state_dict = torch.load(personal_config.checkpoint_filename, map_location=device) |
|
|
|
load_result = self.load_state_dict(state_dict, strict=False) |
|
self.to(device) |
|
self.eval() |
|
|
|
def forward(self, h_input_ids, h_attention_mask, l_input_ids, l_attention_mask, **kwargs): |
|
|
|
outputs_h = self.roberta_heavy(input_ids=h_input_ids.to(torch.int64), attention_mask=h_attention_mask) |
|
outputs_l = self.roberta_light(input_ids=l_input_ids.to(torch.int64), attention_mask=l_attention_mask) |
|
|
|
|
|
pooled_output_h = get_sequence_embeddings(h_attention_mask, outputs_h) |
|
pooled_output_l = get_sequence_embeddings(l_attention_mask, outputs_l) |
|
|
|
|
|
pooled_output = torch.cat([pooled_output_h, pooled_output_l], dim=1) |
|
pooled_output = self.mixer(pooled_output) |
|
embedding = F.normalize(pooled_output, p=2, dim=1) |
|
return embedding |
|
|