Amelie commited on
Commit
73d1860
·
0 Parent(s):

Intial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - biology
5
+ ---
6
+
7
+ # LucaOne
8
+ LucaOne: Generalized Biological Foundation Model with Unified Nucleic Acid and Protein Language.
9
+
10
+ Github Page: https://github.com/LucaOne/LucaOne
11
+
12
+ This repo contains weights (checkpoint=17600000) and core codes (modified to suit HF API, might be unstable in the current stage) for LucaOne general-purpose language model (LucaOneGPLM).
13
+
14
+
15
+ To calculate the embedding of a nucleotide/protein sequence:
16
+ ```
17
+ import torch
18
+ from transformers import AutoModel, AutoTokenizer
19
+
20
+ def gene_seq_replace(seq):
21
+ '''
22
+ Nucleic acid (gene replace: A->1, U/T->2, C->3, G->4, N->5
23
+ :param seq:
24
+ :return:
25
+ '''
26
+ new_seq = ""
27
+ for ch in seq:
28
+ if ch in ["A", "a"]:
29
+ new_seq += "1"
30
+ elif ch in ["T", "U", "t", "u"]:
31
+ new_seq += "2"
32
+ elif ch in ["C", "c"]:
33
+ new_seq += "3"
34
+ elif ch in ["G", "g"]:
35
+ new_seq += "4"
36
+ else: # unknown
37
+ new_seq += "5"
38
+ return new_seq
39
+
40
+
41
+ model = AutoModel.from_pretrained("Yuanfei/LucaOne", trust_remote_code=True)
42
+ tokenizer = AutoTokenizer.from_pretrained("Yuanfei/LucaOne", trust_remote_code=True)
43
+
44
+ # Test input
45
+ seq = "ATCGCGAGTAGCGAGNNNAGCGAT"
46
+ seq_type = "gene" # or "prot"
47
+
48
+ if seq_type == "gene":
49
+ seq = gene_seq_replace(seq)
50
+
51
+ print("seq len: %d:" % len(seq))
52
+
53
+ # Test run
54
+ seq_encoded = tokenizer.encode(seq)
55
+ input_ids = torch.tensor(seq_encoded, dtype=torch.int64).unsqueeze(0)
56
+
57
+ print("input_ids:")
58
+ print(input_ids)
59
+
60
+ if seq_type == "gene":
61
+ token_type_ids = torch.zeros_like(input_ids)
62
+ else:
63
+ token_type_ids = torch.ones_like(input_ids)
64
+
65
+ encoding = {
66
+ "input_ids": input_ids,
67
+ "token_type_ids": token_type_ids,
68
+ }
69
+
70
+ if seq_type == "prot":
71
+ new_encoding = {}
72
+ for item in encoding.items():
73
+ new_encoding[item[0] + "_b"] = item[1]
74
+ encoding = new_encoding
75
+
76
+ batch = encoding
77
+ batch["return_dict"] = True
78
+
79
+ res = model(**batch)
80
+
81
+ if seq_type == "prot":
82
+ embedding = res.hidden_states_b
83
+ else:
84
+ embedding = res.hidden_states
85
+
86
+ print("embedding matrix(include [CLS] and [SEP]):")
87
+ print(embedding)
88
+ print(embedding.shape)
89
+
90
+ print("[CLS] embedding vector:")
91
+ cls_vec = embedding[0, 0, :]
92
+ print(cls_vec)
93
+ print(cls_vec.shape)
94
+ ```
95
+
96
+ If there is an error when loading tokenizer: "ValueError: Tokenizer class AlphabetTokenizer does not exist or is not currently imported."
97
+ then try to run the alphabet.py first.
98
+ ```
99
+ ```
alphabet.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ import os
5
+ import json
6
+ import itertools
7
+ from typing import Sequence, List
8
+ from transformers import PreTrainedTokenizer
9
+
10
+ gene_standard_toks = ['1', '2', '3', '4', '5', '.', '-', '*']
11
+
12
+ prot_standard_toks = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', 'J', '.', '-', '*']
13
+
14
+ gene_prot_standard_toks = ['1', '2', '3', '4', '5', 'L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', 'J', '.', '-', '*']
15
+
16
+ gene_prot_prepend_toks = ['[PAD]', '[UNK]']
17
+
18
+ gene_prot_append_toks = ['[CLS]', '[SEP]', '[MASK]']
19
+
20
+
21
+ class Alphabet(object):
22
+ def __init__(
23
+ self,
24
+ standard_toks: Sequence[str] = gene_prot_standard_toks,
25
+ prepend_toks: Sequence[str] = gene_prot_prepend_toks,
26
+ append_toks: Sequence[str] = gene_prot_append_toks,
27
+ prepend_bos: bool = True,
28
+ append_eos: bool = True
29
+ ):
30
+ self.standard_toks = list(standard_toks)
31
+ self.prepend_toks = list(prepend_toks)
32
+ self.append_toks = list(append_toks)
33
+ self.prepend_bos = prepend_bos
34
+ self.append_eos = append_eos
35
+
36
+ self.all_toks = list(self.prepend_toks)
37
+ self.all_toks.extend(self.append_toks)
38
+ self.all_toks.extend(self.standard_toks)
39
+
40
+ self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)}
41
+
42
+ self.unk_idx = self.tok_to_idx["[UNK]"]
43
+ self.padding_idx = self.get_idx("[PAD]")
44
+ self.pad_token_id = self.padding_idx
45
+ self.cls_idx = self.get_idx("[CLS]")
46
+ self.mask_idx = self.get_idx("[MASK]")
47
+ self.eos_idx = self.get_idx("[SEP]")
48
+ self.all_special_tokens = prepend_toks + append_toks
49
+ self.all_special_token_idx_list = [self.tok_to_idx[v] for v in self.all_special_tokens]
50
+ self.unique_no_split_tokens = self.all_toks
51
+ self.vocab_size = self.__len__()
52
+
53
+ def get_vocab(self):
54
+ """Returns the vocabulary as a dictionary, required by transformers library."""
55
+ return self.tok_to_idx.copy()
56
+
57
+ def __len__(self):
58
+ return len(self.all_toks)
59
+
60
+ def get_idx(self, tok):
61
+ return self.tok_to_idx.get(tok, self.unk_idx)
62
+
63
+ def get_tok(self, ind):
64
+ return self.all_toks[ind]
65
+
66
+ def to_dict(self):
67
+ return self.tok_to_idx.copy()
68
+
69
+ @classmethod
70
+ def from_predefined(cls, name: str):
71
+ if name.lower() == "prot":
72
+ standard_toks = prot_standard_toks
73
+ elif name.lower() == "gene":
74
+ standard_toks = gene_standard_toks
75
+ elif name.lower() in ["gene_prot", "prot_gene"]:
76
+ standard_toks = gene_prot_standard_toks
77
+ else:
78
+ raise Exception("Not support tokenizer name: %s" % name)
79
+
80
+ prepend_toks = gene_prot_prepend_toks
81
+ append_toks = gene_prot_append_toks
82
+ prepend_bos = True
83
+ append_eos = True
84
+
85
+ return cls(standard_toks, prepend_toks, append_toks, prepend_bos, append_eos)
86
+
87
+ @classmethod
88
+ def from_pretrained(cls, dir_path):
89
+ import os, pickle
90
+ return pickle.load(open(os.path.join(dir_path, "alphabet.pkl"), "rb"))
91
+
92
+ def save_pretrained(self, save_dir):
93
+ import os, pickle
94
+ with open(os.path.join(save_dir, "alphabet.pkl"), 'wb') as outp:
95
+ pickle.dump(self, outp, pickle.HIGHEST_PROTOCOL)
96
+
97
+ def _tokenize(self, text) -> str:
98
+ return text.split()
99
+
100
+ def tokenize(self, text, **kwargs) -> List[str]:
101
+ def split_on_token(tok, text):
102
+ result = []
103
+ split_text = text.split(tok)
104
+ for i, sub_text in enumerate(split_text):
105
+ if i < len(split_text) - 1:
106
+ sub_text = sub_text.rstrip()
107
+ if i > 0:
108
+ sub_text = sub_text.lstrip()
109
+
110
+ if i == 0 and not sub_text:
111
+ result.append(tok)
112
+ elif i == len(split_text) - 1:
113
+ if sub_text:
114
+ result.append(sub_text)
115
+ else:
116
+ pass
117
+ else:
118
+ if sub_text:
119
+ result.append(sub_text)
120
+ result.append(tok)
121
+ return result
122
+
123
+ def split_on_tokens(tok_list, text):
124
+ if not text.strip():
125
+ return []
126
+ tokenized_text = []
127
+ text_list = [text]
128
+ for tok in tok_list:
129
+ tokenized_text = []
130
+ for sub_text in text_list:
131
+ if sub_text not in self.unique_no_split_tokens:
132
+ tokenized_text.extend(split_on_token(tok, sub_text))
133
+ else:
134
+ tokenized_text.append(sub_text)
135
+ text_list = tokenized_text
136
+
137
+ return list(
138
+ itertools.chain.from_iterable(
139
+ (
140
+ self._tokenize(token)
141
+ if token not in self.unique_no_split_tokens
142
+ else [token]
143
+ for token in tokenized_text
144
+ )
145
+ )
146
+ )
147
+
148
+ no_split_token = self.unique_no_split_tokens
149
+ tokenized_text = split_on_tokens(no_split_token, text)
150
+ return tokenized_text
151
+
152
+ def encode(self, text):
153
+ return [self.tok_to_idx[tok] for tok in self.tokenize(text)]
154
+
155
+ class AlphabetTokenizer(PreTrainedTokenizer):
156
+ def __init__(
157
+ self,
158
+ alphabet: Alphabet = Alphabet(),
159
+ **kwargs
160
+ ):
161
+ super().__init__(**kwargs)
162
+ self.alphabet = alphabet
163
+ self.pad_token = '[PAD]'
164
+ self.cls_token = '[CLS]'
165
+ self.sep_token = '[SEP]'
166
+ self.mask_token = '[MASK]'
167
+ self.unk_token = '[UNK]'
168
+
169
+ def _tokenize(self, text: str):
170
+ # Use your Alphabet class's tokenize method
171
+ return self.alphabet.tokenize(text)
172
+
173
+ def convert_tokens_to_ids(self, tokens):
174
+ # Use the Alphabet class's get_idx method
175
+ return [self.alphabet.get_idx(token) for token in tokens]
176
+
177
+ def convert_ids_to_tokens(self, ids):
178
+ # Use the Alphabet class's get_tok method
179
+ return [self.alphabet.get_tok(index) for index in ids]
180
+
181
+ def save_vocabulary(self, save_directory, filename_prefix=None):
182
+ # Save the tokenizer vocabulary, required by Hugging Face
183
+ vocab_file = os.path.join(save_directory, (filename_prefix or "") + "vocab.json")
184
+ with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
185
+ json.dump(self.alphabet.to_dict(), vocab_writer, ensure_ascii=False)
186
+ return (vocab_file,)
187
+
188
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
189
+ # Add special tokens to input ids, if required
190
+ cls_token = [self.alphabet.cls_idx]
191
+ sep_token = [self.alphabet.eos_idx]
192
+ if token_ids_1:
193
+ return cls_token + token_ids_0 + sep_token + token_ids_1 + sep_token
194
+ return cls_token + token_ids_0 + sep_token
alphabet_atom.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ from rdkit import Chem
5
+ from rdkit.Chem import AllChem
6
+ from typing import Sequence, List
7
+
8
+ atom_standard_toks = ['C', 'N', 'O', 'S', 'H', 'Cl', 'F', 'Br', 'I',
9
+ 'Si', 'P', 'B', 'Na', 'K', 'Al', 'Ca', 'Sn', 'As',
10
+ 'Hg', 'Fe', 'Zn', 'Cr', 'Se', 'Gd', 'Au', 'Li'
11
+ ]
12
+
13
+ atom_prepend_toks = ['[PAD]', '[UNK]', '[CLS]']
14
+
15
+ atom_append_toks = ['[SEP]', '[MASK]']
16
+
17
+
18
+ class AlphabetAtom(object):
19
+ def __init__(
20
+ self,
21
+ standard_toks: Sequence[str] = atom_standard_toks,
22
+ prepend_toks: Sequence[str] = atom_prepend_toks,
23
+ append_toks: Sequence[str] = atom_append_toks,
24
+ prepend_bos: bool = True,
25
+ append_eos: bool = True
26
+ ):
27
+ self.standard_toks = list(standard_toks)
28
+ self.prepend_toks = list(prepend_toks)
29
+ self.append_toks = list(append_toks)
30
+ self.prepend_bos = prepend_bos
31
+ self.append_eos = append_eos
32
+
33
+ self.all_toks = list(self.prepend_toks)
34
+ self.all_toks.extend(self.append_toks)
35
+ self.all_toks.extend(self.standard_toks)
36
+
37
+ self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)}
38
+
39
+ self.unk_idx = self.tok_to_idx["[UNK]"]
40
+ self.padding_idx = self.get_idx("[PAD]")
41
+ self.pad_idx = self.get_idx("[PAD]")
42
+ self.pad_token_id = self.padding_idx
43
+ self.cls_idx = self.get_idx("[CLS]")
44
+ self.mask_idx = self.get_idx("[MASK]")
45
+ self.eos_idx = self.get_idx("[SEP]")
46
+ self.all_special_tokens = prepend_toks + append_toks
47
+ self.all_special_token_idx_list = [self.tok_to_idx[v] for v in self.all_special_tokens]
48
+ self.unique_no_split_tokens = self.all_toks
49
+ self.vocab_size = self.__len__()
50
+
51
+ def __len__(self):
52
+ return len(self.all_toks)
53
+
54
+ def get_idx(self, tok):
55
+ return self.tok_to_idx.get(tok, self.unk_idx)
56
+
57
+ def get_tok(self, ind):
58
+ return self.all_toks[ind]
59
+
60
+ def to_dict(self):
61
+ return self.tok_to_idx.copy()
62
+
63
+ def get_batch_converter(self, task_level_type, label_size, output_mode, no_position_embeddings,
64
+ no_token_type_embeddings, truncation_seq_length: int = None, ignore_index: int = -100, mlm_probability=0.15):
65
+ '''
66
+ return BatchConverter(
67
+ task_level_type,
68
+ label_size,
69
+ output_mode,
70
+ seq_subword=False,
71
+ seq_tokenizer=self,
72
+ no_position_embeddings=no_position_embeddings,
73
+ no_token_type_embeddings=no_token_type_embeddings,
74
+ truncation_seq_length=truncation_seq_length,
75
+ truncation_matrix_length=truncation_seq_length,
76
+ ignore_index=ignore_index,
77
+ mlm_probability=mlm_probability,
78
+ prepend_bos=self.prepend_bos,
79
+ append_eos=self.append_eos)
80
+ '''
81
+ pass
82
+
83
+ @classmethod
84
+ def smiles_2_atom_seq(cls, smi):
85
+ mol = Chem.MolFromSmiles(smi)
86
+ mol = AllChem.AddHs(mol)
87
+ atoms = [atom.GetSymbol() for atom in mol.GetAtoms()] # after add H
88
+ return atoms
89
+
90
+ @classmethod
91
+ def from_predefined(cls, name: str = "atom_v1"):
92
+ if name.lower() == "atom_v1":
93
+ standard_toks = atom_standard_toks
94
+ else:
95
+ raise Exception("Not support tokenizer name: %s" % name)
96
+
97
+ prepend_toks = atom_prepend_toks
98
+ append_toks = atom_append_toks
99
+ prepend_bos = True
100
+ append_eos = True
101
+
102
+ return cls(standard_toks, prepend_toks, append_toks, prepend_bos, append_eos)
103
+
104
+ @classmethod
105
+ def from_pretrained(cls, dir_path):
106
+ import os, pickle
107
+ return pickle.load(open(os.path.join(dir_path, "alphabet_atom.pkl"), "rb"))
108
+
109
+ def save_pretrained(self, save_dir):
110
+ import os, pickle
111
+ with open(os.path.join(save_dir, "alphabet_atom.pkl"), 'wb') as outp:
112
+ pickle.dump(self, outp, pickle.HIGHEST_PROTOCOL)
113
+
114
+ def tokenize(self, smi, prepend_bos, append_eos) -> List[str]:
115
+ seq = AlphabetAtom.smiles_2_atom_seq(smi)
116
+ if prepend_bos:
117
+ seq = [self.get_tok(self.cls_idx)] + seq
118
+ if append_eos:
119
+ seq = seq + [self.get_tok(self.eos_idx)]
120
+ return seq
121
+
122
+ def encode(self, atom_list, prepend_bos, append_eos):
123
+ idx_list = [self.get_idx(tok) for tok in atom_list]
124
+ if prepend_bos:
125
+ idx_list = [self.cls_idx] + idx_list
126
+ if append_eos:
127
+ idx_list = idx_list + [self.eos_idx]
128
+ return idx_list
129
+
130
+ def encode_smi(self, smi, prepend_bos, append_eos):
131
+ atom_list = self.smiles_2_atom_seq(smi)
132
+ return self.encode(atom_list, prepend_bos, append_eos)
batch_converter.py ADDED
@@ -0,0 +1,1365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ import sys
5
+ import torch
6
+ from typing import Sequence
7
+
8
+ from .alphabet_atom import AlphabetAtom
9
+ from .utils import gene_seq_replace
10
+
11
+ class BatchConverter(object):
12
+
13
+ def __init__(self,
14
+ task_level_type,
15
+ label_size,
16
+ output_mode,
17
+ seq_subword,
18
+ seq_tokenizer,
19
+ no_position_embeddings,
20
+ no_token_type_embeddings,
21
+ truncation_seq_length: int = None,
22
+ truncation_matrix_length: int = None,
23
+ atom_tokenizer: AlphabetAtom = None,
24
+ atom_truncation_seq_length: int = None,
25
+ atom_truncation_matrix_length: int = None,
26
+ ignore_index: int = -100,
27
+ padding_idx: int = 0,
28
+ unk_idx: int = 1,
29
+ cls_idx: int = 2,
30
+ eos_idx: int = 3,
31
+ mask_idx: int = 4,
32
+ non_ignore: bool = False,
33
+ mlm_probability=0.15,
34
+ prepend_bos=None,
35
+ append_eos=None,
36
+ **kwargs):
37
+ print("------BatchConverter------")
38
+ print("BatchConverter, kwargs:")
39
+ print(kwargs)
40
+ self.task_level_type = task_level_type
41
+ self.label_size = label_size
42
+ self.output_mode = output_mode
43
+ self.seq_tokenizer = seq_tokenizer
44
+ self.seq_subword = seq_subword
45
+ self.ignore_index = ignore_index
46
+ self.non_ignore = non_ignore
47
+ self.mlm_probability = mlm_probability
48
+ self.truncation_seq_length = truncation_seq_length
49
+ self.truncation_matrix_length = truncation_matrix_length
50
+
51
+ # subword 则必包含两个特殊字符
52
+ if prepend_bos is None:
53
+ if seq_subword is not None:
54
+ self.prepend_bos = True
55
+ else:
56
+ self.prepend_bos = False
57
+ else:
58
+ self.prepend_bos = prepend_bos
59
+ if append_eos is None:
60
+ if seq_subword is not None:
61
+ self.append_eos = True
62
+ else:
63
+ self.append_eos = False
64
+ else:
65
+ self.append_eos = append_eos
66
+
67
+ self.padding_idx = padding_idx
68
+ self.unk_idx = unk_idx
69
+ self.cls_idx = cls_idx
70
+ self.eos_idx = eos_idx
71
+ self.mask_idx = mask_idx
72
+ if self.seq_tokenizer is None:
73
+ self.append_len = 0
74
+ else:
75
+ if hasattr(seq_tokenizer, "prepend_bos"):
76
+ self.prepend_bos = self.seq_tokenizer.prepend_bos
77
+ if hasattr(seq_tokenizer, "append_eos"):
78
+ self.append_eos = self.seq_tokenizer.append_eos
79
+ if hasattr(seq_tokenizer, "padding_idx"):
80
+ self.padding_idx = self.seq_tokenizer.padding_idx
81
+ if hasattr(seq_tokenizer, "unk_idx"):
82
+ self.unk_idx = self.seq_tokenizer.unk_idx
83
+ if hasattr(seq_tokenizer, "cls_idx"):
84
+ self.cls_idx = self.seq_tokenizer.cls_idx
85
+ if hasattr(seq_tokenizer, "eos_idx"):
86
+ self.eos_idx = self.seq_tokenizer.eos_idx
87
+ if hasattr(seq_tokenizer, "mask_idx"):
88
+ self.mask_idx = self.seq_tokenizer.mask_idx
89
+ if hasattr(seq_tokenizer, "all_special_token_idx_list"):
90
+ self.all_special_token_idx_list = self.seq_tokenizer.all_special_token_idx_list
91
+ else:
92
+ self.all_special_token_idx_list = [self.padding_idx, self.unk_idx, self.cls_idx, self.eos_idx, self.mask_idx]
93
+ self.append_len = int(self.prepend_bos) + int(self.append_eos)
94
+
95
+ # for atom
96
+ self.atom_tokenizer = atom_tokenizer
97
+ self.atom_truncation_seq_length = atom_truncation_seq_length
98
+ self.atom_truncation_matrix_length = atom_truncation_matrix_length
99
+ self.atom_prepend_bos = False
100
+ self.atom_append_eos = False
101
+ self.atom_padding_idx = padding_idx
102
+ self.atom_unk_idx = unk_idx
103
+ self.atom_cls_idx = cls_idx
104
+ self.atom_eos_idx = eos_idx
105
+ self.atom_mask_idx = mask_idx
106
+ if self.atom_tokenizer is None:
107
+ self.atom_append_len = 0
108
+ else:
109
+ if hasattr(seq_tokenizer, "padding_idx"):
110
+ self.padding_idx = self.seq_tokenizer.padding_idx
111
+ elif hasattr(seq_tokenizer, "pad_idx"):
112
+ self.padding_idx = self.seq_tokenizer.pad_idx
113
+ elif hasattr(seq_tokenizer, "pad_token_id"):
114
+ self.padding_idx = self.seq_tokenizer.pad_token_id
115
+
116
+ if hasattr(seq_tokenizer, "unk_idx"):
117
+ self.unk_idx = self.seq_tokenizer.unk_idx
118
+ elif hasattr(seq_tokenizer, "unk_token_id"):
119
+ self.unk_idx = self.seq_tokenizer.unk_token_id
120
+
121
+ if hasattr(seq_tokenizer, "cls_idx"):
122
+ self.cls_idx = self.seq_tokenizer.cls_idx
123
+ elif hasattr(seq_tokenizer, "cls_token_id"):
124
+ self.cls_idx = self.seq_tokenizer.cls_token_id
125
+ elif hasattr(seq_tokenizer, "bos_idx"):
126
+ self.cls_idx = self.seq_tokenizer.bos_idx
127
+ elif hasattr(seq_tokenizer, "bos_token_id"):
128
+ self.cls_idx = self.seq_tokenizer.bos_token_id
129
+
130
+ if hasattr(seq_tokenizer, "eos_idx"):
131
+ self.eos_idx = self.seq_tokenizer.eos_idx
132
+ elif hasattr(seq_tokenizer, "eos_token_id"):
133
+ self.eos_idx = self.seq_tokenizer.eos_token_id
134
+ elif hasattr(seq_tokenizer, "sep_token_id"):
135
+ self.eos_idx = self.seq_tokenizer.sep_token_id
136
+
137
+ if hasattr(seq_tokenizer, "mask_idx"):
138
+ self.mask_idx = self.seq_tokenizer.mask_idx
139
+ elif hasattr(seq_tokenizer, "mask_token_id"):
140
+ self.mask_idx = self.seq_tokenizer.mask_token_id
141
+ if hasattr(atom_tokenizer, "all_special_token_idx_list"):
142
+ self.atom_all_special_token_idx_list = self.atom_tokenizer.all_special_token_idx_list
143
+ else:
144
+ self.atom_all_special_token_idx_list = [self.padding_idx, self.unk_idx, self.cls_idx, self.eos_idx, self.mask_idx]
145
+ self.atom_append_len = int(self.atom_prepend_bos) + int(self.atom_append_eos)
146
+
147
+ print("BatchConverter: prepend_bos=%r, append_eos=%r" % (self.prepend_bos, self.append_eos))
148
+ print("BatchConverter: atom_prepend_bos=%r, atom_append_eos=%r" % (self.atom_prepend_bos, self.atom_append_eos))
149
+ self.matrix_add_special_token = False
150
+ if "matrix_add_special_token" in kwargs and kwargs["matrix_add_special_token"]:
151
+ self.matrix_add_special_token = kwargs["matrix_add_special_token"]
152
+ if self.matrix_add_special_token:
153
+ self.prepend_bos = True
154
+ self.append_eos = True
155
+ self.atom_prepend_bos = True
156
+ self.atom_append_eos = True
157
+ self.append_len = int(self.prepend_bos) + int(self.append_eos)
158
+ self.atom_append_len = int(self.atom_prepend_bos) + int(self.atom_append_eos)
159
+
160
+ # 减去特殊字符之后的长度
161
+ self.truncation_seq_length -= self.append_len
162
+ self.truncation_matrix_length -= self.append_len
163
+ # 减去特殊字符之后的长度
164
+ if self.atom_truncation_seq_length:
165
+ self.atom_truncation_seq_length -= self.atom_append_len
166
+ if self.atom_truncation_matrix_length:
167
+ self.atom_truncation_matrix_length -= self.atom_append_len
168
+
169
+ self.input_type = None
170
+ if "input_type" in kwargs and kwargs["input_type"]:
171
+ self.input_type = kwargs["input_type"]
172
+
173
+ if "max_sentence_length" in kwargs and kwargs["max_sentence_length"]:
174
+ self.max_sentence_length = kwargs["max_sentence_length"] - self.append_len
175
+ print("BatchConverter: self.max_sentence_length=%d" % self.max_sentence_length)
176
+ if atom_tokenizer is not None:
177
+ self.atom_max_sentence_length = kwargs["max_sentence_length"] - self.atom_append_len
178
+ print("BatchConverter: self.atom_max_sentence_length=%d" % self.atom_max_sentence_length)
179
+ if "max_sentences" in kwargs and kwargs["max_sentences"]:
180
+ self.max_sentences = kwargs["max_sentences"]
181
+ print("BatchConverter: self.max_sentences=%d" % self.max_sentences)
182
+ self.trunc_type = "right"
183
+ if "trunc_type" in kwargs and kwargs["trunc_type"]:
184
+ self.trunc_type = kwargs["trunc_type"]
185
+ print("BatchConverter: self.trunc_type=%s" % self.trunc_type)
186
+
187
+ self.no_position_embeddings = no_position_embeddings
188
+ self.no_token_type_embeddings = no_token_type_embeddings
189
+ print("BatchConverter: prepend_bos=%r, append_eos=%r" % (self.prepend_bos, self.append_eos))
190
+ print("BatchConverter: atom_prepend_bos=%r, atom_append_eos=%r" % (self.atom_prepend_bos, self.atom_append_eos))
191
+ print("-" * 50)
192
+
193
+ def __parse_label__(self, max_length, task_level_type, label_size, output_mode, label):
194
+ if isinstance(label, str):
195
+ label = eval(label)
196
+ '''
197
+ print("label:")
198
+ print(label)
199
+ '''
200
+ # 需要是padding长度
201
+ cur_len = max_length
202
+ if task_level_type in ["token_level", "structure_level"]:
203
+ if output_mode in ["multi_label", "multi-label"]:
204
+ # N * seq_len * label_size
205
+ new_label = []
206
+ for _ in range(cur_len):
207
+ tmp = []
208
+ for _ in range(label_size):
209
+ tmp.append(0 if self.non_ignore else self.ignore_index)
210
+ new_label.append(tmp)
211
+ else:
212
+ # N * seq_len
213
+ new_label = []
214
+ for _ in range(cur_len):
215
+ new_label.append(0 if self.non_ignore else self.ignore_index)
216
+ if label is not None and len(label) > 0:
217
+ begin_idx = 0
218
+ end_idx = cur_len
219
+ if self.prepend_bos:
220
+ begin_idx = 1
221
+ if self.append_eos:
222
+ end_idx = cur_len - 1
223
+ for idx, item in enumerate(label):
224
+ idx += begin_idx
225
+ if idx >= end_idx:
226
+ break
227
+ if output_mode in ["multi_label", "multi-label"]:
228
+ for v in item:
229
+ new_label[idx][v] = 1
230
+ else:
231
+ new_label[idx] = item
232
+ elif task_level_type == "span_level":
233
+ if output_mode in ["multi_label", "multi-label"]:
234
+ # N * seq_len * label_size
235
+ new_label = []
236
+ for _ in range(cur_len):
237
+ tmp = []
238
+ for _ in range(label_size):
239
+ tmp.append(0 if self.non_ignore else self.ignore_index)
240
+ new_label.append(tmp)
241
+ else:
242
+ # N * seq_len
243
+ new_label = []
244
+ for _ in range(cur_len):
245
+ new_label.append(0 if self.non_ignore else self.ignore_index)
246
+ if label is not None and len(label) > 0:
247
+ begin_idx = 0
248
+ end_idx = cur_len
249
+ if self.prepend_bos:
250
+ begin_idx = 1
251
+ if self.append_eos:
252
+ end_idx = cur_len - 1
253
+ for item in label:
254
+ for idx in range(item[0], item[1] + 1, 1):
255
+ idx += begin_idx
256
+ if idx >= end_idx:
257
+ break
258
+ if output_mode in ["multi_label", "multi-label"]:
259
+ new_label[idx][item[2]] = 1
260
+ else:
261
+ new_label[idx] = item[2]
262
+ elif task_level_type in ["seq_level"]:
263
+ if output_mode in ["multi_label", "multi-label"]:
264
+ # N * label_size
265
+ new_label = []
266
+ for _ in range(label_size):
267
+ new_label.append(0 if self.non_ignore else self.ignore_index)
268
+ else:
269
+ # N * 1
270
+ new_label = [0 if self.non_ignore else self.ignore_index]
271
+ if output_mode in ["multi_label", "multi-label"]:
272
+ if label is not None and len(label) > 0:
273
+ for v in label:
274
+ new_label[int(v)] = 1
275
+ else:
276
+ if label is not None and len(str(label)) > 0:
277
+ if isinstance(label, str):
278
+ new_label = [int(label)]
279
+ elif isinstance(label, list):
280
+ new_label = [int(label[0])]
281
+ else:
282
+ new_label = [label]
283
+ else:
284
+ raise Exception("Not support task_level_type=%s" % task_level_type)
285
+ return new_label
286
+
287
+ def __atom_parse_label__(self, max_length, task_level_type, label_size, output_mode, label):
288
+ if isinstance(label, str):
289
+ label = eval(label)
290
+ '''
291
+ print("label:")
292
+ print(label)
293
+ '''
294
+ # 需要是padding长度
295
+ cur_len = max_length
296
+ if task_level_type in ["token_level", "structure_level"]:
297
+ if output_mode in ["multi_label", "multi-label"]:
298
+ # N * seq_len * label_size
299
+ new_label = []
300
+ for _ in range(cur_len):
301
+ tmp = []
302
+ for _ in range(label_size):
303
+ tmp.append(0 if self.non_ignore else self.ignore_index)
304
+ new_label.append(tmp)
305
+ else:
306
+ # N * seq_len
307
+ new_label = []
308
+ for _ in range(cur_len):
309
+ new_label.append(0 if self.non_ignore else self.ignore_index)
310
+ if label is not None and len(label) > 0:
311
+ begin_idx = 0
312
+ end_idx = cur_len
313
+ if self.atom_prepend_bos:
314
+ begin_idx = 1
315
+ if self.atom_append_eos:
316
+ end_idx = cur_len - 1
317
+ for idx, item in enumerate(label):
318
+ idx += begin_idx
319
+ if idx >= end_idx:
320
+ break
321
+ if output_mode in ["multi_label", "multi-label"]:
322
+ for v in item:
323
+ new_label[idx][v] = 1
324
+ else:
325
+ new_label[idx] = item
326
+ elif task_level_type == "span_level":
327
+ if output_mode in ["multi_label", "multi-label"]:
328
+ # N * seq_len * label_size
329
+ new_label = []
330
+ for _ in range(cur_len):
331
+ tmp = []
332
+ for _ in range(label_size):
333
+ tmp.append(0 if self.non_ignore else self.ignore_index)
334
+ new_label.append(tmp)
335
+ else:
336
+ # N * seq_len
337
+ new_label = []
338
+ for _ in range(cur_len):
339
+ new_label.append(0 if self.non_ignore else self.ignore_index)
340
+ if label is not None and len(label) > 0:
341
+ begin_idx = 0
342
+ end_idx = cur_len
343
+ if self.atom_prepend_bos:
344
+ begin_idx = 1
345
+ if self.atom_append_eos:
346
+ end_idx = cur_len - 1
347
+ for item in label:
348
+ for idx in range(item[0], item[1] + 1, 1):
349
+ idx += begin_idx
350
+ if idx >= end_idx:
351
+ break
352
+ if output_mode in ["multi_label", "multi-label"]:
353
+ new_label[idx][item[2]] = 1
354
+ else:
355
+ new_label[idx] = item[2]
356
+ elif task_level_type in ["seq_level"]:
357
+ if output_mode in ["multi_label", "multi-label"]:
358
+ # N * label_size
359
+ new_label = []
360
+ for _ in range(label_size):
361
+ new_label.append(0 if self.non_ignore else self.ignore_index)
362
+ else:
363
+ # N * 1
364
+ new_label = [0 if self.non_ignore else self.ignore_index]
365
+ if output_mode in ["multi_label", "multi-label"]:
366
+ if label is not None and len(label) > 0:
367
+ for v in label:
368
+ new_label[int(v)] = 1
369
+ else:
370
+ if label is not None and len(str(label)) > 0:
371
+ if isinstance(label, str):
372
+ new_label = [int(label)]
373
+ elif isinstance(label, list):
374
+ new_label = [int(label[0])]
375
+ else:
376
+ new_label = [label]
377
+ else:
378
+ raise Exception("Not support task_level_type=%s" % task_level_type)
379
+
380
+ return new_label
381
+
382
+ def __mask_tokens__(self, input_ids):
383
+ labels = input_ids.clone()
384
+ probability_matrix = torch.full(labels.shape, self.mlm_probability)
385
+
386
+ # 特殊字符处为1
387
+ special_tokens_mask = [
388
+ 1 if v in self.all_special_token_idx_list else 0 for v in labels.tolist()
389
+ ]
390
+ special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
391
+ # 将特殊字符处填充为0.0
392
+ probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
393
+
394
+ # 非特殊字符的位置
395
+ masked_indices = torch.bernoulli(probability_matrix).bool()
396
+ # 特殊字符处为-100
397
+ labels[~masked_indices] = self.ignore_index # We only compute loss on masked tokens
398
+
399
+ # 80% of the time, we replace masked input tokens with alphabet.mask_token ([MASK])
400
+ indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
401
+ input_ids[indices_replaced] = self.mask_idx
402
+
403
+ # 10% of the time, we replace masked input tokens with random word
404
+ indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
405
+ random_words = torch.randint(len(self.seq_tokenizer), labels.shape, dtype=torch.long)
406
+ input_ids[indices_random] = random_words[indices_random]
407
+
408
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
409
+ return input_ids, labels
410
+
411
+ def __atom_mask_tokens__(self, input_ids):
412
+ labels = input_ids.clone()
413
+ probability_matrix = torch.full(labels.shape, self.mlm_probability)
414
+
415
+ # 特殊字符处为1
416
+ special_tokens_mask = [
417
+ 1 if v in self.atom_all_special_token_idx_list else 0 for v in labels.tolist()
418
+ ]
419
+ special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
420
+ # 将特殊字符处填充为0.0
421
+ probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
422
+
423
+ # 非特殊字符的位置
424
+ masked_indices = torch.bernoulli(probability_matrix).bool()
425
+ # 特殊字符处为-100
426
+ labels[~masked_indices] = self.ignore_index # We only compute loss on masked tokens
427
+
428
+ # 80% of the time, we replace masked input tokens with alphabet.mask_token ([MASK])
429
+ indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
430
+ input_ids[indices_replaced] = self.atom_mask_idx
431
+
432
+ # 10% of the time, we replace masked input tokens with random word
433
+ indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
434
+ random_words = torch.randint(len(self.atom_tokenizer), labels.shape, dtype=torch.long)
435
+ input_ids[indices_random] = random_words[indices_random]
436
+
437
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
438
+ return input_ids, labels
439
+
440
+ def __seq_encode__(self, batch_size, seqs):
441
+ '''
442
+ 该函数不加特殊字符[CLS]与[SEP]
443
+ :param batch_size:
444
+ :param seqs:
445
+ :return:
446
+ '''
447
+ if self.seq_subword:
448
+ seq_encoded_list = []
449
+ for seq_str in seqs:
450
+ seq_to_list = self.seq_subword.process_line(seq_str.upper()).split(" ")
451
+ seq = " ".join(seq_to_list)
452
+ inputs = self.seq_tokenizer.encode_plus(
453
+ seq,
454
+ None,
455
+ add_special_tokens=False,
456
+ max_length=self.truncation_seq_length,
457
+ truncation=True
458
+ )
459
+ seq_encoded_list.append(inputs["input_ids"])
460
+ else:
461
+ seq_encoded_list = [self.seq_tokenizer.encode(seq_str.upper()) for seq_str in seqs]
462
+ # 该长度已经减去了需要增加的特殊字符的个数
463
+ if self.truncation_seq_length:
464
+ seq_encoded_list = [encoded[:self.truncation_seq_length] for encoded in seq_encoded_list]
465
+ max_len = max(len(seq_encoded) for seq_encoded in seq_encoded_list)
466
+ max_len = max_len + int(self.prepend_bos) + int(self.append_eos)
467
+ # for input
468
+ input_ids = torch.empty(
469
+ (
470
+ batch_size,
471
+ max_len,
472
+ ),
473
+ dtype=torch.int64,
474
+ )
475
+ input_ids.fill_(self.padding_idx)
476
+
477
+ position_ids = None
478
+ if not self.no_position_embeddings:
479
+ position_ids = torch.empty(
480
+ (
481
+ batch_size,
482
+ max_len,
483
+ ),
484
+ dtype=torch.int64,
485
+ )
486
+ position_ids.fill_(self.padding_idx)
487
+
488
+ token_type_ids = None
489
+ if not self.no_position_embeddings:
490
+ token_type_ids = torch.empty(
491
+ (
492
+ batch_size,
493
+ max_len,
494
+ ),
495
+ dtype=torch.int64,
496
+ )
497
+ token_type_ids.fill_(self.padding_idx)
498
+ attention_masks = torch.empty(
499
+ (
500
+ batch_size,
501
+ max_len,
502
+ ),
503
+ dtype=torch.int64,
504
+ )
505
+ attention_masks.fill_(0)
506
+
507
+ return seq_encoded_list, input_ids, position_ids, token_type_ids, attention_masks, max_len
508
+
509
+ def __multi_seq_encode__(self, batch_size, seqs):
510
+ '''
511
+ 该函数是多sentence的表征器,每个sentence都加[CLS]与[SEP]
512
+ :param batch_size:
513
+ :param seqs:
514
+ :return:
515
+ '''
516
+ assert hasattr(self, "max_sentences") and hasattr(self, "max_sentence_length")
517
+ max_sentence_len = 0
518
+ max_sentence_num = 0
519
+ if self.seq_subword:
520
+ seq_encoded_list = []
521
+ for cur_sample_seqs in seqs:
522
+ cur_seq_encoded_list = []
523
+ if len(cur_sample_seqs) > self.max_sentences:
524
+ # 每个样本最多cur_sample_seqs个
525
+ if self.trunc_type == "left":
526
+ cur_sample_seqs = cur_sample_seqs[-self.max_sentences:]
527
+ else:
528
+ cur_sample_seqs = cur_sample_seqs[:self.max_sentences]
529
+ if max_sentence_num < len(cur_sample_seqs):
530
+ max_sentence_num = len(cur_sample_seqs)
531
+ for seq_idx, seq_str in enumerate(cur_sample_seqs):
532
+ seq_to_list = self.seq_subword.process_line(seq_str.upper()).split(" ")
533
+ seq = " ".join(seq_to_list)
534
+ inputs = self.seq_tokenizer.encode_plus(
535
+ seq,
536
+ None,
537
+ add_special_tokens=False,
538
+ max_length=self.max_sentence_length,
539
+ truncation=True
540
+ )
541
+ if self.prepend_bos:
542
+ inputs["input_ids"] = [self.cls_idx] + inputs["input_ids"]
543
+ if self.append_eos:
544
+ inputs["input_ids"] = inputs["input_ids"] + [self.eos_idx]
545
+ if max_sentence_len < len(inputs["input_ids"]):
546
+ max_sentence_len = len(inputs["input_ids"])
547
+ cur_seq_encoded_list.append(inputs["input_ids"])
548
+ seq_encoded_list.append(cur_seq_encoded_list)
549
+ else:
550
+ seq_encoded_list = []
551
+ for cur_sample_seqs in seqs:
552
+ cur_seq_encoded_list = []
553
+ if len(cur_sample_seqs) > self.max_sentences:
554
+ # 每个样本最多cur_sample_seqs个
555
+ if self.trunc_type == "left":
556
+ cur_sample_seqs = cur_sample_seqs[-self.max_sentences:]
557
+ else:
558
+ cur_sample_seqs = cur_sample_seqs[:self.max_sentences]
559
+ if max_sentence_num < len(cur_sample_seqs):
560
+ max_sentence_num = len(cur_sample_seqs)
561
+ for seq_idx, seq_str in enumerate(cur_sample_seqs):
562
+ if len(seq_str) > self.max_sentence_length:
563
+ if self.trunc_type == "left":
564
+ seq_str = seq_str[-self.max_sentence_length:]
565
+ else:
566
+ seq_str = seq_str[:self.max_sentence_length]
567
+
568
+ inputs = self.seq_tokenizer.encode(seq_str.upper())
569
+ # print("len:%d, %s" % (len(seq_str), seq_str.upper()))
570
+ if self.prepend_bos:
571
+ inputs = [self.cls_idx] + inputs
572
+ if self.append_eos:
573
+ inputs = inputs + [self.eos_idx]
574
+ # print("inputs:%d, " %len(inputs), inputs)
575
+ cur_seq_encoded_list.append(inputs)
576
+ if max_sentence_len < len(inputs):
577
+ max_sentence_len = len(inputs)
578
+ seq_encoded_list.append(cur_seq_encoded_list)
579
+ # for input
580
+ input_ids = torch.empty(
581
+ (
582
+ batch_size,
583
+ max_sentence_num,
584
+ max_sentence_len,
585
+ ),
586
+ dtype=torch.int64,
587
+ )
588
+ input_ids.fill_(self.padding_idx)
589
+
590
+ position_ids = None
591
+ if not self.no_position_embeddings:
592
+ position_ids = torch.empty(
593
+ (
594
+ batch_size,
595
+ max_sentence_num,
596
+ max_sentence_len
597
+ ),
598
+ dtype=torch.int64,
599
+ )
600
+ position_ids.fill_(self.padding_idx)
601
+
602
+ token_type_ids = None
603
+ if not self.no_position_embeddings:
604
+ token_type_ids = torch.empty(
605
+ (
606
+ batch_size,
607
+ max_sentence_num,
608
+ max_sentence_len
609
+ ),
610
+ dtype=torch.int64,
611
+ )
612
+ token_type_ids.fill_(self.padding_idx)
613
+ attention_masks = torch.empty(
614
+ (
615
+ batch_size,
616
+ max_sentence_num,
617
+ max_sentence_len
618
+ ),
619
+ dtype=torch.int64,
620
+ )
621
+ attention_masks.fill_(0)
622
+
623
+ return seq_encoded_list, input_ids, position_ids, token_type_ids, attention_masks, max_sentence_num, max_sentence_len
624
+
625
+ def __atom_seq_encode__(self, batch_size, seqs):
626
+ '''
627
+ 该函数不加特殊字符[CLS]与[SEP]
628
+ :param batch_size:
629
+ :param seqs:
630
+ :return:
631
+ '''
632
+ seq_encoded_list = []
633
+ for seq_idx, cur_seq in enumerate(seqs):
634
+ if isinstance(cur_seq, str): # smiles
635
+ cur_seq_encoded = self.atom_tokenizer.encode_smi(cur_seq,
636
+ prepend_bos=False,
637
+ append_eos=False)
638
+ elif isinstance(cur_seq, list): # atom list
639
+ cur_seq_encoded = self.atom_tokenizer.encode(cur_seq,
640
+ prepend_bos=False,
641
+ append_eos=False)
642
+ else:
643
+ raise Exception("not support molecule input type:", type(cur_seq))
644
+ # 该长度已经减去了需要增加的特殊字符的个数
645
+ if self.atom_truncation_seq_length:
646
+ cur_seq_encoded = cur_seq_encoded[:self.atom_truncation_seq_length]
647
+ seq_encoded_list.append(cur_seq_encoded)
648
+ max_len = max(len(seq_encoded) for seq_encoded in seq_encoded_list)
649
+ max_len = max_len + int(self.atom_prepend_bos) + int(self.atom_append_eos)
650
+ # for input
651
+ input_ids = torch.empty(
652
+ (
653
+ batch_size,
654
+ max_len,
655
+ ),
656
+ dtype=torch.int64,
657
+ )
658
+ input_ids.fill_(self.atom_padding_idx)
659
+
660
+ position_ids = None
661
+ if not self.no_position_embeddings:
662
+ position_ids = torch.empty(
663
+ (
664
+ batch_size,
665
+ max_len,
666
+ ),
667
+ dtype=torch.int64,
668
+ )
669
+ position_ids.fill_(self.atom_padding_idx)
670
+
671
+ token_type_ids = None
672
+ if not self.no_position_embeddings:
673
+ token_type_ids = torch.empty(
674
+ (
675
+ batch_size,
676
+ max_len,
677
+ ),
678
+ dtype=torch.int64,
679
+ )
680
+ token_type_ids.fill_(self.atom_padding_idx)
681
+ attention_masks = torch.empty(
682
+ (
683
+ batch_size,
684
+ max_len,
685
+ ),
686
+ dtype=torch.int64,
687
+ )
688
+ attention_masks.fill_(0)
689
+
690
+ return seq_encoded_list, input_ids, position_ids, token_type_ids, attention_masks, max_len
691
+
692
+ def __vector_encode__(self, batch_size, vectors):
693
+ embedding_vector_dim = vectors[0].shape[0]
694
+ filled_vectors = torch.empty(
695
+ (
696
+ batch_size,
697
+ embedding_vector_dim
698
+ ),
699
+ dtype=torch.float32,
700
+ )
701
+ filled_vectors.fill_(0.0)
702
+ return filled_vectors, 1
703
+
704
+ def __atom_vector_encode__(self, batch_size, vectors):
705
+ return self.__vector_encode__(batch_size, vectors)
706
+
707
+ def __multi_vector_encode__(self, batch_size, vectors):
708
+ embedding_vector_dim = vectors[0][0].shape[0]
709
+ filled_vectors = torch.empty(
710
+ (
711
+ batch_size,
712
+ self.max_sentences,
713
+ embedding_vector_dim
714
+ ),
715
+ dtype=torch.float32,
716
+ )
717
+ filled_vectors.fill_(0.0)
718
+ return filled_vectors, self.max_sentences, 1
719
+
720
+ def __matrix_encode__(self, batch_size, matrices):
721
+ '''
722
+ 该函数不加特殊字符[CLS]与[SEP]的向量
723
+ :param batch_size:
724
+ :param matrices:
725
+ :return:
726
+ '''
727
+ max_len = max(matrix.shape[0] for matrix in matrices)
728
+ if self.matrix_add_special_token:
729
+ max_len -= 2
730
+ if self.truncation_matrix_length:
731
+ max_len = min(max_len, self.truncation_matrix_length)
732
+ if self.matrix_add_special_token:
733
+ max_len += 2
734
+ else:
735
+ max_len = max_len + int(self.prepend_bos) + int(self.append_eos)
736
+ embedding_vector_dim = matrices[0].shape[1]
737
+ # for input
738
+ filled_matrices = torch.empty(
739
+ (
740
+ batch_size,
741
+ max_len,
742
+ embedding_vector_dim
743
+ ),
744
+ dtype=torch.float32,
745
+ )
746
+ filled_matrices.fill_(0.0)
747
+ attention_masks = torch.empty(
748
+ (
749
+ batch_size,
750
+ max_len,
751
+ ),
752
+ dtype=torch.int64,
753
+ )
754
+ attention_masks.fill_(0)
755
+ return filled_matrices, attention_masks, max_len
756
+
757
+ def __atom_matrix_encode__(self, batch_size, matrices):
758
+ '''
759
+ 该函数不加特殊字符[CLS]与[SEP]的向量
760
+ :param batch_size:
761
+ :param matrices:
762
+ :return:
763
+ '''
764
+ max_len = max(matrix.shape[0] for matrix in matrices)
765
+ if self.matrix_add_special_token:
766
+ max_len -= 2
767
+ if self.atom_truncation_matrix_length:
768
+ max_len = min(max_len, self.atom_truncation_matrix_length)
769
+ if self.matrix_add_special_token:
770
+ max_len += 2
771
+ else:
772
+ max_len = max_len + int(self.atom_prepend_bos) + int(self.atom_append_eos)
773
+ embedding_vector_dim = matrices[0].shape[1]
774
+ # for input
775
+ filled_matrices = torch.empty(
776
+ (
777
+ batch_size,
778
+ max_len,
779
+ embedding_vector_dim
780
+ ),
781
+ dtype=torch.float32,
782
+ )
783
+ filled_matrices.fill_(0.0)
784
+ attention_masks = torch.empty(
785
+ (
786
+ batch_size,
787
+ max_len,
788
+ ),
789
+ dtype=torch.int64,
790
+ )
791
+ attention_masks.fill_(0)
792
+ return filled_matrices, attention_masks, max_len
793
+
794
+ def __multi_matrix_encode__(self, batch_size, matrices):
795
+ '''
796
+ 该函数不加特殊字符[CLS]与[SEP]的向量
797
+ :param batch_size:
798
+ :param matrices:
799
+ :return:
800
+ '''
801
+ max_sentence_num = max(len(cur_matrix) for cur_matrix in matrices)
802
+ max_sentence_num = min(max_sentence_num, self.max_sentences)
803
+ if self.trunc_type == "left":
804
+ max_sentence_len = max(max(matrix.shape[0] for matrix in cur_matrix[-max_sentence_num:]) for cur_matrix in matrices)
805
+ else:
806
+ max_sentence_len = max(max(matrix.shape[0] for matrix in cur_matrix[:max_sentence_num]) for cur_matrix in matrices)
807
+ # print("encoder max_sentence_num:%d, max_sentence_len: %d" % (max_sentence_num, max_sentence_len))
808
+ if self.matrix_add_special_token:
809
+ max_sentence_len -= 2
810
+ max_sentence_len = min(max_sentence_len, self.max_sentence_length)
811
+ # print("encoder max_sentence_num:%d, max_sentence_len: %d" % (max_sentence_num, max_sentence_len))
812
+ if self.matrix_add_special_token:
813
+ max_sentence_len += 2
814
+ else:
815
+ max_sentence_len = max_sentence_len + int(self.prepend_bos) + int(self.append_eos)
816
+ # print("encoder max_sentence_num:%d, max_sentence_len: %d" % (max_sentence_num, max_sentence_len))
817
+ # print("self.max_sentence_length: %d" % self.max_sentence_length)
818
+ # print("max_sentence_len: %d" % max_sentence_len)
819
+ embedding_vector_dim = matrices[0][0].shape[1]
820
+ # for input
821
+ filled_matrices = torch.empty(
822
+ (
823
+ batch_size,
824
+ max_sentence_num,
825
+ max_sentence_len,
826
+ embedding_vector_dim
827
+ ),
828
+ dtype=torch.float32,
829
+ )
830
+ filled_matrices.fill_(0.0)
831
+ attention_masks = torch.empty(
832
+ (
833
+ batch_size,
834
+ max_sentence_num,
835
+ max_sentence_len
836
+ ),
837
+ dtype=torch.int64,
838
+ )
839
+ attention_masks.fill_(0)
840
+ return filled_matrices, attention_masks, max_sentence_num, max_sentence_len
841
+
842
+ def __call_single__(self, batch_size, seq_types, seqs, vectors, matrices, labels):
843
+ max_length = sys.maxsize
844
+ input_ids, position_ids, token_type_ids, seq_attention_masks = None, None, None, None
845
+ seq_part_of_input = False
846
+ molecule_flag = False
847
+ multi_seq_flag = False
848
+ if seqs:
849
+ new_seqs = []
850
+ for seq_idx, seq_type in enumerate(seq_types):
851
+ if seq_type == "gene":
852
+ new_seqs.append(gene_seq_replace(seqs[seq_idx].upper()))
853
+ elif seq_type == "molecule":
854
+ if isinstance(seqs[seq_idx], str):
855
+ new_seqs.append(AlphabetAtom.smiles_2_atom_seq(seqs[seq_idx]))
856
+ else:
857
+ new_seqs.append(seqs[seq_idx])
858
+ molecule_flag = True
859
+ elif seq_type == "multi_gene":
860
+ new_seqs.append([gene_seq_replace(seq).upper() for seq in seqs[seq_idx].split(",")])
861
+ multi_seq_flag = True
862
+ elif seq_type == "multi_prot":
863
+ new_seqs.append([seq.upper() for seq in seqs[seq_idx].split(",")])
864
+ multi_seq_flag = True
865
+ else:
866
+ new_seqs.append(seqs[seq_idx].upper())
867
+ if molecule_flag:
868
+ # seq_encoded_list没有加特殊字符,input_ids标志位来占位, seq_max_length 根据标志位来加特殊字符长度
869
+ seq_encoded_list, input_ids, position_ids, token_type_ids, seq_attention_masks, seq_max_length = self.__atom_seq_encode__(
870
+ batch_size=batch_size, seqs=new_seqs)
871
+
872
+ elif multi_seq_flag:
873
+ # seq_encoded_list根据标志位来加特殊字符,input_ids根据标志位来加特殊字符, seq_max_len 根据标志位来加特殊字符长度
874
+ seq_encoded_list, input_ids, position_ids, token_type_ids, seq_attention_masks, seq_max_num, seq_max_len = self.__multi_seq_encode__(
875
+ batch_size=batch_size, seqs=new_seqs)
876
+ '''
877
+ print("seq_max_num: %d" % seq_max_num)
878
+ print("seq_max_len: %d" % seq_max_len)
879
+ print(input_ids.shape)
880
+ print("len(seq_encoded_list): %d" % len(seq_encoded_list))
881
+ for input_id in input_ids:
882
+ print(len(input_id))
883
+ for matrix in input_id:
884
+ print(matrix.shape)
885
+ print("*****")
886
+ '''
887
+ else:
888
+ # seq_encoded_list没有加特殊字符,input_ids标志位来占位, seq_max_length 根据标志位来加特殊字符长度
889
+ seq_encoded_list, input_ids, position_ids, token_type_ids, seq_attention_masks, seq_max_length = self.__seq_encode__(
890
+ batch_size=batch_size, seqs=new_seqs)
891
+ if multi_seq_flag:
892
+ max_length = min(max_length, seq_max_num * seq_max_len)
893
+ else:
894
+ max_length = min(max_length, seq_max_length)
895
+ seq_part_of_input = True
896
+
897
+ encoded_vectors = None
898
+ vector_part_of_input = False
899
+ if vectors is not None and len(vectors) > 0:
900
+ if multi_seq_flag:
901
+ encoded_vectors, vector_max_num, vector_max_len = self.__multi_vector_encode__(batch_size=batch_size, vectors=vectors)
902
+ elif molecule_flag:
903
+ encoded_vectors, vector_max_length = self.__atom_vector_encode__(batch_size=batch_size, vectors=vectors)
904
+ else:
905
+ encoded_vectors, vector_max_length = self.__vector_encode__(batch_size=batch_size, vectors=vectors)
906
+ # max_length = min(max_length, vector_max_length)
907
+ vector_part_of_input = True
908
+
909
+ encoded_matrices, matrix_attention_masks = None, None
910
+ matrix_part_of_input = False
911
+ # print("multi_seq_flag:", multi_seq_flag)
912
+ if matrices is not None and len(matrices) > 0:
913
+ if multi_seq_flag:
914
+ # 根据标记位填充,��据标记位填充,句子数量,根据标记位是否加上特殊字符长度
915
+ encoded_matrices, matrix_attention_masks, matrix_max_num, matrix_max_len = self.__multi_matrix_encode__(
916
+ batch_size=batch_size,
917
+ matrices=matrices)
918
+ '''
919
+ print("matrix_max_num: %d" % matrix_max_num)
920
+ print("matrix_max_len: %d" % matrix_max_len)
921
+ print(encoded_matrices.shape)
922
+ print("len(matrices): %d" % len(matrices))
923
+ for matrix_array in matrices:
924
+ print(len(matrix_array))
925
+ for matrix in matrix_array:
926
+ print(matrix.shape)
927
+ print("*****")
928
+ '''
929
+ elif molecule_flag:
930
+ # 根据标记位填充,根据标记位填充,句子数量,根据标记位是否加上特殊字符长度
931
+ encoded_matrices, matrix_attention_masks, matrix_max_length = self.__atom_matrix_encode__(batch_size=batch_size,
932
+ matrices=matrices
933
+ )
934
+ else:
935
+ # 根据标记位填充,根据标记位填充,句子数量,根据标记位是否加上特殊字符长度
936
+ encoded_matrices, matrix_attention_masks, matrix_max_length = self.__matrix_encode__(batch_size=batch_size,
937
+ matrices=matrices)
938
+ if multi_seq_flag:
939
+ max_length = min(max_length, matrix_max_num * matrix_max_len)
940
+ else:
941
+ max_length = min(max_length, matrix_max_length)
942
+ matrix_part_of_input = True
943
+ has_label = False
944
+ if labels:
945
+ has_label = True
946
+
947
+ new_labels = []
948
+ num_sentences = 1
949
+ sentence_length = 1
950
+ for sample_idx in range(batch_size):
951
+ # seq
952
+ if seq_part_of_input:
953
+ if multi_seq_flag:
954
+ # cls_idx 已经添加
955
+ pass
956
+ elif not molecule_flag and self.prepend_bos:
957
+ input_ids[sample_idx, 0] = self.cls_idx
958
+ elif molecule_flag and self.atom_prepend_bos:
959
+ input_ids[sample_idx, 0] = self.atom_cls_idx
960
+
961
+ seq_encoded = seq_encoded_list[sample_idx]
962
+ real_seq_len = len(seq_encoded)
963
+
964
+ # seq_tensor = torch.tensor(seq_encoded, dtype=torch.int64)
965
+ # print("seq_encoded:")
966
+ # print(seq_encoded)
967
+ if multi_seq_flag:
968
+ cur_seq_num = min(len(seq_encoded), seq_max_num)
969
+ if len(seq_encoded) > cur_seq_num:
970
+ if self.trunc_type == "left":
971
+ seq_encoded = seq_encoded[-cur_seq_num:]
972
+ else:
973
+ seq_encoded = seq_encoded[cur_seq_num:]
974
+ if num_sentences < cur_seq_num:
975
+ num_sentences = cur_seq_num
976
+ # print("cur_seq_num: %d" % len(seq_encoded))
977
+ for seq_idx in range(cur_seq_num):
978
+ cur_seq = seq_encoded[seq_idx]
979
+ cur_seq_len = min(len(cur_seq), seq_max_len)
980
+ '''
981
+ print("cur_seq:")
982
+ print(cur_seq_len)
983
+ print("input_ids:")
984
+ print(input_ids.shape)
985
+ '''
986
+ input_ids[sample_idx, seq_idx, :cur_seq_len] = torch.tensor(cur_seq[:cur_seq_len], dtype=torch.int64)
987
+ seq_attention_masks[sample_idx, seq_idx, :cur_seq_len] = 1
988
+ if cur_seq_len > sentence_length:
989
+ sentence_length = cur_seq_len
990
+ elif molecule_flag:
991
+ seq_tensor = torch.tensor(seq_encoded, dtype=torch.int64)
992
+ input_ids[sample_idx, int(self.atom_prepend_bos): real_seq_len + int(self.atom_prepend_bos)] = seq_tensor
993
+ cur_sentence_length = int(self.atom_prepend_bos) + real_seq_len + int(self.atom_prepend_bos)
994
+ if cur_sentence_length > sentence_length:
995
+ sentence_length = cur_sentence_length
996
+ else:
997
+ seq_tensor = torch.tensor(seq_encoded, dtype=torch.int64)
998
+ input_ids[sample_idx, int(self.prepend_bos): real_seq_len + int(self.prepend_bos)] = seq_tensor
999
+ cur_sentence_length = int(self.prepend_bos) + real_seq_len + int(self.prepend_bos)
1000
+ if cur_sentence_length > sentence_length:
1001
+ sentence_length = cur_sentence_length
1002
+
1003
+ if multi_seq_flag:
1004
+ # eos_idx 已经添加
1005
+ pass
1006
+ elif not molecule_flag and self.append_eos:
1007
+ input_ids[sample_idx, real_seq_len + int(self.prepend_bos)] = self.eos_idx
1008
+ elif molecule_flag and self.atom_append_eos:
1009
+ input_ids[sample_idx, real_seq_len + int(self.atom_prepend_bos)] = self.atom_eos_idx
1010
+
1011
+ if multi_seq_flag:
1012
+ cur_len = num_sentences * sentence_length
1013
+ elif molecule_flag:
1014
+ cur_len = int(self.atom_prepend_bos) + real_seq_len + int(self.atom_append_eos)
1015
+ else:
1016
+ cur_len = int(self.prepend_bos) + real_seq_len + int(self.append_eos)
1017
+
1018
+ if not self.no_position_embeddings:
1019
+ if multi_seq_flag:
1020
+ for pos_idx in range(0, cur_len):
1021
+ position_ids[sample_idx, pos_idx//sentence_length, pos_idx % sentence_length] = pos_idx % sentence_length
1022
+ else:
1023
+ for pos_idx in range(0, cur_len):
1024
+ position_ids[sample_idx, pos_idx] = pos_idx
1025
+
1026
+ if not self.no_token_type_embeddings:
1027
+ seq_type = seq_types[sample_idx]
1028
+ if seq_type == "gene":
1029
+ type_value = 0
1030
+ else:
1031
+ type_value = 1
1032
+ if multi_seq_flag:
1033
+ for pos_idx in range(0, cur_len):
1034
+ token_type_ids[sample_idx, pos_idx//sentence_length, pos_idx % sentence_length] = type_value
1035
+ else:
1036
+ for pos_idx in range(0, cur_len):
1037
+ token_type_ids[sample_idx, pos_idx] = type_value
1038
+
1039
+ if multi_seq_flag:
1040
+ pass
1041
+ else:
1042
+ seq_attention_masks[sample_idx, 0: cur_len] = 1
1043
+
1044
+ # vector
1045
+ if vector_part_of_input:
1046
+ if multi_seq_flag:
1047
+ cur_vector_num = min(len(vectors[sample_idx]), vector_max_num)
1048
+ if num_sentences < cur_vector_num:
1049
+ num_sentences = cur_vector_num
1050
+ for vector_idx in range(cur_vector_num):
1051
+ encoded_vectors[sample_idx, vector_idx, :] = torch.tensor(vectors[sample_idx][vector_idx], dtype=torch.float32)
1052
+ else:
1053
+ encoded_vectors[sample_idx, :] = torch.tensor(vectors[sample_idx], dtype=torch.float32)
1054
+
1055
+ # matrix
1056
+ if matrix_part_of_input:
1057
+ '''
1058
+ matrix_encoded = matrices[sample_idx]
1059
+ if self.matrix_add_special_token:
1060
+ real_seq_len = matrix_encoded.shape[0] - 2
1061
+ else:
1062
+ real_seq_len = matrix_encoded.shape[0]
1063
+ if multi_seq_flag:
1064
+ pass
1065
+ elif molecule_flag:
1066
+ # real_seq_len = real_seq_len - int(self.atom_prepend_bos) - int(self.atom_append_eos)
1067
+ real_seq_len = min(real_seq_len, self.atom_truncation_matrix_length)
1068
+ else:
1069
+ # real_seq_len = real_seq_len - int(self.prepend_bos) - int(self.append_eos)
1070
+ real_seq_len = min(real_seq_len, self.truncation_matrix_length)
1071
+ # print("real_seq_len: %d" % real_seq_len)
1072
+ '''
1073
+ if multi_seq_flag:
1074
+ # 多序列matrix
1075
+ matrix_encoded_list = matrices[sample_idx]
1076
+ cur_matrix_num = min(len(matrix_encoded_list), matrix_max_num)
1077
+ if len(matrix_encoded_list) > cur_matrix_num:
1078
+ if self.trunc_type == "left":
1079
+ matrix_encoded_list = matrix_encoded_list[:cur_matrix_num]
1080
+ else:
1081
+ matrix_encoded_list = matrix_encoded_list[-cur_matrix_num:]
1082
+ if num_sentences < cur_matrix_num:
1083
+ num_sentences = cur_matrix_num
1084
+ # print("matrix_encoded_list: %d" % len(matrix_encoded_list))
1085
+ for matrix_idx in range(cur_matrix_num):
1086
+ # print("matrix_idx: %d" % matrix_idx)
1087
+ cur_matrix = matrix_encoded_list[matrix_idx]
1088
+ cur_matrix = torch.tensor(cur_matrix, dtype=torch.float32)
1089
+ cur_matrix_len = min(cur_matrix.shape[0], matrix_max_len)
1090
+ if self.matrix_add_special_token:
1091
+ encoded_matrices[sample_idx, matrix_idx, 0: cur_matrix_len - 1] = cur_matrix[0:cur_matrix_len - 1]
1092
+ encoded_matrices[sample_idx, matrix_idx, cur_matrix_len - 1] = cur_matrix[-1]
1093
+ matrix_attention_masks[sample_idx, matrix_idx, 0:cur_matrix_len] = 1
1094
+ else:
1095
+ encoded_matrices[sample_idx, matrix_idx, int(self.prepend_bos): cur_matrix_len + int(self.prepend_bos)] = cur_matrix[0:cur_matrix_len]
1096
+ matrix_attention_masks[sample_idx, matrix_idx, 0: int(self.prepend_bos) + cur_matrix_len + int(self.append_eos)] = 1
1097
+ cur_matrix_len = int(self.prepend_bos) + cur_matrix_len + int(self.append_eos)
1098
+ if sentence_length < cur_matrix_len:
1099
+ sentence_length = cur_matrix_len
1100
+ else:
1101
+ matrix_encoded = matrices[sample_idx]
1102
+ if self.matrix_add_special_token:
1103
+ real_seq_len = matrix_encoded.shape[0] - 2
1104
+ else:
1105
+ real_seq_len = matrix_encoded.shape[0]
1106
+ if molecule_flag:
1107
+ # real_seq_len = real_seq_len - int(self.atom_prepend_bos) - int(self.atom_append_eos)
1108
+ real_seq_len = min(real_seq_len, self.atom_truncation_matrix_length)
1109
+ matrix = torch.tensor(matrix_encoded, dtype=torch.float32)
1110
+ if self.matrix_add_special_token:
1111
+ encoded_matrices[sample_idx, 0: real_seq_len + 2] \
1112
+ = matrix[0: real_seq_len + 2]
1113
+ matrix_attention_masks[sample_idx, 0: real_seq_len + 2] = 1
1114
+ cur_sentence_length = real_seq_len + 2
1115
+ else:
1116
+ encoded_matrices[sample_idx, int(self.atom_prepend_bos): real_seq_len + int(self.atom_prepend_bos)] \
1117
+ = matrix[0: real_seq_len]
1118
+ # matrix_attention_masks[sample_idx, int(self.atom_prepend_bos): real_seq_len + int(self.atom_prepend_bos)] = 1
1119
+ matrix_attention_masks[sample_idx, 0: int(self.atom_prepend_bos) + real_seq_len + int(self.atom_append_eos)] = 1
1120
+ cur_sentence_length = int(self.atom_prepend_bos) + real_seq_len + int(self.atom_prepend_bos)
1121
+ if cur_sentence_length > sentence_length:
1122
+ sentence_length = cur_sentence_length
1123
+ else:
1124
+ # real_seq_len = real_seq_len - int(self.prepend_bos) - int(self.append_eos)
1125
+ real_seq_len = min(real_seq_len, self.truncation_matrix_length)
1126
+ matrix = torch.tensor(matrix_encoded, dtype=torch.float32)
1127
+ if self.matrix_add_special_token:
1128
+ encoded_matrices[sample_idx, 0: real_seq_len + 2] = matrix[0: real_seq_len + 2]
1129
+ matrix_attention_masks[sample_idx, 0: real_seq_len + 2] = 1
1130
+ cur_sentence_length = real_seq_len + 2
1131
+ else:
1132
+ encoded_matrices[sample_idx, int(self.prepend_bos): real_seq_len + int(self.prepend_bos)] = matrix[0: real_seq_len]
1133
+ # matrix_attention_masks[sample_idx, int(self.prepend_bos): real_seq_len + int(self.prepend_bos)] = 1
1134
+ matrix_attention_masks[sample_idx, 0: int(self.prepend_bos) + real_seq_len + int(self.append_eos)] = 1
1135
+ cur_sentence_length = int(self.prepend_bos) + real_seq_len + int(self.prepend_bos)
1136
+ if cur_sentence_length > sentence_length:
1137
+ sentence_length = cur_sentence_length
1138
+
1139
+ if has_label:
1140
+ if multi_seq_flag:
1141
+ # to do
1142
+ new_labels.append(
1143
+ self.__parse_label__(max_length, self.task_level_type,
1144
+ self.label_size, self.output_mode, labels[sample_idx]))
1145
+ elif molecule_flag:
1146
+ new_labels.append(
1147
+ self.__atom_parse_label__(max_length, self.task_level_type,
1148
+ self.label_size, self.output_mode, labels[sample_idx]))
1149
+ else:
1150
+ new_labels.append(
1151
+ self.__parse_label__(max_length, self.task_level_type,
1152
+ self.label_size, self.output_mode, labels[sample_idx]))
1153
+ if new_labels is not None and new_labels:
1154
+ if self.output_mode in ["regression"]:
1155
+ labels = torch.tensor(new_labels, dtype=torch.float32)
1156
+ else:
1157
+ labels = torch.tensor(new_labels, dtype=torch.int64)
1158
+ else:
1159
+ labels = None
1160
+ '''
1161
+ print(input_ids.shape)
1162
+ print("encoded_matrices:")
1163
+ print(encoded_matrices.shape)
1164
+ print("num_sentences:%d" % num_sentences)
1165
+ print("sentence_length:%d" % sentence_length)
1166
+ if labels is not None:
1167
+ print("labels:")
1168
+ print(labels.shape)
1169
+ '''
1170
+
1171
+ if multi_seq_flag:
1172
+ if seq_part_of_input:
1173
+ input_ids = torch.reshape(input_ids, (input_ids.shape[0], -1))
1174
+ if matrix_part_of_input:
1175
+ encoded_matrices = torch.reshape(encoded_matrices, (encoded_matrices.shape[0], -1, encoded_matrices.shape[-1]))
1176
+ if position_ids is not None:
1177
+ position_ids = torch.reshape(position_ids, (position_ids.shape[0], -1))
1178
+ if token_type_ids is not None:
1179
+ token_type_ids = torch.reshape(token_type_ids, (token_type_ids.shape[0], -1))
1180
+ if seq_attention_masks is not None:
1181
+ seq_attention_masks = torch.reshape(seq_attention_masks, (seq_attention_masks.shape[0], -1))
1182
+ if matrix_attention_masks is not None:
1183
+ matrix_attention_masks = torch.reshape(matrix_attention_masks, (matrix_attention_masks.shape[0], -1))
1184
+ '''
1185
+ print(input_ids.shape)
1186
+ print("encoded_matrices:")
1187
+ print(encoded_matrices.shape)
1188
+ print("num_sentences:%d" % num_sentences)
1189
+ print("sentence_length:%d" % sentence_length)
1190
+ if labels is not None:
1191
+ print("labels:")
1192
+ print(labels.shape)
1193
+ print("-" * 50)
1194
+ '''
1195
+
1196
+ return input_ids, \
1197
+ position_ids, \
1198
+ token_type_ids, \
1199
+ seq_attention_masks, \
1200
+ encoded_vectors, \
1201
+ encoded_matrices, \
1202
+ matrix_attention_masks, \
1203
+ num_sentences, \
1204
+ sentence_length, \
1205
+ labels
1206
+
1207
+ def __call__(self, raw_batch: Sequence[dict]):
1208
+ batch_size = len(raw_batch)
1209
+ # pair
1210
+ if "seq_id_a" in raw_batch[0] and "seq_id_b" in raw_batch[0]:
1211
+ res = {}
1212
+ # seq_ids_a = []
1213
+ seq_types_a = []
1214
+ seqs_a = []
1215
+ vectors_a = []
1216
+ matrices_a = []
1217
+
1218
+ # seq_ids_b = []
1219
+ seq_types_b = []
1220
+ seqs_b = []
1221
+ vectors_b = []
1222
+ matrices_b = []
1223
+
1224
+ labels = []
1225
+ for item in raw_batch:
1226
+ # seq_ids_a.append(item["seq_id_a"])
1227
+ seq_types_a.append(item["seq_type_a"])
1228
+ if item["seq_a"] is not None:
1229
+ seqs_a.append(item["seq_a"])
1230
+ if item["vector_a"] is not None:
1231
+ vectors_a.append(item["vector_a"])
1232
+ if item["matrix_a"] is not None:
1233
+ matrices_a.append(item["matrix_a"])
1234
+
1235
+ # seq_ids_b.append(item["seq_id_b"])
1236
+ seq_types_b.append(item["seq_type_b"])
1237
+ if item["seq_b"] is not None:
1238
+ seqs_b.append(item["seq_b"])
1239
+ if item["vector_b"] is not None:
1240
+ vectors_b.append(item["vector_b"])
1241
+ if item["matrix_b"] is not None:
1242
+ matrices_b.append(item["matrix_b"])
1243
+ if "label" in item and item["label"] is not None:
1244
+ labels.append(item["label"])
1245
+ input_ids_a, position_ids_a, token_type_ids_a, seq_attention_masks_a, encoded_vectors_a, encoded_matrices_a, matrix_attention_masks_a, num_sentences_a, sentence_length_a, labels \
1246
+ = self.__call_single__(batch_size, seq_types_a, seqs_a, vectors_a, matrices_a, labels)
1247
+ if not hasattr(self, "max_sentences") or self.max_sentences is None:
1248
+ res.update({
1249
+ "input_ids_a": input_ids_a,
1250
+ "position_ids_a": position_ids_a,
1251
+ "token_type_ids_a": token_type_ids_a,
1252
+ "seq_attention_masks_a": seq_attention_masks_a,
1253
+ "vectors_a": encoded_vectors_a,
1254
+ "matrices_a": encoded_matrices_a,
1255
+ "matrix_attention_masks_a": matrix_attention_masks_a,
1256
+ "labels": labels if labels is not None and len(labels) > 0 else None
1257
+ })
1258
+ else:
1259
+ res.update({
1260
+ "input_ids_a": input_ids_a,
1261
+ "position_ids_a": position_ids_a,
1262
+ "token_type_ids_a": token_type_ids_a,
1263
+ "seq_attention_masks_a": seq_attention_masks_a,
1264
+ "vectors_a": encoded_vectors_a,
1265
+ "matrices_a": encoded_matrices_a,
1266
+ "matrix_attention_masks_a": matrix_attention_masks_a,
1267
+ "num_sentences_a": num_sentences_a,
1268
+ "sentence_length_a": sentence_length_a,
1269
+ "labels": labels if labels is not None and len(labels) > 0 else None
1270
+ })
1271
+ input_ids_b, position_ids_b, token_type_ids_b, seq_attention_masks_b, encoded_vectors_b, encoded_matrices_b, matrix_attention_masks_b, num_sentences_b, sentence_length_b, _ \
1272
+ = self.__call_single__(batch_size, seq_types_b, seqs_b, vectors_b, matrices_b, labels=None)
1273
+ if not hasattr(self, "max_sentences") or self.max_sentences is None:
1274
+ res.update({
1275
+ "input_ids_b": input_ids_b,
1276
+ "position_ids_b": position_ids_b,
1277
+ "token_type_ids_b": token_type_ids_b,
1278
+ "seq_attention_masks_b": seq_attention_masks_b,
1279
+ "vectors_b": encoded_vectors_b,
1280
+ "matrices_b": encoded_matrices_b,
1281
+ "matrix_attention_masks_b": matrix_attention_masks_b
1282
+ })
1283
+ else:
1284
+ res.update({
1285
+ "input_ids_b": input_ids_b,
1286
+ "position_ids_b": position_ids_b,
1287
+ "token_type_ids_b": token_type_ids_b,
1288
+ "seq_attention_masks_b": seq_attention_masks_b,
1289
+ "vectors_b": encoded_vectors_b,
1290
+ "matrices_b": encoded_matrices_b,
1291
+ "num_sentences_b": num_sentences_b,
1292
+ "sentence_length_b": sentence_length_b,
1293
+ "matrix_attention_masks_b": matrix_attention_masks_b
1294
+ })
1295
+ return res
1296
+ else:
1297
+ res = {}
1298
+ # seq_ids = []
1299
+ seq_types = []
1300
+ seqs = []
1301
+ vectors = []
1302
+ matrices = []
1303
+ labels = []
1304
+ for item in raw_batch:
1305
+ # seq_ids.append(item["seq_id"])
1306
+ seq_types.append(item["seq_type"])
1307
+ if item["seq"] is not None:
1308
+ seqs.append(item["seq"])
1309
+ if item["vector"] is not None:
1310
+ vectors.append(item["vector"])
1311
+ if item["matrix"] is not None:
1312
+ matrices.append(item["matrix"])
1313
+ if item["label"] is not None:
1314
+ labels.append(item["label"])
1315
+ '''
1316
+ print("seqs:")
1317
+ print(seqs)
1318
+ print([len(seq) for seq in seqs])
1319
+ print("matrices:")
1320
+ print(matrices)
1321
+ print([matrix.shape for matrix in matrices])
1322
+ print("labels:")
1323
+ print(labels)
1324
+ print([len(eval(label)) for label in labels])
1325
+ '''
1326
+ input_ids, position_ids, token_type_ids, seq_attention_masks, encoded_vectors, encoded_matrices, matrix_attention_masks, num_sentences, sentence_length, labels = self.__call_single__(
1327
+ batch_size, seq_types, seqs, vectors, matrices, labels=labels)
1328
+
1329
+ if not hasattr(self, "max_sentences") or self.max_sentences is None:
1330
+ res.update({
1331
+ "input_ids": input_ids,
1332
+ "position_ids": position_ids,
1333
+ "token_type_ids": token_type_ids,
1334
+ "seq_attention_masks": seq_attention_masks,
1335
+ "vectors": encoded_vectors,
1336
+ "matrices": encoded_matrices,
1337
+ "matrix_attention_masks": matrix_attention_masks,
1338
+ "labels": labels if labels is not None and len(labels) > 0 else None
1339
+ })
1340
+ else:
1341
+ res.update({
1342
+ "input_ids": input_ids,
1343
+ "position_ids": position_ids,
1344
+ "token_type_ids": token_type_ids,
1345
+ "seq_attention_masks": seq_attention_masks,
1346
+ "vectors": encoded_vectors,
1347
+ "matrices": encoded_matrices,
1348
+ "matrix_attention_masks": matrix_attention_masks,
1349
+ "num_sentences": num_sentences,
1350
+ "sentence_length": sentence_length,
1351
+ "labels": labels if labels is not None and len(labels) > 0 else None
1352
+ })
1353
+
1354
+ '''
1355
+ for item in res.items():
1356
+ key_name = item[0]
1357
+ print(key_name, ":")
1358
+ if item[1] is not None:
1359
+ print(item[1])
1360
+ print(item[1].shape)
1361
+ else:
1362
+ print("None")
1363
+ '''
1364
+ return res
1365
+
classification_loss.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ '''
4
+ @license: (C) Copyright 2021, Hey.
5
+ @author: Hey
6
+ @email: sanyuan.hy@alibaba-inc.com
7
+ @tel: 137****6540
8
+ @datetime: 2023/5/3 20:35
9
+ @project: LucaOne
10
+ @file: loss.py
11
+ @desc: loss
12
+ '''
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+ from .masked_loss import _MaskedLoss
17
+
18
+ class MaskedFocalLoss(_MaskedLoss):
19
+ """Masked FocalLoss"""
20
+ def __init__(self, alpha=1, gamma=2, normalization=False, reduction='mean', ignore_nans=True, ignore_value=-100):
21
+ super().__init__(reduction=reduction, ignore_nans=ignore_nans, ignore_value=ignore_value)
22
+ self.criterion = FocalLoss(alpha=alpha, gamma=gamma, normalization=normalization, reduction='none')
23
+
24
+
25
+ class FocalLoss(nn.Module):
26
+ '''
27
+ Focal loss
28
+ '''
29
+ def __init__(self, alpha=1, gamma=2, normalization=False, reduction="mean"):
30
+ super(FocalLoss, self).__init__()
31
+ self.alpha = alpha
32
+ self.gamma = gamma
33
+ self.normalization = normalization
34
+ self.reduction = reduction
35
+
36
+ def forward(self, inputs, targets):
37
+ if self.normalization:
38
+ '''
39
+ reduction: the operation on the output loss, which can be set to 'none', 'mean', and 'sum';
40
+ 'none' will not perform any processing on the loss,
41
+ 'mean' will calculate the mean of the loss,
42
+ 'sum' will sum the loss, and the default is 'mean'
43
+ '''
44
+ bce = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
45
+ probs = torch.sigmoid(inputs)
46
+ else:
47
+ bce = F.binary_cross_entropy(inputs, targets, reduction='none')
48
+ probs = inputs
49
+ pt = targets * probs + (1 - targets) * (1 - probs)
50
+ modulate = 1 if self.gamma is None else (1 - pt) ** self.gamma
51
+
52
+ focal_loss = modulate * bce
53
+
54
+ if self.alpha is not None:
55
+ assert 0 <= self.alpha <= 1
56
+ alpha_weights = targets * self.alpha + (1 - targets) * (1 - self.alpha)
57
+ focal_loss *= alpha_weights
58
+ if self.reduction == "mean":
59
+ # global mean
60
+ return torch.mean(focal_loss)
61
+ if self.reduction in ["summean", "meansum"]:
62
+ # sum of all samples and calc the mean value
63
+ return torch.mean(torch.sum(focal_loss, dim=1))
64
+ elif self.reduction == "sum":
65
+ return torch.sum(focal_loss, dim=1)
66
+ else:
67
+ return focal_loss
68
+
69
+
70
+ class MaskedMultiLabelCCE(_MaskedLoss):
71
+ """Masked MultiLabel CCE"""
72
+ def __init__(self, normalization=False, reduction='mean', ignore_nans=True, ignore_value=-100):
73
+ super().__init__(reduction=reduction, ignore_nans=ignore_nans, ignore_value=ignore_value)
74
+ self.criterion = MultiLabelCCE(normalization=normalization, reduction='none')
75
+
76
+
77
+ class MultiLabelCCE(nn.Module):
78
+ '''
79
+ Multi Label CCE
80
+ '''
81
+ def __init__(self, normalization=False, reduction='mean'):
82
+ super(MultiLabelCCE, self).__init__()
83
+ self.normalization = normalization
84
+ self.reduction = reduction
85
+
86
+ def forward(self, inputs, targets):
87
+ """
88
+ Cross entropy of multi-label classification
89
+ Note:The shapes of y_true and y_pred are consistent, and the elements of y_true are either 0 or 1. 1 indicates
90
+ that the corresponding class is a target class, and 0 indicates that the corresponding class is a non-target class.
91
+ """
92
+ if self.normalization:
93
+ y_pred = torch.softmax(inputs, dim=-1)
94
+ else:
95
+ y_pred = inputs
96
+ y_true = targets
97
+ y_pred = (1 - 2 * y_true) * y_pred
98
+ y_pred_neg = y_pred - y_true * 1e12
99
+ y_pred_pos = y_pred - (1 - y_true) * 1e12
100
+ zeros = torch.zeros_like(y_pred[..., :1])
101
+ y_pred_neg = torch.cat((y_pred_neg, zeros), axis=-1)
102
+ y_pred_pos = torch.cat((y_pred_pos, zeros), axis=-1)
103
+ neg_loss = torch.logsumexp(y_pred_neg, axis=-1)
104
+ pos_loss = torch.logsumexp(y_pred_pos, axis=-1)
105
+ if self.reduction == 'mean':
106
+ return torch.mean(neg_loss + pos_loss)
107
+ elif self.reduction == 'sum':
108
+ return torch.sum(neg_loss + pos_loss)
109
+ else:
110
+ return neg_loss + pos_loss
111
+
112
+
113
+ class MaskedAsymmetricLoss(_MaskedLoss):
114
+ """Masked AsymmetricLoss"""
115
+ def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False, reduction='mean', ignore_nans=True, ignore_value=-100):
116
+ super().__init__(reduction=reduction, ignore_nans=ignore_nans, ignore_value=ignore_value)
117
+ self.criterion = AsymmetricLoss(gamma_neg, gamma_pos, clip, eps, disable_torch_grad_focal_loss)
118
+
119
+
120
+ class AsymmetricLoss(nn.Module):
121
+ def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=True):
122
+ super(AsymmetricLoss, self).__init__()
123
+
124
+ self.gamma_neg = gamma_neg
125
+ self.gamma_pos = gamma_pos
126
+ self.clip = clip
127
+ self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
128
+ self.eps = eps
129
+
130
+ def forward(self, x, y):
131
+ """"
132
+ Parameters
133
+ ----------
134
+ x: input logits
135
+ y: targets (multi-label binarized vector)
136
+ """
137
+
138
+ # Calculating Probabilities
139
+ x_sigmoid = torch.sigmoid(x)
140
+ xs_pos = x_sigmoid
141
+ xs_neg = 1 - x_sigmoid
142
+
143
+ # Asymmetric Clipping
144
+ if self.clip is not None and self.clip > 0:
145
+ xs_neg = (xs_neg + self.clip).clamp(max=1)
146
+
147
+ # Basic CE calculation
148
+ los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
149
+ los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
150
+ loss = los_pos + los_neg
151
+
152
+ # Asymmetric Focusing
153
+ if self.gamma_neg > 0 or self.gamma_pos > 0:
154
+ if self.disable_torch_grad_focal_loss:
155
+ torch.set_grad_enabled(False)
156
+ pt0 = xs_pos * y
157
+ pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
158
+ pt = pt0 + pt1
159
+ one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
160
+ one_sided_w = torch.pow(1 - pt, one_sided_gamma)
161
+ if self.disable_torch_grad_focal_loss:
162
+ torch.set_grad_enabled(True)
163
+ loss *= one_sided_w
164
+
165
+ return -loss.sum()
166
+
167
+
168
+ class MaskedAsymmetricLossOptimized(_MaskedLoss):
169
+ """Masked ASLSingleLabel loss"""
170
+ def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False, reduction='mean', ignore_nans=True, ignore_value=-100):
171
+ super().__init__(reduction=reduction, ignore_nans=ignore_nans, ignore_value=ignore_value)
172
+ self.criterion = AsymmetricLossOptimized(gamma_neg, gamma_pos, clip, eps, disable_torch_grad_focal_loss)
173
+
174
+
175
+ class AsymmetricLossOptimized(nn.Module):
176
+ '''
177
+ Notice - optimized version, minimizes memory allocation and gpu uploading,
178
+ favors inplace operations
179
+ '''
180
+
181
+ def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False):
182
+ super(AsymmetricLossOptimized, self).__init__()
183
+
184
+ self.gamma_neg = gamma_neg
185
+ self.gamma_pos = gamma_pos
186
+ self.clip = clip
187
+ self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
188
+ self.eps = eps
189
+
190
+ # prevent memory allocation and gpu uploading every iteration, and encourages inplace operations
191
+ self.targets = self.anti_targets = self.xs_pos = self.xs_neg = self.asymmetric_w = self.loss = None
192
+
193
+ def forward(self, x, y):
194
+ """"
195
+ Parameters
196
+ ----------
197
+ x: input logits
198
+ y: targets (multi-label binarized vector)
199
+ """
200
+
201
+ self.targets = y
202
+ self.anti_targets = 1 - y
203
+
204
+ # Calculating Probabilities
205
+ self.xs_pos = torch.sigmoid(x)
206
+ self.xs_neg = 1.0 - self.xs_pos
207
+
208
+ # Asymmetric Clipping
209
+ if self.clip is not None and self.clip > 0:
210
+ self.xs_neg.add_(self.clip).clamp_(max=1)
211
+
212
+ # Basic CE calculation
213
+ self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps))
214
+ self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps)))
215
+
216
+ # Asymmetric Focusing
217
+ if self.gamma_neg > 0 or self.gamma_pos > 0:
218
+ if self.disable_torch_grad_focal_loss:
219
+ torch.set_grad_enabled(False)
220
+ self.xs_pos = self.xs_pos * self.targets
221
+ self.xs_neg = self.xs_neg * self.anti_targets
222
+ self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg,
223
+ self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets)
224
+ if self.disable_torch_grad_focal_loss:
225
+ torch.set_grad_enabled(True)
226
+ self.loss *= self.asymmetric_w
227
+
228
+ return -self.loss.sum()
229
+
230
+
231
+ class MaskedASLSingleLabel(_MaskedLoss):
232
+ """Masked ASLSingleLabel loss"""
233
+ def __init__(self, gamma_pos=0, gamma_neg=4, eps: float = 0.1, reduction='mean', ignore_nans=True, ignore_value=-100):
234
+ super().__init__(reduction=reduction, ignore_nans=ignore_nans, ignore_value=ignore_value)
235
+ self.criterion = ASLSingleLabel(gamma_pos, gamma_neg, eps, reduction='none')
236
+
237
+
238
+ class ASLSingleLabel(nn.Module):
239
+ '''
240
+ This loss is intended for single-label classification problems(multi-class)
241
+ '''
242
+ def __init__(self, gamma_pos=0, gamma_neg=4, eps: float = 0.1, reduction='mean'):
243
+ super(ASLSingleLabel, self).__init__()
244
+
245
+ self.eps = eps
246
+ self.logsoftmax = nn.LogSoftmax(dim=-1)
247
+ self.targets_classes = []
248
+ self.gamma_pos = gamma_pos
249
+ self.gamma_neg = gamma_neg
250
+ self.reduction = reduction
251
+
252
+ def forward(self, inputs, target):
253
+ '''
254
+ "input" dimensions: - (batch_size, number_classes)
255
+ "target" dimensions: - (batch_size)
256
+ '''
257
+ num_classes = inputs.size()[-1]
258
+ log_preds = self.logsoftmax(inputs)
259
+ self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1)
260
+
261
+ # ASL weights
262
+ targets = self.targets_classes
263
+ anti_targets = 1 - targets
264
+ xs_pos = torch.exp(log_preds)
265
+ xs_neg = 1 - xs_pos
266
+ xs_pos = xs_pos * targets
267
+ xs_neg = xs_neg * anti_targets
268
+ asymmetric_w = torch.pow(1 - xs_pos - xs_neg, self.gamma_pos * targets + self.gamma_neg * anti_targets)
269
+ log_preds = log_preds * asymmetric_w
270
+
271
+ if self.eps > 0:
272
+ # label smoothing
273
+ self.targets_classes = self.targets_classes.mul(1 - self.eps).add(self.eps / num_classes)
274
+
275
+ # loss calculation
276
+ loss = - self.targets_classes.mul(log_preds)
277
+
278
+ loss = loss.sum(dim=-1)
279
+ if self.reduction == 'mean':
280
+ loss = loss.mean()
281
+
282
+ return loss
283
+
284
+
285
+ class MaskedBCEWithLogitsLoss(_MaskedLoss):
286
+ """Masked MSE loss"""
287
+ def __init__(self, pos_weight=None, weight=None, reduction='mean', ignore_nans=True, ignore_value=-100):
288
+ super().__init__(reduction=reduction, ignore_nans=ignore_nans, ignore_value=ignore_value)
289
+ self.criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight, weight=weight, reduction='none')
290
+
291
+
292
+ class MaskedCrossEntropyLoss(_MaskedLoss):
293
+ """Masked MSE loss"""
294
+ def __init__(self, weight=None, reduction='mean', ignore_nans=True, ignore_value=-100):
295
+ super().__init__(reduction=reduction, ignore_nans=ignore_nans, ignore_value=ignore_value)
296
+ self.criterion = nn.CrossEntropyLoss(weight=weight, reduction='none', ignore_index=ignore_value)
config.json ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alphabet": "gene_prot",
3
+ "architectures": [
4
+ "LucaGPLM"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "lucaone_gplm_config.LucaGPLMConfig",
9
+ "AutoModel": "lucaone_gplm.LucaGPLM"
10
+ },
11
+ "bos_token_id": 2,
12
+ "classifier_dropout": 0.0,
13
+ "classifier_dropout_prob": 0.0,
14
+ "classifier_hidden_act": "gelu",
15
+ "embed_scale": 1.0,
16
+ "eos_token_id": 3,
17
+ "gene_mask_classifier_output_size": 2048,
18
+ "gene_mask_label_num": 39,
19
+ "gene_taxonomy_classifier_output_size": 2048,
20
+ "gene_taxonomy_label_num": 735,
21
+ "gene_type_classifier_output_size": 128,
22
+ "gene_type_label_num": 8,
23
+ "hidden_act": "gelu",
24
+ "hidden_dropout_prob": 0.0,
25
+ "hidden_size": 2560,
26
+ "id2label": {
27
+ "0": "LABEL_0",
28
+ "1": "LABEL_1",
29
+ "2": "LABEL_2"
30
+ },
31
+ "ignore_index": -100,
32
+ "label2id": {
33
+ "LABEL_0": 0,
34
+ "LABEL_1": 1,
35
+ "LABEL_2": 2
36
+ },
37
+ "mask_token_id": 4,
38
+ "max_position_embeddings": 1280,
39
+ "model_type": "lucagplm",
40
+ "no_position_embeddings": true,
41
+ "no_token_type_embeddings": false,
42
+ "num_attention_heads": 40,
43
+ "num_hidden_layers": 20,
44
+ "pad_token_id": 0,
45
+ "prot_contact_classifier_output_size": 3072,
46
+ "prot_domain_classifier_output_size": 10240,
47
+ "prot_domain_label_num": 13717,
48
+ "prot_homo_classifier_output_size": 4096,
49
+ "prot_homo_label_num": 3443,
50
+ "prot_keyword_classifier_output_size": 2048,
51
+ "prot_keyword_label_num": 1179,
52
+ "prot_mask_classifier_output_size": 2048,
53
+ "prot_mask_label_num": 39,
54
+ "prot_secondary_classifier_output_size": 3072,
55
+ "prot_site_classifier_output_size": 1024,
56
+ "prot_site_label_num": 946,
57
+ "prot_structure_classifier_output_size": 128,
58
+ "prot_structure_label_num": 3,
59
+ "prot_taxonomy_classifier_output_size": 2048,
60
+ "prot_taxonomy_label_num": 2196,
61
+ "sep_token_id": 3,
62
+ "token_dropout": false,
63
+ "torch_dtype": "float32",
64
+ "trans_classifier_output_size": 128,
65
+ "transformers_version": "4.29.0",
66
+ "type_vocab_size": 2,
67
+ "unk_token_id": 1,
68
+ "use_embed_layer_norm": false,
69
+ "use_last_layer_norm": true,
70
+ "vocab_size": 39
71
+ }
file_operator.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ import csv, sys
5
+ import io, textwrap, itertools
6
+ from Bio import SeqIO
7
+ from Bio.Seq import Seq
8
+ from Bio.SeqRecord import SeqRecord
9
+ csv.field_size_limit(sys.maxsize)
10
+
11
+
12
+ common_nucleotide_set = {'A', 'T', 'C', 'G', 'U', 'N'}
13
+
14
+ # not {'O', 'U', 'Z', 'J', 'B'}
15
+ # Common amino acids
16
+ common_amino_acid_set = {'R', 'X', 'S', 'G', 'W', 'I', 'Q', 'A', 'T', 'V', 'K', 'Y', 'C', 'N', 'L', 'F', 'D', 'M', 'P', 'H', 'E'}
17
+
18
+
19
+ def clean_seq(protein_id, seq):
20
+ seq = seq.upper()
21
+ new_seq = ""
22
+ has_invalid_char = False
23
+ invalid_char_set = set()
24
+ for ch in seq:
25
+ if 'A' <= ch <= 'Z' and ch not in ['J']:
26
+ new_seq += ch
27
+ else:
28
+ invalid_char_set.add(ch)
29
+ has_invalid_char = True
30
+ if has_invalid_char:
31
+ print("id: %s. Seq: %s" % (protein_id, seq))
32
+ print("invalid char set:", invalid_char_set)
33
+ return new_seq
34
+
35
+
36
+ def file_reader(filename, header=True, header_filter=True):
37
+ if filename.endswith(".fa") or filename.endswith(".fas") or filename.endswith(".fasta"):
38
+ return fasta_reader(filename)
39
+ elif filename.endswith(".csv"):
40
+ return csv_reader(filename, header=True, header_filter=True)
41
+ elif filename.endswith(".tsv"):
42
+ return tsv_reader(filename, header=True, header_filter=True)
43
+ else:
44
+ return txt_reader(filename, header=header, header_filter=header_filter)
45
+
46
+
47
+ def txt_reader(handle, header=True, header_filter=True):
48
+ '''
49
+ csv 读取器,适合大文件
50
+ :param handle:
51
+ :param header:
52
+ :param header_filter: 返回结果是否去掉头
53
+ :return:
54
+ '''
55
+ handle = handle if isinstance(handle, io.TextIOWrapper) else open(handle, 'r')
56
+ try:
57
+ cnt = 0
58
+ for line in handle:
59
+ cnt += 1
60
+ if header and header_filter and cnt == 1:
61
+ continue
62
+ yield line.strip()
63
+ except Exception as e:
64
+ raise StopIteration
65
+ finally:
66
+ if not handle.closed:
67
+ handle.close()
68
+
69
+
70
+ def tsv_reader(handle, header=True, header_filter=True):
71
+ '''
72
+ csv 读取器,适合大文件
73
+ :param handle:
74
+ :param header:
75
+ :param header_filter: 返回结果是否去掉头
76
+ :return:
77
+ '''
78
+ handle = handle if isinstance(handle, io.TextIOWrapper) else open(handle, 'r')
79
+ try:
80
+ reader = csv.reader(handle, delimiter="\t")
81
+ cnt = 0
82
+ for row in reader:
83
+ cnt += 1
84
+ if header and header_filter and cnt == 1:
85
+ continue
86
+ yield row
87
+ except Exception as e:
88
+ raise StopIteration
89
+ finally:
90
+ if not handle.closed:
91
+ handle.close()
92
+
93
+
94
+ def csv_reader(handle, header=True, header_filter=True):
95
+ '''
96
+ csv 读取器,适合大文件
97
+ :param handle:
98
+ :param header:
99
+ :param header_filter: 返回结果是否去掉头
100
+ :return:
101
+ '''
102
+ handle = handle if isinstance(handle, io.TextIOWrapper) else open(handle, 'r')
103
+ try:
104
+ # data = csv.reader((line.replace('\0','') for line in data_initial), delimiter=",")
105
+ # reader = csv.reader(handle)
106
+ reader = csv.reader((line.replace('\0', '') for line in handle))
107
+ cnt = 0
108
+ for row in reader:
109
+ cnt += 1
110
+ if header and header_filter and cnt == 1:
111
+ continue
112
+ yield row
113
+ except Exception as e:
114
+ raise StopIteration
115
+ finally:
116
+ if not handle.closed:
117
+ handle.close()
118
+
119
+
120
+ def txt_writer(dataset, handle, header=None):
121
+ '''
122
+ txt 写
123
+ :param dataset: 数据
124
+ :param handle: 文件
125
+ :param header: 头
126
+ :return:
127
+ '''
128
+ '''
129
+ handle = handle if isinstance(handle, io.TextIOWrapper) else open(handle, 'w')
130
+ try:
131
+ if header:
132
+ if isinstance(header, list):
133
+ handle.write(",".join(header) + "\n")
134
+ else:
135
+ handle.write(header + "\n")
136
+ print("header: %s" %header)
137
+ for row in dataset:
138
+ handle.write(str(row) + "\n")
139
+ except Exception as e:
140
+ raise e
141
+ finally:
142
+ if not handle.closed:
143
+ handle.close()
144
+ '''
145
+ with open(handle, "w") as wfp:
146
+ if header:
147
+ if isinstance(header, list):
148
+ wfp.write(",".join(header) + "\n")
149
+ else:
150
+ wfp.write(header + "\n")
151
+ for row in dataset:
152
+ wfp.write(str(row) + "\n")
153
+
154
+
155
+ def csv_writer(dataset, handle, header):
156
+ '''
157
+ csv 写,适合大文件
158
+ :param dataset: 数据
159
+ :param handle: 文件
160
+ :param header: 头
161
+ :return:
162
+ '''
163
+ handle = handle if isinstance(handle, io.TextIOWrapper) else open(handle, 'w')
164
+ try:
165
+ writer = csv.writer(handle)
166
+ if header:
167
+ writer.writerow(header)
168
+ for row in dataset:
169
+ writer.writerow(row)
170
+ except Exception as e:
171
+ raise e
172
+ finally:
173
+ if not handle.closed:
174
+ handle.close()
175
+
176
+
177
+ def fasta_reader(handle, width=None):
178
+ """
179
+ Reads a FASTA file, yielding header, sequence pairs for each sequence recovered 适合大文件
180
+ args:
181
+ :handle (str, pathliob.Path, or file pointer) - fasta to read from
182
+ :width (int or None) - formats the sequence to have max `width` character per line.
183
+ If <= 0, processed as None. If None, there is no max width.
184
+ yields:
185
+ :(header, sequence) tuples
186
+ returns:
187
+ :None
188
+ """
189
+ FASTA_STOP_CODON = "*"
190
+
191
+ handle = handle if isinstance(handle, io.TextIOWrapper) else open(handle, 'r')
192
+ width = width if isinstance(width, int) and width > 0 else None
193
+ try:
194
+ header = None
195
+ for is_header, group in itertools.groupby(handle, lambda line: line.startswith(">")):
196
+ if is_header:
197
+ header = group.__next__().strip()
198
+ else:
199
+ seq = ''.join(line.strip() for line in group).strip().rstrip(FASTA_STOP_CODON)
200
+ if width is not None:
201
+ seq = textwrap.fill(seq, width)
202
+ yield header, seq
203
+ except Exception as e:
204
+ raise StopIteration
205
+ finally:
206
+ if not handle.closed:
207
+ handle.close()
208
+
209
+
210
+ def write_fasta(filepath, sequences):
211
+ '''
212
+ write fasta file
213
+ :param filepath: savepath
214
+ :param sequences: fasta sequence(each item: [id, seq])
215
+ :return:
216
+ '''
217
+
218
+ if sequences:
219
+ with open(filepath, "w") as output_handle:
220
+ if len(sequences[0]) > 1 and isinstance(sequences[0][0], str):
221
+ for row in sequences:
222
+ protein_id = row[0]
223
+ seq = row[1]
224
+ sequence = SeqRecord(Seq(seq, None), id=protein_id[1:] if protein_id and protein_id[0] == ">" else protein_id, description="")
225
+ SeqIO.write(sequence, output_handle, "fasta")
226
+ else:
227
+ for sequence in sequences:
228
+ SeqIO.write(sequence, output_handle, "fasta")
229
+
230
+
loss.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ '''
4
+ @license: (C) Copyright 2021, Hey.
5
+ @author: Hey
6
+ @email: sanyuan.hy@alibaba-inc.com
7
+ @tel: 137****6540
8
+ @datetime: 2023/5/3 20:35
9
+ @project: LucaOne
10
+ @file: loss.py
11
+ @desc: loss
12
+ '''
13
+ import torch, math
14
+ import torch.nn as nn
15
+
16
+ from .classification_loss import *
17
+ from .regression_loss import *
18
+
19
+
20
+
21
+ class NewGELUActivation(nn.Module):
22
+ """
23
+ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
24
+ the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
25
+ """
26
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
27
+ return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
28
+
29
+
30
+ def create_activate(activate_func):
31
+ if activate_func:
32
+ activate_func = activate_func.lower()
33
+ if activate_func == "tanh":
34
+ return nn.Tanh()
35
+ elif activate_func == "relu":
36
+ return nn.ReLU()
37
+ elif activate_func == "leakyrelu":
38
+ return nn.LeakyReLU()
39
+ elif activate_func == "gelu":
40
+ return nn.GELU()
41
+ elif activate_func == "gelu_new":
42
+ return NewGELUActivation()
43
+ else:
44
+ return nn.Tanh()
45
+
46
+
47
+ def create_loss_function(config,
48
+ args,
49
+ task_level_type,
50
+ task_level_name,
51
+ sigmoid,
52
+ output_mode,
53
+ num_labels,
54
+ loss_type,
55
+ ignore_index=-100,
56
+ pair_level=False,
57
+ return_types=["dropout", "hidden_layer", "hidden_act", "classifier", "output", "loss"]
58
+ ):
59
+ '''
60
+ create the output layer and loss layer
61
+ :param task_level_name:
62
+ :param task_level_type:
63
+ :param pair_level:
64
+ :param config:
65
+ :param args:
66
+ :param sigmoid:
67
+ :param output_mode:
68
+ :param num_labels:
69
+ :param loss_type:
70
+ :param ignore_index:
71
+ :param return_types:
72
+ :return:
73
+ '''
74
+ dropout, hidden_layer, hidden_act, classifier, output, loss_fct = None, None, None, None, None, None
75
+ if "dropout" in return_types:
76
+ if hasattr(config, "classifier_dropout_prob"):
77
+ dropout = nn.Dropout(config.classifier_dropout_prob)
78
+ elif hasattr(config, "dropout_prob"):
79
+ dropout = nn.Dropout(config.dropout_prob)
80
+ else:
81
+ dropout = nn.Dropout(0.1)
82
+
83
+ if pair_level:
84
+ hidden_size = 2 * config.hidden_size
85
+ else:
86
+ hidden_size = config.hidden_size
87
+ if "hidden_layer" in return_types:
88
+ if isinstance(args.classifier_size, int):
89
+ hidden_layer_size = args.classifier_size
90
+ else:
91
+ hidden_layer_size = args.classifier_size[task_level_type][task_level_name]
92
+ hidden_layer = nn.Linear(hidden_size, hidden_layer_size, bias=True)
93
+ hidden_size = hidden_layer_size
94
+
95
+ if "hidden_act" in return_types:
96
+ if hasattr(args, "classifier_hidden_act"):
97
+ hidden_act = create_activate(args.classifier_hidden_act)
98
+ elif hasattr(config, "classifier_hidden_act"):
99
+ hidden_act = create_activate(config.classifier_hidden_act)
100
+
101
+ if "classifier" in return_types:
102
+ if sigmoid:
103
+ if output_mode in ["binary_class", "binary-class"]:
104
+ classifier = nn.Linear(hidden_size, 1, bias=True)
105
+ else:
106
+ classifier = nn.Linear(hidden_size, num_labels, bias=True)
107
+ else:
108
+ classifier = nn.Linear(hidden_size, num_labels, bias=True)
109
+ if "output" in return_types:
110
+ if sigmoid or output_mode in ["multi_label", "multi-label", "binary_class", "binary-class"]:
111
+ output = nn.Sigmoid()
112
+ elif output_mode in ["multi_class", "multi-class"]:
113
+ output = nn.Softmax(dim=-1)
114
+ else:
115
+ output = None
116
+
117
+ if "loss" in return_types:
118
+ # positive weight
119
+ if hasattr(args, "pos_weight") and args.pos_weight:
120
+ pos_weight = args.pos_weight
121
+ elif hasattr(config, "pos_weight") and config.pos_weight:
122
+ pos_weight = config.pos_weight
123
+ else:
124
+ pos_weight = None
125
+
126
+ if hasattr(args, "weight") and args.weight is not None:
127
+ weight = args.weight
128
+ elif hasattr(config, "weight") and config.weight is not None:
129
+ weight = config.weight
130
+ else:
131
+ weight = None
132
+
133
+ reduction = config.loss_reduction if hasattr(config, "loss_reduction") else "meanmean"
134
+ if output_mode in ["regression"]:
135
+ if loss_type == "l2":
136
+ loss_fct = MaskedMSELoss(reduction=reduction, ignore_nans=True,
137
+ ignore_value=ignore_index * 1.0 if ignore_index else None)
138
+ elif loss_type == "l1":
139
+ loss_fct = MaskedL1Loss(reduction=reduction, ignore_nans=True,
140
+ ignore_value=ignore_index * 1.0 if ignore_index else None)
141
+ elif output_mode in ["multi_label", "multi-label"]:
142
+ if loss_type == "bce":
143
+ if pos_weight:
144
+ if isinstance(pos_weight, str) or isinstance(pos_weight, int):
145
+ pos_weight = [float(pos_weight)] * num_labels
146
+ elif isinstance(pos_weight, float):
147
+ pos_weight = [pos_weight] * num_labels
148
+ pos_weight = torch.tensor(pos_weight, dtype=torch.float32).to(args.device)
149
+ print("multi_label pos_weight:")
150
+ print(pos_weight)
151
+ assert pos_weight.ndim == 1 and pos_weight.shape[0] == num_labels
152
+ print("multi_label reduction:")
153
+ print(reduction)
154
+ loss_fct = MaskedBCEWithLogitsLoss(pos_weight=pos_weight, reduction=reduction,
155
+ ignore_nans=True, ignore_value=ignore_index)
156
+ else:
157
+ loss_fct = MaskedBCEWithLogitsLoss(reduction=reduction,
158
+ ignore_nans=True, ignore_value=ignore_index)
159
+ elif loss_type == "asl":
160
+ loss_fct = MaskedAsymmetricLossOptimized(gamma_neg=args.asl_gamma_neg if hasattr(args, "asl_gamma_neg") else 4.0,
161
+ gamma_pos=args.asl_gamma_pos if hasattr(args, "asl_gamma_pos") else 1.0,
162
+ clip=args.clip if hasattr(args, "clip") else 0.05,
163
+ eps=args.eps if hasattr(args, "eps") else 1e-8,
164
+ disable_torch_grad_focal_loss=args.disable_torch_grad_focal_loss if hasattr(args, "disable_torch_grad_focal_loss") else False,
165
+ reduction=reduction,
166
+ ignore_nans=True,
167
+ ignore_value=ignore_index)
168
+ elif loss_type == "focal_loss":
169
+ loss_fct = MaskedFocalLoss(alpha=args.focal_loss_alpha if hasattr(args, "focal_loss_alpha") else 0.7,
170
+ gamma=args.focal_loss_gamma if hasattr(args, "focal_loss_gamma") else 2.0,
171
+ normalization=True,
172
+ reduction=reduction,
173
+ ignore_nans=True,
174
+ ignore_value=ignore_index)
175
+ elif loss_type == "multilabel_cce":
176
+ loss_fct = MaskedMultiLabelCCE(normalization=True,
177
+ reduction=reduction,
178
+ ignore_nans=True,
179
+ ignore_value=ignore_index)
180
+ elif output_mode in ["binary_class", "binary-class"]:
181
+ if loss_type == "bce":
182
+ if pos_weight:
183
+ if isinstance(pos_weight, int) or isinstance(pos_weight, str):
184
+ pos_weight = torch.tensor([float(pos_weight)], dtype=torch.float32).to(args.device)
185
+ elif isinstance(pos_weight, float):
186
+ pos_weight = torch.tensor([pos_weight], dtype=torch.float32).to(args.device)
187
+ print("binary_class pos_weight:")
188
+ print(pos_weight)
189
+ assert pos_weight.ndim == 1 and pos_weight.shape[0] == 1
190
+ loss_fct = MaskedBCEWithLogitsLoss(pos_weight=pos_weight, reduction=reduction, ignore_nans=True,
191
+ ignore_value=ignore_index)
192
+ else:
193
+ loss_fct = MaskedBCEWithLogitsLoss(reduction=reduction, ignore_nans=True, ignore_value=ignore_index)
194
+ elif loss_type == "focal_loss":
195
+ loss_fct = MaskedFocalLoss(alpha=args.focal_loss_alpha if hasattr(args, "focal_loss_alpha") else 0.7,
196
+ gamma=args.focal_loss_gamma if hasattr(args, "focal_loss_gamma") else 2.0,
197
+ normalization=True,
198
+ reduction=reduction,
199
+ ignore_nans=True,
200
+ ignore_value=ignore_index)
201
+ elif output_mode in ["multi_class", "multi-class"]:
202
+ if weight:
203
+ # [1, 1, 1, ,1, 1...] length: num_labels
204
+ if isinstance(weight, str) or isinstance(weight, int):
205
+ weight = [float(weight)] * num_labels
206
+ if isinstance(weight, float):
207
+ weight = [weight] * num_labels
208
+ weight = torch.tensor(weight, dtype=torch.float32).to(args.device)
209
+ print("multi_class weight:")
210
+ print(weight)
211
+ assert weight.ndim == 1 and weight.shape[0] == num_labels
212
+ if ignore_index is None:
213
+ loss_fct = nn.CrossEntropyLoss(weight=weight, reduction=reduction)
214
+ else:
215
+ loss_fct = MaskedCrossEntropyLoss(weight=weight, reduction=reduction, ignore_nans=True, ignore_value=ignore_index)
216
+ else:
217
+ if ignore_index is None:
218
+ loss_fct = nn.CrossEntropyLoss(reduction=reduction)
219
+ else:
220
+ loss_fct = MaskedCrossEntropyLoss(reduction=reduction, ignore_nans=True, ignore_value=ignore_index)
221
+ else:
222
+ raise Exception("Not support output mode: %s." % output_mode)
223
+
224
+ return dropout, hidden_layer, hidden_act, classifier, output, loss_fct
lucaone_gplm.py ADDED
@@ -0,0 +1,605 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ from .loss import *
5
+ from .model_utils import AllOutput, create_output_loss_lucagplm
6
+ from .alphabet import Alphabet
7
+ from .modeling_gplm import *
8
+ from .lucaone_gplm_config import LucaGPLMConfig
9
+ from transformers import PreTrainedModel
10
+
11
+ class LucaGPLM(PreTrainedModel):
12
+ config_class = LucaGPLMConfig
13
+
14
+ def __init__(self, config):
15
+ super().__init__(config)
16
+ self.config = config
17
+ self.max_position_embeddings = config.max_position_embeddings
18
+ self.type_vocab_size = config.type_vocab_size
19
+ self.num_layers = config.num_hidden_layers
20
+ self.embed_dim = config.hidden_size
21
+ self.attention_heads = config.num_attention_heads
22
+ self.no_position_embeddings = config.no_position_embeddings
23
+ self.no_token_type_embeddings = config.no_token_type_embeddings
24
+ if not isinstance(config.alphabet, Alphabet):
25
+ self.alphabet = Alphabet.from_predefined(config.alphabet)
26
+ else:
27
+ self.alphabet = config.alphabet
28
+ self.alphabet_size = len(self.alphabet)
29
+ self.padding_idx = self.alphabet.padding_idx
30
+ self.mask_idx = self.alphabet.mask_idx
31
+ self.cls_idx = self.alphabet.cls_idx
32
+ self.eos_idx = self.alphabet.eos_idx
33
+ self.prepend_bos = self.alphabet.prepend_bos
34
+ self.append_eos = self.alphabet.append_eos
35
+ self.token_dropout = config.token_dropout
36
+ self.ignore_index = config.ignore_index
37
+ self.use_embed_layer_norm = config.use_embed_layer_norm
38
+ self.use_last_layer_norm = config.use_last_layer_norm
39
+ self.embed_scale = config.embed_scale
40
+ self.embedding_inference = True
41
+ self._init_submodules()
42
+
43
+ def _init_submodules(self):
44
+ # normal_(0, 1)
45
+ self.embed_tokens = nn.Embedding(
46
+ self.alphabet_size,
47
+ self.embed_dim,
48
+ padding_idx=self.padding_idx,
49
+ )
50
+ self.embed_pos = None
51
+ if not self.no_position_embeddings:
52
+ self.embed_pos = nn.Embedding(self.max_position_embeddings, self.embed_dim)
53
+ self.embed_type = None
54
+ if not self.no_token_type_embeddings:
55
+ self.embed_type = nn.Embedding(self.type_vocab_size, self.embed_dim)
56
+ if self.use_embed_layer_norm:
57
+ self.embed_layer_norm = LucaGPLM1bLayerNorm(self.embed_dim)
58
+ else:
59
+ self.embed_layer_norm = None
60
+
61
+ self.layers = nn.ModuleList(
62
+ [
63
+ LucaGPLMTransformerLayer(
64
+ self.embed_dim,
65
+ 4 * self.embed_dim,
66
+ self.attention_heads,
67
+ add_bias_kv=False,
68
+ use_lucagplm1b_layer_norm=True,
69
+ use_rotary_embeddings=True,
70
+ )
71
+ for _ in range(self.num_layers)
72
+ ]
73
+ )
74
+ self.layer_size = len(self.layers)
75
+
76
+ if not self.embedding_inference:
77
+ self.contact_head = ContactPredictionHead(
78
+ self.num_layers * self.attention_heads,
79
+ self.prepend_bos,
80
+ self.append_eos,
81
+ eos_idx=self.eos_idx,
82
+ )
83
+ if self.use_last_layer_norm:
84
+ self.last_layer_norm = LucaGPLM1bLayerNorm(self.embed_dim)
85
+ else:
86
+ self.last_layer_norm = None
87
+ if not self.embedding_inference:
88
+ self.lm_head = RobertaLMHead(
89
+ embed_dim=self.embed_dim,
90
+ output_dim=self.alphabet_size,
91
+ weight=self.embed_tokens.weight,
92
+ )
93
+
94
+ def _init_embedding(self, pretrained_token_matrix, token_matrix):
95
+ '''
96
+ 0->2
97
+ 1->0
98
+ 2->3
99
+ 3->1
100
+ 4->10
101
+ ...
102
+ 28->34
103
+ 29->36
104
+ 30->37
105
+ 31->38
106
+ 32->4
107
+ '''
108
+ # print("Load pretrained exists embedding vectors:")
109
+ token_matrix[2, :] = pretrained_token_matrix[0, :]
110
+ token_matrix[0, :] = pretrained_token_matrix[1, :]
111
+ token_matrix[3, :] = pretrained_token_matrix[2, :]
112
+ token_matrix[1, :] = pretrained_token_matrix[3, :]
113
+ for idx in range(10, 35):
114
+ token_matrix[idx, :] = pretrained_token_matrix[idx - 6, :]
115
+ token_matrix[36, :] = pretrained_token_matrix[29, :]
116
+ token_matrix[37, :] = pretrained_token_matrix[30, :]
117
+ token_matrix[38, :] = pretrained_token_matrix[31, :]
118
+ token_matrix[4, :] = pretrained_token_matrix[32, :]
119
+ return token_matrix
120
+
121
+ def _init_submodules_new(self, pretrained_model_name):
122
+ # print("Load pretrained model exists weights:")
123
+ from esm import pretrained
124
+ from collections import OrderedDict
125
+ pretrained, _ = pretrained.load_model_and_alphabet(pretrained_model_name)
126
+ pretrained_state_dict = pretrained.state_dict()
127
+ new_state_dict = OrderedDict()
128
+ our_model_state_dict = {}
129
+ for key, value in self.state_dict().items():
130
+ our_model_state_dict[key] = value
131
+ for name, weight in pretrained_state_dict.items():
132
+ if "final_layer_norm" in name:
133
+ name = name.replace("final_layer_norm", "post_layer_norm")
134
+ elif "self_attn_layer_norm" in name:
135
+ name = name.replace("self_attn_layer_norm", "pre_layer_norm")
136
+ elif "emb_layer_norm_after" in name:
137
+ name = name.replace("emb_layer_norm_after", "last_layer_norm")
138
+ if name.startswith("layers."):
139
+ layer_id = name.split(".")[1]
140
+ if int(layer_id) >= self.num_layers:
141
+ continue
142
+ if name == "embed_tokens.weight":
143
+ new_state_dict[name] = self._init_embedding(weight, our_model_state_dict[name])
144
+ del our_model_state_dict[name]
145
+ elif name in our_model_state_dict and our_model_state_dict[name].shape == weight.shape:
146
+ del our_model_state_dict[name]
147
+ new_state_dict[name] = weight
148
+ '''
149
+ print("Exists layer names:")
150
+ print(new_state_dict.keys())
151
+ print("Not exists Layer names:")
152
+ print(our_model_state_dict.keys())
153
+ '''
154
+ new_state_dict.update(our_model_state_dict)
155
+ self.load_state_dict(new_state_dict)
156
+
157
+ def __calc_loss__(self, task_level_type, output_mode, logits, label, label_size, loss_fct, loss_reduction):
158
+ if output_mode in ["regression"]:
159
+ if task_level_type not in ["seq_level"] and loss_reduction == "meanmean":
160
+ # structure-level regression
161
+ # logits: N, seq_len, 3
162
+ # label: N, seq_len, 3
163
+ loss = loss_fct(logits, label)
164
+ else:
165
+ # structure-level regression
166
+ # logits: N * seq_len * 3
167
+ # label: N * seq_len * 3
168
+ loss = loss_fct(logits.view(-1), label.view(-1))
169
+ elif output_mode in ["multi_label", "multi-label"]:
170
+ # only for seq-level
171
+ if loss_reduction == "meanmean":
172
+ # logits: N , label_size
173
+ # label: N , label_size
174
+ loss = loss_fct(logits, label.float())
175
+ else:
176
+ # logits: N , label_size
177
+ # label: N , label_size
178
+ loss = loss_fct(logits.view(-1, label_size), label.view(-1, label_size).float())
179
+ elif label_size <= 2 or output_mode in ["binary_class", "binary-class"]:
180
+ if task_level_type not in ["seq_level"] and loss_reduction == "meanmean":
181
+ # token-level & meanmean
182
+ # logits: N ,seq_len, 1
183
+ # label: N, seq_len
184
+ loss = loss_fct(logits, label.float())
185
+ else:
186
+ # seq-level || token-level
187
+ # logits: N
188
+ # label: N
189
+ loss = loss_fct(logits.view(-1), label.view(-1).float())
190
+ elif output_mode in ["multi_class", "multi-class"]:
191
+ if task_level_type not in ["seq_level"] and loss_reduction == "meanmean":
192
+ # token-level
193
+ # logits: N ,seq_len, label_size
194
+ # label: N , seq_len
195
+ loss = loss_fct(logits, label)
196
+ else:
197
+ # token-level
198
+ # logits: N * seq_len, label_size
199
+ # label: N * seq_len
200
+ # seq-level
201
+ # logits: N, label_size
202
+ # label: N
203
+ loss = loss_fct(logits.view(-1, label_size), label.view(-1))
204
+ else:
205
+ raise Exception("Not support output_mode=%s" % output_mode)
206
+ return loss
207
+
208
+ def __forword__(self,
209
+ input_ids: Optional[torch.Tensor] = None,
210
+ attention_mask: Optional[torch.Tensor] = None,
211
+ token_type_ids: Optional[torch.Tensor] = None,
212
+ position_ids: Optional[torch.Tensor] = None,
213
+ output_keys: Optional[dict[str, set[str]]] = None,
214
+ labels: Optional[dict[str, dict[str, torch.Tensor]]] = None,
215
+ repr_layers=[-1],
216
+ need_head_weights=False,
217
+ return_contacts=False,
218
+ use_last_layer_norm=True):
219
+ assert all(-(self.layer_size + 1) <= i <= self.layer_size for i in repr_layers)
220
+ repr_layers = [(i + self.layer_size + 1) % (self.layer_size + 1) for i in repr_layers]
221
+
222
+ if return_contacts:
223
+ need_head_weights = True
224
+
225
+ assert input_ids.ndim == 2
226
+ # 动态求mask,(B * Seq_len) 被mask掉位置的值为True
227
+ if attention_mask is None:
228
+ padding_mask = input_ids.eq(self.padding_idx)
229
+ else:
230
+ padding_mask = attention_mask.eq(self.padding_idx)
231
+
232
+ x = self.embed_scale * self.embed_tokens(input_ids)
233
+ if self.embed_pos is not None and position_ids is not None:
234
+ x += self.embed_scale * self.embed_pos(position_ids)
235
+ if self.embed_type is not None and token_type_ids is not None:
236
+ x += self.embed_scale * self.embed_type(token_type_ids)
237
+ if self.embed_layer_norm is not None:
238
+ x = self.embed_layer_norm(x)
239
+ # Token dropout
240
+ if self.token_dropout:
241
+ x.masked_fill_((input_ids == self.mask_idx).unsqueeze(-1), 0.0)
242
+ # x: B x L x C
243
+ mask_ratio_train = 0.15 * 0.8
244
+ src_lengths = (~padding_mask).sum(-1)
245
+ mask_ratio_observed = (input_ids == self.mask_idx).sum(-1).to(x.dtype) / src_lengths
246
+ x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
247
+
248
+ # Mask 操作
249
+ if padding_mask is not None:
250
+ x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
251
+
252
+ # 返回值包括哪些
253
+ repr_layers = set(repr_layers)
254
+ hidden_representations = {}
255
+ # 0:embedding
256
+ if 0 in repr_layers:
257
+ hidden_representations[0] = x
258
+
259
+ # 是否需要返回head weights
260
+ if need_head_weights:
261
+ attn_weights = []
262
+
263
+ # (B, L, E) => (L, B, E)
264
+ x = x.transpose(0, 1)
265
+
266
+ if not padding_mask.any():
267
+ padding_mask = None
268
+
269
+ for layer_idx, layer in enumerate(self.layers):
270
+ x, attn = layer(
271
+ x,
272
+ self_attn_padding_mask=padding_mask,
273
+ need_head_weights=need_head_weights,
274
+ )
275
+ if (layer_idx + 1) in repr_layers:
276
+ hidden_representations[layer_idx + 1] = x.transpose(0, 1)
277
+ if need_head_weights:
278
+ # (H, B, L, L) => (B, H, L, L)
279
+ attn_weights.append(attn.transpose(1, 0))
280
+
281
+ # (L, B, E)
282
+ if self.last_layer_norm is not None and use_last_layer_norm:
283
+ # 最后一层隐含层 加一层layernorm
284
+ x = self.last_layer_norm(x)
285
+ x = x.transpose(0, 1) # (L, B, E) => (B, L, E)
286
+
287
+ # last hidden representation should have layer norm applied
288
+ if (layer_idx + 1) in repr_layers:
289
+ hidden_representations[layer_idx + 1] = x
290
+ # 最后一层作为表征矩阵
291
+ # (B, L, E)
292
+ representation_matrix = hidden_representations[self.layer_size]
293
+ # mask 任务
294
+ # B * Seq_len * vocab_size
295
+ if not self.embedding_inference:
296
+ lm_mask_logits = self.lm_head(x)
297
+ # lm head的输出向量作为表征向量
298
+ # (B, E)
299
+ representation_vector = representation_matrix[:, 0, :]
300
+
301
+ logits = {}
302
+ losses = {}
303
+ outputs = {}
304
+ representations = {
305
+ "representation_matrix": representation_matrix,
306
+ "representation_vector": representation_vector
307
+ }
308
+ # 每一层的attention值
309
+ if need_head_weights:
310
+ # attentions: B x Layers x H x L x L
311
+ attentions = torch.stack(attn_weights, 1)
312
+ if padding_mask is not None:
313
+ attention_mask = 1 - padding_mask.type_as(attentions)
314
+ attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
315
+ attentions = attentions * attention_mask[:, None, None, :, :]
316
+ representations["attentions"] = attentions
317
+ # 预测contact矩阵
318
+ if return_contacts and hasattr(self, "contact_head") \
319
+ and not self.embedding_inference:
320
+ contacts = self.contact_head(input_ids, attentions)
321
+ representations["contacts"] = contacts
322
+ '''
323
+ print("output_keys:")
324
+ print(output_keys)
325
+ '''
326
+ if not self.embedding_inference and output_keys:
327
+ for item in output_keys.items():
328
+ cur_task_level_type = item[0]
329
+ if cur_task_level_type not in logits:
330
+ logits[cur_task_level_type] = {}
331
+ outputs[cur_task_level_type] = {}
332
+ for cur_task_level_name in item[1]:
333
+ if cur_task_level_type == "token_level":
334
+ cur_logits = lm_mask_logits
335
+ elif cur_task_level_type == "seq_level":
336
+ cur_logits = self.classifier_dropout[cur_task_level_type][cur_task_level_name](representation_vector)
337
+ cur_hidden_layer = self.hidden_layer[cur_task_level_type][cur_task_level_name]
338
+ if cur_hidden_layer is not None:
339
+ cur_logits = cur_hidden_layer(cur_logits)
340
+ cur_hidden_act = self.hidden_act[cur_task_level_type][cur_task_level_name]
341
+ if cur_hidden_act is not None:
342
+ cur_logits = cur_hidden_act(cur_logits)
343
+ cur_logits = self.classifier[cur_task_level_type][cur_task_level_name](cur_logits)
344
+ elif cur_task_level_type == "span_level":
345
+ cur_logits = self.classifier_dropout[cur_task_level_type][cur_task_level_name](representation_matrix)
346
+ cur_hidden_layer = self.hidden_layer[cur_task_level_type][cur_task_level_name]
347
+ if cur_hidden_layer is not None:
348
+ cur_logits = cur_hidden_layer(cur_logits)
349
+ cur_hidden_act = self.hidden_act[cur_task_level_type][cur_task_level_name]
350
+ if cur_hidden_act is not None:
351
+ cur_logits = cur_hidden_act(cur_logits)
352
+ cur_logits = self.classifier[cur_task_level_type][cur_task_level_name](cur_logits)
353
+ elif cur_task_level_type == "structure_level":
354
+ cur_logits = self.classifier_dropout[cur_task_level_type][cur_task_level_name](representation_matrix)
355
+ cur_hidden_layer = self.hidden_layer[cur_task_level_type][cur_task_level_name]
356
+ if cur_hidden_layer is not None:
357
+ cur_logits = cur_hidden_layer(cur_logits)
358
+ cur_hidden_act = self.hidden_act[cur_task_level_type][cur_task_level_name]
359
+ if cur_hidden_act is not None:
360
+ cur_logits = cur_hidden_act(cur_logits)
361
+ cur_logits = self.classifier[cur_task_level_type][cur_task_level_name](cur_logits)
362
+ logits[cur_task_level_type][cur_task_level_name] = cur_logits
363
+ if cur_task_level_type in self.output and cur_task_level_name in self.output[cur_task_level_type] \
364
+ and self.output[cur_task_level_type][cur_task_level_name] is not None:
365
+ outputs[cur_task_level_type][cur_task_level_name] = self.output[cur_task_level_type][cur_task_level_name](cur_logits)
366
+ else:
367
+ outputs[cur_task_level_type][cur_task_level_name] = cur_logits
368
+ if labels is not None and cur_task_level_type in labels and cur_task_level_name in labels[cur_task_level_type]:
369
+ if cur_task_level_type not in losses:
370
+ losses[cur_task_level_type] = {}
371
+ cur_label = labels[cur_task_level_type][cur_task_level_name]
372
+ cur_label_size = self.label_size[cur_task_level_type][cur_task_level_name]
373
+ cur_output_mode = self.output_mode[cur_task_level_type][cur_task_level_name]
374
+ cur_loss_fct = self.loss_fct[cur_task_level_type][cur_task_level_name]
375
+ cur_loss = self.__calc_loss__(
376
+ task_level_type=cur_task_level_type,
377
+ output_mode=cur_output_mode,
378
+ logits=cur_logits,
379
+ label=cur_label,
380
+ label_size=cur_label_size,
381
+ loss_fct=cur_loss_fct,
382
+ loss_reduction="meanmean")
383
+ losses[cur_task_level_type][cur_task_level_name] = cur_loss
384
+ return representations, logits, outputs, losses
385
+
386
+ def forward(
387
+ self,
388
+ input_ids: Optional[torch.Tensor] = None,
389
+ attention_mask: Optional[torch.Tensor] = None,
390
+ global_attention_mask: Optional[torch.Tensor] = None,
391
+ token_type_ids: Optional[torch.Tensor] = None,
392
+ position_ids: Optional[torch.Tensor] = None,
393
+ head_mask: Optional[torch.Tensor] = None,
394
+ inputs_embeds: Optional[torch.Tensor] = None,
395
+ output_keys: Optional[dict[str, set[str]]] = None,
396
+ labels: Optional[dict[str, dict[str, torch.Tensor]]] = None,
397
+ input_ids_b: Optional[torch.Tensor] = None,
398
+ attention_mask_b: Optional[torch.Tensor] = None,
399
+ global_attention_mask_b: Optional[torch.Tensor] = None,
400
+ token_type_ids_b: Optional[torch.Tensor] = None,
401
+ position_ids_b: Optional[torch.Tensor] = None,
402
+ head_mask_b: Optional[torch.Tensor] = None,
403
+ inputs_embeds_b: Optional[torch.Tensor] = None,
404
+ output_keys_b: Optional[dict[str, set[str]]] = None,
405
+ labels_b: Optional[dict[str, dict[str, torch.Tensor]]] = None,
406
+ pair_label: Optional[dict[str, dict[str, torch.Tensor]]] = None,
407
+ pair_output_keys: Optional[dict[str, set[str]]] = None,
408
+ output_hidden_states: Optional[dict[str, set[str]]] = None,
409
+ output_attentions: Optional[dict[str, set[str]]] = None,
410
+ need_head_weights: Optional[bool] = None,
411
+ return_contacts: Optional[bool] = None,
412
+ repr_layers: Optional[list[int]] = None,
413
+ return_dict: Optional[bool] = None,
414
+ use_last_layer_norm: Optional[bool] = True
415
+ ) -> Union[Tuple[torch.Tensor], AllOutput]:
416
+ if return_dict is None and self.config is not None:
417
+ return_dict = self.config.use_return_dict
418
+ if return_dict is None:
419
+ return_dict = False
420
+ if repr_layers is None or len(repr_layers) == 0:
421
+ repr_layers = [-1]
422
+ if return_contacts is None:
423
+ return_contacts = False
424
+ if need_head_weights is None:
425
+ need_head_weights = True
426
+ has_pair = False
427
+ has_pair_b = False
428
+ if input_ids is not None or inputs_embeds is not None:
429
+ encoding, logits, outputs, losses = self.__forword__(
430
+ input_ids=input_ids,
431
+ attention_mask=attention_mask,
432
+ token_type_ids=token_type_ids,
433
+ position_ids=position_ids,
434
+ output_keys=output_keys,
435
+ labels=labels,
436
+ repr_layers=repr_layers,
437
+ need_head_weights=need_head_weights,
438
+ return_contacts=return_contacts,
439
+ use_last_layer_norm=use_last_layer_norm
440
+ )
441
+ has_pair = True
442
+ if input_ids_b is not None or inputs_embeds_b is not None:
443
+ encoding_b, logits_b, outputs_b, losses_b = self.__forword__(
444
+ input_ids=input_ids_b,
445
+ attention_mask=attention_mask_b,
446
+ token_type_ids=token_type_ids_b,
447
+ position_ids=position_ids_b,
448
+ output_keys=output_keys_b,
449
+ labels=labels_b,
450
+ repr_layers=repr_layers,
451
+ need_head_weights=need_head_weights,
452
+ return_contacts=return_contacts,
453
+ use_last_layer_norm=use_last_layer_norm
454
+ )
455
+ has_pair_b = True
456
+
457
+ if not self.embedding_inference:
458
+ if has_pair and has_pair_b and pair_output_keys and len(pair_output_keys) > 0:
459
+ cur_representation_vector = encoding["representation_vector"]
460
+ cur_representation_vector_b = encoding_b["representation_vector"]
461
+
462
+ pair_logits = {}
463
+ pair_outputs = {}
464
+ for item1 in pair_output_keys.items():
465
+ cur_task_level_type = item1[0]
466
+ if cur_task_level_type not in pair_outputs:
467
+ pair_outputs[cur_task_level_type] = {}
468
+ pair_logits[cur_task_level_type] = {}
469
+ for cur_task_level_name in item1[1]:
470
+ cur_logits = self.classifier_dropout[cur_task_level_type][cur_task_level_name](
471
+ torch.cat((cur_representation_vector, cur_representation_vector_b), dim=-1)
472
+ )
473
+ cur_hidden_layer = self.hidden_layer[cur_task_level_type][cur_task_level_name]
474
+ if cur_hidden_layer is not None:
475
+ cur_logits = cur_hidden_layer(cur_logits)
476
+ cur_logits = self.classifier[cur_task_level_type][cur_task_level_name](cur_logits)
477
+ pair_logits[cur_task_level_type][cur_task_level_name] = cur_logits
478
+ pair_outputs[cur_task_level_type][cur_task_level_name] = self.output[cur_task_level_type][cur_task_level_name](cur_logits)
479
+
480
+ if pair_label is not None:
481
+ pair_loss = {}
482
+ for item1 in pair_output_keys.items():
483
+ cur_task_level_type = item1[0]
484
+ if cur_task_level_type not in pair_label:
485
+ continue
486
+ if cur_task_level_type in pair_label:
487
+ pair_loss[cur_task_level_type] = {}
488
+ for cur_task_level_name in item1[1]:
489
+ if cur_task_level_name not in pair_label[cur_task_level_type]:
490
+ continue
491
+ cur_label = pair_label[cur_task_level_type][cur_task_level_name]
492
+ cur_label_size = self.label_size[cur_task_level_type][cur_task_level_name]
493
+ cur_output_mode = self.output_mode[cur_task_level_type][cur_task_level_name]
494
+ cur_loss_fct = self.loss_fct[cur_task_level_type][cur_task_level_name]
495
+ cur_logits = pair_logits[cur_task_level_type][cur_task_level_name]
496
+ cur_loss = self.__calc_loss__(
497
+ task_level_type=cur_task_level_type,
498
+ output_mode=cur_output_mode, logits=cur_logits,
499
+ label=cur_label, label_size=cur_label_size, loss_fct=cur_loss_fct,
500
+ loss_reduction="meanmean")
501
+ pair_loss[cur_task_level_type][cur_task_level_name] = cur_loss
502
+
503
+ if not return_dict:
504
+ return [[losses, losses_b, pair_loss], [outputs, outputs_b, pair_outputs]] + [[encoding, encoding_b]]
505
+ return AllOutput(
506
+ losses=losses,
507
+ outputs=outputs,
508
+ hidden_states=encoding["representation_matrix"] if "representation_matrix" in encoding else None,
509
+ attentions=encoding["attentions"] if "attentions" in encoding else None,
510
+ global_attentions=None,
511
+ contacts=encoding["contacts"] if "contacts" in encoding else None,
512
+ losses_b=losses_b,
513
+ outputs_b=outputs_b,
514
+ hidden_states_b=encoding_b["representation_matrix"] if "representation_matrix" in encoding_b else None,
515
+ attentions_b=encoding_b["attentions"] if "hidden_states" in encoding_b else None,
516
+ global_attentions_b=None,
517
+ contacts_b=encoding_b["contacts"] if "contacts" in encoding_b else None,
518
+ pair_outputs=pair_outputs,
519
+ pair_losses=pair_loss)
520
+ else:
521
+ if not return_dict:
522
+ return [[losses, losses_b], [outputs, outputs_b]] + [[encoding, encoding_b]]
523
+ return AllOutput(
524
+ losses=losses,
525
+ outputs=outputs,
526
+ hidden_states=encoding["representation_matrix"] if "representation_matrix" in encoding else None,
527
+ attentions=encoding["attentions"] if "attentions" in encoding else None,
528
+ global_attentions=None,
529
+ contacts=encoding["contacts"] if "contacts" in encoding else None,
530
+ losses_b=losses_b,
531
+ outputs_b=outputs_b,
532
+ hidden_states_b=encoding_b["representation_matrix"] if "representation_matrix" in encoding_b else None,
533
+ attentions_b=encoding_b["attentions"] if "attentions" in encoding_b else None,
534
+ global_attentions_b=None,
535
+ contacts_b=encoding_b["contacts"] if "contacts" in encoding_b else None
536
+ )
537
+ elif has_pair:
538
+ if not return_dict:
539
+ return [[losses], [outputs], [encoding]]
540
+ return AllOutput(
541
+ losses=losses,
542
+ outputs=outputs,
543
+ hidden_states=encoding["representation_matrix"] if "representation_matrix" in encoding else None,
544
+ attentions=encoding["attentions"] if "attentions" in encoding else None,
545
+ global_attentions=None,
546
+ contacts=encoding["contacts"] if "contacts" in encoding else None
547
+ )
548
+ else:
549
+ if not return_dict:
550
+ return [[losses_b], [outputs_b], [encoding_b]]
551
+ return AllOutput(
552
+ losses_b=losses_b,
553
+ outputs_b=outputs_b,
554
+ hidden_states_b=encoding_b["representation_matrix"] if "representation_matrix" in encoding_b else None,
555
+ attentions_b=encoding_b["attentions"] if "attentions" in encoding_b else None,
556
+ global_attentions_b=None,
557
+ contacts_b=encoding_b["contacts"] if "contacts" in encoding_b else None
558
+ )
559
+ else:
560
+ if has_pair and has_pair_b:
561
+ if not return_dict:
562
+ return [[None, None], [None, None]] + [[encoding, encoding_b]]
563
+ return AllOutput(
564
+ losses=None,
565
+ outputs=None,
566
+ hidden_states=encoding["representation_matrix"] if "representation_matrix" in encoding else None,
567
+ attentions=encoding["attentions"] if "attentions" in encoding else None,
568
+ global_attentions=None,
569
+ contacts=encoding["contacts"] if "contacts" in encoding else None,
570
+ losses_b=None,
571
+ outputs_b=None,
572
+ hidden_states_b=encoding_b["representation_matrix"] if "representation_matrix" in encoding_b else None,
573
+ attentions_b=encoding_b["attentions"] if "attentions" in encoding_b else None,
574
+ global_attentions_b=None,
575
+ contacts_b=encoding_b["contacts"] if "contacts" in encoding_b else None
576
+ )
577
+ elif has_pair:
578
+ if not return_dict:
579
+ return [[None], [None], [encoding]]
580
+ return AllOutput(
581
+ losses=None,
582
+ outputs=None,
583
+ hidden_states=encoding["representation_matrix"] if "representation_matrix" in encoding else None,
584
+ attentions=encoding["attentions"] if "attentions" in encoding else None,
585
+ global_attentions=None,
586
+ contacts=encoding["contacts"] if "contacts" in encoding else None
587
+ )
588
+ else:
589
+ if not return_dict:
590
+ return [[None], [None], [encoding_b]]
591
+ return AllOutput(
592
+ losses_b=None,
593
+ outputs_b=None,
594
+ hidden_states_b=encoding_b["representation_matrix"] if "representation_matrix" in encoding_b else None,
595
+ attentions_b=encoding_b["attentions"] if "attentions" in encoding_b else None,
596
+ global_attentions_b=None,
597
+ contacts_b=encoding_b["contacts"] if "contacts" in encoding_b else None
598
+ )
599
+
600
+ def predict_contacts(self, input_ids, position_ids=None, token_type_ids=None):
601
+ return self(
602
+ input_ids=input_ids,
603
+ position_ids=position_ids,
604
+ token_type_ids=token_type_ids,
605
+ return_contacts=True)["contacts"]
lucaone_gplm_config.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ from transformers import PretrainedConfig
5
+
6
+ class LucaGPLMConfig(PretrainedConfig):
7
+ model_type = "lucagplm"
8
+
9
+ def __init__(
10
+ self,
11
+ vocab_size=-1,
12
+ pad_token_id=0,
13
+ max_position_embeddings: int = 4096,
14
+ type_vocab_size: int = 2,
15
+ num_hidden_layers: int = 24,
16
+ hidden_size: int = 1280,
17
+ num_attention_heads: int = 20,
18
+ no_position_embeddings: bool = False,
19
+ no_token_type_embeddings: bool = False,
20
+ alphabet: str = "gene_prot",
21
+ token_dropout: bool = True,
22
+ attention_probs_dropout_prob=0.1,
23
+ hidden_dropout_prob=0.1,
24
+ classifier_dropout_prob=0.1,
25
+ use_embed_layer_norm=True,
26
+ use_last_layer_norm=True,
27
+ embed_scale=1.0,
28
+ ignore_index=-100,
29
+ **kwargs
30
+ ):
31
+
32
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
33
+ self.alphabet = alphabet
34
+ self.vocab_size = vocab_size
35
+ self.max_position_embeddings = max_position_embeddings
36
+ self.type_vocab_size = type_vocab_size
37
+ self.no_token_type_embeddings = no_token_type_embeddings
38
+ self.no_position_embeddings = no_position_embeddings
39
+ self.num_hidden_layers = num_hidden_layers
40
+ self.hidden_size = hidden_size
41
+ self.num_attention_heads = num_attention_heads
42
+ self.token_dropout = token_dropout
43
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
44
+ self.hidden_dropout_prob = hidden_dropout_prob
45
+ self.classifier_dropout_prob = classifier_dropout_prob
46
+ self.ignore_index = ignore_index
47
+ self.use_embed_layer_norm = use_embed_layer_norm
48
+ self.use_last_layer_norm = use_last_layer_norm
49
+ self.embed_scale = embed_scale
masked_loss.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ '''
4
+ @license: (C) Copyright 2021, Hey.
5
+ @author: Hey
6
+ @email: sanyuan.hy@alibaba-inc.com
7
+ @tel: 137****6540
8
+ @datetime: 2023/6/28 10:25
9
+ @project: LucaOne
10
+ @file: masked_loss.py
11
+ @desc: masked loss
12
+ '''
13
+ import warnings
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+
18
+ class _MaskedLoss(nn.Module):
19
+ """Base class for masked losses"""
20
+
21
+ def __init__(self, reduction='mean', ignore_nans=True, ignore_value=-100.0):
22
+ super().__init__()
23
+ self.reduction = reduction
24
+ self.ignore_nans = ignore_nans
25
+ self.ignore_value = ignore_value
26
+
27
+ def forward(self, pred, target, mask=None):
28
+ """Compute a loss between pred and target for given mask.
29
+ Note that this implementation is faster than loss(pred[mask], target[mask])
30
+ for a given loss, and is nan-proof."""
31
+ '''
32
+ if not (target.size() == pred.size()):
33
+ warnings.warn(
34
+ "Using a target size ({}) that is different to the pred size ({}). "
35
+ "This will likely lead to incorrect results due to broadcasting. "
36
+ "Please ensure they have the same size.".format(
37
+ target.size(), pred.size()),
38
+ stacklevel=2,
39
+ )
40
+ '''
41
+ if mask is None and self.ignore_value is not None:
42
+ mask = target != self.ignore_value
43
+ elif mask is None:
44
+ mask = torch.ones_like(target, dtype=bool)
45
+ target_proxy = target
46
+ if self.ignore_nans:
47
+ target_proxy = target.clone()
48
+ nans = torch.isnan(target)
49
+ if nans.any():
50
+ with torch.no_grad():
51
+ mask = mask & ~nans
52
+ target_proxy[nans] = 0
53
+ # full_loss = self.criterion(pred, target_proxy)
54
+ # print("mask shape")
55
+ # print(mask.shape)
56
+ if self.reduction == 'meanmean' and pred.ndim == 3 and pred.shape[-1] == 1:
57
+ # token-level binary classification
58
+ # pred: n , seq_len, 1 -> n * seq_len
59
+ # target: n, seq_len -> n * seq_len
60
+ full_loss = self.criterion(pred.view(-1), target_proxy.view(-1))
61
+ full_loss = torch.reshape(full_loss, (-1, pred.shape[1]))
62
+ # print("ok1")
63
+ elif self.reduction == 'meanmean' and pred.ndim == 3:
64
+ if target.ndim == 3:
65
+ # token-level regression
66
+ # pred: n , seq_len, label_size -> n * seq_len * label_size
67
+ # target: n, seq_len, label_size -> n * seq_len * label_size
68
+ full_loss = self.criterion(pred.view(-1), target_proxy.view(-1))
69
+ full_loss = torch.reshape(full_loss, (-1, pred.shape[1], pred.shape[-1]))
70
+ # print("ok21")
71
+ else:
72
+ # token-level multi classification
73
+ # pred: n , seq_len, label_size -> n * seq_len, label_size
74
+ # target: n, seq_len -> n * seq_len
75
+ full_loss = self.criterion(pred.view(-1, pred.shape[-1]), target_proxy.view(-1))
76
+ full_loss = torch.reshape(full_loss, (-1, pred.shape[1]))
77
+ # print("ok22")
78
+ elif self.reduction == 'meanmean' and pred.ndim == 2 and target.ndim == 2:
79
+ # seq-level multi label
80
+ # pred: n , label_size -> n * label_size
81
+ # target: n, label_size -> n * label_size
82
+ full_loss = self.criterion(pred.view(-1), target_proxy.view(-1))
83
+ full_loss = torch.reshape(full_loss, (-1, pred.shape[1]))
84
+ # print("ok3")
85
+ elif self.reduction == 'meanmean':
86
+ self.reduction = "mean"
87
+ full_loss = self.criterion(pred, target_proxy)
88
+ # print("ok4")
89
+ else:
90
+ full_loss = self.criterion(pred, target_proxy)
91
+ # print("ok5")
92
+
93
+ full_loss[~mask] = 0
94
+ '''
95
+ if not mask.any():
96
+ warnings.warn("Evaluation mask is False everywhere, this might lead to incorrect results.")
97
+ print(full_loss.sum(), mask.to(full_loss.dtype).sum())
98
+ '''
99
+ if self.reduction == 'none':
100
+ return full_loss
101
+ if self.reduction == 'sum':
102
+ return full_loss.sum()
103
+ if self.reduction == 'mean':
104
+ '''
105
+ print("mask:")
106
+ print(mask.to(full_loss.dtype).sum(dim=-1))
107
+ print(mask.to(full_loss.dtype).sum())
108
+ '''
109
+ return full_loss.sum() / (mask.to(full_loss.dtype).sum() + 1e-12)
110
+ if self.reduction == 'meanmean':
111
+ if mask.ndim == 3:
112
+ mask_sum = mask.to(full_loss.dtype).sum(dim=-1)
113
+ '''
114
+ print("mask:")
115
+ print(mask_sum)
116
+ '''
117
+ full_loss = full_loss.sum(dim=-1) / (mask_sum + 1e-12)
118
+ mask_sum = mask_sum.to(torch.bool).sum(dim=-1)
119
+ # print(mask_sum)
120
+ full_loss = full_loss.sum(dim=-1) / (mask_sum + 1e-12)
121
+ mask_sum = mask_sum.to(torch.bool).sum()
122
+ # print(mask_sum)
123
+ loss = full_loss.sum() / (mask_sum + 1e-12)
124
+ else:
125
+ mask_sum = mask.to(full_loss.dtype).sum(dim=-1)
126
+ '''
127
+ print("mask:")
128
+ print(mask_sum)
129
+ print(mask_sum.to(torch.bool).sum())
130
+ '''
131
+ loss = torch.sum(full_loss.sum(dim=-1) / (mask_sum + 1e-12)) / (mask_sum.to(torch.bool).sum() + 1e-12)
132
+ # print(full_loss.sum() / (mask.to(full_loss.dtype).sum() + 1e-12), loss)
133
+ return loss
134
+ if self.reduction in ["summean", "meansum"]:
135
+ if mask.ndim == 3:
136
+ mask_sum = mask.to(full_loss.dtype).sum(dim=-1)
137
+ '''
138
+ print("mask:")
139
+ print(mask_sum)
140
+ '''
141
+ full_loss = full_loss.sum(dim=-1)
142
+ mask_sum = mask_sum.to(torch.bool).sum(dim=-1)
143
+ # print(mask_sum)
144
+ full_loss = full_loss.sum(dim=-1) / (mask_sum + 1e-12)
145
+ mask_sum = mask_sum.to(torch.bool).sum()
146
+ # print(mask_sum)
147
+ loss = full_loss.sum() / (mask_sum + 1e-12)
148
+ else:
149
+ mask_sum = mask.to(full_loss.dtype).sum(dim=-1)
150
+ '''
151
+ print("mask:")
152
+ print(mask_sum)
153
+ print(mask_sum.to(torch.bool).sum())
154
+ '''
155
+ loss = full_loss.sum() / (mask_sum.to(torch.bool).sum() + 1e-12)
156
+ return loss
157
+ return full_loss
158
+
159
+
metrics.py ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ '''
4
+ @license: (C) Copyright 2021, Hey.
5
+ @author: Hey
6
+ @email: sanyuan.**@**.com
7
+ @tel: 137****6540
8
+ @datetime: 2022/11/26 21:05
9
+ @project: LucaOne
10
+ @file: metrics.py
11
+ @desc: metrics for binary classification or multi-class classification
12
+ '''
13
+ import csv
14
+ import numpy as np
15
+ import matplotlib.pyplot as plt
16
+ plt.rcParams.update({'font.size': 18})
17
+ plt.rcParams['axes.unicode_minus'] = False
18
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, \
19
+ average_precision_score, confusion_matrix, mean_absolute_error, mean_squared_error, r2_score
20
+
21
+
22
+ def topk_accuracy_score(targets, probs, k=3):
23
+ '''
24
+ topk accuracy
25
+ :param targets:
26
+ :param probs:
27
+ :param k:
28
+ :return:
29
+ '''
30
+ # obtain top-k label
31
+ max_k_preds = probs.argsort(axis=1)[:, -k:][:, ::-1]
32
+ a_real = np.resize(targets, (targets.shape[0], 1))
33
+ # obtain the match result
34
+ match_array = np.logical_or.reduce(max_k_preds == a_real, axis=1)
35
+ topk_acc_score = match_array.sum() / match_array.shape[0]
36
+ return topk_acc_score
37
+
38
+
39
+ def multi_class_acc(targets, probs, threshold=0.5):
40
+ if targets.ndim == 2:
41
+ targets = np.argmax(targets, axis=1)
42
+ preds = np.argmax(probs, axis=1)
43
+ return accuracy_score(targets, preds)
44
+
45
+
46
+ def multi_class_precision(targets, probs, average='macro'):
47
+ if targets.ndim == 2:
48
+ targets = np.argmax(targets, axis=1)
49
+ preds = np.argmax(probs, axis=1)
50
+ return precision_score(targets, preds, average=average)
51
+
52
+
53
+ def multi_class_recall(targets, probs, average='macro'):
54
+ if targets.ndim == 2:
55
+ targets = np.argmax(targets, axis=1)
56
+ preds = np.argmax(probs, axis=1)
57
+ return recall_score(targets, preds, average=average)
58
+
59
+
60
+ def multi_class_f1(targets, probs, average='macro'):
61
+ if targets.ndim == 2:
62
+ targets = np.argmax(targets, axis=1)
63
+ preds = np.argmax(probs, axis=1)
64
+ return f1_score(targets, preds, average=average)
65
+
66
+
67
+ def multi_class_roc_auc(targets, probs, average='macro'):
68
+ if targets.ndim == 2:
69
+ targets = np.argmax(targets, axis=1)
70
+ return roc_auc_score(targets, probs, average=average, multi_class='ovr')
71
+
72
+
73
+ def multi_class_pr_auc(targets, probs, average='macro'):
74
+ if targets.ndim == 2:
75
+ targets = np.argmax(targets, axis=1)
76
+ z = probs.shape[1]
77
+ new_targets = np.eye(z)[targets]
78
+ pr_auc = average_precision_score(new_targets, probs, average=average)
79
+ return pr_auc
80
+
81
+
82
+ def metrics_multi_class(targets, probs, average="macro"):
83
+ '''
84
+ metrics of multi-class classification
85
+ :param targets: 1d-array class index (n_samples, )
86
+ :param probs: 2d-array probability (n_samples, m_classes)
87
+ :return:
88
+ '''
89
+ if targets.ndim == 2 and targets.shape[1] > 1:
90
+ targets = np.argmax(targets, axis=1)
91
+ elif targets.ndim == 2 and targets.shape[1] == 1:
92
+ targets = np.squeeze(targets, axis=1)
93
+
94
+ preds = np.argmax(probs, axis=1)
95
+ acc = accuracy_score(targets, preds)
96
+ prec = precision_score(targets, preds, average=average)
97
+ recall = recall_score(targets, preds, average=average)
98
+ f1 = f1_score(targets, preds, average=average)
99
+ result = {
100
+ "acc": round(float(acc), 6),
101
+ "prec": round(float(prec), 6),
102
+ "recall": round(float(recall), 6),
103
+ "f1": round(float(f1), 6)
104
+ }
105
+ result.update({
106
+ "top2_acc": round(float(topk_accuracy_score(targets, probs, k=2)), 6),
107
+ "top3_acc": round(float(topk_accuracy_score(targets, probs, k=3)), 6),
108
+ "top5_acc": round(float(topk_accuracy_score(targets, probs, k=5)), 6),
109
+ "top10_acc": round(float(topk_accuracy_score(targets, probs, k=10)), 6)
110
+ })
111
+ try:
112
+ roc_auc = roc_auc_score(targets, probs, average=average, multi_class='ovr')
113
+ result.update({
114
+ "roc_auc": round(float(roc_auc), 6)
115
+ })
116
+ except Exception as e:
117
+ pass
118
+ try:
119
+ z = probs.shape[1]
120
+ new_targets = np.eye(z)[targets]
121
+ pr_auc = average_precision_score(new_targets, probs, average=average)
122
+ result.update({
123
+ "pr_auc": round(float(pr_auc), 6),
124
+ })
125
+ except Exception as e:
126
+ pass
127
+ return result
128
+
129
+
130
+ def metrics_multi_class_for_pred(targets, preds, probs=None, average="macro", savepath=None):
131
+ '''
132
+ metrcis for multi-class classification
133
+ :param targets: 1d-array class index (n_samples, )
134
+ :param preds: 1d-array class index (n_samples, )
135
+ :return:
136
+ '''
137
+ if targets.ndim == 2 and targets.shape[1] > 1:
138
+ targets = np.argmax(targets, axis=1)
139
+ elif targets.ndim == 2 and targets.shape[1] == 1:
140
+ targets = np.squeeze(targets, axis=1)
141
+
142
+ acc = accuracy_score(targets, preds)
143
+ prec = precision_score(targets, preds, average=average)
144
+ recall = recall_score(targets, preds, average=average)
145
+ f1 = f1_score(y_true=targets, y_pred=preds, average=average)
146
+ result = {
147
+ "acc": round(float(acc), 6),
148
+ "prec": round(float(prec), 6),
149
+ "recall": round(float(recall), 6),
150
+ "f1": round(float(f1), 6)
151
+ }
152
+ try:
153
+ roc_auc = roc_auc_score(targets, probs, average=average, multi_class='ovr')
154
+ result.update({
155
+ "roc_auc": round(float(roc_auc), 6)
156
+ })
157
+ except Exception as e:
158
+ pass
159
+ try:
160
+ z = probs.shape[1]
161
+ new_targets = np.eye(z)[targets]
162
+ pr_auc = average_precision_score(new_targets, probs, average=average)
163
+ result.update({
164
+ "pr_auc": round(float(pr_auc), 6),
165
+ })
166
+ except Exception as e:
167
+ pass
168
+ return result
169
+
170
+
171
+ def metrics_regression(targets, preds):
172
+ '''
173
+ metrcis for regression
174
+ :param targets: 1d-array class index (n_samples, )
175
+ :param preds: 1d-array class index (n_samples, )
176
+ :return:
177
+ '''
178
+ mae = mean_absolute_error(targets, preds)
179
+ mse = mean_squared_error(targets, preds)
180
+ r2 = r2_score(targets, preds)
181
+ return {
182
+ "mae": round(float(mae), 6),
183
+ "mse": round(float(mse), 6),
184
+ "r2": round(float(r2), 6)
185
+ }
186
+
187
+
188
+ def transform(targets, probs, threshold):
189
+ '''
190
+ metrics of binary classification
191
+ :param targets: 1d-array class index (n_samples, )
192
+ :param probs: 1d-array larger class probability (n_samples, )
193
+ :param threshold: 0-1 prob threshokd
194
+ :return:
195
+ '''
196
+ if targets.ndim == 2:
197
+ if targets.shape[1] == 2: # [[0, 1], [1, 0]]
198
+ targets = np.argmax(targets, axis=1)
199
+ else: # [[1], [0]]
200
+ targets = targets.flatten()
201
+ if probs.ndim == 2:
202
+ if probs.shape[1] == 2: # [[0.1, 0.9], [0.9, 0.1]]
203
+ preds = np.argmax(probs, axis=1)
204
+ probs = probs[:, 1].flatten()
205
+ else: # [[0.9], [0.1]]
206
+ preds = (probs >= threshold).astype(int).flatten()
207
+ probs = probs.flatten()
208
+ else:
209
+ preds = (probs >= threshold).astype(int)
210
+ return targets, probs, preds
211
+
212
+
213
+ def binary_acc(targets, probs, threshold=0.5):
214
+ targets, probs, preds = transform(targets, probs, threshold)
215
+ return accuracy_score(targets, preds)
216
+
217
+
218
+ def binary_precision(targets, probs, threshold=0.5, average='binary'):
219
+ targets, probs, preds = transform(targets, probs, threshold)
220
+ return precision_score(targets, preds, average=average)
221
+
222
+
223
+ def binary_recall(targets, probs, threshold=0.5, average='binary'):
224
+ targets, probs, preds = transform(targets, probs, threshold)
225
+ return recall_score(targets, preds, average=average)
226
+
227
+
228
+ def binary_f1(targets, probs, threshold=0.5, average='binary'):
229
+ targets, probs, preds = transform(targets, probs, threshold)
230
+ return f1_score(targets, preds, average=average)
231
+
232
+
233
+ def binary_roc_auc(targets, probs, threshold=0.5, average='macro'):
234
+ targets, probs, preds = transform(targets, probs, threshold)
235
+ return roc_auc_score(targets, probs, average=average)
236
+
237
+
238
+ def binary_pr_auc(targets, probs, threshold=0.5, average='macro'):
239
+ targets, probs, preds = transform(targets, probs, threshold)
240
+ return average_precision_score(targets, probs, average=average)
241
+
242
+
243
+ def binary_confusion_matrix(targets, probs, threshold=0.5, savepath=None):
244
+ targets, probs, preds = transform(targets, probs, threshold)
245
+ cm_obj = confusion_matrix(targets, preds, labels=[0, 1])
246
+ plot_confusion_matrix_for_binary_class(targets, preds, cm=cm_obj, savepath=savepath)
247
+ tn, fp, fn, tp = cm_obj.ravel()
248
+ cm = {"tn": int(tn), "fp": int(fp), "fn": int(fn), "tp": int(tp)}
249
+ return cm
250
+
251
+
252
+ def metrics_binary(targets, probs, threshold=0.5, average="binary", savepath=None):
253
+ '''
254
+ metrics for binary classification
255
+ :param targets: 1d-array class index (n_samples, )
256
+ :param probs: 1d-array larger class probability (n_samples, )
257
+ :param threshold: 0-1 prob threshold
258
+ :return:
259
+ '''
260
+ if targets.ndim == 2:
261
+ if targets.shape[1] == 2: # [[0, 1], [1, 0]]
262
+ targets = np.argmax(targets, axis=1)
263
+ else: # [[1], [0]]
264
+ targets = targets.flatten()
265
+ if probs.ndim == 2:
266
+ if probs.shape[1] == 2: # [[0.1, 0.9], [0.9, 0.1]]
267
+ preds = np.argmax(probs, axis=1)
268
+ probs = probs[:, 1].flatten()
269
+ else: # [[0.9], [0.1]]
270
+ preds = (probs >= threshold).astype(int).flatten()
271
+ probs = probs.flatten()
272
+ else:
273
+ preds = (probs >= threshold).astype(int)
274
+ acc = accuracy_score(targets, preds)
275
+ prec = precision_score(targets, preds, average=average)
276
+ recall = recall_score(targets, preds, average=average)
277
+ f1 = f1_score(targets, preds, average=average)
278
+ result = {
279
+ "acc": round(float(acc), 6),
280
+ "prec": round(float(prec), 6),
281
+ "recall": round(float(recall), 6),
282
+ "f1": round(float(f1), 6)
283
+ }
284
+ try:
285
+ roc_auc = roc_auc_score(targets, probs, average="macro")
286
+ result.update({
287
+ "roc_auc": round(float(roc_auc), 6)
288
+ })
289
+ except Exception as e:
290
+ pass
291
+ try:
292
+ pr_auc = average_precision_score(targets, probs, average="macro")
293
+ result.update({
294
+ "pr_auc": round(float(pr_auc), 6)
295
+ })
296
+ except Exception as e:
297
+ pass
298
+ try:
299
+ cm_obj = confusion_matrix(targets, preds, labels=[0, 1])
300
+ plot_confusion_matrix_for_binary_class(targets, preds, cm=cm_obj, savepath=savepath)
301
+ tn, fp, fn, tp = cm_obj.ravel()
302
+ cm = {"tn": int(tn), "fp": int(fp), "fn": int(fn), "tp": int(tp)}
303
+ result.update({
304
+ "confusion_matrix": cm
305
+ })
306
+ except Exception as e:
307
+ pass
308
+ # add mcc
309
+ try:
310
+ tn, fp, fn, tp = cm["tn"], cm["fp"], cm["fn"], cm["tp"]
311
+ mcc = (tn*tp - fp*fn) / (((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)) ** 0.5)
312
+ result.update({
313
+ "mcc": round(mcc, 6)
314
+ })
315
+ except Exception as e:
316
+ pass
317
+ return result
318
+
319
+
320
+ def metrics_binary_for_pred(targets, preds, probs=None, average="binary", savepath=None):
321
+ '''
322
+ metrics for binary classification
323
+ :param targets: 1d-array class index (n_samples, )
324
+ :param preds: 1d-array larger class index (n_samples, )
325
+ :return:
326
+ '''
327
+ if targets.ndim == 2:
328
+ if targets.shape[1] == 2: # [[1, 0], [0, 1]
329
+ targets = np.argmax(targets, axis=1)
330
+ else: # [[1], [0]]
331
+ targets = targets.flatten()
332
+ if preds.ndim == 2:
333
+ if preds.shape[1] == 2: # [[0.9, 0.1], [0.1, 0.9]]
334
+ preds = np.argmax(preds, axis=1)
335
+ else: # [[0], [1]]
336
+ preds = preds.flatten()
337
+ cm_obj = confusion_matrix(targets, preds, labels=[0, 1])
338
+ plot_confusion_matrix_for_binary_class(targets, preds, cm=cm_obj, savepath=savepath)
339
+ tn, fp, fn, tp = cm_obj.ravel()
340
+ cm = {"tn": int(tn), "fp": int(fp), "fn": int(fn), "tp": int(tp)}
341
+ if len(np.unique(targets)) > 1:
342
+ acc = accuracy_score(targets, preds)
343
+ prec = precision_score(targets, preds, average=average)
344
+ recall = recall_score(targets, preds, average=average)
345
+ f1 = f1_score(y_true=targets, y_pred=preds, average=average)
346
+ result = {
347
+ "acc": round(float(acc), 6),
348
+ "prec": round(float(prec), 6),
349
+ "recall": round(float(recall), 6),
350
+ "f1": round(float(f1), 6)
351
+ }
352
+ else:
353
+
354
+ result = {
355
+ "acc": round(float((cm["tp"] + cm["tn"]) / (cm["tp"] + cm["tn"] + cm["fp"] + cm["fn"])), 6),
356
+ "prec": round(float(cm["tp"]/(cm["tp"] + cm["fp"]) if cm["tp"] + cm["fp"] > 0 else 1.0), 6),
357
+ "recall": round(float(cm["tp"]/(cm["tp"] + cm["fn"]) if cm["tp"] + cm["fn"] > 0 else 1.0), 6),
358
+ }
359
+ result["f1"] = 2 * result["prec"] * result["recall"] / (result["prec"] + result["recall"])
360
+
361
+ try:
362
+ pr_auc = average_precision_score(targets, probs, average="macro")
363
+ result.update({
364
+ "pr_auc": round(float(pr_auc), 6)
365
+ })
366
+ except Exception as e:
367
+ pass
368
+ try:
369
+ roc_auc = roc_auc_score(targets, probs, average="macro")
370
+ result.update({
371
+ "roc_auc": round(float(roc_auc), 6)
372
+ })
373
+ except Exception as e:
374
+ pass
375
+ try:
376
+ tn, fp, fn, tp = cm["tn"], cm["fp"], cm["fn"], cm["tp"]
377
+ mcc = (tn*tp - fp*fn) / (((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)) ** 0.5)
378
+ result.update({
379
+ "mcc": round(mcc, 6)
380
+ })
381
+ except Exception as e:
382
+ pass
383
+ result.update({
384
+ "confusion_matrix": cm
385
+ })
386
+ return result
387
+
388
+
389
+ def write_error_samples_multi_class(filepath, samples, input_indexs, input_id_2_names, output_id_2_name, targets, probs,
390
+ use_other_diags=False, use_other_operas=False, use_checkin_department=False):
391
+ '''
392
+ write the bad cases of multi-class classification
393
+ :param filepath:
394
+ :param samples:
395
+ :param input_indexs:
396
+ :param input_id_2_names:
397
+ :param output_id_2_name:
398
+ :param targets:
399
+ :param probs:
400
+ :param use_other_diags:
401
+ :param use_other_operas:
402
+ :param use_checkin_department:
403
+ :return:
404
+ '''
405
+ targets = np.argmax(targets, axis=1)
406
+ preds = np.argmax(probs, axis=1)
407
+ with open(filepath, "w") as fp:
408
+ writer = csv.writer(fp)
409
+ writer.writerow(["score", "y_true", "y_pred", "inputs"])
410
+ for i in range(len(targets)):
411
+ target = targets[i]
412
+ pred = preds[i]
413
+ score = 1
414
+ if target != pred:
415
+ score = 0
416
+ if output_id_2_name:
417
+ target_label = output_id_2_name[target]
418
+ pred_label = output_id_2_name[pred]
419
+ else:
420
+ target_label = target
421
+ pred_label = pred
422
+ sample = samples[i]
423
+ if input_id_2_names:
424
+ new_sample = []
425
+ for idx, input_index in enumerate(input_indexs):
426
+ if input_index == 3 and not use_checkin_department:
427
+ input_index = 12
428
+ new_sample.append([input_id_2_names[idx][v] for v in sample[input_index]])
429
+ if (input_index == 6 and use_other_diags) or (input_index == 8 and use_other_operas) or (input_index == 10 and use_other_diags):
430
+ new_sample.append([input_id_2_names[idx][v] for v in sample[input_index + 1]])
431
+ else:
432
+ new_sample = sample
433
+ row = [score, target_label, pred_label, new_sample]
434
+ writer.writerow(row)
435
+
436
+
437
+ def write_error_samples_binary(filepath, samples, input_indexs, input_id_2_names, targets, probs, threshold=0.5,
438
+ use_other_diags=False, use_other_operas=False, use_checkin_department=False):
439
+ '''
440
+ write bad cases of binary classification
441
+ :param filepath:
442
+ :param samples:
443
+ :param input_indexs:
444
+ :param input_id_2_names:
445
+ :param targets:
446
+ :param probs:
447
+ :param threshold:
448
+ :param use_other_diags:
449
+ :param use_other_operas:
450
+ :param use_checkin_department:
451
+ :return:
452
+ '''
453
+ with open(filepath, "w") as fp:
454
+ writer = csv.writer(fp)
455
+ writer.writerow(["score", "y_true", "y_pred", "inputs"])
456
+ for i in range(len(targets)):
457
+ target = targets[i][0]
458
+ if target != 1:
459
+ target = 1
460
+ prob = probs[i][0]
461
+ if prob >= threshold:
462
+ pred = 1
463
+ else:
464
+ pred = 0
465
+ score = 1
466
+ if target != pred:
467
+ score = 0
468
+ target_label = "True" if target == 1 else "False"
469
+ pred_label = "True" if target == 1 else "False"
470
+ sample = samples[i]
471
+ if input_id_2_names:
472
+ new_sample = []
473
+ for idx, input_index in enumerate(input_indexs):
474
+ if input_index == 3 and not use_checkin_department:
475
+ input_index = 12
476
+ new_sample.append([input_id_2_names[idx][v] for v in sample[input_index]])
477
+ if (input_index == 6 and use_other_diags) or (input_index == 8 and use_other_operas) or (input_index == 10 and use_other_diags):
478
+ new_sample.append([input_id_2_names[idx][v] for v in sample[input_index + 1]])
479
+ else:
480
+ new_sample = sample
481
+ row = [score, target_label, pred_label, new_sample]
482
+ writer.writerow(row)
483
+
484
+
485
+ def plot_confusion_matrix_for_binary_class(targets, preds, cm=None, savepath=None):
486
+ '''
487
+ :param targets: ground truth
488
+ :param preds: prediction probs
489
+ :param cm: confusion matrix
490
+ :param savepath: confusion matrix picture savepth
491
+ '''
492
+
493
+ plt.figure(figsize=(40, 20), dpi=100)
494
+ if cm is None:
495
+ cm = confusion_matrix(targets, preds, labels=[0, 1])
496
+
497
+ plt.matshow(cm, cmap=plt.cm.Oranges)
498
+ plt.colorbar()
499
+
500
+ for x in range(len(cm)):
501
+ for y in range(len(cm)):
502
+ plt.annotate(cm[x, y], xy=(y, x), verticalalignment='center', horizontalalignment='center')
503
+ plt.ylabel('True')
504
+ plt.xlabel('Prediction')
505
+ if savepath:
506
+ plt.savefig(savepath, dpi=100)
507
+ else:
508
+ plt.show()
509
+ plt.close("all")
510
+
511
+
512
+ if __name__ == "__main__":
513
+ '''multi_class'''
514
+ targets = np.array([0, 1, 2, 1, 3])
515
+ probs = np.array([[0.9, 0.05, 0.05, 0], [0.5, 0.45, 0.05, 0], [0.4, 0.05, 0.55, 0], [0.1, 0.55, 0.25, 0.1], [0.4, 0.25, 0.35, 0]])
516
+ print(metrics_multi_class(targets, probs))
517
+
518
+ targets = np.array([0, 1, 2, 3, 3])
519
+ probs = np.array([[0.9, 0.05, 0.05, 0], [0.5, 0.45, 0.05, 0], [0.4, 0.05, 0.55, 0], [0.1, 0.25, 0.25, 0.4], [0.1, 0.25, 0.25, 0.4]])
520
+ print(metrics_multi_class(targets, probs))
521
+ targets = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], [0, 0, 0, 1]])
522
+ probs = np.array([[0.9, 0.05, 0.05, 0], [0.5, 0.45, 0.05, 0], [0.4, 0.05, 0.55, 0], [0.1, 0.25, 0.25, 0.4], [0.1, 0.25, 0.25, 0.4]])
523
+ print(metrics_multi_class(targets, probs))
524
+
525
+ '''binary'''
526
+ targets = np.array([0, 0, 1, 1])
527
+ probs = np.array([[0.1], [0.1], [0.1], [0.9]])
528
+ print(metrics_binary(targets, probs))
529
+
530
+ targets = np.array([[0], [0], [1], [1]])
531
+ probs = np.array([[0.1], [0.1], [0.1], [0.9]])
532
+ print(metrics_binary(targets, probs))
533
+
534
+ targets = np.array([0, 0, 1, 1])
535
+ probs = np.array([[0.1, 0.1, 0.1, 0.9]])
536
+ print(metrics_binary(targets, probs))
537
+
538
+ targets = np.array([0, 0, 1, 1])
539
+ probs = np.array([0.1, 0.1, 0.1, 0.9])
540
+ print(metrics_binary(targets, probs))
541
+
542
+ targets = np.array([0, 1, 2, 1, 3])
543
+ probs = np.array([[0.9, 0.05, 0.05, 0], [0.5, 0.45, 0.05, 0], [0.4, 0.05, 0.55, 0], [0.1, 0.55, 0.25, 0.1], [0.4, 0.25, 0.25, 0.1]])
544
+ z = probs.shape[1]
545
+ # print(z)
546
+ print(np.eye(z))
547
+ new_targets = np.eye(z)[targets]
548
+ print(new_targets)
549
+
model_utils.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ from typing import Optional, Tuple
5
+ from dataclasses import dataclass
6
+ from transformers.modeling_outputs import ModelOutput
7
+ import sys, copy, math
8
+
9
+ from .pooling import *
10
+ from .loss import *
11
+
12
+ @dataclass
13
+ class AllOutput(ModelOutput):
14
+ losses: Optional[dict[str, dict[str, torch.FloatTensor]]] = None
15
+ outputs: Optional[dict[str, dict[str, torch.FloatTensor]]] = None
16
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
17
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
18
+ cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
19
+ global_attentions: Optional[Tuple[torch.FloatTensor]] = None
20
+ contacts: Optional[Tuple[torch.FloatTensor]] = None
21
+ losses_b: Optional[dict[str, dict[str, torch.FloatTensor]]] = None
22
+ outputs_b: Optional[dict[str, dict[str, torch.FloatTensor]]] = None
23
+ hidden_states_b: Optional[Tuple[torch.FloatTensor]] = None
24
+ attentions_b: Optional[Tuple[torch.FloatTensor]] = None
25
+ cross_attentions_b: Optional[Tuple[torch.FloatTensor]] = None
26
+ global_attentions_b: Optional[Tuple[torch.FloatTensor]] = None
27
+ contacts_b: Optional[Tuple[torch.FloatTensor]] = None
28
+ pair_outputs: Optional[Tuple[torch.FloatTensor]] = None
29
+ pair_losses: Optional[dict[str, dict[str, torch.FloatTensor]]] = None
30
+
31
+
32
+ def create_pooler(task_level_type, task_level_name, config, args):
33
+ '''
34
+ pooler building
35
+ :param task_level_type:
36
+ :param task_level_name:
37
+ :param config:
38
+ :param args:
39
+ :return:
40
+ '''
41
+ hidden_size = config.hidden_size[task_level_type][task_level_name]
42
+ pooling_type = args.pooling_type[task_level_type][task_level_name]
43
+
44
+ if pooling_type == "max":
45
+ return GlobalMaskMaxPooling1D()
46
+ elif pooling_type == "sum":
47
+ return GlobalMaskSumPooling1D(axis=1)
48
+ elif pooling_type == "avg":
49
+ return GlobalMaskAvgPooling1D()
50
+ elif pooling_type == "attention":
51
+ return GlobalMaskContextAttentionPooling1D(embed_size=hidden_size)
52
+ elif pooling_type == "context_attention":
53
+ return GlobalMaskContextAttentionPooling1D(embed_size=hidden_size)
54
+ elif pooling_type == "weighted_attention":
55
+ return GlobalMaskWeightedAttentionPooling1D(embed_size=hidden_size)
56
+ elif pooling_type == "value_attention":
57
+ return GlobalMaskValueAttentionPooling1D(embed_size=hidden_size)
58
+ elif pooling_type == "transformer":
59
+ copy_config = copy.deepcopy(config)
60
+ copy_config.hidden_size = hidden_size
61
+ return GlobalMaskTransformerPooling1D(copy_config)
62
+ else:
63
+ return None
64
+
65
+
66
+ def create_output_loss_lucagplm(task_level_type, task_level_name, config):
67
+ '''not cls module'''
68
+ if not hasattr(config, "sigmoid"):
69
+ config.sigmoid = {task_level_type: {}}
70
+ elif task_level_type not in config.sigmoid:
71
+ config.sigmoid[task_level_type] = {}
72
+ config.sigmoid[task_level_type][task_level_name] = False if config.output_mode[task_level_type][task_level_name] \
73
+ in ["multi_class", "multi-class", "regression"] else True
74
+ # 特殊情况,contact需要是sigmoid, 需要思考strcuture需不需要sigmoid
75
+ if task_level_name == "prot_contact":
76
+ config.sigmoid[task_level_type][task_level_name] = True
77
+ config.num_labels = config.label_size[task_level_type][task_level_name]
78
+ if task_level_type in ["token_level", "whole_level"]:
79
+ return_types = ["output", "loss"]
80
+ else:
81
+ return_types = ["dropout", "hidden_layer", "hidden_act", "classifier", "output", "loss"]
82
+ return create_loss_function(config,
83
+ task_level_type=task_level_type,
84
+ task_level_name=task_level_name,
85
+ sigmoid=config.sigmoid[task_level_type][task_level_name],
86
+ output_mode=config.output_mode[task_level_type][task_level_name],
87
+ num_labels=config.num_labels,
88
+ loss_type=config.loss_type[task_level_type][task_level_name],
89
+ ignore_index=config.ignore_index,
90
+ pair_level=True if task_level_type == "pair_level" else False,
91
+ return_types=return_types)
92
+
93
+
94
+ def create_output_loss(task_level_type, task_level_name, cls_module, config, args):
95
+ cls = None
96
+ if task_level_type in ["token_level", "whole_level"]:
97
+ cls = cls_module(config)
98
+ dropout, hidden_layer, hidden_act, classifier, output, loss_fct = create_output_loss_lucagplm(task_level_type, task_level_name, config, args)
99
+ return cls, dropout, hidden_layer, hidden_act, classifier, output, loss_fct
modeling_bert.py ADDED
@@ -0,0 +1,1911 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ '''
4
+ @license: (C) Copyright 2021, Hey.
5
+ @author: Hey
6
+ @email: sanyuan.**@**.com
7
+ @tel: 137****6540
8
+ @datetime: 2022/12/2 09:38
9
+ @project: LucaOne
10
+ @file: modeling_bert
11
+ @desc: transformer layers
12
+ '''
13
+ import math
14
+ import os
15
+ import warnings
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.utils.checkpoint
21
+ from packaging import version
22
+ from torch import nn
23
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24
+
25
+ from transformers.activations import ACT2FN
26
+
27
+ from transformers.modeling_outputs import (
28
+ BaseModelOutputWithPastAndCrossAttentions,
29
+ BaseModelOutputWithPoolingAndCrossAttentions,
30
+ CausalLMOutputWithCrossAttentions,
31
+ MaskedLMOutput,
32
+ MultipleChoiceModelOutput,
33
+ NextSentencePredictorOutput,
34
+ QuestionAnsweringModelOutput,
35
+ SequenceClassifierOutput,
36
+ TokenClassifierOutput,
37
+ )
38
+ from transformers.modeling_utils import PreTrainedModel
39
+ from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
40
+ from transformers.utils import (
41
+ ModelOutput,
42
+ add_code_sample_docstrings,
43
+ add_start_docstrings,
44
+ add_start_docstrings_to_model_forward,
45
+ logging,
46
+ replace_return_docstrings,
47
+ )
48
+ from transformers.models.bert.configuration_bert import BertConfig
49
+
50
+
51
+ logger = logging.get_logger(__name__)
52
+
53
+ _CHECKPOINT_FOR_DOC = "bert-base-uncased"
54
+ _CONFIG_FOR_DOC = "BertConfig"
55
+ _TOKENIZER_FOR_DOC = "BertTokenizer"
56
+
57
+ # TokenClassification docstring
58
+ _CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english"
59
+ _TOKEN_CLASS_EXPECTED_OUTPUT = (
60
+ "['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] "
61
+ )
62
+ _TOKEN_CLASS_EXPECTED_LOSS = 0.01
63
+
64
+ # QuestionAnswering docstring
65
+ _CHECKPOINT_FOR_QA = "deepset/bert-base-cased-squad2"
66
+ _QA_EXPECTED_OUTPUT = "'a nice puppet'"
67
+ _QA_EXPECTED_LOSS = 7.41
68
+ _QA_TARGET_START_INDEX = 14
69
+ _QA_TARGET_END_INDEX = 15
70
+
71
+ # SequenceClassification docstring
72
+ _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "textattack/bert-base-uncased-yelp-polarity"
73
+ _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
74
+ _SEQ_CLASS_EXPECTED_LOSS = 0.01
75
+
76
+
77
+ BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
78
+ "bert-base-uncased",
79
+ "bert-large-uncased",
80
+ "bert-base-cased",
81
+ "bert-large-cased",
82
+ "bert-base-multilingual-uncased",
83
+ "bert-base-multilingual-cased",
84
+ "bert-base-chinese",
85
+ "bert-base-german-cased",
86
+ "bert-large-uncased-whole-word-masking",
87
+ "bert-large-cased-whole-word-masking",
88
+ "bert-large-uncased-whole-word-masking-finetuned-squad",
89
+ "bert-large-cased-whole-word-masking-finetuned-squad",
90
+ "bert-base-cased-finetuned-mrpc",
91
+ "bert-base-german-dbmdz-cased",
92
+ "bert-base-german-dbmdz-uncased",
93
+ "cl-tohoku/bert-base-japanese",
94
+ "cl-tohoku/bert-base-japanese-whole-word-masking",
95
+ "cl-tohoku/bert-base-japanese-char",
96
+ "cl-tohoku/bert-base-japanese-char-whole-word-masking",
97
+ "TurkuNLP/bert-base-finnish-cased-v1",
98
+ "TurkuNLP/bert-base-finnish-uncased-v1",
99
+ "wietsedv/bert-base-dutch-cased",
100
+ # See all BERT models at https://huggingface.co/models?filter=bert
101
+ ]
102
+
103
+
104
+ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
105
+ """Load tf checkpoints in a pytorch model."""
106
+ try:
107
+ import re
108
+
109
+ import numpy as np
110
+ import tensorflow as tf
111
+ except ImportError:
112
+ logger.error(
113
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
114
+ "https://www.tensorflow.org/install/ for installation instructions."
115
+ )
116
+ raise
117
+ tf_path = os.path.abspath(tf_checkpoint_path)
118
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
119
+ # Load weights from TF model
120
+ init_vars = tf.train.list_variables(tf_path)
121
+ names = []
122
+ arrays = []
123
+ for name, shape in init_vars:
124
+ logger.info(f"Loading TF weight {name} with shape {shape}")
125
+ array = tf.train.load_variable(tf_path, name)
126
+ names.append(name)
127
+ arrays.append(array)
128
+
129
+ for name, array in zip(names, arrays):
130
+ name = name.split("/")
131
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
132
+ # which are not required for using pretrained model
133
+ if any(
134
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
135
+ for n in name
136
+ ):
137
+ logger.info(f"Skipping {'/'.join(name)}")
138
+ continue
139
+ pointer = model
140
+ for m_name in name:
141
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
142
+ scope_names = re.split(r"_(\d+)", m_name)
143
+ else:
144
+ scope_names = [m_name]
145
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
146
+ pointer = getattr(pointer, "weight")
147
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
148
+ pointer = getattr(pointer, "bias")
149
+ elif scope_names[0] == "output_weights":
150
+ pointer = getattr(pointer, "weight")
151
+ elif scope_names[0] == "squad":
152
+ pointer = getattr(pointer, "classifier")
153
+ else:
154
+ try:
155
+ pointer = getattr(pointer, scope_names[0])
156
+ except AttributeError:
157
+ logger.info(f"Skipping {'/'.join(name)}")
158
+ continue
159
+ if len(scope_names) >= 2:
160
+ num = int(scope_names[1])
161
+ pointer = pointer[num]
162
+ if m_name[-11:] == "_embeddings":
163
+ pointer = getattr(pointer, "weight")
164
+ elif m_name == "kernel":
165
+ array = np.transpose(array)
166
+ try:
167
+ if pointer.shape != array.shape:
168
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
169
+ except AssertionError as e:
170
+ e.args += (pointer.shape, array.shape)
171
+ raise
172
+ logger.info(f"Initialize PyTorch weight {name}")
173
+ pointer.data = torch.from_numpy(array)
174
+ return model
175
+
176
+
177
+ class BertEmbeddings(nn.Module):
178
+ """Construct the embeddings from word, position and token_type embeddings."""
179
+
180
+ def __init__(self, config):
181
+ super().__init__()
182
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
183
+ if hasattr(config, "no_position_embeddings"):
184
+ self.no_position_embeddings = config.no_position_embeddings
185
+ else:
186
+ self.no_position_embeddings = False
187
+
188
+ if hasattr(config, "no_token_type_embeddings"):
189
+ self.no_token_type_embeddings = config.no_token_type_embeddings
190
+ else:
191
+ self.no_token_type_embeddings = False
192
+
193
+ if not self.no_position_embeddings:
194
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
195
+
196
+ if not self.no_token_type_embeddings:
197
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
198
+
199
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
200
+ # any TensorFlow checkpoint file
201
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
202
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
203
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
204
+ if not self.no_position_embeddings:
205
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
206
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
207
+
208
+ if not self.no_token_type_embeddings:
209
+ if not hasattr(self, "position_ids"):
210
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
211
+ if version.parse(torch.__version__) > version.parse("1.6.0"):
212
+ self.register_buffer(
213
+ "token_type_ids",
214
+ torch.zeros(self.position_ids.size(), dtype=torch.long),
215
+ persistent=False,
216
+ )
217
+
218
+ def forward(
219
+ self,
220
+ input_ids: Optional[torch.LongTensor] = None,
221
+ token_type_ids: Optional[torch.LongTensor] = None,
222
+ position_ids: Optional[torch.LongTensor] = None,
223
+ inputs_embeds: Optional[torch.FloatTensor] = None,
224
+ past_key_values_length: int = 0,
225
+ ) -> torch.Tensor:
226
+ if input_ids is not None:
227
+ input_shape = input_ids.size()
228
+ else:
229
+ input_shape = inputs_embeds.size()[:-1]
230
+
231
+ seq_length = input_shape[1]
232
+
233
+ if (not self.no_position_embeddings or not self.no_token_type_embeddings) and position_ids is None:
234
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
235
+
236
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
237
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
238
+ # issue #5664
239
+ if not self.no_token_type_embeddings:
240
+ if token_type_ids is None:
241
+ if hasattr(self, "token_type_ids"):
242
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
243
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
244
+ token_type_ids = buffered_token_type_ids_expanded
245
+ else:
246
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
247
+
248
+ if inputs_embeds is None:
249
+ inputs_embeds = self.word_embeddings(input_ids)
250
+ embeddings = inputs_embeds
251
+
252
+ if not self.no_token_type_embeddings:
253
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
254
+ embeddings += token_type_embeddings
255
+
256
+ if not self.no_position_embeddings and self.position_embedding_type == "absolute":
257
+ position_embeddings = self.position_embeddings(position_ids)
258
+ embeddings += position_embeddings
259
+
260
+ embeddings = self.LayerNorm(embeddings)
261
+ embeddings = self.dropout(embeddings)
262
+ return embeddings
263
+
264
+
265
+ class BertSelfAttention(nn.Module):
266
+ def __init__(self, config, position_embedding_type=None):
267
+ super().__init__()
268
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
269
+ raise ValueError(
270
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
271
+ f"heads ({config.num_attention_heads})"
272
+ )
273
+
274
+ self.num_attention_heads = config.num_attention_heads
275
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
276
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
277
+
278
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
279
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
280
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
281
+
282
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
283
+ self.position_embedding_type = position_embedding_type or getattr(
284
+ config, "position_embedding_type", "absolute"
285
+ )
286
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
287
+ self.max_position_embeddings = config.max_position_embeddings
288
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
289
+
290
+ self.is_decoder = config.is_decoder
291
+
292
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
293
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
294
+ x = x.view(new_x_shape)
295
+ return x.permute(0, 2, 1, 3)
296
+
297
+ def forward(
298
+ self,
299
+ hidden_states: torch.Tensor,
300
+ attention_mask: Optional[torch.FloatTensor] = None,
301
+ head_mask: Optional[torch.FloatTensor] = None,
302
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
303
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
304
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
305
+ output_attentions: Optional[bool] = False,
306
+ ) -> Tuple[torch.Tensor]:
307
+ mixed_query_layer = self.query(hidden_states)
308
+
309
+ # If this is instantiated as a cross-attention module, the keys
310
+ # and values come from an encoder; the attention mask needs to be
311
+ # such that the encoder's padding tokens are not attended to.
312
+ is_cross_attention = encoder_hidden_states is not None
313
+
314
+ if is_cross_attention and past_key_value is not None:
315
+ # reuse k,v, cross_attentions
316
+ key_layer = past_key_value[0]
317
+ value_layer = past_key_value[1]
318
+ attention_mask = encoder_attention_mask
319
+ elif is_cross_attention:
320
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
321
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
322
+ attention_mask = encoder_attention_mask
323
+ elif past_key_value is not None:
324
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
325
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
326
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
327
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
328
+ else:
329
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
330
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
331
+
332
+ query_layer = self.transpose_for_scores(mixed_query_layer)
333
+
334
+ if self.is_decoder:
335
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
336
+ # Further calls to cross_attention layer can then reuse all cross-attention
337
+ # key/value_states (first "if" case)
338
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
339
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
340
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
341
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
342
+ past_key_value = (key_layer, value_layer)
343
+
344
+ # Take the dot product between "query" and "key" to get the raw attention scores.
345
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
346
+
347
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
348
+ seq_length = hidden_states.size()[1]
349
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
350
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
351
+ distance = position_ids_l - position_ids_r
352
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
353
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
354
+
355
+ if self.position_embedding_type == "relative_key":
356
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
357
+ attention_scores = attention_scores + relative_position_scores
358
+ elif self.position_embedding_type == "relative_key_query":
359
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
360
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
361
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
362
+
363
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
364
+ if attention_mask is not None:
365
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
366
+ attention_scores = attention_scores + attention_mask
367
+
368
+ # Normalize the attention scores to probabilities.
369
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
370
+
371
+ # This is actually dropping out entire tokens to attend to, which might
372
+ # seem a bit unusual, but is taken from the original Transformer paper.
373
+ attention_probs = self.dropout(attention_probs)
374
+
375
+ # Mask heads if we want to
376
+ if head_mask is not None:
377
+ attention_probs = attention_probs * head_mask
378
+
379
+ context_layer = torch.matmul(attention_probs, value_layer)
380
+
381
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
382
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
383
+ context_layer = context_layer.view(new_context_layer_shape)
384
+
385
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
386
+
387
+ if self.is_decoder:
388
+ outputs = outputs + (past_key_value,)
389
+ return outputs
390
+
391
+
392
+ class BertSelfOutput(nn.Module):
393
+ def __init__(self, config):
394
+ super().__init__()
395
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
396
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
397
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
398
+
399
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
400
+ hidden_states = self.dense(hidden_states)
401
+ hidden_states = self.dropout(hidden_states)
402
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
403
+ return hidden_states
404
+
405
+
406
+ class BertAttention(nn.Module):
407
+ def __init__(self, config, position_embedding_type=None):
408
+ super().__init__()
409
+ self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type)
410
+ self.output = BertSelfOutput(config)
411
+ self.pruned_heads = set()
412
+
413
+ def prune_heads(self, heads):
414
+ if len(heads) == 0:
415
+ return
416
+ heads, index = find_pruneable_heads_and_indices(
417
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
418
+ )
419
+
420
+ # Prune linear layers
421
+ self.self.query = prune_linear_layer(self.self.query, index)
422
+ self.self.key = prune_linear_layer(self.self.key, index)
423
+ self.self.value = prune_linear_layer(self.self.value, index)
424
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
425
+
426
+ # Update hyper params and store pruned heads
427
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
428
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
429
+ self.pruned_heads = self.pruned_heads.union(heads)
430
+
431
+ def forward(
432
+ self,
433
+ hidden_states: torch.Tensor,
434
+ attention_mask: Optional[torch.FloatTensor] = None,
435
+ head_mask: Optional[torch.FloatTensor] = None,
436
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
437
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
438
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
439
+ output_attentions: Optional[bool] = False,
440
+ ) -> Tuple[torch.Tensor]:
441
+ self_outputs = self.self(
442
+ hidden_states,
443
+ attention_mask,
444
+ head_mask,
445
+ encoder_hidden_states,
446
+ encoder_attention_mask,
447
+ past_key_value,
448
+ output_attentions,
449
+ )
450
+ attention_output = self.output(self_outputs[0], hidden_states)
451
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
452
+ return outputs
453
+
454
+
455
+ class BertIntermediate(nn.Module):
456
+ def __init__(self, config):
457
+ super().__init__()
458
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
459
+ if isinstance(config.hidden_act, str):
460
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
461
+ else:
462
+ self.intermediate_act_fn = config.hidden_act
463
+
464
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
465
+ hidden_states = self.dense(hidden_states)
466
+ hidden_states = self.intermediate_act_fn(hidden_states)
467
+ return hidden_states
468
+
469
+
470
+ class BertOutput(nn.Module):
471
+ def __init__(self, config):
472
+ super().__init__()
473
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
474
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
475
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
476
+
477
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
478
+ hidden_states = self.dense(hidden_states)
479
+ hidden_states = self.dropout(hidden_states)
480
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
481
+ return hidden_states
482
+
483
+
484
+ class BertLayer(nn.Module):
485
+ def __init__(self, config):
486
+ super().__init__()
487
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
488
+ self.seq_len_dim = 1
489
+ self.attention = BertAttention(config)
490
+ self.is_decoder = config.is_decoder
491
+ self.add_cross_attention = config.add_cross_attention
492
+ if self.add_cross_attention:
493
+ if not self.is_decoder:
494
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
495
+ self.crossattention = BertAttention(config, position_embedding_type="absolute")
496
+ self.intermediate = BertIntermediate(config)
497
+ self.output = BertOutput(config)
498
+
499
+ def forward(
500
+ self,
501
+ hidden_states: torch.Tensor,
502
+ attention_mask: Optional[torch.FloatTensor] = None,
503
+ head_mask: Optional[torch.FloatTensor] = None,
504
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
505
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
506
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
507
+ output_attentions: Optional[bool] = False,
508
+ ) -> Tuple[torch.Tensor]:
509
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
510
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
511
+ self_attention_outputs = self.attention(
512
+ hidden_states,
513
+ attention_mask,
514
+ head_mask,
515
+ output_attentions=output_attentions,
516
+ past_key_value=self_attn_past_key_value,
517
+ )
518
+ attention_output = self_attention_outputs[0]
519
+
520
+ # if decoder, the last output is tuple of self-attn cache
521
+ if self.is_decoder:
522
+ outputs = self_attention_outputs[1:-1]
523
+ present_key_value = self_attention_outputs[-1]
524
+ else:
525
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
526
+
527
+ cross_attn_present_key_value = None
528
+ if self.is_decoder and encoder_hidden_states is not None:
529
+ if not hasattr(self, "crossattention"):
530
+ raise ValueError(
531
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
532
+ " by setting `config.add_cross_attention=True`"
533
+ )
534
+
535
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
536
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
537
+ cross_attention_outputs = self.crossattention(
538
+ attention_output,
539
+ attention_mask,
540
+ head_mask,
541
+ encoder_hidden_states,
542
+ encoder_attention_mask,
543
+ cross_attn_past_key_value,
544
+ output_attentions,
545
+ )
546
+ attention_output = cross_attention_outputs[0]
547
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
548
+
549
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
550
+ cross_attn_present_key_value = cross_attention_outputs[-1]
551
+ present_key_value = present_key_value + cross_attn_present_key_value
552
+
553
+ layer_output = apply_chunking_to_forward(
554
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
555
+ )
556
+ outputs = (layer_output,) + outputs
557
+
558
+ # if decoder, return the attn key/values as the last output
559
+ if self.is_decoder:
560
+ outputs = outputs + (present_key_value,)
561
+
562
+ return outputs
563
+
564
+ def feed_forward_chunk(self, attention_output):
565
+ intermediate_output = self.intermediate(attention_output)
566
+ layer_output = self.output(intermediate_output, attention_output)
567
+ return layer_output
568
+
569
+
570
+ class BertEncoder(nn.Module):
571
+ def __init__(self, config):
572
+ super().__init__()
573
+ self.config = config
574
+ self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
575
+ self.gradient_checkpointing = False
576
+
577
+ def forward(
578
+ self,
579
+ hidden_states: torch.Tensor,
580
+ attention_mask: Optional[torch.FloatTensor] = None,
581
+ head_mask: Optional[torch.FloatTensor] = None,
582
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
583
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
584
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
585
+ use_cache: Optional[bool] = None,
586
+ output_attentions: Optional[bool] = False,
587
+ output_hidden_states: Optional[bool] = False,
588
+ return_dict: Optional[bool] = True,
589
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
590
+ all_hidden_states = () if output_hidden_states else None
591
+ all_self_attentions = () if output_attentions else None
592
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
593
+
594
+ next_decoder_cache = () if use_cache else None
595
+ for i, layer_module in enumerate(self.layer):
596
+ if output_hidden_states:
597
+ all_hidden_states = all_hidden_states + (hidden_states,)
598
+
599
+ layer_head_mask = head_mask[i] if head_mask is not None else None
600
+ past_key_value = past_key_values[i] if past_key_values is not None else None
601
+
602
+ if self.gradient_checkpointing and self.training:
603
+
604
+ if use_cache:
605
+ logger.warning(
606
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
607
+ )
608
+ use_cache = False
609
+
610
+ def create_custom_forward(module):
611
+ def custom_forward(*inputs):
612
+ return module(*inputs, past_key_value, output_attentions)
613
+
614
+ return custom_forward
615
+
616
+ layer_outputs = torch.utils.checkpoint.checkpoint(
617
+ create_custom_forward(layer_module),
618
+ hidden_states,
619
+ attention_mask,
620
+ layer_head_mask,
621
+ encoder_hidden_states,
622
+ encoder_attention_mask,
623
+ )
624
+ else:
625
+ layer_outputs = layer_module(
626
+ hidden_states,
627
+ attention_mask,
628
+ layer_head_mask,
629
+ encoder_hidden_states,
630
+ encoder_attention_mask,
631
+ past_key_value,
632
+ output_attentions,
633
+ )
634
+
635
+ hidden_states = layer_outputs[0]
636
+ if use_cache:
637
+ next_decoder_cache += (layer_outputs[-1],)
638
+ if output_attentions:
639
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
640
+ if self.config.add_cross_attention:
641
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
642
+
643
+ if output_hidden_states:
644
+ all_hidden_states = all_hidden_states + (hidden_states,)
645
+
646
+ if not return_dict:
647
+ return tuple(
648
+ v
649
+ for v in [
650
+ hidden_states,
651
+ next_decoder_cache,
652
+ all_hidden_states,
653
+ all_self_attentions,
654
+ all_cross_attentions,
655
+ ]
656
+ if v is not None
657
+ )
658
+ return BaseModelOutputWithPastAndCrossAttentions(
659
+ last_hidden_state=hidden_states,
660
+ past_key_values=next_decoder_cache,
661
+ hidden_states=all_hidden_states,
662
+ attentions=all_self_attentions,
663
+ cross_attentions=all_cross_attentions,
664
+ )
665
+
666
+
667
+ class BertPooler(nn.Module):
668
+ def __init__(self, config):
669
+ super().__init__()
670
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
671
+ self.activation = nn.Tanh()
672
+
673
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
674
+ # We "pool" the model by simply taking the hidden state corresponding
675
+ # to the first token.
676
+ first_token_tensor = hidden_states[:, 0]
677
+ pooled_output = self.dense(first_token_tensor)
678
+ pooled_output = self.activation(pooled_output)
679
+ return pooled_output
680
+
681
+
682
+ class BertPredictionHeadTransform(nn.Module):
683
+ def __init__(self, config):
684
+ super().__init__()
685
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
686
+ if isinstance(config.hidden_act, str):
687
+ self.transform_act_fn = ACT2FN[config.hidden_act]
688
+ else:
689
+ self.transform_act_fn = config.hidden_act
690
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
691
+
692
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
693
+ hidden_states = self.dense(hidden_states)
694
+ hidden_states = self.transform_act_fn(hidden_states)
695
+ hidden_states = self.LayerNorm(hidden_states)
696
+ return hidden_states
697
+
698
+
699
+ class BertLMPredictionHead(nn.Module):
700
+ def __init__(self, config):
701
+ super().__init__()
702
+ self.transform = BertPredictionHeadTransform(config)
703
+
704
+ # The output weights are the same as the input embeddings, but there is
705
+ # an output-only bias for each token.
706
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
707
+
708
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
709
+
710
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
711
+ self.decoder.bias = self.bias
712
+
713
+ def forward(self, hidden_states):
714
+ hidden_states = self.transform(hidden_states)
715
+ hidden_states = self.decoder(hidden_states)
716
+ return hidden_states
717
+
718
+
719
+ class BertOnlyMLMHead(nn.Module):
720
+ def __init__(self, config):
721
+ super().__init__()
722
+ self.predictions = BertLMPredictionHead(config)
723
+
724
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
725
+ prediction_scores = self.predictions(sequence_output)
726
+ return prediction_scores
727
+
728
+
729
+ class BertOnlyNSPHead(nn.Module):
730
+ def __init__(self, config):
731
+ super().__init__()
732
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
733
+
734
+ def forward(self, pooled_output):
735
+ seq_relationship_score = self.seq_relationship(pooled_output)
736
+ return seq_relationship_score
737
+
738
+
739
+ class BertPreTrainingHeads(nn.Module):
740
+ def __init__(self, config):
741
+ super().__init__()
742
+ self.predictions = BertLMPredictionHead(config)
743
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
744
+
745
+ def forward(self, sequence_output, pooled_output):
746
+ prediction_scores = self.predictions(sequence_output)
747
+ seq_relationship_score = self.seq_relationship(pooled_output)
748
+ return prediction_scores, seq_relationship_score
749
+
750
+
751
+ class BertPreTrainedModel(PreTrainedModel):
752
+ """
753
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
754
+ models.
755
+ """
756
+
757
+ config_class = BertConfig
758
+ load_tf_weights = load_tf_weights_in_bert
759
+ base_model_prefix = "bert"
760
+ supports_gradient_checkpointing = True
761
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
762
+
763
+ def _init_weights(self, module):
764
+ """Initialize the weights"""
765
+ if isinstance(module, nn.Linear):
766
+ # Slightly different from the TF version which uses truncated_normal for initialization
767
+ # cf https://github.com/pytorch/pytorch/pull/5617
768
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
769
+ if module.bias is not None:
770
+ module.bias.data.zero_()
771
+ elif isinstance(module, nn.Embedding):
772
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
773
+ if module.padding_idx is not None:
774
+ module.weight.data[module.padding_idx].zero_()
775
+ elif isinstance(module, nn.LayerNorm):
776
+ module.bias.data.zero_()
777
+ module.weight.data.fill_(1.0)
778
+
779
+ def _set_gradient_checkpointing(self, module, value=False):
780
+ if isinstance(module, BertEncoder):
781
+ module.gradient_checkpointing = value
782
+
783
+
784
+ @dataclass
785
+ class BertForPreTrainingOutput(ModelOutput):
786
+ """
787
+ Output type of [`BertForPreTraining`].
788
+
789
+ Args:
790
+ loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
791
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction
792
+ (classification) loss.
793
+ prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
794
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
795
+ seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
796
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
797
+ before SoftMax).
798
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
799
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
800
+ shape `(batch_size, sequence_length, hidden_size)`.
801
+
802
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
803
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
804
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
805
+ sequence_length)`.
806
+
807
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
808
+ heads.
809
+ """
810
+
811
+ loss: Optional[torch.FloatTensor] = None
812
+ prediction_logits: torch.FloatTensor = None
813
+ seq_relationship_logits: torch.FloatTensor = None
814
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
815
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
816
+
817
+
818
+ BERT_START_DOCSTRING = r"""
819
+
820
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
821
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
822
+ etc.)
823
+
824
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
825
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
826
+ and behavior.
827
+
828
+ Parameters:
829
+ config ([`BertConfig`]): Model configuration class with all the parameters of the model.
830
+ Initializing with a config file does not load the weights associated with the model, only the
831
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
832
+ """
833
+
834
+ BERT_INPUTS_DOCSTRING = r"""
835
+ Args:
836
+ input_ids (`torch.LongTensor` of shape `({0})`):
837
+ Indices of input sequence tokens in the vocabulary.
838
+
839
+ Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.encode`] and
840
+ [`PreTrainedTokenizer.__call__`] for details.
841
+
842
+ [What are input IDs?](../glossary#input-ids)
843
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
844
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
845
+
846
+ - 1 for tokens that are **not masked**,
847
+ - 0 for tokens that are **masked**.
848
+
849
+ [What are attention masks?](../glossary#attention-mask)
850
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
851
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
852
+ 1]`:
853
+
854
+ - 0 corresponds to a *sentence A* token,
855
+ - 1 corresponds to a *sentence B* token.
856
+
857
+ [What are token type IDs?](../glossary#token-type-ids)
858
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
859
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
860
+ config.max_position_embeddings - 1]`.
861
+
862
+ [What are position IDs?](../glossary#position-ids)
863
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
864
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
865
+
866
+ - 1 indicates the head is **not masked**,
867
+ - 0 indicates the head is **masked**.
868
+
869
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
870
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
871
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
872
+ model's internal embedding lookup matrix.
873
+ output_attentions (`bool`, *optional*):
874
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
875
+ tensors for more detail.
876
+ output_hidden_states (`bool`, *optional*):
877
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
878
+ more detail.
879
+ return_dict (`bool`, *optional*):
880
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
881
+ """
882
+
883
+
884
+ @add_start_docstrings(
885
+ "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
886
+ BERT_START_DOCSTRING,
887
+ )
888
+ class BertModel(BertPreTrainedModel):
889
+ """
890
+
891
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
892
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
893
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
894
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
895
+
896
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
897
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
898
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
899
+ """
900
+
901
+ def __init__(self, config, add_pooling_layer=True):
902
+ super().__init__(config)
903
+ self.config = config
904
+
905
+ self.embeddings = BertEmbeddings(config)
906
+ self.encoder = BertEncoder(config)
907
+
908
+ self.pooler = BertPooler(config) if add_pooling_layer else None
909
+
910
+ # Initialize weights and apply final processing
911
+ self.post_init()
912
+
913
+ def get_input_embeddings(self):
914
+ return self.embeddings.word_embeddings
915
+
916
+ def set_input_embeddings(self, value):
917
+ self.embeddings.word_embeddings = value
918
+
919
+ def _prune_heads(self, heads_to_prune):
920
+ """
921
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
922
+ class PreTrainedModel
923
+ """
924
+ for layer, heads in heads_to_prune.items():
925
+ self.encoder.layer[layer].attention.prune_heads(heads)
926
+
927
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
928
+ @add_code_sample_docstrings(
929
+ processor_class=_TOKENIZER_FOR_DOC,
930
+ checkpoint=_CHECKPOINT_FOR_DOC,
931
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
932
+ config_class=_CONFIG_FOR_DOC,
933
+ )
934
+ def forward(
935
+ self,
936
+ input_ids: Optional[torch.Tensor] = None,
937
+ attention_mask: Optional[torch.Tensor] = None,
938
+ token_type_ids: Optional[torch.Tensor] = None,
939
+ position_ids: Optional[torch.Tensor] = None,
940
+ head_mask: Optional[torch.Tensor] = None,
941
+ inputs_embeds: Optional[torch.Tensor] = None,
942
+ encoder_hidden_states: Optional[torch.Tensor] = None,
943
+ encoder_attention_mask: Optional[torch.Tensor] = None,
944
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
945
+ use_cache: Optional[bool] = None,
946
+ output_attentions: Optional[bool] = None,
947
+ output_hidden_states: Optional[bool] = None,
948
+ return_dict: Optional[bool] = None,
949
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
950
+ r"""
951
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
952
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
953
+ the model is configured as a decoder.
954
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
955
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
956
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
957
+
958
+ - 1 for tokens that are **not masked**,
959
+ - 0 for tokens that are **masked**.
960
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
961
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
962
+
963
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
964
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
965
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
966
+ use_cache (`bool`, *optional*):
967
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
968
+ `past_key_values`).
969
+ """
970
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
971
+ output_hidden_states = (
972
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
973
+ )
974
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
975
+
976
+ if self.config.is_decoder:
977
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
978
+ else:
979
+ use_cache = False
980
+
981
+ if input_ids is not None and inputs_embeds is not None:
982
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
983
+ elif input_ids is not None:
984
+ input_shape = input_ids.size()
985
+ elif inputs_embeds is not None:
986
+ input_shape = inputs_embeds.size()[:-1]
987
+ else:
988
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
989
+
990
+ batch_size, seq_length = input_shape
991
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
992
+
993
+ # past_key_values_length
994
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
995
+
996
+ if attention_mask is None:
997
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
998
+
999
+ if token_type_ids is None:
1000
+ if hasattr(self.embeddings, "token_type_ids"):
1001
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
1002
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
1003
+ token_type_ids = buffered_token_type_ids_expanded
1004
+ else:
1005
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
1006
+
1007
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1008
+ # ourselves in which case we just need to make it broadcastable to all heads.
1009
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
1010
+
1011
+ # If a 2D or 3D attention mask is provided for the cross-attention
1012
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1013
+ if self.config.is_decoder and encoder_hidden_states is not None:
1014
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
1015
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1016
+ if encoder_attention_mask is None:
1017
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
1018
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1019
+ else:
1020
+ encoder_extended_attention_mask = None
1021
+
1022
+ # Prepare head mask if needed
1023
+ # 1.0 in head_mask indicate we keep the head
1024
+ # attention_probs has shape bsz x n_heads x N x N
1025
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1026
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1027
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1028
+
1029
+ embedding_output = self.embeddings(
1030
+ input_ids=input_ids,
1031
+ position_ids=position_ids,
1032
+ token_type_ids=token_type_ids,
1033
+ inputs_embeds=inputs_embeds,
1034
+ past_key_values_length=past_key_values_length,
1035
+ )
1036
+ encoder_outputs = self.encoder(
1037
+ embedding_output,
1038
+ attention_mask=extended_attention_mask,
1039
+ head_mask=head_mask,
1040
+ encoder_hidden_states=encoder_hidden_states,
1041
+ encoder_attention_mask=encoder_extended_attention_mask,
1042
+ past_key_values=past_key_values,
1043
+ use_cache=use_cache,
1044
+ output_attentions=output_attentions,
1045
+ output_hidden_states=output_hidden_states,
1046
+ return_dict=return_dict,
1047
+ )
1048
+ sequence_output = encoder_outputs[0]
1049
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1050
+
1051
+ if not return_dict:
1052
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1053
+
1054
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1055
+ last_hidden_state=sequence_output,
1056
+ pooler_output=pooled_output,
1057
+ past_key_values=encoder_outputs.past_key_values,
1058
+ hidden_states=encoder_outputs.hidden_states,
1059
+ attentions=encoder_outputs.attentions,
1060
+ cross_attentions=encoder_outputs.cross_attentions,
1061
+ )
1062
+
1063
+
1064
+ @add_start_docstrings(
1065
+ """
1066
+ Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
1067
+ sentence prediction (classification)` head.
1068
+ """,
1069
+ BERT_START_DOCSTRING,
1070
+ )
1071
+ class BertForPreTraining(BertPreTrainedModel):
1072
+ def __init__(self, config):
1073
+ super().__init__(config)
1074
+
1075
+ self.bert = BertModel(config)
1076
+ self.cls = BertPreTrainingHeads(config)
1077
+
1078
+ # Initialize weights and apply final processing
1079
+ self.post_init()
1080
+
1081
+ def get_output_embeddings(self):
1082
+ return self.cls.predictions.decoder
1083
+
1084
+ def set_output_embeddings(self, new_embeddings):
1085
+ self.cls.predictions.decoder = new_embeddings
1086
+
1087
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1088
+ @replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
1089
+ def forward(
1090
+ self,
1091
+ input_ids: Optional[torch.Tensor] = None,
1092
+ attention_mask: Optional[torch.Tensor] = None,
1093
+ token_type_ids: Optional[torch.Tensor] = None,
1094
+ position_ids: Optional[torch.Tensor] = None,
1095
+ head_mask: Optional[torch.Tensor] = None,
1096
+ inputs_embeds: Optional[torch.Tensor] = None,
1097
+ labels: Optional[torch.Tensor] = None,
1098
+ next_sentence_label: Optional[torch.Tensor] = None,
1099
+ output_attentions: Optional[bool] = None,
1100
+ output_hidden_states: Optional[bool] = None,
1101
+ return_dict: Optional[bool] = None,
1102
+ ) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]:
1103
+ r"""
1104
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1105
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1106
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked),
1107
+ the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1108
+ next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1109
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence
1110
+ pair (see `input_ids` docstring) Indices should be in `[0, 1]`:
1111
+
1112
+ - 0 indicates sequence B is a continuation of sequence A,
1113
+ - 1 indicates sequence B is a random sequence.
1114
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
1115
+ Used to hide legacy arguments that have been deprecated.
1116
+
1117
+ Returns:
1118
+
1119
+ Example:
1120
+
1121
+ ```python
1122
+ >>> from transformers import BertTokenizer, BertForPreTraining
1123
+ >>> import torch
1124
+
1125
+ >>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
1126
+ >>> model = BertForPreTraining.from_pretrained("bert-base-uncased")
1127
+
1128
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1129
+ >>> outputs = model(**inputs)
1130
+
1131
+ >>> prediction_logits = outputs.prediction_logits
1132
+ >>> seq_relationship_logits = outputs.seq_relationship_logits
1133
+ ```
1134
+ """
1135
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1136
+
1137
+ outputs = self.bert(
1138
+ input_ids,
1139
+ attention_mask=attention_mask,
1140
+ token_type_ids=token_type_ids,
1141
+ position_ids=position_ids,
1142
+ head_mask=head_mask,
1143
+ inputs_embeds=inputs_embeds,
1144
+ output_attentions=output_attentions,
1145
+ output_hidden_states=output_hidden_states,
1146
+ return_dict=return_dict,
1147
+ )
1148
+
1149
+ sequence_output, pooled_output = outputs[:2]
1150
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
1151
+
1152
+ total_loss = None
1153
+ if labels is not None and next_sentence_label is not None:
1154
+ loss_fct = CrossEntropyLoss()
1155
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1156
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
1157
+ total_loss = masked_lm_loss + next_sentence_loss
1158
+
1159
+ if not return_dict:
1160
+ output = (prediction_scores, seq_relationship_score) + outputs[2:]
1161
+ return ((total_loss,) + output) if total_loss is not None else output
1162
+
1163
+ return BertForPreTrainingOutput(
1164
+ loss=total_loss,
1165
+ prediction_logits=prediction_scores,
1166
+ seq_relationship_logits=seq_relationship_score,
1167
+ hidden_states=outputs.hidden_states,
1168
+ attentions=outputs.attentions,
1169
+ )
1170
+
1171
+
1172
+ @add_start_docstrings(
1173
+ """Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING
1174
+ )
1175
+ class BertLMHeadModel(BertPreTrainedModel):
1176
+
1177
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1178
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1179
+
1180
+ def __init__(self, config):
1181
+ super().__init__(config)
1182
+
1183
+ if not config.is_decoder:
1184
+ logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`")
1185
+
1186
+ self.bert = BertModel(config, add_pooling_layer=False)
1187
+ self.cls = BertOnlyMLMHead(config)
1188
+
1189
+ # Initialize weights and apply final processing
1190
+ self.post_init()
1191
+
1192
+ def get_output_embeddings(self):
1193
+ return self.cls.predictions.decoder
1194
+
1195
+ def set_output_embeddings(self, new_embeddings):
1196
+ self.cls.predictions.decoder = new_embeddings
1197
+
1198
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1199
+ @add_code_sample_docstrings(
1200
+ processor_class=_TOKENIZER_FOR_DOC,
1201
+ checkpoint=_CHECKPOINT_FOR_DOC,
1202
+ output_type=CausalLMOutputWithCrossAttentions,
1203
+ config_class=_CONFIG_FOR_DOC,
1204
+ )
1205
+ def forward(
1206
+ self,
1207
+ input_ids: Optional[torch.Tensor] = None,
1208
+ attention_mask: Optional[torch.Tensor] = None,
1209
+ token_type_ids: Optional[torch.Tensor] = None,
1210
+ position_ids: Optional[torch.Tensor] = None,
1211
+ head_mask: Optional[torch.Tensor] = None,
1212
+ inputs_embeds: Optional[torch.Tensor] = None,
1213
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1214
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1215
+ labels: Optional[torch.Tensor] = None,
1216
+ past_key_values: Optional[List[torch.Tensor]] = None,
1217
+ use_cache: Optional[bool] = None,
1218
+ output_attentions: Optional[bool] = None,
1219
+ output_hidden_states: Optional[bool] = None,
1220
+ return_dict: Optional[bool] = None,
1221
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
1222
+ r"""
1223
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1224
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1225
+ the model is configured as a decoder.
1226
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1227
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1228
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1229
+
1230
+ - 1 for tokens that are **not masked**,
1231
+ - 0 for tokens that are **masked**.
1232
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1233
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1234
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
1235
+ ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
1236
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1237
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1238
+
1239
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1240
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1241
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1242
+ use_cache (`bool`, *optional*):
1243
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1244
+ `past_key_values`).
1245
+ """
1246
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1247
+ if labels is not None:
1248
+ use_cache = False
1249
+
1250
+ outputs = self.bert(
1251
+ input_ids,
1252
+ attention_mask=attention_mask,
1253
+ token_type_ids=token_type_ids,
1254
+ position_ids=position_ids,
1255
+ head_mask=head_mask,
1256
+ inputs_embeds=inputs_embeds,
1257
+ encoder_hidden_states=encoder_hidden_states,
1258
+ encoder_attention_mask=encoder_attention_mask,
1259
+ past_key_values=past_key_values,
1260
+ use_cache=use_cache,
1261
+ output_attentions=output_attentions,
1262
+ output_hidden_states=output_hidden_states,
1263
+ return_dict=return_dict,
1264
+ )
1265
+
1266
+ sequence_output = outputs[0]
1267
+ prediction_scores = self.cls(sequence_output)
1268
+
1269
+ lm_loss = None
1270
+ if labels is not None:
1271
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1272
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1273
+ labels = labels[:, 1:].contiguous()
1274
+ loss_fct = CrossEntropyLoss()
1275
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1276
+
1277
+ if not return_dict:
1278
+ output = (prediction_scores,) + outputs[2:]
1279
+ return ((lm_loss,) + output) if lm_loss is not None else output
1280
+
1281
+ return CausalLMOutputWithCrossAttentions(
1282
+ loss=lm_loss,
1283
+ logits=prediction_scores,
1284
+ past_key_values=outputs.past_key_values,
1285
+ hidden_states=outputs.hidden_states,
1286
+ attentions=outputs.attentions,
1287
+ cross_attentions=outputs.cross_attentions,
1288
+ )
1289
+
1290
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
1291
+ input_shape = input_ids.shape
1292
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1293
+ if attention_mask is None:
1294
+ attention_mask = input_ids.new_ones(input_shape)
1295
+
1296
+ # cut decoder_input_ids if past is used
1297
+ if past is not None:
1298
+ input_ids = input_ids[:, -1:]
1299
+
1300
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
1301
+
1302
+ def _reorder_cache(self, past, beam_idx):
1303
+ reordered_past = ()
1304
+ for layer_past in past:
1305
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1306
+ return reordered_past
1307
+
1308
+
1309
+ @add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING)
1310
+ class BertForMaskedLM(BertPreTrainedModel):
1311
+
1312
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1313
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1314
+
1315
+ def __init__(self, config):
1316
+ super().__init__(config)
1317
+
1318
+ if config.is_decoder:
1319
+ logger.warning(
1320
+ "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
1321
+ "bi-directional self-attention."
1322
+ )
1323
+
1324
+ self.bert = BertModel(config, add_pooling_layer=False)
1325
+ self.cls = BertOnlyMLMHead(config)
1326
+
1327
+ # Initialize weights and apply final processing
1328
+ self.post_init()
1329
+
1330
+ def get_output_embeddings(self):
1331
+ return self.cls.predictions.decoder
1332
+
1333
+ def set_output_embeddings(self, new_embeddings):
1334
+ self.cls.predictions.decoder = new_embeddings
1335
+
1336
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1337
+ @add_code_sample_docstrings(
1338
+ processor_class=_TOKENIZER_FOR_DOC,
1339
+ checkpoint=_CHECKPOINT_FOR_DOC,
1340
+ output_type=MaskedLMOutput,
1341
+ config_class=_CONFIG_FOR_DOC,
1342
+ expected_output="'paris'",
1343
+ expected_loss=0.88,
1344
+ )
1345
+ def forward(
1346
+ self,
1347
+ input_ids: Optional[torch.Tensor] = None,
1348
+ attention_mask: Optional[torch.Tensor] = None,
1349
+ token_type_ids: Optional[torch.Tensor] = None,
1350
+ position_ids: Optional[torch.Tensor] = None,
1351
+ head_mask: Optional[torch.Tensor] = None,
1352
+ inputs_embeds: Optional[torch.Tensor] = None,
1353
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1354
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1355
+ labels: Optional[torch.Tensor] = None,
1356
+ output_attentions: Optional[bool] = None,
1357
+ output_hidden_states: Optional[bool] = None,
1358
+ return_dict: Optional[bool] = None,
1359
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
1360
+ r"""
1361
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1362
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1363
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1364
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1365
+ """
1366
+
1367
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1368
+
1369
+ outputs = self.bert(
1370
+ input_ids,
1371
+ attention_mask=attention_mask,
1372
+ token_type_ids=token_type_ids,
1373
+ position_ids=position_ids,
1374
+ head_mask=head_mask,
1375
+ inputs_embeds=inputs_embeds,
1376
+ encoder_hidden_states=encoder_hidden_states,
1377
+ encoder_attention_mask=encoder_attention_mask,
1378
+ output_attentions=output_attentions,
1379
+ output_hidden_states=output_hidden_states,
1380
+ return_dict=return_dict,
1381
+ )
1382
+
1383
+ sequence_output = outputs[0]
1384
+ prediction_scores = self.cls(sequence_output)
1385
+
1386
+ masked_lm_loss = None
1387
+ if labels is not None:
1388
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1389
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1390
+
1391
+ if not return_dict:
1392
+ output = (prediction_scores,) + outputs[2:]
1393
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1394
+
1395
+ return MaskedLMOutput(
1396
+ loss=masked_lm_loss,
1397
+ logits=prediction_scores,
1398
+ hidden_states=outputs.hidden_states,
1399
+ attentions=outputs.attentions,
1400
+ )
1401
+
1402
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
1403
+ input_shape = input_ids.shape
1404
+ effective_batch_size = input_shape[0]
1405
+
1406
+ # add a dummy token
1407
+ if self.config.pad_token_id is None:
1408
+ raise ValueError("The PAD token should be defined for generation")
1409
+
1410
+ attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
1411
+ dummy_token = torch.full(
1412
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
1413
+ )
1414
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
1415
+
1416
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1417
+
1418
+
1419
+ @add_start_docstrings(
1420
+ """Bert Model with a `next sentence prediction (classification)` head on top.""",
1421
+ BERT_START_DOCSTRING,
1422
+ )
1423
+ class BertForNextSentencePrediction(BertPreTrainedModel):
1424
+ def __init__(self, config):
1425
+ super().__init__(config)
1426
+
1427
+ self.bert = BertModel(config)
1428
+ self.cls = BertOnlyNSPHead(config)
1429
+
1430
+ # Initialize weights and apply final processing
1431
+ self.post_init()
1432
+
1433
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1434
+ @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
1435
+ def forward(
1436
+ self,
1437
+ input_ids: Optional[torch.Tensor] = None,
1438
+ attention_mask: Optional[torch.Tensor] = None,
1439
+ token_type_ids: Optional[torch.Tensor] = None,
1440
+ position_ids: Optional[torch.Tensor] = None,
1441
+ head_mask: Optional[torch.Tensor] = None,
1442
+ inputs_embeds: Optional[torch.Tensor] = None,
1443
+ labels: Optional[torch.Tensor] = None,
1444
+ output_attentions: Optional[bool] = None,
1445
+ output_hidden_states: Optional[bool] = None,
1446
+ return_dict: Optional[bool] = None,
1447
+ **kwargs,
1448
+ ) -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
1449
+ r"""
1450
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1451
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
1452
+ (see `input_ids` docstring). Indices should be in `[0, 1]`:
1453
+
1454
+ - 0 indicates sequence B is a continuation of sequence A,
1455
+ - 1 indicates sequence B is a random sequence.
1456
+
1457
+ Returns:
1458
+
1459
+ Example:
1460
+
1461
+ ```python
1462
+ >>> from transformers import BertTokenizer, BertForNextSentencePrediction
1463
+ >>> import torch
1464
+
1465
+ >>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
1466
+ >>> model = BertForNextSentencePrediction.from_pretrained("bert-base-uncased")
1467
+
1468
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
1469
+ >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
1470
+ >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
1471
+
1472
+ >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
1473
+ >>> logits = outputs.logits
1474
+ >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
1475
+ ```
1476
+ """
1477
+
1478
+ if "next_sentence_label" in kwargs:
1479
+ warnings.warn(
1480
+ "The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
1481
+ " `labels` instead.",
1482
+ FutureWarning,
1483
+ )
1484
+ labels = kwargs.pop("next_sentence_label")
1485
+
1486
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1487
+
1488
+ outputs = self.bert(
1489
+ input_ids,
1490
+ attention_mask=attention_mask,
1491
+ token_type_ids=token_type_ids,
1492
+ position_ids=position_ids,
1493
+ head_mask=head_mask,
1494
+ inputs_embeds=inputs_embeds,
1495
+ output_attentions=output_attentions,
1496
+ output_hidden_states=output_hidden_states,
1497
+ return_dict=return_dict,
1498
+ )
1499
+
1500
+ pooled_output = outputs[1]
1501
+
1502
+ seq_relationship_scores = self.cls(pooled_output)
1503
+
1504
+ next_sentence_loss = None
1505
+ if labels is not None:
1506
+ loss_fct = CrossEntropyLoss()
1507
+ next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
1508
+
1509
+ if not return_dict:
1510
+ output = (seq_relationship_scores,) + outputs[2:]
1511
+ return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
1512
+
1513
+ return NextSentencePredictorOutput(
1514
+ loss=next_sentence_loss,
1515
+ logits=seq_relationship_scores,
1516
+ hidden_states=outputs.hidden_states,
1517
+ attentions=outputs.attentions,
1518
+ )
1519
+
1520
+
1521
+ @add_start_docstrings(
1522
+ """
1523
+ Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
1524
+ output) e.g. for GLUE tasks.
1525
+ """,
1526
+ BERT_START_DOCSTRING,
1527
+ )
1528
+ class BertForSequenceClassification(BertPreTrainedModel):
1529
+ def __init__(self, config):
1530
+ super().__init__(config)
1531
+ self.num_labels = config.num_labels
1532
+ self.config = config
1533
+
1534
+ self.bert = BertModel(config)
1535
+ classifier_dropout_prob = (
1536
+ config.classifier_dropout_prob if config.classifier_dropout_prob is not None else config.hidden_dropout_prob
1537
+ )
1538
+ self.dropout = nn.Dropout(classifier_dropout_prob)
1539
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1540
+
1541
+ # Initialize weights and apply final processing
1542
+ self.post_init()
1543
+
1544
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1545
+ @add_code_sample_docstrings(
1546
+ processor_class=_TOKENIZER_FOR_DOC,
1547
+ checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
1548
+ output_type=SequenceClassifierOutput,
1549
+ config_class=_CONFIG_FOR_DOC,
1550
+ expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
1551
+ expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
1552
+ )
1553
+ def forward(
1554
+ self,
1555
+ input_ids: Optional[torch.Tensor] = None,
1556
+ attention_mask: Optional[torch.Tensor] = None,
1557
+ token_type_ids: Optional[torch.Tensor] = None,
1558
+ position_ids: Optional[torch.Tensor] = None,
1559
+ head_mask: Optional[torch.Tensor] = None,
1560
+ inputs_embeds: Optional[torch.Tensor] = None,
1561
+ labels: Optional[torch.Tensor] = None,
1562
+ output_attentions: Optional[bool] = None,
1563
+ output_hidden_states: Optional[bool] = None,
1564
+ return_dict: Optional[bool] = None,
1565
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1566
+ r"""
1567
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1568
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1569
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1570
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1571
+ """
1572
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1573
+
1574
+ outputs = self.bert(
1575
+ input_ids,
1576
+ attention_mask=attention_mask,
1577
+ token_type_ids=token_type_ids,
1578
+ position_ids=position_ids,
1579
+ head_mask=head_mask,
1580
+ inputs_embeds=inputs_embeds,
1581
+ output_attentions=output_attentions,
1582
+ output_hidden_states=output_hidden_states,
1583
+ return_dict=return_dict,
1584
+ )
1585
+
1586
+ pooled_output = outputs[1]
1587
+
1588
+ pooled_output = self.dropout(pooled_output)
1589
+ logits = self.classifier(pooled_output)
1590
+
1591
+ loss = None
1592
+ if labels is not None:
1593
+ if self.config.problem_type is None:
1594
+ if self.num_labels == 1:
1595
+ self.config.problem_type = "regression"
1596
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1597
+ self.config.problem_type = "single_label_classification"
1598
+ else:
1599
+ self.config.problem_type = "multi_label_classification"
1600
+
1601
+ if self.config.problem_type == "regression":
1602
+ loss_fct = MSELoss()
1603
+ if self.num_labels == 1:
1604
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1605
+ else:
1606
+ loss = loss_fct(logits, labels)
1607
+ elif self.config.problem_type == "single_label_classification":
1608
+ loss_fct = CrossEntropyLoss()
1609
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1610
+ elif self.config.problem_type == "multi_label_classification":
1611
+ loss_fct = BCEWithLogitsLoss()
1612
+ loss = loss_fct(logits, labels)
1613
+ if not return_dict:
1614
+ output = (logits,) + outputs[2:]
1615
+ return ((loss,) + output) if loss is not None else output
1616
+
1617
+ return SequenceClassifierOutput(
1618
+ loss=loss,
1619
+ logits=logits,
1620
+ hidden_states=outputs.hidden_states,
1621
+ attentions=outputs.attentions,
1622
+ )
1623
+
1624
+
1625
+ @add_start_docstrings(
1626
+ """
1627
+ Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1628
+ softmax) e.g. for RocStories/SWAG tasks.
1629
+ """,
1630
+ BERT_START_DOCSTRING,
1631
+ )
1632
+ class BertForMultipleChoice(BertPreTrainedModel):
1633
+ def __init__(self, config):
1634
+ super().__init__(config)
1635
+
1636
+ self.bert = BertModel(config)
1637
+ classifier_dropout_prob = (
1638
+ config.classifier_dropout_prob if config.classifier_dropout_prob is not None else config.hidden_dropout_prob
1639
+ )
1640
+ self.dropout = nn.Dropout(classifier_dropout_prob)
1641
+ self.classifier = nn.Linear(config.hidden_size, 1)
1642
+
1643
+ # Initialize weights and apply final processing
1644
+ self.post_init()
1645
+
1646
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
1647
+ @add_code_sample_docstrings(
1648
+ processor_class=_TOKENIZER_FOR_DOC,
1649
+ checkpoint=_CHECKPOINT_FOR_DOC,
1650
+ output_type=MultipleChoiceModelOutput,
1651
+ config_class=_CONFIG_FOR_DOC,
1652
+ )
1653
+ def forward(
1654
+ self,
1655
+ input_ids: Optional[torch.Tensor] = None,
1656
+ attention_mask: Optional[torch.Tensor] = None,
1657
+ token_type_ids: Optional[torch.Tensor] = None,
1658
+ position_ids: Optional[torch.Tensor] = None,
1659
+ head_mask: Optional[torch.Tensor] = None,
1660
+ inputs_embeds: Optional[torch.Tensor] = None,
1661
+ labels: Optional[torch.Tensor] = None,
1662
+ output_attentions: Optional[bool] = None,
1663
+ output_hidden_states: Optional[bool] = None,
1664
+ return_dict: Optional[bool] = None,
1665
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
1666
+ r"""
1667
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1668
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1669
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
1670
+ `input_ids` above)
1671
+ """
1672
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1673
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1674
+
1675
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1676
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1677
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1678
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1679
+ inputs_embeds = (
1680
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1681
+ if inputs_embeds is not None
1682
+ else None
1683
+ )
1684
+
1685
+ outputs = self.bert(
1686
+ input_ids,
1687
+ attention_mask=attention_mask,
1688
+ token_type_ids=token_type_ids,
1689
+ position_ids=position_ids,
1690
+ head_mask=head_mask,
1691
+ inputs_embeds=inputs_embeds,
1692
+ output_attentions=output_attentions,
1693
+ output_hidden_states=output_hidden_states,
1694
+ return_dict=return_dict,
1695
+ )
1696
+
1697
+ pooled_output = outputs[1]
1698
+
1699
+ pooled_output = self.dropout(pooled_output)
1700
+ logits = self.classifier(pooled_output)
1701
+ reshaped_logits = logits.view(-1, num_choices)
1702
+
1703
+ loss = None
1704
+ if labels is not None:
1705
+ loss_fct = CrossEntropyLoss()
1706
+ loss = loss_fct(reshaped_logits, labels)
1707
+
1708
+ if not return_dict:
1709
+ output = (reshaped_logits,) + outputs[2:]
1710
+ return ((loss,) + output) if loss is not None else output
1711
+
1712
+ return MultipleChoiceModelOutput(
1713
+ loss=loss,
1714
+ logits=reshaped_logits,
1715
+ hidden_states=outputs.hidden_states,
1716
+ attentions=outputs.attentions,
1717
+ )
1718
+
1719
+
1720
+ @add_start_docstrings(
1721
+ """
1722
+ Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1723
+ Named-Entity-Recognition (NER) tasks.
1724
+ """,
1725
+ BERT_START_DOCSTRING,
1726
+ )
1727
+ class BertForTokenClassification(BertPreTrainedModel):
1728
+
1729
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1730
+
1731
+ def __init__(self, config):
1732
+ super().__init__(config)
1733
+ self.num_labels = config.num_labels
1734
+
1735
+ self.bert = BertModel(config, add_pooling_layer=False)
1736
+ classifier_dropout_prob = (
1737
+ config.classifier_dropout_prob if config.classifier_dropout_prob is not None else config.hidden_dropout_prob
1738
+ )
1739
+ self.dropout = nn.Dropout(classifier_dropout_prob)
1740
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1741
+
1742
+ # Initialize weights and apply final processing
1743
+ self.post_init()
1744
+
1745
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1746
+ @add_code_sample_docstrings(
1747
+ processor_class=_TOKENIZER_FOR_DOC,
1748
+ checkpoint=_CHECKPOINT_FOR_TOKEN_CLASSIFICATION,
1749
+ output_type=TokenClassifierOutput,
1750
+ config_class=_CONFIG_FOR_DOC,
1751
+ expected_output=_TOKEN_CLASS_EXPECTED_OUTPUT,
1752
+ expected_loss=_TOKEN_CLASS_EXPECTED_LOSS,
1753
+ )
1754
+ def forward(
1755
+ self,
1756
+ input_ids: Optional[torch.Tensor] = None,
1757
+ attention_mask: Optional[torch.Tensor] = None,
1758
+ token_type_ids: Optional[torch.Tensor] = None,
1759
+ position_ids: Optional[torch.Tensor] = None,
1760
+ head_mask: Optional[torch.Tensor] = None,
1761
+ inputs_embeds: Optional[torch.Tensor] = None,
1762
+ labels: Optional[torch.Tensor] = None,
1763
+ output_attentions: Optional[bool] = None,
1764
+ output_hidden_states: Optional[bool] = None,
1765
+ return_dict: Optional[bool] = None,
1766
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1767
+ r"""
1768
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1769
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1770
+ """
1771
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1772
+
1773
+ outputs = self.bert(
1774
+ input_ids,
1775
+ attention_mask=attention_mask,
1776
+ token_type_ids=token_type_ids,
1777
+ position_ids=position_ids,
1778
+ head_mask=head_mask,
1779
+ inputs_embeds=inputs_embeds,
1780
+ output_attentions=output_attentions,
1781
+ output_hidden_states=output_hidden_states,
1782
+ return_dict=return_dict,
1783
+ )
1784
+
1785
+ sequence_output = outputs[0]
1786
+
1787
+ sequence_output = self.dropout(sequence_output)
1788
+ logits = self.classifier(sequence_output)
1789
+
1790
+ loss = None
1791
+ if labels is not None:
1792
+ loss_fct = CrossEntropyLoss()
1793
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1794
+
1795
+ if not return_dict:
1796
+ output = (logits,) + outputs[2:]
1797
+ return ((loss,) + output) if loss is not None else output
1798
+
1799
+ return TokenClassifierOutput(
1800
+ loss=loss,
1801
+ logits=logits,
1802
+ hidden_states=outputs.hidden_states,
1803
+ attentions=outputs.attentions,
1804
+ )
1805
+
1806
+
1807
+ @add_start_docstrings(
1808
+ """
1809
+ Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1810
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1811
+ """,
1812
+ BERT_START_DOCSTRING,
1813
+ )
1814
+ class BertForQuestionAnswering(BertPreTrainedModel):
1815
+
1816
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1817
+
1818
+ def __init__(self, config):
1819
+ super().__init__(config)
1820
+ self.num_labels = config.num_labels
1821
+
1822
+ self.bert = BertModel(config, add_pooling_layer=False)
1823
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1824
+
1825
+ # Initialize weights and apply final processing
1826
+ self.post_init()
1827
+
1828
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1829
+ @add_code_sample_docstrings(
1830
+ processor_class=_TOKENIZER_FOR_DOC,
1831
+ checkpoint=_CHECKPOINT_FOR_QA,
1832
+ output_type=QuestionAnsweringModelOutput,
1833
+ config_class=_CONFIG_FOR_DOC,
1834
+ qa_target_start_index=_QA_TARGET_START_INDEX,
1835
+ qa_target_end_index=_QA_TARGET_END_INDEX,
1836
+ expected_output=_QA_EXPECTED_OUTPUT,
1837
+ expected_loss=_QA_EXPECTED_LOSS,
1838
+ )
1839
+ def forward(
1840
+ self,
1841
+ input_ids: Optional[torch.Tensor] = None,
1842
+ attention_mask: Optional[torch.Tensor] = None,
1843
+ token_type_ids: Optional[torch.Tensor] = None,
1844
+ position_ids: Optional[torch.Tensor] = None,
1845
+ head_mask: Optional[torch.Tensor] = None,
1846
+ inputs_embeds: Optional[torch.Tensor] = None,
1847
+ start_positions: Optional[torch.Tensor] = None,
1848
+ end_positions: Optional[torch.Tensor] = None,
1849
+ output_attentions: Optional[bool] = None,
1850
+ output_hidden_states: Optional[bool] = None,
1851
+ return_dict: Optional[bool] = None,
1852
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
1853
+ r"""
1854
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1855
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1856
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1857
+ are not taken into account for computing the loss.
1858
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1859
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1860
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1861
+ are not taken into account for computing the loss.
1862
+ """
1863
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1864
+
1865
+ outputs = self.bert(
1866
+ input_ids,
1867
+ attention_mask=attention_mask,
1868
+ token_type_ids=token_type_ids,
1869
+ position_ids=position_ids,
1870
+ head_mask=head_mask,
1871
+ inputs_embeds=inputs_embeds,
1872
+ output_attentions=output_attentions,
1873
+ output_hidden_states=output_hidden_states,
1874
+ return_dict=return_dict,
1875
+ )
1876
+
1877
+ sequence_output = outputs[0]
1878
+
1879
+ logits = self.qa_outputs(sequence_output)
1880
+ start_logits, end_logits = logits.split(1, dim=-1)
1881
+ start_logits = start_logits.squeeze(-1).contiguous()
1882
+ end_logits = end_logits.squeeze(-1).contiguous()
1883
+
1884
+ total_loss = None
1885
+ if start_positions is not None and end_positions is not None:
1886
+ # If we are on multi-GPU, split add a dimension
1887
+ if len(start_positions.size()) > 1:
1888
+ start_positions = start_positions.squeeze(-1)
1889
+ if len(end_positions.size()) > 1:
1890
+ end_positions = end_positions.squeeze(-1)
1891
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1892
+ ignored_index = start_logits.size(1)
1893
+ start_positions = start_positions.clamp(0, ignored_index)
1894
+ end_positions = end_positions.clamp(0, ignored_index)
1895
+
1896
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1897
+ start_loss = loss_fct(start_logits, start_positions)
1898
+ end_loss = loss_fct(end_logits, end_positions)
1899
+ total_loss = (start_loss + end_loss) / 2
1900
+
1901
+ if not return_dict:
1902
+ output = (start_logits, end_logits) + outputs[2:]
1903
+ return ((total_loss,) + output) if total_loss is not None else output
1904
+
1905
+ return QuestionAnsweringModelOutput(
1906
+ loss=total_loss,
1907
+ start_logits=start_logits,
1908
+ end_logits=end_logits,
1909
+ hidden_states=outputs.hidden_states,
1910
+ attentions=outputs.attentions,
1911
+ )
modeling_gplm.py ADDED
@@ -0,0 +1,1145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ '''
4
+ @license: (C) Copyright 2021, Hey.
5
+ @author: Hey
6
+ @email: sanyuan.hy@alibaba-inc.com
7
+ @tel: 137****6540
8
+ @datetime: 2023/7/24 10:01
9
+ @project: LucaOne
10
+ @file: modeling_gplm
11
+ @desc: LucaOne Model Detail
12
+ '''
13
+ import math
14
+ from typing import Dict, Optional, Sequence, Tuple, List, Union
15
+ import uuid
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import Tensor, nn
19
+ from torch.nn import Parameter
20
+
21
+
22
+ def gelu(x):
23
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
24
+
25
+
26
+ def symmetrize(x):
27
+ return x + x.transpose(-1, -2)
28
+
29
+
30
+ def apc(x):
31
+ a1 = x.sum(-1, keepdims=True)
32
+ a2 = x.sum(-2, keepdims=True)
33
+ a12 = x.sum((-1, -2), keepdims=True)
34
+
35
+ avg = a1 * a2
36
+ avg.div_(a12) # in-place to reduce memory
37
+ normalized = x - avg
38
+ return normalized
39
+
40
+
41
+ class LucaGPLM1LayerNorm(nn.Module):
42
+ def __init__(self, hidden_size, eps=1e-12, affine=True):
43
+ """Construct a layernorm layer in the TF style (eps inside the sqrt)."""
44
+ super().__init__()
45
+ self.hidden_size = (hidden_size,) if isinstance(hidden_size, int) else tuple(hidden_size)
46
+ self.eps = eps
47
+ self.affine = bool(affine)
48
+ if self.affine:
49
+ self.weight = nn.Parameter(torch.ones(hidden_size))
50
+ self.bias = nn.Parameter(torch.zeros(hidden_size))
51
+ else:
52
+ self.weight, self.bias = None, None
53
+
54
+ def forward(self, x):
55
+ dims = tuple(-(i + 1) for i in range(len(self.hidden_size)))
56
+ means = x.mean(dims, keepdim=True)
57
+ x_zeromean = x - means
58
+ variances = x_zeromean.pow(2).mean(dims, keepdim=True)
59
+ x = x_zeromean / torch.sqrt(variances + self.eps)
60
+ if self.affine:
61
+ x = (self.weight * x) + self.bias
62
+ return x
63
+
64
+ from torch.nn import LayerNorm as LucaGPLM1bLayerNorm
65
+
66
+ class LucaGPLMTransformerLayer(nn.Module):
67
+ """LucaGPLM Transformer layer block."""
68
+
69
+ def __init__(
70
+ self,
71
+ embed_dim,
72
+ ffn_embed_dim,
73
+ attention_heads,
74
+ add_bias_kv=True,
75
+ use_lucagplm1b_layer_norm=False,
76
+ use_rotary_embeddings: bool = False,
77
+ ):
78
+ '''
79
+ Tramsformer-Encoder 层
80
+ :param embed_dim: token embedding dim
81
+ :param ffn_embed_dim: fully connected layer dim
82
+ :param attention_heads: heads num
83
+ :param add_bias_kv: key-value layer add bias
84
+ :param use_lucagplm1b_layer_norm: whether to use lucagplm 1b layer norm
85
+ :param use_rotary_embeddings: whether to use rotary embedding
86
+ '''
87
+ super().__init__()
88
+ self.embed_dim = embed_dim
89
+ self.ffn_embed_dim = ffn_embed_dim
90
+ self.attention_heads = attention_heads
91
+ self.use_rotary_embeddings = use_rotary_embeddings
92
+ self._init_submodules(add_bias_kv, use_lucagplm1b_layer_norm)
93
+
94
+ def _init_submodules(self, add_bias_kv, use_lucagplm1b_layer_norm):
95
+ LucaGPLMLayerNorm = LucaGPLM1bLayerNorm if use_lucagplm1b_layer_norm else LucaGPLM1LayerNorm
96
+
97
+ # pre layer norm
98
+ self.pre_layer_norm = LucaGPLMLayerNorm(self.embed_dim)
99
+
100
+ self.self_attn = LucaGPLMMultiheadAttention(
101
+ self.embed_dim,
102
+ self.attention_heads,
103
+ add_bias_kv=add_bias_kv,
104
+ add_zero_attn=False,
105
+ use_rotary_embeddings=self.use_rotary_embeddings,
106
+ )
107
+
108
+ # post layer norm
109
+ self.post_layer_norm = LucaGPLMLayerNorm(self.embed_dim)
110
+
111
+ # dimension increase by the fully connected layer
112
+ self.fc1 = nn.Linear(self.embed_dim, self.ffn_embed_dim)
113
+
114
+ # dimension reduction by the fully connected layer
115
+ self.fc2 = nn.Linear(self.ffn_embed_dim, self.embed_dim)
116
+
117
+ def forward(
118
+ self,
119
+ x,
120
+ self_attn_mask=None,
121
+ self_attn_padding_mask=None,
122
+ need_head_weights=False
123
+ ):
124
+ residual = x
125
+ x = self.pre_layer_norm(x)
126
+ x, attn = self.self_attn(
127
+ query=x,
128
+ key=x,
129
+ value=x,
130
+ key_padding_mask=self_attn_padding_mask,
131
+ need_weights=True,
132
+ need_head_weights=need_head_weights,
133
+ attn_mask=self_attn_mask,
134
+ )
135
+ x = residual + x
136
+
137
+ residual = x
138
+ x = self.post_layer_norm(x)
139
+ x = gelu(self.fc1(x))
140
+ x = self.fc2(x)
141
+ x = residual + x
142
+
143
+ return x, attn
144
+
145
+
146
+ class AxialTransformerLayer(nn.Module):
147
+ def __init__(
148
+ self,
149
+ embedding_dim: int = 768,
150
+ ffn_embedding_dim: int = 3072,
151
+ num_attention_heads: int = 8,
152
+ dropout: float = 0.1,
153
+ attention_dropout: float = 0.1,
154
+ activation_dropout: float = 0.1,
155
+ max_tokens_per_msa: int = 2**14,
156
+ ) -> None:
157
+ super().__init__()
158
+
159
+ # Initialize parameters
160
+ self.embedding_dim = embedding_dim
161
+ self.dropout_prob = dropout
162
+
163
+ row_self_attention = RowSelfAttention(
164
+ embedding_dim,
165
+ num_attention_heads,
166
+ dropout=dropout,
167
+ max_tokens_per_msa=max_tokens_per_msa,
168
+ )
169
+
170
+ column_self_attention = ColumnSelfAttention(
171
+ embedding_dim,
172
+ num_attention_heads,
173
+ dropout=dropout,
174
+ max_tokens_per_msa=max_tokens_per_msa,
175
+ )
176
+
177
+ feed_forward_layer = FeedForwardNetwork(
178
+ embedding_dim,
179
+ ffn_embedding_dim,
180
+ activation_dropout=activation_dropout,
181
+ max_tokens_per_msa=max_tokens_per_msa,
182
+ )
183
+
184
+ self.row_self_attention = self.build_residual(row_self_attention)
185
+ self.column_self_attention = self.build_residual(column_self_attention)
186
+ self.feed_forward_layer = self.build_residual(feed_forward_layer)
187
+
188
+ def build_residual(self, layer: nn.Module):
189
+ return NormalizedResidualBlock(
190
+ layer,
191
+ self.embedding_dim,
192
+ self.dropout_prob,
193
+ )
194
+
195
+ def forward(
196
+ self,
197
+ x: torch.Tensor,
198
+ self_attn_mask: Optional[torch.Tensor] = None,
199
+ self_attn_padding_mask: Optional[torch.Tensor] = None,
200
+ need_head_weights: bool = False,
201
+ ):
202
+ x, row_attn = self.row_self_attention(
203
+ x,
204
+ self_attn_mask=self_attn_mask,
205
+ self_attn_padding_mask=self_attn_padding_mask,
206
+ )
207
+ x, column_attn = self.column_self_attention(
208
+ x,
209
+ self_attn_mask=self_attn_mask,
210
+ self_attn_padding_mask=self_attn_padding_mask,
211
+ )
212
+ x = self.feed_forward_layer(x)
213
+ if need_head_weights:
214
+ return x, column_attn, row_attn
215
+ else:
216
+ return x
217
+
218
+
219
+ class LearnedPositionalEmbedding(nn.Embedding):
220
+ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
221
+ if padding_idx is not None:
222
+ num_embeddings_ = num_embeddings + padding_idx + 1
223
+ else:
224
+ num_embeddings_ = num_embeddings
225
+ super().__init__(num_embeddings_, embedding_dim, padding_idx)
226
+ self.max_positions = num_embeddings
227
+
228
+ def forward(self, input: torch.Tensor):
229
+ """Input is expected to be of size [bsz x seqlen]."""
230
+ if input.size(1) > self.max_positions:
231
+ raise ValueError(
232
+ f"Sequence length {input.size(1)} above maximum "
233
+ f" sequence length of {self.max_positions}"
234
+ )
235
+ mask = input.ne(self.padding_idx).int()
236
+ positions = (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + self.padding_idx
237
+ return F.embedding(
238
+ positions,
239
+ self.weight,
240
+ self.padding_idx,
241
+ self.max_norm,
242
+ self.norm_type,
243
+ self.scale_grad_by_freq,
244
+ self.sparse,
245
+ )
246
+
247
+
248
+ class SinusoidalPositionalEmbedding(nn.Module):
249
+ def __init__(self, embed_dim, padding_idx, learned=False):
250
+ super().__init__()
251
+ self.embed_dim = embed_dim
252
+ self.padding_idx = padding_idx
253
+ self.register_buffer("_float_tensor", torch.FloatTensor(1))
254
+ self.weights = None
255
+
256
+ def forward(self, x):
257
+ bsz, seq_len = x.shape
258
+ max_pos = self.padding_idx + 1 + seq_len
259
+ if self.weights is None or max_pos > self.weights.size(0):
260
+ self.weights = self.get_embedding(max_pos)
261
+ self.weights = self.weights.type_as(self._float_tensor)
262
+
263
+ positions = self.make_positions(x)
264
+ return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
265
+
266
+ def make_positions(self, x):
267
+ mask = x.ne(self.padding_idx)
268
+ range_buf = torch.arange(x.size(1), device=x.device).expand_as(x) + self.padding_idx + 1
269
+ positions = range_buf.expand_as(x)
270
+ return positions * mask.long() + self.padding_idx * (1 - mask.long())
271
+
272
+ def get_embedding(self, num_embeddings):
273
+ half_dim = self.embed_dim // 2
274
+ emb = math.log(10000) / (half_dim - 1)
275
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
276
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
277
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
278
+ if self.embed_dim % 2 == 1:
279
+ # zero pad
280
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
281
+ if self.padding_idx is not None:
282
+ emb[self.padding_idx, :] = 0
283
+ return emb
284
+
285
+
286
+ class RobertaLMHead(nn.Module):
287
+ def __init__(self, embed_dim, output_dim, weight):
288
+ super().__init__()
289
+ self.dense = nn.Linear(embed_dim, embed_dim)
290
+ self.layer_norm = LucaGPLM1bLayerNorm(embed_dim)
291
+ self.weight = weight
292
+ self.bias = nn.Parameter(torch.zeros(output_dim))
293
+
294
+ def forward(self, features):
295
+ x = self.dense(features)
296
+ x = gelu(x)
297
+ x = self.layer_norm(x)
298
+ # project back to size of vocabulary with bias
299
+ x = F.linear(x, self.weight) + self.bias
300
+ return x
301
+
302
+
303
+ class ContactPredictionHead(nn.Module):
304
+ def __init__(
305
+ self,
306
+ in_features: int,
307
+ prepend_bos: bool,
308
+ append_eos: bool,
309
+ bias=True,
310
+ eos_idx: Optional[int] = None,
311
+ ):
312
+ super().__init__()
313
+ self.in_features = in_features
314
+ self.prepend_bos = prepend_bos
315
+ self.append_eos = append_eos
316
+ if append_eos and eos_idx is None:
317
+ raise ValueError("Using an alphabet with eos token, but no eos token was passed in.")
318
+ self.eos_idx = eos_idx
319
+ self.regression = nn.Linear(in_features, 1, bias)
320
+ self.activation = nn.Sigmoid()
321
+
322
+ def forward(self, tokens, attentions):
323
+ # remove eos token attentions
324
+ if self.append_eos:
325
+ eos_mask = tokens.ne(self.eos_idx).to(attentions)
326
+ eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
327
+ attentions = attentions * eos_mask[:, None, None, :, :]
328
+ attentions = attentions[..., :-1, :-1]
329
+ # remove cls token attentions
330
+ if self.prepend_bos:
331
+ attentions = attentions[..., 1:, 1:]
332
+ batch_size, layers, heads, seqlen, _ = attentions.size()
333
+ attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
334
+
335
+ # features: B x C x T x T
336
+ attentions = attentions.to(
337
+ self.regression.weight.device
338
+ ) # attentions always float32, may need to convert to float16
339
+ attentions = apc(symmetrize(attentions))
340
+ attentions = attentions.permute(0, 2, 3, 1)
341
+ return self.activation(self.regression(attentions).squeeze(3))
342
+
343
+
344
+ class NormalizedResidualBlock(nn.Module):
345
+ def __init__(
346
+ self,
347
+ layer: nn.Module,
348
+ embedding_dim: int,
349
+ dropout: float = 0.1,
350
+ ):
351
+ super().__init__()
352
+ self.embedding_dim = embedding_dim
353
+
354
+ self.layer = layer
355
+ self.dropout_module = nn.Dropout(
356
+ dropout,
357
+ )
358
+ self.layer_norm = LucaGPLM1bLayerNorm(self.embedding_dim)
359
+
360
+ def forward(self, x, *args, **kwargs):
361
+ residual = x
362
+ x = self.layer_norm(x)
363
+ outputs = self.layer(x, *args, **kwargs)
364
+ if isinstance(outputs, tuple):
365
+ x, *out = outputs
366
+ else:
367
+ x = outputs
368
+ out = None
369
+
370
+ x = self.dropout_module(x)
371
+ x = residual + x
372
+
373
+ if out is not None:
374
+ return (x,) + tuple(out)
375
+ else:
376
+ return x
377
+
378
+
379
+ class FeedForwardNetwork(nn.Module):
380
+ def __init__(
381
+ self,
382
+ embedding_dim: int,
383
+ ffn_embedding_dim: int,
384
+ activation_dropout: float = 0.1,
385
+ max_tokens_per_msa: int = 2**14,
386
+ ):
387
+ super().__init__()
388
+ self.embedding_dim = embedding_dim
389
+ self.ffn_embedding_dim = ffn_embedding_dim
390
+ self.max_tokens_per_msa = max_tokens_per_msa
391
+ self.activation_fn = nn.GELU()
392
+ self.activation_dropout_module = nn.Dropout(
393
+ activation_dropout,
394
+ )
395
+ self.fc1 = nn.Linear(embedding_dim, ffn_embedding_dim)
396
+ self.fc2 = nn.Linear(ffn_embedding_dim, embedding_dim)
397
+
398
+ def forward(self, x):
399
+ x = self.activation_fn(self.fc1(x))
400
+ x = self.activation_dropout_module(x)
401
+ x = self.fc2(x)
402
+ return x
403
+
404
+
405
+ class RowSelfAttention(nn.Module):
406
+ """Compute self-attention over rows of a 2D input."""
407
+
408
+ def __init__(
409
+ self,
410
+ embed_dim,
411
+ num_heads,
412
+ dropout=0.0,
413
+ max_tokens_per_msa: int = 2 ** 16,
414
+ ):
415
+ super().__init__()
416
+ self.num_heads = num_heads
417
+ self.dropout = dropout
418
+ self.head_dim = embed_dim // num_heads
419
+ self.scaling = self.head_dim ** -0.5
420
+ self.max_tokens_per_msa = max_tokens_per_msa
421
+ self.attn_shape = "hnij"
422
+
423
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
424
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
425
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
426
+
427
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
428
+ self.dropout_module = nn.Dropout(dropout)
429
+
430
+ def align_scaling(self, q):
431
+ num_rows = q.size(0)
432
+ return self.scaling / math.sqrt(num_rows)
433
+
434
+ def _batched_forward(
435
+ self,
436
+ x,
437
+ self_attn_mask=None,
438
+ self_attn_padding_mask=None,
439
+ ):
440
+ num_rows, num_cols, batch_size, embed_dim = x.size()
441
+ max_rows = max(1, self.max_tokens_per_msa // num_cols)
442
+ attns = 0
443
+ scaling = self.align_scaling(x)
444
+ for start in range(0, num_rows, max_rows):
445
+ attn_weights = self.compute_attention_weights(
446
+ x[start : start + max_rows],
447
+ scaling,
448
+ self_attn_mask=self_attn_mask,
449
+ self_attn_padding_mask=self_attn_padding_mask[:, start : start + max_rows]
450
+ if self_attn_padding_mask is not None
451
+ else None,
452
+ )
453
+ attns += attn_weights
454
+ attn_probs = attns.softmax(-1)
455
+ attn_probs = self.dropout_module(attn_probs)
456
+
457
+ outputs = []
458
+ for start in range(0, num_rows, max_rows):
459
+ output = self.compute_attention_update(x[start : start + max_rows], attn_probs)
460
+ outputs.append(output)
461
+
462
+ output = torch.cat(outputs, 0)
463
+ return output, attn_probs
464
+
465
+ def compute_attention_weights(
466
+ self,
467
+ x,
468
+ scaling: float,
469
+ self_attn_mask=None,
470
+ self_attn_padding_mask=None,
471
+ ):
472
+ num_rows, num_cols, batch_size, embed_dim = x.size()
473
+ q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
474
+ k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
475
+ q *= scaling
476
+ if self_attn_padding_mask is not None:
477
+ # Zero out any padded aligned positions - this is important since
478
+ # we take a sum across the alignment axis.
479
+ q *= 1 - self_attn_padding_mask.permute(1, 2, 0).unsqueeze(3).unsqueeze(4).to(q)
480
+
481
+ attn_weights = torch.einsum(f"rinhd,rjnhd->{self.attn_shape}", q, k)
482
+
483
+ if self_attn_mask is not None:
484
+ raise NotImplementedError
485
+ # Mask Size: [B x R x C], Weights Size: [H x B x C x C]
486
+
487
+ if self_attn_padding_mask is not None:
488
+ attn_weights = attn_weights.masked_fill(
489
+ self_attn_padding_mask[:, 0].unsqueeze(0).unsqueeze(2),
490
+ -10000,
491
+ )
492
+
493
+ return attn_weights
494
+
495
+ def compute_attention_update(
496
+ self,
497
+ x,
498
+ attn_probs,
499
+ ):
500
+ num_rows, num_cols, batch_size, embed_dim = x.size()
501
+ v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
502
+ context = torch.einsum(f"{self.attn_shape},rjnhd->rinhd", attn_probs, v)
503
+ context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim)
504
+ output = self.out_proj(context)
505
+ return output
506
+
507
+ def forward(
508
+ self,
509
+ x,
510
+ self_attn_mask=None,
511
+ self_attn_padding_mask=None,
512
+ ):
513
+ num_rows, num_cols, batch_size, embed_dim = x.size()
514
+ if (num_rows * num_cols > self.max_tokens_per_msa) and not torch.is_grad_enabled():
515
+ return self._batched_forward(x, self_attn_mask, self_attn_padding_mask)
516
+ else:
517
+ scaling = self.align_scaling(x)
518
+ attn_weights = self.compute_attention_weights(
519
+ x, scaling, self_attn_mask, self_attn_padding_mask
520
+ )
521
+ attn_probs = attn_weights.softmax(-1)
522
+ attn_probs = self.dropout_module(attn_probs)
523
+ output = self.compute_attention_update(x, attn_probs)
524
+ return output, attn_probs
525
+
526
+
527
+ class ColumnSelfAttention(nn.Module):
528
+ """Compute self-attention over columns of a 2D input."""
529
+
530
+ def __init__(
531
+ self,
532
+ embed_dim,
533
+ num_heads,
534
+ dropout=0.0,
535
+ max_tokens_per_msa: int = 2 ** 16,
536
+ ):
537
+ super().__init__()
538
+
539
+ self.num_heads = num_heads
540
+ self.dropout = dropout
541
+ self.head_dim = embed_dim // num_heads
542
+ self.scaling = self.head_dim ** -0.5
543
+ self.max_tokens_per_msa = max_tokens_per_msa
544
+
545
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
546
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
547
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
548
+
549
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
550
+ self.dropout_module = nn.Dropout(dropout)
551
+
552
+ def _batched_forward(
553
+ self,
554
+ x,
555
+ self_attn_mask=None,
556
+ self_attn_padding_mask=None,
557
+ ):
558
+ num_rows, num_cols, batch_size, embed_dim = x.size()
559
+ max_cols = max(1, self.max_tokens_per_msa // num_rows)
560
+ outputs = []
561
+ attns = []
562
+ for start in range(0, num_cols, max_cols):
563
+ output, attn = self(
564
+ x[:, start : start + max_cols],
565
+ self_attn_mask=self_attn_mask,
566
+ self_attn_padding_mask=self_attn_padding_mask[:, :, start : start + max_cols]
567
+ if self_attn_padding_mask is not None
568
+ else None,
569
+ )
570
+ outputs.append(output)
571
+ attns.append(attn)
572
+ output = torch.cat(outputs, 1)
573
+ attns = torch.cat(attns, 1)
574
+ return output, attns
575
+
576
+ def compute_attention_update(
577
+ self,
578
+ x,
579
+ self_attn_mask=None,
580
+ self_attn_padding_mask=None,
581
+ ):
582
+ num_rows, num_cols, batch_size, embed_dim = x.size()
583
+ if num_rows == 1:
584
+ # if there is only 1 position, this is equivalent and doesn't break with padding
585
+ attn_probs = torch.ones(
586
+ self.num_heads,
587
+ num_cols,
588
+ batch_size,
589
+ num_rows,
590
+ num_rows,
591
+ device=x.device,
592
+ dtype=x.dtype,
593
+ )
594
+ output = self.out_proj(self.v_proj(x))
595
+ else:
596
+ q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
597
+ k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
598
+ v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
599
+ q *= self.scaling
600
+
601
+ attn_weights = torch.einsum("icnhd,jcnhd->hcnij", q, k)
602
+
603
+ if self_attn_mask is not None:
604
+ raise NotImplementedError
605
+ if self_attn_padding_mask is not None:
606
+ attn_weights = attn_weights.masked_fill(
607
+ self_attn_padding_mask.permute(2, 0, 1).unsqueeze(0).unsqueeze(3),
608
+ -10000,
609
+ )
610
+
611
+ attn_probs = attn_weights.softmax(-1)
612
+ attn_probs = self.dropout_module(attn_probs)
613
+ context = torch.einsum("hcnij,jcnhd->icnhd", attn_probs, v)
614
+ context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim)
615
+ output = self.out_proj(context)
616
+ return output, attn_probs
617
+
618
+ def forward(
619
+ self,
620
+ x,
621
+ self_attn_mask=None,
622
+ self_attn_padding_mask=None,
623
+ ):
624
+ num_rows, num_cols, batch_size, embed_dim = x.size()
625
+ # if False and num_rows * num_cols > 2 ** 14 and not torch.is_grad_enabled():
626
+ if (num_rows * num_cols) > self.max_tokens_per_msa and not torch.is_grad_enabled():
627
+ return self._batched_forward(
628
+ x,
629
+ self_attn_mask,
630
+ self_attn_padding_mask,
631
+ )
632
+ else:
633
+ return self.compute_attention_update(x, self_attn_mask, self_attn_padding_mask)
634
+
635
+
636
+ def utils_softmax(x, dim: int, onnx_trace: bool = False):
637
+ if onnx_trace:
638
+ return F.softmax(x.float(), dim=dim)
639
+ else:
640
+ return F.softmax(x, dim=dim, dtype=torch.float32)
641
+
642
+
643
+ class FairseqIncrementalState(object):
644
+ def __init__(self, *args, **kwargs):
645
+ super().__init__(*args, **kwargs)
646
+ self.init_incremental_state()
647
+
648
+ def init_incremental_state(self):
649
+ self._incremental_state_id = str(uuid.uuid4())
650
+
651
+ def _get_full_incremental_state_key(self, key: str) -> str:
652
+ return "{}.{}".format(self._incremental_state_id, key)
653
+
654
+ def get_incremental_state(
655
+ self,
656
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
657
+ key: str,
658
+ ) -> Optional[Dict[str, Optional[Tensor]]]:
659
+ """Helper for getting incremental state for an nn.Module."""
660
+ full_key = self._get_full_incremental_state_key(key)
661
+ if incremental_state is None or full_key not in incremental_state:
662
+ return None
663
+ return incremental_state[full_key]
664
+
665
+ def set_incremental_state(
666
+ self,
667
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
668
+ key: str,
669
+ value: Dict[str, Optional[Tensor]],
670
+ ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
671
+ """Helper for setting incremental state for an nn.Module."""
672
+ if incremental_state is not None:
673
+ full_key = self._get_full_incremental_state_key(key)
674
+ incremental_state[full_key] = value
675
+ return incremental_state
676
+
677
+
678
+ def with_incremental_state(cls):
679
+ cls.__bases__ = (FairseqIncrementalState,) + tuple(
680
+ b for b in cls.__bases__ if b != FairseqIncrementalState
681
+ )
682
+ return cls
683
+
684
+
685
+ @with_incremental_state
686
+ class LucaGPLMMultiheadAttention(nn.Module):
687
+ def __init__(
688
+ self,
689
+ embed_dim,
690
+ num_heads,
691
+ kdim=None,
692
+ vdim=None,
693
+ dropout=0.0,
694
+ bias=True,
695
+ add_bias_kv: bool = False,
696
+ add_zero_attn: bool = False,
697
+ self_attention: bool = False,
698
+ encoder_decoder_attention: bool = False,
699
+ use_rotary_embeddings: bool = False,
700
+ ):
701
+ super().__init__()
702
+ self.embed_dim = embed_dim
703
+ self.kdim = kdim if kdim is not None else embed_dim
704
+ self.vdim = vdim if vdim is not None else embed_dim
705
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
706
+
707
+ self.num_heads = num_heads
708
+ self.dropout = dropout
709
+ self.head_dim = embed_dim // num_heads
710
+ assert (
711
+ self.head_dim * num_heads == self.embed_dim
712
+ ), "embed_dim must be divisible by num_heads"
713
+ self.scaling = self.head_dim**-0.5
714
+
715
+ self.self_attention = self_attention
716
+ self.encoder_decoder_attention = encoder_decoder_attention
717
+
718
+ assert not self.self_attention or self.qkv_same_dim, (
719
+ "Self-attention requires query, key and " "value to be of the same size"
720
+ )
721
+
722
+ self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
723
+ self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
724
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
725
+
726
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
727
+
728
+ if add_bias_kv:
729
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
730
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
731
+ else:
732
+ self.bias_k = self.bias_v = None
733
+
734
+ self.add_zero_attn = add_zero_attn
735
+
736
+ self.reset_parameters()
737
+
738
+ self.onnx_trace = False
739
+ self.rot_emb = None
740
+ if use_rotary_embeddings:
741
+ self.rot_emb = RotaryEmbedding(dim=self.head_dim)
742
+
743
+ self.enable_torch_version = False
744
+ if hasattr(F, "multi_head_attention_forward"):
745
+ self.enable_torch_version = True
746
+ else:
747
+ self.enable_torch_version = False
748
+
749
+ def prepare_for_onnx_export_(self):
750
+ self.onnx_trace = True
751
+
752
+ def reset_parameters(self):
753
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=nn.init.calculate_gain("relu"))
754
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=nn.init.calculate_gain("relu"))
755
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=nn.init.calculate_gain("relu"))
756
+
757
+ nn.init.xavier_uniform_(self.out_proj.weight, gain=nn.init.calculate_gain("relu"))
758
+ # nn.init.xavier_uniform_(self.out_proj.weight)
759
+ if self.out_proj.bias is not None:
760
+ nn.init.constant_(self.out_proj.bias, 0.0)
761
+ if self.bias_k is not None:
762
+ nn.init.xavier_normal_(self.bias_k)
763
+ if self.bias_v is not None:
764
+ nn.init.xavier_normal_(self.bias_v)
765
+
766
+ def forward(
767
+ self,
768
+ query,
769
+ key: Optional[Tensor],
770
+ value: Optional[Tensor],
771
+ key_padding_mask: Optional[Tensor] = None,
772
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
773
+ need_weights: bool = True,
774
+ static_kv: bool = False,
775
+ attn_mask: Optional[Tensor] = None,
776
+ before_softmax: bool = False,
777
+ need_head_weights: bool = False,
778
+ ) -> Tuple[Tensor, Optional[Tensor]]:
779
+ if need_head_weights:
780
+ need_weights = True
781
+
782
+ tgt_len, bsz, embed_dim = query.size()
783
+ assert embed_dim == self.embed_dim
784
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
785
+
786
+ if (
787
+ not self.rot_emb
788
+ and self.enable_torch_version
789
+ and not self.onnx_trace
790
+ and incremental_state is None
791
+ and not static_kv
792
+ # A workaround for quantization to work. Otherwise JIT compilation
793
+ # treats bias in linear module as method.
794
+ and not torch.jit.is_scripting()
795
+ and not need_head_weights
796
+ ):
797
+ assert key is not None and value is not None
798
+ return F.multi_head_attention_forward(
799
+ query,
800
+ key,
801
+ value,
802
+ self.embed_dim,
803
+ self.num_heads,
804
+ torch.empty([0]),
805
+ torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
806
+ self.bias_k,
807
+ self.bias_v,
808
+ self.add_zero_attn,
809
+ self.dropout,
810
+ self.out_proj.weight,
811
+ self.out_proj.bias,
812
+ self.training,
813
+ key_padding_mask,
814
+ need_weights,
815
+ attn_mask,
816
+ use_separate_proj_weight=True,
817
+ q_proj_weight=self.q_proj.weight,
818
+ k_proj_weight=self.k_proj.weight,
819
+ v_proj_weight=self.v_proj.weight,
820
+ )
821
+ if incremental_state is not None:
822
+ saved_state = self._get_input_buffer(incremental_state)
823
+ if saved_state is not None and "prev_key" in saved_state:
824
+ # previous time steps are cached - no need to recompute
825
+ # key and value if they are static
826
+ if static_kv:
827
+ assert self.encoder_decoder_attention and not self.self_attention
828
+ key = value = None
829
+ else:
830
+ saved_state = None
831
+
832
+ if self.self_attention:
833
+ q = self.q_proj(query)
834
+ k = self.k_proj(query)
835
+ v = self.v_proj(query)
836
+ elif self.encoder_decoder_attention:
837
+ # encoder-decoder attention
838
+ q = self.q_proj(query)
839
+ if key is None:
840
+ assert value is None
841
+ k = v = None
842
+ else:
843
+ k = self.k_proj(key)
844
+ v = self.v_proj(key)
845
+
846
+ else:
847
+ assert key is not None and value is not None
848
+ q = self.q_proj(query)
849
+ k = self.k_proj(key)
850
+ v = self.v_proj(value)
851
+ q *= self.scaling
852
+
853
+ if self.bias_k is not None:
854
+ assert self.bias_v is not None
855
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
856
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
857
+ if attn_mask is not None:
858
+ attn_mask = torch.cat(
859
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
860
+ )
861
+ if key_padding_mask is not None:
862
+ key_padding_mask = torch.cat(
863
+ [
864
+ key_padding_mask,
865
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
866
+ ],
867
+ dim=1,
868
+ )
869
+
870
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
871
+ if k is not None:
872
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
873
+ if v is not None:
874
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
875
+
876
+ if saved_state is not None:
877
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
878
+ if "prev_key" in saved_state:
879
+ _prev_key = saved_state["prev_key"]
880
+ assert _prev_key is not None
881
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
882
+ if static_kv:
883
+ k = prev_key
884
+ else:
885
+ assert k is not None
886
+ k = torch.cat([prev_key, k], dim=1)
887
+ if "prev_value" in saved_state:
888
+ _prev_value = saved_state["prev_value"]
889
+ assert _prev_value is not None
890
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
891
+ if static_kv:
892
+ v = prev_value
893
+ else:
894
+ assert v is not None
895
+ v = torch.cat([prev_value, v], dim=1)
896
+ prev_key_padding_mask: Optional[Tensor] = None
897
+ if "prev_key_padding_mask" in saved_state:
898
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
899
+ assert k is not None and v is not None
900
+ key_padding_mask = LucaGPLMMultiheadAttention._append_prev_key_padding_mask(
901
+ key_padding_mask=key_padding_mask,
902
+ prev_key_padding_mask=prev_key_padding_mask,
903
+ batch_size=bsz,
904
+ src_len=k.size(1),
905
+ static_kv=static_kv,
906
+ )
907
+
908
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
909
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
910
+ saved_state["prev_key_padding_mask"] = key_padding_mask
911
+ # In this branch incremental_state is never None
912
+ assert incremental_state is not None
913
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
914
+ assert k is not None
915
+ src_len = k.size(1)
916
+
917
+ # This is part of a workaround to get around fork/join parallelism
918
+ # not supporting Optional types.
919
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
920
+ key_padding_mask = None
921
+
922
+ if key_padding_mask is not None:
923
+ assert key_padding_mask.size(0) == bsz
924
+ assert key_padding_mask.size(1) == src_len
925
+
926
+ if self.add_zero_attn:
927
+ assert v is not None
928
+ src_len += 1
929
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
930
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
931
+ if attn_mask is not None:
932
+ attn_mask = torch.cat(
933
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
934
+ )
935
+ if key_padding_mask is not None:
936
+ key_padding_mask = torch.cat(
937
+ [
938
+ key_padding_mask,
939
+ torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask),
940
+ ],
941
+ dim=1,
942
+ )
943
+
944
+ if self.rot_emb:
945
+ q, k = self.rot_emb(q, k)
946
+
947
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
948
+ attn_weights = LucaGPLMMultiheadAttention.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
949
+
950
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
951
+
952
+ if attn_mask is not None:
953
+ attn_mask = attn_mask.unsqueeze(0)
954
+ if self.onnx_trace:
955
+ attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
956
+ attn_weights += attn_mask
957
+
958
+ if key_padding_mask is not None:
959
+ # don't attend to padding symbols
960
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
961
+ attn_weights = attn_weights.masked_fill(
962
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
963
+ )
964
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
965
+
966
+ if before_softmax:
967
+ return attn_weights, v
968
+
969
+ attn_weights_float = utils_softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace)
970
+ attn_weights = attn_weights_float.type_as(attn_weights)
971
+ attn_probs = F.dropout(
972
+ attn_weights_float.type_as(attn_weights),
973
+ p=self.dropout,
974
+ training=self.training,
975
+ )
976
+ assert v is not None
977
+ attn = torch.bmm(attn_probs, v)
978
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
979
+ if self.onnx_trace and attn.size(1) == 1:
980
+ # when ONNX tracing a single decoder step (sequence length == 1)
981
+ # the transpose is a no-op copy before view, thus unnecessary
982
+ attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
983
+ else:
984
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
985
+ attn = self.out_proj(attn)
986
+ attn_weights: Optional[Tensor] = None
987
+ if need_weights:
988
+ attn_weights = attn_weights_float.view(
989
+ bsz, self.num_heads, tgt_len, src_len
990
+ ).type_as(attn).transpose(1, 0)
991
+ if not need_head_weights:
992
+ # average attention weights over heads
993
+ attn_weights = attn_weights.mean(dim=0)
994
+
995
+ return attn, attn_weights
996
+
997
+ @staticmethod
998
+ def _append_prev_key_padding_mask(
999
+ key_padding_mask: Optional[Tensor],
1000
+ prev_key_padding_mask: Optional[Tensor],
1001
+ batch_size: int,
1002
+ src_len: int,
1003
+ static_kv: bool,
1004
+ ) -> Optional[Tensor]:
1005
+ # saved key padding masks have shape (bsz, seq_len)
1006
+ if prev_key_padding_mask is not None and static_kv:
1007
+ new_key_padding_mask = prev_key_padding_mask
1008
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
1009
+ new_key_padding_mask = torch.cat(
1010
+ [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
1011
+ )
1012
+ # During incremental decoding, as the padding token enters and
1013
+ # leaves the frame, there will be a time when prev or current
1014
+ # is None
1015
+ elif prev_key_padding_mask is not None:
1016
+ filler = torch.zeros(
1017
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
1018
+ device=prev_key_padding_mask.device,
1019
+ )
1020
+ new_key_padding_mask = torch.cat(
1021
+ [prev_key_padding_mask.float(), filler.float()], dim=1
1022
+ )
1023
+ elif key_padding_mask is not None:
1024
+ filler = torch.zeros(
1025
+ (batch_size, src_len - key_padding_mask.size(1)),
1026
+ device=key_padding_mask.device,
1027
+ )
1028
+ new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
1029
+ else:
1030
+ new_key_padding_mask = prev_key_padding_mask
1031
+ return new_key_padding_mask
1032
+
1033
+ @torch.jit.export
1034
+ def reorder_incremental_state(
1035
+ self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor
1036
+ ):
1037
+ input_buffer = self._get_input_buffer(incremental_state)
1038
+ if input_buffer is not None:
1039
+ for k in input_buffer.keys():
1040
+ input_buffer_k = input_buffer[k]
1041
+ if input_buffer_k is not None:
1042
+ if self.encoder_decoder_attention and input_buffer_k.size(0) == new_order.size(
1043
+ 0
1044
+ ):
1045
+ break
1046
+ input_buffer[k] = input_buffer_k.index_select(0, new_order)
1047
+ incremental_state = self._set_input_buffer(incremental_state, input_buffer)
1048
+ return incremental_state
1049
+
1050
+ def _get_input_buffer(
1051
+ self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
1052
+ ) -> Dict[str, Optional[Tensor]]:
1053
+ result = self.get_incremental_state(incremental_state, "attn_state")
1054
+ if result is not None:
1055
+ return result
1056
+ else:
1057
+ empty_result: Dict[str, Optional[Tensor]] = {}
1058
+ return empty_result
1059
+
1060
+ def _set_input_buffer(
1061
+ self,
1062
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
1063
+ buffer: Dict[str, Optional[Tensor]],
1064
+ ):
1065
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
1066
+
1067
+ def apply_sparse_mask(attn_weights, tgt_len: int, src_len: int, bsz: int):
1068
+ return attn_weights
1069
+
1070
+ def upgrade_state_dict_named(self, state_dict, name):
1071
+ prefix = name + "." if name != "" else ""
1072
+ items_to_add = {}
1073
+ keys_to_remove = []
1074
+ for k in state_dict.keys():
1075
+ if k.endswith(prefix + "in_proj_weight"):
1076
+ dim = int(state_dict[k].shape[0] / 3)
1077
+ items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
1078
+ items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
1079
+ items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
1080
+
1081
+ keys_to_remove.append(k)
1082
+
1083
+ k_bias = prefix + "in_proj_bias"
1084
+ if k_bias in state_dict.keys():
1085
+ dim = int(state_dict[k].shape[0] / 3)
1086
+ items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
1087
+ items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][dim : 2 * dim]
1088
+ items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
1089
+
1090
+ keys_to_remove.append(prefix + "in_proj_bias")
1091
+
1092
+ for k in keys_to_remove:
1093
+ del state_dict[k]
1094
+
1095
+ for key, value in items_to_add.items():
1096
+ state_dict[key] = value
1097
+
1098
+
1099
+ def rotate_half(x):
1100
+ x1, x2 = x.chunk(2, dim=-1)
1101
+ return torch.cat((-x2, x1), dim=-1)
1102
+
1103
+
1104
+ def apply_rotary_pos_emb(x, cos, sin):
1105
+ cos = cos[:, : x.shape[-2], :]
1106
+ sin = sin[:, : x.shape[-2], :]
1107
+
1108
+ return (x * cos) + (rotate_half(x) * sin)
1109
+
1110
+
1111
+ class RotaryEmbedding(torch.nn.Module):
1112
+ def __init__(self, dim: int, *_, **__):
1113
+ super().__init__()
1114
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
1115
+ self.register_buffer("inv_freq", inv_freq)
1116
+
1117
+ self._seq_len_cached = None
1118
+ self._cos_cached = None
1119
+ self._sin_cached = None
1120
+
1121
+ def _update_cos_sin_tables(self, x, seq_dimension=1):
1122
+ seq_len = x.shape[seq_dimension]
1123
+
1124
+ if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
1125
+ self._seq_len_cached = seq_len
1126
+ t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
1127
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
1128
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
1129
+
1130
+ self._cos_cached = emb.cos()[None, :, :]
1131
+ self._sin_cached = emb.sin()[None, :, :]
1132
+
1133
+ return self._cos_cached, self._sin_cached
1134
+
1135
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
1136
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
1137
+
1138
+ return (
1139
+ apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
1140
+ apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
1141
+ )
1142
+
1143
+
1144
+
1145
+
multi_label_metrics.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ '''
4
+ @license: (C) Copyright 2021, Hey.
5
+ @author: Hey
6
+ @email: sanyuan.**@**.com
7
+ @tel: 137****6540
8
+ @datetime: 2022/11/26 21:05
9
+ @project: LucaOne
10
+ @file: multi_label_metrics.py
11
+ @desc: metrics for multi-label classification
12
+ '''
13
+ import csv
14
+ import numpy as np
15
+ import torch
16
+ from sklearn.metrics import roc_auc_score, average_precision_score
17
+
18
+
19
+ def multi_label_acc(targets, probs, threshold=0.5):
20
+ targets_relevant = relevant_indexes(targets)
21
+ preds_relevant = relevant_indexes((probs >= threshold).astype(int))
22
+ acc_list = []
23
+ for idx in range(targets.shape[0]):
24
+ target_relevant = targets_relevant[idx]
25
+ pred_relevant = preds_relevant[idx]
26
+ union_len = len(set(target_relevant).union(set(pred_relevant)))
27
+ intersection_len = len(set(target_relevant).intersection(set(pred_relevant)))
28
+ if union_len == 0:
29
+ acc_list.append(1.0)
30
+ else:
31
+ # acc
32
+ acc = 1.0 - (union_len - intersection_len) / targets.shape[1]
33
+ acc_list.append(acc)
34
+ return round(sum(acc_list)/len(acc_list), 6) if len(acc_list) > 0 else 0
35
+
36
+
37
+ def multi_label_precision(targets, probs, threshold=0.5):
38
+ targets_relevant = relevant_indexes(targets)
39
+ preds_relevant = relevant_indexes((probs >= threshold).astype(int))
40
+ prec_list = []
41
+
42
+ for idx in range(targets.shape[0]):
43
+ target_relevant = targets_relevant[idx]
44
+ pred_relevant = preds_relevant[idx]
45
+ target_len = len(target_relevant)
46
+ predict_len = len(pred_relevant)
47
+ union_len = len(set(target_relevant).union(set(pred_relevant)))
48
+ intersection_len = len(set(target_relevant).intersection(set(pred_relevant)))
49
+ if union_len == 0:
50
+ prec_list.append(1.0)
51
+ else:
52
+ # precision
53
+ prec = 0.0
54
+ if predict_len > 0:
55
+ prec = intersection_len / predict_len
56
+ prec_list.append(prec)
57
+
58
+ round(sum(prec_list)/len(prec_list), 6) if len(prec_list) > 0 else 0
59
+
60
+
61
+ def multi_label_recall(targets, probs, threshold=0.5):
62
+ targets_relevant = relevant_indexes(targets)
63
+ preds_relevant = relevant_indexes((probs >= threshold).astype(int))
64
+ recall_list = []
65
+ for idx in range(targets.shape[0]):
66
+ target_relevant = targets_relevant[idx]
67
+ pred_relevant = preds_relevant[idx]
68
+ target_len = len(target_relevant)
69
+ union_len = len(set(target_relevant).union(set(pred_relevant)))
70
+ intersection_len = len(set(target_relevant).intersection(set(pred_relevant)))
71
+ if union_len == 0:
72
+ recall_list.append(1.0)
73
+ else:
74
+ # recall
75
+ if target_len > 0:
76
+ recall = intersection_len / target_len
77
+ else:
78
+ recall = 1.0
79
+ recall_list.append(recall)
80
+ return round(sum(recall_list)/len(recall_list), 6) if len(recall_list) > 0 else 0
81
+
82
+
83
+ def multi_label_jaccard(targets, probs, threshold=0.5):
84
+ targets_relevant = relevant_indexes(targets)
85
+ preds_relevant = relevant_indexes((probs >= threshold).astype(int))
86
+ jaccard_list = []
87
+ for idx in range(targets.shape[0]):
88
+ target_relevant = targets_relevant[idx]
89
+ pred_relevant = preds_relevant[idx]
90
+ union_len = len(set(target_relevant).union(set(pred_relevant)))
91
+ intersection_len = len(set(target_relevant).intersection(set(pred_relevant)))
92
+ if union_len == 0:
93
+ jaccard_list.append(1.0)
94
+ else:
95
+ # jaccard sim
96
+ jac = intersection_len / union_len
97
+ jaccard_list.append(jac)
98
+ return round(sum(jaccard_list)/len(jaccard_list), 6) if len(jaccard_list) > 0 else 0
99
+
100
+
101
+ def multi_label_f1(targets, probs, threshold=0.5):
102
+ targets_relevant = relevant_indexes(targets)
103
+ preds_relevant = relevant_indexes((probs >= threshold).astype(int))
104
+ f1_list = []
105
+ for idx in range(targets.shape[0]):
106
+ target_relevant = targets_relevant[idx]
107
+ pred_relevant = preds_relevant[idx]
108
+ target_len = len(target_relevant)
109
+ predict_len = len(pred_relevant)
110
+ union_len = len(set(target_relevant).union(set(pred_relevant)))
111
+ intersection_len = len(set(target_relevant).intersection(set(pred_relevant)))
112
+ if union_len == 0:
113
+ f1_list.append(1.0)
114
+ else:
115
+ # precision
116
+ prec = 0.0
117
+
118
+ # recall
119
+ if target_len > 0:
120
+ recall = intersection_len / target_len
121
+ else:
122
+ recall = 1.0
123
+ # f1
124
+ if prec + recall == 0:
125
+ f1 = 0.0
126
+ else:
127
+ f1 = 2.0 * prec * recall / (prec + recall)
128
+ f1_list.append(f1)
129
+ return round(sum(f1_list)/len(f1_list), 6) if len(f1_list) > 0 else 0
130
+
131
+
132
+ def multi_label_roc_auc(targets, probs, threshold=0.5):
133
+ targets_relevant = relevant_indexes(targets)
134
+ preds_relevant = relevant_indexes((probs >= threshold).astype(int))
135
+ roc_auc_list = []
136
+ for idx in range(targets.shape[0]):
137
+ target_relevant = targets_relevant[idx]
138
+ pred_relevant = preds_relevant[idx]
139
+ union_len = len(set(target_relevant).union(set(pred_relevant)))
140
+ if union_len == 0:
141
+ roc_auc_list.append(1.0)
142
+ else:
143
+ # roc_auc
144
+ if len(np.unique(targets[idx, :])) > 1:
145
+ roc_auc = roc_auc_macro(targets[idx, :], probs[idx, :])
146
+ roc_auc_list.append(roc_auc)
147
+ return round(sum(roc_auc_list)/len(roc_auc_list), 6) if len(roc_auc_list) > 0 else 0
148
+
149
+
150
+ def multi_label_pr_auc(targets, probs, threshold=0.5):
151
+ targets_relevant = relevant_indexes(targets)
152
+ preds_relevant = relevant_indexes((probs >= threshold).astype(int))
153
+ pr_auc_list = []
154
+ for idx in range(targets.shape[0]):
155
+ target_relevant = targets_relevant[idx]
156
+ pred_relevant = preds_relevant[idx]
157
+ union_len = len(set(target_relevant).union(set(pred_relevant)))
158
+ if union_len == 0:
159
+ pr_auc_list.append(1.0)
160
+ else:
161
+ # roc_auc
162
+ if len(np.unique(targets[idx, :])) > 1:
163
+
164
+ pr_auc = pr_auc_macro(targets[idx, :], probs[idx, :])
165
+ pr_auc_list.append(pr_auc)
166
+
167
+ return round(sum(pr_auc_list)/len(pr_auc_list), 6) if len(pr_auc_list) > 0 else 0
168
+
169
+
170
+ def metrics_multi_label(targets, probs, threshold=0.5):
171
+ '''
172
+ metrics of multi-label classification
173
+ cal metrics for true matrix to predict probability matrix
174
+ :param targets: true 0-1 indicator matrix (n_samples, n_labels)
175
+ :param probs: probs 0~1 probability matrix (n_samples, n_labels)
176
+ :param threshold: negative-positive threshold
177
+ :return: some metrics
178
+ '''
179
+ targets_relevant = relevant_indexes(targets)
180
+ preds_relevant = relevant_indexes((probs >= threshold).astype(int))
181
+ acc_list = []
182
+ prec_list = []
183
+ recall_list = []
184
+ jaccard_list = []
185
+ f1_list = []
186
+ roc_auc_list = []
187
+ pr_auc_list = []
188
+ for idx in range(targets.shape[0]):
189
+ target_relevant = targets_relevant[idx]
190
+ pred_relevant = preds_relevant[idx]
191
+ target_len = len(target_relevant)
192
+ predict_len = len(pred_relevant)
193
+ union_len = len(set(target_relevant).union(set(pred_relevant)))
194
+ intersection_len = len(set(target_relevant).intersection(set(pred_relevant)))
195
+ if union_len == 0:
196
+ acc_list.append(1.0)
197
+ prec_list.append(1.0)
198
+ recall_list.append(1.0)
199
+ roc_auc_list.append(1.0)
200
+ jaccard_list.append(1.0)
201
+ f1_list.append(1.0)
202
+ pr_auc_list.append(1.0)
203
+ else:
204
+ # acc
205
+ acc = 1.0 - (union_len - intersection_len) / targets.shape[1]
206
+ acc_list.append(acc)
207
+
208
+ # precision
209
+ prec = 0.0
210
+ if predict_len > 0:
211
+ prec = intersection_len / predict_len
212
+ prec_list.append(prec)
213
+
214
+ # recall
215
+ if target_len > 0:
216
+ recall = intersection_len / target_len
217
+ else:
218
+ recall = 1.0
219
+ recall_list.append(recall)
220
+
221
+ # jaccard sim
222
+ jac = intersection_len / union_len
223
+ jaccard_list.append(jac)
224
+
225
+ # f1
226
+ if prec + recall == 0:
227
+ f1 = 0.0
228
+ else:
229
+ f1 = 2.0 * prec * recall / (prec + recall)
230
+ f1_list.append(f1)
231
+
232
+ # roc_auc
233
+ if len(np.unique(targets[idx, :])) > 1:
234
+ roc_auc = roc_auc_macro(targets[idx, :], probs[idx, :])
235
+ roc_auc_list.append(roc_auc)
236
+ pr_auc = pr_auc_macro(targets[idx, :], probs[idx, :])
237
+ pr_auc_list.append(pr_auc)
238
+
239
+ f_max_value, p_max_value, r_max_value, t_max_value, preds_max_value = f_max(targets, probs)
240
+ return {
241
+ "acc": round(float(sum(acc_list)/len(acc_list)), 6) if len(acc_list) > 0 else 0,
242
+ "jaccard": round(float(sum(jaccard_list)/len(jaccard_list)), 6) if len(jaccard_list) > 0 else 0,
243
+ "prec": round(float(sum(prec_list)/len(prec_list)), 6) if len(prec_list) > 0 else 0,
244
+ "recall": round(float(sum(recall_list)/len(recall_list)), 6) if len(recall_list) > 0 else 0,
245
+ "f1": round(float(sum(f1_list)/len(f1_list)), 6) if len(f1_list) > 0 else 0,
246
+ "pr_auc": round(float(sum(pr_auc_list)/len(pr_auc_list)), 6) if len(pr_auc_list) > 0 else 0,
247
+ "roc_auc": round(float(sum(roc_auc_list)/len(roc_auc_list)), 6) if len(roc_auc_list) > 0 else 0,
248
+ "fmax": round(float(f_max_value), 6),
249
+ "pmax": round(float(p_max_value), 6) ,
250
+ "rmax": round(float(r_max_value), 6),
251
+ "tmax": round(float(t_max_value), 6)
252
+ }
253
+
254
+
255
+ def f_max(targets, probs, gos=None):
256
+ '''
257
+ f-max for multi-label classification
258
+ :param targets: true 0-1 indicator matrix (n_samples, n_labels)
259
+ :param probs: probs 0~1 probability matrix (n_samples, n_labels)
260
+ :param gos:
261
+ :return: fmax, p_max(precision max), r_max(recall max), t_max(classificaton threshold), preds_max(0-1 indicator matrix)
262
+ '''
263
+ preds_max = None
264
+ f_max = 0
265
+ p_max = 0
266
+ r_max = 0
267
+ t_max = 0
268
+ # from 0.01 to 1 (100 thresholds)
269
+ for tt in range(1, 101):
270
+ threshold = tt / 100.0
271
+ preds = (probs > threshold).astype(np.int32)
272
+ p = 0.0
273
+ r = 0.0
274
+ total = 0
275
+ p_total = 0
276
+ for i in range(preds.shape[0]):
277
+ tp = np.sum(preds[i, :] * targets[i, :])
278
+ fp = np.sum(preds[i, :]) - tp
279
+ fn = np.sum(targets[i, :]) - tp
280
+ if gos:
281
+ fn += gos[i]
282
+
283
+ if tp == 0 and fp == 0 and fn == 0:
284
+ continue
285
+ total += 1
286
+ if tp != 0:
287
+ p_total += 1
288
+ precision = tp / (1.0 * (tp + fp))
289
+ recall = tp / (1.0 * (tp + fn))
290
+ p += precision
291
+ r += recall
292
+
293
+ if total > 0 and p_total > 0:
294
+ r /= total
295
+ p /= p_total
296
+ if p + r > 0:
297
+ f = 2 * p * r / (p + r)
298
+ if f_max < f:
299
+ f_max = f
300
+ p_max = p
301
+ r_max = r
302
+ t_max = threshold
303
+ preds_max = preds
304
+
305
+ return f_max, p_max, r_max, t_max, preds_max
306
+
307
+
308
+ def metrics_multi_label_for_pred(targets, preds, savepath=None):
309
+ '''
310
+ metrics for multi-label classification
311
+ cal metrics for true matrix to predict
312
+ :param targets: true 0-1 indicator matrix (n_samples, n_labels)
313
+ :param preds: preds 0~1 indicator matrix (n_samples, n_labels)
314
+ :return: some metrics
315
+ '''
316
+ targets_relevant = relevant_indexes(targets)
317
+ preds_relevant = relevant_indexes(preds)
318
+ acc_list = []
319
+ prec_list = []
320
+ recall_list = []
321
+ jaccard_list = []
322
+ f1_list = []
323
+ for idx in range(targets.shape[0]):
324
+ target_relevant = targets_relevant[idx]
325
+ pred_relevant = preds_relevant[idx]
326
+
327
+ target_len = len(target_relevant)
328
+ predict_len = len(pred_relevant)
329
+ union_len = len(set(target_relevant).union(set(pred_relevant)))
330
+ intersection_len = len(set(target_relevant).intersection(set(pred_relevant)))
331
+ acc = 1.0 - (union_len - intersection_len) / targets.shape[1]
332
+ prec = 0.0
333
+ if predict_len > 0:
334
+ prec = intersection_len / predict_len
335
+ recall = 0
336
+ if target_len > 0:
337
+ recall = intersection_len / target_len
338
+ else:
339
+ print(targets[idx])
340
+ jac = intersection_len / union_len
341
+ if prec + recall == 0:
342
+ f1 = 0.0
343
+ else:
344
+ f1 = 2.0 * prec * recall / (prec + recall)
345
+
346
+ acc_list.append(acc)
347
+ prec_list.append(prec)
348
+ recall_list.append(recall)
349
+ jaccard_list.append(jac)
350
+ f1_list.append(f1)
351
+
352
+ return {
353
+ "acc": round(sum(acc_list)/targets.shape[0], 6),
354
+ "jaccard": round(sum(jaccard_list)/targets.shape[0], 6),
355
+ "prec": round(sum(prec_list)/targets.shape[0], 6),
356
+ "recall": round(sum(recall_list)/targets.shape[0], 6),
357
+ "f1": round(sum(f1_list)/targets.shape[0], 6)
358
+ }
359
+
360
+
361
+ def label_id_2_array(label_ids, label_size):
362
+ '''
363
+ building 0-1 indicator array for multi-label classification
364
+ :param label_ids:
365
+ :param label_size:
366
+ :return:
367
+ '''
368
+ arr = np.zeros(label_size)
369
+ arr[label_ids] = 1
370
+ return arr
371
+
372
+
373
+ def relevant_indexes(matrix):
374
+ '''
375
+ Which positions in the multi-label are labeled as 1
376
+ :param matrix:
377
+ :return:
378
+ '''
379
+ if torch.is_tensor(matrix):
380
+ matrix = matrix.detach().cpu().numpy()
381
+ relevants = []
382
+ shape = matrix.shape
383
+ if matrix.ndim == 3:
384
+
385
+ for x in range(shape[0]):
386
+ relevant_x = []
387
+ for y in range(shape[1]):
388
+ relevant_y = []
389
+ for z in range(shape[2]):
390
+ if matrix[x, y, z] == 1:
391
+ relevant_y.append(int(z))
392
+ relevant_x.append(relevant_y)
393
+ relevants.append(relevant_x)
394
+ elif matrix.ndim == 2:
395
+ for row in range(shape[0]):
396
+ relevant = []
397
+ for col in range(shape[1]):
398
+ if matrix[row, col] == 1:
399
+ relevant.append(int(col))
400
+ relevants.append(relevant)
401
+ else:
402
+ for idx in range(matrix.shape[0]):
403
+ if matrix[idx] == 1:
404
+ relevants.append(int(idx))
405
+ return relevants
406
+
407
+
408
+ def irrelevant_indexes(matrix):
409
+ '''
410
+ Which positions in the multi-label label are 0
411
+ :param matrix:
412
+ :return:
413
+ '''
414
+ if torch.is_tensor(matrix):
415
+ matrix = matrix.detach().cpu().numpy()
416
+
417
+ irrelevants = []
418
+ if matrix.ndim == 3:
419
+ for x in range(matrix.shape[0]):
420
+ irrelevant_x = []
421
+ for y in range(matrix.shape[1]):
422
+ irrelevant_y = []
423
+ for z in range(matrix.shape[2]):
424
+ if matrix[x, y, z] == 0:
425
+ irrelevant_y.append(int(z))
426
+ irrelevant_x.append(irrelevant_y)
427
+ irrelevants.append(irrelevant_x)
428
+ elif matrix.ndim == 2:
429
+ for row in range(matrix.shape[0]):
430
+ irrelevant = []
431
+ for col in range(matrix.shape[1]):
432
+ if matrix[row, col] == 1:
433
+ irrelevant.append(int(col))
434
+ irrelevants.append(irrelevant)
435
+ else:
436
+ for idx in range(matrix.shape[0]):
437
+ if matrix[idx] == 1:
438
+ irrelevants.append(int(idx))
439
+
440
+ return irrelevants
441
+
442
+
443
+ def prob_2_pred(prob, threshold):
444
+ '''
445
+ Probabilities converted to 0-1 predicted labels
446
+ :param prob:
447
+ :param threshold:
448
+ :return:
449
+ '''
450
+ if torch.is_tensor(prob):
451
+ prob = prob.detach().cpu().numpy()
452
+
453
+ if isinstance(prob, (np.ndarray, np.generic)):
454
+ return (prob >= threshold).astype(int)
455
+
456
+
457
+ def roc_auc_macro(target, prob):
458
+ '''
459
+ macro roc auc
460
+ :param target:
461
+ :param prob:
462
+ :return:
463
+ '''
464
+ return roc_auc_score(target, prob, average="macro")
465
+
466
+
467
+ def pr_auc_macro(target, prob):
468
+ '''
469
+ macro pr-auc
470
+ :param target:
471
+ :param prob:
472
+ :return:
473
+ '''
474
+ return average_precision_score(target, prob, average="macro")
475
+
476
+
477
+ def write_error_samples_multi_label(filepath, samples, input_indexs, input_id_2_names, output_id_2_name, targets,
478
+ probs, threshold=0.5,
479
+ use_other_diags=False, use_other_operas=False, use_checkin_department=False):
480
+ '''
481
+ writer bad cases for multi-label classification
482
+ :param filepath:
483
+ :param samples:
484
+ :param input_indexs:
485
+ :param input_id_2_names:
486
+ :param output_id_2_name:
487
+ :param targets:
488
+ :param probs:
489
+ :param threshold:
490
+ :param use_other_diags:
491
+ :param use_other_operas:
492
+ :param use_checkin_department:
493
+ :return:
494
+ '''
495
+ preds = prob_2_pred(probs, threshold=threshold)
496
+ targets_relevant = relevant_indexes(targets)
497
+ preds_relevant = relevant_indexes(preds)
498
+ with open(filepath, "w") as fp:
499
+ writer = csv.writer(fp)
500
+ writer.writerow(["score", "y_true", "y_pred", "inputs"])
501
+ for i in range(len(targets_relevant)):
502
+ target = set(targets_relevant[i])
503
+ pred = set(preds_relevant[i])
504
+ jacc = len(target.intersection(pred))/(len(target.union(pred)))
505
+ if output_id_2_name:
506
+ target_labels = [output_id_2_name[v] for v in target]
507
+ pred_labels = [output_id_2_name[v] for v in pred]
508
+ else:
509
+ target_labels = target
510
+ pred_labels = pred
511
+ sample = samples[i]
512
+ if input_id_2_names:
513
+ new_sample = []
514
+ for idx, input_index in enumerate(input_indexs):
515
+ if input_index == 3 and not use_checkin_department:
516
+ input_index = 12
517
+ new_sample.append([input_id_2_names[idx][v] for v in sample[input_index]])
518
+ if input_index == 6 and use_other_diags or input_index == 8 and use_other_operas or input_index == 10 and use_other_diags:
519
+ new_sample.append([input_id_2_names[idx][v] for v in sample[input_index + 1]])
520
+ else:
521
+ new_sample = sample
522
+ row = [jacc, target_labels, pred_labels, new_sample]
523
+ writer.writerow(row)
524
+
525
+
526
+ if __name__ == "__main__":
527
+ '''multi_label'''
528
+ probs = np.array([[0.6, 0.1, 0.1], [0.8, 0.3, 0.8], [0.8, 0.1, 0.1], [0.8, 0.1, 0.1]])
529
+ targets = np.array([[1, 1, 0], [1, 0, 1], [1, 0, 0], [0, 0, 1]])
530
+ print(metrics_multi_label(targets, probs))
531
+ t = np.array([[0, 0, 0], [1, 1, 1]])
532
+ print(t[0, :])
533
+ print(np.unique(t[0, :]))
534
+
535
+
536
+
pooling.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from .modeling_bert import BertEncoder, BertPooler
8
+
9
+ class GlobalMaskMaxPooling1D(nn.Module):
10
+ def __init__(self, ):
11
+ super(GlobalMaskMaxPooling1D, self).__init__()
12
+
13
+ def forward(self, x, mask=None):
14
+ if mask is not None:
15
+ # (B, Seq_len) -> (B, Seq_len, 1)
16
+ mask = 1.0 - mask
17
+ mask = mask * (-2**10 + 1)
18
+ mask = torch.unsqueeze(mask, dim=-1)
19
+ x += mask
20
+ return torch.max(x, dim=1)[0]
21
+
22
+
23
+ class GlobalMaskMinPooling1D(nn.Module):
24
+ def __init__(self, ):
25
+ super(GlobalMaskMinPooling1D, self).__init__()
26
+
27
+ def forward(self, x, mask=None):
28
+ if mask is not None:
29
+ # (B, Seq_len) -> (B, Seq_len, 1)
30
+ mask = 1.0 - mask
31
+ mask = mask * (2**10+1)
32
+ mask = torch.unsqueeze(mask, dim=-1)
33
+ x += mask
34
+ return torch.min(x, dim=1)[0]
35
+
36
+
37
+ class GlobalMaskAvgPooling1D(nn.Module):
38
+ def __init__(self):
39
+ super(GlobalMaskAvgPooling1D, self).__init__()
40
+
41
+ def forward(self, x, mask=None):
42
+ if mask is not None:
43
+ # (B, Seq_len) -> (B, Seq_len, 1)
44
+ mask = torch.unsqueeze(mask, dim=-1)
45
+ x *= mask
46
+ return torch.sum(x, dim=1)/torch.sum(mask, dim=1)
47
+ else:
48
+ return torch.mean(x, dim=1)
49
+
50
+
51
+ class GlobalMaskSumPooling1D(nn.Module):
52
+ def __init__(self, axis):
53
+ '''
54
+ sum pooling
55
+ :param axis: axis=0, add all the rows of the matrix,axis=1, add all the cols of the matrix
56
+ '''
57
+ super(GlobalMaskSumPooling1D, self).__init__()
58
+ self.axis = axis
59
+
60
+ def forward(self, x, mask=None):
61
+ if mask is not None:
62
+ # (B, Seq_len) -> (B, Seq_len, 1)
63
+ mask = torch.unsqueeze(mask, dim=-1)
64
+ x *= mask
65
+ return torch.sum(x, dim=self.axis)
66
+
67
+
68
+ class GlobalMaskWeightedAttentionPooling1D(nn.Module):
69
+ def __init__(self, embed_size, use_bias=False):
70
+ super(GlobalMaskWeightedAttentionPooling1D, self).__init__()
71
+ self.embed_size = embed_size
72
+ self.use_bias = use_bias
73
+
74
+ self.W = nn.Parameter(torch.Tensor(self.embed_size))
75
+ nn.init.trunc_normal_(self.W, std=0.01)
76
+ if self.use_bias:
77
+ self.b = nn.Parameter(torch.Tensor(1))
78
+ nn.init.trunc_normal_(self.b, std=0.01)
79
+
80
+ def forward(self, x, mask=None):
81
+ # (B, Len, Embed) x (Embed,) = (B, Len)
82
+ logits = torch.matmul(x, self.W)
83
+ if self.use_bias:
84
+ logits += self.b
85
+
86
+ if mask is not None:
87
+ attention_probs = nn.Softmax(dim=-1)(logits + (1.0 - mask) * -10000)
88
+ else:
89
+ attention_probs = nn.Softmax(dim=-1)(logits)
90
+ x = torch.sum(torch.unsqueeze(attention_probs, dim=-1) * x, dim=1)
91
+ return x
92
+
93
+
94
+ class GlobalMaskContextAttentionPooling1D(nn.Module):
95
+ def __init__(self, embed_size, units=None, use_additive_bias=False, use_attention_bias=False):
96
+ super(GlobalMaskContextAttentionPooling1D, self).__init__()
97
+ self.embed_size = embed_size
98
+ self.use_additive_bias = use_additive_bias
99
+ self.use_attention_bias = use_attention_bias
100
+ self.units = units if units else embed_size
101
+
102
+ self.U = nn.Parameter(torch.Tensor(self.embed_size, self.units))
103
+ self.V = nn.Parameter(torch.Tensor(self.embed_size, self.units))
104
+ if self.use_additive_bias:
105
+ self.b1 = nn.Parameter(torch.Tensor(self.units))
106
+ nn.init.trunc_normal_(self.b1, std=0.01)
107
+ if self.use_attention_bias:
108
+ self.b2 = nn.Parameter(torch.Tensor(1))
109
+ nn.init.trunc_normal_(self.b2, std=0.01)
110
+
111
+ self.c = nn.Parameter(torch.Tensor(self.units))
112
+
113
+ nn.init.trunc_normal_(self.U, std=0.01)
114
+ nn.init.trunc_normal_(self.V, std=0.01)
115
+ nn.init.trunc_normal_(self.c, std=0.01)
116
+
117
+ def forward(self, x, mask=None):
118
+ # (B, Len, Embed) x (Embed, Units) = (B, Len, Units)
119
+ q = torch.matmul(x, self.U)
120
+ k = torch.matmul(x, self.V)
121
+ if self.use_additive_bias:
122
+ h = torch.tanh(q + k + self.b1)
123
+ else:
124
+ h = torch.tanh(q + k)
125
+
126
+ if self.use_attention_bias:
127
+ e = torch.matmul(h, self.c) + self.b2
128
+ else:
129
+ e = torch.matmul(h, self.c)
130
+ if mask is not None:
131
+ attention_probs = nn.Softmax(dim=-1)(e + (1.0 - mask) * -10000)
132
+ else:
133
+ attention_probs = nn.Softmax(dim=-1)(e)
134
+ x = torch.sum(torch.unsqueeze(attention_probs, dim=-1) * x, dim=1)
135
+ return x
136
+
137
+
138
+ class GlobalMaskValueAttentionPooling1D(nn.Module):
139
+ def __init__(self, embed_size, units=None, use_additive_bias=False, use_attention_bias=False):
140
+ super(GlobalMaskValueAttentionPooling1D, self).__init__()
141
+ self.embed_size = embed_size
142
+ self.use_additive_bias = use_additive_bias
143
+ self.use_attention_bias = use_attention_bias
144
+ self.units = units if units else embed_size
145
+
146
+ self.U = nn.Parameter(torch.Tensor(self.embed_size, self.units))
147
+ self.V = nn.Parameter(torch.Tensor(self.embed_size, self.units))
148
+ if self.use_additive_bias:
149
+ self.b1 = nn.Parameter(torch.Tensor(self.units))
150
+ nn.init.trunc_normal_(self.b1, std=0.01)
151
+ if self.use_attention_bias:
152
+ self.b2 = nn.Parameter(torch.Tensor(self.embed_size))
153
+ nn.init.trunc_normal_(self.b2, std=0.01)
154
+
155
+ self.W = nn.Parameter(torch.Tensor(self.units, self.embed_size))
156
+
157
+ nn.init.trunc_normal_(self.U, std=0.01)
158
+ nn.init.trunc_normal_(self.V, std=0.01)
159
+ nn.init.trunc_normal_(self.W, std=0.01)
160
+
161
+ def forward(self, x, mask=None):
162
+ # (B, Len, Embed) x (Embed, Units) = (B, Len, Units)
163
+ q = torch.matmul(x, self.U)
164
+ k = torch.matmul(x, self.V)
165
+ if self.use_additive_bias:
166
+ h = torch.tanh(q + k + self.b1)
167
+ else:
168
+ h = torch.tanh(q + k)
169
+
170
+ # (B, Len, Units) x (Units, Embed) = (B, Len, Embed)
171
+ if self.use_attention_bias:
172
+ e = torch.matmul(h, self.W) + self.b2
173
+ else:
174
+ e = torch.matmul(h, self.W)
175
+ if mask is not None:
176
+ attention_probs = nn.Softmax(dim=1)(e + torch.unsqueeze((1.0 - mask) * -10000, dim=-1))
177
+ else:
178
+ attention_probs = nn.Softmax(dim=1)(e)
179
+ x = torch.sum(attention_probs * x, dim=1)
180
+ return x
181
+
182
+ def __repr__(self):
183
+ return self.__class__.__name__ + ' (' + str(self.embed_size) + ' -> ' + str(self.embed_size) + ')'
184
+
185
+
186
+ class GlobalMaskTransformerPooling1D(nn.Module):
187
+ def __init__(self, config):
188
+ super(GlobalMaskTransformerPooling1D, self).__init__()
189
+ self.embeddings = nn.Parameter(torch.Tensor(1, 1, config.hidden_size))
190
+ nn.init.trunc_normal_(self.embeddings, std=0.02)
191
+ config.num_hidden_layers = 2
192
+ self.encoder = BertEncoder(config)
193
+ self.pooler = BertPooler(config)
194
+
195
+ def forward(self, x, mask=None):
196
+ B, Seq_len, Enbed = x.size()
197
+ cls_emb_batch = self.embeddings.expand(B, 1, Enbed)
198
+ merged_output = torch.cat((cls_emb_batch, x), dim=1) # [B, Seq_len + 1, Enbed]
199
+ if mask is not None:
200
+ device = x.device
201
+ cls_mask = torch.ones(B, 1).to(device)
202
+ mask = torch.cat([cls_mask, mask], dim=1)
203
+ mask = mask[:, None, None, :]
204
+
205
+ sequence_output = self.encoder(merged_output,
206
+ attention_mask=mask,
207
+ head_mask=None,
208
+ encoder_hidden_states=None,
209
+ encoder_attention_mask=None,
210
+ output_attentions=False,
211
+ output_hidden_states=False,
212
+ return_dict=False)[0]
213
+ pooled_output = self.pooler(sequence_output)
214
+ return pooled_output
215
+
216
+
217
+ class GlobalMaxPool1d(nn.Module):
218
+ def __init__(self):
219
+ super(GlobalMaxPool1d,self).__init__()
220
+ self.fc = nn.AdaptiveMaxPool1d(1)
221
+
222
+ def forward(self, x):
223
+ x = x.permute(0, 2, 1)
224
+ x = self.fc(x)
225
+ x = torch.squeeze(x, dim=-1)
226
+ return x
227
+
228
+
229
+ class GlobalAvgPool1d(nn.Module):
230
+ def __init__(self, ):
231
+ super(GlobalAvgPool1d, self).__init__()
232
+ self.fc = nn.AdaptiveAvgPool1d(1)
233
+
234
+ def forward(self, x):
235
+ x = x.permute(0, 2, 1)
236
+ x = self.fc(x)
237
+ x = torch.squeeze(x, dim=-1)
238
+ return x
239
+
240
+
241
+ class AttentionPool1d(nn.Module):
242
+ def __init__(self, embed_size, device="cuda"):
243
+ super(AttentionPool1d, self).__init__()
244
+ self.embed_size = embed_size
245
+ self.W = nn.Parameter(torch.Tensor(self.embed_size, self.embed_size))
246
+ self.b = nn.Parameter(torch.Tensor(self.embed_size))
247
+ self.c = nn.Parameter(torch.Tensor(self.embed_size))
248
+ nn.init.trunc_normal_(self.W, std=0.02)
249
+ nn.init.trunc_normal_(self.b, std=0.02)
250
+ nn.init.trunc_normal_(self.c, std=0.02)
251
+
252
+ def forward(self, x):
253
+ '''
254
+ # x:(B, Seq_len, Enbed)
255
+ # mul: (B, Seq_len)
256
+ mul = torch.matmul(x, self.w)
257
+ # B, Seq_len
258
+ attention_probs = nn.Softmax(dim=-1)(mul)
259
+ # B, Seq_len
260
+ x = torch.sum(torch.unsqueeze(attention_probs, dim=-1) * x, dim=1)
261
+ '''
262
+ mul = torch.tanh(torch.matmul(x, self.W) + self.b)
263
+ mul = torch.matmul(mul, self.c)
264
+ attention_probs = nn.Softmax(dim=-1)(mul)
265
+ x = torch.sum(torch.unsqueeze(attention_probs, dim=-1) * x, dim=1)
266
+ return x
267
+
268
+
269
+ class TransformerPool1d(nn.Module):
270
+ def __init__(self, config, embeddings, embed_size, num_transformer_layers=2, CLS_ID=102, device="cuda"):
271
+ super(TransformerPool1d, self).__init__()
272
+ if embeddings:
273
+ self.embeddings = embeddings
274
+ else:
275
+ self.embeddings = nn.Parameter(torch.Tensor(1, 1, embed_size))
276
+ nn.init.trunc_normal_(self.embeddings, std=0.02)
277
+ # self.embeddings = BertEmbeddings(config)
278
+ self.CLS_ID = CLS_ID
279
+ self.device = device
280
+ config.num_hidden_layers = num_transformer_layers
281
+ self.encoder = BertEncoder(config)
282
+ self.pooler = BertPooler(config)
283
+
284
+ def forward(self, x):
285
+ # x:(B, Seq_len, Enbed)
286
+ B, Seq_len, Enbed = x.size()
287
+ #cls_emb_batch = self.embeddings(torch.tensor([[self.CLS_ID]] * x.size()[0], dtype=torch.long).to(self.device)) # B, 1
288
+ cls_emb_batch = self.embeddings.expand(B, 1, Enbed)
289
+ merged_output = torch.cat((cls_emb_batch, x), dim=1) # [B, Seq_len + 1, Enbed]
290
+ sequence_output = self.encoder(merged_output,
291
+ attention_mask=None,
292
+ head_mask=None,
293
+ encoder_hidden_states=None,
294
+ encoder_attention_mask=None,
295
+ output_attentions=False,
296
+ output_hidden_states=False,
297
+ return_dict=False)[0]
298
+ pooled_output = self.pooler(sequence_output)
299
+ return pooled_output
300
+
301
+
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:234ed601e664ca2e736f2427dfb8544b47370f641bbd82612297efca3943892a
3
+ size 6320919985
regression_loss.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ '''
4
+ @license: (C) Copyright 2021, Hey.
5
+ @author: Hey
6
+ @email: sanyuan.hy@alibaba-inc.com
7
+ @tel: 137****6540
8
+ @datetime: 2023/6/15 22:53
9
+ @project: LucaOne
10
+ @file: regression_loss.py
11
+ @desc: regression loss
12
+ '''
13
+ import warnings
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ from statsmodels.stats.stattools import durbin_watson
18
+
19
+ from .masked_loss import _MaskedLoss
20
+
21
+
22
+ def nanstd(input, dim=None, keepdim=False):
23
+ mu = torch.nanmean(input, dim=dim, keepdim=True)
24
+ return torch.sqrt(torch.nanmean((input - mu)**2, dim=dim, keepdim=keepdim))
25
+
26
+
27
+ def iqr(batch, dim=None, reduction='mean'):
28
+ if dim is None:
29
+ if len(batch.shape) == 1:
30
+ dim = 0
31
+ else:
32
+ dim = 1
33
+ if isinstance(batch, np.ndarray):
34
+ out = np.quantile(batch, 0.75, axis=dim) - \
35
+ np.quantile(batch, 0.25, axis=dim)
36
+ elif isinstance(batch, torch.Tensor):
37
+ out = torch.quantile(batch, 0.75, dim=dim) - \
38
+ torch.quantile(batch, 0.25, dim=dim)
39
+ if reduction == 'none':
40
+ return out
41
+ elif reduction == 'mean':
42
+ return out.mean()
43
+ else:
44
+ raise NotImplementedError
45
+
46
+
47
+ def naniqr(batch, dim=None, reduction='none'):
48
+ if dim is None:
49
+ if len(batch.shape) == 1:
50
+ dim = 0
51
+ else:
52
+ dim = 1
53
+ if isinstance(batch, np.ndarray):
54
+ out = np.nanquantile(batch, 0.75, axis=dim) - \
55
+ np.nanquantile(batch, 0.25, axis=dim)
56
+ elif isinstance(batch, torch.Tensor):
57
+ out = torch.nanquantile(batch, 0.75, dim=dim) - \
58
+ torch.nanquantile(batch, 0.25, dim=dim)
59
+ if reduction == 'none':
60
+ return out
61
+ elif reduction == 'mean':
62
+ return out.mean()
63
+ elif reduction == 'nanmean':
64
+ return torch.nanmean(out)
65
+ else:
66
+ raise NotImplementedError
67
+
68
+
69
+ def compute_dw(res, dim=1, replace_missing=0., reduction='none'):
70
+ """Durbin-Watson statistics
71
+ https://www.statsmodels.org/devel/generated/statsmodels.stats.stattools.durbin_watson.html
72
+ """
73
+ if isinstance(res, torch.Tensor):
74
+ res = res.detach().cpu().numpy()
75
+ if replace_missing is not None:
76
+ res = res.copy()
77
+ res[np.isnan(res)] = replace_missing
78
+ out = durbin_watson(res, axis=dim)
79
+ if reduction == 'mean':
80
+ return out.mean()
81
+ elif reduction == 'none':
82
+ return out
83
+ elif reduction == 'median':
84
+ return np.median(out)
85
+
86
+
87
+ def estimate_noise(x, dim=1, window_size=10, step=5, reduce='nanmean', keepdim=True):
88
+ noises = nanstd(x.unfold(dim, window_size, step), -1, keepdim=False)
89
+ if reduce == 'nanmedian':
90
+ return noises.nanmedian(dim, keepdim=keepdim).values
91
+ if reduce == 'nanmean':
92
+ return noises.nanmean(dim, keepdim=keepdim)
93
+ if reduce == 'median':
94
+ return noises.median(dim, keepdim=keepdim).values
95
+ if reduce == 'mean':
96
+ return noises.mean(dim, keepdim=keepdim)
97
+ if reduce == 'none':
98
+ return noises
99
+ raise ValueError
100
+
101
+
102
+ class MaskedMSELoss(_MaskedLoss):
103
+ """Masked MSE loss"""
104
+ def __init__(self, reduction='mean', ignore_nans=True, ignore_value=-100.0):
105
+ super().__init__(reduction=reduction, ignore_nans=ignore_nans, ignore_value=ignore_value)
106
+ self.criterion = nn.MSELoss(reduction='none')
107
+
108
+
109
+ class MaskedL1Loss(_MaskedLoss):
110
+ """Masked L1 loss."""
111
+
112
+ def __init__(self, reduction='mean', ignore_nans=True, ignore_value=-100.0):
113
+ super().__init__(reduction=reduction, ignore_nans=ignore_nans, ignore_value=ignore_value)
114
+ self.criterion = nn.L1Loss(reduction='none')
115
+
116
+
117
+ class MaskedHuberLoss(_MaskedLoss):
118
+ """Masked L1 loss."""
119
+
120
+ def __init__(self, reduction='mean', ignore_nans=True, delta=1, ignore_value=-100.0):
121
+ super().__init__(reduction=reduction, ignore_nans=ignore_nans, ignore_value=ignore_value)
122
+ self.criterion = nn.HuberLoss(reduction='none', delta=delta)
123
+
124
+
125
+ class IQRLoss(nn.Module):
126
+ "IQR of the residuals"
127
+ def __init__(self, reduction='nanmean', ignore_nans=True, ignore_value=-100.0):
128
+ super().__init__()
129
+ self.reduction = reduction
130
+ self.ignore_nans = ignore_nans
131
+ self.ignore_value = ignore_value
132
+
133
+ def forward(self, input, target=0.):
134
+ if isinstance(target, torch.Tensor) and not (target.size() == input.size()):
135
+ warnings.warn(
136
+ "Using a target size ({}) that is different to the input size ({}). "
137
+ "This will likely lead to incorrect results due to broadcasting. "
138
+ "Please ensure they have the same size.".format(
139
+ target.size(), input.size()),
140
+ stacklevel=2,
141
+ )
142
+ if self.ignore_nans:
143
+ return naniqr(target-input, reduction=self.reduction)
144
+ else:
145
+ return iqr(target-input, reduction=self.reduction)
146
+
147
+
148
+ class MaskedLogCoshLoss(_MaskedLoss):
149
+ def __init__(self, reduction='mean', ignore_nans=True, ignore_value=-100.0):
150
+ super().__init__(reduction=reduction, ignore_nans=ignore_nans, ignore_value=ignore_value)
151
+ self.criterion = LogCoshLoss(reduction='none')
152
+
153
+
154
+ class MaskedXTanhLoss(_MaskedLoss):
155
+ def __init__(self, reduction='mean', ignore_nans=True, ignore_value=-100.0):
156
+ super().__init__(reduction=reduction, ignore_nans=ignore_nans, ignore_value=ignore_value)
157
+ self.criterion = XTanhLoss(reduction='none')
158
+
159
+
160
+ class MaskedXSigmoidLoss(_MaskedLoss):
161
+ def __init__(self, reduction='mean', ignore_nans=True, ignore_value=-100.0):
162
+ super().__init__(reduction=reduction, ignore_nans=ignore_nans, ignore_value=ignore_value)
163
+ self.criterion = XSigmoidLoss(reduction='none')
164
+
165
+
166
+ class MaskedAlgebraicLoss(_MaskedLoss):
167
+ def __init__(self, reduction='mean', ignore_nans=True, ignore_value=-100.0):
168
+ super().__init__(reduction=reduction, ignore_nans=ignore_nans, ignore_value=ignore_value)
169
+ self.criterion = AlgebraicLoss(reduction='none')
170
+
171
+
172
+ class LogCoshLoss(torch.nn.Module):
173
+ def __init__(self, reduction='none'):
174
+ super().__init__()
175
+ self.reduction = reduction
176
+
177
+ def forward(self, input, target):
178
+ diff = input - target
179
+ if self.reduction == 'mean':
180
+ return torch.mean(torch.log(torch.cosh(diff + 1e-12)))
181
+ elif self.reduction == 'sum':
182
+ return torch.sum(torch.log(torch.cosh(diff + 1e-12)))
183
+ else:
184
+ return torch.log(torch.cosh(diff + 1e-12))
185
+
186
+
187
+ class XTanhLoss(torch.nn.Module):
188
+ def __init__(self, reduction='none'):
189
+ super().__init__()
190
+ self.reduction = reduction
191
+
192
+ def forward(self, input, target):
193
+ diff = input - target
194
+ if self.reduction == 'mean':
195
+ return torch.mean(diff * torch.tanh(diff))
196
+ elif self.reduction == 'sum':
197
+ return torch.sum(diff * torch.tanh(diff))
198
+ else:
199
+ return diff * torch.tanh(diff)
200
+
201
+
202
+ class XSigmoidLoss(torch.nn.Module):
203
+ def __init__(self, reduction='none'):
204
+ super().__init__()
205
+ self.reduction = reduction
206
+
207
+ def forward(self, input, target):
208
+ diff = input - target
209
+ if self.reduction == 'mean':
210
+ return torch.mean(2 * diff * torch.sigmoid(diff) - diff)
211
+ elif self.reduction == 'sum':
212
+ return torch.sum(2 * diff * torch.sigmoid(diff) - diff)
213
+ else:
214
+ return 2 * diff * torch.sigmoid(diff) - diff
215
+
216
+
217
+ class AlgebraicLoss(torch.nn.Module):
218
+ def __init__(self, reduction='none'):
219
+ super().__init__()
220
+ self.reduction = reduction
221
+
222
+ def forward(self, input, target):
223
+ diff = input - target
224
+ if self.reduction == 'mean':
225
+ return torch.mean(diff * diff / torch.sqrt(1 + diff * diff))
226
+ elif self.reduction == 'sum':
227
+ return torch.sum(diff * diff / torch.sqrt(1 + diff * diff))
228
+ else:
229
+ return diff * diff / torch.sqrt(1 + diff * diff)
230
+
231
+
232
+ if __name__ == "__main__":
233
+ import torch
234
+ label = torch.Tensor([[[1], [1], [-100]], [[1], [-100], [0]]])
235
+ pred = torch.Tensor([[[2], [1], [3]], [[2], [1], [3]]])
236
+ loss = MaskedMSELoss(reduction="mean", ignore_nans=True, ignore_value=-100.0)
237
+ print("loss:")
238
+ print(loss(pred, label))
special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoTokenizer": [
4
+ "alphabet.AlphabetTokenizer",
5
+ null
6
+ ]
7
+ },
8
+ "clean_up_tokenization_spaces": true,
9
+ "model_max_length": 1000000000000000019884624838656,
10
+ "tokenizer_class": "AlphabetTokenizer"
11
+ }
utils.py ADDED
@@ -0,0 +1,979 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ import math
5
+ import os, csv, json
6
+ import io, textwrap, itertools
7
+ import subprocess
8
+ from Bio import SeqIO
9
+ import torch
10
+ import numpy as np
11
+ import sys, random
12
+ from sklearn.metrics import confusion_matrix
13
+ import matplotlib.pyplot as plt
14
+ import pynvml, requests
15
+ from collections import OrderedDict
16
+
17
+ plt.rcParams.update({'font.size': 18})
18
+ plt.rcParams['axes.unicode_minus'] = False
19
+
20
+ from .file_operator import file_reader
21
+ from .multi_label_metrics import prob_2_pred, relevant_indexes, metrics_multi_label
22
+ from .metrics import metrics_multi_class, metrics_binary, metrics_regression
23
+
24
+ common_nucleotide_set = {'A', 'T', 'C', 'G', 'U', 'N'}
25
+
26
+ # not {'O', 'U', 'Z', 'J', 'B'}
27
+ # Common amino acids
28
+ common_amino_acid_set = {'R', 'X', 'S', 'G', 'W', 'I', 'Q', 'A', 'T', 'V', 'K', 'Y', 'C', 'N', 'L', 'F', 'D', 'M', 'P', 'H', 'E'}
29
+
30
+
31
+ def to_device(device, batch):
32
+ '''
33
+ input to device
34
+ :param device:
35
+ :param batch:
36
+ :return:
37
+ '''
38
+ new_batch = {}
39
+ sample_num = 0
40
+ tens = None
41
+ for item1 in batch.items():
42
+ new_batch[item1[0]] = {}
43
+ if isinstance(item1[1], dict):
44
+ for item2 in item1[1].items():
45
+ new_batch[item1[0]][item2[0]] = {}
46
+ if isinstance(item2[1], dict):
47
+ for item3 in item2[1].items():
48
+ if item3[1] is not None and not isinstance(item3[1], int) and not isinstance(item3[1], str) and not isinstance(item3[1], float):
49
+ new_batch[item1[0]][item2[0]][item3[0]] = item3[1].to(device)
50
+ tens = item3[1]
51
+ else:
52
+ new_batch[item1[0]][item2[0]][item3[0]] = item3[1]
53
+ else:
54
+ if item2[1] is not None and not isinstance(item2[1], int) and not isinstance(item2[1], str) and not isinstance(item2[1], float):
55
+ new_batch[item1[0]][item2[0]] = item2[1].to(device)
56
+ tens = item2[1]
57
+ else:
58
+ new_batch[item1[0]][item2[0]] = item2[1]
59
+ else:
60
+ if item1[1] is not None and not isinstance(item1[1], int) and not isinstance(item1[1], str) and not isinstance(item1[1], float):
61
+ new_batch[item1[0]] = item1[1].to(device)
62
+ tens = item1[1]
63
+ else:
64
+ new_batch[item1[0]] = item1[1]
65
+ if tens is not None:
66
+ sample_num = tens.shape[0]
67
+ return new_batch, sample_num
68
+
69
+
70
+ def get_parameter_number(model):
71
+ '''
72
+ colc the parameter number of the model
73
+ :param model:
74
+ :return:
75
+ '''
76
+ param_size = 0
77
+ param_sum = 0
78
+ trainable_size = 0
79
+ trainable_num = 0
80
+ for param in model.parameters():
81
+ cur_size = param.nelement() * param.element_size()
82
+ cur_num = param.nelement()
83
+ param_size += cur_size
84
+ param_sum += cur_num
85
+ if param.requires_grad:
86
+ trainable_size += cur_size
87
+ trainable_num += cur_num
88
+ buffer_size = 0
89
+ buffer_sum = 0
90
+ for buffer in model.buffers():
91
+ buffer_size += buffer.nelement() * buffer.element_size()
92
+ buffer_sum += buffer.nelement()
93
+ '''
94
+ total_num = sum(p.numel() for p in model.parameters())
95
+ total_size = sum(p.numel() * p.element_size() for p in model.parameters())
96
+ total_num += sum(p.numel() for p in model.buffers())
97
+ total_size += sum(p.numel() * p.element_size() for p in model.buffers())
98
+ trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
99
+ trainable_size = sum(p.numel() * p.element_size() for p in model.parameters() if p.requires_grad)
100
+ '''
101
+ return {
102
+ 'total_num': "%fM" % round((buffer_sum + param_sum)/(1024 * 1024), 2),
103
+ 'total_size': "%fMB" % round((buffer_size + param_size)/(1024 * 1024), 2),
104
+ 'param_sum': "%fM" % round(param_sum/(1024 * 1024), 2),
105
+ 'param_size': "%fMB" % round(param_size/(1024 * 1024), 2),
106
+ 'buffer_sum': "%fM" % round(buffer_sum/(1024 * 1024), 2),
107
+ 'buffer_size': "%fMB" % round(buffer_size/(1024 * 1024), 2),
108
+ 'trainable_num': "%fM" % round(trainable_num/(1024 * 1024), 10),
109
+ 'trainable_size': "%fMB" % round(trainable_size/(1024 * 1024), 10)
110
+ }
111
+
112
+
113
+ def set_seed(args):
114
+ random.seed(args.seed)
115
+ np.random.seed(args.seed)
116
+ torch.manual_seed(args.seed)
117
+ if args.n_gpu > 0:
118
+ torch.cuda.manual_seed(args.seed)
119
+ torch.cuda.manual_seed_all(args.seed)
120
+
121
+
122
+ def label_id_2_label_name(output_mode, label_list, prob, threshold=0.5):
123
+ '''
124
+ convect label id to label name
125
+ :param output_mode:
126
+ :param label_list:
127
+ :param prob:
128
+ :param threshold:
129
+ :return:
130
+ '''
131
+ if output_mode in ["multi-label", "multi_label"]:
132
+ res = []
133
+ pred = prob_2_pred(prob, threshold)
134
+ pred_index = relevant_indexes(pred)
135
+ for row in range(prob.shape[0]):
136
+ label_names = [label_list[idx] for idx in pred_index[row]]
137
+ res.append(label_names)
138
+ return res
139
+ elif output_mode in ["multi-class", "multi_class"]:
140
+ pred = np.argmax(prob, axis=1)
141
+ label_names = [label_list[idx] for idx in pred]
142
+ return label_names
143
+ elif output_mode in ["binary-class", "binary_class"]:
144
+ if prob.ndim == 2:
145
+ prob = prob.flatten(order="C")
146
+ pred = prob_2_pred(prob, threshold)
147
+ label_names = [label_list[idx] for idx in pred]
148
+ return label_names
149
+ else:
150
+ raise KeyError(output_mode)
151
+
152
+
153
+ def plot_bins(data, xlabel, ylabel, bins, filepath):
154
+ '''
155
+ plot bins
156
+ :param data:
157
+ :param xlabel:
158
+ :param ylabel:
159
+ :param bins: bins number
160
+ :param filepath: png save filepath
161
+ :return:
162
+ '''
163
+ plt.figure(figsize=(40, 20), dpi=100)
164
+ plt.hist(data, bins=bins)
165
+ # plt.xticks(range(min(data), max(data)))
166
+ # plt.grid(linestyle='--', alpha=0.5)
167
+
168
+ plt.xlabel(xlabel)
169
+ plt.ylabel(ylabel)
170
+ if filepath is None:
171
+ plt.show()
172
+ else:
173
+ plt.savefig(filepath)
174
+ plt.clf()
175
+ plt.close()
176
+
177
+
178
+ def plot_confusion_matrix_for_binary_class(targets, preds, cm=None, savepath=None):
179
+ '''
180
+ :param targets: ground truth
181
+ :param preds: prediction probs
182
+ :param cm: confusion matrix
183
+ :param savepath: confusion matrix picture savepth
184
+ '''
185
+
186
+ plt.figure(figsize=(40, 20), dpi=100)
187
+ if cm is None:
188
+ cm = confusion_matrix(targets, preds, labels=[0, 1])
189
+
190
+ plt.matshow(cm, cmap=plt.cm.Oranges)
191
+ plt.colorbar()
192
+
193
+ for x in range(len(cm)):
194
+ for y in range(len(cm)):
195
+ plt.annotate(cm[x, y], xy=(y, x), verticalalignment='center', horizontalalignment='center')
196
+ plt.ylabel('True')
197
+ plt.xlabel('Prediction')
198
+ if savepath:
199
+ plt.savefig(savepath, dpi=100)
200
+ else:
201
+ plt.show()
202
+ plt.close("all")
203
+
204
+
205
+ def save_labels(filepath, label_list):
206
+ '''
207
+ save labels
208
+ :param filepath:
209
+ :param label_list:
210
+ :return:
211
+ '''
212
+ with open(filepath, "w") as wfp:
213
+ wfp.write("label" + "\n")
214
+ for label in label_list:
215
+ wfp.write(label + "\n")
216
+
217
+
218
+ def load_labels(filepath, header=True):
219
+ '''
220
+ load labels
221
+ :param filepath:
222
+ :param header: where the file has header or not
223
+ :return:
224
+ '''
225
+ label_list = []
226
+ with open(filepath, "r") as rfp:
227
+ for label in rfp:
228
+ label_list.append(label.strip())
229
+ if len(label_list) > 0 and (header or label_list[0] == "label"):
230
+ return label_list[1:]
231
+ return label_list
232
+
233
+
234
+ def load_vocab(vocab_path):
235
+ '''
236
+ load vocab
237
+ :param vocab_path:
238
+ :return:
239
+ '''
240
+ vocab = {}
241
+ with open(vocab_path, "r") as rfp:
242
+ for line in rfp:
243
+ v = line.strip()
244
+ vocab[v] = len(vocab)
245
+ return vocab
246
+
247
+
248
+ def subprocess_popen(statement):
249
+ '''
250
+ execute shell cmd
251
+ :param statement:
252
+ :return:
253
+ '''
254
+ p = subprocess.Popen(statement, shell=True, stdout=subprocess.PIPE)
255
+ while p.poll() is None:
256
+ if p.wait() != 0:
257
+ print("fail.")
258
+ return False
259
+ else:
260
+ re = p.stdout.readlines()
261
+ result = []
262
+ for i in range(len(re)):
263
+ res = re[i].decode('utf-8').strip('\r\n')
264
+ result.append(res)
265
+ return result
266
+
267
+
268
+ def prepare_inputs(input_type, embedding_type, batch):
269
+ if input_type == "sequence":
270
+ inputs = {
271
+ "input_ids_a": batch[0],
272
+ "attention_mask_a": batch[1],
273
+ "token_type_ids_a": batch[2],
274
+ "input_ids_b": batch[4],
275
+ "attention_mask_b": batch[5],
276
+ "token_type_ids_b": batch[6],
277
+ "labels": batch[-1]
278
+ }
279
+ elif input_type == "embedding":
280
+ if embedding_type not in ["vector", "bos"]:
281
+ inputs = {
282
+ "embedding_info_a": batch[0],
283
+ "embedding_attention_mask_a": batch[1],
284
+ "embedding_info_b": batch[2],
285
+ "embedding_attention_mask_b": batch[3],
286
+ "labels": batch[-1]
287
+ }
288
+ else:
289
+ inputs = {
290
+ "embedding_info_a": batch[0],
291
+ "embedding_attention_mask_a": None,
292
+ "embedding_info_b": batch[1],
293
+ "embedding_attention_mask_b": None,
294
+ "labels": batch[-1]
295
+ }
296
+ elif input_type == "structure":
297
+ inputs = {
298
+ "struct_input_ids_a": batch[0],
299
+ "struct_contact_map_a": batch[1],
300
+ "struct_input_ids_b": batch[2],
301
+ "struct_contact_map_b": batch[3],
302
+ "labels": batch[-1]
303
+ }
304
+ elif input_type == "sefn":
305
+ if embedding_type not in ["vector", "bos"]:
306
+ inputs = {
307
+ "input_ids_a": batch[0],
308
+ "attention_mask_a": batch[1],
309
+ "token_type_ids_a": batch[2],
310
+ "embedding_info_a": batch[4],
311
+ "embedding_attention_mask_a": batch[5],
312
+ "input_ids_b": batch[6],
313
+ "attention_mask_b": batch[7],
314
+ "token_type_ids_b": batch[8],
315
+ "embedding_info_b": batch[10],
316
+ "embedding_attention_mask_b": batch[11],
317
+ "labels": batch[-1],
318
+ }
319
+ else:
320
+ inputs = {
321
+ "input_ids_a": batch[0],
322
+ "attention_mask_a": batch[1],
323
+ "token_type_ids_a": batch[2],
324
+ "embedding_info_a": batch[4],
325
+ "embedding_attention_mask_a": None,
326
+ "input_ids_b": batch[5],
327
+ "attention_mask_b": batch[6],
328
+ "token_type_ids_b": batch[7],
329
+ "embedding_info_b": batch[9],
330
+ "embedding_attention_mask_b": None,
331
+ "labels": batch[-1],
332
+ }
333
+ elif input_type == "ssfn":
334
+ inputs = {
335
+ "input_ids_a": batch[0],
336
+ "attention_mask_a": batch[1],
337
+ "token_type_ids_a": batch[2],
338
+ "struct_input_ids_a": batch[4],
339
+ "struct_contact_map_a": batch[5],
340
+ "input_ids_b": batch[6],
341
+ "attention_mask_b": batch[7],
342
+ "token_type_ids_b": batch[8],
343
+ "struct_input_ids_b": batch[10],
344
+ "struct_contact_map_b": batch[11],
345
+ "labels": batch[-1]
346
+ }
347
+ else:
348
+ inputs = None
349
+ return inputs
350
+
351
+
352
+ def gene_seq_replace_re(seq):
353
+ '''
354
+ Nucleic acid 还原
355
+ :param seq:
356
+ :return:
357
+ '''
358
+ new_seq = ""
359
+ for ch in seq:
360
+ if ch == '1':
361
+ new_seq += "A"
362
+ elif ch == '2':
363
+ new_seq += "T"
364
+ elif ch == '3':
365
+ new_seq += "C"
366
+ elif ch == '4':
367
+ new_seq += "G"
368
+ else: # unknown
369
+ new_seq += "N"
370
+ return new_seq
371
+
372
+
373
+ def gene_seq_replace(seq):
374
+ '''
375
+ Nucleic acid (gene replace: A->1, U/T->2, C->3, G->4, N->5
376
+ :param seq:
377
+ :return:
378
+ '''
379
+ new_seq = ""
380
+ for ch in seq:
381
+ if ch in ["A", "a"]:
382
+ new_seq += "1"
383
+ elif ch in ["T", "U", "t", "u"]:
384
+ new_seq += "2"
385
+ elif ch in ["C", "c"]:
386
+ new_seq += "3"
387
+ elif ch in ["G", "g"]:
388
+ new_seq += "4"
389
+ else: # unknown
390
+ new_seq += "5"
391
+ return new_seq
392
+
393
+
394
+ def get_labels(label_filepath, header=True):
395
+ '''
396
+ get labels from file, exists header
397
+ :param label_filepath:
398
+ :param header:
399
+ :return:
400
+ '''
401
+ with open(label_filepath, "r") as fp:
402
+ labels = []
403
+ multi_cols = False
404
+ cnt = 0
405
+ for line in fp:
406
+ line = line.strip()
407
+ cnt += 1
408
+ if cnt == 1 and (header or line == "label"):
409
+ if line.find(",") > 0:
410
+ multi_cols = True
411
+ continue
412
+ if multi_cols:
413
+ idx = line.find(",")
414
+ if idx > 0:
415
+ label_name = line[idx + 1:].strip()
416
+ else:
417
+ label_name = line
418
+ else:
419
+ label_name = line
420
+ labels.append(label_name)
421
+ return labels
422
+
423
+
424
+ def available_gpu_id():
425
+ '''
426
+ 计算可用的GPU id
427
+ :return:
428
+ '''
429
+ pynvml.nvmlInit()
430
+ if not torch.cuda.is_available():
431
+ print("GPU not available")
432
+ return -1
433
+ # 获取GPU数量
434
+ device_count = pynvml.nvmlDeviceGetCount()
435
+ max_available_gpu = -1
436
+ max_available_rate = 0
437
+
438
+ # 遍历所有GPU并检查可用性
439
+ for i in range(device_count):
440
+ handle = pynvml.nvmlDeviceGetHandleByIndex(i)
441
+ memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
442
+ utilization = pynvml.nvmlDeviceGetUtilizationRates(handle)
443
+ # 假设如果GPU利用率小于某个阈值(例如10%),我们认为这个GPU目前是空闲的
444
+ if utilization.gpu < 10 and max_available_rate < 100 - utilization.gpu:
445
+ max_available_rate = 100 - utilization.gpu
446
+ max_available_gpu = i
447
+ # 打印可用的GPU ID
448
+ if max_available_gpu > -1:
449
+ print("Available GPU ID: %d, Free Rate: %0.2f%%" % (max_available_gpu, max_available_rate))
450
+ else:
451
+ print("No Available GPU!")
452
+
453
+ # Shutdown NVML
454
+ pynvml.nvmlShutdown()
455
+ return max_available_gpu
456
+
457
+
458
+ def eval_metrics(output_mode, truths, preds, threshold=0.5):
459
+ '''
460
+ eval metrics
461
+ :param output_mode:
462
+ :param truths:
463
+ :param preds:
464
+ :param threshold:
465
+ :return:
466
+ '''
467
+ print("\ntruths size: ", truths.shape)
468
+ print("\npreds size: ", preds.shape)
469
+ if output_mode in ["multi-label", "multi_label"]:
470
+ return metrics_multi_label(truths, preds, threshold=threshold)
471
+ elif output_mode in ["multi-class", "multi_class"]:
472
+ return metrics_multi_class(truths, preds)
473
+ elif output_mode == "regression":
474
+ return metrics_regression(truths, preds)
475
+ elif output_mode in ["binary-class", "binary_class"]:
476
+ return metrics_binary(truths, preds, threshold=threshold)
477
+ else:
478
+ raise Exception("Not Support this output mode: %s" % output_mode)
479
+
480
+
481
+ def load_trained_model(model_config, args, model_class, model_dirpath):
482
+ # load exists checkpoint
483
+ print("load pretrained model: %s" % model_dirpath)
484
+ try:
485
+ model = model_class.from_pretrained(model_dirpath, args=args)
486
+ except Exception as e:
487
+ model = model_class(model_config, args=args)
488
+ pretrained_net_dict = torch.load(os.path.join(args.model_dirpath, "pytorch.pth"),
489
+ map_location=torch.device("cpu"))
490
+ model_state_dict_keys = set()
491
+ for key in model.state_dict():
492
+ model_state_dict_keys.add(key)
493
+ new_state_dict = OrderedDict()
494
+ for k, v in pretrained_net_dict.items():
495
+ if k.startswith("module."):
496
+ # remove `module.`
497
+ name = k[7:]
498
+ else:
499
+ name = k
500
+ if name in model_state_dict_keys:
501
+ new_state_dict[name] = v
502
+ # print("diff:")
503
+ # print(model_state_dict_keys.difference(new_state_dict.keys()))
504
+ model.load_state_dict(new_state_dict)
505
+ return model
506
+
507
+
508
+ def clean_seq(protein_id, seq, return_rm_index=False):
509
+ seq = seq.upper()
510
+ new_seq = ""
511
+ has_invalid_char = False
512
+ invalid_char_set = set()
513
+ return_rm_index_set = set()
514
+ for idx, ch in enumerate(seq):
515
+ if 'A' <= ch <= 'Z' and ch not in ['J']:
516
+ new_seq += ch
517
+ else:
518
+ invalid_char_set.add(ch)
519
+ return_rm_index_set.add(idx)
520
+ has_invalid_char = True
521
+ if has_invalid_char:
522
+ print("id: %s. Seq: %s" % (protein_id, seq))
523
+ print("invalid char set:", invalid_char_set)
524
+ print("return_rm_index:", return_rm_index_set)
525
+ if return_rm_index:
526
+ return new_seq, return_rm_index_set
527
+ return new_seq
528
+
529
+
530
+ def sample_size(data_dirpath):
531
+ if os.path.isdir(data_dirpath):
532
+ new_filepaths = []
533
+ for filename in os.listdir(data_dirpath):
534
+ if not filename.startswith("."):
535
+ new_filepaths.append(os.path.join(data_dirpath, filename))
536
+ filepaths = new_filepaths
537
+ else:
538
+ filepaths = [data_dirpath]
539
+ total = 0
540
+ for filepath in filepaths:
541
+ header = filepath.endswith(".tsv") or filepath.endswith(".csv")
542
+ print("sample_size filepath: %s" % filepath)
543
+ for _ in file_reader(filepath, header=header, header_filter=True):
544
+ total += 1
545
+ return total
546
+
547
+
548
+ def writer_info_tb(tb_writer, logs, global_step, prefix=None):
549
+ '''
550
+ write info to tensorboard
551
+ :param tb_writer:
552
+ :param logs:
553
+ :param global_step:
554
+ :param prefix:
555
+ :return:
556
+ '''
557
+ for key, value in logs.items():
558
+ if isinstance(value, dict):
559
+ '''
560
+ for key1, value1 in value.items():
561
+ tb_writer.add_scalar(key + "_" + key1, value1, global_step)
562
+ '''
563
+ writer_info_tb(tb_writer, value, global_step, prefix=key)
564
+ elif not math.isnan(value) and not math.isinf(value):
565
+ tb_writer.add_scalar(prefix + "_" + key if prefix else key, value, global_step)
566
+ else:
567
+ print("writer_info_tb NaN or Inf, Key-Value: %s=%s" % (key, value))
568
+
569
+
570
+ def get_lr(optimizer):
571
+ '''
572
+ get learning rate
573
+ :param optimizer:
574
+ :return:
575
+ '''
576
+ for p in optimizer.param_groups:
577
+ if "lr" in p:
578
+ return p["lr"]
579
+
580
+
581
+ def metrics_merge(results, all_results):
582
+ '''
583
+ merge metrics
584
+ :param results:
585
+ :param all_results:
586
+ :return:
587
+ '''
588
+ for item1 in results.items():
589
+ if item1[0] not in all_results:
590
+ all_results[item1[0]] = {}
591
+ for item2 in item1[1].items():
592
+ if item2[0] not in all_results[item1[0]]:
593
+ all_results[item1[0]][item2[0]] = {}
594
+ for item3 in item2[1].items():
595
+ if item3[0] not in all_results[item1[0]][item2[0]]:
596
+ all_results[item1[0]][item2[0]][item3[0]] = item3[1]
597
+ else:
598
+ all_results[item1[0]][item2[0]][item3[0]] += item3[1]
599
+ return all_results
600
+
601
+
602
+ def print_shape(item):
603
+ '''
604
+ print shape
605
+ :param item:
606
+ :return:
607
+ '''
608
+ if isinstance(item, dict):
609
+ for item1 in item.items():
610
+ print(item1[0] + ":")
611
+ print_shape(item1[1])
612
+ elif isinstance(item, list):
613
+ for idx, item1 in enumerate(item):
614
+ print("idx: %d" % idx)
615
+ print_shape(item1)
616
+ else:
617
+ print("shape:", item.shape)
618
+
619
+
620
+ def process_outputs(output_mode, truth, pred, output_truth, output_pred, ignore_index, keep_seq=False):
621
+ if keep_seq:
622
+ # to do
623
+ return None, None
624
+ else:
625
+ if output_mode in ["multi_class", "multi-class"]:
626
+ cur_truth = truth.view(-1)
627
+ cur_mask = cur_truth != ignore_index
628
+ cur_pred = pred.view(-1, pred.shape[-1])
629
+ cur_truth = cur_truth[cur_mask]
630
+ cur_pred = cur_pred[cur_mask, :]
631
+ sum_v = cur_mask.sum().item()
632
+ elif output_mode in ["multi_label", "multi-label"]:
633
+ cur_truth = truth.view(-1, truth.shape[-1])
634
+ cur_pred = pred.view(-1, pred.shape[-1])
635
+ sum_v = pred.shape[0]
636
+ elif output_mode in ["binary_class", "binary-class"]:
637
+ cur_truth = truth.view(-1)
638
+ cur_mask = cur_truth != ignore_index
639
+ cur_pred = pred.view(-1)
640
+ cur_truth = cur_truth[cur_mask]
641
+ cur_pred = cur_pred[cur_mask]
642
+ sum_v = cur_mask.sum().item()
643
+ elif output_mode in ["regression"]:
644
+ cur_truth = truth.view(-1)
645
+ cur_mask = cur_truth != ignore_index
646
+ cur_pred = pred.view(-1)
647
+ cur_truth = cur_truth[cur_mask]
648
+ cur_pred = cur_pred[cur_mask]
649
+ sum_v = cur_mask.sum().item()
650
+ else:
651
+ raise Exception("not output mode: %s" % output_mode)
652
+ if sum_v > 0:
653
+ cur_truth = cur_truth.detach().cpu().numpy()
654
+ cur_pred = cur_pred.detach().cpu().numpy()
655
+ if output_truth is None or output_pred is None:
656
+ return cur_truth, cur_pred
657
+ else:
658
+ output_truth = np.append(output_truth, cur_truth, axis=0)
659
+ output_pred = np.append(output_pred, cur_pred, axis=0)
660
+ return output_truth, output_pred
661
+ return truth, pred
662
+
663
+
664
+ def print_batch(value, key=None, debug_path=None, wfp=None, local_rank=-1):
665
+ '''
666
+ print a batch
667
+ :param value:
668
+ :param key:
669
+ :param debug_path:
670
+ :param wfp:
671
+ :param local_rank:
672
+ :return:
673
+ '''
674
+ if isinstance(value, list):
675
+ for idx, v in enumerate(value):
676
+ if wfp is not None:
677
+ if v is not None:
678
+ wfp.write(str([torch.min(v), torch.min(torch.where(v == -100, 10000, v)), torch.max(v)]) + "\n")
679
+ wfp.write(str(v.shape) + "\n")
680
+ else:
681
+ wfp.write("None\n")
682
+ wfp.write("-" * 10 + "\n")
683
+ else:
684
+ if v is not None:
685
+ print([torch.min(v), torch.min(torch.where(v == -100, 10000, v)), torch.max(v)])
686
+ print(v.shape)
687
+ else:
688
+ print("None")
689
+ print("-" * 50)
690
+ if v is not None:
691
+ try:
692
+ value = v.detach().cpu().numpy().astype(int)
693
+ if debug_path is not None:
694
+ if value.ndim == 3:
695
+ for dim_1_idx in range(value.shape[0]):
696
+ np.savetxt(os.path.join(debug_path, "%s_batch_%d.txt" % (key, dim_1_idx)), value[dim_1_idx, :, :], fmt='%i', delimiter=",")
697
+ else:
698
+ np.savetxt(os.path.join(debug_path, "%d.txt" % idx), value, fmt='%i', delimiter=",")
699
+ else:
700
+ if value.ndim == 3:
701
+ for dim_1_idx in range(value.shape[0]):
702
+ np.savetxt(os.path.join(debug_path, "%s_batch_%d.txt" % (key, dim_1_idx)), value[dim_1_idx, :, :], fmt='%i', delimiter=",")
703
+ else:
704
+ np.savetxt("%d.txt" % idx, value, fmt='%i', delimiter=",")
705
+ except Exception as e:
706
+ print(e)
707
+ elif isinstance(value, dict):
708
+ for item in value.items():
709
+ if wfp is not None:
710
+ wfp.write(str(item[0]) + ":\n")
711
+ else:
712
+ print(str(item[0]) + ':')
713
+ print_batch(item[1], item[0], debug_path, wfp, local_rank)
714
+ else:
715
+ if wfp is not None:
716
+ if value is not None:
717
+ wfp.write(str([torch.min(value), torch.min(torch.where(value == -100, 10000, value)), torch.max(value)]) + "\n")
718
+ wfp.write(str(value.shape) + "\n")
719
+ else:
720
+ wfp.write("None\n")
721
+ wfp.write("-" * 10 + "\n")
722
+ else:
723
+ if value is not None:
724
+ print([torch.min(value), torch.min(torch.where(value == -100, 10000, value)), torch.max(value)])
725
+ print(value.shape)
726
+ else:
727
+ print("None")
728
+ print("-" * 10)
729
+ if value is not None:
730
+ if key != "prot_structure":
731
+ fmt = '%i'
732
+ d_type = int
733
+ else:
734
+ fmt = '%0.4f'
735
+ d_type = float
736
+ try:
737
+ value = value.detach().cpu().numpy().astype(d_type)
738
+ if debug_path is not None:
739
+ if value.ndim == 3:
740
+ for dim_1_idx in range(value.shape[0]):
741
+ np.savetxt(os.path.join(debug_path, "%s_batch_%d.txt" % (key, dim_1_idx)), value[dim_1_idx, :, :], fmt=fmt, delimiter=",")
742
+ else:
743
+ np.savetxt(os.path.join(debug_path, "%s.txt" % key), value, fmt=fmt, delimiter=",")
744
+ else:
745
+ if value.ndim == 3:
746
+ for dim_1_idx in range(value.shape[0]):
747
+ np.savetxt("%s_batch_%d.txt" % (key, dim_1_idx), value[dim_1_idx, :, :], fmt=fmt, delimiter=",")
748
+ else:
749
+ np.savetxt("%s.txt" % key, value, fmt=fmt, delimiter=",")
750
+ except Exception as e:
751
+ print(e)
752
+
753
+
754
+ def gcd(x, y):
755
+ '''
756
+ 最大公约数
757
+ :param x:
758
+ :param y:
759
+ :return:
760
+ '''
761
+ m = max(x, y)
762
+ n = min(x, y)
763
+ while m % n:
764
+ m, n = n, m % n
765
+ return n
766
+
767
+
768
+ def lcm(x, y):
769
+ '''
770
+ 最小公倍数
771
+ :param x:
772
+ :param y:
773
+ :return:
774
+ '''
775
+ m = max(x, y)
776
+ n = min(x, y)
777
+ while m % n:
778
+ m, n = n, m % n
779
+ return x*y//n
780
+
781
+
782
+ def device_memory(gpu_id):
783
+ if gpu_id is None or gpu_id < 0:
784
+ return
785
+ pynvml.nvmlInit()
786
+ device_cnt = pynvml.nvmlDeviceGetCount()
787
+ for idx in range(device_cnt):
788
+ if gpu_id is not None and gpu_id != idx:
789
+ continue
790
+ handle = pynvml.nvmlDeviceGetHandleByIndex(idx)
791
+ info = pynvml.nvmlDeviceGetMemoryInfo(handle)
792
+ print(f"Device {idx}: {pynvml.nvmlDeviceGetName(handle)}")
793
+ print(f"Total memory: {info.total / 1024**3:.8f} GB")
794
+ print(f"Used memory: {info.used / 1024**3:.8f} GB")
795
+ print(f"Free memory: {info.free / 1024**3:.8f} GB")
796
+ pynvml.nvmlShutdown()
797
+
798
+
799
+ def calc_emb_filename_by_seq_id(seq_id, embedding_type):
800
+ """
801
+ 根据seq_id得到emb_filename
802
+ :param seq_id:
803
+ :param embedding_type:
804
+ :return:
805
+ """
806
+ if seq_id[0] == ">":
807
+ seq_id = seq_id[1:]
808
+ if "|" in seq_id:
809
+ strs = seq_id.split("|")
810
+ if len(strs) > 1:
811
+ emb_filename = embedding_type + "_" + strs[1].strip() + ".pt"
812
+ else:
813
+ emb_filename = embedding_type + "_" + seq_id.replace(" ", "").replace("/", "_") + ".pt"
814
+ else:
815
+ emb_filename = embedding_type + "_" + seq_id.replace(" ", "").replace("/", "_") + ".pt"
816
+ return emb_filename
817
+
818
+
819
+ def download_file(url, local_filename):
820
+ with requests.get(url, stream=True) as r:
821
+ r.raise_for_status()
822
+ dir_name = os.path.dirname(local_filename)
823
+ if not os.path.exists(dir_name):
824
+ os.makedirs(dir_name)
825
+ with open(local_filename, 'wb') as f:
826
+ for chunk in r.iter_content(chunk_size=8192):
827
+ if chunk: # filter out keep-alive new chunks
828
+ f.write(chunk)
829
+ return local_filename
830
+
831
+
832
+ def download_folder(base_url, file_names, local_dir):
833
+ if not os.path.exists(local_dir):
834
+ os.makedirs(local_dir)
835
+
836
+ for file_name in file_names:
837
+ file_url = f"{base_url}/{file_name}"
838
+ local_filename = os.path.join(local_dir, file_name)
839
+ download_file(file_url, local_filename)
840
+ print(f"Downloaded {file_name}")
841
+
842
+
843
+ def download_trained_checkpoint_lucaone(
844
+ llm_dir,
845
+ llm_type="lucaone_gplm",
846
+ llm_version="v2.0",
847
+ llm_task_level="token_level,span_level,seq_level,structure_level",
848
+ llm_time_str="20231125113045",
849
+ llm_step="5600000",
850
+ base_url="http://47.93.21.181/lucaone/TrainedCheckPoint"
851
+ ):
852
+ """
853
+ donwload trained checkpoint of LucaOne
854
+ :param llm_dir:
855
+ :param llm_type:
856
+ :param llm_version:
857
+ :param llm_task_level:
858
+ :param llm_time_str:
859
+ :param llm_step:
860
+ :param base_url:
861
+ :return:
862
+ """
863
+ print("------Download Trained LLM(LucaOne)------")
864
+ try:
865
+ logs_file_names = ["logs.txt"]
866
+ models_file_names = ["config.json", "pytorch.pth", "training_args.bin", "tokenizer/alphabet.pkl"]
867
+ logs_path = "logs/lucagplm/%s/%s/%s/%s" % (llm_version, llm_task_level, llm_type, llm_time_str)
868
+ models_path = "models/lucagplm/%s/%s/%s/%s/checkpoint-step%s" % (llm_version, llm_task_level, llm_type, llm_time_str, llm_step)
869
+ logs_local_dir = os.path.join(llm_dir, logs_path)
870
+ exists = True
871
+ for logs_file_name in logs_file_names:
872
+ if not os.path.exists(os.path.join(logs_local_dir, logs_file_name)):
873
+ exists = False
874
+ break
875
+ models_local_dir = os.path.join(llm_dir, models_path)
876
+ if exists:
877
+ for models_file_name in models_file_names:
878
+ if not os.path.exists(os.path.join(models_local_dir, models_file_name)):
879
+ exists = False
880
+ break
881
+ if not exists:
882
+ print("*" * 20 + "Downloading" + "*" * 20)
883
+ print("Downloading LucaOne TrainedCheckPoint: LucaOne-%s-%s-%s ..." % (llm_version, llm_time_str, llm_step))
884
+ print("Wait a moment, please.")
885
+ # download logs
886
+ if not os.path.exists(logs_local_dir):
887
+ os.makedirs(logs_local_dir)
888
+ logs_base_url = os.path.join(base_url, logs_path)
889
+ download_folder(logs_base_url, logs_file_names, logs_local_dir)
890
+ # download models
891
+ if not os.path.exists(models_local_dir):
892
+ os.makedirs(models_local_dir)
893
+ models_base_url = os.path.join(base_url, models_path)
894
+ download_folder(models_base_url, models_file_names, models_local_dir)
895
+ print("LucaOne Download Succeed.")
896
+ print("*" * 50)
897
+ except Exception as e:
898
+ print(e)
899
+ print("Download automatically LucaOne Trained CheckPoint failed!")
900
+ print("You can manually download 'logs/' and 'models/' into local directory: %s/ from %s" % (os.path.abspath(llm_dir), os.path.join(base_url, "TrainedCheckPoint/")))
901
+ raise Exception(e)
902
+
903
+
904
+ def download_trained_checkpoint_downstream_tasks(
905
+ save_dir="../",
906
+ dataset_name=["CentralDogma", "GenusTax", "InfA", "ncRNAFam", "ncRPI", "PPI", "ProtLoc", "ProtStab", "SpeciesTax", "SupKTax"],
907
+ dataset_type=["gene_protein", "gene", "gene_gene", "gene", "gene_protein", "protein", "protein", "protein", "gene", "gene"],
908
+ task_type=["binary_class", "multi_class", "binary_class", "multi_class", "binary_class", "binary_class", "multi_class", "regression", "multi_class", "multi_class"],
909
+ model_type=["lucappi2", "luca_base", "lucappi", "luca_base", "lucappi2", "lucappi", "luca_base", "luca_base", "luca_base", "luca_base"],
910
+ input_type=["matrix", "matrix", "matrix", "matrix", "matrix", "matrix", "matrix", "matrix", "matrix", "matrix"],
911
+ time_str=["20240406173806", "20240412100337", "20240214105653", "20240414155526", "20240404105148", "20240216205421", "20240412140824", "20240404104215", "20240411144916", "20240212202328"],
912
+ step=[64000, 24500, 9603, 1958484, 716380, 52304, 466005, 70371, 24000, 37000],
913
+ base_url="http://47.93.21.181/lucaone/DownstreamTasksTrainedModels"
914
+ ):
915
+ """
916
+ donwload trained downstream task models
917
+ :param save_dir: 本地保存路径
918
+ :param dataset_name:
919
+ :param dataset_type:
920
+ :param task_type:
921
+ :param model_type:
922
+ :param input_type:
923
+ :param time_str:
924
+ :param step:
925
+ :param base_url:
926
+ :return:
927
+ """
928
+ assert len(dataset_name) == len(dataset_type) == len(task_type) == \
929
+ len(model_type) == len(input_type) == len(time_str) == len(step)
930
+ assert isinstance(dataset_name, list)
931
+ assert isinstance(dataset_type, list)
932
+ assert isinstance(task_type, list)
933
+ assert isinstance(model_type, list)
934
+ assert isinstance(input_type, list)
935
+ assert isinstance(time_str, list)
936
+ assert isinstance(step, list)
937
+ download_succeed_task_num = 0
938
+ print("------Download Trained Models------")
939
+ for idx in range(len(dataset_name)):
940
+ try:
941
+ logs_file_names = ["logs.txt", "label.txt"]
942
+ models_file_names = ["config.json", "pytorch_model.bin", "training_args.bin", "tokenizer/alphabet.pkl"]
943
+ logs_path = "logs/%s/%s/%s/%s/%s/%s" % (dataset_name[idx], dataset_type[idx], task_type[idx], model_type[idx], input_type[idx], time_str[idx])
944
+ models_path = "models/%s/%s/%s/%s/%s/%s/checkpoint-%s" % (dataset_name[idx], dataset_type[idx], task_type[idx], model_type[idx], input_type[idx], time_str[idx], str(step[idx]))
945
+ logs_local_dir = os.path.join(save_dir, logs_path)
946
+ exists = True
947
+ for logs_file_name in logs_file_names:
948
+ if not os.path.exists(os.path.join(logs_local_dir, logs_file_name)):
949
+ exists = False
950
+ break
951
+ models_local_dir = os.path.join(save_dir, models_path)
952
+ if exists:
953
+ for models_file_name in models_file_names:
954
+ if not os.path.exists(os.path.join(models_local_dir, models_file_name)):
955
+ exists = False
956
+ break
957
+ if not exists:
958
+ print("*" * 20 + "Downloading" + "*" * 20)
959
+ print("Downloading Downstream Task: %s TrainedCheckPoint: %s-%s-%s ..." % (dataset_name[idx], dataset_name[idx], time_str[idx], str(step[idx])))
960
+ print("Wait a moment, please.")
961
+ # download logs
962
+ if not os.path.exists(logs_local_dir):
963
+ os.makedirs(logs_local_dir)
964
+ logs_base_url = os.path.join(base_url, dataset_name[idx], logs_path)
965
+ download_folder(logs_base_url, logs_file_names, logs_local_dir)
966
+ # download models
967
+ if not os.path.exists(models_local_dir):
968
+ os.makedirs(models_local_dir)
969
+ models_base_url = os.path.join(base_url, dataset_name[idx], models_path)
970
+ download_folder(models_base_url, models_file_names, models_local_dir)
971
+ print("Downstream Task: %s Trained Model Download Succeed." % dataset_name[idx])
972
+ print("*" * 50)
973
+ download_succeed_task_num += 1
974
+ except Exception as e:
975
+ print(e)
976
+ print("Download automatically LucaDownstream Task: %s Trained CheckPoint failed!" % dataset_name[idx])
977
+ print("You can manually download 'logs/' and 'models/' into local directory: %s/ from %s" % (os.path.abspath(save_dir), os.path.join(base_url, dataset_name[idx])))
978
+ raise Exception(e)
979
+ print("%d Downstream Task Trained Model Download Succeed." % download_succeed_task_num)
vocab.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"[PAD]": 0, "[UNK]": 1, "[CLS]": 2, "[SEP]": 3, "[MASK]": 4, "1": 5, "2": 6, "3": 7, "4": 8, "5": 9, "L": 10, "A": 11, "G": 12, "V": 13, "S": 14, "E": 15, "R": 16, "T": 17, "I": 18, "D": 19, "P": 20, "K": 21, "Q": 22, "N": 23, "F": 24, "Y": 25, "M": 26, "H": 27, "W": 28, "C": 29, "X": 30, "B": 31, "U": 32, "Z": 33, "O": 34, "J": 35, ".": 36, "-": 37, "*": 38}