diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..e2e68308705c59d0e047f13c1116ed223f9b7fec
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,33 @@
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.mlmodel filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+*.safetensors filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..75c45fdcb59a4c94649d0e5fd317f22969ecaf3d
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+.idea
+pretrained_models
diff --git a/LICENSE.txt b/LICENSE.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b4b944d4a3d42d900933fa1cad13c099ad8c1a07
--- /dev/null
+++ b/LICENSE.txt
@@ -0,0 +1,400 @@
+
+Attribution-NonCommercial 4.0 International
+
+=======================================================================
+
+Creative Commons Corporation ("Creative Commons") is not a law firm and
+does not provide legal services or legal advice. Distribution of
+Creative Commons public licenses does not create a lawyer-client or
+other relationship. Creative Commons makes its licenses and related
+information available on an "as-is" basis. Creative Commons gives no
+warranties regarding its licenses, any material licensed under their
+terms and conditions, or any related information. Creative Commons
+disclaims all liability for damages resulting from their use to the
+fullest extent possible.
+
+Using Creative Commons Public Licenses
+
+Creative Commons public licenses provide a standard set of terms and
+conditions that creators and other rights holders may use to share
+original works of authorship and other material subject to copyright
+and certain other rights specified in the public license below. The
+following considerations are for informational purposes only, are not
+exhaustive, and do not form part of our licenses.
+
+ Considerations for licensors: Our public licenses are
+ intended for use by those authorized to give the public
+ permission to use material in ways otherwise restricted by
+ copyright and certain other rights. Our licenses are
+ irrevocable. Licensors should read and understand the terms
+ and conditions of the license they choose before applying it.
+ Licensors should also secure all rights necessary before
+ applying our licenses so that the public can reuse the
+ material as expected. Licensors should clearly mark any
+ material not subject to the license. This includes other CC-
+ licensed material, or material used under an exception or
+ limitation to copyright. More considerations for licensors:
+ wiki.creativecommons.org/Considerations_for_licensors
+
+ Considerations for the public: By using one of our public
+ licenses, a licensor grants the public permission to use the
+ licensed material under specified terms and conditions. If
+ the licensor's permission is not necessary for any reason--for
+ example, because of any applicable exception or limitation to
+ copyright--then that use is not regulated by the license. Our
+ licenses grant only permissions under copyright and certain
+ other rights that a licensor has authority to grant. Use of
+ the licensed material may still be restricted for other
+ reasons, including because others have copyright or other
+ rights in the material. A licensor may make special requests,
+ such as asking that all changes be marked or described.
+ Although not required by our licenses, you are encouraged to
+ respect those requests where reasonable. More_considerations
+ for the public:
+ wiki.creativecommons.org/Considerations_for_licensees
+
+=======================================================================
+
+Creative Commons Attribution-NonCommercial 4.0 International Public
+License
+
+By exercising the Licensed Rights (defined below), You accept and agree
+to be bound by the terms and conditions of this Creative Commons
+Attribution-NonCommercial 4.0 International Public License ("Public
+License"). To the extent this Public License may be interpreted as a
+contract, You are granted the Licensed Rights in consideration of Your
+acceptance of these terms and conditions, and the Licensor grants You
+such rights in consideration of benefits the Licensor receives from
+making the Licensed Material available under these terms and
+conditions.
+
+Section 1 -- Definitions.
+
+ a. Adapted Material means material subject to Copyright and Similar
+ Rights that is derived from or based upon the Licensed Material
+ and in which the Licensed Material is translated, altered,
+ arranged, transformed, or otherwise modified in a manner requiring
+ permission under the Copyright and Similar Rights held by the
+ Licensor. For purposes of this Public License, where the Licensed
+ Material is a musical work, performance, or sound recording,
+ Adapted Material is always produced where the Licensed Material is
+ synched in timed relation with a moving image.
+
+ b. Adapter's License means the license You apply to Your Copyright
+ and Similar Rights in Your contributions to Adapted Material in
+ accordance with the terms and conditions of this Public License.
+
+ c. Copyright and Similar Rights means copyright and/or similar rights
+ closely related to copyright including, without limitation,
+ performance, broadcast, sound recording, and Sui Generis Database
+ Rights, without regard to how the rights are labeled or
+ categorized. For purposes of this Public License, the rights
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
+ Rights.
+ d. Effective Technological Measures means those measures that, in the
+ absence of proper authority, may not be circumvented under laws
+ fulfilling obligations under Article 11 of the WIPO Copyright
+ Treaty adopted on December 20, 1996, and/or similar international
+ agreements.
+
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
+ any other exception or limitation to Copyright and Similar Rights
+ that applies to Your use of the Licensed Material.
+
+ f. Licensed Material means the artistic or literary work, database,
+ or other material to which the Licensor applied this Public
+ License.
+
+ g. Licensed Rights means the rights granted to You subject to the
+ terms and conditions of this Public License, which are limited to
+ all Copyright and Similar Rights that apply to Your use of the
+ Licensed Material and that the Licensor has authority to license.
+
+ h. Licensor means the individual(s) or entity(ies) granting rights
+ under this Public License.
+
+ i. NonCommercial means not primarily intended for or directed towards
+ commercial advantage or monetary compensation. For purposes of
+ this Public License, the exchange of the Licensed Material for
+ other material subject to Copyright and Similar Rights by digital
+ file-sharing or similar means is NonCommercial provided there is
+ no payment of monetary compensation in connection with the
+ exchange.
+
+ j. Share means to provide material to the public by any means or
+ process that requires permission under the Licensed Rights, such
+ as reproduction, public display, public performance, distribution,
+ dissemination, communication, or importation, and to make material
+ available to the public including in ways that members of the
+ public may access the material from a place and at a time
+ individually chosen by them.
+
+ k. Sui Generis Database Rights means rights other than copyright
+ resulting from Directive 96/9/EC of the European Parliament and of
+ the Council of 11 March 1996 on the legal protection of databases,
+ as amended and/or succeeded, as well as other essentially
+ equivalent rights anywhere in the world.
+
+ l. You means the individual or entity exercising the Licensed Rights
+ under this Public License. Your has a corresponding meaning.
+
+Section 2 -- Scope.
+
+ a. License grant.
+
+ 1. Subject to the terms and conditions of this Public License,
+ the Licensor hereby grants You a worldwide, royalty-free,
+ non-sublicensable, non-exclusive, irrevocable license to
+ exercise the Licensed Rights in the Licensed Material to:
+
+ a. reproduce and Share the Licensed Material, in whole or
+ in part, for NonCommercial purposes only; and
+
+ b. produce, reproduce, and Share Adapted Material for
+ NonCommercial purposes only.
+
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
+ Exceptions and Limitations apply to Your use, this Public
+ License does not apply, and You do not need to comply with
+ its terms and conditions.
+
+ 3. Term. The term of this Public License is specified in Section
+ 6(a).
+
+ 4. Media and formats; technical modifications allowed. The
+ Licensor authorizes You to exercise the Licensed Rights in
+ all media and formats whether now known or hereafter created,
+ and to make technical modifications necessary to do so. The
+ Licensor waives and/or agrees not to assert any right or
+ authority to forbid You from making technical modifications
+ necessary to exercise the Licensed Rights, including
+ technical modifications necessary to circumvent Effective
+ Technological Measures. For purposes of this Public License,
+ simply making modifications authorized by this Section 2(a)
+ (4) never produces Adapted Material.
+
+ 5. Downstream recipients.
+
+ a. Offer from the Licensor -- Licensed Material. Every
+ recipient of the Licensed Material automatically
+ receives an offer from the Licensor to exercise the
+ Licensed Rights under the terms and conditions of this
+ Public License.
+
+ b. No downstream restrictions. You may not offer or impose
+ any additional or different terms or conditions on, or
+ apply any Effective Technological Measures to, the
+ Licensed Material if doing so restricts exercise of the
+ Licensed Rights by any recipient of the Licensed
+ Material.
+
+ 6. No endorsement. Nothing in this Public License constitutes or
+ may be construed as permission to assert or imply that You
+ are, or that Your use of the Licensed Material is, connected
+ with, or sponsored, endorsed, or granted official status by,
+ the Licensor or others designated to receive attribution as
+ provided in Section 3(a)(1)(A)(i).
+
+ b. Other rights.
+
+ 1. Moral rights, such as the right of integrity, are not
+ licensed under this Public License, nor are publicity,
+ privacy, and/or other similar personality rights; however, to
+ the extent possible, the Licensor waives and/or agrees not to
+ assert any such rights held by the Licensor to the limited
+ extent necessary to allow You to exercise the Licensed
+ Rights, but not otherwise.
+
+ 2. Patent and trademark rights are not licensed under this
+ Public License.
+
+ 3. To the extent possible, the Licensor waives any right to
+ collect royalties from You for the exercise of the Licensed
+ Rights, whether directly or through a collecting society
+ under any voluntary or waivable statutory or compulsory
+ licensing scheme. In all other cases the Licensor expressly
+ reserves any right to collect such royalties, including when
+ the Licensed Material is used other than for NonCommercial
+ purposes.
+
+Section 3 -- License Conditions.
+
+Your exercise of the Licensed Rights is expressly made subject to the
+following conditions.
+
+ a. Attribution.
+
+ 1. If You Share the Licensed Material (including in modified
+ form), You must:
+
+ a. retain the following if it is supplied by the Licensor
+ with the Licensed Material:
+
+ i. identification of the creator(s) of the Licensed
+ Material and any others designated to receive
+ attribution, in any reasonable manner requested by
+ the Licensor (including by pseudonym if
+ designated);
+
+ ii. a copyright notice;
+
+ iii. a notice that refers to this Public License;
+
+ iv. a notice that refers to the disclaimer of
+ warranties;
+
+ v. a URI or hyperlink to the Licensed Material to the
+ extent reasonably practicable;
+
+ b. indicate if You modified the Licensed Material and
+ retain an indication of any previous modifications; and
+
+ c. indicate the Licensed Material is licensed under this
+ Public License, and include the text of, or the URI or
+ hyperlink to, this Public License.
+
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
+ reasonable manner based on the medium, means, and context in
+ which You Share the Licensed Material. For example, it may be
+ reasonable to satisfy the conditions by providing a URI or
+ hyperlink to a resource that includes the required
+ information.
+
+ 3. If requested by the Licensor, You must remove any of the
+ information required by Section 3(a)(1)(A) to the extent
+ reasonably practicable.
+
+ 4. If You Share Adapted Material You produce, the Adapter's
+ License You apply must not prevent recipients of the Adapted
+ Material from complying with this Public License.
+
+Section 4 -- Sui Generis Database Rights.
+
+Where the Licensed Rights include Sui Generis Database Rights that
+apply to Your use of the Licensed Material:
+
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
+ to extract, reuse, reproduce, and Share all or a substantial
+ portion of the contents of the database for NonCommercial purposes
+ only;
+
+ b. if You include all or a substantial portion of the database
+ contents in a database in which You have Sui Generis Database
+ Rights, then the database in which You have Sui Generis Database
+ Rights (but not its individual contents) is Adapted Material; and
+
+ c. You must comply with the conditions in Section 3(a) if You Share
+ all or a substantial portion of the contents of the database.
+
+For the avoidance of doubt, this Section 4 supplements and does not
+replace Your obligations under this Public License where the Licensed
+Rights include other Copyright and Similar Rights.
+
+Section 5 -- Disclaimer of Warranties and Limitation of Liability.
+
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
+
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
+
+ c. The disclaimer of warranties and limitation of liability provided
+ above shall be interpreted in a manner that, to the extent
+ possible, most closely approximates an absolute disclaimer and
+ waiver of all liability.
+
+Section 6 -- Term and Termination.
+
+ a. This Public License applies for the term of the Copyright and
+ Similar Rights licensed here. However, if You fail to comply with
+ this Public License, then Your rights under this Public License
+ terminate automatically.
+
+ b. Where Your right to use the Licensed Material has terminated under
+ Section 6(a), it reinstates:
+
+ 1. automatically as of the date the violation is cured, provided
+ it is cured within 30 days of Your discovery of the
+ violation; or
+
+ 2. upon express reinstatement by the Licensor.
+
+ For the avoidance of doubt, this Section 6(b) does not affect any
+ right the Licensor may have to seek remedies for Your violations
+ of this Public License.
+
+ c. For the avoidance of doubt, the Licensor may also offer the
+ Licensed Material under separate terms or conditions or stop
+ distributing the Licensed Material at any time; however, doing so
+ will not terminate this Public License.
+
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
+ License.
+
+Section 7 -- Other Terms and Conditions.
+
+ a. The Licensor shall not be bound by any additional or different
+ terms or conditions communicated by You unless expressly agreed.
+
+ b. Any arrangements, understandings, or agreements regarding the
+ Licensed Material not stated herein are separate from and
+ independent of the terms and conditions of this Public License.
+
+Section 8 -- Interpretation.
+
+ a. For the avoidance of doubt, this Public License does not, and
+ shall not be interpreted to, reduce, limit, restrict, or impose
+ conditions on any use of the Licensed Material that could lawfully
+ be made without permission under this Public License.
+
+ b. To the extent possible, if any provision of this Public License is
+ deemed unenforceable, it shall be automatically reformed to the
+ minimum extent necessary to make it enforceable. If the provision
+ cannot be reformed, it shall be severed from this Public License
+ without affecting the enforceability of the remaining terms and
+ conditions.
+
+ c. No term or condition of this Public License will be waived and no
+ failure to comply consented to unless expressly agreed to by the
+ Licensor.
+
+ d. Nothing in this Public License constitutes or may be interpreted
+ as a limitation upon, or waiver of, any privileges and immunities
+ that apply to the Licensor or You, including from the legal
+ processes of any jurisdiction or authority.
+
+=======================================================================
+
+Creative Commons is not a party to its public
+licenses. Notwithstanding, Creative Commons may elect to apply one of
+its public licenses to material it publishes and in those instances
+will be considered the “Licensor.” The text of the Creative Commons
+public licenses is dedicated to the public domain under the CC0 Public
+Domain Dedication. Except for the limited purpose of indicating that
+material is shared under a Creative Commons public license or as
+otherwise permitted by the Creative Commons policies published at
+creativecommons.org/policies, Creative Commons does not authorize the
+use of the trademark "Creative Commons" or any other trademark or logo
+of Creative Commons without its prior written consent including,
+without limitation, in connection with any unauthorized modifications
+to any of its public licenses or any other arrangements,
+understandings, or agreements concerning use of licensed material. For
+the avoidance of doubt, this paragraph does not form part of the
+public licenses.
+
+Creative Commons may be contacted at creativecommons.org.
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..d1aae787b021c47ad6914bdc857cf7016244a496
--- /dev/null
+++ b/README.md
@@ -0,0 +1,13 @@
+---
+title: Diffusion Transformers (DiT)
+emoji: 🚀
+colorFrom: yellow
+colorTo: green
+sdk: gradio
+sdk_version: 3.6
+app_file: app.py
+pinned: false
+license: cc-by-nc-4.0
+---
+
+The code and model weights are licensed under CC-BY-NC. See LICENSE.txt for details.
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffac77c274d3eed27e0728636331397710993921
--- /dev/null
+++ b/app.py
@@ -0,0 +1,136 @@
+import torch
+from torchvision.utils import make_grid
+import math
+from PIL import Image
+from diffusion import create_diffusion
+from diffusers.models import AutoencoderKL
+import gradio as gr
+from imagenet_class_data import IMAGENET_1K_CLASSES
+from download import find_model
+from models import DiT_XL_2
+
+
+def load_model(image_size=256):
+ assert image_size in [256, 512]
+ latent_size = image_size // 8
+ model = DiT_XL_2(input_size=latent_size).to(device)
+ state_dict = find_model(f"DiT-XL-2-{image_size}x{image_size}.pt")
+ model.load_state_dict(state_dict)
+ model.eval()
+ return model
+
+
+torch.set_grad_enabled(False)
+device = "cuda" if torch.cuda.is_available() else "cpu"
+model = load_model(image_size=256)
+vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)
+current_image_size = 256
+current_vae_model = "stabilityai/sd-vae-ft-mse"
+
+
+def generate(image_size, vae_model, class_label, cfg_scale, num_sampling_steps, n, seed):
+ image_size = int(image_size.split("x")[0])
+ global current_image_size
+ if image_size != current_image_size:
+ global model
+ del model
+ # if device == "cuda":
+ # torch.cuda.empty_cache()
+ model = load_model(image_size=image_size)
+ current_image_size = image_size
+
+ global current_vae_model
+ if vae_model != current_vae_model:
+ global vae
+ if device == "cuda":
+ vae.to("cpu")
+ del vae
+ vae = AutoencoderKL.from_pretrained(vae_model).to(device)
+
+ # Seed PyTorch:
+ torch.manual_seed(seed)
+
+ # Setup diffusion
+ diffusion = create_diffusion(str(num_sampling_steps))
+
+ # Create sampling noise:
+ latent_size = image_size // 8
+ z = torch.randn(n, 4, latent_size, latent_size, device=device)
+ y = torch.tensor([class_label] * n, device=device)
+
+ # Setup classifier-free guidance:
+ z = torch.cat([z, z], 0)
+ y_null = torch.tensor([1000] * n, device=device)
+ y = torch.cat([y, y_null], 0)
+ model_kwargs = dict(y=y, cfg_scale=cfg_scale)
+
+ # Sample images:
+ samples = diffusion.p_sample_loop(
+ model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device
+ )
+ samples, _ = samples.chunk(2, dim=0) # Remove null class samples
+ samples = vae.decode(samples / 0.18215).sample
+
+ # Convert to PIL.Image format:
+ samples = samples.mul(127.5).add_(128.0).clamp_(0, 255).permute(0, 2, 3, 1).to("cpu", torch.uint8).numpy()
+ samples = [Image.fromarray(sample) for sample in samples]
+ return samples
+
+
+description = '''This is a demo of our DiT image generation models. DiTs are a new class of diffusion models with
+transformer backbones. They are class-conditional models trained on ImageNet-1K, and they outperform prior DDPMs.'''
+
+project_links = '''
+
+Project Page ·
+Colab ·
+Paper ·
+GitHub
'''
+
+examples = [
+ ["512x512", "stabilityai/sd-vae-ft-mse", "golden retriever", 4.0, 200, 4, 1000],
+ ["512x512", "stabilityai/sd-vae-ft-mse", "macaw", 4.0, 200, 4, 1],
+ ["512x512", "stabilityai/sd-vae-ft-mse", "balloon", 4.0, 200, 4, 1],
+ ["512x512", "stabilityai/sd-vae-ft-mse", "cliff, drop, drop-off", 4.0, 200, 4, 7],
+ ["512x512", "stabilityai/sd-vae-ft-mse", "Pembroke, Pembroke Welsh corgi", 4.0, 200, 4, 0],
+ ["256x256", "stabilityai/sd-vae-ft-mse", "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita", 4.0, 200,
+ 4, 1],
+ ["256x256", "stabilityai/sd-vae-ft-mse", "teddy, teddy bear", 4.0, 200, 4, 3],
+ ["256x256", "stabilityai/sd-vae-ft-mse", "cheeseburger", 4.0, 200, 4, 2],
+
+]
+
+with gr.Blocks() as demo:
+ gr.Markdown("Scalable Diffusion Models with Transformers (DiT)
")
+ gr.Markdown(project_links)
+ gr.Markdown(description)
+
+ with gr.Tabs():
+ with gr.TabItem('Generate'):
+ with gr.Row():
+ with gr.Column():
+ with gr.Row():
+ image_size = gr.inputs.Radio(choices=["256x256", "512x512"], default="256x256", label='DiT Model Resolution')
+ vae_model = gr.inputs.Radio(choices=["stabilityai/sd-vae-ft-mse", "stabilityai/sd-vae-ft-ema"],
+ default="stabilityai/sd-vae-ft-mse", label='VAE Decoder')
+ with gr.Row():
+ i1k_class = gr.inputs.Dropdown(
+ list(IMAGENET_1K_CLASSES.values()),
+ default='golden retriever',
+ type="index", label='ImageNet-1K Class'
+ )
+ cfg_scale = gr.inputs.Slider(minimum=1, maximum=25, step=0.1, default=4.0, label='Classifier-free Guidance Scale')
+ steps = gr.inputs.Slider(minimum=4, maximum=1000, step=1, default=75, label='Sampling Steps')
+ n = gr.inputs.Slider(minimum=1, maximum=16, step=1, default=1, label='Number of Samples')
+ seed = gr.inputs.Number(default=0, label='Seed')
+ button = gr.Button("Generate", variant="primary")
+ with gr.Column():
+ output = gr.Gallery(label='Generated Images').style(grid=[2], height="auto")
+ button.click(generate, inputs=[image_size, vae_model, i1k_class, cfg_scale, steps, n, seed], outputs=[output])
+ with gr.Row():
+ ex = gr.Examples(examples=examples, fn=generate,
+ inputs=[image_size, vae_model, i1k_class, cfg_scale, steps, n, seed],
+ outputs=[output],
+ cache_examples=True)
+
+ demo.launch()
diff --git a/diffusion/__init__.py b/diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c536a98da92c4d051458803737661e5ecf974c2
--- /dev/null
+++ b/diffusion/__init__.py
@@ -0,0 +1,46 @@
+# Modified from OpenAI's diffusion repos
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+
+from . import gaussian_diffusion as gd
+from .respace import SpacedDiffusion, space_timesteps
+
+
+def create_diffusion(
+ timestep_respacing,
+ noise_schedule="linear",
+ use_kl=False,
+ sigma_small=False,
+ predict_xstart=False,
+ learn_sigma=True,
+ rescale_learned_sigmas=False,
+ diffusion_steps=1000
+):
+ betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
+ if use_kl:
+ loss_type = gd.LossType.RESCALED_KL
+ elif rescale_learned_sigmas:
+ loss_type = gd.LossType.RESCALED_MSE
+ else:
+ loss_type = gd.LossType.MSE
+ if timestep_respacing is None or timestep_respacing == "":
+ timestep_respacing = [diffusion_steps]
+ return SpacedDiffusion(
+ use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
+ betas=betas,
+ model_mean_type=(
+ gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
+ ),
+ model_var_type=(
+ (
+ gd.ModelVarType.FIXED_LARGE
+ if not sigma_small
+ else gd.ModelVarType.FIXED_SMALL
+ )
+ if not learn_sigma
+ else gd.ModelVarType.LEARNED_RANGE
+ ),
+ loss_type=loss_type
+ # rescale_timesteps=rescale_timesteps,
+ )
diff --git a/diffusion/__pycache__/__init__.cpython-39.pyc b/diffusion/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0b5c8ed4caa8afc6c000283303604810c8f21509
Binary files /dev/null and b/diffusion/__pycache__/__init__.cpython-39.pyc differ
diff --git a/diffusion/__pycache__/diffusion_utils.cpython-39.pyc b/diffusion/__pycache__/diffusion_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..568e49b4392362148d812797b29ff8625a22d9fd
Binary files /dev/null and b/diffusion/__pycache__/diffusion_utils.cpython-39.pyc differ
diff --git a/diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc b/diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5577bd197c50b380f1d4ddd3443e9ca6d87a924c
Binary files /dev/null and b/diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc differ
diff --git a/diffusion/__pycache__/respace.cpython-39.pyc b/diffusion/__pycache__/respace.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8cfa72a7618953768035871077c702652937b3c9
Binary files /dev/null and b/diffusion/__pycache__/respace.cpython-39.pyc differ
diff --git a/diffusion/diffusion_utils.py b/diffusion/diffusion_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e493a6a3ecb91e553a53cc7eadee5cc0d1753060
--- /dev/null
+++ b/diffusion/diffusion_utils.py
@@ -0,0 +1,88 @@
+# Modified from OpenAI's diffusion repos
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+
+import torch as th
+import numpy as np
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, th.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for th.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + th.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
+ )
+
+
+def approx_standard_normal_cdf(x):
+ """
+ A fast approximation of the cumulative distribution function of the
+ standard normal.
+ """
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
+
+
+def continuous_gaussian_log_likelihood(x, *, means, log_scales):
+ """
+ Compute the log-likelihood of a continuous Gaussian distribution.
+ :param x: the targets
+ :param means: the Gaussian mean Tensor.
+ :param log_scales: the Gaussian log stddev Tensor.
+ :return: a tensor like x of log probabilities (in nats).
+ """
+ centered_x = x - means
+ inv_stdv = th.exp(-log_scales)
+ normalized_x = centered_x * inv_stdv
+ log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
+ return log_probs
+
+
+def discretized_gaussian_log_likelihood(x, *, means, log_scales):
+ """
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
+ given image.
+ :param x: the target images. It is assumed that this was uint8 values,
+ rescaled to the range [-1, 1].
+ :param means: the Gaussian mean Tensor.
+ :param log_scales: the Gaussian log stddev Tensor.
+ :return: a tensor like x of log probabilities (in nats).
+ """
+ assert x.shape == means.shape == log_scales.shape
+ centered_x = x - means
+ inv_stdv = th.exp(-log_scales)
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
+ cdf_plus = approx_standard_normal_cdf(plus_in)
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
+ cdf_min = approx_standard_normal_cdf(min_in)
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
+ cdf_delta = cdf_plus - cdf_min
+ log_probs = th.where(
+ x < -0.999,
+ log_cdf_plus,
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
+ )
+ assert log_probs.shape == x.shape
+ return log_probs
diff --git a/diffusion/gaussian_diffusion.py b/diffusion/gaussian_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..ccbcefeca4348e2e627d22723b50492c894e66ba
--- /dev/null
+++ b/diffusion/gaussian_diffusion.py
@@ -0,0 +1,873 @@
+# Modified from OpenAI's diffusion repos
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+
+
+import math
+
+import numpy as np
+import torch as th
+import enum
+
+from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+class ModelMeanType(enum.Enum):
+ """
+ Which type of output the model predicts.
+ """
+
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
+ START_X = enum.auto() # the model predicts x_0
+ EPSILON = enum.auto() # the model predicts epsilon
+
+
+class ModelVarType(enum.Enum):
+ """
+ What is used as the model's output variance.
+ The LEARNED_RANGE option has been added to allow the model to predict
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
+ """
+
+ LEARNED = enum.auto()
+ FIXED_SMALL = enum.auto()
+ FIXED_LARGE = enum.auto()
+ LEARNED_RANGE = enum.auto()
+
+
+class LossType(enum.Enum):
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
+ RESCALED_MSE = (
+ enum.auto()
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
+ KL = enum.auto() # use the variational lower-bound
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
+
+ def is_vb(self):
+ return self == LossType.KL or self == LossType.RESCALED_KL
+
+
+def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
+ betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
+ return betas
+
+
+def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
+ """
+ This is the deprecated API for creating beta schedules.
+ See get_named_beta_schedule() for the new library of schedules.
+ """
+ if beta_schedule == "quad":
+ betas = (
+ np.linspace(
+ beta_start ** 0.5,
+ beta_end ** 0.5,
+ num_diffusion_timesteps,
+ dtype=np.float64,
+ )
+ ** 2
+ )
+ elif beta_schedule == "linear":
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
+ elif beta_schedule == "warmup10":
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
+ elif beta_schedule == "warmup50":
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
+ elif beta_schedule == "const":
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
+ betas = 1.0 / np.linspace(
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
+ )
+ else:
+ raise NotImplementedError(beta_schedule)
+ assert betas.shape == (num_diffusion_timesteps,)
+ return betas
+
+
+def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
+ """
+ Get a pre-defined beta schedule for the given name.
+ The beta schedule library consists of beta schedules which remain similar
+ in the limit of num_diffusion_timesteps.
+ Beta schedules may be added, but should not be removed or changed once
+ they are committed to maintain backwards compatibility.
+ """
+ if schedule_name == "linear":
+ # Linear schedule from Ho et al, extended to work for any number of
+ # diffusion steps.
+ scale = 1000 / num_diffusion_timesteps
+ return get_beta_schedule(
+ "linear",
+ beta_start=scale * 0.0001,
+ beta_end=scale * 0.02,
+ num_diffusion_timesteps=num_diffusion_timesteps,
+ )
+ elif schedule_name == "squaredcos_cap_v2":
+ return betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
+ )
+ else:
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+class GaussianDiffusion:
+ """
+ Utilities for training and sampling diffusion models.
+ Original ported from this codebase:
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
+ starting at T and going to 1.
+ """
+
+ def __init__(
+ self,
+ *,
+ betas,
+ model_mean_type,
+ model_var_type,
+ loss_type
+ ):
+
+ self.model_mean_type = model_mean_type
+ self.model_var_type = model_var_type
+ self.loss_type = loss_type
+
+ # Use float64 for accuracy.
+ betas = np.array(betas, dtype=np.float64)
+ self.betas = betas
+ assert len(betas.shape) == 1, "betas must be 1-D"
+ assert (betas > 0).all() and (betas <= 1).all()
+
+ self.num_timesteps = int(betas.shape[0])
+
+ alphas = 1.0 - betas
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ self.posterior_variance = (
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
+ )
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+ self.posterior_log_variance_clipped = np.log(
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
+ ) if len(self.posterior_variance) > 1 else np.array([])
+
+ self.posterior_mean_coef1 = (
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
+ )
+ self.posterior_mean_coef2 = (
+ (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
+ )
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
+ return mean, variance, log_variance
+
+ def q_sample(self, x_start, t, noise=None):
+ """
+ Diffuse the data for a given number of diffusion steps.
+ In other words, sample from q(x_t | x_0).
+ :param x_start: the initial data batch.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :param noise: if specified, the split-out normal noise.
+ :return: A noisy version of x_start.
+ """
+ if noise is None:
+ noise = th.randn_like(x_start)
+ assert noise.shape == x_start.shape
+ return (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
+ )
+
+ def q_posterior_mean_variance(self, x_start, x_t, t):
+ """
+ Compute the mean and variance of the diffusion posterior:
+ q(x_{t-1} | x_t, x_0)
+ """
+ assert x_start.shape == x_t.shape
+ posterior_mean = (
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = _extract_into_tensor(
+ self.posterior_log_variance_clipped, t, x_t.shape
+ )
+ assert (
+ posterior_mean.shape[0]
+ == posterior_variance.shape[0]
+ == posterior_log_variance_clipped.shape[0]
+ == x_start.shape[0]
+ )
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
+ """
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
+ the initial x, x_0.
+ :param model: the model, which takes a signal and a batch of timesteps
+ as input.
+ :param x: the [N x C x ...] tensor at time t.
+ :param t: a 1-D Tensor of timesteps.
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample. Applies before
+ clip_denoised.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict with the following keys:
+ - 'mean': the model mean output.
+ - 'variance': the model variance output.
+ - 'log_variance': the log of 'variance'.
+ - 'pred_xstart': the prediction for x_0.
+ """
+ if model_kwargs is None:
+ model_kwargs = {}
+
+ B, C = x.shape[:2]
+ assert t.shape == (B,)
+ model_output = model(x, t, **model_kwargs)
+ if isinstance(model_output, tuple):
+ model_output, extra = model_output
+ else:
+ extra = None
+
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
+ model_output, model_var_values = th.split(model_output, C, dim=1)
+ min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
+ # The model_var_values is [-1, 1] for [min_var, max_var].
+ frac = (model_var_values + 1) / 2
+ model_log_variance = frac * max_log + (1 - frac) * min_log
+ model_variance = th.exp(model_log_variance)
+ else:
+ model_variance, model_log_variance = {
+ # for fixedlarge, we set the initial (log-)variance like so
+ # to get a better decoder log likelihood.
+ ModelVarType.FIXED_LARGE: (
+ np.append(self.posterior_variance[1], self.betas[1:]),
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
+ ),
+ ModelVarType.FIXED_SMALL: (
+ self.posterior_variance,
+ self.posterior_log_variance_clipped,
+ ),
+ }[self.model_var_type]
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
+
+ def process_xstart(x):
+ if denoised_fn is not None:
+ x = denoised_fn(x)
+ if clip_denoised:
+ return x.clamp(-1, 1)
+ return x
+
+ if self.model_mean_type == ModelMeanType.START_X:
+ pred_xstart = process_xstart(model_output)
+ else:
+ pred_xstart = process_xstart(
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
+ )
+ model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
+
+ assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
+ return {
+ "mean": model_mean,
+ "variance": model_variance,
+ "log_variance": model_log_variance,
+ "pred_xstart": pred_xstart,
+ "extra": extra,
+ }
+
+ def _predict_xstart_from_eps(self, x_t, t, eps):
+ assert x_t.shape == eps.shape
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
+ )
+
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
+ """
+ Compute the mean for the previous step, given a function cond_fn that
+ computes the gradient of a conditional log probability with respect to
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
+ condition on y.
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
+ """
+ gradient = cond_fn(x, t, **model_kwargs)
+ new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
+ return new_mean
+
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
+ """
+ Compute what the p_mean_variance output would have been, should the
+ model's score function be conditioned by cond_fn.
+ See condition_mean() for details on cond_fn.
+ Unlike condition_mean(), this instead uses the conditioning strategy
+ from Song et al (2020).
+ """
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
+
+ out = p_mean_var.copy()
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
+ out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
+ return out
+
+ def p_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ ):
+ """
+ Sample x_{t-1} from the model at the given timestep.
+ :param model: the model to sample from.
+ :param x: the current tensor at x_{t-1}.
+ :param t: the value of t, starting at 0 for the first diffusion step.
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param cond_fn: if not None, this is a gradient function that acts
+ similarly to the model.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict containing the following keys:
+ - 'sample': a random sample from the model.
+ - 'pred_xstart': a prediction of x_0.
+ """
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ noise = th.randn_like(x)
+ nonzero_mask = (
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ if cond_fn is not None:
+ out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
+
+ def p_sample_loop(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ ):
+ """
+ Generate samples from the model.
+ :param model: the model module.
+ :param shape: the shape of the samples, (N, C, H, W).
+ :param noise: if specified, the noise from the encoder to sample.
+ Should be of the same shape as `shape`.
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param cond_fn: if not None, this is a gradient function that acts
+ similarly to the model.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param device: if specified, the device to create the samples on.
+ If not specified, use a model parameter's device.
+ :param progress: if True, show a tqdm progress bar.
+ :return: a non-differentiable batch of samples.
+ """
+ final = None
+ for sample in self.p_sample_loop_progressive(
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ ):
+ final = sample
+ return final["sample"]
+
+ def p_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ ):
+ """
+ Generate samples from the model and yield intermediate samples from
+ each timestep of diffusion.
+ Arguments are the same as p_sample_loop().
+ Returns a generator over dicts, where each dict is the return value of
+ p_sample().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad():
+ out = self.p_sample(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ )
+ yield out
+ img = out["sample"]
+
+ def ddim_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ ):
+ """
+ Sample x_{t-1} from the model using DDIM.
+ Same usage as p_sample().
+ """
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ if cond_fn is not None:
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
+
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
+
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
+ sigma = (
+ eta
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
+ )
+ # Equation 12.
+ noise = th.randn_like(x)
+ mean_pred = (
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
+ )
+ nonzero_mask = (
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ sample = mean_pred + nonzero_mask * sigma * noise
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
+
+ def ddim_reverse_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ ):
+ """
+ Sample x_{t+1} from the model using DDIM reverse ODE.
+ """
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ if cond_fn is not None:
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
+ - out["pred_xstart"]
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
+
+ # Equation 12. reversed
+ mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
+
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
+
+ def ddim_sample_loop(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ eta=0.0,
+ ):
+ """
+ Generate samples from the model using DDIM.
+ Same usage as p_sample_loop().
+ """
+ final = None
+ for sample in self.ddim_sample_loop_progressive(
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ eta=eta,
+ ):
+ final = sample
+ return final["sample"]
+
+ def ddim_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ eta=0.0,
+ ):
+ """
+ Use DDIM to sample from the model and yield intermediate samples from
+ each timestep of DDIM.
+ Same usage as p_sample_loop_progressive().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad():
+ out = self.ddim_sample(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ eta=eta,
+ )
+ yield out
+ img = out["sample"]
+
+ def _vb_terms_bpd(
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
+ ):
+ """
+ Get a term for the variational lower-bound.
+ The resulting units are bits (rather than nats, as one might expect).
+ This allows for comparison to other papers.
+ :return: a dict with the following keys:
+ - 'output': a shape [N] tensor of NLLs or KLs.
+ - 'pred_xstart': the x_0 predictions.
+ """
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
+ x_start=x_start, x_t=x_t, t=t
+ )
+ out = self.p_mean_variance(
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
+ )
+ kl = normal_kl(
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
+ )
+ kl = mean_flat(kl) / np.log(2.0)
+
+ decoder_nll = -discretized_gaussian_log_likelihood(
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
+ )
+ assert decoder_nll.shape == x_start.shape
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
+
+ # At the first timestep return the decoder NLL,
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
+ output = th.where((t == 0), decoder_nll, kl)
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
+
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
+ """
+ Compute training losses for a single timestep.
+ :param model: the model to evaluate loss on.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :param t: a batch of timestep indices.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param noise: if specified, the specific Gaussian noise to try to remove.
+ :return: a dict with the key "loss" containing a tensor of shape [N].
+ Some mean or variance settings may also have other keys.
+ """
+ if model_kwargs is None:
+ model_kwargs = {}
+ if noise is None:
+ noise = th.randn_like(x_start)
+ x_t = self.q_sample(x_start, t, noise=noise)
+
+ terms = {}
+
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
+ terms["loss"] = self._vb_terms_bpd(
+ model=model,
+ x_start=x_start,
+ x_t=x_t,
+ t=t,
+ clip_denoised=False,
+ model_kwargs=model_kwargs,
+ )["output"]
+ if self.loss_type == LossType.RESCALED_KL:
+ terms["loss"] *= self.num_timesteps
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
+ model_output = model(x_t, t, **model_kwargs)
+
+ if self.model_var_type in [
+ ModelVarType.LEARNED,
+ ModelVarType.LEARNED_RANGE,
+ ]:
+ B, C = x_t.shape[:2]
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:])
+ model_output, model_var_values = th.split(model_output, C, dim=1)
+ # Learn the variance using the variational bound, but don't let
+ # it affect our mean prediction.
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
+ terms["vb"] = self._vb_terms_bpd(
+ model=lambda *args, r=frozen_out: r,
+ x_start=x_start,
+ x_t=x_t,
+ t=t,
+ clip_denoised=False,
+ )["output"]
+ if self.loss_type == LossType.RESCALED_MSE:
+ # Divide by 1000 for equivalence with initial implementation.
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
+ terms["vb"] *= self.num_timesteps / 1000.0
+
+ target = {
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
+ x_start=x_start, x_t=x_t, t=t
+ )[0],
+ ModelMeanType.START_X: x_start,
+ ModelMeanType.EPSILON: noise,
+ }[self.model_mean_type]
+ assert model_output.shape == target.shape == x_start.shape
+ terms["mse"] = mean_flat((target - model_output) ** 2)
+ if "vb" in terms:
+ terms["loss"] = terms["mse"] + terms["vb"]
+ else:
+ terms["loss"] = terms["mse"]
+ else:
+ raise NotImplementedError(self.loss_type)
+
+ return terms
+
+ def _prior_bpd(self, x_start):
+ """
+ Get the prior KL term for the variational lower-bound, measured in
+ bits-per-dim.
+ This term can't be optimized, as it only depends on the encoder.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :return: a batch of [N] KL values (in bits), one per batch element.
+ """
+ batch_size = x_start.shape[0]
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
+ kl_prior = normal_kl(
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
+ )
+ return mean_flat(kl_prior) / np.log(2.0)
+
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
+ """
+ Compute the entire variational lower-bound, measured in bits-per-dim,
+ as well as other related quantities.
+ :param model: the model to evaluate loss on.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :param clip_denoised: if True, clip denoised samples.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict containing the following keys:
+ - total_bpd: the total variational lower-bound, per batch element.
+ - prior_bpd: the prior term in the lower-bound.
+ - vb: an [N x T] tensor of terms in the lower-bound.
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
+ """
+ device = x_start.device
+ batch_size = x_start.shape[0]
+
+ vb = []
+ xstart_mse = []
+ mse = []
+ for t in list(range(self.num_timesteps))[::-1]:
+ t_batch = th.tensor([t] * batch_size, device=device)
+ noise = th.randn_like(x_start)
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
+ # Calculate VLB term at the current timestep
+ with th.no_grad():
+ out = self._vb_terms_bpd(
+ model,
+ x_start=x_start,
+ x_t=x_t,
+ t=t_batch,
+ clip_denoised=clip_denoised,
+ model_kwargs=model_kwargs,
+ )
+ vb.append(out["output"])
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
+ mse.append(mean_flat((eps - noise) ** 2))
+
+ vb = th.stack(vb, dim=1)
+ xstart_mse = th.stack(xstart_mse, dim=1)
+ mse = th.stack(mse, dim=1)
+
+ prior_bpd = self._prior_bpd(x_start)
+ total_bpd = vb.sum(dim=1) + prior_bpd
+ return {
+ "total_bpd": total_bpd,
+ "prior_bpd": prior_bpd,
+ "vb": vb,
+ "xstart_mse": xstart_mse,
+ "mse": mse,
+ }
+
+
+def _extract_into_tensor(arr, timesteps, broadcast_shape):
+ """
+ Extract values from a 1-D numpy array for a batch of indices.
+ :param arr: the 1-D numpy array.
+ :param timesteps: a tensor of indices into the array to extract.
+ :param broadcast_shape: a larger shape of K dimensions with the batch
+ dimension equal to the length of timesteps.
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
+ """
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
+ while len(res.shape) < len(broadcast_shape):
+ res = res[..., None]
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
diff --git a/diffusion/respace.py b/diffusion/respace.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a2cc0435d1ace54466585db9043b284973d454e
--- /dev/null
+++ b/diffusion/respace.py
@@ -0,0 +1,129 @@
+# Modified from OpenAI's diffusion repos
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+
+import numpy as np
+import torch as th
+
+from .gaussian_diffusion import GaussianDiffusion
+
+
+def space_timesteps(num_timesteps, section_counts):
+ """
+ Create a list of timesteps to use from an original diffusion process,
+ given the number of timesteps we want to take from equally-sized portions
+ of the original process.
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
+ If the stride is a string starting with "ddim", then the fixed striding
+ from the DDIM paper is used, and only one section is allowed.
+ :param num_timesteps: the number of diffusion steps in the original
+ process to divide up.
+ :param section_counts: either a list of numbers, or a string containing
+ comma-separated numbers, indicating the step count
+ per section. As a special case, use "ddimN" where N
+ is a number of steps to use the striding from the
+ DDIM paper.
+ :return: a set of diffusion steps from the original process to use.
+ """
+ if isinstance(section_counts, str):
+ if section_counts.startswith("ddim"):
+ desired_count = int(section_counts[len("ddim") :])
+ for i in range(1, num_timesteps):
+ if len(range(0, num_timesteps, i)) == desired_count:
+ return set(range(0, num_timesteps, i))
+ raise ValueError(
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
+ )
+ section_counts = [int(x) for x in section_counts.split(",")]
+ size_per = num_timesteps // len(section_counts)
+ extra = num_timesteps % len(section_counts)
+ start_idx = 0
+ all_steps = []
+ for i, section_count in enumerate(section_counts):
+ size = size_per + (1 if i < extra else 0)
+ if size < section_count:
+ raise ValueError(
+ f"cannot divide section of {size} steps into {section_count}"
+ )
+ if section_count <= 1:
+ frac_stride = 1
+ else:
+ frac_stride = (size - 1) / (section_count - 1)
+ cur_idx = 0.0
+ taken_steps = []
+ for _ in range(section_count):
+ taken_steps.append(start_idx + round(cur_idx))
+ cur_idx += frac_stride
+ all_steps += taken_steps
+ start_idx += size
+ return set(all_steps)
+
+
+class SpacedDiffusion(GaussianDiffusion):
+ """
+ A diffusion process which can skip steps in a base diffusion process.
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
+ original diffusion process to retain.
+ :param kwargs: the kwargs to create the base diffusion process.
+ """
+
+ def __init__(self, use_timesteps, **kwargs):
+ self.use_timesteps = set(use_timesteps)
+ self.timestep_map = []
+ self.original_num_steps = len(kwargs["betas"])
+
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
+ last_alpha_cumprod = 1.0
+ new_betas = []
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
+ if i in self.use_timesteps:
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
+ last_alpha_cumprod = alpha_cumprod
+ self.timestep_map.append(i)
+ kwargs["betas"] = np.array(new_betas)
+ super().__init__(**kwargs)
+
+ def p_mean_variance(
+ self, model, *args, **kwargs
+ ): # pylint: disable=signature-differs
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
+
+ def training_losses(
+ self, model, *args, **kwargs
+ ): # pylint: disable=signature-differs
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
+
+ def condition_mean(self, cond_fn, *args, **kwargs):
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
+
+ def condition_score(self, cond_fn, *args, **kwargs):
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
+
+ def _wrap_model(self, model):
+ if isinstance(model, _WrappedModel):
+ return model
+ return _WrappedModel(
+ model, self.timestep_map, self.original_num_steps
+ )
+
+ def _scale_timesteps(self, t):
+ # Scaling is done by the wrapped model.
+ return t
+
+
+class _WrappedModel:
+ def __init__(self, model, timestep_map, original_num_steps):
+ self.model = model
+ self.timestep_map = timestep_map
+ # self.rescale_timesteps = rescale_timesteps
+ self.original_num_steps = original_num_steps
+
+ def __call__(self, x, ts, **kwargs):
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
+ new_ts = map_tensor[ts]
+ # if self.rescale_timesteps:
+ # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
+ return self.model(x, new_ts, **kwargs)
diff --git a/diffusion/timestep_sampler.py b/diffusion/timestep_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3f369847677d8dbaaadb8297691b1be92cf189f
--- /dev/null
+++ b/diffusion/timestep_sampler.py
@@ -0,0 +1,150 @@
+# Modified from OpenAI's diffusion repos
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+
+from abc import ABC, abstractmethod
+
+import numpy as np
+import torch as th
+import torch.distributed as dist
+
+
+def create_named_schedule_sampler(name, diffusion):
+ """
+ Create a ScheduleSampler from a library of pre-defined samplers.
+ :param name: the name of the sampler.
+ :param diffusion: the diffusion object to sample for.
+ """
+ if name == "uniform":
+ return UniformSampler(diffusion)
+ elif name == "loss-second-moment":
+ return LossSecondMomentResampler(diffusion)
+ else:
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
+
+
+class ScheduleSampler(ABC):
+ """
+ A distribution over timesteps in the diffusion process, intended to reduce
+ variance of the objective.
+ By default, samplers perform unbiased importance sampling, in which the
+ objective's mean is unchanged.
+ However, subclasses may override sample() to change how the resampled
+ terms are reweighted, allowing for actual changes in the objective.
+ """
+
+ @abstractmethod
+ def weights(self):
+ """
+ Get a numpy array of weights, one per diffusion step.
+ The weights needn't be normalized, but must be positive.
+ """
+
+ def sample(self, batch_size, device):
+ """
+ Importance-sample timesteps for a batch.
+ :param batch_size: the number of timesteps.
+ :param device: the torch device to save to.
+ :return: a tuple (timesteps, weights):
+ - timesteps: a tensor of timestep indices.
+ - weights: a tensor of weights to scale the resulting losses.
+ """
+ w = self.weights()
+ p = w / np.sum(w)
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
+ indices = th.from_numpy(indices_np).long().to(device)
+ weights_np = 1 / (len(p) * p[indices_np])
+ weights = th.from_numpy(weights_np).float().to(device)
+ return indices, weights
+
+
+class UniformSampler(ScheduleSampler):
+ def __init__(self, diffusion):
+ self.diffusion = diffusion
+ self._weights = np.ones([diffusion.num_timesteps])
+
+ def weights(self):
+ return self._weights
+
+
+class LossAwareSampler(ScheduleSampler):
+ def update_with_local_losses(self, local_ts, local_losses):
+ """
+ Update the reweighting using losses from a model.
+ Call this method from each rank with a batch of timesteps and the
+ corresponding losses for each of those timesteps.
+ This method will perform synchronization to make sure all of the ranks
+ maintain the exact same reweighting.
+ :param local_ts: an integer Tensor of timesteps.
+ :param local_losses: a 1D Tensor of losses.
+ """
+ batch_sizes = [
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
+ for _ in range(dist.get_world_size())
+ ]
+ dist.all_gather(
+ batch_sizes,
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
+ )
+
+ # Pad all_gather batches to be the maximum batch size.
+ batch_sizes = [x.item() for x in batch_sizes]
+ max_bs = max(batch_sizes)
+
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
+ dist.all_gather(timestep_batches, local_ts)
+ dist.all_gather(loss_batches, local_losses)
+ timesteps = [
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
+ ]
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
+ self.update_with_all_losses(timesteps, losses)
+
+ @abstractmethod
+ def update_with_all_losses(self, ts, losses):
+ """
+ Update the reweighting using losses from a model.
+ Sub-classes should override this method to update the reweighting
+ using losses from the model.
+ This method directly updates the reweighting without synchronizing
+ between workers. It is called by update_with_local_losses from all
+ ranks with identical arguments. Thus, it should have deterministic
+ behavior to maintain state across workers.
+ :param ts: a list of int timesteps.
+ :param losses: a list of float losses, one per timestep.
+ """
+
+
+class LossSecondMomentResampler(LossAwareSampler):
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
+ self.diffusion = diffusion
+ self.history_per_term = history_per_term
+ self.uniform_prob = uniform_prob
+ self._loss_history = np.zeros(
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
+ )
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
+
+ def weights(self):
+ if not self._warmed_up():
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
+ weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
+ weights /= np.sum(weights)
+ weights *= 1 - self.uniform_prob
+ weights += self.uniform_prob / len(weights)
+ return weights
+
+ def update_with_all_losses(self, ts, losses):
+ for t, loss in zip(ts, losses):
+ if self._loss_counts[t] == self.history_per_term:
+ # Shift out the oldest loss term.
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
+ self._loss_history[t, -1] = loss
+ else:
+ self._loss_history[t, self._loss_counts[t]] = loss
+ self._loss_counts[t] += 1
+
+ def _warmed_up(self):
+ return (self._loss_counts == self.history_per_term).all()
diff --git a/download.py b/download.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfdcf3b3e7481a84b7ff30a5f579195f271d21ef
--- /dev/null
+++ b/download.py
@@ -0,0 +1,47 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Functions for downloading pre-trained DiT models
+"""
+from torchvision.datasets.utils import download_url
+import torch
+import os
+
+
+pretrained_models = {'DiT-XL-2-512x512.pt', 'DiT-XL-2-256x256.pt'}
+
+
+def find_model(model_name):
+ """
+ Finds a pre-trained G.pt model, downloading it if necessary. Alternatively, loads a model from a local path.
+ """
+ if model_name in pretrained_models: # Find/download our pre-trained G.pt checkpoints
+ return download_model(model_name)
+ else: # Load a custom DiT checkpoint:
+ assert os.path.isfile(model_name), f'Could not find DiT checkpoint at {model_name}'
+ return torch.load(model_name, map_location=lambda storage, loc: storage)
+
+
+def download_model(model_name):
+ """
+ Downloads a pre-trained DiT model from the web.
+ """
+ assert model_name in pretrained_models
+ local_path = f'pretrained_models/{model_name}'
+ if not os.path.isfile(local_path):
+ os.makedirs('pretrained_models', exist_ok=True)
+ web_path = f'https://dl.fbaipublicfiles.com/DiT/models/{model_name}'
+ download_url(web_path, 'pretrained_models')
+ model = torch.load(local_path, map_location=lambda storage, loc: storage)
+ return model
+
+
+if __name__ == "__main__":
+ # Download all DiT checkpoints
+ for model in pretrained_models:
+ download_model(model)
+ print('Done.')
diff --git a/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/captions.json b/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/captions.json
new file mode 100644
index 0000000000000000000000000000000000000000..42a01d277fc24c6b3ff121343348579bd6e75479
--- /dev/null
+++ b/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/captions.json
@@ -0,0 +1 @@
+{"/home/peebles/demo/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpe9n9k0az54l_f3gw.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpmoaadwkx_iwdl820.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpurgtzs1y9ie5h2s2.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpjtpzxdb26dcckrfs.png": null}
\ No newline at end of file
diff --git a/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpe9n9k0az54l_f3gw.png b/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpe9n9k0az54l_f3gw.png
new file mode 100644
index 0000000000000000000000000000000000000000..929ccf2b10b96d8651c23ec766d3866a99aeff46
Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpe9n9k0az54l_f3gw.png differ
diff --git a/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpjtpzxdb26dcckrfs.png b/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpjtpzxdb26dcckrfs.png
new file mode 100644
index 0000000000000000000000000000000000000000..d02d4de55875dc68e2af562798efc3ce9910efb4
Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpjtpzxdb26dcckrfs.png differ
diff --git a/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpmoaadwkx_iwdl820.png b/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpmoaadwkx_iwdl820.png
new file mode 100644
index 0000000000000000000000000000000000000000..1bccab369fb801a6b20e9e18cf82cd59eced3bb9
Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpmoaadwkx_iwdl820.png differ
diff --git a/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpurgtzs1y9ie5h2s2.png b/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpurgtzs1y9ie5h2s2.png
new file mode 100644
index 0000000000000000000000000000000000000000..dc29f02a7d6bd3a4d71f7c2d17121aff8692b577
Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec/tmpurgtzs1y9ie5h2s2.png differ
diff --git a/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/captions.json b/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/captions.json
new file mode 100644
index 0000000000000000000000000000000000000000..1893450dbb2cd025d2b1c10aaf47bd642d142151
--- /dev/null
+++ b/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/captions.json
@@ -0,0 +1 @@
+{"/home/peebles/demo/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmpfz82ezx8is9yjfdi.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmpff9tyvtyjf15zmm7.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmpz96a09w08a8kcl84.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmppl9m3bza1t5k15ta.png": null}
\ No newline at end of file
diff --git a/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmpff9tyvtyjf15zmm7.png b/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmpff9tyvtyjf15zmm7.png
new file mode 100644
index 0000000000000000000000000000000000000000..fe8e9dded1f9107d6c0ac527b18fa555a7dbf96d
Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmpff9tyvtyjf15zmm7.png differ
diff --git a/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmpfz82ezx8is9yjfdi.png b/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmpfz82ezx8is9yjfdi.png
new file mode 100644
index 0000000000000000000000000000000000000000..ca65fb2db579c355ff026bc657f7ea0ec381b779
Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmpfz82ezx8is9yjfdi.png differ
diff --git a/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmppl9m3bza1t5k15ta.png b/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmppl9m3bza1t5k15ta.png
new file mode 100644
index 0000000000000000000000000000000000000000..d6dac6bf8ff5203dd7229adcfbdfdcbb7f0a053b
Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmppl9m3bza1t5k15ta.png differ
diff --git a/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmpz96a09w08a8kcl84.png b/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmpz96a09w08a8kcl84.png
new file mode 100644
index 0000000000000000000000000000000000000000..7324c35c02e1a2d66951827533b8db49feb176b6
Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e/tmpz96a09w08a8kcl84.png differ
diff --git a/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/captions.json b/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/captions.json
new file mode 100644
index 0000000000000000000000000000000000000000..45a65e26742bc3cb66cc8eb085e895606ccb90c1
--- /dev/null
+++ b/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/captions.json
@@ -0,0 +1 @@
+{"/home/peebles/demo/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmpkhiwf5_97pu17qfx.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmp8pyeu6a8l_en8xim.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmpipmoifyv87sepx5o.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmp63ria_p2997pqmtd.png": null}
\ No newline at end of file
diff --git a/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmp63ria_p2997pqmtd.png b/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmp63ria_p2997pqmtd.png
new file mode 100644
index 0000000000000000000000000000000000000000..5a30e3d619f7d3f8c9ae93a644fb2329819e8a3b
Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmp63ria_p2997pqmtd.png differ
diff --git a/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmp8pyeu6a8l_en8xim.png b/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmp8pyeu6a8l_en8xim.png
new file mode 100644
index 0000000000000000000000000000000000000000..cc06ee1f0b48593fd8170164a51ce11b9a0ef212
Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmp8pyeu6a8l_en8xim.png differ
diff --git a/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmpipmoifyv87sepx5o.png b/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmpipmoifyv87sepx5o.png
new file mode 100644
index 0000000000000000000000000000000000000000..e8ff9b2de5150afebcff7c48d1de1f82c77255bf
Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmpipmoifyv87sepx5o.png differ
diff --git a/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmpkhiwf5_97pu17qfx.png b/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmpkhiwf5_97pu17qfx.png
new file mode 100644
index 0000000000000000000000000000000000000000..596665ee0e55e9822e5fc3d34c369aa8552d93f7
Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63/tmpkhiwf5_97pu17qfx.png differ
diff --git a/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/captions.json b/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/captions.json
new file mode 100644
index 0000000000000000000000000000000000000000..cc2544c2fe11d2f7368858348cb99772daf2eab3
--- /dev/null
+++ b/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/captions.json
@@ -0,0 +1 @@
+{"/home/peebles/demo/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmpgnbz7k1qi372rdol.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmpv4smnu57thdhxlo5.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmp3oqwhs41q0ygkzgo.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmp1j0z6hrnm0vw0rai.png": null}
\ No newline at end of file
diff --git a/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmp1j0z6hrnm0vw0rai.png b/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmp1j0z6hrnm0vw0rai.png
new file mode 100644
index 0000000000000000000000000000000000000000..c7a7cdf3f95315f8cd87c4abd8d5b8c03e30273b
Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmp1j0z6hrnm0vw0rai.png differ
diff --git a/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmp3oqwhs41q0ygkzgo.png b/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmp3oqwhs41q0ygkzgo.png
new file mode 100644
index 0000000000000000000000000000000000000000..0f0458f72d60b5622e1ebb73ac7b448a148bc531
Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmp3oqwhs41q0ygkzgo.png differ
diff --git a/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmpgnbz7k1qi372rdol.png b/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmpgnbz7k1qi372rdol.png
new file mode 100644
index 0000000000000000000000000000000000000000..f2b7c1d78f55e9a31ce1cbcde57f3cbcc24df0f2
Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmpgnbz7k1qi372rdol.png differ
diff --git a/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmpv4smnu57thdhxlo5.png b/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmpv4smnu57thdhxlo5.png
new file mode 100644
index 0000000000000000000000000000000000000000..d8813c0322c0a6744454eddc0cd2689f5a32d2d0
Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61/tmpv4smnu57thdhxlo5.png differ
diff --git a/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/captions.json b/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/captions.json
new file mode 100644
index 0000000000000000000000000000000000000000..cf5f8231d7fe826c59e62505a484fc49222d5f9b
--- /dev/null
+++ b/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/captions.json
@@ -0,0 +1 @@
+{"/home/peebles/demo/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmpamh42w2dmr65v467.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmp2xnh42_y5utj14_w.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmp7th5inqczdoynzrl.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmpnleysn66i6zycwi5.png": null}
\ No newline at end of file
diff --git a/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmp2xnh42_y5utj14_w.png b/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmp2xnh42_y5utj14_w.png
new file mode 100644
index 0000000000000000000000000000000000000000..04fafa07b652ddc3b21ed6c535a20fe983483ffb
Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmp2xnh42_y5utj14_w.png differ
diff --git a/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmp7th5inqczdoynzrl.png b/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmp7th5inqczdoynzrl.png
new file mode 100644
index 0000000000000000000000000000000000000000..4a6a3d26ccfe7c65e69e43e0dc4cf55d48da09c3
Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmp7th5inqczdoynzrl.png differ
diff --git a/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmpamh42w2dmr65v467.png b/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmpamh42w2dmr65v467.png
new file mode 100644
index 0000000000000000000000000000000000000000..ace17cf260b0726e968f6f44d3600ba17d57c79e
Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmpamh42w2dmr65v467.png differ
diff --git a/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmpnleysn66i6zycwi5.png b/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmpnleysn66i6zycwi5.png
new file mode 100644
index 0000000000000000000000000000000000000000..fd043f10f30cb37029c0947e2c87b7cce58f5920
Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf/tmpnleysn66i6zycwi5.png differ
diff --git a/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/captions.json b/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/captions.json
new file mode 100644
index 0000000000000000000000000000000000000000..3c4754345c0b8531ad80977366f5a907fd94545c
--- /dev/null
+++ b/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/captions.json
@@ -0,0 +1 @@
+{"/home/peebles/demo/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmp_haqfqy2_0t9ntvn.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmpwy2tstw8dlphar0w.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmpe7h1fv6ojayij55x.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmpvza140f8v_pikrxf.png": null}
\ No newline at end of file
diff --git a/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmp_haqfqy2_0t9ntvn.png b/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmp_haqfqy2_0t9ntvn.png
new file mode 100644
index 0000000000000000000000000000000000000000..f3ad3af3fc31c9570d2d92829ef2f9d6d2b46349
Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmp_haqfqy2_0t9ntvn.png differ
diff --git a/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmpe7h1fv6ojayij55x.png b/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmpe7h1fv6ojayij55x.png
new file mode 100644
index 0000000000000000000000000000000000000000..84ca66ecb7c6ae9f23e9c790037e37653a86bde0
Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmpe7h1fv6ojayij55x.png differ
diff --git a/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmpvza140f8v_pikrxf.png b/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmpvza140f8v_pikrxf.png
new file mode 100644
index 0000000000000000000000000000000000000000..a246622dd0beb9632ae76fdf033e9fcaa3c9ad8c
Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmpvza140f8v_pikrxf.png differ
diff --git a/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmpwy2tstw8dlphar0w.png b/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmpwy2tstw8dlphar0w.png
new file mode 100644
index 0000000000000000000000000000000000000000..5e1ab2903964c4f7b964a01ff8a28f1450b3117f
Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd/tmpwy2tstw8dlphar0w.png differ
diff --git a/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/captions.json b/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/captions.json
new file mode 100644
index 0000000000000000000000000000000000000000..225a3463fd33f2eb8726076b6347e21039a801de
--- /dev/null
+++ b/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/captions.json
@@ -0,0 +1 @@
+{"/home/peebles/demo/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmp2u8kljaep52iold0.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmpbbapqj_eo3ty6ub2.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmp3wbfuj7t3s4jrwxg.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmp5a1jyli246vbchhw.png": null}
\ No newline at end of file
diff --git a/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmp2u8kljaep52iold0.png b/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmp2u8kljaep52iold0.png
new file mode 100644
index 0000000000000000000000000000000000000000..3e23bab873c45546d249fecb845b7f5c84661742
Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmp2u8kljaep52iold0.png differ
diff --git a/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmp3wbfuj7t3s4jrwxg.png b/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmp3wbfuj7t3s4jrwxg.png
new file mode 100644
index 0000000000000000000000000000000000000000..4413095f76aa125986448f93aef1971a61fa8ae0
Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmp3wbfuj7t3s4jrwxg.png differ
diff --git a/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmp5a1jyli246vbchhw.png b/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmp5a1jyli246vbchhw.png
new file mode 100644
index 0000000000000000000000000000000000000000..554e274a289b0df3f0a19d57b1d37e6dfc86cd6f
Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmp5a1jyli246vbchhw.png differ
diff --git a/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmpbbapqj_eo3ty6ub2.png b/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmpbbapqj_eo3ty6ub2.png
new file mode 100644
index 0000000000000000000000000000000000000000..9c1bac4e9627743b5c9b2f6b620a18b244d84398
Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998/tmpbbapqj_eo3ty6ub2.png differ
diff --git a/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/captions.json b/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/captions.json
new file mode 100644
index 0000000000000000000000000000000000000000..51926d3808e002ed81f38a16100fd05d59dc5658
--- /dev/null
+++ b/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/captions.json
@@ -0,0 +1 @@
+{"/home/peebles/demo/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/tmprezpqwr5516qiohq.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/tmpgt7c7p6l8xiv4p83.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/tmpt_dcbfedfr0d8mnn.png": null, "/home/peebles/demo/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/tmpbrh6kbrfgw02swpp.png": null}
\ No newline at end of file
diff --git a/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/tmpbrh6kbrfgw02swpp.png b/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/tmpbrh6kbrfgw02swpp.png
new file mode 100644
index 0000000000000000000000000000000000000000..70dc335566745a187df069aa849668217a32182c
Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/tmpbrh6kbrfgw02swpp.png differ
diff --git a/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/tmpgt7c7p6l8xiv4p83.png b/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/tmpgt7c7p6l8xiv4p83.png
new file mode 100644
index 0000000000000000000000000000000000000000..be9c1c9e4d29a996ccd9e1220463229385c1e8bd
Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/tmpgt7c7p6l8xiv4p83.png differ
diff --git a/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/tmprezpqwr5516qiohq.png b/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/tmprezpqwr5516qiohq.png
new file mode 100644
index 0000000000000000000000000000000000000000..171e084f4fd9ccc91fd90bb64f861e6ef56c17a6
Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/tmprezpqwr5516qiohq.png differ
diff --git a/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/tmpt_dcbfedfr0d8mnn.png b/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/tmpt_dcbfedfr0d8mnn.png
new file mode 100644
index 0000000000000000000000000000000000000000..665b1eef870aacdde6e3d7b55ec8e3da2d377bde
Binary files /dev/null and b/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8/tmpt_dcbfedfr0d8mnn.png differ
diff --git a/gradio_cached_examples/24/log.csv b/gradio_cached_examples/24/log.csv
new file mode 100644
index 0000000000000000000000000000000000000000..e42b6baeea589b126c02d031fa10b2fbb5c600ef
--- /dev/null
+++ b/gradio_cached_examples/24/log.csv
@@ -0,0 +1,9 @@
+Generated Images,flag,username,timestamp
+/home/peebles/demo/gradio_cached_examples/24/Generated Images/231fe3e1-c89b-4f63-82a9-440acf3a7cec,,,2022-12-21 22:19:52.975486
+/home/peebles/demo/gradio_cached_examples/24/Generated Images/d27e1dc9-db99-41dd-8b49-64c258de32dd,,,2022-12-21 22:21:49.400813
+/home/peebles/demo/gradio_cached_examples/24/Generated Images/9911670d-5269-42e8-88b3-6b14ef619d63,,,2022-12-21 22:23:45.918671
+/home/peebles/demo/gradio_cached_examples/24/Generated Images/38eaef4c-3ed3-4733-b410-2f2c7392898e,,,2022-12-21 22:25:42.360174
+/home/peebles/demo/gradio_cached_examples/24/Generated Images/e087b228-ae0b-4c27-92d4-6843133aa998,,,2022-12-21 22:27:38.970601
+/home/peebles/demo/gradio_cached_examples/24/Generated Images/c4f23c2e-829d-4261-b0de-f48c17bf2f61,,,2022-12-21 22:28:17.384826
+/home/peebles/demo/gradio_cached_examples/24/Generated Images/ed43f675-4837-43b1-b1ff-7081b5df0bf8,,,2022-12-21 22:28:44.680872
+/home/peebles/demo/gradio_cached_examples/24/Generated Images/cfcf58d6-cf80-432e-9da1-a30fc69440bf,,,2022-12-21 22:29:12.090413
diff --git a/imagenet_class_data.py b/imagenet_class_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e52b050f2fffaed156eae333a245cf660e8da0f
--- /dev/null
+++ b/imagenet_class_data.py
@@ -0,0 +1,1003 @@
+# https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt
+
+IMAGENET_1K_CLASSES = \
+{0: 'tench, Tinca tinca',
+ 1: 'goldfish, Carassius auratus',
+ 2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',
+ 3: 'tiger shark, Galeocerdo cuvieri',
+ 4: 'hammerhead, hammerhead shark',
+ 5: 'electric ray, crampfish, numbfish, torpedo',
+ 6: 'stingray',
+ 7: 'cock',
+ 8: 'hen',
+ 9: 'ostrich, Struthio camelus',
+ 10: 'brambling, Fringilla montifringilla',
+ 11: 'goldfinch, Carduelis carduelis',
+ 12: 'house finch, linnet, Carpodacus mexicanus',
+ 13: 'junco, snowbird',
+ 14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea',
+ 15: 'robin, American robin, Turdus migratorius',
+ 16: 'bulbul',
+ 17: 'jay',
+ 18: 'magpie',
+ 19: 'chickadee',
+ 20: 'water ouzel, dipper',
+ 21: 'kite',
+ 22: 'bald eagle, American eagle, Haliaeetus leucocephalus',
+ 23: 'vulture',
+ 24: 'great grey owl, great gray owl, Strix nebulosa',
+ 25: 'European fire salamander, Salamandra salamandra',
+ 26: 'common newt, Triturus vulgaris',
+ 27: 'eft',
+ 28: 'spotted salamander, Ambystoma maculatum',
+ 29: 'axolotl, mud puppy, Ambystoma mexicanum',
+ 30: 'bullfrog, Rana catesbeiana',
+ 31: 'tree frog, tree-frog',
+ 32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui',
+ 33: 'loggerhead, loggerhead turtle, Caretta caretta',
+ 34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea',
+ 35: 'mud turtle',
+ 36: 'terrapin',
+ 37: 'box turtle, box tortoise',
+ 38: 'banded gecko',
+ 39: 'common iguana, iguana, Iguana iguana',
+ 40: 'American chameleon, anole, Anolis carolinensis',
+ 41: 'whiptail, whiptail lizard',
+ 42: 'agama',
+ 43: 'frilled lizard, Chlamydosaurus kingi',
+ 44: 'alligator lizard',
+ 45: 'Gila monster, Heloderma suspectum',
+ 46: 'green lizard, Lacerta viridis',
+ 47: 'African chameleon, Chamaeleo chamaeleon',
+ 48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis',
+ 49: 'African crocodile, Nile crocodile, Crocodylus niloticus',
+ 50: 'American alligator, Alligator mississipiensis',
+ 51: 'triceratops',
+ 52: 'thunder snake, worm snake, Carphophis amoenus',
+ 53: 'ringneck snake, ring-necked snake, ring snake',
+ 54: 'hognose snake, puff adder, sand viper',
+ 55: 'green snake, grass snake',
+ 56: 'king snake, kingsnake',
+ 57: 'garter snake, grass snake',
+ 58: 'water snake',
+ 59: 'vine snake',
+ 60: 'night snake, Hypsiglena torquata',
+ 61: 'boa constrictor, Constrictor constrictor',
+ 62: 'rock python, rock snake, Python sebae',
+ 63: 'Indian cobra, Naja naja',
+ 64: 'green mamba',
+ 65: 'sea snake',
+ 66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus',
+ 67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus',
+ 68: 'sidewinder, horned rattlesnake, Crotalus cerastes',
+ 69: 'trilobite',
+ 70: 'harvestman, daddy longlegs, Phalangium opilio',
+ 71: 'scorpion',
+ 72: 'black and gold garden spider, Argiope aurantia',
+ 73: 'barn spider, Araneus cavaticus',
+ 74: 'garden spider, Aranea diademata',
+ 75: 'black widow, Latrodectus mactans',
+ 76: 'tarantula',
+ 77: 'wolf spider, hunting spider',
+ 78: 'tick',
+ 79: 'centipede',
+ 80: 'black grouse',
+ 81: 'ptarmigan',
+ 82: 'ruffed grouse, partridge, Bonasa umbellus',
+ 83: 'prairie chicken, prairie grouse, prairie fowl',
+ 84: 'peacock',
+ 85: 'quail',
+ 86: 'partridge',
+ 87: 'African grey, African gray, Psittacus erithacus',
+ 88: 'macaw',
+ 89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita',
+ 90: 'lorikeet',
+ 91: 'coucal',
+ 92: 'bee eater',
+ 93: 'hornbill',
+ 94: 'hummingbird',
+ 95: 'jacamar',
+ 96: 'toucan',
+ 97: 'drake',
+ 98: 'red-breasted merganser, Mergus serrator',
+ 99: 'goose',
+ 100: 'black swan, Cygnus atratus',
+ 101: 'tusker',
+ 102: 'echidna, spiny anteater, anteater',
+ 103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus',
+ 104: 'wallaby, brush kangaroo',
+ 105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus',
+ 106: 'wombat',
+ 107: 'jellyfish',
+ 108: 'sea anemone, anemone',
+ 109: 'brain coral',
+ 110: 'flatworm, platyhelminth',
+ 111: 'nematode, nematode worm, roundworm',
+ 112: 'conch',
+ 113: 'snail',
+ 114: 'slug',
+ 115: 'sea slug, nudibranch',
+ 116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore',
+ 117: 'chambered nautilus, pearly nautilus, nautilus',
+ 118: 'Dungeness crab, Cancer magister',
+ 119: 'rock crab, Cancer irroratus',
+ 120: 'fiddler crab',
+ 121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica',
+ 122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus',
+ 123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish',
+ 124: 'crayfish, crawfish, crawdad, crawdaddy',
+ 125: 'hermit crab',
+ 126: 'isopod',
+ 127: 'white stork, Ciconia ciconia',
+ 128: 'black stork, Ciconia nigra',
+ 129: 'spoonbill',
+ 130: 'flamingo',
+ 131: 'little blue heron, Egretta caerulea',
+ 132: 'American egret, great white heron, Egretta albus',
+ 133: 'bittern',
+ 134: 'crane',
+ 135: 'limpkin, Aramus pictus',
+ 136: 'European gallinule, Porphyrio porphyrio',
+ 137: 'American coot, marsh hen, mud hen, water hen, Fulica americana',
+ 138: 'bustard',
+ 139: 'ruddy turnstone, Arenaria interpres',
+ 140: 'red-backed sandpiper, dunlin, Erolia alpina',
+ 141: 'redshank, Tringa totanus',
+ 142: 'dowitcher',
+ 143: 'oystercatcher, oyster catcher',
+ 144: 'pelican',
+ 145: 'king penguin, Aptenodytes patagonica',
+ 146: 'albatross, mollymawk',
+ 147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus',
+ 148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca',
+ 149: 'dugong, Dugong dugon',
+ 150: 'sea lion',
+ 151: 'Chihuahua',
+ 152: 'Japanese spaniel',
+ 153: 'Maltese dog, Maltese terrier, Maltese',
+ 154: 'Pekinese, Pekingese, Peke',
+ 155: 'Shih-Tzu',
+ 156: 'Blenheim spaniel',
+ 157: 'papillon',
+ 158: 'toy terrier',
+ 159: 'Rhodesian ridgeback',
+ 160: 'Afghan hound, Afghan',
+ 161: 'basset, basset hound',
+ 162: 'beagle',
+ 163: 'bloodhound, sleuthhound',
+ 164: 'bluetick',
+ 165: 'black-and-tan coonhound',
+ 166: 'Walker hound, Walker foxhound',
+ 167: 'English foxhound',
+ 168: 'redbone',
+ 169: 'borzoi, Russian wolfhound',
+ 170: 'Irish wolfhound',
+ 171: 'Italian greyhound',
+ 172: 'whippet',
+ 173: 'Ibizan hound, Ibizan Podenco',
+ 174: 'Norwegian elkhound, elkhound',
+ 175: 'otterhound, otter hound',
+ 176: 'Saluki, gazelle hound',
+ 177: 'Scottish deerhound, deerhound',
+ 178: 'Weimaraner',
+ 179: 'Staffordshire bullterrier, Staffordshire bull terrier',
+ 180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier',
+ 181: 'Bedlington terrier',
+ 182: 'Border terrier',
+ 183: 'Kerry blue terrier',
+ 184: 'Irish terrier',
+ 185: 'Norfolk terrier',
+ 186: 'Norwich terrier',
+ 187: 'Yorkshire terrier',
+ 188: 'wire-haired fox terrier',
+ 189: 'Lakeland terrier',
+ 190: 'Sealyham terrier, Sealyham',
+ 191: 'Airedale, Airedale terrier',
+ 192: 'cairn, cairn terrier',
+ 193: 'Australian terrier',
+ 194: 'Dandie Dinmont, Dandie Dinmont terrier',
+ 195: 'Boston bull, Boston terrier',
+ 196: 'miniature schnauzer',
+ 197: 'giant schnauzer',
+ 198: 'standard schnauzer',
+ 199: 'Scotch terrier, Scottish terrier, Scottie',
+ 200: 'Tibetan terrier, chrysanthemum dog',
+ 201: 'silky terrier, Sydney silky',
+ 202: 'soft-coated wheaten terrier',
+ 203: 'West Highland white terrier',
+ 204: 'Lhasa, Lhasa apso',
+ 205: 'flat-coated retriever',
+ 206: 'curly-coated retriever',
+ 207: 'golden retriever',
+ 208: 'Labrador retriever',
+ 209: 'Chesapeake Bay retriever',
+ 210: 'German short-haired pointer',
+ 211: 'vizsla, Hungarian pointer',
+ 212: 'English setter',
+ 213: 'Irish setter, red setter',
+ 214: 'Gordon setter',
+ 215: 'Brittany spaniel',
+ 216: 'clumber, clumber spaniel',
+ 217: 'English springer, English springer spaniel',
+ 218: 'Welsh springer spaniel',
+ 219: 'cocker spaniel, English cocker spaniel, cocker',
+ 220: 'Sussex spaniel',
+ 221: 'Irish water spaniel',
+ 222: 'kuvasz',
+ 223: 'schipperke',
+ 224: 'groenendael',
+ 225: 'malinois',
+ 226: 'briard',
+ 227: 'kelpie',
+ 228: 'komondor',
+ 229: 'Old English sheepdog, bobtail',
+ 230: 'Shetland sheepdog, Shetland sheep dog, Shetland',
+ 231: 'collie',
+ 232: 'Border collie',
+ 233: 'Bouvier des Flandres, Bouviers des Flandres',
+ 234: 'Rottweiler',
+ 235: 'German shepherd, German shepherd dog, German police dog, alsatian',
+ 236: 'Doberman, Doberman pinscher',
+ 237: 'miniature pinscher',
+ 238: 'Greater Swiss Mountain dog',
+ 239: 'Bernese mountain dog',
+ 240: 'Appenzeller',
+ 241: 'EntleBucher',
+ 242: 'boxer',
+ 243: 'bull mastiff',
+ 244: 'Tibetan mastiff',
+ 245: 'French bulldog',
+ 246: 'Great Dane',
+ 247: 'Saint Bernard, St Bernard',
+ 248: 'Eskimo dog, husky',
+ 249: 'malamute, malemute, Alaskan malamute',
+ 250: 'Siberian husky',
+ 251: 'dalmatian, coach dog, carriage dog',
+ 252: 'affenpinscher, monkey pinscher, monkey dog',
+ 253: 'basenji',
+ 254: 'pug, pug-dog',
+ 255: 'Leonberg',
+ 256: 'Newfoundland, Newfoundland dog',
+ 257: 'Great Pyrenees',
+ 258: 'Samoyed, Samoyede',
+ 259: 'Pomeranian',
+ 260: 'chow, chow chow',
+ 261: 'keeshond',
+ 262: 'Brabancon griffon',
+ 263: 'Pembroke, Pembroke Welsh corgi',
+ 264: 'Cardigan, Cardigan Welsh corgi',
+ 265: 'toy poodle',
+ 266: 'miniature poodle',
+ 267: 'standard poodle',
+ 268: 'Mexican hairless',
+ 269: 'timber wolf, grey wolf, gray wolf, Canis lupus',
+ 270: 'white wolf, Arctic wolf, Canis lupus tundrarum',
+ 271: 'red wolf, maned wolf, Canis rufus, Canis niger',
+ 272: 'coyote, prairie wolf, brush wolf, Canis latrans',
+ 273: 'dingo, warrigal, warragal, Canis dingo',
+ 274: 'dhole, Cuon alpinus',
+ 275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus',
+ 276: 'hyena, hyaena',
+ 277: 'red fox, Vulpes vulpes',
+ 278: 'kit fox, Vulpes macrotis',
+ 279: 'Arctic fox, white fox, Alopex lagopus',
+ 280: 'grey fox, gray fox, Urocyon cinereoargenteus',
+ 281: 'tabby, tabby cat',
+ 282: 'tiger cat',
+ 283: 'Persian cat',
+ 284: 'Siamese cat, Siamese',
+ 285: 'Egyptian cat',
+ 286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor',
+ 287: 'lynx, catamount',
+ 288: 'leopard, Panthera pardus',
+ 289: 'snow leopard, ounce, Panthera uncia',
+ 290: 'jaguar, panther, Panthera onca, Felis onca',
+ 291: 'lion, king of beasts, Panthera leo',
+ 292: 'tiger, Panthera tigris',
+ 293: 'cheetah, chetah, Acinonyx jubatus',
+ 294: 'brown bear, bruin, Ursus arctos',
+ 295: 'American black bear, black bear, Ursus americanus, Euarctos americanus',
+ 296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus',
+ 297: 'sloth bear, Melursus ursinus, Ursus ursinus',
+ 298: 'mongoose',
+ 299: 'meerkat, mierkat',
+ 300: 'tiger beetle',
+ 301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle',
+ 302: 'ground beetle, carabid beetle',
+ 303: 'long-horned beetle, longicorn, longicorn beetle',
+ 304: 'leaf beetle, chrysomelid',
+ 305: 'dung beetle',
+ 306: 'rhinoceros beetle',
+ 307: 'weevil',
+ 308: 'fly',
+ 309: 'bee',
+ 310: 'ant, emmet, pismire',
+ 311: 'grasshopper, hopper',
+ 312: 'cricket',
+ 313: 'walking stick, walkingstick, stick insect',
+ 314: 'cockroach, roach',
+ 315: 'mantis, mantid',
+ 316: 'cicada, cicala',
+ 317: 'leafhopper',
+ 318: 'lacewing, lacewing fly',
+ 319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
+ 320: 'damselfly',
+ 321: 'admiral',
+ 322: 'ringlet, ringlet butterfly',
+ 323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus',
+ 324: 'cabbage butterfly',
+ 325: 'sulphur butterfly, sulfur butterfly',
+ 326: 'lycaenid, lycaenid butterfly',
+ 327: 'starfish, sea star',
+ 328: 'sea urchin',
+ 329: 'sea cucumber, holothurian',
+ 330: 'wood rabbit, cottontail, cottontail rabbit',
+ 331: 'hare',
+ 332: 'Angora, Angora rabbit',
+ 333: 'hamster',
+ 334: 'porcupine, hedgehog',
+ 335: 'fox squirrel, eastern fox squirrel, Sciurus niger',
+ 336: 'marmot',
+ 337: 'beaver',
+ 338: 'guinea pig, Cavia cobaya',
+ 339: 'sorrel',
+ 340: 'zebra',
+ 341: 'hog, pig, grunter, squealer, Sus scrofa',
+ 342: 'wild boar, boar, Sus scrofa',
+ 343: 'warthog',
+ 344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius',
+ 345: 'ox',
+ 346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis',
+ 347: 'bison',
+ 348: 'ram, tup',
+ 349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis',
+ 350: 'ibex, Capra ibex',
+ 351: 'hartebeest',
+ 352: 'impala, Aepyceros melampus',
+ 353: 'gazelle',
+ 354: 'Arabian camel, dromedary, Camelus dromedarius',
+ 355: 'llama',
+ 356: 'weasel',
+ 357: 'mink',
+ 358: 'polecat, fitch, foulmart, foumart, Mustela putorius',
+ 359: 'black-footed ferret, ferret, Mustela nigripes',
+ 360: 'otter',
+ 361: 'skunk, polecat, wood pussy',
+ 362: 'badger',
+ 363: 'armadillo',
+ 364: 'three-toed sloth, ai, Bradypus tridactylus',
+ 365: 'orangutan, orang, orangutang, Pongo pygmaeus',
+ 366: 'gorilla, Gorilla gorilla',
+ 367: 'chimpanzee, chimp, Pan troglodytes',
+ 368: 'gibbon, Hylobates lar',
+ 369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus',
+ 370: 'guenon, guenon monkey',
+ 371: 'patas, hussar monkey, Erythrocebus patas',
+ 372: 'baboon',
+ 373: 'macaque',
+ 374: 'langur',
+ 375: 'colobus, colobus monkey',
+ 376: 'proboscis monkey, Nasalis larvatus',
+ 377: 'marmoset',
+ 378: 'capuchin, ringtail, Cebus capucinus',
+ 379: 'howler monkey, howler',
+ 380: 'titi, titi monkey',
+ 381: 'spider monkey, Ateles geoffroyi',
+ 382: 'squirrel monkey, Saimiri sciureus',
+ 383: 'Madagascar cat, ring-tailed lemur, Lemur catta',
+ 384: 'indri, indris, Indri indri, Indri brevicaudatus',
+ 385: 'Indian elephant, Elephas maximus',
+ 386: 'African elephant, Loxodonta africana',
+ 387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens',
+ 388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca',
+ 389: 'barracouta, snoek',
+ 390: 'eel',
+ 391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch',
+ 392: 'rock beauty, Holocanthus tricolor',
+ 393: 'anemone fish',
+ 394: 'sturgeon',
+ 395: 'gar, garfish, garpike, billfish, Lepisosteus osseus',
+ 396: 'lionfish',
+ 397: 'puffer, pufferfish, blowfish, globefish',
+ 398: 'abacus',
+ 399: 'abaya',
+ 400: "academic gown, academic robe, judge's robe",
+ 401: 'accordion, piano accordion, squeeze box',
+ 402: 'acoustic guitar',
+ 403: 'aircraft carrier, carrier, flattop, attack aircraft carrier',
+ 404: 'airliner',
+ 405: 'airship, dirigible',
+ 406: 'altar',
+ 407: 'ambulance',
+ 408: 'amphibian, amphibious vehicle',
+ 409: 'analog clock',
+ 410: 'apiary, bee house',
+ 411: 'apron',
+ 412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin',
+ 413: 'assault rifle, assault gun',
+ 414: 'backpack, back pack, knapsack, packsack, rucksack, haversack',
+ 415: 'bakery, bakeshop, bakehouse',
+ 416: 'balance beam, beam',
+ 417: 'balloon',
+ 418: 'ballpoint, ballpoint pen, ballpen, Biro',
+ 419: 'Band Aid',
+ 420: 'banjo',
+ 421: 'bannister, banister, balustrade, balusters, handrail',
+ 422: 'barbell',
+ 423: 'barber chair',
+ 424: 'barbershop',
+ 425: 'barn',
+ 426: 'barometer',
+ 427: 'barrel, cask',
+ 428: 'barrow, garden cart, lawn cart, wheelbarrow',
+ 429: 'baseball',
+ 430: 'basketball',
+ 431: 'bassinet',
+ 432: 'bassoon',
+ 433: 'bathing cap, swimming cap',
+ 434: 'bath towel',
+ 435: 'bathtub, bathing tub, bath, tub',
+ 436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon',
+ 437: 'beacon, lighthouse, beacon light, pharos',
+ 438: 'beaker',
+ 439: 'bearskin, busby, shako',
+ 440: 'beer bottle',
+ 441: 'beer glass',
+ 442: 'bell cote, bell cot',
+ 443: 'bib',
+ 444: 'bicycle-built-for-two, tandem bicycle, tandem',
+ 445: 'bikini, two-piece',
+ 446: 'binder, ring-binder',
+ 447: 'binoculars, field glasses, opera glasses',
+ 448: 'birdhouse',
+ 449: 'boathouse',
+ 450: 'bobsled, bobsleigh, bob',
+ 451: 'bolo tie, bolo, bola tie, bola',
+ 452: 'bonnet, poke bonnet',
+ 453: 'bookcase',
+ 454: 'bookshop, bookstore, bookstall',
+ 455: 'bottlecap',
+ 456: 'bow',
+ 457: 'bow tie, bow-tie, bowtie',
+ 458: 'brass, memorial tablet, plaque',
+ 459: 'brassiere, bra, bandeau',
+ 460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty',
+ 461: 'breastplate, aegis, egis',
+ 462: 'broom',
+ 463: 'bucket, pail',
+ 464: 'buckle',
+ 465: 'bulletproof vest',
+ 466: 'bullet train, bullet',
+ 467: 'butcher shop, meat market',
+ 468: 'cab, hack, taxi, taxicab',
+ 469: 'caldron, cauldron',
+ 470: 'candle, taper, wax light',
+ 471: 'cannon',
+ 472: 'canoe',
+ 473: 'can opener, tin opener',
+ 474: 'cardigan',
+ 475: 'car mirror',
+ 476: 'carousel, carrousel, merry-go-round, roundabout, whirligig',
+ 477: "carpenter's kit, tool kit",
+ 478: 'carton',
+ 479: 'car wheel',
+ 480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM',
+ 481: 'cassette',
+ 482: 'cassette player',
+ 483: 'castle',
+ 484: 'catamaran',
+ 485: 'CD player',
+ 486: 'cello, violoncello',
+ 487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone',
+ 488: 'chain',
+ 489: 'chainlink fence',
+ 490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour',
+ 491: 'chain saw, chainsaw',
+ 492: 'chest',
+ 493: 'chiffonier, commode',
+ 494: 'chime, bell, gong',
+ 495: 'china cabinet, china closet',
+ 496: 'Christmas stocking',
+ 497: 'church, church building',
+ 498: 'cinema, movie theater, movie theatre, movie house, picture palace',
+ 499: 'cleaver, meat cleaver, chopper',
+ 500: 'cliff dwelling',
+ 501: 'cloak',
+ 502: 'clog, geta, patten, sabot',
+ 503: 'cocktail shaker',
+ 504: 'coffee mug',
+ 505: 'coffeepot',
+ 506: 'coil, spiral, volute, whorl, helix',
+ 507: 'combination lock',
+ 508: 'computer keyboard, keypad',
+ 509: 'confectionery, confectionary, candy store',
+ 510: 'container ship, containership, container vessel',
+ 511: 'convertible',
+ 512: 'corkscrew, bottle screw',
+ 513: 'cornet, horn, trumpet, trump',
+ 514: 'cowboy boot',
+ 515: 'cowboy hat, ten-gallon hat',
+ 516: 'cradle',
+ 517: 'crane',
+ 518: 'crash helmet',
+ 519: 'crate',
+ 520: 'crib, cot',
+ 521: 'Crock Pot',
+ 522: 'croquet ball',
+ 523: 'crutch',
+ 524: 'cuirass',
+ 525: 'dam, dike, dyke',
+ 526: 'desk',
+ 527: 'desktop computer',
+ 528: 'dial telephone, dial phone',
+ 529: 'diaper, nappy, napkin',
+ 530: 'digital clock',
+ 531: 'digital watch',
+ 532: 'dining table, board',
+ 533: 'dishrag, dishcloth',
+ 534: 'dishwasher, dish washer, dishwashing machine',
+ 535: 'disk brake, disc brake',
+ 536: 'dock, dockage, docking facility',
+ 537: 'dogsled, dog sled, dog sleigh',
+ 538: 'dome',
+ 539: 'doormat, welcome mat',
+ 540: 'drilling platform, offshore rig',
+ 541: 'drum, membranophone, tympan',
+ 542: 'drumstick',
+ 543: 'dumbbell',
+ 544: 'Dutch oven',
+ 545: 'electric fan, blower',
+ 546: 'electric guitar',
+ 547: 'electric locomotive',
+ 548: 'entertainment center',
+ 549: 'envelope',
+ 550: 'espresso maker',
+ 551: 'face powder',
+ 552: 'feather boa, boa',
+ 553: 'file, file cabinet, filing cabinet',
+ 554: 'fireboat',
+ 555: 'fire engine, fire truck',
+ 556: 'fire screen, fireguard',
+ 557: 'flagpole, flagstaff',
+ 558: 'flute, transverse flute',
+ 559: 'folding chair',
+ 560: 'football helmet',
+ 561: 'forklift',
+ 562: 'fountain',
+ 563: 'fountain pen',
+ 564: 'four-poster',
+ 565: 'freight car',
+ 566: 'French horn, horn',
+ 567: 'frying pan, frypan, skillet',
+ 568: 'fur coat',
+ 569: 'garbage truck, dustcart',
+ 570: 'gasmask, respirator, gas helmet',
+ 571: 'gas pump, gasoline pump, petrol pump, island dispenser',
+ 572: 'goblet',
+ 573: 'go-kart',
+ 574: 'golf ball',
+ 575: 'golfcart, golf cart',
+ 576: 'gondola',
+ 577: 'gong, tam-tam',
+ 578: 'gown',
+ 579: 'grand piano, grand',
+ 580: 'greenhouse, nursery, glasshouse',
+ 581: 'grille, radiator grille',
+ 582: 'grocery store, grocery, food market, market',
+ 583: 'guillotine',
+ 584: 'hair slide',
+ 585: 'hair spray',
+ 586: 'half track',
+ 587: 'hammer',
+ 588: 'hamper',
+ 589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier',
+ 590: 'hand-held computer, hand-held microcomputer',
+ 591: 'handkerchief, hankie, hanky, hankey',
+ 592: 'hard disc, hard disk, fixed disk',
+ 593: 'harmonica, mouth organ, harp, mouth harp',
+ 594: 'harp',
+ 595: 'harvester, reaper',
+ 596: 'hatchet',
+ 597: 'holster',
+ 598: 'home theater, home theatre',
+ 599: 'honeycomb',
+ 600: 'hook, claw',
+ 601: 'hoopskirt, crinoline',
+ 602: 'horizontal bar, high bar',
+ 603: 'horse cart, horse-cart',
+ 604: 'hourglass',
+ 605: 'iPod',
+ 606: 'iron, smoothing iron',
+ 607: "jack-o'-lantern",
+ 608: 'jean, blue jean, denim',
+ 609: 'jeep, landrover',
+ 610: 'jersey, T-shirt, tee shirt',
+ 611: 'jigsaw puzzle',
+ 612: 'jinrikisha, ricksha, rickshaw',
+ 613: 'joystick',
+ 614: 'kimono',
+ 615: 'knee pad',
+ 616: 'knot',
+ 617: 'lab coat, laboratory coat',
+ 618: 'ladle',
+ 619: 'lampshade, lamp shade',
+ 620: 'laptop, laptop computer',
+ 621: 'lawn mower, mower',
+ 622: 'lens cap, lens cover',
+ 623: 'letter opener, paper knife, paperknife',
+ 624: 'library',
+ 625: 'lifeboat',
+ 626: 'lighter, light, igniter, ignitor',
+ 627: 'limousine, limo',
+ 628: 'liner, ocean liner',
+ 629: 'lipstick, lip rouge',
+ 630: 'Loafer',
+ 631: 'lotion',
+ 632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system',
+ 633: "loupe, jeweler's loupe",
+ 634: 'lumbermill, sawmill',
+ 635: 'magnetic compass',
+ 636: 'mailbag, postbag',
+ 637: 'mailbox, letter box',
+ 638: 'maillot',
+ 639: 'maillot, tank suit',
+ 640: 'manhole cover',
+ 641: 'maraca',
+ 642: 'marimba, xylophone',
+ 643: 'mask',
+ 644: 'matchstick',
+ 645: 'maypole',
+ 646: 'maze, labyrinth',
+ 647: 'measuring cup',
+ 648: 'medicine chest, medicine cabinet',
+ 649: 'megalith, megalithic structure',
+ 650: 'microphone, mike',
+ 651: 'microwave, microwave oven',
+ 652: 'military uniform',
+ 653: 'milk can',
+ 654: 'minibus',
+ 655: 'miniskirt, mini',
+ 656: 'minivan',
+ 657: 'missile',
+ 658: 'mitten',
+ 659: 'mixing bowl',
+ 660: 'mobile home, manufactured home',
+ 661: 'Model T',
+ 662: 'modem',
+ 663: 'monastery',
+ 664: 'monitor',
+ 665: 'moped',
+ 666: 'mortar',
+ 667: 'mortarboard',
+ 668: 'mosque',
+ 669: 'mosquito net',
+ 670: 'motor scooter, scooter',
+ 671: 'mountain bike, all-terrain bike, off-roader',
+ 672: 'mountain tent',
+ 673: 'mouse, computer mouse',
+ 674: 'mousetrap',
+ 675: 'moving van',
+ 676: 'muzzle',
+ 677: 'nail',
+ 678: 'neck brace',
+ 679: 'necklace',
+ 680: 'nipple',
+ 681: 'notebook, notebook computer',
+ 682: 'obelisk',
+ 683: 'oboe, hautboy, hautbois',
+ 684: 'ocarina, sweet potato',
+ 685: 'odometer, hodometer, mileometer, milometer',
+ 686: 'oil filter',
+ 687: 'organ, pipe organ',
+ 688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO',
+ 689: 'overskirt',
+ 690: 'oxcart',
+ 691: 'oxygen mask',
+ 692: 'packet',
+ 693: 'paddle, boat paddle',
+ 694: 'paddlewheel, paddle wheel',
+ 695: 'padlock',
+ 696: 'paintbrush',
+ 697: "pajama, pyjama, pj's, jammies",
+ 698: 'palace',
+ 699: 'panpipe, pandean pipe, syrinx',
+ 700: 'paper towel',
+ 701: 'parachute, chute',
+ 702: 'parallel bars, bars',
+ 703: 'park bench',
+ 704: 'parking meter',
+ 705: 'passenger car, coach, carriage',
+ 706: 'patio, terrace',
+ 707: 'pay-phone, pay-station',
+ 708: 'pedestal, plinth, footstall',
+ 709: 'pencil box, pencil case',
+ 710: 'pencil sharpener',
+ 711: 'perfume, essence',
+ 712: 'Petri dish',
+ 713: 'photocopier',
+ 714: 'pick, plectrum, plectron',
+ 715: 'pickelhaube',
+ 716: 'picket fence, paling',
+ 717: 'pickup, pickup truck',
+ 718: 'pier',
+ 719: 'piggy bank, penny bank',
+ 720: 'pill bottle',
+ 721: 'pillow',
+ 722: 'ping-pong ball',
+ 723: 'pinwheel',
+ 724: 'pirate, pirate ship',
+ 725: 'pitcher, ewer',
+ 726: "plane, carpenter's plane, woodworking plane",
+ 727: 'planetarium',
+ 728: 'plastic bag',
+ 729: 'plate rack',
+ 730: 'plow, plough',
+ 731: "plunger, plumber's helper",
+ 732: 'Polaroid camera, Polaroid Land camera',
+ 733: 'pole',
+ 734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria',
+ 735: 'poncho',
+ 736: 'pool table, billiard table, snooker table',
+ 737: 'pop bottle, soda bottle',
+ 738: 'pot, flowerpot',
+ 739: "potter's wheel",
+ 740: 'power drill',
+ 741: 'prayer rug, prayer mat',
+ 742: 'printer',
+ 743: 'prison, prison house',
+ 744: 'projectile, missile',
+ 745: 'projector',
+ 746: 'puck, hockey puck',
+ 747: 'punching bag, punch bag, punching ball, punchball',
+ 748: 'purse',
+ 749: 'quill, quill pen',
+ 750: 'quilt, comforter, comfort, puff',
+ 751: 'racer, race car, racing car',
+ 752: 'racket, racquet',
+ 753: 'radiator',
+ 754: 'radio, wireless',
+ 755: 'radio telescope, radio reflector',
+ 756: 'rain barrel',
+ 757: 'recreational vehicle, RV, R.V.',
+ 758: 'reel',
+ 759: 'reflex camera',
+ 760: 'refrigerator, icebox',
+ 761: 'remote control, remote',
+ 762: 'restaurant, eating house, eating place, eatery',
+ 763: 'revolver, six-gun, six-shooter',
+ 764: 'rifle',
+ 765: 'rocking chair, rocker',
+ 766: 'rotisserie',
+ 767: 'rubber eraser, rubber, pencil eraser',
+ 768: 'rugby ball',
+ 769: 'rule, ruler',
+ 770: 'running shoe',
+ 771: 'safe',
+ 772: 'safety pin',
+ 773: 'saltshaker, salt shaker',
+ 774: 'sandal',
+ 775: 'sarong',
+ 776: 'sax, saxophone',
+ 777: 'scabbard',
+ 778: 'scale, weighing machine',
+ 779: 'school bus',
+ 780: 'schooner',
+ 781: 'scoreboard',
+ 782: 'screen, CRT screen',
+ 783: 'screw',
+ 784: 'screwdriver',
+ 785: 'seat belt, seatbelt',
+ 786: 'sewing machine',
+ 787: 'shield, buckler',
+ 788: 'shoe shop, shoe-shop, shoe store',
+ 789: 'shoji',
+ 790: 'shopping basket',
+ 791: 'shopping cart',
+ 792: 'shovel',
+ 793: 'shower cap',
+ 794: 'shower curtain',
+ 795: 'ski',
+ 796: 'ski mask',
+ 797: 'sleeping bag',
+ 798: 'slide rule, slipstick',
+ 799: 'sliding door',
+ 800: 'slot, one-armed bandit',
+ 801: 'snorkel',
+ 802: 'snowmobile',
+ 803: 'snowplow, snowplough',
+ 804: 'soap dispenser',
+ 805: 'soccer ball',
+ 806: 'sock',
+ 807: 'solar dish, solar collector, solar furnace',
+ 808: 'sombrero',
+ 809: 'soup bowl',
+ 810: 'space bar',
+ 811: 'space heater',
+ 812: 'space shuttle',
+ 813: 'spatula',
+ 814: 'speedboat',
+ 815: "spider web, spider's web",
+ 816: 'spindle',
+ 817: 'sports car, sport car',
+ 818: 'spotlight, spot',
+ 819: 'stage',
+ 820: 'steam locomotive',
+ 821: 'steel arch bridge',
+ 822: 'steel drum',
+ 823: 'stethoscope',
+ 824: 'stole',
+ 825: 'stone wall',
+ 826: 'stopwatch, stop watch',
+ 827: 'stove',
+ 828: 'strainer',
+ 829: 'streetcar, tram, tramcar, trolley, trolley car',
+ 830: 'stretcher',
+ 831: 'studio couch, day bed',
+ 832: 'stupa, tope',
+ 833: 'submarine, pigboat, sub, U-boat',
+ 834: 'suit, suit of clothes',
+ 835: 'sundial',
+ 836: 'sunglass',
+ 837: 'sunglasses, dark glasses, shades',
+ 838: 'sunscreen, sunblock, sun blocker',
+ 839: 'suspension bridge',
+ 840: 'swab, swob, mop',
+ 841: 'sweatshirt',
+ 842: 'swimming trunks, bathing trunks',
+ 843: 'swing',
+ 844: 'switch, electric switch, electrical switch',
+ 845: 'syringe',
+ 846: 'table lamp',
+ 847: 'tank, army tank, armored combat vehicle, armoured combat vehicle',
+ 848: 'tape player',
+ 849: 'teapot',
+ 850: 'teddy, teddy bear',
+ 851: 'television, television system',
+ 852: 'tennis ball',
+ 853: 'thatch, thatched roof',
+ 854: 'theater curtain, theatre curtain',
+ 855: 'thimble',
+ 856: 'thresher, thrasher, threshing machine',
+ 857: 'throne',
+ 858: 'tile roof',
+ 859: 'toaster',
+ 860: 'tobacco shop, tobacconist shop, tobacconist',
+ 861: 'toilet seat',
+ 862: 'torch',
+ 863: 'totem pole',
+ 864: 'tow truck, tow car, wrecker',
+ 865: 'toyshop',
+ 866: 'tractor',
+ 867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi',
+ 868: 'tray',
+ 869: 'trench coat',
+ 870: 'tricycle, trike, velocipede',
+ 871: 'trimaran',
+ 872: 'tripod',
+ 873: 'triumphal arch',
+ 874: 'trolleybus, trolley coach, trackless trolley',
+ 875: 'trombone',
+ 876: 'tub, vat',
+ 877: 'turnstile',
+ 878: 'typewriter keyboard',
+ 879: 'umbrella',
+ 880: 'unicycle, monocycle',
+ 881: 'upright, upright piano',
+ 882: 'vacuum, vacuum cleaner',
+ 883: 'vase',
+ 884: 'vault',
+ 885: 'velvet',
+ 886: 'vending machine',
+ 887: 'vestment',
+ 888: 'viaduct',
+ 889: 'violin, fiddle',
+ 890: 'volleyball',
+ 891: 'waffle iron',
+ 892: 'wall clock',
+ 893: 'wallet, billfold, notecase, pocketbook',
+ 894: 'wardrobe, closet, press',
+ 895: 'warplane, military plane',
+ 896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin',
+ 897: 'washer, automatic washer, washing machine',
+ 898: 'water bottle',
+ 899: 'water jug',
+ 900: 'water tower',
+ 901: 'whiskey jug',
+ 902: 'whistle',
+ 903: 'wig',
+ 904: 'window screen',
+ 905: 'window shade',
+ 906: 'Windsor tie',
+ 907: 'wine bottle',
+ 908: 'wing',
+ 909: 'wok',
+ 910: 'wooden spoon',
+ 911: 'wool, woolen, woollen',
+ 912: 'worm fence, snake fence, snake-rail fence, Virginia fence',
+ 913: 'wreck',
+ 914: 'yawl',
+ 915: 'yurt',
+ 916: 'web site, website, internet site, site',
+ 917: 'comic book',
+ 918: 'crossword puzzle, crossword',
+ 919: 'street sign',
+ 920: 'traffic light, traffic signal, stoplight',
+ 921: 'book jacket, dust cover, dust jacket, dust wrapper',
+ 922: 'menu',
+ 923: 'plate',
+ 924: 'guacamole',
+ 925: 'consomme',
+ 926: 'hot pot, hotpot',
+ 927: 'trifle',
+ 928: 'ice cream, icecream',
+ 929: 'ice lolly, lolly, lollipop, popsicle',
+ 930: 'French loaf',
+ 931: 'bagel, beigel',
+ 932: 'pretzel',
+ 933: 'cheeseburger',
+ 934: 'hotdog, hot dog, red hot',
+ 935: 'mashed potato',
+ 936: 'head cabbage',
+ 937: 'broccoli',
+ 938: 'cauliflower',
+ 939: 'zucchini, courgette',
+ 940: 'spaghetti squash',
+ 941: 'acorn squash',
+ 942: 'butternut squash',
+ 943: 'cucumber, cuke',
+ 944: 'artichoke, globe artichoke',
+ 945: 'bell pepper',
+ 946: 'cardoon',
+ 947: 'mushroom',
+ 948: 'Granny Smith',
+ 949: 'strawberry',
+ 950: 'orange',
+ 951: 'lemon',
+ 952: 'fig',
+ 953: 'pineapple, ananas',
+ 954: 'banana',
+ 955: 'jackfruit, jak, jack',
+ 956: 'custard apple',
+ 957: 'pomegranate',
+ 958: 'hay',
+ 959: 'carbonara',
+ 960: 'chocolate sauce, chocolate syrup',
+ 961: 'dough',
+ 962: 'meat loaf, meatloaf',
+ 963: 'pizza, pizza pie',
+ 964: 'potpie',
+ 965: 'burrito',
+ 966: 'red wine',
+ 967: 'espresso',
+ 968: 'cup',
+ 969: 'eggnog',
+ 970: 'alp',
+ 971: 'bubble',
+ 972: 'cliff, drop, drop-off',
+ 973: 'coral reef',
+ 974: 'geyser',
+ 975: 'lakeside, lakeshore',
+ 976: 'promontory, headland, head, foreland',
+ 977: 'sandbar, sand bar',
+ 978: 'seashore, coast, seacoast, sea-coast',
+ 979: 'valley, vale',
+ 980: 'volcano',
+ 981: 'ballplayer, baseball player',
+ 982: 'groom, bridegroom',
+ 983: 'scuba diver',
+ 984: 'rapeseed',
+ 985: 'daisy',
+ 986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
+ 987: 'corn',
+ 988: 'acorn',
+ 989: 'hip, rose hip, rosehip',
+ 990: 'buckeye, horse chestnut, conker',
+ 991: 'coral fungus',
+ 992: 'agaric',
+ 993: 'gyromitra',
+ 994: 'stinkhorn, carrion fungus',
+ 995: 'earthstar',
+ 996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa',
+ 997: 'bolete',
+ 998: 'ear, spike, capitulum',
+ 999: 'toilet tissue, toilet paper, bathroom tissue'}
\ No newline at end of file
diff --git a/models.py b/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbb9a01d9aab53ef1f5c5b8d9dcf231d1e49ce95
--- /dev/null
+++ b/models.py
@@ -0,0 +1,350 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import numpy as np
+import math
+from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
+
+
+def modulate(x, shift, scale):
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+
+
+#################################################################################
+# Embedding Layers for Timesteps and Class Labels #
+#################################################################################
+
+class TimestepEmbedder(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+ :param t: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an (N, D) Tensor of positional embeddings.
+ """
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=t.device)
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+
+class LabelEmbedder(nn.Module):
+ """
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
+ """
+ def __init__(self, num_classes, hidden_size, dropout_prob):
+ super().__init__()
+ use_cfg_embedding = dropout_prob > 0
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
+ self.num_classes = num_classes
+ self.dropout_prob = dropout_prob
+
+ def token_drop(self, labels, force_drop_ids=None):
+ """
+ Drops labels to enable classifier-free guidance.
+ """
+ if force_drop_ids is None:
+ drop_ids = torch.rand(labels.shape[0]) < self.dropout_prob
+ else:
+ drop_ids = force_drop_ids == 1
+ labels = torch.where(drop_ids, self.num_classes, labels)
+ return labels
+
+ def forward(self, labels, train, force_drop_ids=None):
+ use_dropout = self.dropout_prob > 0
+ if (train and use_dropout) or (force_drop_ids is not None):
+ labels = self.token_drop(labels, force_drop_ids)
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+
+#################################################################################
+# Core DiT Model #
+#################################################################################
+
+class DiTBlock(nn.Module):
+ """
+ A DiT block with gated adaptive layer norm (adaLN) conditioning.
+ """
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
+ return x
+
+
+class FinalLayer(nn.Module):
+ """
+ The final layer of DiT.
+ """
+ def __init__(self, hidden_size, patch_size, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+
+class DiT(nn.Module):
+ """
+ Diffusion model with a Transformer backbone.
+ """
+ def __init__(
+ self,
+ input_size=32,
+ patch_size=2,
+ in_channels=4,
+ hidden_size=1152,
+ depth=28,
+ num_heads=16,
+ mlp_ratio=4.0,
+ class_dropout_prob=0.1,
+ num_classes=1000,
+ learn_sigma=True,
+ ):
+ super().__init__()
+ self.learn_sigma = learn_sigma
+ self.in_channels = in_channels
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
+ self.patch_size = patch_size
+ self.num_heads = num_heads
+
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
+ num_patches = self.x_embedder.num_patches
+ # Will use fixed sin-cos embedding:
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
+
+ self.blocks = nn.ModuleList([
+ DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
+ ])
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ # Initialize transformer layers:
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ self.apply(_basic_init)
+
+ # Initialize (and freeze) pos_embed by sin-cos embedding
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
+
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d)
+ w = self.x_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+
+ # Initialize label embedding table:
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ # Zero-out adaLN modulation layers in DiT blocks:
+ for block in self.blocks:
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ # Zero-out output layers:
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+ nn.init.constant_(self.final_layer.linear.weight, 0)
+ nn.init.constant_(self.final_layer.linear.bias, 0)
+
+ def unpatchify(self, x):
+ """
+ x: (N, T, patch_size**2 * C)
+ imgs: (N, H, W, C)
+ """
+ c = self.out_channels
+ p = self.x_embedder.patch_size[0]
+ h = w = int(x.shape[1] ** 0.5)
+ assert h * w == x.shape[1]
+
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
+ x = torch.einsum('nhwpqc->nchpwq', x)
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
+ return imgs
+
+ def forward(self, x, t, y):
+ """
+ Forward pass of DiT.
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
+ t: (N,) tensor of diffusion timesteps
+ y: (N,) tensor of class labels
+ """
+ x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
+ t = self.t_embedder(t) # (N, D)
+ y = self.y_embedder(y, self.training) # (N, D)
+ c = t + y # (N, D)
+ for block in self.blocks:
+ x = block(x, c) # (N, T, D)
+ x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
+ x = self.unpatchify(x) # (N, out_channels, H, W)
+ return x
+
+ def forward_with_cfg(self, x, t, y, cfg_scale):
+ """
+ Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
+ half = x[: len(x) // 2]
+ combined = torch.cat([half, half], dim=0)
+ model_out = self.forward(combined, t, y)
+ eps, rest = model_out[:, :3], model_out[:, 3:]
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
+ eps = torch.cat([half_eps, half_eps], dim=0)
+ return torch.cat([eps, rest], dim=1)
+
+
+#################################################################################
+# Sine/Cosine Positional Embedding Functions #
+#################################################################################
+
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size, grid_size])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token and extra_tokens > 0:
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
+ omega /= embed_dim / 2.
+ omega = 1. / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+#################################################################################
+# DiT Configs #
+#################################################################################
+
+def DiT_XL_2(**kwargs):
+ return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
+
+def DiT_XL_4(**kwargs):
+ return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
+
+def DiT_XL_8(**kwargs):
+ return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
+
+def DiT_L_2(**kwargs):
+ return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
+
+def DiT_L_4(**kwargs):
+ return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
+
+def DiT_L_8(**kwargs):
+ return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
+
+def DiT_B_2(**kwargs):
+ return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
+
+def DiT_B_4(**kwargs):
+ return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
+
+def DiT_B_8(**kwargs):
+ return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
+
+def DiT_S_2(**kwargs):
+ return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
+
+def DiT_S_4(**kwargs):
+ return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
+
+def DiT_S_8(**kwargs):
+ return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0ebd02af4bc54eb452dbaae36fca4c5621b1b99b
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,8 @@
+torch
+torchvision
+numpy
+tqdm
+timm
+pillow
+diffusers
+accelerate