Markus28 commited on
Commit
32458be
1 Parent(s): 3b35eab

feat: added encode method

Browse files
Files changed (2) hide show
  1. configuration_bert.py +2 -0
  2. modeling_bert.py +165 -0
configuration_bert.py CHANGED
@@ -84,6 +84,7 @@ class JinaBertConfig(PretrainedConfig):
84
  num_tasks=0,
85
  use_flash_attn=True,
86
  use_qk_norm=True,
 
87
  **kwargs,
88
  ):
89
  assert 'position_embedding_type' not in kwargs
@@ -112,3 +113,4 @@ class JinaBertConfig(PretrainedConfig):
112
  self.num_tasks = num_tasks
113
  self.use_flash_attn = use_flash_attn
114
  self.use_qk_norm = use_qk_norm
 
 
84
  num_tasks=0,
85
  use_flash_attn=True,
86
  use_qk_norm=True,
87
+ emb_pooler=None,
88
  **kwargs,
89
  ):
90
  assert 'position_embedding_type' not in kwargs
 
113
  self.num_tasks = num_tasks
114
  self.use_flash_attn = use_flash_attn
115
  self.use_qk_norm = use_qk_norm
116
+ self.emb_pooler = emb_pooler
modeling_bert.py CHANGED
@@ -15,7 +15,10 @@ and made modifications to use ALiBi.
15
  import logging
16
  from collections.abc import Sequence
17
  from functools import partial
 
 
18
 
 
19
  import torch
20
  import torch.nn as nn
21
  import torch.nn.functional as F
@@ -54,6 +57,10 @@ try:
54
  except ImportError:
55
  CrossEntropyLoss = None
56
 
 
 
 
 
57
 
58
  logger = logging.getLogger(__name__)
59
 
@@ -346,6 +353,15 @@ class BertModel(BertPreTrainedModel):
346
  self.pooler = BertPooler(config) if add_pooling_layer else None
347
  self.task_type_embeddings = nn.Embedding(config.num_tasks, config.hidden_size)
348
 
 
 
 
 
 
 
 
 
 
349
  # We now initialize the task embeddings to 0; We do not use task types during
350
  # pretraining. When we start using task types during embedding training,
351
  # we want the model to behave exactly as in pretraining (i.e. task types
@@ -419,6 +435,155 @@ class BertModel(BertPreTrainedModel):
419
  )
420
 
421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
  class BertForPreTraining(BertPreTrainedModel):
423
  def __init__(self, config: JinaBertConfig):
424
  super().__init__(config)
 
15
  import logging
16
  from collections.abc import Sequence
17
  from functools import partial
18
+ from typing import Union, List, Optional
19
+ import warnings
20
 
21
+ import numpy as np
22
  import torch
23
  import torch.nn as nn
24
  import torch.nn.functional as F
 
57
  except ImportError:
58
  CrossEntropyLoss = None
59
 
60
+ try:
61
+ from tqdm.autonotebook import trange
62
+ except ImportError:
63
+ trange = None
64
 
65
  logger = logging.getLogger(__name__)
66
 
 
353
  self.pooler = BertPooler(config) if add_pooling_layer else None
354
  self.task_type_embeddings = nn.Embedding(config.num_tasks, config.hidden_size)
355
 
356
+ self.emb_pooler = config.emb_pooler
357
+ self._name_or_path = config._name_or_path
358
+ if self.emb_pooler is not None:
359
+ from transformers import AutoTokenizer
360
+
361
+ self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
362
+ else:
363
+ self.tokenizer = None
364
+
365
  # We now initialize the task embeddings to 0; We do not use task types during
366
  # pretraining. When we start using task types during embedding training,
367
  # we want the model to behave exactly as in pretraining (i.e. task types
 
435
  )
436
 
437
 
