boris commited on
Commit
c7776fb
1 Parent(s): 1d8a799

feat(wandb-examples): use model file

Browse files
Files changed (1) hide show
  1. dev/predictions/wandb-examples.py +15 -56
dev/predictions/wandb-examples.py CHANGED
@@ -4,16 +4,14 @@
4
  import random
5
 
6
  import jax
7
- import flax.linen as nn
8
  from flax.training.common_utils import shard
9
  from flax.jax_utils import replicate, unreplicate
10
 
11
  from transformers.models.bart.modeling_flax_bart import *
12
  from transformers import BartTokenizer, FlaxBartForConditionalGeneration
13
 
14
- import io
15
 
16
- import requests
17
  from PIL import Image
18
  import numpy as np
19
  import matplotlib.pyplot as plt
@@ -23,58 +21,24 @@ import torchvision.transforms as T
23
  import torchvision.transforms.functional as TF
24
  from torchvision.transforms import InterpolationMode
25
 
 
26
  from vqgan_jax.modeling_flax_vqgan import VQModel
27
 
28
- # TODO: set those args in a config file
29
- OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
30
- OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
31
- BOS_TOKEN_ID = 16384
32
- BASE_MODEL = 'facebook/bart-large-cnn'
33
-
34
- class CustomFlaxBartModule(FlaxBartModule):
35
- def setup(self):
36
- # we keep shared to easily load pre-trained weights
37
- self.shared = nn.Embed(
38
- self.config.vocab_size,
39
- self.config.d_model,
40
- embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
41
- dtype=self.dtype,
42
- )
43
- # a separate embedding is used for the decoder
44
- self.decoder_embed = nn.Embed(
45
- OUTPUT_VOCAB_SIZE,
46
- self.config.d_model,
47
- embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
48
- dtype=self.dtype,
49
- )
50
- self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
51
-
52
- # the decoder has a different config
53
- decoder_config = BartConfig(self.config.to_dict())
54
- decoder_config.max_position_embeddings = OUTPUT_LENGTH
55
- decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
56
- self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
57
-
58
- class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
59
- def setup(self):
60
- self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
61
- self.lm_head = nn.Dense(
62
- OUTPUT_VOCAB_SIZE,
63
- use_bias=False,
64
- dtype=self.dtype,
65
- kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
66
- )
67
- self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, OUTPUT_VOCAB_SIZE))
68
-
69
- class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
70
- module_class = CustomFlaxBartForConditionalGenerationModule
71
-
72
 
73
  import wandb
74
  import os
 
 
 
 
75
  os.environ["WANDB_SILENT"] = "true"
76
  os.environ["WANDB_CONSOLE"] = "off"
77
 
 
 
 
78
  # set id to None so our latest images don't get overwritten
79
  id = None
80
  run = wandb.init(id=id,
@@ -87,8 +51,10 @@ artifact = run.use_artifact('wandb/hf-flax-dalle-mini/model-4oh3u7ca:latest', ty
87
  artifact_dir = artifact.download()
88
 
89
  # create our model
90
- tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)
91
  model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
 
 
 
92
  model.config.force_bos_token_to_be_generated = False
93
  model.config.forced_bos_token_id = None
94
  model.config.forced_eos_token_id = None
@@ -143,9 +109,6 @@ p_get_images = jax.pmap(get_images, "batch")
143
  bart_params = replicate(model.params)
144
  vqgan_params = replicate(vqgan.params)
145
 
146
- # ## CLIP Scoring
147
- from transformers import CLIPProcessor, FlaxCLIPModel
148
-
149
  clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
150
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
151
 
@@ -170,16 +133,12 @@ def hallucinate(prompt, num_images=64):
170
 
171
  def clip_top_k(prompt, images, k=8):
172
  inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
 
173
  outputs = clip(**inputs)
174
  logits = outputs.logits_per_text
175
  scores = np.array(logits[0]).argsort()[-k:][::-1]
176
  return [images[score] for score in scores]
177
 
178
-
179
- # ## Log to wandb
180
-
181
- from dalle_mini.helpers import captioned_strip
182
-
183
  def log_to_wandb(prompts):
184
  strips = []
185
  for prompt in prompts:
 
4
  import random
5
 
6
  import jax
 
7
  from flax.training.common_utils import shard
8
  from flax.jax_utils import replicate, unreplicate
9
 
10
  from transformers.models.bart.modeling_flax_bart import *
11
  from transformers import BartTokenizer, FlaxBartForConditionalGeneration
12
 
13
+ import os
14
 
 
15
  from PIL import Image
16
  import numpy as np
17
  import matplotlib.pyplot as plt
 
21
  import torchvision.transforms.functional as TF
22
  from torchvision.transforms import InterpolationMode
23
 
24
+ from dalle_mini.model import CustomFlaxBartForConditionalGeneration
25
  from vqgan_jax.modeling_flax_vqgan import VQModel
26
 
27
+ # ## CLIP Scoring
28
+ from transformers import CLIPProcessor, FlaxCLIPModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  import wandb
31
  import os
32
+
33
+ from dalle_mini.helpers import captioned_strip
34
+
35
+
36
  os.environ["WANDB_SILENT"] = "true"
37
  os.environ["WANDB_CONSOLE"] = "off"
38
 
39
+ # TODO: used for legacy support
40
+ BASE_MODEL = 'facebook/bart-large-cnn'
41
+
42
  # set id to None so our latest images don't get overwritten
43
  id = None
44
  run = wandb.init(id=id,
 
51
  artifact_dir = artifact.download()
52
 
53
  # create our model
 
54
  model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
55
+
56
+ # TODO: legacy support (earlier models)
57
+ tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)
58
  model.config.force_bos_token_to_be_generated = False
59
  model.config.forced_bos_token_id = None
60
  model.config.forced_eos_token_id = None
 
109
  bart_params = replicate(model.params)
110
  vqgan_params = replicate(vqgan.params)
111
 
 
 
 
112
  clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
113
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
114
 
 
133
 
134
  def clip_top_k(prompt, images, k=8):
135
  inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
136
+ # FIXME: image should be resized and normalized prior to being processed by CLIP
137
  outputs = clip(**inputs)
138
  logits = outputs.logits_per_text
139
  scores = np.array(logits[0]).argsort()[-k:][::-1]
140
  return [images[score] for score in scores]
141
 
 
 
 
 
 
142
  def log_to_wandb(prompts):
143
  strips = []
144
  for prompt in prompts: