3v324v23 commited on
Commit
8e64bfa
·
1 Parent(s): ea83b6a
.gitignore ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Initially taken from Github's Python gitignore file
2
+
3
+ # Byte-compiled / optimized / DLL files
4
+ __pycache__/
5
+ *.py[cod]
6
+ *$py.class
7
+
8
+ # C extensions
9
+ *.so
10
+
11
+ # tests and logs
12
+ tests/fixtures/cached_*_text.txt
13
+ logs/
14
+ lightning_logs/
15
+ lang_code_data/
16
+ nohup.out
17
+ output/
18
+
19
+ # Distribution / packaging
20
+ .Python
21
+ build/
22
+ develop-eggs/
23
+ dist/
24
+ downloads/
25
+ eggs/
26
+ .eggs/
27
+ lib/
28
+ lib64/
29
+ parts/
30
+ sdist/
31
+ var/
32
+ wheels/
33
+ *.egg-info/
34
+ .installed.cfg
35
+ *.egg
36
+ MANIFEST
37
+
38
+ # PyInstaller
39
+ # Usually these files are written by a python script from a template
40
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
41
+ *.manifest
42
+ *.spec
43
+
44
+ # Installer logs
45
+ pip-log.txt
46
+ pip-delete-this-directory.txt
47
+
48
+ # Unit test / coverage reports
49
+ htmlcov/
50
+ .tox/
51
+ .nox/
52
+ .coverage
53
+ .coverage.*
54
+ .cache
55
+ nosetests.xml
56
+ coverage.xml
57
+ *.cover
58
+ .hypothesis/
59
+ .pytest_cache/
60
+
61
+ # Translations
62
+ *.mo
63
+ *.pot
64
+
65
+ # Django stuff:
66
+ *.log
67
+ local_settings.py
68
+ db.sqlite3
69
+
70
+ # Flask stuff:
71
+ instance/
72
+ .webassets-cache
73
+
74
+ # Scrapy stuff:
75
+ .scrapy
76
+
77
+ # Sphinx documentation
78
+ docs/_build/
79
+
80
+ # PyBuilder
81
+ target/
82
+
83
+ # Jupyter Notebook
84
+ .ipynb_checkpoints
85
+
86
+ # IPython
87
+ profile_default/
88
+ ipython_config.py
89
+
90
+ # pyenv
91
+ .python-version
92
+
93
+ # celery beat schedule file
94
+ celerybeat-schedule
95
+
96
+ # SageMath parsed files
97
+ *.sage.py
98
+
99
+ # Environments
100
+ .env
101
+ .venv
102
+ env/
103
+ venv/
104
+ ENV/
105
+ env.bak/
106
+ venv.bak/
107
+
108
+ # Spyder project settings
109
+ .spyderproject
110
+ .spyproject
111
+
112
+ # Rope project settings
113
+ .ropeproject
114
+
115
+ # mkdocs documentation
116
+ /site
117
+
118
+ # mypy
119
+ .mypy_cache/
120
+ .dmypy.json
121
+ dmypy.json
122
+
123
+ # Pyre type checker
124
+ .pyre/
125
+
126
+ # vscode
127
+ .vs
128
+ .vscode
129
+
130
+ # Pycharm
131
+ .idea
132
+
133
+ # TF code
134
+ tensorflow_code
135
+
136
+ # Models
137
+ proc_data
138
+
139
+ # examples
140
+ runs
141
+ /runs_old
142
+ /wandb
143
+ /examples/runs
144
+ /examples/**/*.args
145
+ /examples/rag/sweep
146
+ /inv
147
+
148
+ # data
149
+ /data
150
+ serialization_dir
151
+
152
+ # emacs
153
+ *.*~
154
+ debug.env
155
+
156
+ # vim
157
+ .*.swp
158
+
159
+ #ctags
160
+ tags
161
+
162
+ # pre-commit
163
+ .pre-commit*
164
+
165
+ # .lock
166
+ *.lock
167
+
168
+ inv.py
config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "bozhou/DeBERTa-base",
3
+ "architectures": [
4
+ "DeBERTa"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "modeling.config.ModelConfig",
8
+ "AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration",
9
+ "AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration"
10
+ },
11
+ "bos_token_id": 130004,
12
+ "eos_token_id": 130005,
13
+ "mask_token_id": 130000,
14
+ "gmask_token_id": 130001,
15
+ "pad_token_id": 3,
16
+ "hidden_size": 4096,
17
+ "inner_hidden_size": 16384,
18
+ "layernorm_epsilon": 1e-05,
19
+ "max_sequence_length": 2048,
20
+ "model_type": "chatglm",
21
+ "num_attention_heads": 32,
22
+ "num_layers": 28,
23
+ "position_encoding_2d": true,
24
+ "torch_dtype": "float16",
25
+ "transformers_version": "4.23.1",
26
+ "use_cache": true,
27
+ "vocab_size": 130528
28
+ }
modeling/__init__.py CHANGED
@@ -1,37 +0,0 @@
1
- #
2
- # Zhou Bo
3
-
4
- #
5
-
6
- """ Components for NN
7
- """
8
-
9
- from __future__ import absolute_import
10
- from __future__ import division
11
- from __future__ import print_function
12
-
13
- from .tokenizers import *
14
- from .pooling import *
15
- from .mlm import MLMPredictionHead
16
- from .nnmodule import NNModule
17
- from .deberta import *
18
- from .disentangled_attention import *
19
- from .ops import *
20
- from .bert import *
21
- from .config import *
22
- from .cache_utils import *
23
- from .focal_loss import *
24
- # from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
25
- from .modeling import (BertConfig, BertModel, BertForPreTraining, BertForMaskedLM,
26
- BertForNextSentencePrediction, PreTrainedBertModel,
27
- BertForSequenceClassification, BertForMultipleChoice, BertForTokenClassification,
28
- BertForQuestionAnswering, BertForPreTrainingLossMask, BertPreTrainingPairRel,
29
- BertPreTrainingPairTransform, BertPreTrainingHeads, MLMHead)
30
- # from .optimization import BertAdam, BertAdamFineTune
31
- try:
32
- from .optimization_fp16 import FP16_Optimizer_State
33
- except:
34
- pass
35
- from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE
36
- from .flash import FlashQuadModel
37
- from .gat import GatModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling/bert.py CHANGED
@@ -6,17 +6,10 @@
6
 
7
  # This piece of code is modified based on https://github.com/huggingface/transformers
8
 
9
- import copy
10
  import torch
11
  from torch import nn
12
  from collections import Sequence
13
  from packaging import version
14
- import numpy as np
15
- import math
16
- import os
17
- import pdb
18
-
19
- import json
20
  from .ops import *
21
  from .disentangled_attention import *
22
  from .da_utils import *
 
6
 
7
  # This piece of code is modified based on https://github.com/huggingface/transformers
8
 
 
9
  import torch
10
  from torch import nn
11
  from collections import Sequence
12
  from packaging import version
 
 
 
 
 
 
13
  from .ops import *
14
  from .disentangled_attention import *
15
  from .da_utils import *
modeling/cache_utils.py CHANGED
@@ -13,10 +13,7 @@ import os
13
  import requests
14
  from .config import ModelConfig
15
  import pathlib
16
- from ..utils import xtqdm as tqdm
17
- from zipfile import ZipFile
18
  import loguru
19
- # from ..utils import get_logger
20
  logger = loguru.logger
21
 
22
  __all__ = ['pretrained_models', 'load_model_state', 'load_vocab']
@@ -49,36 +46,7 @@ pretrained_models= {
49
  'deberta-v3-xsmall': PretrainedModel('deberta-v3-xsmall', 'spm.model', 'spm'),
50
  }
51
 
52
- def download_asset(url, name, tag=None, no_cache=False, cache_dir=None):
53
- _tag = tag
54
- if _tag is None:
55
- _tag = 'latest'
56
- if not cache_dir:
57
- cache_dir = os.path.join(pathlib.Path.home(), f'.~DeBERTa/assets/{_tag}/')
58
- os.makedirs(cache_dir, exist_ok=True)
59
- output=os.path.join(cache_dir, name)
60
- if os.path.exists(output) and (not no_cache):
61
- return output
62
 
63
- #repo=f'https://huggingface.co/microsoft/deberta-{name}/blob/main/bpe_encoder.bin'
64
- headers = {}
65
- headers['Accept'] = 'application/octet-stream'
66
- resp = requests.get(url, stream=True, headers=headers)
67
- if resp.status_code != 200:
68
- raise Exception(f'Request for {url} return {resp.status_code}, {resp.text}')
69
-
70
- try:
71
- with open(output, 'wb') as fs:
72
- progress = tqdm(total=int(resp.headers['Content-Length']) if 'Content-Length' in resp.headers else -1, ncols=80, desc=f'Downloading {name}')
73
- for c in resp.iter_content(chunk_size=1024*1024):
74
- fs.write(c)
75
- progress.update(len(c))
76
- progress.close()
77
- except:
78
- os.remove(output)
79
- raise
80
-
81
- return output
82
 
83
  def load_model_state(path_or_pretrained_id, tag=None, no_cache=False, cache_dir=None):
84
  model_path = path_or_pretrained_id
