Markus28 commited on
Commit
a416a9d
1 Parent(s): cdf5490

feat: make num of loras part of the config

Browse files
Files changed (2) hide show
  1. configuration_bert.py +3 -1
  2. modeling_lora.py +3 -3
configuration_bert.py CHANGED
@@ -86,6 +86,7 @@ class JinaBertConfig(PretrainedConfig):
86
  use_qk_norm=True,
87
  emb_pooler=None,
88
  classifier_dropout=None,
 
89
  **kwargs,
90
  ):
91
  assert 'position_embedding_type' not in kwargs
@@ -118,4 +119,5 @@ class JinaBertConfig(PretrainedConfig):
118
  self.use_flash_attn = use_flash_attn
119
  self.use_qk_norm = use_qk_norm
120
  self.emb_pooler = emb_pooler
121
- self.classifier_dropout = classifier_dropout
 
 
86
  use_qk_norm=True,
87
  emb_pooler=None,
88
  classifier_dropout=None,
89
+ num_loras=5,
90
  **kwargs,
91
  ):
92
  assert 'position_embedding_type' not in kwargs
 
119
  self.use_flash_attn = use_flash_attn
120
  self.use_qk_norm = use_qk_norm
121
  self.emb_pooler = emb_pooler
122
+ self.classifier_dropout = classifier_dropout
123
+ self.num_loras = num_loras
modeling_lora.py CHANGED
@@ -201,14 +201,14 @@ class LoRAParametrization(nn.Module):
201
 
202
 
203
  class BertLoRA(BertPreTrainedModel):
204
- def __init__(self, config: JinaBertConfig, bert: Optional[BertModel] = None, add_pooling_layer=True, num_adaptions=1):
205
  super().__init__(config)
206
  if bert is None:
207
  self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
208
  else:
209
  self.bert = bert
210
- self._num_adaptions = num_adaptions
211
- self._register_lora(num_adaptions)
212
  self.main_params_trainable = False
213
  self.current_task = 0
214
 
 
201
 
202
 
203
  class BertLoRA(BertPreTrainedModel):
204
+ def __init__(self, config: JinaBertConfig, bert: Optional[BertModel] = None, add_pooling_layer=True):
205
  super().__init__(config)
206
  if bert is None:
207
  self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
208
  else:
209
  self.bert = bert
210
+ self._num_adaptions = config.num_loras
211
+ self._register_lora(self._num_adaptions)
212
  self.main_params_trainable = False
213
  self.current_task = 0
214