import gradio as gr import torch from torch import Tensor, nn import spaces import numpy as np import io import base64 from flax import nnx import jax.numpy as jnp from jax import Array as Tensor from transformers import (FlaxCLIPTextModel, CLIPTokenizer, FlaxT5EncoderModel, T5Tokenizer) models = {} class HFEmbedder(nnx.Module): def __init__(self, version: str, max_length: int, **hf_kwargs): self.is_clip = version.startswith("openai") self.max_length = max_length self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" dtype = hf_kwargs.get("dtype", jnp.float32) if self.is_clip: self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length) # self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs) self.hf_module, params = FlaxCLIPTextModel.from_pretrained(version, _do_init=False, **hf_kwargs) else: self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length) # self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs) self.hf_module, params = FlaxT5EncoderModel.from_pretrained(version, _do_init=False,**hf_kwargs) self.hf_module._is_initialized = True import jax self.hf_module.params = jax.tree.map(lambda x: jax.device_put(x, jax.devices("cuda")[0]), params) # if dtype==jnp.bfloat16: def tokenize(self, text: list[str]) -> Tensor: batch_encoding = self.tokenizer( text, truncation=True, max_length=self.max_length, return_length=False, return_overflowing_tokens=False, padding="max_length", return_tensors="jax", ) return batch_encoding["input_ids"] def __call__(self, input_ids: Tensor) -> Tensor: # outputs = self.hf_module( # input_ids=batch_encoding["input_ids"].to(self.hf_module.device), # attention_mask=None, # output_hidden_states=False, # ) outputs = self.hf_module( input_ids=input_ids, attention_mask=None, output_hidden_states=False, train=False, ) return outputs[self.output_key] # def __call__(self, text: list[str]) -> Tensor: # batch_encoding = self.tokenizer( # text, # truncation=True, # max_length=self.max_length, # return_length=False, # return_overflowing_tokens=False, # padding="max_length", # return_tensors="jax", # ) # # outputs = self.hf_module( # # input_ids=batch_encoding["input_ids"].to(self.hf_module.device), # # attention_mask=None, # # output_hidden_states=False, # # ) # outputs = self.hf_module( # input_ids=batch_encoding["input_ids"], # attention_mask=None, # output_hidden_states=False, # train=False, # ) # return outputs[self.output_key] def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder: # max length 64, 128, 256 and 512 should work (if your sequence is short enough) return HFEmbedder("lnyan/t5-v1_1-xxl-encoder", max_length=max_length, dtype=jnp.bfloat16) def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, dtype=jnp.bfloat16) @spaces.GPU(duration=60) def load_encoders(): is_schnell = True t5 = load_t5("cuda", max_length=256 if is_schnell else 512) clip = load_clip("cuda") return t5, clip import numpy as np def b64(txt,vec): buffer = io.BytesIO() jnp.savez(buffer, txt=txt, vec=vec) buffer.seek(0) encoded = base64.b64encode(buffer.getvalue()).decode('utf-8') return encoded # t5,clip=load_encoders() @spaces.GPU(duration=20) def convert(prompt): t5,clip=models["t5"],models["clip"] if isinstance(prompt, str): prompt = [prompt] txt = t5.tokenize(prompt) txt = t5(txt) vec = clip.tokenize(prompt) vec = clip(vec) return b64(txt,vec) import jax def _to_embed(t5, clip, txt, vec): t5=nnx.merge(*t5) clip=nnx.merge(*clip) return t5(txt), clip(vec) to_embed=jax.jit(_to_embed) # t5_tuple=nnx.split(t5) # clip_tuple=nnx.split(clip) @spaces.GPU(duration=120) def compile(prompt): t5,clip,t5_tuple,clip_tuple=models["t5"],models["clip"],models["t5_tuple"],models["clip_tuple"] if isinstance(prompt, str): prompt = [prompt] txt = t5.tokenize(prompt) vec = clip.tokenize(prompt) text,vec=to_embed(t5_tuple,clip_tuple,txt,vec) return b64(txt,vec) @spaces.GPU(duration=120) def load(prompt): is_schnell = True t5 = load_t5("cuda", max_length=256 if is_schnell else 512) clip = load_clip("cuda") models["t5"]=t5 models["clip"]=clip models["t5_tuple"]=nnx.split(t5) models["clip_tuple"]=nnx.split(clip) return "Loaded" print(load("")) with gr.Blocks() as demo: gr.Markdown("""A workaround for flux-flax to fit into 40G VRAM""") with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="prompt") convert_btn = gr.Button(value="Convert") compile_btn = gr.Button(value="Compile") load_btn = gr.Button(value="Load") with gr.Column(): output = gr.Textbox(label="output") load_btn.click(load, inputs=prompt, outputs=output, api_name="load") convert_btn.click(convert, inputs=prompt, outputs=output, api_name="convert") compile_btn.click(compile, inputs=prompt, outputs=output, api_name="compile") demo.launch()