Santiago commited on
Commit
98c1ffc
1 Parent(s): 69cfc30
README.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ tags:
5
+ - vision
6
+ license: Apache 2.0
7
+ ---
8
+
9
+ # MedCLIP
10
+
11
+ ## Model description
12
+
13
+
14
+ ## Intended uses & limitations
15
+
16
+ #### How to use
17
+
18
+ ```python
19
+ # You can include sample code which will be formatted
20
+ ```
21
+
22
+ #### Limitations and bias
23
+
24
+ Provide examples of latent issues and potential remediations.
25
+
26
+ ## Training data
27
+
28
+ Describe the data you used to train the model.
29
+ If you initialized it with pre-trained weights, add a link to the pre-trained model card or repository with description of the pre-training data.
30
+
31
+ ## Training procedure
32
+
33
+ Preprocessing, hardware used, hyperparameters...
34
+
35
+ ## Eval results
36
+
37
+ ### BibTeX entry and citation info
38
+
39
+ ```bibtex
40
+ @inproceedings{...,
41
+ year={2020}
42
+ }
43
+ ```
prepare_data.py DELETED
@@ -1,141 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding: utf-8
3
-
4
- from typing import Dict, List1
5
-
6
- import argparse
7
- import json
8
- from functools import partial
9
- import pathlib
10
- import shutil
11
- import re
12
-
13
- from tqdm import tqdm
14
- from PIL import Image
15
- import pandas as pd
16
-
17
- ImageCaptionMap = Dict[str, Dict[str, str]]
18
-
19
- def _get_image_path(row: pd.Series, root_dir: str = '.') -> str:
20
- path = [
21
- root_dir,
22
- 'files',
23
- f'p{row.subject_id}'[:3],
24
- f'p{row.subject_id}',
25
- f's{row.study_id}',
26
- f'{row.dicom_id}.jpg'
27
- ]
28
-
29
- return '/'.join(path)
30
-
31
- def _prepare_dataframe(
32
- captions: pd.DataFrame,
33
- metadata: pd.DataFrame,
34
- row: pd.Series
35
- ) -> pd.Series:
36
- if f's{row.study_id}' in captions.index:
37
- row[captions.columns] = (
38
- captions
39
- .loc[f's{row.study_id}']
40
- .apply(lambda text: (
41
- re.sub('_+', '_', text)
42
- .replace('\n', ' ')
43
- .lower().rstrip('.')
44
- ))
45
- )
46
-
47
- if row.dicom_id in metadata.index:
48
- row['view_position'] = metadata.loc[row.dicom_id, 'ViewPosition']
49
-
50
- return row
51
-
52
- def copy_image(
53
- row: pd.Series,
54
- target_path: pathlib.Path,
55
- split: str,
56
- size: int = 224
57
- ) -> str:
58
- target_img_path = target_path / split / f'{row.dicom_id}.jpg'
59
- target_img_path = str(target_img_path.resolve())
60
-
61
- img = Image.open(row.path)
62
- img = img.resize((size, size))
63
- img.save(target_img_path)
64
-
65
- return target_img_path
66
-
67
- def generate_dataset(
68
- root_dir: pathlib.Path,
69
- target_dir: pathlib.Path,
70
- split: str = 'validate'
71
- ) -> ImageCaptionMap:
72
- meta_dir = root_dir / 'metadata'
73
-
74
- metadata = pd.read_csv(meta_dir / 'mimic-cxr-2.0.0-metadata.csv')
75
- df_split = pd.read_csv(meta_dir / 'mimic-cxr-2.0.0-split.csv')
76
- captions = pd.read_csv(meta_dir / 'mimic_cxr_sectioned.csv')
77
-
78
- captions = captions.where(~captions.isna(), '').set_index('study')
79
- metadata = metadata.set_index('dicom_id')
80
-
81
- if split in df_split.split.unique():
82
- current_split = df_split[df_split.split == split]
83
- get_abs_path = partial(_get_image_path, root_dir=str(root_dir.resolve()))
84
-
85
- current_split['path'] = current_split.apply(get_abs_path, axis=1)
86
- current_split['view_position'] = ''
87
- for col in captions.columns:
88
- current_split[col] = ''
89
-
90
- preprocess_func = partial(_prepare_dataframe, captions, metadata)
91
-
92
- df = current_split.apply(preprocess_func, axis=1)
93
-
94
- else:
95
- raise ValueError('bad split')
96
-
97
- image_path_to_caption = {}
98
- (target_dir / split).mkdir(exist_ok=True, parents=True)
99
-
100
- for _, element in tqdm(df.iterrows()):
101
- caption = {
102
- 'impression': element['impression'],
103
- 'findings': element['findings'],
104
- 'last_paragraph': element['last_paragraph'],
105
- 'comparison': element['comparison'],
106
- 'view_position': element['view_position'],
107
- }
108
-
109
- image_path = copy_image(element, target_dir, split)
110
-
111
- image_path_to_caption[image_path] = caption
112
-
113
- return image_path_to_caption
114
-
115
- def dump_dataset(image_path_to_caption: ImageCaptionMap) -> List[str]:
116
- lines = []
117
-
118
- for image_path, captions in image_path_to_caption.items():
119
- lines.append(json.dumps({
120
- 'image_path': image_path,
121
- 'caption': captions,
122
- }))
123
-
124
- return lines
125
-
126
- if __name__ == '__main__':
127
- parser = argparse.ArgumentParser(description='Preprocess MIMIC-CXR dataset')
128
- parser.add_argument('--data_dir', description='MIMIC-CXR path')
129
- parser.add_argument('--target_dir', description='output path')
130
-
131
- args = parser.parse_args()
132
-
133
- data_dir = pathlib.Path(args.data_dir)
134
- target_dir = pathlib.Path(args.target_dir)
135
-
136
- for split in ['test', 'validate', 'train']:
137
- image_path_to_caption = generate_dataset(data_dir, target_dir, split)
138
- lines = dump_dataset(image_path_to_caption)
139
-
140
- with open(target_dir / f'{split}_dataset.json', 'w') as f:
141
- f.write('\n'.join(lines))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt DELETED
@@ -1,8 +0,0 @@
1
- jax>=0.2.8
2
- jaxlib>=0.1.59
3
- flax>=0.3.4
4
- optax>=0.0.8
5
- -f https://download.pytorch.org/whl/torch_stable.html
6
- torch==1.9.0+cpu
7
- -f https://download.pytorch.org/whl/torch_stable.html
8
- torchvision==0.10.0+cpu
 
 
 
 
 
 
 
 
 