438
+ @torch.inference_mode()
439
+ def encode(
440
+ self: 'BertModel',
441
+ sentences: Union[str, List[str]],
442
+ batch_size: int = 32,
443
+ show_progress_bar: Optional[bool] = None,
444
+ output_value: str = 'sentence_embedding',
445
+ convert_to_numpy: bool = True,
446
+ convert_to_tensor: bool = False,
447
+ device: Optional[torch.device] = None,
448
+ normalize_embeddings: bool = False,
449
+ **tokenizer_kwargs,
450
+ ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
451
+ """
452
+ Computes sentence embeddings
453
+ Args:
454
+ sentences(`str` or `List[str]`):
455
+ Sentence or sentences to be encoded
456
+ batch_size(`int`, *optional*, defaults to 32):
457
+ Batch size for the computation
458
+ show_progress_bar(`bool`, *optional*, defaults to None):
459
+ Show a progress bar when encoding sentences.
460
+ If set to None, progress bar is only shown when `logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
461
+ output_value(`str`, *optional*, defaults to 'sentence_embedding'):
462
+ Default sentence_embedding, to get sentence embeddings.
463
+ Can be set to token_embeddings to get wordpiece token embeddings.
464
+ Set to None, to get all output values
465
+ convert_to_numpy(`bool`, *optional*, defaults to True):
466
+ If true, the output is a list of numpy vectors.
467
+ Else, it is a list of pytorch tensors.
468
+ convert_to_tensor(`bool`, *optional*, defaults to False):
469
+ If true, you get one large tensor as return.
470
+ Overwrites any setting from convert_to_numpy
471
+ device(`torch.device`, *optional*, defaults to None):
472
+ Which torch.device to use for the computation
473
+ normalize_embeddings(`bool`, *optional*, defaults to False):
474
+ If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.
475
+ tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
476
+ Keyword arguments for the tokenizer
477
+ Returns:
478
+ By default, a list of tensors is returned.
479
+ If convert_to_tensor, a stacked tensor is returned.
480
+ If convert_to_numpy, a numpy matrix is returned.
481
+ """
482
+ if self.emb_pooler is None:
483
+ warnings.warn("No emb_pooler specified, defaulting to mean pooling.")
484
+ self.emb_pooler = 'mean'
485
+ from transformers import AutoTokenizer
486
+
487
+ self.tokenizer = AutoTokenizer.from_pretrained(self._name_or_path)
488
+ if self.emb_pooler != 'mean':
489
+ raise NotImplementedError
490
+
491
+ is_training = self.training
492
+ self.eval()
493
+
494
+ if show_progress_bar is None:
495
+ show_progress_bar = (
496
+ logger.getEffectiveLevel() == logging.INFO
497
+ or logger.getEffectiveLevel() == logging.DEBUG
498
+ )
499
+
500
+ if convert_to_tensor:
501
+ convert_to_numpy = False
502
+
503
+ if output_value != 'sentence_embedding':
504
+ convert_to_tensor = False
505
+ convert_to_numpy = False
506
+
507
+ input_was_string = False
508
+ if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
509
+ sentences = [sentences]
510
+ input_was_string = True
511
+
512
+ if device is not None:
513
+ self.to(device)
514
+
515
+ # TODO: Maybe use better length heuristic?
516
+ permutation = np.argsort([-len(i) for i in sentences])
517
+ inverse_permutation = np.argsort(permutation)
518
+ sentences = [sentences[idx] for idx in permutation]
519
+
520
+ tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True)
521
+ tokenizer_kwargs['max_length'] = tokenizer_kwargs.get('max_length', 8192)
522
+ tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True)
523
+
524
+ all_embeddings = []
525
+
526
+ if trange is not None:
527
+ range_iter = trange(
528
+ 0,
529
+ len(sentences),
530
+ batch_size,
531
+ desc="Encoding",
532
+ disable=not show_progress_bar,
533
+ )
534
+ else:
535
+ range_iter = range(0, len(sentences), batch_size)
536
+
537
+ for i in range_iter:
538
+ encoded_input = self.tokenizer(
539
+ sentences[i : i + batch_size],
540
+ return_tensors='pt',
541
+ **tokenizer_kwargs,
542
+ ).to(self.device)
543
+ token_embs = self.forward(**encoded_input)[0]
544
+
545
+ # Accumulate in fp32 to avoid overflow
546
+ token_embs = token_embs.float()
547
+
548
+ if output_value == 'token_embeddings':
549
+ raise NotImplementedError
550
+ elif output_value is None:
551
+ raise NotImplementedError
552
+ else:
553
+ embeddings = self.mean_pooling(
554
+ token_embs, encoded_input['attention_mask']
555
+ )
556
+
557
+ if normalize_embeddings:
558
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
559
+
560
+ if convert_to_numpy:
561
+ embeddings = embeddings.cpu()
562
+ all_embeddings.extend(embeddings)
563
+
564
+ all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
565
+
566
+ if convert_to_tensor:
567
+ all_embeddings = torch.stack(all_embeddings)
568
+ elif convert_to_numpy:
569
+ all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
570
+
571
+ if input_was_string:
572
+ all_embeddings = all_embeddings[0]
573
+
574
+ self.train(is_training)
575
+ return all_embeddings
576
+
577
+ def mean_pooling(
578
+ self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
579
+ ):
580
+ input_mask_expanded = (
581
+ attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
582
+ )
583
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
584
+ input_mask_expanded.sum(1), min=1e-9
585
+ )
586
+
587
  class BertForPreTraining(BertPreTrainedModel):
588
  def __init__(self, config: JinaBertConfig):
589
  super().__init__(config)