edugp commited on
Commit
4463ade
1 Parent(s): de5d267

Format with black and isort and lint with flake8

Browse files
README.md CHANGED
@@ -7,18 +7,22 @@ tags:
7
  - vit
8
  ---
9
  # CLIP-Spanish
 
10
  CLIP Spanish is a CLIP-like model for Spanish language. It is composed of [BERTIN](https://huggingface.co/bertin-project/bertin-roberta-base-spanish) as a language encoder and the ViT-B/32 image encoder from [CLIP](https://huggingface.co/openai/clip-vit-base-patch32). The model is implemented in [Flax](https://github.com/google/flax), including training scripts (see `training.md`).
11
  This is part of the [Flax/Jax Community Week](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104), organised by [HuggingFace](https://huggingface.co/) and TPU usage sponsored by Google.
12
 
13
  ## Spanish WIT
 
14
  We used a subset of 141,230 Spanish captions from the [WIT dataset](https://github.com/google-research-datasets/wit) for training.
15
 
16
  ## Team members
 
17
  - Eduardo González Ponferrada ([edugp](https://huggingface.co/edugp))
18
  - Manu Romero ([mrm8488](https://huggingface.co/))
19
  - María Grandury ([mariagrandury](https://huggingface.co/))
20
 
21
  ## Useful links
 
22
  - [Community Week timeline](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104#summary-timeline-calendar-6)
23
  - [Community Week README](https://github.com/huggingface/transformers/blob/master/examples/research_projects/jax-projects/README.md)
24
  - [Community Week thread](https://discuss.huggingface.co/t/bertin-pretrain-roberta-large-from-scratch-in-spanish/7125)
 
7
  - vit
8
  ---
9
  # CLIP-Spanish
10
+
11
  CLIP Spanish is a CLIP-like model for Spanish language. It is composed of [BERTIN](https://huggingface.co/bertin-project/bertin-roberta-base-spanish) as a language encoder and the ViT-B/32 image encoder from [CLIP](https://huggingface.co/openai/clip-vit-base-patch32). The model is implemented in [Flax](https://github.com/google/flax), including training scripts (see `training.md`).
12
  This is part of the [Flax/Jax Community Week](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104), organised by [HuggingFace](https://huggingface.co/) and TPU usage sponsored by Google.
13
 
14
  ## Spanish WIT
15
+
16
  We used a subset of 141,230 Spanish captions from the [WIT dataset](https://github.com/google-research-datasets/wit) for training.
17
 
18
  ## Team members
19
+
20
  - Eduardo González Ponferrada ([edugp](https://huggingface.co/edugp))
21
  - Manu Romero ([mrm8488](https://huggingface.co/))
22
  - María Grandury ([mariagrandury](https://huggingface.co/))
23
 
24
  ## Useful links
25
+
26
  - [Community Week timeline](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104#summary-timeline-calendar-6)
27
  - [Community Week README](https://github.com/huggingface/transformers/blob/master/examples/research_projects/jax-projects/README.md)
28
  - [Community Week thread](https://discuss.huggingface.co/t/bertin-pretrain-roberta-large-from-scratch-in-spanish/7125)
configuration_hybrid_clip.py CHANGED
@@ -3,7 +3,6 @@ import copy
3
  from transformers.configuration_utils import PretrainedConfig
4
  from transformers.utils import logging
5
 
6
-
7
  logger = logging.get_logger(__name__)
8
 
9
 
@@ -64,19 +63,25 @@ class HybridCLIPConfig(PretrainedConfig):
64
  self.text_config = AutoConfig.for_model(text_model_type, **text_config)
65
 
66
  if vision_model_type == "clip":
67
- self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config).vision_config
 
 
68
  elif vision_model_type == "clip_vision_model":
69
  from transformers import CLIPVisionConfig
70
 
71
  self.vision_config = CLIPVisionConfig(**vision_config)
72
  else:
73
- self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config)
 
 
74
 
75
  self.projection_dim = projection_dim
76
  self.initializer_factor = 1.0
77
 
78
  @classmethod
79
- def from_text_vision_configs(cls, text_config: PretrainedConfig, vision_config: PretrainedConfig, **kwargs):
 
 
80
  r"""
81
  Instantiate a :class:`HybridCLIPConfig` (or a derived class) from text model configuration and
82
  vision model configuration.
@@ -84,7 +89,11 @@ class HybridCLIPConfig(PretrainedConfig):
84
  :class:`HybridCLIPConfig`: An instance of a configuration object
85
  """
86
 
87
- return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
 
 
 
 
88
 
89
  def to_dict(self):
90
  """
 
3
  from transformers.configuration_utils import PretrainedConfig
4
  from transformers.utils import logging
5
 
 
6
  logger = logging.get_logger(__name__)
7
 
8
 
 
63
  self.text_config = AutoConfig.for_model(text_model_type, **text_config)
64
 
65
  if vision_model_type == "clip":
66
+ self.vision_config = AutoConfig.for_model(
67
+ vision_model_type, **vision_config
68
+ ).vision_config
69
  elif vision_model_type == "clip_vision_model":
70
  from transformers import CLIPVisionConfig
71
 
72
  self.vision_config = CLIPVisionConfig(**vision_config)
73
  else:
74
+ self.vision_config = AutoConfig.for_model(
75
+ vision_model_type, **vision_config
76
+ )
77
 
78
  self.projection_dim = projection_dim
79
  self.initializer_factor = 1.0
80
 
81
  @classmethod
82
+ def from_text_vision_configs(
83
+ cls, text_config: PretrainedConfig, vision_config: PretrainedConfig, **kwargs
84
+ ):
85
  r"""
86
  Instantiate a :class:`HybridCLIPConfig` (or a derived class) from text model configuration and
87
  vision model configuration.
 
89
  :class:`HybridCLIPConfig`: An instance of a configuration object
90
  """
91
 
92
+ return cls(
93
+ text_config=text_config.to_dict(),
94
+ vision_config=vision_config.to_dict(),
95
+ **kwargs
96
+ )
97
 
98
  def to_dict(self):
99
  """
discard_incorrect_files.py CHANGED
@@ -1,9 +1,7 @@
1
  import json
2
  import os
3
- from tqdm import tqdm
4
 
5
- import torch
6
- from torchvision.io import ImageReadMode, read_image
7
 
8
  JOINT_JSON_DIRECTORY = f"/home/{os.environ['USER']}/data/wit/all_jsons"
9
  SCALE_CONVERTED_DIRECTORY = f"/home/{os.environ['USER']}/data/wit_scale_converted"
@@ -16,13 +14,16 @@ for split in ["train", "valid", "test"]:
16
 
17
  supported_examples = []
18
  for example in tqdm(examples):
19
- directory, filename = os.path.split(example['image_path'])
20
  if filename in valid_files:
21
  example["image_path"] = os.path.join(SCALE_CONVERTED_DIRECTORY, filename)
22
  supported_examples.append(json.dumps(example, ensure_ascii=False))
23
 
24
  print(f"Total {split} examples: {len(supported_examples)}")
25
- with open(f"{SCALE_CONVERTED_DIRECTORY}/{split}_dataset_scale_converted_98_1_1_split.json", "w") as f:
 
 
 
26
  f.write("\n".join(supported_examples))
27
 
28
  print("DONE!")
 
1
  import json
2
  import os
 
3
 
4
+ from tqdm import tqdm
 
5
 
6
  JOINT_JSON_DIRECTORY = f"/home/{os.environ['USER']}/data/wit/all_jsons"
7
  SCALE_CONVERTED_DIRECTORY = f"/home/{os.environ['USER']}/data/wit_scale_converted"
 
14
 
15
  supported_examples = []
16
  for example in tqdm(examples):
17
+ directory, filename = os.path.split(example["image_path"])
18
  if filename in valid_files:
19
  example["image_path"] = os.path.join(SCALE_CONVERTED_DIRECTORY, filename)
20
  supported_examples.append(json.dumps(example, ensure_ascii=False))
21
 
22
  print(f"Total {split} examples: {len(supported_examples)}")
23
+ with open(
24
+ f"{SCALE_CONVERTED_DIRECTORY}/{split}_dataset_scale_converted_98_1_1_split.json",
25
+ "w",
26
+ ) as f:
27
  f.write("\n".join(supported_examples))
28
 
29
  print("DONE!")
join_datasets_custom_split.py CHANGED
@@ -1,10 +1,7 @@
1
- import os
2
  import json
 
3
  import random
4
 
5
- import pandas as pd
6
-
7
-
8
  DATA_DIR = f"/home/{os.environ['USER']}/data/wit/all_jsons"
9
  SEED = 0
10
  PROPORTION_TRAIN = 0.98
@@ -12,7 +9,9 @@ PROPORTION_VALID = 0.01
12
 
13
  random.seed(SEED)
14
 
15
- all_files = [f"{DATA_DIR}/{file_}" for file_ in os.listdir(DATA_DIR) if ("all" not in file_)]
 
 
16
 
17
  print(all_files)
18
 
@@ -20,7 +19,9 @@ examples = []
20
  for file_ in all_files:
21
  print(file_)
22
  with open(file_) as f:
23
- file_examples = [json.dumps(json.loads(line), ensure_ascii=False) for line in f.readlines()]
 
 
24
  print(len(file_examples))
25
  examples.extend(file_examples)
26
 
@@ -34,15 +35,23 @@ random.shuffle(examples)
34
  print(examples[0])
35
 
36
  split_dataset = {}
37
- split_dataset["train"] = examples[:int(len(examples) * PROPORTION_TRAIN)]
38
- split_dataset["valid"] = examples[int(len(examples) * PROPORTION_TRAIN): int(len(examples) * (PROPORTION_TRAIN + PROPORTION_VALID))]
39
- split_dataset["test"] = examples[int(len(examples) * (PROPORTION_TRAIN + PROPORTION_VALID)):]
 
 
 
 
 
 
40
 
41
 
42
  for split in ["train", "valid", "test"]:
43
  print("-----")
44
  print(len(split_dataset[split]))
45
  print("-----")
46
- with open(f"/home/{os.environ['USER']}/data/wit/all_jsons/{split}_dataset_all_98_1_1_split.json", "w") as f:
 
 
 
47
  f.write("\n".join(split_dataset[split]))
48
-
 
 
1
  import json
2
+ import os
3
  import random
4
 
 
 
 
5
  DATA_DIR = f"/home/{os.environ['USER']}/data/wit/all_jsons"
6
  SEED = 0
7
  PROPORTION_TRAIN = 0.98
 
9
 
10
  random.seed(SEED)
11
 
12
+ all_files = [
13
+ f"{DATA_DIR}/{file_}" for file_ in os.listdir(DATA_DIR) if ("all" not in file_)
14
+ ]
15
 
16
  print(all_files)
17
 
 
19
  for file_ in all_files:
20
  print(file_)
21
  with open(file_) as f:
22
+ file_examples = [
23
+ json.dumps(json.loads(line), ensure_ascii=False) for line in f.readlines()
24
+ ]
25
  print(len(file_examples))
26
  examples.extend(file_examples)
27
 
 
35
  print(examples[0])
36
 
37
  split_dataset = {}
38
+ split_dataset["train"] = examples[: int(len(examples) * PROPORTION_TRAIN)]
39
+ split_dataset["valid"] = examples[
40
+ int(len(examples) * PROPORTION_TRAIN) : int(
41
+ len(examples) * (PROPORTION_TRAIN + PROPORTION_VALID)
42
+ )
43
+ ]
44
+ split_dataset["test"] = examples[
45
+ int(len(examples) * (PROPORTION_TRAIN + PROPORTION_VALID)) :
46
+ ]
47
 
48
 
49
  for split in ["train", "valid", "test"]:
50
  print("-----")
51
  print(len(split_dataset[split]))
52
  print("-----")
53
+ with open(
54
+ f"/home/{os.environ['USER']}/data/wit/all_jsons/{split}_dataset_all_98_1_1_split.json",
55
+ "w",
56
+ ) as f:
57
  f.write("\n".join(split_dataset[split]))
 
modeling_hybrid_clip.py CHANGED
@@ -18,13 +18,13 @@ from typing import Optional, Tuple
18
  import flax.linen as nn
19
  import jax
20
  import jax.numpy as jnp
21
- from configuration_hybrid_clip import HybridCLIPConfig
22
  from flax.core.frozen_dict import FrozenDict
23
  from transformers import FLAX_MODEL_MAPPING, FlaxCLIPVisionModel
24
  from transformers.modeling_flax_utils import FlaxPreTrainedModel
25
  from transformers.models.clip.modeling_flax_clip import FlaxCLIPOutput
26
  from transformers.utils import logging
27
 
 
28
 
29
  logger = logging.get_logger(__name__)
30
 
@@ -42,7 +42,9 @@ class FlaxHybridCLIPModule(nn.Module):
42
  self.vision_embed_dim = vision_config.hidden_size
43
 
44
  text_module = FLAX_MODEL_MAPPING[self.config.text_config.__class__].module_class
45
- vision_module = FLAX_MODEL_MAPPING.get(self.config.vision_config.__class__, FlaxCLIPVisionModel).module_class
 
 
46
 
47
  self.text_model = text_module(text_config, dtype=self.dtype)
48
  self.vision_model = vision_module(vision_config, dtype=self.dtype)
@@ -73,7 +75,9 @@ class FlaxHybridCLIPModule(nn.Module):
73
  output_hidden_states=None,
74
  return_dict=None,
75
  ):
76
- return_dict = return_dict if return_dict is not None else self.config.return_dict
 
 
77
 
78
  vision_outputs = self.vision_model(
79
  pixel_values=pixel_values,
@@ -101,7 +105,9 @@ class FlaxHybridCLIPModule(nn.Module):
101
  text_embeds = self.text_projection(text_embeds)
102
 
103
  # normalized features
104
- image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=-1, keepdims=True)
 
 
105
  text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=-1, keepdims=True)
106
 
107
  # cosine similarity as logits
@@ -110,7 +116,14 @@ class FlaxHybridCLIPModule(nn.Module):
110
  logits_per_image = logits_per_text.T
111
 
112
  if not return_dict:
113
- return (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
 
 
 
 
 
 
 
114
 
115
  return FlaxCLIPOutput(
116
  logits_per_image=logits_per_image,
@@ -132,18 +145,30 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
132
  input_shape: Optional[Tuple] = None,
133
  seed: int = 0,
134
  dtype: jnp.dtype = jnp.float32,
135
- **kwargs
136
  ):
137
  if input_shape is None:
138
- input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
 
 
 
 
 
 
 
 
139
 
140
  module = self.module_class(config=config, dtype=dtype, **kwargs)
141
- super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
 
 
142
 
143
  def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
144
  # init input tensor
145
  input_ids = jnp.zeros(input_shape[0], dtype="i4")
146
- position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0])
 
 
147
  token_type_ids = jnp.ones_like(input_ids)
148
  attention_mask = jnp.ones_like(input_ids)
149
 
@@ -152,7 +177,9 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
152
  params_rng, dropout_rng = jax.random.split(rng)
153
  rngs = {"params": params_rng, "dropout": dropout_rng}
154
 
155
- return self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids, token_type_ids)["params"]
 
 
156
 
157
  def __call__(
158
  self,
@@ -168,14 +195,24 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
168
  output_hidden_states: Optional[bool] = None,
169
  return_dict: Optional[bool] = None,
170
  ):
171
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
172
  output_hidden_states = (
173
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
174
  )
175
- return_dict = return_dict if return_dict is not None else self.config.return_dict
176
 
177
  if position_ids is None:
178
- position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
 
 
179
 
180
  if token_type_ids is None:
181
  token_type_ids = jnp.zeros_like(input_ids)
@@ -225,7 +262,9 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
225
  obtained by applying the projection layer to the pooled output of text model.
226
  """
227
  if position_ids is None:
228
- position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
 
 
229
 
230
  if token_type_ids is None:
231
  token_type_ids = jnp.zeros_like(input_ids)
@@ -238,7 +277,14 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
238
  if dropout_rng is not None:
239
  rngs["dropout"] = dropout_rng
240
 
241
- def _get_features(module, input_ids, attention_mask, position_ids, token_type_ids, deterministic):
 
 
 
 
 
 
 
242
  text_outputs = module.text_model(
243
  input_ids=input_ids,
244
  attention_mask=attention_mask,
@@ -261,7 +307,9 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
261
  rngs=rngs,
262
  )
263
 
264
- def get_image_features(self, pixel_values, dropout_rng: jax.random.PRNGKey = None, train=False):
 
 
265
  r"""
266
  Args:
267
  pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
@@ -279,7 +327,9 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
279
  rngs["dropout"] = dropout_rng
280
 
281
  def _get_features(module, pixel_values, deterministic):
282
- vision_outputs = module.vision_model(pixel_values=pixel_values, deterministic=deterministic)
 
 
283
  pooled_output = vision_outputs[1] # pooled_output
284
  image_features = module.visual_projection(pooled_output)
285
  return image_features
@@ -345,11 +395,15 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
345
  """
346
 
347
  kwargs_text = {
348
- argument[len("text_") :]: value for argument, value in kwargs.items() if argument.startswith("text_")
 
 
349
  }
350
 
351
  kwargs_vision = {
352
- argument[len("vision_") :]: value for argument, value in kwargs.items() if argument.startswith("vision_")
 
 
353
  }
354
 
355
  # remove text, vision kwargs from kwargs
@@ -372,7 +426,9 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
372
  text_config = AutoConfig.from_pretrained(text_model_name_or_path)
373
  kwargs_text["config"] = text_config
374
 
375
- text_model = FlaxAutoModel.from_pretrained(text_model_name_or_path, *model_args, **kwargs_text)
 
 
376
 
377
  vision_model = kwargs_vision.pop("model", None)
378
  if vision_model is None:
@@ -387,21 +443,29 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
387
  vision_config = AutoConfig.from_pretrained(vision_model_name_or_path)
388
  kwargs_vision["config"] = vision_config
389
 
390
- vision_model = FlaxAutoModel.from_pretrained(vision_model_name_or_path, *model_args, **kwargs_vision)
 
 
391
 
392
  # instantiate config with corresponding kwargs
393
  dtype = kwargs.pop("dtype", jnp.float32)
394
- config = HybridCLIPConfig.from_text_vision_configs(text_model.config, vision_model.config, **kwargs)
 
 
395
 
396
  # init model
397
  model = cls(config, *model_args, dtype=dtype, **kwargs)
398
 
399
  if vision_config.model_type == "clip":
400
- model.params["vision_model"]["vision_model"] = vision_model.params["vision_model"]
401
- model.params["visual_projection"]["kernel"] = vision_model.params["visual_projection"]["kernel"]
 
 
 
 
402
  else:
403
  model.params["vision_model"] = vision_model.params
404
 
405
  model.params["text_model"] = text_model.params
406
 
407
- return model
 
18
  import flax.linen as nn
19
  import jax
20
  import jax.numpy as jnp
 
21
  from flax.core.frozen_dict import FrozenDict
22
  from transformers import FLAX_MODEL_MAPPING, FlaxCLIPVisionModel
23
  from transformers.modeling_flax_utils import FlaxPreTrainedModel
24
  from transformers.models.clip.modeling_flax_clip import FlaxCLIPOutput
25
  from transformers.utils import logging
26
 
27
+ from configuration_hybrid_clip import HybridCLIPConfig
28
 
29
  logger = logging.get_logger(__name__)
30
 
 
42
  self.vision_embed_dim = vision_config.hidden_size
43
 
44
  text_module = FLAX_MODEL_MAPPING[self.config.text_config.__class__].module_class
45
+ vision_module = FLAX_MODEL_MAPPING.get(
46
+ self.config.vision_config.__class__, FlaxCLIPVisionModel
47
+ ).module_class
48
 
49
  self.text_model = text_module(text_config, dtype=self.dtype)
50
  self.vision_model = vision_module(vision_config, dtype=self.dtype)
 
75
  output_hidden_states=None,
76
  return_dict=None,
77
  ):
78
+ return_dict = (
79
+ return_dict if return_dict is not None else self.config.return_dict
80
+ )
81
 
82
  vision_outputs = self.vision_model(
83
  pixel_values=pixel_values,
 
105
  text_embeds = self.text_projection(text_embeds)
106
 
107
  # normalized features
108
+ image_embeds = image_embeds / jnp.linalg.norm(
109
+ image_embeds, axis=-1, keepdims=True
110
+ )
111
  text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=-1, keepdims=True)
112
 
113
  # cosine similarity as logits
 
116
  logits_per_image = logits_per_text.T
117
 
118
  if not return_dict:
119
+ return (
120
+ logits_per_image,
121
+ logits_per_text,
122
+ text_embeds,
123
+ image_embeds,
124
+ text_outputs,
125
+ vision_outputs,
126
+ )
127
 
128
  return FlaxCLIPOutput(
129
  logits_per_image=logits_per_image,
 
145
  input_shape: Optional[Tuple] = None,
146
  seed: int = 0,
147
  dtype: jnp.dtype = jnp.float32,
148
+ **kwargs,
149
  ):
150
  if input_shape is None:
151
+ input_shape = (
152
+ (1, 1),
153
+ (
154
+ 1,
155
+ config.vision_config.image_size,
156
+ config.vision_config.image_size,
157
+ 3,
158
+ ),
159
+ )
160
 
161
  module = self.module_class(config=config, dtype=dtype, **kwargs)
162
+ super().__init__(
163
+ config, module, input_shape=input_shape, seed=seed, dtype=dtype
164
+ )
165
 
166
  def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
167
  # init input tensor
168
  input_ids = jnp.zeros(input_shape[0], dtype="i4")
169
+ position_ids = jnp.broadcast_to(
170
+ jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0]
171
+ )
172
  token_type_ids = jnp.ones_like(input_ids)
173
  attention_mask = jnp.ones_like(input_ids)
174
 
 
177
  params_rng, dropout_rng = jax.random.split(rng)
178
  rngs = {"params": params_rng, "dropout": dropout_rng}
179
 
180
+ return self.module.init(
181
+ rngs, input_ids, pixel_values, attention_mask, position_ids, token_type_ids
182
+ )["params"]
183
 
184
  def __call__(
185
  self,
 
195
  output_hidden_states: Optional[bool] = None,
196
  return_dict: Optional[bool] = None,
197
  ):
198
+ output_attentions = (
199
+ output_attentions
200
+ if output_attentions is not None
201
+ else self.config.output_attentions
202
+ )
203
  output_hidden_states = (
204
+ output_hidden_states
205
+ if output_hidden_states is not None
206
+ else self.config.output_hidden_states
207
+ )
208
+ return_dict = (
209
+ return_dict if return_dict is not None else self.config.return_dict
210
  )
 
211
 
212
  if position_ids is None:
213
+ position_ids = jnp.broadcast_to(
214
+ jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape
215
+ )
216
 
217
  if token_type_ids is None:
218
  token_type_ids = jnp.zeros_like(input_ids)
 
262
  obtained by applying the projection layer to the pooled output of text model.
263
  """
264
  if position_ids is None:
265
+ position_ids = jnp.broadcast_to(
266
+ jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape
267
+ )
268
 
269
  if token_type_ids is None:
270
  token_type_ids = jnp.zeros_like(input_ids)
 
277
  if dropout_rng is not None:
278
  rngs["dropout"] = dropout_rng
279
 
280
+ def _get_features(
281
+ module,
282
+ input_ids,
283
+ attention_mask,
284
+ position_ids,
285
+ token_type_ids,
286
+ deterministic,
287
+ ):
288
  text_outputs = module.text_model(
289
  input_ids=input_ids,
290
  attention_mask=attention_mask,
 
307
  rngs=rngs,
308
  )
309
 
310
+ def get_image_features(
311
+ self, pixel_values, dropout_rng: jax.random.PRNGKey = None, train=False
312
+ ):
313
  r"""
314
  Args:
315
  pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
 
327
  rngs["dropout"] = dropout_rng
328
 
329
  def _get_features(module, pixel_values, deterministic):
330
+ vision_outputs = module.vision_model(
331
+ pixel_values=pixel_values, deterministic=deterministic
332
+ )
333
  pooled_output = vision_outputs[1] # pooled_output
334
  image_features = module.visual_projection(pooled_output)
335
  return image_features
 
395
  """
396
 
397
  kwargs_text = {
398
+ argument[len("text_") :]: value
399
+ for argument, value in kwargs.items()
400
+ if argument.startswith("text_")
401
  }
402
 
403
  kwargs_vision = {
404
+ argument[len("vision_") :]: value
405
+ for argument, value in kwargs.items()
406
+ if argument.startswith("vision_")
407
  }
408
 
409
  # remove text, vision kwargs from kwargs
 
426
  text_config = AutoConfig.from_pretrained(text_model_name_or_path)
427
  kwargs_text["config"] = text_config
428
 
429
+ text_model = FlaxAutoModel.from_pretrained(
430
+ text_model_name_or_path, *model_args, **kwargs_text
431
+ )
432
 
433
  vision_model = kwargs_vision.pop("model", None)
434
  if vision_model is None:
 
443
  vision_config = AutoConfig.from_pretrained(vision_model_name_or_path)
444
  kwargs_vision["config"] = vision_config
445
 
446
+ vision_model = FlaxAutoModel.from_pretrained(
447
+ vision_model_name_or_path, *model_args, **kwargs_vision
448
+ )
449
 
450
  # instantiate config with corresponding kwargs
451
  dtype = kwargs.pop("dtype", jnp.float32)
452
+ config = HybridCLIPConfig.from_text_vision_configs(
453
+ text_model.config, vision_model.config, **kwargs
454
+ )
455
 
456
  # init model
457
  model = cls(config, *model_args, dtype=dtype, **kwargs)
458
 
459
  if vision_config.model_type == "clip":
460
+ model.params["vision_model"]["vision_model"] = vision_model.params[
461
+ "vision_model"
462
+ ]
463
+ model.params["visual_projection"]["kernel"] = vision_model.params[
464
+ "visual_projection"
465
+ ]["kernel"]
466
  else:
467
  model.params["vision_model"] = vision_model.params
468
 
469
  model.params["text_model"] = text_model.params
470
 
471
+ return model
prepare_wit.py CHANGED
@@ -3,14 +3,13 @@ import json
3
  import logging
4
  import os
5
  import time
6
- from typing import List
7
- import urllib.request
8
  import urllib.error
 
 
9
 
10
  import pandas as pd
11
  from tqdm import tqdm
12
 
13
-
14
  logging.basicConfig(
15
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
16
  datefmt="%m/%d/%Y %H:%M:%S",
@@ -18,11 +17,18 @@ logging.basicConfig(
18
  )
19
  logger = logging.getLogger(__name__)
20
 
21
- def split_and_save_datasets(lines: List[str], output_dir: str, train_proportion: float, valid_proportion: float):
 
 
 
22
  total_lines = len(lines)
23
- train_lines = lines[:int(total_lines * train_proportion)]
24
- valid_lines = lines[int(total_lines * train_proportion):int(total_lines * (train_proportion + valid_proportion))]
25
- test_lines = lines[int(total_lines * (train_proportion + valid_proportion)):]
 
 
 
 
26
 
27
  with open(f"{output_dir}/train_dataset.json", "w") as f:
28
  f.write("\n".join(train_lines))
@@ -33,14 +39,33 @@ def split_and_save_datasets(lines: List[str], output_dir: str, train_proportion:
33
  with open(f"{output_dir}/test_dataset.json", "w") as f:
34
  f.write("\n".join(test_lines))
35
 
 
36
  def prepare_wit(
37
- tsv: str, language: str, output_dir: str, seed: int, train_proportion: float, valid_proportion: float, backup_period: int, language_col: str="language", caption_col: str="caption_reference_description", url_col: str="image_url", pause=0.875, retries: int=10):
 
 
 
 
 
 
 
 
 
 
 
 
38
  os.makedirs(output_dir, exist_ok=True)
39
  logger.info("Loading dataset")
40
  df = pd.read_csv(tsv, sep="\t", engine="python")
41
  existing_files = set(os.listdir(output_dir))
42
- not_exists_condition = (~(df[url_col].map(lambda x: x.split("/")[-1][-100:]).isin(existing_files)))
43
- df = df[(df["language"] == language) & (~df["caption_reference_description"].isnull()) & not_exists_condition]
 
 
 
 
 
 
44
  # Shuffle
45
  df = df.sample(frac=1.0, random_state=seed)
46
  logger.info(f"Trying to downloading {df.shape[0]} files")
@@ -58,14 +83,21 @@ def prepare_wit(
58
  try:
59
  # Download file
60
  urllib.request.urlretrieve(url, image_path)
61
- lines.append(json.dumps({"image_path": image_path, "captions": [caption]}, ensure_ascii=False))
 
 
 
 
 
62
  count += 1
63
  break
64
- except urllib.error.HTTPError as e:
65
  time.sleep(pause * 10)
66
  if count % backup_period == 0:
67
  logger.info(f"Saving dataset backup: Number of lines {len(lines)}")
68
- split_and_save_datasets(lines, output_dir, train_proportion, valid_proportion)
 
 
69
  if retry == retries - 1:
70
  logger.info(f"Skipping {image_filename}")
71
  pbar.update(1)
@@ -73,16 +105,35 @@ def prepare_wit(
73
  finally:
74
  split_and_save_datasets(lines, output_dir, train_proportion, valid_proportion)
75
 
 
76
  if __name__ == "__main__":
77
- parser = argparse.ArgumentParser(description = "Download and prepare the WIT dataset")
78
- parser.add_argument("--tsv", type=str, default=f"/home/{os.environ['USER']}/data/wit/wit_v1.train.all-1percent_sample.tsv")
 
 
 
 
79
  parser.add_argument("--language", type=str, default="es")
80
- parser.add_argument("--output_dir", type=str, default=f"/home/{os.environ['USER']}/data/wit/prepared_dataset")
 
 
 
 
81
  parser.add_argument("--random_seed", type=int, default=0)
82
  parser.add_argument("--train_proportion", type=float, default=0.8)
83
  parser.add_argument("--valid_proportion", type=float, default=0.1)
84
  parser.add_argument("--backup_period", type=int, default=1000)
85
 
86
  args = parser.parse_args()
87
- assert args.train_proportion + args.valid_proportion < 1.0, "The sum of train_proportion and valid_proportion has to be < 1.0"
88
- prepare_wit(args.tsv, args.language, args.output_dir, args.random_seed, args.train_proportion, args.valid_proportion, args.backup_period)
 
 
 
 
 
 
 
 
 
 
 
3
  import logging
4
  import os
5
  import time
 
 
6
  import urllib.error
7
+ import urllib.request
8
+ from typing import List
9
 
10
  import pandas as pd
11
  from tqdm import tqdm
12
 
 
13
  logging.basicConfig(
14
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
15
  datefmt="%m/%d/%Y %H:%M:%S",
 
17
  )
18
  logger = logging.getLogger(__name__)
19
 
20
+
21
+ def split_and_save_datasets(
22
+ lines: List[str], output_dir: str, train_proportion: float, valid_proportion: float
23
+ ):
24
  total_lines = len(lines)
25
+ train_lines = lines[: int(total_lines * train_proportion)]
26
+ valid_lines = lines[
27
+ int(total_lines * train_proportion) : int(
28
+ total_lines * (train_proportion + valid_proportion)
29
+ )
30
+ ]
31
+ test_lines = lines[int(total_lines * (train_proportion + valid_proportion)) :]
32
 
33
  with open(f"{output_dir}/train_dataset.json", "w") as f:
34
  f.write("\n".join(train_lines))
 
39
  with open(f"{output_dir}/test_dataset.json", "w") as f:
40
  f.write("\n".join(test_lines))
41
 
42
+
43
  def prepare_wit(
44
+ tsv: str,
45
+ language: str,
46
+ output_dir: str,
47
+ seed: int,
48
+ train_proportion: float,
49
+ valid_proportion: float,
50
+ backup_period: int,
51
+ language_col: str = "language",
52
+ caption_col: str = "caption_reference_description",
53
+ url_col: str = "image_url",
54
+ pause=0.875,
55
+ retries: int = 10,
56
+ ):
57
  os.makedirs(output_dir, exist_ok=True)
58
  logger.info("Loading dataset")
59
  df = pd.read_csv(tsv, sep="\t", engine="python")
60
  existing_files = set(os.listdir(output_dir))
61
+ not_exists_condition = ~(
62
+ df[url_col].map(lambda x: x.split("/")[-1][-100:]).isin(existing_files)
63
+ )
64
+ df = df[
65
+ (df["language"] == language)
66
+ & (~df["caption_reference_description"].isnull())
67
+ & not_exists_condition
68
+ ]
69
  # Shuffle
70
  df = df.sample(frac=1.0, random_state=seed)
71
  logger.info(f"Trying to downloading {df.shape[0]} files")
 
83
  try:
84
  # Download file
85
  urllib.request.urlretrieve(url, image_path)
86
+ lines.append(
87
+ json.dumps(
88
+ {"image_path": image_path, "captions": [caption]},
89
+ ensure_ascii=False,
90
+ )
91
+ )
92
  count += 1
93
  break
94
+ except urllib.error.HTTPError:
95
  time.sleep(pause * 10)
96
  if count % backup_period == 0:
97
  logger.info(f"Saving dataset backup: Number of lines {len(lines)}")
98
+ split_and_save_datasets(
99
+ lines, output_dir, train_proportion, valid_proportion
100
+ )
101
  if retry == retries - 1:
102
  logger.info(f"Skipping {image_filename}")
103
  pbar.update(1)
 
105
  finally:
106
  split_and_save_datasets(lines, output_dir, train_proportion, valid_proportion)
107
 
108
+
109
  if __name__ == "__main__":
110
+ parser = argparse.ArgumentParser(description="Download and prepare the WIT dataset")
111
+ parser.add_argument(
112
+ "--tsv",
113
+ type=str,
114
+ default=f"/home/{os.environ['USER']}/data/wit/wit_v1.train.all-1percent_sample.tsv",
115
+ )
116
  parser.add_argument("--language", type=str, default="es")
117
+ parser.add_argument(
118
+ "--output_dir",
119
+ type=str,
120
+ default=f"/home/{os.environ['USER']}/data/wit/prepared_dataset",
121
+ )
122
  parser.add_argument("--random_seed", type=int, default=0)
123
  parser.add_argument("--train_proportion", type=float, default=0.8)
124
  parser.add_argument("--valid_proportion", type=float, default=0.1)
125
  parser.add_argument("--backup_period", type=int, default=1000)
126
 
127
  args = parser.parse_args()
128
+ assert (
129
+ args.train_proportion + args.valid_proportion < 1.0
130
+ ), "The sum of train_proportion and valid_proportion has to be < 1.0"
131
+ prepare_wit(
132
+ args.tsv,
133
+ args.language,
134
+ args.output_dir,
135
+ args.random_seed,
136
+ args.train_proportion,
137
+ args.valid_proportion,
138
+ args.backup_period,
139
+ )
run_hybrid_clip.py CHANGED
@@ -32,25 +32,26 @@ from dataclasses import dataclass, field
32
  from pathlib import Path
33
  from typing import Callable, Optional
34
 
35
- import numpy as np
36
- import torch
37
- from torchvision.datasets import VisionDataset
38
- from torchvision.io import ImageReadMode, read_image
39
- from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
40
- from torchvision.transforms.functional import InterpolationMode
41
- from tqdm import tqdm
42
-
43
  import jax
44
  import jax.numpy as jnp
 
45
  import optax
 
46
  import transformers
47
  from flax import jax_utils
48
  from flax.jax_utils import unreplicate
49
  from flax.training import train_state
50
  from flax.training.common_utils import get_metrics, shard, shard_prng_key
51
- from modeling_hybrid_clip import FlaxHybridCLIP
52
- from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, is_tensorboard_available, set_seed
 
 
 
 
 
 
53
 
 
54
 
55
  logger = logging.getLogger(__name__)
56
 
@@ -61,7 +62,9 @@ if has_tensorboard:
61
  from flax.metrics.tensorboard import SummaryWriter
62
  except ImportError as ie:
63
  has_tensorboard = False
64
- print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}")
 
 
65
 
66
  else:
67
  print(
@@ -90,20 +93,33 @@ class ModelArguments:
90
  )
91
  from_pt: bool = field(
92
  default=True,
93
- metadata={"help": "whether to load the text and vision model using PyTorch checkpoints."},
 
 
94
  )
95
  config_name: Optional[str] = field(
96
- default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
 
 
 
97
  )
98
  tokenizer_name: Optional[str] = field(
99
- default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
 
 
 
100
  )
101
  cache_dir: Optional[str] = field(
102
- default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
 
 
 
103
  )
104
  use_fast_tokenizer: bool = field(
105
  default=True,
106
- metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
 
 
107
  )
108
  dtype: Optional[str] = field(
109
  default="float32",
@@ -119,9 +135,12 @@ class DataTrainingArguments:
119
  Arguments pertaining to what data we are going to input our model for training and eval.
120
  """
121
 
122
- data_dir: Optional[str] = field(default=None, metadata={"help": "The data directory containing input files."})
 
 
123
  train_file: Optional[str] = field(
124
- default=None, metadata={"help": "The input training data file (a jsonlines file)."}
 
125
  )
126
  validation_file: Optional[str] = field(
127
  default=None,
@@ -149,10 +168,12 @@ class DataTrainingArguments:
149
  },
150
  )
151
  overwrite_cache: bool = field(
152
- default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
 
153
  )
154
  overwrite_cache: bool = field(
155
- default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
 
156
  )
157
  preprocessing_num_workers: Optional[int] = field(
158
  default=None,
@@ -161,7 +182,9 @@ class DataTrainingArguments:
161
 
162
  def __post_init__(self):
163
  if self.train_file is None and self.validation_file is None:
164
- raise ValueError("Need either a dataset name or a training/validation file.")
 
 
165
  else:
166
  if self.train_file is not None:
167
  extension = self.train_file.split(".")[-1]
@@ -180,7 +203,10 @@ class Transform(torch.nn.Module):
180
  Resize([image_size], interpolation=InterpolationMode.BICUBIC),
181
  CenterCrop(image_size),
182
  ConvertImageDtype(torch.float),
183
- Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
 
 
 
184
  )
185
 
186
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -225,7 +251,7 @@ class ImageTextDataset(VisionDataset):
225
  self.image_paths = []
226
 
227
  for example in examples:
228
- captions_subset = example["captions"][:captions_per_image]
229
  self.captions.extend(captions_subset)
230
  self.image_paths.extend([example["image_path"]] * len(captions_subset))
231
 
@@ -253,7 +279,9 @@ class TrainState(train_state.TrainState):
253
  dropout_rng: jnp.ndarray
254
 
255
  def replicate(self):
256
- return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
 
 
257
 
258
 
259
  def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
@@ -270,25 +298,39 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
270
 
271
 
272
  def create_learning_rate_fn(
273
- train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
 
 
 
 
274
  ) -> Callable[[int], jnp.array]:
275
  """Returns a linear warmup, linear_decay learning rate function."""
276
  steps_per_epoch = train_ds_size // train_batch_size
277
  num_train_steps = steps_per_epoch * num_train_epochs
278
- warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
 
 
279
  decay_fn = optax.linear_schedule(
280
- init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
 
 
 
 
 
281
  )
282
- schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
283
  return schedule_fn
284
 
285
 
286
  def main():
287
- parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
 
 
288
  if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
289
  # If we pass only one argument to the script and it's the path to a json file,
290
  # let's parse it to get our arguments.
291
- model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
 
 
292
  else:
293
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
294
 
@@ -321,11 +363,15 @@ def main():
321
 
322
  if model_args.tokenizer_name:
323
  tokenizer = AutoTokenizer.from_pretrained(
324
- model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
 
 
325
  )
326
  elif model_args.text_model_name_or_path:
327
  tokenizer = AutoTokenizer.from_pretrained(
328
- model_args.text_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
 
 
329
  )
330
  else:
331
  raise ValueError(
@@ -366,16 +412,28 @@ def main():
366
 
367
  # Store some constant
368
  num_epochs = int(training_args.num_train_epochs)
369
- train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
 
 
370
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
371
  steps_per_epoch = len(train_dataset) // train_batch_size
372
  total_train_steps = steps_per_epoch * num_epochs
373
 
374
  # Use collate function to tokenizer the text and convert the processed images to numpy
375
  def collate_fn(examples):
376
- pixel_values = torch.stack([example[0] for example in examples]).permute(0, 2, 3, 1).numpy()
 
 
 
 
377
  captions = [example[1] for example in examples]
378
- inputs = tokenizer(captions, max_length=data_args.max_seq_length, padding="max_length", truncation=True, return_tensors="np")
 
 
 
 
 
 
379
 
380
  batch = {
381
  "pixel_values": pixel_values,
@@ -408,7 +466,9 @@ def main():
408
 
409
  # Enable tensorboard only on the master node
410
  if has_tensorboard and jax.process_index() == 0:
411
- summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir).joinpath("logs").as_posix())
 
 
412
 
413
  # Initialize our training
414
  rng = jax.random.PRNGKey(training_args.seed)
@@ -433,7 +493,9 @@ def main():
433
  )
434
 
435
  # Setup train state
436
- state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
 
 
437
 
438
  def cross_entropy(logits, axis):
439
  logprobs = jax.nn.log_softmax(logits, axis=axis)
@@ -442,7 +504,9 @@ def main():
442
  return ce
443
 
444
  def clip_loss(similarity):
445
- loss = (cross_entropy(similarity, axis=0) + cross_entropy(similarity, axis=1)) / 2
 
 
446
  return loss
447
 
448
  # Define gradient update step fn
@@ -450,7 +514,9 @@ def main():
450
  dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
451
 
452
  def compute_loss(params):
453
- logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
 
 
454
  loss = clip_loss(logits)
455
  return loss
456
 
@@ -460,7 +526,10 @@ def main():
460
 
461
  new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
462
 
463
- metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
 
 
 
464
  metrics = jax.lax.pmean(metrics, axis_name="batch")
465
 
466
  return new_state, metrics
@@ -485,8 +554,12 @@ def main():
485
  logger.info("***** Running training *****")
486
  logger.info(f" Num examples = {len(train_dataset)}")
487
  logger.info(f" Num Epochs = {num_epochs}")
488
- logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
489
- logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
 
 
 
 
490
  logger.info(f" Total optimization steps = {total_train_steps}")
491
 
492
  train_time = 0
@@ -504,7 +577,9 @@ def main():
504
  train_metrics = []
505
 
506
  steps_per_epoch = len(train_dataset) // train_batch_size
507
- train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)
 
 
508
  # train
509
  for batch in train_loader:
510
  batch = shard(batch)
@@ -525,7 +600,9 @@ def main():
525
  # ======================== Evaluating ==============================
526
  eval_metrics = []
527
  eval_steps = len(eval_dataset) // eval_batch_size
528
- eval_step_progress_bar = tqdm(total=eval_steps, desc="Evaluating...", position=2, leave=False)
 
 
529
  for batch in eval_loader:
530
  # Model forward
531
  batch = shard(batch)
@@ -541,14 +618,18 @@ def main():
541
 
542
  # Print metrics and update progress bar
543
  eval_step_progress_bar.close()
544
- desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
 
 
545
  epochs.write(desc)
546
  epochs.desc = desc
547
 
548
  # Save metrics
549
  if has_tensorboard and jax.process_index() == 0:
550
  cur_step = epoch * (len(train_dataset) // train_batch_size)
551
- write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
 
 
552
 
553
  # save checkpoint after each epoch and push checkpoint to the hub
554
  if jax.process_index() == 0:
 
32
  from pathlib import Path
33
  from typing import Callable, Optional
34
 
 
 
 
 
 
 
 
 
35
  import jax
36
  import jax.numpy as jnp
37
+ import numpy as np
38
  import optax
39
+ import torch
40
  import transformers
41
  from flax import jax_utils
42
  from flax.jax_utils import unreplicate
43
  from flax.training import train_state
44
  from flax.training.common_utils import get_metrics, shard, shard_prng_key
45
+ from torchvision.datasets import VisionDataset
46
+ from torchvision.io import ImageReadMode, read_image
47
+ from torchvision.transforms import (CenterCrop, ConvertImageDtype, Normalize,
48
+ Resize)
49
+ from torchvision.transforms.functional import InterpolationMode
50
+ from tqdm import tqdm
51
+ from transformers import (AutoTokenizer, HfArgumentParser, TrainingArguments,
52
+ is_tensorboard_available, set_seed)
53
 
54
+ from modeling_hybrid_clip import FlaxHybridCLIP
55
 
56
  logger = logging.getLogger(__name__)
57
 
 
62
  from flax.metrics.tensorboard import SummaryWriter
63
  except ImportError as ie:
64
  has_tensorboard = False
65
+ print(
66
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
67
+ )
68
 
69
  else:
70
  print(
 
93
  )
94
  from_pt: bool = field(
95
  default=True,
96
+ metadata={
97
+ "help": "whether to load the text and vision model using PyTorch checkpoints."
98
+ },
99
  )
100
  config_name: Optional[str] = field(
101
+ default=None,
102
+ metadata={
103
+ "help": "Pretrained config name or path if not the same as model_name"
104
+ },
105
  )
106
  tokenizer_name: Optional[str] = field(
107
+ default=None,
108
+ metadata={
109
+ "help": "Pretrained tokenizer name or path if not the same as model_name"
110
+ },
111
  )
112
  cache_dir: Optional[str] = field(
113
+ default=None,
114
+ metadata={
115
+ "help": "Where do you want to store the pretrained models downloaded from s3"
116
+ },
117
  )
118
  use_fast_tokenizer: bool = field(
119
  default=True,
120
+ metadata={
121
+ "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
122
+ },
123
  )
124
  dtype: Optional[str] = field(
125
  default="float32",
 
135
  Arguments pertaining to what data we are going to input our model for training and eval.
136
  """
137
 
138
+ data_dir: Optional[str] = field(
139
+ default=None, metadata={"help": "The data directory containing input files."}
140
+ )
141
  train_file: Optional[str] = field(
142
+ default=None,
143
+ metadata={"help": "The input training data file (a jsonlines file)."},
144
  )
145
  validation_file: Optional[str] = field(
146
  default=None,
 
168
  },
169
  )
170
  overwrite_cache: bool = field(
171
+ default=False,
172
+ metadata={"help": "Overwrite the cached training and evaluation sets"},
173
  )
174
  overwrite_cache: bool = field(
175
+ default=False,
176
+ metadata={"help": "Overwrite the cached training and evaluation sets"},
177
  )
178
  preprocessing_num_workers: Optional[int] = field(
179
  default=None,
 
182
 
183
  def __post_init__(self):
184
  if self.train_file is None and self.validation_file is None:
185
+ raise ValueError(
186
+ "Need either a dataset name or a training/validation file."
187
+ )
188
  else:
189
  if self.train_file is not None:
190
  extension = self.train_file.split(".")[-1]
 
203
  Resize([image_size], interpolation=InterpolationMode.BICUBIC),
204
  CenterCrop(image_size),
205
  ConvertImageDtype(torch.float),
206
+ Normalize(
207
+ (0.48145466, 0.4578275, 0.40821073),
208
+ (0.26862954, 0.26130258, 0.27577711),
209
+ ),
210
  )
211
 
212
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
251
  self.image_paths = []
252
 
253
  for example in examples:
254
+ captions_subset = example["captions"][:captions_per_image]
255
  self.captions.extend(captions_subset)
256
  self.image_paths.extend([example["image_path"]] * len(captions_subset))
257
 
 
279
  dropout_rng: jnp.ndarray
280
 
281
  def replicate(self):
282
+ return jax_utils.replicate(self).replace(
283
+ dropout_rng=shard_prng_key(self.dropout_rng)
284
+ )
285
 
286
 
287
  def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
 
298
 
299
 
300
  def create_learning_rate_fn(
301
+ train_ds_size: int,
302
+ train_batch_size: int,
303
+ num_train_epochs: int,
304
+ num_warmup_steps: int,
305
+ learning_rate: float,
306
  ) -> Callable[[int], jnp.array]:
307
  """Returns a linear warmup, linear_decay learning rate function."""
308
  steps_per_epoch = train_ds_size // train_batch_size
309
  num_train_steps = steps_per_epoch * num_train_epochs
310
+ warmup_fn = optax.linear_schedule(
311
+ init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
312
+ )
313
  decay_fn = optax.linear_schedule(
314
+ init_value=learning_rate,
315
+ end_value=0,
316
+ transition_steps=num_train_steps - num_warmup_steps,
317
+ )
318
+ schedule_fn = optax.join_schedules(
319
+ schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]
320
  )
 
321
  return schedule_fn
322
 
323
 
324
  def main():
325
+ parser = HfArgumentParser(
326
+ (ModelArguments, DataTrainingArguments, TrainingArguments)
327
+ )
328
  if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
329
  # If we pass only one argument to the script and it's the path to a json file,
330
  # let's parse it to get our arguments.
331
+ model_args, data_args, training_args = parser.parse_json_file(
332
+ json_file=os.path.abspath(sys.argv[1])
333
+ )
334
  else:
335
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
336
 
 
363
 
364
  if model_args.tokenizer_name:
365
  tokenizer = AutoTokenizer.from_pretrained(
366
+ model_args.tokenizer_name,
367
+ cache_dir=model_args.cache_dir,
368
+ use_fast=model_args.use_fast_tokenizer,
369
  )
370
  elif model_args.text_model_name_or_path:
371
  tokenizer = AutoTokenizer.from_pretrained(
372
+ model_args.text_model_name_or_path,
373
+ cache_dir=model_args.cache_dir,
374
+ use_fast=model_args.use_fast_tokenizer,
375
  )
376
  else:
377
  raise ValueError(
 
412
 
413
  # Store some constant
414
  num_epochs = int(training_args.num_train_epochs)
415
+ train_batch_size = (
416
+ int(training_args.per_device_train_batch_size) * jax.device_count()
417
+ )
418
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
419
  steps_per_epoch = len(train_dataset) // train_batch_size
420
  total_train_steps = steps_per_epoch * num_epochs
421
 
422
  # Use collate function to tokenizer the text and convert the processed images to numpy
423
  def collate_fn(examples):
424
+ pixel_values = (
425
+ torch.stack([example[0] for example in examples])
426
+ .permute(0, 2, 3, 1)
427
+ .numpy()
428
+ )
429
  captions = [example[1] for example in examples]
430
+ inputs = tokenizer(
431
+ captions,
432
+ max_length=data_args.max_seq_length,
433
+ padding="max_length",
434
+ truncation=True,
435
+ return_tensors="np",
436
+ )
437
 
438
  batch = {
439
  "pixel_values": pixel_values,
 
466
 
467
  # Enable tensorboard only on the master node
468
  if has_tensorboard and jax.process_index() == 0:
469
+ summary_writer = SummaryWriter(
470
+ log_dir=Path(training_args.output_dir).joinpath("logs").as_posix()
471
+ )
472
 
473
  # Initialize our training
474
  rng = jax.random.PRNGKey(training_args.seed)
 
493
  )
494
 
495
  # Setup train state
496
+ state = TrainState.create(
497
+ apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng
498
+ )
499
 
500
  def cross_entropy(logits, axis):
501
  logprobs = jax.nn.log_softmax(logits, axis=axis)
 
504
  return ce
505
 
506
  def clip_loss(similarity):
507
+ loss = (
508
+ cross_entropy(similarity, axis=0) + cross_entropy(similarity, axis=1)
509
+ ) / 2
510
  return loss
511
 
512
  # Define gradient update step fn
 
514
  dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
515
 
516
  def compute_loss(params):
517
+ logits = state.apply_fn(
518
+ **batch, params=params, dropout_rng=dropout_rng, train=True
519
+ )[0]
520
  loss = clip_loss(logits)
521
  return loss
522
 
 
526
 
527
  new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
528
 
529
+ metrics = {
530
+ "loss": loss,
531
+ "learning_rate": linear_decay_lr_schedule_fn(state.step),
532
+ }
533
  metrics = jax.lax.pmean(metrics, axis_name="batch")
534
 
535
  return new_state, metrics
 
554
  logger.info("***** Running training *****")
555
  logger.info(f" Num examples = {len(train_dataset)}")
556
  logger.info(f" Num Epochs = {num_epochs}")
557
+ logger.info(
558
+ f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
559
+ )
560
+ logger.info(
561
+ f" Total train batch size (w. parallel & distributed) = {train_batch_size}"
562
+ )
563
  logger.info(f" Total optimization steps = {total_train_steps}")
564
 
565
  train_time = 0
 
577
  train_metrics = []
578
 
579
  steps_per_epoch = len(train_dataset) // train_batch_size
580
+ train_step_progress_bar = tqdm(
581
+ total=steps_per_epoch, desc="Training...", position=1, leave=False
582
+ )
583
  # train
584
  for batch in train_loader:
585
  batch = shard(batch)
 
600
  # ======================== Evaluating ==============================
601
  eval_metrics = []
602
  eval_steps = len(eval_dataset) // eval_batch_size
603
+ eval_step_progress_bar = tqdm(
604
+ total=eval_steps, desc="Evaluating...", position=2, leave=False
605
+ )
606
  for batch in eval_loader:
607
  # Model forward
608
  batch = shard(batch)
 
618
 
619
  # Print metrics and update progress bar
620
  eval_step_progress_bar.close()
621
+ desc = (
622
+ f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
623
+ )
624
  epochs.write(desc)
625
  epochs.desc = desc
626
 
627
  # Save metrics
628
  if has_tensorboard and jax.process_index() == 0:
629
  cur_step = epoch * (len(train_dataset) // train_batch_size)
630
+ write_metric(
631
+ summary_writer, train_metrics, eval_metrics, train_time, cur_step
632
+ )
633
 
634
  # save checkpoint after each epoch and push checkpoint to the hub
635
  if jax.process_index() == 0:
scale_convert.py CHANGED
@@ -1,12 +1,11 @@
1
  import glob
2
  import itertools
3
- from argparse import ArgumentParser
4
- from joblib import Parallel, delayed
5
  import os
6
  import subprocess
 
7
  from collections import Counter
8
- import shutil
9
 
 
10
 
11
  parser = ArgumentParser()
12
  parser.add_argument("in_dir")
@@ -26,17 +25,16 @@ files = itertools.chain(
26
  glob.iglob(f"{args.in_dir}/*/*.SVG"),
27
  )
28
 
 
29
  def process_file(path):
30
  basename = os.path.basename(path)
31
- ext = os.path.splitext(basename)[1]
32
  name = os.path.splitext(basename)[0]
33
 
34
- dirname = os.path.dirname(path)
35
  try:
36
  r = subprocess.run(
37
  f'convert {path} -resize "224^>" -colorspace RGB -density 1200 {args.out_dir}/{name}.jpg',
38
  shell=True,
39
- timeout=10
40
  )
41
  rcode = r.returncode
42
  except subprocess.TimeoutExpired:
@@ -48,6 +46,8 @@ def process_file(path):
48
 
49
  return rcode
50
 
51
- codes = Parallel(n_jobs=32, prefer="threads", verbose=1)(delayed(process_file)(f) for f in files)
52
- print(Counter(codes))
53
 
 
 
 
 
 
1
  import glob
2
  import itertools
 
 
3
  import os
4
  import subprocess
5
+ from argparse import ArgumentParser
6
  from collections import Counter
 
7
 
8
+ from joblib import Parallel, delayed
9
 
10
  parser = ArgumentParser()
11
  parser.add_argument("in_dir")
 
25
  glob.iglob(f"{args.in_dir}/*/*.SVG"),
26
  )
27
 
28
+
29
  def process_file(path):
30
  basename = os.path.basename(path)
 
31
  name = os.path.splitext(basename)[0]
32
 
 
33
  try:
34
  r = subprocess.run(
35
  f'convert {path} -resize "224^>" -colorspace RGB -density 1200 {args.out_dir}/{name}.jpg',
36
  shell=True,
37
+ timeout=10,
38
  )
39
  rcode = r.returncode
40
  except subprocess.TimeoutExpired:
 
46
 
47
  return rcode
48
 
 
 
49
 
50
+ codes = Parallel(n_jobs=32, prefer="threads", verbose=1)(
51
+ delayed(process_file)(f) for f in files
52
+ )
53
+ print(Counter(codes))
test_on_image.py CHANGED
@@ -17,13 +17,21 @@ def prepare_image(image_path, model):
17
  pixel_values = torch.stack([preprocessed_image]).permute(0, 2, 3, 1).numpy()
18
  return pixel_values
19
 
 
20
  def prepare_text(text, tokenizer):
21
  return tokenizer(text, return_tensors="np")
22
 
 
23
  def run_inference(image_path, text, model, tokenizer):
24
  pixel_values = prepare_image(image_path, model)
25
  input_text = prepare_text(text, tokenizer)
26
- model_output = model(input_text["input_ids"], pixel_values, attention_mask=input_text["attention_mask"], train=False, return_dict=True)
 
 
 
 
 
 
27
  logits = model_output["logits_per_image"]
28
  score = jax.nn.sigmoid(logits)[0][0]
29
  return score
@@ -31,9 +39,11 @@ def run_inference(image_path, text, model, tokenizer):
31
 
32
  if __name__ == "__main__":
33
  model = FlaxHybridCLIP.from_pretrained("./")
34
- tokenizer = AutoTokenizer.from_pretrained("bertin-project/bertin-roberta-base-spanish")
 
 
35
 
36
  image_path = f"/home/{os.environ['USER']}/data/wit_scale_converted/Santuar.jpg"
37
  text = "Fachada del Santuario"
38
 
39
- print(run_inference(image_path, text, model, tokenizer))
 
17
  pixel_values = torch.stack([preprocessed_image]).permute(0, 2, 3, 1).numpy()
18
  return pixel_values
19
 
20
+
21
  def prepare_text(text, tokenizer):
22
  return tokenizer(text, return_tensors="np")
23
 
24
+
25
  def run_inference(image_path, text, model, tokenizer):
26
  pixel_values = prepare_image(image_path, model)
27
  input_text = prepare_text(text, tokenizer)
28
+ model_output = model(
29
+ input_text["input_ids"],
30
+ pixel_values,
31
+ attention_mask=input_text["attention_mask"],
32
+ train=False,
33
+ return_dict=True,
34
+ )
35
  logits = model_output["logits_per_image"]
36
  score = jax.nn.sigmoid(logits)[0][0]
37
  return score
 
39
 
40
  if __name__ == "__main__":
41
  model = FlaxHybridCLIP.from_pretrained("./")
42
+ tokenizer = AutoTokenizer.from_pretrained(
43
+ "bertin-project/bertin-roberta-base-spanish"
44
+ )
45
 
46
  image_path = f"/home/{os.environ['USER']}/data/wit_scale_converted/Santuar.jpg"
47
  text = "Fachada del Santuario"
48
 
49
+ print(run_inference(image_path, text, model, tokenizer))