ovi054 commited on
Commit
449d6db
·
0 Parent(s):

Initial commit with LFS

Browse files
.gitattributes ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ cat.png filter=lfs diff=lfs merge=lfs -text
37
+ flowers.png filter=lfs diff=lfs merge=lfs -text
38
+ monster.png filter=lfs diff=lfs merge=lfs -text
39
+ *.png filter=lfs diff=lfs merge=lfs -text
40
+ *.jpg filter=lfs diff=lfs merge=lfs -text
41
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
42
+ *.gif filter=lfs diff=lfs merge=lfs -text
43
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: FLUX.1 Kontext
3
+ emoji: ⚡
4
+ colorFrom: green
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 5.34.0
8
+ app_file: app.py
9
+ pinned: true
10
+ license: mit
11
+ short_description: 'Kontext image editing on FLUX[dev] '
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PyTorch 2.8 (temporary hack)
2
+ import os
3
+ os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces')
4
+
5
+ # Actual demo code
6
+ import gradio as gr
7
+ import numpy as np
8
+ import spaces
9
+ import torch
10
+ import random
11
+ from PIL import Image, ImageOps
12
+
13
+ from diffusers import FluxKontextPipeline
14
+ from diffusers.utils import load_image
15
+
16
+ # from optimization import optimize_pipeline_
17
+
18
+ MAX_SEED = np.iinfo(np.int32).max
19
+
20
+ pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
21
+ pipe.load_lora_weights("ovi054/Draw2Photo")
22
+ pipe.fuse_lora()
23
+ # optimize_pipeline_(pipe, image=Image.new("RGB", (512, 512)), prompt='prompt')
24
+
25
+
26
+ def add_overlay(base_img, overlay_img, margin=20):
27
+ """
28
+ Pastes an overlay image onto the top-right corner of a base image.
29
+
30
+ The overlay is resized to be 1/5th of the width of the base image,
31
+ maintaining its aspect ratio.
32
+
33
+ Args:
34
+ base_img (PIL.Image.Image): The main image.
35
+ overlay_img (PIL.Image.Image): The image to place on top.
36
+ margin (int, optional): The pixel margin from the top and right edges. Defaults to 20.
37
+
38
+ Returns:
39
+ PIL.Image.Image: The combined image.
40
+ """
41
+ if base_img is None or overlay_img is None:
42
+ return base_img
43
+
44
+ base = base_img.convert("RGBA")
45
+ overlay = overlay_img.convert("RGBA")
46
+
47
+ # --- MODIFICATION ---
48
+ # Calculate the target width to be 1/5th of the base image's width
49
+ target_width = base.width // 5
50
+
51
+ # Keep aspect ratio, resize overlay to the newly calculated target width
52
+ w, h = overlay.size
53
+
54
+ # Add a check to prevent division by zero if the overlay image has no width
55
+ if w == 0:
56
+ return base
57
+
58
+ new_height = int(h * (target_width / w))
59
+ overlay = overlay.resize((target_width, new_height), Image.LANCZOS)
60
+
61
+ # Position: top-right corner with a margin
62
+ x = base.width - overlay.width - margin
63
+ y = margin
64
+
65
+ # Paste the resized overlay onto the base image using its alpha channel for transparency
66
+ base.paste(overlay, (x, y), overlay)
67
+ return base
68
+
69
+
70
+ # def add_overlay(base_img, overlay_img, margin=20, target_width=200):
71
+ # if base_img is None or overlay_img is None:
72
+ # return base_img
73
+
74
+ # base = base_img.convert("RGBA")
75
+ # overlay = overlay_img.convert("RGBA")
76
+
77
+ # # Keep aspect ratio, resize overlay to target width
78
+ # w, h = overlay.size
79
+ # new_height = int(h * (target_width / w))
80
+ # overlay = overlay.resize((target_width, new_height), Image.LANCZOS)
81
+
82
+ # # Position: top-right with margin
83
+ # x = base.width - overlay.width - margin
84
+ # y = margin
85
+
86
+ # # Paste overlay on base with transparency
87
+ # base.paste(overlay, (x, y), overlay)
88
+ # return base
89
+
90
+
91
+ # @spaces.GPU
92
+ # def infer(input_image, input_image_upload, overlay_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, steps=28, progress=gr.Progress(track_tqdm=True)):
93
+ # """
94
+ # Perform image editing using the FLUX.1 Kontext pipeline.
95
+
96
+ # This function takes an input image and a text prompt to generate a modified version
97
+ # of the image based on the provided instructions. It uses the FLUX.1 Kontext model
98
+ # for contextual image editing tasks.
99
+
100
+ # Args:
101
+ # input_image (PIL.Image.Image): The input image to be edited. Will be converted
102
+ # to RGB format if not already in that format.
103
+ # prompt (str): Text description of the desired edit to apply to the image.
104
+ # Examples: "Remove glasses", "Add a hat", "Change background to beach".
105
+ # seed (int, optional): Random seed for reproducible generation. Defaults to 42.
106
+ # Must be between 0 and MAX_SEED (2^31 - 1).
107
+ # randomize_seed (bool, optional): If True, generates a random seed instead of
108
+ # using the provided seed value. Defaults to False.
109
+ # guidance_scale (float, optional): Controls how closely the model follows the
110
+ # prompt. Higher values mean stronger adherence to the prompt but may reduce
111
+ # image quality. Range: 1.0-10.0. Defaults to 2.5.
112
+ # steps (int, optional): Controls how many steps to run the diffusion model for.
113
+ # Range: 1-30. Defaults to 28.
114
+ # progress (gr.Progress, optional): Gradio progress tracker for monitoring
115
+ # generation progress. Defaults to gr.Progress(track_tqdm=True).
116
+
117
+ # Returns:
118
+ # tuple: A 3-tuple containing:
119
+ # - PIL.Image.Image: The generated/edited image
120
+ # - int: The seed value used for generation (useful when randomize_seed=True)
121
+ # - gr.update: Gradio update object to make the reuse button visible
122
+
123
+ # Example:
124
+ # >>> edited_image, used_seed, button_update = infer(
125
+ # ... input_image=my_image,
126
+ # ... prompt="Add sunglasses",
127
+ # ... seed=123,
128
+ # ... randomize_seed=False,
129
+ # ... guidance_scale=2.5
130
+ # ... )
131
+ # """
132
+ # if randomize_seed:
133
+ # seed = random.randint(0, MAX_SEED)
134
+
135
+ # if input_image_upload is not None:
136
+ # input_image_upload = input_image
137
+ # elif "composite" in input_image and input_image["composite"] is not None:
138
+ # input_image = input_image["composite"]
139
+ # elif "background" in input_image and input_image["background"] is not None:
140
+ # input_image = input_image["background"]
141
+ # else:
142
+ # raise ValueError("No valid image found in EditorValue dict (both 'composite' and 'background' are None)")
143
+
144
+
145
+ # if input_image is not None:
146
+ # if overlay_image is not None:
147
+ # input_image = add_overlay(input_image, overlay_image)
148
+
149
+ # input_image = input_image.convert("RGB")
150
+ # image = pipe(
151
+ # image=input_image,
152
+ # prompt=prompt,
153
+ # guidance_scale=guidance_scale,
154
+ # width = input_image.size[0],
155
+ # height = input_image.size[1],
156
+ # num_inference_steps=steps,
157
+ # generator=torch.Generator().manual_seed(seed),
158
+ # ).images[0]
159
+ # else:
160
+ # image = pipe(
161
+ # prompt=prompt,
162
+ # guidance_scale=guidance_scale,
163
+ # num_inference_steps=steps,
164
+ # generator=torch.Generator().manual_seed(seed),
165
+ # ).images[0]
166
+ # return image, input_image, seed, gr.Button(visible=True)
167
+
168
+
169
+ @spaces.GPU
170
+ def infer(input_image, input_image_upload, overlay_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, steps=28, progress=gr.Progress(track_tqdm=True)):
171
+ """
172
+ Perform image editing using the FLUX.1 Kontext pipeline.
173
+
174
+ This function takes an input image and a text prompt to generate a modified version
175
+ of the image based on the provided instructions. It uses the FLUX.1 Kontext model
176
+ for contextual image editing tasks.
177
+
178
+ Args:
179
+ input_image (dict or PIL.Image.Image): The input from the gr.Paint component.
180
+ input_image_upload (PIL.Image.Image): The input from the gr.Image upload component.
181
+ overlay_image (PIL.Image.Image): The face photo to overlay.
182
+ prompt (str): Text description of the desired edit to apply to the image.
183
+ seed (int, optional): Random seed for reproducible generation.
184
+ randomize_seed (bool, optional): If True, generates a random seed.
185
+ guidance_scale (float, optional): Controls how closely the model follows the prompt.
186
+ steps (int, optional): Controls how many steps to run the diffusion model for.
187
+ progress (gr.Progress, optional): Gradio progress tracker.
188
+
189
+ Returns:
190
+ tuple: A 4-tuple containing the result image, the processed input image, the seed, and a gr.Button update.
191
+ """
192
+ if randomize_seed:
193
+ seed = random.randint(0, MAX_SEED)
194
+
195
+ # --- CORRECTED LOGIC STARTS HERE ---
196
+
197
+ # 1. Prioritize the uploaded image. If it exists, it becomes our main 'input_image'.
198
+ if input_image_upload is not None:
199
+ processed_input_image = input_image_upload
200
+ # 2. If no image was uploaded, check the drawing canvas.
201
+ elif isinstance(input_image, dict):
202
+ # Extract the actual image from the dictionary provided by gr.Paint
203
+ if "composite" in input_image and input_image["composite"] is not None:
204
+ processed_input_image = input_image["composite"]
205
+ elif "background" in input_image and input_image["background"] is not None:
206
+ processed_input_image = input_image["background"]
207
+ else:
208
+ # The canvas is empty, so there's no input image.
209
+ processed_input_image = None
210
+ else:
211
+ # Fallback in case the input is neither from upload nor a valid canvas dict.
212
+ processed_input_image = None
213
+
214
+ # --- CORRECTED LOGIC ENDS HERE ---
215
+
216
+ # From this point on, 'processed_input_image' is either a PIL Image or None.
217
+ if processed_input_image is not None:
218
+ if overlay_image is not None:
219
+ # Now this function is guaranteed to receive a PIL Image.
220
+ processed_input_image = add_overlay(processed_input_image, overlay_image)
221
+
222
+ processed_input_image = processed_input_image.convert("RGB")
223
+ image = pipe(
224
+ image=processed_input_image,
225
+ prompt=prompt,
226
+ guidance_scale=guidance_scale,
227
+ width = processed_input_image.size[0],
228
+ height = processed_input_image.size[1],
229
+ num_inference_steps=steps,
230
+ generator=torch.Generator().manual_seed(seed),
231
+ ).images[0]
232
+ else:
233
+ # Handle the text-to-image case where no input image was provided.
234
+ image = pipe(
235
+ prompt=prompt,
236
+ guidance_scale=guidance_scale,
237
+ num_inference_steps=steps,
238
+ generator=torch.Generator().manual_seed(seed),
239
+ ).images[0]
240
+
241
+ return image, processed_input_image, seed, gr.Button(visible=True)
242
+
243
+ @spaces.GPU
244
+ def infer_example(input_image, prompt):
245
+ image, seed, _ = infer(input_image, prompt)
246
+ return image, seed
247
+
248
+ css="""
249
+ #col-container {
250
+ margin: 0 auto;
251
+ max-width: 960px;
252
+ }
253
+ """
254
+
255
+ with gr.Blocks(css=css) as demo:
256
+
257
+ with gr.Column(elem_id="col-container"):
258
+ gr.Markdown(f"""# FLUX.1 Kontext [dev]
259
+ Image editing and manipulation model guidance-distilled from FLUX.1 Kontext [pro], [[blog]](https://bfl.ai/announcements/flux-1-kontext-dev) [[model]](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev)
260
+ """)
261
+ with gr.Row():
262
+ with gr.Column():
263
+ # input_image = gr.Image(label="Upload the image for editing", type="pil")
264
+ with gr.Row():
265
+ with gr.Tabs() as tabs:
266
+ with gr.TabItem("Draw"):
267
+ input_image = gr.Paint(
268
+ type="pil",
269
+ brush=gr.Brush(default_size=6, colors=["#000000"], color_mode="fixed"),
270
+ canvas_size = (1200,1200),
271
+ layers = False
272
+ )
273
+ with gr.TabItem("Upload"):
274
+ input_image_upload = gr.Image(label="Upload the drawing", type="pil")
275
+ with gr.Row():
276
+ overlay_image = gr.Image(label="Upload face photo", type="pil")
277
+ with gr.Row():
278
+ prompt = gr.Text(
279
+ label="Prompt",
280
+ show_label=False,
281
+ max_lines=1,
282
+ value = "make it real",
283
+ placeholder="Enter your prompt for editing (e.g., 'Remove glasses', 'Add a hat')",
284
+ container=False,
285
+ )
286
+ run_button = gr.Button("Run", scale=0)
287
+ with gr.Accordion("Advanced Settings", open=False):
288
+
289
+ seed = gr.Slider(
290
+ label="Seed",
291
+ minimum=0,
292
+ maximum=MAX_SEED,
293
+ step=1,
294
+ value=0,
295
+ )
296
+
297
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
298
+
299
+ guidance_scale = gr.Slider(
300
+ label="Guidance Scale",
301
+ minimum=1,
302
+ maximum=10,
303
+ step=0.1,
304
+ value=2.5,
305
+ )
306
+
307
+ steps = gr.Slider(
308
+ label="Steps",
309
+ minimum=1,
310
+ maximum=30,
311
+ value=28,
312
+ step=1
313
+ )
314
+
315
+ with gr.Column():
316
+ result = gr.Image(label="Result", show_label=False, interactive=False)
317
+ result_input = gr.Image(label="Result", show_label=False, interactive=False)
318
+ reuse_button = gr.Button("Reuse this image", visible=False)
319
+
320
+
321
+ examples = gr.Examples(
322
+ examples=[
323
+ ["flowers.png", "turn the flowers into sunflowers"],
324
+ ["monster.png", "make this monster ride a skateboard on the beach"],
325
+ ["cat.png", "make this cat happy"]
326
+ ],
327
+ inputs=[input_image, prompt],
328
+ outputs=[result, seed],
329
+ fn=infer_example,
330
+ cache_examples="lazy"
331
+ )
332
+
333
+ gr.on(
334
+ triggers=[run_button.click, prompt.submit],
335
+ fn = infer,
336
+ inputs = [input_image, input_image_upload, overlay_image, prompt, seed, randomize_seed, guidance_scale, steps],
337
+ outputs = [result, result_input, seed, reuse_button]
338
+ )
339
+ reuse_button.click(
340
+ fn = lambda image: image,
341
+ inputs = [result],
342
+ outputs = [input_image]
343
+ )
344
+
345
+ demo.launch(mcp_server=True)
cat.png ADDED

