zhihan1996 commited on
Commit
3ea006c
1 Parent(s): 46964ab

Update dnabert_layer.py

Browse files
Files changed (1) hide show
  1. dnabert_layer.py +13 -1
dnabert_layer.py CHANGED
@@ -3,9 +3,21 @@ from typing import List, Optional, Tuple, Union
3
  import torch
4
  import torch.nn as nn
5
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
6
- from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel
 
 
7
  from transformers.modeling_outputs import SequenceClassifierOutput
8
 
 
 
 
 
 
 
 
 
 
 
9
  class DNABertForSequenceClassification(BertPreTrainedModel):
10
  def __init__(self, config):
11
  super().__init__(config)
 
3
  import torch
4
  import torch.nn as nn
5
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
6
+ from transformers.models.bert.modeling_bert import BertModel as TransformersBertModel
7
+ from transformers.models.bert.modeling_bert import BertForMaskedLM as TransformersBertForMaskedLM
8
+ from transformers.models.bert.modeling_bert import BertPreTrainedModel
9
  from transformers.modeling_outputs import SequenceClassifierOutput
10
 
11
+
12
+ class BertModel(TransformersBertModel):
13
+ def __init__(self, config):
14
+ super().__init__(config)
15
+
16
+ class BertForMaskedLM(TransformersBertForMaskedLM):
17
+ def __init__(self, config):
18
+ super().__init__(config)
19
+
20
+
21
  class DNABertForSequenceClassification(BertPreTrainedModel):
22
  def __init__(self, config):
23
  super().__init__(config)