gmastrapas commited on
Commit
d779277
·
1 Parent(s): a4480ad

feat: remove adapter_mask from interface

Browse files
Files changed (2) hide show
  1. hf_model.py +76 -23
  2. modeling_clip.py +10 -39
hf_model.py CHANGED
@@ -1,6 +1,6 @@
1
  import re
2
  import warnings
3
- from typing import Dict, Optional
4
 
5
  import torch
6
  import torch.nn as nn
@@ -208,21 +208,48 @@ class HFTextEncoder(nn.Module):
208
  self._task_instructions = self.transformer._task_instructions
209
  self._supports_task_instructions = True
210
 
211
- self.default_instruction_task = None
212
- self.default_lora_task = None
213
- self.default_instruction = None
214
- self.default_loraid = None
 
215
  if default_instruction_task is not None:
216
- self.default_instruction_task = default_instruction_task
217
- self.default_instruction = self.get_instruction_from_task(
218
  default_instruction_task
219
  )
220
  if default_lora_task is not None:
221
- self.default_lora_task = default_lora_task
222
- self.default_loraid = self.get_loraid_from_task(default_lora_task)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
- def get_instruction_from_task(self, task: str) -> Optional[str]:
 
 
 
 
 
 
 
 
225
  if self._supports_task_instructions:
 
 
226
  if task not in self._task_instructions:
227
  raise ValueError(
228
  f'Unsupported task \'{task}\'. Choose one of the following: '
@@ -231,14 +258,17 @@ class HFTextEncoder(nn.Module):
231
  )
232
  return self._task_instructions[task]
233
  else:
234
- warnings.warn(
235
- 'Model does not support task instructions, ignoring instruction '
236
- f"task '{task}'"
237
- )
 
238
  return None
239
 
240
- def get_loraid_from_task(self, task: str) -> Optional[int]:
241
  if self._supports_lora:
 
 
242
  if task not in self._lora_adaptation_map:
243
  raise ValueError(
244
  f'Unsupported task \'{task}\'. Choose one of the following: '
@@ -247,11 +277,18 @@ class HFTextEncoder(nn.Module):
247
  )
248
  return self._lora_adaptation_map[task]
249
  else:
250
- warnings.warn(
251
- f"Model does not support LoRA adapters, ignoring LoRA task '{task}'"
252
- )
 
253
  return None
254
 
 
 
 
 
 
 
255
  @torch.jit.ignore
256
  def set_grad_checkpointing(self, _=True):
257
  self.transformer.gradient_checkpointing_enable()
@@ -260,12 +297,28 @@ class HFTextEncoder(nn.Module):
260
  pass
261
 
262
  def forward(self, x: torch.Tensor, adapter_mask: Optional[torch.Tensor] = None):
263
- attn_mask = (x != self.config.pad_token_id).long()
264
- kwargs = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  if adapter_mask is not None:
266
- kwargs['adapter_mask'] = adapter_mask
267
- out = self.transformer(input_ids=x, attention_mask=attn_mask, **kwargs)
268
- pooled_out = self.pooler(out, attn_mask)
 
 
 
269
  projected = self.proj(pooled_out)
270
  seqlen = out.last_hidden_state.shape[1]
