gmastrapas commited on
Commit
440a9f4
1 Parent(s): b8b8f72

fix: sentence-transformers port

Browse files
Files changed (2) hide show
  1. custom_st.py +108 -124
  2. modules.json +4 -4
custom_st.py CHANGED
@@ -12,168 +12,147 @@ from transformers import AutoConfig, AutoImageProcessor, AutoModel, AutoTokenize
12
 
13
 
14
  class Transformer(nn.Module):
15
- """Huggingface AutoModel to generate token embeddings.
16
- Loads the correct class, e.g. BERT / RoBERTa etc.
17
-
18
- Args:
19
- model_name_or_path: Huggingface models name
20
- (https://huggingface.co/models)
21
- max_seq_length: Truncate any inputs longer than max_seq_length
22
- model_args: Keyword arguments passed to the Huggingface
23
- Transformers model
24
- tokenizer_args: Keyword arguments passed to the Huggingface
25
- Transformers tokenizer
26
- config_args: Keyword arguments passed to the Huggingface
27
- Transformers config
28
- cache_dir: Cache dir for Huggingface Transformers to store/load
29
- models
30
- do_lower_case: If true, lowercases the input (independent if the
31
- model is cased or not)
32
- tokenizer_name_or_path: Name or path of the tokenizer. When
33
- None, then model_name_or_path is used
34
- """
35
-
36
  def __init__(
37
  self,
38
  model_name_or_path: str,
 
 
39
  max_seq_length: Optional[int] = None,
40
- model_args: Optional[Dict[str, Any]] = None,
41
- tokenizer_args: Optional[Dict[str, Any]] = None,
42
- config_args: Optional[Dict[str, Any]] = None,
43
- cache_dir: Optional[str] = None,
44
- do_lower_case: bool = False,
45
- tokenizer_name_or_path: str = None,
46
  ) -> None:
47
  super(Transformer, self).__init__()
48
- self.config_keys = ['max_seq_length', 'do_lower_case']
49
- self.do_lower_case = do_lower_case
50
- if model_args is None:
51
- model_args = {}
52
- if tokenizer_args is None:
53
- tokenizer_args = {}
54
- if config_args is None:
55
- config_args = {}
56
-
57
- config = AutoConfig.from_pretrained(
58
- model_name_or_path, **config_args, cache_dir=cache_dir
59
- )
60
- self.jina_clip = AutoModel.from_pretrained(
61
- model_name_or_path, config=config, cache_dir=cache_dir, **model_args
62
  )
63
- if max_seq_length is not None and 'model_max_length' not in tokenizer_args:
64
- tokenizer_args['model_max_length'] = max_seq_length
 
65
  self.tokenizer = AutoTokenizer.from_pretrained(
66
- (
67
- tokenizer_name_or_path
68
- if tokenizer_name_or_path is not None
69
- else model_name_or_path
70
- ),
71
- cache_dir=cache_dir,
72
- **tokenizer_args,
73
  )
74
- self.preprocessor = AutoImageProcessor.from_pretrained(
75
- (
76
- tokenizer_name_or_path
77
- if tokenizer_name_or_path is not None
78
- else model_name_or_path
79
- ),
80
- cache_dir=cache_dir,
81
- **tokenizer_args,
82
  )
83
 
84
  # No max_seq_length set. Try to infer from model
85
  if max_seq_length is None:
86
  if (
87
- hasattr(self.jina_clip, 'config')
88
- and hasattr(self.jina_clip.config, 'max_position_embeddings')
89
- and hasattr(self.tokenizer, 'model_max_length')
90
  ):
91
  max_seq_length = min(
92
- self.jina_clip.config.max_position_embeddings,
93
  self.tokenizer.model_max_length,
94
  )
95
-
96
  self.max_seq_length = max_seq_length
97
-
98
  if tokenizer_name_or_path is not None:
