dalle-mini-fork / pipeline.py
osanseviero's picture
Fix
cdc4c33
raw
history blame
4.12 kB
import jax
import flax.linen as nn
import random
import numpy as np
from PIL import Image
from transformers import BartConfig, BartTokenizer
from transformers.models.bart.modeling_flax_bart import (
FlaxBartModule,
FlaxBartForConditionalGenerationModule,
FlaxBartForConditionalGeneration,
FlaxBartEncoder,
FlaxBartDecoder
)
from vqgan_jax.modeling_flax_vqgan import VQModel
# Model hyperparameters, for convenience
OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
BOS_TOKEN_ID = 16384
BASE_MODEL = 'facebook/bart-large-cnn' # we currently have issues with bart-large
class CustomFlaxBartModule(FlaxBartModule):
def setup(self):
# check config is valid, otherwise set default values
self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE)
self.config.max_position_embeddings_decoder = getattr(self.config, 'max_position_embeddings_decoder', OUTPUT_LENGTH)
# we keep shared to easily load pre-trained weights
self.shared = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
)
# a separate embedding is used for the decoder
self.decoder_embed = nn.Embed(
self.config.vocab_size_output,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
)
self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
# the decoder has a different config
decoder_config = BartConfig(self.config.to_dict())
decoder_config.max_position_embeddings = self.config.max_position_embeddings_decoder
decoder_config.vocab_size = self.config.vocab_size_output
self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
def setup(self):
# check config is valid, otherwise set default values
self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE)
self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
self.lm_head = nn.Dense(
self.config.vocab_size_output,
use_bias=False,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
)
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.config.vocab_size_output))
class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
module_class = CustomFlaxBartForConditionalGenerationModule
class PreTrainedPipeline():
def __init__(self, path=""):
self.vqgan = VQModel.from_pretrained("flax-community/vqgan_f16_16384", revision="90cc46addd2dd8f5be21586a9a23e1b95aa506a9")
self.tokenizer = BartTokenizer.from_pretrained(path)
self.model = CustomFlaxBartForConditionalGeneration.from_pretrained(path)
def __call__(self, inputs: str):
"""
Args:
inputs (:obj:`str`):
a string containing some text
Return:
A :obj:`PIL.Image` with the raw image representation as PIL.
"""
tokenized_prompt = self.tokenizer(inputs, return_tensors='jax', padding='max_length', truncation=True, max_length=128)
key = jax.random.PRNGKey(random.randint(0, 2**32-1))
encoded_image = self.model.generate(**tokenized_prompt, do_sample=True, num_beams=1, prng_key=key)
# remove first token (BOS)
encoded_image = encoded_image.sequences[..., 1:]
decoded_image = self.vqgan.decode_code(encoded_image)
clipped_image = decoded_image.squeeze().clip(0., 1.)
return Image.fromarray(np.asarray(clipped_image * 255, dtype=np.uint8))