@@ -91,9 +59,6 @@ def load_model_state(path_or_pretrained_id, tag=None, no_cache=False, cache_dir=
91
  cache_dir = os.path.join(pathlib.Path.home(), f'.~DeBERTa/assets/{_tag}/{pretrained.name}')
92
  os.makedirs(cache_dir, exist_ok=True)
93
  model_path = os.path.join(cache_dir, 'pytorch_model.bin')
94
- if (not os.path.exists(model_path)) or no_cache:
95
- asset = download_asset(pretrained.model_url, 'pytorch_model.bin', tag=tag, no_cache=no_cache, cache_dir=cache_dir)
96
- asset = download_asset(pretrained.config_url, 'model_config.json', tag=tag, no_cache=no_cache, cache_dir=cache_dir)
97
  elif not model_path:
98
  return None,None
99
 
@@ -107,26 +72,3 @@ def load_model_state(path_or_pretrained_id, tag=None, no_cache=False, cache_dir=
107
  else:
108
  model_config = None
109
  return model_state, model_config
110
-
111
- def load_vocab(vocab_path=None, vocab_type=None, pretrained_id=None, tag=None, no_cache=False, cache_dir=None):
112
- if pretrained_id and (pretrained_id.lower() in pretrained_models):
113
- _tag = tag
114
- if _tag is None:
115
- _tag = 'latest'
116
-
117
- pretrained = pretrained_models[pretrained_id.lower()]
118
- if not cache_dir:
119
- cache_dir = os.path.join(pathlib.Path.home(), f'.~DeBERTa/assets/{_tag}/{pretrained.name}')
120
- os.makedirs(cache_dir, exist_ok=True)
121
- vocab_type = pretrained.vocab_type
122
- url = pretrained.vocab_url
123
- outname = os.path.basename(url)
124
- vocab_path =os.path.join(cache_dir, outname)
125
- if (not os.path.exists(vocab_path)) or no_cache:
126
- asset = download_asset(url, outname, tag=tag, no_cache=no_cache, cache_dir=cache_dir)
127
- if vocab_type is None:
128
- vocab_type = 'spm'
129
- return vocab_path, vocab_type
130
-
131
- def test_download():
132
- vocab = load_vocab()
 
13
  import requests
14
  from .config import ModelConfig
15
  import pathlib
 
 
16
  import loguru
 
17
  logger = loguru.logger
18
 
19
  __all__ = ['pretrained_models', 'load_model_state', 'load_vocab']
 
46
  'deberta-v3-xsmall': PretrainedModel('deberta-v3-xsmall', 'spm.model', 'spm'),
47
  }
48
 
 
 
 
 
 
 
 
 
 
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def load_model_state(path_or_pretrained_id, tag=None, no_cache=False, cache_dir=None):
52
  model_path = path_or_pretrained_id
 
59
  cache_dir = os.path.join(pathlib.Path.home(), f'.~DeBERTa/assets/{_tag}/{pretrained.name}')
60
  os.makedirs(cache_dir, exist_ok=True)
61
  model_path = os.path.join(cache_dir, 'pytorch_model.bin')
 
 
 
62
  elif not model_path:
63
  return None,None
64
 
 
72
  else:
73
  model_config = None
74
  return model_state, model_config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling/config.py CHANGED
@@ -1,8 +1,114 @@
1
  import json
2
  import copy
3
 
 
 
4
  __all__=['AbsModelConfig', 'ModelConfig']
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  class AbsModelConfig(object):
7
  def __init__(self):
8
  pass
 
1
  import json
2
  import copy
3
 
4
+ from transformers.configuration_utils import PretrainedConfig
5
+
6
  __all__=['AbsModelConfig', 'ModelConfig']
7
 
8
+
9
+ class DebertaConfig(PretrainedConfig):
10
+ model_type = 'deberta-v2'
11
+
12
+ def __init__(self,
13
+ vocab_size_or_config_json_file,
14
+ hidden_size=768,
15
+ num_hidden_layers=12,
16
+ num_attention_heads=12,
17
+ intermediate_size=3072,
18
+ hidden_act="gelu",
19
+ hidden_dropout_prob=0.1,
20
+ attention_probs_dropout_prob=0.1,
21
+ max_position_embeddings=512,
22
+ type_vocab_size=2,
23
+ relax_projection=0,
24
+ new_pos_ids=False,
25
+ initializer_range=0.02,
26
+ task_idx=None,
27
+ fp32_embedding=False,
28
+ ffn_type=0,
29
+ label_smoothing=None,
30
+ num_qkv=0,
31
+ seg_emb=False):
32
+ """Constructs BertConfig.
33
+
34
+ Args:
35
+ vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
36
+ hidden_size: Size of the encoder layers and the pooler layer.
37
+ num_hidden_layers: Number of hidden layers in the Transformer encoder.
38
+ num_attention_heads: Number of attention heads for each attention layer in
39
+ the Transformer encoder.
40
+ intermediate_size: The size of the "intermediate" (i.e., feed-forward)
41
+ layer in the Transformer encoder.
42
+ hidden_act: The non-linear activation function (function or string) in the
43
+ encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
44
+ hidden_dropout_prob: The dropout probabilitiy for all fully connected
45
+ layers in the embeddings, encoder, and pooler.
46
+ attention_probs_dropout_prob: The dropout ratio for the attention
47
+ probabilities.
48
+ max_position_embeddings: The maximum sequence length that this model might
49
+ ever be used with. Typically set this to something large just in case
50
+ (e.g., 512 or 1024 or 2048).
51
+ type_vocab_size: The vocabulary size of the `token_type_ids` passed into
52
+ `BertModel`.
53
+ initializer_range: The sttdev of the truncated_normal_initializer for
54
+ initializing all weight matrices.
55
+ """
56
+ if isinstance(vocab_size_or_config_json_file, str):
57
+ with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
58
+ json_config = json.loads(reader.read())
59
+ for key, value in json_config.items():
60
+ self.__dict__[key] = value
61
+ elif isinstance(vocab_size_or_config_json_file, int):
62
+ self.vocab_size = vocab_size_or_config_json_file
63
+ self.hidden_size = hidden_size
64
+ self.num_hidden_layers = num_hidden_layers
65
+ self.num_attention_heads = num_attention_heads
66
+ self.hidden_act = hidden_act
67
+ self.intermediate_size = intermediate_size
68
+ self.hidden_dropout_prob = hidden_dropout_prob
69
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
70
+ self.max_position_embeddings = max_position_embeddings
71
+ self.type_vocab_size = type_vocab_size
72
+ self.relax_projection = relax_projection
73
+ self.new_pos_ids = new_pos_ids
74
+ self.initializer_range = initializer_range
75
+ self.task_idx = task_idx
76
+ self.fp32_embedding = fp32_embedding
77
+ self.ffn_type = ffn_type
78
+ self.label_smoothing = label_smoothing
79
+ self.num_qkv = num_qkv
80
+ self.seg_emb = seg_emb
81
+ else:
82
+ raise ValueError("First argument must be either a vocabulary size (int)"
83
+ "or the path to a pretrained model config file (str)")
84
+
85
+ # @classmethod
86
+ # def from_dict(cls, json_object):
87
+ # """Constructs a `BertConfig` from a Python dictionary of parameters."""
88
+ # config = DebertaConfig(vocab_size_or_config_json_file=-1)
89
+ # for key, value in json_object.items():
90
+ # config.__dict__[key] = value
91
+ # return config
92
+
93
+ # @classmethod
94
+ # def from_json_file(cls, json_file):
95
+ # """Constructs a `BertConfig` from a json file of parameters."""
96
+ # with open(json_file, "r", encoding='utf-8') as reader:
97
+ # text = reader.read()
98
+ # return cls.from_dict(json.loads(text))
99
+
100
+ # def __repr__(self):
101
+ # return str(self.to_json_string())
102
+
103
+ # def to_dict(self):
104
+ # """Serializes this instance to a Python dictionary."""
105
+ # output = copy.deepcopy(self.__dict__)
106
+ # return output
107
+
108
+ # def to_json_string(self):
109
+ # """Serializes this instance to a JSON string."""
110
+ # return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
111
+
112
  class AbsModelConfig(object):
113
  def __init__(self):
114
  pass
modeling/da_utils.py CHANGED
@@ -1,5 +1,4 @@
1
  import torch
2
- import pdb
3
  from functools import lru_cache
4
  import numpy as np
5
 
 
1
  import torch
 
2
  from functools import lru_cache
3
  import numpy as np
4
 
modeling/deberta.py CHANGED
@@ -9,14 +9,10 @@
9
 
10
  import copy
11
  import torch
12
- import os
13
-
14
- import json
15
  from .ops import *
16
  from .bert import *
17
  from .config import ModelConfig
18
  from .cache_utils import load_model_state
19
- import pdb
20
 
21
  __all__ = ['DeBERTa']
22
 
 
9
 
10
  import copy
11
  import torch
 
 
 
12
  from .ops import *
13
  from .bert import *
14
  from .config import ModelConfig
15
  from .cache_utils import load_model_state
 
16
 
17
  __all__ = ['DeBERTa']
18
 
modeling/disentangled_attention.py CHANGED
@@ -11,12 +11,9 @@
11
  Disentangled SelfAttention module
12
  """
13
 
14
- import numpy as np
15
  import math
16
  import torch
17
  from torch import nn
18
- import functools
19
- import pdb
20
 
21
  from .ops import *
22
  from .da_utils import build_relative_position
 
11
  Disentangled SelfAttention module
12
  """
13
 
 
14
  import math
15
  import torch
16
  from torch import nn
 
 
17
 
18
  from .ops import *
19
  from .da_utils import build_relative_position
modeling/flash.py DELETED
@@ -1,794 +0,0 @@
1
- #
2
- # Zhoubo
3
- #
4
- """
5
- FLASH: https://arxiv.org/abs/2202.10447
6
- """
7
- import copy
8
- import torch
9
- import os
10
- from collections import Sequence
11
- import json
12
-
13
- import torch
14
- import torch.nn as nn
15
- import torch.nn.functional as F
16
- from transformers.activations import ACT2FN
17
- from .modeling import *
18
- from .ops import XSoftmax, sequence_masking
19
-
20
- from .bert import *
21
- from .config import ModelConfig
22
- from .cache_utils import load_model_state
23
- import einops
24
-
25
-
26
- class ScaleNorm(nn.Module):
27
- def __init__(self, eps=1e-5):
28
- super().__init__()
29
- self.eps = eps
30
- self.scala = nn.Parameter(torch.ones(1))
31
-
32
- def forward(self, x):
33
- mean_square = (x ** 2).mean(dim=-1, keepdim=True)
34
- x = x * torch.rsqrt(mean_square + self.eps) * self.scala
35
- return x
36
-
37
-
38
-
39
- class OffsetScale(nn.Module):
40
- def __init__(self, dim, heads = 1):
41
- super().__init__()
42
- self.gamma = nn.Parameter(torch.ones(heads, dim))
43
- self.beta = nn.Parameter(torch.zeros(heads, dim))
44
- # nn.init.normal_(self.gamma, std = 0.02)
45
- # nn.init.xavier_uniform_(self.gamma)
46
-
47
- def forward(self, x):
48
- out = (x * self.gamma) + self.beta
49
- return out
50
-
51
-
52
- class ScaledSinuEmbedding(nn.Module):
53
- def __init__(self, dim):
54
- super().__init__()
55
- self.scale = nn.Parameter(torch.ones(1,))
56
- inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
57
- self.register_buffer('inv_freq', inv_freq)
58
-
59
- def forward(self, x):
60
- n, device = x.shape[1], x.device
61
- t = torch.arange(n, device = device).type_as(self.inv_freq)
62
- sinu = torch.einsum('i , j -> i j', t, self.inv_freq)
63
- emb = torch.cat((sinu.sin(), sinu.cos()), dim = -1)
64
- return emb * self.scale
65
-
66
-
67
- def RoPE(x, dim):
68
- """
69
- :param x: input tensor
70
- :param dim: oprate dimension
71
- :return: tensor
72
- """
73
- shape = x.shape
74
- if isinstance(dim, int):
75
- dim = [dim]
76
-
77
- spatial_shape = [shape[i] for i in dim]
78
- total_len = 1
79
- for i in spatial_shape:
80
- total_len *= i
81
- position = torch.reshape(torch.arange(total_len, dtype=torch.float, device=x.device), spatial_shape)
82
-
83
- for i in range(dim[-1] + 1, len(shape) - 1, 1):
84
- position = torch.unsqueeze(position, dim=-1)
85
-
86
- half_size = shape[-1] // 2
87
- freq_seq = -torch.arange(half_size, dtype=torch.float, device=x.device) / float(half_size)
88
- inv_freq = 10000 ** -freq_seq
89
- sinusoid = torch.einsum("...,d->...d", position, inv_freq)
90
- sin = torch.sin(sinusoid).repeat_interleave(2, -1)
91
- cos = torch.cos(sinusoid).repeat_interleave(2, -1)
92
- tensor_cross = torch.stack([-x[..., 1:: 2], x[..., :: 2]], -1).reshape(x.shape)
93
- # x1, x2 = torch.chunk(x, 2, dim=-1)
94
- # return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
95
- return x * cos + tensor_cross * sin
96
-
97
-
98
- def rel_pos_bias(seq_len, s):
99
- a = torch.rand([1, s], dtype=torch.float)
100
- b = torch.rand([1, s], dtype=torch.float)
101
- w = torch.rand([2 * seq_len - 1], dtype=torch.float)
102
- if seq_len <= 512:
103
- t = F.pad(w[: 2 * seq_len - 1], [0, seq_len]).repeat(seq_len)
104
- t = t[..., :-seq_len].reshape(-1, seq_len, 3 * seq_len - 2)
105
- r = (2 * seq_len - 1) // 2
106
- t = t[..., r:-r]
107
- else:
108
- a = RoPE(a.repeat(seq_len, 1), dim=[0])
109
- b = RoPE(b.repeat(seq_len, 1), dim=[0])
110
- t = torch.einsum("mk,nk->mn", a, b)
111
- return t
112
-
113
- def squared_relu(x, attention_mask, dim=-1):
114
- rmask = ~(attention_mask.bool())
115
- x = x.masked_fill(rmask, 0)
116
- return torch.square(F.relu(x))
117
-
118
-
119
- def attention_normalize(a, axis=-1, mask=None, fn='softmax'):
120
- if fn == 'softmax':
121
- return XSoftmax.apply(a, mask, axis)
122
- else:
123
- mask_ = a > -float('inf') / 10
124
- # mask_ = mask_.byte()
125
- mask_ = torch.sum(mask_, axis=axis, keepdim=True)
126
- l = torch.maximum(mask_, torch.ones_like(mask_))
127
- if fn == 'squared_relu':
128
- rmask = ~(mask.bool())
129
- a = a.masked_fill(rmask, 0)
130
- return torch.square(F.relu(a)) / l
131
- elif fn == 'softmax_plus':
132
- return XSoftmax.apply(a * torch.log(l) / np.log(512), mask, axis)
133
- return a
134
-
135
-
136
- class GAULinear(nn.Linear):
137
- def init_weight(self):
138
- nn.init.xavier_uniform_(self.weight)
139
-
140
-
141
- class GatedAttentionUnit(nn.Module):
142
- """
143
- GAU Block: Gate Attention Unit
144
- """
145
- def __init__(
146
- self,
147
- max_seq_length,
148
- hidden_size,
149
- attention_key_size=128,
150
- activation='swish',
151
- use_bias=True,
152
- attention_norm_type='squared_relu',
153
- attention_scale=True,
154
- dropout=0.1,
155
- pre_norm=False,
156
- norm_type="layer_norm",
157
- eps=1e-5,
158
- shift_token=False,
159
- use_rel_bias=False,
160
- add_residual=True,
161
- **kwargs,):
162
-
163
- super(GatedAttentionUnit, self).__init__(**kwargs)
164
- self.max_seq_length = max_seq_length
165
- self.units = hidden_size
166
- self.intermediate_size = self.units * 2
167
- self.key_size = attention_key_size
168
- self.activation = activation
169
- self.use_bias = use_bias
170
- self.attention_norm_type = attention_norm_type
171
- self.attention_scale = attention_scale
172
- self.dropout = StableDropout(dropout)
173
- self.i_dense = nn.Sequential(
174
- nn.Linear(self.units, 2 * self.intermediate_size + self.key_size, bias=self.use_bias),
175
- nn.SiLU()
176
- )
177
- self.o_dense = nn.Sequential(
178
- nn.Linear(self.intermediate_size, self.units, bias=self.use_bias),
179
- self.dropout)
180
- self.q_scaleoffset = OffsetScale(self.key_size)
181
- self.k_scaleoffset = OffsetScale(self.key_size)
182
- self.pre_norm = pre_norm
183
- self.norm = (nn.LayerNorm(hidden_size, eps=eps) if norm_type.lower() == "layer_norm" else ScaleNorm(eps=eps))
184
- self.add_residual = add_residual
185
-
186
- def forward(self, x, attention_mask=None, **kwargs):
187
- shortcut = x
188
-
189
- if self.pre_norm:
190
- x = self.norm(x)
191
-
192
- x = self.i_dense(x)
193
- u, v, qk = torch.split(x, [self.intermediate_size, self.intermediate_size, self.key_size], dim=-1)
194
- q, k = self.q_scaleoffset(qk), self.k_scaleoffset(qk)
195
- qk = RoPE(torch.stack([q, k], 2), dim=1)
196
- q, k = qk[:, :, 0], qk[:, :, 1]
197
- a = torch.einsum('bmd,bnd->bmn', q, k)
198
- if self.attention_scale:
199
- a = a / self.key_size**0.5
200
- a = sequence_masking(a, attention_mask, '-inf', -1)
201
- A = attention_normalize(a, -1, fn=self.attention_norm_type)
202
- if self.dropout:
203
- A = self.dropout(A)
204
- out = self.o_dense(u * torch.einsum('bmn,bnd->bmd', A, v))
205
-
206
- if self.add_residual:
207
- out = out + shortcut
208
- if not self.pre_norm:
209
- out = self.norm(out)
210
- return out
211
- # # 加入RoPE
212
- # if p_bias == 'rotary':
213
- # qk = K.stack([q, k], 2)
214
- # qk = apply_rotary_position_embeddings(inputs[n], qk)[0]
215
- # q, k = qk[:, :, 0], qk[:, :, 1]
216
- # # Attention
217
- # a = tf.einsum('bmd,bnd->bmn', q, k)
218
- # if self.attention_scale:
219
- # a = a / self.key_size**0.5
220
- # if a_bias is not None:
221
- # a = a + a_bias
222
- # a = sequence_masking(a, mask, '-inf', -1)
223
- # A = attention_normalize(a, -1, self.normalization)
224
- # if self.attention_dropout:
225
- # A = Dropout(self.attention_dropout)(A)
226
- # # 计算输出
227
- # o = self.o_dense(u * tf.einsum('bmn,bnd->bmd', A, v))
228
-
229
- # return o
230
-
231
- class GAU(nn.Module):
232
- def __init__(self, max_seq_length, hidden_size, expansion_factor=2, s=128, norm_type="layer_norm", eps=1e-5,
233
- hidden_act="silu", shift_token=False, use_rel_bias=False, attention_norm_type='softmax',
234
- pre_norm=False, dropout=0, add_residual = True):
235
- super(GAU, self).__init__()
236
- self.max_seq_length = max_seq_length
237
- self.shift_token = shift_token
238
- hidden_dim = int(expansion_factor * hidden_size)
239
- self.norm = (nn.LayerNorm(hidden_size, eps=eps) if norm_type == "layer_norm" else ScaleNorm(eps=eps))
240
- self.use_rel_bias = use_rel_bias
241
- self.attention_norm_type = attention_norm_type
242
- # if attention_norm_type == 'relu':
243
- # self.attention_norm_func = squared_relu
244
- # else:
245
- # self.attention_norm_func = XSoftmax.apply
246
- # self.norm = norm_klass(hidden_size)
247
-
248
- self.dropout = nn.Dropout(dropout)
249
-
250
- self.to_hidden = nn.Sequential(
251
- nn.Linear(hidden_size, hidden_dim * 2),
252
- nn.SiLU()
253
- )
254
-
255
- self.to_qk = nn.Sequential(
256
- nn.Linear(hidden_size, s),
257
- nn.SiLU()
258
- )
259
-
260
- self.offsetscale = OffsetScale(s, heads = 2)
261
-
262
- self.to_out = nn.Sequential(
263
- nn.Linear(hidden_dim, hidden_size),
264
- nn.Dropout(dropout)
265
- )
266
-
267
- self.add_residual = add_residual
268
- self.act_fn = ACT2FN[hidden_act]
269
- self.pre_norm = pre_norm
270
-
271
-
272
- def forward(
273
- self,
274
- x,
275
- relative_pos = None,
276
- attention_mask = None
277
- ):
278
- seq_len, device = x.shape[-2], x.device
279
- if self.pre_norm:
280
- normed_x = self.norm(x)
281
- else:
282
- normed_x = x
283
- v, gate = self.to_hidden(normed_x).chunk(2, dim = -1)
284
-
285
- qk = self.to_qk(normed_x)
286
- base = self.offsetscale(qk)
287
- base = RoPE(base, 1)
288
- q, k = base.unbind(dim = -2)
289
- sim = torch.einsum('b i d, b j d -> b i j', q, k)
290
-
291
- if relative_pos is not None:
292
- sim = sim + relative_pos
293
- if attention_mask is not None:
294
- if attention_mask.dim() < 3:
295
- attention_mask = einops.rearrange(attention_mask, 'b j -> b 1 j')
296
- # attn = attn.masked_fill(~attention_mask.bool(), 0.)
297
- attn = attention_normalize(sim, mask=attention_mask, fn=self.attention_norm_type)
298
- # attn = F.relu(sim) ** 2 / seq_len# / q.size(-1)
299
- # logger.info(attn.max())
300
- attn = self.dropout(attn)
301
- # if self.causal:
302
- # causal_mask = torch.ones((seq_len, seq_len), dtype = torch.bool, device = device).triu(1)
303
- # attn = attn.masked_fill(causal_mask, 0.)
304
-
305
- out = torch.einsum('b i j, b j d -> b i d', attn, v)
306
- out = out * gate
307
-
308
- out = self.to_out(out)
309
-
310
- if self.add_residual:
311
- out = out + x
312
- if not self.pre_norm:
313
- out = self.norm(out)
314
- return out
315
-
316
-
317
- class GAULayer(nn.Module):
318
- def __init__(self, config, shift_token=False, use_ffn=False):
319
- super(GAULayer, self).__init__()
320
- self.attention = GatedAttentionUnit(config.max_position_embeddings, config.hidden_size,
321
- shift_token=shift_token, use_rel_bias=config.use_rel_bias,
322
- norm_type=config.norm_type, attention_norm_type=config.attention_norm_type,
323
- pre_norm=config.pre_norm, dropout=config.hidden_dropout_prob)
324
- if use_ffn:
325
- self.intermediate = BertIntermediate(config)
326
- self.output = BertOutput(config)
327
- self.use_ffn = use_ffn
328
-
329
- def forward(self, hidden_states, attention_mask, return_att=False, query_states=None, relative_pos=None, rel_embeddings=None):
330
- attention_output = self.attention(hidden_states, attention_mask=attention_mask, relative_pos=relative_pos)
331
- if self.use_ffn:
332
- intermediate_output = self.intermediate(attention_output)
333
- layer_output = self.output(intermediate_output, attention_output)
334
- return layer_output
335
- else:
336
- return attention_output
337
-
338
-
339
- class FlashBlock(nn.Module):
340
- """
341
- FLASH Block: Fast Linear Attention with a Single Head
342
- """
343
-
344
- def __init__(self, model_size, sequence_length, chunk_size=256, expansion_factor=2, s=128, norm_type="layer_norm", eps=1e-5,
345
- hidden_act="silu"):
346
- super(FlashBlock, self).__init__()
347
- self.s = s
348
- self.eps = eps
349
- self.norm_type = norm_type
350
- self.model_size = model_size
351
- self.chunk_size = chunk_size
352
- self.hidden_act = hidden_act
353
- self.sequence_length = sequence_length
354
- self.expansion_factor = expansion_factor
355
- self.e = int(self.model_size * self.expansion_factor)
356
-
357
- self.dense1 = nn.Linear(self.model_size, 2 * self.e + self.s, bias=True)
358
- self.gamma = nn.Parameter(torch.rand((4, self.s)))
359
- self.beta = nn.Parameter(torch.rand((4, self.s)))
360
- self.dense2 = nn.Linear(self.e, self.model_size)
361
- self.LayerNorm = (
362
- nn.LayerNorm(model_size, eps=self.eps) if norm_type == "layer_norm" else ScaleNorm(eps=self.eps))
363
-
364
- nn.init.xavier_normal_(self.dense1.weight)
365
- self.act_fn = ACT2FN(self.hidden_act)
366
-
367
- def global_linear_attention(self, query, key, value, causal):
368
- if causal:
369
- kv = torch.einsum("bgcs, bgce->bgse", key, value)
370
- kv = torch.cumsum(kv, dim=1)
371
- lin_v = torch.einsum("bgcs, bgse->bgce", query, kv)
372
- return lin_v
373
- else:
374
- kv = torch.einsum("bgcs, bgce->bse", key, value)
375
- lin_v = torch.einsum("bgcs, bse->bgce", query, kv)
376
- return lin_v
377
-
378
- def segment_ids_to_mask(self, segment_ids, causal=False):
379
- """Generate the segment mask from the segment ids.
380
- The segment mask is used to remove the attention between tokens in different documents.
381
- """
382
- min_ids, max_ids = torch.min(segment_ids, dim=-1).values, torch.max(segment_ids, dim=-1).values
383
- # 1.0 indicates in the same group and 0.0 otherwise
384
- mask = torch.logical_and(torch.less_equal(min_ids[:, :, None], max_ids[:, None, :]),
385
- torch.greater_equal(max_ids[:, :, None], min_ids[:, None, :]))
386
- mask = torch.tensor(mask, torch.float32)
387
- if causal:
388
- g = segment_ids.size()[1]
389
- causal_mask = 1.0 - torch.triu(torch.ones([g, g], dtype=torch.float32)) # 保留主对角线以及主对角线以上的元素
390
- mask *= causal_mask
391
- mask = torch.div(mask, torch.sum(mask, dim=-1, keepdim=True))
392
- return mask
393
-
394
- def forward(self, x, causal=False, attention_mask=None, sequence_mask=None, **kwargs):
395
- """
396
- inputs: [batch_size, num_chunk, chunk_length, model_size]
397
- """
398
- _, g, n, d = x.size()
399
- shortcut, x = x, self.LayerNorm(x)
400
- # 通过线性变换得到Z,见论文公式(4)
401
- uv = self.dense1(x)
402
- # 将uv按最后一维切分,得到Ug:[C*e],Vg:[C*e], Zg:[C*s], 论文中的3.2部分
403
- # u:[batch_size, num_chunk, chunk_length, self.e]
404
- # v:[batch_size, num_chunk, chunk_length, self.e]
405
- # z:[batch_size, num_chunk, chunk_length, self.s]
406
- u, v, z = torch.split(self.act_fn(uv), [self.e, self.e, self.s], dim=-1)
407
-
408
- # 生���quad_q, quad_k, lin_q, lin_k
409
- # 首先进行简单的offset和scale,融入RoPE位置向量
410
- z = torch.einsum("...r, hr->...hr", z, self.gamma) + self.beta
411
- z = RoPE(z, dim=[1, 2])
412
- quad_q, quad_k, lin_q, lin_k = torch.unbind(z, dim=-2) # 按-2维进行分解得到quad_q, quad_k, lin_q和lin_k
413
- # 计算global的lin_v
414
- lin_v = self.global_linear_attention(lin_q, lin_k, v, causal)
415
- if causal:
416
- # 线性注意力部分
417
- lin_kv = torch.einsum("bgnk, bgne->bgke", lin_k, lin_v) / torch.tensor(n, x.dtype) # 见公式(7)
418
- mask = self.segment_ids_to_mask(segment_ids=segment_ids, causal=causal)
419
- cum_lin_kv = torch.einsum('bhke, bgh->bgke', lin_kv, mask)
420
- linear = torch.einsum("bgnk, bgke->bgne", lin_kv, cum_lin_kv)
421
- # 二次注意力
422
- quad_qk = torch.einsum("bgnk, bgmk->bgnm", quad_q, quad_k) # 论文Local attention per chunk部分
423
- bias = rel_pos_bias(self.sequence_length, self.s)[:, :n, :n]
424
- kernel = torch.square(F.relu(quad_qk / n + bias)) # 论文中的relu**2部分
425
- causal_mask = torch.triu(torch.ones([n, n], dtype=x.dtype))
426
- quadratic = torch.einsum("bgnm, bgme->bgne", kernel * causal_mask, v)
427
- else:
428
- lin_kv = torch.einsum("bgnk, bgne->bgke", lin_k, lin_v) / torch.tensor(n, x.dtype) # 见公式(7)
429
- mask = self.segment_ids_to_mask(segment_ids=segment_ids, causal=causal)
430
- lin_kv = torch.einsum("bhke, bgh->bgke", lin_kv, mask)
431
- linear = torch.einsum("bgnk, bgke->bgne", lin_q, lin_kv)
432
- # 二次注意力
433
- quad_qk = torch.einsum("bgnk, bgmk->bgnm", quad_q, quad_k) # 论文Local attention per chunk部分
434
- bias = rel_pos_bias(self.sequence_length, self.s)[:, :n, :n]
435
- kernel = torch.square(F.relu(quad_qk / n + bias)) # 论文中的relu**2部分
436
- quadratic = torch.einsum("bgnm, bgme->bgne", kernel, v)
437
- x = u * (quadratic + linear)
438
- x = self.dense2(x)
439
- x = x + shortcut
440
- return x
441
-
442
- class RelativePositionBias(nn.Module):
443
- def __init__(
444
- self,
445
- scale,
446
- causal = False,
447
- num_buckets = 32,
448
- max_distance = 128
449
- ):
450
- super().__init__()
451
- self.scale = scale
452
- self.causal = causal
453
- self.num_buckets = num_buckets
454
- self.max_distance = max_distance
455
- self.relative_attention_bias = nn.Embedding(num_buckets, 1)
456
-
457
- @staticmethod
458
- def _relative_position_bucket(
459
- relative_position,
460
- causal = True,
461
- num_buckets = 32,
462
- max_distance = 128
463
- ):
464
- ret = 0
465
- n = -relative_position
466
- if not causal:
467
- num_buckets //= 2
468
- ret += (n < 0).long() * num_buckets
469
- n = torch.abs(n)
470
- else:
471
- n = torch.max(n, torch.zeros_like(n))
472
-
473
- max_exact = num_buckets // 2
474
- is_small = n < max_exact
475
-
476
- val_if_large = max_exact + (
477
- torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
478
- ).long()
479
- val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
480
-
481
- ret += torch.where(is_small, n, val_if_large)
482
- return ret
483
-
484
- def forward(self, x):
485
- i, j, device = *x.shape[-2:], x.device
486
- q_pos = torch.arange(i, dtype = torch.long, device = device)
487
- k_pos = torch.arange(j, dtype = torch.long, device = device)
488
- rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
489
- rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
490
- values = self.relative_attention_bias(rp_bucket)
491
- bias = rearrange(values, 'i j 1 -> i j')
492
- return bias * self.scale
493
-
494
-
495
- class FlashEmbeddings(nn.Module):
496
- """Construct the embeddings from word, position and token_type embeddings.
497
- """
498
- def __init__(self, config, with_position=False):
499
- super(FlashEmbeddings, self).__init__()
500
- self.word_embeddings = nn.Embedding(
501
- config.vocab_size, config.hidden_size)
502
- self.token_type_embeddings = nn.Embedding(
503
- config.type_vocab_size, config.hidden_size)
504
- self.with_position = with_position
505
- if with_position:
506
- self.position_embeddings = ScaledSinuEmbedding(config.hidden_size)
507
-
508
- # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
509
- # any TensorFlow checkpoint file
510
- self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5)
511
- self.dropout = StableDropout(config.hidden_dropout_prob)
512
-
513
- def forward(self, input_ids, token_type_ids=None, position_ids=None, token_mask=None):
514
- seq_length = input_ids.size(1)
515
- if position_ids is None:
516
- position_ids = torch.arange(
517
- seq_length, dtype=torch.long, device=input_ids.device)
518
- position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
519
- if token_type_ids is None:
520
- token_type_ids = torch.zeros_like(input_ids)
521
-
522
- words_embeddings = self.word_embeddings(input_ids)
523
- if self.with_position:
524
- position_embeddings = self.position_embeddings(words_embeddings)
525
- else:
526
- position_embeddings = 0
527
- token_type_embeddings = self.token_type_embeddings(token_type_ids)
528
-
529
- # if self.num_pos_emb > 1:
530
- # num_batch = position_embeddings.size(0)
531
- # num_pos = position_embeddings.size(1)
532
- # position_embeddings = position_embeddings.view(
533
- # num_batch, num_pos, self.num_pos_emb, -1)[torch.arange(0, num_batch).long(), :, task_idx, :]
534
-
535
- embeddings = words_embeddings + position_embeddings + token_type_embeddings
536
- # if self.fp32_embedding:
537
- # embeddings = embeddings.half()
538
- embeddings = MaskedLayerNorm(self.LayerNorm, embeddings, token_mask)
539
- embeddings = self.dropout(embeddings)
540
- return {
541
- 'embeddings': embeddings,
542
- 'position_embeddings': position_embeddings}
543
-
544
-
545
- class GAUEncoder(nn.Module):
546
- def __init__(self, config, shift_token=False):
547
- super().__init__()
548
- layer = GAULayer(config, shift_token=shift_token)
549
- self.layer = nn.ModuleList([copy.deepcopy(layer)
550
- for _ in range(config.num_hidden_layers)])
551
-
552
- def get_attention_mask(self, attention_mask):
553
- if attention_mask.dim() <= 2:
554
- extended_attention_mask = attention_mask.unsqueeze(1)
555
- attention_mask = extended_attention_mask*extended_attention_mask.squeeze(-2).unsqueeze(-1)
556
- attention_mask = attention_mask #.byte()
557
- return attention_mask
558
-
559
- def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, return_att=False, query_states = None, relative_pos=None):
560
- all_encoder_layers = []
561
- att_matrices = []
562
- if isinstance(hidden_states, Sequence):
563
- next_kv = hidden_states[0]
564
- else:
565
- next_kv = hidden_states
566
- # rel_embeddings = self.get_rel_embedding()
567
- for i, layer_module in enumerate(self.layer):
568
- output_states = layer_module(next_kv, attention_mask, query_states = query_states, relative_pos=relative_pos)
569
- if return_att:
570
- output_states, att_m = output_states
571
-
572
- # if i == 0 and self.with_conv:
573
- # prenorm = output_states #output['prenorm_states']
574
- # output_states = self.conv(hidden_states, prenorm, input_mask)
575
-
576
- if query_states is not None:
577
- query_states = output_states
578
- if isinstance(hidden_states, Sequence):
579
- next_kv = hidden_states[i+1] if i+1 < len(self.layer) else None
580
- else:
581
- next_kv = output_states
582
-
583
- if output_all_encoded_layers:
584
- all_encoder_layers.append(output_states)
585
- if return_att:
586
- att_matrices.append(att_m)
587
- if not output_all_encoded_layers:
588
- all_encoder_layers.append(output_states)
589
- if return_att:
590
- att_matrices.append(att_m)
591
- return {
592
- 'hidden_states': all_encoder_layers,
593
- 'attention_matrices': att_matrices
594
- }
595
-
596
- class FlashEncoder(nn.Module):
597
- def __init__(self, config):
598
- super().__init__(config)
599
- layer = GateAttentionUnit(config.max_position_embeddings, config.hidden_size)
600
- self.layer = nn.ModuleList([copy.deepcopy(layer)
601
- for _ in range(config.num_hidden_layers)])
602
-
603
- def forward(self, hidden_states, attention_mask, token_mask=None,
604
- output_all_encoded_layers=True,
605
- prev_embedding=None, prev_encoded_layers=None, mask_qkv=None, seg_ids=None):
606
- # history embedding and encoded layer must be simultanously given
607
- assert (prev_embedding is None) == (prev_encoded_layers is None)
608
-
609
- all_encoder_layers = []
610
- if (prev_embedding is not None) and (prev_encoded_layers is not None):
611
- history_states = prev_embedding
612
- for i, layer_module in enumerate(self.layer):
613
- hidden_states = layer_module(
614
- hidden_states, attention_mask, history_states=history_states, mask_qkv=mask_qkv, seg_ids=seg_ids)
615
- if output_all_encoded_layers:
616
- all_encoder_layers.append(hidden_states)
617
- if prev_encoded_layers is not None:
618
- history_states = prev_encoded_layers[i]
619
- else:
620
- for layer_module in self.layer:
621
- hidden_states = layer_module(
622
- hidden_states, attention_mask=attention_mask, mask_qkv=mask_qkv, seg_ids=seg_ids)
623
- if output_all_encoded_layers:
624
- all_encoder_layers.append(hidden_states)
625
- if not output_all_encoded_layers:
626
- all_encoder_layers.append(hidden_states)
627
- return all_encoder_layers
628
-
629
- # class FlashQuadModel(BertModel):
630
- # def __init__(self, config, pooler=False, shift_token=False, causal=False) -> None:
631
- # super().__init__(config)
632
- # self.embeddings = FlashEmbeddings(config)
633
- # self.encoder = GAUEncoder(config, causal=causal, shift_token=shift_token)
634
- # if not pooler:
635
- # self.pooler = None
636
- # self.apply(self.init_bert_weights)
637
-
638
-
639
- class FlashQuadModel(torch.nn.Module):
640
- """
641
- Parameters:
642
- config:
643
- A model config class instance with the configuration to build a new model. The schema is similar to `BertConfig`,
644
-
645
- pre_trained:
646
- The pre-trained DeBERTa model, it can be a physical path of a pre-trained DeBERTa model or a released configurations,
647
- i.e. [**base, large, base_mnli, large_mnli**]
648
-
649
- """
650
-
651
- def __init__(self, config=None, pre_trained=None, pooler=False, shift_token=False, causal=False, **kwargs):
652
- super().__init__()
653
- state = None
654
- if pre_trained is not None:
655
- state, model_config = load_model_state(pre_trained)
656
- if config is not None and model_config is not None:
657
- for k in config.__dict__:
658
- if k not in ['hidden_size',
659
- 'intermediate_size',
660
- 'num_attention_heads',
661
- 'num_hidden_layers',
662
- 'vocab_size',
663
- 'max_position_embeddings']:
664
- model_config.__dict__[k] = config.__dict__[k]
665
- config = copy.copy(model_config)
666
- self.embeddings = FlashEmbeddings(config, with_position=True)
667
- self.encoder = GAUEncoder(config, shift_token=shift_token)
668
- if not pooler:
669
- self.pooler = None
670
- self.config = config
671
- self.pre_trained = pre_trained
672
- self.apply_state(state)
673
-
674
- def get_attention_mask(self, input_ids=None, token_type_ids=None, attention_mask=None, input_mask=None):
675
- if attention_mask is None:
676
- if input_mask is not None:
677
- return input_mask.unsqueeze(-1).expand(input_mask.size(0), input_mask.size(1), input_mask.size(1))
678
- else:
679
- return torch.ones_like(input_ids, dtype=torch.uint8).unsqueeze(-1).expand(input_mask.size(0), input_mask.size(1), input_mask.size(1))
680
- else:
681
- if attention_mask.dim() == 2:
682
- if input_mask is not None:
683
- attention_mask = attention_mask * input_mask
684
- return attention_mask.unsqueeze(-1).expand(input_mask.size(0), input_mask.size(1), attention_mask.size(-1))
685
- if attention_mask.dim() == 4:
686
- attention_mask = attention_mask.squeeze(2)
687
- if attention_mask.dim() == 3:
688
- if input_mask is not None:
689
- return attention_mask * input_mask.unsqueeze(-1).expand(input_mask.size(0), input_mask.size(1), attention_mask.size(-1))
690
- else:
691
- return attention_mask
692
-
693
-
694
- def forward(self, input_ids, input_mask, attention_mask=None, token_type_ids=None,
695
- output_all_encoded_layers=True, position_ids=None, return_att=False):
696
- """
697
- Args:
698
- input_ids:
699
- a torch.LongTensor of shape [batch_size, sequence_length] \
700
- with the word token indices in the vocabulary
701
-
702
- attention_mask:
703
- an optional parameter for input mask or attention mask.
704
-
705
- - If it's an input mask, then it will be torch.LongTensor of shape [batch_size, sequence_length] with indices \
706
- selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max \
707
- input sequence length in the current batch. It's the mask that we typically use for attention when \
708
- a batch has varying length sentences.
709
-
710
- - If it's an attention mask then it will be torch.LongTensor of shape [batch_size, sequence_length, sequence_length]. \
711
- In this case, it's a mask indicate which tokens in the sequence should be attended by other tokens in the sequence.
712
-
713
- token_type_ids:
714
- an optional torch.LongTensor of shape [batch_size, sequence_length] with the token \
715
- types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to \
716
- a `sentence B` token (see BERT paper for more details).
717
-
718
- output_all_encoded_layers:
719
- whether to output results of all encoder layers, default, True
720
-
721
- Returns:
722
-
723
- - The output of the stacked transformer layers if `output_all_encoded_layers=True`, else \
724
- the last layer of stacked transformer layers
725
-
726
- - Attention matrix of self-attention layers if `return_att=True`
727
-
728
-
729
- Example::
730
-
731
- # Batch of wordPiece token ids.
732
- # Each sample was padded with zero to the maxium length of the batch
733
- input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
734
- # Mask of valid input ids
735
- attention_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
736
-
737
- # DeBERTa model initialized with pretrained base model
738
- bert = DeBERTa(pre_trained='base')
739
-
740
- encoder_layers = bert(input_ids, attention_mask=attention_mask)
741
-
742
- """
743
- if token_type_ids is None:
744
- token_type_ids = torch.zeros_like(input_ids)
745
- # input_mask = torch.ones_like(input_ids)
746
-
747
- if input_mask is None:
748
- idxs = torch.flip(torch.cumsum(torch.flip(token_type_ids, [-1]), axis=1), [-1])
749
- input_mask = idxs > 0
750
- if not torch.any(input_mask):
751
- input_mask = torch.ones_like(input_ids)
752
- input_mask = input_mask # .byte()
753
- attention_mask = self.get_attention_mask(input_ids, token_type_ids, attention_mask, input_mask)
754
- attention_mask = attention_mask #.byte()
755
- embedding_output = self.embeddings(input_ids.to(torch.long), token_type_ids.to(torch.long), position_ids, input_mask)
756
- encoder_output = self.encoder(embedding_output['embeddings'], attention_mask, output_all_encoded_layers=output_all_encoded_layers, return_att = return_att)
757
- encoder_output.update(embedding_output)
758
- return encoder_output
759
-
760
- def apply_state(self, state = None):
761
- """ Load state from previous loaded model state dictionary.
762
-
763
- Args:
764
- state (:obj:`dict`, optional): State dictionary as the state returned by torch.module.state_dict(), default: `None`. \
765
- If it's `None`, then will use the pre-trained state loaded via the constructor to re-initialize \
766
- the `DeBERTa` model
767
- """
768
- if self.pre_trained is None and state is None:
769
- return
770
- if state is None:
771
- state, config = load_model_state(self.pre_trained)
772
- self.config = config
773
-
774
- prefix = ''
775
- for k in state:
776
- if 'embeddings.' in k:
777
- if not k.startswith('embeddings.'):
778
- prefix = k[:k.index('embeddings.')]
779
- break
780
-
781
- missing_keys = []
782
- unexpected_keys = []
783
- error_msgs = []
784
- self._load_from_state_dict(state, prefix = prefix, local_metadata=None, strict=True, missing_keys=missing_keys, unexpected_keys=unexpected_keys, error_msgs=error_msgs)
785
-
786
-
787
- class FlashModel(BertModel):
788
- def __init__(self, config) -> None:
789
- super().__init__(config)
790
- self.encoder = FlashEncoder(config)
791
- self.apply(self.init_bert_weights)
792
-
793
- if __name__ == '__main__':
794
- model = FlashModel(768, 64)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling/focal_loss.py DELETED
@@ -1,200 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import torch.cuda.amp as amp
5
-
6
-
7
- ##
8
- # version 1: use torch.autograd
9
- class FocalLossV1(nn.Module):
10
-
11
- def __init__(self,
12
- alpha=0.25,
13
- gamma=2,
14
- reduction='mean',):
15
- super(FocalLossV1, self).__init__()
16
- self.alpha = alpha
17
- self.gamma = gamma
18
- self.reduction = reduction
19
- self.crit = nn.BCEWithLogitsLoss(reduction='none')
20
-
21
- def forward(self, logits, label):
22
- '''
23
- Usage is same as nn.BCEWithLogits:
24
- >>> criteria = FocalLossV1()
25
- >>> logits = torch.randn(8, 19, 384, 384)
26
- >>> lbs = torch.randint(0, 2, (8, 19, 384, 384)).float()
27
- >>> loss = criteria(logits, lbs)
28
- '''
29
- probs = torch.sigmoid(logits)
30
- coeff = torch.abs(label - probs).pow(self.gamma).neg()
31
- log_probs = torch.where(logits >= 0,
32
- F.softplus(logits, -1, 50),
33
- logits - F.softplus(logits, 1, 50))
34
- log_1_probs = torch.where(logits >= 0,
35
- -logits + F.softplus(logits, -1, 50),
36
- -F.softplus(logits, 1, 50))
37
- loss = label * self.alpha * log_probs + (1. - label) * (1. - self.alpha) * log_1_probs
38
- loss = loss * coeff
39
-
40
- if self.reduction == 'mean':
41
- loss = loss.mean()
42
- if self.reduction == 'sum':
43
- loss = loss.sum()
44
- return loss
45
-
46
-
47
- ##
48
- # version 2: user derived grad computation
49
- class FocalSigmoidLossFuncV2(torch.autograd.Function):
50
- '''
51
- compute backward directly for better numeric stability
52
- '''
53
- @staticmethod
54
- @amp.custom_fwd(cast_inputs=torch.float32)
55
- def forward(ctx, logits, label, alpha, gamma):
56
- # logits = logits.float()
57
-
58
- probs = torch.sigmoid(logits)
59
- coeff = (label - probs).abs_().pow_(gamma).neg_()
60
- log_probs = torch.where(logits >= 0,
61
- F.softplus(logits, -1, 50),
62
- logits - F.softplus(logits, 1, 50))
63
- log_1_probs = torch.where(logits >= 0,
64
- -logits + F.softplus(logits, -1, 50),
65
- -F.softplus(logits, 1, 50))
66
- ce_term1 = log_probs.mul_(label).mul_(alpha)
67
- ce_term2 = log_1_probs.mul_(1. - label).mul_(1. - alpha)
68
- ce = ce_term1.add_(ce_term2)
69
- loss = ce * coeff
70
-
71
- ctx.vars = (coeff, probs, ce, label, gamma, alpha)
72
-
73
- return loss
74
-
75
- @staticmethod
76
- @amp.custom_bwd
77
- def backward(ctx, grad_output):
78
- '''
79
- compute gradient of focal loss
80
- '''
81
- (coeff, probs, ce, label, gamma, alpha) = ctx.vars
82
-
83
- d_coeff = (label - probs).abs_().pow_(gamma - 1.).mul_(gamma)
84
- d_coeff.mul_(probs).mul_(1. - probs)
85
- d_coeff = torch.where(label < probs, d_coeff.neg(), d_coeff)
86
- term1 = d_coeff.mul_(ce)
87
-
88
- d_ce = label * alpha
89
- d_ce.sub_(probs.mul_((label * alpha).mul_(2).add_(1).sub_(label).sub_(alpha)))
90
- term2 = d_ce.mul(coeff)
91
-
92
- grads = term1.add_(term2)
93
- grads.mul_(grad_output)
94
-
95
- return grads, None, None, None
96
-
97
-
98
- class FocalLossV2(nn.Module):
99
-
100
- def __init__(self,
101
- alpha=0.25,
102
- gamma=2,
103
- reduction='mean'):
104
- super(FocalLossV2, self).__init__()
105
- self.alpha = alpha
106
- self.gamma = gamma
107
- self.reduction = reduction
108
-
109
- def forward(self, logits, label):
110
- '''
111
- Usage is same as nn.BCEWithLogits:
112
- >>> criteria = FocalLossV2()
113
- >>> logits = torch.randn(8, 19, 384, 384)
114
- >>> lbs = torch.randint(0, 2, (8, 19, 384, 384)).float()
115
- >>> loss = criteria(logits, lbs)
116
- '''
117
- loss = FocalSigmoidLossFuncV2.apply(logits, label, self.alpha, self.gamma)
118
- if self.reduction == 'mean':
119
- loss = loss.mean()
120
- if self.reduction == 'sum':
121
- loss = loss.sum()
122
- return loss
123
-
124
-
125
- if __name__ == '__main__':
126
- import torchvision
127
- import torch
128
- import numpy as np
129
- import random
130
- torch.manual_seed(15)
131
- random.seed(15)
132
- np.random.seed(15)
133
- torch.backends.cudnn.deterministic = True
134
-
135
- class Model(nn.Module):
136
- def __init__(self):
137
- super(Model, self).__init__()
138
- net = torchvision.models.resnet18(pretrained=False)
139
- self.conv1 = net.conv1
140
- self.bn1 = net.bn1
141
- self.maxpool = net.maxpool
142
- self.relu = net.relu
143
- self.layer1 = net.layer1
144
- self.layer2 = net.layer2
145
- self.layer3 = net.layer3
146
- self.layer4 = net.layer4
147
- self.out = nn.Conv2d(512, 3, 3, 1, 1)
148
- def forward(self, x):
149
- feat = self.conv1(x)
150
- feat = self.bn1(feat)
151
- feat = self.relu(feat)
152
- feat = self.maxpool(feat)
153
- feat = self.layer1(feat)
154
- feat = self.layer2(feat)
155
- feat = self.layer3(feat)
156
- feat = self.layer4(feat)
157
- feat = self.out(feat)
158
- out = F.interpolate(feat, x.size()[2:], mode='bilinear', align_corners=True)
159
- return out
160
- net1 = Model()
161
- net2 = Model()
162
- net2.load_state_dict(net1.state_dict())
163
-
164
- criteria1 = FocalLossV2()
165
- # criteria2 = FocalLossV3()
166
- net1.cuda()
167
- net2.cuda()
168
- net1.train()
169
- net2.train()
170
- net1.double()
171
- net2.double()
172
- criteria1.cuda()
173
- # criteria2.cuda()
174
-
175
- optim1 = torch.optim.SGD(net1.parameters(), lr=1e-2)
176
- # optim2 = torch.optim.SGD(net2.parameters(), lr=1e-2)
177
-
178
- bs = 16
179
- for it in range(300000):
180
- inten = torch.randn(bs, 3, 224, 244).cuda()
181
- # lbs = torch.randint(0, 2, (bs, 3, 224, 244)).float().cuda()
182
- lbs = torch.randn(bs, 3, 224, 244).sigmoid().cuda()
183
- inten = inten.double()
184
- lbs = lbs.double()
185
- logits = net1(inten)
186
- loss1 = criteria1(logits, lbs)
187
- optim1.zero_grad()
188
- loss1.backward()
189
- optim1.step()
190
- # logits = net2(inten)
191
- # loss2 = criteria2(logits, lbs)
192
- # optim2.zero_grad()
193
- # loss2.backward()
194
- # optim2.step()
195
- # with torch.no_grad():
196
- # if (it+1) % 50 == 0:
197
- # print('iter: {}, ================='.format(it+1))
198
- # print('out.weight: ', torch.mean(torch.abs(net1.out.weight - net2.out.weight)).item())
199
- # print('conv1.weight: ', torch.mean(torch.abs(net1.conv1.weight - net2.conv1.weight)).item())
200
- # print('loss: ', loss1.item() - loss2.item())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling/gat.py DELETED
@@ -1,665 +0,0 @@
1
- #
2
- # Zhoubo
3
- #
4
- """
5
- FLASH: https://arxiv.org/abs/2202.10447
6
- """
7
- import copy
8
- import torch
9
- import math
10
- import os
11
- from collections import Sequence
12
- import json
13
- import numpy as np
14
- import torch
15
- import torch.nn as nn
16
- import torch.nn.functional as F
17
- from transformers.activations import ACT2FN
18
- from .ops import sequence_masking, XSoftmax, StableDropout, MaskedLayerNorm
19
- from .config import ModelConfig
20
- from .cache_utils import load_model_state
21
- import einops
22
-
23
-
24
- class ScaleNorm(nn.Module):
25
- def __init__(self, eps=1e-5):
26
- super().__init__()
27
- self.eps = eps
28
- self.scala = nn.Parameter(torch.ones(1))
29
-
30
- def forward(self, x):
31
- mean_square = (x ** 2).mean(dim=-1, keepdim=True)
32
- x = x * torch.rsqrt(mean_square + self.eps) * self.scala
33
- return x
34
-
35
-
36
- class BertLayerNorm(nn.Module):
37
- def __init__(self, hidden_size, eps=1e-5):
38
- """Construct a layernorm module in the TF style (epsilon inside the square root).
39
- """
40
- super(BertLayerNorm, self).__init__()
41
- self.weight = nn.Parameter(torch.ones(hidden_size))
42
- self.bias = nn.Parameter(torch.zeros(hidden_size))
43
- self.variance_epsilon = eps
44
-
45
- def forward(self, x):
46
- u = x.mean(-1, keepdim=True)
47
- s = (x - u).pow(2).mean(-1, keepdim=True)
48
- x = (x - u) / torch.sqrt(s + self.variance_epsilon)
49
- return self.weight * x + self.bias
50
-
51
-
52
- class ScaledSinuEmbedding(nn.Module):
53
- def __init__(self, dim):
54
- super().__init__()
55
- self.scale = nn.Parameter(torch.ones(1,))
56
- inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
57
- self.register_buffer('inv_freq', inv_freq)
58
-
59
- def forward(self, x):
60
- n, device = x.shape[1], x.device
61
- t = torch.arange(n, device = device).type_as(self.inv_freq)
62
- sinu = torch.einsum('i , j -> i j', t, self.inv_freq)
63
- emb = torch.cat((sinu.sin(), sinu.cos()), dim = -1)
64
- return emb * self.scale
65
-
66
-
67
- def RoPE(x, dim):
68
- """
69
- :param x: input tensor
70
- :param dim: oprate dimension
71
- :return: tensor
72
- """
73
- shape = x.shape
74
- if isinstance(dim, int):
75
- dim = [dim]
76
-
77
- spatial_shape = [shape[i] for i in dim]
78
- total_len = 1
79
- for i in spatial_shape:
80
- total_len *= i
81
- position = torch.reshape(torch.arange(total_len, dtype=torch.float, device=x.device), spatial_shape)
82
-
83
- for i in range(dim[-1] + 1, len(shape) - 1, 1):
84
- position = torch.unsqueeze(position, dim=-1)
85
-
86
- half_size = shape[-1] // 2
87
- freq_seq = -torch.arange(half_size, dtype=torch.float, device=x.device) / float(half_size)
88
- inv_freq = 10000 ** -freq_seq
89
- sinusoid = torch.einsum("...,d->...d", position, inv_freq)
90
- sin = torch.sin(sinusoid)
91
- cos = torch.cos(sinusoid)
92
- x1, x2 = torch.chunk(x, 2, dim=-1)
93
- return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
94
-
95
-
96
- def rel_pos_bias(seq_len, s):
97
- a = torch.rand([1, s], dtype=torch.float)
98
- b = torch.rand([1, s], dtype=torch.float)
99
- w = torch.rand([2 * seq_len - 1], dtype=torch.float)
100
- if seq_len <= 512:
101
- t = F.pad(w[: 2 * seq_len - 1], [0, seq_len]).repeat(seq_len)
102
- t = t[..., :-seq_len].reshape(-1, seq_len, 3 * seq_len - 2)
103
- r = (2 * seq_len - 1) // 2
104
- t = t[..., r:-r]
105
- else:
106
- a = RoPE(a.repeat(seq_len, 1), dim=[0])
107
- b = RoPE(b.repeat(seq_len, 1), dim=[0])
108
- t = torch.einsum("mk,nk->mn", a, b)
109
- return t
110
-
111
- def squared_relu(x, attention_mask, dim=-1):
112
- rmask = ~(attention_mask.bool())
113
- x = x.masked_fill(rmask, 0)
114
- return torch.square(F.relu(x))
115
-
116
-
117
- def attention_normalize(a, axis=-1, mask=None, fn='softmax'):
118
- if fn == 'softmax':
119
- return XSoftmax.apply(a, mask, axis)
120
- else:
121
- mask_ = a > -float('inf') / 10
122
- # mask_ = mask_.byte()
123
- mask_ = torch.sum(mask_, axis=axis, keepdim=True)
124
- l = torch.maximum(mask_, torch.ones_like(mask_))
125
- if fn == 'relu':
126
- rmask = ~(mask.bool())
127
- a = a.masked_fill(rmask, 0)
128
- return torch.square(F.relu(a)) / l
129
- elif fn == 'softmax_plus':
130
- return XSoftmax.apply(a * torch.log(l) / np.log(512), mask, axis)
131
- return a
132
-
133
-
134
- class GAULinear(nn.Linear):
135
- def init_weight(self):
136
- nn.init.xavier_uniform_(self.weight)
137
-
138
-
139
- class GatedAttentionUnit(nn.Module):
140
- """
141
- GAU Block: Gate Attention Unit
142
- """
143
- def __init__(
144
- self,
145
- max_seq_length,
146
- hidden_size,
147
- attention_key_size=128,
148
- activation='swish',
149
- use_bias=True,
150
- attention_norm_type='squared_relu',
151
- attention_scale=True,
152
- dropout=0.1,
153
- pre_norm=False,
154
- norm_type="layer_norm",
155
- eps=1e-5,
156
- shift_token=False,
157
- use_rel_bias=False,
158
- add_residual=True,
159
- **kwargs,):
160
-
161
- super(GatedAttentionUnit, self).__init__(**kwargs)
162
- self.max_seq_length = max_seq_length
163
- self.units = hidden_size
164
- self.intermediate_size = self.units * 2
165
- self.key_size = attention_key_size
166
- self.activation = activation
167
- self.use_bias = use_bias
168
- self.attention_norm_type = attention_norm_type
169
- self.attention_scale = attention_scale
170
- self.dropout = StableDropout(dropout)
171
- self.i_dense = nn.Sequential(
172
- nn.Linear(self.units, 2 * self.intermediate_size + self.key_size, bias=self.use_bias),
173
- nn.SiLU()
174
- )
175
- self.o_dense = nn.Sequential(
176
- nn.Linear(self.intermediate_size, self.units, bias=self.use_bias),
177
- self.dropout)
178
- self.q_scaleoffset = OffsetScale(self.key_size)
179
- self.k_scaleoffset = OffsetScale(self.key_size)
180
- self.pre_norm = pre_norm
181
- self.norm = (nn.LayerNorm(hidden_size, eps=eps) if norm_type.lower() == "layer_norm" else ScaleNorm(eps=eps))
182
- self.add_residual = add_residual
183
-
184
- def forward(self, x, attention_mask=None, **kwargs):
185
- shortcut = x
186
-
187
- if self.pre_norm:
188
- x = self.norm(x)
189
-
190
- x = self.i_dense(x)
191
- u, v, qk = torch.split(x, [self.intermediate_size, self.intermediate_size, self.key_size], dim=-1)
192
- q, k = self.q_scaleoffset(qk), self.k_scaleoffset(qk)
193
- qk = RoPE(torch.stack([q, k], 2), dim=1)
194
- q, k = qk[:, :, 0], qk[:, :, 1]
195
- a = torch.einsum('bmd,bnd->bmn', q, k)
196
- if self.attention_scale:
197
- a = a / self.key_size**0.5
198
- a = sequence_masking(a, attention_mask, '-inf', -1)
199
- A = attention_normalize(a, -1, fn=self.attention_norm_type)
200
- if self.dropout:
201
- A = self.dropout(A)
202
- out = self.o_dense(u * torch.einsum('bmn,bnd->bmd', A, v))
203
-
204
- if self.add_residual:
205
- out = out + shortcut
206
- if not self.pre_norm:
207
- out = self.norm(out)
208
- return out
209
-
210
-
211
- class OffsetScale(nn.Module):
212
- def __init__(self, dim, heads = 1):
213
- super().__init__()
214
- self.gamma = nn.Parameter(torch.ones(heads, dim))
215
- self.beta = nn.Parameter(torch.zeros(heads, dim))
216
- # nn.init.normal_(self.gamma, std = 0.02)
217
- nn.init.xavier_uniform_(self.gamma)
218
-
219
- def forward(self, x):
220
- out = torch.einsum('... d, h d -> ... h d', x, self.gamma) + self.beta
221
- return out
222
-
223
-
224
- class BertIntermediate(nn.Module):
225
- def __init__(self, config):
226
- super().__init__()
227
- self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
228
- self.intermediate_act_fn = ACT2FN[config.hidden_act] \
229
- if isinstance(config.hidden_act, str) else config.hidden_act
230
-
231
- def forward(self, hidden_states):
232
- hidden_states = self.dense(hidden_states)
233
- hidden_states = self.intermediate_act_fn(hidden_states)
234
- return hidden_states
235
-
236
-
237
- class BertOutput(nn.Module):
238
- def __init__(self, config):
239
- super(BertOutput, self).__init__()
240
- self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
241
- self.LayerNorm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps)
242
- self.dropout = StableDropout(config.hidden_dropout_prob)
243
- self.config = config
244
-
245
- def forward(self, hidden_states, input_states, mask=None):
246
- hidden_states = self.dense(hidden_states)
247
- hidden_states = self.dropout(hidden_states)
248
- hidden_states += input_states
249
- hidden_states = MaskedLayerNorm(self.LayerNorm, hidden_states)
250
- return hidden_states
251
-
252
-
253
- class GAU(nn.Module):
254
- def __init__(self, max_seq_length, hidden_size, expansion_factor=2, s=128, norm_type="layer_norm", eps=1e-5,
255
- hidden_act="silu", shift_token=False, use_rel_bias=False, attention_norm_type='softmax',
256
- pre_norm=False, dropout=0, add_residual = True):
257
- super(GAU, self).__init__()
258
- self.max_seq_length = max_seq_length
259
- self.shift_token = shift_token
260
- hidden_dim = int(expansion_factor * hidden_size)
261
- self.norm = (nn.LayerNorm(hidden_size, eps=eps) if norm_type == "layer_norm" else ScaleNorm(eps=eps))
262
- self.use_rel_bias = use_rel_bias
263
- self.attention_norm_type = attention_norm_type
264
- # if attention_norm_type == 'relu':
265
- # self.attention_norm_func = squared_relu
266
- # else:
267
- # self.attention_norm_func = XSoftmax.apply
268
- # self.norm = norm_klass(hidden_size)
269
-
270
- self.dropout = nn.Dropout(dropout)
271
-
272
- self.to_hidden = nn.Sequential(
273
- nn.Linear(hidden_size, hidden_dim * 2),
274
- nn.SiLU()
275
- )
276
-
277
- self.to_qk = nn.Sequential(
278
- nn.Linear(hidden_size, s),
279
- nn.SiLU()
280
- )
281
-
282
- self.offsetscale = OffsetScale(s, heads = 2)
283
-
284
- self.to_out = nn.Sequential(
285
- nn.Linear(hidden_dim, hidden_size),
286
- nn.Dropout(dropout)
287
- )
288
-
289
- self.add_residual = add_residual
290
- self.act_fn = ACT2FN[hidden_act]
291
- self.pre_norm = pre_norm
292
-
293
-
294
- def forward(
295
- self,
296
- x,
297
- relative_pos = None,
298
- attention_mask = None
299
- ):
300
- seq_len, device = x.shape[-2], x.device
301
- if self.pre_norm:
302
- normed_x = self.norm(x)
303
- else:
304
- normed_x = x
305
- v, gate = self.to_hidden(normed_x).chunk(2, dim = -1)
306
-
307
- qk = self.to_qk(normed_x)
308
- base = self.offsetscale(qk)
309
- base = RoPE(base, 1).half()
310
- q, k = base.unbind(dim = -2)
311
- sim = torch.einsum('b i d, b j d -> b i j', q, k)
312
-
313
- if relative_pos is not None:
314
- sim = sim + relative_pos
315
- if attention_mask is not None:
316
- if attention_mask.dim() < 3:
317
- attention_mask = einops.rearrange(attention_mask, 'b j -> b 1 j')
318
- # attn = attn.masked_fill(~attention_mask.bool(), 0.)
319
- attn = attention_normalize(sim, mask=attention_mask, fn=self.attention_norm_type)
320
- # attn = F.relu(sim) ** 2 / seq_len# / q.size(-1)
321
- # logger.info(attn.max())
322
- attn = self.dropout(attn)
323
- # if self.causal:
324
- # causal_mask = torch.ones((seq_len, seq_len), dtype = torch.bool, device = device).triu(1)
325
- # attn = attn.masked_fill(causal_mask, 0.)
326
-
327
- out = torch.einsum('b i j, b j d -> b i d', attn.half(), v)
328
- out = out * gate
329
-
330
- out = self.to_out(out)
331
-
332
- if self.add_residual:
333
- out = out + x
334
- if not self.pre_norm:
335
- out = self.norm(out)
336
- return out
337
-
338
-
339
- class GatLayer(nn.Module):
340
- def __init__(self, config, shift_token=False, use_ffn=False):
341
- super(GatLayer, self).__init__()
342
- self.attention = GatedAttentionUnit(config.max_position_embeddings, config.hidden_size,
343
- shift_token=shift_token, use_rel_bias=config.use_rel_bias,
344
- norm_type=config.norm_type, attention_norm_type=config.attention_norm_type,
345
- pre_norm=config.pre_norm, dropout=config.hidden_dropout_prob)
346
- if use_ffn:
347
- self.intermediate = BertIntermediate(config)
348
- self.output = BertOutput(config)
349
- self.use_ffn = use_ffn
350
-
351
- def forward(self, hidden_states, attention_mask, return_att=False, query_states=None, relative_pos=None, rel_embeddings=None):
352
- attention_output = self.attention(hidden_states, attention_mask=attention_mask, relative_pos=relative_pos)
353
- if self.use_ffn:
354
- intermediate_output = self.intermediate(attention_output)
355
- layer_output = self.output(intermediate_output, attention_output)
356
- return layer_output
357
- else:
358
- return attention_output
359
-
360
-
361
- class RelativePositionBias(nn.Module):
362
- def __init__(
363
- self,
364
- scale,
365
- causal = False,
366
- num_buckets = 32,
367
- max_distance = 128
368
- ):
369
- super().__init__()
370
- self.scale = scale
371
- self.causal = causal
372
- self.num_buckets = num_buckets
373
- self.max_distance = max_distance
374
- self.relative_attention_bias = nn.Embedding(num_buckets, 1)
375
-
376
- @staticmethod
377
- def _relative_position_bucket(
378
- relative_position,
379
- causal = True,
380
- num_buckets = 32,
381
- max_distance = 128
382
- ):
383
- ret = 0
384
- n = -relative_position
385
- if not causal:
386
- num_buckets //= 2
387
- ret += (n < 0).long() * num_buckets
388
- n = torch.abs(n)
389
- else:
390
- n = torch.max(n, torch.zeros_like(n))
391
-
392
- max_exact = num_buckets // 2
393
- is_small = n < max_exact
394
-
395
- val_if_large = max_exact + (
396
- torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
397
- ).long()
398
- val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
399
-
400
- ret += torch.where(is_small, n, val_if_large)
401
- return ret
402
-
403
- def forward(self, x):
404
- i, j, device = *x.shape[-2:], x.device
405
- q_pos = torch.arange(i, dtype = torch.long, device = device)
406
- k_pos = torch.arange(j, dtype = torch.long, device = device)
407
- rel_pos = einops.rearrange(k_pos, 'j -> 1 j') - einops.rearrange(q_pos, 'i -> i 1')
408
- rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
409
- values = self.relative_attention_bias(rp_bucket)
410
- bias = einops.rearrange(values, 'i j 1 -> i j')
411
- return bias * self.scale
412
-
413
-
414
- class GatEmbeddings(nn.Module):
415
- """Construct the embeddings from word, position and token_type embeddings.
416
- """
417
- def __init__(self, config, with_position=False):
418
- super(GatEmbeddings, self).__init__()
419
- self.word_embeddings = nn.Embedding(
420
- config.vocab_size, config.hidden_size)
421
- self.token_type_embeddings = nn.Embedding(
422
- config.type_vocab_size, config.hidden_size)
423
- self.with_position = with_position
424
- if with_position:
425
- self.position_embeddings = ScaledSinuEmbedding(config.hidden_size)
426
-
427
- # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
428
- # any TensorFlow checkpoint file
429
- self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5)
430
- self.dropout = StableDropout(config.hidden_dropout_prob)
431
-
432
- def forward(self, input_ids, token_type_ids=None, position_ids=None, token_mask=None):
433
- seq_length = input_ids.size(1)
434
- if position_ids is None:
435
- position_ids = torch.arange(
436
- seq_length, dtype=torch.long, device=input_ids.device)
437
- position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
438
- if token_type_ids is None:
439
- token_type_ids = torch.zeros_like(input_ids)
440
-
441
- words_embeddings = self.word_embeddings(input_ids)
442
- if self.with_position:
443
- position_embeddings = self.position_embeddings(words_embeddings)
444
- else:
445
- position_embeddings = 0
446
- token_type_embeddings = self.token_type_embeddings(token_type_ids)
447
-
448
- # if self.num_pos_emb > 1:
449
- # num_batch = position_embeddings.size(0)
450
- # num_pos = position_embeddings.size(1)
451
- # position_embeddings = position_embeddings.view(
452
- # num_batch, num_pos, self.num_pos_emb, -1)[torch.arange(0, num_batch).long(), :, task_idx, :]
453
-
454
- embeddings = words_embeddings + position_embeddings + token_type_embeddings
455
- # if self.fp32_embedding:
456
- # embeddings = embeddings.half()
457
- embeddings = MaskedLayerNorm(self.LayerNorm, embeddings, token_mask)
458
- embeddings = self.dropout(embeddings)
459
- return {
460
- 'embeddings': embeddings,
461
- 'position_embeddings': position_embeddings}
462
-
463
-
464
- class GatEncoder(nn.Module):
465
- def __init__(self, config, shift_token=False):
466
- super().__init__()
467
- layer = GatLayer(config, shift_token=shift_token)
468
- self.layer = nn.ModuleList([copy.deepcopy(layer)
469
- for _ in range(config.num_hidden_layers)])
470
-
471
- def get_attention_mask(self, attention_mask):
472
- if attention_mask.dim() <= 2:
473
- extended_attention_mask = attention_mask.unsqueeze(1)
474
- attention_mask = extended_attention_mask*extended_attention_mask.squeeze(-2).unsqueeze(-1)
475
- attention_mask = attention_mask.byte()
476
- return attention_mask
477
-
478
- def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, return_att=False, query_states = None, relative_pos=None):
479
- all_encoder_layers = []
480
- att_matrices = []
481
- if isinstance(hidden_states, Sequence):
482
- next_kv = hidden_states[0]
483
- else:
484
- next_kv = hidden_states
485
- # rel_embeddings = self.get_rel_embedding()
486
- for i, layer_module in enumerate(self.layer):
487
- output_states = layer_module(next_kv, attention_mask, query_states = query_states, relative_pos=relative_pos)
488
- if return_att:
489
- output_states, att_m = output_states
490
-
491
- # if i == 0 and self.with_conv:
492
- # prenorm = output_states #output['prenorm_states']
493
- # output_states = self.conv(hidden_states, prenorm, input_mask)
494
-
495
- if query_states is not None:
496
- query_states = output_states
497
- if isinstance(hidden_states, Sequence):
498
- next_kv = hidden_states[i+1] if i+1 < len(self.layer) else None
499
- else:
500
- next_kv = output_states
501
-
502
- if output_all_encoded_layers:
503
- all_encoder_layers.append(output_states)
504
- if return_att:
505
- att_matrices.append(att_m)
506
- if not output_all_encoded_layers:
507
- all_encoder_layers.append(output_states)
508
- if return_att:
509
- att_matrices.append(att_m)
510
- return {
511
- 'hidden_states': all_encoder_layers,
512
- 'attention_matrices': att_matrices
513
- }
514
-
515
-
516
- class GatModel(torch.nn.Module):
517
- """
518
- Parameters:
519
- config:
520
- A model config class instance with the configuration to build a new model. The schema is similar to `BertConfig`,
521
-
522
- pre_trained:
523
- The pre-trained DeBERTa model, it can be a physical path of a pre-trained DeBERTa model or a released configurations,
524
- i.e. [**base, large, base_mnli, large_mnli**]
525
-
526
- """
527
-
528
- def __init__(self, config=None, pre_trained=None, pooler=False, shift_token=False, causal=False, **kwargs):
529
- super().__init__()
530
- state = None
531
- if pre_trained is not None:
532
- state, model_config = load_model_state(pre_trained)
533
- if config is not None and model_config is not None:
534
- for k in config.__dict__:
535
- if k not in ['hidden_size',
536
- 'intermediate_size',
537
- 'num_attention_heads',
538
- 'num_hidden_layers',
539
- 'vocab_size',
540
- 'max_position_embeddings']:
541
- model_config.__dict__[k] = config.__dict__[k]
542
- config = copy.copy(model_config)
543
- self.embeddings = GatEmbeddings(config, with_position=True)
544
- self.encoder = GatEncoder(config, shift_token=shift_token)
545
- if not pooler:
546
- self.pooler = None
547
- self.config = config
548
- self.pre_trained = pre_trained
549
- self.apply_state(state)
550
-
551
- def get_attention_mask(self, input_ids=None, token_type_ids=None, attention_mask=None, input_mask=None):
552
- if attention_mask is None:
553
- if input_mask is not None:
554
- return input_mask.unsqueeze(-1).expand(input_mask.size(0), input_mask.size(1), input_mask.size(1))
555
- else:
556
- return torch.ones_like(input_ids, dtype=torch.uint8).unsqueeze(-1).expand(input_mask.size(0), input_mask.size(1), input_mask.size(1))
557
- else:
558
- if attention_mask.dim() == 2:
559
- if input_mask is not None:
560
- attention_mask = attention_mask * input_mask
561
- return attention_mask.unsqueeze(-1).expand(input_mask.size(0), input_mask.size(1), attention_mask.size(-1))
562
- if attention_mask.dim() == 4:
563
- attention_mask = attention_mask.squeeze(2)
564
- if attention_mask.dim() == 3:
565
- if input_mask is not None:
566
- return attention_mask * input_mask.unsqueeze(-1).expand(input_mask.size(0), input_mask.size(1), attention_mask.size(-1))
567
- else:
568
- return attention_mask
569
-
570
-
571
- def forward(self, input_ids, input_mask, attention_mask=None, token_type_ids=None,
572
- output_all_encoded_layers=True, position_ids=None, return_att=False):
573
- """
574
- Args:
575
- input_ids:
576
- a torch.LongTensor of shape [batch_size, sequence_length] \
577
- with the word token indices in the vocabulary
578
-
579
- attention_mask:
580
- an optional parameter for input mask or attention mask.
581
-
582
- - If it's an input mask, then it will be torch.LongTensor of shape [batch_size, sequence_length] with indices \
583
- selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max \
584
- input sequence length in the current batch. It's the mask that we typically use for attention when \
585
- a batch has varying length sentences.
586
-
587
- - If it's an attention mask then it will be torch.LongTensor of shape [batch_size, sequence_length, sequence_length]. \
588
- In this case, it's a mask indicate which tokens in the sequence should be attended by other tokens in the sequence.
589
-
590
- token_type_ids:
591
- an optional torch.LongTensor of shape [batch_size, sequence_length] with the token \
592
- types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to \
593
- a `sentence B` token (see BERT paper for more details).
594
-
595
- output_all_encoded_layers:
596
- whether to output results of all encoder layers, default, True
597
-
598
- Returns:
599
-
600
- - The output of the stacked transformer layers if `output_all_encoded_layers=True`, else \
601
- the last layer of stacked transformer layers
602
-
603
- - Attention matrix of self-attention layers if `return_att=True`
604
-
605
-
606
- Example::
607
-
608
- # Batch of wordPiece token ids.
609
- # Each sample was padded with zero to the maxium length of the batch
610
- input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
611
- # Mask of valid input ids
612
- attention_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
613
-
614
- # DeBERTa model initialized with pretrained base model
615
- bert = DeBERTa(pre_trained='base')
616
-
617
- encoder_layers = bert(input_ids, attention_mask=attention_mask)
618
-
619
- """
620
- if token_type_ids is None:
621
- token_type_ids = torch.zeros_like(input_ids)
622
- # input_mask = torch.ones_like(input_ids)
623
-
624
- if input_mask is None:
625
- idxs = torch.flip(torch.cumsum(torch.flip(token_type_ids, [-1]), axis=1), [-1])
626
- input_mask = idxs > 0
627
- if not torch.any(input_mask):
628
- input_mask = torch.ones_like(input_ids)
629
- input_mask = input_mask.byte()
630
- attention_mask = self.get_attention_mask(input_ids, token_type_ids, attention_mask, input_mask)
631
- attention_mask = attention_mask.byte()
632
- embedding_output = self.embeddings(input_ids.to(torch.long), token_type_ids.to(torch.long), position_ids, input_mask)
633
- encoder_output = self.encoder(embedding_output['embeddings'], attention_mask, output_all_encoded_layers=output_all_encoded_layers, return_att = return_att)
634
- encoder_output.update(embedding_output)
635
- return encoder_output
636
-
637
- def apply_state(self, state = None):
638
- """ Load state from previous loaded model state dictionary.
639
-
640
- Args:
641
- state (:obj:`dict`, optional): State dictionary as the state returned by torch.module.state_dict(), default: `None`. \
642
- If it's `None`, then will use the pre-trained state loaded via the constructor to re-initialize \
643
- the `DeBERTa` model
644
- """
645
- if self.pre_trained is None and state is None:
646
- return
647
- if state is None:
648
- state, config = load_model_state(self.pre_trained)
649
- self.config = config
650
-
651
- prefix = ''
652
- for k in state:
653
- if 'embeddings.' in k:
654
- if not k.startswith('embeddings.'):
655
- prefix = k[:k.index('embeddings.')]
656
- break
657
-
658
- missing_keys = []
659
- unexpected_keys = []
660
- error_msgs = []
661
- self._load_from_state_dict(state, prefix = prefix, local_metadata=None, strict=True, missing_keys=missing_keys, unexpected_keys=unexpected_keys, error_msgs=error_msgs)
662
-
663
-
664
- if __name__ == '__main__':
665
- model = GatModel(768, 64)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling/mlm.py DELETED
@@ -1,38 +0,0 @@
1
- # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
2
- # Copyright (c) Microsoft, Inc. 2020
3
- #
4
- # This source code is licensed under the MIT license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- # This piece of code is modified based on https://github.com/huggingface/transformers
8
-
9
- import torch
10
- from torch import nn
11
- import pdb
12
-
13
- from .bert import LayerNorm,ACT2FN
14
-
15
- __all__ = ['MLMPredictionHead']
16
-
17
- class MLMPredictionHead(nn.Module):
18
- def __init__(self, config, vocab_size):
19
- super().__init__()
20
- self.embedding_size = getattr(config, 'embedding_size', config.hidden_size)
21
- self.dense = nn.Linear(config.hidden_size, self.embedding_size)
22
- self.transform_act_fn = ACT2FN[config.hidden_act] \
23
- if isinstance(config.hidden_act, str) else config.hidden_act
24
-
25
- self.LayerNorm = LayerNorm(self.embedding_size, config.layer_norm_eps)
26
- self.bias = nn.Parameter(torch.zeros(vocab_size))
27
- self.pre_norm = PreLayerNorm(config)
28
-
29
- def forward(self, hidden_states, embeding_weight):
30
- hidden_states = self.pre_norm(hidden_states)
31
- hidden_states = self.dense(hidden_states)
32
- hidden_states = self.transform_act_fn(hidden_states)
33
- # b x s x d
34
- hidden_states = MaskedLayerNorm(self.LayerNorm, hidden_states)
35
-
36
- # b x s x v
37
- logits = torch.matmul(hidden_states, embeding_weight.t().to(hidden_states)) + self.bias
38
- return logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling/modeling.py DELETED
The diff for this file is too large to render. See raw diff
 