Git LFS Details

  • SHA256: a23d3036df9a9a47b458f0b5fd1d3b46f2061b20e2055c4f797de9cf9a1efd33
  • Pointer size: 131 Bytes
  • Size of remote file: 545 kB
examples/base/01.png ADDED

Git LFS Details

  • SHA256: 8973a9ea7fbaf9def410a3c1d3ce648cd628fb632d93978f1e0eb3a78e15ef6b
  • Pointer size: 131 Bytes
  • Size of remote file: 996 kB
examples/base/02.png ADDED

Git LFS Details

  • SHA256: 20e60386cb376312509ae29b7f3ecd958885c3d09c7c76fd73fdd64ab89db317
  • Pointer size: 131 Bytes
  • Size of remote file: 622 kB
examples/base/04.png ADDED

Git LFS Details

  • SHA256: ff53a6066b6585fe8c3f85e42b2459ecc4f7c172a748e414ec97134da14ad0ed
  • Pointer size: 131 Bytes
  • Size of remote file: 668 kB
examples/base/07.png ADDED

Git LFS Details

  • SHA256: d0a02127ad5b31332a98a94fb6116198fea0e8fe8664a6f454182cded76723f6
  • Pointer size: 131 Bytes
  • Size of remote file: 725 kB
examples/base/08.png ADDED

