text-to-image-generation / inference.py
fedorajuandy's picture
rn wandb project
140cc8d
""" Generate and return image adapted from DALL-E mini's playground """
import random
from functools import partial
import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard_prng_key, shard
from vqgan_jax.modeling_flax_vqgan import VQModel
import numpy as np
from PIL import Image
from tqdm.notebook import trange
from dalle_mini import DalleBart, DalleBartProcessor
from transformers import CLIPProcessor, FlaxCLIPModel
import wandb
import os
wandb.login(key=os.environ["wandb"])
# Model to generate image tokens
MODEL = "fedorajuandy/tugas-akhir/model-jhhchemc:v11"
MODEL_COMMIT_ID = "None"
# VQGAN to decode image tokens
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
# number of predictions; split per device
N_PREDICTIONS = 8
# generetion parameters
GEN_TOP_K = None
GEN_TOP_P = None
TEMPERATURE = None
COND_SCALE = 10.0
# CLIP
CLIP_REPO = "openai/clip-vit-base-patch32"
CLIP_COMMIT_ID = None
# Load models, not randomised
model, model_params = DalleBart.from_pretrained(
MODEL, revision=MODEL_COMMIT_ID, dtype=jnp.float32, _do_init=False
)
# To process text
processor = DalleBartProcessor.from_pretrained(
MODEL, revision=MODEL_COMMIT_ID
)
vqgan, vqgan_params = VQModel.from_pretrained(
VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False
)
clip, clip_params = FlaxCLIPModel.from_pretrained(
CLIP_REPO, revision=CLIP_COMMIT_ID, dtype=jnp.float16, _do_init=False
)
# To process text and image
clip_processor = CLIPProcessor.from_pretrained(
CLIP_REPO, revision=CLIP_COMMIT_ID
)
# Replicate parameters to each device
model_params = replicate(model_params)
vqgan_params = replicate(vqgan_params)
clip_params = replicate(clip_params)
# Functions are compiled and parallelised to each device
@partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
def p_generate(tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale):
""" Model inference """
return model.generate(
**tokenized_prompt,
prng_key=key,
params=params,
top_k=top_k,
top_p=top_p,
temperature=temperature,
condition_scale=condition_scale,
)
@partial(jax.pmap, axis_name="batch")
def p_decode(indices, params):
""" Decode image tokens """
return vqgan.decode_code(indices, params=params)
# Score images
@partial(jax.pmap, axis_name="batch")
def p_clip(inputs, params):
""" Return logits, wutever dat is """
logits = clip(params=params, **inputs).logits_per_image
return logits
def generate_image(text_prompt):
""" Take text prompt and return generated image """
# Generate key that is passed to each device to generate different images
seed = random.randint(0, 2**32 - 1)
key = jax.random.PRNGKey(seed)
texts = [text_prompt]
tokenized_prompts = processor(texts)
tokenized_prompt = replicate(tokenized_prompts)
# Generate images
images = []
for i in trange(max(N_PREDICTIONS // jax.device_count(), 1)):
# Get a new key
key, subkey = jax.random.split(key)
encoded_images = p_generate(
tokenized_prompt,
shard_prng_key(subkey),
model_params,
GEN_TOP_K,
GEN_TOP_P,
TEMPERATURE,
COND_SCALE,
)
# Remove BOS token
encoded_images = encoded_images.sequences[..., 1:]
decoded_images = p_decode(encoded_images, vqgan_params)
decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
for decoded_img in decoded_images:
# Create image object NumPy array.
img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
images.append(img)
# Get scores
clip_inputs = clip_processor(
text=texts * jax.device_count(),
images=images,
return_tensors="np",
padding="max_length",
max_length=77,
truncation=True,
).data
# Shard for each device
logits = p_clip(shard(clip_inputs), clip_params)
# Organize scores
logits = np.asarray([logits[:, i::1, i] for i in range(1)]).squeeze()
imgs = []
for i, _ in enumerate(texts):
for idx in logits[i].argsort()[::-1]:
imgs.append(images[idx * 1 + i])
# print(f"Score: {jnp.asarray(logits[i][idx], dtype=jnp.float32):.2f}\n")
result = [imgs[0]]
return result