clap-asm / clap_modeling.py
hustcw's picture
update modeling
4f9cb85
# MIT License
# Copyright (c) 2024 Hustcw
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import torch
import torch.utils.checkpoint
from torch import nn
from typing import Optional
import torch.nn.functional as F
from transformers.models.roformer.modeling_roformer import (
RoFormerEmbeddings,
RoFormerModel,
RoFormerEncoder,
RoFormerLayer,
RoFormerAttention,
RoFormerIntermediate,
RoFormerOutput,
RoFormerSelfAttention,
RoFormerPreTrainedModel
)
from transformers.models.mpnet.modeling_mpnet import MPNetModel
from transformers import MPNetTokenizerFast, BatchEncoding
class AsmTokenizer(MPNetTokenizerFast):
@property
def pad_token_type_id(self) -> int:
"""
`int`: Id of the padding token type in the vocabulary.
"""
return self.pad_token_id
def tokenize_function(self, function):
total_len = 0
tokenized_functions = {"token": [], "instr": []}
for key, value in function.items():
tokens = self.tokenize(value.replace(',', ''), max_length=20, truncation=True, add_special_tokens=False) # set max token for a instruction
instr_index = "INSTR" + key
instructions = [instr_index] * len(tokens)
tokenized_functions["token"].extend(tokens)
tokenized_functions["instr"].extend(instructions)
total_len += len(tokens)
if total_len > self.model_max_length:
tokenized_functions['token'] = tokenized_functions['token'][:self.model_max_length]
tokenized_functions['instr'] = tokenized_functions['instr'][:self.model_max_length]
break
return tokenized_functions
def encode_function(self, function):
tokenized_functions = self.tokenize_function(function)
token_ids = self.convert_tokens_to_ids(tokenized_functions["token"])
instr_ids = self.convert_tokens_to_ids(tokenized_functions["instr"])
return BatchEncoding({
"input_ids": token_ids,
"attention_mask": [1] * len(token_ids),
"token_type_ids": instr_ids,
})
def __call__(self, functions, **kwargs):
if len(functions) == 0:
return BatchEncoding({
"input_ids": [],
"attention_mask": [],
"token_type_ids": [],
})
if not isinstance(functions, list):
raise ValueError("functions must be a list of dict")
elif not isinstance(functions[0], dict):
raise ValueError("functions must be a list of dict")
else:
batch_encode_result = {
"input_ids": [],
"attention_mask": [],
"token_type_ids": [],
}
for function in functions:
tokenized_functions = self.tokenize_function(function)
token_ids = self.convert_tokens_to_ids(tokenized_functions["token"])
instr_ids = self.convert_tokens_to_ids(tokenized_functions["instr"])
attention_mask = [1] * len(token_ids)
batch_encode_result["input_ids"].append(token_ids)
batch_encode_result["attention_mask"].append(attention_mask)
batch_encode_result["token_type_ids"].append(instr_ids)
batch_encoding = BatchEncoding(batch_encode_result)
return self.pad(batch_encoding, **kwargs)
@property
def vocab_size(self) -> int:
return len(self.vocab)
class JRoFormerEmbeddings(RoFormerEmbeddings):
"""Construct the embeddings from word and token_type embeddings."""
def __init__(self, config):
super().__init__(config)
self.word_embeddings = nn.Embedding(
config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id
)
self.token_type_embeddings = self.word_embeddings
class JRoFormerSelfAttention(RoFormerSelfAttention):
def __init__(self, config):
super().__init__(config)
self.query = nn.Linear(
config.hidden_size, self.all_head_size, bias=config.use_bias
)
self.key = nn.Linear(
config.hidden_size, self.all_head_size, bias=config.use_bias
)
self.value = nn.Linear(
config.hidden_size, self.all_head_size, bias=config.use_bias
)
class JRoFormerAttention(RoFormerAttention):
def __init__(self, config):
super().__init__(config)
self.self = JRoFormerSelfAttention(config)
class JRoFormerLayer(RoFormerLayer):
def __init__(self, config):
super().__init__(config)
self.attention = JRoFormerAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(
f"{self} should be used as a decoder model if cross attention is added"
)
self.crossattention = RoFormerAttention(config)
self.intermediate = RoFormerIntermediate(config)
self.output = RoFormerOutput(config)
class JRoFormerEncoder(RoFormerEncoder):
def __init__(self, config):
super().__init__(config)
self.layer = nn.ModuleList(
[JRoFormerLayer(config) for _ in range(config.num_hidden_layers)]
)
class JRoFormerModel(RoFormerModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.embeddings = JRoFormerEmbeddings(config)
if config.embedding_size != config.hidden_size:
self.embeddings_project = nn.Linear(
config.embedding_size, config.hidden_size
)
self.encoder = JRoFormerEncoder(config)
# Initialize weights and apply final processing
self.post_init()
class AsmEncoder(RoFormerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.jroformer = JRoFormerModel(config)
self.projection = nn.Linear(config.hidden_size, config.hidden_size)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.jroformer(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
token_embeddings = outputs[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype)
asm_embedding = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
asm_embedding = self.projection(asm_embedding)
asm_embedding = F.normalize(asm_embedding, p=2, dim=1)
return asm_embedding
class TextEncoder(MPNetModel):
def __init__(self, config, add_pooling_layer=True):
super().__init__(config, add_pooling_layer=add_pooling_layer)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
):
output = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
token_embeddings = output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
text_embedding = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
text_embedding = F.normalize(text_embedding, p=2, dim=1)
return text_embedding