271
  tokens = (
 
1
  import re
2
  import warnings
3
+ from typing import Dict, Optional, Union
4
 
5
  import torch
6
  import torch.nn as nn
 
208
  self._task_instructions = self.transformer._task_instructions
209
  self._supports_task_instructions = True
210
 
211
+ self._default_instruction_task = None
212
+ self._default_lora_task = None
213
+ self._default_instruction = None
214
+ self._default_loraid = None
215
+
216
  if default_instruction_task is not None:
217
+ self._default_instruction_task = default_instruction_task
218
+ self._default_instruction = self.get_instruction_from_task(
219
  default_instruction_task
220
  )
221
  if default_lora_task is not None:
222
+ self._default_lora_task = default_lora_task
223
+ self._default_loraid = self.get_loraid_from_task(default_lora_task)
224
+
225
+ @property
226
+ def supports_task_instructions(self) -> bool:
227
+ return self._supports_task_instructions
228
+
229
+ @property
230
+ def supports_lora(self) -> bool:
231
+ return self._supports_lora
232
+
233
+ @property
234
+ def task_instructions(self) -> Dict[str, str]:
235
+ return self._task_instructions
236
+
237
+ @property
238
+ def lora_adaptation_map(self) -> Dict[str, int]:
239
+ return self._lora_adaptation_map
240
 
241
+ @property
242
+ def default_instruction(self) -> Optional[str]:
243
+ return self._default_instruction
244
+
245
+ @property
246
+ def default_loraid(self) -> Optional[int]:
247
+ return self._default_loraid
248
+
249
+ def get_instruction_from_task(self, task: Optional[str]) -> Optional[str]:
250
  if self._supports_task_instructions:
251
+ if task is None:
252
+ return self._default_instruction
253
  if task not in self._task_instructions:
254
  raise ValueError(
255
  f'Unsupported task \'{task}\'. Choose one of the following: '
 
258
  )
259
  return self._task_instructions[task]
260
  else:
261
+ if task is not None:
262
+ warnings.warn(
263
+ 'Model does not support task instructions, ignoring instruction '
264
+ f"task '{task}'"
265
+ )
266
  return None
267
 
268
+ def get_loraid_from_task(self, task: Optional[str]) -> Optional[int]:
269
  if self._supports_lora:
270
+ if task is None:
271
+ return self._default_loraid
272
  if task not in self._lora_adaptation_map:
273
  raise ValueError(
274
  f'Unsupported task \'{task}\'. Choose one of the following: '
 
277
  )
278
  return self._lora_adaptation_map[task]
279
  else:
280
+ if task is not None:
281
+ warnings.warn(
282
+ f"Model does not support LoRA adapters, ignoring LoRA task '{task}'"
283
+ )
284
  return None
285
 
286
+ @staticmethod
287
+ def get_adapter_mask_from_loraid(
288
+ batch_size: int, loraid: int, device: Union[str, torch.device]
289
+ ):
290
+ return torch.full((batch_size,), loraid, dtype=torch.int32, device=device)
291
+
292
  @torch.jit.ignore
293
  def set_grad_checkpointing(self, _=True):
294
  self.transformer.gradient_checkpointing_enable()
 
297
  pass
298
 
299
  def forward(self, x: torch.Tensor, adapter_mask: Optional[torch.Tensor] = None):
300
+ if adapter_mask is None:
301
+ default_loraid = self.default_loraid
302
+ if default_loraid is not None:
303
+ adapter_mask = self.get_adapter_mask_from_loraid(
304
+ x.shape[0], default_loraid, x.device
305
+ )
306
+ else:
307
+ if not self.supports_lora:
308
+ warnings.warn(
309
+ 'Model does not support LoRA adapters, setting adapter_mask to None'
310
+ )
311
+ adapter_mask = None
312
+
313
+ attention_mask = (x != self.config.pad_token_id).long()
314
+ lora_kwargs = {}
315
  if adapter_mask is not None:
316
+ lora_kwargs['adapter_mask'] = adapter_mask
317
+
318
+ out = self.transformer(
319
+ input_ids=x, attention_mask=attention_mask, **lora_kwargs
320
+ )
321
+ pooled_out = self.pooler(out, attention_mask)
322
  projected = self.proj(pooled_out)
323
  seqlen = out.last_hidden_state.shape[1]
324
  tokens = (
modeling_clip.py CHANGED
@@ -159,9 +159,6 @@ class JinaCLIPTextModel(JinaCLIPPreTrainedModel):
159
  self,
160
  input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
161
  return_dict: Optional[bool] = None,
162
- use_lora: bool = False,
163
- adapter_mask: Optional[torch.Tensor] = None,
164
- task: Optional[str] = None,
165
  *_,
166
  **__,
167
  ) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPTextModelOutput]:
@@ -169,12 +166,7 @@ class JinaCLIPTextModel(JinaCLIPPreTrainedModel):
169
  return_dict if return_dict is not None else self.config.use_return_dict
170
  )
171
  x = input_ids.input_ids if isinstance(input_ids, BatchEncoding) else input_ids
172
- feats = self.text_model(
173
- x=x,
174
- use_lora=use_lora,
175
- adapter_mask=adapter_mask,
176
- task=task,
177
- )
178
  out = CLIPTextModelOutput(text_embeds=feats)
179
  return out if return_dict else out.to_tuple()
180
 
@@ -277,12 +269,11 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
277
  def get_text_features(
278
  self,
279
  input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
280
- adapter_mask: Optional[torch.Tensor] = None,
281
  *_,
282
  **__,
283
  ) -> torch.FloatTensor:
284
  x = input_ids.input_ids if isinstance(input_ids, BatchEncoding) else input_ids
285
- return self.text_projection(self.text_model(x=x, adapter_mask=adapter_mask))
286
 