99
- self.jina_clip.config.tokenizer_class = self.tokenizer.__class__.__name__
100
-
101
- def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
102
- """Returns token_embeddings, cls_token"""
103
- if 'input_ids' in features:
104
- embedding = self.jina_clip.get_text_features(
105
- input_ids=features['input_ids']
106
- )
107
- else:
108
- embedding = self.jina_clip.get_image_features(
109
- pixel_values=features['pixel_values']
110
- )
111
- return {'sentence_embedding': embedding}
112
-
113
- def get_word_embedding_dimension(self) -> int:
114
- return self.config.text_config.embed_dim
115
 
116
  @staticmethod
117
- def decode_data_image(data_image_str):
118
- header, data = data_image_str.split(',', 1)
119
  image_data = base64.b64decode(data)
120
  return Image.open(BytesIO(image_data))
121
 
122
  def tokenize(
123
- self, batch: Union[List[str]], padding: Union[str, bool] = True
124
  ) -> Dict[str, torch.Tensor]:
125
- """Tokenizes a text and maps tokens to token-ids"""
126
- images = []
127
- texts = []
128
- for sample in batch:
 
 
 
 
129
  if isinstance(sample, str):
130
- if sample.startswith('http'):
131
  response = requests.get(sample)
132
- images.append(Image.open(BytesIO(response.content)).convert('RGB'))
133
- elif sample.startswith('data:image/'):
134
- images.append(self.decode_data_image(sample).convert('RGB'))
 
 
135
  else:
136
  try:
137
- images.append(Image.open(sample).convert('RGB'))
 
138
  except Exception as e:
139
  _ = str(e)
140
- texts.append(sample)
 
141
  elif isinstance(sample, Image.Image):
142
- images.append(sample.convert('RGB'))
 
143
 
