In [None]:
import tempfile
from functools import partial
import random
import numpy as np
from PIL import Image
from tqdm.notebook import tqdm
import jax
import jax.numpy as jnp
from flax.training.common_utils import shard, shard_prng_key
from flax.jax_utils import replicate
import wandb
from dalle_mini.model import CustomFlaxBartForConditionalGeneration
from vqgan_jax.modeling_flax_vqgan import VQModel
from transformers import BartTokenizer, CLIPProcessor, FlaxCLIPModel
from dalle_mini.text import TextNormalizer

In [None]:
run_ids = ['63otg87g']
ENTITY, PROJECT = 'dalle-mini', 'dalle-mini' # used only for training run
VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', 'e93a26e7707683d349bf5d5c41c5b0ef69b677a9'
latest_only = True # log only latest or all versions
suffix = '' # mainly for duplicate inference runs with a deleted version
add_clip_32 = False

In [None]:
# model.generate parameters - Not used yet
gen_top_k = None
gen_top_p = None
temperature = None

In [None]:
batch_size = 8
num_images = 128
top_k = 8
text_normalizer = TextNormalizer()
padding_item = 'NONE'
seed = random.randint(0, 2**32-1)
key = jax.random.PRNGKey(seed)
api = wandb.Api()

In [None]:
vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
vqgan_params = replicate(vqgan.params)

clip16 = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch16")
processor16 = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
clip16_params = replicate(clip16.params)

if add_clip_32:
 clip32 = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
 processor32 = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
 clip32_params = replicate(clip32.params)

In [None]:
@partial(jax.pmap, axis_name="batch")
def p_decode(indices, params):
 return vqgan.decode_code(indices, params=params)

@partial(jax.pmap, axis_name="batch")
def p_clip16(inputs, params):
 logits = clip16(params=params, **inputs).logits_per_image
 return logits

if add_clip_32:
 @partial(jax.pmap, axis_name="batch")
 def p_clip32(inputs, params):
 logits = clip32(params=params, **inputs).logits_per_image
 return logits

In [None]:
with open('samples.txt', encoding='utf8') as f:
 samples = [l.strip() for l in f.readlines()]
 # make list multiple of batch_size by adding elements
 samples_to_add = [padding_item] * (-len(samples) % batch_size)
 samples.extend(samples_to_add)
 # reshape
 samples = [samples[i:i+batch_size] for i in range(0, len(samples), batch_size)]

In [None]:
def get_artifact_versions(run_id, latest_only=False):
 try:
 if latest_only:
 return [api.artifact(type='bart_model', name=f'{ENTITY}/{PROJECT}/model-{run_id}:latest')]
 else:
 return api.artifact_versions(type_name='bart_model', name=f'{ENTITY}/{PROJECT}/model-{run_id}', per_page=10000)
 except:
 return []

In [None]:
def get_training_config(run_id):
 training_run = api.run(f'{ENTITY}/{PROJECT}/{run_id}')
 config = training_run.config
 return config

In [None]:
# retrieve inference run details
def get_last_inference_version(run_id):
 try:
 inference_run = api.run(f'dalle-mini/dalle-mini/{run_id}-clip16{suffix}')
 return inference_run.summary.get('version', None)
 except:
 return None

In [None]:
# compile functions - needed only once per run
def pmap_model_function(model):
 
 @partial(jax.pmap, axis_name="batch")
 def _generate(tokenized_prompt, key, params):
 return model.generate(
 **tokenized_prompt,
 do_sample=True,
 num_beams=1,
 prng_key=key,
 params=params,
 top_k=gen_top_k,
 top_p=gen_top_p
 )
 
 return _generate

In [None]:
run_id = run_ids[0]
# TODO: loop over runs

