alanzty commited on
Commit
e5b81bf
1 Parent(s): 94a810e

enabling transformers

Browse files
config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MarqoFashionSigLIP"
4
+ ],
5
+ "open_clip_model_name": "hf-hub:Marqo/marqo-ecommerce-embeddings-L",
6
+ "torch_dtype": "float32",
7
+ "transformers_version": "4.45.1"
8
+ }
marqo_fashionSigLIP.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from open_clip import create_model
3
+ from transformers import PretrainedConfig, PreTrainedModel
4
+ from transformers.models.siglip.modeling_siglip import SiglipOutput
5
+ from typing import Optional, Tuple, Union, List
6
+ from transformers.feature_extraction_utils import BatchFeature
7
+ from transformers.image_utils import ImageInput
8
+ from transformers.processing_utils import ProcessorMixin
9
+ from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
10
+ from transformers.utils import TensorType
11
+ import string
12
+ import ftfy
13
+ import html
14
+
15
+ def basic_clean(text):
16
+ text = ftfy.fix_text(text)
17
+ text = html.unescape(html.unescape(text))
18
+ return text.strip()
19
+
20
+ def canonicalize_text(
21
+ text,
22
+ *,
23
+ keep_punctuation_exact_string=None,
24
+ trans_punctuation: dict = str.maketrans("", "", string.punctuation),
25
+ ):
26
+ """Returns canonicalized `text` (lowercase and punctuation removed).
27
+
28
+ From: https://github.com/google-research/big_vision/blob/53f18caf27a9419231bbf08d3388b07671616d3d/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94
29
+
30
+ Args:
31
+ text: string to be canonicalized.
32
+ keep_punctuation_exact_string: If provided, then this exact string kept.
33
+ For example providing '{}' will keep any occurrences of '{}' (but will
34
+ still remove '{' and '}' that appear separately).
35
+ """
36
+ text = text.replace("_", " ")
37
+ if keep_punctuation_exact_string:
38
+ text = keep_punctuation_exact_string.join(
39
+ part.translate(trans_punctuation)
40
+ for part in text.split(keep_punctuation_exact_string)
41
+ )
42
+ else:
43
+ text = text.translate(trans_punctuation)
44
+ text = text.lower()
45
+ text = " ".join(text.split())
46
+ return text.strip()
47
+
48
+ def _clean_canonicalize(x):
49
+ # basic, remove whitespace, remove punctuation, lower case
50
+ return canonicalize_text(basic_clean(x))
51
+
52
+ class MarqoFashionSigLIPConfig(PretrainedConfig):
53
+ def __init__(
54
+ self,
55
+ open_clip_model_name: str = "",
56
+ **kwargs,
57
+ ):
58
+ super().__init__(**kwargs)
59
+ self.open_clip_model_name = open_clip_model_name
60
+
61
+ class MarqoFashionSigLIPProcessor(ProcessorMixin):
62
+ r"""
63
+ Constructs a Siglip processor which wraps a Siglip image processor and a Siglip tokenizer into a single processor.
64
+
65
+ [`SiglipProcessor`] offers all the functionalities of [`SiglipImageProcessor`] and [`SiglipTokenizer`]. See the
66
+ [`~SiglipProcessor.__call__`] and [`~SiglipProcessor.decode`] for more information.
67
+
68
+ Args:
69
+ image_processor ([`SiglipImageProcessor`]):
70
+ The image processor is a required input.
71
+ tokenizer ([`T5TokenizerFast`]):
72
+ The tokenizer is a required input.
73
+ """
74
+
75
+ attributes = ["image_processor", "tokenizer"]
76
+ image_processor_class = "SiglipImageProcessor"
77
+ tokenizer_class = "T5TokenizerFast"
78
+
79
+ def __init__(self, image_processor, tokenizer):
80
+ super().__init__(image_processor, tokenizer)
81
+
82
+ def __call__(
83
+ self,
84
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
85
+ images: ImageInput = None,
86
+ padding: Union[bool, str, PaddingStrategy] = False,
87
+ truncation: Union[bool, str, TruncationStrategy] = None,
88
+ max_length: int = None,
89
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
90
+ ) -> BatchFeature:
91
+ """
92
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
93
+ and `kwargs` arguments to SiglipTokenizer's [`~SiglipTokenizer.__call__`] if `text` is not `None` to encode
94
+ the text. To prepare the image(s), this method forwards the `images` argument to
95
+ SiglipImageProcessor's [`~SiglipImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
96
+ of the above two methods for more information.
97
+
98
+ Args:
99
+ text (`str`, `List[str]`, `List[List[str]]`):
100
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
101
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
102
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
103
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
104
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
105
+ tensor. Both channels-first and channels-last formats are supported.
106
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
107
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
108
+ index) among:
109
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
110
+ sequence if provided).
111
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
112
+ acceptable input length for the model if that argument is not provided.
113
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
114
+ lengths).
115
+ max_length (`int`, *optional*):
116
+ Maximum length of the returned list and optionally padding length (see above).
117
+ truncation (`bool`, *optional*):
118
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
119
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
120
+ If set, will return tensors of a particular framework. Acceptable values are:
121
+
122
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
123
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
124
+ - `'np'`: Return NumPy `np.ndarray` objects.
125
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
126
+
127
+ Returns:
128
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
129
+
130
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
131
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
132
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
133
+ `None`).
134
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
135
+ """
136
+
137
+ if text is None and images is None:
138
+ raise ValueError("You have to specify either text or images. Both cannot be none.")
139
+
140
+ if text is not None:
141
+ if isinstance(text, str):
142
+ text = [text]
143
+ text = [_clean_canonicalize(raw_text) for raw_text in text]
144
+ encoding = self.tokenizer(
145
+ text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
146
+ )
147
+
148
+ if images is not None:
149
+ try:
150
+ images = [image.convert('RGB') for image in images] if isinstance(images, list) else images.convert('RGB')
151
+ except:
152
+ images = images
153
+ image_features = self.image_processor(images, return_tensors=return_tensors)
154
+
155
+ if text is not None and images is not None:
156
+ encoding["pixel_values"] = image_features.pixel_values
157
+ return encoding
158
+ elif text is not None:
159
+ return encoding
160
+ else:
161
+ return BatchFeature(data=dict(**image_features), tensor_type=return_tensors)
162
+
163
+ def decode(self, *args, **kwargs):
164
+ """
165
+ This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
166
+ the docstring of this method for more information.
167
+ """
168
+ return self.tokenizer.decode(*args, **kwargs)
169
+
170
+ def batch_decode(self, *args, **kwargs):
171
+ """
172
+ This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
173
+ refer to the docstring of this method for more information.
174
+ """
175
+ return self.tokenizer.batch_decode(*args, **kwargs)
176
+
177
+ @property
178
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->Siglip, T5->Siglip
179
+ def model_input_names(self):
180
+ tokenizer_input_names = self.tokenizer.model_input_names
181
+ image_processor_input_names = self.image_processor.model_input_names
182
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
183
+
184
+ class MarqoFashionSigLIP(PreTrainedModel):
185
+ config_class = MarqoFashionSigLIPConfig
186
+
187
+ def __init__(self, config: MarqoFashionSigLIPConfig):
188
+ super().__init__(config)
189
+ self.config = config
190
+ self.model = create_model(config.open_clip_model_name, output_dict=True)
191
+ self.model.eval()
192
+ self.model.to(self.device)
193
+
194
+ def get_image_features(
195
+ self,
196
+ pixel_values: torch.FloatTensor,
197
+ normalize: bool = False,
198
+ **kwargs
199
+ ) -> torch.FloatTensor:
200
+
201
+ with torch.inference_mode():
202
+ image_features = self.model.encode_image(pixel_values, normalize=normalize)
203
+ return image_features
204
+
205
+ def get_text_features(
206
+ self,
207
+ input_ids: torch.Tensor,
208
+ normalize: bool = False,
209
+ **kwargs
210
+ ) -> torch.FloatTensor:
211
+
212
+ with torch.inference_mode():
213
+ text_features = self.model.encode_text(input_ids, normalize=normalize)
214
+ return text_features
215
+
216
+ def forward(
217
+ self,
218
+ input_ids: Optional[torch.LongTensor] = None,
219
+ pixel_values: Optional[torch.FloatTensor] = None,
220
+ return_dict: Optional[bool] = None,
221
+ ) -> Union[Tuple, SiglipOutput]:
222
+
223
+ vision_outputs = self.get_image_features(pixel_values=pixel_values, normalize=True)
224
+ text_outputs = self.get_text_features(input_ids=input_ids, normalize=True)
225
+
226
+ logits_per_text = text_outputs @ vision_outputs.T
227
+ logits_per_image = logits_per_text.T
228
+
229
+ if not return_dict:
230
+ return logits_per_image, logits_per_text, text_outputs, vision_outputs
231
+
232
+ return SiglipOutput(
233
+ logits_per_image=logits_per_image,
234
+ logits_per_text=logits_per_text,
235
+ text_embeds=text_outputs,
236
+ image_embeds=vision_outputs
237
+ )
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f54e3323fc98caddba9626aa9771efd873c3cb9d63cc65b4619c2ccb6213e4e
3
+ size 2608674872
preprocessor_config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "marqo_fashionSigLIP.MarqoFashionSigLIPProcessor"
4
+ },
5
+ "do_normalize": true,
6
+ "do_rescale": true,
7
+ "do_resize": true,
8
+ "do_convert_rgb": true,
9
+ "image_processor_type": "SiglipImageProcessor",
10
+ "image_mean": [
11
+ 0.5,
12
+ 0.5,
13
+ 0.5
14
+ ],
15
+ "processor_class": "marqo_fashionSigLIP.MarqoFashionSigLIPProcessor",
16
+ "resample": 3,
17
+ "rescale_factor": 0.00392156862745098,
18
+ "size": {
19
+ "height": 224,
20
+ "width": 224
21
+ },
22
+ "image_std": [
23
+ 0.5,
24
+ 0.5,
25
+ 0.5
26
+ ]
27
+ }
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d60acb128cf7b7f2536e8f38a5b18a05535c9e14c7a355904270e15b0945ea86
3
+ size 791656