sandspeare commited on
Commit
39a3276
β€’
1 Parent(s): 18a319d
encoder/config.json β†’ config.json RENAMED
File without changes
optrans_modeling.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.utils.checkpoint
4
+ from torch import nn
5
+ from typing import Optional, Tuple
6
+ import torch.nn.functional as F
7
+ from transformers import BatchEncoding
8
+ from transformers import MPNetTokenizerFast
9
+
10
+
11
+ from transformers.models.roformer.modeling_roformer import (
12
+ RoFormerEmbeddings,
13
+ RoFormerModel,
14
+ RoFormerEncoder,
15
+ RoFormerLayer,
16
+ RoFormerAttention,
17
+ RoFormerIntermediate,
18
+ RoFormerOutput,
19
+ RoFormerSelfAttention,
20
+ RoFormerPreTrainedModel
21
+ )
22
+
23
+ from transformers.models.mpnet.modeling_mpnet import MPNetModel
24
+
25
+
26
+ class JRoFormerEmbeddings(RoFormerEmbeddings):
27
+ """Construct the embeddings from word and token_type embeddings."""
28
+
29
+ def __init__(self, config):
30
+ super().__init__(config)
31
+ self.word_embeddings = nn.Embedding(
32
+ config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id
33
+ )
34
+ self.token_type_embeddings = self.word_embeddings
35
+
36
+
37
+ class JRoFormerSelfAttention(RoFormerSelfAttention):
38
+ def __init__(self, config):
39
+ super().__init__(config)
40
+ self.query = nn.Linear(
41
+ config.hidden_size, self.all_head_size, bias=config.use_bias
42
+ )
43
+ self.key = nn.Linear(
44
+ config.hidden_size, self.all_head_size, bias=config.use_bias
45
+ )
46
+ self.value = nn.Linear(
47
+ config.hidden_size, self.all_head_size, bias=config.use_bias
48
+ )
49
+
50
+
51
+ class JRoFormerAttention(RoFormerAttention):
52
+ def __init__(self, config):
53
+ super().__init__(config)
54
+ self.self = JRoFormerSelfAttention(config)
55
+
56
+
57
+ class JRoFormerLayer(RoFormerLayer):
58
+ def __init__(self, config):
59
+ super().__init__(config)
60
+ self.attention = JRoFormerAttention(config)
61
+ self.is_decoder = config.is_decoder
62
+ self.add_cross_attention = config.add_cross_attention
63
+ if self.add_cross_attention:
64
+ if not self.is_decoder:
65
+ raise ValueError(
66
+ f"{self} should be used as a decoder model if cross attention is added"
67
+ )
68
+ self.crossattention = RoFormerAttention(config)
69
+ self.intermediate = RoFormerIntermediate(config)
70
+ self.output = RoFormerOutput(config)
71
+
72
+
73
+ class JRoFormerEncoder(RoFormerEncoder):
74
+ def __init__(self, config):
75
+ super().__init__(config)
76
+ self.layer = nn.ModuleList(
77
+ [JRoFormerLayer(config) for _ in range(config.num_hidden_layers)]
78
+ )
79
+
80
+
81
+ class JRoFormerModel(RoFormerModel):
82
+ def __init__(self, config):
83
+ super().__init__(config)
84
+ self.config = config
85
+ self.embeddings = JRoFormerEmbeddings(config)
86
+
87
+ if config.embedding_size != config.hidden_size:
88
+ self.embeddings_project = nn.Linear(
89
+ config.embedding_size, config.hidden_size
90
+ )
91
+
92
+ self.encoder = JRoFormerEncoder(config)
93
+
94
+ # Initialize weights and apply final processing
95
+ self.post_init()
96
+
97
+ class AsmEncoder(RoFormerPreTrainedModel):
98
+ def __init__(self, config):
99
+ super().__init__(config)
100
+ self.config = config
101
+ self.roformer = JRoFormerModel(config)
102
+ self.projection = nn.Linear(config.hidden_size, config.bla_dim)
103
+
104
+ def forward(
105
+ self,
106
+ input_ids: Optional[torch.LongTensor] = None,
107
+ attention_mask: Optional[torch.FloatTensor] = None,
108
+ token_type_ids: Optional[torch.LongTensor] = None,
109
+ head_mask: Optional[torch.FloatTensor] = None,
110
+ inputs_embeds: Optional[torch.FloatTensor] = None,
111
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
112
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
113
+ output_attentions: Optional[bool] = None,
114
+ output_hidden_states: Optional[bool] = None,
115
+ return_dict: Optional[bool] = None,
116
+ ):
117
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
118
+
119
+ outputs = self.roformer(
120
+ input_ids,
121
+ attention_mask=attention_mask,
122
+ token_type_ids=token_type_ids,
123
+ head_mask=head_mask,
124
+ inputs_embeds=inputs_embeds,
125
+ encoder_hidden_states=encoder_hidden_states,
126
+ encoder_attention_mask=encoder_attention_mask,
127
+ output_attentions=output_attentions,
128
+ output_hidden_states=output_hidden_states,
129
+ return_dict=return_dict,
130
+ )
131
+
132
+ token_embeddings = outputs[0]
133
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype)
134
+ asm_embedding = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
135
+ asm_embedding = self.projection(asm_embedding)
136
+ asm_embedding = F.normalize(asm_embedding, p=2, dim=1)
137
+
138
+ return asm_embedding
encoder/pytorch_model.bin β†’ pytorch_model.bin RENAMED
File without changes
tokenizer/special_tokens_map.json β†’ special_tokens_map.json RENAMED
File without changes
tokenizer/tokenizer.json β†’ tokenizer.json RENAMED
File without changes
tokenizer/tokenizer_config.json β†’ tokenizer_config.json RENAMED
File without changes