Galuh Sahid commited on
Commit
bb13925
1 Parent(s): d354c6f
hybrid_clip/configuration_hybrid_clip.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers.utils import logging
5
+
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+
10
+ class HybridCLIPConfig(PretrainedConfig):
11
+ r"""
12
+ :class:`HybridCLIPConfig` is the configuration class to store the configuration of a
13
+ :class:`~HybridCLIPModel`. It is used to instantiate HybridCLIPModel model according to the specified arguments,
14
+ defining the text model and vision model configs.
15
+
16
+ Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
17
+ outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
18
+
19
+ Args:
20
+ text_config_dict (:obj:`dict`):
21
+ Dictionary of configuration options that defines text model config.
22
+ vision_config_dict (:obj:`dict`):
23
+ Dictionary of configuration options that defines vison model config.
24
+ projection_dim (:obj:`int`, `optional`, defaults to 512):
25
+ Dimentionality of text and vision projection layers.
26
+ kwargs (`optional`):
27
+ Dictionary of keyword arguments.
28
+
29
+ Examples::
30
+
31
+ >>> from transformers import BertConfig, CLIPConfig, HybridCLIPConfig, FlaxHybridCLIP
32
+
33
+ >>> # Initializing a BERT and CLIP configuration
34
+ >>> config_text = BertConfig()
35
+ >>> config_vision = CLIPConfig()
36
+
37
+ >>> config = HybridCLIPConfig.from_text_vision_configs(config_text, config_vision, projection_dim=512)
38
+
39
+ >>> # Initializing a BERT and CLIPVision model
40
+ >>> model = EncoderDecoderModel(config=config)
41
+
42
+ >>> # Accessing the model configuration
43
+ >>> config_text = model.config.text_config
44
+ >>> config_vision = model.config.vision_config
45
+
46
+ >>> # Saving the model, including its configuration
47
+ >>> model.save_pretrained('my-model')
48
+
49
+ >>> # loading model and config from pretrained folder
50
+ >>> encoder_decoder_config = HybridCLIPConfig.from_pretrained('my-model')
51
+ >>> model = FlaxHybridCLIP.from_pretrained('my-model', config=encoder_decoder_config)
52
+ """
53
+
54
+ model_type = "hybrid-clip"
55
+ is_composition = True
56
+
57
+ def __init__(self, projection_dim=512, **kwargs):
58
+ super().__init__(**kwargs)
59
+
60
+ if "text_config" not in kwargs:
61
+ raise ValueError("`text_config` can not be `None`.")
62
+
63
+ if "vision_config" not in kwargs:
64
+ raise ValueError("`vision_config` can not be `None`.")
65
+
66
+ text_config = kwargs.pop("text_config")
67
+ vision_config = kwargs.pop("vision_config")
68
+
69
+ text_model_type = text_config.pop("model_type")
70
+ vision_model_type = vision_config.pop("model_type")
71
+
72
+ from transformers import AutoConfig
73
+
74
+ self.text_config = AutoConfig.for_model(text_model_type, **text_config)
75
+
76
+ if vision_model_type == "clip":
77
+ self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config).vision_config
78
+ else:
79
+ self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config)
80
+
81
+ self.projection_dim = projection_dim
82
+ self.initializer_factor = 1.0
83
+
84
+ @classmethod
85
+ def from_text_vision_configs(cls, text_config: PretrainedConfig, vision_config: PretrainedConfig, **kwargs):
86
+ r"""
87
+ Instantiate a :class:`HybridCLIPConfig` (or a derived class) from text model configuration and
88
+ vision model configuration.
89
+
90
+ Returns:
91
+ :class:`HybridCLIPConfig`: An instance of a configuration object
92
+ """
93
+
94
+ return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
95
+
96
+ def to_dict(self):
97
+ """
98
+ Serializes this instance to a Python dictionary. Override the default
99
+ :meth:`~transformers.PretrainedConfig.to_dict`.
100
+
101
+ Returns:
102
+ :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
103
+ """
104
+ output = copy.deepcopy(self.__dict__)
105
+ output["text_config"] = self.text_config.to_dict()
106
+ output["vision_config"] = self.vision_config.to_dict()
107
+ output["model_type"] = self.__class__.model_type
108
+ return output
hybrid_clip/modeling_hybrid_clip.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional, Tuple
17
+
18
+ import flax.linen as nn
19
+ import jax
20
+ import jax.numpy as jnp
21
+ from configuration_hybrid_clip import HybridCLIPConfig
22
+ from flax.core.frozen_dict import FrozenDict
23
+ from transformers import FLAX_MODEL_MAPPING, FlaxCLIPVisionModel
24
+ from transformers.modeling_flax_utils import FlaxPreTrainedModel
25
+ from transformers.models.clip.modeling_flax_clip import FlaxCLIPOutput
26
+ from transformers.utils import logging
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+ import warnings
32
+ warnings.filterwarnings("ignore")
33
+
34
+
35
+ class FlaxHybridCLIPModule(nn.Module):
36
+ config: HybridCLIPConfig
37
+ dtype: jnp.dtype = jnp.float32
38
+ freeze_backbones: bool = False
39
+
40
+
41
+ def setup(self):
42
+ text_config = self.config.text_config
43
+ vision_config = self.config.vision_config
44
+
45
+ self.projection_dim = self.config.projection_dim
46
+ self.text_embed_dim = text_config.hidden_size
47
+ self.vision_embed_dim = vision_config.hidden_size
48
+
49
+ text_module = FLAX_MODEL_MAPPING[self.config.text_config.__class__].module_class
50
+ vision_module = FLAX_MODEL_MAPPING.get(self.config.vision_config.__class__, FlaxCLIPVisionModel).module_class
51
+
52
+ self.text_model = text_module(text_config, dtype=self.dtype)
53
+ self.vision_model = vision_module(vision_config, dtype=self.dtype)
54
+
55
+ self.visual_projection = nn.Dense(
56
+ self.projection_dim,
57
+ dtype=self.dtype,
58
+ kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
59
+ use_bias=False,
60
+ )
61
+ self.text_projection = nn.Dense(
62
+ self.projection_dim,
63
+ dtype=self.dtype,
64
+ kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
65
+ use_bias=False,
66
+ )
67
+ self.logit_scale = self.param("logit_scale", jax.nn.initializers.ones, []) * 20
68
+ #self.logit_scale = self.param("logit_scale", jnp.array([20.]), [], mutable=False)
69
+ #self.logit_scale = self.param("logit_scale", jax.nn.initializers.ones, [])
70
+
71
+ def __call__(
72
+ self,
73
+ input_ids=None,
74
+ pixel_values=None,
75
+ attention_mask=None,
76
+ position_ids=None,
77
+ token_type_ids=None,
78
+ deterministic: bool = True,
79
+ output_attentions=None,
80
+ output_hidden_states=None,
81
+ return_dict=None,
82
+ ):
83
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
84
+
85
+ vision_outputs = self.vision_model(
86
+ pixel_values=pixel_values,
87
+ deterministic=deterministic,
88
+ output_attentions=output_attentions,
89
+ output_hidden_states=output_hidden_states,
90
+ return_dict=return_dict,
91
+ )
92
+
93
+ text_outputs = self.text_model(
94
+ input_ids=input_ids,
95
+ attention_mask=attention_mask,
96
+ token_type_ids=token_type_ids,
97
+ position_ids=position_ids,
98
+ deterministic=deterministic,
99
+ output_attentions=output_attentions,
100
+ output_hidden_states=output_hidden_states,
101
+ return_dict=return_dict,
102
+ )
103
+
104
+ image_embeds = vision_outputs[1]
105
+ if self.freeze_backbones:
106
+ image_embeds = jax.lax.stop_gradient(image_embeds)
107
+ image_embeds = self.visual_projection(image_embeds)
108
+
109
+ text_embeds = text_outputs[1]
110
+ if self.freeze_backbones:
111
+ text_embeds = jax.lax.stop_gradient(text_embeds)
112
+ text_embeds = self.text_projection(text_embeds)
113
+
114
+ # normalized features
115
+ image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=-1, keepdims=True)
116
+ text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=-1, keepdims=True)
117
+
118
+ # cosine similarity as logits
119
+ # logit_scale = jnp.exp(self.logit_scale)
120
+ logit_scale = jax.lax.stop_gradient(self.logit_scale)
121
+ logits_per_text = jnp.matmul(text_embeds, image_embeds.T) * logit_scale
122
+ logits_per_image = logits_per_text.T
123
+
124
+ if not return_dict:
125
+ return (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
126
+
127
+ return FlaxCLIPOutput(
128
+ logits_per_image=logits_per_image,
129
+ logits_per_text=logits_per_text,
130
+ text_embeds=text_embeds,
131
+ image_embeds=image_embeds,
132
+ text_model_output=text_outputs,
133
+ vision_model_output=vision_outputs,
134
+ )
135
+
136
+
137
+ class FlaxHybridCLIP(FlaxPreTrainedModel):
138
+ config_class = HybridCLIPConfig
139
+ module_class = FlaxHybridCLIPModule
140
+
141
+ def __init__(
142
+ self,
143
+ config: HybridCLIPConfig,
144
+ input_shape: Optional[Tuple] = None,
145
+ seed: int = 0,
146
+ dtype: jnp.dtype = jnp.float32,
147
+ **kwargs
148
+ ):
149
+ if input_shape is None:
150
+ input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
151
+
152
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
153
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
154
+
155
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
156
+ # init input tensor
157
+ input_ids = jnp.zeros(input_shape[0], dtype="i4")
158
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0])
159
+ token_type_ids = jnp.ones_like(input_ids)
160
+ attention_mask = jnp.ones_like(input_ids)
161
+
162
+ pixel_values = jax.random.normal(rng, input_shape[1])
163
+
164
+ params_rng, dropout_rng = jax.random.split(rng)
165
+ rngs = {"params": params_rng, "dropout": dropout_rng}
166
+
167
+ return self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids, token_type_ids)["params"]
168
+
169
+ def __call__(
170
+ self,
171
+ input_ids,
172
+ pixel_values,
173
+ attention_mask=None,
174
+ position_ids=None,
175
+ token_type_ids=None,
176
+ params: dict = None,
177
+ dropout_rng: jax.random.PRNGKey = None,
178
+ train: bool = False,
179
+ output_attentions: Optional[bool] = None,
180
+ output_hidden_states: Optional[bool] = None,
181
+ return_dict: Optional[bool] = None,
182
+ ):
183
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
184
+ output_hidden_states = (
185
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
186
+ )
187
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
188
+
189
+ if position_ids is None:
190
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
191
+
192
+ if token_type_ids is None:
193
+ token_type_ids = jnp.zeros_like(input_ids)
194
+
195
+ if attention_mask is None:
196
+ attention_mask = jnp.ones_like(input_ids)
197
+
198
+ # Handle any PRNG if needed
199
+ rngs = {}
200
+ if dropout_rng is not None:
201
+ rngs["dropout"] = dropout_rng
202
+
203
+ return self.module.apply(
204
+ {"params": params or self.params},
205
+ jnp.array(input_ids, dtype="i4"),
206
+ jnp.array(pixel_values, dtype=jnp.float32),
207
+ jnp.array(attention_mask, dtype="i4"),
208
+ jnp.array(position_ids, dtype="i4"),
209
+ jnp.array(token_type_ids, dtype="i4"),
210
+ not train,
211
+ output_attentions,
212
+ output_hidden_states,
213
+ return_dict,
214
+ rngs=rngs,
215
+ )
216
+
217
+ def get_text_features(
218
+ self,
219
+ input_ids,
220
+ attention_mask=None,
221
+ position_ids=None,
222
+ token_type_ids=None,
223
+ dropout_rng: jax.random.PRNGKey = None,
224
+ train=False,
225
+ ):
226
+ r"""
227
+ Args:
228
+ input_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`):
229
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
230
+ provide it.
231
+
232
+ Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See
233
+ :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
234
+ for details.
235
+
236
+ `What are input IDs? <../glossary.html#input-ids>`__
237
+
238
+ Returns:
239
+ text_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The text embeddings
240
+ obtained by applying the projection layer to the pooled output of text model.
241
+ """
242
+ if position_ids is None:
243
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
244
+
245
+ if token_type_ids is None:
246
+ token_type_ids = jnp.zeros_like(input_ids)
247
+
248
+ if attention_mask is None:
249
+ attention_mask = jnp.ones_like(input_ids)
250
+
251
+ # Handle any PRNG if needed
252
+ rngs = {}
253
+ if dropout_rng is not None:
254
+ rngs["dropout"] = dropout_rng
255
+
256
+ def _get_features(module, input_ids, attention_mask, position_ids, token_type_ids, deterministic):
257
+ text_outputs = module.text_model(
258
+ input_ids=input_ids,
259
+ attention_mask=attention_mask,
260
+ position_ids=position_ids,
261
+ token_type_ids=token_type_ids,
262
+ deterministic=deterministic,
263
+ )
264
+ pooled_output = text_outputs[1]
265
+ text_features = module.text_projection(pooled_output)
266
+ return text_features
267
+
268
+ return self.module.apply(
269
+ {"params": self.params},
270
+ jnp.array(input_ids, dtype="i4"),
271
+ jnp.array(attention_mask, dtype="i4"),
272
+ jnp.array(position_ids, dtype="i4"),
273
+ jnp.array(token_type_ids, dtype="i4"),
274
+ not train,
275
+ method=_get_features,
276
+ rngs=rngs,
277
+ )
278
+
279
+ def get_image_features(self, pixel_values, dropout_rng: jax.random.PRNGKey = None, train=False):
280
+ r"""
281
+ Args:
282
+ pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
283
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained
284
+ using :class:`~transformers.ImageFeatureExtractionMixin`. See
285
+ :meth:`transformers.ImageFeatureExtractionMixin.__call__` for details.
286
+
287
+ Returns:
288
+ image_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The image embeddings
289
+ obtained by applying the projection layer to the pooled output of vision model.
290
+ """
291
+
292
+ # Handle any PRNG if needed
293
+ rngs = {}
294
+ if dropout_rng is not None:
295
+ rngs["dropout"] = dropout_rng
296
+
297
+ def _get_features(module, pixel_values, deterministic):
298
+ vision_outputs = module.vision_model(pixel_values=pixel_values, deterministic=deterministic)
299
+ pooled_output = vision_outputs[1] # pooled_output
300
+ image_features = module.visual_projection(pooled_output)
301
+ return image_features
302
+
303
+ return self.module.apply(
304
+ {"params": self.params},
305
+ jnp.array(pixel_values, dtype=jnp.float32),
306
+ not train,
307
+ method=_get_features,
308
+ rngs=rngs,
309
+ )
310
+
311
+ @classmethod
312
+ def from_text_vision_pretrained(
313
+ cls,
314
+ text_model_name_or_path: str = None,
315
+ vision_model_name_or_path: str = None,
316
+ *model_args,
317
+ **kwargs,
318
+ ) -> FlaxPreTrainedModel:
319
+ """
320
+ Params:
321
+ text_model_name_or_path (:obj: `str`, `optional`):
322
+ Information necessary to initiate the text model. Can be either:
323
+
324
+ - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
325
+ Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
326
+ a user or organization name, like ``dbmdz/bert-base-german-cased``.
327
+ - A path to a `directory` containing model weights saved using
328
+ :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
329
+ - A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In
330
+ this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
331
+ as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in
332
+ a Flax model using the provided conversion scripts and loading the Flax model afterwards.
333
+
334
+ vision_model_name_or_path (:obj: `str`, `optional`, defaults to `None`):
335
+ Information necessary to initiate the vision model. Can be either:
336
+
337
+ - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
338
+ Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
339
+ a user or organization name, like ``dbmdz/bert-base-german-cased``.
340
+ - A path to a `directory` containing model weights saved using
341
+ :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
342
+ - A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In
343
+ this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
344
+ as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in
345
+ a Flax model using the provided conversion scripts and loading the Flax model afterwards.
346
+
347
+ model_args (remaining positional arguments, `optional`):
348
+ All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
349
+
350
+ kwargs (remaining dictionary of keyword arguments, `optional`):
351
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
352
+ :obj:`output_attentions=True`).
353
+
354
+ - To update the text configuration, use the prefix `text_` for each configuration parameter.
355
+ - To update the vision configuration, use the prefix `vision_` for each configuration parameter.
356
+ - To update the parent model configuration, do not use a prefix for each configuration parameter.
357
+
358
+ Behaves differently depending on whether a :obj:`config` is provided or automatically loaded.
359
+
360
+ Example::
361
+
362
+ >>> from transformers import FlaxHybridCLIP
363
+ >>> # initialize a model from pretrained BERT and CLIP models. Note that the projection layers will be randomly initialized.
364
+ >>> # If using CLIP's vision model the vision projection layer will be initialized using pre-trained weights
365
+ >>> model = FlaxHybridCLIP.from_text_vision_pretrained('bert-base-uncased', 'openai/clip-vit-base-patch32')
366
+ >>> # saving model after fine-tuning
367
+ >>> model.save_pretrained("./bert-clip")
368
+ >>> # load fine-tuned model
369
+ >>> model = FlaxHybridCLIP.from_pretrained("./bert-clip")
370
+ """
371
+
372
+ kwargs_text = {
373
+ argument[len("text_") :]: value for argument, value in kwargs.items() if argument.startswith("text_")
374
+ }
375
+
376
+ kwargs_vision = {
377
+ argument[len("vision_") :]: value for argument, value in kwargs.items() if argument.startswith("vision_")
378
+ }
379
+
380
+ # remove text, vision kwargs from kwargs
381
+ for key in kwargs_text.keys():
382
+ del kwargs["text_" + key]
383
+ for key in kwargs_vision.keys():
384
+ del kwargs["vision_" + key]
385
+
386
+ # Load and initialize the text and vision model
387
+ text_model = kwargs_text.pop("model", None)
388
+ if text_model is None:
389
+ assert (
390
+ text_model_name_or_path is not None
391
+ ), "If `model` is not defined as an argument, a `text_model_name_or_path` has to be defined"
392
+ from transformers import FlaxAutoModel
393
+
394
+ if "config" not in kwargs_text:
395
+ from transformers import AutoConfig
396
+
397
+ text_config = AutoConfig.from_pretrained(text_model_name_or_path)
398
+ kwargs_text["config"] = text_config
399
+
400
+ text_model = FlaxAutoModel.from_pretrained(text_model_name_or_path, *model_args, **kwargs_text)
401
+
402
+ vision_model = kwargs_vision.pop("model", None)
403
+ if vision_model is None:
404
+ assert (
405
+ vision_model_name_or_path is not None
406
+ ), "If `model` is not defined as an argument, a `vision_model_name_or_path` has to be defined"
407
+ from transformers import FlaxAutoModel
408
+
409
+ if "config" not in kwargs_vision:
410
+ from transformers import AutoConfig
411
+
412
+ vision_config = AutoConfig.from_pretrained(vision_model_name_or_path)
413
+ kwargs_vision["config"] = vision_config
414
+
415
+ vision_model = FlaxAutoModel.from_pretrained(vision_model_name_or_path, *model_args, **kwargs_vision)
416
+
417
+ # instantiate config with corresponding kwargs
418
+ dtype = kwargs.pop("dtype", jnp.float32)
419
+ config = HybridCLIPConfig.from_text_vision_configs(text_model.config, vision_model.config, **kwargs)
420
+
421
+ # init model
422
+ model = cls(config, *model_args, dtype=dtype, **kwargs)
423
+
424
+ if vision_config.model_type == "clip":
425
+ model.params["vision_model"]["vision_model"] = vision_model.params["vision_model"]
426
+ model.params["visual_projection"]["kernel"] = vision_model.params["visual_projection"]["kernel"]
427
+ else:
428
+ model.params["vision_model"] = vision_model.params
429
+
430
+ model.params["text_model"] = text_model.params
431
+
432
+ return model
433
+
hybrid_clip/requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ jax>=0.2.8
2
+ jaxlib>=0.1.59
3
+ flax>=0.3.4
4
+ optax>=0.0.8
5
+ -f https://download.pytorch.org/whl/torch_stable.html
6
+ torch==1.9.0+cpu
7
+ -f https://download.pytorch.org/whl/torch_stable.html
8
+ torchvision==0.10.0+cpu
9
+ comet_ml==3.12.2
10
+ python-dotenv==0.18.0
11
+ tqdm
12
+ transformers
hybrid_clip/run_hybrid_clip.py ADDED
@@ -0,0 +1,976 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Training a CLIP like dual encoder models using text and vision encoders in the library.
18
+
19
+ The script can be used to train CLIP like models for languages other than english by using
20
+ a text encoder pre-trained in the desired language. Currently this script support the following vision
21
+ and text models:
22
+ Vision models: ViT(https://huggingface.co/models?filter=vit), CLIP (https://huggingface.co/models?filter=clip)
23
+ Text models: BERT, ROBERTa (https://huggingface.co/models?filter=masked-lm)
24
+ """
25
+
26
+ import json
27
+ import logging
28
+ import os
29
+ import sys
30
+ import time
31
+ import numpy as np
32
+ from dataclasses import dataclass, field
33
+ from pathlib import Path
34
+ from typing import Callable, Optional
35
+ import shutil
36
+ import gc
37
+ import pyarrow as pa
38
+
39
+ try:
40
+ from dotenv import load_dotenv
41
+ load_dotenv("../.env")
42
+ except:
43
+ print("Couldn't find ../.env file")
44
+
45
+ import wandb
46
+ from transformers.file_utils import PushToHubMixin
47
+
48
+
49
+ import torch
50
+ from torchvision.datasets import VisionDataset
51
+ from torchvision.io import ImageReadMode, read_image
52
+ from torchvision.transforms import (
53
+ CenterCrop,
54
+ ConvertImageDtype,
55
+ Normalize,
56
+ Resize,
57
+ ColorJitter,
58
+ RandomHorizontalFlip,
59
+ RandomRotation,
60
+ RandomCrop,
61
+ RandomAffine,
62
+ RandomPerspective,
63
+ RandomAutocontrast,
64
+ RandomEqualize,
65
+ )
66
+ from torchvision.transforms.functional import InterpolationMode
67
+ from tqdm import tqdm
68
+
69
+ import jax
70
+ import jax.numpy as jnp
71
+ import optax
72
+ import transformers
73
+ from flax import jax_utils
74
+ from flax.jax_utils import unreplicate
75
+ from flax.training import train_state
76
+ from flax.training.common_utils import get_metrics, shard, shard_prng_key
77
+ from modeling_hybrid_clip import FlaxHybridCLIP
78
+ from configuration_hybrid_clip import HybridCLIPConfig
79
+ from transformers import (
80
+ AutoTokenizer,
81
+ HfArgumentParser,
82
+ TrainingArguments,
83
+ is_tensorboard_available,
84
+ set_seed,
85
+ )
86
+ from numpy.random import default_rng
87
+ from flax.serialization import to_bytes, from_bytes
88
+
89
+ logger = logging.getLogger(__name__)
90
+
91
+ def mb_item(x):
92
+ return x.item() if hasattr(x, "item") else x
93
+
94
+ # checkpoint functions
95
+ def save_model_checkpoint(
96
+ model,
97
+ save_dir,
98
+ state,
99
+ logger,
100
+ organization,
101
+ with_opt: bool = False,
102
+ push_to_hub: bool = False,
103
+ overwrite=False,
104
+ **kwargs,
105
+ ):
106
+ state = jax_utils.unreplicate(state)
107
+ #params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
108
+ logger.info(f"Saving Checkpoint in {save_dir}")
109
+ ckpt_save_dir = f"{save_dir}/ckpt-{mb_item(state.step)-1}"
110
+ if os.path.exists(ckpt_save_dir) and not overwrite:
111
+ logger.info("checkpoint exists, skipping overwrite")
112
+ else:
113
+ model.save_pretrained(
114
+ ckpt_save_dir, params=state.params, push_to_hub=False, **kwargs
115
+ )
116
+ if with_opt:
117
+ with open(os.path.join(ckpt_save_dir, "opt_state.msgpack"), "wb") as f:
118
+ f.write(to_bytes(state.opt_state))
119
+ with open(os.path.join(ckpt_save_dir, "training_state.json"), "w") as f:
120
+ json.dump({"step": state.step.item()}, f)
121
+
122
+ logger.info("checkpoint saved")
123
+
124
+ if push_to_hub:
125
+ repo_name = Path(save_dir).name
126
+ repo_url = PushToHubMixin._get_repo_url_from_name(
127
+ repo_name, organization=organization, private=False, use_auth_token=True
128
+ )
129
+ repo = PushToHubMixin._create_or_get_repo(
130
+ save_dir,
131
+ repo_url=repo_url,
132
+ organization=organization,
133
+ use_auth_token=True,
134
+ )
135
+ commit_message = f"Saving weights and logs at step {mb_item(state.step)-1}"
136
+ url = PushToHubMixin._push_to_hub(repo=repo, commit_message=commit_message)
137
+ logger.info(f"Model pushed to the hub in this commit: {url}")
138
+
139
+
140
+ def restore_model_checkpoint(save_dir, state, logger):
141
+ logger.info(f"Restoring checkpoint from {save_dir}.")
142
+ with open(os.path.join(save_dir, "flax_model.msgpack"), "rb") as f:
143
+ params = from_bytes(state.params, f.read())
144
+
145
+ with open(os.path.join(save_dir, "opt_state.msgpack"), "rb") as f:
146
+ opt_state = from_bytes(state.opt_state, f.read())
147
+
148
+ with open(os.path.join(save_dir, "training_state.json"), "r") as f:
149
+ training_state = json.load(f)
150
+ step = training_state["step"]
151
+
152
+ logger.info("checkpoint restored")
153
+ # return state.replace(step=step, params=params, opt_state=opt_state), step
154
+ return params, opt_state, step
155
+
156
+
157
+ def rotate_checkpoints(ckpt_dir: str, save_total_limit: int, logger):
158
+ "Removes older checkpoints so that `save_total_limit` checkpoints are kept"
159
+ # TODO: what to remove is decided using step number only, we might want to improve that
160
+ ckpts = [str(x) for x in Path(ckpt_dir).glob("ckpt-*")]
161
+ # sort checkpoints by step
162
+ ckpts_sorted = sorted(ckpts, key=lambda x: int(x.split("-")[-1]))
163
+ ckpts_to_delete = ckpts_sorted[:-save_total_limit]
164
+ for ckpt in ckpts_to_delete:
165
+ logger.info(
166
+ f"Deleting older checkpoint [{ckpt}] due to save_total_limit ({save_total_limit})"
167
+ )
168
+ shutil.rmtree(ckpt)
169
+
170
+ # Cache the result
171
+ has_tensorboard = is_tensorboard_available()
172
+ if has_tensorboard:
173
+ try:
174
+ from flax.metrics.tensorboard import SummaryWriter
175
+ except ImportError as ie:
176
+ has_tensorboard = False
177
+ print(
178
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
179
+ )
180
+
181
+ else:
182
+ print(
183
+ "Unable to display metrics through TensorBoard because the package is not installed: "
184
+ "Please run pip install tensorboard to enable."
185
+ )
186
+
187
+
188
+ @dataclass
189
+ class ModelArguments:
190
+ """
191
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
192
+ """
193
+
194
+ text_model_name_or_path: str = field(
195
+ metadata={
196
+ "help": "The text model checkpoint for weights initialization."
197
+ "Don't set if you want to train a model from scratch."
198
+ },
199
+ )
200
+ vision_model_name_or_path: str = field(
201
+ metadata={
202
+ "help": "The vision model checkpoint for weights initialization."
203
+ "Don't set if you want to train a model from scratch."
204
+ },
205
+ )
206
+ from_pt: bool = field(
207
+ default=True,
208
+ metadata={
209
+ "help": "whether to load the text and vision model using PyTorch checkpoints."
210
+ },
211
+ )
212
+ config_name: Optional[str] = field(
213
+ default=None,
214
+ metadata={
215
+ "help": "Pretrained config name or path if not the same as model_name"
216
+ },
217
+ )
218
+ tokenizer_name: Optional[str] = field(
219
+ default=None,
220
+ metadata={
221
+ "help": "Pretrained tokenizer name or path if not the same as model_name"
222
+ },
223
+ )
224
+ cache_dir: Optional[str] = field(
225
+ default=None,
226
+ metadata={
227
+ "help": "Where do you want to store the pretrained models downloaded from s3"
228
+ },
229
+ )
230
+ use_fast_tokenizer: bool = field(
231
+ default=True,
232
+ metadata={
233
+ "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
234
+ },
235
+ )
236
+ dtype: Optional[str] = field(
237
+ default="float32",
238
+ metadata={
239
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
240
+ },
241
+ )
242
+
243
+
244
+ @dataclass
245
+ class DataTrainingArguments:
246
+ """
247
+ Arguments pertaining to what data we are going to input our model for training and eval.
248
+ """
249
+
250
+ data_dir: Optional[str] = field(
251
+ default=None, metadata={"help": "The data directory containing input files."}
252
+ )
253
+ train_file: Optional[str] = field(
254
+ default=None,
255
+ metadata={"help": "The input training data file (a jsonlines file)."},
256
+ )
257
+ validation_file: Optional[str] = field(
258
+ default=None,
259
+ metadata={"help": "An optional input evaluation data file (a jsonlines file)."},
260
+ )
261
+ max_seq_length: Optional[int] = field(
262
+ default=72,
263
+ metadata={
264
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
265
+ "than this will be truncated, sequences shorter will be padded."
266
+ },
267
+ )
268
+ max_train_samples: Optional[int] = field(
269
+ default=None,
270
+ metadata={
271
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
272
+ "value if set."
273
+ },
274
+ )
275
+ max_eval_samples: Optional[int] = field(
276
+ default=None,
277
+ metadata={
278
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
279
+ "value if set."
280
+ },
281
+ )
282
+ overwrite_cache: bool = field(
283
+ default=False,
284
+ metadata={"help": "Overwrite the cached training and evaluation sets"},
285
+ )
286
+ overwrite_cache: bool = field(
287
+ default=False,
288
+ metadata={"help": "Overwrite the cached training and evaluation sets"},
289
+ )
290
+ preprocessing_num_workers: Optional[int] = field(
291
+ default=None,
292
+ metadata={"help": "The number of processes to use for the preprocessing."},
293
+ )
294
+
295
+ def __post_init__(self):
296
+ if self.train_file is None and self.validation_file is None:
297
+ raise ValueError(
298
+ "Need either a dataset name or a training/validation file."
299
+ )
300
+ else:
301
+ if self.train_file is not None:
302
+ extension = self.train_file.split(".")[-1]
303
+ assert extension == "json", "`train_file` should be a json file."
304
+ if self.validation_file is not None:
305
+ extension = self.validation_file.split(".")[-1]
306
+ assert extension == "json", "`validation_file` should be a json file."
307
+
308
+
309
+ # We use torchvision for faster image pre-processing.
310
+ # We need to ensure faster processing speed as it can become a bottleneck on TPU
311
+ class Transform(torch.nn.Module):
312
+ def __init__(self, image_size, augment=False):
313
+ super().__init__()
314
+ if not augment:
315
+ self.transforms = torch.nn.Sequential(
316
+ Resize([image_size], interpolation=InterpolationMode.BICUBIC),
317
+ CenterCrop(image_size),
318
+ ConvertImageDtype(torch.float),
319
+ Normalize(
320
+ (0.48145466, 0.4578275, 0.40821073),
321
+ (0.26862954, 0.26130258, 0.27577711),
322
+ ),
323
+ )
324
+ else:
325
+ self.transforms = torch.nn.Sequential(
326
+ Resize([image_size], interpolation=InterpolationMode.BICUBIC),
327
+ # CenterCrop(image_size),
328
+ RandomCrop([image_size], pad_if_needed=True, padding_mode="edge"),
329
+ ColorJitter(hue=0.1),
330
+ RandomHorizontalFlip(),
331
+ # RandomRotation(15, interpolation=InterpolationMode.BILINEAR, fill=128),
332
+ RandomAffine(
333
+ degrees=15,
334
+ translate=(0.1, 0.1),
335
+ scale=(0.8, 1.2),
336
+ shear=(-15, 15, -15, 15),
337
+ interpolation=InterpolationMode.BILINEAR,
338
+ fill=127,
339
+ ),
340
+ RandomPerspective(
341
+ distortion_scale=0.3,
342
+ p=0.3,
343
+ interpolation=InterpolationMode.BILINEAR,
344
+ fill=127,
345
+ ),
346
+ RandomAutocontrast(p=0.3),
347
+ RandomEqualize(p=0.3),
348
+ ConvertImageDtype(torch.float),
349
+ Normalize(
350
+ (0.48145466, 0.4578275, 0.40821073),
351
+ (0.26862954, 0.26130258, 0.27577711),
352
+ ),
353
+ )
354
+
355
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
356
+ with torch.no_grad():
357
+ x = self.transforms(x)
358
+ return x
359
+
360
+
361
+ class ImageTextDataset(VisionDataset):
362
+ """
363
+ Dtaset for loading image-text data for tasks like CLIP training, Image Captioning.
364
+
365
+ Args:
366
+ root: (string): The root path where the dataset is stored
367
+ file_path: (string): Path to the file containing the image_paths and associated captions.
368
+ The expected format is jsonlines where each line is a json object containing to keys.
369
+ `image_path`: The path to the image.
370
+ `captions`: An `array` of captions.
371
+ transform (callable, optional): A function/transform that takes in an PIL image
372
+ and returns a transformed version. E.g, ``transforms.ToTensor``
373
+ target_transform (callable, optional): A function/transform that takes in the
374
+ target and transforms it.
375
+ transforms (callable, optional): A function/transform that takes input sample and its target as entry
376
+ and returns a transformed version.
377
+ """
378
+
379
+ def __init__(
380
+ self,
381
+ root: str,
382
+ file_path: str,
383
+ captions_per_image=-1,
384
+ transform: Optional[Callable] = None,
385
+ target_transform: Optional[Callable] = None,
386
+ transforms: Optional[Callable] = None,
387
+ seed=42,
388
+ ):
389
+ super().__init__(root, transforms, transform, target_transform)
390
+ with open(file_path, "r") as f:
391
+ examples = [json.loads(line) for line in f.readlines()]
392
+ #examples = pa.array([json.loads(line) for line in f.readlines()])
393
+
394
+ self.rand_generator = default_rng(seed)
395
+
396
+ self.captions = []
397
+ self.image_paths = []
398
+
399
+ for example in examples:
400
+ if captions_per_image <= -1:
401
+ self.captions.append(example["captions"])
402
+ elif captions_per_image > 0:
403
+ self.captions.append(example["captions"][:captions_per_image])
404
+ else:
405
+ raise ValueError("captions per image cannot be zero")
406
+
407
+ #self.image_paths.append(str(example["image_path"]))
408
+ self.image_paths.append(example["image_path"])
409
+
410
+ self.captions = self.captions
411
+ self.image_paths = self.image_paths
412
+
413
+ def _load_image(self, idx: int):
414
+ path = self.image_paths[idx]
415
+ im = read_image(path, mode=ImageReadMode.RGB)
416
+ return im
417
+
418
+ def _load_target(self, idx):
419
+ return str(self.rand_generator.choice(self.captions[idx]))
420
+ # if len(self.captions[idx]) > 1:
421
+ # caption_idx = np.random.randint(0, len(self.captions[idx]))
422
+ # else:
423
+ # caption_idx = 0
424
+ # return self.captions[idx][caption_idx]
425
+
426
+ def __getitem__(self, index: int):
427
+ image = self._load_image(index)
428
+ target = self._load_target(index)
429
+
430
+ if self.transforms is not None:
431
+ image, target = self.transforms(image, target)
432
+
433
+ return image, target
434
+
435
+ def __len__(self) -> int:
436
+ return len(self.captions)
437
+
438
+
439
+ class TrainState(train_state.TrainState):
440
+ dropout_rng: jnp.ndarray
441
+
442
+ def replicate(self):
443
+ return jax_utils.replicate(self).replace(
444
+ dropout_rng=shard_prng_key(self.dropout_rng)
445
+ )
446
+
447
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
448
+ summary_writer.scalar("train_time", train_time, step)
449
+
450
+ train_metrics = get_metrics(train_metrics)
451
+ for key, vals in train_metrics.items():
452
+ tag = f"train_{key}"
453
+ for i, val in enumerate(vals):
454
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
455
+
456
+
457
+ def write_eval_metric(summary_writer, eval_metrics, step):
458
+ for metric_name, value in eval_metrics.items():
459
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
460
+
461
+
462
+ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
463
+ summary_writer.scalar("train_time", train_time, step)
464
+
465
+ train_metrics = get_metrics(train_metrics)
466
+ for key, vals in train_metrics.items():
467
+ tag = f"train_{key}"
468
+ for i, val in enumerate(vals):
469
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
470
+
471
+ for metric_name, value in eval_metrics.items():
472
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
473
+
474
+
475
+ def create_learning_rate_fn(
476
+ train_ds_size: int,
477
+ train_batch_size: int,
478
+ num_train_epochs: int,
479
+ num_warmup_steps: int,
480
+ learning_rate: float,
481
+ linear=False,
482
+ ) -> Callable[[int], jnp.array]:
483
+ """Returns a linear warmup, linear_decay learning rate function."""
484
+ steps_per_epoch = train_ds_size // train_batch_size
485
+ num_train_steps = steps_per_epoch * num_train_epochs
486
+ if linear:
487
+ warmup_fn = optax.linear_schedule(
488
+ init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
489
+ )
490
+ decay_fn = optax.linear_schedule(
491
+ init_value=learning_rate,
492
+ end_value=0,
493
+ transition_steps=num_train_steps - num_warmup_steps,
494
+ )
495
+ else:
496
+ warmup_fn = optax.linear_schedule(
497
+ init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
498
+ )
499
+ decay_fn = optax.cosine_decay_schedule(
500
+ init_value=learning_rate,
501
+ decay_steps=num_train_steps - num_warmup_steps,
502
+ alpha=0.0,
503
+ )
504
+ schedule_fn = optax.join_schedules(
505
+ schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]
506
+ )
507
+ return schedule_fn
508
+
509
+
510
+ def main():
511
+ parser = HfArgumentParser(
512
+ (ModelArguments, DataTrainingArguments, TrainingArguments)
513
+ )
514
+ parser.add_argument("--log_wandb", action="store_true")
515
+ parser.add_argument("--freeze_backbones", action="store_true")
516
+ parser.add_argument("--exp_name", type=str, default=None)
517
+ parser.add_argument("--run_from_checkpoint", type=str, default=None)
518
+
519
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
520
+ # If we pass only one argument to the script and it's the path to a json file,
521
+ # let's parse it to get our arguments.
522
+ model_args, data_args, training_args = parser.parse_json_file(
523
+ json_file=os.path.abspath(sys.argv[1])
524
+ )
525
+ else:
526
+ (
527
+ model_args,
528
+ data_args,
529
+ training_args,
530
+ args,
531
+ ) = parser.parse_args_into_dataclasses()
532
+
533
+ if (
534
+ os.path.exists(training_args.output_dir)
535
+ and os.listdir(training_args.output_dir)
536
+ and training_args.do_train
537
+ and not training_args.overwrite_output_dir
538
+ ):
539
+ raise ValueError(
540
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
541
+ "Use --overwrite_output_dir to overcome."
542
+ )
543
+
544
+ # Make one log on every process with the configuration for debugging.
545
+ logging.basicConfig(
546
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
547
+ datefmt="%m/%d/%Y %H:%M:%S",
548
+ level=logging.INFO,
549
+ )
550
+ # Setup logging, we only want one process per machine to log things on the screen.
551
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
552
+ if jax.process_index() == 0:
553
+ transformers.utils.logging.set_verbosity_info()
554
+ else:
555
+ transformers.utils.logging.set_verbosity_error()
556
+
557
+ # Set the verbosity to info of the Transformers logger (on main process only):
558
+ logger.info(f"Training/evaluation parameters {training_args}")
559
+
560
+ if model_args.tokenizer_name:
561
+ tokenizer = AutoTokenizer.from_pretrained(
562
+ model_args.tokenizer_name,
563
+ cache_dir=model_args.cache_dir,
564
+ use_fast=model_args.use_fast_tokenizer
565
+ )
566
+ elif model_args.text_model_name_or_path:
567
+ tokenizer = AutoTokenizer.from_pretrained(
568
+ model_args.text_model_name_or_path,
569
+ cache_dir=model_args.cache_dir,
570
+ use_fast=model_args.use_fast_tokenizer,
571
+ )
572
+ else:
573
+ raise ValueError(
574
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
575
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
576
+ )
577
+
578
+
579
+ if args.run_from_checkpoint is not None:
580
+ with open(f"{args.run_from_checkpoint}/config.json", "r") as fp:
581
+ config_dict = json.load(fp)
582
+ config_dict["vision_config"]["model_type"] = "clip"
583
+ config = HybridCLIPConfig(**config_dict)
584
+ model = FlaxHybridCLIP.from_pretrained(
585
+ args.run_from_checkpoint,
586
+ seed=training_args.seed,
587
+ dtype=getattr(jnp, model_args.dtype),
588
+ config=config,
589
+ freeze_backbones=args.freeze_backbones
590
+ )
591
+ else:
592
+
593
+ model = FlaxHybridCLIP.from_text_vision_pretrained(
594
+ model_args.text_model_name_or_path,
595
+ model_args.vision_model_name_or_path,
596
+ seed=training_args.seed,
597
+ dtype=getattr(jnp, model_args.dtype),
598
+ text_from_pt=False,
599
+ vision_from_pt=model_args.from_pt,
600
+ freeze_backbones=args.freeze_backbones
601
+ )
602
+ config = model.config
603
+ # set seed for torch dataloaders
604
+ set_seed(training_args.seed)
605
+
606
+ # Initialize torchvision transforms and jit them for faster processing
607
+ train_preprocess = Transform(config.vision_config.image_size, augment=True)
608
+ train_preprocess = torch.jit.script(train_preprocess)
609
+
610
+ val_preprocess = Transform(config.vision_config.image_size)
611
+ val_preprocess = torch.jit.script(val_preprocess)
612
+
613
+ # Initialize the image-text dataset
614
+ train_dataset = ImageTextDataset(
615
+ data_args.data_dir,
616
+ data_args.train_file,
617
+ captions_per_image=-1,
618
+ transform=train_preprocess,
619
+ seed=training_args.seed,
620
+ )
621
+
622
+ eval_dataset = ImageTextDataset(
623
+ data_args.data_dir,
624
+ data_args.validation_file,
625
+ captions_per_image=-1,
626
+ transform=val_preprocess,
627
+ seed=training_args.seed,
628
+ )
629
+
630
+ # Store some constant
631
+ num_epochs = int(training_args.num_train_epochs)
632
+ train_batch_size = (
633
+ int(training_args.per_device_train_batch_size) * jax.device_count()
634
+ )
635
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
636
+ steps_per_epoch = len(train_dataset) // train_batch_size
637
+ total_train_steps = steps_per_epoch * num_epochs
638
+
639
+ # Use collate function to tokenizer the text and convert the processed images to numpy
640
+ def collate_fn(examples):
641
+ pixel_values = (
642
+ torch.stack([example[0] for example in examples])
643
+ .permute(0, 2, 3, 1)
644
+ .numpy()
645
+ )
646
+ captions = [example[1] for example in examples]
647
+ inputs = tokenizer(
648
+ captions,
649
+ max_length=data_args.max_seq_length,
650
+ padding="max_length",
651
+ truncation=True,
652
+ return_tensors="np",
653
+ )
654
+
655
+ batch = {
656
+ "pixel_values": pixel_values,
657
+ "input_ids": inputs["input_ids"],
658
+ "attention_mask": inputs["attention_mask"],
659
+ }
660
+
661
+ return batch
662
+
663
+ # Create data loaders
664
+ train_loader = torch.utils.data.DataLoader(
665
+ train_dataset,
666
+ batch_size=train_batch_size,
667
+ shuffle=True,
668
+ num_workers=data_args.preprocessing_num_workers,
669
+ #persistent_workers=True,
670
+ drop_last=True,
671
+ collate_fn=collate_fn,
672
+ )
673
+
674
+ eval_loader = torch.utils.data.DataLoader(
675
+ eval_dataset,
676
+ batch_size=eval_batch_size,
677
+ shuffle=False,
678
+ num_workers=data_args.preprocessing_num_workers,
679
+ #persistent_workers=True,
680
+ drop_last=True,
681
+ collate_fn=collate_fn,
682
+ )
683
+
684
+ # Enable tensorboard only on the master node
685
+ if has_tensorboard and jax.process_index() == 0:
686
+ summary_writer = SummaryWriter(
687
+ log_dir=Path(training_args.output_dir).joinpath("logs").as_posix()
688
+ )
689
+
690
+ # Enable wandb
691
+ if jax.process_index() == 0 and args.log_wandb:
692
+ try:
693
+ wandb.init(
694
+ name=args.exp_name,
695
+ entity="galuh",
696
+ project="indoclip",
697
+ sync_tensorboard=True
698
+ )
699
+ wandb.config.update(training_args)
700
+ wandb.config.update(model_args)
701
+ wandb.config.update(data_args)
702
+ except ImportError as e:
703
+ print(e)
704
+
705
+ # Initialize our training
706
+ rng = jax.random.PRNGKey(training_args.seed)
707
+ rng, dropout_rng = jax.random.split(rng)
708
+
709
+ # Create learning rate schedule
710
+ if training_args.warmup_steps:
711
+ warmup_steps = training_args.warmup_steps
712
+ elif training_args.warmup_ratio:
713
+ warmup_steps = int(training_args.warmup_ratio * total_train_steps)
714
+ else:
715
+ raise RuntimeError(
716
+ "You have to specify either the warmup_steps or warmup_ratio CLI parameter"
717
+ )
718
+
719
+ decay_lr_schedule_fn = create_learning_rate_fn(
720
+ len(train_dataset),
721
+ train_batch_size,
722
+ training_args.num_train_epochs,
723
+ warmup_steps,
724
+ training_args.learning_rate,
725
+ linear=False, # set False to activate cosine annealing
726
+ )
727
+
728
+ # create adam optimizer
729
+ # optimizer = optax.adamw(
730
+ # learning_rate=decay_lr_schedule_fn,
731
+ # b1=training_args.adam_beta1,
732
+ # b2=training_args.adam_beta2,
733
+ # eps=training_args.adam_epsilon,
734
+ # weight_decay=training_args.weight_decay,
735
+ # )
736
+
737
+ optimizer = optax.chain(
738
+ optax.adaptive_grad_clip(0.01, eps=0.001),
739
+ optax.scale_by_belief(),
740
+ optax.scale_by_schedule(decay_lr_schedule_fn),
741
+ optax.scale(-1.0),
742
+ )
743
+
744
+ '''optimizer = optax.adafactor(
745
+ learning_rate=decay_lr_schedule_fn,
746
+ )'''
747
+
748
+ # Setup train state
749
+ state = TrainState.create(
750
+ apply_fn=model.__call__,
751
+ params=model.params,
752
+ tx=optimizer,
753
+ dropout_rng=dropout_rng,
754
+ )
755
+
756
+ def cross_entropy(logits, axis):
757
+ logprobs = jax.nn.log_softmax(logits, axis=axis)
758
+ nll = jnp.diag(logprobs)
759
+ ce = -jnp.mean(nll)
760
+ return ce
761
+
762
+ def clip_loss(similarity):
763
+ loss = (
764
+ cross_entropy(similarity, axis=0) + cross_entropy(similarity, axis=1)
765
+ ) / 2
766
+ return loss
767
+
768
+ # Define gradient update step fn
769
+ def train_step(state, batch):
770
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
771
+
772
+ def compute_loss(params):
773
+ logits = state.apply_fn(
774
+ **batch, params=params, dropout_rng=dropout_rng, train=True
775
+ )[0]
776
+ loss = clip_loss(logits)
777
+ return loss
778
+
779
+ grad_fn = jax.value_and_grad(compute_loss)
780
+ loss, grad = grad_fn(state.params)
781
+ grad = jax.lax.pmean(grad, "batch")
782
+
783
+ new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
784
+
785
+ metrics = {
786
+ "loss": loss,
787
+ "learning_rate": decay_lr_schedule_fn(state.step),
788
+ }
789
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
790
+
791
+ return new_state, metrics
792
+
793
+ # Define eval fn
794
+ def eval_step(params, batch):
795
+ logits = model(**batch, params=params, train=False)[0]
796
+ loss = clip_loss(logits)
797
+
798
+ # summarize metrics
799
+ metrics = {"loss": loss}
800
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
801
+ return metrics
802
+
803
+ # Create parallel version of the train and eval step
804
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
805
+ p_eval_step = jax.pmap(eval_step, "batch")
806
+
807
+ # Replicate the train state on each device
808
+ state = state.replicate()
809
+
810
+ logger.info("***** Running training *****")
811
+ logger.info(f" TPU = {jax.device_count()}")
812
+ logger.info(f" Num examples = {len(train_dataset)}")
813
+ logger.info(f" Num Epochs = {num_epochs}")
814
+ logger.info(
815
+ f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
816
+ )
817
+ logger.info(
818
+ f" Total train batch size (w. parallel & distributed) = {train_batch_size}"
819
+ )
820
+ logger.info(f" Total optimization steps = {total_train_steps}")
821
+ logger.info(f" Total warmup steps = {warmup_steps}")
822
+
823
+ train_time = 0
824
+ # Create sampling rng
825
+ rng, input_rng = jax.random.split(rng)
826
+
827
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
828
+ for epoch in epochs:
829
+ # ======================== Training ================================
830
+ train_start = time.time()
831
+
832
+ # Create sampling rng
833
+ rng, input_rng = jax.random.split(rng)
834
+ train_metrics = []
835
+
836
+ num_train_samples = len(train_dataset)
837
+
838
+ steps_per_epoch = len(train_dataset) // train_batch_size
839
+ train_step_progress_bar = tqdm(
840
+ total=steps_per_epoch, desc="Training...", position=1, leave=False
841
+ )
842
+ # train
843
+ for step, batch in enumerate(train_loader):
844
+ batch = shard(batch)
845
+ state, train_metric = p_train_step(state, batch)
846
+ train_metrics.append(train_metric)
847
+
848
+ train_step_progress_bar.update(1)
849
+
850
+ cur_step = epoch * (num_train_samples // train_batch_size) + step + 1
851
+
852
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
853
+ train_time += time.time() - train_start
854
+ train_metric = unreplicate(train_metric)
855
+
856
+ # Save tensorboard metrics
857
+ if has_tensorboard and jax.process_index() == 0:
858
+ write_train_metric(
859
+ summary_writer, train_metrics, train_time, cur_step
860
+ )
861
+
862
+ # Save wandb metrics
863
+ if args.log_wandb and jax.process_index() == 0:
864
+ #_metrics = {k if k=="learning_rate" else f"train_{k}":mb_item(v.mean()) for k, v in train_metric.items()}
865
+ #_metrics = {k if k=="learning_rate" else f"train_{k}":mb_item(v.mean()) for k, v in train_metric.items()}
866
+ _metrics = {f'train_{k}': jax.device_get(v) for k,v in train_metric.items()}
867
+ wandb.log({"train_step":cur_step, **_metrics}, commit=True)
868
+
869
+ epochs.write(
870
+ f"Log at Step: {cur_step} (Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
871
+ )
872
+
873
+ logging.info("Emptying train metrics")
874
+
875
+ del train_metric
876
+ del train_metrics
877
+ train_metrics = []
878
+
879
+ gc.collect()
880
+ torch.cuda.empty_cache()
881
+
882
+ if cur_step % training_args.eval_steps == 0 and cur_step > 0:
883
+ # ======================== Evaluating ==============================
884
+ num_eval_samples = len(eval_dataset)
885
+ eval_metrics = []
886
+ eval_steps = len(eval_dataset) // eval_batch_size
887
+ eval_step_progress_bar = tqdm(
888
+ total=eval_steps, desc="Evaluating...", position=2, leave=False
889
+ )
890
+ for batch in eval_loader:
891
+ # Model forward
892
+ batch = shard(batch)
893
+ metrics = p_eval_step(state.params, batch)
894
+ eval_metrics.append(metrics)
895
+
896
+ eval_step_progress_bar.update(1)
897
+
898
+ # normalize eval metrics
899
+ eval_metrics = get_metrics(eval_metrics)
900
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
901
+
902
+ # Print metrics and update progress bar
903
+ desc = f"Eval at Step: {cur_step} (Loss: {eval_metrics['loss']})"
904
+ epochs.write(desc)
905
+ epochs.desc = desc
906
+
907
+ # Save tfboard eval
908
+ if has_tensorboard and jax.process_index() == 0:
909
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
910
+
911
+ # Save eval wandb
912
+ if args.log_wandb and jax.process_index() == 0:
913
+ #_metrics = {f"eval_{k}":mb_item(v) for k, v in eval_metrics.items()}
914
+ _metrics = {f'eval_{k}': jax.device_get(v) for k,v in eval_metrics.items()}
915
+ wandb.log({"eval_step":cur_step, **_metrics})
916
+
917
+ logging.info("Emptying eval metrics")
918
+ del eval_metrics
919
+
920
+ eval_metrics = []
921
+
922
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
923
+ # save checkpoint after each epoch and push checkpoint to the hub
924
+ if jax.process_index() == 0:
925
+ # params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
926
+ # model.save_pretrained(
927
+ # training_args.output_dir,
928
+ # params=params,
929
+ # push_to_hub=training_args.push_to_hub,
930
+ # commit_message=f"Saving weights and logs of step {cur_step}",
931
+ # )
932
+ save_model_checkpoint(
933
+ model,
934
+ training_args.output_dir,
935
+ state,
936
+ logger,
937
+ training_args.push_to_hub_organization,
938
+ with_opt=True,
939
+ push_to_hub=training_args.push_to_hub,
940
+ overwrite=True,
941
+ )
942
+ # if model_args.save_optimizer:
943
+ # # this saves full state including optimizer
944
+ # save_checkpoint(training_args.output_dir, state, state.step, keep=training_args.save_total_limit, overwrite=True)
945
+ if training_args.save_total_limit is not None:
946
+ rotate_checkpoints(
947
+ training_args.output_dir,
948
+ training_args.save_total_limit,
949
+ logger,
950
+ )
951
+
952
+ train_step_progress_bar.close() #check
953
+
954
+ '''# save checkpoint after each epoch and push checkpoint to the hub
955
+ if jax.process_index() == 0:
956
+ params = jax.device_get(unreplicate(state.params))
957
+ model.save_pretrained(
958
+ training_args.output_dir + f"/{epoch+1}/",
959
+ params=params,
960
+ push_to_hub=training_args.push_to_hub,
961
+ commit_message=f"Saving weights and logs of epoch {epoch+1}",
962
+ )'''
963
+
964
+ # save model after training is over
965
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
966
+ model.save_pretrained(
967
+ training_args.output_dir,
968
+ params=params,
969
+ push_to_hub=training_args.push_to_hub,
970
+ commit_message="Add final model",
971
+ )
972
+
973
+
974
+ if __name__ == "__main__":
975
+ main()
976
+
hybrid_clip/run_hybrid_clip_backup.py ADDED
@@ -0,0 +1,970 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Training a CLIP like dual encoder models using text and vision encoders in the library.
18
+
19
+ The script can be used to train CLIP like models for languages other than english by using
20
+ a text encoder pre-trained in the desired language. Currently this script support the following vision
21
+ and text models:
22
+ Vision models: ViT(https://huggingface.co/models?filter=vit), CLIP (https://huggingface.co/models?filter=clip)
23
+ Text models: BERT, ROBERTa (https://huggingface.co/models?filter=masked-lm)
24
+ """
25
+
26
+ import json
27
+ import logging
28
+ import os
29
+ import sys
30
+ import time
31
+ import numpy as np
32
+ from dataclasses import dataclass, field
33
+ from pathlib import Path
34
+ from typing import Callable, Optional
35
+ import shutil
36
+ import gc
37
+
38
+ try:
39
+ from dotenv import load_dotenv
40
+ load_dotenv("../.env")
41
+ except:
42
+ print("Couldn't find ../.env file")
43
+
44
+ import wandb
45
+ from transformers.file_utils import PushToHubMixin
46
+
47
+
48
+ import torch
49
+ from torchvision.datasets import VisionDataset
50
+ from torchvision.io import ImageReadMode, read_image
51
+ from torchvision.transforms import (
52
+ CenterCrop,
53
+ ConvertImageDtype,
54
+ Normalize,
55
+ Resize,
56
+ ColorJitter,
57
+ RandomHorizontalFlip,
58
+ RandomRotation,
59
+ RandomCrop,
60
+ RandomAffine,
61
+ RandomPerspective,
62
+ RandomAutocontrast,
63
+ RandomEqualize,
64
+ )
65
+ from torchvision.transforms.functional import InterpolationMode
66
+ from tqdm import tqdm
67
+
68
+ import jax
69
+ import jax.numpy as jnp
70
+ import optax
71
+ import transformers
72
+ from flax import jax_utils
73
+ from flax.jax_utils import unreplicate
74
+ from flax.training import train_state
75
+ from flax.training.common_utils import get_metrics, shard, shard_prng_key
76
+ from modeling_hybrid_clip import FlaxHybridCLIP
77
+ from configuration_hybrid_clip import HybridCLIPConfig
78
+ from transformers import (
79
+ AutoTokenizer,
80
+ HfArgumentParser,
81
+ TrainingArguments,
82
+ is_tensorboard_available,
83
+ set_seed,
84
+ )
85
+ from numpy.random import default_rng
86
+ from flax.serialization import to_bytes, from_bytes
87
+
88
+ logger = logging.getLogger(__name__)
89
+
90
+ def mb_item(x):
91
+ return x.item() if hasattr(x, "item") else x
92
+
93
+ # checkpoint functions
94
+ def save_model_checkpoint(
95
+ model,
96
+ save_dir,
97
+ state,
98
+ logger,
99
+ organization,
100
+ with_opt: bool = False,
101
+ push_to_hub: bool = False,
102
+ overwrite=False,
103
+ **kwargs,
104
+ ):
105
+ state = jax_utils.unreplicate(state)
106
+ #params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
107
+ logger.info(f"Saving Checkpoint in {save_dir}")
108
+ ckpt_save_dir = f"{save_dir}/ckpt-{mb_item(state.step)-1}"
109
+ if os.path.exists(ckpt_save_dir) and not overwrite:
110
+ logger.info("checkpoint exists, skipping overwrite")
111
+ else:
112
+ model.save_pretrained(
113
+ ckpt_save_dir, params=state.params, push_to_hub=False, **kwargs
114
+ )
115
+ if with_opt:
116
+ with open(os.path.join(ckpt_save_dir, "opt_state.msgpack"), "wb") as f:
117
+ f.write(to_bytes(state.opt_state))
118
+ with open(os.path.join(ckpt_save_dir, "training_state.json"), "w") as f:
119
+ json.dump({"step": state.step.item()}, f)
120
+
121
+ logger.info("checkpoint saved")
122
+
123
+ if push_to_hub:
124
+ repo_name = Path(save_dir).name
125
+ repo_url = PushToHubMixin._get_repo_url_from_name(
126
+ repo_name, organization=organization, private=False, use_auth_token=True
127
+ )
128
+ repo = PushToHubMixin._create_or_get_repo(
129
+ save_dir,
130
+ repo_url=repo_url,
131
+ organization=organization,
132
+ use_auth_token=True,
133
+ )
134
+ commit_message = f"Saving weights and logs at step {mb_item(state.step)-1}"
135
+ url = PushToHubMixin._push_to_hub(repo=repo, commit_message=commit_message)
136
+ logger.info(f"Model pushed to the hub in this commit: {url}")
137
+
138
+
139
+ def restore_model_checkpoint(save_dir, state, logger):
140
+ logger.info(f"Restoring checkpoint from {save_dir}.")
141
+ with open(os.path.join(save_dir, "flax_model.msgpack"), "rb") as f:
142
+ params = from_bytes(state.params, f.read())
143
+
144
+ with open(os.path.join(save_dir, "opt_state.msgpack"), "rb") as f:
145
+ opt_state = from_bytes(state.opt_state, f.read())
146
+
147
+ with open(os.path.join(save_dir, "training_state.json"), "r") as f:
148
+ training_state = json.load(f)
149
+ step = training_state["step"]
150
+
151
+ logger.info("checkpoint restored")
152
+ # return state.replace(step=step, params=params, opt_state=opt_state), step
153
+ return params, opt_state, step
154
+
155
+
156
+ def rotate_checkpoints(ckpt_dir: str, save_total_limit: int, logger):
157
+ "Removes older checkpoints so that `save_total_limit` checkpoints are kept"
158
+ # TODO: what to remove is decided using step number only, we might want to improve that
159
+ ckpts = [str(x) for x in Path(ckpt_dir).glob("ckpt-*")]
160
+ # sort checkpoints by step
161
+ ckpts_sorted = sorted(ckpts, key=lambda x: int(x.split("-")[-1]))
162
+ ckpts_to_delete = ckpts_sorted[:-save_total_limit]
163
+ for ckpt in ckpts_to_delete:
164
+ logger.info(
165
+ f"Deleting older checkpoint [{ckpt}] due to save_total_limit ({save_total_limit})"
166
+ )
167
+ shutil.rmtree(ckpt)
168
+
169
+ # Cache the result
170
+ has_tensorboard = is_tensorboard_available()
171
+ if has_tensorboard:
172
+ try:
173
+ from flax.metrics.tensorboard import SummaryWriter
174
+ except ImportError as ie:
175
+ has_tensorboard = False
176
+ print(
177
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
178
+ )
179
+
180
+ else:
181
+ print(
182
+ "Unable to display metrics through TensorBoard because the package is not installed: "
183
+ "Please run pip install tensorboard to enable."
184
+ )
185
+
186
+
187
+ @dataclass
188
+ class ModelArguments:
189
+ """
190
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
191
+ """
192
+
193
+ text_model_name_or_path: str = field(
194
+ metadata={
195
+ "help": "The text model checkpoint for weights initialization."
196
+ "Don't set if you want to train a model from scratch."
197
+ },
198
+ )
199
+ vision_model_name_or_path: str = field(
200
+ metadata={
201
+ "help": "The vision model checkpoint for weights initialization."
202
+ "Don't set if you want to train a model from scratch."
203
+ },
204
+ )
205
+ from_pt: bool = field(
206
+ default=True,
207
+ metadata={
208
+ "help": "whether to load the text and vision model using PyTorch checkpoints."
209
+ },
210
+ )
211
+ config_name: Optional[str] = field(
212
+ default=None,
213
+ metadata={
214
+ "help": "Pretrained config name or path if not the same as model_name"
215
+ },
216
+ )
217
+ tokenizer_name: Optional[str] = field(
218
+ default=None,
219
+ metadata={
220
+ "help": "Pretrained tokenizer name or path if not the same as model_name"
221
+ },
222
+ )
223
+ cache_dir: Optional[str] = field(
224
+ default=None,
225
+ metadata={
226
+ "help": "Where do you want to store the pretrained models downloaded from s3"
227
+ },
228
+ )
229
+ use_fast_tokenizer: bool = field(
230
+ default=True,
231
+ metadata={
232
+ "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
233
+ },
234
+ )
235
+ dtype: Optional[str] = field(
236
+ default="float32",
237
+ metadata={
238
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
239
+ },
240
+ )
241
+
242
+
243
+ @dataclass
244
+ class DataTrainingArguments:
245
+ """
246
+ Arguments pertaining to what data we are going to input our model for training and eval.
247
+ """
248
+
249
+ data_dir: Optional[str] = field(
250
+ default=None, metadata={"help": "The data directory containing input files."}
251
+ )
252
+ train_file: Optional[str] = field(
253
+ default=None,
254
+ metadata={"help": "The input training data file (a jsonlines file)."},
255
+ )
256
+ validation_file: Optional[str] = field(
257
+ default=None,
258
+ metadata={"help": "An optional input evaluation data file (a jsonlines file)."},
259
+ )
260
+ max_seq_length: Optional[int] = field(
261
+ default=72,
262
+ metadata={
263
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
264
+ "than this will be truncated, sequences shorter will be padded."
265
+ },
266
+ )
267
+ max_train_samples: Optional[int] = field(
268
+ default=None,
269
+ metadata={
270
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
271
+ "value if set."
272
+ },
273
+ )
274
+ max_eval_samples: Optional[int] = field(
275
+ default=None,
276
+ metadata={
277
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
278
+ "value if set."
279
+ },
280
+ )
281
+ overwrite_cache: bool = field(
282
+ default=False,
283
+ metadata={"help": "Overwrite the cached training and evaluation sets"},
284
+ )
285
+ overwrite_cache: bool = field(
286
+ default=False,
287
+ metadata={"help": "Overwrite the cached training and evaluation sets"},
288
+ )
289
+ preprocessing_num_workers: Optional[int] = field(
290
+ default=None,
291
+ metadata={"help": "The number of processes to use for the preprocessing."},
292
+ )
293
+
294
+ def __post_init__(self):
295
+ if self.train_file is None and self.validation_file is None:
296
+ raise ValueError(
297
+ "Need either a dataset name or a training/validation file."
298
+ )
299
+ else:
300
+ if self.train_file is not None:
301
+ extension = self.train_file.split(".")[-1]
302
+ assert extension == "json", "`train_file` should be a json file."
303
+ if self.validation_file is not None:
304
+ extension = self.validation_file.split(".")[-1]
305
+ assert extension == "json", "`validation_file` should be a json file."
306
+
307
+
308
+ # We use torchvision for faster image pre-processing.
309
+ # We need to ensure faster processing speed as it can become a bottleneck on TPU
310
+ class Transform(torch.nn.Module):
311
+ def __init__(self, image_size, augment=False):
312
+ super().__init__()
313
+ if not augment:
314
+ self.transforms = torch.nn.Sequential(
315
+ Resize([image_size], interpolation=InterpolationMode.BICUBIC),
316
+ CenterCrop(image_size),
317
+ ConvertImageDtype(torch.float),
318
+ Normalize(
319
+ (0.48145466, 0.4578275, 0.40821073),
320
+ (0.26862954, 0.26130258, 0.27577711),
321
+ ),
322
+ )
323
+ else:
324
+ self.transforms = torch.nn.Sequential(
325
+ Resize([image_size], interpolation=InterpolationMode.BICUBIC),
326
+ # CenterCrop(image_size),
327
+ RandomCrop([image_size], pad_if_needed=True, padding_mode="edge"),
328
+ ColorJitter(hue=0.1),
329
+ RandomHorizontalFlip(),
330
+ # RandomRotation(15, interpolation=InterpolationMode.BILINEAR, fill=128),
331
+ RandomAffine(
332
+ degrees=15,
333
+ translate=(0.1, 0.1),
334
+ scale=(0.8, 1.2),
335
+ shear=(-15, 15, -15, 15),
336
+ interpolation=InterpolationMode.BILINEAR,
337
+ fill=127,
338
+ ),
339
+ RandomPerspective(
340
+ distortion_scale=0.3,
341
+ p=0.3,
342
+ interpolation=InterpolationMode.BILINEAR,
343
+ fill=127,
344
+ ),
345
+ RandomAutocontrast(p=0.3),
346
+ RandomEqualize(p=0.3),
347
+ ConvertImageDtype(torch.float),
348
+ Normalize(
349
+ (0.48145466, 0.4578275, 0.40821073),
350
+ (0.26862954, 0.26130258, 0.27577711),
351
+ ),
352
+ )
353
+
354
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
355
+ with torch.no_grad():
356
+ x = self.transforms(x)
357
+ return x
358
+
359
+
360
+ class ImageTextDataset(VisionDataset):
361
+ """
362
+ Dtaset for loading image-text data for tasks like CLIP training, Image Captioning.
363
+
364
+ Args:
365
+ root: (string): The root path where the dataset is stored
366
+ file_path: (string): Path to the file containing the image_paths and associated captions.
367
+ The expected format is jsonlines where each line is a json object containing to keys.
368
+ `image_path`: The path to the image.
369
+ `captions`: An `array` of captions.
370
+ transform (callable, optional): A function/transform that takes in an PIL image
371
+ and returns a transformed version. E.g, ``transforms.ToTensor``
372
+ target_transform (callable, optional): A function/transform that takes in the
373
+ target and transforms it.
374
+ transforms (callable, optional): A function/transform that takes input sample and its target as entry
375
+ and returns a transformed version.
376
+ """
377
+
378
+ def __init__(
379
+ self,
380
+ root: str,
381
+ file_path: str,
382
+ captions_per_image=-1,
383
+ transform: Optional[Callable] = None,
384
+ target_transform: Optional[Callable] = None,
385
+ transforms: Optional[Callable] = None,
386
+ seed=42,
387
+ ):
388
+ super().__init__(root, transforms, transform, target_transform)
389
+ with open(file_path, "r") as f:
390
+ examples = [json.loads(line) for line in f.readlines()]
391
+
392
+ self.rand_generator = default_rng(seed)
393
+
394
+ self.captions = []
395
+ self.image_paths = []
396
+
397
+ for example in examples:
398
+ if captions_per_image <= -1:
399
+ self.captions.append(example["captions"])
400
+ elif captions_per_image > 0:
401
+ self.captions.append(example["captions"][:captions_per_image])
402
+ else:
403
+ raise ValueError("captions per image cannot be zero")
404
+
405
+ self.image_paths.append(example["image_path"])
406
+
407
+ def _load_image(self, idx: int):
408
+ path = self.image_paths[idx]
409
+ im = read_image(path, mode=ImageReadMode.RGB)
410
+ return im
411
+
412
+ def _load_target(self, idx):
413
+ return self.rand_generator.choice(self.captions[idx])
414
+ # if len(self.captions[idx]) > 1:
415
+ # caption_idx = np.random.randint(0, len(self.captions[idx]))
416
+ # else:
417
+ # caption_idx = 0
418
+ # return self.captions[idx][caption_idx]
419
+
420
+ def __getitem__(self, index: int):
421
+ image = self._load_image(index)
422
+ target = self._load_target(index)
423
+
424
+ if self.transforms is not None:
425
+ image, target = self.transforms(image, target)
426
+
427
+ return image, target
428
+
429
+ def __len__(self) -> int:
430
+ return len(self.captions)
431
+
432
+
433
+ class TrainState(train_state.TrainState):
434
+ dropout_rng: jnp.ndarray
435
+
436
+ def replicate(self):
437
+ return jax_utils.replicate(self).replace(
438
+ dropout_rng=shard_prng_key(self.dropout_rng)
439
+ )
440
+
441
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
442
+ summary_writer.scalar("train_time", train_time, step)
443
+
444
+ train_metrics = get_metrics(train_metrics)
445
+ for key, vals in train_metrics.items():
446
+ tag = f"train_{key}"
447
+ for i, val in enumerate(vals):
448
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
449
+
450
+
451
+ def write_eval_metric(summary_writer, eval_metrics, step):
452
+ for metric_name, value in eval_metrics.items():
453
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
454
+
455
+
456
+ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
457
+ summary_writer.scalar("train_time", train_time, step)
458
+
459
+ train_metrics = get_metrics(train_metrics)
460
+ for key, vals in train_metrics.items():
461
+ tag = f"train_{key}"
462
+ for i, val in enumerate(vals):
463
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
464
+
465
+ for metric_name, value in eval_metrics.items():
466
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
467
+
468
+
469
+ def create_learning_rate_fn(
470
+ train_ds_size: int,
471
+ train_batch_size: int,
472
+ num_train_epochs: int,
473
+ num_warmup_steps: int,
474
+ learning_rate: float,
475
+ linear=False,
476
+ ) -> Callable[[int], jnp.array]:
477
+ """Returns a linear warmup, linear_decay learning rate function."""
478
+ steps_per_epoch = train_ds_size // train_batch_size
479
+ num_train_steps = steps_per_epoch * num_train_epochs
480
+ if linear:
481
+ warmup_fn = optax.linear_schedule(
482
+ init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
483
+ )
484
+ decay_fn = optax.linear_schedule(
485
+ init_value=learning_rate,
486
+ end_value=0,
487
+ transition_steps=num_train_steps - num_warmup_steps,
488
+ )
489
+ else:
490
+ warmup_fn = optax.linear_schedule(
491
+ init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
492
+ )
493
+ decay_fn = optax.cosine_decay_schedule(
494
+ init_value=learning_rate,
495
+ decay_steps=num_train_steps - num_warmup_steps,
496
+ alpha=0.0,
497
+ )
498
+ schedule_fn = optax.join_schedules(
499
+ schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]
500
+ )
501
+ return schedule_fn
502
+
503
+
504
+ def main():
505
+ parser = HfArgumentParser(
506
+ (ModelArguments, DataTrainingArguments, TrainingArguments)
507
+ )
508
+ parser.add_argument("--log_wandb", action="store_true")
509
+ parser.add_argument("--freeze_backbones", action="store_true")
510
+ parser.add_argument("--exp_name", type=str, default=None)
511
+ parser.add_argument("--run_from_checkpoint", type=str, default=None)
512
+
513
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
514
+ # If we pass only one argument to the script and it's the path to a json file,
515
+ # let's parse it to get our arguments.
516
+ model_args, data_args, training_args = parser.parse_json_file(
517
+ json_file=os.path.abspath(sys.argv[1])
518
+ )
519
+ else:
520
+ (
521
+ model_args,
522
+ data_args,
523
+ training_args,
524
+ args,
525
+ ) = parser.parse_args_into_dataclasses()
526
+
527
+ if (
528
+ os.path.exists(training_args.output_dir)
529
+ and os.listdir(training_args.output_dir)
530
+ and training_args.do_train
531
+ and not training_args.overwrite_output_dir
532
+ ):
533
+ raise ValueError(
534
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
535
+ "Use --overwrite_output_dir to overcome."
536
+ )
537
+
538
+ # Make one log on every process with the configuration for debugging.
539
+ logging.basicConfig(
540
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
541
+ datefmt="%m/%d/%Y %H:%M:%S",
542
+ level=logging.INFO,
543
+ )
544
+ # Setup logging, we only want one process per machine to log things on the screen.
545
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
546
+ if jax.process_index() == 0:
547
+ transformers.utils.logging.set_verbosity_info()
548
+ else:
549
+ transformers.utils.logging.set_verbosity_error()
550
+
551
+ # Set the verbosity to info of the Transformers logger (on main process only):
552
+ logger.info(f"Training/evaluation parameters {training_args}")
553
+
554
+ if model_args.tokenizer_name:
555
+ tokenizer = AutoTokenizer.from_pretrained(
556
+ model_args.tokenizer_name,
557
+ cache_dir=model_args.cache_dir,
558
+ use_fast=model_args.use_fast_tokenizer
559
+ )
560
+ elif model_args.text_model_name_or_path:
561
+ tokenizer = AutoTokenizer.from_pretrained(
562
+ model_args.text_model_name_or_path,
563
+ cache_dir=model_args.cache_dir,
564
+ use_fast=model_args.use_fast_tokenizer,
565
+ )
566
+ else:
567
+ raise ValueError(
568
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
569
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
570
+ )
571
+
572
+
573
+ if args.run_from_checkpoint is not None:
574
+ with open(f"{args.run_from_checkpoint}/config.json", "r") as fp:
575
+ config_dict = json.load(fp)
576
+ config_dict["vision_config"]["model_type"] = "clip"
577
+ config = HybridCLIPConfig(**config_dict)
578
+ model = FlaxHybridCLIP.from_pretrained(
579
+ args.run_from_checkpoint,
580
+ seed=training_args.seed,
581
+ dtype=getattr(jnp, model_args.dtype),
582
+ config=config,
583
+ freeze_backbones=args.freeze_backbones
584
+ )
585
+ else:
586
+
587
+ model = FlaxHybridCLIP.from_text_vision_pretrained(
588
+ model_args.text_model_name_or_path,
589
+ model_args.vision_model_name_or_path,
590
+ seed=training_args.seed,
591
+ dtype=getattr(jnp, model_args.dtype),
592
+ text_from_pt=False,
593
+ vision_from_pt=model_args.from_pt,
594
+ freeze_backbones=args.freeze_backbones
595
+ )
596
+ config = model.config
597
+ # set seed for torch dataloaders
598
+ set_seed(training_args.seed)
599
+
600
+ # Initialize torchvision transforms and jit them for faster processing
601
+ train_preprocess = Transform(config.vision_config.image_size, augment=True)
602
+ train_preprocess = torch.jit.script(train_preprocess)
603
+
604
+ val_preprocess = Transform(config.vision_config.image_size)
605
+ val_preprocess = torch.jit.script(val_preprocess)
606
+
607
+ # Initialize the image-text dataset
608
+ train_dataset = ImageTextDataset(
609
+ data_args.data_dir,
610
+ data_args.train_file,
611
+ captions_per_image=-1,
612
+ transform=train_preprocess,
613
+ seed=training_args.seed,
614
+ )
615
+
616
+ eval_dataset = ImageTextDataset(
617
+ data_args.data_dir,
618
+ data_args.validation_file,
619
+ captions_per_image=-1,
620
+ transform=val_preprocess,
621
+ seed=training_args.seed,
622
+ )
623
+
624
+ # Store some constant
625
+ num_epochs = int(training_args.num_train_epochs)
626
+ train_batch_size = (
627
+ int(training_args.per_device_train_batch_size) * jax.device_count()
628
+ )
629
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
630
+ steps_per_epoch = len(train_dataset) // train_batch_size
631
+ total_train_steps = steps_per_epoch * num_epochs
632
+
633
+ # Use collate function to tokenizer the text and convert the processed images to numpy
634
+ def collate_fn(examples):
635
+ pixel_values = (
636
+ torch.stack([example[0] for example in examples])
637
+ .permute(0, 2, 3, 1)
638
+ .numpy()
639
+ )
640
+ captions = [example[1] for example in examples]
641
+ inputs = tokenizer(
642
+ captions,
643
+ max_length=data_args.max_seq_length,
644
+ padding="max_length",
645
+ truncation=True,
646
+ return_tensors="np",
647
+ )
648
+
649
+ batch = {
650
+ "pixel_values": pixel_values,
651
+ "input_ids": inputs["input_ids"],
652
+ "attention_mask": inputs["attention_mask"],
653
+ }
654
+
655
+ return batch
656
+
657
+ # Create data loaders
658
+ train_loader = torch.utils.data.DataLoader(
659
+ train_dataset,
660
+ batch_size=train_batch_size,
661
+ shuffle=True,
662
+ num_workers=data_args.preprocessing_num_workers,
663
+ #persistent_workers=True,
664
+ drop_last=True,
665
+ collate_fn=collate_fn,
666
+ )
667
+
668
+ eval_loader = torch.utils.data.DataLoader(
669
+ eval_dataset,
670
+ batch_size=eval_batch_size,
671
+ shuffle=False,
672
+ num_workers=data_args.preprocessing_num_workers,
673
+ #persistent_workers=True,
674
+ drop_last=True,
675
+ collate_fn=collate_fn,
676
+ )
677
+
678
+ # Enable tensorboard only on the master node
679
+ if has_tensorboard and jax.process_index() == 0:
680
+ summary_writer = SummaryWriter(
681
+ log_dir=Path(training_args.output_dir).joinpath("logs").as_posix()
682
+ )
683
+
684
+ # Enable wandb
685
+ if jax.process_index() == 0 and args.log_wandb:
686
+ try:
687
+ wandb.init(
688
+ name=args.exp_name,
689
+ entity="galuh",
690
+ project="clip-indonesian",
691
+ sync_tensorboard=True
692
+ )
693
+ wandb.config.update(training_args)
694
+ wandb.config.update(model_args)
695
+ wandb.config.update(data_args)
696
+ except ImportError as e:
697
+ print(e)
698
+
699
+ # Initialize our training
700
+ rng = jax.random.PRNGKey(training_args.seed)
701
+ rng, dropout_rng = jax.random.split(rng)
702
+
703
+ # Create learning rate schedule
704
+ if training_args.warmup_steps:
705
+ warmup_steps = training_args.warmup_steps
706
+ elif training_args.warmup_ratio:
707
+ warmup_steps = int(training_args.warmup_ratio * total_train_steps)
708
+ else:
709
+ raise RuntimeError(
710
+ "You have to specify either the warmup_steps or warmup_ratio CLI parameter"
711
+ )
712
+
713
+ decay_lr_schedule_fn = create_learning_rate_fn(
714
+ len(train_dataset),
715
+ train_batch_size,
716
+ training_args.num_train_epochs,
717
+ warmup_steps,
718
+ training_args.learning_rate,
719
+ linear=False, # set False to activate cosine annealing
720
+ )
721
+
722
+ # create adam optimizer
723
+ # optimizer = optax.adamw(
724
+ # learning_rate=decay_lr_schedule_fn,
725
+ # b1=training_args.adam_beta1,
726
+ # b2=training_args.adam_beta2,
727
+ # eps=training_args.adam_epsilon,
728
+ # weight_decay=training_args.weight_decay,
729
+ # )
730
+
731
+ optimizer = optax.chain(
732
+ optax.adaptive_grad_clip(0.01, eps=0.001),
733
+ optax.scale_by_belief(),
734
+ optax.scale_by_schedule(decay_lr_schedule_fn),
735
+ optax.scale(-1.0),
736
+ )
737
+
738
+ '''optimizer = optax.adafactor(
739
+ learning_rate=decay_lr_schedule_fn,
740
+ )'''
741
+
742
+ # Setup train state
743
+ state = TrainState.create(
744
+ apply_fn=model.__call__,
745
+ params=model.params,
746
+ tx=optimizer,
747
+ dropout_rng=dropout_rng,
748
+ )
749
+
750
+ def cross_entropy(logits, axis):
751
+ logprobs = jax.nn.log_softmax(logits, axis=axis)
752
+ nll = jnp.diag(logprobs)
753
+ ce = -jnp.mean(nll)
754
+ return ce
755
+
756
+ def clip_loss(similarity):
757
+ loss = (
758
+ cross_entropy(similarity, axis=0) + cross_entropy(similarity, axis=1)
759
+ ) / 2
760
+ return loss
761
+
762
+ # Define gradient update step fn
763
+ def train_step(state, batch):
764
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
765
+
766
+ def compute_loss(params):
767
+ logits = state.apply_fn(
768
+ **batch, params=params, dropout_rng=dropout_rng, train=True
769
+ )[0]
770
+ loss = clip_loss(logits)
771
+ return loss
772
+
773
+ grad_fn = jax.value_and_grad(compute_loss)
774
+ loss, grad = grad_fn(state.params)
775
+ grad = jax.lax.pmean(grad, "batch")
776
+
777
+ new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
778
+
779
+ metrics = {
780
+ "loss": loss,
781
+ "learning_rate": decay_lr_schedule_fn(state.step),
782
+ }
783
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
784
+
785
+ return new_state, metrics
786
+
787
+ # Define eval fn
788
+ def eval_step(params, batch):
789
+ logits = model(**batch, params=params, train=False)[0]
790
+ loss = clip_loss(logits)
791
+
792
+ # summarize metrics
793
+ metrics = {"loss": loss}
794
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
795
+ return metrics
796
+
797
+ # Create parallel version of the train and eval step
798
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
799
+ p_eval_step = jax.pmap(eval_step, "batch")
800
+
801
+ # Replicate the train state on each device
802
+ state = state.replicate()
803
+
804
+ logger.info("***** Running training *****")
805
+ logger.info(f" TPU = {jax.device_count()}")
806
+ logger.info(f" Num examples = {len(train_dataset)}")
807
+ logger.info(f" Num Epochs = {num_epochs}")
808
+ logger.info(
809
+ f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
810
+ )
811
+ logger.info(
812
+ f" Total train batch size (w. parallel & distributed) = {train_batch_size}"
813
+ )
814
+ logger.info(f" Total optimization steps = {total_train_steps}")
815
+ logger.info(f" Total warmup steps = {warmup_steps}")
816
+
817
+ train_time = 0
818
+ # Create sampling rng
819
+ rng, input_rng = jax.random.split(rng)
820
+
821
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
822
+ for epoch in epochs:
823
+ # ======================== Training ================================
824
+ train_start = time.time()
825
+
826
+ # Create sampling rng
827
+ rng, input_rng = jax.random.split(rng)
828
+ train_metrics = []
829
+
830
+ num_train_samples = len(train_dataset)
831
+
832
+ steps_per_epoch = len(train_dataset) // train_batch_size
833
+ train_step_progress_bar = tqdm(
834
+ total=steps_per_epoch, desc="Training...", position=1, leave=False
835
+ )
836
+ # train
837
+ for step, batch in enumerate(train_loader):
838
+ batch = shard(batch)
839
+ state, train_metric = p_train_step(state, batch)
840
+ train_metrics.append(train_metric)
841
+
842
+ train_step_progress_bar.update(1)
843
+
844
+ cur_step = epoch * (num_train_samples // train_batch_size) + step + 1
845
+
846
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
847
+ train_time += time.time() - train_start
848
+ train_metric = unreplicate(train_metric)
849
+
850
+ # Save tensorboard metrics
851
+ if has_tensorboard and jax.process_index() == 0:
852
+ write_train_metric(
853
+ summary_writer, train_metrics, train_time, cur_step
854
+ )
855
+
856
+ # Save wandb metrics
857
+ if args.log_wandb and jax.process_index() == 0:
858
+ #_metrics = {k if k=="learning_rate" else f"train_{k}":mb_item(v.mean()) for k, v in train_metric.items()}
859
+ #_metrics = {k if k=="learning_rate" else f"train_{k}":mb_item(v.mean()) for k, v in train_metric.items()}
860
+ _metrics = {f'train_{k}': jax.device_get(v) for k,v in train_metric.items()}
861
+ wandb.log({"train_step":cur_step, **_metrics}, commit=True)
862
+
863
+ epochs.write(
864
+ f"Log at Step: {cur_step} (Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
865
+ )
866
+
867
+ logging.info("Emptying train metrics")
868
+
869
+ del train_metric
870
+ del train_metrics
871
+ train_metrics = []
872
+
873
+ gc.collect()
874
+ torch.cuda.empty_cache()
875
+
876
+ if cur_step % training_args.eval_steps == 0 and cur_step > 0:
877
+ # ======================== Evaluating ==============================
878
+ num_eval_samples = len(eval_dataset)
879
+ eval_metrics = []
880
+ eval_steps = len(eval_dataset) // eval_batch_size
881
+ eval_step_progress_bar = tqdm(
882
+ total=eval_steps, desc="Evaluating...", position=2, leave=False
883
+ )
884
+ for batch in eval_loader:
885
+ # Model forward
886
+ batch = shard(batch)
887
+ metrics = p_eval_step(state.params, batch)
888
+ eval_metrics.append(metrics)
889
+
890
+ eval_step_progress_bar.update(1)
891
+
892
+ # normalize eval metrics
893
+ eval_metrics = get_metrics(eval_metrics)
894
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
895
+
896
+ # Print metrics and update progress bar
897
+ desc = f"Eval at Step: {cur_step} (Loss: {eval_metrics['loss']})"
898
+ epochs.write(desc)
899
+ epochs.desc = desc
900
+
901
+ # Save tfboard eval
902
+ if has_tensorboard and jax.process_index() == 0:
903
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
904
+
905
+ # Save eval wandb
906
+ if args.log_wandb and jax.process_index() == 0:
907
+ #_metrics = {f"eval_{k}":mb_item(v) for k, v in eval_metrics.items()}
908
+ _metrics = {f'eval_{k}': jax.device_get(v) for k,v in eval_metrics.items()}
909
+ wandb.log({"eval_step":cur_step, **_metrics})
910
+
911
+ logging.info("Emptying eval metrics")
912
+ del eval_metrics
913
+
914
+ eval_metrics = []
915
+
916
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
917
+ # save checkpoint after each epoch and push checkpoint to the hub
918
+ if jax.process_index() == 0:
919
+ # params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
920
+ # model.save_pretrained(
921
+ # training_args.output_dir,
922
+ # params=params,
923
+ # push_to_hub=training_args.push_to_hub,
924
+ # commit_message=f"Saving weights and logs of step {cur_step}",
925
+ # )
926
+ save_model_checkpoint(
927
+ model,
928
+ training_args.output_dir,
929
+ state,
930
+ logger,
931
+ training_args.push_to_hub_organization,
932
+ with_opt=True,
933
+ push_to_hub=training_args.push_to_hub,
934
+ overwrite=True,
935
+ )
936
+ # if model_args.save_optimizer:
937
+ # # this saves full state including optimizer
938
+ # save_checkpoint(training_args.output_dir, state, state.step, keep=training_args.save_total_limit, overwrite=True)
939
+ if training_args.save_total_limit is not None:
940
+ rotate_checkpoints(
941
+ training_args.output_dir,
942
+ training_args.save_total_limit,
943
+ logger,
944
+ )
945
+
946
+ train_step_progress_bar.close() #check
947
+
948
+ '''# save checkpoint after each epoch and push checkpoint to the hub
949
+ if jax.process_index() == 0:
950
+ params = jax.device_get(unreplicate(state.params))
951
+ model.save_pretrained(
952
+ training_args.output_dir + f"/{epoch+1}/",
953
+ params=params,
954
+ push_to_hub=training_args.push_to_hub,
955
+ commit_message=f"Saving weights and logs of epoch {epoch+1}",
956
+ )'''
957
+
958
+ # save model after training is over
959
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
960
+ model.save_pretrained(
961
+ training_args.output_dir,
962
+ params=params,
963
+ push_to_hub=training_args.push_to_hub,
964
+ commit_message="Add final model",
965
+ )
966
+
967
+
968
+ if __name__ == "__main__":
969
+ main()
970
+
hybrid_clip/run_hybrid_clip_backup_2.py ADDED
@@ -0,0 +1,971 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Training a CLIP like dual encoder models using text and vision encoders in the library.
18
+
19
+ The script can be used to train CLIP like models for languages other than english by using
20
+ a text encoder pre-trained in the desired language. Currently this script support the following vision
21
+ and text models:
22
+ Vision models: ViT(https://huggingface.co/models?filter=vit), CLIP (https://huggingface.co/models?filter=clip)
23
+ Text models: BERT, ROBERTa (https://huggingface.co/models?filter=masked-lm)
24
+ """
25
+
26
+ import json
27
+ import logging
28
+ import os
29
+ import sys
30
+ import time
31
+ import numpy as np
32
+ from dataclasses import dataclass, field
33
+ from pathlib import Path
34
+ from typing import Callable, Optional
35
+ import shutil
36
+ import gc
37
+
38
+ try:
39
+ from dotenv import load_dotenv
40
+ load_dotenv("../.env")
41
+ except:
42
+ print("Couldn't find ../.env file")
43
+
44
+ import wandb
45
+ from transformers.file_utils import PushToHubMixin
46
+
47
+
48
+ import torch
49
+ from torchvision.datasets import VisionDataset
50
+ from torchvision.io import ImageReadMode, read_image
51
+ from torchvision.transforms import (
52
+ CenterCrop,
53
+ ConvertImageDtype,
54
+ Normalize,
55
+ Resize,
56
+ ColorJitter,
57
+ RandomHorizontalFlip,
58
+ RandomRotation,
59
+ RandomCrop,
60
+ RandomAffine,
61
+ RandomPerspective,
62
+ RandomAutocontrast,
63
+ RandomEqualize,
64
+ )
65
+ from torchvision.transforms.functional import InterpolationMode
66
+ from tqdm import tqdm
67
+
68
+ import jax
69
+ import jax.numpy as jnp
70
+ import optax
71
+ import transformers
72
+ from flax import jax_utils
73
+ from flax.jax_utils import unreplicate
74
+ from flax.training import train_state
75
+ from flax.training.common_utils import get_metrics, shard, shard_prng_key
76
+ from modeling_hybrid_clip import FlaxHybridCLIP
77
+ from configuration_hybrid_clip import HybridCLIPConfig
78
+ from transformers import (
79
+ AutoTokenizer,
80
+ HfArgumentParser,
81
+ TrainingArguments,
82
+ is_tensorboard_available,
83
+ set_seed,
84
+ )
85
+ from numpy.random import default_rng
86
+ from flax.serialization import to_bytes, from_bytes
87
+
88
+ logger = logging.getLogger(__name__)
89
+
90
+ def mb_item(x):
91
+ return x.item() if hasattr(x, "item") else x
92
+
93
+ # checkpoint functions
94
+ def save_model_checkpoint(
95
+ model,
96
+ save_dir,
97
+ state,
98
+ logger,
99
+ organization,
100
+ with_opt: bool = False,
101
+ push_to_hub: bool = False,
102
+ overwrite=False,
103
+ **kwargs,
104
+ ):
105
+ state = jax_utils.unreplicate(state)
106
+ #params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
107
+ logger.info(f"Saving Checkpoint in {save_dir}")
108
+ ckpt_save_dir = f"{save_dir}/ckpt-{mb_item(state.step)-1}"
109
+ if os.path.exists(ckpt_save_dir) and not overwrite:
110
+ logger.info("checkpoint exists, skipping overwrite")
111
+ else:
112
+ model.save_pretrained(
113
+ ckpt_save_dir, params=state.params, push_to_hub=False, **kwargs
114
+ )
115
+ if with_opt:
116
+ with open(os.path.join(ckpt_save_dir, "opt_state.msgpack"), "wb") as f:
117
+ f.write(to_bytes(state.opt_state))
118
+ with open(os.path.join(ckpt_save_dir, "training_state.json"), "w") as f:
119
+ json.dump({"step": state.step.item()}, f)
120
+
121
+ logger.info("checkpoint saved")
122
+
123
+ if push_to_hub:
124
+ repo_name = Path(save_dir).name
125
+ repo_url = PushToHubMixin._get_repo_url_from_name(
126
+ repo_name, organization=organization, private=False, use_auth_token=True
127
+ )
128
+ repo = PushToHubMixin._create_or_get_repo(
129
+ save_dir,
130
+ repo_url=repo_url,
131
+ organization=organization,
132
+ use_auth_token=True,
133
+ )
134
+ commit_message = f"Saving weights and logs at step {mb_item(state.step)-1}"
135
+ url = PushToHubMixin._push_to_hub(repo=repo, commit_message=commit_message)
136
+ logger.info(f"Model pushed to the hub in this commit: {url}")
137
+
138
+
139
+ def restore_model_checkpoint(save_dir, state, logger):
140
+ logger.info(f"Restoring checkpoint from {save_dir}.")
141
+ with open(os.path.join(save_dir, "flax_model.msgpack"), "rb") as f:
142
+ params = from_bytes(state.params, f.read())
143
+
144
+ with open(os.path.join(save_dir, "opt_state.msgpack"), "rb") as f:
145
+ opt_state = from_bytes(state.opt_state, f.read())
146
+
147
+ with open(os.path.join(save_dir, "training_state.json"), "r") as f:
148
+ training_state = json.load(f)
149
+ step = training_state["step"]
150
+
151
+ logger.info("checkpoint restored")
152
+ # return state.replace(step=step, params=params, opt_state=opt_state), step
153
+ return params, opt_state, step
154
+
155
+
156
+ def rotate_checkpoints(ckpt_dir: str, save_total_limit: int, logger):
157
+ "Removes older checkpoints so that `save_total_limit` checkpoints are kept"
158
+ # TODO: what to remove is decided using step number only, we might want to improve that
159
+ ckpts = [str(x) for x in Path(ckpt_dir).glob("ckpt-*")]
160
+ # sort checkpoints by step
161
+ ckpts_sorted = sorted(ckpts, key=lambda x: int(x.split("-")[-1]))
162
+ ckpts_to_delete = ckpts_sorted[:-save_total_limit]
163
+ for ckpt in ckpts_to_delete:
164
+ logger.info(
165
+ f"Deleting older checkpoint [{ckpt}] due to save_total_limit ({save_total_limit})"
166
+ )
167
+ shutil.rmtree(ckpt)
168
+
169
+ # Cache the result
170
+ has_tensorboard = is_tensorboard_available()
171
+ if has_tensorboard:
172
+ try:
173
+ from flax.metrics.tensorboard import SummaryWriter
174
+ except ImportError as ie:
175
+ has_tensorboard = False
176
+ print(
177
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
178
+ )
179
+
180
+ else:
181
+ print(
182
+ "Unable to display metrics through TensorBoard because the package is not installed: "
183
+ "Please run pip install tensorboard to enable."
184
+ )
185
+
186
+
187
+ @dataclass
188
+ class ModelArguments:
189
+ """
190
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
191
+ """
192
+
193
+ text_model_name_or_path: str = field(
194
+ metadata={
195
+ "help": "The text model checkpoint for weights initialization."
196
+ "Don't set if you want to train a model from scratch."
197
+ },
198
+ )
199
+ vision_model_name_or_path: str = field(
200
+ metadata={
201
+ "help": "The vision model checkpoint for weights initialization."
202
+ "Don't set if you want to train a model from scratch."
203
+ },
204
+ )
205
+ from_pt: bool = field(
206
+ default=True,
207
+ metadata={
208
+ "help": "whether to load the text and vision model using PyTorch checkpoints."
209
+ },
210
+ )
211
+ config_name: Optional[str] = field(
212
+ default=None,
213
+ metadata={
214
+ "help": "Pretrained config name or path if not the same as model_name"
215
+ },
216
+ )
217
+ tokenizer_name: Optional[str] = field(
218
+ default=None,
219
+ metadata={
220
+ "help": "Pretrained tokenizer name or path if not the same as model_name"
221
+ },
222
+ )
223
+ cache_dir: Optional[str] = field(
224
+ default=None,
225
+ metadata={
226
+ "help": "Where do you want to store the pretrained models downloaded from s3"
227
+ },
228
+ )
229
+ use_fast_tokenizer: bool = field(
230
+ default=True,
231
+ metadata={
232
+ "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
233
+ },
234
+ )
235
+ dtype: Optional[str] = field(
236
+ default="float32",
237
+ metadata={
238
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
239
+ },
240
+ )
241
+
242
+
243
+ @dataclass
244
+ class DataTrainingArguments:
245
+ """
246
+ Arguments pertaining to what data we are going to input our model for training and eval.
247
+ """
248
+
249
+ data_dir: Optional[str] = field(
250
+ default=None, metadata={"help": "The data directory containing input files."}
251
+ )
252
+ train_file: Optional[str] = field(
253
+ default=None,
254
+ metadata={"help": "The input training data file (a jsonlines file)."},
255
+ )
256
+ validation_file: Optional[str] = field(
257
+ default=None,
258
+ metadata={"help": "An optional input evaluation data file (a jsonlines file)."},
259
+ )
260
+ max_seq_length: Optional[int] = field(
261
+ default=72,
262
+ metadata={
263
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
264
+ "than this will be truncated, sequences shorter will be padded."
265
+ },
266
+ )
267
+ max_train_samples: Optional[int] = field(
268
+ default=None,
269
+ metadata={
270
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
271
+ "value if set."
272
+ },
273
+ )
274
+ max_eval_samples: Optional[int] = field(
275
+ default=None,
276
+ metadata={
277
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
278
+ "value if set."
279
+ },
280
+ )
281
+ overwrite_cache: bool = field(
282
+ default=False,
283
+ metadata={"help": "Overwrite the cached training and evaluation sets"},
284
+ )
285
+ overwrite_cache: bool = field(
286
+ default=False,
287
+ metadata={"help": "Overwrite the cached training and evaluation sets"},
288
+ )
289
+ preprocessing_num_workers: Optional[int] = field(
290
+ default=None,
291
+ metadata={"help": "The number of processes to use for the preprocessing."},
292
+ )
293
+
294
+ def __post_init__(self):
295
+ if self.train_file is None and self.validation_file is None:
296
+ raise ValueError(
297
+ "Need either a dataset name or a training/validation file."
298
+ )
299
+ else:
300
+ if self.train_file is not None:
301
+ extension = self.train_file.split(".")[-1]
302
+ assert extension == "json", "`train_file` should be a json file."
303
+ if self.validation_file is not None:
304
+ extension = self.validation_file.split(".")[-1]
305
+ assert extension == "json", "`validation_file` should be a json file."
306
+
307
+
308
+ # We use torchvision for faster image pre-processing.
309
+ # We need to ensure faster processing speed as it can become a bottleneck on TPU
310
+ class Transform(torch.nn.Module):
311
+ def __init__(self, image_size, augment=False):
312
+ super().__init__()
313
+ if not augment:
314
+ self.transforms = torch.nn.Sequential(
315
+ Resize([image_size], interpolation=InterpolationMode.BICUBIC),
316
+ CenterCrop(image_size),
317
+ ConvertImageDtype(torch.float),
318
+ Normalize(
319
+ (0.48145466, 0.4578275, 0.40821073),
320
+ (0.26862954, 0.26130258, 0.27577711),
321
+ ),
322
+ )
323
+ else:
324
+ self.transforms = torch.nn.Sequential(
325
+ Resize([image_size], interpolation=InterpolationMode.BICUBIC),
326
+ # CenterCrop(image_size),
327
+ RandomCrop([image_size], pad_if_needed=True, padding_mode="edge"),
328
+ ColorJitter(hue=0.1),
329
+ RandomHorizontalFlip(),
330
+ # RandomRotation(15, interpolation=InterpolationMode.BILINEAR, fill=128),
331
+ RandomAffine(
332
+ degrees=15,
333
+ translate=(0.1, 0.1),
334
+ scale=(0.8, 1.2),
335
+ shear=(-15, 15, -15, 15),
336
+ interpolation=InterpolationMode.BILINEAR,
337
+ fill=127,
338
+ ),
339
+ RandomPerspective(
340
+ distortion_scale=0.3,
341
+ p=0.3,
342
+ interpolation=InterpolationMode.BILINEAR,
343
+ fill=127,
344
+ ),
345
+ RandomAutocontrast(p=0.3),
346
+ RandomEqualize(p=0.3),
347
+ ConvertImageDtype(torch.float),
348
+ Normalize(
349
+ (0.48145466, 0.4578275, 0.40821073),
350
+ (0.26862954, 0.26130258, 0.27577711),
351
+ ),
352
+ )
353
+
354
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
355
+ with torch.no_grad():
356
+ x = self.transforms(x)
357
+ return x
358
+
359
+
360
+ class ImageTextDataset(VisionDataset):
361
+ """
362
+ Dtaset for loading image-text data for tasks like CLIP training, Image Captioning.
363
+
364
+ Args:
365
+ root: (string): The root path where the dataset is stored
366
+ file_path: (string): Path to the file containing the image_paths and associated captions.
367
+ The expected format is jsonlines where each line is a json object containing to keys.
368
+ `image_path`: The path to the image.
369
+ `captions`: An `array` of captions.
370
+ transform (callable, optional): A function/transform that takes in an PIL image
371
+ and returns a transformed version. E.g, ``transforms.ToTensor``
372
+ target_transform (callable, optional): A function/transform that takes in the
373
+ target and transforms it.
374
+ transforms (callable, optional): A function/transform that takes input sample and its target as entry
375
+ and returns a transformed version.
376
+ """
377
+
378
+ def __init__(
379
+ self,
380
+ root: str,
381
+ file_path: str,
382
+ captions_per_image=-1,
383
+ transform: Optional[Callable] = None,
384
+ target_transform: Optional[Callable] = None,
385
+ transforms: Optional[Callable] = None,
386
+ seed=42,
387
+ ):
388
+ super().__init__(root, transforms, transform, target_transform)
389
+ with open(file_path, "r") as f:
390
+ #examples = [json.loads(line) for line in f.readlines()]
391
+ examples = np.array([json.loads(line) for line in f.readlines()])
392
+
393
+ self.rand_generator = default_rng(seed)
394
+
395
+ self.captions = []
396
+ self.image_paths = []
397
+
398
+ for example in examples:
399
+ if captions_per_image <= -1:
400
+ self.captions.append(example["captions"])
401
+ elif captions_per_image > 0:
402
+ self.captions.append(example["captions"][:captions_per_image])
403
+ else:
404
+ raise ValueError("captions per image cannot be zero")
405
+
406
+ self.image_paths.append(example["image_path"])
407
+
408
+ def _load_image(self, idx: int):
409
+ path = self.image_paths[idx]
410
+ im = read_image(path, mode=ImageReadMode.RGB)
411
+ return im
412
+
413
+ def _load_target(self, idx):
414
+ return self.rand_generator.choice(self.captions[idx])
415
+ # if len(self.captions[idx]) > 1:
416
+ # caption_idx = np.random.randint(0, len(self.captions[idx]))
417
+ # else:
418
+ # caption_idx = 0
419
+ # return self.captions[idx][caption_idx]
420
+
421
+ def __getitem__(self, index: int):
422
+ image = self._load_image(index)
423
+ target = self._load_target(index)
424
+
425
+ if self.transforms is not None:
426
+ image, target = self.transforms(image, target)
427
+
428
+ return image, target
429
+
430
+ def __len__(self) -> int:
431
+ return len(self.captions)
432
+
433
+
434
+ class TrainState(train_state.TrainState):
435
+ dropout_rng: jnp.ndarray
436
+
437
+ def replicate(self):
438
+ return jax_utils.replicate(self).replace(
439
+ dropout_rng=shard_prng_key(self.dropout_rng)
440
+ )
441
+
442
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
443
+ summary_writer.scalar("train_time", train_time, step)
444
+
445
+ train_metrics = get_metrics(train_metrics)
446
+ for key, vals in train_metrics.items():
447
+ tag = f"train_{key}"
448
+ for i, val in enumerate(vals):
449
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
450
+
451
+
452
+ def write_eval_metric(summary_writer, eval_metrics, step):
453
+ for metric_name, value in eval_metrics.items():
454
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
455
+
456
+
457
+ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
458
+ summary_writer.scalar("train_time", train_time, step)
459
+
460
+ train_metrics = get_metrics(train_metrics)
461
+ for key, vals in train_metrics.items():
462
+ tag = f"train_{key}"
463
+ for i, val in enumerate(vals):
464
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
465
+
466
+ for metric_name, value in eval_metrics.items():
467
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
468
+
469
+
470
+ def create_learning_rate_fn(
471
+ train_ds_size: int,
472
+ train_batch_size: int,
473
+ num_train_epochs: int,
474
+ num_warmup_steps: int,
475
+ learning_rate: float,
476
+ linear=False,
477
+ ) -> Callable[[int], jnp.array]:
478
+ """Returns a linear warmup, linear_decay learning rate function."""
479
+ steps_per_epoch = train_ds_size // train_batch_size
480
+ num_train_steps = steps_per_epoch * num_train_epochs
481
+ if linear:
482
+ warmup_fn = optax.linear_schedule(
483
+ init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
484
+ )
485
+ decay_fn = optax.linear_schedule(
486
+ init_value=learning_rate,
487
+ end_value=0,
488
+ transition_steps=num_train_steps - num_warmup_steps,
489
+ )
490
+ else:
491
+ warmup_fn = optax.linear_schedule(
492
+ init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
493
+ )
494
+ decay_fn = optax.cosine_decay_schedule(
495
+ init_value=learning_rate,
496
+ decay_steps=num_train_steps - num_warmup_steps,
497
+ alpha=0.0,
498
+ )
499
+ schedule_fn = optax.join_schedules(
500
+ schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]
501
+ )
502
+ return schedule_fn
503
+
504
+
505
+ def main():
506
+ parser = HfArgumentParser(
507
+ (ModelArguments, DataTrainingArguments, TrainingArguments)
508
+ )
509
+ parser.add_argument("--log_wandb", action="store_true")
510
+ parser.add_argument("--freeze_backbones", action="store_true")
511
+ parser.add_argument("--exp_name", type=str, default=None)
512
+ parser.add_argument("--run_from_checkpoint", type=str, default=None)
513
+
514
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
515
+ # If we pass only one argument to the script and it's the path to a json file,
516
+ # let's parse it to get our arguments.
517
+ model_args, data_args, training_args = parser.parse_json_file(
518
+ json_file=os.path.abspath(sys.argv[1])
519
+ )
520
+ else:
521
+ (
522
+ model_args,
523
+ data_args,
524
+ training_args,
525
+ args,
526
+ ) = parser.parse_args_into_dataclasses()
527
+
528
+ if (
529
+ os.path.exists(training_args.output_dir)
530
+ and os.listdir(training_args.output_dir)
531
+ and training_args.do_train
532
+ and not training_args.overwrite_output_dir
533
+ ):
534
+ raise ValueError(
535
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
536
+ "Use --overwrite_output_dir to overcome."
537
+ )
538
+
539
+ # Make one log on every process with the configuration for debugging.
540
+ logging.basicConfig(
541
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
542
+ datefmt="%m/%d/%Y %H:%M:%S",
543
+ level=logging.INFO,
544
+ )
545
+ # Setup logging, we only want one process per machine to log things on the screen.
546
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
547
+ if jax.process_index() == 0:
548
+ transformers.utils.logging.set_verbosity_info()
549
+ else:
550
+ transformers.utils.logging.set_verbosity_error()
551
+
552
+ # Set the verbosity to info of the Transformers logger (on main process only):
553
+ logger.info(f"Training/evaluation parameters {training_args}")
554
+
555
+ if model_args.tokenizer_name:
556
+ tokenizer = AutoTokenizer.from_pretrained(
557
+ model_args.tokenizer_name,
558
+ cache_dir=model_args.cache_dir,
559
+ use_fast=model_args.use_fast_tokenizer
560
+ )
561
+ elif model_args.text_model_name_or_path:
562
+ tokenizer = AutoTokenizer.from_pretrained(
563
+ model_args.text_model_name_or_path,
564
+ cache_dir=model_args.cache_dir,
565
+ use_fast=model_args.use_fast_tokenizer,
566
+ )
567
+ else:
568
+ raise ValueError(
569
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
570
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
571
+ )
572
+
573
+
574
+ if args.run_from_checkpoint is not None:
575
+ with open(f"{args.run_from_checkpoint}/config.json", "r") as fp:
576
+ config_dict = json.load(fp)
577
+ config_dict["vision_config"]["model_type"] = "clip"
578
+ config = HybridCLIPConfig(**config_dict)
579
+ model = FlaxHybridCLIP.from_pretrained(
580
+ args.run_from_checkpoint,
581
+ seed=training_args.seed,
582
+ dtype=getattr(jnp, model_args.dtype),
583
+ config=config,
584
+ freeze_backbones=args.freeze_backbones
585
+ )
586
+ else:
587
+
588
+ model = FlaxHybridCLIP.from_text_vision_pretrained(
589
+ model_args.text_model_name_or_path,
590
+ model_args.vision_model_name_or_path,
591
+ seed=training_args.seed,
592
+ dtype=getattr(jnp, model_args.dtype),
593
+ text_from_pt=False,
594
+ vision_from_pt=model_args.from_pt,
595
+ freeze_backbones=args.freeze_backbones
596
+ )
597
+ config = model.config
598
+ # set seed for torch dataloaders
599
+ set_seed(training_args.seed)
600
+
601
+ # Initialize torchvision transforms and jit them for faster processing
602
+ train_preprocess = Transform(config.vision_config.image_size, augment=True)
603
+ train_preprocess = torch.jit.script(train_preprocess)
604
+
605
+ val_preprocess = Transform(config.vision_config.image_size)
606
+ val_preprocess = torch.jit.script(val_preprocess)
607
+
608
+ # Initialize the image-text dataset
609
+ train_dataset = ImageTextDataset(
610
+ data_args.data_dir,
611
+ data_args.train_file,
612
+ captions_per_image=-1,
613
+ transform=train_preprocess,
614
+ seed=training_args.seed,
615
+ )
616
+
617
+ eval_dataset = ImageTextDataset(
618
+ data_args.data_dir,
619
+ data_args.validation_file,
620
+ captions_per_image=-1,
621
+ transform=val_preprocess,
622
+ seed=training_args.seed,
623
+ )
624
+
625
+ # Store some constant
626
+ num_epochs = int(training_args.num_train_epochs)
627
+ train_batch_size = (
628
+ int(training_args.per_device_train_batch_size) * jax.device_count()
629
+ )
630
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
631
+ steps_per_epoch = len(train_dataset) // train_batch_size
632
+ total_train_steps = steps_per_epoch * num_epochs
633
+
634
+ # Use collate function to tokenizer the text and convert the processed images to numpy
635
+ def collate_fn(examples):
636
+ pixel_values = (
637
+ torch.stack([example[0] for example in examples])
638
+ .permute(0, 2, 3, 1)
639
+ .numpy()
640
+ )
641
+ captions = [example[1] for example in examples]
642
+ inputs = tokenizer(
643
+ captions,
644
+ max_length=data_args.max_seq_length,
645
+ padding="max_length",
646
+ truncation=True,
647
+ return_tensors="np",
648
+ )
649
+
650
+ batch = {
651
+ "pixel_values": pixel_values,
652
+ "input_ids": inputs["input_ids"],
653
+ "attention_mask": inputs["attention_mask"],
654
+ }
655
+
656
+ return batch
657
+
658
+ # Create data loaders
659
+ train_loader = torch.utils.data.DataLoader(
660
+ train_dataset,
661
+ batch_size=train_batch_size,
662
+ shuffle=True,
663
+ num_workers=data_args.preprocessing_num_workers,
664
+ #persistent_workers=True,
665
+ drop_last=True,
666
+ collate_fn=collate_fn,
667
+ )
668
+
669
+ eval_loader = torch.utils.data.DataLoader(
670
+ eval_dataset,
671
+ batch_size=eval_batch_size,
672
+ shuffle=False,
673
+ num_workers=data_args.preprocessing_num_workers,
674
+ #persistent_workers=True,
675
+ drop_last=True,
676
+ collate_fn=collate_fn,
677
+ )
678
+
679
+ # Enable tensorboard only on the master node
680
+ if has_tensorboard and jax.process_index() == 0:
681
+ summary_writer = SummaryWriter(
682
+ log_dir=Path(training_args.output_dir).joinpath("logs").as_posix()
683
+ )
684
+
685
+ # Enable wandb
686
+ if jax.process_index() == 0 and args.log_wandb:
687
+ try:
688
+ wandb.init(
689
+ name=args.exp_name,
690
+ entity="galuh",
691
+ project="clip-indonesian",
692
+ sync_tensorboard=True
693
+ )
694
+ wandb.config.update(training_args)
695
+ wandb.config.update(model_args)
696
+ wandb.config.update(data_args)
697
+ except ImportError as e:
698
+ print(e)
699
+
700
+ # Initialize our training
701
+ rng = jax.random.PRNGKey(training_args.seed)
702
+ rng, dropout_rng = jax.random.split(rng)
703
+
704
+ # Create learning rate schedule
705
+ if training_args.warmup_steps:
706
+ warmup_steps = training_args.warmup_steps
707
+ elif training_args.warmup_ratio:
708
+ warmup_steps = int(training_args.warmup_ratio * total_train_steps)
709
+ else:
710
+ raise RuntimeError(
711
+ "You have to specify either the warmup_steps or warmup_ratio CLI parameter"
712
+ )
713
+
714
+ decay_lr_schedule_fn = create_learning_rate_fn(
715
+ len(train_dataset),
716
+ train_batch_size,
717
+ training_args.num_train_epochs,
718
+ warmup_steps,
719
+ training_args.learning_rate,
720
+ linear=False, # set False to activate cosine annealing
721
+ )
722
+
723
+ # create adam optimizer
724
+ # optimizer = optax.adamw(
725
+ # learning_rate=decay_lr_schedule_fn,
726
+ # b1=training_args.adam_beta1,
727
+ # b2=training_args.adam_beta2,
728
+ # eps=training_args.adam_epsilon,
729
+ # weight_decay=training_args.weight_decay,
730
+ # )
731
+
732
+ optimizer = optax.chain(
733
+ optax.adaptive_grad_clip(0.01, eps=0.001),
734
+ optax.scale_by_belief(),
735
+ optax.scale_by_schedule(decay_lr_schedule_fn),
736
+ optax.scale(-1.0),
737
+ )
738
+
739
+ '''optimizer = optax.adafactor(
740
+ learning_rate=decay_lr_schedule_fn,
741
+ )'''
742
+
743
+ # Setup train state
744
+ state = TrainState.create(
745
+ apply_fn=model.__call__,
746
+ params=model.params,
747
+ tx=optimizer,
748
+ dropout_rng=dropout_rng,
749
+ )
750
+
751
+ def cross_entropy(logits, axis):
752
+ logprobs = jax.nn.log_softmax(logits, axis=axis)
753
+ nll = jnp.diag(logprobs)
754
+ ce = -jnp.mean(nll)
755
+ return ce
756
+
757
+ def clip_loss(similarity):
758
+ loss = (
759
+ cross_entropy(similarity, axis=0) + cross_entropy(similarity, axis=1)
760
+ ) / 2
761
+ return loss
762
+
763
+ # Define gradient update step fn
764
+ def train_step(state, batch):
765
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
766
+
767
+ def compute_loss(params):
768
+ logits = state.apply_fn(
769
+ **batch, params=params, dropout_rng=dropout_rng, train=True
770
+ )[0]
771
+ loss = clip_loss(logits)
772
+ return loss
773
+
774
+ grad_fn = jax.value_and_grad(compute_loss)
775
+ loss, grad = grad_fn(state.params)
776
+ grad = jax.lax.pmean(grad, "batch")
777
+
778
+ new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
779
+
780
+ metrics = {
781
+ "loss": loss,
782
+ "learning_rate": decay_lr_schedule_fn(state.step),
783
+ }
784
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
785
+
786
+ return new_state, metrics
787
+
788
+ # Define eval fn
789
+ def eval_step(params, batch):
790
+ logits = model(**batch, params=params, train=False)[0]
791
+ loss = clip_loss(logits)
792
+
793
+ # summarize metrics
794
+ metrics = {"loss": loss}
795
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
796
+ return metrics
797
+
798
+ # Create parallel version of the train and eval step
799
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
800
+ p_eval_step = jax.pmap(eval_step, "batch")
801
+
802
+ # Replicate the train state on each device
803
+ state = state.replicate()
804
+
805
+ logger.info("***** Running training *****")
806
+ logger.info(f" TPU = {jax.device_count()}")
807
+ logger.info(f" Num examples = {len(train_dataset)}")
808
+ logger.info(f" Num Epochs = {num_epochs}")
809
+ logger.info(
810
+ f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
811
+ )
812
+ logger.info(
813
+ f" Total train batch size (w. parallel & distributed) = {train_batch_size}"
814
+ )
815
+ logger.info(f" Total optimization steps = {total_train_steps}")
816
+ logger.info(f" Total warmup steps = {warmup_steps}")
817
+
818
+ train_time = 0
819
+ # Create sampling rng
820
+ rng, input_rng = jax.random.split(rng)
821
+
822
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
823
+ for epoch in epochs:
824
+ # ======================== Training ================================
825
+ train_start = time.time()
826
+
827
+ # Create sampling rng
828
+ rng, input_rng = jax.random.split(rng)
829
+ train_metrics = []
830
+
831
+ num_train_samples = len(train_dataset)
832
+
833
+ steps_per_epoch = len(train_dataset) // train_batch_size
834
+ train_step_progress_bar = tqdm(
835
+ total=steps_per_epoch, desc="Training...", position=1, leave=False
836
+ )
837
+ # train
838
+ for step, batch in enumerate(train_loader):
839
+ batch = shard(batch)
840
+ state, train_metric = p_train_step(state, batch)
841
+ train_metrics.append(train_metric)
842
+
843
+ train_step_progress_bar.update(1)
844
+
845
+ cur_step = epoch * (num_train_samples // train_batch_size) + step + 1
846
+
847
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
848
+ train_time += time.time() - train_start
849
+ train_metric = unreplicate(train_metric)
850
+
851
+ # Save tensorboard metrics
852
+ if has_tensorboard and jax.process_index() == 0:
853
+ write_train_metric(
854
+ summary_writer, train_metrics, train_time, cur_step
855
+ )
856
+
857
+ # Save wandb metrics
858
+ if args.log_wandb and jax.process_index() == 0:
859
+ #_metrics = {k if k=="learning_rate" else f"train_{k}":mb_item(v.mean()) for k, v in train_metric.items()}
860
+ #_metrics = {k if k=="learning_rate" else f"train_{k}":mb_item(v.mean()) for k, v in train_metric.items()}
861
+ _metrics = {f'train_{k}': jax.device_get(v) for k,v in train_metric.items()}
862
+ wandb.log({"train_step":cur_step, **_metrics}, commit=True)
863
+
864
+ epochs.write(
865
+ f"Log at Step: {cur_step} (Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
866
+ )
867
+
868
+ logging.info("Emptying train metrics")
869
+
870
+ del train_metric
871
+ del train_metrics
872
+ train_metrics = []
873
+
874
+ gc.collect()
875
+ torch.cuda.empty_cache()
876
+
877
+ if cur_step % training_args.eval_steps == 0 and cur_step > 0:
878
+ # ======================== Evaluating ==============================
879
+ num_eval_samples = len(eval_dataset)
880
+ eval_metrics = []
881
+ eval_steps = len(eval_dataset) // eval_batch_size
882
+ eval_step_progress_bar = tqdm(
883
+ total=eval_steps, desc="Evaluating...", position=2, leave=False
884
+ )
885
+ for batch in eval_loader:
886
+ # Model forward
887
+ batch = shard(batch)
888
+ metrics = p_eval_step(state.params, batch)
889
+ eval_metrics.append(metrics)
890
+
891
+ eval_step_progress_bar.update(1)
892
+
893
+ # normalize eval metrics
894
+ eval_metrics = get_metrics(eval_metrics)
895
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
896
+
897
+ # Print metrics and update progress bar
898
+ desc = f"Eval at Step: {cur_step} (Loss: {eval_metrics['loss']})"
899
+ epochs.write(desc)
900
+ epochs.desc = desc
901
+
902
+ # Save tfboard eval
903
+ if has_tensorboard and jax.process_index() == 0:
904
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
905
+
906
+ # Save eval wandb
907
+ if args.log_wandb and jax.process_index() == 0:
908
+ #_metrics = {f"eval_{k}":mb_item(v) for k, v in eval_metrics.items()}
909
+ _metrics = {f'eval_{k}': jax.device_get(v) for k,v in eval_metrics.items()}
910
+ wandb.log({"eval_step":cur_step, **_metrics})
911
+
912
+ logging.info("Emptying eval metrics")
913
+ del eval_metrics
914
+
915
+ eval_metrics = []
916
+
917
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
918
+ # save checkpoint after each epoch and push checkpoint to the hub
919
+ if jax.process_index() == 0:
920
+ # params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
921
+ # model.save_pretrained(
922
+ # training_args.output_dir,
923
+ # params=params,
924
+ # push_to_hub=training_args.push_to_hub,
925
+ # commit_message=f"Saving weights and logs of step {cur_step}",
926
+ # )
927
+ save_model_checkpoint(
928
+ model,
929
+ training_args.output_dir,
930
+ state,
931
+ logger,
932
+ training_args.push_to_hub_organization,
933
+ with_opt=True,
934
+ push_to_hub=training_args.push_to_hub,
935
+ overwrite=True,
936
+ )
937
+ # if model_args.save_optimizer:
938
+ # # this saves full state including optimizer
939
+ # save_checkpoint(training_args.output_dir, state, state.step, keep=training_args.save_total_limit, overwrite=True)
940
+ if training_args.save_total_limit is not None:
941
+ rotate_checkpoints(
942
+ training_args.output_dir,
943
+ training_args.save_total_limit,
944
+ logger,
945
+ )
946
+
947
+ train_step_progress_bar.close() #check
948
+
949
+ '''# save checkpoint after each epoch and push checkpoint to the hub
950
+ if jax.process_index() == 0:
951
+ params = jax.device_get(unreplicate(state.params))
952
+ model.save_pretrained(
953
+ training_args.output_dir + f"/{epoch+1}/",
954
+ params=params,
955
+ push_to_hub=training_args.push_to_hub,
956
+ commit_message=f"Saving weights and logs of epoch {epoch+1}",
957
+ )'''
958
+
959
+ # save model after training is over
960
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
961
+ model.save_pretrained(
962
+ training_args.output_dir,
963
+ params=params,
964
+ push_to_hub=training_args.push_to_hub,
965
+ commit_message="Add final model",
966
+ )
967
+
968
+
969
+ if __name__ == "__main__":
970
+ main()
971
+
hybrid_clip/run_training.sh ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ SCRIPT_DIR=.
4
+ MODEL_DIR=/mnt/disks/data-1/models/training_v4_unfreeze
5
+
6
+ IMAGE_ENCODER="openai/clip-vit-base-patch32"
7
+ TEXT_ENCODER="flax-community/indonesian-roberta-base"
8
+
9
+ python ${SCRIPT_DIR}/run_hybrid_clip.py \
10
+ --output_dir ${MODEL_DIR} \
11
+ --overwrite_output_dir \
12
+ --tokenizer_name=${TEXT_ENCODER} \
13
+ --train_file="../data/train_dataset_v6.json" \
14
+ --validation_file="../data/val_dataset_v6.json" \
15
+ --do_train --do_eval \
16
+ --num_train_epochs="20" --max_seq_length 96 \
17
+ --per_device_train_batch_size="64" \
18
+ --per_device_eval_batch_size="64" \
19
+ --learning_rate="0.00001" --warmup_ratio 0.1 --weight_decay 0.0 \
20
+ --preprocessing_num_workers 16 \
21
+ --exp_name training_v4_unfreeze \
22
+ --text_model_name_or_path=${TEXT_ENCODER} \
23
+ --vision_model_name_or_path=${IMAGE_ENCODER} \
24
+ --eval_steps 500 \
25
+ --logging_steps 100 \
26
+ --save_steps 500 \
27
+ --save_total_limit 5 \
28
+ --log_wandb \
29
+ --run_from_checkpoint="/mnt/disks/data-1/models/training_v4/ckpt-70999" # edit
30
+ #--freeze_backbones
31
+ #--push_to_hub
hybrid_clip/run_training_backup.sh ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ SCRIPT_DIR=.
4
+ MODEL_DIR=~/models/training_v3_new
5
+
6
+ IMAGE_ENCODER="openai/clip-vit-base-patch32"
7
+ TEXT_ENCODER="indobenchmark/indobert-base-p2"
8
+
9
+ python ${SCRIPT_DIR}/run_hybrid_clip.py \
10
+ --output_dir ${MODEL_DIR} \
11
+ --overwrite_output_dir \
12
+ --tokenizer_name=${TEXT_ENCODER} \
13
+ --train_file="../data/train_dataset_v3.json" \
14
+ --validation_file="../data/val_dataset_v3.json" \
15
+ --do_train --do_eval \
16
+ --num_train_epochs="10" --max_seq_length 96 \
17
+ --per_device_train_batch_size="64" \
18
+ --per_device_eval_batch_size="64" \
19
+ --learning_rate="0.00005" --warmup_ratio 0.1 --weight_decay 0.0 \
20
+ --preprocessing_num_workers 16 \
21
+ --exp_name training_v3 \
22
+ --text_model_name_or_path=${TEXT_ENCODER} \
23
+ --vision_model_name_or_path=${IMAGE_ENCODER} \
24
+ --eval_steps 2500 \
25
+ --logging_steps 200 \
26
+ --save_steps 2500 \
27
+ --save_total_limit 5 \
28
+ --log_wandb \
29
+ --freeze_backbones
30
+ #--push_to_hub
hybrid_clip/run_training_unfreeze.sh ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ SCRIPT_DIR=.
4
+ MODEL_DIR=~/models/training_v3_new_unfreeze
5
+
6
+ IMAGE_ENCODER="openai/clip-vit-base-patch32"
7
+ TEXT_ENCODER="indobenchmark/indobert-base-p2"
8
+
9
+ python ${SCRIPT_DIR}/run_hybrid_clip.py \
10
+ --output_dir ${MODEL_DIR} \
11
+ --overwrite_output_dir \
12
+ --tokenizer_name=${TEXT_ENCODER} \
13
+ --train_file="../data/train_dataset_v3.json" \
14
+ --validation_file="../data/val_dataset_v3.json" \
15
+ --do_train --do_eval \
16
+ --num_train_epochs="10" --max_seq_length 96 \
17
+ --per_device_train_batch_size="64" \
18
+ --per_device_eval_batch_size="64" \
19
+ --learning_rate="0.00005" --warmup_ratio 0.1 --weight_decay 0.0 \
20
+ --preprocessing_num_workers 16 \
21
+ --exp_name training_v3_unfreeze \
22
+ --text_model_name_or_path=${TEXT_ENCODER} \
23
+ --vision_model_name_or_path=${IMAGE_ENCODER} \
24
+ --eval_steps 2500 \
25
+ --logging_steps 200 \
26
+ --save_steps 2500 \
27
+ --save_total_limit 5 \
28
+ --log_wandb \
29
+ --run_from_checkpoint="../../models/training_v3_new/ckpt-42499"
30
+ #--freeze_backbones
31
+ #--push_to_hub
hybrid_clip/run_training_unfreeze_2.sh ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ SCRIPT_DIR=.
4
+ MODEL_DIR=~/models/training_v3_new_unfreeze_2
5
+
6
+ IMAGE_ENCODER="openai/clip-vit-base-patch32"
7
+ TEXT_ENCODER="indobenchmark/indobert-base-p2"
8
+
9
+ python ${SCRIPT_DIR}/run_hybrid_clip.py \
10
+ --output_dir ${MODEL_DIR} \
11
+ --overwrite_output_dir \
12
+ --tokenizer_name=${TEXT_ENCODER} \
13
+ --train_file="../data/train_dataset_v3.json" \
14
+ --validation_file="../data/val_dataset_v3.json" \
15
+ --do_train --do_eval \
16
+ --num_train_epochs="10" --max_seq_length 96 \
17
+ --per_device_train_batch_size="64" \
18
+ --per_device_eval_batch_size="64" \
19
+ --learning_rate="0.00005" --warmup_ratio 0.1 --weight_decay 0.0 \
20
+ --preprocessing_num_workers 16 \
21
+ --exp_name training_v3_unfreeze_2 \
22
+ --text_model_name_or_path=${TEXT_ENCODER} \
23
+ --vision_model_name_or_path=${IMAGE_ENCODER} \
24
+ --eval_steps 2500 \
25
+ --logging_steps 200 \
26
+ --save_steps 2500 \
27
+ --save_total_limit 5 \
28
+ --log_wandb \
29
+ --run_from_checkpoint="../../models/training_v3_new_unfreeze/ckpt-12499"
30
+ #--freeze_backbones
31
+ #--push_to_hub