|
import jax |
|
import flax.linen as nn |
|
|
|
import random |
|
import numpy as np |
|
from PIL import Image |
|
|
|
from transformers import BartConfig, BartTokenizer |
|
|
|
from transformers.models.bart.modeling_flax_bart import ( |
|
FlaxBartModule, |
|
FlaxBartForConditionalGenerationModule, |
|
FlaxBartForConditionalGeneration, |
|
FlaxBartEncoder, |
|
FlaxBartDecoder |
|
) |
|
|
|
|
|
from vqgan_jax.modeling_flax_vqgan import VQModel |
|
|
|
|
|
|
|
|
|
OUTPUT_VOCAB_SIZE = 16384 + 1 |
|
OUTPUT_LENGTH = 256 + 1 |
|
BOS_TOKEN_ID = 16384 |
|
BASE_MODEL = 'facebook/bart-large-cnn' |
|
|
|
class CustomFlaxBartModule(FlaxBartModule): |
|
def setup(self): |
|
|
|
self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE) |
|
self.config.max_position_embeddings_decoder = getattr(self.config, 'max_position_embeddings_decoder', OUTPUT_LENGTH) |
|
|
|
|
|
self.shared = nn.Embed( |
|
self.config.vocab_size, |
|
self.config.d_model, |
|
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), |
|
dtype=self.dtype, |
|
) |
|
|
|
self.decoder_embed = nn.Embed( |
|
self.config.vocab_size_output, |
|
self.config.d_model, |
|
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), |
|
dtype=self.dtype, |
|
) |
|
self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) |
|
|
|
|
|
decoder_config = BartConfig(self.config.to_dict()) |
|
decoder_config.max_position_embeddings = self.config.max_position_embeddings_decoder |
|
decoder_config.vocab_size = self.config.vocab_size_output |
|
self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed) |
|
|
|
class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule): |
|
def setup(self): |
|
|
|
self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE) |
|
|
|
self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype) |
|
self.lm_head = nn.Dense( |
|
self.config.vocab_size_output, |
|
use_bias=False, |
|
dtype=self.dtype, |
|
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), |
|
) |
|
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.config.vocab_size_output)) |
|
|
|
class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration): |
|
module_class = CustomFlaxBartForConditionalGenerationModule |
|
|
|
class PreTrainedPipeline(): |
|
def __init__(self, path=""): |
|
self.vqgan = VQModel.from_pretrained("flax-community/vqgan_f16_16384", revision="90cc46addd2dd8f5be21586a9a23e1b95aa506a9") |
|
self.tokenizer = BartTokenizer.from_pretrained(path) |
|
self.model = CustomFlaxBartForConditionalGeneration.from_pretrained(path) |
|
|
|
|
|
|
|
def __call__(self, inputs: str): |
|
""" |
|
Args: |
|
inputs (:obj:`str`): |
|
a string containing some text |
|
Return: |
|
A :obj:`PIL.Image` with the raw image representation as PIL. |
|
""" |
|
tokenized_prompt = self.tokenizer(inputs, return_tensors='jax', padding='max_length', truncation=True, max_length=128) |
|
key = jax.random.PRNGKey(random.randint(0, 2**32-1)) |
|
encoded_image = self.model.generate(**tokenized_prompt, do_sample=True, num_beams=1, prng_key=key) |
|
|
|
|
|
encoded_image = encoded_image.sequences[..., 1:] |
|
decoded_image = self.vqgan.decode_code(encoded_image) |
|
clipped_image = decoded_image.squeeze().clip(0., 1.) |
|
|
|
return Image.fromarray(np.asarray(clipped_image * 255, dtype=np.uint8)) |
|
|