Git LFS Details

  • SHA256: 7c9c0677ebf7a558128fe765085471641fbb5fb6f86f950cfd11b7bac358042d
  • Pointer size: 131 Bytes
  • Size of remote file: 670 kB
examples/base/22.png ADDED

Git LFS Details

  • SHA256: 1a735c310c8318c6a374d9bff51ef634d5f170e1ca91d3992bb5ce710bfc1662
  • Pointer size: 131 Bytes
  • Size of remote file: 519 kB
examples/base/25.png ADDED

Git LFS Details

  • SHA256: f19aa7ab547fd6276768a9836e4adf489c7cefc691d889e7e78ca21a8d88d7a2
  • Pointer size: 131 Bytes
  • Size of remote file: 643 kB
examples/base/6.png ADDED

Git LFS Details

  • SHA256: 5e0e0510f6bb258751d63df06b7645fc0a35cd01185a195057c44b857bcd7ad8
  • Pointer size: 132 Bytes
  • Size of remote file: 2.72 MB
examples/face/09 11.png ADDED

Git LFS Details

  • SHA256: bd7d47d59ba41878f683ef79f8d899937c23443cfc35752471a996856c065e9a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.23 MB
flowers.png ADDED

Git LFS Details

  • SHA256: c97ca8d8e8932d8753915b5f1c5985cfaadb8c7be492d125f6a2a592a278eca1
  • Pointer size: 131 Bytes
  • Size of remote file: 559 kB