144
- if images and texts:
145
- raise ValueError('Batch must contain either images or texts, not both')
146
-
147
- if texts:
148
- return self.tokenizer(
149
  texts,
150
  padding=padding,
151
- truncation='longest_first',
152
- return_tensors='pt',
153
  max_length=self.max_seq_length,
154
- )
155
- elif images:
156
- return self.preprocessor(images)
157
- return {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  def save(self, output_path: str, safe_serialization: bool = True) -> None:
160
- self.jina_clip.save_pretrained(
161
- output_path, safe_serialization=safe_serialization
162
- )
163
  self.tokenizer.save_pretrained(output_path)
164
- self.preprocessor.save_pretrained(output_path)
165
 
166
  @staticmethod
167
- def load(input_path: str) -> 'Transformer':
168
  # Old classes used other config names than 'sentence_bert_config.json'
169
  for config_name in [
170
- 'sentence_bert_config.json',
171
- 'sentence_roberta_config.json',
172
- 'sentence_distilbert_config.json',
173
- 'sentence_camembert_config.json',
174
- 'sentence_albert_config.json',
175
- 'sentence_xlm-roberta_config.json',
176
- 'sentence_xlnet_config.json',
177
  ]:
178
  sbert_config_path = os.path.join(input_path, config_name)
179
  if os.path.exists(sbert_config_path):
@@ -183,14 +162,19 @@ class Transformer(nn.Module):
183
  config = json.load(fIn)
184
 
185
  # Don't allow configs to set trust_remote_code
186
- if 'model_args' in config and 'trust_remote_code' in config['model_args']:
187
- config['model_args'].pop('trust_remote_code')
 
 
 
 
 
 
 
188
  if (
189
- 'tokenizer_args' in config
190
- and 'trust_remote_code' in config['tokenizer_args']
191
  ):
192
- config['tokenizer_args'].pop('trust_remote_code')
193
- if 'config_args' in config and 'trust_remote_code' in config['config_args']:
194
- config['config_args'].pop('trust_remote_code')
195
 
196
  return Transformer(model_name_or_path=input_path, **config)
 
12
 
13
 
14
  class Transformer(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def __init__(
16
  self,
17
  model_name_or_path: str,
18
+ tokenizer_name_or_path: Optional[str] = None,
19
+ image_processor_name_or_path: Optional[str] = None,
20
  max_seq_length: Optional[int] = None,
21
+ config_kwargs: Optional[Dict[str, Any]] = None,
22
+ model_kwargs: Optional[Dict[str, Any]] = None,
23
+ tokenizer_kwargs: Optional[Dict[str, Any]] = None,
24
+ image_processor_kwargs: Optional[Dict[str, Any]] = None,
 
 
25
  ) -> None:
26
  super(Transformer, self).__init__()
27
+
28
+ config_kwargs = config_kwargs or {}
29
+ model_kwargs = model_kwargs or {}
30
+ tokenizer_kwargs = tokenizer_kwargs or {}
31
+ image_processor_kwargs = image_processor_kwargs or {}
32
+
33
+ config = AutoConfig.from_pretrained(model_name_or_path, **config_kwargs)
34
+ self.model = AutoModel.from_pretrained(
35
+ model_name_or_path, config=config, **model_kwargs
 
 
 
 
 
36
  )
37
+ if max_seq_length is not None and "model_max_length" not in tokenizer_kwargs:
38
+ tokenizer_kwargs["model_max_length"] = max_seq_length
39
+
40
  self.tokenizer = AutoTokenizer.from_pretrained(
41
+ tokenizer_name_or_path or model_name_or_path,
42
+ **tokenizer_kwargs,
 
 
 
 
 
43
  )
44
+ self.image_processor = AutoImageProcessor.from_pretrained(
45
+ image_processor_name_or_path or model_name_or_path,
46
+ **image_processor_kwargs,
 
 
 
 
 
47
  )
48
 
49
  # No max_seq_length set. Try to infer from model
50
  if max_seq_length is None:
51
  if (
52
+ hasattr(self.model, "config")
53
+ and hasattr(self.model.config, "max_position_embeddings")
54
+ and hasattr(self.tokenizer, "model_max_length")
55
  ):
56
  max_seq_length = min(
57
+ self.model.config.max_position_embeddings,
58
  self.tokenizer.model_max_length,
59
  )
 
60
  self.max_seq_length = max_seq_length
 
61
  if tokenizer_name_or_path is not None:
62
+ self.model.config.tokenizer_class = self.tokenizer.__class__.__name__
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  @staticmethod
65
+ def _decode_data_image(data_image_str: str) -> Image.Image:
66
+ header, data = data_image_str.split(",", 1)
67
  image_data = base64.b64decode(data)
68
  return Image.open(BytesIO(image_data))
69
 
70
  def tokenize(
71
+ self, texts: List[Union[str, Image.Image]], padding: Union[str, bool] = True
72
  ) -> Dict[str, torch.Tensor]:
73
+ """
74
+ Encodes input samples. Text samples are tokenized. Image URLs, image data
75
+ buffers and PIL images are passed through the image processor.
76
+ """
77
+ _images = []
78
+ _texts = []
79
+ _image_or_text_descriptors = []
80
+ for sample in texts:
81
  if isinstance(sample, str):
82
+ if sample.startswith("http"):
83
  response = requests.get(sample)
84
+ _images.append(Image.open(BytesIO(response.content)).convert("RGB"))
85
+ _image_or_text_descriptors.append(0)
86
+ elif sample.startswith("data:image/"):
87
+ _images.append(self._decode_data_image(sample).convert("RGB"))
88
+ _image_or_text_descriptors.append(0)
89
  else:
90
  try:
91
+ _images.append(Image.open(sample).convert("RGB"))
92
+ _image_or_text_descriptors.append(0)
93
  except Exception as e:
94
  _ = str(e)
95
+ _texts.append(sample)
96
+ _image_or_text_descriptors.append(1)
97
  elif isinstance(sample, Image.Image):
98
+ _images.append(sample.convert("RGB"))
99
+ _image_or_text_descriptors.append(0)
100
 
101
+ encoding = {}
102
+ if len(_texts):
103
+ encoding["input_ids"] = self.tokenizer(
 
 
104
  texts,
105
  padding=padding,
106
+ truncation="longest_first",
107
+ return_tensors="pt",
108
  max_length=self.max_seq_length,
109
+ ).input_ids
110
+
111
+ if len(_images):
112
+ encoding["pixel_values"] = self.image_processor(
113
+ _images, return_tensors="pt"
114
+ ).pixel_values
115
+
116
+ encoding["image_text_info"] = _image_or_text_descriptors
117
+ return encoding
118
+
119
+ def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
120
+ image_embeddings = []
121
+ text_embeddings = []
122
+
123
+ if "pixel_values" in features:
124
+ image_embeddings = self.model.get_image_features(features["pixel_values"])
125
+ if "input_ids" in features:
126
+ text_embeddings = self.model.get_text_features(features["input_ids"])
127
+
128
+ sentence_embedding = []
129
+ image_features = iter(image_embeddings)
130
+ text_features = iter(text_embeddings)
131
+ for _, _input_type in enumerate(features["image_text_info"]):
132
+ if _input_type == 0:
133
+ sentence_embedding.append(next(image_features))
134
+ else:
135
+ sentence_embedding.append(next(text_features))
136
+
137
+ features["sentence_embedding"] = torch.stack(sentence_embedding).float()
138
+ return features
139
 
140
  def save(self, output_path: str, safe_serialization: bool = True) -> None:
141
+ self.model.save_pretrained(output_path, safe_serialization=safe_serialization)
 
 
142
  self.tokenizer.save_pretrained(output_path)
143
+ self.image_processor.save_pretrained(output_path)
144
 
145
  @staticmethod
146
+ def load(input_path: str) -> "Transformer":
147
  # Old classes used other config names than 'sentence_bert_config.json'
148
  for config_name in [
149
+ "sentence_bert_config.json",
150
+ "sentence_roberta_config.json",
151
+ "sentence_distilbert_config.json",
152
+ "sentence_camembert_config.json",
153
+ "sentence_albert_config.json",
154
+ "sentence_xlm-roberta_config.json",
155
+ "sentence_xlnet_config.json",
156
  ]:
157
  sbert_config_path = os.path.join(input_path, config_name)
158
  if os.path.exists(sbert_config_path):
 
162
  config = json.load(fIn)
163
 
164
  # Don't allow configs to set trust_remote_code
165
+ if "config_kwargs" in config and "trust_remote_code" in config["config_kwargs"]:
166
+ config["config_kwargs"].pop("trust_remote_code")
167
+ if "model_kwargs" in config and "trust_remote_code" in config["model_kwargs"]:
168
+ config["model_kwargs"].pop("trust_remote_code")
169
+ if (
170
+ "tokenizer_kwargs" in config
171
+ and "trust_remote_code" in config["tokenizer_kwargs"]
172
+ ):
173
+ config["tokenizer_kwargs"].pop("trust_remote_code")
174
  if (
175
+ "image_processor_kwargs" in config
176
+ and "trust_remote_code" in config["image_processor_kwargs"]
177
  ):
178
+ config["image_processor_kwargs"].pop("trust_remote_code")
 
 
179
 
180
  return Transformer(model_name_or_path=input_path, **config)
modules.json CHANGED
@@ -1,14 +1,14 @@
1
  [
2
  {
3
  "idx": 0,
4
- "name": "0",
5
  "path": "",
6
  "type": "custom_st.Transformer"
7
  },
8
  {
9
- "idx": 2,
10
- "name": "2",
11
- "path": "2_Normalize",
12
  "type": "sentence_transformers.models.Normalize"
13
  }
14
  ]
 
1
  [
2
  {
3
  "idx": 0,
4
+ "name": "transformer",
5
  "path": "",
6
  "type": "custom_st.Transformer"
7
  },
8
  {
9
+ "idx": 1,
10
+ "name": "normalizer",
11
+ "path": "1_Normalize",
12
  "type": "sentence_transformers.models.Normalize"
13
  }
14
  ]