optrans / optrans_modeling.py
sandspeare's picture
update
39a3276
raw
history blame contribute delete
No virus
4.83 kB
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
from typing import Optional, Tuple
import torch.nn.functional as F
from transformers import BatchEncoding
from transformers import MPNetTokenizerFast
from transformers.models.roformer.modeling_roformer import (
RoFormerEmbeddings,
RoFormerModel,
RoFormerEncoder,
RoFormerLayer,
RoFormerAttention,
RoFormerIntermediate,
RoFormerOutput,
RoFormerSelfAttention,
RoFormerPreTrainedModel
)
from transformers.models.mpnet.modeling_mpnet import MPNetModel
class JRoFormerEmbeddings(RoFormerEmbeddings):
"""Construct the embeddings from word and token_type embeddings."""
def __init__(self, config):
super().__init__(config)
self.word_embeddings = nn.Embedding(
config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id
)
self.token_type_embeddings = self.word_embeddings
class JRoFormerSelfAttention(RoFormerSelfAttention):
def __init__(self, config):
super().__init__(config)
self.query = nn.Linear(
config.hidden_size, self.all_head_size, bias=config.use_bias
)
self.key = nn.Linear(
config.hidden_size, self.all_head_size, bias=config.use_bias
)
self.value = nn.Linear(
config.hidden_size, self.all_head_size, bias=config.use_bias
)
class JRoFormerAttention(RoFormerAttention):
def __init__(self, config):
super().__init__(config)
self.self = JRoFormerSelfAttention(config)
class JRoFormerLayer(RoFormerLayer):
def __init__(self, config):
super().__init__(config)
self.attention = JRoFormerAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(
f"{self} should be used as a decoder model if cross attention is added"
)
self.crossattention = RoFormerAttention(config)
self.intermediate = RoFormerIntermediate(config)
self.output = RoFormerOutput(config)
class JRoFormerEncoder(RoFormerEncoder):
def __init__(self, config):
super().__init__(config)
self.layer = nn.ModuleList(
[JRoFormerLayer(config) for _ in range(config.num_hidden_layers)]
)
class JRoFormerModel(RoFormerModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.embeddings = JRoFormerEmbeddings(config)
if config.embedding_size != config.hidden_size:
self.embeddings_project = nn.Linear(
config.embedding_size, config.hidden_size
)
self.encoder = JRoFormerEncoder(config)
# Initialize weights and apply final processing
self.post_init()
class AsmEncoder(RoFormerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.roformer = JRoFormerModel(config)
self.projection = nn.Linear(config.hidden_size, config.bla_dim)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.roformer(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
token_embeddings = outputs[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype)
asm_embedding = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
asm_embedding = self.projection(asm_embedding)
asm_embedding = F.normalize(asm_embedding, p=2, dim=1)
return asm_embedding