dalle-mini-fork / pipeline.py
osanseviero's picture
osanseviero HF staff
Update pipeline.py
d5b4d55
raw
history blame
4.38 kB
import jax
import flax.linen as nn
import random
import numpy as np
from PIL import Image
from transformers import BartConfig, BartTokenizer
from transformers.models.bart.modeling_flax_bart import (
FlaxBartModule,
FlaxBartForConditionalGenerationModule,
FlaxBartForConditionalGeneration,
FlaxBartEncoder,
FlaxBartDecoder
)
from vqgan_jax.modeling_flax_vqgan import VQModel
# Model hyperparameters, for convenience
OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
BOS_TOKEN_ID = 16384
BASE_MODEL = 'facebook/bart-large-cnn' # we currently have issues with bart-large
class CustomFlaxBartModule(FlaxBartModule):
def setup(self):
# check config is valid, otherwise set default values
self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE)
self.config.max_position_embeddings_decoder = getattr(self.config, 'max_position_embeddings_decoder', OUTPUT_LENGTH)
# we keep shared to easily load pre-trained weights
self.shared = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
)
# a separate embedding is used for the decoder
self.decoder_embed = nn.Embed(
self.config.vocab_size_output,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
)
self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
# the decoder has a different config
decoder_config = BartConfig(self.config.to_dict())
decoder_config.max_position_embeddings = self.config.max_position_embeddings_decoder
decoder_config.vocab_size = self.config.vocab_size_output
self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
def setup(self):
# check config is valid, otherwise set default values
self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE)
self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
self.lm_head = nn.Dense(
self.config.vocab_size_output,
use_bias=False,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
)
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.config.vocab_size_output))
class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
module_class = CustomFlaxBartForConditionalGenerationModule
class PreTrainedPipeline():
def __init__(self, path=""):
# IMPLEMENT_THIS
# Preload all the elements you are going to need at inference.
# For instance your model, processors, tokenizer that might be needed.
# This function is only called once, so do all the heavy processing I/O here"""
self.tokenizer = BartTokenizer.from_pretrained(path)
self.model = CustomFlaxBartForConditionalGeneration.from_pretrained(path)
self.vqgan = VQModel.from_pretrained("flax-community/vqgan_f16_16384", revision="90cc46addd2dd8f5be21586a9a23e1b95aa506a9")
def __call__(self, inputs: str):
"""
Args:
inputs (:obj:`str`):
a string containing some text
Return:
A :obj:`PIL.Image` with the raw image representation as PIL.
"""
tokenized_prompt = self.tokenizer(inputs, return_tensors='jax', padding='max_length', truncation=True, max_length=128)
key = jax.random.PRNGKey(random.randint(0, 2**32-1))
encoded_image = self.model.generate(**tokenized_prompt, do_sample=True, num_beams=1, prng_key=key)
# remove first token (BOS)
encoded_image = encoded_image.sequences[..., 1:]
decoded_image = vqgan.decode_code(encoded_image)
clipped_image = decoded_image.squeeze().clip(0., 1.)
return Image.fromarray(np.asarray(clipped_image * 255, dtype=np.uint8))