|
import numpy as np |
|
import torch |
|
import torch.utils.checkpoint |
|
from torch import nn |
|
from typing import Optional, Tuple |
|
import torch.nn.functional as F |
|
from transformers import BatchEncoding |
|
from transformers import MPNetTokenizerFast |
|
|
|
|
|
from transformers.models.roformer.modeling_roformer import ( |
|
RoFormerEmbeddings, |
|
RoFormerModel, |
|
RoFormerEncoder, |
|
RoFormerLayer, |
|
RoFormerAttention, |
|
RoFormerIntermediate, |
|
RoFormerOutput, |
|
RoFormerSelfAttention, |
|
RoFormerPreTrainedModel |
|
) |
|
|
|
from transformers.models.mpnet.modeling_mpnet import MPNetModel |
|
|
|
|
|
class JRoFormerEmbeddings(RoFormerEmbeddings): |
|
"""Construct the embeddings from word and token_type embeddings.""" |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.word_embeddings = nn.Embedding( |
|
config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id |
|
) |
|
self.token_type_embeddings = self.word_embeddings |
|
|
|
|
|
class JRoFormerSelfAttention(RoFormerSelfAttention): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.query = nn.Linear( |
|
config.hidden_size, self.all_head_size, bias=config.use_bias |
|
) |
|
self.key = nn.Linear( |
|
config.hidden_size, self.all_head_size, bias=config.use_bias |
|
) |
|
self.value = nn.Linear( |
|
config.hidden_size, self.all_head_size, bias=config.use_bias |
|
) |
|
|
|
|
|
class JRoFormerAttention(RoFormerAttention): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.self = JRoFormerSelfAttention(config) |
|
|
|
|
|
class JRoFormerLayer(RoFormerLayer): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.attention = JRoFormerAttention(config) |
|
self.is_decoder = config.is_decoder |
|
self.add_cross_attention = config.add_cross_attention |
|
if self.add_cross_attention: |
|
if not self.is_decoder: |
|
raise ValueError( |
|
f"{self} should be used as a decoder model if cross attention is added" |
|
) |
|
self.crossattention = RoFormerAttention(config) |
|
self.intermediate = RoFormerIntermediate(config) |
|
self.output = RoFormerOutput(config) |
|
|
|
|
|
class JRoFormerEncoder(RoFormerEncoder): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.layer = nn.ModuleList( |
|
[JRoFormerLayer(config) for _ in range(config.num_hidden_layers)] |
|
) |
|
|
|
|
|
class JRoFormerModel(RoFormerModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
self.embeddings = JRoFormerEmbeddings(config) |
|
|
|
if config.embedding_size != config.hidden_size: |
|
self.embeddings_project = nn.Linear( |
|
config.embedding_size, config.hidden_size |
|
) |
|
|
|
self.encoder = JRoFormerEncoder(config) |
|
|
|
|
|
self.post_init() |
|
|
|
class AsmEncoder(RoFormerPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
self.roformer = JRoFormerModel(config) |
|
self.projection = nn.Linear(config.hidden_size, config.bla_dim) |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
encoder_attention_mask: Optional[torch.FloatTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
): |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.roformer( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
token_embeddings = outputs[0] |
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype) |
|
asm_embedding = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
asm_embedding = self.projection(asm_embedding) |
|
asm_embedding = F.normalize(asm_embedding, p=2, dim=1) |
|
|
|
return asm_embedding |
|
|