Yuanshi commited on
Commit
6ed1db6
·
1 Parent(s): f34bc73
README.md CHANGED
@@ -1,12 +1,127 @@
1
- ---
2
- title: OminiControl
3
- emoji: 🌍
4
- colorFrom: blue
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 5.6.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # OminiControl
2
+
3
+
4
+ <img src='./assets/demo/demo_this_is_omini_control.jpg' width='100%' />
5
+ <br>
6
+
7
+ <a href="https://arxiv.org/abs/2411.15098"><img src="https://img.shields.io/badge/ariXv-2411.15098-A42C25.svg" alt="arXiv"></a>
8
+ <a href="https://huggingface.co/Yuanshi/OminiControl"><img src="https://img.shields.io/badge/🤗_HuggingFace-Model-ffbd45.svg" alt="HuggingFace"></a>
9
+ <a href="https://github.com/Yuanshi9815/Subjects200K"><img src="https://img.shields.io/badge/GitHub-Subjects200K dataset-blue.svg?logo=github&" alt="GitHub"></a>
10
+
11
+ > **OminiControl: Minimal and Universal Control for Diffuison Transformer**
12
+ > <br>
13
+ > Zhenxiong Tan,
14
+ > [Songhua Liu](http://121.37.94.87/),
15
+ > [Xingyi Yang](https://adamdad.github.io/),
16
+ > Qiaochu Xue,
17
+ > and
18
+ > [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/)
19
+ > <br>
20
+ > [Learning and Vision Lab](http://lv-nus.org/), National University of Singapore
21
+ > <br>
22
+
23
+
24
+ ## Features
25
+
26
+ OmniControl is a minimal yet powerful universal control framework for Diffusion Transformer models like [FLUX](https://github.com/black-forest-labs/flux).
27
+
28
+ * **Universal Control 🌐**: A unified control framework that supports both subject-driven control and spatial control (such as edge-guided and in-painting generation).
29
+
30
+ * **Minimal Design 🚀**: Injects control signals while preserving original model structure. Only introduces 0.1% additional parameters to the base model.
31
+
32
+ ## Quick Start
33
+ ### Setup (Optional)
34
+ 1. **Environment setup**
35
+ ```bash
36
+ conda create -n omini python=3.10
37
+ conda activate omini
38
+ ```
39
+ 2. **Requirements installation**
40
+ ```bash
41
+ pip install -r requirements.txt
42
+ ```
43
+ ### Usage example
44
+ 1. Subject-driven generation: `examples/subject.ipynb`
45
+ 2. In-painting: `examples/inpainting.ipynb`
46
+ 3. Canny edge to image, depth to image, colorization, deblurring: `examples/spatial.ipynb`
47
+
48
+ ## Generated samples
49
+ ### Subject-driven generation
50
+ **Demos** (Left: condition image; Right: generated image)
51
+
52
+ <div float="left">
53
+ <img src='./assets/demo/oranges_omini.jpg' width='48%'/>
54
+ <img src='./assets/demo/rc_car_omini.jpg' width='48%' />
55
+ <img src='./assets/demo/clock_omini.jpg' width='48%' />
56
+ <img src='./assets/demo/shirt_omini.jpg' width='48%' />
57
+ </div>
58
+
59
+ <details>
60
+ <summary>Text Prompts</summary>
61
+
62
+ - Prompt1: *A close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show. With text on the screen that reads 'Omini Control!.'*
63
+ - Prompt2: *A film style shot. On the moon, this item drives across the moon surface. A flag on it reads 'Omini'. The background is that Earth looms large in the foreground.*
64
+ - Prompt3: *In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.*
65
+ - Prompt4: *In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.*
66
+ </details>
67
+ <details>
68
+ <summary>More results</summary>
69
+
70
+ * Try on:
71
+ <img src='./assets/demo/try_on.jpg'/>
72
+ * Scene variations:
73
+ <img src='./assets/demo/scene_variation.jpg'/>
74
+ * Dreambooth dataset:
75
+ <img src='./assets/demo/dreambooth_res.jpg'/>
76
+ </details>
77
+
78
+ ### Spaitally aligned control
79
+ 1. **Image Inpainting** (Left: original image; Center: masked image; Right: filled image)
80
+ - Prompt: *The Mona Lisa is wearing a white VR headset with 'Omini' written on it.*
81
+ </br>
82
+ <img src='./assets/demo/monalisa_omini.jpg' width='700px' />
83
+ - Prompt: *A yellow book with the word 'OMINI' in large font on the cover. The text 'for FLUX' appears at the bottom.*
84
+ </br>
85
+ <img src='./assets/demo/book_omini.jpg' width='700px' />
86
+ 2. **Other spatially aligned tasks** (Canny edge to image, depth to image, colorization, deblurring)
87
+ </br>
88
+ <details>
89
+ <summary>Click to show</summary>
90
+ <div float="left">
91
+ <img src='./assets/demo/room_corner_canny.jpg' width='48%'/>
92
+ <img src='./assets/demo/room_corner_depth.jpg' width='48%' />
93
+ <img src='./assets/demo/room_corner_coloring.jpg' width='48%' />
94
+ <img src='./assets/demo/room_corner_deblurring.jpg' width='48%' />
95
+ </div>
96
+
97
+ Prompt: *A light gray sofa stands against a white wall, featuring a black and white geometric patterned pillow. A white side table sits next to the sofa, topped with a white adjustable desk lamp and some books. Dark hardwood flooring contrasts with the pale walls and furniture.*
98
+ </details>
99
+
100
+
101
+
102
+
103
+ ## Models
104
+
105
+ **Subject-driven control:**
106
+ | Model | Base model | Description | Resolution |
107
+ | ------------------------------------------------------------------------------------------------ | -------------- | -------------------------------------------------------------------------------------------------------- | ------------ |
108
+ | [`experimental`](https://huggingface.co/Yuanshi/OminiControl/tree/main/experimental) / `subject` | FLUX.1-schnell | The model used in the paper. | (512, 512) |
109
+ | [`omini`](https://huggingface.co/Yuanshi/OminiControl/tree/main/omini) / `subject_512` | FLUX.1-schnell | The model has been fine-tuned on a larger dataset. | (512, 512) |
110
+ | [`omini`](https://huggingface.co/Yuanshi/OminiControl/tree/main/omini) / `subject_1024` | FLUX.1-schnell | The model has been fine-tuned on a larger dataset and accommodates higher resolution. (To be released) | (1024, 1024) |
111
+
112
+ **Spatial aligned control:**
113
+ | Model | Base model | Description | Resolution |
114
+ | --------------------------------------------------------------------------------------------------------- | ---------- | -------------------------------------------------------------------------- | ------------ |
115
+ | [`experimental`](https://huggingface.co/Yuanshi/OminiControl/tree/main/experimental) / `<task_name>` | FLUX.1 | Canny edge to image, depth to image, colorization, deblurring, in-painting | (512, 512) |
116
+ | [`experimental`](https://huggingface.co/Yuanshi/OminiControl/tree/main/experimental) / `<task_name>_1024` | FLUX.1 | Supports higher resolution.(To be released) | (1024, 1024) |
117
+
118
+ ## Citation
119
+ ```
120
+ @article{
121
+ tan2024omini,
122
+ title={OminiControl: Minimal and Universal Control for Diffusion Transformer},
123
+ author={Zhenxiong Tan, Songhua Liu, Xingyi Yang, Qiaochu Xue, and Xinchao Wang},
124
+ journal={arXiv preprint arXiv:2411.15098},
125
+ year={2024}
126
+ }
127
+ ```
assets/book.jpg ADDED

Git LFS Details

  • SHA256: c4a0168de6842d12bd0adefdfc9f9791ca2963649db86af6774055be588cfcb4
  • Pointer size: 130 Bytes
  • Size of remote file: 62.7 kB
assets/clock.jpg ADDED

Git LFS Details

  • SHA256: 41235973f26152ac92d32bfc166fb5f9f1e352c5e16807920238473316ec462b
  • Pointer size: 131 Bytes
  • Size of remote file: 289 kB
assets/demo/book_omini.jpg ADDED

Git LFS Details

  • SHA256: f58783c277876d11c10f5da15da659445af4a5074d5f6025508d7716c1998cde
  • Pointer size: 130 Bytes
  • Size of remote file: 61.7 kB
assets/demo/clock_omini.jpg ADDED

Git LFS Details

  • SHA256: 0a58342f8e8751a1b27cad2e207fa089e3d09ee5d518f791db4b39bc957308d0
  • Pointer size: 130 Bytes
  • Size of remote file: 79.4 kB
assets/demo/demo_this_is_omini_control.jpg ADDED

Git LFS Details

  • SHA256: 798b7c25be6be118dc0de97c444c840869afca633a0d48f99d940aec040a7518
  • Pointer size: 131 Bytes
  • Size of remote file: 129 kB
assets/demo/dreambooth_res.jpg ADDED

Git LFS Details

  • SHA256: ba36bd861989564dc679acf3b5e56f382f1a11b1596e6f611ea0bd7d81b89680
  • Pointer size: 132 Bytes
  • Size of remote file: 1.94 MB
assets/demo/monalisa_omini.jpg ADDED

Git LFS Details

  • SHA256: e5ca6c2bf44f19d216b2eb16dcc67d19f11d87220d3ee80f5e5e1ad98a5536dc
  • Pointer size: 131 Bytes
  • Size of remote file: 133 kB
assets/demo/oranges_omini.jpg ADDED

Git LFS Details

  • SHA256: 56ee72d3e38695e3babd01f4909c9303b3d4079b5017ea84bd8e7e835764f18b
  • Pointer size: 130 Bytes
  • Size of remote file: 44 kB
assets/demo/penguin_omini.jpg ADDED

Git LFS Details

  • SHA256: 6de9eb67e96baaff560cde978ef0c1cd43d8ccec949e3b5898e5613c2f78d01e
  • Pointer size: 130 Bytes
  • Size of remote file: 50.6 kB
assets/demo/rc_car_omini.jpg ADDED

Git LFS Details

  • SHA256: fcedbe922d92f5736cca92798283af4db6da0f73e0df8e3a38fa16870d80dcac
  • Pointer size: 130 Bytes
  • Size of remote file: 65.9 kB
assets/demo/room_corner_canny.jpg ADDED

Git LFS Details

  • SHA256: f73418bc5a794b2d43c2e4a19ae574794586f7cede0024983f19c31ea7d4e061
  • Pointer size: 130 Bytes
  • Size of remote file: 57.8 kB
assets/demo/room_corner_coloring.jpg ADDED

Git LFS Details

  • SHA256: 8005eab40ecb99db41e6bd448dda3595d265dd1a4875bbd7b7e63dd713843f19
  • Pointer size: 130 Bytes
  • Size of remote file: 50.3 kB
assets/demo/room_corner_deblurring.jpg ADDED

Git LFS Details

  • SHA256: f59b5931a2235d4ede7a04aed56b3b08ce31e88ec454e92f67d2be11e95cfadf
  • Pointer size: 130 Bytes
  • Size of remote file: 37.4 kB
assets/demo/room_corner_depth.jpg ADDED

Git LFS Details

  • SHA256: f6a42c4e0cef91a8893510885a1ed8917ef42029a790193d77780d5f3d48de00
  • Pointer size: 130 Bytes
  • Size of remote file: 34.6 kB
assets/demo/scene_variation.jpg ADDED

Git LFS Details

  • SHA256: 39e4e16d2eeb58b3775b6d34c8b3e125d0d19cc36fa90b07c6c8d57624ad4333
  • Pointer size: 131 Bytes
  • Size of remote file: 958 kB
assets/demo/shirt_omini.jpg ADDED

Git LFS Details

  • SHA256: 0bd5f57c687196e8777da9b96c7dc29664d1f9b37ea9fea108278dfe37e8183f
  • Pointer size: 130 Bytes
  • Size of remote file: 81.3 kB
assets/demo/try_on.jpg ADDED

Git LFS Details

  • SHA256: 6adce5194329a83f0109b4375e00667c341879e64fb55831c70ea3f3b2f99f7e
  • Pointer size: 131 Bytes
  • Size of remote file: 774 kB
assets/monalisa.jpg ADDED

Git LFS Details

  • SHA256: 188b8b6499e4541f9dfef2a9daf6f1eb920079c9208f587fd97566d6aa4a9719
  • Pointer size: 131 Bytes
  • Size of remote file: 353 kB
assets/oranges.jpg ADDED

Git LFS Details

  • SHA256: 5ce1a31c0de86987e896e57467f5457e7b2a9d5ee4f585ee4cb138f41d3987cf
  • Pointer size: 130 Bytes
  • Size of remote file: 53.9 kB
assets/penguin.jpg ADDED

Git LFS Details

  • SHA256: 4731ab69d614d55080435355d9a07d20dbadff695138c2a4374c483400674241
  • Pointer size: 130 Bytes
  • Size of remote file: 48.6 kB
assets/rc_car.jpg ADDED

Git LFS Details

  • SHA256: ae8aed11029fa3b084deb286c07a8cab5056840c9c123816fe2b504e94233e95
  • Pointer size: 131 Bytes
  • Size of remote file: 254 kB
assets/room_corner.jpg ADDED

Git LFS Details

  • SHA256: f97bd63df05f5f15ad5dd1a2ccef803e74e12caadd8fe145493fd6d5219045e7
  • Pointer size: 131 Bytes
  • Size of remote file: 236 kB
assets/tshirt.jpg ADDED

Git LFS Details

  • SHA256: cb1803315765302113a9e7a64dedd4ecba2672028cf093cbc33ef2edd2247c39
  • Pointer size: 131 Bytes
  • Size of remote file: 301 kB
examples/inpainting.ipynb ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "\n",
11
+ "os.chdir(\"..\")"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": null,
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "import torch\n",
21
+ "from diffusers.pipelines import FluxPipeline\n",
22
+ "from src.condition import Condition\n",
23
+ "from PIL import Image\n",
24
+ "\n",
25
+ "from src.generate import generate, seed_everything"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "pipe = FluxPipeline.from_pretrained(\n",
35
+ " \"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16\n",
36
+ ")\n",
37
+ "pipe = pipe.to(\"cuda\")"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": null,
43
+ "metadata": {},
44
+ "outputs": [],
45
+ "source": [
46
+ "pipe.load_lora_weights(\n",
47
+ " \"Yuanshi/OminiControl\",\n",
48
+ " weight_name=f\"experimental/fill.safetensors\",\n",
49
+ " adapter_name=\"fill\",\n",
50
+ ")"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": null,
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "image = Image.open(\"assets/monalisa.jpg\").convert(\"RGB\").resize((512, 512))\n",
60
+ "\n",
61
+ "masked_image = image.copy()\n",
62
+ "masked_image.paste((0, 0, 0), (128, 100, 384, 220))\n",
63
+ "\n",
64
+ "condition = Condition(\"fill\", masked_image)\n",
65
+ "\n",
66
+ "seed_everything()\n",
67
+ "result_img = generate(\n",
68
+ " pipe,\n",
69
+ " prompt=\"The Mona Lisa is wearing a white VR headset with 'Omini' written on it.\",\n",
70
+ " conditions=[condition],\n",
71
+ ").images[0]\n",
72
+ "\n",
73
+ "concat_image = Image.new(\"RGB\", (1536, 512))\n",
74
+ "concat_image.paste(image, (0, 0))\n",
75
+ "concat_image.paste(condition.condition, (512, 0))\n",
76
+ "concat_image.paste(result_img, (1024, 0))\n",
77
+ "concat_image"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": null,
83
+ "metadata": {},
84
+ "outputs": [],
85
+ "source": [
86
+ "image = Image.open(\"assets/book.jpg\").convert(\"RGB\").resize((512, 512))\n",
87
+ "\n",
88
+ "w, h, min_dim = image.size + (min(image.size),)\n",
89
+ "image = image.crop(\n",
90
+ " ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2)\n",
91
+ ").resize((512, 512))\n",
92
+ "\n",
93
+ "\n",
94
+ "masked_image = image.copy()\n",
95
+ "masked_image.paste((0, 0, 0), (150, 150, 350, 250))\n",
96
+ "masked_image.paste((0, 0, 0), (200, 380, 320, 420))\n",
97
+ "\n",
98
+ "condition = Condition(\"fill\", masked_image)\n",
99
+ "\n",
100
+ "seed_everything()\n",
101
+ "result_img = generate(\n",
102
+ " pipe,\n",
103
+ " prompt=\"A yellow book with the word 'OMINI' in large font on the cover. The text 'for FLUX' appears at the bottom.\",\n",
104
+ " conditions=[condition],\n",
105
+ ").images[0]\n",
106
+ "\n",
107
+ "concat_image = Image.new(\"RGB\", (1536, 512))\n",
108
+ "concat_image.paste(image, (0, 0))\n",
109
+ "concat_image.paste(condition.condition, (512, 0))\n",
110
+ "concat_image.paste(result_img, (1024, 0))\n",
111
+ "concat_image"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": null,
117
+ "metadata": {},
118
+ "outputs": [],
119
+ "source": []
120
+ }
121
+ ],
122
+ "metadata": {
123
+ "kernelspec": {
124
+ "display_name": "base",
125
+ "language": "python",
126
+ "name": "python3"
127
+ },
128
+ "language_info": {
129
+ "codemirror_mode": {
130
+ "name": "ipython",
131
+ "version": 3
132
+ },
133
+ "file_extension": ".py",
134
+ "mimetype": "text/x-python",
135
+ "name": "python",
136
+ "nbconvert_exporter": "python",
137
+ "pygments_lexer": "ipython3",
138
+ "version": "3.12.7"
139
+ }
140
+ },
141
+ "nbformat": 4,
142
+ "nbformat_minor": 2
143
+ }
examples/spatial.ipynb ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "\n",
11
+ "os.chdir(\"..\")"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": null,
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "import torch\n",
21
+ "from diffusers.pipelines import FluxPipeline\n",
22
+ "from src.condition import Condition\n",
23
+ "from PIL import Image\n",
24
+ "\n",
25
+ "from src.generate import generate, seed_everything"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "pipe = FluxPipeline.from_pretrained(\n",
35
+ " \"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16\n",
36
+ ")\n",
37
+ "pipe = pipe.to(\"cuda\")"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": null,
43
+ "metadata": {},
44
+ "outputs": [],
45
+ "source": [
46
+ "for condition_type in [\"canny\", \"depth\", \"coloring\", \"deblurring\"]:\n",
47
+ " pipe.load_lora_weights(\n",
48
+ " \"Yuanshi/OminiControl\",\n",
49
+ " weight_name=f\"experimental/{condition_type}.safetensors\",\n",
50
+ " adapter_name=condition_type,\n",
51
+ " )"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": null,
57
+ "metadata": {},
58
+ "outputs": [],
59
+ "source": [
60
+ "image = Image.open(\"assets/coffee.png\").convert(\"RGB\")\n",
61
+ "\n",
62
+ "w, h, min_dim = image.size + (min(image.size),)\n",
63
+ "image = image.crop(\n",
64
+ " ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2)\n",
65
+ ").resize((512, 512))\n",
66
+ "\n",
67
+ "prompt = \"In a bright room. A cup of a coffee with some beans on the side. They are placed on a dark wooden table.\""
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": null,
73
+ "metadata": {},
74
+ "outputs": [],
75
+ "source": [
76
+ "condition = Condition(\"canny\", image)\n",
77
+ "\n",
78
+ "seed_everything()\n",
79
+ "\n",
80
+ "result_img = generate(\n",
81
+ " pipe,\n",
82
+ " prompt=prompt,\n",
83
+ " conditions=[condition],\n",
84
+ ").images[0]\n",
85
+ "\n",
86
+ "concat_image = Image.new(\"RGB\", (1536, 512))\n",
87
+ "concat_image.paste(image, (0, 0))\n",
88
+ "concat_image.paste(condition.condition, (512, 0))\n",
89
+ "concat_image.paste(result_img, (1024, 0))\n",
90
+ "concat_image"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": null,
96
+ "metadata": {},
97
+ "outputs": [],
98
+ "source": [
99
+ "condition = Condition(\"depth\", image)\n",
100
+ "\n",
101
+ "seed_everything()\n",
102
+ "\n",
103
+ "result_img = generate(\n",
104
+ " pipe,\n",
105
+ " prompt=prompt,\n",
106
+ " conditions=[condition],\n",
107
+ ").images[0]\n",
108
+ "\n",
109
+ "concat_image = Image.new(\"RGB\", (1536, 512))\n",
110
+ "concat_image.paste(image, (0, 0))\n",
111
+ "concat_image.paste(condition.condition, (512, 0))\n",
112
+ "concat_image.paste(result_img, (1024, 0))\n",
113
+ "concat_image"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "code",
118
+ "execution_count": null,
119
+ "metadata": {},
120
+ "outputs": [],
121
+ "source": [
122
+ "condition = Condition(\"deblurring\", image)\n",
123
+ "\n",
124
+ "seed_everything()\n",
125
+ "\n",
126
+ "result_img = generate(\n",
127
+ " pipe,\n",
128
+ " prompt=prompt,\n",
129
+ " conditions=[condition],\n",
130
+ ").images[0]\n",
131
+ "\n",
132
+ "concat_image = Image.new(\"RGB\", (1536, 512))\n",
133
+ "concat_image.paste(image, (0, 0))\n",
134
+ "concat_image.paste(condition.condition, (512, 0))\n",
135
+ "concat_image.paste(result_img, (1024, 0))\n",
136
+ "concat_image"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "execution_count": null,
142
+ "metadata": {},
143
+ "outputs": [],
144
+ "source": [
145
+ "condition = Condition(\"coloring\", image)\n",
146
+ "\n",
147
+ "seed_everything()\n",
148
+ "\n",
149
+ "result_img = generate(\n",
150
+ " pipe,\n",
151
+ " prompt=prompt,\n",
152
+ " conditions=[condition],\n",
153
+ ").images[0]\n",
154
+ "\n",
155
+ "concat_image = Image.new(\"RGB\", (1536, 512))\n",
156
+ "concat_image.paste(image, (0, 0))\n",
157
+ "concat_image.paste(condition.condition, (512, 0))\n",
158
+ "concat_image.paste(result_img, (1024, 0))\n",
159
+ "concat_image"
160
+ ]
161
+ }
162
+ ],
163
+ "metadata": {
164
+ "kernelspec": {
165
+ "display_name": "base",
166
+ "language": "python",
167
+ "name": "python3"
168
+ },
169
+ "language_info": {
170
+ "codemirror_mode": {
171
+ "name": "ipython",
172
+ "version": 3
173
+ },
174
+ "file_extension": ".py",
175
+ "mimetype": "text/x-python",
176
+ "name": "python",
177
+ "nbconvert_exporter": "python",
178
+ "pygments_lexer": "ipython3",
179
+ "version": "3.12.7"
180
+ }
181
+ },
182
+ "nbformat": 4,
183
+ "nbformat_minor": 2
184
+ }
examples/subject.ipynb ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "\n",
11
+ "os.chdir(\"..\")"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": null,
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "import torch\n",
21
+ "from diffusers.pipelines import FluxPipeline\n",
22
+ "from src.condition import Condition\n",
23
+ "from PIL import Image\n",
24
+ "\n",
25
+ "from src.generate import generate, seed_everything"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "pipe = FluxPipeline.from_pretrained(\n",
35
+ " \"black-forest-labs/FLUX.1-schnell\", torch_dtype=torch.bfloat16\n",
36
+ ")\n",
37
+ "pipe = pipe.to(\"cuda\")\n",
38
+ "pipe.load_lora_weights(\n",
39
+ " \"Yuanshi/OminiControl\",\n",
40
+ " weight_name=f\"omini/subject_512.safetensors\",\n",
41
+ " adapter_name=\"subject\",\n",
42
+ ")"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": null,
48
+ "metadata": {},
49
+ "outputs": [],
50
+ "source": [
51
+ "image = Image.open(\"assets/penguin.jpg\").convert(\"RGB\").resize((512, 512))\n",
52
+ "\n",
53
+ "condition = Condition(\"subject\", image)\n",
54
+ "\n",
55
+ "prompt = \"On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat.\"\n",
56
+ "\n",
57
+ "\n",
58
+ "seed_everything(0)\n",
59
+ "\n",
60
+ "result_img = generate(\n",
61
+ " pipe,\n",
62
+ " prompt=prompt,\n",
63
+ " conditions=[condition],\n",
64
+ " num_inference_steps=8,\n",
65
+ " height=512,\n",
66
+ " width=512,\n",
67
+ ").images[0]\n",
68
+ "\n",
69
+ "concat_image = Image.new(\"RGB\", (1024, 512))\n",
70
+ "concat_image.paste(image, (0, 0))\n",
71
+ "concat_image.paste(result_img, (512, 0))\n",
72
+ "concat_image"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": null,
78
+ "metadata": {},
79
+ "outputs": [],
80
+ "source": [
81
+ "image = Image.open(\"assets/tshirt.jpg\").convert(\"RGB\").resize((512, 512))\n",
82
+ "\n",
83
+ "condition = Condition(\"subject\", image)\n",
84
+ "\n",
85
+ "prompt = \"On the beach, a lady sits under a beach umbrella. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her. The sun is setting in the background. The sky is a beautiful shade of orange and purple.\"\n",
86
+ "\n",
87
+ "\n",
88
+ "seed_everything()\n",
89
+ "\n",
90
+ "result_img = generate(\n",
91
+ " pipe,\n",
92
+ " prompt=prompt,\n",
93
+ " conditions=[condition],\n",
94
+ " num_inference_steps=8,\n",
95
+ " height=512,\n",
96
+ " width=512,\n",
97
+ ").images[0]\n",
98
+ "\n",
99
+ "concat_image = Image.new(\"RGB\", (1024, 512))\n",
100
+ "concat_image.paste(condition.condition, (0, 0))\n",
101
+ "concat_image.paste(result_img, (512, 0))\n",
102
+ "concat_image"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": null,
108
+ "metadata": {},
109
+ "outputs": [],
110
+ "source": [
111
+ "image = Image.open(\"assets/rc_car.jpg\").convert(\"RGB\").resize((512, 512))\n",
112
+ "\n",
113
+ "condition = Condition(\"subject\", image)\n",
114
+ "\n",
115
+ "prompt = \"A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.\"\n",
116
+ "\n",
117
+ "seed_everything()\n",
118
+ "\n",
119
+ "result_img = generate(\n",
120
+ " pipe,\n",
121
+ " prompt=prompt,\n",
122
+ " conditions=[condition],\n",
123
+ " num_inference_steps=8,\n",
124
+ " height=512,\n",
125
+ " width=512,\n",
126
+ ").images[0]\n",
127
+ "\n",
128
+ "concat_image = Image.new(\"RGB\", (1024, 512))\n",
129
+ "concat_image.paste(condition.condition, (0, 0))\n",
130
+ "concat_image.paste(result_img, (512, 0))\n",
131
+ "concat_image"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "execution_count": null,
137
+ "metadata": {},
138
+ "outputs": [],
139
+ "source": [
140
+ "image = Image.open(\"assets/clock.jpg\").convert(\"RGB\").resize((512, 512))\n",
141
+ "\n",
142
+ "condition = Condition(\"subject\", image)\n",
143
+ "\n",
144
+ "prompt = \"In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.\"\n",
145
+ "\n",
146
+ "seed_everything()\n",
147
+ "\n",
148
+ "result_img = generate(\n",
149
+ " pipe,\n",
150
+ " prompt=prompt,\n",
151
+ " conditions=[condition],\n",
152
+ " num_inference_steps=8,\n",
153
+ " height=512,\n",
154
+ " width=512,\n",
155
+ ").images[0]\n",
156
+ "\n",
157
+ "concat_image = Image.new(\"RGB\", (1024, 512))\n",
158
+ "concat_image.paste(condition.condition, (0, 0))\n",
159
+ "concat_image.paste(result_img, (512, 0))\n",
160
+ "concat_image"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "code",
165
+ "execution_count": null,
166
+ "metadata": {},
167
+ "outputs": [],
168
+ "source": [
169
+ "image = Image.open(\"assets/oranges.jpg\").convert(\"RGB\").resize((512, 512))\n",
170
+ "\n",
171
+ "condition = Condition(\"subject\", image)\n",
172
+ "\n",
173
+ "prompt = \"A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show.\"\n",
174
+ "\n",
175
+ "seed_everything()\n",
176
+ "\n",
177
+ "result_img = generate(\n",
178
+ " pipe,\n",
179
+ " prompt=prompt,\n",
180
+ " conditions=[condition],\n",
181
+ " num_inference_steps=8,\n",
182
+ " height=512,\n",
183
+ " width=512,\n",
184
+ ").images[0]\n",
185
+ "\n",
186
+ "concat_image = Image.new(\"RGB\", (1024, 512))\n",
187
+ "concat_image.paste(condition.condition, (0, 0))\n",
188
+ "concat_image.paste(result_img, (512, 0))\n",
189
+ "concat_image"
190
+ ]
191
+ }
192
+ ],
193
+ "metadata": {
194
+ "kernelspec": {
195
+ "display_name": "base",
196
+ "language": "python",
197
+ "name": "python3"
198
+ },
199
+ "language_info": {
200
+ "codemirror_mode": {
201
+ "name": "ipython",
202
+ "version": 3
203
+ },
204
+ "file_extension": ".py",
205
+ "mimetype": "text/x-python",
206
+ "name": "python",
207
+ "nbconvert_exporter": "python",
208
+ "pygments_lexer": "ipython3",
209
+ "version": "3.12.7"
210
+ }
211
+ },
212
+ "nbformat": 4,
213
+ "nbformat_minor": 2
214
+ }
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ transformers
2
+ diffusers
3
+ peft
4
+ opencv-python
5
+ protobuf
6
+ sentencepiece
src/block.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import List, Union, Optional, Dict, Any, Callable
3
+ from diffusers.models.attention_processor import Attention, F
4
+ from .lora_controller import enable_lora
5
+
6
+
7
+ def attn_forward(
8
+ attn: Attention,
9
+ hidden_states: torch.FloatTensor,
10
+ encoder_hidden_states: torch.FloatTensor = None,
11
+ condition_latents: torch.FloatTensor = None,
12
+ attention_mask: Optional[torch.FloatTensor] = None,
13
+ image_rotary_emb: Optional[torch.Tensor] = None,
14
+ cond_rotary_emb: Optional[torch.Tensor] = None,
15
+ model_config: Optional[Dict[str, Any]] = {},
16
+ ) -> torch.FloatTensor:
17
+ batch_size, _, _ = (
18
+ hidden_states.shape
19
+ if encoder_hidden_states is None
20
+ else encoder_hidden_states.shape
21
+ )
22
+
23
+ with enable_lora(
24
+ (attn.to_q, attn.to_k, attn.to_v), model_config.get("latent_lora", False)
25
+ ):
26
+ # `sample` projections.
27
+ query = attn.to_q(hidden_states)
28
+ key = attn.to_k(hidden_states)
29
+ value = attn.to_v(hidden_states)
30
+
31
+ inner_dim = key.shape[-1]
32
+ head_dim = inner_dim // attn.heads
33
+
34
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
35
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
36
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
37
+
38
+ if attn.norm_q is not None:
39
+ query = attn.norm_q(query)
40
+ if attn.norm_k is not None:
41
+ key = attn.norm_k(key)
42
+
43
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
44
+ if encoder_hidden_states is not None:
45
+ # `context` projections.
46
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
47
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
48
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
49
+
50
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
51
+ batch_size, -1, attn.heads, head_dim
52
+ ).transpose(1, 2)
53
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
54
+ batch_size, -1, attn.heads, head_dim
55
+ ).transpose(1, 2)
56
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
57
+ batch_size, -1, attn.heads, head_dim
58
+ ).transpose(1, 2)
59
+
60
+ if attn.norm_added_q is not None:
61
+ encoder_hidden_states_query_proj = attn.norm_added_q(
62
+ encoder_hidden_states_query_proj
63
+ )
64
+ if attn.norm_added_k is not None:
65
+ encoder_hidden_states_key_proj = attn.norm_added_k(
66
+ encoder_hidden_states_key_proj
67
+ )
68
+
69
+ # attention
70
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
71
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
72
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
73
+
74
+ if image_rotary_emb is not None:
75
+ from diffusers.models.embeddings import apply_rotary_emb
76
+
77
+ query = apply_rotary_emb(query, image_rotary_emb)
78
+ key = apply_rotary_emb(key, image_rotary_emb)
79
+
80
+ if condition_latents is not None:
81
+ cond_query = attn.to_q(condition_latents)
82
+ cond_key = attn.to_k(condition_latents)
83
+ cond_value = attn.to_v(condition_latents)
84
+
85
+ cond_query = cond_query.view(batch_size, -1, attn.heads, head_dim).transpose(
86
+ 1, 2
87
+ )
88
+ cond_key = cond_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
89
+ cond_value = cond_value.view(batch_size, -1, attn.heads, head_dim).transpose(
90
+ 1, 2
91
+ )
92
+ if attn.norm_q is not None:
93
+ cond_query = attn.norm_q(cond_query)
94
+ if attn.norm_k is not None:
95
+ cond_key = attn.norm_k(cond_key)
96
+
97
+ if cond_rotary_emb is not None:
98
+ cond_query = apply_rotary_emb(cond_query, cond_rotary_emb)
99
+ cond_key = apply_rotary_emb(cond_key, cond_rotary_emb)
100
+
101
+ if condition_latents is not None:
102
+ query = torch.cat([query, cond_query], dim=2)
103
+ key = torch.cat([key, cond_key], dim=2)
104
+ value = torch.cat([value, cond_value], dim=2)
105
+
106
+ if not model_config.get("union_cond_attn", True):
107
+ # If we don't want to use the union condition attention, we need to mask the attention
108
+ # between the hidden states and the condition latents
109
+ attention_mask = torch.ones(
110
+ query.shape[2], key.shape[2], device=query.device, dtype=torch.bool
111
+ )
112
+ condition_n = cond_query.shape[2]
113
+ attention_mask[-condition_n:, :-condition_n] = False
114
+ attention_mask[:-condition_n, -condition_n:] = False
115
+ if hasattr(attn, "c_factor"):
116
+ attention_mask = torch.zeros(
117
+ query.shape[2], key.shape[2], device=query.device, dtype=query.dtype
118
+ )
119
+ condition_n = cond_query.shape[2]
120
+ bias = torch.log(attn.c_factor[0])
121
+ attention_mask[-condition_n:, :-condition_n] = bias
122
+ attention_mask[:-condition_n, -condition_n:] = bias
123
+ hidden_states = F.scaled_dot_product_attention(
124
+ query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask
125
+ )
126
+ hidden_states = hidden_states.transpose(1, 2).reshape(
127
+ batch_size, -1, attn.heads * head_dim
128
+ )
129
+ hidden_states = hidden_states.to(query.dtype)
130
+
131
+ if encoder_hidden_states is not None:
132
+ if condition_latents is not None:
133
+ encoder_hidden_states, hidden_states, condition_latents = (
134
+ hidden_states[:, : encoder_hidden_states.shape[1]],
135
+ hidden_states[
136
+ :, encoder_hidden_states.shape[1] : -condition_latents.shape[1]
137
+ ],
138
+ hidden_states[:, -condition_latents.shape[1] :],
139
+ )
140
+ else:
141
+ encoder_hidden_states, hidden_states = (
142
+ hidden_states[:, : encoder_hidden_states.shape[1]],
143
+ hidden_states[:, encoder_hidden_states.shape[1] :],
144
+ )
145
+
146
+ with enable_lora((attn.to_out[0],), model_config.get("latent_lora", False)):
147
+ # linear proj
148
+ hidden_states = attn.to_out[0](hidden_states)
149
+ # dropout
150
+ hidden_states = attn.to_out[1](hidden_states)
151
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
152
+
153
+ if condition_latents is not None:
154
+ condition_latents = attn.to_out[0](condition_latents)
155
+ condition_latents = attn.to_out[1](condition_latents)
156
+
157
+ return (
158
+ (hidden_states, encoder_hidden_states, condition_latents)
159
+ if condition_latents is not None
160
+ else (hidden_states, encoder_hidden_states)
161
+ )
162
+ elif condition_latents is not None:
163
+ # if there are condition_latents, we need to separate the hidden_states and the condition_latents
164
+ hidden_states, condition_latents = (
165
+ hidden_states[:, : -condition_latents.shape[1]],
166
+ hidden_states[:, -condition_latents.shape[1] :],
167
+ )
168
+ return hidden_states, condition_latents
169
+ else:
170
+ return hidden_states
171
+
172
+
173
+ def block_forward(
174
+ self,
175
+ hidden_states: torch.FloatTensor,
176
+ encoder_hidden_states: torch.FloatTensor,
177
+ condition_latents: torch.FloatTensor,
178
+ temb: torch.FloatTensor,
179
+ cond_temb: torch.FloatTensor,
180
+ cond_rotary_emb=None,
181
+ image_rotary_emb=None,
182
+ model_config: Optional[Dict[str, Any]] = {},
183
+ ):
184
+ use_cond = condition_latents is not None
185
+ with enable_lora((self.norm1.linear,), model_config.get("latent_lora", False)):
186
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
187
+ hidden_states, emb=temb
188
+ )
189
+
190
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
191
+ self.norm1_context(encoder_hidden_states, emb=temb)
192
+ )
193
+
194
+ if use_cond:
195
+ (
196
+ norm_condition_latents,
197
+ cond_gate_msa,
198
+ cond_shift_mlp,
199
+ cond_scale_mlp,
200
+ cond_gate_mlp,
201
+ ) = self.norm1(condition_latents, emb=cond_temb)
202
+
203
+ # Attention.
204
+ result = attn_forward(
205
+ self.attn,
206
+ model_config=model_config,
207
+ hidden_states=norm_hidden_states,
208
+ encoder_hidden_states=norm_encoder_hidden_states,
209
+ condition_latents=norm_condition_latents if use_cond else None,
210
+ image_rotary_emb=image_rotary_emb,
211
+ cond_rotary_emb=cond_rotary_emb if use_cond else None,
212
+ )
213
+ attn_output, context_attn_output = result[:2]
214
+ cond_attn_output = result[2] if use_cond else None
215
+
216
+ # Process attention outputs for the `hidden_states`.
217
+ # 1. hidden_states
218
+ attn_output = gate_msa.unsqueeze(1) * attn_output
219
+ hidden_states = hidden_states + attn_output
220
+ # 2. encoder_hidden_states
221
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
222
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
223
+ # 3. condition_latents
224
+ if use_cond:
225
+ cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output
226
+ condition_latents = condition_latents + cond_attn_output
227
+ if model_config.get("add_cond_attn", False):
228
+ hidden_states += cond_attn_output
229
+
230
+ # LayerNorm + MLP.
231
+ # 1. hidden_states
232
+ norm_hidden_states = self.norm2(hidden_states)
233
+ norm_hidden_states = (
234
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
235
+ )
236
+ # 2. encoder_hidden_states
237
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
238
+ norm_encoder_hidden_states = (
239
+ norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
240
+ )
241
+ # 3. condition_latents
242
+ if use_cond:
243
+ norm_condition_latents = self.norm2(condition_latents)
244
+ norm_condition_latents = (
245
+ norm_condition_latents * (1 + cond_scale_mlp[:, None])
246
+ + cond_shift_mlp[:, None]
247
+ )
248
+
249
+ # Feed-forward.
250
+ with enable_lora((self.ff.net[2],), model_config.get("latent_lora", False)):
251
+ # 1. hidden_states
252
+ ff_output = self.ff(norm_hidden_states)
253
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
254
+ # 2. encoder_hidden_states
255
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
256
+ context_ff_output = c_gate_mlp.unsqueeze(1) * context_ff_output
257
+ # 3. condition_latents
258
+ if use_cond:
259
+ cond_ff_output = self.ff(norm_condition_latents)
260
+ cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output
261
+
262
+ # Process feed-forward outputs.
263
+ hidden_states = hidden_states + ff_output
264
+ encoder_hidden_states = encoder_hidden_states + context_ff_output
265
+ if use_cond:
266
+ condition_latents = condition_latents + cond_ff_output
267
+
268
+ # Clip to avoid overflow.
269
+ if encoder_hidden_states.dtype == torch.float16:
270
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
271
+
272
+ return encoder_hidden_states, hidden_states, condition_latents if use_cond else None
273
+
274
+
275
+ def single_block_forward(
276
+ self,
277
+ hidden_states: torch.FloatTensor,
278
+ temb: torch.FloatTensor,
279
+ image_rotary_emb=None,
280
+ condition_latents: torch.FloatTensor = None,
281
+ cond_temb: torch.FloatTensor = None,
282
+ cond_rotary_emb=None,
283
+ model_config: Optional[Dict[str, Any]] = {},
284
+ ):
285
+
286
+ using_cond = condition_latents is not None
287
+ residual = hidden_states
288
+ with enable_lora(
289
+ (
290
+ self.norm.linear,
291
+ self.proj_mlp,
292
+ ),
293
+ model_config.get("latent_lora", False),
294
+ ):
295
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
296
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
297
+ if using_cond:
298
+ residual_cond = condition_latents
299
+ norm_condition_latents, cond_gate = self.norm(condition_latents, emb=cond_temb)
300
+ mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_condition_latents))
301
+
302
+ attn_output = attn_forward(
303
+ self.attn,
304
+ model_config=model_config,
305
+ hidden_states=norm_hidden_states,
306
+ image_rotary_emb=image_rotary_emb,
307
+ **(
308
+ {
309
+ "condition_latents": norm_condition_latents,
310
+ "cond_rotary_emb": cond_rotary_emb if using_cond else None,
311
+ }
312
+ if using_cond
313
+ else {}
314
+ ),
315
+ )
316
+ if using_cond:
317
+ attn_output, cond_attn_output = attn_output
318
+
319
+ with enable_lora((self.proj_out,), model_config.get("latent_lora", False)):
320
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
321
+ gate = gate.unsqueeze(1)
322
+ hidden_states = gate * self.proj_out(hidden_states)
323
+ hidden_states = residual + hidden_states
324
+ if using_cond:
325
+ condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2)
326
+ cond_gate = cond_gate.unsqueeze(1)
327
+ condition_latents = cond_gate * self.proj_out(condition_latents)
328
+ condition_latents = residual_cond + condition_latents
329
+
330
+ if hidden_states.dtype == torch.float16:
331
+ hidden_states = hidden_states.clip(-65504, 65504)
332
+
333
+ return hidden_states if not using_cond else (hidden_states, condition_latents)
src/condition.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional, Union, List, Tuple
3
+ from diffusers.pipelines import FluxPipeline
4
+ from PIL import Image, ImageFilter
5
+ import numpy as np
6
+ import cv2
7
+
8
+ condition_dict = {
9
+ "depth": 0,
10
+ "canny": 1,
11
+ "subject": 4,
12
+ "coloring": 6,
13
+ "deblurring": 7,
14
+ "fill": 9,
15
+ }
16
+
17
+
18
+ class Condition(object):
19
+ def __init__(
20
+ self,
21
+ condition_type: str,
22
+ raw_img: Union[Image.Image, torch.Tensor] = None,
23
+ condition: Union[Image.Image, torch.Tensor] = None,
24
+ mask=None,
25
+ ) -> None:
26
+ self.condition_type = condition_type
27
+ assert raw_img is not None or condition is not None
28
+ if raw_img is not None:
29
+ self.condition = self.get_condition(condition_type, raw_img)
30
+ else:
31
+ self.condition = condition
32
+ # TODO: Add mask support
33
+ assert mask is None, "Mask not supported yet"
34
+
35
+ def get_condition(
36
+ self, condition_type: str, raw_img: Union[Image.Image, torch.Tensor]
37
+ ) -> Union[Image.Image, torch.Tensor]:
38
+ """
39
+ Returns the condition image.
40
+ """
41
+ if condition_type == "depth":
42
+ from transformers import pipeline
43
+
44
+ depth_pipe = pipeline(
45
+ task="depth-estimation",
46
+ model="LiheYoung/depth-anything-small-hf",
47
+ device="cuda",
48
+ )
49
+ source_image = raw_img.convert("RGB")
50
+ condition_img = depth_pipe(source_image)["depth"].convert("RGB")
51
+ return condition_img
52
+ elif condition_type == "canny":
53
+ img = np.array(raw_img)
54
+ edges = cv2.Canny(img, 100, 200)
55
+ edges = Image.fromarray(edges).convert("RGB")
56
+ return edges
57
+ elif condition_type == "subject":
58
+ return raw_img
59
+ elif condition_type == "coloring":
60
+ return raw_img.convert("L").convert("RGB")
61
+ elif condition_type == "deblurring":
62
+ condition_image = (
63
+ raw_img.convert("RGB")
64
+ .filter(ImageFilter.GaussianBlur(10))
65
+ .convert("RGB")
66
+ )
67
+ return condition_image
68
+ elif condition_type == "fill":
69
+ return raw_img.convert("RGB")
70
+ return self.condition
71
+
72
+ @property
73
+ def type_id(self) -> int:
74
+ """
75
+ Returns the type id of the condition.
76
+ """
77
+ return condition_dict[self.condition_type]
78
+
79
+ @classmethod
80
+ def get_type_id(cls, condition_type: str) -> int:
81
+ """
82
+ Returns the type id of the condition.
83
+ """
84
+ return condition_dict[condition_type]
85
+
86
+ def _encode_image(self, pipe: FluxPipeline, cond_img: Image.Image) -> torch.Tensor:
87
+ """
88
+ Encodes an image condition into tokens using the pipeline.
89
+ """
90
+ cond_img = pipe.image_processor.preprocess(cond_img)
91
+ cond_img = cond_img.to(pipe.device).to(pipe.dtype)
92
+ cond_img = pipe.vae.encode(cond_img).latent_dist.sample()
93
+ cond_img = (
94
+ cond_img - pipe.vae.config.shift_factor
95
+ ) * pipe.vae.config.scaling_factor
96
+ cond_tokens = pipe._pack_latents(cond_img, *cond_img.shape)
97
+ cond_ids = pipe._prepare_latent_image_ids(
98
+ cond_img.shape[0],
99
+ cond_img.shape[2],
100
+ cond_img.shape[3],
101
+ pipe.device,
102
+ pipe.dtype,
103
+ )
104
+ return cond_tokens, cond_ids
105
+
106
+ def encode(self, pipe: FluxPipeline) -> Tuple[torch.Tensor, torch.Tensor, int]:
107
+ """
108
+ Encodes the condition into tokens, ids and type_id.
109
+ """
110
+ if self.condition_type in [
111
+ "depth",
112
+ "canny",
113
+ "subject",
114
+ "coloring",
115
+ "deblurring",
116
+ "fill",
117
+ ]:
118
+ tokens, ids = self._encode_image(pipe, self.condition)
119
+ else:
120
+ raise NotImplementedError(
121
+ f"Condition type {self.condition_type} not implemented"
122
+ )
123
+ type_id = torch.ones_like(ids[:, :1]) * self.type_id
124
+ return tokens, ids, type_id
src/generate.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import yaml, os
3
+ from diffusers.pipelines import FluxPipeline
4
+ from typing import List, Union, Optional, Dict, Any, Callable
5
+ from .transformer import tranformer_forward
6
+ from .condition import Condition
7
+
8
+ from diffusers.pipelines.flux.pipeline_flux import (
9
+ FluxPipelineOutput,
10
+ calculate_shift,
11
+ retrieve_timesteps,
12
+ np,
13
+ )
14
+
15
+
16
+ def prepare_params(
17
+ prompt: Union[str, List[str]] = None,
18
+ prompt_2: Optional[Union[str, List[str]]] = None,
19
+ height: Optional[int] = 512,
20
+ width: Optional[int] = 512,
21
+ num_inference_steps: int = 28,
22
+ timesteps: List[int] = None,
23
+ guidance_scale: float = 3.5,
24
+ num_images_per_prompt: Optional[int] = 1,
25
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
26
+ latents: Optional[torch.FloatTensor] = None,
27
+ prompt_embeds: Optional[torch.FloatTensor] = None,
28
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
29
+ output_type: Optional[str] = "pil",
30
+ return_dict: bool = True,
31
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
32
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
33
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
34
+ max_sequence_length: int = 512,
35
+ **kwargs: dict,
36
+ ):
37
+ return (
38
+ prompt,
39
+ prompt_2,
40
+ height,
41
+ width,
42
+ num_inference_steps,
43
+ timesteps,
44
+ guidance_scale,
45
+ num_images_per_prompt,
46
+ generator,
47
+ latents,
48
+ prompt_embeds,
49
+ pooled_prompt_embeds,
50
+ output_type,
51
+ return_dict,
52
+ joint_attention_kwargs,
53
+ callback_on_step_end,
54
+ callback_on_step_end_tensor_inputs,
55
+ max_sequence_length,
56
+ )
57
+
58
+
59
+ def seed_everything(seed: int = 42):
60
+ torch.backends.cudnn.deterministic = True
61
+ torch.manual_seed(seed)
62
+ np.random.seed(seed)
63
+
64
+
65
+ @torch.no_grad()
66
+ def generate(
67
+ pipeline: FluxPipeline,
68
+ conditions: List[Condition] = None,
69
+ model_config: Optional[Dict[str, Any]] = {},
70
+ condition_scale: float = 1.0,
71
+ **params: dict,
72
+ ):
73
+ # model_config = model_config or get_config(config_path).get("model", {})
74
+ if condition_scale != 1:
75
+ for name, module in pipeline.transformer.named_modules():
76
+ if not name.endswith(".attn"):
77
+ continue
78
+ module.c_factor = torch.ones(1, 1) * condition_scale
79
+
80
+ self = pipeline
81
+ (
82
+ prompt,
83
+ prompt_2,
84
+ height,
85
+ width,
86
+ num_inference_steps,
87
+ timesteps,
88
+ guidance_scale,
89
+ num_images_per_prompt,
90
+ generator,
91
+ latents,
92
+ prompt_embeds,
93
+ pooled_prompt_embeds,
94
+ output_type,
95
+ return_dict,
96
+ joint_attention_kwargs,
97
+ callback_on_step_end,
98
+ callback_on_step_end_tensor_inputs,
99
+ max_sequence_length,
100
+ ) = prepare_params(**params)
101
+
102
+ height = height or self.default_sample_size * self.vae_scale_factor
103
+ width = width or self.default_sample_size * self.vae_scale_factor
104
+
105
+ # 1. Check inputs. Raise error if not correct
106
+ self.check_inputs(
107
+ prompt,
108
+ prompt_2,
109
+ height,
110
+ width,
111
+ prompt_embeds=prompt_embeds,
112
+ pooled_prompt_embeds=pooled_prompt_embeds,
113
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
114
+ max_sequence_length=max_sequence_length,
115
+ )
116
+
117
+ self._guidance_scale = guidance_scale
118
+ self._joint_attention_kwargs = joint_attention_kwargs
119
+ self._interrupt = False
120
+
121
+ # 2. Define call parameters
122
+ if prompt is not None and isinstance(prompt, str):
123
+ batch_size = 1
124
+ elif prompt is not None and isinstance(prompt, list):
125
+ batch_size = len(prompt)
126
+ else:
127
+ batch_size = prompt_embeds.shape[0]
128
+
129
+ device = self._execution_device
130
+
131
+ lora_scale = (
132
+ self.joint_attention_kwargs.get("scale", None)
133
+ if self.joint_attention_kwargs is not None
134
+ else None
135
+ )
136
+ (
137
+ prompt_embeds,
138
+ pooled_prompt_embeds,
139
+ text_ids,
140
+ ) = self.encode_prompt(
141
+ prompt=prompt,
142
+ prompt_2=prompt_2,
143
+ prompt_embeds=prompt_embeds,
144
+ pooled_prompt_embeds=pooled_prompt_embeds,
145
+ device=device,
146
+ num_images_per_prompt=num_images_per_prompt,
147
+ max_sequence_length=max_sequence_length,
148
+ lora_scale=lora_scale,
149
+ )
150
+
151
+ # 4. Prepare latent variables
152
+ num_channels_latents = self.transformer.config.in_channels // 4
153
+ latents, latent_image_ids = self.prepare_latents(
154
+ batch_size * num_images_per_prompt,
155
+ num_channels_latents,
156
+ height,
157
+ width,
158
+ prompt_embeds.dtype,
159
+ device,
160
+ generator,
161
+ latents,
162
+ )
163
+
164
+ # 4.1. Prepare conditions
165
+ condition_latents, condition_ids, condition_type_ids = ([] for _ in range(3))
166
+ use_condition = conditions is not None or []
167
+ if use_condition:
168
+ assert len(conditions) <= 1, "Only one condition is supported for now."
169
+ pipeline.set_adapters(conditions[0].condition_type)
170
+ for condition in conditions:
171
+ tokens, ids, type_id = condition.encode(self)
172
+ condition_latents.append(tokens) # [batch_size, token_n, token_dim]
173
+ condition_ids.append(ids) # [token_n, id_dim(3)]
174
+ condition_type_ids.append(type_id) # [token_n, 1]
175
+ condition_latents = torch.cat(condition_latents, dim=1)
176
+ condition_ids = torch.cat(condition_ids, dim=0)
177
+ if condition.condition_type == "subject":
178
+ condition_ids[:, 2] += width // 16
179
+ condition_type_ids = torch.cat(condition_type_ids, dim=0)
180
+
181
+ # 5. Prepare timesteps
182
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
183
+ image_seq_len = latents.shape[1]
184
+ mu = calculate_shift(
185
+ image_seq_len,
186
+ self.scheduler.config.base_image_seq_len,
187
+ self.scheduler.config.max_image_seq_len,
188
+ self.scheduler.config.base_shift,
189
+ self.scheduler.config.max_shift,
190
+ )
191
+ timesteps, num_inference_steps = retrieve_timesteps(
192
+ self.scheduler,
193
+ num_inference_steps,
194
+ device,
195
+ timesteps,
196
+ sigmas,
197
+ mu=mu,
198
+ )
199
+ num_warmup_steps = max(
200
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
201
+ )
202
+ self._num_timesteps = len(timesteps)
203
+
204
+ # 6. Denoising loop
205
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
206
+ for i, t in enumerate(timesteps):
207
+ if self.interrupt:
208
+ continue
209
+
210
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
211
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
212
+
213
+ # handle guidance
214
+ if self.transformer.config.guidance_embeds:
215
+ guidance = torch.tensor([guidance_scale], device=device)
216
+ guidance = guidance.expand(latents.shape[0])
217
+ else:
218
+ guidance = None
219
+ noise_pred = tranformer_forward(
220
+ self.transformer,
221
+ model_config=model_config,
222
+ # Inputs of the condition (new feature)
223
+ condition_latents=condition_latents if use_condition else None,
224
+ condition_ids=condition_ids if use_condition else None,
225
+ condition_type_ids=condition_type_ids if use_condition else None,
226
+ # Inputs to the original transformer
227
+ hidden_states=latents,
228
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
229
+ timestep=timestep / 1000,
230
+ guidance=guidance,
231
+ pooled_projections=pooled_prompt_embeds,
232
+ encoder_hidden_states=prompt_embeds,
233
+ txt_ids=text_ids,
234
+ img_ids=latent_image_ids,
235
+ joint_attention_kwargs=self.joint_attention_kwargs,
236
+ return_dict=False,
237
+ )[0]
238
+
239
+ # compute the previous noisy sample x_t -> x_t-1
240
+ latents_dtype = latents.dtype
241
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
242
+
243
+ if latents.dtype != latents_dtype:
244
+ if torch.backends.mps.is_available():
245
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
246
+ latents = latents.to(latents_dtype)
247
+
248
+ if callback_on_step_end is not None:
249
+ callback_kwargs = {}
250
+ for k in callback_on_step_end_tensor_inputs:
251
+ callback_kwargs[k] = locals()[k]
252
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
253
+
254
+ latents = callback_outputs.pop("latents", latents)
255
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
256
+
257
+ # call the callback, if provided
258
+ if i == len(timesteps) - 1 or (
259
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
260
+ ):
261
+ progress_bar.update()
262
+
263
+ if output_type == "latent":
264
+ image = latents
265
+
266
+ else:
267
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
268
+ latents = (
269
+ latents / self.vae.config.scaling_factor
270
+ ) + self.vae.config.shift_factor
271
+ image = self.vae.decode(latents, return_dict=False)[0]
272
+ image = self.image_processor.postprocess(image, output_type=output_type)
273
+
274
+ # Offload all models
275
+ self.maybe_free_model_hooks()
276
+
277
+ if condition_scale != 1:
278
+ for name, module in pipeline.transformer.named_modules():
279
+ if not name.endswith(".attn"):
280
+ continue
281
+ del module.c_factor
282
+
283
+ if not return_dict:
284
+ return (image,)
285
+
286
+ return FluxPipelineOutput(images=image)
src/lora_controller.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from peft.tuners.tuners_utils import BaseTunerLayer
2
+ from typing import List, Any, Optional, Type
3
+
4
+
5
+ class enable_lora:
6
+ def __init__(self, lora_modules: List[BaseTunerLayer], activated: bool) -> None:
7
+ self.activated: bool = activated
8
+ if activated:
9
+ return
10
+ self.lora_modules: List[BaseTunerLayer] = [
11
+ each for each in lora_modules if isinstance(each, BaseTunerLayer)
12
+ ]
13
+ self.scales = [
14
+ {
15
+ active_adapter: lora_module.scaling[active_adapter]
16
+ for active_adapter in lora_module.active_adapters
17
+ }
18
+ for lora_module in self.lora_modules
19
+ ]
20
+
21
+ def __enter__(self) -> None:
22
+ if self.activated:
23
+ return
24
+
25
+ for lora_module in self.lora_modules:
26
+ if not isinstance(lora_module, BaseTunerLayer):
27
+ continue
28
+ lora_module.scale_layer(0)
29
+
30
+ def __exit__(
31
+ self,
32
+ exc_type: Optional[Type[BaseException]],
33
+ exc_val: Optional[BaseException],
34
+ exc_tb: Optional[Any],
35
+ ) -> None:
36
+ if self.activated:
37
+ return
38
+ for i, lora_module in enumerate(self.lora_modules):
39
+ if not isinstance(lora_module, BaseTunerLayer):
40
+ continue
41
+ for active_adapter in lora_module.active_adapters:
42
+ lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
43
+
44
+
45
+ class set_lora_scale:
46
+ def __init__(self, lora_modules: List[BaseTunerLayer], scale: float) -> None:
47
+ self.lora_modules: List[BaseTunerLayer] = [
48
+ each for each in lora_modules if isinstance(each, BaseTunerLayer)
49
+ ]
50
+ self.scales = [
51
+ {
52
+ active_adapter: lora_module.scaling[active_adapter]
53
+ for active_adapter in lora_module.active_adapters
54
+ }
55
+ for lora_module in self.lora_modules
56
+ ]
57
+ self.scale = scale
58
+
59
+ def __enter__(self) -> None:
60
+ for lora_module in self.lora_modules:
61
+ if not isinstance(lora_module, BaseTunerLayer):
62
+ continue
63
+ lora_module.scale_layer(self.scale)
64
+
65
+ def __exit__(
66
+ self,
67
+ exc_type: Optional[Type[BaseException]],
68
+ exc_val: Optional[BaseException],
69
+ exc_tb: Optional[Any],
70
+ ) -> None:
71
+ for i, lora_module in enumerate(self.lora_modules):
72
+ if not isinstance(lora_module, BaseTunerLayer):
73
+ continue
74
+ for active_adapter in lora_module.active_adapters:
75
+ lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
src/transformer.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.pipelines import FluxPipeline
3
+ from typing import List, Union, Optional, Dict, Any, Callable
4
+ from .block import block_forward, single_block_forward
5
+ from .lora_controller import enable_lora
6
+ from diffusers.models.transformers.transformer_flux import (
7
+ FluxTransformer2DModel,
8
+ Transformer2DModelOutput,
9
+ USE_PEFT_BACKEND,
10
+ is_torch_version,
11
+ scale_lora_layers,
12
+ unscale_lora_layers,
13
+ logger,
14
+ )
15
+ import numpy as np
16
+
17
+
18
+ def prepare_params(
19
+ hidden_states: torch.Tensor,
20
+ encoder_hidden_states: torch.Tensor = None,
21
+ pooled_projections: torch.Tensor = None,
22
+ timestep: torch.LongTensor = None,
23
+ img_ids: torch.Tensor = None,
24
+ txt_ids: torch.Tensor = None,
25
+ guidance: torch.Tensor = None,
26
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
27
+ controlnet_block_samples=None,
28
+ controlnet_single_block_samples=None,
29
+ return_dict: bool = True,
30
+ **kwargs: dict,
31
+ ):
32
+ return (
33
+ hidden_states,
34
+ encoder_hidden_states,
35
+ pooled_projections,
36
+ timestep,
37
+ img_ids,
38
+ txt_ids,
39
+ guidance,
40
+ joint_attention_kwargs,
41
+ controlnet_block_samples,
42
+ controlnet_single_block_samples,
43
+ return_dict,
44
+ )
45
+
46
+
47
+ def tranformer_forward(
48
+ transformer: FluxTransformer2DModel,
49
+ condition_latents: torch.Tensor,
50
+ condition_ids: torch.Tensor,
51
+ condition_type_ids: torch.Tensor,
52
+ model_config: Optional[Dict[str, Any]] = {},
53
+ return_conditional_latents: bool = False,
54
+ c_t=0,
55
+ **params: dict,
56
+ ):
57
+ self = transformer
58
+ use_condition = condition_latents is not None
59
+ use_condition_in_single_blocks = model_config.get(
60
+ "use_condition_in_single_blocks", True
61
+ )
62
+ # if return_conditional_latents is True, use_condition and use_condition_in_single_blocks must be True
63
+ assert not return_conditional_latents or (
64
+ use_condition and use_condition_in_single_blocks
65
+ ), "`return_conditional_latents` is True, `use_condition` and `use_condition_in_single_blocks` must be True"
66
+
67
+ (
68
+ hidden_states,
69
+ encoder_hidden_states,
70
+ pooled_projections,
71
+ timestep,
72
+ img_ids,
73
+ txt_ids,
74
+ guidance,
75
+ joint_attention_kwargs,
76
+ controlnet_block_samples,
77
+ controlnet_single_block_samples,
78
+ return_dict,
79
+ ) = prepare_params(**params)
80
+
81
+ if joint_attention_kwargs is not None:
82
+ joint_attention_kwargs = joint_attention_kwargs.copy()
83
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
84
+ else:
85
+ lora_scale = 1.0
86
+
87
+ if USE_PEFT_BACKEND:
88
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
89
+ scale_lora_layers(self, lora_scale)
90
+ else:
91
+ if (
92
+ joint_attention_kwargs is not None
93
+ and joint_attention_kwargs.get("scale", None) is not None
94
+ ):
95
+ logger.warning(
96
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
97
+ )
98
+ with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)):
99
+ hidden_states = self.x_embedder(hidden_states)
100
+ condition_latents = self.x_embedder(condition_latents) if use_condition else None
101
+
102
+ timestep = timestep.to(hidden_states.dtype) * 1000
103
+ if guidance is not None:
104
+ guidance = guidance.to(hidden_states.dtype) * 1000
105
+ else:
106
+ guidance = None
107
+ temb = (
108
+ self.time_text_embed(timestep, pooled_projections)
109
+ if guidance is None
110
+ else self.time_text_embed(timestep, guidance, pooled_projections)
111
+ )
112
+ cond_temb = (
113
+ self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections)
114
+ if guidance is None
115
+ else self.time_text_embed(
116
+ torch.ones_like(timestep) * c_t * 1000, guidance, pooled_projections
117
+ )
118
+ )
119
+ if hasattr(self, "cond_type_embed") and condition_type_ids is not None:
120
+ cond_type_proj = self.time_text_embed.time_proj(condition_type_ids[0])
121
+ cond_type_emb = self.cond_type_embed(cond_type_proj.to(dtype=cond_temb.dtype))
122
+ cond_temb = cond_temb + cond_type_emb
123
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
124
+
125
+ if txt_ids.ndim == 3:
126
+ logger.warning(
127
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
128
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
129
+ )
130
+ txt_ids = txt_ids[0]
131
+ if img_ids.ndim == 3:
132
+ logger.warning(
133
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
134
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
135
+ )
136
+ img_ids = img_ids[0]
137
+
138
+ ids = torch.cat((txt_ids, img_ids), dim=0)
139
+ image_rotary_emb = self.pos_embed(ids)
140
+ if use_condition:
141
+ cond_ids = condition_ids
142
+ cond_rotary_emb = self.pos_embed(cond_ids)
143
+
144
+ # hidden_states = torch.cat([hidden_states, condition_latents], dim=1)
145
+
146
+ for index_block, block in enumerate(self.transformer_blocks):
147
+ if self.training and self.gradient_checkpointing:
148
+
149
+ def create_custom_forward(module, return_dict=None):
150
+ def custom_forward(*inputs):
151
+ if return_dict is not None:
152
+ return module(*inputs, return_dict=return_dict)
153
+ else:
154
+ return module(*inputs)
155
+
156
+ return custom_forward
157
+
158
+ ckpt_kwargs: Dict[str, Any] = (
159
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
160
+ )
161
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
162
+ create_custom_forward(block),
163
+ hidden_states,
164
+ encoder_hidden_states,
165
+ temb,
166
+ image_rotary_emb,
167
+ **ckpt_kwargs,
168
+ )
169
+
170
+ else:
171
+ encoder_hidden_states, hidden_states, condition_latents = block_forward(
172
+ block,
173
+ model_config=model_config,
174
+ hidden_states=hidden_states,
175
+ encoder_hidden_states=encoder_hidden_states,
176
+ condition_latents=condition_latents if use_condition else None,
177
+ temb=temb,
178
+ cond_temb=cond_temb if use_condition else None,
179
+ cond_rotary_emb=cond_rotary_emb if use_condition else None,
180
+ image_rotary_emb=image_rotary_emb,
181
+ )
182
+
183
+ # controlnet residual
184
+ if controlnet_block_samples is not None:
185
+ interval_control = len(self.transformer_blocks) / len(
186
+ controlnet_block_samples
187
+ )
188
+ interval_control = int(np.ceil(interval_control))
189
+ hidden_states = (
190
+ hidden_states
191
+ + controlnet_block_samples[index_block // interval_control]
192
+ )
193
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
194
+
195
+ for index_block, block in enumerate(self.single_transformer_blocks):
196
+ if self.training and self.gradient_checkpointing:
197
+
198
+ def create_custom_forward(module, return_dict=None):
199
+ def custom_forward(*inputs):
200
+ if return_dict is not None:
201
+ return module(*inputs, return_dict=return_dict)
202
+ else:
203
+ return module(*inputs)
204
+
205
+ return custom_forward
206
+
207
+ ckpt_kwargs: Dict[str, Any] = (
208
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
209
+ )
210
+ hidden_states = torch.utils.checkpoint.checkpoint(
211
+ create_custom_forward(block),
212
+ hidden_states,
213
+ temb,
214
+ image_rotary_emb,
215
+ **ckpt_kwargs,
216
+ )
217
+
218
+ else:
219
+ result = single_block_forward(
220
+ block,
221
+ model_config=model_config,
222
+ hidden_states=hidden_states,
223
+ temb=temb,
224
+ image_rotary_emb=image_rotary_emb,
225
+ **(
226
+ {
227
+ "condition_latents": condition_latents,
228
+ "cond_temb": cond_temb,
229
+ "cond_rotary_emb": cond_rotary_emb,
230
+ }
231
+ if use_condition_in_single_blocks and use_condition
232
+ else {}
233
+ ),
234
+ )
235
+ if use_condition_in_single_blocks and use_condition:
236
+ hidden_states, condition_latents = result
237
+ else:
238
+ hidden_states = result
239
+
240
+ # controlnet residual
241
+ if controlnet_single_block_samples is not None:
242
+ interval_control = len(self.single_transformer_blocks) / len(
243
+ controlnet_single_block_samples
244
+ )
245
+ interval_control = int(np.ceil(interval_control))
246
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
247
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
248
+ + controlnet_single_block_samples[index_block // interval_control]
249
+ )
250
+
251
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
252
+
253
+ hidden_states = self.norm_out(hidden_states, temb)
254
+ output = self.proj_out(hidden_states)
255
+ if return_conditional_latents:
256
+ condition_latents = (
257
+ self.norm_out(condition_latents, cond_temb) if use_condition else None
258
+ )
259
+ condition_output = self.proj_out(condition_latents) if use_condition else None
260
+
261
+ if USE_PEFT_BACKEND:
262
+ # remove `lora_scale` from each PEFT layer
263
+ unscale_lora_layers(self, lora_scale)
264
+
265
+ if not return_dict:
266
+ return (
267
+ (output,) if not return_conditional_latents else (output, condition_output)
268
+ )
269
+
270
+ return Transformer2DModelOutput(sample=output)