File size: 4,518 Bytes
24a3da7
 
 
 
 
 
 
 
 
 
 
 
 
 
03f44ea
2d82c00
24a3da7
52d0ddb
24a3da7
 
140cc8d
52d0ddb
24a3da7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
""" Generate and return image adapted from DALL-E mini's playground """

import random
from functools import partial
import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard_prng_key, shard
from vqgan_jax.modeling_flax_vqgan import VQModel
import numpy as np
from PIL import Image
from tqdm.notebook import trange
from dalle_mini import DalleBart, DalleBartProcessor
from transformers import CLIPProcessor, FlaxCLIPModel
import wandb
import os

wandb.login(key=os.environ["wandb"])

# Model to generate image tokens
MODEL = "fedorajuandy/tugas-akhir/model-jhhchemc:v11"
MODEL_COMMIT_ID = "None"

# VQGAN to decode image tokens
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"

# number of predictions; split per device
N_PREDICTIONS = 8

# generetion parameters
GEN_TOP_K = None
GEN_TOP_P = None
TEMPERATURE = None
COND_SCALE = 10.0

# CLIP
CLIP_REPO = "openai/clip-vit-base-patch32"
CLIP_COMMIT_ID = None


# Load models, not randomised
model, model_params = DalleBart.from_pretrained(
    MODEL, revision=MODEL_COMMIT_ID, dtype=jnp.float32, _do_init=False
)
# To process text
processor = DalleBartProcessor.from_pretrained(
    MODEL, revision=MODEL_COMMIT_ID
)

vqgan, vqgan_params = VQModel.from_pretrained(
    VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False
)

clip, clip_params = FlaxCLIPModel.from_pretrained(
    CLIP_REPO, revision=CLIP_COMMIT_ID, dtype=jnp.float16, _do_init=False
)
# To process text and image
clip_processor = CLIPProcessor.from_pretrained(
    CLIP_REPO, revision=CLIP_COMMIT_ID
)


# Replicate parameters to each device
model_params = replicate(model_params)
vqgan_params = replicate(vqgan_params)
clip_params = replicate(clip_params)


# Functions are compiled and parallelised to each device
@partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
def p_generate(tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale):
    """ Model inference """
    return model.generate(
        **tokenized_prompt,
        prng_key=key,
        params=params,
        top_k=top_k,
        top_p=top_p,
        temperature=temperature,
        condition_scale=condition_scale,
    )

@partial(jax.pmap, axis_name="batch")
def p_decode(indices, params):
    """ Decode image tokens """
    return vqgan.decode_code(indices, params=params)

# Score images
@partial(jax.pmap, axis_name="batch")
def p_clip(inputs, params):
    """ Return logits, wutever dat is """
    logits = clip(params=params, **inputs).logits_per_image
    return logits


def generate_image(text_prompt):
    """ Take text prompt and return generated image """

    # Generate key that is passed to each device to generate different images
    seed = random.randint(0, 2**32 - 1)
    key = jax.random.PRNGKey(seed)

    texts = [text_prompt]
    tokenized_prompts = processor(texts)
    tokenized_prompt = replicate(tokenized_prompts)

    # Generate images
    images = []
    for i in trange(max(N_PREDICTIONS // jax.device_count(), 1)):
        # Get a new key
        key, subkey = jax.random.split(key)
        encoded_images = p_generate(
            tokenized_prompt,
            shard_prng_key(subkey),
            model_params,
            GEN_TOP_K,
            GEN_TOP_P,
            TEMPERATURE,
            COND_SCALE,
        )
        # Remove BOS token
        encoded_images = encoded_images.sequences[..., 1:]
        decoded_images = p_decode(encoded_images, vqgan_params)
        decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
        for decoded_img in decoded_images:
            # Create image object NumPy array.
            img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
            images.append(img)

    # Get scores
    clip_inputs = clip_processor(
        text=texts * jax.device_count(),
        images=images,
        return_tensors="np",
        padding="max_length",
        max_length=77,
        truncation=True,
    ).data
    # Shard for each device
    logits = p_clip(shard(clip_inputs), clip_params)

    # Organize scores
    logits = np.asarray([logits[:, i::1, i] for i in range(1)]).squeeze()

    imgs = []
    for i, _ in enumerate(texts):
        for idx in logits[i].argsort()[::-1]:
            imgs.append(images[idx * 1 + i])
            # print(f"Score: {jnp.asarray(logits[i][idx], dtype=jnp.float32):.2f}\n")

    result = [imgs[0]]


    return result