makram93 commited on
Commit
71b163e
1 Parent(s): 7ad815b

feat: set adapter based on prompt

Browse files

Signed-off-by: Mohammad Kalim Akram <kalim.akram@jina.ai>

Files changed (2) hide show
  1. modeling_lora.py +34 -15
  2. modeling_xlm_roberta.py +3 -9
modeling_lora.py CHANGED
@@ -14,9 +14,6 @@ from transformers import PretrainedConfig
14
  from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel, XLMRobertaPreTrainedModel
15
 
16
 
17
- LORA_NO_UPDATE = '__lora_no_update__'
18
-
19
-
20
  def initialized_weights(
21
  shape: Tuple[int], num_adaptations: int, init: str = "kaiming"
22
  ) -> torch.Tensor:
@@ -247,6 +244,13 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
247
  self._task_idx = None
248
  # By default, disable LoRA until it's specified which adapter/task to use
249
  self.current_task = None
 
 
 
 
 
 
 
250
 
251
  @property
252
  def main_params_trainable(self):
@@ -332,9 +336,18 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
332
  partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
333
  )
334
 
335
- def forward(self, *args, task: Union[str, None] = LORA_NO_UPDATE, **kwargs):
336
- if task != LORA_NO_UPDATE:
337
- self.current_task = task
 
 
 
 
 
 
 
 
 
338
 
339
  return self.roberta(*args, **kwargs)
340
 
@@ -355,7 +368,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
355
  def encode(
356
  self,
357
  *args,
358
- task: Union[str, None] = LORA_NO_UPDATE,
359
  **kwargs,
360
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
361
  """
@@ -364,18 +377,24 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
364
  task(`str`, *optional*, defaults to `LORA_NO_UPDATE`):
365
  Specifies the task for which the encoding is intended. This parameter controls the
366
  use of specialized LoRA adapters that are tuned for specific tasks. If `task` is set
367
- to `LORA_NO_UPDATE`, there will be no update to the current task, retaining the
368
- existing adapter configuration. If `task` is explicitly set to `None`, all LoRA
369
- adapters are disabled, and the model reverts to its original, general-purpose weights.
370
- If `task` is set to a specific LoRA adaptation, that adaptation is activated.
371
  """
372
- if task != LORA_NO_UPDATE:
373
- if not task:
 
 
 
 
 
 
 
374
  warnings.warn(
375
  f"Task-specific embeddings are disabled. To enable, specify the `task` "
376
  f"argument with one of the supported tasks: {', '.join(self.config.lora_adaptations)}",
377
  category=UserWarning,
378
  )
379
- self.current_task = task
380
 
381
- return self.roberta.encode(*args, **kwargs)
 
14
  from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel, XLMRobertaPreTrainedModel
15
 
16
 
 
 
 
17
  def initialized_weights(
18
  shape: Tuple[int], num_adaptations: int, init: str = "kaiming"
19
  ) -> torch.Tensor:
 
244
  self._task_idx = None
245
  # By default, disable LoRA until it's specified which adapter/task to use
246
  self.current_task = None
247
+ self.prompts = {
248
+ 'query': 'Represent the query for retrieving supporting documents: ',
249
+ 'document': 'Represent the document for retrieval: ',
250
+ 'sts': 'Represent the text for Semantic Textual Similarity: ',
251
+ 'clustering': 'Cluster the text: ',
252
+ 'classification': 'Classify the text: ',
253
+ }
254
 
255
  @property
256
  def main_params_trainable(self):
 
336
  partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
337
  )
338
 
339
+ def forward(self, *args, task_type: Union[str, None] = None, **kwargs):
340
+ if task_type:
341
+ self.current_task = task_type
342
+ else:
343
+ input_ids = kwargs["input_ids"]
344
+ input_text = self.roberta.tokenizer.decode(input_ids[0], skip_special_tokens=True)
345
+ for task_name, prompt in self.prompts.items():
346
+ if input_text.startswith(prompt):
347
+ self.current_task = task_name
348
+ break
349
+ else:
350
+ self.current_task = None # No task-specific adapter is found, just use the general-purpose weights
351
 
352
  return self.roberta(*args, **kwargs)
353
 
 
368
  def encode(
369
  self,
370
  *args,
371
+ task_type: Union[str, None] = None,
372
  **kwargs,
373
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
374
  """
 
377
  task(`str`, *optional*, defaults to `LORA_NO_UPDATE`):
378
  Specifies the task for which the encoding is intended. This parameter controls the
379
  use of specialized LoRA adapters that are tuned for specific tasks. If `task` is set
380
+ to `None`, all LoRA adapters are disabled, and the model reverts to its original,
381
+ general-purpose weights. If `task` is set to a specific LoRA adaptation, that adaptation
382
+ is activated.
 
383
  """
384
+ if task_type:
385
+ self.current_task = task_type
386
+ else: # infer the task from the input text
387
+ input_text = args[0][0] if isinstance(args[0], list) else args[0] # take only the first sentence
388
+ for task_name, prompt in self.prompts.items():
389
+ if input_text.startswith(prompt):
390
+ self.current_task = task_name
391
+ break
392
+ else:
393
  warnings.warn(
394
  f"Task-specific embeddings are disabled. To enable, specify the `task` "
395
  f"argument with one of the supported tasks: {', '.join(self.config.lora_adaptations)}",
396
  category=UserWarning,
397
  )
398
+ self.current_task = None # No task-specific adapter is found, just use the general-purpose weights
399
 
400
+ return self.roberta.encode(*args, **kwargs)
modeling_xlm_roberta.py CHANGED
@@ -21,7 +21,7 @@ import torch.nn.functional as F
21
  import torch.utils.checkpoint
22
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
  from einops import rearrange
24
- from transformers import PretrainedConfig
25
  from transformers.modeling_utils import PreTrainedModel
26
  from transformers.modeling_outputs import MaskedLMOutput,SequenceClassifierOutput
27
  from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
@@ -440,7 +440,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
440
  self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
441
 
442
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
443
-
444
 
445
  @torch.inference_mode()
446
  def encode(
@@ -492,12 +492,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
492
  If convert_to_tensor, a stacked tensor is returned.
493
  If convert_to_numpy, a numpy matrix is returned.
494
  """
495
- from transformers import AutoTokenizer
496
-
497
- self.tokenizer = AutoTokenizer.from_pretrained(
498
- self.name_or_path, trust_remote_code=True
499
- )
500
-
501
  is_training = self.training
502
  self.eval()
503
 
@@ -1278,4 +1272,4 @@ class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
1278
  logits=logits,
1279
  hidden_states=outputs.hidden_states,
1280
  attentions=outputs.attentions,
1281
- )
 
21
  import torch.utils.checkpoint
22
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
  from einops import rearrange
24
+ from transformers import PretrainedConfig, AutoTokenizer
25
  from transformers.modeling_utils import PreTrainedModel
26
  from transformers.modeling_outputs import MaskedLMOutput,SequenceClassifierOutput
27
  from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
 
440
  self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
441
 
442
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
443
+ self.tokenizer = AutoTokenizer.from_pretrained(self.name_or_path, trust_remote_code=True)
444
 
445
  @torch.inference_mode()
446
  def encode(
 
492
  If convert_to_tensor, a stacked tensor is returned.
493
  If convert_to_numpy, a numpy matrix is returned.
494
  """
 
 
 
 
 
 
495
  is_training = self.training
496
  self.eval()
497
 
 
1272
  logits=logits,
1273
  hidden_states=outputs.hidden_states,
1274
  attentions=outputs.attentions,
1275
+ )