ydshieh commited on
Commit
b860da0
1 Parent(s): eb8e175
configuration_my_model.py CHANGED
@@ -7,7 +7,9 @@ class MyModelConfig(PretrainedConfig):
7
  def __init__(
8
  self,
9
  n_layers=2,
 
10
  **kwargs,
11
  ):
12
  self.n_layers = n_layers
 
13
  super().__init__(**kwargs)
 
7
  def __init__(
8
  self,
9
  n_layers=2,
10
+ hidden_dim=3,
11
  **kwargs,
12
  ):
13
  self.n_layers = n_layers
14
+ self.hidden_dim = hidden_dim
15
  super().__init__(**kwargs)
modeling_my_model.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from transformers.modeling_utils import PreTrainedModel
2
 
3
  from .configuration_my_model import MyModelConfig
@@ -11,5 +13,8 @@ class MyModel(MyModelPretrainedModel):
11
 
12
  def __init__(self, config: MyModelConfig):
13
  super().__init__(config)
 
14
 
15
  self.n_layers = config.n_layers
 
 
 
1
+ from torch import nn
2
+
3
  from transformers.modeling_utils import PreTrainedModel
4
 
5
  from .configuration_my_model import MyModelConfig
 
13
 
14
  def __init__(self, config: MyModelConfig):
15
  super().__init__(config)
16
+ self.config = config
17
 
18
  self.n_layers = config.n_layers
19
+ self.hidden_dim = config.hidden_dim
20
+ self.linear = nn.Linear(config.hidden_dim, config.hidden_dim)