boris commited on
Commit
7939874
·
unverified ·
1 Parent(s): 803ccbf

feat(data): super conditioning (#141)

Browse files

* feat(data): online filtering
* feat(generate): super conditioning
* feat: add processor

README.md CHANGED
@@ -35,7 +35,6 @@ To generate sample predictions and understand the inference pipeline step by ste
35
  Join the community on the [DALLE-Pytorch Discord](https://discord.gg/xBPBXfcFHd).
36
  Any contribution is welcome, from reporting issues to proposing fixes/improvements or testing the model with cool prompts!
37
 
38
-
39
  ## Development
40
 
41
  ### Dependencies Installation
@@ -95,6 +94,7 @@ Many thanks to the people who helped make it better:
95
 
96
  - the [DALLE-Pytorch](https://discord.gg/xBPBXfcFHd) and [EleutherAI](https://www.eleuther.ai/) communities for testing and exchanging cool ideas
97
  - [Rohan Anil](https://github.com/rohan-anil) for adding Distributed Shampoo optimizer
 
98
 
99
  ## Citing DALL·E mini
100
 
 
35
  Join the community on the [DALLE-Pytorch Discord](https://discord.gg/xBPBXfcFHd).
36
  Any contribution is welcome, from reporting issues to proposing fixes/improvements or testing the model with cool prompts!
37
 
 
38
  ## Development
39
 
40
  ### Dependencies Installation
 
94
 
95
  - the [DALLE-Pytorch](https://discord.gg/xBPBXfcFHd) and [EleutherAI](https://www.eleuther.ai/) communities for testing and exchanging cool ideas
96
  - [Rohan Anil](https://github.com/rohan-anil) for adding Distributed Shampoo optimizer
97
+ - [Katherine Crowson](https://github.com/crowsonkb) for [super conditioning](https://twitter.com/RiversHaveWings/status/1478093658716966912)
98
 
99
  ## Citing DALL·E mini
100
 
src/dalle_mini/__init__.py CHANGED
@@ -1 +1,3 @@
1
- __version__ = "0.0.2"
 
 
 
1
+ __version__ = "0.0.3"
2
+
3
+ from .model import DalleBart, DalleBartProcessor
src/dalle_mini/data.py CHANGED
@@ -7,7 +7,7 @@ import numpy as np
7
  from braceexpand import braceexpand
8
  from datasets import Dataset, load_dataset
9
 
10
- from .text import TextNormalizer
11
 
12
 
13
  @dataclass
@@ -28,6 +28,11 @@ class Dataset:
28
  seed_dataset: int = None
29
  shard_by_host: bool = False
30
  blank_caption_prob: float = 0.0
 
 
 
 
 
31
  train_dataset: Dataset = field(init=False)
32
  eval_dataset: Dataset = field(init=False)
33
  rng_dataset: jnp.ndarray = field(init=False)
@@ -36,6 +41,7 @@ class Dataset:
36
  def __post_init__(self):
37
  self.multi_hosts = jax.process_count() > 1
38
  # feed blank captions only in streaming mode for now
 
39
  if self.blank_caption_prob:
40
  assert (
41
  self.streaming is True
@@ -107,23 +113,30 @@ class Dataset:
107
  self.seed_dataset = np.random.get_state()[1][0]
108
  self.rng_dataset = jax.random.PRNGKey(self.seed_dataset)
109
 
110
- # blank captions
111
- if self.blank_caption_prob:
112
- partial_blank_caption_function = partial(
113
- blank_caption_function,
114
- text_column=self.text_column,
115
- blank_caption_prob=self.blank_caption_prob,
116
- )
117
- if hasattr(self, "train_dataset"):
118
- self.train_dataset = (
119
- self.train_dataset.map(partial_blank_caption_function)
120
- if self.streaming
121
- else self.train_dataset.map(
122
- partial_blank_caption_function,
123
- num_proc=self.preprocessing_num_workers,
124
- load_from_cache_file=False,
125
- desc="Blanking some captions",
126
- )
 
 
 
 
 
 
 
127
  )
128
 
129
  # normalize text
@@ -151,6 +164,25 @@ class Dataset:
151
  ),
152
  )
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  # preprocess
155
  partial_preprocess_function = partial(
156
  preprocess_function,
@@ -230,8 +262,8 @@ class Dataset:
230
  dataset.set_epoch(epoch)
231
  epoch += 1
232
  for item in dataset:
233
- for k, v in item.items():
234
- batch[k].append(v)
235
  if len(batch[keys[0]]) == batch_size:
236
  batch = {k: jnp.array(v) for k, v in batch.items()}
237
  yield batch
@@ -292,6 +324,23 @@ def normalize_function(example, text_column, text_normalizer):
292
  return example
293
 
294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  def preprocess_function(
296
  examples,
297
  tokenizer,
 
7
  from braceexpand import braceexpand
8
  from datasets import Dataset, load_dataset
9
 
10
+ from .model.text import TextNormalizer
11
 
12
 
13
  @dataclass
 
28
  seed_dataset: int = None
29
  shard_by_host: bool = False
30
  blank_caption_prob: float = 0.0
31
+ clip_score_column: str = "clip_score"
32
+ min_clip_score: float = None
33
+ max_clip_score: float = None
34
+ filter_column: str = None
35
+ filter_value: str = None
36
  train_dataset: Dataset = field(init=False)
37
  eval_dataset: Dataset = field(init=False)
38
  rng_dataset: jnp.ndarray = field(init=False)
 
41
  def __post_init__(self):
42
  self.multi_hosts = jax.process_count() > 1
43
  # feed blank captions only in streaming mode for now
44
+ # otherwise dataset could be cached with same blanked captions
45
  if self.blank_caption_prob:
46
  assert (
47
  self.streaming is True
 
113
  self.seed_dataset = np.random.get_state()[1][0]
114
  self.rng_dataset = jax.random.PRNGKey(self.seed_dataset)
115
 
116
+ # filter data
117
+ partial_filter_function = partial(
118
+ filter_function,
119
+ filter_column=self.filter_column,
120
+ filter_value=self.filter_value,
121
+ clip_score_column=self.clip_score_column,
122
+ min_clip_score=self.min_clip_score,
123
+ max_clip_score=self.max_clip_score,
124
+ )
125
+ for ds in ["train_dataset", "eval_dataset"]:
126
+ if hasattr(self, ds):
127
+ setattr(
128
+ self,
129
+ ds,
130
+ (
131
+ getattr(self, ds).filter(partial_filter_function)
132
+ if self.streaming
133
+ else getattr(self, ds).filter(
134
+ partial_filter_function,
135
+ num_proc=self.preprocessing_num_workers,
136
+ load_from_cache_file=not self.overwrite_cache,
137
+ desc="Filtering datasets",
138
+ )
139
+ ),
140
  )
141
 
142
  # normalize text
 
164
  ),
165
  )
166
 
167
+ # blank captions
168
+ if self.blank_caption_prob:
169
+ partial_blank_caption_function = partial(
170
+ blank_caption_function,
171
+ text_column=self.text_column,
172
+ blank_caption_prob=self.blank_caption_prob,
173
+ )
174
+ if hasattr(self, "train_dataset"):
175
+ self.train_dataset = (
176
+ self.train_dataset.map(partial_blank_caption_function)
177
+ if self.streaming
178
+ else self.train_dataset.map(
179
+ partial_blank_caption_function,
180
+ num_proc=self.preprocessing_num_workers,
181
+ load_from_cache_file=False,
182
+ desc="Blanking some captions",
183
+ )
184
+ )
185
+
186
  # preprocess
187
  partial_preprocess_function = partial(
188
  preprocess_function,
 
262
  dataset.set_epoch(epoch)
263
  epoch += 1
264
  for item in dataset:
265
+ for k in keys:
266
+ batch[k].append(item[k])
267
  if len(batch[keys[0]]) == batch_size:
268
  batch = {k: jnp.array(v) for k, v in batch.items()}
269
  yield batch
 
324
  return example
325
 
326
 
327
+ def filter_function(
328
+ example,
329
+ min_clip_score,
330
+ max_clip_score,
331
+ clip_score_column,
332
+ filter_column,
333
+ filter_value,
334
+ ):
335
+ if min_clip_score is not None and example[clip_score_column] < min_clip_score:
336
+ return False
337
+ if max_clip_score is not None and example[clip_score_column] > max_clip_score:
338
+ return False
339
+ if filter_column is not None and example[filter_column] != filter_value:
340
+ return False
341
+ return True
342
+
343
+
344
  def preprocess_function(
345
  examples,
346
  tokenizer,
src/dalle_mini/model/__init__.py CHANGED
@@ -1,4 +1,5 @@
1
  from .configuration import DalleBartConfig
2
  from .modeling import DalleBart
3
  from .partitions import set_partitions
 
4
  from .tokenizer import DalleBartTokenizer
 
1
  from .configuration import DalleBartConfig
2
  from .modeling import DalleBart
3
  from .partitions import set_partitions
4
+ from .processor import DalleBartProcessor
5
  from .tokenizer import DalleBartTokenizer
src/dalle_mini/model/modeling.py CHANGED
@@ -18,8 +18,9 @@ import math
18
  import os
19
  from functools import partial
20
  from pickle import UnpicklingError
21
- from typing import Optional, Tuple, Union
22
 
 
23
  import flax.linen as nn
24
  import jax
25
  import jax.numpy as jnp
@@ -39,6 +40,7 @@ from transformers.file_utils import (
39
  is_offline_mode,
40
  is_remote_url,
41
  )
 
42
  from transformers.modeling_flax_outputs import (
43
  FlaxCausalLMOutputWithCrossAttentions,
44
  FlaxSeq2SeqLMOutput,
@@ -691,6 +693,17 @@ class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationMod
691
  )
692
 
693
 
 
 
 
 
 
 
 
 
 
 
 
694
  class DalleBart(
695
  PretrainedFromWandbMixin, FlaxBartPreTrainedModel, FlaxBartForConditionalGeneration
696
  ):
@@ -702,6 +715,7 @@ class DalleBart(
702
  - no bias in decode method
703
  - custom prepare_inputs_for_generation using "max_length - 1" to avoid issues
704
  related to position embedding during model.generate()
 
705
  """
706
 
707
  module_class = FlaxBartForConditionalGenerationModule
@@ -872,3 +886,325 @@ class DalleBart(
872
  "decoder_attention_mask": extended_attention_mask,
873
  "decoder_position_ids": position_ids,
874
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  import os
19
  from functools import partial
20
  from pickle import UnpicklingError
21
+ from typing import Dict, Optional, Tuple, Union
22
 
23
+ import flax
24
  import flax.linen as nn
25
  import jax
26
  import jax.numpy as jnp
 
40
  is_offline_mode,
41
  is_remote_url,
42
  )
43
+ from transformers.generation_flax_utils import FlaxSampleOutput
44
  from transformers.modeling_flax_outputs import (
45
  FlaxCausalLMOutputWithCrossAttentions,
46
  FlaxSeq2SeqLMOutput,
 
693
  )
694
 
695
 
696
+ @flax.struct.dataclass
697
+ class SampleState:
698
+ cur_len: jnp.ndarray
699
+ sequences: jnp.ndarray
700
+ running_token: jnp.ndarray
701
+ is_sent_finished: jnp.ndarray
702
+ prng_key: jnp.ndarray
703
+ model_kwargs: Dict[str, jnp.ndarray]
704
+ model_kwargs_uncond: Dict[str, jnp.ndarray]
705
+
706
+
707
  class DalleBart(
708
  PretrainedFromWandbMixin, FlaxBartPreTrainedModel, FlaxBartForConditionalGeneration
709
  ):
 
715
  - no bias in decode method
716
  - custom prepare_inputs_for_generation using "max_length - 1" to avoid issues
717
  related to position embedding during model.generate()
718
+ - custom generate method to allow super conditions
719
  """
720
 
721
  module_class = FlaxBartForConditionalGenerationModule
 
886
  "decoder_attention_mask": extended_attention_mask,
887
  "decoder_position_ids": position_ids,
888
  }
889
+
890
+ def generate(
891
+ self,
892
+ input_ids: jnp.ndarray,
893
+ attention_mask: Optional[jnp.ndarray] = None,
894
+ max_length: Optional[int] = None,
895
+ pad_token_id: Optional[int] = None,
896
+ bos_token_id: Optional[int] = None,
897
+ eos_token_id: Optional[int] = None,
898
+ decoder_start_token_id: Optional[int] = None,
899
+ do_sample: Optional[bool] = None,
900
+ prng_key: Optional[jnp.ndarray] = None,
901
+ top_k: Optional[int] = None,
902
+ top_p: Optional[float] = None,
903
+ temperature: Optional[float] = None,
904
+ num_beams: Optional[int] = None,
905
+ no_repeat_ngram_size: Optional[int] = None,
906
+ min_length: Optional[int] = None,
907
+ forced_bos_token_id: Optional[int] = None,
908
+ forced_eos_token_id: Optional[int] = None,
909
+ length_penalty: Optional[float] = None,
910
+ early_stopping: Optional[bool] = None,
911
+ trace: bool = True,
912
+ params: Optional[Dict[str, jnp.ndarray]] = None,
913
+ condition_scale: Optional[float] = 1.0,
914
+ input_ids_uncond: Optional[jnp.ndarray] = None,
915
+ attention_mask_uncond: Optional[jnp.ndarray] = None,
916
+ **model_kwargs,
917
+ ):
918
+ """Edit: Allow super conditioning."""
919
+
920
+ # set init values
921
+ max_length = max_length if max_length is not None else self.config.max_length
922
+ bos_token_id = (
923
+ bos_token_id if bos_token_id is not None else self.config.bos_token_id
924
+ )
925
+ pad_token_id = (
926
+ pad_token_id if pad_token_id is not None else self.config.pad_token_id
927
+ )
928
+ eos_token_id = (
929
+ eos_token_id if eos_token_id is not None else self.config.eos_token_id
930
+ )
931
+ decoder_start_token_id = (
932
+ decoder_start_token_id
933
+ if decoder_start_token_id
934
+ else self.config.decoder_start_token_id
935
+ )
936
+ prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
937
+
938
+ if decoder_start_token_id is None and self.config.is_encoder_decoder:
939
+ raise ValueError(
940
+ "`decoder_start_token_id` has to be defined for encoder-decoder generation."
941
+ )
942
+
943
+ do_sample = do_sample if do_sample is not None else self.config.do_sample
944
+ num_beams = num_beams if num_beams is not None else self.config.num_beams
945
+
946
+ if self.config.is_encoder_decoder:
947
+ # add encoder_outputs to model_kwargs
948
+ if model_kwargs.get("encoder_outputs") is None:
949
+ model_kwargs_input = dict(model_kwargs)
950
+ model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
951
+ input_ids,
952
+ params,
953
+ {"attention_mask": attention_mask, **model_kwargs_input},
954
+ )
955
+ if condition_scale != 1.0:
956
+ assert (
957
+ input_ids_uncond is not None
958
+ ), "`input_ids_uncond` has to be defined for super conditioning."
959
+ assert (
960
+ do_sample is True
961
+ ), "`do_sample` has to be True for super conditioning."
962
+ assert (
963
+ num_beams == 1
964
+ ), "`num_beams` has to be 1 for super conditioning."
965
+ model_kwargs_uncond = (
966
+ self._prepare_encoder_decoder_kwargs_for_generation(
967
+ input_ids_uncond,
968
+ params,
969
+ {
970
+ "attention_mask": attention_mask_uncond,
971
+ **model_kwargs_input,
972
+ },
973
+ )
974
+ )
975
+ else:
976
+ model_kwargs_uncond = None
977
+ # prepare decoder_input_ids for generation
978
+ input_ids = (
979
+ jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
980
+ )
981
+
982
+ if not do_sample and num_beams == 1:
983
+ logits_processor = self._get_logits_processor(
984
+ no_repeat_ngram_size,
985
+ min_length,
986
+ max_length,
987
+ eos_token_id,
988
+ forced_bos_token_id,
989
+ forced_eos_token_id,
990
+ )
991
+ return self._greedy_search(
992
+ input_ids,
993
+ max_length,
994
+ pad_token_id,
995
+ eos_token_id,
996
+ logits_processor=logits_processor,
997
+ trace=trace,
998
+ params=params,
999
+ model_kwargs=model_kwargs,
1000
+ )
1001
+ elif do_sample and num_beams == 1:
1002
+ logits_warper = self._get_logits_warper(
1003
+ top_k=top_k, top_p=top_p, temperature=temperature
1004
+ )
1005
+ logits_processor = self._get_logits_processor(
1006
+ no_repeat_ngram_size,
1007
+ min_length,
1008
+ max_length,
1009
+ eos_token_id,
1010
+ forced_bos_token_id,
1011
+ forced_eos_token_id,
1012
+ )
1013
+ return self._sample(
1014
+ input_ids,
1015
+ max_length,
1016
+ pad_token_id,
1017
+ eos_token_id,
1018
+ prng_key,
1019
+ logits_warper=logits_warper,
1020
+ logits_processor=logits_processor,
1021
+ trace=trace,
1022
+ params=params,
1023
+ model_kwargs=model_kwargs,
1024
+ condition_scale=condition_scale,
1025
+ model_kwargs_uncond=model_kwargs_uncond,
1026
+ )
1027
+ elif not do_sample and num_beams > 1:
1028
+ # broadcast input_ids & encoder_outputs
1029
+ input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams)
1030
+
1031
+ if "encoder_outputs" in model_kwargs:
1032
+ model_kwargs["encoder_outputs"][
1033
+ "last_hidden_state"
1034
+ ] = self._expand_to_num_beams(
1035
+ model_kwargs["encoder_outputs"]["last_hidden_state"],
1036
+ num_beams=num_beams,
1037
+ )
1038
+
1039
+ if "attention_mask" in model_kwargs:
1040
+ model_kwargs["attention_mask"] = self._expand_to_num_beams(
1041
+ model_kwargs["attention_mask"], num_beams=num_beams
1042
+ )
1043
+
1044
+ logits_processor = self._get_logits_processor(
1045
+ no_repeat_ngram_size,
1046
+ min_length,
1047
+ max_length,
1048
+ eos_token_id,
1049
+ forced_bos_token_id,
1050
+ forced_eos_token_id,
1051
+ )
1052
+
1053
+ return self._beam_search(
1054
+ input_ids,
1055
+ max_length,
1056
+ pad_token_id,
1057
+ eos_token_id,
1058
+ length_penalty=length_penalty,
1059
+ early_stopping=early_stopping,
1060
+ logits_processor=logits_processor,
1061
+ trace=trace,
1062
+ params=params,
1063
+ model_kwargs=model_kwargs,
1064
+ )
1065
+ else:
1066
+ raise NotImplementedError("`Beam sampling is currently not implemented.")
1067
+
1068
+ def _sample(
1069
+ self,
1070
+ input_ids: None,
1071
+ max_length: Optional[int] = None,
1072
+ pad_token_id: Optional[int] = None,
1073
+ eos_token_id: Optional[int] = None,
1074
+ prng_key: Optional[jnp.ndarray] = None,
1075
+ logits_processor=None,
1076
+ logits_warper=None,
1077
+ trace: bool = True,
1078
+ params: Optional[Dict[str, jnp.ndarray]] = None,
1079
+ model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
1080
+ condition_scale: float = 1.0,
1081
+ model_kwargs_uncond: Optional[Dict[str, jnp.ndarray]] = None,
1082
+ ):
1083
+ # init values
1084
+ max_length = max_length if max_length is not None else self.config.max_length
1085
+ pad_token_id = (
1086
+ pad_token_id if pad_token_id is not None else self.config.pad_token_id
1087
+ )
1088
+ eos_token_id = (
1089
+ eos_token_id if eos_token_id is not None else self.config.eos_token_id
1090
+ )
1091
+ prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
1092
+
1093
+ batch_size, cur_len = input_ids.shape
1094
+
1095
+ eos_token_id = jnp.array(eos_token_id)
1096
+ pad_token_id = jnp.array(pad_token_id)
1097
+ cur_len = jnp.array(cur_len)
1098
+
1099
+ # per batch-item holding current token in loop.
1100
+ sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
1101
+ sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))
1102
+
1103
+ # per batch-item state bit indicating if sentence has finished.
1104
+ is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
1105
+
1106
+ # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
1107
+ # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
1108
+ model = self.decode if self.config.is_encoder_decoder else self
1109
+
1110
+ # initialize model specific kwargs
1111
+ model_kwargs = self.prepare_inputs_for_generation(
1112
+ input_ids, max_length, **model_kwargs
1113
+ )
1114
+ if condition_scale != 1.0:
1115
+ model_kwargs_uncond = self.prepare_inputs_for_generation(
1116
+ input_ids, max_length, **model_kwargs_uncond
1117
+ )
1118
+
1119
+ # initialize state
1120
+ state = SampleState(
1121
+ cur_len=cur_len,
1122
+ sequences=sequences,
1123
+ running_token=input_ids,
1124
+ is_sent_finished=is_sent_finished,
1125
+ prng_key=prng_key,
1126
+ model_kwargs=model_kwargs,
1127
+ model_kwargs_uncond=model_kwargs_uncond,
1128
+ )
1129
+
1130
+ def sample_search_cond_fn(state):
1131
+ """state termination condition fn."""
1132
+ has_reached_max_length = state.cur_len == max_length
1133
+ all_sequence_finished = jnp.all(state.is_sent_finished)
1134
+ finish_generation = jnp.logical_or(
1135
+ has_reached_max_length, all_sequence_finished
1136
+ )
1137
+ return ~finish_generation
1138
+
1139
+ def sample_search_body_fn(state):
1140
+ """state update fn."""
1141
+ prng_key, prng_key_next = jax.random.split(state.prng_key)
1142
+ model_outputs = model(
1143
+ state.running_token, params=params, **state.model_kwargs
1144
+ )
1145
+
1146
+ logits = model_outputs.logits[:, -1]
1147
+
1148
+ # perform super conditioning
1149
+ # Source: @RiversHaveWings - https://twitter.com/RiversHaveWings/status/1478093658716966912?s=20&t=xdm-wZ61Wf7OLnE_NJHZ1w
1150
+ if condition_scale != 1.0:
1151
+ model_outputs_uncond = model(
1152
+ state.running_token, params=params, **state.model_kwargs_uncond
1153
+ )
1154
+ logits_uncond = model_outputs_uncond.logits[:, -1]
1155
+ logits = logits_uncond + condition_scale * (logits - logits_uncond)
1156
+ else:
1157
+ model_outputs_uncond = None
1158
+
1159
+ # apply min_length, ...
1160
+ logits = logits_processor(state.sequences, logits, state.cur_len)
1161
+ # apply top_k, top_k, temperature
1162
+ logits = logits_warper(logits, logits, state.cur_len)
1163
+
1164
+ next_token = jax.random.categorical(prng_key, logits, axis=-1)
1165
+
1166
+ next_is_sent_finished = state.is_sent_finished | (
1167
+ next_token == eos_token_id
1168
+ )
1169
+ next_token = (
1170
+ next_token * ~next_is_sent_finished
1171
+ + pad_token_id * next_is_sent_finished
1172
+ )
1173
+ next_token = next_token[:, None]
1174
+
1175
+ next_sequences = lax.dynamic_update_slice(
1176
+ state.sequences, next_token, (0, state.cur_len)
1177
+ )
1178
+ next_model_kwargs = self.update_inputs_for_generation(
1179
+ model_outputs, state.model_kwargs
1180
+ )
1181
+ next_model_kwargs_uncond = (
1182
+ self.update_inputs_for_generation(
1183
+ model_outputs_uncond, state.model_kwargs_uncond
1184
+ )
1185
+ if condition_scale != 1.0
1186
+ else None
1187
+ )
1188
+
1189
+ return SampleState(
1190
+ cur_len=state.cur_len + 1,
1191
+ sequences=next_sequences,
1192
+ running_token=next_token,
1193
+ is_sent_finished=next_is_sent_finished,
1194
+ model_kwargs=next_model_kwargs,
1195
+ model_kwargs_uncond=next_model_kwargs_uncond,
1196
+ prng_key=prng_key_next,
1197
+ )
1198
+
1199
+ # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
1200
+ if input_ids.shape[1] > 1:
1201
+ state = sample_search_body_fn(state)
1202
+
1203
+ if not trace:
1204
+ state = self._run_loop_in_debug(
1205
+ sample_search_cond_fn, sample_search_body_fn, state
1206
+ )
1207
+ else:
1208
+ state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state)
1209
+
1210
+ return FlaxSampleOutput(sequences=state.sequences)
src/dalle_mini/model/processor.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ DalleBart processor """
2
+
3
+ import jax.numpy as jnp
4
+
5
+ from .configuration import DalleBartConfig
6
+ from .text import TextNormalizer
7
+ from .tokenizer import DalleBartTokenizer
8
+ from .utils import PretrainedFromWandbMixin
9
+
10
+
11
+ class DalleBartProcessorBase:
12
+ def __init__(
13
+ self, tokenizer: DalleBartTokenizer, normalize_text: bool, max_text_length: int
14
+ ):
15
+ self.tokenizer = tokenizer
16
+ self.normalize_text = normalize_text
17
+ self.max_text_length = max_text_length
18
+ if normalize_text:
19
+ self.text_processor = TextNormalizer()
20
+ # create unconditional tokens
21
+ uncond = self.tokenizer(
22
+ "",
23
+ return_tensors="jax",
24
+ padding="max_length",
25
+ truncation=True,
26
+ max_length=self.max_text_length,
27
+ ).data
28
+ self.input_ids_uncond = uncond["input_ids"]
29
+ self.attention_mask_uncond = uncond["attention_mask"]
30
+
31
+ def __call__(self, text: str = None):
32
+ # check that text is not a string
33
+ assert not isinstance(text, str), "text must be a list of strings"
34
+
35
+ if self.normalize_text:
36
+ text = [self.text_processor(t) for t in text]
37
+ res = self.tokenizer(
38
+ text,
39
+ return_tensors="jax",
40
+ padding="max_length",
41
+ truncation=True,
42
+ max_length=self.max_text_length,
43
+ ).data
44
+ # tokens used only with super conditioning
45
+ n = len(text)
46
+ res["input_ids_uncond"] = jnp.repeat(self.input_ids_uncond, n, axis=0)
47
+ res["attention_mask_uncond"] = jnp.repeat(self.attention_mask_uncond, n, axis=0)
48
+ return res
49
+
50
+ @classmethod
51
+ def from_pretrained(cls, *args, **kwargs):
52
+ tokenizer = DalleBartTokenizer.from_pretrained(*args, **kwargs)
53
+ config = DalleBartConfig.from_pretrained(*args, **kwargs)
54
+ return cls(tokenizer, config.normalize_text, config.max_text_length)
55
+
56
+
57
+ class DalleBartProcessor(PretrainedFromWandbMixin, DalleBartProcessorBase):
58
+ pass
src/dalle_mini/{text.py → model/text.py} RENAMED
File without changes
tools/inference/inference_pipeline.ipynb CHANGED
@@ -75,7 +75,7 @@
75
  "# Model references\n",
76
  "\n",
77
  "# dalle-mini\n",
78
- "DALLE_MODEL = \"dalle-mini/dalle-mini/model-1reghx5l:latest\" # can be wandb artifact or 🤗 Hub or local folder\n",
79
  "DALLE_COMMIT_ID = None\n",
80
  "\n",
81
  "# VQGAN model\n",
@@ -126,7 +126,7 @@
126
  "outputs": [],
127
  "source": [
128
  "# Load models & tokenizer\n",
129
- "from dalle_mini.model import DalleBart, DalleBartTokenizer\n",
130
  "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
131
  "from transformers import CLIPProcessor, FlaxCLIPModel\n",
132
  "\n",
@@ -134,14 +134,13 @@
134
  "model = DalleBart.from_pretrained(\n",
135
  " DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=dtype, abstract_init=True\n",
136
  ")\n",
137
- "tokenizer = DalleBartTokenizer.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)\n",
138
  "\n",
139
  "# Load VQGAN\n",
140
  "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
141
  "\n",
142
  "# Load CLIP\n",
143
  "clip = FlaxCLIPModel.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n",
144
- "processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)"
145
  ]
146
  },
147
  {
@@ -192,17 +191,18 @@
192
  "from functools import partial\n",
193
  "\n",
194
  "# model inference\n",
195
- "@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4))\n",
196
- "def p_generate(tokenized_prompt, key, params, top_k, top_p):\n",
 
 
197
  " return model.generate(\n",
198
  " **tokenized_prompt,\n",
199
- " do_sample=True,\n",
200
- " num_beams=1,\n",
201
  " prng_key=key,\n",
202
  " params=params,\n",
203
  " top_k=top_k,\n",
204
  " top_p=top_p,\n",
205
- " max_length=257\n",
 
206
  " )\n",
207
  "\n",
208
  "\n",
@@ -258,7 +258,7 @@
258
  "id": "rsmj0Aj5OQox"
259
  },
260
  "source": [
261
- "Our model may require to normalize the prompt."
262
  ]
263
  },
264
  {
@@ -269,9 +269,9 @@
269
  },
270
  "outputs": [],
271
  "source": [
272
- "from dalle_mini.text import TextNormalizer\n",
273
  "\n",
274
- "text_normalizer = TextNormalizer() if model.config.normalize_text else None"
275
  ]
276
  },
277
  {
@@ -291,7 +291,7 @@
291
  },
292
  "outputs": [],
293
  "source": [
294
- "prompt = \"view of the beach during sunset\""
295
  ]
296
  },
297
  {
@@ -302,34 +302,7 @@
302
  },
303
  "outputs": [],
304
  "source": [
305
- "processed_prompt = text_normalizer(prompt) if model.config.normalize_text else prompt\n",
306
- "processed_prompt"
307
- ]
308
- },
309
- {
310
- "cell_type": "markdown",
311
- "metadata": {
312
- "id": "QUzYACWxOe5z"
313
- },
314
- "source": [
315
- "We tokenize the prompt."
316
- ]
317
- },
318
- {
319
- "cell_type": "code",
320
- "execution_count": null,
321
- "metadata": {
322
- "id": "n8e7MvGwOe5z"
323
- },
324
- "outputs": [],
325
- "source": [
326
- "tokenized_prompt = tokenizer(\n",
327
- " processed_prompt,\n",
328
- " return_tensors=\"jax\",\n",
329
- " padding=\"max_length\",\n",
330
- " truncation=True,\n",
331
- " max_length=128,\n",
332
- ").data\n",
333
  "tokenized_prompt"
334
  ]
335
  },
@@ -390,7 +363,9 @@
390
  "\n",
391
  "# We can customize top_k/top_p used for generating samples\n",
392
  "gen_top_k = None\n",
393
- "gen_top_p = None"
 
 
394
  ]
395
  },
396
  {
@@ -413,7 +388,13 @@
413
  " key, subkey = jax.random.split(key)\n",
414
  " # generate images\n",
415
  " encoded_images = p_generate(\n",
416
- " tokenized_prompt, shard_prng_key(subkey), model.params, gen_top_k, gen_top_p\n",
 
 
 
 
 
 
417
  " )\n",
418
  " # remove BOS\n",
419
  " encoded_images = encoded_images.sequences[..., 1:]\n",
@@ -444,7 +425,7 @@
444
  "from flax.training.common_utils import shard\n",
445
  "\n",
446
  "# get clip scores\n",
447
- "clip_inputs = processor(\n",
448
  " text=[prompt] * jax.device_count(),\n",
449
  " images=images,\n",
450
  " return_tensors=\"np\",\n",
 
75
  "# Model references\n",
76
  "\n",
77
  "# dalle-mini\n",
78
+ "DALLE_MODEL = \"dalle-mini/dalle-mini/model-2vm4itcx:latest\" # can be wandb artifact or 🤗 Hub or local folder or google bucket\n",
79
  "DALLE_COMMIT_ID = None\n",
80
  "\n",
81
  "# VQGAN model\n",
 
126
  "outputs": [],
127
  "source": [
128
  "# Load models & tokenizer\n",
129
+ "from dalle_mini import DalleBart, DalleBartProcessor\n",
130
  "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
131
  "from transformers import CLIPProcessor, FlaxCLIPModel\n",
132
  "\n",
 
134
  "model = DalleBart.from_pretrained(\n",
135
  " DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=dtype, abstract_init=True\n",
136
  ")\n",
 
137
  "\n",
138
  "# Load VQGAN\n",
139
  "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
140
  "\n",
141
  "# Load CLIP\n",
142
  "clip = FlaxCLIPModel.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n",
143
+ "clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)"
144
  ]
145
  },
146
  {
 
191
  "from functools import partial\n",
192
  "\n",
193
  "# model inference\n",
194
+ "@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4, 5, 6))\n",
195
+ "def p_generate(\n",
196
+ " tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale\n",
197
+ "):\n",
198
  " return model.generate(\n",
199
  " **tokenized_prompt,\n",
 
 
200
  " prng_key=key,\n",
201
  " params=params,\n",
202
  " top_k=top_k,\n",
203
  " top_p=top_p,\n",
204
+ " temperature=temperature,\n",
205
+ " condition_scale=condition_scale,\n",
206
  " )\n",
207
  "\n",
208
  "\n",
 
258
  "id": "rsmj0Aj5OQox"
259
  },
260
  "source": [
261
+ "Our model requires processing prompts."
262
  ]
263
  },
264
  {
 
269
  },
270
  "outputs": [],
271
  "source": [
272
+ "from dalle_mini import DalleBartProcessor\n",
273
  "\n",
274
+ "processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)"
275
  ]
276
  },
277
  {
 
291
  },
292
  "outputs": [],
293
  "source": [
294
+ "prompt = \"a blue table\""
295
  ]
296
  },
297
  {
 
302
  },
303
  "outputs": [],
304
  "source": [
305
+ "tokenized_prompt = processor([prompt])\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  "tokenized_prompt"
307
  ]
308
  },
 
363
  "\n",
364
  "# We can customize top_k/top_p used for generating samples\n",
365
  "gen_top_k = None\n",
366
+ "gen_top_p = None\n",
367
+ "temperature = 0.85\n",
368
+ "cond_scale = 3.0"
369
  ]
370
  },
371
  {
 
388
  " key, subkey = jax.random.split(key)\n",
389
  " # generate images\n",
390
  " encoded_images = p_generate(\n",
391
+ " tokenized_prompt,\n",
392
+ " shard_prng_key(subkey),\n",
393
+ " model.params,\n",
394
+ " gen_top_k,\n",
395
+ " gen_top_p,\n",
396
+ " temperature,\n",
397
+ " cond_scale,\n",
398
  " )\n",
399
  " # remove BOS\n",
400
  " encoded_images = encoded_images.sequences[..., 1:]\n",
 
425
  "from flax.training.common_utils import shard\n",
426
  "\n",
427
  "# get clip scores\n",
428
+ "clip_inputs = clip_processor(\n",
429
  " text=[prompt] * jax.device_count(),\n",
430
  " images=images,\n",
431
  " return_tensors=\"np\",\n",
tools/train/train.py CHANGED
@@ -103,7 +103,7 @@ class ModelArguments:
103
 
104
  def __post_init__(self):
105
  if self.tokenizer_name is None:
106
- self.tokenizer_name == self.model_name_or_path
107
  assert (
108
  self.tokenizer_name is not None
109
  ), "Tokenizer name or model name/path needs to be specified"
@@ -209,6 +209,26 @@ class DataTrainingArguments:
209
  "help": "Probability of removing some captions for classifier-free guidance."
210
  },
211
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  max_train_samples: Optional[int] = field(
213
  default=None,
214
  metadata={
 
103
 
104
  def __post_init__(self):
105
  if self.tokenizer_name is None:
106
+ self.tokenizer_name = self.model_name_or_path
107
  assert (
108
  self.tokenizer_name is not None
109
  ), "Tokenizer name or model name/path needs to be specified"
 
209
  "help": "Probability of removing some captions for classifier-free guidance."
210
  },
211
  )
212
+ clip_score_column: Optional[str] = field(
213
+ default="clip_score",
214
+ metadata={"help": "Column that containts clip score for filtering."},
215
+ )
216
+ min_clip_score: Optional[float] = field(
217
+ default=None,
218
+ metadata={"help": "Minimum clip score required."},
219
+ )
220
+ max_clip_score: Optional[float] = field(
221
+ default=None,
222
+ metadata={"help": "Maximum clip score required."},
223
+ )
224
+ filter_column: Optional[str] = field(
225
+ default=None,
226
+ metadata={"help": "Column that containts classes to be filtered."},
227
+ )
228
+ filter_value: Optional[str] = field(
229
+ default=None,
230
+ metadata={"help": "Class value to be kept during filtering."},
231
+ )
232
  max_train_samples: Optional[int] = field(
233
  default=None,
234
  metadata={