287
  def get_image_features(
288
  self,
@@ -461,9 +452,9 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
461
  sentences(`str` or `List[str]`):
462
  Sentence or sentences to be encoded
463
  task(`str`, *optional*, defaults to `None`):
464
- Specifies the task for which the encoding is intended. If `task` is
465
- not provided, all LoRA adapters are disabled, and the model reverts
466
- to its original, general-purpose weights
467
  batch_size(`int`, *optional*, defaults to 32):
468
  Batch size for the computation
469
  show_progress_bar(`bool`, *optional*, defaults to None):
@@ -534,35 +525,17 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
534
 
535
  truncate_dim = truncate_dim or self.config.truncate_dim
536
 
537
- instruction = self.text_model.default_instruction
538
- loraid = self.text_model.default_loraid
539
- if task:
540
- _selected_instruction = self.text_model.get_instruction_from_task(task)
541
- if _selected_instruction is not None:
542
- instruction = _selected_instruction
543
- _selected_loraid = self.text_model.get_loraid_from_task(task)
544
- if _selected_loraid is not None:
545
- loraid = _selected_loraid
546
-
547
- if instruction is not None:
548
  sentences = [instruction + sentence for sentence in sentences]
549
 
550
- adapter_mask = None
551
- if loraid is not None:
552
- nexamples = 1 if isinstance(sentences, str) else len(sentences)
553
- adapter_mask = torch.full(
554
- (nexamples,), loraid, dtype=torch.int32, device=self.device
555
- )
556
-
557
  for i in range_iter:
558
  tokens = self.tokenizer(
559
  sentences[i: i + batch_size],
560
  return_tensors='pt',
561
  **tokenizer_kwargs,
562
  ).to(self.device)
563
- embeddings = self.get_text_features(
564
- input_ids=tokens, adapter_mask=adapter_mask
565
- )
566
  if truncate_dim:
567
  embeddings = self.truncate_embeddings(embeddings, truncate_dim)
568
  if normalize_embeddings:
@@ -589,7 +562,6 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
589
  self,
590
  input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
591
  pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None,
592
- adapter_mask: Optional[torch.Tensor] = None,
593
  return_dict: Optional[bool] = None,
594
  return_loss: Optional[bool] = None,
595
  *_,
@@ -599,9 +571,8 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
599
  return_dict if return_dict is not None else self.config.use_return_dict
600
  )
601
  image_embeds = self.get_image_features(pixel_values=pixel_values)
602
- text_embeds = self.get_text_features(
603
- input_ids=input_ids, adapter_mask=adapter_mask
604
- )
605
  # normalized features
606
  image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
607
  text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
 
159
  self,
160
  input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
161
  return_dict: Optional[bool] = None,
 
 
 
162
  *_,
163
  **__,
164
  ) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPTextModelOutput]:
 
166
  return_dict if return_dict is not None else self.config.use_return_dict
167
  )
168
  x = input_ids.input_ids if isinstance(input_ids, BatchEncoding) else input_ids
169
+ feats = self.text_model(x=x)
 
 
 
 
 
170
  out = CLIPTextModelOutput(text_embeds=feats)
171
  return out if return_dict else out.to_tuple()
172
 
 
269
  def get_text_features(
270
  self,
271
  input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
 
272
  *_,
273
  **__,
274
  ) -> torch.FloatTensor:
275
  x = input_ids.input_ids if isinstance(input_ids, BatchEncoding) else input_ids
276
+ return self.text_projection(self.text_model(x=x))
277
 
278
  def get_image_features(
279
  self,
 
452
  sentences(`str` or `List[str]`):
453
  Sentence or sentences to be encoded
454
  task(`str`, *optional*, defaults to `None`):
455
+ Specifies the task for which the encoding is intended. If a `task` is
456
+ provided, a task-specific instruction is added to the beginning of each
457
+ sentence. If `task` is not provided, no instructions are added.
458
  batch_size(`int`, *optional*, defaults to 32):
459
  Batch size for the computation
460
  show_progress_bar(`bool`, *optional*, defaults to None):
 
525
 
526
  truncate_dim = truncate_dim or self.config.truncate_dim
527
 
528
+ instruction = self.text_model.get_instruction_from_task(task)
529
+ if instruction:
 
 
 
 
 
 
 
 
 
530
  sentences = [instruction + sentence for sentence in sentences]
531
 
 
 
 
 
 
 
 
532
  for i in range_iter:
533
  tokens = self.tokenizer(
534
  sentences[i: i + batch_size],
535
  return_tensors='pt',
536
  **tokenizer_kwargs,
537
  ).to(self.device)
538
+ embeddings = self.get_text_features(input_ids=tokens)
 
 
539
  if truncate_dim:
540
  embeddings = self.truncate_embeddings(embeddings, truncate_dim)
541
  if normalize_embeddings:
 
562
  self,
563
  input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
564
  pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None,
 
565
  return_dict: Optional[bool] = None,
566
  return_loss: Optional[bool] = None,
567
  *_,
 
571
  return_dict if return_dict is not None else self.config.use_return_dict
572
  )
573
  image_embeds = self.get_image_features(pixel_values=pixel_values)
574
+ text_embeds = self.get_text_features(input_ids=input_ids)
575
+
 
576
  # normalized features
577
  image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
578
  text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)