File size: 5,993 Bytes
29d660f
 
 
 
45ca5aa
98efed6
9045130
98efed6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
---
library_name: diffusers
---

# yujiepan/FLUX.1-dev-tiny-random

This pipeline is intended for debugging. It is adapted from [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) with smaller size and randomly initialized parameters.

## Usage
```python
import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained("yujiepan/FLUX.1-dev-tiny-random", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power
prompt = "A cat holding a sign that says hello world"
image = pipe(
    prompt,
    height=1024,
    width=1024,
    guidance_scale=3.5,
    num_inference_steps=50,
    max_sequence_length=512,
    generator=torch.Generator("cpu").manual_seed(0)
).images[0]
# image.save("flux-dev.png")
```

## Codes
```python
import importlib

import torch
import transformers

import diffusers
import rich


def get_original_model_configs(
    pipeline_cls: type[diffusers.FluxPipeline],
    pipeline_id: str
):
    pipeline_config: dict[str, list[str]] = \
        pipeline_cls.load_config(pipeline_id)
    model_configs = {}

    for subfolder, import_strings in pipeline_config.items():
        if subfolder.startswith("_"):
            continue
        module = importlib.import_module(".".join(import_strings[:-1]))
        cls = getattr(module, import_strings[-1])
        if issubclass(cls, transformers.PreTrainedModel):
            config_class: transformers.PretrainedConfig = cls.config_class
            config = config_class.from_pretrained(
                pipeline_id, subfolder=subfolder)
            model_configs[subfolder] = config
        elif issubclass(cls, diffusers.ModelMixin) and issubclass(cls, diffusers.ConfigMixin):
            config = cls.load_config(pipeline_id, subfolder=subfolder)
            model_configs[subfolder] = config
        elif subfolder in ['scheduler', 'tokenizer', 'tokenizer_2', 'tokenizer_3']:
            pass
        else:
            raise NotImplementedError(f"unknown {subfolder}: {import_strings}")

    return model_configs


def load_pipeline(pipeline_cls: type[diffusers.DiffusionPipeline], pipeline_id: str, model_configs: dict[str, dict]):
    pipeline_config: dict[str, list[str]
                          ] = pipeline_cls.load_config(pipeline_id)
    components = {}
    for subfolder, import_strings in pipeline_config.items():
        if subfolder.startswith("_"):
            continue
        module = importlib.import_module(".".join(import_strings[:-1]))
        cls = getattr(module, import_strings[-1])
        print(f"Loading:", ".".join(import_strings))
        if issubclass(cls, transformers.PreTrainedModel):
            config = model_configs[subfolder]
            component = cls(config)
        elif issubclass(cls, transformers.PreTrainedTokenizerBase):
            component = cls.from_pretrained(pipeline_id, subfolder=subfolder)
        elif issubclass(cls, diffusers.ModelMixin) and issubclass(cls, diffusers.ConfigMixin):
            config = model_configs[subfolder]
            component = cls.from_config(config)
        elif issubclass(cls, diffusers.SchedulerMixin) and issubclass(cls, diffusers.ConfigMixin):
            component = cls.from_pretrained(pipeline_id, subfolder=subfolder)
        else:
            raise (f"unknown {subfolder}: {import_strings}")
        components[subfolder] = component
        if 'transformer' in component.__class__.__name__.lower():
            print(component)
    pipeline = pipeline_cls(**components)
    return pipeline


def get_pipeline():
    torch.manual_seed(42)
    pipeline_id = "black-forest-labs/FLUX.1-dev"
    pipeline_cls = diffusers.FluxPipeline
    model_configs = get_original_model_configs(pipeline_cls, pipeline_id)

    HIDDEN_SIZE = 8
    model_configs["text_encoder"].hidden_size = HIDDEN_SIZE
    model_configs["text_encoder"].intermediate_size = HIDDEN_SIZE * 2
    model_configs["text_encoder"].num_attention_heads = 2
    model_configs["text_encoder"].num_hidden_layers = 2
    model_configs["text_encoder"].projection_dim = HIDDEN_SIZE

    model_configs["text_encoder_2"].d_model = HIDDEN_SIZE
    model_configs["text_encoder_2"].d_ff = HIDDEN_SIZE * 2
    model_configs["text_encoder_2"].d_kv = HIDDEN_SIZE // 2
    model_configs["text_encoder_2"].num_heads = 2
    model_configs["text_encoder_2"].num_layers = 2

    model_configs["transformer"]["num_layers"] = 2
    model_configs["transformer"]["num_single_layers"] = 4
    model_configs["transformer"]["num_attention_heads"] = 2
    model_configs["transformer"]["attention_head_dim"] = HIDDEN_SIZE
    model_configs["transformer"]["pooled_projection_dim"] = HIDDEN_SIZE
    model_configs["transformer"]["joint_attention_dim"] = HIDDEN_SIZE
    model_configs["transformer"]["axes_dims_rope"] = (4, 2, 2)
    # model_configs["transformer"]["caption_projection_dim"] = HIDDEN_SIZE

    model_configs["vae"]["layers_per_block"] = 1
    model_configs["vae"]["block_out_channels"] = [HIDDEN_SIZE] * 4
    model_configs["vae"]["norm_num_groups"] = 2
    model_configs["vae"]["latent_channels"] = 16

    pipeline = load_pipeline(pipeline_cls, pipeline_id, model_configs)
    return pipeline


pipe = get_pipeline()
pipe = pipe.to(torch.bfloat16)

from pathlib import Path
save_folder = '/tmp/yujiepan/FLUX.1-dev-tiny-random'
Path(save_folder).mkdir(parents=True, exist_ok=True)
pipe.save_pretrained(save_folder)

pipe = diffusers.FluxPipeline.from_pretrained(save_folder, torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
prompt = "A cat holding a sign that says hello world"
image = pipe(
    prompt,
    height=1024,
    width=1024,
    guidance_scale=3.5,
    num_inference_steps=50,
    max_sequence_length=512,
    generator=torch.Generator("cpu").manual_seed(0)
).images[0]

configs = get_original_model_configs(diffusers.FluxPipeline, save_folder)
rich.print(configs)

pipe.push_to_hub(save_folder.removeprefix('/tmp/'))
```