import jax import flax.linen as nn import random import numpy as np from PIL import Image from transformers import BartConfig, BartTokenizer from transformers.models.bart.modeling_flax_bart import ( FlaxBartModule, FlaxBartForConditionalGenerationModule, FlaxBartForConditionalGeneration, FlaxBartEncoder, FlaxBartDecoder ) from vqgan_jax.modeling_flax_vqgan import VQModel # Model hyperparameters, for convenience OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos BOS_TOKEN_ID = 16384 BASE_MODEL = 'facebook/bart-large-cnn' # we currently have issues with bart-large class CustomFlaxBartModule(FlaxBartModule): def setup(self): # check config is valid, otherwise set default values self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE) self.config.max_position_embeddings_decoder = getattr(self.config, 'max_position_embeddings_decoder', OUTPUT_LENGTH) # we keep shared to easily load pre-trained weights self.shared = nn.Embed( self.config.vocab_size, self.config.d_model, embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), dtype=self.dtype, ) # a separate embedding is used for the decoder self.decoder_embed = nn.Embed( self.config.vocab_size_output, self.config.d_model, embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), dtype=self.dtype, ) self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) # the decoder has a different config decoder_config = BartConfig(self.config.to_dict()) decoder_config.max_position_embeddings = self.config.max_position_embeddings_decoder decoder_config.vocab_size = self.config.vocab_size_output self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed) class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule): def setup(self): # check config is valid, otherwise set default values self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE) self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype) self.lm_head = nn.Dense( self.config.vocab_size_output, use_bias=False, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), ) self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.config.vocab_size_output)) class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration): module_class = CustomFlaxBartForConditionalGenerationModule class PreTrainedPipeline(): def __init__(self, path=""): self.vqgan = VQModel.from_pretrained("flax-community/vqgan_f16_16384", revision="90cc46addd2dd8f5be21586a9a23e1b95aa506a9") self.tokenizer = BartTokenizer.from_pretrained(path) self.model = CustomFlaxBartForConditionalGeneration.from_pretrained(path) def __call__(self, inputs: str): """ Args: inputs (:obj:`str`): a string containing some text Return: A :obj:`PIL.Image` with the raw image representation as PIL. """ tokenized_prompt = self.tokenizer(inputs, return_tensors='jax', padding='max_length', truncation=True, max_length=128) key = jax.random.PRNGKey(random.randint(0, 2**32-1)) encoded_image = self.model.generate(**tokenized_prompt, do_sample=True, num_beams=1, prng_key=key) # remove first token (BOS) encoded_image = encoded_image.sequences[..., 1:] decoded_image = self.vqgan.decode_code(encoded_image) clipped_image = decoded_image.squeeze().clip(0., 1.) return Image.fromarray(np.asarray(clipped_image * 255, dtype=np.uint8))