Markus28 commited on
Commit
559b3ae
1 Parent(s): 8e3d0b8

feat: removed tokenizer

Browse files
Files changed (1) hide show
  1. tokenizer.py +0 -107
tokenizer.py DELETED
@@ -1,107 +0,0 @@
1
- import torch
2
- from enum import IntEnum
3
- import numpy as np
4
- from transformers import RobertaTokenizer, BatchEncoding, RobertaTokenizerFast
5
- import warnings
6
-
7
-
8
- def get_tokenizer(parent_class):
9
- class TokenizerClass(parent_class):
10
- class TaskTypes(IntEnum):
11
- NULL = 0
12
- QUERY = 1
13
- DOCUMENT = 2
14
- STS = 3
15
- CLUSTERING = 4
16
- CLASSIFICATION = 5
17
-
18
- def __init__(self, *args, **kwargs):
19
- """
20
- This class dynamically extends a given tokenizer class from the HF
21
- Transformers library (RobertaTokenizer or RobertaTokenizerFast).
22
- The task_type_ids are used to pass instruction information to the model.
23
- A task_type should either be an integer or a sequence of integers with the same
24
- length as the batch size.
25
- """
26
- super().__init__(*args, **kwargs)
27
-
28
- def __call__(self, *args, task_type: TaskTypes = None, **kwargs):
29
- batch_encoding = super().__call__(*args, **kwargs)
30
- if task_type is not None:
31
- batch_encoding = self._add_task_type_ids(
32
- batch_encoding, task_type, kwargs.get('return_tensors')
33
- )
34
- return batch_encoding
35
-
36
- def _batch_encode_plus(self, *args, task_type: TaskTypes = None, **kwargs):
37
- batch_encoding = super()._batch_encode_plus(*args, **kwargs)
38
- if task_type is not None:
39
- batch_encoding = self._add_task_type_ids(
40
- batch_encoding, task_type, kwargs.get('return_tensors')
41
- )
42
- return batch_encoding
43
-
44
- def _encode_plus(self, *args, task_type: TaskTypes = None, **kwargs):
45
- batch_encoding = super()._encode_plus(*args, **kwargs)
46
- if task_type is not None:
47
- batch_encoding = self._add_task_type_ids(
48
- batch_encoding, task_type, kwargs.get('return_tensors')
49
- )
50
- return batch_encoding
51
-
52
- @classmethod
53
- def _add_task_type_ids(
54
- cls, batch_encoding: BatchEncoding, task_type: TaskTypes, tensor_type: str
55
- ):
56
- return BatchEncoding(
57
- {
58
- 'task_type_ids': cls._get_task_type_ids(batch_encoding, task_type),
59
- **batch_encoding,
60
- },
61
- tensor_type=tensor_type,
62
- )
63
-
64
- @staticmethod
65
- def _get_task_type_ids(batch_encoding: BatchEncoding, task_type: TaskTypes):
66
- def apply_task_type(m, x):
67
- x = torch.tensor(x)
68
- assert (
69
- len(x.shape) == 0 or x.shape[0] == m.shape[0]
70
- ), 'The shape of task_type does not match the size of the batch.'
71
- return m * x if len(x.shape) == 0 else m * x[:, None]
72
-
73
- if isinstance(batch_encoding['input_ids'], torch.Tensor):
74
- shape = batch_encoding['input_ids'].shape
75
- return apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
76
- else:
77
- try:
78
- shape = torch.tensor(batch_encoding['input_ids']).shape
79
- except:
80
- raise ValueError(
81
- "Unable to create tensor, you should probably "
82
- "activate truncation and/or padding with "
83
- "'padding=True' 'truncation=True' to have batched "
84
- "tensors with the same length."
85
- )
86
- if isinstance(batch_encoding['input_ids'], list):
87
- return (
88
- apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
89
- ).tolist()
90
- elif isinstance(batch_encoding['input_ids'], np.array):
91
- return (
92
- apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
93
- ).numpy()
94
- else:
95
- warnings.warn(
96
- 'input_ids is not a torch tensor, numpy array, or list. Returning torch tensor'
97
- )
98
- return apply_task_type(
99
- torch.ones(shape, dtype=torch.long), task_type
100
- )
101
-
102
- return TokenizerClass
103
-
104
-
105
- JinaTokenizer = get_tokenizer(RobertaTokenizer)
106
- JinaTokenizerFast = get_tokenizer(RobertaTokenizerFast)
107
-