sahilnishad commited on
Commit
3558c2e
1 Parent(s): 6699d1c

Create bert_model.py

Browse files
Files changed (1) hide show
  1. bert_model.py +30 -0
bert_model.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class BERTEmbedding(nn.Module):
5
+ def __init__(self, vocab_size, n_segments, max_len, embed_dim, dropout):
6
+ super().__init__()
7
+ self.token_embed = nn.Embedding(vocab_size, embed_dim)
8
+ self.segment_embed = nn.Embedding(n_segments, embed_dim)
9
+ self.pos_embed = nn.Embedding(max_len, embed_dim)
10
+ self.drop = nn.Dropout(dropout)
11
+ self.pos_inp = torch.tensor([i for i in range(max_len)],)
12
+
13
+ def forward(self, seq, seg):
14
+ current_max_len = seq.size(1) # Get current sequence length
15
+ pos_inp = torch.arange(0, current_max_len, device=seq.device).unsqueeze(0) # Dynamically create position tensor based on input size
16
+ embed_val = self.token_embed(seq) + self.segment_embed(seg) + self.pos_embed(pos_inp)
17
+ embed_val = self.drop(embed_val)
18
+ return embed_val
19
+
20
+ class BERT(nn.Module):
21
+ def __init__(self, vocab_size, n_segments, max_len, embed_dim, n_layers, attn_heads, dropout):
22
+ super().__init__()
23
+ self.embedding = BERTEmbedding(vocab_size, n_segments, max_len, embed_dim, dropout)
24
+ self.encoder_layer = nn.TransformerEncoderLayer(embed_dim, attn_heads, embed_dim*4)
25
+ self.encoder_block = nn.TransformerEncoder(self.encoder_layer, n_layers)
26
+
27
+ def forward(self, seq, seg):
28
+ out = self.embedding(seq, seg)
29
+ out = self.encoder_block(out)
30
+ return out