Flux Qwen Neutered

Qwen is a lightweight alternative to the T5 model. For use with the Flux Dev model.

For the numerical stability, it requires both tokenizers from Qwen and Flux, that's 10MB of additional data.

This repo is an experimental work, and not a final replacement for the built-in text encoder.

Compared to a standalone version, this demo has improved accuracy and training time. This is mainly due to the reuse of a pre-trained model.

Inference

from diffusers import FluxPipeline, FluxTransformer2DModel
from text_encoder import PretrainedTextEncoder
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from typing import List, Optional, Union

def setup_qwen(pipe,
               qwen_path='Qwen/Qwen2.5-0.5B',
               device=None,
               dtype=torch.bfloat16):
    pipe.qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_path)
    qwen = AutoModelForCausalLM.from_pretrained(qwen_path,
                                                device_map=device,
                                                torch_dtype=dtype)
    pipe.qwen_model = qwen.model

    return pipe

class FluxQwenPipeline(FluxPipeline):
    def _get_t5_prompt_embeds(self,
                              prompt: Union[str, List[str]] = None,
                              num_images_per_prompt: int = 1,
                              max_sequence_length: int = 512,
                              device: Optional[torch.device] = None,
                              dtype: Optional[torch.dtype] = None):
        qwen_out = self.encode_qwen(prompt, max_sequence_length, device)
        inputs = self.tokenizer_2(prompt,
                                  return_tensors='pt',
                                  padding='max_length',
                                  truncation=True,
                                  max_length=max_sequence_length)
        input_ids = inputs['input_ids'].to(device)
        attention_mask = inputs['attention_mask'].to(device)
        output = encoder(qwen_out, encoder.shared(input_ids), max_length=max_sequence_length)

        return output * attention_mask.unsqueeze(-1)

    def encode_qwen(self, prompt, max_sequence_length=256, device=None):
        inputs = self.qwen_tokenizer(prompt,
                                     return_tensors='pt',
                                     padding='max_length',
                                     truncation=True,
                                     max_length=max_sequence_length)
        input_ids = inputs['input_ids'].to(device)
        attention_mask = inputs['attention_mask'].to(device)
        output = self.qwen_model(input_ids=input_ids,
                                 attention_mask=attention_mask)

        return output.last_hidden_state

if __name__ == '__main__':
    encoder = PretrainedTextEncoder.from_pretrained('twodgirl/flux-qwen-neutered',
                                                    device_map='cuda',
                                                    torch_dtype=torch.bfloat16)
    pipe = FluxQwenPipeline.from_pretrained('black-forest-labs/FLUX.1-dev',
                                            text_encoder_2=None,
                                            torch_dtype=torch.bfloat16)
    setup_qwen(pipe, device='cuda')
    pipe.enable_model_cpu_offload()
    image = pipe('a black cat wearing a Pikachu cosplay').images[0]
    image.save('cat.png')

Disclaimer

Use of this code and the model requires citation and attribution to the author via a link to their Hugging Face profile in all resulting work.

Downloads last month
12
Safetensors
Model size
264M params
Tensor type
BF16
·
Inference Examples
Unable to determine this model's library. Check the docs .

Model tree for twodgirl/flux-qwen-neutered

Finetuned
(269)
this model