aatir commited on
Commit
39fc645
·
1 Parent(s): 9ec92e2

added files

Browse files
Files changed (7) hide show
  1. LICENSE +201 -0
  2. README copy.md +14 -0
  3. app.py +396 -0
  4. chat_interface.py +641 -0
  5. lib_omost/canvas.py +248 -0
  6. lib_omost/pipeline.py +445 -0
  7. requirements.txt +12 -0
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README copy.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Omost
3
+ emoji: 😻
4
+ colorFrom: red
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 4.32.2
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ fullWidth: true
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import spaces
3
+
4
+ os.environ['HF_HOME'] = os.path.join(os.path.dirname(__file__), 'hf_download')
5
+ HF_TOKEN = os.environ['hf_token'] if 'hf_token' in os.environ else None
6
+
7
+ import uuid
8
+ import time
9
+ import torch
10
+ import numpy as np
11
+ import gradio as gr
12
+ import tempfile
13
+
14
+ gradio_temp_dir = os.path.join(tempfile.gettempdir(), 'gradio')
15
+ os.makedirs(gradio_temp_dir, exist_ok=True)
16
+
17
+ from threading import Thread
18
+
19
+ # Phi3 Hijack
20
+ from transformers.models.phi3.modeling_phi3 import Phi3PreTrainedModel
21
+
22
+ Phi3PreTrainedModel._supports_sdpa = True
23
+
24
+ from PIL import Image
25
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
26
+ from diffusers import AutoencoderKL, UNet2DConditionModel
27
+ from diffusers.models.attention_processor import AttnProcessor2_0
28
+ from transformers import CLIPTextModel, CLIPTokenizer
29
+ from lib_omost.pipeline import StableDiffusionXLOmostPipeline
30
+ from chat_interface import ChatInterface
31
+ from transformers.generation.stopping_criteria import StoppingCriteriaList
32
+
33
+ import lib_omost.canvas as omost_canvas
34
+
35
+
36
+ # SDXL
37
+
38
+ sdxl_name = 'SG161222/RealVisXL_V4.0'
39
+ # sdxl_name = 'stabilityai/stable-diffusion-xl-base-1.0'
40
+
41
+ tokenizer = CLIPTokenizer.from_pretrained(
42
+ sdxl_name, subfolder="tokenizer")
43
+ tokenizer_2 = CLIPTokenizer.from_pretrained(
44
+ sdxl_name, subfolder="tokenizer_2")
45
+ text_encoder = CLIPTextModel.from_pretrained(
46
+ sdxl_name, subfolder="text_encoder", torch_dtype=torch.float16, variant="fp16", device_map="auto")
47
+ text_encoder_2 = CLIPTextModel.from_pretrained(
48
+ sdxl_name, subfolder="text_encoder_2", torch_dtype=torch.float16, variant="fp16", device_map="auto")
49
+ vae = AutoencoderKL.from_pretrained(
50
+ sdxl_name, subfolder="vae", torch_dtype=torch.bfloat16, variant="fp16", device_map="auto") # bfloat16 vae
51
+ unet = UNet2DConditionModel.from_pretrained(
52
+ sdxl_name, subfolder="unet", torch_dtype=torch.float16, variant="fp16", device_map="auto")
53
+
54
+ unet.set_attn_processor(AttnProcessor2_0())
55
+ vae.set_attn_processor(AttnProcessor2_0())
56
+
57
+ pipeline = StableDiffusionXLOmostPipeline(
58
+ vae=vae,
59
+ text_encoder=text_encoder,
60
+ tokenizer=tokenizer,
61
+ text_encoder_2=text_encoder_2,
62
+ tokenizer_2=tokenizer_2,
63
+ unet=unet,
64
+ scheduler=None, # We completely give up diffusers sampling system and use A1111's method
65
+ )
66
+
67
+ # LLM
68
+
69
+ # model_name = 'lllyasviel/omost-phi-3-mini-128k'
70
+ llm_name = 'lllyasviel/omost-llama-3-8b'
71
+ # model_name = 'lllyasviel/omost-dolphin-2.9-llama3-8b'
72
+
73
+ llm_model = AutoModelForCausalLM.from_pretrained(
74
+ llm_name,
75
+ torch_dtype="auto",
76
+ token=HF_TOKEN,
77
+ device_map="auto",
78
+ trust_remote_code=True,
79
+ )
80
+
81
+ llm_tokenizer = AutoTokenizer.from_pretrained(
82
+ llm_name,
83
+ token=HF_TOKEN
84
+ )
85
+
86
+
87
+ @torch.inference_mode()
88
+ def pytorch2numpy(imgs):
89
+ results = []
90
+ for x in imgs:
91
+ y = x.movedim(0, -1)
92
+ y = y * 127.5 + 127.5
93
+ y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
94
+ results.append(y)
95
+ return results
96
+
97
+
98
+ @torch.inference_mode()
99
+ def numpy2pytorch(imgs):
100
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
101
+ h = h.movedim(-1, 1)
102
+ return h
103
+
104
+
105
+ def resize_without_crop(image, target_width, target_height):
106
+ pil_image = Image.fromarray(image)
107
+ resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
108
+ return np.array(resized_image)
109
+
110
+
111
+ @spaces.GPU(duration=120)
112
+ @torch.inference_mode()
113
+ def chat_fn(message: str, history: list, seed:int, temperature: float, top_p: float, max_new_tokens: int) -> str:
114
+ print('Chat begin:', message)
115
+ time_stamp = time.time()
116
+
117
+ np.random.seed(int(seed))
118
+ torch.manual_seed(int(seed))
119
+
120
+ conversation = [{"role": "system", "content": omost_canvas.system_prompt}]
121
+
122
+ for user, assistant in history:
123
+ if isinstance(user, str) and isinstance(assistant, str):
124
+ if len(user) > 0 and len(assistant) > 0:
125
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
126
+
127
+ conversation.append({"role": "user", "content": message})
128
+
129
+ input_ids = llm_tokenizer.apply_chat_template(
130
+ conversation, return_tensors="pt", add_generation_prompt=True).to(llm_model.device)
131
+
132
+ streamer = TextIteratorStreamer(llm_tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
133
+
134
+ def interactive_stopping_criteria(*args, **kwargs) -> bool:
135
+ if getattr(streamer, 'user_interrupted', False):
136
+ print('User stopped generation:', message)
137
+ return True
138
+ else:
139
+ return False
140
+
141
+ stopping_criteria = StoppingCriteriaList([interactive_stopping_criteria])
142
+
143
+ def interrupter():
144
+ streamer.user_interrupted = True
145
+ return
146
+
147
+ generate_kwargs = dict(
148
+ input_ids=input_ids,
149
+ streamer=streamer,
150
+ stopping_criteria=stopping_criteria,
151
+ max_new_tokens=max_new_tokens,
152
+ do_sample=True,
153
+ temperature=temperature,
154
+ top_p=top_p,
155
+ )
156
+
157
+ if temperature == 0:
158
+ generate_kwargs['do_sample'] = False
159
+
160
+ Thread(target=llm_model.generate, kwargs=generate_kwargs).start()
161
+
162
+ outputs = []
163
+ for text in streamer:
164
+ outputs.append(text)
165
+ # print(outputs)
166
+ yield "".join(outputs), None
167
+
168
+ print(f'Chat end at {time.time() - time_stamp:.2f} seconds:', message)
169
+ return
170
+
171
+
172
+ @torch.inference_mode()
173
+ def post_chat(history):
174
+ canvas_outputs = None
175
+
176
+ try:
177
+ if history:
178
+ history = [(user, assistant) for user, assistant in history if isinstance(user, str) and isinstance(assistant, str)]
179
+ last_assistant = history[-1][1] if len(history) > 0 else None
180
+ canvas = omost_canvas.Canvas.from_bot_response(last_assistant)
181
+ canvas_outputs = canvas.process()
182
+ except Exception as e:
183
+ print('Last assistant response is not valid canvas:', e)
184
+
185
+ return canvas_outputs, gr.update(visible=canvas_outputs is not None), gr.update(interactive=len(history) > 0)
186
+
187
+ def preprocess_product_image(image, target_width, target_height):
188
+ image = image.convert("RGB")
189
+ image = image.resize((target_width, target_height), Image.LANCZOS)
190
+ image = np.array(image)
191
+ image = image.astype(np.float32) / 127.5 - 1.0
192
+ image = np.transpose(image, (2, 0, 1))
193
+ return torch.from_numpy(image).unsqueeze(0)
194
+
195
+ @spaces.GPU
196
+ @torch.inference_mode()
197
+ def diffusion_fn(chatbot, canvas_outputs, num_samples, seed, image_width, image_height,
198
+ highres_scale, steps, cfg, highres_steps, highres_denoise, negative_prompt):
199
+
200
+ use_initial_latent = False
201
+ eps = 0.05
202
+
203
+ image_width, image_height = int(image_width // 64) * 64, int(image_height // 64) * 64
204
+
205
+ if product_image is not None:
206
+ product_image = preprocess_product_image(product_image, image_width, image_height)
207
+
208
+ rng = torch.Generator(unet.device).manual_seed(seed)
209
+
210
+ positive_cond, positive_pooler, negative_cond, negative_pooler = pipeline.all_conds_from_canvas(canvas_outputs, negative_prompt)
211
+
212
+ if use_initial_latent:
213
+ initial_latent = torch.from_numpy(canvas_outputs['initial_latent'])[None].movedim(-1, 1) / 127.5 - 1.0
214
+ initial_latent_blur = 40
215
+ initial_latent = torch.nn.functional.avg_pool2d(
216
+ torch.nn.functional.pad(initial_latent, (initial_latent_blur,) * 4, mode='reflect'),
217
+ kernel_size=(initial_latent_blur * 2 + 1,) * 2, stride=(1, 1))
218
+ initial_latent = torch.nn.functional.interpolate(initial_latent, (image_height, image_width))
219
+ initial_latent = initial_latent.to(dtype=vae.dtype, device=vae.device)
220
+ initial_latent = vae.encode(initial_latent).latent_dist.mode() * vae.config.scaling_factor
221
+ else:
222
+ initial_latent = torch.zeros(size=(num_samples, 4, image_height // 8, image_width // 8), dtype=torch.float32)
223
+
224
+ initial_latent = initial_latent.to(dtype=unet.dtype, device=unet.device)
225
+
226
+ latents = pipeline(
227
+ initial_latent=initial_latent,
228
+ strength=1.0,
229
+ num_inference_steps=int(steps),
230
+ batch_size=num_samples,
231
+ prompt_embeds=positive_cond,
232
+ negative_prompt_embeds=negative_cond,
233
+ pooled_prompt_embeds=positive_pooler,
234
+ negative_pooled_prompt_embeds=negative_pooler,
235
+ generator=rng,
236
+ guidance_scale=float(cfg),
237
+ product_image=product_image,
238
+ ).images
239
+
240
+ latents = latents.to(dtype=vae.dtype, device=vae.device) / vae.config.scaling_factor
241
+ pixels = vae.decode(latents).sample
242
+ B, C, H, W = pixels.shape
243
+ pixels = pytorch2numpy(pixels)
244
+
245
+ if highres_scale > 1.0 + eps:
246
+ pixels = [
247
+ resize_without_crop(
248
+ image=p,
249
+ target_width=int(round(W * highres_scale / 64.0) * 64),
250
+ target_height=int(round(H * highres_scale / 64.0) * 64)
251
+ ) for p in pixels
252
+ ]
253
+
254
+ pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
255
+ latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
256
+
257
+ latents = latents.to(device=unet.device, dtype=unet.dtype)
258
+
259
+ latents = pipeline(
260
+ initial_latent=latents,
261
+ strength=highres_denoise,
262
+ num_inference_steps=highres_steps,
263
+ batch_size=num_samples,
264
+ prompt_embeds=positive_cond,
265
+ negative_prompt_embeds=negative_cond,
266
+ pooled_prompt_embeds=positive_pooler,
267
+ negative_pooled_prompt_embeds=negative_pooler,
268
+ generator=rng,
269
+ guidance_scale=float(cfg),
270
+ ).images
271
+
272
+ latents = latents.to(dtype=vae.dtype, device=vae.device) / vae.config.scaling_factor
273
+ pixels = vae.decode(latents).sample
274
+ pixels = pytorch2numpy(pixels)
275
+
276
+ for i in range(len(pixels)):
277
+ unique_hex = uuid.uuid4().hex
278
+ image_path = os.path.join(gradio_temp_dir, f"{unique_hex}_{i}.png")
279
+ image = Image.fromarray(pixels[i])
280
+ image.save(image_path)
281
+ chatbot = chatbot + [(None, (image_path, 'image'))]
282
+
283
+ return chatbot
284
+
285
+
286
+ css = '''
287
+ code {white-space: pre-wrap !important;}
288
+ .gradio-container {max-width: none !important;}
289
+ .outer_parent {flex: 1;}
290
+ .inner_parent {flex: 1;}
291
+ footer {display: none !important; visibility: hidden !important;}
292
+ .translucent {display: none !important; visibility: hidden !important;}
293
+ '''
294
+
295
+ from gradio.themes.utils import colors
296
+
297
+ with gr.Blocks(
298
+ fill_height=True, css=css,
299
+ theme=gr.themes.Default(primary_hue=colors.blue, secondary_hue=colors.cyan, neutral_hue=colors.gray)
300
+ ) as demo:
301
+ with gr.Row(elem_classes='outer_parent'):
302
+ with gr.Column(scale=25):
303
+ product_image = gr.Image(label="Product Image", type="pil")
304
+ with gr.Row():
305
+ clear_btn = gr.Button("➕ New Chat", variant="secondary", size="sm", min_width=60)
306
+ retry_btn = gr.Button("Retry", variant="secondary", size="sm", min_width=60, visible=False)
307
+ undo_btn = gr.Button("✏️️ Edit Last Input", variant="secondary", size="sm", min_width=60, interactive=False)
308
+
309
+ seed = gr.Number(label="Random Seed", value=123456, precision=0)
310
+
311
+ with gr.Accordion(open=True, label='Language Model'):
312
+ with gr.Group():
313
+ with gr.Row():
314
+ temperature = gr.Slider(
315
+ minimum=0.0,
316
+ maximum=2.0,
317
+ step=0.01,
318
+ value=0.6,
319
+ label="Temperature")
320
+ top_p = gr.Slider(
321
+ minimum=0.0,
322
+ maximum=1.0,
323
+ step=0.01,
324
+ value=0.9,
325
+ label="Top P")
326
+ max_new_tokens = gr.Slider(
327
+ minimum=128,
328
+ maximum=4096,
329
+ step=1,
330
+ value=4096,
331
+ label="Max New Tokens")
332
+ with gr.Accordion(open=True, label='Image Diffusion Model'):
333
+ with gr.Group():
334
+ with gr.Row():
335
+ image_width = gr.Slider(label="Image Width", minimum=256, maximum=2048, value=896, step=64)
336
+ image_height = gr.Slider(label="Image Height", minimum=256, maximum=2048, value=1152, step=64)
337
+
338
+ with gr.Row():
339
+ num_samples = gr.Slider(label="Image Number", minimum=1, maximum=12, value=1, step=1)
340
+ steps = gr.Slider(label="Sampling Steps", minimum=1, maximum=100, value=25, step=1)
341
+
342
+ with gr.Accordion(open=False, label='Advanced'):
343
+ cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=5.0, step=0.01)
344
+ highres_scale = gr.Slider(label="HR-fix Scale (\"1\" is disabled)", minimum=1.0, maximum=2.0, value=1.0, step=0.01)
345
+ highres_steps = gr.Slider(label="Highres Fix Steps", minimum=1, maximum=100, value=20, step=1)
346
+ highres_denoise = gr.Slider(label="Highres Fix Denoise", minimum=0.1, maximum=1.0, value=0.4, step=0.01)
347
+ n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality')
348
+
349
+ render_button = gr.Button("Render the Image!", size='lg', variant="primary", visible=False)
350
+
351
+ examples = gr.Dataset(
352
+ samples=[
353
+ ['generate an image of the fierce battle of warriors and the dragon'],
354
+ ['change the dragon to a dinosaur']
355
+ ],
356
+ components=[gr.Textbox(visible=False)],
357
+ label='Quick Prompts'
358
+ )
359
+
360
+ with gr.Row():
361
+ gr.Markdown("Omost: converting LLM's coding capability to image compositing capability.")
362
+ with gr.Row():
363
+ gr.Markdown("Local version (8GB VRAM): https://github.com/lllyasviel/Omost")
364
+ # with gr.Row():
365
+ # gr.Markdown("Hint: You can [duplicate this space](https://huggingface.co/spaces/lllyasviel/Omost?duplicate=true) to your private account to bypass the waiting queue.")
366
+
367
+ with gr.Column(scale=75, elem_classes='inner_parent'):
368
+ canvas_state = gr.State(None)
369
+ chatbot = gr.Chatbot(label='Omost', scale=1, show_copy_button=True, layout="panel", render=False)
370
+ chatInterface = ChatInterface(
371
+ fn=chat_fn,
372
+ post_fn=post_chat,
373
+ post_fn_kwargs=dict(inputs=[chatbot], outputs=[canvas_state, render_button, undo_btn]),
374
+ pre_fn=lambda: gr.update(visible=False),
375
+ pre_fn_kwargs=dict(outputs=[render_button]),
376
+ chatbot=chatbot,
377
+ retry_btn=retry_btn,
378
+ undo_btn=undo_btn,
379
+ clear_btn=clear_btn,
380
+ additional_inputs=[seed, temperature, top_p, max_new_tokens],
381
+ examples=examples,
382
+ show_stop_button=False
383
+ )
384
+
385
+ render_button.click(
386
+ fn=diffusion_fn, inputs=[
387
+ chatInterface.chatbot, canvas_state,
388
+ num_samples, seed, image_width, image_height, highres_scale,
389
+ steps, cfg, highres_steps, highres_denoise, n_prompt, product_image
390
+ ], outputs=[chatInterface.chatbot]).then(
391
+ fn=lambda x: x, inputs=[
392
+ chatInterface.chatbot
393
+ ], outputs=[chatInterface.chatbot_state])
394
+
395
+ if __name__ == "__main__":
396
+ demo.queue().launch()
chat_interface.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file defines a useful high-level abstraction to build Gradio chatbots: ChatInterface.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import inspect
8
+ from typing import AsyncGenerator, Callable, Literal, Union, cast
9
+
10
+ import anyio
11
+ from gradio_client.documentation import document
12
+
13
+ from gradio.blocks import Blocks
14
+ from gradio.components import (
15
+ Button,
16
+ Chatbot,
17
+ Component,
18
+ Markdown,
19
+ MultimodalTextbox,
20
+ State,
21
+ Textbox,
22
+ get_component_instance,
23
+ Dataset,
24
+ )
25
+ from gradio.events import Dependency, on
26
+ from gradio.helpers import special_args
27
+ from gradio.layouts import Accordion, Group, Row
28
+ from gradio.routes import Request
29
+ from gradio.themes import ThemeClass as Theme
30
+ from gradio.utils import SyncToAsyncIterator, async_iteration, async_lambda
31
+
32
+
33
+ @document()
34
+ class ChatInterface(Blocks):
35
+ """
36
+ ChatInterface is Gradio's high-level abstraction for creating chatbot UIs, and allows you to create
37
+ a web-based demo around a chatbot model in a few lines of code. Only one parameter is required: fn, which
38
+ takes a function that governs the response of the chatbot based on the user input and chat history. Additional
39
+ parameters can be used to control the appearance and behavior of the demo.
40
+
41
+ Example:
42
+ import gradio as gr
43
+
44
+ def echo(message, history):
45
+ return message
46
+
47
+ demo = gr.ChatInterface(fn=echo, examples=["hello", "hola", "merhaba"], title="Echo Bot")
48
+ demo.launch()
49
+ Demos: chatinterface_multimodal, chatinterface_random_response, chatinterface_streaming_echo
50
+ Guides: creating-a-chatbot-fast, sharing-your-app
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ fn: Callable,
56
+ post_fn: Callable,
57
+ pre_fn: Callable,
58
+ chatbot: Chatbot,
59
+ *,
60
+ show_stop_button=True,
61
+ post_fn_kwargs: dict = None,
62
+ pre_fn_kwargs: dict = None,
63
+ multimodal: bool = False,
64
+ textbox: Textbox | MultimodalTextbox | None = None,
65
+ additional_inputs: str | Component | list[str | Component] | None = None,
66
+ additional_inputs_accordion_name: str | None = None,
67
+ additional_inputs_accordion: str | Accordion | None = None,
68
+ examples: Dataset = None,
69
+ title: str | None = None,
70
+ description: str | None = None,
71
+ theme: Theme | str | None = None,
72
+ css: str | None = None,
73
+ js: str | None = None,
74
+ head: str | None = None,
75
+ analytics_enabled: bool | None = None,
76
+ submit_btn: str | None | Button = "Submit",
77
+ stop_btn: str | None | Button = "Stop",
78
+ retry_btn: str | None | Button = "🔄 Retry",
79
+ undo_btn: str | None | Button = "↩️ Undo",
80
+ clear_btn: str | None | Button = "🗑️ Clear",
81
+ autofocus: bool = True,
82
+ concurrency_limit: int | None | Literal["default"] = "default",
83
+ fill_height: bool = True,
84
+ delete_cache: tuple[int, int] | None = None,
85
+ ):
86
+ super().__init__(
87
+ analytics_enabled=analytics_enabled,
88
+ mode="chat_interface",
89
+ css=css,
90
+ title=title or "Gradio",
91
+ theme=theme,
92
+ js=js,
93
+ head=head,
94
+ fill_height=fill_height,
95
+ delete_cache=delete_cache,
96
+ )
97
+
98
+ if post_fn_kwargs is None:
99
+ post_fn_kwargs = []
100
+
101
+ self.post_fn = post_fn
102
+ self.post_fn_kwargs = post_fn_kwargs
103
+
104
+ self.pre_fn = pre_fn
105
+ self.pre_fn_kwargs = pre_fn_kwargs
106
+
107
+ self.show_stop_button = show_stop_button
108
+
109
+ self.interrupter = State(None)
110
+
111
+ self.multimodal = multimodal
112
+ self.concurrency_limit = concurrency_limit
113
+ self.fn = fn
114
+ self.is_async = inspect.iscoroutinefunction(
115
+ self.fn
116
+ ) or inspect.isasyncgenfunction(self.fn)
117
+ self.is_generator = inspect.isgeneratorfunction(
118
+ self.fn
119
+ ) or inspect.isasyncgenfunction(self.fn)
120
+
121
+ if additional_inputs:
122
+ if not isinstance(additional_inputs, list):
123
+ additional_inputs = [additional_inputs]
124
+ self.additional_inputs = [
125
+ get_component_instance(i)
126
+ for i in additional_inputs # type: ignore
127
+ ]
128
+ else:
129
+ self.additional_inputs = []
130
+ if additional_inputs_accordion_name is not None:
131
+ print(
132
+ "The `additional_inputs_accordion_name` parameter is deprecated and will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead."
133
+ )
134
+ self.additional_inputs_accordion_params = {
135
+ "label": additional_inputs_accordion_name
136
+ }
137
+ if additional_inputs_accordion is None:
138
+ self.additional_inputs_accordion_params = {
139
+ "label": "Additional Inputs",
140
+ "open": False,
141
+ }
142
+ elif isinstance(additional_inputs_accordion, str):
143
+ self.additional_inputs_accordion_params = {
144
+ "label": additional_inputs_accordion
145
+ }
146
+ elif isinstance(additional_inputs_accordion, Accordion):
147
+ self.additional_inputs_accordion_params = (
148
+ additional_inputs_accordion.recover_kwargs(
149
+ additional_inputs_accordion.get_config()
150
+ )
151
+ )
152
+ else:
153
+ raise ValueError(
154
+ f"The `additional_inputs_accordion` parameter must be a string or gr.Accordion, not {type(additional_inputs_accordion)}"
155
+ )
156
+
157
+ with self:
158
+ if title:
159
+ Markdown(
160
+ f"<h1 style='text-align: center; margin-bottom: 1rem'>{self.title}</h1>"
161
+ )
162
+ if description:
163
+ Markdown(description)
164
+
165
+ self.chatbot = chatbot.render()
166
+
167
+ self.buttons = [retry_btn, undo_btn, clear_btn]
168
+
169
+ with Group():
170
+ with Row():
171
+ if textbox:
172
+ if self.multimodal:
173
+ submit_btn = None
174
+ else:
175
+ textbox.container = False
176
+ textbox.show_label = False
177
+ textbox_ = textbox.render()
178
+ if not isinstance(textbox_, (Textbox, MultimodalTextbox)):
179
+ raise TypeError(
180
+ f"Expected a gr.Textbox or gr.MultimodalTextbox component, but got {type(textbox_)}"
181
+ )
182
+ self.textbox = textbox_
183
+ elif self.multimodal:
184
+ submit_btn = None
185
+ self.textbox = MultimodalTextbox(
186
+ show_label=False,
187
+ label="Message",
188
+ placeholder="Type a message...",
189
+ scale=7,
190
+ autofocus=autofocus,
191
+ )
192
+ else:
193
+ self.textbox = Textbox(
194
+ container=False,
195
+ show_label=False,
196
+ label="Message",
197
+ placeholder="Type a message...",
198
+ scale=7,
199
+ autofocus=autofocus,
200
+ )
201
+ if submit_btn is not None and not multimodal:
202
+ if isinstance(submit_btn, Button):
203
+ submit_btn.render()
204
+ elif isinstance(submit_btn, str):
205
+ submit_btn = Button(
206
+ submit_btn,
207
+ variant="primary",
208
+ scale=1,
209
+ min_width=150,
210
+ )
211
+ else:
212
+ raise ValueError(
213
+ f"The submit_btn parameter must be a gr.Button, string, or None, not {type(submit_btn)}"
214
+ )
215
+ if stop_btn is not None:
216
+ if isinstance(stop_btn, Button):
217
+ stop_btn.visible = False
218
+ stop_btn.render()
219
+ elif isinstance(stop_btn, str):
220
+ stop_btn = Button(
221
+ stop_btn,
222
+ variant="stop",
223
+ visible=False,
224
+ scale=1,
225
+ min_width=150,
226
+ )
227
+ else:
228
+ raise ValueError(
229
+ f"The stop_btn parameter must be a gr.Button, string, or None, not {type(stop_btn)}"
230
+ )
231
+ self.buttons.extend([submit_btn, stop_btn]) # type: ignore
232
+
233
+ self.fake_api_btn = Button("Fake API", visible=False)
234
+ self.fake_response_textbox = Textbox(label="Response", visible=False)
235
+ (
236
+ self.retry_btn,
237
+ self.undo_btn,
238
+ self.clear_btn,
239
+ self.submit_btn,
240
+ self.stop_btn,
241
+ ) = self.buttons
242
+
243
+ any_unrendered_inputs = any(
244
+ not inp.is_rendered for inp in self.additional_inputs
245
+ )
246
+ if self.additional_inputs and any_unrendered_inputs:
247
+ with Accordion(**self.additional_inputs_accordion_params): # type: ignore
248
+ for input_component in self.additional_inputs:
249
+ if not input_component.is_rendered:
250
+ input_component.render()
251
+
252
+ self.saved_input = State()
253
+ self.chatbot_state = (
254
+ State(self.chatbot.value) if self.chatbot.value else State([])
255
+ )
256
+
257
+ self._setup_events()
258
+ self._setup_api()
259
+
260
+ if examples:
261
+ examples.click(lambda x: x[0], inputs=[examples], outputs=self.textbox, show_progress=False, queue=False)
262
+
263
+ def _setup_events(self) -> None:
264
+ submit_fn = self._stream_fn if self.is_generator else self._submit_fn
265
+ submit_triggers = (
266
+ [self.textbox.submit, self.submit_btn.click]
267
+ if self.submit_btn
268
+ else [self.textbox.submit]
269
+ )
270
+ submit_event = (
271
+ on(
272
+ submit_triggers,
273
+ self._clear_and_save_textbox,
274
+ [self.textbox],
275
+ [self.textbox, self.saved_input],
276
+ show_api=False,
277
+ queue=False,
278
+ )
279
+ .then(
280
+ self.pre_fn,
281
+ **self.pre_fn_kwargs,
282
+ show_api=False,
283
+ queue=False,
284
+ )
285
+ .then(
286
+ self._display_input,
287
+ [self.saved_input, self.chatbot_state],
288
+ [self.chatbot, self.chatbot_state],
289
+ show_api=False,
290
+ queue=False,
291
+ )
292
+ .then(
293
+ submit_fn,
294
+ [self.saved_input, self.chatbot_state] + self.additional_inputs,
295
+ [self.chatbot, self.chatbot_state, self.interrupter],
296
+ show_api=False,
297
+ concurrency_limit=cast(
298
+ Union[int, Literal["default"], None], self.concurrency_limit
299
+ ),
300
+ ).then(
301
+ self.post_fn,
302
+ **self.post_fn_kwargs,
303
+ show_api=False,
304
+ concurrency_limit=cast(
305
+ Union[int, Literal["default"], None], self.concurrency_limit
306
+ ),
307
+ )
308
+ )
309
+ self._setup_stop_events(submit_triggers, submit_event)
310
+
311
+ if self.retry_btn:
312
+ retry_event = (
313
+ self.retry_btn.click(
314
+ self._delete_prev_fn,
315
+ [self.saved_input, self.chatbot_state],
316
+ [self.chatbot, self.saved_input, self.chatbot_state],
317
+ show_api=False,
318
+ queue=False,
319
+ )
320
+ .then(
321
+ self.pre_fn,
322
+ **self.pre_fn_kwargs,
323
+ show_api=False,
324
+ queue=False,
325
+ )
326
+ .then(
327
+ self._display_input,
328
+ [self.saved_input, self.chatbot_state],
329
+ [self.chatbot, self.chatbot_state],
330
+ show_api=False,
331
+ queue=False,
332
+ )
333
+ .then(
334
+ submit_fn,
335
+ [self.saved_input, self.chatbot_state] + self.additional_inputs,
336
+ [self.chatbot, self.chatbot_state],
337
+ show_api=False,
338
+ concurrency_limit=cast(
339
+ Union[int, Literal["default"], None], self.concurrency_limit
340
+ ),
341
+ ).then(
342
+ self.post_fn,
343
+ **self.post_fn_kwargs,
344
+ show_api=False,
345
+ concurrency_limit=cast(
346
+ Union[int, Literal["default"], None], self.concurrency_limit
347
+ ),
348
+ )
349
+ )
350
+ self._setup_stop_events([self.retry_btn.click], retry_event)
351
+
352
+ if self.undo_btn:
353
+ self.undo_btn.click(
354
+ self._delete_prev_fn,
355
+ [self.saved_input, self.chatbot_state],
356
+ [self.chatbot, self.saved_input, self.chatbot_state],
357
+ show_api=False,
358
+ queue=False,
359
+ ).then(
360
+ self.pre_fn,
361
+ **self.pre_fn_kwargs,
362
+ show_api=False,
363
+ queue=False,
364
+ ).then(
365
+ async_lambda(lambda x: x),
366
+ [self.saved_input],
367
+ [self.textbox],
368
+ show_api=False,
369
+ queue=False,
370
+ ).then(
371
+ self.post_fn,
372
+ **self.post_fn_kwargs,
373
+ show_api=False,
374
+ concurrency_limit=cast(
375
+ Union[int, Literal["default"], None], self.concurrency_limit
376
+ ),
377
+ )
378
+
379
+ if self.clear_btn:
380
+ self.clear_btn.click(
381
+ async_lambda(lambda: ([], [], None)),
382
+ None,
383
+ [self.chatbot, self.chatbot_state, self.saved_input],
384
+ queue=False,
385
+ show_api=False,
386
+ ).then(
387
+ self.pre_fn,
388
+ **self.pre_fn_kwargs,
389
+ show_api=False,
390
+ queue=False,
391
+ ).then(
392
+ self.post_fn,
393
+ **self.post_fn_kwargs,
394
+ show_api=False,
395
+ concurrency_limit=cast(
396
+ Union[int, Literal["default"], None], self.concurrency_limit
397
+ ),
398
+ )
399
+
400
+ def _setup_stop_events(
401
+ self, event_triggers: list[Callable], event_to_cancel: Dependency
402
+ ) -> None:
403
+ def perform_interrupt(ipc):
404
+ if ipc is not None:
405
+ ipc()
406
+ return
407
+
408
+ if self.stop_btn and self.is_generator:
409
+ if self.submit_btn:
410
+ for event_trigger in event_triggers:
411
+ event_trigger(
412
+ async_lambda(
413
+ lambda: (
414
+ Button(visible=False),
415
+ Button(visible=self.show_stop_button),
416
+ )
417
+ ),
418
+ None,
419
+ [self.submit_btn, self.stop_btn],
420
+ show_api=False,
421
+ queue=False,
422
+ )
423
+ event_to_cancel.then(
424
+ async_lambda(lambda: (Button(visible=True), Button(visible=False))),
425
+ None,
426
+ [self.submit_btn, self.stop_btn],
427
+ show_api=False,
428
+ queue=False,
429
+ )
430
+ else:
431
+ for event_trigger in event_triggers:
432
+ event_trigger(
433
+ async_lambda(lambda: Button(visible=self.show_stop_button)),
434
+ None,
435
+ [self.stop_btn],
436
+ show_api=False,
437
+ queue=False,
438
+ )
439
+ event_to_cancel.then(
440
+ async_lambda(lambda: Button(visible=False)),
441
+ None,
442
+ [self.stop_btn],
443
+ show_api=False,
444
+ queue=False,
445
+ )
446
+ self.stop_btn.click(
447
+ fn=perform_interrupt,
448
+ inputs=[self.interrupter],
449
+ cancels=event_to_cancel,
450
+ show_api=False,
451
+ )
452
+
453
+ def _setup_api(self) -> None:
454
+ api_fn = self._api_stream_fn if self.is_generator else self._api_submit_fn
455
+
456
+ self.fake_api_btn.click(
457
+ api_fn,
458
+ [self.textbox, self.chatbot_state] + self.additional_inputs,
459
+ [self.textbox, self.chatbot_state],
460
+ api_name="chat",
461
+ concurrency_limit=cast(
462
+ Union[int, Literal["default"], None], self.concurrency_limit
463
+ ),
464
+ )
465
+
466
+ def _clear_and_save_textbox(self, message: str) -> tuple[str | dict, str]:
467
+ if self.multimodal:
468
+ return {"text": "", "files": []}, message
469
+ else:
470
+ return "", message
471
+
472
+ def _append_multimodal_history(
473
+ self,
474
+ message: dict[str, list],
475
+ response: str | None,
476
+ history: list[list[str | tuple | None]],
477
+ ):
478
+ for x in message["files"]:
479
+ history.append([(x,), None])
480
+ if message["text"] is None or not isinstance(message["text"], str):
481
+ return
482
+ elif message["text"] == "" and message["files"] != []:
483
+ history.append([None, response])
484
+ else:
485
+ history.append([message["text"], response])
486
+
487
+ async def _display_input(
488
+ self, message: str | dict[str, list], history: list[list[str | tuple | None]]
489
+ ) -> tuple[list[list[str | tuple | None]], list[list[str | tuple | None]]]:
490
+ if self.multimodal and isinstance(message, dict):
491
+ self._append_multimodal_history(message, None, history)
492
+ elif isinstance(message, str):
493
+ history.append([message, None])
494
+ return history, history
495
+
496
+ async def _submit_fn(
497
+ self,
498
+ message: str | dict[str, list],
499
+ history_with_input: list[list[str | tuple | None]],
500
+ request: Request,
501
+ *args,
502
+ ) -> tuple[list[list[str | tuple | None]], list[list[str | tuple | None]]]:
503
+ if self.multimodal and isinstance(message, dict):
504
+ remove_input = (
505
+ len(message["files"]) + 1
506
+ if message["text"] is not None
507
+ else len(message["files"])
508
+ )
509
+ history = history_with_input[:-remove_input]
510
+ else:
511
+ history = history_with_input[:-1]
512
+ inputs, _, _ = special_args(
513
+ self.fn, inputs=[message, history, *args], request=request
514
+ )
515
+
516
+ if self.is_async:
517
+ response = await self.fn(*inputs)
518
+ else:
519
+ response = await anyio.to_thread.run_sync(
520
+ self.fn, *inputs, limiter=self.limiter
521
+ )
522
+
523
+ if self.multimodal and isinstance(message, dict):
524
+ self._append_multimodal_history(message, response, history)
525
+ elif isinstance(message, str):
526
+ history.append([message, response])
527
+ return history, history
528
+
529
+ async def _stream_fn(
530
+ self,
531
+ message: str | dict[str, list],
532
+ history_with_input: list[list[str | tuple | None]],
533
+ request: Request,
534
+ *args,
535
+ ) -> AsyncGenerator:
536
+ if self.multimodal and isinstance(message, dict):
537
+ remove_input = (
538
+ len(message["files"]) + 1
539
+ if message["text"] is not None
540
+ else len(message["files"])
541
+ )
542
+ history = history_with_input[:-remove_input]
543
+ else:
544
+ history = history_with_input[:-1]
545
+ inputs, _, _ = special_args(
546
+ self.fn, inputs=[message, history, *args], request=request
547
+ )
548
+
549
+ if self.is_async:
550
+ generator = self.fn(*inputs)
551
+ else:
552
+ generator = await anyio.to_thread.run_sync(
553
+ self.fn, *inputs, limiter=self.limiter
554
+ )
555
+ generator = SyncToAsyncIterator(generator, self.limiter)
556
+ try:
557
+ first_response, first_interrupter = await async_iteration(generator)
558
+ if self.multimodal and isinstance(message, dict):
559
+ for x in message["files"]:
560
+ history.append([(x,), None])
561
+ update = history + [[message["text"], first_response]]
562
+ yield update, update
563
+ else:
564
+ update = history + [[message, first_response]]
565
+ yield update, update, first_interrupter
566
+ except StopIteration:
567
+ if self.multimodal and isinstance(message, dict):
568
+ self._append_multimodal_history(message, None, history)
569
+ yield history, history
570
+ else:
571
+ update = history + [[message, None]]
572
+ yield update, update, first_interrupter
573
+ async for response, interrupter in generator:
574
+ if self.multimodal and isinstance(message, dict):
575
+ update = history + [[message["text"], response]]
576
+ yield update, update
577
+ else:
578
+ update = history + [[message, response]]
579
+ yield update, update, interrupter
580
+
581
+ async def _api_submit_fn(
582
+ self, message: str, history: list[list[str | None]], request: Request, *args
583
+ ) -> tuple[str, list[list[str | None]]]:
584
+ inputs, _, _ = special_args(
585
+ self.fn, inputs=[message, history, *args], request=request
586
+ )
587
+
588
+ if self.is_async:
589
+ response = await self.fn(*inputs)
590
+ else:
591
+ response = await anyio.to_thread.run_sync(
592
+ self.fn, *inputs, limiter=self.limiter
593
+ )
594
+ history.append([message, response])
595
+ return response, history
596
+
597
+ async def _api_stream_fn(
598
+ self, message: str, history: list[list[str | None]], request: Request, *args
599
+ ) -> AsyncGenerator:
600
+ inputs, _, _ = special_args(
601
+ self.fn, inputs=[message, history, *args], request=request
602
+ )
603
+
604
+ if self.is_async:
605
+ generator = self.fn(*inputs)
606
+ else:
607
+ generator = await anyio.to_thread.run_sync(
608
+ self.fn, *inputs, limiter=self.limiter
609
+ )
610
+ generator = SyncToAsyncIterator(generator, self.limiter)
611
+ try:
612
+ first_response = await async_iteration(generator)
613
+ yield first_response, history + [[message, first_response]]
614
+ except StopIteration:
615
+ yield None, history + [[message, None]]
616
+ async for response in generator:
617
+ yield response, history + [[message, response]]
618
+
619
+ async def _delete_prev_fn(
620
+ self,
621
+ message: str | dict[str, list],
622
+ history: list[list[str | tuple | None]],
623
+ ) -> tuple[
624
+ list[list[str | tuple | None]],
625
+ str | dict[str, list],
626
+ list[list[str | tuple | None]],
627
+ ]:
628
+ if self.multimodal and isinstance(message, dict):
629
+ remove_input = (
630
+ len(message["files"]) + 1
631
+ if message["text"] is not None
632
+ else len(message["files"])
633
+ )
634
+ history = history[:-remove_input]
635
+ else:
636
+ while history:
637
+ deleted_a, deleted_b = history[-1]
638
+ history = history[:-1]
639
+ if isinstance(deleted_a, str) and isinstance(deleted_b, str):
640
+ break
641
+ return history, message or "", history
lib_omost/canvas.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import difflib
3
+ import numpy as np
4
+
5
+ system_prompt = r'''You are a helpful AI assistant to compose images using the below python class `Canvas`:
6
+
7
+ ```python
8
+ class Canvas:
9
+ def set_global_description(self, description: str, detailed_descriptions: list[str], tags: str, HTML_web_color_name: str):
10
+ pass
11
+
12
+ def add_local_description(self, location: str, offset: str, area: str, distance_to_viewer: float, description: str, detailed_descriptions: list[str], tags: str, atmosphere: str, style: str, quality_meta: str, HTML_web_color_name: str):
13
+ assert location in ["in the center", "on the left", "on the right", "on the top", "on the bottom", "on the top-left", "on the top-right", "on the bottom-left", "on the bottom-right"]
14
+ assert offset in ["no offset", "slightly to the left", "slightly to the right", "slightly to the upper", "slightly to the lower", "slightly to the upper-left", "slightly to the upper-right", "slightly to the lower-left", "slightly to the lower-right"]
15
+ assert area in ["a small square area", "a small vertical area", "a small horizontal area", "a medium-sized square area", "a medium-sized vertical area", "a medium-sized horizontal area", "a large square area", "a large vertical area", "a large horizontal area"]
16
+ assert distance_to_viewer > 0
17
+ pass
18
+ ```'''
19
+
20
+ valid_colors = { # r, g, b
21
+ 'aliceblue': (240, 248, 255), 'antiquewhite': (250, 235, 215), 'aqua': (0, 255, 255),
22
+ 'aquamarine': (127, 255, 212), 'azure': (240, 255, 255), 'beige': (245, 245, 220),
23
+ 'bisque': (255, 228, 196), 'black': (0, 0, 0), 'blanchedalmond': (255, 235, 205), 'blue': (0, 0, 255),
24
+ 'blueviolet': (138, 43, 226), 'brown': (165, 42, 42), 'burlywood': (222, 184, 135),
25
+ 'cadetblue': (95, 158, 160), 'chartreuse': (127, 255, 0), 'chocolate': (210, 105, 30),
26
+ 'coral': (255, 127, 80), 'cornflowerblue': (100, 149, 237), 'cornsilk': (255, 248, 220),
27
+ 'crimson': (220, 20, 60), 'cyan': (0, 255, 255), 'darkblue': (0, 0, 139), 'darkcyan': (0, 139, 139),
28
+ 'darkgoldenrod': (184, 134, 11), 'darkgray': (169, 169, 169), 'darkgrey': (169, 169, 169),
29
+ 'darkgreen': (0, 100, 0), 'darkkhaki': (189, 183, 107), 'darkmagenta': (139, 0, 139),
30
+ 'darkolivegreen': (85, 107, 47), 'darkorange': (255, 140, 0), 'darkorchid': (153, 50, 204),
31
+ 'darkred': (139, 0, 0), 'darksalmon': (233, 150, 122), 'darkseagreen': (143, 188, 143),
32
+ 'darkslateblue': (72, 61, 139), 'darkslategray': (47, 79, 79), 'darkslategrey': (47, 79, 79),
33
+ 'darkturquoise': (0, 206, 209), 'darkviolet': (148, 0, 211), 'deeppink': (255, 20, 147),
34
+ 'deepskyblue': (0, 191, 255), 'dimgray': (105, 105, 105), 'dimgrey': (105, 105, 105),
35
+ 'dodgerblue': (30, 144, 255), 'firebrick': (178, 34, 34), 'floralwhite': (255, 250, 240),
36
+ 'forestgreen': (34, 139, 34), 'fuchsia': (255, 0, 255), 'gainsboro': (220, 220, 220),
37
+ 'ghostwhite': (248, 248, 255), 'gold': (255, 215, 0), 'goldenrod': (218, 165, 32),
38
+ 'gray': (128, 128, 128), 'grey': (128, 128, 128), 'green': (0, 128, 0), 'greenyellow': (173, 255, 47),
39
+ 'honeydew': (240, 255, 240), 'hotpink': (255, 105, 180), 'indianred': (205, 92, 92),
40
+ 'indigo': (75, 0, 130), 'ivory': (255, 255, 240), 'khaki': (240, 230, 140), 'lavender': (230, 230, 250),
41
+ 'lavenderblush': (255, 240, 245), 'lawngreen': (124, 252, 0), 'lemonchiffon': (255, 250, 205),
42
+ 'lightblue': (173, 216, 230), 'lightcoral': (240, 128, 128), 'lightcyan': (224, 255, 255),
43
+ 'lightgoldenrodyellow': (250, 250, 210), 'lightgray': (211, 211, 211), 'lightgrey': (211, 211, 211),
44
+ 'lightgreen': (144, 238, 144), 'lightpink': (255, 182, 193), 'lightsalmon': (255, 160, 122),
45
+ 'lightseagreen': (32, 178, 170), 'lightskyblue': (135, 206, 250), 'lightslategray': (119, 136, 153),
46
+ 'lightslategrey': (119, 136, 153), 'lightsteelblue': (176, 196, 222), 'lightyellow': (255, 255, 224),
47
+ 'lime': (0, 255, 0), 'limegreen': (50, 205, 50), 'linen': (250, 240, 230), 'magenta': (255, 0, 255),
48
+ 'maroon': (128, 0, 0), 'mediumaquamarine': (102, 205, 170), 'mediumblue': (0, 0, 205),
49
+ 'mediumorchid': (186, 85, 211), 'mediumpurple': (147, 112, 219), 'mediumseagreen': (60, 179, 113),
50
+ 'mediumslateblue': (123, 104, 238), 'mediumspringgreen': (0, 250, 154),
51
+ 'mediumturquoise': (72, 209, 204), 'mediumvioletred': (199, 21, 133), 'midnightblue': (25, 25, 112),
52
+ 'mintcream': (245, 255, 250), 'mistyrose': (255, 228, 225), 'moccasin': (255, 228, 181),
53
+ 'navajowhite': (255, 222, 173), 'navy': (0, 0, 128), 'navyblue': (0, 0, 128),
54
+ 'oldlace': (253, 245, 230), 'olive': (128, 128, 0), 'olivedrab': (107, 142, 35),
55
+ 'orange': (255, 165, 0), 'orangered': (255, 69, 0), 'orchid': (218, 112, 214),
56
+ 'palegoldenrod': (238, 232, 170), 'palegreen': (152, 251, 152), 'paleturquoise': (175, 238, 238),
57
+ 'palevioletred': (219, 112, 147), 'papayawhip': (255, 239, 213), 'peachpuff': (255, 218, 185),
58
+ 'peru': (205, 133, 63), 'pink': (255, 192, 203), 'plum': (221, 160, 221), 'powderblue': (176, 224, 230),
59
+ 'purple': (128, 0, 128), 'rebeccapurple': (102, 51, 153), 'red': (255, 0, 0),
60
+ 'rosybrown': (188, 143, 143), 'royalblue': (65, 105, 225), 'saddlebrown': (139, 69, 19),
61
+ 'salmon': (250, 128, 114), 'sandybrown': (244, 164, 96), 'seagreen': (46, 139, 87),
62
+ 'seashell': (255, 245, 238), 'sienna': (160, 82, 45), 'silver': (192, 192, 192),
63
+ 'skyblue': (135, 206, 235), 'slateblue': (106, 90, 205), 'slategray': (112, 128, 144),
64
+ 'slategrey': (112, 128, 144), 'snow': (255, 250, 250), 'springgreen': (0, 255, 127),
65
+ 'steelblue': (70, 130, 180), 'tan': (210, 180, 140), 'teal': (0, 128, 128), 'thistle': (216, 191, 216),
66
+ 'tomato': (255, 99, 71), 'turquoise': (64, 224, 208), 'violet': (238, 130, 238),
67
+ 'wheat': (245, 222, 179), 'white': (255, 255, 255), 'whitesmoke': (245, 245, 245),
68
+ 'yellow': (255, 255, 0), 'yellowgreen': (154, 205, 50)
69
+ }
70
+
71
+ valid_locations = { # x, y in 90*90
72
+ 'in the center': (45, 45),
73
+ 'on the left': (15, 45),
74
+ 'on the right': (75, 45),
75
+ 'on the top': (45, 15),
76
+ 'on the bottom': (45, 75),
77
+ 'on the top-left': (15, 15),
78
+ 'on the top-right': (75, 15),
79
+ 'on the bottom-left': (15, 75),
80
+ 'on the bottom-right': (75, 75)
81
+ }
82
+
83
+ valid_offsets = { # x, y in 90*90
84
+ 'no offset': (0, 0),
85
+ 'slightly to the left': (-10, 0),
86
+ 'slightly to the right': (10, 0),
87
+ 'slightly to the upper': (0, -10),
88
+ 'slightly to the lower': (0, 10),
89
+ 'slightly to the upper-left': (-10, -10),
90
+ 'slightly to the upper-right': (10, -10),
91
+ 'slightly to the lower-left': (-10, 10),
92
+ 'slightly to the lower-right': (10, 10)}
93
+
94
+ valid_areas = { # w, h in 90*90
95
+ "a small square area": (50, 50),
96
+ "a small vertical area": (40, 60),
97
+ "a small horizontal area": (60, 40),
98
+ "a medium-sized square area": (60, 60),
99
+ "a medium-sized vertical area": (50, 80),
100
+ "a medium-sized horizontal area": (80, 50),
101
+ "a large square area": (70, 70),
102
+ "a large vertical area": (60, 90),
103
+ "a large horizontal area": (90, 60)
104
+ }
105
+
106
+
107
+ def closest_name(input_str, options):
108
+ input_str = input_str.lower()
109
+
110
+ closest_match = difflib.get_close_matches(input_str, list(options.keys()), n=1, cutoff=0.5)
111
+ assert isinstance(closest_match, list) and len(closest_match) > 0, f'The value [{input_str}] is not valid!'
112
+ result = closest_match[0]
113
+
114
+ if result != input_str:
115
+ print(f'Automatically corrected [{input_str}] -> [{result}].')
116
+
117
+ return result
118
+
119
+
120
+ def safe_str(x):
121
+ return x.strip(',. ') + '.'
122
+
123
+
124
+ def binary_nonzero_positions(n, offset=0):
125
+ binary_str = bin(n)[2:]
126
+ positions = [i + offset for i, bit in enumerate(reversed(binary_str)) if bit == '1']
127
+ return positions
128
+
129
+
130
+ class Canvas:
131
+ @staticmethod
132
+ def from_bot_response(response: str):
133
+ matched = re.search(r'```python\n(.*?)\n```', response, re.DOTALL)
134
+ assert matched, 'Response does not contain codes!'
135
+ code_content = matched.group(1)
136
+ assert 'canvas = Canvas()' in code_content, 'Code block must include valid canvas var!'
137
+ local_vars = {'Canvas': Canvas}
138
+ exec(code_content, {}, local_vars)
139
+ canvas = local_vars.get('canvas', None)
140
+ assert isinstance(canvas, Canvas), 'Code block must produce valid canvas var!'
141
+ return canvas
142
+
143
+ def __init__(self):
144
+ self.components = []
145
+ self.color = None
146
+ self.record_tags = True
147
+ self.prefixes = []
148
+ self.suffixes = []
149
+ return
150
+
151
+ def set_global_description(self, description: str, detailed_descriptions: list[str], tags: str,
152
+ HTML_web_color_name: str):
153
+ assert isinstance(description, str), 'Global description is not valid!'
154
+ assert isinstance(detailed_descriptions, list) and all(isinstance(item, str) for item in detailed_descriptions), \
155
+ 'Global detailed_descriptions is not valid!'
156
+ assert isinstance(tags, str), 'Global tags is not valid!'
157
+
158
+ HTML_web_color_name = closest_name(HTML_web_color_name, valid_colors)
159
+ self.color = np.array([[valid_colors[HTML_web_color_name]]], dtype=np.uint8)
160
+
161
+ self.prefixes = ["Generate a suitable background for a product:"] + [description]
162
+ self.suffixes = detailed_descriptions
163
+
164
+ if self.record_tags:
165
+ self.suffixes = self.suffixes + [tags]
166
+
167
+ self.prefixes = [safe_str(x) for x in self.prefixes]
168
+ self.suffixes = [safe_str(x) for x in self.suffixes]
169
+
170
+ return
171
+
172
+ def add_local_description(self, location: str, offset: str, area: str, distance_to_viewer: float, description: str,
173
+ detailed_descriptions: list[str], tags: str, atmosphere: str, style: str,
174
+ quality_meta: str, HTML_web_color_name: str):
175
+ assert isinstance(description, str), 'Local description is wrong!'
176
+ assert isinstance(distance_to_viewer, (int, float)) and distance_to_viewer > 0, \
177
+ f'The distance_to_viewer for [{description}] is not positive float number!'
178
+ assert isinstance(detailed_descriptions, list) and all(isinstance(item, str) for item in detailed_descriptions), \
179
+ f'The detailed_descriptions for [{description}] is not valid!'
180
+ assert isinstance(tags, str), f'The tags for [{description}] is not valid!'
181
+ assert isinstance(atmosphere, str), f'The atmosphere for [{description}] is not valid!'
182
+ assert isinstance(style, str), f'The style for [{description}] is not valid!'
183
+ assert isinstance(quality_meta, str), f'The quality_meta for [{description}] is not valid!'
184
+
185
+ location = closest_name(location, valid_locations)
186
+ offset = closest_name(offset, valid_offsets)
187
+ area = closest_name(area, valid_areas)
188
+ HTML_web_color_name = closest_name(HTML_web_color_name, valid_colors)
189
+
190
+ xb, yb = valid_locations[location]
191
+ xo, yo = valid_offsets[offset]
192
+ w, h = valid_areas[area]
193
+ rect = (yb + yo - h // 2, yb + yo + h // 2, xb + xo - w // 2, xb + xo + w // 2)
194
+ rect = [max(0, min(90, i)) for i in rect]
195
+ color = np.array([[valid_colors[HTML_web_color_name]]], dtype=np.uint8)
196
+
197
+ prefixes = self.prefixes + [description]
198
+ suffixes = detailed_descriptions
199
+
200
+ if self.record_tags:
201
+ suffixes = suffixes + [tags, atmosphere, style, quality_meta]
202
+
203
+ prefixes = [safe_str(x) for x in prefixes]
204
+ suffixes = [safe_str(x) for x in suffixes]
205
+
206
+ self.components.append(dict(
207
+ rect=rect,
208
+ distance_to_viewer=distance_to_viewer,
209
+ color=color,
210
+ prefixes=prefixes,
211
+ suffixes=suffixes
212
+ ))
213
+
214
+ return
215
+
216
+ def process(self):
217
+ # sort components
218
+ self.components = sorted(self.components, key=lambda x: x['distance_to_viewer'], reverse=True)
219
+
220
+ # compute initial latent
221
+ initial_latent = np.zeros(shape=(90, 90, 3), dtype=np.float32) + self.color
222
+
223
+ for component in self.components:
224
+ a, b, c, d = component['rect']
225
+ initial_latent[a:b, c:d] = 0.7 * component['color'] + 0.3 * initial_latent[a:b, c:d]
226
+
227
+ initial_latent = initial_latent.clip(0, 255).astype(np.uint8)
228
+
229
+ # compute conditions
230
+
231
+ bag_of_conditions = [
232
+ dict(mask=np.ones(shape=(90, 90), dtype=np.float32), prefixes=self.prefixes, suffixes=self.suffixes)
233
+ ]
234
+
235
+ for i, component in enumerate(self.components):
236
+ a, b, c, d = component['rect']
237
+ m = np.zeros(shape=(90, 90), dtype=np.float32)
238
+ m[a:b, c:d] = 1.0
239
+ bag_of_conditions.append(dict(
240
+ mask=m,
241
+ prefixes=component['prefixes'],
242
+ suffixes=component['suffixes']
243
+ ))
244
+
245
+ return dict(
246
+ initial_latent=initial_latent,
247
+ bag_of_conditions=bag_of_conditions,
248
+ )
lib_omost/pipeline.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import copy
3
+
4
+ from tqdm.auto import trange
5
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import *
6
+ from diffusers.models.transformers import Transformer2DModel
7
+
8
+
9
+ original_Transformer2DModel_forward = Transformer2DModel.forward
10
+
11
+
12
+ def hacked_Transformer2DModel_forward(
13
+ self,
14
+ hidden_states: torch.Tensor,
15
+ encoder_hidden_states: Optional[torch.Tensor] = None,
16
+ timestep: Optional[torch.LongTensor] = None,
17
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
18
+ class_labels: Optional[torch.LongTensor] = None,
19
+ cross_attention_kwargs: Dict[str, Any] = None,
20
+ attention_mask: Optional[torch.Tensor] = None,
21
+ encoder_attention_mask: Optional[torch.Tensor] = None,
22
+ return_dict: bool = True,
23
+ ):
24
+ cross_attention_kwargs = cross_attention_kwargs or {}
25
+ cross_attention_kwargs['hidden_states_original_shape'] = hidden_states.shape
26
+ return original_Transformer2DModel_forward(
27
+ self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs, class_labels, cross_attention_kwargs,
28
+ attention_mask, encoder_attention_mask, return_dict)
29
+
30
+
31
+ Transformer2DModel.forward = hacked_Transformer2DModel_forward
32
+
33
+
34
+ @torch.no_grad()
35
+ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
36
+ """DPM-Solver++(2M)."""
37
+ extra_args = {} if extra_args is None else extra_args
38
+ s_in = x.new_ones([x.shape[0]])
39
+ sigma_fn = lambda t: t.neg().exp()
40
+ t_fn = lambda sigma: sigma.log().neg()
41
+ old_denoised = None
42
+
43
+ for i in trange(len(sigmas) - 1, disable=disable):
44
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
45
+ if callback is not None:
46
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
47
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
48
+ h = t_next - t
49
+ if old_denoised is None or sigmas[i + 1] == 0:
50
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
51
+ else:
52
+ h_last = t - t_fn(sigmas[i - 1])
53
+ r = h_last / h
54
+ denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
55
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
56
+ old_denoised = denoised
57
+ return x
58
+
59
+
60
+ class KModel:
61
+ def __init__(self, unet, timesteps=1000, linear_start=0.00085, linear_end=0.012):
62
+ betas = torch.linspace(linear_start ** 0.5, linear_end ** 0.5, timesteps, dtype=torch.float64) ** 2
63
+ alphas = 1. - betas
64
+ alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32)
65
+
66
+ self.sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
67
+ self.log_sigmas = self.sigmas.log()
68
+ self.sigma_data = 1.0
69
+ self.unet = unet
70
+ return
71
+
72
+ @property
73
+ def sigma_min(self):
74
+ return self.sigmas[0]
75
+
76
+ @property
77
+ def sigma_max(self):
78
+ return self.sigmas[-1]
79
+
80
+ def timestep(self, sigma):
81
+ log_sigma = sigma.log()
82
+ dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
83
+ return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device)
84
+
85
+ def get_sigmas_karras(self, n, rho=7.):
86
+ ramp = torch.linspace(0, 1, n)
87
+ min_inv_rho = self.sigma_min ** (1 / rho)
88
+ max_inv_rho = self.sigma_max ** (1 / rho)
89
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
90
+ return torch.cat([sigmas, sigmas.new_zeros([1])])
91
+
92
+ def __call__(self, x, sigma, **extra_args):
93
+ x_ddim_space = x / (sigma[:, None, None, None] ** 2 + self.sigma_data ** 2) ** 0.5
94
+ t = self.timestep(sigma)
95
+ cfg_scale = extra_args['cfg_scale']
96
+ eps_positive = self.unet(x_ddim_space, t, return_dict=False, **extra_args['positive'])[0]
97
+ eps_negative = self.unet(x_ddim_space, t, return_dict=False, **extra_args['negative'])[0]
98
+ noise_pred = eps_negative + cfg_scale * (eps_positive - eps_negative)
99
+ return x - noise_pred * sigma[:, None, None, None]
100
+
101
+
102
+ class OmostSelfAttnProcessor:
103
+ def __call__(self, attn, hidden_states, encoder_hidden_states, hidden_states_original_shape, *args, **kwargs):
104
+ batch_size, sequence_length, _ = hidden_states.shape
105
+
106
+ query = attn.to_q(hidden_states)
107
+ key = attn.to_k(hidden_states)
108
+ value = attn.to_v(hidden_states)
109
+
110
+ inner_dim = key.shape[-1]
111
+ head_dim = inner_dim // attn.heads
112
+
113
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
114
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
115
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
116
+
117
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(
118
+ query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
119
+ )
120
+
121
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
122
+ hidden_states = hidden_states.to(query.dtype)
123
+ hidden_states = attn.to_out[0](hidden_states)
124
+ hidden_states = attn.to_out[1](hidden_states)
125
+
126
+ return hidden_states
127
+
128
+
129
+ class OmostCrossAttnProcessor:
130
+ def __call__(self, attn, hidden_states, encoder_hidden_states, hidden_states_original_shape, *args, **kwargs):
131
+ B, C, H, W = hidden_states_original_shape
132
+
133
+ conds = []
134
+ masks = []
135
+
136
+ for m, c in encoder_hidden_states:
137
+ m = torch.nn.functional.interpolate(m[None, None, :, :], (H, W), mode='nearest-exact').flatten().unsqueeze(1).repeat(1, c.size(1))
138
+ conds.append(c)
139
+ masks.append(m)
140
+
141
+ conds = torch.cat(conds, dim=1)
142
+ masks = torch.cat(masks, dim=1)
143
+
144
+ mask_bool = masks > 0.5
145
+ mask_scale = (H * W) / torch.sum(masks, dim=0, keepdim=True)
146
+
147
+ batch_size, sequence_length, _ = conds.shape
148
+
149
+ query = attn.to_q(hidden_states)
150
+ key = attn.to_k(conds)
151
+ value = attn.to_v(conds)
152
+
153
+ inner_dim = key.shape[-1]
154
+ head_dim = inner_dim // attn.heads
155
+
156
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
157
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
158
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
159
+
160
+ mask_bool = mask_bool[None, None, :, :].repeat(query.size(0), query.size(1), 1, 1)
161
+ mask_scale = mask_scale[None, None, :, :].repeat(query.size(0), query.size(1), 1, 1)
162
+
163
+ sim = query @ key.transpose(-2, -1) * attn.scale
164
+ sim = sim * mask_scale.to(sim)
165
+ sim.masked_fill_(mask_bool.logical_not(), float("-inf"))
166
+ sim = sim.softmax(dim=-1)
167
+
168
+ h = sim @ value
169
+ h = h.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
170
+
171
+ h = attn.to_out[0](h)
172
+ h = attn.to_out[1](h)
173
+ return h
174
+
175
+
176
+ class StableDiffusionXLOmostPipeline(StableDiffusionXLImg2ImgPipeline):
177
+ def __init__(self, *args, **kwargs):
178
+ super().__init__(*args, **kwargs)
179
+ self.k_model = KModel(unet=self.unet)
180
+
181
+ attn_procs = {}
182
+ for name in self.unet.attn_processors.keys():
183
+ if name.endswith("attn2.processor"):
184
+ attn_procs[name] = OmostCrossAttnProcessor()
185
+ else:
186
+ attn_procs[name] = OmostSelfAttnProcessor()
187
+
188
+ self.unet.set_attn_processor(attn_procs)
189
+ return
190
+
191
+ @torch.inference_mode()
192
+ def encode_bag_of_subprompts_greedy(self, prefixes: list[str], suffixes: list[str]):
193
+ device = self.text_encoder.device
194
+
195
+ @torch.inference_mode()
196
+ def greedy_partition(items, max_sum):
197
+ bags = []
198
+ current_bag = []
199
+ current_sum = 0
200
+
201
+ for item in items:
202
+ num = item['length']
203
+ if current_sum + num > max_sum:
204
+ if current_bag:
205
+ bags.append(current_bag)
206
+ current_bag = [item]
207
+ current_sum = num
208
+ else:
209
+ current_bag.append(item)
210
+ current_sum += num
211
+
212
+ if current_bag:
213
+ bags.append(current_bag)
214
+
215
+ return bags
216
+
217
+ @torch.inference_mode()
218
+ def get_77_tokens_in_torch(subprompt_inds, tokenizer):
219
+ # Note that all subprompt are theoretically less than 75 tokens (without bos/eos)
220
+ result = [tokenizer.bos_token_id] + subprompt_inds[:75] + [tokenizer.eos_token_id] + [tokenizer.pad_token_id] * 75
221
+ result = result[:77]
222
+ result = torch.tensor([result]).to(device=device, dtype=torch.int64)
223
+ return result
224
+
225
+ @torch.inference_mode()
226
+ def merge_with_prefix(bag):
227
+ merged_ids_t1 = copy.deepcopy(prefix_ids_t1)
228
+ merged_ids_t2 = copy.deepcopy(prefix_ids_t2)
229
+
230
+ for item in bag:
231
+ merged_ids_t1.extend(item['ids_t1'])
232
+ merged_ids_t2.extend(item['ids_t2'])
233
+
234
+ return dict(
235
+ ids_t1=get_77_tokens_in_torch(merged_ids_t1, self.tokenizer),
236
+ ids_t2=get_77_tokens_in_torch(merged_ids_t2, self.tokenizer_2)
237
+ )
238
+
239
+ @torch.inference_mode()
240
+ def double_encode(pair_of_inds):
241
+ inds = [pair_of_inds['ids_t1'], pair_of_inds['ids_t2']]
242
+ text_encoders = [self.text_encoder, self.text_encoder_2]
243
+
244
+ pooled_prompt_embeds = None
245
+ prompt_embeds_list = []
246
+
247
+ for text_input_ids, text_encoder in zip(inds, text_encoders):
248
+ prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True)
249
+
250
+ # Only last pooler_output is needed
251
+ pooled_prompt_embeds = prompt_embeds.pooler_output
252
+
253
+ # "2" because SDXL always indexes from the penultimate layer.
254
+ prompt_embeds = prompt_embeds.hidden_states[-2]
255
+ prompt_embeds_list.append(prompt_embeds)
256
+
257
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
258
+ return prompt_embeds, pooled_prompt_embeds
259
+
260
+ # Begin with tokenizing prefixes
261
+
262
+ prefix_length = 0
263
+ prefix_ids_t1 = []
264
+ prefix_ids_t2 = []
265
+
266
+ for prefix in prefixes:
267
+ ids_t1 = self.tokenizer(prefix, truncation=False, add_special_tokens=False).input_ids
268
+ ids_t2 = self.tokenizer_2(prefix, truncation=False, add_special_tokens=False).input_ids
269
+ assert len(ids_t1) == len(ids_t2)
270
+ prefix_length += len(ids_t1)
271
+ prefix_ids_t1 += ids_t1
272
+ prefix_ids_t2 += ids_t2
273
+
274
+ # Then tokenizing suffixes
275
+
276
+ allowed_suffix_length = 75 - prefix_length
277
+ suffix_targets = []
278
+
279
+ for subprompt in suffixes:
280
+ # Note that all subprompt are theoretically less than 75 tokens (without bos/eos)
281
+ # So we can safely just crop it to 75
282
+ ids_t1 = self.tokenizer(subprompt, truncation=False, add_special_tokens=False).input_ids[:75]
283
+ ids_t2 = self.tokenizer_2(subprompt, truncation=False, add_special_tokens=False).input_ids[:75]
284
+ assert len(ids_t1) == len(ids_t2)
285
+ suffix_targets.append(dict(
286
+ length=len(ids_t1),
287
+ ids_t1=ids_t1,
288
+ ids_t2=ids_t2
289
+ ))
290
+
291
+ # Then merge prefix and suffix tokens
292
+
293
+ suffix_targets = greedy_partition(suffix_targets, max_sum=allowed_suffix_length)
294
+ targets = [merge_with_prefix(b) for b in suffix_targets]
295
+
296
+ # Encode!
297
+
298
+ conds, poolers = [], []
299
+
300
+ for target in targets:
301
+ cond, pooler = double_encode(target)
302
+ conds.append(cond)
303
+ poolers.append(pooler)
304
+
305
+ conds_merged = torch.concat(conds, dim=1)
306
+ poolers_merged = poolers[0]
307
+
308
+ return dict(cond=conds_merged, pooler=poolers_merged)
309
+
310
+ @torch.inference_mode()
311
+ def all_conds_from_canvas(self, canvas_outputs, negative_prompt):
312
+ mask_all = torch.ones(size=(90, 90), dtype=torch.float32)
313
+ negative_cond, negative_pooler = self.encode_cropped_prompt_77tokens(negative_prompt)
314
+ negative_result = [(mask_all, negative_cond)]
315
+
316
+ positive_result = []
317
+ positive_pooler = None
318
+
319
+ for item in canvas_outputs['bag_of_conditions']:
320
+ current_mask = torch.from_numpy(item['mask']).to(torch.float32)
321
+ current_prefixes = item['prefixes']
322
+ current_suffixes = item['suffixes']
323
+
324
+ current_cond = self.encode_bag_of_subprompts_greedy(prefixes=current_prefixes, suffixes=current_suffixes)
325
+
326
+ if positive_pooler is None:
327
+ positive_pooler = current_cond['pooler']
328
+
329
+ positive_result.append((current_mask, current_cond['cond']))
330
+
331
+ return positive_result, positive_pooler, negative_result, negative_pooler
332
+
333
+ @torch.inference_mode()
334
+ def encode_cropped_prompt_77tokens(self, prompt: str):
335
+ device = self.text_encoder.device
336
+ tokenizers = [self.tokenizer, self.tokenizer_2]
337
+ text_encoders = [self.text_encoder, self.text_encoder_2]
338
+
339
+ pooled_prompt_embeds = None
340
+ prompt_embeds_list = []
341
+
342
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
343
+ text_input_ids = tokenizer(
344
+ prompt,
345
+ padding="max_length",
346
+ max_length=tokenizer.model_max_length,
347
+ truncation=True,
348
+ return_tensors="pt",
349
+ ).input_ids
350
+
351
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
352
+
353
+ # Only last pooler_output is needed
354
+ pooled_prompt_embeds = prompt_embeds.pooler_output
355
+
356
+ # "2" because SDXL always indexes from the penultimate layer.
357
+ prompt_embeds = prompt_embeds.hidden_states[-2]
358
+ prompt_embeds_list.append(prompt_embeds)
359
+
360
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
361
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
362
+
363
+ return prompt_embeds, pooled_prompt_embeds
364
+
365
+ @torch.inference_mode()
366
+ def __call__(
367
+ self,
368
+ initial_latent: torch.FloatTensor = None,
369
+ strength: float = 1.0,
370
+ num_inference_steps: int = 25,
371
+ guidance_scale: float = 5.0,
372
+ batch_size: Optional[int] = 1,
373
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
374
+ prompt_embeds: Optional[torch.FloatTensor] = None,
375
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
376
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
377
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
378
+ cross_attention_kwargs: Optional[dict] = None,
379
+ product_image: Optional[torch.FloatTensor] = None,
380
+ ):
381
+
382
+ device = self.unet.device
383
+ cross_attention_kwargs = cross_attention_kwargs or {}
384
+
385
+ # Sigmas
386
+
387
+ sigmas = self.k_model.get_sigmas_karras(int(num_inference_steps / strength))
388
+ sigmas = sigmas[-(num_inference_steps + 1):].to(device)
389
+
390
+ # Initial latents
391
+
392
+ _, C, H, W = initial_latent.shape
393
+ noise = randn_tensor((batch_size, C, H, W), generator=generator, device=device, dtype=self.unet.dtype)
394
+ latents = initial_latent.to(noise) + noise * sigmas[0].to(noise)
395
+
396
+ # Shape
397
+
398
+ height, width = latents.shape[-2:]
399
+ height = height * self.vae_scale_factor
400
+ width = width * self.vae_scale_factor
401
+
402
+ add_time_ids = list((height, width) + (0, 0) + (height, width))
403
+ add_time_ids = torch.tensor([add_time_ids], dtype=self.unet.dtype)
404
+ add_neg_time_ids = add_time_ids.clone()
405
+
406
+ # Batch
407
+
408
+ latents = latents.to(device)
409
+ add_time_ids = add_time_ids.repeat(batch_size, 1).to(device)
410
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size, 1).to(device)
411
+ prompt_embeds = [(k.to(device), v.repeat(batch_size, 1, 1).to(noise)) for k, v in prompt_embeds]
412
+ negative_prompt_embeds = [(k.to(device), v.repeat(batch_size, 1, 1).to(noise)) for k, v in negative_prompt_embeds]
413
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(batch_size, 1).to(noise)
414
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(batch_size, 1).to(noise)
415
+
416
+ # Feeds
417
+
418
+ sampler_kwargs = dict(
419
+ cfg_scale=guidance_scale,
420
+ positive=dict(
421
+ encoder_hidden_states=prompt_embeds,
422
+ added_cond_kwargs={"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids},
423
+ cross_attention_kwargs=cross_attention_kwargs
424
+ ),
425
+ negative=dict(
426
+ encoder_hidden_states=negative_prompt_embeds,
427
+ added_cond_kwargs={"text_embeds": negative_pooled_prompt_embeds, "time_ids": add_neg_time_ids},
428
+ cross_attention_kwargs=cross_attention_kwargs
429
+ )
430
+ )
431
+
432
+ # Sample
433
+
434
+ results = sample_dpmpp_2m(self.k_model, latents, sigmas, extra_args=sampler_kwargs, disable=False)
435
+
436
+ if product_image is not None:
437
+ # Encode product image
438
+ product_latent = self.vae.encode(product_image.to(device)).latent_dist.sample()
439
+ product_latent = product_latent * self.vae.config.scaling_factor
440
+
441
+ # Combine product latent with generated background
442
+ alpha = 0.7 # Adjust this value to control the blending
443
+ results = alpha * results + (1 - alpha) * product_latent
444
+
445
+ return StableDiffusionXLPipelineOutput(images=results)
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.28.0
2
+ transformers==4.41.1
3
+ gradio==4.31.5
4
+ accelerate==0.30.1
5
+ protobuf==3.20
6
+ opencv-python
7
+ tensorboardX
8
+ safetensors
9
+ pillow
10
+ einops
11
+ torch
12
+ peft