chore: clean up embedding.py
Browse files- embedding.py +13 -6
- modeling_bert.py +1 -1
embedding.py
CHANGED
@@ -7,10 +7,9 @@ https://github.com/Dao-AILab/flash-attention/blob/43950dda456e095969d842fca7a73c
|
|
7 |
|
8 |
import torch
|
9 |
import torch.nn as nn
|
10 |
-
from torch import Tensor
|
11 |
|
12 |
|
13 |
-
class
|
14 |
def __init__(
|
15 |
self,
|
16 |
embed_dim,
|
@@ -37,24 +36,32 @@ class BertEmbeddings(nn.Module):
|
|
37 |
max_position_embeddings, embed_dim, **factory_kwargs
|
38 |
)
|
39 |
if self.type_vocab_size > 0:
|
40 |
-
self.token_type_embeddings = nn.Embedding(
|
|
|
|
|
41 |
|
42 |
def forward(self, input_ids, position_ids=None, token_type_ids=None):
|
43 |
"""
|
44 |
input_ids: (batch, seqlen)
|
45 |
position_ids: (batch, seqlen)
|
46 |
token_type_ids: (batch, seqlen)
|
|
|
|
|
47 |
"""
|
48 |
-
|
49 |
embeddings = self.word_embeddings(input_ids)
|
50 |
if self.max_position_embeddings > 0:
|
51 |
if position_ids is None:
|
52 |
-
position_ids = torch.arange(
|
|
|
|
|
53 |
position_embeddings = self.position_embeddings(position_ids)
|
54 |
embeddings = embeddings + position_embeddings
|
55 |
if self.type_vocab_size > 0:
|
56 |
if token_type_ids is None:
|
57 |
-
token_type_ids = torch.zeros(
|
|
|
|
|
58 |
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
59 |
embeddings = embeddings + token_type_embeddings
|
60 |
return embeddings
|
|
|
7 |
|
8 |
import torch
|
9 |
import torch.nn as nn
|
|
|
10 |
|
11 |
|
12 |
+
class JinaBertEmbeddings(nn.Module):
|
13 |
def __init__(
|
14 |
self,
|
15 |
embed_dim,
|
|
|
36 |
max_position_embeddings, embed_dim, **factory_kwargs
|
37 |
)
|
38 |
if self.type_vocab_size > 0:
|
39 |
+
self.token_type_embeddings = nn.Embedding(
|
40 |
+
type_vocab_size, embed_dim, **factory_kwargs
|
41 |
+
)
|
42 |
|
43 |
def forward(self, input_ids, position_ids=None, token_type_ids=None):
|
44 |
"""
|
45 |
input_ids: (batch, seqlen)
|
46 |
position_ids: (batch, seqlen)
|
47 |
token_type_ids: (batch, seqlen)
|
48 |
+
..note: layer norm and dropout has been taken out from Embeddings forward, but in `moddeling_bert.py`.
|
49 |
+
This is different from jina_bert_implementation.
|
50 |
"""
|
51 |
+
_, seqlen = input_ids.shape
|
52 |
embeddings = self.word_embeddings(input_ids)
|
53 |
if self.max_position_embeddings > 0:
|
54 |
if position_ids is None:
|
55 |
+
position_ids = torch.arange(
|
56 |
+
seqlen, dtype=torch.long, device=input_ids.device
|
57 |
+
)
|
58 |
position_embeddings = self.position_embeddings(position_ids)
|
59 |
embeddings = embeddings + position_embeddings
|
60 |
if self.type_vocab_size > 0:
|
61 |
if token_type_ids is None:
|
62 |
+
token_type_ids = torch.zeros(
|
63 |
+
seqlen, dtype=torch.long, device=input_ids.device
|
64 |
+
)
|
65 |
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
66 |
embeddings = embeddings + token_type_embeddings
|
67 |
return embeddings
|
modeling_bert.py
CHANGED
@@ -37,7 +37,7 @@ from .bert_padding import (
|
|
37 |
)
|
38 |
|
39 |
from .block import Block
|
40 |
-
from .embedding import
|
41 |
from .mha import MHA
|
42 |
from .mlp import FusedMLP, Mlp
|
43 |
|
|
|
37 |
)
|
38 |
|
39 |
from .block import Block
|
40 |
+
from .embedding import JinaBertEmbeddings
|
41 |
from .mha import MHA
|
42 |
from .mlp import FusedMLP, Mlp
|
43 |
|