ysharma HF staff commited on
Commit
4f39709
1 Parent(s): da673fb
Files changed (1) hide show
  1. modules.py +6 -4
modules.py CHANGED
@@ -31,8 +31,9 @@ class ClassEmbedder(nn.Module):
31
 
32
  class TransformerEmbedder(AbstractEncoder):
33
  """Some transformer encoder layers"""
34
- def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
35
  super().__init__()
 
36
  self.device = device
37
  self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
38
  attn_layers=Encoder(dim=n_embed, depth=n_layer))
@@ -48,10 +49,11 @@ class TransformerEmbedder(AbstractEncoder):
48
 
49
  class BERTTokenizer(AbstractEncoder):
50
  """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
51
- def __init__(self, device="cuda", vq_interface=True, max_length=77):
52
  super().__init__()
53
  from transformers import BertTokenizerFast # TODO: add to reuquirements
54
  self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
 
55
  self.device = device
56
  self.vq_interface = vq_interface
57
  self.max_length = max_length
@@ -76,7 +78,7 @@ class BERTTokenizer(AbstractEncoder):
76
  class BERTEmbedder(AbstractEncoder):
77
  """Uses the BERT tokenizr model and add some transformer encoder layers"""
78
  def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
79
- device="cuda",use_tokenizer=True, embedding_dropout=0.0):
80
  super().__init__()
81
  self.use_tknz_fn = use_tokenizer
82
  if self.use_tknz_fn:
@@ -88,7 +90,7 @@ class BERTEmbedder(AbstractEncoder):
88
 
89
  def forward(self, text):
90
  if self.use_tknz_fn:
91
- tokens = self.tknz_fn(text)#.to(self.device)
92
  else:
93
  tokens = text
94
  z = self.transformer(tokens, return_embeddings=True)
31
 
32
  class TransformerEmbedder(AbstractEncoder):
33
  """Some transformer encoder layers"""
34
+ def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cpu"):
35
  super().__init__()
36
+ #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
  self.device = device
38
  self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
39
  attn_layers=Encoder(dim=n_embed, depth=n_layer))
49
 
50
  class BERTTokenizer(AbstractEncoder):
51
  """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
52
+ def __init__(self, device="cpu", vq_interface=True, max_length=77):
53
  super().__init__()
54
  from transformers import BertTokenizerFast # TODO: add to reuquirements
55
  self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
56
+ #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
  self.device = device
58
  self.vq_interface = vq_interface
59
  self.max_length = max_length
78
  class BERTEmbedder(AbstractEncoder):
79
  """Uses the BERT tokenizr model and add some transformer encoder layers"""
80
  def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
81
+ device="cpu",use_tokenizer=True, embedding_dropout=0.0):
82
  super().__init__()
83
  self.use_tknz_fn = use_tokenizer
84
  if self.use_tknz_fn:
90
 
91
  def forward(self, text):
92
  if self.use_tknz_fn:
93
+ tokens = self.tknz_fn(text) #.to(self.device)
94
  else:
95
  tokens = text
96
  z = self.transformer(tokens, return_embeddings=True)