run_medclip.py CHANGED
@@ -1,6 +1,6 @@
1
  #!/usr/bin/env python
2
  # coding=utf-8
3
- # Copyright 2021 The HuggingFace Team All rights reserved.
4
  #
5
  # Licensed under the Apache License, Version 2.0 (the "License");
6
  # you may not use this file except in compliance with the License.
@@ -23,7 +23,6 @@ Vision models: ViT(https://huggingface.co/models?filter=vit), CLIP (https://hugg
23
  Text models: BERT, ROBERTa (https://huggingface.co/models?filter=masked-lm)
24
  """
25
 
26
- import json
27
  import logging
28
  import os
29
  import sys
@@ -34,24 +33,26 @@ from pathlib import Path
34
  from typing import Callable, Optional
35
 
36
  import torch
37
- from torchvision.datasets import VisionDataset
38
- from torchvision.io import ImageReadMode, read_image
39
  from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
40
  from torchvision.transforms.functional import InterpolationMode
41
  from tqdm import tqdm
42
 
43
  import jax
44
  import jax.numpy as jnp
 
45
  import optax
46
  import transformers
47
  from flax import jax_utils
48
  from flax.jax_utils import unreplicate
49
  from flax.training import train_state
50
  from flax.training.common_utils import get_metrics, shard, shard_prng_key
51
- from src.modeling_medclip import FlaxHybridCLIP
52
  from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, is_tensorboard_available, set_seed
53
  import wandb
54
 
 
 
 
55
  logger = logging.getLogger(__name__)
56
 
57
  # Cache the result
@@ -69,7 +70,6 @@ else:
69
  "Please run pip install tensorboard to enable."
70
  )
71
 
72
-
73
  @dataclass
74
  class ModelArguments:
75
  """
@@ -119,16 +119,18 @@ class DataTrainingArguments:
119
  Arguments pertaining to what data we are going to input our model for training and eval.
120
  """
121
 
122
- data_dir: Optional[str] = field(default=None, metadata={"help": "The data directory containing input files."})
123
- train_file: Optional[str] = field(
124
  default=None, metadata={"help": "The input training data file (a jsonlines file)."}
125
  )
126
- validation_file: Optional[str] = field(
127
  default=None,
128
  metadata={"help": "An optional input evaluation data file (a jsonlines file)."},
129
  )
 
 
130
  max_seq_length: Optional[int] = field(
131
- default=72,
132
  metadata={
133
  "help": "The maximum total input sequence length after tokenization. Sequences longer "
134
  "than this will be truncated, sequences shorter will be padded."
@@ -155,19 +157,19 @@ class DataTrainingArguments:
155
  default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
156
  )
157
  preprocessing_num_workers: Optional[int] = field(
158
- default=None,
159
  metadata={"help": "The number of processes to use for the preprocessing."},
160
  )
161
 
162
  def __post_init__(self):
163
- if self.train_file is None and self.validation_file is None:
164
  raise ValueError("Need either a dataset name or a training/validation file.")
165
  else:
166
- if self.train_file is not None:
167
- extension = self.train_file.split(".")[-1]
168
  assert extension == "json", "`train_file` should be a json file."
169
- if self.validation_file is not None:
170
- extension = self.validation_file.split(".")[-1]
171
  assert extension == "json", "`validation_file` should be a json file."
172
 
173
 
@@ -188,71 +190,6 @@ class Transform(torch.nn.Module):
188
  x = self.transforms(x)
189
  return x
190
 
191
-
192
- class ImageTextDataset(VisionDataset):
193
- """
194
- Dtaset for loading image-text data for tasks like CLIP training, Image Captioning.
195
-
196
- Args:
197
- root: (string): The root path where the dataset is stored
198
- file_path: (string): Path to the file containing the image_paths and associated captions.
199
- The expected format is jsonlines where each line is a json object containing to keys.
200
- `image_path`: The path to the image.
201
- `captions`: An `array` of captions.
202
- transform (callable, optional): A function/transform that takes in an PIL image
203
- and returns a transformed version. E.g, ``transforms.ToTensor``
204
- target_transform (callable, optional): A function/transform that takes in the
205
- target and transforms it.
206
- transforms (callable, optional): A function/transform that takes input sample and its target as entry
207
- and returns a transformed version.
208
- """
209
-
210
- def __init__(
211
- self,
212
- root: str,
213
- file_path: str,
214
- transform: Optional[Callable] = None,
215
- target_transform: Optional[Callable] = None,
216
- transforms: Optional[Callable] = None,
217
- ):
218
- super().__init__(root, transforms, transform, target_transform)
219
-
220
- with open(file_path, "r") as f:
221
- examples = [json.loads(line) for line in f.readlines()]
222
-
223
- self.captions = []
224
- self.image_paths = []
225
-
226
- for example in examples:
227
- self.captions.append(example["caption"])
228
- self.image_paths.append(f'{root}/{example["image_path"]}')
229
-
230
- def _load_image(self, idx: int):
231
- path = self.image_paths[idx]
232
- return read_image(path, mode=ImageReadMode.RGB)
233
-
234
- def _load_target(self, idx):
235
- sections = self.captions[idx]
236
- longest_section = max(
237
- filter(lambda x: isinstance(x, str), sections.values()),
238
- key=len
239
- )
240
-
241
- return longest_section
242
-
243
- def __getitem__(self, index: int):
244
- image = self._load_image(index)
245
- target = self._load_target(index)
246
-
247
- if self.transforms is not None:
248
- image, target = self.transforms(image, target)
249
-
250
- return image, target
251
-
252
- def __len__(self) -> int:
253
- return len(self.captions)
254
-
255
-
256
  class TrainState(train_state.TrainState):
257
  dropout_rng: jnp.ndarray
258
 
@@ -348,7 +285,7 @@ def main():
348
  "You can do it from another script, save it, and load it from here, using --tokenizer_name."
349
  )
350
 
351
- model = FlaxHybridCLIP.from_text_vision_pretrained(
352
  model_args.text_model_name_or_path,
353
  model_args.vision_model_name_or_path,
354
  seed=training_args.seed,
@@ -364,18 +301,51 @@ def main():
364
  preprocess = Transform(config.vision_config.image_size)
365
  preprocess = torch.jit.script(preprocess)
366
 
367
- # Initialize the image-text dataset
368
- train_dataset = ImageTextDataset(
369
- data_args.data_dir,
370
- data_args.train_file,
371
- transform=preprocess,
372
- )
 
 
 
 
 
 
 
373
 
374
- eval_dataset = ImageTextDataset(
375
- data_args.data_dir,
376
- data_args.validation_file,
377
- transform=preprocess,
378
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
 
380
  # Store some constant
381
  num_epochs = int(training_args.num_train_epochs)
@@ -387,8 +357,15 @@ def main():
387
  # Use collate function to tokenizer the text and convert the processed images to numpy
388
  def collate_fn(examples):
389
  pixel_values = torch.stack([example[0] for example in examples]).permute(0, 2, 3, 1).numpy()
390
- captions = [example[1] for example in examples]
391
- inputs = tokenizer(captions, max_length=data_args.max_seq_length, padding="max_length", return_tensors="np")
 
 
 
 
 
 
 
392
 
393
  batch = {
394
  "pixel_values": pixel_values,
@@ -406,6 +383,7 @@ def main():
406
  num_workers=data_args.preprocessing_num_workers,
407
  persistent_workers=True,
408
  drop_last=True,
 
409
  collate_fn=collate_fn,
410
  )
411
 
@@ -416,17 +394,20 @@ def main():
416
  num_workers=data_args.preprocessing_num_workers,
417
  persistent_workers=True,
418
  drop_last=True,
 
419
  collate_fn=collate_fn,
420
  )
421
 
422
  # Enable tensorboard only on the master node
423
  if has_tensorboard and jax.process_index() == 0:
424
- summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir).joinpath("logs").as_posix())
 
425
 
426
  # Initialize our training
427
  rng = jax.random.PRNGKey(training_args.seed)
428
  rng, dropout_rng = jax.random.split(rng)
429
 
 
430
  # Create learning rate schedule
431
  linear_decay_lr_schedule_fn = create_learning_rate_fn(
432
  len(train_dataset),
@@ -435,10 +416,17 @@ def main():
435
  training_args.warmup_steps,
436
  training_args.learning_rate,
437
  )
 
 
 
 
 
 
 
438
 
439
  # create adam optimizer
440
- adamw = optax.adamw(
441
- learning_rate=linear_decay_lr_schedule_fn,
442
  b1=training_args.adam_beta1,
443
  b2=training_args.adam_beta2,
444
  eps=training_args.adam_epsilon,
@@ -473,7 +461,7 @@ def main():
473
 
474
  new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
475
 
476
- metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
477
  metrics = jax.lax.pmean(metrics, axis_name="batch")
478
 
479
  return new_state, metrics
@@ -506,6 +494,7 @@ def main():
506
  # Create sampling rng
507
  rng, input_rng = jax.random.split(rng)
508
 
 
509
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
510
  for epoch in epochs:
511
  # ======================== Training ================================
@@ -572,6 +561,11 @@ def main():
572
  commit_message=f"Saving weights and logs of epoch {epoch+1}",
573
  )
574
 
 
 
 
 
575
 
576
  if __name__ == "__main__":
577
- main()
 
 
1
  #!/usr/bin/env python
2
  # coding=utf-8
3
+ # Copyright 2021 Santiago Hincapie-Potes & The HuggingFace Team All rights reserved.
4
  #
5
  # Licensed under the Apache License, Version 2.0 (the "License");
6
  # you may not use this file except in compliance with the License.
 
23
  Text models: BERT, ROBERTa (https://huggingface.co/models?filter=masked-lm)
24
  """
25
 
 
26
  import logging
27
  import os
28
  import sys
 
33
  from typing import Callable, Optional
34
 
35
  import torch
36
+ from torch.utils.data import ConcatDataset
 
37
  from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
38
  from torchvision.transforms.functional import InterpolationMode
39
  from tqdm import tqdm
40
 
41
  import jax
42
  import jax.numpy as jnp
43
+ import numpy as onp
44
  import optax
45
  import transformers
46
  from flax import jax_utils
47
  from flax.jax_utils import unreplicate
48
  from flax.training import train_state
49
  from flax.training.common_utils import get_metrics, shard, shard_prng_key
 
50
  from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, is_tensorboard_available, set_seed
51
  import wandb
52
 
53
+ from src.modeling_medclip import FlaxMedCLIP
54
+ from src.datasets_medclip import MIMICDataset, ROCODataset
55
+
56
  logger = logging.getLogger(__name__)
57
 
58
  # Cache the result
 
70
  "Please run pip install tensorboard to enable."
71
  )
72
 
 
73
  @dataclass
74
  class ModelArguments:
75
  """
 
119
  Arguments pertaining to what data we are going to input our model for training and eval.
120
  """
121
 
122
+ mimic_data_dir: Optional[str] = field(default=None, metadata={"help": "The data directory with that containing the MIMIC-CXD dataset."})
123
+ mimic_train_file: Optional[str] = field(
124
  default=None, metadata={"help": "The input training data file (a jsonlines file)."}
125
  )
126
+ mimic_validation_file: Optional[str] = field(
127
  default=None,
128
  metadata={"help": "An optional input evaluation data file (a jsonlines file)."},
129
  )
130
+ mimic_mode: Optional[str] = field(default=None, metadata={"help": "longest or docs"})
131
+ roco_data_dir: Optional[str] = field(default=None, metadata={"help": "The data directory with that containing the ROCO dataset."})
132
  max_seq_length: Optional[int] = field(
133
+ default=128,
134
  metadata={
135
  "help": "The maximum total input sequence length after tokenization. Sequences longer "
136
  "than this will be truncated, sequences shorter will be padded."
 
157
  default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
158
  )
159
  preprocessing_num_workers: Optional[int] = field(
160
+ default=32,
161
  metadata={"help": "The number of processes to use for the preprocessing."},
162
  )
163
 
164
  def __post_init__(self):
165
+ if self.mimic_train_file is None and self.mimic_validation_file is None:
166
  raise ValueError("Need either a dataset name or a training/validation file.")
167
  else:
168
+ if self.mimic_train_file is not None:
169
+ extension = self.mimic_train_file.split(".")[-1]
170
  assert extension == "json", "`train_file` should be a json file."
171
+ if self.mimic_validation_file is not None:
172
+ extension = self.mimic_validation_file.split(".")[-1]
173
  assert extension == "json", "`validation_file` should be a json file."
174
 
175
 
 
190
  x = self.transforms(x)
191
  return x
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  class TrainState(train_state.TrainState):
194
  dropout_rng: jnp.ndarray
195
 
 
285
  "You can do it from another script, save it, and load it from here, using --tokenizer_name."
286
  )
287
 
288
+ model = FlaxMedCLIP.from_text_vision_pretrained(
289
  model_args.text_model_name_or_path,
290
  model_args.vision_model_name_or_path,
291
  seed=training_args.seed,
 
301
  preprocess = Transform(config.vision_config.image_size)
302
  preprocess = torch.jit.script(preprocess)
303
 
304
+ _train_datasets = []
305
+ _eval_datasets = []
306
+
307
+ if data_args.mimic_data_dir is not None:
308
+ # Initialize the image-text dataset
309
+ _train_datasets.append(
310
+ MIMICDataset(
311
+ data_args.mimic_data_dir,
312
+ data_args.mimic_train_file,
313
+ transform=preprocess,
314
+ mode=data_args.mimic_mode,
315
+ )
316
+ )
317
 
318
+ _eval_datasets.append(
319
+ MIMICDataset(
320
+ data_args.mimic_data_dir,
321
+ data_args.mimic_validation_file,
322
+ transform=preprocess,
323
+ mode=data_args.mimic_mode,
324
+ )
325
+ )
326
+
327
+ if data_args.roco_data_dir is not None:
328
+ _train_datasets.append(
329
+ ROCODataset(
330
+ data_args.roco_data_dir,
331
+ split="train",
332
+ transform=preprocess,
333
+ )
334
+ )
335
+
336
+ _eval_datasets.append(
337
+ ROCODataset(
338
+ data_args.roco_data_dir,
339
+ split="validate",
340
+ transform=preprocess,
341
+ )
342
+ )
343
+
344
+ if not _train_datasets or not _eval_datasets:
345
+ raise ValueError
346
+ else:
347
+ train_dataset = ConcatDataset(_train_datasets)
348
+ eval_dataset = ConcatDataset(_eval_datasets)
349
 
350
  # Store some constant
351
  num_epochs = int(training_args.num_train_epochs)
 
357
  # Use collate function to tokenizer the text and convert the processed images to numpy
358
  def collate_fn(examples):
359
  pixel_values = torch.stack([example[0] for example in examples]).permute(0, 2, 3, 1).numpy()
360
+ texts = [example[1] for example in examples]
361
+
362
+ inputs = tokenizer(
363
+ texts,
364
+ max_length=data_args.max_seq_length,
365
+ padding="max_length",
366
+ return_tensors="np",
367
+ truncation=True,
368
+ )
369
 
370
  batch = {
371
  "pixel_values": pixel_values,
 
383
  num_workers=data_args.preprocessing_num_workers,
384
  persistent_workers=True,
385
  drop_last=True,
386
+ pin_memory=True,
387
  collate_fn=collate_fn,
388
  )
389
 
 
394
  num_workers=data_args.preprocessing_num_workers,
395
  persistent_workers=True,
396
  drop_last=True,
397
+ pin_memory=True,
398
  collate_fn=collate_fn,
399
  )
400
 
401
  # Enable tensorboard only on the master node
402
  if has_tensorboard and jax.process_index() == 0:
403
+ log_dir = Path(training_args.output_dir).joinpath("logs").as_posix()
404
+ summary_writer = SummaryWriter(log_dir=log_dir)
405
 
406
  # Initialize our training
407
  rng = jax.random.PRNGKey(training_args.seed)
408
  rng, dropout_rng = jax.random.split(rng)
409
 
410
+ """
411
  # Create learning rate schedule
412
  linear_decay_lr_schedule_fn = create_learning_rate_fn(
413
  len(train_dataset),
 
416
  training_args.warmup_steps,
417
  training_args.learning_rate,
418
  )
419
+ """
420
+
421
+ cosine_decay_lr_schedule_fn = optax.cosine_decay_schedule(
422
+ training_args.learning_rate,
423
+ training_args.warmup_steps,
424
+ training_args.learning_rate / 1000,
425
+ )
426
 
427
  # create adam optimizer
428
+ adamw = optax.lamb(
429
+ learning_rate=cosine_decay_lr_schedule_fn, #linear_decay_lr_schedule_fn,
430
  b1=training_args.adam_beta1,
431
  b2=training_args.adam_beta2,
432
  eps=training_args.adam_epsilon,
 
461
 
462
  new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
463
 
464
+ metrics = {"loss": loss, "learning_rate": cosine_decay_lr_schedule_fn(state.step)}
465
  metrics = jax.lax.pmean(metrics, axis_name="batch")
466
 
467
  return new_state, metrics
 
494
  # Create sampling rng
495
  rng, input_rng = jax.random.split(rng)
496
 
497
+ #jax.profiler.start_trace(log_dir)
498
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
499
  for epoch in epochs:
500
  # ======================== Training ================================
 
561
  commit_message=f"Saving weights and logs of epoch {epoch+1}",
562
  )
563
 
564
+ #jax.profiler.stop_trace()
565
+
566
+ return model, params
567
+
568
 
569
  if __name__ == "__main__":
570
+ model, params = main()
571
+ model.save_pretrained("model", params=params)
src/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (132 Bytes). View file
 
src/__pycache__/configuration_medclip.cpython-38.pyc ADDED
Binary file (4.14 kB). View file
 
src/__pycache__/datasets_medclip.cpython-38.pyc ADDED
Binary file (6.04 kB). View file
 
src/__pycache__/modeling_medclip.cpython-38.pyc ADDED
Binary file (12.9 kB). View file
 
src/configuration_medclip.py CHANGED
@@ -7,10 +7,10 @@ from transformers.utils import logging
7
  logger = logging.get_logger(__name__)
8
 
9
 
10
- class HybridCLIPConfig(PretrainedConfig):
11
  r"""
12
- :class:`HybridCLIPConfig` is the configuration class to store the configuration of a
13
- :class:`~HybridCLIPModel`. It is used to instantiate HybridCLIPModel model according to the specified arguments,
14
  defining the text model and vision model configs.
15
 
16
  Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
@@ -28,13 +28,13 @@ class HybridCLIPConfig(PretrainedConfig):
28
 
29
  Examples::
30
 
31
- >>> from transformers import BertConfig, CLIPConfig, HybridCLIPConfig, FlaxHybridCLIP
32
 
33
  >>> # Initializing a BERT and CLIP configuration
34
  >>> config_text = BertConfig()
35
  >>> config_vision = CLIPConfig()
36
 
37
- >>> config = HybridCLIPConfig.from_text_vision_configs(config_text, config_vision, projection_dim=512)
38
 
39
  >>> # Initializing a BERT and CLIPVision model
40
  >>> model = EncoderDecoderModel(config=config)
@@ -47,8 +47,8 @@ class HybridCLIPConfig(PretrainedConfig):
47
  >>> model.save_pretrained('my-model')
48
 
49
  >>> # loading model and config from pretrained folder
50
- >>> encoder_decoder_config = HybridCLIPConfig.from_pretrained('my-model')
51
- >>> model = FlaxHybridCLIP.from_pretrained('my-model', config=encoder_decoder_config)
52
  """
53
 
54
  model_type = "hybrid-clip"
@@ -84,11 +84,11 @@ class HybridCLIPConfig(PretrainedConfig):
84
  @classmethod
85
  def from_text_vision_configs(cls, text_config: PretrainedConfig, vision_config: PretrainedConfig, **kwargs):
86
  r"""
87
- Instantiate a :class:`HybridCLIPConfig` (or a derived class) from text model configuration and
88
  vision model configuration.
89
 
90
  Returns:
91
- :class:`HybridCLIPConfig`: An instance of a configuration object
92
  """
93
 
94
  return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
 
7
  logger = logging.get_logger(__name__)
8
 
9
 
10
+ class MedCLIPConfig(PretrainedConfig):
11
  r"""
12
+ :class:`MedCLIPConfig` is the configuration class to store the configuration of a
13
+ :class:`~MedCLIPModel`. It is used to instantiate HybridCLIPModel model according to the specified arguments,
14
  defining the text model and vision model configs.
15
 
16
  Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
 
28
 
29
  Examples::
30
 
31
+ >>> from transformers import BertConfig, CLIPConfig, MedCLIPConfig, FlaxMedCLIP
32
 
33
  >>> # Initializing a BERT and CLIP configuration
34
  >>> config_text = BertConfig()
35
  >>> config_vision = CLIPConfig()
36
 
37
+ >>> config = MedCLIPConfig.from_text_vision_configs(config_text, config_vision, projection_dim=512)
38
 
39
  >>> # Initializing a BERT and CLIPVision model
40
  >>> model = EncoderDecoderModel(config=config)
 
47
  >>> model.save_pretrained('my-model')
48
 
49
  >>> # loading model and config from pretrained folder
50
+ >>> encoder_decoder_config = MedCLIPConfig.from_pretrained('my-model')
51
+ >>> model = FlaxMedCLIP.from_pretrained('my-model', config=encoder_decoder_config)
52
  """
53
 
54
  model_type = "hybrid-clip"
 
84
  @classmethod
85
  def from_text_vision_configs(cls, text_config: PretrainedConfig, vision_config: PretrainedConfig, **kwargs):
86
  r"""
87
+ Instantiate a :class:`MedCLIPConfig` (or a derived class) from text model configuration and
88
  vision model configuration.
89
 
90
  Returns:
91
+ :class:`MedCLIPConfig`: An instance of a configuration object
92
  """
93
 
94
  return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
src/datasets_medclip.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 Santiago Hincapie-Potes & The HuggingFace Team All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import csv
17
+ import json
18
+ import random
19
+ from pathlib import Path
20
+ from typing import Callable, Dict, Optional, Union
21
+
22
+ from torchvision.datasets import VisionDataset
23
+ from torchvision.io import ImageReadMode, read_image
24
+
25
+ class MIMICDataset(VisionDataset):
26
+ """
27
+ Dataset for loading image-text data for tasks like CLIP training, Image Captioning.
28
+
29
+ Args:
30
+ root: (string): The root path where the dataset is stored
31
+ file_path: (string): Path to the file containing the image_paths and associated captions.
32
+ The expected format is jsonlines where each line is a json object containing to keys.
33
+ `image_path`: The path to the image.
34
+ `captions`: An `array` of captions.
35
+ mode: (string): target format:
36
+ * 'longest': return the longest sections
37
+ * 'docs': return findings and impressions
38
+ transform (callable, optional): A function/transform that takes in an PIL image
39
+ and returns a transformed version. E.g, ``transforms.ToTensor``
40
+ target_transform (callable, optional): A function/transform that takes in the
41
+ target and transforms it.
42
+ transforms (callable, optional): A function/transform that takes input sample and its target as entry
43
+ and returns a transformed version.
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ root: str,
49
+ file_path: str,
50
+ mode: str = 'longest',
51
+ transform: Optional[Callable] = None,
52
+ target_transform: Optional[Callable] = None,
53
+ transforms: Optional[Callable] = None,
54
+ ):
55
+ super().__init__(root, transforms, transform, target_transform)
56
+
57
+ root = Path(root)
58
+
59
+ if not mode in {'longest', 'docs'}:
60
+ raise ValueError('Invalid mode')
61
+
62
+ self.mode = mode
63
+
64
+ with open(root / file_path, "r") as f:
65
+ examples = [json.loads(line) for line in f.readlines()]
66
+
67
+ self.captions = []
68
+ self.image_paths = []
69
+
70
+ for example in examples:
71
+ img_path = root / example["image_path"]
72
+ if img_path.exists():
73
+ self.captions.append(example["caption"])
74
+ self.image_paths.append(str(img_path))
75
+
76
+ def _load_image(self, idx: int):
77
+ path = self.image_paths[idx]
78
+ return read_image(path, mode=ImageReadMode.RGB)
79
+
80
+ def _load_target(self, idx) -> str:
81
+ sections = self.captions[idx]
82
+
83
+ if self.mode == 'docs':
84
+ _collection = []
85
+ if 'impression' in sections:
86
+ _collection.append(sections['impression'])
87
+
88
+ if 'findings' in sections:
89
+ _collection.append(sections['findings'])
90
+
91
+ if len(_collection) == 1:
92
+ output = _collection[0]
93
+ if len(_collection) == 2:
94
+ output = random.choice(_collection)
95
+
96
+ if self.mode == 'longest' or len(_collection) == 0:
97
+ longest_section = max(
98
+ filter(lambda x: isinstance(x, str), sections.values()),
99
+ key=len
100
+ )
101
+
102
+ output = longest_section
103
+
104
+ return output
105
+
106
+ def __getitem__(self, index: int):
107
+ image = self._load_image(index)
108
+ target = self._load_target(index)
109
+
110
+ if self.transforms is not None:
111
+ image, target = self.transforms(image, target)
112
+
113
+ return image, target
114
+
115
+ def __len__(self) -> int:
116
+ return len(self.captions)
117
+
118
+
119
+ class ROCODataset(VisionDataset):
120
+ """
121
+ Dataset for loading image-text data for tasks like CLIP training, Image Captioning.
122
+
123
+ Args:
124
+ root: (string): The root path where the dataset is stored
125
+ file_path: (string): Path to the file containing the image_paths and associated captions.
126
+ The expected format is jsonlines where each line is a json object containing to keys.
127
+ `image_path`: The path to the image.
128
+ `captions`: An `array` of captions.
129
+ transform (callable, optional): A function/transform that takes in an PIL image
130
+ and returns a transformed version. E.g, ``transforms.ToTensor``
131
+ target_transform (callable, optional): A function/transform that takes in the
132
+ target and transforms it.
133
+ transforms (callable, optional): A function/transform that takes input sample and its target as entry
134
+ and returns a transformed version.
135
+ """
136
+
137
+ def __init__(
138
+ self,
139
+ root: str,
140
+ split: str,
141
+ transform: Optional[Callable] = None,
142
+ target_transform: Optional[Callable] = None,
143
+ transforms: Optional[Callable] = None,
144
+ ):
145
+ super().__init__(root, transforms, transform, target_transform)
146
+
147
+ root = Path(root) / f"{split}/radiology/"
148
+ file_path = f"{split}.csv"
149
+
150
+ self.captions = []
151
+ self.image_paths = []
152
+
153
+ with open((root / file_path).resolve(), 'r') as buf:
154
+ csv_reader = csv.reader(buf)
155
+ next(csv_reader) # skip header
156
+
157
+ for row in csv_reader:
158
+ if len(row) == 3:
159
+ _, fname, caption = row
160
+ else:
161
+ print(row)
162
+ self.captions.append(caption.strip())
163
+ self.image_paths.append(str(root / 'images' / fname.strip()))
164
+
165
+ def _load_image(self, idx: int):
166
+ path = self.image_paths[idx]
167
+ return read_image(path, mode=ImageReadMode.RGB)
168
+
169
+ def _load_target(self, idx: int) -> str:
170
+ return self.captions[idx]
171
+
172
+ def __getitem__(self, index: int):
173
+ image = self._load_image(index)
174
+ target = self._load_target(index)
175
+
176
+ if self.transforms is not None:
177
+ image, target = self.transforms(image, target)
178
+
179
+ return image, target
180
+
181
+ def __len__(self) -> int:
182
+ return len(self.captions)
src/modeling_medclip.py CHANGED
@@ -1,5 +1,5 @@
1
  # coding=utf-8
2
- # Copyright 2021 The HuggingFace Team. All rights reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
@@ -18,7 +18,7 @@ from typing import Optional, Tuple
18
  import flax.linen as nn
19
  import jax
20
  import jax.numpy as jnp
21
- from src.configuration_medclip import HybridCLIPConfig
22
  from flax.core.frozen_dict import FrozenDict
23
  from transformers import FLAX_MODEL_MAPPING, FlaxCLIPVisionModel
24
  from transformers.modeling_flax_utils import FlaxPreTrainedModel
@@ -29,8 +29,8 @@ from transformers.utils import logging
29
  logger = logging.get_logger(__name__)
30
 
31
 
32
- class FlaxHybridCLIPModule(nn.Module):
33
- config: HybridCLIPConfig
34
  dtype: jnp.dtype = jnp.float32
35
 
36
  def setup(self):
@@ -122,13 +122,13 @@ class FlaxHybridCLIPModule(nn.Module):
122
  )
123
 
124
 
125
- class FlaxHybridCLIP(FlaxPreTrainedModel):
126
- config_class = HybridCLIPConfig
127
- module_class = FlaxHybridCLIPModule
128
 
129
  def __init__(
130
  self,
131
- config: HybridCLIPConfig,
132
  input_shape: Optional[Tuple] = None,
133
  seed: int = 0,
134
  dtype: jnp.dtype = jnp.float32,
@@ -347,14 +347,14 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
347
 
348
  Example::
349
 
350
- >>> from transformers import FlaxHybridCLIP
351
  >>> # initialize a model from pretrained BERT and CLIP models. Note that the projection layers will be randomly initialized.
352
  >>> # If using CLIP's vision model the vision projection layer will be initialized using pre-trained weights
353
- >>> model = FlaxHybridCLIP.from_text_vision_pretrained('bert-base-uncased', 'openai/clip-vit-base-patch32')
354
  >>> # saving model after fine-tuning
355
  >>> model.save_pretrained("./bert-clip")
356
  >>> # load fine-tuned model
357
- >>> model = FlaxHybridCLIP.from_pretrained("./bert-clip")
358
  """
359
 
360
  kwargs_text = {
@@ -404,7 +404,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
404
 
405
  # instantiate config with corresponding kwargs
406
  dtype = kwargs.pop("dtype", jnp.float32)
407
- config = HybridCLIPConfig.from_text_vision_configs(text_model.config, vision_model.config, **kwargs)
408
 
409
  # init model
410
  model = cls(config, *model_args, dtype=dtype, **kwargs)
 
1
  # coding=utf-8
2
+ # Copyright 2021 Santiago Hincapie-Potes & The HuggingFace Team. All rights reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
 
18
  import flax.linen as nn
19
  import jax
20
  import jax.numpy as jnp
21
+ from src.configuration_medclip import MedCLIPConfig
22
  from flax.core.frozen_dict import FrozenDict
23
  from transformers import FLAX_MODEL_MAPPING, FlaxCLIPVisionModel
24
  from transformers.modeling_flax_utils import FlaxPreTrainedModel
 
29
  logger = logging.get_logger(__name__)
30
 
31
 
32
+ class FlaxMedCLIPModule(nn.Module):
33
+ config: MedCLIPConfig
34
  dtype: jnp.dtype = jnp.float32
35
 
36
  def setup(self):
 
122
  )
123
 
124
 
125
+ class FlaxMedCLIP(FlaxPreTrainedModel):
126
+ config_class = MedCLIPConfig
127
+ module_class = FlaxMedCLIPModule
128
 
129
  def __init__(
130
  self,
131
+ config: MedCLIPConfig,
132
  input_shape: Optional[Tuple] = None,
133
  seed: int = 0,
134
  dtype: jnp.dtype = jnp.float32,
 
347
 
348
  Example::
349
 
350
+ >>> from transformers import FlaxMedCLIP
351
  >>> # initialize a model from pretrained BERT and CLIP models. Note that the projection layers will be randomly initialized.
352
  >>> # If using CLIP's vision model the vision projection layer will be initialized using pre-trained weights
353
+ >>> model = FlaxMedCLIP.from_text_vision_pretrained('bert-base-uncased', 'openai/clip-vit-base-patch32')
354
  >>> # saving model after fine-tuning
355
  >>> model.save_pretrained("./bert-clip")
356
  >>> # load fine-tuned model
357
+ >>> model = FlaxMedCLIP.from_pretrained("./bert-clip")
358
  """
359
 
360
  kwargs_text = {
 
404
 
405
  # instantiate config with corresponding kwargs
406
  dtype = kwargs.pop("dtype", jnp.float32)
407
+ config = MedCLIPConfig.from_text_vision_configs(text_model.config, vision_model.config, **kwargs)
408
 
409
  # init model
410
  model = cls(config, *model_args, dtype=dtype, **kwargs)
tasks/prepare_roco.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ from pathlib import Path
3
+ import torchvision
4
+
5
+ def main(roco_root: str):
6
+ root = Path(roco_root)
7
+
8
+ check_images(
9
+ root / 'train/radiology', 'traindata.csv', 'train.csv'
10
+ )
11
+
12
+ check_images(
13
+ root / 'validate/radiology', 'valdata.csv', 'validate.csv'
14
+ )
15
+
16
+ check_images(
17
+ root / 'test/radiology', 'testdata.csv', 'test.csv'
18
+ )
19
+
20
+ def check_images(split_dir: Path, input_csv: str, target_output: str):
21
+ with open(split_dir / input_csv, 'r') as buf:
22
+ csv_reader = csv.reader(buf)
23
+ next(csv_reader, None)
24
+
25
+ filtered_csv = []
26
+
27
+ for row in csv_reader:
28
+ image_path = split_dir / 'images' / row[1]
29
+ try:
30
+ torchvision.io.read_image(str(image_path))
31
+ except:
32
+ continue
33
+ filtered_csv.append(row)
34
+
35
+ with open(split_dir / target_output, 'w') as csvfile:
36
+ spamwriter = csv.writer(csvfile)
37
+ for row in filtered_csv:
38
+ spamwriter.writerow(row)
39
+
40
+
41
+ if __name__ == '__main__':
42
+ main('/home/shpotes/medclip/data/roco-dataset')
43
+ {mode:full,isActive:false}
train_model.sh CHANGED
@@ -1,15 +1,18 @@
1
  python run_medclip.py \
2
- --output_dir model \
 
3
  --text_model_name_or_path="allenai/scibert_scivocab_uncased" \
4
  --vision_model_name_or_path="openai/clip-vit-base-patch32" \
5
  --tokenizer_name="allenai/scibert_scivocab_uncased" \
6
- --data_dir="/home/shared/data/mimic-cxr" \
7
- --train_file="/home/shared/data/mimic-cxr/train_dataset.json" \
8
- --validation_file="/home/shared/data/mimic-cxr/validate_dataset.json" \
 
 
9
  --do_train --do_eval \
10
- --num_train_epochs="40" --max_seq_length 512 \
11
- --per_device_train_batch_size="64" \
12
- --per_device_eval_batch_size="64" \
13
- --learning_rate="5e-5" --warmup_steps="0" --weight_decay 0.1 \
14
- --overwrite_output_dir \
15
- --preprocessing_num_workers 32 \
 
1
  python run_medclip.py \
2
+ --output_dir "flax-community/medclip" \
3
+ --overwrite_output_dir \
4
  --text_model_name_or_path="allenai/scibert_scivocab_uncased" \
5
  --vision_model_name_or_path="openai/clip-vit-base-patch32" \
6
  --tokenizer_name="allenai/scibert_scivocab_uncased" \
7
+ --mimic_data_dir="/home/shpotes/medclip/data/mimic-cxr/" \
8
+ --mimic_train_file="train_dataset.json" \
9
+ --mimic_validation_file="validate_dataset.json" \
10
+ --mimic_mode="docs" \
11
+ --roco_data_dir="/home/shpotes/medclip/data/roco-dataset/" \
12
  --do_train --do_eval \
13
+ --num_train_epochs="20" \
14
+ --preprocessing_num_workers=32 \
15
+ --per_device_train_batch_size=64 \
16
+ --per_device_eval_batch_size=64 \
17
+ --warmup_steps=3000 \
18
+ --learning_rate="3e-4"