modeling/nnmodule.py DELETED
@@ -1,184 +0,0 @@
1
- import pdb
2
- import os
3
- import torch
4
- import copy
5
- from torch import nn, tensor
6
- from .config import ModelConfig
7
- from ..utils import xtqdm as tqdm
8
- from .cache_utils import load_model_state
9
- from .flash import GAULinear
10
-
11
- from ..utils import get_logger
12
- logger = get_logger()
13
-
14
- __all__ = ['NNModule']
15
-
16
- def truncated_normal_(shape, mean=0, std=0.09):
17
- with torch.no_grad():
18
- tensor = torch.zeros(shape)
19
- tmp = tensor.new_empty(shape + (4,)).normal_()
20
- valid = (tmp < 2) & (tmp > -2)
21
- ind = valid.max(-1, keepdim=True)[1]
22
- tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
23
- tensor.data.mul_(std).add_(mean)
24
- return tensor
25
-
26
- class NNModule(nn.Module):
27
- """ An abstract class to handle weights initialization and \
28
- a simple interface for dowloading and loading pretrained models.
29
-
30
- Args:
31
-
32
- config (:obj:`~DeBERTa.deberta.ModelConfig`): The model config to the module
33
-
34
- """
35
-
36
- def __init__(self, config, *inputs, **kwargs):
37
- super().__init__()
38
- self.config = config
39
-
40
- def init_weights(self, module):
41
- """ Apply Gaussian(mean=0, std=`config.initializer_range`) initialization to the module.
42
-
43
- Args:
44
-
45
- module (:obj:`torch.nn.Module`): The module to apply the initialization.
46
-
47
- Example::
48
-
49
- class MyModule(NNModule):
50
- def __init__(self, config):
51
- # Add construction instructions
52
- self.bert = DeBERTa(config)
53
-
54
- # Add other modules
55
- ...
56
-
57
- # Apply initialization
58
- self.apply(self.init_weights)
59
-
60
- """
61
- if isinstance(module, (nn.Linear, nn.Embedding)):
62
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
63
- if isinstance(module, nn.Linear) and module.bias is not None:
64
- module.bias.data.zero_()
65
-
66
- def init_weights_gau(self, module):
67
- """ Apply Gaussian(mean=0, std=`config.initializer_range`) initialization to the module.
68
-
69
- Args:
70
-
71
- module (:obj:`torch.nn.Module`): The module to apply the initialization.
72
-
73
- Example::
74
-
75
- class MyModule(NNModule):
76
- def __init__(self, config):
77
- # Add construction instructions
78
- self.bert = DeBERTa(config)
79
-
80
- # Add other modules
81
- ...
82
-
83
- # Apply initialization
84
- self.apply(self.init_weights)
85
-
86
- """
87
- if isinstance(module, GAULinear):
88
- module.init_weight()
89
- else:
90
- if isinstance(module, (nn.Linear, nn.Embedding)):
91
- # module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
92
- module.weight.data.copy_(self.initializer(module.weight.data.shape))
93
- if isinstance(module, nn.Linear) and module.bias is not None:
94
- module.bias.data.zero_()
95
-
96
- def initializer(self, shape, dtype=None, order=3, gain=1.0):
97
- if shape[1] > 10000 or shape[1] < 10:
98
- hidden_size = shape[0]
99
- else:
100
- hidden_size = shape[1]
101
- gain *= self.config.num_hidden_layers ** (-1.0 / order)
102
- stddev = 1.13684723 / hidden_size**0.5 * gain
103
- return torch.nn.init.trunc_normal_(torch.empty(shape, dtype=dtype), std=stddev)# truncated_normal_(shape, std=stddev)
104
-
105
- @classmethod
106
- def load_model(cls, model_path, model_config=None, tag=None, no_cache=False, cache_dir=None , *inputs, **kwargs):
107
- """ Instantiate a sub-class of NNModule from a pre-trained model file.
108
-
109
- Args:
110
-
111
- model_path (:obj:`str`): Path or name of the pre-trained model which can be either,
112
-
113
- - The path of pre-trained model
114
-
115
- - The pre-trained DeBERTa model name in `DeBERTa GitHub releases <https://github.com/microsoft/DeBERTa/releases>`_, i.e. [**base, base_mnli, large, large_mnli**].
116
-
117
- If `model_path` is `None` or `-`, then the method will create a new sub-class without initialing from pre-trained models.
118
-
119
- model_config (:obj:`str`): The path of model config file. If it's `None`, then the method will try to find the the config in order:
120
-
121
- 1. ['config'] in the model state dictionary.
122
-
123
- 2. `model_config.json` aside the `model_path`.
124
-
125
- If it failed to find a config the method will fail.
126
-
127
- tag (:obj:`str`, optional): The release tag of DeBERTa, default: `None`.
128
-
129
- no_cache (:obj:`bool`, optional): Disable local cache of downloaded models, default: `False`.
130
-
131
- cache_dir (:obj:`str`, optional): The cache directory used to save the downloaded models, default: `None`. If it's `None`, then the models will be saved at `$HOME/.~DeBERTa`
132
-
133
- Return:
134
-
135
- :obj:`NNModule` : The sub-class object.
136
-
137
- """
138
- # Load config
139
- if model_config:
140
- config = ModelConfig.from_json_file(model_config)
141
- else:
142
- config = None
143
- model_config = None
144
- model_state = None
145
- if (model_path is not None) and (model_path.strip() == '-' or model_path.strip()==''):
146
- model_path = None
147
- try:
148
- model_state, model_config = load_model_state(model_path, tag=tag, no_cache=no_cache, cache_dir=cache_dir)
149
- except Exception as exp:
150
- raise Exception(f'Failed to get model {model_path}. Exception: {exp}')
151
-
152
- if config is not None and model_config is not None:
153
- for k in config.__dict__:
154
- if k not in ['hidden_size',
155
- 'intermediate_size',
156
- 'num_attention_heads',
157
- 'num_hidden_layers',
158
- 'vocab_size',
159
- 'max_position_embeddings'] or (k not in model_config.__dict__) or (model_config.__dict__[k] < 0):
160
- model_config.__dict__[k] = config.__dict__[k]
161
- if model_config is not None:
162
- config = copy.copy(model_config)
163
- vocab_size = config.vocab_size
164
- # Instantiate model.
165
- model = cls(config, *inputs, **kwargs)
166
- if not model_state:
167
- return model
168
- # copy state_dict so _load_from_state_dict can modify it
169
- state_dict = model_state.copy()
170
-
171
- missing_keys = []
172
- unexpected_keys = []
173
- error_msgs = []
174
- metadata = getattr(state_dict, '_metadata', None)
175
- def load(module, prefix=''):
176
- local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
177
- module._load_from_state_dict(
178
- state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
179
- for name, child in module._modules.items():
180
- if child is not None:
181
- load(child, prefix + name + '.')
182
- load(model)
183
- logger.warning(f'Missing keys: {missing_keys}, unexpected_keys: {unexpected_keys}, error_msgs: {error_msgs}')
184
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling/ops.py CHANGED
@@ -7,12 +7,10 @@
7
  # Date: 01/15/2020
8
  #
9
 
10
- import pdb
11
  import math
12
  from packaging import version
13
  import torch
14
  from torch.nn import LayerNorm
15
- from wywLM.utils.jit_tracing import traceable
16
 
17
  if version.Version(torch.__version__) >= version.Version('1.0.0'):
18
  from torch import _softmax_backward_data as _softmax_backward_data
@@ -21,7 +19,7 @@ else:
21
 
22
  __all__ = ['StableDropout', 'MaskedLayerNorm', 'XSoftmax', 'ACT2FN', 'LayerNorm']
23
 
24
- @traceable
25
  class XSoftmax(torch.autograd.Function):
26
  """ Masked Softmax which is optimized for saving memory
27
 
@@ -113,7 +111,7 @@ def get_mask(input, local_context):
113
 
114
  return mask, dropout
115
 
116
- @traceable
117
  class XDropout(torch.autograd.Function):
118
  @staticmethod
119
  def forward(ctx, input, local_ctx):
 
7
  # Date: 01/15/2020
8
  #
9
 
 
10
  import math
11
  from packaging import version
12
  import torch
13
  from torch.nn import LayerNorm
 
14
 
15
  if version.Version(torch.__version__) >= version.Version('1.0.0'):
16
  from torch import _softmax_backward_data as _softmax_backward_data
 
19
 
20
  __all__ = ['StableDropout', 'MaskedLayerNorm', 'XSoftmax', 'ACT2FN', 'LayerNorm']
21
 
22
+
23
  class XSoftmax(torch.autograd.Function):
24
  """ Masked Softmax which is optimized for saving memory
25
 
 
111
 
112
  return mask, dropout
113
 
114
+
115
  class XDropout(torch.autograd.Function):
116
  @staticmethod
117
  def forward(ctx, input, local_ctx):
modeling/pretrained_models.py DELETED
@@ -1,2 +0,0 @@
1
-
2
-
 
 
 
modeling/wywlm_modeling.py DELETED
@@ -1,446 +0,0 @@
1
- # Copyright (c) Microsoft, Inc. 2020
2
- #
3
- # This source code is licensed under the MIT license found in the
4
- # LICENSE file in the root directory of this source tree.
5
- #
6
- # Zhou Bo
7
- # Date: 01/15/2020
8
- #
9
-
10
- import copy
11
- import torch
12
- import os
13
- import random
14
-
15
- import json
16
- from .ops import *
17
- from .bert import *
18
- from .bert import BertLayer
19
- from .config import ModelConfig
20
- from .cache_utils import load_model_state
21
- from .nnmodule import NNModule
22
-
23
- # from ..utils.bad_grad_viz import register_hooks
24
-
25
- __all__ = ['WywLM']
26
-
27
- def flatten_states(q_states, mask_index):
28
- q_states = q_states.reshape((-1, q_states.size(-1)))
29
- q_states = q_states.index_select(0, mask_index)
30
- return q_states
31
-
32
-
33
- class UGDecoder(torch.nn.Module):
34
- def __init__(self, config, vocab_size):
35
- super().__init__()
36
- self.config = config
37
- self.position_biased_input = getattr(config, 'position_biased_input', True)
38
- # self.layer = torch.nn.ModuleList([BertLayer(config) for _ in range(2)])
39
-
40
- # self.causal_mask = torch.tril(torch.ones((input_ids.dim(0), input_ids.dim(1), input_ids.dim(1))), diagonal=0)
41
-
42
- def forward(self, ctx_layers, word_embedding, input_ids, z_states, attention_mask, \
43
- encoder, target_ids=None, relative_pos=None, decode=False, s2s_idx=None):
44
- causal_outputs, lm_outputs = self.emd_context_layer(ctx_layers, z_states, attention_mask,
45
- encoder, target_ids, input_ids,
46
- relative_pos=relative_pos, decode=decode,
47
- word_embedding=word_embedding, s2s_idx=s2s_idx)
48
- # loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
49
-
50
- # ctx_layer = mlm_ctx_layers[-1]
51
-
52
- # lm_logits = lm_logits.view(-1, lm_logits.size(-1))
53
-
54
- return causal_outputs[-1], lm_outputs[-1]
55
-
56
- def emd_context_layer(self, encoder_layers, z_states, attention_mask, encoder, target_ids, input_ids,\
57
- relative_pos=None, decode=False, word_embedding=None, s2s_idx=None):
58
- # if decode:
59
- # attention_mask = torch.tril(torch.ones((input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1])), diagonal=0).to(input_ids.device)
60
- # else:
61
- if attention_mask.dim()<=2:
62
- extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
63
- att_mask = extended_attention_mask.byte()
64
- attention_mask = att_mask*att_mask.squeeze(-2).unsqueeze(-1)
65
- elif attention_mask.dim()==3:
66
- attention_mask = attention_mask.unsqueeze(1)
67
-
68
-
69
- if not self.position_biased_input:
70
-
71
-
72
- lm_outputs = []
73
- # else:
74
- hidden_states = encoder_layers[-2]
75
- layers = [encoder.layer[-1] for _ in range(2)]
76
- z_states += hidden_states
77
- query_states = z_states
78
- query_mask = attention_mask
79
- rel_embeddings = encoder.get_rel_embedding()
80
- for layer in layers:
81
- # TODO: pass relative pos ids
82
- output = layer(hidden_states, query_mask, return_att=False,
83
- query_states=query_states, relative_pos=relative_pos,
84
- rel_embeddings=rel_embeddings)
85
- query_states = output
86
- lm_outputs.append(query_states)
87
-
88
- # if decode:
89
- attention_mask = torch.tril(torch.ones((input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1])),
90
- diagonal=0).to(input_ids.device)
91
- causal_outputs = []
92
- # with torch.no_grad():
93
- target_embd = word_embedding(target_ids)
94
-
95
- target_embd += z_states.detach()
96
- # self attention of target
97
- output = layers[-2](target_embd, attention_mask, return_att=False,
98
- query_states=target_embd, relative_pos=relative_pos,
99
- rel_embeddings=encoder.get_rel_embedding())
100
- causal_outputs.append(output)
101
- # cross attention
102
- output = layers[-1](output, attention_mask, return_att=False,
103
- query_states=query_states, relative_pos=relative_pos,
104
- rel_embeddings=encoder.get_rel_embedding())
105
- causal_outputs.append(output)
106
-
107
- else:
108
- causal_outputs = [encoder_layers[-1]]
109
- lm_outputs = [encoder_layers[-1]]
110
- return causal_outputs, lm_outputs
111
-
112
-
113
- def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
114
- """
115
- Shift input ids one token to the right.
116
- """
117
- shifted_input_ids = input_ids.new_zeros(input_ids.shape)
118
- shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
119
- shifted_input_ids[:, 0] = decoder_start_token_id
120
-
121
- if pad_token_id is None:
122
- raise ValueError("self.model.config.pad_token_id has to be defined.")
123
- # replace possible -100 values in labels by `pad_token_id`
124
- shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
125
-
126
- return shifted_input_ids
127
-
128
-
129
- class WywLMLoss(torch.nn.Module):
130
- def __init__(self, config) -> None:
131
- super().__init__()
132
- self.loss_fn = torch.nn.CrossEntropyLoss(reduction='mean')
133
- hidden_size = getattr(config, 'embedding_size', config.hidden_size)
134
- self.compare = torch.nn.Linear(hidden_size * 3, 2)
135
- # self.mlm_head = BertLMPredictionHead(config, config.vocab_size)
136
- self.lm_head = BertLMPredictionHead(config, config.vocab_size)
137
-
138
- def forward(self, logits, lm_logits, target_ids, dict_pos, input_ids, target_ids_s2s, decode=False, ebd_weight=None, task=0):
139
- loss_compare = torch.tensor(0).to(logits).float()
140
- mlm_loss = torch.tensor(0).to(logits).float()
141
- lm_loss = torch.tensor(0).to(logits).float()
142
-
143
- # else:
144
- if task == 1:
145
- compare_logits = []
146
- compare_labels = []
147
- for bi, sampel_pos in enumerate(dict_pos):
148
- num_pos = int((sampel_pos > 0).sum().detach().cpu().numpy() / 4) - 1
149
- if num_pos <= 1:
150
- continue
151
- for pi in range(num_pos):
152
- pos = sampel_pos[pi]
153
- entry_logits = logits[bi][pos[0]: pos[1]]
154
- desc_logits = logits[bi][pos[2]: pos[3]]
155
- neg_num = random.randint(0, num_pos) # torch.randint(low=0, high=num_pos, size=(1,))
156
- ids_neg = input_ids[bi][sampel_pos[neg_num][0]: sampel_pos[neg_num][1]]
157
- ids_pos = input_ids[bi][pos[0]: pos[1]]
158
- if neg_num == pi or (ids_neg.shape == ids_pos.shape and torch.all(ids_neg == ids_pos)):
159
- neg_num = -1
160
- for ni in range(num_pos):
161
- neg_num = random.randint(0, num_pos)# torch.randint(low=0, high=num_pos, size=(1,))
162
- ids_neg = input_ids[bi][sampel_pos[neg_num][0]: sampel_pos[neg_num][1]]
163
- if neg_num != pi and (ids_neg.shape != ids_pos.shape or not torch.all(ids_neg == ids_pos)):
164
- break
165
- else:
166
- neg_num = -1
167
- if neg_num == -1:
168
- continue
169
- neg_desc_logits = logits[bi][sampel_pos[neg_num][2]: sampel_pos[neg_num][3]]
170
- if torch.any(torch.isnan(neg_desc_logits)):
171
- print('error')
172
- entry_logits = entry_logits.mean(dim=0, keepdim=True).float()
173
- desc_logits = desc_logits.mean(dim=0, keepdim=True).float()
174
- neg_desc_logits = neg_desc_logits.mean(dim=0, keepdim=True).float()
175
- compare_logits.append(torch.concat([entry_logits, desc_logits, entry_logits - desc_logits], dim=1))
176
- compare_logits.append(torch.concat([entry_logits, neg_desc_logits, entry_logits - neg_desc_logits], dim=1))
177
- compare_labels += [1, 0]
178
- if len(compare_logits) > 0:
179
- compare_logits = torch.concat(compare_logits, dim=0).to(logits.dtype)
180
- compare_pred = self.compare(compare_logits)
181
- loss_compare = self.loss_fn(compare_pred, torch.tensor(compare_labels, dtype=torch.long, device=compare_logits.device)).mean()
182
-
183
- if torch.all(loss_compare == 0):
184
- entry_logits = logits[0][0].unsqueeze(0)
185
- compare_logits = torch.concat([entry_logits, entry_logits, entry_logits - entry_logits], dim=1)
186
- compare_pred = self.compare(compare_logits)
187
- compare_labels = [1]
188
- loss_compare = self.loss_fn(compare_pred, torch.tensor(compare_labels, dtype=torch.long, device=compare_logits.device)).mean()
189
-
190
- # if decode:
191
- # lm_labels = target_ids_s2s.index_select(0, (target_ids_s2s.sum(-1) > 0).nonzero().view(-1)[0])
192
- # lm_labels = lm_labels.repeat(logits.shape[0], 1).clone().view(-1)
193
- # lm_labels = target_ids_s2s.clone()
194
- # target_ids_s2s = shift_tokens_right(target_ids_s2s, 0, 1)
195
- # target_ids_s2s.masked_fill_(target_ids_s2s==0, 3)
196
- if task == 0:
197
- _mask_index = (target_ids_s2s > 0).view(-1).nonzero().view(-1)
198
- lm_logits_ = flatten_states(lm_logits, _mask_index)
199
- lm_pred = self.lm_head(lm_logits_, ebd_weight).float()
200
- lm_labels = target_ids_s2s.clone().reshape(-1)
201
- lm_labels = lm_labels.index_select(0, _mask_index)
202
- # lm_pred = torch.nn.functional.log_softmax(lm_pred)
203
- # lm_loss = torch.nn.functional.nll_loss(lm_pred, lm_labels.long())
204
- lm_loss = self.loss_fn(lm_pred, lm_labels.long())
205
- # dot = register_hooks(lm_loss)
206
- # lm_loss.backward()
207
- # dot().save('tmp.dot')
208
-
209
-
210
- _mask_index = (target_ids > 0).view(-1).nonzero().view(-1)
211
- mlm_logits = flatten_states(logits, _mask_index)
212
- mlm_pred = self.lm_head(mlm_logits, ebd_weight).float()
213
- mlm_labels = target_ids.view(-1)
214
- mlm_labels = mlm_labels.index_select(0, _mask_index)
215
- mlm_loss = self.loss_fn(mlm_pred, mlm_labels.long())
216
- return loss_compare, mlm_loss, lm_loss
217
-
218
- class WywLM(torch.nn.Module):
219
- """ DeBERTa encoder
220
- This module is composed of the input embedding layer with stacked transformer layers with disentangled attention.
221
-
222
- Parameters:
223
- config:
224
- A model config class instance with the configuration to build a new model. The schema is similar to `BertConfig`, \
225
- for more details, please refer :class:`~DeBERTa.deberta.ModelConfig`
226
-
227
- pre_trained:
228
- The pre-trained DeBERTa model, it can be a physical path of a pre-trained DeBERTa model or a released configurations, \
229
- i.e. [**base, large, base_mnli, large_mnli**]
230
-
231
- """
232
-
233
- def __init__(self, config=None, pre_trained=None):
234
- super().__init__()
235
- state = None
236
- if pre_trained is not None:
237
- state, model_config = load_model_state(pre_trained)
238
- if config is not None and model_config is not None:
239
- for k in config.__dict__:
240
- if k not in ['hidden_size',
241
- 'intermediate_size',
242
- 'num_attention_heads',
243
- 'num_hidden_layers',
244
- 'vocab_size',
245
- 'max_position_embeddings']:
246
- model_config.__dict__[k] = config.__dict__[k]
247
- config = copy.copy(model_config)
248
- self.embeddings = BertEmbeddings(config)
249
- self.encoder = BertEncoder(config)
250
- self.config = config
251
- self.pre_trained = pre_trained
252
- self.apply_state(state)
253
-
254
- def forward(self, input_ids, attention_mask=None, token_type_ids=None, output_all_encoded_layers=True, position_ids = None, return_att = False):
255
- """
256
- Args:
257
- input_ids:
258
- a torch.LongTensor of shape [batch_size, sequence_length] \
259
- with the word token indices in the vocabulary
260
-
261
- attention_mask:
262
- an optional parameter for input mask or attention mask.
263
-
264
- - If it's an input mask, then it will be torch.LongTensor of shape [batch_size, sequence_length] with indices \
265
- selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max \
266
- input sequence length in the current batch. It's the mask that we typically use for attention when \
267
- a batch has varying length sentences.
268
-
269
- - If it's an attention mask then it will be torch.LongTensor of shape [batch_size, sequence_length, sequence_length]. \
270
- In this case, it's a mask indicate which tokens in the sequence should be attended by other tokens in the sequence.
271
-
272
- token_type_ids:
273
- an optional torch.LongTensor of shape [batch_size, sequence_length] with the token \
274
- types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to \
275
- a `sentence B` token (see BERT paper for more details).
276
-
277
- output_all_encoded_layers:
278
- whether to output results of all encoder layers, default, True
279
-
280
- Returns:
281
-
282
- - The output of the stacked transformer layers if `output_all_encoded_layers=True`, else \
283
- the last layer of stacked transformer layers
284
-
285
- - Attention matrix of self-attention layers if `return_att=True`
286
-
287
-
288
- Example::
289
-
290
- # Batch of wordPiece token ids.
291
- # Each sample was padded with zero to the maxium length of the batch
292
- input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
293
- # Mask of valid input ids
294
- attention_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
295
-
296
- # DeBERTa model initialized with pretrained base model
297
- bert = DeBERTa(pre_trained='base')
298
-
299
- encoder_layers = bert(input_ids, attention_mask=attention_mask)
300
-
301
- """
302
-
303
- if attention_mask is None:
304
- attention_mask = torch.ones_like(input_ids)
305
- if token_type_ids is None:
306
- token_type_ids = torch.zeros_like(input_ids)
307
- token_mask = torch.ones_like(input_ids)
308
- else:
309
- idxs = torch.flip(torch.cumsum(torch.flip(token_type_ids, [-1]), axis=1), [-1])
310
- token_mask = idxs > 0
311
- token_mask = token_mask.byte()
312
- ebd_output = self.embeddings(input_ids.to(torch.long), token_type_ids.to(torch.long), position_ids, token_mask)
313
- embedding_output = ebd_output['embeddings']
314
- encoder_output = self.encoder(embedding_output,
315
- attention_mask,
316
- output_all_encoded_layers=output_all_encoded_layers, return_att = return_att)
317
- encoder_output.update(ebd_output)
318
- return encoder_output
319
-
320
- def apply_state(self, state = None):
321
- """ Load state from previous loaded model state dictionary.
322
-
323
- Args:
324
- state (:obj:`dict`, optional): State dictionary as the state returned by torch.module.state_dict(), default: `None`. \
325
- If it's `None`, then will use the pre-trained state loaded via the constructor to re-initialize \
326
- the `DeBERTa` model
327
- """
328
- if self.pre_trained is None and state is None:
329
- return
330
- if state is None:
331
- state, config = load_model_state(self.pre_trained)
332
- self.config = config
333
-
334
- prefix = ''
335
- for k in state:
336
- if 'embeddings.' in k:
337
- if not k.startswith('embeddings.'):
338
- prefix = k[:k.index('embeddings.')]
339
- break
340
-
341
- missing_keys = []
342
- unexpected_keys = []
343
- error_msgs = []
344
- self._load_from_state_dict(state, prefix = prefix, local_metadata=None, strict=True, missing_keys=missing_keys, unexpected_keys=unexpected_keys, error_msgs=error_msgs)
345
-
346
-
347
- class MaskedLanguageModel(NNModule):
348
- """ Masked language model
349
- """
350
- def __init__(self, config, *wargs, **kwargs):
351
- super().__init__(config)
352
- self.backbone = WywLM(config)
353
-
354
- self.max_relative_positions = getattr(config, 'max_relative_positions', -1)
355
- self.position_buckets = getattr(config, 'position_buckets', -1)
356
- if self.max_relative_positions <1:
357
- self.max_relative_positions = config.max_position_embeddings
358
- # self.mlm_predictions = UGDecoder(self.backbone.config, self.backbone.embeddings.word_embeddings.weight.size(0))
359
- self.lm_predictions = UGDecoder(self.backbone.config, self.backbone.embeddings.word_embeddings.weight.size(0))
360
- self.device = None
361
- self.loss = WywLMLoss(config)
362
- # self.loss_lm = WywLMLoss(config)
363
- self.apply(self.init_weights)
364
-
365
- def forward(self, samples, position_ids=None):
366
- task = samples['task']
367
- if task == 0:
368
- input_ids = samples['s2s_input_ids']
369
- type_ids = samples['s2s_token_type_ids']
370
- attention_mask = samples['s2s_attention_mask']
371
- labels = samples['s2s_masked_lm_labels']
372
- dict_pos = samples['dict_pos']
373
- s2s_label = samples['s2s_label']
374
- else:
375
- input_ids = samples['input_ids']
376
- type_ids = samples['token_type_ids']
377
- attention_mask = samples['attention_mask']
378
- labels = samples['masked_lm_labels']
379
- dict_pos = samples['dict_pos']
380
- s2s_label = samples['s2s_label']
381
-
382
- if self.device is None:
383
- self.device = list(self.parameters())[0].device
384
-
385
- input_ids = input_ids.to(self.device)
386
-
387
- type_ids = None
388
- lm_labels = labels.to(self.device)
389
- s2s_label = s2s_label.to(self.device)
390
- attention_mask = attention_mask.to(self.device)
391
-
392
- encoder_output = self.backbone(input_ids, attention_mask, type_ids, output_all_encoded_layers=True, position_ids = position_ids)
393
- encoder_layers = encoder_output['hidden_states']
394
- z_states = encoder_output['position_embeddings']
395
- ctx_layer = encoder_layers[-1]
396
- mlm_loss = torch.tensor(0).to(ctx_layer).float()
397
- lm_loss = torch.tensor(0).to(ctx_layer).float()
398
- lm_logits = None
399
- label_inputs = None
400
- loss = torch.tensor(0).to(ctx_layer).float()
401
- loss_compare = torch.tensor(0).to(ctx_layer).float()
402
-
403
- ebd_weight = self.backbone.embeddings.word_embeddings.weight
404
- lm_logits, mlm_logits = self.lm_predictions(encoder_layers, self.backbone.embeddings.word_embeddings,
405
- input_ids, z_states,
406
- attention_mask, self.backbone.encoder,
407
- target_ids=lm_labels)
408
- # if lm_labels.detach().sum() != 0:
409
- loss_compare, mlm_loss, lm_loss = self.loss(mlm_logits,
410
- lm_logits,
411
- lm_labels,
412
- dict_pos,
413
- target_ids_s2s=s2s_label,
414
- decode=False,
415
- ebd_weight=ebd_weight,
416
- input_ids=input_ids,
417
- task=task)
418
- loss = loss_compare * 10 + mlm_loss + lm_loss
419
- # if s2s_label.detach().sum() != 0:
420
- # s2s_idx = (s2s_label.sum(-1)>0).nonzero().view(-1)
421
- # s2s_label = s2s_label.index_select(0, s2s_idx)
422
- # # ebd_weight = self.backbone.embeddings.word_embeddings.weight
423
- # # lm_logits = self.lm_predictions(encoder_layers[-3], self.backbone.embeddings.word_embeddings,
424
- # # input_ids.index_select(0, s2s_idx), z_states.index_select(0, s2s_idx),
425
- # # attention_mask.index_select(0, s2s_idx), self.backbone.encoder,
426
- # # target_ids=s2s_label,
427
- # # decode=True, s2s_idx=s2s_idx)
428
- # # lm_logits = encoder_layers[-1].detach().index_select(0, s2s_idx)
429
- # _, lm_loss = self.loss_lm(lm_logits,
430
- # s2s_label,
431
- # torch.zeros_like(dict_pos),
432
- # decode=True,
433
- # ebd_weight=ebd_weight,
434
- # input_ids=input_ids.index_select(0, s2s_idx))
435
- # lm_loss = lm_logits.max()
436
- # loss = loss + lm_loss
437
-
438
- return {
439
- 'logits' : lm_logits,
440
- 'labels' : lm_labels,
441
- 's2s_label': s2s_label,
442
- 'loss' : loss.float(),
443
- 'loss_compare': loss_compare.float(),
444
- 'lm_loss': lm_loss.float(),
445
- 'mlm_loss': mlm_loss.float()
446
- }