In [None]:
artifact_versions = get_artifact_versions(run_id, latest_only)
last_inference_version = get_last_inference_version(run_id)
training_config = get_training_config(run_id)
run = None
p_generate = None
model_files = ['config.json', 'flax_model.msgpack', 'merges.txt', 'special_tokens_map.json', 'tokenizer.json', 'tokenizer_config.json', 'vocab.json']
for artifact in artifact_versions:
 print(f'Processing artifact: {artifact.name}')
 version = int(artifact.version[1:])
 results16, results32 = [], []
 columns = ['Caption'] + [f'Image {i+1}' for i in range(top_k)]
 
 if latest_only:
 assert last_inference_version is None or version > last_inference_version
 else:
 if last_inference_version is None:
 # we should start from v0
 assert version == 0
 elif version <= last_inference_version:
 print(f'v{version} has already been logged (versions logged up to v{last_inference_version}')
 else:
 # check we are logging the correct version
 assert version == last_inference_version + 1

 # start/resume corresponding run
 if run is None:
 run = wandb.init(job_type='inference', entity='dalle-mini', project='dalle-mini', config=training_config, id=f'{run_id}-clip16{suffix}', resume='allow')

 # work in temporary directory
 with tempfile.TemporaryDirectory() as tmp:

 # download model files
 artifact = run.use_artifact(artifact)
 for f in model_files:
 artifact.get_path(f).download(tmp)

 # load tokenizer and model
 tokenizer = BartTokenizer.from_pretrained(tmp)
 model = CustomFlaxBartForConditionalGeneration.from_pretrained(tmp)
 model_params = replicate(model.params)

 # pmap model function needs to happen only once per model config
 if p_generate is None:
 p_generate = pmap_model_function(model)

 # process one batch of captions
 for batch in tqdm(samples):
 processed_prompts = [text_normalizer(x) for x in batch] if model.config.normalize_text else list(batch)

 # repeat the prompts to distribute over each device and tokenize
 processed_prompts = processed_prompts * jax.device_count()
 tokenized_prompt = tokenizer(processed_prompts, return_tensors='jax', padding='max_length', truncation=True, max_length=128).data
 tokenized_prompt = shard(tokenized_prompt)

 # generate images
 images = []
 pbar = tqdm(range(num_images // jax.device_count()), desc='Generating Images', leave=True)
 for i in pbar:
 key, subkey = jax.random.split(key)
 encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), model_params)
 encoded_images = encoded_images.sequences[..., 1:]
 decoded_images = p_decode(encoded_images, vqgan_params)
 decoded_images = decoded_images.clip(0., 1.).reshape((-1, 256, 256, 3))
 for img in decoded_images:
 images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))

 def add_clip_results(results, processor, p_clip, clip_params): 
 clip_inputs = processor(text=batch, images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data
 # each shard will have one prompt, images need to be reorganized to be associated to the correct shard
 images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))
 clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))
 clip_inputs = shard(clip_inputs)
 logits = p_clip(clip_inputs, clip_params)
 logits = logits.reshape(-1, num_images)
 top_scores = logits.argsort()[:, -top_k:][..., ::-1]
 logits = jax.device_get(logits)
 # add to results table
 for i, (idx, scores, sample) in enumerate(zip(top_scores, logits, batch)):
 if sample == padding_item: continue
 cur_images = [images[x] for x in images_per_prompt_indices + i]
 top_images = [wandb.Image(cur_images[x], caption=f'Score: {scores[x]:.2f}') for x in idx]
 results.append([sample] + top_images)
 
 # get clip scores
 pbar.set_description('Calculating CLIP 16 scores')
 add_clip_results(results16, processor16, p_clip16, clip16_params)
 
 # get clip 32 scores
 if add_clip_32:
 pbar.set_description('Calculating CLIP 32 scores')
 add_clip_results(results32, processor32, p_clip32, clip32_params)

 pbar.close()

 

 # log results
 table = wandb.Table(columns=columns, data=results16)
 run.log({'Samples': table, 'version': version})
 wandb.finish()
 
 if add_clip_32: 
 run = wandb.init(job_type='inference', entity='dalle-mini', project='dalle-mini', config=training_config, id=f'{run_id}-clip32{suffix}', resume='allow')
 table = wandb.Table(columns=columns, data=results32)
 run.log({'Samples': table, 'version': version})
 wandb.finish()
 run = None # ensure we don't log on this run