monster.png ADDED

Git LFS Details

  • SHA256: c00e55fc9a976868765c39c994f1efd999d94819ce29ab1fb6719189a1bd55e9
  • Pointer size: 131 Bytes
  • Size of remote file: 364 kB
optimization.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ from typing import Any
5
+ from typing import Callable
6
+ from typing import ParamSpec
7
+
8
+ import spaces
9
+ import torch
10
+ from torch.utils._pytree import tree_map_only
11
+
12
+ from optimization_utils import capture_component_call
13
+ from optimization_utils import aoti_compile
14
+
15
+
16
+ P = ParamSpec('P')
17
+
18
+
19
+ TRANSFORMER_HIDDEN_DIM = torch.export.Dim('hidden', min=4096, max=8212)
20
+
21
+ TRANSFORMER_DYNAMIC_SHAPES = {
22
+ 'hidden_states': {1: TRANSFORMER_HIDDEN_DIM},
23
+ 'img_ids': {0: TRANSFORMER_HIDDEN_DIM},
24
+ }
25
+
26
+ INDUCTOR_CONFIGS = {
27
+ 'conv_1x1_as_mm': True,
28
+ 'epilogue_fusion': False,
29
+ 'coordinate_descent_tuning': True,
30
+ 'coordinate_descent_check_all_directions': True,
31
+ 'max_autotune': True,
32
+ 'triton.cudagraphs': True,
33
+ }
34
+
35
+
36
+ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
37
+
38
+ @spaces.GPU(duration=1500)
39
+ def compile_transformer():
40
+
41
+ with capture_component_call(pipeline, 'transformer') as call:
42
+ pipeline(*args, **kwargs)
43
+
44
+ dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
45
+ dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
46
+
47
+ pipeline.transformer.fuse_qkv_projections()
48
+
49
+ exported = torch.export.export(
50
+ mod=pipeline.transformer,
51
+ args=call.args,
52
+ kwargs=call.kwargs,
53
+ dynamic_shapes=dynamic_shapes,
54
+ )
55
+
56
+ return aoti_compile(exported, INDUCTOR_CONFIGS)
57
+
58
+ transformer_config = pipeline.transformer.config
59
+ pipeline.transformer = compile_transformer()
60
+ pipeline.transformer.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
optimization_utils.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ import contextlib
4
+ from contextvars import ContextVar
5
+ from io import BytesIO
6
+ from typing import Any
7
+ from typing import cast
8
+ from unittest.mock import patch
9
+
10
+ import torch
11
+ from torch._inductor.package.package import package_aoti
12
+ from torch.export.pt2_archive._package import AOTICompiledModel
13
+ from torch.export.pt2_archive._package_weights import TensorProperties
14
+ from torch.export.pt2_archive._package_weights import Weights
15
+
16
+
17
+ INDUCTOR_CONFIGS_OVERRIDES = {
18
+ 'aot_inductor.package_constants_in_so': False,
19
+ 'aot_inductor.package_constants_on_disk': True,
20
+ 'aot_inductor.package': True,
21
+ }
22
+
23
+
24
+ class ZeroGPUCompiledModel:
25
+ def __init__(self, archive_file: torch.types.FileLike, weights: Weights, cuda: bool = False):
26
+ self.archive_file = archive_file
27
+ self.weights = weights
28
+ if cuda:
29
+ self.weights_to_cuda_()
30
+ self.compiled_model: ContextVar[AOTICompiledModel | None] = ContextVar('compiled_model', default=None)
31
+ def weights_to_cuda_(self):
32
+ for name in self.weights:
33
+ tensor, properties = self.weights.get_weight(name)
34
+ self.weights[name] = (tensor.to('cuda'), properties)
35
+ def __call__(self, *args, **kwargs):
36
+ if (compiled_model := self.compiled_model.get()) is None:
37
+ constants_map = {name: value[0] for name, value in self.weights.items()}
38
+ compiled_model = cast(AOTICompiledModel, torch._inductor.aoti_load_package(self.archive_file))
39
+ compiled_model.load_constants(constants_map, check_full_update=True, user_managed=True)
40
+ self.compiled_model.set(compiled_model)
41
+ return compiled_model(*args, **kwargs)
42
+ def __reduce__(self):
43
+ weight_dict: dict[str, tuple[torch.Tensor, TensorProperties]] = {}
44
+ for name in self.weights:
45
+ tensor, properties = self.weights.get_weight(name)
46
+ tensor_ = torch.empty_like(tensor, device='cpu').pin_memory()
47
+ weight_dict[name] = (tensor_.copy_(tensor).detach().share_memory_(), properties)
48
+ return ZeroGPUCompiledModel, (self.archive_file, Weights(weight_dict), True)
49
+
50
+
51
+ def aoti_compile(
52
+ exported_program: torch.export.ExportedProgram,
53
+ inductor_configs: dict[str, Any] | None = None,
54
+ ):
55
+ inductor_configs = (inductor_configs or {}) | INDUCTOR_CONFIGS_OVERRIDES
56
+ gm = cast(torch.fx.GraphModule, exported_program.module())
57
+ assert exported_program.example_inputs is not None
58
+ args, kwargs = exported_program.example_inputs
59
+ artifacts = torch._inductor.aot_compile(gm, args, kwargs, options=inductor_configs)
60
+ archive_file = BytesIO()
61
+ files: list[str | Weights] = [file for file in artifacts if isinstance(file, str)]
62
+ package_aoti(archive_file, files)
63
+ weights, = (artifact for artifact in artifacts if isinstance(artifact, Weights))
64
+ return ZeroGPUCompiledModel(archive_file, weights)
65
+
66
+
67
+ @contextlib.contextmanager
68
+ def capture_component_call(
69
+ pipeline: Any,
70
+ component_name: str,
71
+ component_method='forward',
72
+ ):
73
+
74
+ class CapturedCallException(Exception):
75
+ def __init__(self, *args, **kwargs):
76
+ super().__init__()
77
+ self.args = args
78
+ self.kwargs = kwargs
79
+
80
+ class CapturedCall:
81
+ def __init__(self):
82
+ self.args: tuple[Any, ...] = ()
83
+ self.kwargs: dict[str, Any] = {}
84
+
85
+ component = getattr(pipeline, component_name)
86
+ captured_call = CapturedCall()
87
+
88
+ def capture_call(*args, **kwargs):
89
+ raise CapturedCallException(*args, **kwargs)
90
+
91
+ with patch.object(component, component_method, new=capture_call):
92
+ try:
93
+ yield captured_call
94
+ except CapturedCallException as e:
95
+ captured_call.args = e.args
96
+ captured_call.kwargs = e.kwargs
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ transformers
2
+ git+https://github.com/huggingface/diffusers.git
3
+ accelerate
4
+ safetensors
5
+ sentencepiece
6
+ peft