optrans / optrans_modeling.py
sandspeare's picture
update
39a3276
raw
history blame
No virus
4.83 kB
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
from typing import Optional, Tuple
import torch.nn.functional as F
from transformers import BatchEncoding
from transformers import MPNetTokenizerFast
from transformers.models.roformer.modeling_roformer import (
RoFormerEmbeddings,
RoFormerModel,
RoFormerEncoder,
RoFormerLayer,
RoFormerAttention,
RoFormerIntermediate,
RoFormerOutput,
RoFormerSelfAttention,
RoFormerPreTrainedModel
)
from transformers.models.mpnet.modeling_mpnet import MPNetModel
class JRoFormerEmbeddings(RoFormerEmbeddings):
"""Construct the embeddings from word and token_type embeddings."""
def __init__(self, config):
super().__init__(config)
self.word_embeddings = nn.Embedding(
config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id
)
self.token_type_embeddings = self.word_embeddings
class JRoFormerSelfAttention(RoFormerSelfAttention):
def __init__(self, config):
super().__init__(config)
self.query = nn.Linear(
config.hidden_size, self.all_head_size, bias=config.use_bias
)
self.key = nn.Linear(
config.hidden_size, self.all_head_size, bias=config.use_bias
)
self.value = nn.Linear(
config.hidden_size, self.all_head_size, bias=config.use_bias
)
class JRoFormerAttention(RoFormerAttention):
def __init__(self, config):
super().__init__(config)
self.self = JRoFormerSelfAttention(config)
class JRoFormerLayer(RoFormerLayer):
def __init__(self, config):
super().__init__(config)
self.attention = JRoFormerAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(
f"{self} should be used as a decoder model if cross attention is added"
)
self.crossattention = RoFormerAttention(config)
self.intermediate = RoFormerIntermediate(config)
self.output = RoFormerOutput(config)
class JRoFormerEncoder(RoFormerEncoder):
def __init__(self, config):
super().__init__(config)
self.layer = nn.ModuleList(
[JRoFormerLayer(config) for _ in range(config.num_hidden_layers)]
)
class JRoFormerModel(RoFormerModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.embeddings = JRoFormerEmbeddings(config)
if config.embedding_size != config.hidden_size:
self.embeddings_project = nn.Linear(
config.embedding_size, config.hidden_size
)
self.encoder = JRoFormerEncoder(config)
# Initialize weights and apply final processing
self.post_init()
class AsmEncoder(RoFormerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.roformer = JRoFormerModel(config)
self.projection = nn.Linear(config.hidden_size, config.bla_dim)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.roformer(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
token_embeddings = outputs[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype)
asm_embedding = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
asm_embedding = self.projection(asm_embedding)
asm_embedding = F.normalize(asm_embedding, p=2, dim=1)
return asm_embedding