diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..c7d9f3332a950355d5a77d85000f05e6f45435ea
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,34 @@
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ckpt filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.mlmodel filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+*.safetensors filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..a078dddb0572050e2a91fe698750ddcdb9d104f9
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1 @@
+hf_auth
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..bb65ea1bf49577f47bb00e4b3f2eee4a4dc81823
--- /dev/null
+++ b/README.md
@@ -0,0 +1,14 @@
+---
+title: EDICT
+emoji: ⚡
+colorFrom: indigo
+colorTo: red
+sdk: gradio
+sdk_version: 3.18.0
+app_file: app.py
+pinned: false
+license: bsd-3-clause
+duplicated_from: Salesforce/EDICT
+---
+
+Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9681c98dc0af65201967d48132b71b811ab1256
--- /dev/null
+++ b/app.py
@@ -0,0 +1,146 @@
+import gradio as gr
+import numpy as np
+# from edict_functions import EDICT_editing
+from PIL import Image
+from utils import Endpoint, get_token
+from io import BytesIO
+import requests
+
+
+endpoint = Endpoint()
+
+def local_edict(x, source_text, edit_text,
+ edit_strength, guidance_scale,
+ steps=50, mix_weight=0.93, ):
+ x = Image.fromarray(x)
+ return_im = EDICT_editing(x,
+ source_text,
+ edit_text,
+ steps=steps,
+ mix_weight=mix_weight,
+ init_image_strength=edit_strength,
+ guidance_scale=guidance_scale
+ )[0]
+ return np.array(return_im)
+
+def encode_image(image):
+ buffered = BytesIO()
+ image.save(buffered, format="JPEG", quality=95)
+ buffered.seek(0)
+
+ return buffered
+
+
+
+def decode_image(img_obj):
+ img = Image.open(img_obj).convert("RGB")
+ return img
+
+def edict(x, source_text, edit_text,
+ edit_strength, guidance_scale,
+ steps=50, mix_weight=0.93, ):
+
+ url = endpoint.url
+ url = url + "/api/edit"
+ headers = {### Misc.
+
+ "User-Agent": "EDICT HuggingFace Space",
+ "Auth-Token": get_token(),
+ }
+
+ data = {
+ "source_text": source_text,
+ "edit_text": edit_text,
+ "edit_strength": edit_strength,
+ "guidance_scale": guidance_scale,
+ }
+
+ image = encode_image(Image.fromarray(x))
+ files = {"image": image}
+
+ response = requests.post(url, data=data, files=files, headers=headers)
+
+ if response.status_code == 200:
+ return np.array(decode_image(BytesIO(response.content)))
+ else:
+ return "Error: " + response.text
+ # x = decode_image(response)
+ # return np.array(x)
+
+examples = [
+ ['square_ims/american_gothic.jpg', 'A painting of two people frowning', 'A painting of two people smiling', 0.5, 3],
+ ['square_ims/colloseum.jpg', 'An old ruined building', 'A new modern office building', 0.8, 3],
+ ]
+
+
+examples.append(['square_ims/scream.jpg', 'A painting of someone screaming', 'A painting of an alien', 0.5, 3])
+examples.append(['square_ims/yosemite.jpg', 'Granite forest valley', 'Granite desert valley', 0.8, 3])
+examples.append(['square_ims/einstein.jpg', 'Mouth open', 'Mouth closed', 0.8, 3])
+examples.append(['square_ims/einstein.jpg', 'A man', 'A man in K.I.S.S. facepaint', 0.8, 3])
+"""
+examples.extend([
+ ['square_ims/imagenet_cake_2.jpg', 'A cupcake', 'A Chinese New Year cupcake', 0.8, 3],
+ ['square_ims/imagenet_cake_2.jpg', 'A cupcake', 'A Union Jack cupcake', 0.8, 3],
+ ['square_ims/imagenet_cake_2.jpg', 'A cupcake', 'A Nigerian flag cupcake', 0.8, 3],
+ ['square_ims/imagenet_cake_2.jpg', 'A cupcake', 'A Santa Claus cupcake', 0.8, 3],
+ ['square_ims/imagenet_cake_2.jpg', 'A cupcake', 'An Easter cupcake', 0.8, 3],
+ ['square_ims/imagenet_cake_2.jpg', 'A cupcake', 'A hedgehog cupcake', 0.8, 3],
+ ['square_ims/imagenet_cake_2.jpg', 'A cupcake', 'A rose cupcake', 0.8, 3],
+ ])
+"""
+
+for dog_i in [1, 2]:
+ for breed in ['Golden Retriever', 'Chihuahua', 'Dalmatian']:
+ examples.append([f'square_ims/imagenet_dog_{dog_i}.jpg', 'A dog', f'A {breed}', 0.8, 3])
+
+
+description = """
+**We have disabled image uploading from March 22. 2023.**
+
+**Please try examples provided below.**
+
+A gradio demo for [EDICT](https://arxiv.org/abs/2211.12446) (CVPR23)
+"""
+# description = gr.Markdown(description)
+
+article = """
+
+### Prompting Style
+
+As with many text-to-image methods, the prompting style of EDICT can make a big difference. When in doubt, experiment! Some guidance:
+* Parallel *Original Description* and *Edit Description* construction as much as possible. Inserting/editing single words often is enough to affect a change while maintaining a lot of the original structure
+* Words that will affect the entire setting (e.g. "A photo of " vs. "A painting of") can make a big difference. Playing around with them can help a lot
+
+### Parameters
+Both `edit_strength` and `guidance_scale` have similar properties qualitatively: the higher the value the more the image will change. We suggest
+* Increasing/decreasing `edit_strength` first, particularly to alter/preserve more of the original structure/content
+* Then changing `guidance_scale` to make the change in the edited region more or less pronounced.
+
+Usually we find changing `edit_strength` to be enough, but feel free to play around (and report any interesting results)!
+
+### Misc.
+
+Having difficulty coming up with a caption? Try [BLIP](https://huggingface.co/spaces/Salesforce/BLIP2) to automatically generate one!
+
+As with most StableDiffusion approaches, faces/text are often problematic to render, especially if they're small. Having these in the foreground will help keep them cleaner.
+
+A returned black image means that the [Safety Checker](https://huggingface.co/CompVis/stable-diffusion-safety-checker) triggered on the photo. This happens in odd cases sometimes (it often rejects
+the huggingface logo or variations), but we need to keep it in for obvious reasons.
+"""
+# article = gr.Markdown(description)
+
+iface = gr.Interface(fn=edict, inputs=[gr.Image(interactive=False),
+ gr.Textbox(label="Original Description", interactive=False),
+ gr.Textbox(label="Edit Description", interactive=False),
+ # 50, # gr.Slider(5, 50, value=20, step=1),
+ # 0.93, # gr.Slider(0.5, 1, value=0.7, step=0.05),
+ gr.Slider(0.0, 1, value=0.8, step=0.05),
+ gr.Slider(0, 10, value=3, step=0.5),
+ ],
+ examples = examples,
+ outputs="image",
+ description=description,
+ article=article,
+ cache_examples=True
+ )
+iface.launch()
diff --git a/app_fully_disabled.py b/app_fully_disabled.py
new file mode 100644
index 0000000000000000000000000000000000000000..7da4e8ee37f180579a1b98add2fecb496587d4fe
--- /dev/null
+++ b/app_fully_disabled.py
@@ -0,0 +1,285 @@
+from io import BytesIO
+
+import string
+import gradio as gr
+import requests
+from utils import Endpoint, get_token
+
+
+def encode_image(image):
+ buffered = BytesIO()
+ image.save(buffered, format="JPEG")
+ buffered.seek(0)
+
+ return buffered
+
+
+def query_chat_api(
+ image, prompt, decoding_method, temperature, len_penalty, repetition_penalty
+):
+
+ url = endpoint.url
+ url = url + "/api/generate"
+
+ headers = {
+ "User-Agent": "BLIP-2 HuggingFace Space",
+ "Auth-Token": get_token(),
+ }
+
+ data = {
+ "prompt": prompt,
+ "use_nucleus_sampling": decoding_method == "Nucleus sampling",
+ "temperature": temperature,
+ "length_penalty": len_penalty,
+ "repetition_penalty": repetition_penalty,
+ }
+
+ image = encode_image(image)
+ files = {"image": image}
+
+ response = requests.post(url, data=data, files=files, headers=headers)
+
+ if response.status_code == 200:
+ return response.json()
+ else:
+ return "Error: " + response.text
+
+
+def query_caption_api(
+ image, decoding_method, temperature, len_penalty, repetition_penalty
+):
+
+ url = endpoint.url
+ url = url + "/api/caption"
+
+ headers = {
+ "User-Agent": "BLIP-2 HuggingFace Space",
+ "Auth-Token": get_token(),
+ }
+
+ data = {
+ "use_nucleus_sampling": decoding_method == "Nucleus sampling",
+ "temperature": temperature,
+ "length_penalty": len_penalty,
+ "repetition_penalty": repetition_penalty,
+ }
+
+ image = encode_image(image)
+ files = {"image": image}
+
+ response = requests.post(url, data=data, files=files, headers=headers)
+
+ if response.status_code == 200:
+ return response.json()
+ else:
+ return "Error: " + response.text
+
+
+def postprocess_output(output):
+ # if last character is not a punctuation, add a full stop
+ if not output[0][-1] in string.punctuation:
+ output[0] += "."
+
+ return output
+
+
+def inference_chat(
+ image,
+ text_input,
+ decoding_method,
+ temperature,
+ length_penalty,
+ repetition_penalty,
+ history=[],
+):
+ text_input = text_input
+ history.append(text_input)
+
+ prompt = " ".join(history)
+
+ output = query_chat_api(
+ image, prompt, decoding_method, temperature, length_penalty, repetition_penalty
+ )
+ output = postprocess_output(output)
+ history += output
+
+ chat = [
+ (history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)
+ ] # convert to tuples of list
+
+ return {chatbot: chat, state: history}
+
+
+def inference_caption(
+ image,
+ decoding_method,
+ temperature,
+ length_penalty,
+ repetition_penalty,
+):
+ output = query_caption_api(
+ image, decoding_method, temperature, length_penalty, repetition_penalty
+ )
+
+ return output[0]
+
+
+title = """
BLIP-2
"""
+description = """Gradio demo for BLIP-2, image-to-text generation from Salesforce Research. To use it, simply upload your image, or click one of the examples to load them.
+
Disclaimer: This is a research prototype and is not intended for production use. No data including but not restricted to text and images is collected."""
+article = """Paper: BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models
+
Code: BLIP2 is now integrated into GitHub repo: LAVIS: a One-stop Library for Language and Vision
+
🤗 `transformers` integration: You can now use `transformers` to use our BLIP-2 models! Check out the official docs
+ Project Page: BLIP2 on LAVIS
+
Description: Captioning results from BLIP2_OPT_6.7B. Chat results from BLIP2_FlanT5xxl.
+
+
For safety and ethical considerations, we have disabled image uploading from March 21. 2023.
+
Please try examples provided below.
+"""
+
+endpoint = Endpoint()
+
+examples = [
+ ["house.png", "How could someone get out of the house?"],
+ ["flower.jpg", "Question: What is this flower and where is it's origin? Answer:"],
+ ["pizza.jpg", "What are steps to cook it?"],
+ ["sunset.jpg", "Here is a romantic message going along the photo:"],
+ ["forbidden_city.webp", "In what dynasties was this place built?"],
+]
+
+with gr.Blocks(
+ css="""
+ .message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
+ #component-21 > div.wrap.svelte-w6rprc {height: 600px;}
+ """
+) as iface:
+ state = gr.State([])
+
+ gr.Markdown(title)
+ gr.Markdown(description)
+ gr.Markdown(article)
+
+ with gr.Row():
+ with gr.Column(scale=1):
+ image_input = gr.Image(type="pil", interactive=False)
+
+ # with gr.Row():
+ sampling = gr.Radio(
+ choices=["Beam search", "Nucleus sampling"],
+ value="Beam search",
+ label="Text Decoding Method",
+ interactive=True,
+ )
+
+ temperature = gr.Slider(
+ minimum=0.5,
+ maximum=1.0,
+ value=1.0,
+ step=0.1,
+ interactive=True,
+ label="Temperature (used with nucleus sampling)",
+ )
+
+ len_penalty = gr.Slider(
+ minimum=-1.0,
+ maximum=2.0,
+ value=1.0,
+ step=0.2,
+ interactive=True,
+ label="Length Penalty (set to larger for longer sequence, used with beam search)",
+ )
+
+ rep_penalty = gr.Slider(
+ minimum=1.0,
+ maximum=5.0,
+ value=1.5,
+ step=0.5,
+ interactive=True,
+ label="Repeat Penalty (larger value prevents repetition)",
+ )
+
+ with gr.Column(scale=1.8):
+
+ with gr.Column():
+ caption_output = gr.Textbox(lines=1, label="Caption Output")
+ caption_button = gr.Button(
+ value="Caption it!", interactive=True, variant="primary"
+ )
+ caption_button.click(
+ inference_caption,
+ [
+ image_input,
+ sampling,
+ temperature,
+ len_penalty,
+ rep_penalty,
+ ],
+ [caption_output],
+ )
+
+ gr.Markdown("""Trying prompting your input for chat; e.g. example prompt for QA, \"Question: {} Answer:\" Use proper punctuation (e.g., question mark).""")
+ with gr.Row():
+ with gr.Column(
+ scale=1.5,
+ ):
+ chatbot = gr.Chatbot(
+ label="Chat Output (from FlanT5)",
+ )
+
+ # with gr.Row():
+ with gr.Column(scale=1):
+ chat_input = gr.Textbox(lines=1, label="Chat Input")
+ chat_input.submit(
+ inference_chat,
+ [
+ image_input,
+ chat_input,
+ sampling,
+ temperature,
+ len_penalty,
+ rep_penalty,
+ state,
+ ],
+ [chatbot, state],
+ )
+
+ with gr.Row():
+ clear_button = gr.Button(value="Clear", interactive=True)
+ clear_button.click(
+ lambda: ("", [], []),
+ [],
+ [chat_input, chatbot, state],
+ queue=False,
+ )
+
+ submit_button = gr.Button(
+ value="Submit", interactive=True, variant="primary"
+ )
+ submit_button.click(
+ inference_chat,
+ [
+ image_input,
+ chat_input,
+ sampling,
+ temperature,
+ len_penalty,
+ rep_penalty,
+ state,
+ ],
+ [chatbot, state],
+ )
+
+ image_input.change(
+ lambda: ("", "", []),
+ [],
+ [chatbot, caption_output, state],
+ queue=False,
+ )
+
+ examples = gr.Examples(
+ examples=examples,
+ inputs=[image_input, chat_input],
+ )
+
+iface.queue(concurrency_count=1, api_open=False, max_size=10)
+iface.launch(enable_queue=True)
diff --git a/edict_functions.py b/edict_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..c68fc76f7f474002e6622ef2ee2bebdb7ca37b76
--- /dev/null
+++ b/edict_functions.py
@@ -0,0 +1,997 @@
+import torch
+from transformers import CLIPModel, CLIPTextModel, CLIPTokenizer
+from omegaconf import OmegaConf
+import math
+import imageio
+from PIL import Image
+import torchvision
+import torch.nn.functional as F
+import torch
+import numpy as np
+from PIL import Image
+import time
+import datetime
+import torch
+import sys
+import os
+from torchvision import datasets
+import pickle
+
+
+
+# StableDiffusion P2P implementation originally from https://github.com/bloc97/CrossAttentionControl
+use_half_prec = True
+if use_half_prec:
+ from my_half_diffusers import AutoencoderKL, UNet2DConditionModel
+ from my_half_diffusers.schedulers.scheduling_utils import SchedulerOutput
+ from my_half_diffusers import LMSDiscreteScheduler, PNDMScheduler, DDPMScheduler, DDIMScheduler
+else:
+ from my_diffusers import AutoencoderKL, UNet2DConditionModel
+ from my_diffusers.schedulers.scheduling_utils import SchedulerOutput
+ from my_diffusers import LMSDiscreteScheduler, PNDMScheduler, DDPMScheduler, DDIMScheduler
+torch_dtype = torch.float16 if use_half_prec else torch.float64
+np_dtype = np.float16 if use_half_prec else np.float64
+
+
+
+import random
+from tqdm.auto import tqdm
+from torch import autocast
+from difflib import SequenceMatcher
+
+# Build our CLIP model
+model_path_clip = "openai/clip-vit-large-patch14"
+clip_tokenizer = CLIPTokenizer.from_pretrained(model_path_clip)
+clip_model = CLIPModel.from_pretrained(model_path_clip, torch_dtype=torch_dtype)
+clip = clip_model.text_model
+
+
+# Getting our HF Auth token
+auth_token = os.environ.get('auth_token')
+if auth_token is None:
+ with open('hf_auth', 'r') as f:
+ auth_token = f.readlines()[0].strip()
+model_path_diffusion = "CompVis/stable-diffusion-v1-4"
+# Build our SD model
+unet = UNet2DConditionModel.from_pretrained(model_path_diffusion, subfolder="unet", use_auth_token=auth_token, revision="fp16", torch_dtype=torch_dtype)
+vae = AutoencoderKL.from_pretrained(model_path_diffusion, subfolder="vae", use_auth_token=auth_token, revision="fp16", torch_dtype=torch_dtype)
+
+# Push to devices w/ double precision
+device = 'cuda'
+if use_half_prec:
+ unet.to(device)
+ vae.to(device)
+ clip.to(device)
+else:
+ unet.double().to(device)
+ vae.double().to(device)
+ clip.double().to(device)
+print("Loaded all models")
+
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from transformers import AutoFeatureExtractor
+# load safety model
+safety_model_id = "CompVis/stable-diffusion-safety-checker"
+safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
+safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
+def load_replacement(x):
+ try:
+ hwc = x.shape
+ y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
+ y = (np.array(y)/255.0).astype(x.dtype)
+ assert y.shape == x.shape
+ return y
+ except Exception:
+ return x
+def check_safety(x_image):
+ safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
+ x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
+ assert x_checked_image.shape[0] == len(has_nsfw_concept)
+ for i in range(len(has_nsfw_concept)):
+ if has_nsfw_concept[i]:
+ # x_checked_image[i] = load_replacement(x_checked_image[i])
+ x_checked_image[i] *= 0 # load_replacement(x_checked_image[i])
+ return x_checked_image, has_nsfw_concept
+
+
+def EDICT_editing(im_path,
+ base_prompt,
+ edit_prompt,
+ use_p2p=False,
+ steps=50,
+ mix_weight=0.93,
+ init_image_strength=0.8,
+ guidance_scale=3,
+ run_baseline=False,
+ width=512, height=512):
+ """
+ Main call of our research, performs editing with either EDICT or DDIM
+
+ Args:
+ im_path: path to image to run on
+ base_prompt: conditional prompt to deterministically noise with
+ edit_prompt: desired text conditoining
+ steps: ddim steps
+ mix_weight: Weight of mixing layers.
+ Higher means more consistent generations but divergence in inversion
+ Lower means opposite
+ This is fairly tuned and can get good results
+ init_image_strength: Editing strength. Higher = more dramatic edit.
+ Typically [0.6, 0.9] is good range.
+ Definitely tunable per-image/maybe best results are at a different value
+ guidance_scale: classifier-free guidance scale
+ 3 I've found is the best for both our method and basic DDIM inversion
+ Higher can result in more distorted results
+ run_baseline:
+ VERY IMPORTANT
+ True is EDICT, False is DDIM
+ Output:
+ PAIR of Images (tuple)
+ If run_baseline=True then [0] will be edit and [1] will be original
+ If run_baseline=False then they will be two nearly identical edited versions
+ """
+ # Resize/center crop to 512x512 (Can do higher res. if desired)
+ if isinstance(im_path, str):
+ orig_im = load_im_into_format_from_path(im_path)
+ elif Image.isImageType(im_path):
+ width, height = im_path.size
+
+
+ # add max dim for sake of memory
+ max_dim = max(width, height)
+ if max_dim > 1024:
+ factor = 1024 / max_dim
+ width *= factor
+ height *= factor
+ width = int(width)
+ height = int(height)
+ im_path = im_path.resize((width, height))
+
+ min_dim = min(width, height)
+ if min_dim < 512:
+ factor = 512 / min_dim
+ width *= factor
+ height *= factor
+ width = int(width)
+ height = int(height)
+ im_path = im_path.resize((width, height))
+
+ width = width - (width%64)
+ height = height - (height%64)
+
+ orig_im = im_path # general_crop(im_path, width, height)
+ else:
+ orig_im = im_path
+
+ # compute latent pair (second one will be original latent if run_baseline=True)
+ latents = coupled_stablediffusion(base_prompt,
+ reverse=True,
+ init_image=orig_im,
+ init_image_strength=init_image_strength,
+ steps=steps,
+ mix_weight=mix_weight,
+ guidance_scale=guidance_scale,
+ run_baseline=run_baseline,
+ width=width, height=height)
+ # Denoise intermediate state with new conditioning
+ gen = coupled_stablediffusion(edit_prompt if (not use_p2p) else base_prompt,
+ None if (not use_p2p) else edit_prompt,
+ fixed_starting_latent=latents,
+ init_image_strength=init_image_strength,
+ steps=steps,
+ mix_weight=mix_weight,
+ guidance_scale=guidance_scale,
+ run_baseline=run_baseline,
+ width=width, height=height)
+
+ return gen
+
+
+def img2img_editing(im_path,
+ edit_prompt,
+ steps=50,
+ init_image_strength=0.7,
+ guidance_scale=3):
+ """
+ Basic SDEdit/img2img, given an image add some noise and denoise with prompt
+ """
+ orig_im = load_im_into_format_from_path(im_path)
+
+ return baseline_stablediffusion(edit_prompt,
+ init_image_strength=init_image_strength,
+ steps=steps,
+ init_image=orig_im,
+ guidance_scale=guidance_scale)
+
+
+def center_crop(im):
+ width, height = im.size # Get dimensions
+ min_dim = min(width, height)
+ left = (width - min_dim)/2
+ top = (height - min_dim)/2
+ right = (width + min_dim)/2
+ bottom = (height + min_dim)/2
+
+ # Crop the center of the image
+ im = im.crop((left, top, right, bottom))
+ return im
+
+
+
+def general_crop(im, target_w, target_h):
+ width, height = im.size # Get dimensions
+ min_dim = min(width, height)
+ left = target_w / 2 # (width - min_dim)/2
+ top = target_h / 2 # (height - min_dim)/2
+ right = width - (target_w / 2) # (width + min_dim)/2
+ bottom = height - (target_h / 2) # (height + min_dim)/2
+
+ # Crop the center of the image
+ im = im.crop((left, top, right, bottom))
+ return im
+
+
+
+def load_im_into_format_from_path(im_path):
+ return center_crop(Image.open(im_path)).resize((512,512))
+
+
+#### P2P STUFF ####
+def init_attention_weights(weight_tuples):
+ tokens_length = clip_tokenizer.model_max_length
+ weights = torch.ones(tokens_length)
+
+ for i, w in weight_tuples:
+ if i < tokens_length and i >= 0:
+ weights[i] = w
+
+
+ for name, module in unet.named_modules():
+ module_name = type(module).__name__
+ if module_name == "CrossAttention" and "attn2" in name:
+ module.last_attn_slice_weights = weights.to(device)
+ if module_name == "CrossAttention" and "attn1" in name:
+ module.last_attn_slice_weights = None
+
+
+def init_attention_edit(tokens, tokens_edit):
+ tokens_length = clip_tokenizer.model_max_length
+ mask = torch.zeros(tokens_length)
+ indices_target = torch.arange(tokens_length, dtype=torch.long)
+ indices = torch.zeros(tokens_length, dtype=torch.long)
+
+ tokens = tokens.input_ids.numpy()[0]
+ tokens_edit = tokens_edit.input_ids.numpy()[0]
+
+ for name, a0, a1, b0, b1 in SequenceMatcher(None, tokens, tokens_edit).get_opcodes():
+ if b0 < tokens_length:
+ if name == "equal" or (name == "replace" and a1-a0 == b1-b0):
+ mask[b0:b1] = 1
+ indices[b0:b1] = indices_target[a0:a1]
+
+ for name, module in unet.named_modules():
+ module_name = type(module).__name__
+ if module_name == "CrossAttention" and "attn2" in name:
+ module.last_attn_slice_mask = mask.to(device)
+ module.last_attn_slice_indices = indices.to(device)
+ if module_name == "CrossAttention" and "attn1" in name:
+ module.last_attn_slice_mask = None
+ module.last_attn_slice_indices = None
+
+
+def init_attention_func():
+ def new_attention(self, query, key, value, sequence_length, dim):
+ batch_size_attention = query.shape[0]
+ hidden_states = torch.zeros(
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
+ )
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
+ for i in range(hidden_states.shape[0] // slice_size):
+ start_idx = i * slice_size
+ end_idx = (i + 1) * slice_size
+ attn_slice = (
+ torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale
+ )
+ attn_slice = attn_slice.softmax(dim=-1)
+
+ if self.use_last_attn_slice:
+ if self.last_attn_slice_mask is not None:
+ new_attn_slice = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices)
+ attn_slice = attn_slice * (1 - self.last_attn_slice_mask) + new_attn_slice * self.last_attn_slice_mask
+ else:
+ attn_slice = self.last_attn_slice
+
+ self.use_last_attn_slice = False
+
+ if self.save_last_attn_slice:
+ self.last_attn_slice = attn_slice
+ self.save_last_attn_slice = False
+
+ if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
+ attn_slice = attn_slice * self.last_attn_slice_weights
+ self.use_last_attn_weights = False
+
+ attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
+
+ hidden_states[start_idx:end_idx] = attn_slice
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+ for name, module in unet.named_modules():
+ module_name = type(module).__name__
+ if module_name == "CrossAttention":
+ module.last_attn_slice = None
+ module.use_last_attn_slice = False
+ module.use_last_attn_weights = False
+ module.save_last_attn_slice = False
+ module._attention = new_attention.__get__(module, type(module))
+
+def use_last_tokens_attention(use=True):
+ for name, module in unet.named_modules():
+ module_name = type(module).__name__
+ if module_name == "CrossAttention" and "attn2" in name:
+ module.use_last_attn_slice = use
+
+def use_last_tokens_attention_weights(use=True):
+ for name, module in unet.named_modules():
+ module_name = type(module).__name__
+ if module_name == "CrossAttention" and "attn2" in name:
+ module.use_last_attn_weights = use
+
+def use_last_self_attention(use=True):
+ for name, module in unet.named_modules():
+ module_name = type(module).__name__
+ if module_name == "CrossAttention" and "attn1" in name:
+ module.use_last_attn_slice = use
+
+def save_last_tokens_attention(save=True):
+ for name, module in unet.named_modules():
+ module_name = type(module).__name__
+ if module_name == "CrossAttention" and "attn2" in name:
+ module.save_last_attn_slice = save
+
+def save_last_self_attention(save=True):
+ for name, module in unet.named_modules():
+ module_name = type(module).__name__
+ if module_name == "CrossAttention" and "attn1" in name:
+ module.save_last_attn_slice = save
+####################################
+
+
+##### BASELINE ALGORITHM, ONLY USED NOW FOR SDEDIT ####3
+
+@torch.no_grad()
+def baseline_stablediffusion(prompt="",
+ prompt_edit=None,
+ null_prompt='',
+ prompt_edit_token_weights=[],
+ prompt_edit_tokens_start=0.0,
+ prompt_edit_tokens_end=1.0,
+ prompt_edit_spatial_start=0.0,
+ prompt_edit_spatial_end=1.0,
+ clip_start=0.0,
+ clip_end=1.0,
+ guidance_scale=7,
+ steps=50,
+ seed=1,
+ width=512, height=512,
+ init_image=None, init_image_strength=0.5,
+ fixed_starting_latent = None,
+ prev_image= None,
+ grid=None,
+ clip_guidance=None,
+ clip_guidance_scale=1,
+ num_cutouts=4,
+ cut_power=1,
+ scheduler_str='lms',
+ return_latent=False,
+ one_pass=False,
+ normalize_noise_pred=False):
+ width = width - width % 64
+ height = height - height % 64
+
+ #If seed is None, randomly select seed from 0 to 2^32-1
+ if seed is None: seed = random.randrange(2**32 - 1)
+ generator = torch.cuda.manual_seed(seed)
+
+ #Set inference timesteps to scheduler
+ scheduler_dict = {'ddim':DDIMScheduler,
+ 'lms':LMSDiscreteScheduler,
+ 'pndm':PNDMScheduler,
+ 'ddpm':DDPMScheduler}
+ scheduler_call = scheduler_dict[scheduler_str]
+ if scheduler_str == 'ddim':
+ scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012,
+ beta_schedule="scaled_linear",
+ clip_sample=False, set_alpha_to_one=False)
+ else:
+ scheduler = scheduler_call(beta_schedule="scaled_linear",
+ num_train_timesteps=1000)
+
+ scheduler.set_timesteps(steps)
+ if prev_image is not None:
+ prev_scheduler = LMSDiscreteScheduler(beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ num_train_timesteps=1000)
+ prev_scheduler.set_timesteps(steps)
+
+ #Preprocess image if it exists (img2img)
+ if init_image is not None:
+ init_image = init_image.resize((width, height), resample=Image.Resampling.LANCZOS)
+ init_image = np.array(init_image).astype(np_dtype) / 255.0 * 2.0 - 1.0
+ init_image = torch.from_numpy(init_image[np.newaxis, ...].transpose(0, 3, 1, 2))
+
+ #If there is alpha channel, composite alpha for white, as the diffusion model does not support alpha channel
+ if init_image.shape[1] > 3:
+ init_image = init_image[:, :3] * init_image[:, 3:] + (1 - init_image[:, 3:])
+
+ #Move image to GPU
+ init_image = init_image.to(device)
+
+ #Encode image
+ with autocast(device):
+ init_latent = vae.encode(init_image).latent_dist.sample(generator=generator) * 0.18215
+
+ t_start = steps - int(steps * init_image_strength)
+
+ else:
+ init_latent = torch.zeros((1, unet.in_channels, height // 8, width // 8), device=device)
+ t_start = 0
+
+ #Generate random normal noise
+ if fixed_starting_latent is None:
+ noise = torch.randn(init_latent.shape, generator=generator, device=device, dtype=unet.dtype)
+ if scheduler_str == 'ddim':
+ if init_image is not None:
+ raise notImplementedError
+ latent = scheduler.add_noise(init_latent, noise,
+ 1000 - int(1000 * init_image_strength)).to(device)
+ else:
+ latent = noise
+ else:
+ latent = scheduler.add_noise(init_latent, noise,
+ t_start).to(device)
+ else:
+ latent = fixed_starting_latent
+ t_start = steps - int(steps * init_image_strength)
+
+ if prev_image is not None:
+ #Resize and prev_image for numpy b h w c -> torch b c h w
+ prev_image = prev_image.resize((width, height), resample=Image.Resampling.LANCZOS)
+ prev_image = np.array(prev_image).astype(np_dtype) / 255.0 * 2.0 - 1.0
+ prev_image = torch.from_numpy(prev_image[np.newaxis, ...].transpose(0, 3, 1, 2))
+
+ #If there is alpha channel, composite alpha for white, as the diffusion model does not support alpha channel
+ if prev_image.shape[1] > 3:
+ prev_image = prev_image[:, :3] * prev_image[:, 3:] + (1 - prev_image[:, 3:])
+
+ #Move image to GPU
+ prev_image = prev_image.to(device)
+
+ #Encode image
+ with autocast(device):
+ prev_init_latent = vae.encode(prev_image).latent_dist.sample(generator=generator) * 0.18215
+
+ t_start = steps - int(steps * init_image_strength)
+
+ prev_latent = prev_scheduler.add_noise(prev_init_latent, noise, t_start).to(device)
+ else:
+ prev_latent = None
+
+
+ #Process clip
+ with autocast(device):
+ tokens_unconditional = clip_tokenizer(null_prompt, padding="max_length", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", return_overflowing_tokens=True)
+ embedding_unconditional = clip(tokens_unconditional.input_ids.to(device)).last_hidden_state
+
+ tokens_conditional = clip_tokenizer(prompt, padding="max_length", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", return_overflowing_tokens=True)
+ embedding_conditional = clip(tokens_conditional.input_ids.to(device)).last_hidden_state
+
+ #Process prompt editing
+ assert not ((prompt_edit is not None) and (prev_image is not None))
+ if prompt_edit is not None:
+ tokens_conditional_edit = clip_tokenizer(prompt_edit, padding="max_length", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", return_overflowing_tokens=True)
+ embedding_conditional_edit = clip(tokens_conditional_edit.input_ids.to(device)).last_hidden_state
+ init_attention_edit(tokens_conditional, tokens_conditional_edit)
+ elif prev_image is not None:
+ init_attention_edit(tokens_conditional, tokens_conditional)
+
+
+ init_attention_func()
+ init_attention_weights(prompt_edit_token_weights)
+
+ timesteps = scheduler.timesteps[t_start:]
+ # print(timesteps)
+
+ assert isinstance(guidance_scale, int)
+ num_cycles = 1 # guidance_scale + 1
+
+ last_noise_preds = None
+ for i, t in tqdm(enumerate(timesteps), total=len(timesteps)):
+ t_index = t_start + i
+
+ latent_model_input = latent
+ if scheduler_str=='lms':
+ sigma = scheduler.sigmas[t_index] # last is first and first is last
+ latent_model_input = (latent_model_input / ((sigma**2 + 1) ** 0.5)).to(unet.dtype)
+ else:
+ assert scheduler_str in ['ddim', 'pndm', 'ddpm']
+
+ #Predict the unconditional noise residual
+
+ if len(t.shape) == 0:
+ t = t[None].to(unet.device)
+ noise_pred_uncond = unet(latent_model_input, t, encoder_hidden_states=embedding_unconditional,
+ ).sample
+
+ if prev_latent is not None:
+ prev_latent_model_input = prev_latent
+ prev_latent_model_input = (prev_latent_model_input / ((sigma**2 + 1) ** 0.5)).to(unet.dtype)
+ prev_noise_pred_uncond = unet(prev_latent_model_input, t,
+ encoder_hidden_states=embedding_unconditional,
+ ).sample
+ # noise_pred_uncond = unet(latent_model_input, t,
+ # encoder_hidden_states=embedding_unconditional)['sample']
+
+ #Prepare the Cross-Attention layers
+ if prompt_edit is not None or prev_latent is not None:
+ save_last_tokens_attention()
+ save_last_self_attention()
+ else:
+ #Use weights on non-edited prompt when edit is None
+ use_last_tokens_attention_weights()
+
+ #Predict the conditional noise residual and save the cross-attention layer activations
+ if prev_latent is not None:
+ raise NotImplementedError # I totally lost track of what this is
+ prev_noise_pred_cond = unet(prev_latent_model_input, t, encoder_hidden_states=embedding_conditional,
+ ).sample
+ else:
+ noise_pred_cond = unet(latent_model_input, t, encoder_hidden_states=embedding_conditional,
+ ).sample
+
+ #Edit the Cross-Attention layer activations
+ t_scale = t / scheduler.num_train_timesteps
+ if prompt_edit is not None or prev_latent is not None:
+ if t_scale >= prompt_edit_tokens_start and t_scale <= prompt_edit_tokens_end:
+ use_last_tokens_attention()
+ if t_scale >= prompt_edit_spatial_start and t_scale <= prompt_edit_spatial_end:
+ use_last_self_attention()
+
+ #Use weights on edited prompt
+ use_last_tokens_attention_weights()
+
+ #Predict the edited conditional noise residual using the cross-attention masks
+ if prompt_edit is not None:
+ noise_pred_cond = unet(latent_model_input, t,
+ encoder_hidden_states=embedding_conditional_edit).sample
+
+ #Perform guidance
+ # if i%(num_cycles)==0: # cycle_i+1==num_cycles:
+ """
+ if cycle_i+1==num_cycles:
+ noise_pred = noise_pred_uncond
+ else:
+ noise_pred = noise_pred_cond - noise_pred_uncond
+
+ """
+ if last_noise_preds is not None:
+ # print( (last_noise_preds[0]*noise_pred_uncond).sum(), (last_noise_preds[1]*noise_pred_cond).sum())
+ # print(F.cosine_similarity(last_noise_preds[0].flatten(), noise_pred_uncond.flatten(), dim=0),
+ # F.cosine_similarity(last_noise_preds[1].flatten(), noise_pred_cond.flatten(), dim=0))
+ last_grad= last_noise_preds[1] - last_noise_preds[0]
+ new_grad = noise_pred_cond - noise_pred_uncond
+ # print( F.cosine_similarity(last_grad.flatten(), new_grad.flatten(), dim=0))
+ last_noise_preds = (noise_pred_uncond, noise_pred_cond)
+
+ use_cond_guidance = True
+ if use_cond_guidance:
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
+ else:
+ noise_pred = noise_pred_uncond
+ if clip_guidance is not None and t_scale >= clip_start and t_scale <= clip_end:
+ noise_pred, latent = new_cond_fn(latent, t, t_index,
+ embedding_conditional, noise_pred,clip_guidance,
+ clip_guidance_scale,
+ num_cutouts,
+ scheduler, unet,use_cutouts=True,
+ cut_power=cut_power)
+ if normalize_noise_pred:
+ noise_pred = noise_pred * noise_pred_uncond.norm() / noise_pred.norm()
+ if scheduler_str == 'ddim':
+ latent = forward_step(scheduler, noise_pred,
+ t,
+ latent).prev_sample
+ else:
+ latent = scheduler.step(noise_pred,
+ t_index,
+ latent).prev_sample
+
+ if prev_latent is not None:
+ prev_noise_pred = prev_noise_pred_uncond + guidance_scale * (prev_noise_pred_cond - prev_noise_pred_uncond)
+ prev_latent = prev_scheduler.step(prev_noise_pred, t_index, prev_latent).prev_sample
+ if one_pass: break
+
+ #scale and decode the image latents with vae
+ if return_latent: return latent
+ latent = latent / 0.18215
+ image = vae.decode(latent.to(vae.dtype)).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+
+ image, _ = check_safety(image)
+
+ image = (image[0] * 255).round().astype("uint8")
+ return Image.fromarray(image)
+####################################
+
+#### HELPER FUNCTIONS FOR OUR METHOD #####
+
+def get_alpha_and_beta(t, scheduler):
+ # want to run this for both current and previous timnestep
+ if t.dtype==torch.long:
+ alpha = scheduler.alphas_cumprod[t]
+ return alpha, 1-alpha
+
+ if t<0:
+ return scheduler.final_alpha_cumprod, 1 - scheduler.final_alpha_cumprod
+
+
+ low = t.floor().long()
+ high = t.ceil().long()
+ rem = t - low
+
+ low_alpha = scheduler.alphas_cumprod[low]
+ high_alpha = scheduler.alphas_cumprod[high]
+ interpolated_alpha = low_alpha * rem + high_alpha * (1-rem)
+ interpolated_beta = 1 - interpolated_alpha
+ return interpolated_alpha, interpolated_beta
+
+
+# A DDIM forward step function
+def forward_step(
+ self,
+ model_output,
+ timestep: int,
+ sample,
+ eta: float = 0.0,
+ use_clipped_model_output: bool = False,
+ generator=None,
+ return_dict: bool = True,
+ use_double=False,
+) :
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ prev_timestep = timestep - self.config.num_train_timesteps / self.num_inference_steps
+
+ if timestep > self.timesteps.max():
+ raise NotImplementedError("Need to double check what the overflow is")
+
+ alpha_prod_t, beta_prod_t = get_alpha_and_beta(timestep, self)
+ alpha_prod_t_prev, _ = get_alpha_and_beta(prev_timestep, self)
+
+
+ alpha_quotient = ((alpha_prod_t / alpha_prod_t_prev)**0.5)
+ first_term = (1./alpha_quotient) * sample
+ second_term = (1./alpha_quotient) * (beta_prod_t ** 0.5) * model_output
+ third_term = ((1 - alpha_prod_t_prev)**0.5) * model_output
+ return first_term - second_term + third_term
+
+# A DDIM reverse step function, the inverse of above
+def reverse_step(
+ self,
+ model_output,
+ timestep: int,
+ sample,
+ eta: float = 0.0,
+ use_clipped_model_output: bool = False,
+ generator=None,
+ return_dict: bool = True,
+ use_double=False,
+) :
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ prev_timestep = timestep - self.config.num_train_timesteps / self.num_inference_steps
+
+ if timestep > self.timesteps.max():
+ raise NotImplementedError
+ else:
+ alpha_prod_t = self.alphas_cumprod[timestep]
+
+ alpha_prod_t, beta_prod_t = get_alpha_and_beta(timestep, self)
+ alpha_prod_t_prev, _ = get_alpha_and_beta(prev_timestep, self)
+
+ alpha_quotient = ((alpha_prod_t / alpha_prod_t_prev)**0.5)
+
+ first_term = alpha_quotient * sample
+ second_term = ((beta_prod_t)**0.5) * model_output
+ third_term = alpha_quotient * ((1 - alpha_prod_t_prev)**0.5) * model_output
+ return first_term + second_term - third_term
+
+
+
+
+@torch.no_grad()
+def latent_to_image(latent):
+ image = vae.decode(latent.to(vae.dtype)/0.18215).sample
+ image = prep_image_for_return(image)
+ return image
+
+def prep_image_for_return(image):
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ image = (image[0] * 255).round().astype("uint8")
+ image = Image.fromarray(image)
+ return image
+
+#############################
+
+##### MAIN EDICT FUNCTION #######
+# Use EDICT_editing to perform calls
+
+@torch.no_grad()
+def coupled_stablediffusion(prompt="",
+ prompt_edit=None,
+ null_prompt='',
+ prompt_edit_token_weights=[],
+ prompt_edit_tokens_start=0.0,
+ prompt_edit_tokens_end=1.0,
+ prompt_edit_spatial_start=0.0,
+ prompt_edit_spatial_end=1.0,
+ guidance_scale=7.0, steps=50,
+ seed=1, width=512, height=512,
+ init_image=None, init_image_strength=1.0,
+ run_baseline=False,
+ use_lms=False,
+ leapfrog_steps=True,
+ reverse=False,
+ return_latents=False,
+ fixed_starting_latent=None,
+ beta_schedule='scaled_linear',
+ mix_weight=0.93):
+ #If seed is None, randomly select seed from 0 to 2^32-1
+ if seed is None: seed = random.randrange(2**32 - 1)
+ generator = torch.cuda.manual_seed(seed)
+
+ def image_to_latent(im):
+ if isinstance(im, torch.Tensor):
+ # assume it's the latent
+ # used to avoid clipping new generation before inversion
+ init_latent = im.to(device)
+ else:
+ #Resize and transpose for numpy b h w c -> torch b c h w
+ im = im.resize((width, height), resample=Image.Resampling.LANCZOS)
+ im = np.array(im).astype(np_dtype) / 255.0 * 2.0 - 1.0
+ # check if black and white
+ if len(im.shape) < 3:
+ im = np.stack([im for _ in range(3)], axis=2) # putting at end b/c channels
+
+ im = torch.from_numpy(im[np.newaxis, ...].transpose(0, 3, 1, 2))
+
+ #If there is alpha channel, composite alpha for white, as the diffusion model does not support alpha channel
+ if im.shape[1] > 3:
+ im = im[:, :3] * im[:, 3:] + (1 - im[:, 3:])
+
+ #Move image to GPU
+ im = im.to(device)
+ #Encode image
+ if use_half_prec:
+ init_latent = vae.encode(im).latent_dist.sample(generator=generator) * 0.18215
+ else:
+ with autocast(device):
+ init_latent = vae.encode(im).latent_dist.sample(generator=generator) * 0.18215
+ return init_latent
+ assert not use_lms, "Can't invert LMS the same as DDIM"
+ if run_baseline: leapfrog_steps=False
+ #Change size to multiple of 64 to prevent size mismatches inside model
+ width = width - width % 64
+ height = height - height % 64
+
+
+ #Preprocess image if it exists (img2img)
+ if init_image is not None:
+ assert reverse # want to be performing deterministic noising
+ # can take either pair (output of generative process) or single image
+ if isinstance(init_image, list):
+ if isinstance(init_image[0], torch.Tensor):
+ init_latent = [t.clone() for t in init_image]
+ else:
+ init_latent = [image_to_latent(im) for im in init_image]
+ else:
+ init_latent = image_to_latent(init_image)
+ # this is t_start for forward, t_end for reverse
+ t_limit = steps - int(steps * init_image_strength)
+ else:
+ assert not reverse, 'Need image to reverse from'
+ init_latent = torch.zeros((1, unet.in_channels, height // 8, width // 8), device=device)
+ t_limit = 0
+
+ if reverse:
+ latent = init_latent
+ else:
+ #Generate random normal noise
+ noise = torch.randn(init_latent.shape,
+ generator=generator,
+ device=device,
+ dtype=torch_dtype)
+ if fixed_starting_latent is None:
+ latent = noise
+ else:
+ if isinstance(fixed_starting_latent, list):
+ latent = [l.clone() for l in fixed_starting_latent]
+ else:
+ latent = fixed_starting_latent.clone()
+ t_limit = steps - int(steps * init_image_strength)
+ if isinstance(latent, list): # initializing from pair of images
+ latent_pair = latent
+ else: # initializing from noise
+ latent_pair = [latent.clone(), latent.clone()]
+
+
+ if steps==0:
+ if init_image is not None:
+ return image_to_latent(init_image)
+ else:
+ image = vae.decode(latent.to(vae.dtype) / 0.18215).sample
+ return prep_image_for_return(image)
+
+ #Set inference timesteps to scheduler
+ schedulers = []
+ for i in range(2):
+ # num_raw_timesteps = max(1000, steps)
+ scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012,
+ beta_schedule=beta_schedule,
+ num_train_timesteps=1000,
+ clip_sample=False,
+ set_alpha_to_one=False)
+ scheduler.set_timesteps(steps)
+ schedulers.append(scheduler)
+
+ with autocast(device):
+ # CLIP Text Embeddings
+ tokens_unconditional = clip_tokenizer(null_prompt, padding="max_length",
+ max_length=clip_tokenizer.model_max_length,
+ truncation=True, return_tensors="pt",
+ return_overflowing_tokens=True)
+ embedding_unconditional = clip(tokens_unconditional.input_ids.to(device)).last_hidden_state
+
+ tokens_conditional = clip_tokenizer(prompt, padding="max_length",
+ max_length=clip_tokenizer.model_max_length,
+ truncation=True, return_tensors="pt",
+ return_overflowing_tokens=True)
+ embedding_conditional = clip(tokens_conditional.input_ids.to(device)).last_hidden_state
+
+ #Process prompt editing (if running Prompt-to-Prompt)
+ if prompt_edit is not None:
+ tokens_conditional_edit = clip_tokenizer(prompt_edit, padding="max_length",
+ max_length=clip_tokenizer.model_max_length,
+ truncation=True, return_tensors="pt",
+ return_overflowing_tokens=True)
+ embedding_conditional_edit = clip(tokens_conditional_edit.input_ids.to(device)).last_hidden_state
+
+ init_attention_edit(tokens_conditional, tokens_conditional_edit)
+
+ init_attention_func()
+ init_attention_weights(prompt_edit_token_weights)
+
+ timesteps = schedulers[0].timesteps[t_limit:]
+ if reverse: timesteps = timesteps.flip(0)
+
+ for i, t in tqdm(enumerate(timesteps), total=len(timesteps)):
+ t_scale = t / schedulers[0].num_train_timesteps
+
+ if (reverse) and (not run_baseline):
+ # Reverse mixing layer
+ new_latents = [l.clone() for l in latent_pair]
+ new_latents[1] = (new_latents[1].clone() - (1-mix_weight)*new_latents[0].clone()) / mix_weight
+ new_latents[0] = (new_latents[0].clone() - (1-mix_weight)*new_latents[1].clone()) / mix_weight
+ latent_pair = new_latents
+
+ # alternate EDICT steps
+ for latent_i in range(2):
+ if run_baseline and latent_i==1: continue # just have one sequence for baseline
+ # this modifies latent_pair[i] while using
+ # latent_pair[(i+1)%2]
+ if reverse and (not run_baseline):
+ if leapfrog_steps:
+ # what i would be from going other way
+ orig_i = len(timesteps) - (i+1)
+ offset = (orig_i+1) % 2
+ latent_i = (latent_i + offset) % 2
+ else:
+ # Do 1 then 0
+ latent_i = (latent_i+1)%2
+ else:
+ if leapfrog_steps:
+ offset = i%2
+ latent_i = (latent_i + offset) % 2
+
+ latent_j = ((latent_i+1) % 2) if not run_baseline else latent_i
+
+ latent_model_input = latent_pair[latent_j]
+ latent_base = latent_pair[latent_i]
+
+ #Predict the unconditional noise residual
+ noise_pred_uncond = unet(latent_model_input, t,
+ encoder_hidden_states=embedding_unconditional).sample
+
+ #Prepare the Cross-Attention layers
+ if prompt_edit is not None:
+ save_last_tokens_attention()
+ save_last_self_attention()
+ else:
+ #Use weights on non-edited prompt when edit is None
+ use_last_tokens_attention_weights()
+
+ #Predict the conditional noise residual and save the cross-attention layer activations
+ noise_pred_cond = unet(latent_model_input, t,
+ encoder_hidden_states=embedding_conditional).sample
+
+ #Edit the Cross-Attention layer activations
+ if prompt_edit is not None:
+ t_scale = t / schedulers[0].num_train_timesteps
+ if t_scale >= prompt_edit_tokens_start and t_scale <= prompt_edit_tokens_end:
+ use_last_tokens_attention()
+ if t_scale >= prompt_edit_spatial_start and t_scale <= prompt_edit_spatial_end:
+ use_last_self_attention()
+
+ #Use weights on edited prompt
+ use_last_tokens_attention_weights()
+
+ #Predict the edited conditional noise residual using the cross-attention masks
+ noise_pred_cond = unet(latent_model_input,
+ t,
+ encoder_hidden_states=embedding_conditional_edit).sample
+
+ #Perform guidance
+ grad = (noise_pred_cond - noise_pred_uncond)
+ noise_pred = noise_pred_uncond + guidance_scale * grad
+
+
+ step_call = reverse_step if reverse else forward_step
+ new_latent = step_call(schedulers[latent_i],
+ noise_pred,
+ t,
+ latent_base)# .prev_sample
+ new_latent = new_latent.to(latent_base.dtype)
+
+ latent_pair[latent_i] = new_latent
+
+ if (not reverse) and (not run_baseline):
+ # Mixing layer (contraction) during generative process
+ new_latents = [l.clone() for l in latent_pair]
+ new_latents[0] = (mix_weight*new_latents[0] + (1-mix_weight)*new_latents[1]).clone()
+ new_latents[1] = ((1-mix_weight)*new_latents[0] + (mix_weight)*new_latents[1]).clone()
+ latent_pair = new_latents
+
+ #scale and decode the image latents with vae, can return latents instead of images
+ if reverse or return_latents:
+ results = [latent_pair]
+ return results if len(results)>1 else results[0]
+
+ # decode latents to iamges
+ images = []
+ for latent_i in range(2):
+ latent = latent_pair[latent_i] / 0.18215
+ image = vae.decode(latent.to(vae.dtype)).sample
+ images.append(image)
+
+ # Return images
+ return_arr = []
+ for image in images:
+ image = prep_image_for_return(image)
+ return_arr.append(image)
+ results = [return_arr]
+ return results if len(results)>1 else results[0]
+
+
diff --git a/local_app.py b/local_app.py
new file mode 100644
index 0000000000000000000000000000000000000000..992af324323c6d949df74b74c2adebe216a7c4d2
--- /dev/null
+++ b/local_app.py
@@ -0,0 +1,33 @@
+import gradio as gr
+import numpy as np
+from edict_functions import EDICT_editing
+from PIL import Image
+
+def greet(name):
+ return "Hello " + name + "!!"
+
+
+def edict(x, source_text, edit_text,
+ edit_strength, guidance_scale,
+ steps=50, mix_weight=0.93, ):
+ x = Image.fromarray(x)
+ return_im = EDICT_editing(x,
+ source_text,
+ edit_text,
+ steps=steps,
+ mix_weight=mix_weight,
+ init_image_strength=edit_strength,
+ guidance_scale=guidance_scale
+ )[0]
+ return np.array(return_im)
+
+iface = gr.Interface(fn=edict, inputs=["image",
+ gr.Textbox(label="Original Description"),
+ gr.Textbox(label="Edit Description"),
+ # 50, # gr.Slider(5, 50, value=20, step=1),
+ # 0.93, # gr.Slider(0.5, 1, value=0.7, step=0.05),
+ gr.Slider(0.0, 1, value=0.8, step=0.05),
+ gr.Slider(0, 10, value=3, step=0.5),
+ ],
+ outputs="image")
+iface.launch()
diff --git a/my_diffusers/__init__.py b/my_diffusers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf2f183c9b5dc45a3cb40a3b2408833f6966ac96
--- /dev/null
+++ b/my_diffusers/__init__.py
@@ -0,0 +1,60 @@
+from .utils import (
+ is_inflect_available,
+ is_onnx_available,
+ is_scipy_available,
+ is_transformers_available,
+ is_unidecode_available,
+)
+
+
+__version__ = "0.3.0"
+
+from .configuration_utils import ConfigMixin
+from .modeling_utils import ModelMixin
+from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
+from .onnx_utils import OnnxRuntimeModel
+from .optimization import (
+ get_constant_schedule,
+ get_constant_schedule_with_warmup,
+ get_cosine_schedule_with_warmup,
+ get_cosine_with_hard_restarts_schedule_with_warmup,
+ get_linear_schedule_with_warmup,
+ get_polynomial_decay_schedule_with_warmup,
+ get_scheduler,
+)
+from .pipeline_utils import DiffusionPipeline
+from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
+from .schedulers import (
+ DDIMScheduler,
+ DDPMScheduler,
+ KarrasVeScheduler,
+ PNDMScheduler,
+ SchedulerMixin,
+ ScoreSdeVeScheduler,
+)
+from .utils import logging
+
+
+if is_scipy_available():
+ from .schedulers import LMSDiscreteScheduler
+else:
+ from .utils.dummy_scipy_objects import * # noqa F403
+
+from .training_utils import EMAModel
+
+
+if is_transformers_available():
+ from .pipelines import (
+ LDMTextToImagePipeline,
+ StableDiffusionImg2ImgPipeline,
+ StableDiffusionInpaintPipeline,
+ StableDiffusionPipeline,
+ )
+else:
+ from .utils.dummy_transformers_objects import * # noqa F403
+
+
+if is_transformers_available() and is_onnx_available():
+ from .pipelines import StableDiffusionOnnxPipeline
+else:
+ from .utils.dummy_transformers_and_onnx_objects import * # noqa F403
diff --git a/my_diffusers/__pycache__/__init__.cpython-38.pyc b/my_diffusers/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5d35e42199ba8b89068d9acde4816d72f6ef745d
Binary files /dev/null and b/my_diffusers/__pycache__/__init__.cpython-38.pyc differ
diff --git a/my_diffusers/__pycache__/configuration_utils.cpython-38.pyc b/my_diffusers/__pycache__/configuration_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..157b853298c2383fa6f1140ae63296e4b802fc34
Binary files /dev/null and b/my_diffusers/__pycache__/configuration_utils.cpython-38.pyc differ
diff --git a/my_diffusers/__pycache__/modeling_utils.cpython-38.pyc b/my_diffusers/__pycache__/modeling_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4a720fc52796afbc3d71e084a1973f20d9eb4bde
Binary files /dev/null and b/my_diffusers/__pycache__/modeling_utils.cpython-38.pyc differ
diff --git a/my_diffusers/__pycache__/onnx_utils.cpython-38.pyc b/my_diffusers/__pycache__/onnx_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..213d6cc532937038d64eb9af3fd3a7ead1a28acb
Binary files /dev/null and b/my_diffusers/__pycache__/onnx_utils.cpython-38.pyc differ
diff --git a/my_diffusers/__pycache__/optimization.cpython-38.pyc b/my_diffusers/__pycache__/optimization.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d30dd304d60ede69e854b86d2357ba5244ce6eee
Binary files /dev/null and b/my_diffusers/__pycache__/optimization.cpython-38.pyc differ
diff --git a/my_diffusers/__pycache__/pipeline_utils.cpython-38.pyc b/my_diffusers/__pycache__/pipeline_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..300ab3f2be6446e14d9f4de76a7cb64fa8b1fc68
Binary files /dev/null and b/my_diffusers/__pycache__/pipeline_utils.cpython-38.pyc differ
diff --git a/my_diffusers/__pycache__/training_utils.cpython-38.pyc b/my_diffusers/__pycache__/training_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5aa29ba3d3d07365a0d75d08c5b5ebdd2b1c0127
Binary files /dev/null and b/my_diffusers/__pycache__/training_utils.cpython-38.pyc differ
diff --git a/my_diffusers/commands/__init__.py b/my_diffusers/commands/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..902bd46cedc6f2df785c1dc5d2e6bd8ef7c69ca6
--- /dev/null
+++ b/my_diffusers/commands/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABC, abstractmethod
+from argparse import ArgumentParser
+
+
+class BaseDiffusersCLICommand(ABC):
+ @staticmethod
+ @abstractmethod
+ def register_subcommand(parser: ArgumentParser):
+ raise NotImplementedError()
+
+ @abstractmethod
+ def run(self):
+ raise NotImplementedError()
diff --git a/my_diffusers/commands/diffusers_cli.py b/my_diffusers/commands/diffusers_cli.py
new file mode 100644
index 0000000000000000000000000000000000000000..30084e55ba4eeec79c87a99eae3e60a6233dc556
--- /dev/null
+++ b/my_diffusers/commands/diffusers_cli.py
@@ -0,0 +1,41 @@
+#!/usr/bin/env python
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from argparse import ArgumentParser
+
+from .env import EnvironmentCommand
+
+
+def main():
+ parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli []")
+ commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
+
+ # Register commands
+ EnvironmentCommand.register_subcommand(commands_parser)
+
+ # Let's go
+ args = parser.parse_args()
+
+ if not hasattr(args, "func"):
+ parser.print_help()
+ exit(1)
+
+ # Run
+ service = args.func(args)
+ service.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/my_diffusers/commands/env.py b/my_diffusers/commands/env.py
new file mode 100644
index 0000000000000000000000000000000000000000..81a878bff6688d3c510b53c60ac9d0e51e4aebcc
--- /dev/null
+++ b/my_diffusers/commands/env.py
@@ -0,0 +1,70 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import platform
+from argparse import ArgumentParser
+
+import huggingface_hub
+
+from .. import __version__ as version
+from ..utils import is_torch_available, is_transformers_available
+from . import BaseDiffusersCLICommand
+
+
+def info_command_factory(_):
+ return EnvironmentCommand()
+
+
+class EnvironmentCommand(BaseDiffusersCLICommand):
+ @staticmethod
+ def register_subcommand(parser: ArgumentParser):
+ download_parser = parser.add_parser("env")
+ download_parser.set_defaults(func=info_command_factory)
+
+ def run(self):
+ hub_version = huggingface_hub.__version__
+
+ pt_version = "not installed"
+ pt_cuda_available = "NA"
+ if is_torch_available():
+ import torch
+
+ pt_version = torch.__version__
+ pt_cuda_available = torch.cuda.is_available()
+
+ transformers_version = "not installed"
+ if is_transformers_available:
+ import transformers
+
+ transformers_version = transformers.__version__
+
+ info = {
+ "`diffusers` version": version,
+ "Platform": platform.platform(),
+ "Python version": platform.python_version(),
+ "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
+ "Huggingface_hub version": hub_version,
+ "Transformers version": transformers_version,
+ "Using GPU in script?": "",
+ "Using distributed or parallel set-up in script?": "",
+ }
+
+ print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
+ print(self.format_dict(info))
+
+ return info
+
+ @staticmethod
+ def format_dict(d):
+ return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
diff --git a/my_diffusers/configuration_utils.py b/my_diffusers/configuration_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a5c40f001dec427e6158fa59d92a0d4e226c302
--- /dev/null
+++ b/my_diffusers/configuration_utils.py
@@ -0,0 +1,403 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" ConfigMixinuration base class and utilities."""
+import functools
+import inspect
+import json
+import os
+import re
+from collections import OrderedDict
+from typing import Any, Dict, Tuple, Union
+
+from huggingface_hub import hf_hub_download
+from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
+from requests import HTTPError
+
+from . import __version__
+from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
+
+
+logger = logging.get_logger(__name__)
+
+_re_configuration_file = re.compile(r"config\.(.*)\.json")
+
+
+class ConfigMixin:
+ r"""
+ Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
+ methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with
+ - [`~ConfigMixin.from_config`]
+ - [`~ConfigMixin.save_config`]
+
+ Class attributes:
+ - **config_name** (`str`) -- A filename under which the config should stored when calling
+ [`~ConfigMixin.save_config`] (should be overridden by parent class).
+ - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
+ overridden by parent class).
+ """
+ config_name = None
+ ignore_for_config = []
+
+ def register_to_config(self, **kwargs):
+ if self.config_name is None:
+ raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
+ kwargs["_class_name"] = self.__class__.__name__
+ kwargs["_diffusers_version"] = __version__
+
+ for key, value in kwargs.items():
+ try:
+ setattr(self, key, value)
+ except AttributeError as err:
+ logger.error(f"Can't set {key} with value {value} for {self}")
+ raise err
+
+ if not hasattr(self, "_internal_dict"):
+ internal_dict = kwargs
+ else:
+ previous_dict = dict(self._internal_dict)
+ internal_dict = {**self._internal_dict, **kwargs}
+ logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
+
+ self._internal_dict = FrozenDict(internal_dict)
+
+ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+ """
+ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
+ [`~ConfigMixin.from_config`] class method.
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
+ """
+ if os.path.isfile(save_directory):
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ # If we save using the predefined names, we can load using `from_config`
+ output_config_file = os.path.join(save_directory, self.config_name)
+
+ self.to_json_file(output_config_file)
+ logger.info(f"ConfigMixinuration saved in {output_config_file}")
+
+ @classmethod
+ def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
+ r"""
+ Instantiate a Python class from a pre-defined JSON-file.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
+ organization name, like `google/ddpm-celebahq-256`.
+ - A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g.,
+ `./my_model_directory/`.
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
+ Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
+ as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
+ checkpoint with 3 labels).
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `transformers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ mirror (`str`, *optional*):
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
+ Please refer to the mirror site for more information.
+
+
+
+ Passing `use_auth_token=True`` is required when you want to use a private model.
+
+
+
+
+
+ Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
+ use this method in a firewalled environment.
+
+
+
+ """
+ config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
+
+ init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
+
+ model = cls(**init_dict)
+
+ if return_unused_kwargs:
+ return model, unused_kwargs
+ else:
+ return model
+
+ @classmethod
+ def get_config_dict(
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+
+ user_agent = {"file_type": "config"}
+
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+
+ if cls.config_name is None:
+ raise ValueError(
+ "`self.config_name` is not defined. Note that one should not load a config from "
+ "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
+ )
+
+ if os.path.isfile(pretrained_model_name_or_path):
+ config_file = pretrained_model_name_or_path
+ elif os.path.isdir(pretrained_model_name_or_path):
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
+ # Load from a PyTorch checkpoint
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
+ elif subfolder is not None and os.path.isfile(
+ os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
+ ):
+ config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
+ else:
+ raise EnvironmentError(
+ f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
+ )
+ else:
+ try:
+ # Load from URL or cache if already cached
+ config_file = hf_hub_download(
+ pretrained_model_name_or_path,
+ filename=cls.config_name,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ user_agent=user_agent,
+ subfolder=subfolder,
+ revision=revision,
+ )
+
+ except RepositoryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
+ " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
+ " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
+ " login` and pass `use_auth_token=True`."
+ )
+ except RevisionNotFoundError:
+ raise EnvironmentError(
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
+ " this model name. Check the model page at"
+ f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
+ )
+ except EntryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
+ )
+ except HTTPError as err:
+ raise EnvironmentError(
+ "There was a specific connection error when trying to load"
+ f" {pretrained_model_name_or_path}:\n{err}"
+ )
+ except ValueError:
+ raise EnvironmentError(
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
+ f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
+ " run the library in offline mode at"
+ " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
+ )
+ except EnvironmentError:
+ raise EnvironmentError(
+ f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
+ f"containing a {cls.config_name} file"
+ )
+
+ try:
+ # Load config dict
+ config_dict = cls._dict_from_json_file(config_file)
+ except (json.JSONDecodeError, UnicodeDecodeError):
+ raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
+
+ return config_dict
+
+ @classmethod
+ def extract_init_dict(cls, config_dict, **kwargs):
+ expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
+ expected_keys.remove("self")
+ # remove general kwargs if present in dict
+ if "kwargs" in expected_keys:
+ expected_keys.remove("kwargs")
+ # remove keys to be ignored
+ if len(cls.ignore_for_config) > 0:
+ expected_keys = expected_keys - set(cls.ignore_for_config)
+ init_dict = {}
+ for key in expected_keys:
+ if key in kwargs:
+ # overwrite key
+ init_dict[key] = kwargs.pop(key)
+ elif key in config_dict:
+ # use value from config dict
+ init_dict[key] = config_dict.pop(key)
+
+ unused_kwargs = config_dict.update(kwargs)
+
+ passed_keys = set(init_dict.keys())
+ if len(expected_keys - passed_keys) > 0:
+ logger.warning(
+ f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
+ )
+
+ return init_dict, unused_kwargs
+
+ @classmethod
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
+ with open(json_file, "r", encoding="utf-8") as reader:
+ text = reader.read()
+ return json.loads(text)
+
+ def __repr__(self):
+ return f"{self.__class__.__name__} {self.to_json_string()}"
+
+ @property
+ def config(self) -> Dict[str, Any]:
+ return self._internal_dict
+
+ def to_json_string(self) -> str:
+ """
+ Serializes this instance to a JSON string.
+
+ Returns:
+ `str`: String containing all the attributes that make up this configuration instance in JSON format.
+ """
+ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
+
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
+ """
+ Save this instance to a JSON file.
+
+ Args:
+ json_file_path (`str` or `os.PathLike`):
+ Path to the JSON file in which this configuration instance's parameters will be saved.
+ """
+ with open(json_file_path, "w", encoding="utf-8") as writer:
+ writer.write(self.to_json_string())
+
+
+class FrozenDict(OrderedDict):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ for key, value in self.items():
+ setattr(self, key, value)
+
+ self.__frozen = True
+
+ def __delitem__(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
+
+ def setdefault(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
+
+ def pop(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
+
+ def update(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
+
+ def __setattr__(self, name, value):
+ if hasattr(self, "__frozen") and self.__frozen:
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
+ super().__setattr__(name, value)
+
+ def __setitem__(self, name, value):
+ if hasattr(self, "__frozen") and self.__frozen:
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
+ super().__setitem__(name, value)
+
+
+def register_to_config(init):
+ r"""
+ Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
+ automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
+ shouldn't be registered in the config, use the `ignore_for_config` class variable
+
+ Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
+ """
+
+ @functools.wraps(init)
+ def inner_init(self, *args, **kwargs):
+ # Ignore private kwargs in the init.
+ init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
+ init(self, *args, **init_kwargs)
+ if not isinstance(self, ConfigMixin):
+ raise RuntimeError(
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
+ "not inherit from `ConfigMixin`."
+ )
+
+ ignore = getattr(self, "ignore_for_config", [])
+ # Get positional arguments aligned with kwargs
+ new_kwargs = {}
+ signature = inspect.signature(init)
+ parameters = {
+ name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
+ }
+ for arg, name in zip(args, parameters.keys()):
+ new_kwargs[name] = arg
+
+ # Then add all kwargs
+ new_kwargs.update(
+ {
+ k: init_kwargs.get(k, default)
+ for k, default in parameters.items()
+ if k not in ignore and k not in new_kwargs
+ }
+ )
+ getattr(self, "register_to_config")(**new_kwargs)
+
+ return inner_init
diff --git a/my_diffusers/dependency_versions_check.py b/my_diffusers/dependency_versions_check.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbf863222a52fd60a15a95be0fbd6391acd3ba6d
--- /dev/null
+++ b/my_diffusers/dependency_versions_check.py
@@ -0,0 +1,47 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+
+from .dependency_versions_table import deps
+from .utils.versions import require_version, require_version_core
+
+
+# define which module versions we always want to check at run time
+# (usually the ones defined in `install_requires` in setup.py)
+#
+# order specific notes:
+# - tqdm must be checked before tokenizers
+
+pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split()
+if sys.version_info < (3, 7):
+ pkgs_to_check_at_runtime.append("dataclasses")
+if sys.version_info < (3, 8):
+ pkgs_to_check_at_runtime.append("importlib_metadata")
+
+for pkg in pkgs_to_check_at_runtime:
+ if pkg in deps:
+ if pkg == "tokenizers":
+ # must be loaded here, or else tqdm check may fail
+ from .utils import is_tokenizers_available
+
+ if not is_tokenizers_available():
+ continue # not required, check version only if installed
+
+ require_version_core(deps[pkg])
+ else:
+ raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
+
+
+def dep_version_check(pkg, hint=None):
+ require_version(deps[pkg], hint)
diff --git a/my_diffusers/dependency_versions_table.py b/my_diffusers/dependency_versions_table.py
new file mode 100644
index 0000000000000000000000000000000000000000..74c5331e5af63fbab6e583da377c811e00791391
--- /dev/null
+++ b/my_diffusers/dependency_versions_table.py
@@ -0,0 +1,26 @@
+# THIS FILE HAS BEEN AUTOGENERATED. To update:
+# 1. modify the `_deps` dict in setup.py
+# 2. run `make deps_table_update``
+deps = {
+ "Pillow": "Pillow",
+ "accelerate": "accelerate>=0.11.0",
+ "black": "black==22.3",
+ "datasets": "datasets",
+ "filelock": "filelock",
+ "flake8": "flake8>=3.8.3",
+ "hf-doc-builder": "hf-doc-builder>=0.3.0",
+ "huggingface-hub": "huggingface-hub>=0.8.1",
+ "importlib_metadata": "importlib_metadata",
+ "isort": "isort>=5.5.4",
+ "modelcards": "modelcards==0.1.4",
+ "numpy": "numpy",
+ "pytest": "pytest",
+ "pytest-timeout": "pytest-timeout",
+ "pytest-xdist": "pytest-xdist",
+ "scipy": "scipy",
+ "regex": "regex!=2019.12.17",
+ "requests": "requests",
+ "tensorboard": "tensorboard",
+ "torch": "torch>=1.4",
+ "transformers": "transformers>=4.21.0",
+}
diff --git a/my_diffusers/dynamic_modules_utils.py b/my_diffusers/dynamic_modules_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ebf916e7af5768be3d3dc9984e5c2a066c5b4a2
--- /dev/null
+++ b/my_diffusers/dynamic_modules_utils.py
@@ -0,0 +1,335 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utilities to dynamically load objects from the Hub."""
+
+import importlib
+import os
+import re
+import shutil
+import sys
+from pathlib import Path
+from typing import Dict, Optional, Union
+
+from huggingface_hub import cached_download
+
+from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def init_hf_modules():
+ """
+ Creates the cache directory for modules with an init, and adds it to the Python path.
+ """
+ # This function has already been executed if HF_MODULES_CACHE already is in the Python path.
+ if HF_MODULES_CACHE in sys.path:
+ return
+
+ sys.path.append(HF_MODULES_CACHE)
+ os.makedirs(HF_MODULES_CACHE, exist_ok=True)
+ init_path = Path(HF_MODULES_CACHE) / "__init__.py"
+ if not init_path.exists():
+ init_path.touch()
+
+
+def create_dynamic_module(name: Union[str, os.PathLike]):
+ """
+ Creates a dynamic module in the cache directory for modules.
+ """
+ init_hf_modules()
+ dynamic_module_path = Path(HF_MODULES_CACHE) / name
+ # If the parent module does not exist yet, recursively create it.
+ if not dynamic_module_path.parent.exists():
+ create_dynamic_module(dynamic_module_path.parent)
+ os.makedirs(dynamic_module_path, exist_ok=True)
+ init_path = dynamic_module_path / "__init__.py"
+ if not init_path.exists():
+ init_path.touch()
+
+
+def get_relative_imports(module_file):
+ """
+ Get the list of modules that are relatively imported in a module file.
+
+ Args:
+ module_file (`str` or `os.PathLike`): The module file to inspect.
+ """
+ with open(module_file, "r", encoding="utf-8") as f:
+ content = f.read()
+
+ # Imports of the form `import .xxx`
+ relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
+ # Imports of the form `from .xxx import yyy`
+ relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
+ # Unique-ify
+ return list(set(relative_imports))
+
+
+def get_relative_import_files(module_file):
+ """
+ Get the list of all files that are needed for a given module. Note that this function recurses through the relative
+ imports (if a imports b and b imports c, it will return module files for b and c).
+
+ Args:
+ module_file (`str` or `os.PathLike`): The module file to inspect.
+ """
+ no_change = False
+ files_to_check = [module_file]
+ all_relative_imports = []
+
+ # Let's recurse through all relative imports
+ while not no_change:
+ new_imports = []
+ for f in files_to_check:
+ new_imports.extend(get_relative_imports(f))
+
+ module_path = Path(module_file).parent
+ new_import_files = [str(module_path / m) for m in new_imports]
+ new_import_files = [f for f in new_import_files if f not in all_relative_imports]
+ files_to_check = [f"{f}.py" for f in new_import_files]
+
+ no_change = len(new_import_files) == 0
+ all_relative_imports.extend(files_to_check)
+
+ return all_relative_imports
+
+
+def check_imports(filename):
+ """
+ Check if the current Python environment contains all the libraries that are imported in a file.
+ """
+ with open(filename, "r", encoding="utf-8") as f:
+ content = f.read()
+
+ # Imports of the form `import xxx`
+ imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
+ # Imports of the form `from xxx import yyy`
+ imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
+ # Only keep the top-level module
+ imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
+
+ # Unique-ify and test we got them all
+ imports = list(set(imports))
+ missing_packages = []
+ for imp in imports:
+ try:
+ importlib.import_module(imp)
+ except ImportError:
+ missing_packages.append(imp)
+
+ if len(missing_packages) > 0:
+ raise ImportError(
+ "This modeling file requires the following packages that were not found in your environment: "
+ f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
+ )
+
+ return get_relative_imports(filename)
+
+
+def get_class_in_module(class_name, module_path):
+ """
+ Import a module on the cache directory for modules and extract a class from it.
+ """
+ module_path = module_path.replace(os.path.sep, ".")
+ module = importlib.import_module(module_path)
+ return getattr(module, class_name)
+
+
+def get_cached_module_file(
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ module_file: str,
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ resume_download: bool = False,
+ proxies: Optional[Dict[str, str]] = None,
+ use_auth_token: Optional[Union[bool, str]] = None,
+ revision: Optional[str] = None,
+ local_files_only: bool = False,
+):
+ """
+ Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
+ Transformers module.
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
+ huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
+ under a user or organization name, like `dbmdz/bert-base-german-cased`.
+ - a path to a *directory* containing a configuration file saved using the
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
+
+ module_file (`str`):
+ The name of the module file containing the class to look for.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
+ cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
+ exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `transformers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ If `True`, will only try to load the tokenizer configuration from local files.
+
+
+
+ Passing `use_auth_token=True` is required when you want to use a private model.
+
+
+
+ Returns:
+ `str`: The path to the module inside the cache.
+ """
+ # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+ module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
+ submodule = "local"
+
+ if os.path.isfile(module_file_or_url):
+ resolved_module_file = module_file_or_url
+ else:
+ try:
+ # Load from URL or cache if already cached
+ resolved_module_file = cached_download(
+ module_file_or_url,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ )
+
+ except EnvironmentError:
+ logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
+ raise
+
+ # Check we have all the requirements in our environment
+ modules_needed = check_imports(resolved_module_file)
+
+ # Now we move the module inside our cached dynamic modules.
+ full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
+ create_dynamic_module(full_submodule)
+ submodule_path = Path(HF_MODULES_CACHE) / full_submodule
+ # We always copy local files (we could hash the file to see if there was a change, and give them the name of
+ # that hash, to only copy when there is a modification but it seems overkill for now).
+ # The only reason we do the copy is to avoid putting too many folders in sys.path.
+ shutil.copy(resolved_module_file, submodule_path / module_file)
+ for module_needed in modules_needed:
+ module_needed = f"{module_needed}.py"
+ shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
+ return os.path.join(full_submodule, module_file)
+
+
+def get_class_from_dynamic_module(
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ module_file: str,
+ class_name: str,
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ resume_download: bool = False,
+ proxies: Optional[Dict[str, str]] = None,
+ use_auth_token: Optional[Union[bool, str]] = None,
+ revision: Optional[str] = None,
+ local_files_only: bool = False,
+ **kwargs,
+):
+ """
+ Extracts a class from a module file, present in the local folder or repository of a model.
+
+
+
+ Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should
+ therefore only be called on trusted repos.
+
+
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
+ huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
+ under a user or organization name, like `dbmdz/bert-base-german-cased`.
+ - a path to a *directory* containing a configuration file saved using the
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
+
+ module_file (`str`):
+ The name of the module file containing the class to look for.
+ class_name (`str`):
+ The name of the class to import in the module.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
+ cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
+ exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ use_auth_token (`str` or `bool`, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `transformers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ If `True`, will only try to load the tokenizer configuration from local files.
+
+
+
+ Passing `use_auth_token=True` is required when you want to use a private model.
+
+
+
+ Returns:
+ `type`: The class, dynamically imported from the module.
+
+ Examples:
+
+ ```python
+ # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
+ # module.
+ cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel")
+ ```"""
+ # And lastly we get the class inside our newly created module
+ final_module = get_cached_module_file(
+ pretrained_model_name_or_path,
+ module_file,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ local_files_only=local_files_only,
+ )
+ return get_class_in_module(class_name, final_module.replace(".py", ""))
diff --git a/my_diffusers/hub_utils.py b/my_diffusers/hub_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c07329e36fe7a8826b0f1fb22396819b220e1b58
--- /dev/null
+++ b/my_diffusers/hub_utils.py
@@ -0,0 +1,197 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import shutil
+from pathlib import Path
+from typing import Optional
+
+from huggingface_hub import HfFolder, Repository, whoami
+
+from .pipeline_utils import DiffusionPipeline
+from .utils import is_modelcards_available, logging
+
+
+if is_modelcards_available():
+ from modelcards import CardData, ModelCard
+
+
+logger = logging.get_logger(__name__)
+
+
+MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md"
+
+
+def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
+ if token is None:
+ token = HfFolder.get_token()
+ if organization is None:
+ username = whoami(token)["name"]
+ return f"{username}/{model_id}"
+ else:
+ return f"{organization}/{model_id}"
+
+
+def init_git_repo(args, at_init: bool = False):
+ """
+ Args:
+ Initializes a git repo in `args.hub_model_id`.
+ at_init (`bool`, *optional*, defaults to `False`):
+ Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True`
+ and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out.
+ """
+ if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
+ return
+ hub_token = args.hub_token if hasattr(args, "hub_token") else None
+ use_auth_token = True if hub_token is None else hub_token
+ if not hasattr(args, "hub_model_id") or args.hub_model_id is None:
+ repo_name = Path(args.output_dir).absolute().name
+ else:
+ repo_name = args.hub_model_id
+ if "/" not in repo_name:
+ repo_name = get_full_repo_name(repo_name, token=hub_token)
+
+ try:
+ repo = Repository(
+ args.output_dir,
+ clone_from=repo_name,
+ use_auth_token=use_auth_token,
+ private=args.hub_private_repo,
+ )
+ except EnvironmentError:
+ if args.overwrite_output_dir and at_init:
+ # Try again after wiping output_dir
+ shutil.rmtree(args.output_dir)
+ repo = Repository(
+ args.output_dir,
+ clone_from=repo_name,
+ use_auth_token=use_auth_token,
+ )
+ else:
+ raise
+
+ repo.git_pull()
+
+ # By default, ignore the checkpoint folders
+ if not os.path.exists(os.path.join(args.output_dir, ".gitignore")):
+ with open(os.path.join(args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
+ writer.writelines(["checkpoint-*/"])
+
+ return repo
+
+
+def push_to_hub(
+ args,
+ pipeline: DiffusionPipeline,
+ repo: Repository,
+ commit_message: Optional[str] = "End of training",
+ blocking: bool = True,
+ **kwargs,
+) -> str:
+ """
+ Parameters:
+ Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
+ commit_message (`str`, *optional*, defaults to `"End of training"`):
+ Message to commit while pushing.
+ blocking (`bool`, *optional*, defaults to `True`):
+ Whether the function should return only when the `git push` has finished.
+ kwargs:
+ Additional keyword arguments passed along to [`create_model_card`].
+ Returns:
+ The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of the
+ commit and an object to track the progress of the commit if `blocking=True`
+ """
+
+ if not hasattr(args, "hub_model_id") or args.hub_model_id is None:
+ model_name = Path(args.output_dir).name
+ else:
+ model_name = args.hub_model_id.split("/")[-1]
+
+ output_dir = args.output_dir
+ os.makedirs(output_dir, exist_ok=True)
+ logger.info(f"Saving pipeline checkpoint to {output_dir}")
+ pipeline.save_pretrained(output_dir)
+
+ # Only push from one node.
+ if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
+ return
+
+ # Cancel any async push in progress if blocking=True. The commits will all be pushed together.
+ if (
+ blocking
+ and len(repo.command_queue) > 0
+ and repo.command_queue[-1] is not None
+ and not repo.command_queue[-1].is_done
+ ):
+ repo.command_queue[-1]._process.kill()
+
+ git_head_commit_url = repo.push_to_hub(commit_message=commit_message, blocking=blocking, auto_lfs_prune=True)
+ # push separately the model card to be independent from the rest of the model
+ create_model_card(args, model_name=model_name)
+ try:
+ repo.push_to_hub(commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True)
+ except EnvironmentError as exc:
+ logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")
+
+ return git_head_commit_url
+
+
+def create_model_card(args, model_name):
+ if not is_modelcards_available:
+ raise ValueError(
+ "Please make sure to have `modelcards` installed when using the `create_model_card` function. You can"
+ " install the package with `pip install modelcards`."
+ )
+
+ if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
+ return
+
+ hub_token = args.hub_token if hasattr(args, "hub_token") else None
+ repo_name = get_full_repo_name(model_name, token=hub_token)
+
+ model_card = ModelCard.from_template(
+ card_data=CardData( # Card metadata object that will be converted to YAML block
+ language="en",
+ license="apache-2.0",
+ library_name="diffusers",
+ tags=[],
+ datasets=args.dataset_name,
+ metrics=[],
+ ),
+ template_path=MODEL_CARD_TEMPLATE_PATH,
+ model_name=model_name,
+ repo_name=repo_name,
+ dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None,
+ learning_rate=args.learning_rate,
+ train_batch_size=args.train_batch_size,
+ eval_batch_size=args.eval_batch_size,
+ gradient_accumulation_steps=args.gradient_accumulation_steps
+ if hasattr(args, "gradient_accumulation_steps")
+ else None,
+ adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None,
+ adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None,
+ adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None,
+ adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None,
+ lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None,
+ lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None,
+ ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None,
+ ema_power=args.ema_power if hasattr(args, "ema_power") else None,
+ ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None,
+ mixed_precision=args.mixed_precision,
+ )
+
+ card_path = os.path.join(args.output_dir, "README.md")
+ model_card.save(card_path)
diff --git a/my_diffusers/modeling_utils.py b/my_diffusers/modeling_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb613614a8782bf2eba2a2e7c2dc2af987088d6f
--- /dev/null
+++ b/my_diffusers/modeling_utils.py
@@ -0,0 +1,542 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from typing import Callable, List, Optional, Tuple, Union
+
+import torch
+from torch import Tensor, device
+
+from huggingface_hub import hf_hub_download
+from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
+from requests import HTTPError
+
+from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
+
+
+WEIGHTS_NAME = "diffusion_pytorch_model.bin"
+
+
+logger = logging.get_logger(__name__)
+
+
+def get_parameter_device(parameter: torch.nn.Module):
+ try:
+ return next(parameter.parameters()).device
+ except StopIteration:
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
+
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
+ return tuples
+
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
+ first_tuple = next(gen)
+ return first_tuple[1].device
+
+
+def get_parameter_dtype(parameter: torch.nn.Module):
+ try:
+ return next(parameter.parameters()).dtype
+ except StopIteration:
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
+
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
+ return tuples
+
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
+ first_tuple = next(gen)
+ return first_tuple[1].dtype
+
+
+def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
+ """
+ Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
+ """
+ try:
+ return torch.load(checkpoint_file, map_location="cpu")
+ except Exception as e:
+ try:
+ with open(checkpoint_file) as f:
+ if f.read().startswith("version"):
+ raise OSError(
+ "You seem to have cloned a repository without having git-lfs installed. Please install "
+ "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
+ "you cloned."
+ )
+ else:
+ raise ValueError(
+ f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
+ "model. Make sure you have saved the model properly."
+ ) from e
+ except (UnicodeDecodeError, ValueError):
+ raise OSError(
+ f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' "
+ f"at '{checkpoint_file}'. "
+ "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
+ )
+
+
+def _load_state_dict_into_model(model_to_load, state_dict):
+ # Convert old format to new format if needed from a PyTorch state_dict
+ # copy state_dict so _load_from_state_dict can modify it
+ state_dict = state_dict.copy()
+ error_msgs = []
+
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
+ # so we need to apply the function recursively.
+ def load(module: torch.nn.Module, prefix=""):
+ args = (state_dict, prefix, {}, True, [], [], error_msgs)
+ module._load_from_state_dict(*args)
+
+ for name, child in module._modules.items():
+ if child is not None:
+ load(child, prefix + name + ".")
+
+ load(model_to_load)
+
+ return error_msgs
+
+
+class ModelMixin(torch.nn.Module):
+ r"""
+ Base class for all models.
+
+ [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
+ and saving models.
+
+ - **config_name** ([`str`]) -- A filename under which the model should be stored when calling
+ [`~modeling_utils.ModelMixin.save_pretrained`].
+ """
+ config_name = CONFIG_NAME
+ _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
+
+ def __init__(self):
+ super().__init__()
+
+ def save_pretrained(
+ self,
+ save_directory: Union[str, os.PathLike],
+ is_main_process: bool = True,
+ save_function: Callable = torch.save,
+ ):
+ """
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
+ `[`~modeling_utils.ModelMixin.from_pretrained`]` class method.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to which to save. Will be created if it doesn't exist.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful when in distributed training like
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
+ the main process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
+ need to replace `torch.save` by another method.
+ """
+ if os.path.isfile(save_directory):
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+ return
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ model_to_save = self
+
+ # Attach architecture to the config
+ # Save the config
+ if is_main_process:
+ model_to_save.save_config(save_directory)
+
+ # Save the model
+ state_dict = model_to_save.state_dict()
+
+ # Clean the folder from a previous save
+ for filename in os.listdir(save_directory):
+ full_filename = os.path.join(save_directory, filename)
+ # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
+ # in distributed settings to avoid race conditions.
+ if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename) and is_main_process:
+ os.remove(full_filename)
+
+ # Save the model
+ save_function(state_dict, os.path.join(save_directory, WEIGHTS_NAME))
+
+ logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}")
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
+ r"""
+ Instantiate a pretrained pytorch model from a pre-trained model configuration.
+
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
+ the model, you should first set it back in training mode with `model.train()`.
+
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
+ task.
+
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
+ weights are discarded.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
+ - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
+ `./my_model_directory/`.
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
+ will be automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `diffusers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ mirror (`str`, *optional*):
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
+ Please refer to the mirror site for more information.
+
+
+
+ Passing `use_auth_token=True`` is required when you want to use a private model.
+
+
+
+
+
+ Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
+ this method in a firewalled environment.
+
+
+
+ """
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ output_loading_info = kwargs.pop("output_loading_info", False)
+ local_files_only = kwargs.pop("local_files_only", False)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ from_auto_class = kwargs.pop("_from_auto", False)
+ torch_dtype = kwargs.pop("torch_dtype", None)
+ subfolder = kwargs.pop("subfolder", None)
+
+ user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
+
+ # Load config if we don't provide a configuration
+ config_path = pretrained_model_name_or_path
+ model, unused_kwargs = cls.from_config(
+ config_path,
+ cache_dir=cache_dir,
+ return_unused_kwargs=True,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ **kwargs,
+ )
+
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
+ raise ValueError(
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
+ )
+ elif torch_dtype is not None:
+ model = model.to(torch_dtype)
+
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
+ # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
+ # Load model
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+ if os.path.isdir(pretrained_model_name_or_path):
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
+ # Load from a PyTorch checkpoint
+ model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
+ elif subfolder is not None and os.path.isfile(
+ os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
+ ):
+ model_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
+ else:
+ raise EnvironmentError(
+ f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
+ )
+ else:
+ try:
+ # Load from URL or cache if already cached
+ model_file = hf_hub_download(
+ pretrained_model_name_or_path,
+ filename=WEIGHTS_NAME,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ user_agent=user_agent,
+ subfolder=subfolder,
+ revision=revision,
+ )
+
+ except RepositoryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
+ "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
+ "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
+ "login` and pass `use_auth_token=True`."
+ )
+ except RevisionNotFoundError:
+ raise EnvironmentError(
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
+ "this model name. Check the model page at "
+ f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
+ )
+ except EntryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}."
+ )
+ except HTTPError as err:
+ raise EnvironmentError(
+ "There was a specific connection error when trying to load"
+ f" {pretrained_model_name_or_path}:\n{err}"
+ )
+ except ValueError:
+ raise EnvironmentError(
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
+ f" directory containing a file named {WEIGHTS_NAME} or"
+ " \nCheckout your internet connection or see how to run the library in"
+ " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
+ )
+ except EnvironmentError:
+ raise EnvironmentError(
+ f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
+ f"containing a file named {WEIGHTS_NAME}"
+ )
+
+ # restore default dtype
+ state_dict = load_state_dict(model_file)
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
+ model,
+ state_dict,
+ model_file,
+ pretrained_model_name_or_path,
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
+ )
+
+ # Set model in evaluation mode to deactivate DropOut modules by default
+ model.eval()
+
+ if output_loading_info:
+ loading_info = {
+ "missing_keys": missing_keys,
+ "unexpected_keys": unexpected_keys,
+ "mismatched_keys": mismatched_keys,
+ "error_msgs": error_msgs,
+ }
+ return model, loading_info
+
+ return model
+
+ @classmethod
+ def _load_pretrained_model(
+ cls,
+ model,
+ state_dict,
+ resolved_archive_file,
+ pretrained_model_name_or_path,
+ ignore_mismatched_sizes=False,
+ ):
+ # Retrieve missing & unexpected_keys
+ model_state_dict = model.state_dict()
+ loaded_keys = [k for k in state_dict.keys()]
+
+ expected_keys = list(model_state_dict.keys())
+
+ original_loaded_keys = loaded_keys
+
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
+
+ # Make sure we are able to load base models as well as derived models (with heads)
+ model_to_load = model
+
+ def _find_mismatched_keys(
+ state_dict,
+ model_state_dict,
+ loaded_keys,
+ ignore_mismatched_sizes,
+ ):
+ mismatched_keys = []
+ if ignore_mismatched_sizes:
+ for checkpoint_key in loaded_keys:
+ model_key = checkpoint_key
+
+ if (
+ model_key in model_state_dict
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
+ ):
+ mismatched_keys.append(
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
+ )
+ del state_dict[checkpoint_key]
+ return mismatched_keys
+
+ if state_dict is not None:
+ # Whole checkpoint
+ mismatched_keys = _find_mismatched_keys(
+ state_dict,
+ model_state_dict,
+ original_loaded_keys,
+ ignore_mismatched_sizes,
+ )
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
+
+ if len(error_msgs) > 0:
+ error_msg = "\n\t".join(error_msgs)
+ if "size mismatch" in error_msg:
+ error_msg += (
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
+ )
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
+
+ if len(unexpected_keys) > 0:
+ logger.warning(
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
+ " identical (initializing a BertForSequenceClassification model from a"
+ " BertForSequenceClassification model)."
+ )
+ else:
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
+ if len(missing_keys) > 0:
+ logger.warning(
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
+ )
+ elif len(mismatched_keys) == 0:
+ logger.info(
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
+ " without further training."
+ )
+ if len(mismatched_keys) > 0:
+ mismatched_warning = "\n".join(
+ [
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
+ for key, shape1, shape2 in mismatched_keys
+ ]
+ )
+ logger.warning(
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
+ " able to use it for predictions and inference."
+ )
+
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
+
+ @property
+ def device(self) -> device:
+ """
+ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
+ device).
+ """
+ return get_parameter_device(self)
+
+ @property
+ def dtype(self) -> torch.dtype:
+ """
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
+ """
+ return get_parameter_dtype(self)
+
+ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
+ """
+ Get number of (optionally, trainable or non-embeddings) parameters in the module.
+
+ Args:
+ only_trainable (`bool`, *optional*, defaults to `False`):
+ Whether or not to return only the number of trainable parameters
+
+ exclude_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether or not to return only the number of non-embeddings parameters
+
+ Returns:
+ `int`: The number of parameters.
+ """
+
+ if exclude_embeddings:
+ embedding_param_names = [
+ f"{name}.weight"
+ for name, module_type in self.named_modules()
+ if isinstance(module_type, torch.nn.Embedding)
+ ]
+ non_embedding_parameters = [
+ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
+ ]
+ return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
+ else:
+ return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
+
+
+def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
+ """
+ Recursively unwraps a model from potential containers (as used in distributed training).
+
+ Args:
+ model (`torch.nn.Module`): The model to unwrap.
+ """
+ # since there could be multiple levels of wrapping, unwrap recursively
+ if hasattr(model, "module"):
+ return unwrap_model(model.module)
+ else:
+ return model
diff --git a/my_diffusers/models/__init__.py b/my_diffusers/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0ac5c8d548b4ec2f7b9c84d5c6d884fd470385b
--- /dev/null
+++ b/my_diffusers/models/__init__.py
@@ -0,0 +1,17 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .unet_2d import UNet2DModel
+from .unet_2d_condition import UNet2DConditionModel
+from .vae import AutoencoderKL, VQModel
diff --git a/my_diffusers/models/__pycache__/__init__.cpython-38.pyc b/my_diffusers/models/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d5c1fc70098001731874eae1691ac6a5a95018b0
Binary files /dev/null and b/my_diffusers/models/__pycache__/__init__.cpython-38.pyc differ
diff --git a/my_diffusers/models/__pycache__/attention.cpython-38.pyc b/my_diffusers/models/__pycache__/attention.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..26b14cda61f7108f2e7f4be446ee5dc9271b62fc
Binary files /dev/null and b/my_diffusers/models/__pycache__/attention.cpython-38.pyc differ
diff --git a/my_diffusers/models/__pycache__/embeddings.cpython-38.pyc b/my_diffusers/models/__pycache__/embeddings.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..74b608dacb3d37eb020bc3c0d4d17419500c0811
Binary files /dev/null and b/my_diffusers/models/__pycache__/embeddings.cpython-38.pyc differ
diff --git a/my_diffusers/models/__pycache__/resnet.cpython-38.pyc b/my_diffusers/models/__pycache__/resnet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..feeeaa0f297034929304c46c415716b1c397d3d0
Binary files /dev/null and b/my_diffusers/models/__pycache__/resnet.cpython-38.pyc differ
diff --git a/my_diffusers/models/__pycache__/unet_2d.cpython-38.pyc b/my_diffusers/models/__pycache__/unet_2d.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..41f5fa7a74bb67845f02c4f6d1cdf4c17c6aad53
Binary files /dev/null and b/my_diffusers/models/__pycache__/unet_2d.cpython-38.pyc differ
diff --git a/my_diffusers/models/__pycache__/unet_2d_condition.cpython-38.pyc b/my_diffusers/models/__pycache__/unet_2d_condition.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..01b2ea217b1f76ad85fdd501e6b4ae8bfc94ff86
Binary files /dev/null and b/my_diffusers/models/__pycache__/unet_2d_condition.cpython-38.pyc differ
diff --git a/my_diffusers/models/__pycache__/unet_blocks.cpython-38.pyc b/my_diffusers/models/__pycache__/unet_blocks.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fdf9eb1f0b21f5c8456f8ec765c6034c3aee3b48
Binary files /dev/null and b/my_diffusers/models/__pycache__/unet_blocks.cpython-38.pyc differ
diff --git a/my_diffusers/models/__pycache__/vae.cpython-38.pyc b/my_diffusers/models/__pycache__/vae.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0729703a5e310db4a33bfb65868fb366dd0ed5c1
Binary files /dev/null and b/my_diffusers/models/__pycache__/vae.cpython-38.pyc differ
diff --git a/my_diffusers/models/attention.py b/my_diffusers/models/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e5ab9ace7c6ffbf048f6ddd3cfc8e4482fac61f
--- /dev/null
+++ b/my_diffusers/models/attention.py
@@ -0,0 +1,333 @@
+import math
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
+ to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ Uses three q, k, v linear layers to compute attention.
+
+ Parameters:
+ channels (:obj:`int`): The number of channels in the input and output.
+ num_head_channels (:obj:`int`, *optional*):
+ The number of channels in each head. If None, then `num_heads` = 1.
+ num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm.
+ rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
+ eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ num_head_channels: Optional[int] = None,
+ num_groups: int = 32,
+ rescale_output_factor = 1.0,
+ eps = 1e-5,
+ ):
+ super().__init__()
+ self.channels = channels
+
+ self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
+ self.num_head_size = num_head_channels
+ self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
+
+ # define q,k,v as linear layers
+ self.query = nn.Linear(channels, channels)
+ self.key = nn.Linear(channels, channels)
+ self.value = nn.Linear(channels, channels)
+
+ self.rescale_output_factor = rescale_output_factor
+ self.proj_attn = nn.Linear(channels, channels, 1)
+
+ def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
+ new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
+ # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
+ new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
+ return new_projection
+
+ def forward(self, hidden_states):
+ residual = hidden_states
+ batch, channel, height, width = hidden_states.shape
+
+ # norm
+ hidden_states = self.group_norm(hidden_states)
+
+ hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
+
+ # proj to q, k, v
+ query_proj = self.query(hidden_states)
+ key_proj = self.key(hidden_states)
+ value_proj = self.value(hidden_states)
+
+ # transpose
+ query_states = self.transpose_for_scores(query_proj)
+ key_states = self.transpose_for_scores(key_proj)
+ value_states = self.transpose_for_scores(value_proj)
+
+ # get scores
+ scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
+
+ attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
+ attention_probs = torch.softmax(attention_scores.double(), dim=-1).type(attention_scores.dtype)
+
+ # compute attention output
+ hidden_states = torch.matmul(attention_probs, value_states)
+
+ hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
+ new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
+ hidden_states = hidden_states.view(new_hidden_states_shape)
+
+ # compute next hidden_states
+ hidden_states = self.proj_attn(hidden_states)
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
+
+ # res connect and rescale
+ hidden_states = (hidden_states + residual) / self.rescale_output_factor
+ return hidden_states
+
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
+ standard transformer action. Finally, reshape to image.
+
+ Parameters:
+ in_channels (:obj:`int`): The number of channels in the input and output.
+ n_heads (:obj:`int`): The number of heads to use for multi-head attention.
+ d_head (:obj:`int`): The number of channels in each head.
+ depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use.
+ context_dim (:obj:`int`, *optional*): The number of context dimensions to use.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ n_heads: int,
+ d_head: int,
+ depth: int = 1,
+ dropout = 0.0,
+ context_dim: Optional[int] = None,
+ ):
+ super().__init__()
+ self.n_heads = n_heads
+ self.d_head = d_head
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
+ for d in range(depth)
+ ]
+ )
+
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+
+ def _set_attention_slice(self, slice_size):
+ for block in self.transformer_blocks:
+ block._set_attention_slice(slice_size)
+
+ def forward(self, x, context=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ x = self.proj_in(x)
+ x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
+ for block in self.transformer_blocks:
+ x = block(x, context=context)
+ x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
+ x = self.proj_out(x)
+ return x + x_in
+
+
+class BasicTransformerBlock(nn.Module):
+ r"""
+ A basic Transformer block.
+
+ Parameters:
+ dim (:obj:`int`): The number of channels in the input and output.
+ n_heads (:obj:`int`): The number of heads to use for multi-head attention.
+ d_head (:obj:`int`): The number of channels in each head.
+ dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention.
+ gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network.
+ checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ n_heads: int,
+ d_head: int,
+ dropout=0.0,
+ context_dim: Optional[int] = None,
+ gated_ff: bool = True,
+ checkpoint: bool = True,
+ ):
+ super().__init__()
+ self.attn1 = CrossAttention(
+ query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
+ ) # is a self-attention
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = CrossAttention(
+ query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
+ ) # is self-attn if context is none
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def _set_attention_slice(self, slice_size):
+ self.attn1._slice_size = slice_size
+ self.attn2._slice_size = slice_size
+
+ def forward(self, x, context=None):
+ x = x.contiguous() if x.device.type == "mps" else x
+ x = self.attn1(self.norm1(x)) + x
+ x = self.attn2(self.norm2(x), context=context) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class CrossAttention(nn.Module):
+ r"""
+ A cross attention layer.
+
+ Parameters:
+ query_dim (:obj:`int`): The number of channels in the query.
+ context_dim (:obj:`int`, *optional*):
+ The number of channels in the context. If not given, defaults to `query_dim`.
+ heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
+ dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head.
+ dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ """
+
+ def __init__(
+ self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0
+ ):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = context_dim if context_dim is not None else query_dim
+
+ self.scale = dim_head**-0.5
+ self.heads = heads
+ # for slice_size > 0 the attention score computation
+ # is split across the batch axis to save memory
+ # You can set slice_size with `set_attention_slice`
+ self._slice_size = None
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
+
+ def reshape_heads_to_batch_dim(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
+ return tensor
+
+ def reshape_batch_dim_to_heads(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def forward(self, x, context=None, mask=None):
+ batch_size, sequence_length, dim = x.shape
+
+ q = self.to_q(x)
+ context = context if context is not None else x
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ q = self.reshape_heads_to_batch_dim(q)
+ k = self.reshape_heads_to_batch_dim(k)
+ v = self.reshape_heads_to_batch_dim(v)
+
+ # TODO(PVP) - mask is currently never used. Remember to re-implement when used
+
+ # attention, what we cannot get enough of
+ hidden_states = self._attention(q, k, v, sequence_length, dim)
+
+ return self.to_out(hidden_states)
+
+ def _attention(self, query, key, value, sequence_length, dim):
+ batch_size_attention = query.shape[0]
+ hidden_states = torch.zeros(
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
+ )
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
+ for i in range(hidden_states.shape[0] // slice_size):
+ start_idx = i * slice_size
+ end_idx = (i + 1) * slice_size
+ attn_slice = (
+ torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale
+ )
+ attn_slice = attn_slice.softmax(dim=-1)
+ attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
+
+ hidden_states[start_idx:end_idx] = attn_slice
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+
+class FeedForward(nn.Module):
+ r"""
+ A feed-forward layer.
+
+ Parameters:
+ dim (:obj:`int`): The number of channels in the input.
+ dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
+ mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
+ glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation.
+ dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ """
+
+ def __init__(
+ self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout = 0.0
+ ):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = dim_out if dim_out is not None else dim
+ project_in = GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
+
+ def forward(self, x):
+ return self.net(x)
+
+
+# feedforward
+class GEGLU(nn.Module):
+ r"""
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
+
+ Parameters:
+ dim_in (:obj:`int`): The number of channels in the input.
+ dim_out (:obj:`int`): The number of channels in the output.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
diff --git a/my_diffusers/models/embeddings.py b/my_diffusers/models/embeddings.py
new file mode 100644
index 0000000000000000000000000000000000000000..734be6068b7817efd51a508b0e42bc1c8f99d289
--- /dev/null
+++ b/my_diffusers/models/embeddings.py
@@ -0,0 +1,116 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+
+import numpy as np
+import torch
+from torch import nn
+
+
+def get_timestep_embedding(
+ timesteps: torch.Tensor,
+ embedding_dim: int,
+ flip_sin_to_cos: bool = False,
+ downscale_freq_shift: float = 1,
+ scale: float = 1,
+ max_period: int = 10000,
+):
+ # print(timesteps)
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
+
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
+ """
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
+
+ half_dim = embedding_dim // 2
+ exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float64)
+ exponent = exponent / (half_dim - downscale_freq_shift)
+
+ emb = torch.exp(exponent).to(device=timesteps.device)
+ emb = timesteps[:, None].double() * emb[None, :]
+
+ # scale embeddings
+ emb = scale * emb
+
+ # concat sine and cosine embeddings
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
+
+ # flip sine and cosine embeddings
+ if flip_sin_to_cos:
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
+
+ # zero pad
+ if embedding_dim % 2 == 1:
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+class TimestepEmbedding(nn.Module):
+ def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"):
+ super().__init__()
+
+ self.linear_1 = nn.Linear(channel, time_embed_dim)
+ self.act = None
+ if act_fn == "silu":
+ self.act = nn.SiLU()
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
+
+ def forward(self, sample):
+ sample = self.linear_1(sample)
+
+ if self.act is not None:
+ sample = self.act(sample)
+
+ sample = self.linear_2(sample)
+ return sample
+
+
+class Timesteps(nn.Module):
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
+ super().__init__()
+ self.num_channels = num_channels
+ self.flip_sin_to_cos = flip_sin_to_cos
+ self.downscale_freq_shift = downscale_freq_shift
+
+ def forward(self, timesteps):
+ t_emb = get_timestep_embedding(
+ timesteps,
+ self.num_channels,
+ flip_sin_to_cos=self.flip_sin_to_cos,
+ downscale_freq_shift=self.downscale_freq_shift,
+ )
+ return t_emb
+
+
+class GaussianFourierProjection(nn.Module):
+ """Gaussian Fourier embeddings for noise levels."""
+
+ def __init__(self, embedding_size: int = 256, scale: float = 1.0):
+ super().__init__()
+ self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
+
+ # to delete later
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
+
+ self.weight = self.W
+
+ def forward(self, x):
+ x = torch.log(x)
+ x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
+ out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
+ return out
diff --git a/my_diffusers/models/resnet.py b/my_diffusers/models/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd7428eb58f1e22180a1acef7453ded281db5eb6
--- /dev/null
+++ b/my_diffusers/models/resnet.py
@@ -0,0 +1,483 @@
+from functools import partial
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class Upsample2D(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+
+ :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
+ applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_conv_transpose = use_conv_transpose
+ self.name = name
+
+ conv = None
+ if use_conv_transpose:
+ conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
+ elif use_conv:
+ conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
+
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
+ if name == "conv":
+ self.conv = conv
+ else:
+ self.Conv2d_0 = conv
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.use_conv_transpose:
+ return self.conv(x)
+
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
+
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
+ if self.use_conv:
+ if self.name == "conv":
+ x = self.conv(x)
+ else:
+ x = self.Conv2d_0(x)
+
+ return x
+
+
+class Downsample2D(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+
+ :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
+ applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.padding = padding
+ stride = 2
+ self.name = name
+
+ if use_conv:
+ conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
+ else:
+ assert self.channels == self.out_channels
+ conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
+
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
+ if name == "conv":
+ self.Conv2d_0 = conv
+ self.conv = conv
+ elif name == "Conv2d_0":
+ self.conv = conv
+ else:
+ self.conv = conv
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.use_conv and self.padding == 0:
+ pad = (0, 1, 0, 1)
+ x = F.pad(x, pad, mode="constant", value=0)
+
+ assert x.shape[1] == self.channels
+ x = self.conv(x)
+
+ return x
+
+
+class FirUpsample2D(nn.Module):
+ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
+ super().__init__()
+ out_channels = out_channels if out_channels else channels
+ if use_conv:
+ self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.use_conv = use_conv
+ self.fir_kernel = fir_kernel
+ self.out_channels = out_channels
+
+ def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
+ """Fused `upsample_2d()` followed by `Conv2d()`.
+
+ Args:
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
+ order.
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
+ C]`.
+ weight: Weight tensor of the shape `[filterH, filterW, inChannels,
+ outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
+ factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
+ `x`.
+ """
+
+ assert isinstance(factor, int) and factor >= 1
+
+ # Setup filter kernel.
+ if kernel is None:
+ kernel = [1] * factor
+
+ # setup kernel
+ kernel = np.asarray(kernel, dtype=np.float64)
+ if kernel.ndim == 1:
+ kernel = np.outer(kernel, kernel)
+ kernel /= np.sum(kernel)
+
+ kernel = kernel * (gain * (factor**2))
+
+ if self.use_conv:
+ convH = weight.shape[2]
+ convW = weight.shape[3]
+ inC = weight.shape[1]
+
+ p = (kernel.shape[0] - factor) - (convW - 1)
+
+ stride = (factor, factor)
+ # Determine data dimensions.
+ stride = [1, 1, factor, factor]
+ output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
+ output_padding = (
+ output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
+ output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
+ )
+ assert output_padding[0] >= 0 and output_padding[1] >= 0
+ inC = weight.shape[1]
+ num_groups = x.shape[1] // inC
+
+ # Transpose weights.
+ weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
+ weight = weight[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
+ weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
+
+ x = F.conv_transpose2d(x, weight, stride=stride, output_padding=output_padding, padding=0)
+
+ x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
+ else:
+ p = kernel.shape[0] - factor
+ x = upfirdn2d_native(
+ x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
+ )
+
+ return x
+
+ def forward(self, x):
+ if self.use_conv:
+ height = self._upsample_2d(x, self.Conv2d_0.weight, kernel=self.fir_kernel)
+ height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
+ else:
+ height = self._upsample_2d(x, kernel=self.fir_kernel, factor=2)
+
+ return height
+
+
+class FirDownsample2D(nn.Module):
+ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
+ super().__init__()
+ out_channels = out_channels if out_channels else channels
+ if use_conv:
+ self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.fir_kernel = fir_kernel
+ self.use_conv = use_conv
+ self.out_channels = out_channels
+
+ def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
+ """Fused `Conv2d()` followed by `downsample_2d()`.
+
+ Args:
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
+ order.
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH,
+ filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] //
+ numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
+ factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain:
+ Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
+ datatype as `x`.
+ """
+
+ assert isinstance(factor, int) and factor >= 1
+ if kernel is None:
+ kernel = [1] * factor
+
+ # setup kernel
+ kernel = np.asarray(kernel, dtype=np.float64)
+ if kernel.ndim == 1:
+ kernel = np.outer(kernel, kernel)
+ kernel /= np.sum(kernel)
+
+ kernel = kernel * gain
+
+ if self.use_conv:
+ _, _, convH, convW = weight.shape
+ p = (kernel.shape[0] - factor) + (convW - 1)
+ s = [factor, factor]
+ x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2, p // 2))
+ x = F.conv2d(x, weight, stride=s, padding=0)
+ else:
+ p = kernel.shape[0] - factor
+ x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
+
+ return x
+
+ def forward(self, x):
+ if self.use_conv:
+ x = self._downsample_2d(x, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
+ x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
+ else:
+ x = self._downsample_2d(x, kernel=self.fir_kernel, factor=2)
+
+ return x
+
+
+class ResnetBlock2D(nn.Module):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout=0.0,
+ temb_channels=512,
+ groups=32,
+ groups_out=None,
+ pre_norm=True,
+ eps=1e-6,
+ non_linearity="swish",
+ time_embedding_norm="default",
+ kernel=None,
+ output_scale_factor=1.0,
+ use_nin_shortcut=None,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.pre_norm = pre_norm
+ self.pre_norm = True
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+ self.time_embedding_norm = time_embedding_norm
+ self.up = up
+ self.down = down
+ self.output_scale_factor = output_scale_factor
+
+ if groups_out is None:
+ groups_out = groups
+
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
+
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ if temb_channels is not None:
+ self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
+ else:
+ self.time_emb_proj = None
+
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ if non_linearity == "swish":
+ self.nonlinearity = lambda x: F.silu(x)
+ elif non_linearity == "mish":
+ self.nonlinearity = Mish()
+ elif non_linearity == "silu":
+ self.nonlinearity = nn.SiLU()
+
+ self.upsample = self.downsample = None
+ if self.up:
+ if kernel == "fir":
+ fir_kernel = (1, 3, 3, 1)
+ self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
+ elif kernel == "sde_vp":
+ self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
+ else:
+ self.upsample = Upsample2D(in_channels, use_conv=False)
+ elif self.down:
+ if kernel == "fir":
+ fir_kernel = (1, 3, 3, 1)
+ self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
+ elif kernel == "sde_vp":
+ self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
+ else:
+ self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
+
+ self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
+
+ self.conv_shortcut = None
+ if self.use_nin_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x, temb):
+ hidden_states = x
+
+ # make sure hidden states is in float32
+ # when running in half-precision
+ hidden_states = self.norm1(hidden_states.double()).type(hidden_states.dtype)
+ hidden_states = self.nonlinearity(hidden_states)
+
+ if self.upsample is not None:
+ x = self.upsample(x)
+ hidden_states = self.upsample(hidden_states)
+ elif self.downsample is not None:
+ x = self.downsample(x)
+ hidden_states = self.downsample(hidden_states)
+
+ hidden_states = self.conv1(hidden_states)
+
+ if temb is not None:
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
+ hidden_states = hidden_states + temb
+
+ # make sure hidden states is in float32
+ # when running in half-precision
+ hidden_states = self.norm2(hidden_states.double()).type(hidden_states.dtype)
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.conv_shortcut is not None:
+ x = self.conv_shortcut(x)
+
+ out = (x + hidden_states) / self.output_scale_factor
+
+ return out
+
+
+class Mish(torch.nn.Module):
+ def forward(self, x):
+ return x * torch.tanh(torch.nn.functional.softplus(x))
+
+
+def upsample_2d(x, kernel=None, factor=2, gain=1):
+ r"""Upsample2D a batch of 2D images with the given filter.
+
+ Args:
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
+ filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
+ `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
+ multiple of the upsampling factor.
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
+ C]`.
+ k: FIR filter of the shape `[firH, firW]` or `[firN]`
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
+ factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ Tensor of the shape `[N, C, H * factor, W * factor]`
+ """
+ assert isinstance(factor, int) and factor >= 1
+ if kernel is None:
+ kernel = [1] * factor
+
+ kernel = np.asarray(kernel, dtype=np.float64)
+ if kernel.ndim == 1:
+ kernel = np.outer(kernel, kernel)
+ kernel /= np.sum(kernel)
+
+ kernel = kernel * (gain * (factor**2))
+ p = kernel.shape[0] - factor
+ return upfirdn2d_native(
+ x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
+ )
+
+
+def downsample_2d(x, kernel=None, factor=2, gain=1):
+ r"""Downsample2D a batch of 2D images with the given filter.
+
+ Args:
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
+ given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
+ specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
+ shape is a multiple of the downsampling factor.
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
+ C]`.
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
+ (separable). The default is `[1] * factor`, which corresponds to average pooling.
+ factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ Tensor of the shape `[N, C, H // factor, W // factor]`
+ """
+
+ assert isinstance(factor, int) and factor >= 1
+ if kernel is None:
+ kernel = [1] * factor
+
+ kernel = np.asarray(kernel, dtype=np.float64)
+ if kernel.ndim == 1:
+ kernel = np.outer(kernel, kernel)
+ kernel /= np.sum(kernel)
+
+ kernel = kernel * gain
+ p = kernel.shape[0] - factor
+ return upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
+
+
+def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
+ up_x = up_y = up
+ down_x = down_y = down
+ pad_x0 = pad_y0 = pad[0]
+ pad_x1 = pad_y1 = pad[1]
+
+ _, channel, in_h, in_w = input.shape
+ input = input.reshape(-1, in_h, in_w, 1)
+
+ _, in_h, in_w, minor = input.shape
+ kernel_h, kernel_w = kernel.shape
+
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
+
+ # Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535
+ if input.device.type == "mps":
+ out = out.to("cpu")
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
+
+ out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
+ out = out.to(input.device) # Move back to mps if necessary
+ out = out[
+ :,
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
+ :,
+ ]
+
+ out = out.permute(0, 3, 1, 2)
+ out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
+ out = F.conv2d(out, w)
+ out = out.reshape(
+ -1,
+ minor,
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
+ )
+ out = out.permute(0, 2, 3, 1)
+ out = out[:, ::down_y, ::down_x, :]
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+
+ return out.view(-1, channel, out_h, out_w)
diff --git a/my_diffusers/models/unet_2d.py b/my_diffusers/models/unet_2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a51ecf79e6ac5da400c97f0b38e2593ae86ed70
--- /dev/null
+++ b/my_diffusers/models/unet_2d.py
@@ -0,0 +1,246 @@
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..modeling_utils import ModelMixin
+from ..utils import BaseOutput
+from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
+from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
+
+
+@dataclass
+class UNet2DOutput(BaseOutput):
+ """
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Hidden states output. Output of last layer of model.
+ """
+
+ sample: torch.DoubleTensor
+
+
+class UNet2DModel(ModelMixin, ConfigMixin):
+ r"""
+ UNet2DModel is a 2D UNet model that takes in a noisy sample and a timestep and returns sample shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
+ implements for all the model (such as downloading or saving, etc.)
+
+ Parameters:
+ sample_size (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
+ Input sample size.
+ in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
+ out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
+ time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
+ freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
+ flip_sin_to_cos (`bool`, *optional*, defaults to :
+ obj:`False`): Whether to flip sin to cos for fourier time embedding.
+ down_block_types (`Tuple[str]`, *optional*, defaults to :
+ obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block
+ types.
+ up_block_types (`Tuple[str]`, *optional*, defaults to :
+ obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to :
+ obj:`(224, 448, 672, 896)`): Tuple of block output channels.
+ layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
+ mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
+ downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
+ norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization.
+ norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ center_input_sample: bool = False,
+ time_embedding_type: str = "positional",
+ freq_shift: int = 0,
+ flip_sin_to_cos: bool = True,
+ down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
+ up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
+ block_out_channels: Tuple[int] = (224, 448, 672, 896),
+ layers_per_block: int = 2,
+ mid_block_scale_factor = 1,
+ downsample_padding: int = 1,
+ act_fn: str = "silu",
+ attention_head_dim: int = 8,
+ norm_num_groups: int = 32,
+ norm_eps = 1e-5,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+ time_embed_dim = block_out_channels[0] * 4
+
+ # input
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
+
+ # time
+ if time_embedding_type == "fourier":
+ self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
+ timestep_input_dim = 2 * block_out_channels[0]
+ elif time_embedding_type == "positional":
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+
+ self.down_blocks = nn.ModuleList([])
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ attn_num_head_channels=attention_head_dim,
+ downsample_padding=downsample_padding,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift="default",
+ attn_num_head_channels=attention_head_dim,
+ resnet_groups=norm_num_groups,
+ )
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ is_final_block = i == len(block_out_channels) - 1
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=time_embed_dim,
+ add_upsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ attn_num_head_channels=attention_head_dim,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
+ self.conv_act = nn.SiLU()
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
+
+ def forward(
+ self,
+ sample: torch.DoubleTensor,
+ timestep: Union[torch.Tensor, float, int],
+ return_dict: bool = True,
+ ) -> Union[UNet2DOutput, Tuple]:
+ """r
+ Args:
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
+ timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True,
+ otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+ """
+ # 0. center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
+
+ t_emb = self.time_proj(timesteps)
+ emb = self.time_embedding(t_emb)
+
+ # 2. pre-process
+ skip_sample = sample
+ sample = self.conv_in(sample)
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "skip_conv"):
+ sample, res_samples, skip_sample = downsample_block(
+ hidden_states=sample, temb=emb, skip_sample=skip_sample
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ sample = self.mid_block(sample, emb)
+
+ # 5. up
+ skip_sample = None
+ for upsample_block in self.up_blocks:
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ if hasattr(upsample_block, "skip_conv"):
+ sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
+ else:
+ sample = upsample_block(sample, res_samples, emb)
+
+ # 6. post-process
+ # make sure hidden states is in float32
+ # when running in half-precision
+ sample = self.conv_norm_out(sample.double()).type(sample.dtype)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ if skip_sample is not None:
+ sample += skip_sample
+
+ if self.config.time_embedding_type == "fourier":
+ timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
+ sample = sample / timesteps
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet2DOutput(sample=sample)
diff --git a/my_diffusers/models/unet_2d_condition.py b/my_diffusers/models/unet_2d_condition.py
new file mode 100644
index 0000000000000000000000000000000000000000..f951e8457fd6d207eed31488ff49863143923d67
--- /dev/null
+++ b/my_diffusers/models/unet_2d_condition.py
@@ -0,0 +1,273 @@
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..modeling_utils import ModelMixin
+from ..utils import BaseOutput
+from .embeddings import TimestepEmbedding, Timesteps
+from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
+
+
+@dataclass
+class UNet2DConditionOutput(BaseOutput):
+ """
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
+ """
+
+ sample: torch.FloatTensor
+
+
+class UNet2DConditionModel(ModelMixin, ConfigMixin):
+ r"""
+ UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
+ and returns sample shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
+ implements for all the model (such as downloading or saving, etc.)
+
+ Parameters:
+ sample_size (`int`, *optional*): The size of the input sample.
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
+ The tuple of upsample blocks to use.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
+ cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 1280,
+ attention_head_dim: int = 8,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+ time_embed_dim = block_out_channels[0] * 4
+
+ # input
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
+
+ # time
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+
+ self.down_blocks = nn.ModuleList([])
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim,
+ downsample_padding=downsample_padding,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = UNetMidBlock2DCrossAttn(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift="default",
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim,
+ resnet_groups=norm_num_groups,
+ )
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ is_final_block = i == len(block_out_channels) - 1
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=time_embed_dim,
+ add_upsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
+ self.conv_act = nn.SiLU()
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
+
+ def set_attention_slice(self, slice_size):
+ if slice_size is not None and self.config.attention_head_dim % slice_size != 0:
+ raise ValueError(
+ f"Make sure slice_size {slice_size} is a divisor of "
+ f"the number of heads used in cross_attention {self.config.attention_head_dim}"
+ )
+ if slice_size is not None and slice_size > self.config.attention_head_dim:
+ raise ValueError(
+ f"Chunk_size {slice_size} has to be smaller or equal to "
+ f"the number of heads used in cross_attention {self.config.attention_head_dim}"
+ )
+
+ for block in self.down_blocks:
+ if hasattr(block, "attentions") and block.attentions is not None:
+ block.set_attention_slice(slice_size)
+
+ self.mid_block.set_attention_slice(slice_size)
+
+ for block in self.up_blocks:
+ if hasattr(block, "attentions") and block.attentions is not None:
+ block.set_attention_slice(slice_size)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ return_dict: bool = True,
+ ) -> Union[UNet2DConditionOutput, Tuple]:
+ """r
+ Args:
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
+ timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
+ encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+ """
+ # 0. center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
+ timesteps = timesteps.to(dtype=torch.float64)
+ timesteps = timesteps[None].to(device=sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+ # print(t_emb.dtype)
+ t_emb = t_emb.to(sample.dtype).to(sample.device)
+ emb = self.time_embedding(t_emb)
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
+ # print(sample.dtype, emb.dtype, encoder_hidden_states.dtype)
+ sample, res_samples = downsample_block(
+ hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
+
+ # 5. up
+ for upsample_block in self.up_blocks:
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ )
+ else:
+ sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples)
+
+ # 6. post-process
+ # make sure hidden states is in float32
+ # when running in half-precision
+ sample = self.conv_norm_out(sample.double()).type(sample.dtype)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet2DConditionOutput(sample=sample)
diff --git a/my_diffusers/models/unet_blocks.py b/my_diffusers/models/unet_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e062165357c33d9b2f0bec13a66204c2e7e7833
--- /dev/null
+++ b/my_diffusers/models/unet_blocks.py
@@ -0,0 +1,1481 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import numpy as np
+
+# limitations under the License.
+import torch
+from torch import nn
+
+from .attention import AttentionBlock, SpatialTransformer
+from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
+
+
+def get_down_block(
+ down_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ temb_channels,
+ add_downsample,
+ resnet_eps,
+ resnet_act_fn,
+ attn_num_head_channels,
+ cross_attention_dim=None,
+ downsample_padding=None,
+):
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
+ if down_block_type == "DownBlock2D":
+ return DownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ )
+ elif down_block_type == "AttnDownBlock2D":
+ return AttnDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif down_block_type == "CrossAttnDownBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
+ return CrossAttnDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif down_block_type == "SkipDownBlock2D":
+ return SkipDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ )
+ elif down_block_type == "AttnSkipDownBlock2D":
+ return AttnSkipDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif down_block_type == "DownEncoderBlock2D":
+ return DownEncoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ )
+
+
+def get_up_block(
+ up_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ prev_output_channel,
+ temb_channels,
+ add_upsample,
+ resnet_eps,
+ resnet_act_fn,
+ attn_num_head_channels,
+ cross_attention_dim=None,
+):
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
+ if up_block_type == "UpBlock2D":
+ return UpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ )
+ elif up_block_type == "CrossAttnUpBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
+ return CrossAttnUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif up_block_type == "AttnUpBlock2D":
+ return AttnUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif up_block_type == "SkipUpBlock2D":
+ return SkipUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ )
+ elif up_block_type == "AttnSkipUpBlock2D":
+ return AttnSkipUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif up_block_type == "UpDecoderBlock2D":
+ return UpDecoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ )
+ raise ValueError(f"{up_block_type} does not exist.")
+
+
+class UNetMidBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ attention_type="default",
+ output_scale_factor=1.0,
+ **kwargs,
+ ):
+ super().__init__()
+
+ self.attention_type = attention_type
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ attentions.append(
+ AttentionBlock(
+ in_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ num_groups=resnet_groups,
+ )
+ )
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(self, hidden_states, temb=None, encoder_states=None):
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ if self.attention_type == "default":
+ hidden_states = attn(hidden_states)
+ else:
+ hidden_states = attn(hidden_states, encoder_states)
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class UNetMidBlock2DCrossAttn(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ attention_type="default",
+ output_scale_factor=1.0,
+ cross_attention_dim=1280,
+ **kwargs,
+ ):
+ super().__init__()
+
+ self.attention_type = attention_type
+ self.attn_num_head_channels = attn_num_head_channels
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ attentions.append(
+ SpatialTransformer(
+ in_channels,
+ attn_num_head_channels,
+ in_channels // attn_num_head_channels,
+ depth=1,
+ context_dim=cross_attention_dim,
+ )
+ )
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def set_attention_slice(self, slice_size):
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
+ raise ValueError(
+ f"Make sure slice_size {slice_size} is a divisor of "
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
+ )
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
+ raise ValueError(
+ f"Chunk_size {slice_size} has to be smaller or equal to "
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
+ )
+
+ for attn in self.attentions:
+ attn._set_attention_slice(slice_size)
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ hidden_states = attn(hidden_states, encoder_hidden_states)
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class AttnDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ attention_type="default",
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_downsample=True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.attention_type = attention_type
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states, temb=None):
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states)
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class CrossAttnDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ cross_attention_dim=1280,
+ attention_type="default",
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_downsample=True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.attention_type = attention_type
+ self.attn_num_head_channels = attn_num_head_channels
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ SpatialTransformer(
+ out_channels,
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ depth=1,
+ context_dim=cross_attention_dim,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def set_attention_slice(self, slice_size):
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
+ raise ValueError(
+ f"Make sure slice_size {slice_size} is a divisor of "
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
+ )
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
+ raise ValueError(
+ f"Chunk_size {slice_size} has to be smaller or equal to "
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
+ )
+
+ for attn in self.attentions:
+ attn._set_attention_slice(slice_size)
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, context=encoder_hidden_states)
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class DownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states, temb=None):
+ output_states = ()
+
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb)
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class DownEncoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states):
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb=None)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnDownEncoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ num_groups=resnet_groups,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb=None)
+ hidden_states = attn(hidden_states)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnSkipDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ attention_type="default",
+ output_scale_factor=np.sqrt(2.0),
+ downsample_padding=1,
+ add_downsample=True,
+ ):
+ super().__init__()
+ self.attentions = nn.ModuleList([])
+ self.resnets = nn.ModuleList([])
+
+ self.attention_type = attention_type
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(in_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ self.attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ )
+ )
+
+ if add_downsample:
+ self.resnet_down = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_nin_shortcut=True,
+ down=True,
+ kernel="fir",
+ )
+ self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
+ else:
+ self.resnet_down = None
+ self.downsamplers = None
+ self.skip_conv = None
+
+ def forward(self, hidden_states, temb=None, skip_sample=None):
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states)
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ hidden_states = self.resnet_down(hidden_states, temb)
+ for downsampler in self.downsamplers:
+ skip_sample = downsampler(skip_sample)
+
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states, skip_sample
+
+
+class SkipDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ output_scale_factor=np.sqrt(2.0),
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ self.resnets = nn.ModuleList([])
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(in_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ if add_downsample:
+ self.resnet_down = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_nin_shortcut=True,
+ down=True,
+ kernel="fir",
+ )
+ self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
+ else:
+ self.resnet_down = None
+ self.downsamplers = None
+ self.skip_conv = None
+
+ def forward(self, hidden_states, temb=None, skip_sample=None):
+ output_states = ()
+
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb)
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ hidden_states = self.resnet_down(hidden_states, temb)
+ for downsampler in self.downsamplers:
+ skip_sample = downsampler(skip_sample)
+
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states, skip_sample
+
+
+class AttnUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_type="default",
+ attn_num_head_channels=1,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.attention_type = attention_type
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
+ for resnet, attn in zip(self.resnets, self.attentions):
+
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class CrossAttnUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ cross_attention_dim=1280,
+ attention_type="default",
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.attention_type = attention_type
+ self.attn_num_head_channels = attn_num_head_channels
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ SpatialTransformer(
+ out_channels,
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ depth=1,
+ context_dim=cross_attention_dim,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ def set_attention_slice(self, slice_size):
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
+ raise ValueError(
+ f"Make sure slice_size {slice_size} is a divisor of "
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
+ )
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
+ raise ValueError(
+ f"Chunk_size {slice_size} has to be smaller or equal to "
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
+ )
+
+ for attn in self.attentions:
+ attn._set_attention_slice(slice_size)
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None):
+ for resnet, attn in zip(self.resnets, self.attentions):
+
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, context=encoder_hidden_states)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class UpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
+ for resnet in self.resnets:
+
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class UpDecoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ input_channels = in_channels if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ def forward(self, hidden_states):
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb=None)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnUpDecoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ for i in range(num_layers):
+ input_channels = in_channels if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ num_groups=resnet_groups,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ def forward(self, hidden_states):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb=None)
+ hidden_states = attn(hidden_states)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnSkipUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ attention_type="default",
+ output_scale_factor=np.sqrt(2.0),
+ upsample_padding=1,
+ add_upsample=True,
+ ):
+ super().__init__()
+ self.attentions = nn.ModuleList([])
+ self.resnets = nn.ModuleList([])
+
+ self.attention_type = attention_type
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(resnet_in_channels + res_skip_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ )
+ )
+
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
+ if add_upsample:
+ self.resnet_up = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_nin_shortcut=True,
+ up=True,
+ kernel="fir",
+ )
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
+ self.skip_norm = torch.nn.GroupNorm(
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
+ )
+ self.act = nn.SiLU()
+ else:
+ self.resnet_up = None
+ self.skip_conv = None
+ self.skip_norm = None
+ self.act = None
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+
+ hidden_states = self.attentions[0](hidden_states)
+
+ if skip_sample is not None:
+ skip_sample = self.upsampler(skip_sample)
+ else:
+ skip_sample = 0
+
+ if self.resnet_up is not None:
+ skip_sample_states = self.skip_norm(hidden_states)
+ skip_sample_states = self.act(skip_sample_states)
+ skip_sample_states = self.skip_conv(skip_sample_states)
+
+ skip_sample = skip_sample + skip_sample_states
+
+ hidden_states = self.resnet_up(hidden_states, temb)
+
+ return hidden_states, skip_sample
+
+
+class SkipUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ output_scale_factor=np.sqrt(2.0),
+ add_upsample=True,
+ upsample_padding=1,
+ ):
+ super().__init__()
+ self.resnets = nn.ModuleList([])
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min((resnet_in_channels + res_skip_channels) // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
+ if add_upsample:
+ self.resnet_up = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_nin_shortcut=True,
+ up=True,
+ kernel="fir",
+ )
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
+ self.skip_norm = torch.nn.GroupNorm(
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
+ )
+ self.act = nn.SiLU()
+ else:
+ self.resnet_up = None
+ self.skip_conv = None
+ self.skip_norm = None
+ self.act = None
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+
+ if skip_sample is not None:
+ skip_sample = self.upsampler(skip_sample)
+ else:
+ skip_sample = 0
+
+ if self.resnet_up is not None:
+ skip_sample_states = self.skip_norm(hidden_states)
+ skip_sample_states = self.act(skip_sample_states)
+ skip_sample_states = self.skip_conv(skip_sample_states)
+
+ skip_sample = skip_sample + skip_sample_states
+
+ hidden_states = self.resnet_up(hidden_states, temb)
+
+ return hidden_states, skip_sample
diff --git a/my_diffusers/models/vae.py b/my_diffusers/models/vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..82748cb5b60c0241cc3ca96f9016f07650e44a54
--- /dev/null
+++ b/my_diffusers/models/vae.py
@@ -0,0 +1,581 @@
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..modeling_utils import ModelMixin
+from ..utils import BaseOutput
+from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
+
+
+@dataclass
+class DecoderOutput(BaseOutput):
+ """
+ Output of decoding method.
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Decoded output sample of the model. Output of the last layer of the model.
+ """
+
+ sample: torch.FloatTensor
+
+
+@dataclass
+class VQEncoderOutput(BaseOutput):
+ """
+ Output of VQModel encoding method.
+
+ Args:
+ latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Encoded output sample of the model. Output of the last layer of the model.
+ """
+
+ latents: torch.FloatTensor
+
+
+@dataclass
+class AutoencoderKLOutput(BaseOutput):
+ """
+ Output of AutoencoderKL encoding method.
+
+ Args:
+ latent_dist (`DiagonalGaussianDistribution`):
+ Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
+ `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
+ """
+
+ latent_dist: "DiagonalGaussianDistribution"
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ in_channels=3,
+ out_channels=3,
+ down_block_types=("DownEncoderBlock2D",),
+ block_out_channels=(64,),
+ layers_per_block=2,
+ act_fn="silu",
+ double_z=True,
+ ):
+ super().__init__()
+ self.layers_per_block = layers_per_block
+
+ self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
+
+ self.mid_block = None
+ self.down_blocks = nn.ModuleList([])
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=self.layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ add_downsample=not is_final_block,
+ resnet_eps=1e-6,
+ downsample_padding=0,
+ resnet_act_fn=act_fn,
+ attn_num_head_channels=None,
+ temb_channels=None,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ output_scale_factor=1,
+ resnet_time_scale_shift="default",
+ attn_num_head_channels=None,
+ resnet_groups=32,
+ temb_channels=None,
+ )
+
+ # out
+ num_groups_out = 32
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups_out, eps=1e-6)
+ self.conv_act = nn.SiLU()
+
+ conv_out_channels = 2 * out_channels if double_z else out_channels
+ self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
+
+ def forward(self, x):
+ sample = x
+ sample = self.conv_in(sample)
+
+ # down
+ for down_block in self.down_blocks:
+ sample = down_block(sample)
+
+ # middle
+ sample = self.mid_block(sample)
+
+ # post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ return sample
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ in_channels=3,
+ out_channels=3,
+ up_block_types=("UpDecoderBlock2D",),
+ block_out_channels=(64,),
+ layers_per_block=2,
+ act_fn="silu",
+ ):
+ super().__init__()
+ self.layers_per_block = layers_per_block
+
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
+
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+
+ # mid
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ output_scale_factor=1,
+ resnet_time_scale_shift="default",
+ attn_num_head_channels=None,
+ resnet_groups=32,
+ temb_channels=None,
+ )
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+
+ is_final_block = i == len(block_out_channels) - 1
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=self.layers_per_block + 1,
+ in_channels=prev_output_channel,
+ out_channels=output_channel,
+ prev_output_channel=None,
+ add_upsample=not is_final_block,
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ attn_num_head_channels=None,
+ temb_channels=None,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ num_groups_out = 32
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=1e-6)
+ self.conv_act = nn.SiLU()
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
+
+ def forward(self, z):
+ sample = z
+ sample = self.conv_in(sample)
+
+ # middle
+ sample = self.mid_block(sample)
+
+ # up
+ for up_block in self.up_blocks:
+ sample = up_block(sample)
+
+ # post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ return sample
+
+
+class VectorQuantizer(nn.Module):
+ """
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
+ multiplications and allows for post-hoc remapping of indices.
+ """
+
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
+ # backwards compatibility we use the buggy version by default, but you can
+ # specify legacy=False to fix it.
+ def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True):
+ super().__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+ self.legacy = legacy
+
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed + 1
+ print(
+ f"Remapping {self.n_e} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices."
+ )
+ else:
+ self.re_embed = n_e
+
+ self.sane_index_shape = sane_index_shape
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ match = (inds[:, :, None] == used[None, None, ...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2) < 1
+ if self.unknown_index == "random":
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
+ return back.reshape(ishape)
+
+ def forward(self, z):
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = z.permute(0, 2, 3, 1).contiguous()
+ z_flattened = z.view(-1, self.e_dim)
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+ d = (
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
+ + torch.sum(self.embedding.weight**2, dim=1)
+ - 2 * torch.einsum("bd,dn->bn", z_flattened, self.embedding.weight.t())
+ )
+
+ min_encoding_indices = torch.argmin(d, dim=1)
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
+ perplexity = None
+ min_encodings = None
+
+ # compute loss for embedding
+ if not self.legacy:
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
+ else:
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ if self.remap is not None:
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
+
+ if self.sane_index_shape:
+ min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
+
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
+
+ def get_codebook_entry(self, indices, shape):
+ # shape specifying (batch, height, width, channel)
+ if self.remap is not None:
+ indices = indices.reshape(shape[0], -1) # add batch axis
+ indices = self.unmap_to_all(indices)
+ indices = indices.reshape(-1) # flatten again
+
+ # get quantized latent vectors
+ z_q = self.embedding(indices)
+
+ if shape is not None:
+ z_q = z_q.view(shape)
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
+ device = self.parameters.device
+ sample_device = "cpu" if device.type == "mps" else device
+ sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device)
+ x = self.mean + self.std * sample
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var
+ - 1.0
+ - self.logvar
+ + other.logvar,
+ dim=[1, 2, 3],
+ )
+
+ def nll(self, sample, dims=[1, 2, 3]):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
+
+ def mode(self):
+ return self.mean
+
+
+class VQModel(ModelMixin, ConfigMixin):
+ r"""VQ-VAE model from the paper Neural Discrete Representation Learning by Aaron van den Oord, Oriol Vinyals and Koray
+ Kavukcuoglu.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
+ implements for all the model (such as downloading or saving, etc.)
+
+ Parameters:
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to :
+ obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
+ up_block_types (`Tuple[str]`, *optional*, defaults to :
+ obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to :
+ obj:`(64,)`): Tuple of block output channels.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
+ sample_size (`int`, *optional*, defaults to `32`): TODO
+ num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
+ block_out_channels: Tuple[int] = (64,),
+ layers_per_block: int = 1,
+ act_fn: str = "silu",
+ latent_channels: int = 3,
+ sample_size: int = 32,
+ num_vq_embeddings: int = 256,
+ ):
+ super().__init__()
+
+ # pass init params to Encoder
+ self.encoder = Encoder(
+ in_channels=in_channels,
+ out_channels=latent_channels,
+ down_block_types=down_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ double_z=False,
+ )
+
+ self.quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
+ self.quantize = VectorQuantizer(
+ num_vq_embeddings, latent_channels, beta=0.25, remap=None, sane_index_shape=False
+ )
+ self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
+
+ # pass init params to Decoder
+ self.decoder = Decoder(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ up_block_types=up_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ )
+
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+
+ if not return_dict:
+ return (h,)
+
+ return VQEncoderOutput(latents=h)
+
+ def decode(
+ self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
+ # also go through quantization layer
+ if not force_not_quantize:
+ quant, emb_loss, info = self.quantize(h)
+ else:
+ quant = h
+ quant = self.post_quant_conv(quant)
+ dec = self.decoder(quant)
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): Input sample.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ x = sample
+ h = self.encode(x).latents
+ dec = self.decode(h).sample
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+
+class AutoencoderKL(ModelMixin, ConfigMixin):
+ r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
+ and Max Welling.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
+ implements for all the model (such as downloading or saving, etc.)
+
+ Parameters:
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to :
+ obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
+ up_block_types (`Tuple[str]`, *optional*, defaults to :
+ obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to :
+ obj:`(64,)`): Tuple of block output channels.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space.
+ sample_size (`int`, *optional*, defaults to `32`): TODO
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
+ block_out_channels: Tuple[int] = (64,),
+ layers_per_block: int = 1,
+ act_fn: str = "silu",
+ latent_channels: int = 4,
+ sample_size: int = 32,
+ ):
+ super().__init__()
+
+ # pass init params to Encoder
+ self.encoder = Encoder(
+ in_channels=in_channels,
+ out_channels=latent_channels,
+ down_block_types=down_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ double_z=True,
+ )
+
+ # pass init params to Decoder
+ self.decoder = Decoder(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ up_block_types=up_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ )
+
+ self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
+ self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
+
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+
+ if not return_dict:
+ return (posterior,)
+
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ def forward(
+ self, sample: torch.FloatTensor, sample_posterior: bool = False, return_dict: bool = True
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): Input sample.
+ sample_posterior (`bool`, *optional*, defaults to `False`):
+ Whether to sample from the posterior.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z).sample
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
diff --git a/my_diffusers/onnx_utils.py b/my_diffusers/onnx_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e840565dd5c1b9bd17422aba5af6dc0d045c4682
--- /dev/null
+++ b/my_diffusers/onnx_utils.py
@@ -0,0 +1,189 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import shutil
+from pathlib import Path
+from typing import Optional, Union
+
+import numpy as np
+
+from huggingface_hub import hf_hub_download
+
+from .utils import is_onnx_available, logging
+
+
+if is_onnx_available():
+ import onnxruntime as ort
+
+
+ONNX_WEIGHTS_NAME = "model.onnx"
+
+
+logger = logging.get_logger(__name__)
+
+
+class OnnxRuntimeModel:
+ base_model_prefix = "onnx_model"
+
+ def __init__(self, model=None, **kwargs):
+ logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.")
+ self.model = model
+ self.model_save_dir = kwargs.get("model_save_dir", None)
+ self.latest_model_name = kwargs.get("latest_model_name", "model.onnx")
+
+ def __call__(self, **kwargs):
+ inputs = {k: np.array(v) for k, v in kwargs.items()}
+ return self.model.run(None, inputs)
+
+ @staticmethod
+ def load_model(path: Union[str, Path], provider=None):
+ """
+ Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`
+
+ Arguments:
+ path (`str` or `Path`):
+ Directory from which to load
+ provider(`str`, *optional*):
+ Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider`
+ """
+ if provider is None:
+ logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
+ provider = "CPUExecutionProvider"
+
+ return ort.InferenceSession(path, providers=[provider])
+
+ def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs):
+ """
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
+ [`~optimum.onnxruntime.modeling_ort.ORTModel.from_pretrained`] class method. It will always save the
+ latest_model_name.
+
+ Arguments:
+ save_directory (`str` or `Path`):
+ Directory where to save the model file.
+ file_name(`str`, *optional*):
+ Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to save the
+ model with a different name.
+ """
+ model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
+
+ src_path = self.model_save_dir.joinpath(self.latest_model_name)
+ dst_path = Path(save_directory).joinpath(model_file_name)
+ if not src_path.samefile(dst_path):
+ shutil.copyfile(src_path, dst_path)
+
+ def save_pretrained(
+ self,
+ save_directory: Union[str, os.PathLike],
+ **kwargs,
+ ):
+ """
+ Save a model to a directory, so that it can be re-loaded using the [`~OnnxModel.from_pretrained`] class
+ method.:
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to which to save. Will be created if it doesn't exist.
+ """
+ if os.path.isfile(save_directory):
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+ return
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ # saving model weights/files
+ self._save_pretrained(save_directory, **kwargs)
+
+ @classmethod
+ def _from_pretrained(
+ cls,
+ model_id: Union[str, Path],
+ use_auth_token: Optional[Union[bool, str, None]] = None,
+ revision: Optional[Union[str, None]] = None,
+ force_download: bool = False,
+ cache_dir: Optional[str] = None,
+ file_name: Optional[str] = None,
+ provider: Optional[str] = None,
+ **kwargs,
+ ):
+ """
+ Load a model from a directory or the HF Hub.
+
+ Arguments:
+ model_id (`str` or `Path`):
+ Directory from which to load
+ use_auth_token (`str` or `bool`):
+ Is needed to load models from a private or gated repository
+ revision (`str`):
+ Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id
+ cache_dir (`Union[str, Path]`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ file_name(`str`):
+ Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to load
+ different model files from the same repository or directory.
+ provider(`str`):
+ The ONNX runtime provider, e.g. `CPUExecutionProvider` or `CUDAExecutionProvider`.
+ kwargs (`Dict`, *optional*):
+ kwargs will be passed to the model during initialization
+ """
+ model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
+ # load model from local directory
+ if os.path.isdir(model_id):
+ model = OnnxRuntimeModel.load_model(os.path.join(model_id, model_file_name), provider=provider)
+ kwargs["model_save_dir"] = Path(model_id)
+ # load model from hub
+ else:
+ # download model
+ model_cache_path = hf_hub_download(
+ repo_id=model_id,
+ filename=model_file_name,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ )
+ kwargs["model_save_dir"] = Path(model_cache_path).parent
+ kwargs["latest_model_name"] = Path(model_cache_path).name
+ model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider)
+ return cls(model=model, **kwargs)
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ model_id: Union[str, Path],
+ force_download: bool = True,
+ use_auth_token: Optional[str] = None,
+ cache_dir: Optional[str] = None,
+ **model_kwargs,
+ ):
+ revision = None
+ if len(str(model_id).split("@")) == 2:
+ model_id, revision = model_id.split("@")
+
+ return cls._from_pretrained(
+ model_id=model_id,
+ revision=revision,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ use_auth_token=use_auth_token,
+ **model_kwargs,
+ )
diff --git a/my_diffusers/optimization.py b/my_diffusers/optimization.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7b836b4a69bffb61c15967ef9b1736201721f1b
--- /dev/null
+++ b/my_diffusers/optimization.py
@@ -0,0 +1,275 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch optimization for diffusion models."""
+
+import math
+from enum import Enum
+from typing import Optional, Union
+
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import LambdaLR
+
+from .utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class SchedulerType(Enum):
+ LINEAR = "linear"
+ COSINE = "cosine"
+ COSINE_WITH_RESTARTS = "cosine_with_restarts"
+ POLYNOMIAL = "polynomial"
+ CONSTANT = "constant"
+ CONSTANT_WITH_WARMUP = "constant_with_warmup"
+
+
+def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
+ """
+ Create a schedule with a constant learning rate, using the learning rate set in optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+ return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
+
+
+def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
+ """
+ Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
+ increases linearly between 0 and the initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ def lr_lambda(current_step: int):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1.0, num_warmup_steps))
+ return 1.0
+
+ return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
+
+
+def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
+ """
+ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
+ a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ def lr_lambda(current_step: int):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ return max(
+ 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
+ )
+
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def get_cosine_schedule_with_warmup(
+ optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
+):
+ """
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
+ initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
+ initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ num_cycles (`float`, *optional*, defaults to 0.5):
+ The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
+ following a half-cosine).
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ def lr_lambda(current_step):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
+
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def get_cosine_with_hard_restarts_schedule_with_warmup(
+ optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
+):
+ """
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
+ initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
+ linearly between 0 and the initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ num_cycles (`int`, *optional*, defaults to 1):
+ The number of hard restarts to use.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ def lr_lambda(current_step):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
+ if progress >= 1.0:
+ return 0.0
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
+
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def get_polynomial_decay_schedule_with_warmup(
+ optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
+):
+ """
+ Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
+ optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
+ initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ lr_end (`float`, *optional*, defaults to 1e-7):
+ The end LR.
+ power (`float`, *optional*, defaults to 1.0):
+ Power factor.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
+ implementation at
+ https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+
+ """
+
+ lr_init = optimizer.defaults["lr"]
+ if not (lr_init > lr_end):
+ raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})")
+
+ def lr_lambda(current_step: int):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ elif current_step > num_training_steps:
+ return lr_end / lr_init # as LambdaLR multiplies by lr_init
+ else:
+ lr_range = lr_init - lr_end
+ decay_steps = num_training_steps - num_warmup_steps
+ pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
+ decay = lr_range * pct_remaining**power + lr_end
+ return decay / lr_init # as LambdaLR multiplies by lr_init
+
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+TYPE_TO_SCHEDULER_FUNCTION = {
+ SchedulerType.LINEAR: get_linear_schedule_with_warmup,
+ SchedulerType.COSINE: get_cosine_schedule_with_warmup,
+ SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
+ SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
+ SchedulerType.CONSTANT: get_constant_schedule,
+ SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
+}
+
+
+def get_scheduler(
+ name: Union[str, SchedulerType],
+ optimizer: Optimizer,
+ num_warmup_steps: Optional[int] = None,
+ num_training_steps: Optional[int] = None,
+):
+ """
+ Unified API to get any scheduler from its name.
+
+ Args:
+ name (`str` or `SchedulerType`):
+ The name of the scheduler to use.
+ optimizer (`torch.optim.Optimizer`):
+ The optimizer that will be used during training.
+ num_warmup_steps (`int`, *optional*):
+ The number of warmup steps to do. This is not required by all schedulers (hence the argument being
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
+ num_training_steps (`int``, *optional*):
+ The number of training steps to do. This is not required by all schedulers (hence the argument being
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
+ """
+ name = SchedulerType(name)
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
+ if name == SchedulerType.CONSTANT:
+ return schedule_func(optimizer)
+
+ # All other schedulers require `num_warmup_steps`
+ if num_warmup_steps is None:
+ raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
+
+ if name == SchedulerType.CONSTANT_WITH_WARMUP:
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
+
+ # All other schedulers require `num_training_steps`
+ if num_training_steps is None:
+ raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
+
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
diff --git a/my_diffusers/pipeline_utils.py b/my_diffusers/pipeline_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..84ee9e20f1107a54dcdaf2799d805cf9e4f3b0a7
--- /dev/null
+++ b/my_diffusers/pipeline_utils.py
@@ -0,0 +1,417 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+import inspect
+import os
+from dataclasses import dataclass
+from typing import List, Optional, Union
+
+import numpy as np
+import torch
+
+import diffusers
+import PIL
+from huggingface_hub import snapshot_download
+from PIL import Image
+from tqdm.auto import tqdm
+
+from .configuration_utils import ConfigMixin
+from .utils import DIFFUSERS_CACHE, BaseOutput, logging
+
+
+INDEX_FILE = "diffusion_pytorch_model.bin"
+
+
+logger = logging.get_logger(__name__)
+
+
+LOADABLE_CLASSES = {
+ "diffusers": {
+ "ModelMixin": ["save_pretrained", "from_pretrained"],
+ "SchedulerMixin": ["save_config", "from_config"],
+ "DiffusionPipeline": ["save_pretrained", "from_pretrained"],
+ "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
+ },
+ "transformers": {
+ "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
+ "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
+ "PreTrainedModel": ["save_pretrained", "from_pretrained"],
+ "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
+ },
+}
+
+ALL_IMPORTABLE_CLASSES = {}
+for library in LOADABLE_CLASSES:
+ ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
+
+
+@dataclass
+class ImagePipelineOutput(BaseOutput):
+ """
+ Output class for image pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
+
+
+class DiffusionPipeline(ConfigMixin):
+ r"""
+ Base class for all models.
+
+ [`DiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion pipelines
+ and handles methods for loading, downloading and saving models as well as a few methods common to all pipelines to:
+
+ - move all PyTorch modules to the device of your choice
+ - enabling/disabling the progress bar for the denoising iteration
+
+ Class attributes:
+
+ - **config_name** ([`str`]) -- name of the config file that will store the class and module names of all
+ compenents of the diffusion pipeline.
+ """
+ config_name = "model_index.json"
+
+ def register_modules(self, **kwargs):
+ # import it here to avoid circular import
+ from diffusers import pipelines
+
+ for name, module in kwargs.items():
+ # retrive library
+ library = module.__module__.split(".")[0]
+
+ # check if the module is a pipeline module
+ pipeline_dir = module.__module__.split(".")[-2]
+ path = module.__module__.split(".")
+ is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
+
+ # if library is not in LOADABLE_CLASSES, then it is a custom module.
+ # Or if it's a pipeline module, then the module is inside the pipeline
+ # folder so we set the library to module name.
+ if library not in LOADABLE_CLASSES or is_pipeline_module:
+ library = pipeline_dir
+
+ # retrive class_name
+ class_name = module.__class__.__name__
+
+ register_dict = {name: (library, class_name)}
+
+ # save model index config
+ self.register_to_config(**register_dict)
+
+ # set models
+ setattr(self, name, module)
+
+ def save_pretrained(self, save_directory: Union[str, os.PathLike]):
+ """
+ Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to
+ a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading
+ method. The pipeline can easily be re-loaded using the `[`~DiffusionPipeline.from_pretrained`]` class method.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to which to save. Will be created if it doesn't exist.
+ """
+ self.save_config(save_directory)
+
+ model_index_dict = dict(self.config)
+ model_index_dict.pop("_class_name")
+ model_index_dict.pop("_diffusers_version")
+ model_index_dict.pop("_module", None)
+
+ for pipeline_component_name in model_index_dict.keys():
+ sub_model = getattr(self, pipeline_component_name)
+ model_cls = sub_model.__class__
+
+ save_method_name = None
+ # search for the model's base class in LOADABLE_CLASSES
+ for library_name, library_classes in LOADABLE_CLASSES.items():
+ library = importlib.import_module(library_name)
+ for base_class, save_load_methods in library_classes.items():
+ class_candidate = getattr(library, base_class)
+ if issubclass(model_cls, class_candidate):
+ # if we found a suitable base class in LOADABLE_CLASSES then grab its save method
+ save_method_name = save_load_methods[0]
+ break
+ if save_method_name is not None:
+ break
+
+ save_method = getattr(sub_model, save_method_name)
+ save_method(os.path.join(save_directory, pipeline_component_name))
+
+ def to(self, torch_device: Optional[Union[str, torch.device]] = None):
+ if torch_device is None:
+ return self
+
+ module_names, _ = self.extract_init_dict(dict(self.config))
+ for name in module_names.keys():
+ module = getattr(self, name)
+ if isinstance(module, torch.nn.Module):
+ module.to(torch_device)
+ return self
+
+ @property
+ def device(self) -> torch.device:
+ r"""
+ Returns:
+ `torch.device`: The torch device on which the pipeline is located.
+ """
+ module_names, _ = self.extract_init_dict(dict(self.config))
+ for name in module_names.keys():
+ module = getattr(self, name)
+ if isinstance(module, torch.nn.Module):
+ return module.device
+ return torch.device("cpu")
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
+ r"""
+ Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights.
+
+ The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
+
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
+ task.
+
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
+ weights are discarded.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on
+ https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like
+ `CompVis/ldm-text2im-large-256`.
+ - A path to a *directory* containing pipeline weights saved using
+ [`~DiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
+ will be automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ mirror (`str`, *optional*):
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
+ Please refer to the mirror site for more information. specify the folder name here.
+
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
+ speficic pipeline class. The overritten components are then directly passed to the pipelines `__init__`
+ method. See example below for more information.
+
+
+
+ Passing `use_auth_token=True`` is required when you want to use a private model, *e.g.*
+ `"CompVis/stable-diffusion-v1-4"`
+
+
+
+
+
+ Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
+ this method in a firewalled environment.
+
+
+
+ Examples:
+
+ ```py
+ >>> from diffusers import DiffusionPipeline
+
+ >>> # Download pipeline from huggingface.co and cache.
+ >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
+
+ >>> # Download pipeline that requires an authorization token
+ >>> # For more information on access tokens, please refer to this section
+ >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
+ >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
+
+ >>> # Download pipeline, but overwrite scheduler
+ >>> from diffusers import LMSDiscreteScheduler
+
+ >>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
+ >>> pipeline = DiffusionPipeline.from_pretrained(
+ ... "CompVis/stable-diffusion-v1-4", scheduler=scheduler, use_auth_token=True
+ ... )
+ ```
+ """
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ torch_dtype = kwargs.pop("torch_dtype", None)
+ provider = kwargs.pop("provider", None)
+
+ # 1. Download the checkpoints and configs
+ # use snapshot download here to get it working from from_pretrained
+ if not os.path.isdir(pretrained_model_name_or_path):
+ cached_folder = snapshot_download(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ )
+ else:
+ cached_folder = pretrained_model_name_or_path
+
+ config_dict = cls.get_config_dict(cached_folder)
+
+ # 2. Load the pipeline class, if using custom module then load it from the hub
+ # if we load from explicit class, let's use it
+ if cls != DiffusionPipeline:
+ pipeline_class = cls
+ else:
+ diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
+ pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
+
+ # some modules can be passed directly to the init
+ # in this case they are already instantiated in `kwargs`
+ # extract them here
+ expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys())
+ passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
+
+ init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
+
+ init_kwargs = {}
+
+ # import it here to avoid circular import
+ from diffusers import pipelines
+
+ # 3. Load each module in the pipeline
+ for name, (library_name, class_name) in init_dict.items():
+ is_pipeline_module = hasattr(pipelines, library_name)
+ loaded_sub_model = None
+
+ # if the model is in a pipeline module, then we load it from the pipeline
+ if name in passed_class_obj:
+ # 1. check that passed_class_obj has correct parent class
+ if not is_pipeline_module:
+ library = importlib.import_module(library_name)
+ class_obj = getattr(library, class_name)
+ importable_classes = LOADABLE_CLASSES[library_name]
+ class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
+
+ expected_class_obj = None
+ for class_name, class_candidate in class_candidates.items():
+ if issubclass(class_obj, class_candidate):
+ expected_class_obj = class_candidate
+
+ if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
+ raise ValueError(
+ f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
+ f" {expected_class_obj}"
+ )
+ else:
+ logger.warn(
+ f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
+ " has the correct type"
+ )
+
+ # set passed class object
+ loaded_sub_model = passed_class_obj[name]
+ elif is_pipeline_module:
+ pipeline_module = getattr(pipelines, library_name)
+ class_obj = getattr(pipeline_module, class_name)
+ importable_classes = ALL_IMPORTABLE_CLASSES
+ class_candidates = {c: class_obj for c in importable_classes.keys()}
+ else:
+ # else we just import it from the library.
+ library = importlib.import_module(library_name)
+ class_obj = getattr(library, class_name)
+ importable_classes = LOADABLE_CLASSES[library_name]
+ class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
+
+ if loaded_sub_model is None:
+ load_method_name = None
+ for class_name, class_candidate in class_candidates.items():
+ if issubclass(class_obj, class_candidate):
+ load_method_name = importable_classes[class_name][1]
+
+ load_method = getattr(class_obj, load_method_name)
+
+ loading_kwargs = {}
+ if issubclass(class_obj, torch.nn.Module):
+ loading_kwargs["torch_dtype"] = torch_dtype
+ if issubclass(class_obj, diffusers.OnnxRuntimeModel):
+ loading_kwargs["provider"] = provider
+
+ # check if the module is in a subdirectory
+ if os.path.isdir(os.path.join(cached_folder, name)):
+ loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
+ else:
+ # else load from the root directory
+ loaded_sub_model = load_method(cached_folder, **loading_kwargs)
+
+ init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
+
+ # 4. Instantiate the pipeline
+ model = pipeline_class(**init_kwargs)
+ return model
+
+ @staticmethod
+ def numpy_to_pil(images):
+ """
+ Convert a numpy image or a batch of images to a PIL image.
+ """
+ if images.ndim == 3:
+ images = images[None, ...]
+ images = (images * 255).round().astype("uint8")
+ pil_images = [Image.fromarray(image) for image in images]
+
+ return pil_images
+
+ def progress_bar(self, iterable):
+ if not hasattr(self, "_progress_bar_config"):
+ self._progress_bar_config = {}
+ elif not isinstance(self._progress_bar_config, dict):
+ raise ValueError(
+ f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
+ )
+
+ return tqdm(iterable, **self._progress_bar_config)
+
+ def set_progress_bar_config(self, **kwargs):
+ self._progress_bar_config = kwargs
diff --git a/my_diffusers/pipelines/__init__.py b/my_diffusers/pipelines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e2aeb4fb2b7f1315adb3a2ddea6aec42e806779
--- /dev/null
+++ b/my_diffusers/pipelines/__init__.py
@@ -0,0 +1,19 @@
+from ..utils import is_onnx_available, is_transformers_available
+from .ddim import DDIMPipeline
+from .ddpm import DDPMPipeline
+from .latent_diffusion_uncond import LDMPipeline
+from .pndm import PNDMPipeline
+from .score_sde_ve import ScoreSdeVePipeline
+from .stochastic_karras_ve import KarrasVePipeline
+
+
+if is_transformers_available():
+ from .latent_diffusion import LDMTextToImagePipeline
+ from .stable_diffusion import (
+ StableDiffusionImg2ImgPipeline,
+ StableDiffusionInpaintPipeline,
+ StableDiffusionPipeline,
+ )
+
+if is_transformers_available() and is_onnx_available():
+ from .stable_diffusion import StableDiffusionOnnxPipeline
diff --git a/my_diffusers/pipelines/__pycache__/__init__.cpython-38.pyc b/my_diffusers/pipelines/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e72179ff7b0130c130095bd2c003d24673965479
Binary files /dev/null and b/my_diffusers/pipelines/__pycache__/__init__.cpython-38.pyc differ
diff --git a/my_diffusers/pipelines/ddim/__init__.py b/my_diffusers/pipelines/ddim/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fd31868a88ac0d9ec7118574f21a9d8a1d4069b
--- /dev/null
+++ b/my_diffusers/pipelines/ddim/__init__.py
@@ -0,0 +1,2 @@
+# flake8: noqa
+from .pipeline_ddim import DDIMPipeline
diff --git a/my_diffusers/pipelines/ddim/__pycache__/__init__.cpython-38.pyc b/my_diffusers/pipelines/ddim/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cff5a2582d722d554451ecb1a08d539d56f17048
Binary files /dev/null and b/my_diffusers/pipelines/ddim/__pycache__/__init__.cpython-38.pyc differ
diff --git a/my_diffusers/pipelines/ddim/__pycache__/pipeline_ddim.cpython-38.pyc b/my_diffusers/pipelines/ddim/__pycache__/pipeline_ddim.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c32aa5ff231c85a761aaf356f303a6ae2b54a206
Binary files /dev/null and b/my_diffusers/pipelines/ddim/__pycache__/pipeline_ddim.cpython-38.pyc differ
diff --git a/my_diffusers/pipelines/ddim/pipeline_ddim.py b/my_diffusers/pipelines/ddim/pipeline_ddim.py
new file mode 100644
index 0000000000000000000000000000000000000000..33f6064dbba347dc82a941edac42e178a3e7df8a
--- /dev/null
+++ b/my_diffusers/pipelines/ddim/pipeline_ddim.py
@@ -0,0 +1,117 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+# limitations under the License.
+
+
+import warnings
+from typing import Optional, Tuple, Union
+
+import torch
+
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+
+
+class DDIMPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
+ [`DDPMScheduler`], or [`DDIMScheduler`].
+ """
+
+ def __init__(self, unet, scheduler):
+ super().__init__()
+ scheduler = scheduler.set_format("pt")
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ generator: Optional[torch.Generator] = None,
+ eta: float = 0.0,
+ num_inference_steps: int = 50,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of images to generate.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ eta (`float`, *optional*, defaults to 0.0):
+ The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM).
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+
+ if "torch_device" in kwargs:
+ device = kwargs.pop("torch_device")
+ warnings.warn(
+ "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
+ " Consider using `pipe.to(torch_device)` instead."
+ )
+
+ # Set device as before (to be removed in 0.3.0)
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.to(device)
+
+ # eta corresponds to η in paper and should be between [0, 1]
+
+ # Sample gaussian noise to begin loop
+ image = torch.randn(
+ (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
+ generator=generator,
+ )
+ image = image.to(self.device)
+
+ # set step values
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ # 1. predict noise model_output
+ model_output = self.unet(image, t).sample
+
+ # 2. predict previous mean of image x_t-1 and add variance depending on eta
+ # do x_t -> x_t-1
+ image = self.scheduler.step(model_output, t, image, eta).prev_sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/my_diffusers/pipelines/ddpm/__init__.py b/my_diffusers/pipelines/ddpm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8889bdae1224e91916e0f8454bafba0ee566f3b9
--- /dev/null
+++ b/my_diffusers/pipelines/ddpm/__init__.py
@@ -0,0 +1,2 @@
+# flake8: noqa
+from .pipeline_ddpm import DDPMPipeline
diff --git a/my_diffusers/pipelines/ddpm/__pycache__/__init__.cpython-38.pyc b/my_diffusers/pipelines/ddpm/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a318e5892bab1c85b74a2c48fc0c514a501093a1
Binary files /dev/null and b/my_diffusers/pipelines/ddpm/__pycache__/__init__.cpython-38.pyc differ
diff --git a/my_diffusers/pipelines/ddpm/__pycache__/pipeline_ddpm.cpython-38.pyc b/my_diffusers/pipelines/ddpm/__pycache__/pipeline_ddpm.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7383da1b9b4a4c1d76845bada22cb4ff1f8ec314
Binary files /dev/null and b/my_diffusers/pipelines/ddpm/__pycache__/pipeline_ddpm.cpython-38.pyc differ
diff --git a/my_diffusers/pipelines/ddpm/pipeline_ddpm.py b/my_diffusers/pipelines/ddpm/pipeline_ddpm.py
new file mode 100644
index 0000000000000000000000000000000000000000..71103bbe4d051e94f3fca9122460464fb8b1a4f7
--- /dev/null
+++ b/my_diffusers/pipelines/ddpm/pipeline_ddpm.py
@@ -0,0 +1,106 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+# limitations under the License.
+
+
+import warnings
+from typing import Optional, Tuple, Union
+
+import torch
+
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+
+
+class DDPMPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
+ [`DDPMScheduler`], or [`DDIMScheduler`].
+ """
+
+ def __init__(self, unet, scheduler):
+ super().__init__()
+ scheduler = scheduler.set_format("pt")
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of images to generate.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+ if "torch_device" in kwargs:
+ device = kwargs.pop("torch_device")
+ warnings.warn(
+ "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
+ " Consider using `pipe.to(torch_device)` instead."
+ )
+
+ # Set device as before (to be removed in 0.3.0)
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.to(device)
+
+ # Sample gaussian noise to begin loop
+ image = torch.randn(
+ (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
+ generator=generator,
+ )
+ image = image.to(self.device)
+
+ # set step values
+ self.scheduler.set_timesteps(1000)
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ # 1. predict noise model_output
+ model_output = self.unet(image, t).sample
+
+ # 2. compute previous image: x_t -> t_t-1
+ image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/my_diffusers/pipelines/latent_diffusion/__init__.py b/my_diffusers/pipelines/latent_diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c481b38cf5e0a1c4e24f7e0edf944efb68e1f979
--- /dev/null
+++ b/my_diffusers/pipelines/latent_diffusion/__init__.py
@@ -0,0 +1,6 @@
+# flake8: noqa
+from ...utils import is_transformers_available
+
+
+if is_transformers_available():
+ from .pipeline_latent_diffusion import LDMBertModel, LDMTextToImagePipeline
diff --git a/my_diffusers/pipelines/latent_diffusion/__pycache__/__init__.cpython-38.pyc b/my_diffusers/pipelines/latent_diffusion/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..26c628073117f055f965e3ca7a0276de27144b74
Binary files /dev/null and b/my_diffusers/pipelines/latent_diffusion/__pycache__/__init__.cpython-38.pyc differ
diff --git a/my_diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion.cpython-38.pyc b/my_diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2ba596bc20919b68b86ef353bf595a0166a3181c
Binary files /dev/null and b/my_diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion.cpython-38.pyc differ
diff --git a/my_diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/my_diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..b39840f2436b1deda0443fe0883eb4d1f6b73957
--- /dev/null
+++ b/my_diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
@@ -0,0 +1,705 @@
+import inspect
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+
+from transformers.activations import ACT2FN
+from transformers.configuration_utils import PretrainedConfig
+from transformers.modeling_outputs import BaseModelOutput
+from transformers.modeling_utils import PreTrainedModel
+from transformers.tokenization_utils import PreTrainedTokenizer
+from transformers.utils import logging
+
+from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+
+
+class LDMTextToImagePipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ vqvae ([`VQModel`]):
+ Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
+ bert ([`LDMBertModel`]):
+ Text-encoder model based on [BERT](ttps://huggingface.co/docs/transformers/model_doc/bert) architecture.
+ tokenizer (`transformers.BertTokenizer`):
+ Tokenizer of class
+ [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+
+ def __init__(
+ self,
+ vqvae: Union[VQModel, AutoencoderKL],
+ bert: PreTrainedModel,
+ tokenizer: PreTrainedTokenizer,
+ unet: Union[UNet2DModel, UNet2DConditionModel],
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ ):
+ super().__init__()
+ scheduler = scheduler.set_format("pt")
+ self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int] = 256,
+ width: Optional[int] = 256,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 1.0,
+ eta: Optional[float] = 0.0,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[Tuple, ImagePipelineOutput]:
+ r"""
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to 256):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 256):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 1.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt` at
+ the, usually at the expense of lower image quality.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+ if "torch_device" in kwargs:
+ device = kwargs.pop("torch_device")
+ warnings.warn(
+ "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
+ " Consider using `pipe.to(torch_device)` instead."
+ )
+
+ # Set device as before (to be removed in 0.3.0)
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.to(device)
+
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ # get unconditional embeddings for classifier free guidance
+ if guidance_scale != 1.0:
+ uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
+ uncond_embeddings = self.bert(uncond_input.input_ids.to(self.device))[0]
+
+ # get prompt text embeddings
+ text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
+ text_embeddings = self.bert(text_input.input_ids.to(self.device))[0]
+
+ latents = torch.randn(
+ (batch_size, self.unet.in_channels, height // 8, width // 8),
+ generator=generator,
+ )
+ latents = latents.to(self.device)
+
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+
+ extra_kwargs = {}
+ if accepts_eta:
+ extra_kwargs["eta"] = eta
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ if guidance_scale == 1.0:
+ # guidance_scale of 1 means no guidance
+ latents_input = latents
+ context = text_embeddings
+ else:
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ latents_input = torch.cat([latents] * 2)
+ context = torch.cat([uncond_embeddings, text_embeddings])
+
+ # predict the noise residual
+ noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample
+ # perform guidance
+ if guidance_scale != 1.0:
+ noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
+
+ # scale and decode the image latents with vae
+ latents = 1 / 0.18215 * latents
+ image = self.vqvae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
+
+
+################################################################################
+# Code for the text transformer model
+################################################################################
+""" PyTorch LDMBERT model."""
+
+
+logger = logging.get_logger(__name__)
+
+LDMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "ldm-bert",
+ # See all LDMBert models at https://huggingface.co/models?filter=ldmbert
+]
+
+
+LDMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "ldm-bert": "https://huggingface.co/ldm-bert/resolve/main/config.json",
+}
+
+
+""" LDMBERT model configuration"""
+
+
+class LDMBertConfig(PretrainedConfig):
+ model_type = "ldmbert"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
+
+ def __init__(
+ self,
+ vocab_size=30522,
+ max_position_embeddings=77,
+ encoder_layers=32,
+ encoder_ffn_dim=5120,
+ encoder_attention_heads=8,
+ head_dim=64,
+ encoder_layerdrop=0.0,
+ activation_function="gelu",
+ d_model=1280,
+ dropout=0.1,
+ attention_dropout=0.0,
+ activation_dropout=0.0,
+ init_std=0.02,
+ classifier_dropout=0.0,
+ scale_embedding=False,
+ use_cache=True,
+ pad_token_id=0,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.d_model = d_model
+ self.encoder_ffn_dim = encoder_ffn_dim
+ self.encoder_layers = encoder_layers
+ self.encoder_attention_heads = encoder_attention_heads
+ self.head_dim = head_dim
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.activation_dropout = activation_dropout
+ self.activation_function = activation_function
+ self.init_std = init_std
+ self.encoder_layerdrop = encoder_layerdrop
+ self.classifier_dropout = classifier_dropout
+ self.use_cache = use_cache
+ self.num_hidden_layers = encoder_layers
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
+
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
+
+
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+
+# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->LDMBert
+class LDMBertAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ head_dim: int,
+ dropout: float = 0.0,
+ is_decoder: bool = False,
+ bias: bool = False,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = head_dim
+ self.inner_dim = head_dim * num_heads
+
+ self.scaling = self.head_dim**-0.5
+ self.is_decoder = is_decoder
+
+ self.k_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias)
+ self.v_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias)
+ self.q_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias)
+ self.out_proj = nn.Linear(self.inner_dim, embed_dim)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scaling
+ # get key, value proj
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+ else:
+ # self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.view(*proj_shape)
+ value_states = value_states.view(*proj_shape)
+
+ src_len = key_states.size(1)
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if layer_head_mask is not None:
+ if layer_head_mask.size() != (self.num_heads,):
+ raise ValueError(
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
+ )
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if output_attentions:
+ # this operation is a bit awkward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to be reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ attn_output = torch.bmm(attn_probs, value_states)
+
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+ # partitioned aross GPUs when using tensor-parallelism.
+ attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped, past_key_value
+
+
+class LDMBertEncoderLayer(nn.Module):
+ def __init__(self, config: LDMBertConfig):
+ super().__init__()
+ self.embed_dim = config.d_model
+ self.self_attn = LDMBertAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.encoder_attention_heads,
+ head_dim=config.head_dim,
+ dropout=config.attention_dropout,
+ )
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: torch.FloatTensor,
+ layer_head_mask: torch.FloatTensor,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+ hidden_states, attn_weights, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ if hidden_states.dtype == torch.float16 and (
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
+ ):
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+# Copied from transformers.models.bart.modeling_bart.BartPretrainedModel with Bart->LDMBert
+class LDMBertPreTrainedModel(PreTrainedModel):
+ config_class = LDMBertConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"]
+
+ def _init_weights(self, module):
+ std = self.config.init_std
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (LDMBertEncoder,)):
+ module.gradient_checkpointing = value
+
+ @property
+ def dummy_inputs(self):
+ pad_token = self.config.pad_token_id
+ input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
+ dummy_inputs = {
+ "attention_mask": input_ids.ne(pad_token),
+ "input_ids": input_ids,
+ }
+ return dummy_inputs
+
+
+class LDMBertEncoder(LDMBertPreTrainedModel):
+ """
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
+ [`LDMBertEncoderLayer`].
+
+ Args:
+ config: LDMBertConfig
+ embed_tokens (nn.Embedding): output embedding
+ """
+
+ def __init__(self, config: LDMBertConfig):
+ super().__init__(config)
+
+ self.dropout = config.dropout
+
+ embed_dim = config.d_model
+ self.padding_idx = config.pad_token_id
+ self.max_source_positions = config.max_position_embeddings
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim)
+ self.embed_positions = nn.Embedding(config.max_position_embeddings, embed_dim)
+ self.layers = nn.ModuleList([LDMBertEncoderLayer(config) for _ in range(config.encoder_layers)])
+ self.layer_norm = nn.LayerNorm(embed_dim)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.BaseModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ seq_len = input_shape[1]
+ if position_ids is None:
+ position_ids = torch.arange(seq_len, dtype=torch.long, device=inputs_embeds.device).expand((1, -1))
+ embed_pos = self.embed_positions(position_ids)
+
+ hidden_states = inputs_embeds + embed_pos
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ # expand attention_mask
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ # check if head_mask has a correct number of layers specified if desired
+ if head_mask is not None:
+ if head_mask.size()[0] != (len(self.layers)):
+ raise ValueError(
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
+
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(encoder_layer),
+ hidden_states,
+ attention_mask,
+ (head_mask[idx] if head_mask is not None else None),
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+
+class LDMBertModel(LDMBertPreTrainedModel):
+ def __init__(self, config: LDMBertConfig):
+ super().__init__(config)
+ self.model = LDMBertEncoder(config)
+ self.to_logits = nn.Linear(config.hidden_size, config.vocab_size)
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ return outputs
diff --git a/my_diffusers/pipelines/latent_diffusion_uncond/__init__.py b/my_diffusers/pipelines/latent_diffusion_uncond/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0826ca7536c706f9bc1f310c157068efbca7f0b3
--- /dev/null
+++ b/my_diffusers/pipelines/latent_diffusion_uncond/__init__.py
@@ -0,0 +1,2 @@
+# flake8: noqa
+from .pipeline_latent_diffusion_uncond import LDMPipeline
diff --git a/my_diffusers/pipelines/latent_diffusion_uncond/__pycache__/__init__.cpython-38.pyc b/my_diffusers/pipelines/latent_diffusion_uncond/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e9f467ecc3f0ae53349349edebd91178b0220ab7
Binary files /dev/null and b/my_diffusers/pipelines/latent_diffusion_uncond/__pycache__/__init__.cpython-38.pyc differ
diff --git a/my_diffusers/pipelines/latent_diffusion_uncond/__pycache__/pipeline_latent_diffusion_uncond.cpython-38.pyc b/my_diffusers/pipelines/latent_diffusion_uncond/__pycache__/pipeline_latent_diffusion_uncond.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b65f7add1435e6ce7b940ccbbf1c82be5351a1b5
Binary files /dev/null and b/my_diffusers/pipelines/latent_diffusion_uncond/__pycache__/pipeline_latent_diffusion_uncond.cpython-38.pyc differ
diff --git a/my_diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/my_diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
new file mode 100644
index 0000000000000000000000000000000000000000..4979d88feee933483ac49c5cf71eef590d8fb34c
--- /dev/null
+++ b/my_diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
@@ -0,0 +1,108 @@
+import inspect
+import warnings
+from typing import Optional, Tuple, Union
+
+import torch
+
+from ...models import UNet2DModel, VQModel
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ...schedulers import DDIMScheduler
+
+
+class LDMPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ vqvae ([`VQModel`]):
+ Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ [`DDIMScheduler`] is to be used in combination with `unet` to denoise the encoded image latens.
+ """
+
+ def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler):
+ super().__init__()
+ scheduler = scheduler.set_format("pt")
+ self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ generator: Optional[torch.Generator] = None,
+ eta: float = 0.0,
+ num_inference_steps: int = 50,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[Tuple, ImagePipelineOutput]:
+
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ Number of images to generate.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+
+ if "torch_device" in kwargs:
+ device = kwargs.pop("torch_device")
+ warnings.warn(
+ "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
+ " Consider using `pipe.to(torch_device)` instead."
+ )
+
+ # Set device as before (to be removed in 0.3.0)
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.to(device)
+
+ latents = torch.randn(
+ (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
+ generator=generator,
+ )
+ latents = latents.to(self.device)
+
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+
+ extra_kwargs = {}
+ if accepts_eta:
+ extra_kwargs["eta"] = eta
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ # predict the noise residual
+ noise_prediction = self.unet(latents, t).sample
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample
+
+ # decode the image latents with the VAE
+ image = self.vqvae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/my_diffusers/pipelines/pndm/__init__.py b/my_diffusers/pipelines/pndm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fc46aaab9fa26e83b49c26843d854e217742664
--- /dev/null
+++ b/my_diffusers/pipelines/pndm/__init__.py
@@ -0,0 +1,2 @@
+# flake8: noqa
+from .pipeline_pndm import PNDMPipeline
diff --git a/my_diffusers/pipelines/pndm/__pycache__/__init__.cpython-38.pyc b/my_diffusers/pipelines/pndm/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..563d8026b625ac7b68315be5f705c79500b2270d
Binary files /dev/null and b/my_diffusers/pipelines/pndm/__pycache__/__init__.cpython-38.pyc differ
diff --git a/my_diffusers/pipelines/pndm/__pycache__/pipeline_pndm.cpython-38.pyc b/my_diffusers/pipelines/pndm/__pycache__/pipeline_pndm.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8c070c4e1455c52b6c8d8e7f3480a036910201e2
Binary files /dev/null and b/my_diffusers/pipelines/pndm/__pycache__/pipeline_pndm.cpython-38.pyc differ
diff --git a/my_diffusers/pipelines/pndm/pipeline_pndm.py b/my_diffusers/pipelines/pndm/pipeline_pndm.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3dff1a9a9416ef7592200c7dbb2ee092bd524d5
--- /dev/null
+++ b/my_diffusers/pipelines/pndm/pipeline_pndm.py
@@ -0,0 +1,111 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+# limitations under the License.
+
+
+import warnings
+from typing import Optional, Tuple, Union
+
+import torch
+
+from ...models import UNet2DModel
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ...schedulers import PNDMScheduler
+
+
+class PNDMPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ unet (`UNet2DModel`): U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ The `PNDMScheduler` to be used in combination with `unet` to denoise the encoded image.
+ """
+
+ unet: UNet2DModel
+ scheduler: PNDMScheduler
+
+ def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler):
+ super().__init__()
+ scheduler = scheduler.set_format("pt")
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ num_inference_steps: int = 50,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Args:
+ batch_size (`int`, `optional`, defaults to 1): The number of images to generate.
+ num_inference_steps (`int`, `optional`, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ generator (`torch.Generator`, `optional`): A [torch
+ generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ output_type (`str`, `optional`, defaults to `"pil"`): The output format of the generate image. Choose
+ between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, `optional`, defaults to `True`): Whether or not to return a
+ [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+ # For more information on the sampling method you can take a look at Algorithm 2 of
+ # the official paper: https://arxiv.org/pdf/2202.09778.pdf
+
+ if "torch_device" in kwargs:
+ device = kwargs.pop("torch_device")
+ warnings.warn(
+ "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
+ " Consider using `pipe.to(torch_device)` instead."
+ )
+
+ # Set device as before (to be removed in 0.3.0)
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.to(device)
+
+ # Sample gaussian noise to begin loop
+ image = torch.randn(
+ (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
+ generator=generator,
+ )
+ image = image.to(self.device)
+
+ self.scheduler.set_timesteps(num_inference_steps)
+ for t in self.progress_bar(self.scheduler.timesteps):
+ model_output = self.unet(image, t).sample
+
+ image = self.scheduler.step(model_output, t, image).prev_sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/my_diffusers/pipelines/score_sde_ve/__init__.py b/my_diffusers/pipelines/score_sde_ve/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..000d61f6e9b183728cb6fc137e7180cac3a616df
--- /dev/null
+++ b/my_diffusers/pipelines/score_sde_ve/__init__.py
@@ -0,0 +1,2 @@
+# flake8: noqa
+from .pipeline_score_sde_ve import ScoreSdeVePipeline
diff --git a/my_diffusers/pipelines/score_sde_ve/__pycache__/__init__.cpython-38.pyc b/my_diffusers/pipelines/score_sde_ve/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..339733ef55211beeb324fe75d5289bfdf2cd48a7
Binary files /dev/null and b/my_diffusers/pipelines/score_sde_ve/__pycache__/__init__.cpython-38.pyc differ
diff --git a/my_diffusers/pipelines/score_sde_ve/__pycache__/pipeline_score_sde_ve.cpython-38.pyc b/my_diffusers/pipelines/score_sde_ve/__pycache__/pipeline_score_sde_ve.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..29c6fa6f4711141882dee4a18b54280b0c4d2935
Binary files /dev/null and b/my_diffusers/pipelines/score_sde_ve/__pycache__/pipeline_score_sde_ve.cpython-38.pyc differ
diff --git a/my_diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/my_diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
new file mode 100644
index 0000000000000000000000000000000000000000..604e2b54cc1766ff446a23235ae4b40f790eadc5
--- /dev/null
+++ b/my_diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
@@ -0,0 +1,101 @@
+#!/usr/bin/env python3
+import warnings
+from typing import Optional, Tuple, Union
+
+import torch
+
+from ...models import UNet2DModel
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ...schedulers import ScoreSdeVeScheduler
+
+
+class ScoreSdeVePipeline(DiffusionPipeline):
+ r"""
+ Parameters:
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. scheduler ([`SchedulerMixin`]):
+ The [`ScoreSdeVeScheduler`] scheduler to be used in combination with `unet` to denoise the encoded image.
+ """
+ unet: UNet2DModel
+ scheduler: ScoreSdeVeScheduler
+
+ def __init__(self, unet: UNet2DModel, scheduler: DiffusionPipeline):
+ super().__init__()
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ num_inference_steps: int = 2000,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of images to generate.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+
+ if "torch_device" in kwargs:
+ device = kwargs.pop("torch_device")
+ warnings.warn(
+ "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
+ " Consider using `pipe.to(torch_device)` instead."
+ )
+
+ # Set device as before (to be removed in 0.3.0)
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.to(device)
+
+ img_size = self.unet.config.sample_size
+ shape = (batch_size, 3, img_size, img_size)
+
+ model = self.unet
+
+ sample = torch.randn(*shape, generator=generator) * self.scheduler.config.sigma_max
+ sample = sample.to(self.device)
+
+ self.scheduler.set_timesteps(num_inference_steps)
+ self.scheduler.set_sigmas(num_inference_steps)
+
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
+ sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device)
+
+ # correction step
+ for _ in range(self.scheduler.correct_steps):
+ model_output = self.unet(sample, sigma_t).sample
+ sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample
+
+ # prediction step
+ model_output = model(sample, sigma_t).sample
+ output = self.scheduler.step_pred(model_output, t, sample, generator=generator)
+
+ sample, sample_mean = output.prev_sample, output.prev_sample_mean
+
+ sample = sample_mean.clamp(0, 1)
+ sample = sample.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ sample = self.numpy_to_pil(sample)
+
+ if not return_dict:
+ return (sample,)
+
+ return ImagePipelineOutput(images=sample)
diff --git a/my_diffusers/pipelines/stable_diffusion/__init__.py b/my_diffusers/pipelines/stable_diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ffda93f172142c03298972177b9a74b85867be6
--- /dev/null
+++ b/my_diffusers/pipelines/stable_diffusion/__init__.py
@@ -0,0 +1,37 @@
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+
+import PIL
+from PIL import Image
+
+from ...utils import BaseOutput, is_onnx_available, is_transformers_available
+
+
+@dataclass
+class StableDiffusionPipelineOutput(BaseOutput):
+ """
+ Output class for Stable Diffusion pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ nsfw_content_detected (`List[bool]`)
+ List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
+ nsfw_content_detected: List[bool]
+
+
+if is_transformers_available():
+ from .pipeline_stable_diffusion import StableDiffusionPipeline
+ from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
+ from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
+ from .safety_checker import StableDiffusionSafetyChecker
+
+if is_transformers_available() and is_onnx_available():
+ from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline
diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/__init__.cpython-38.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..350dbcdb3a453b594f4e04de1b5398cea62b1ebf
Binary files /dev/null and b/my_diffusers/pipelines/stable_diffusion/__pycache__/__init__.cpython-38.pyc differ
diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-38.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9d48810afb070ccb1155024f787e12c3ab762a6b
Binary files /dev/null and b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-38.pyc differ
diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_img2img.cpython-38.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_img2img.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a901b086a179bd45a5458e1b755fa21b347cf637
Binary files /dev/null and b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_img2img.cpython-38.pyc differ
diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint.cpython-38.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..608cd830f921109282e2261cf13c7b44edeeff1d
Binary files /dev/null and b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint.cpython-38.pyc differ
diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_onnx.cpython-38.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_onnx.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a3c1e4709310c046c01684f6b4b96b0bdaa40988
Binary files /dev/null and b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_onnx.cpython-38.pyc differ
diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/safety_checker.cpython-38.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/safety_checker.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0b56dc15a72ae022a316069fc5e8cbd18104450e
Binary files /dev/null and b/my_diffusers/pipelines/stable_diffusion/__pycache__/safety_checker.cpython-38.pyc differ
diff --git a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..f02fa114a8e1607136fd1c8247e3cabb763b4415
--- /dev/null
+++ b/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
@@ -0,0 +1,279 @@
+import inspect
+import warnings
+from typing import List, Optional, Union
+
+import torch
+
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...pipeline_utils import DiffusionPipeline
+from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from . import StableDiffusionPipelineOutput
+from .safety_checker import StableDiffusionSafetyChecker
+
+
+class StableDiffusionPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offsensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ ):
+ super().__init__()
+ scheduler = scheduler.set_format("pt")
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ self.unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ eta: Optional[float] = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ if "torch_device" in kwargs:
+ device = kwargs.pop("torch_device")
+ warnings.warn(
+ "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
+ " Consider using `pipe.to(torch_device)` instead."
+ )
+
+ # Set device as before (to be removed in 0.3.0)
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.to(device)
+
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ # get prompt text embeddings
+ text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ max_length = text_input.input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # get the initial random noise unless the user supplied it
+
+ # Unlike in other pipelines, latents need to be generated in the target device
+ # for 1-to-1 results reproducibility with the CompVis implementation.
+ # However this currently doesn't work in `mps`.
+ latents_device = "cpu" if self.device.type == "mps" else self.device
+ latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
+ if latents is None:
+ latents = torch.randn(
+ latents_shape,
+ generator=generator,
+ device=latents_device,
+ )
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+ latents = latents.to(self.device)
+
+ # set timesteps
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
+ extra_set_kwargs = {}
+ if accepts_offset:
+ extra_set_kwargs["offset"] = 1
+
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
+
+ # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = latents * self.scheduler.sigmas[0]
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ sigma = self.scheduler.sigmas[i]
+ # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
+ latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
+ else:
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # scale and decode the image latents with vae
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+
+ # run safety checker
+ safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
+ image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..475ceef4f002f80842c4b25352a504f6b957db55
--- /dev/null
+++ b/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
@@ -0,0 +1,291 @@
+import inspect
+from typing import List, Optional, Union
+
+import numpy as np
+import torch
+
+import PIL
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...pipeline_utils import DiffusionPipeline
+from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from . import StableDiffusionPipelineOutput
+from .safety_checker import StableDiffusionSafetyChecker
+
+
+def preprocess(image):
+ w, h = image.size
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h), resample=PIL.Image.LANCZOS)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ return 2.0 * image - 1.0
+
+
+class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-guided image to image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offsensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ ):
+ super().__init__()
+ scheduler = scheduler.set_format("pt")
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ self.unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `set_attention_slice`
+ self.enable_attention_slice(None)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ init_image: Union[torch.FloatTensor, PIL.Image.Image],
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ eta: Optional[float] = 0.0,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ init_image (`torch.FloatTensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1.
+ `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter will be modulated by `strength`.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ # set timesteps
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
+ extra_set_kwargs = {}
+ offset = 0
+ if accepts_offset:
+ offset = 1
+ extra_set_kwargs["offset"] = 1
+
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
+
+ if not isinstance(init_image, torch.FloatTensor):
+ init_image = preprocess(init_image)
+
+ # encode the init image into latents and scale the latents
+ init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
+ init_latents = init_latent_dist.sample(generator=generator)
+ init_latents = 0.18215 * init_latents
+
+ # expand init_latents for batch_size
+ init_latents = torch.cat([init_latents] * batch_size)
+
+ # get the original timestep using init_timestep
+ init_timestep = int(num_inference_steps * strength) + offset
+ init_timestep = min(init_timestep, num_inference_steps)
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ timesteps = torch.tensor(
+ [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
+ )
+ else:
+ timesteps = self.scheduler.timesteps[-init_timestep]
+ timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
+
+ # add noise to latents using the timesteps
+ noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
+ init_latents = self.scheduler.add_noise(init_latents, noise, timesteps).to(self.device)
+
+ # get prompt text embeddings
+ text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ max_length = text_input.input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ latents = init_latents
+
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps[t_start:])):
+ t_index = t_start + i
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ sigma = self.scheduler.sigmas[t_index]
+ # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
+ latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
+ latent_model_input = latent_model_input.to(self.unet.dtype)
+ t = t.to(self.unet.dtype)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
+ else:
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # scale and decode the image latents with vae
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents.to(self.vae.dtype)).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+
+ # run safety checker
+ safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
+ image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
new file mode 100644
index 0000000000000000000000000000000000000000..05ea84ae0326231fa2ffbd4ad936f8747a9fed2c
--- /dev/null
+++ b/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
@@ -0,0 +1,309 @@
+import inspect
+from typing import List, Optional, Union
+
+import numpy as np
+import torch
+
+import PIL
+from tqdm.auto import tqdm
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...pipeline_utils import DiffusionPipeline
+from ...schedulers import DDIMScheduler, PNDMScheduler
+from ...utils import logging
+from . import StableDiffusionPipelineOutput
+from .safety_checker import StableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__)
+
+
+def preprocess_image(image):
+ w, h = image.size
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h), resample=PIL.Image.LANCZOS)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ return 2.0 * image - 1.0
+
+
+def preprocess_mask(mask):
+ mask = mask.convert("L")
+ w, h = mask.size
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
+ mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
+ mask = np.array(mask).astype(np.float32) / 255.0
+ mask = np.tile(mask, (4, 1, 1))
+ mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
+ mask = 1 - mask # repaint white, keep black
+ mask = torch.from_numpy(mask)
+ return mask
+
+
+class StableDiffusionInpaintPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offsensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ ):
+ super().__init__()
+ scheduler = scheduler.set_format("pt")
+ logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ self.unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `set_attention_slice`
+ self.enable_attention_slice(None)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ init_image: Union[torch.FloatTensor, PIL.Image.Image],
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image],
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ eta: Optional[float] = 0.0,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ init_image (`torch.FloatTensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process. This is the image whose masked region will be inpainted.
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
+ replaced by noise and therefore repainted, while black pixels will be preserved. The mask image will be
+ converted to a single channel (luminance) before use.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
+ is 1, the denoising process will be run on the masked area for the full number of iterations specified
+ in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more
+ noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
+ the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ # set timesteps
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
+ extra_set_kwargs = {}
+ offset = 0
+ if accepts_offset:
+ offset = 1
+ extra_set_kwargs["offset"] = 1
+
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
+
+ # preprocess image
+ init_image = preprocess_image(init_image).to(self.device)
+
+ # encode the init image into latents and scale the latents
+ init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
+ init_latents = init_latent_dist.sample(generator=generator)
+
+ init_latents = 0.18215 * init_latents
+
+ # Expand init_latents for batch_size
+ init_latents = torch.cat([init_latents] * batch_size)
+ init_latents_orig = init_latents
+
+ # preprocess mask
+ mask = preprocess_mask(mask_image).to(self.device)
+ mask = torch.cat([mask] * batch_size)
+
+ # check sizes
+ if not mask.shape == init_latents.shape:
+ raise ValueError("The mask and init_image should be the same size!")
+
+ # get the original timestep using init_timestep
+ init_timestep = int(num_inference_steps * strength) + offset
+ init_timestep = min(init_timestep, num_inference_steps)
+ timesteps = self.scheduler.timesteps[-init_timestep]
+ timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
+
+ # add noise to latents using the timesteps
+ noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
+ init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
+
+ # get prompt text embeddings
+ text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ max_length = text_input.input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ latents = init_latents
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
+ for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # masking
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
+
+ # scale and decode the image latents with vae
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+
+ # run safety checker
+ safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
+ image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ff3ff22fc21014fa7b6c12fba96a2ca36fc9cc4
--- /dev/null
+++ b/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py
@@ -0,0 +1,165 @@
+import inspect
+from typing import List, Optional, Union
+
+import numpy as np
+
+from transformers import CLIPFeatureExtractor, CLIPTokenizer
+
+from ...onnx_utils import OnnxRuntimeModel
+from ...pipeline_utils import DiffusionPipeline
+from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from . import StableDiffusionPipelineOutput
+
+
+class StableDiffusionOnnxPipeline(DiffusionPipeline):
+ vae_decoder: OnnxRuntimeModel
+ text_encoder: OnnxRuntimeModel
+ tokenizer: CLIPTokenizer
+ unet: OnnxRuntimeModel
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
+ safety_checker: OnnxRuntimeModel
+ feature_extractor: CLIPFeatureExtractor
+
+ def __init__(
+ self,
+ vae_decoder: OnnxRuntimeModel,
+ text_encoder: OnnxRuntimeModel,
+ tokenizer: CLIPTokenizer,
+ unet: OnnxRuntimeModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: OnnxRuntimeModel,
+ feature_extractor: CLIPFeatureExtractor,
+ ):
+ super().__init__()
+ scheduler = scheduler.set_format("np")
+ self.register_modules(
+ vae_decoder=vae_decoder,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ eta: Optional[float] = 0.0,
+ latents: Optional[np.ndarray] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ):
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ # get prompt text embeddings
+ text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ text_embeddings = self.text_encoder(input_ids=text_input.input_ids.astype(np.int32))[0]
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ max_length = text_input.input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
+ )
+ uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
+
+ # get the initial random noise unless the user supplied it
+ latents_shape = (batch_size, 4, height // 8, width // 8)
+ if latents is None:
+ latents = np.random.randn(*latents_shape).astype(np.float32)
+ elif latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+
+ # set timesteps
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
+ extra_set_kwargs = {}
+ if accepts_offset:
+ extra_set_kwargs["offset"] = 1
+
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
+
+ # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = latents * self.scheduler.sigmas[0]
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ sigma = self.scheduler.sigmas[i]
+ # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
+ latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings
+ )
+ noise_pred = noise_pred[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
+ else:
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # scale and decode the image latents with vae
+ latents = 1 / 0.18215 * latents
+ image = self.vae_decoder(latent_sample=latents)[0]
+
+ image = np.clip(image / 2 + 0.5, 0, 1)
+ image = image.transpose((0, 2, 3, 1))
+
+ # run safety checker
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
+ image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image)
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/my_diffusers/pipelines/stable_diffusion/safety_checker.py b/my_diffusers/pipelines/stable_diffusion/safety_checker.py
new file mode 100644
index 0000000000000000000000000000000000000000..09de92eeb1ec7e64863839012b1eddba444ad80a
--- /dev/null
+++ b/my_diffusers/pipelines/stable_diffusion/safety_checker.py
@@ -0,0 +1,106 @@
+import numpy as np
+import torch
+import torch.nn as nn
+
+from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
+
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+def cosine_distance(image_embeds, text_embeds):
+ normalized_image_embeds = nn.functional.normalize(image_embeds)
+ normalized_text_embeds = nn.functional.normalize(text_embeds)
+ return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
+
+
+class StableDiffusionSafetyChecker(PreTrainedModel):
+ config_class = CLIPConfig
+
+ def __init__(self, config: CLIPConfig):
+ super().__init__(config)
+
+ self.vision_model = CLIPVisionModel(config.vision_config)
+ self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False)
+
+ self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
+ self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)
+
+ self.register_buffer("concept_embeds_weights", torch.ones(17))
+ self.register_buffer("special_care_embeds_weights", torch.ones(3))
+
+ @torch.no_grad()
+ def forward(self, clip_input, images):
+ pooled_output = self.vision_model(clip_input)[1] # pooled_output
+ image_embeds = self.visual_projection(pooled_output)
+
+ special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().numpy()
+ cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().numpy()
+
+ result = []
+ batch_size = image_embeds.shape[0]
+ for i in range(batch_size):
+ result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}
+
+ # increase this value to create a stronger `nfsw` filter
+ # at the cost of increasing the possibility of filtering benign images
+ adjustment = 0.0
+
+ for concet_idx in range(len(special_cos_dist[0])):
+ concept_cos = special_cos_dist[i][concet_idx]
+ concept_threshold = self.special_care_embeds_weights[concet_idx].item()
+ result_img["special_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3)
+ if result_img["special_scores"][concet_idx] > 0:
+ result_img["special_care"].append({concet_idx, result_img["special_scores"][concet_idx]})
+ adjustment = 0.01
+
+ for concet_idx in range(len(cos_dist[0])):
+ concept_cos = cos_dist[i][concet_idx]
+ concept_threshold = self.concept_embeds_weights[concet_idx].item()
+ result_img["concept_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3)
+ if result_img["concept_scores"][concet_idx] > 0:
+ result_img["bad_concepts"].append(concet_idx)
+
+ result.append(result_img)
+
+ has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
+
+ for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
+ if has_nsfw_concept:
+ images[idx] = np.zeros(images[idx].shape) # black image
+
+ if any(has_nsfw_concepts):
+ logger.warning(
+ "Potential NSFW content was detected in one or more images. A black image will be returned instead."
+ " Try again with a different prompt and/or seed."
+ )
+
+ return images, has_nsfw_concepts
+
+ @torch.inference_mode()
+ def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor):
+ pooled_output = self.vision_model(clip_input)[1] # pooled_output
+ image_embeds = self.visual_projection(pooled_output)
+
+ special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds)
+ cos_dist = cosine_distance(image_embeds, self.concept_embeds)
+
+ # increase this value to create a stronger `nsfw` filter
+ # at the cost of increasing the possibility of filtering benign images
+ adjustment = 0.0
+
+ special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment
+ # special_scores = special_scores.round(decimals=3)
+ special_care = torch.any(special_scores > 0, dim=1)
+ special_adjustment = special_care * 0.01
+ special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1])
+
+ concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
+ # concept_scores = concept_scores.round(decimals=3)
+ has_nsfw_concepts = torch.any(concept_scores > 0, dim=1)
+
+ images[has_nsfw_concepts] = 0.0 # black image
+
+ return images, has_nsfw_concepts
diff --git a/my_diffusers/pipelines/stochastic_karras_ve/__init__.py b/my_diffusers/pipelines/stochastic_karras_ve/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..db2582043781130794e01b96b3e6beecbfe9f369
--- /dev/null
+++ b/my_diffusers/pipelines/stochastic_karras_ve/__init__.py
@@ -0,0 +1,2 @@
+# flake8: noqa
+from .pipeline_stochastic_karras_ve import KarrasVePipeline
diff --git a/my_diffusers/pipelines/stochastic_karras_ve/__pycache__/__init__.cpython-38.pyc b/my_diffusers/pipelines/stochastic_karras_ve/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d1705af44a67a33fcedcda87228267bd07a23bff
Binary files /dev/null and b/my_diffusers/pipelines/stochastic_karras_ve/__pycache__/__init__.cpython-38.pyc differ
diff --git a/my_diffusers/pipelines/stochastic_karras_ve/__pycache__/pipeline_stochastic_karras_ve.cpython-38.pyc b/my_diffusers/pipelines/stochastic_karras_ve/__pycache__/pipeline_stochastic_karras_ve.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8dd1b9f25f404b4c6a1e84d0ca2dd6916650c6ab
Binary files /dev/null and b/my_diffusers/pipelines/stochastic_karras_ve/__pycache__/pipeline_stochastic_karras_ve.cpython-38.pyc differ
diff --git a/my_diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py b/my_diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py
new file mode 100644
index 0000000000000000000000000000000000000000..15266544db7c8bc7448405955d74396eef7fe950
--- /dev/null
+++ b/my_diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py
@@ -0,0 +1,129 @@
+#!/usr/bin/env python3
+import warnings
+from typing import Optional, Tuple, Union
+
+import torch
+
+from ...models import UNet2DModel
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ...schedulers import KarrasVeScheduler
+
+
+class KarrasVePipeline(DiffusionPipeline):
+ r"""
+ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
+ the VE column of Table 1 from [1] for reference.
+
+ [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
+ https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
+ differential equations." https://arxiv.org/abs/2011.13456
+
+ Parameters:
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
+ scheduler ([`KarrasVeScheduler`]):
+ Scheduler for the diffusion process to be used in combination with `unet` to denoise the encoded image.
+ """
+
+ # add type hints for linting
+ unet: UNet2DModel
+ scheduler: KarrasVeScheduler
+
+ def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler):
+ super().__init__()
+ scheduler = scheduler.set_format("pt")
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ num_inference_steps: int = 50,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[Tuple, ImagePipelineOutput]:
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of images to generate.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+ if "torch_device" in kwargs:
+ device = kwargs.pop("torch_device")
+ warnings.warn(
+ "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
+ " Consider using `pipe.to(torch_device)` instead."
+ )
+
+ # Set device as before (to be removed in 0.3.0)
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.to(device)
+
+ img_size = self.unet.config.sample_size
+ shape = (batch_size, 3, img_size, img_size)
+
+ model = self.unet
+
+ # sample x_0 ~ N(0, sigma_0^2 * I)
+ sample = torch.randn(*shape) * self.scheduler.config.sigma_max
+ sample = sample.to(self.device)
+
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ # here sigma_t == t_i from the paper
+ sigma = self.scheduler.schedule[t]
+ sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0
+
+ # 1. Select temporarily increased noise level sigma_hat
+ # 2. Add new noise to move from sample_i to sample_hat
+ sample_hat, sigma_hat = self.scheduler.add_noise_to_input(sample, sigma, generator=generator)
+
+ # 3. Predict the noise residual given the noise magnitude `sigma_hat`
+ # The model inputs and output are adjusted by following eq. (213) in [1].
+ model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2).sample
+
+ # 4. Evaluate dx/dt at sigma_hat
+ # 5. Take Euler step from sigma to sigma_prev
+ step_output = self.scheduler.step(model_output, sigma_hat, sigma_prev, sample_hat)
+
+ if sigma_prev != 0:
+ # 6. Apply 2nd order correction
+ # The model inputs and output are adjusted by following eq. (213) in [1].
+ model_output = (sigma_prev / 2) * model((step_output.prev_sample + 1) / 2, sigma_prev / 2).sample
+ step_output = self.scheduler.step_correct(
+ model_output,
+ sigma_hat,
+ sigma_prev,
+ sample_hat,
+ step_output.prev_sample,
+ step_output["derivative"],
+ )
+ sample = step_output.prev_sample
+
+ sample = (sample / 2 + 0.5).clamp(0, 1)
+ image = sample.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(sample)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/my_diffusers/schedulers/__init__.py b/my_diffusers/schedulers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..20c25f35183faeeef2cd7b5095f80a70a9edac01
--- /dev/null
+++ b/my_diffusers/schedulers/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..utils import is_scipy_available
+from .scheduling_ddim import DDIMScheduler
+from .scheduling_ddpm import DDPMScheduler
+from .scheduling_karras_ve import KarrasVeScheduler
+from .scheduling_pndm import PNDMScheduler
+from .scheduling_sde_ve import ScoreSdeVeScheduler
+from .scheduling_sde_vp import ScoreSdeVpScheduler
+from .scheduling_utils import SchedulerMixin
+
+
+if is_scipy_available():
+ from .scheduling_lms_discrete import LMSDiscreteScheduler
+else:
+ from ..utils.dummy_scipy_objects import * # noqa F403
diff --git a/my_diffusers/schedulers/__pycache__/__init__.cpython-38.pyc b/my_diffusers/schedulers/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..845880add4f4539f3a9623ece683ed0d5a0d3493
Binary files /dev/null and b/my_diffusers/schedulers/__pycache__/__init__.cpython-38.pyc differ
diff --git a/my_diffusers/schedulers/__pycache__/scheduling_ddim.cpython-38.pyc b/my_diffusers/schedulers/__pycache__/scheduling_ddim.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0673a691b8f3f4de828e7639d101c38a423088f9
Binary files /dev/null and b/my_diffusers/schedulers/__pycache__/scheduling_ddim.cpython-38.pyc differ
diff --git a/my_diffusers/schedulers/__pycache__/scheduling_ddpm.cpython-38.pyc b/my_diffusers/schedulers/__pycache__/scheduling_ddpm.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0b97762f291166153973d74e085a05afd208979a
Binary files /dev/null and b/my_diffusers/schedulers/__pycache__/scheduling_ddpm.cpython-38.pyc differ
diff --git a/my_diffusers/schedulers/__pycache__/scheduling_karras_ve.cpython-38.pyc b/my_diffusers/schedulers/__pycache__/scheduling_karras_ve.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..36cb698ea93d65c8cc8897f756b214eacdaf4c06
Binary files /dev/null and b/my_diffusers/schedulers/__pycache__/scheduling_karras_ve.cpython-38.pyc differ
diff --git a/my_diffusers/schedulers/__pycache__/scheduling_lms_discrete.cpython-38.pyc b/my_diffusers/schedulers/__pycache__/scheduling_lms_discrete.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d6294a53b090c35d29c740499af77dd1c1297eb4
Binary files /dev/null and b/my_diffusers/schedulers/__pycache__/scheduling_lms_discrete.cpython-38.pyc differ
diff --git a/my_diffusers/schedulers/__pycache__/scheduling_pndm.cpython-38.pyc b/my_diffusers/schedulers/__pycache__/scheduling_pndm.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b0efcdd84948d3019cb66ca2cb00f53bf1103125
Binary files /dev/null and b/my_diffusers/schedulers/__pycache__/scheduling_pndm.cpython-38.pyc differ
diff --git a/my_diffusers/schedulers/__pycache__/scheduling_sde_ve.cpython-38.pyc b/my_diffusers/schedulers/__pycache__/scheduling_sde_ve.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a0f034cb27c0e595b7c9ee87e0b2a587cef30381
Binary files /dev/null and b/my_diffusers/schedulers/__pycache__/scheduling_sde_ve.cpython-38.pyc differ
diff --git a/my_diffusers/schedulers/__pycache__/scheduling_sde_vp.cpython-38.pyc b/my_diffusers/schedulers/__pycache__/scheduling_sde_vp.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..df1c210cffcd29846486915ea474aead0a032a96
Binary files /dev/null and b/my_diffusers/schedulers/__pycache__/scheduling_sde_vp.cpython-38.pyc differ
diff --git a/my_diffusers/schedulers/__pycache__/scheduling_utils.cpython-38.pyc b/my_diffusers/schedulers/__pycache__/scheduling_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d12a94d69276ef2ac9efcb2165a154b114a67f69
Binary files /dev/null and b/my_diffusers/schedulers/__pycache__/scheduling_utils.cpython-38.pyc differ
diff --git a/my_diffusers/schedulers/scheduling_ddim.py b/my_diffusers/schedulers/scheduling_ddim.py
new file mode 100644
index 0000000000000000000000000000000000000000..ccfb0f7e648acc81750a98d317a03de715633588
--- /dev/null
+++ b/my_diffusers/schedulers/scheduling_ddim.py
@@ -0,0 +1,270 @@
+# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
+# and https://github.com/hojonathanho/diffusion
+
+import math
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils import SchedulerMixin, SchedulerOutput
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+
+ def alpha_bar(time_step):
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas, dtype=np.float64)
+
+
+class DDIMScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
+ diffusion probabilistic models (DDPMs) with non-Markovian guidance.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
+ [`~ConfigMixin.from_config`] functios.
+
+ For more details, see the original paper: https://arxiv.org/abs/2010.02502
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, optional): TODO
+ timestep_values (`np.ndarray`, optional): TODO
+ clip_sample (`bool`, default `True`):
+ option to clip predicted sample between -1 and 1 for numerical stability.
+ set_alpha_to_one (`bool`, default `True`):
+ if alpha for final step is 1 or the final alpha of the "non-previous" one.
+ tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
+
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[np.ndarray] = None,
+ timestep_values: Optional[np.ndarray] = None,
+ clip_sample: bool = True,
+ set_alpha_to_one: bool = True,
+ tensor_format: str = "pt",
+ ):
+ if trained_betas is not None:
+ self.betas = np.asarray(trained_betas)
+ if beta_schedule == "linear":
+ self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float64)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float64) ** 2
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
+
+ # At every step in ddim, we are looking into the previous alphas_cumprod
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
+ # `set_alpha_to_one` decides whether we set this paratemer simply to one or
+ # whether we use the final alpha of the "non-previous" one.
+ self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
+
+ self.tensor_format = tensor_format
+ self.set_format(tensor_format=tensor_format)
+
+ # print(self.alphas.shape)
+
+
+ def _get_variance(self, timestep, prev_timestep):
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+
+ return variance
+
+ def set_timesteps(self, num_inference_steps: int, offset: int = 0):
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ offset (`int`): TODO
+ """
+ self.num_inference_steps = num_inference_steps
+ if num_inference_steps <= 1000:
+ self.timesteps = np.arange(
+ 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
+ )[::-1].copy()
+ else:
+ print("Hitting new logic, allowing fractional timesteps")
+ self.timesteps = np.linspace(
+ 0, self.config.num_train_timesteps-1, self.num_inference_steps, endpoint=True
+ )[::-1].copy()
+ self.timesteps += offset
+ self.set_format(tensor_format=self.tensor_format)
+
+ def step(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: int,
+ sample: Union[torch.FloatTensor, np.ndarray],
+ eta: float = 0.0,
+ use_clipped_model_output: bool = False,
+ generator=None,
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ eta (`float`): weight of noise for added noise in diffusion step.
+ use_clipped_model_output (`bool`): TODO
+ generator: random number generator.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
+ # Ideally, read DDIM paper in-detail understanding
+
+ # Notation ( ->
+ # - pred_noise_t -> e_theta(x_t, t)
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
+ # - std_dev_t -> sigma_t
+ # - eta -> η
+ # - pred_sample_direction -> "direction pointingc to x_t"
+ # - pred_prev_sample -> "x_t-1"
+
+ # 1. get previous step value (=t-1)
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
+
+ # 2. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+
+ # 3. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+
+ # 4. Clip "predicted x_0"
+ if self.config.clip_sample:
+ pred_original_sample = self.clip(pred_original_sample, -1, 1)
+
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
+ variance = self._get_variance(timestep, prev_timestep)
+ std_dev_t = eta * variance ** (0.5)
+
+ if use_clipped_model_output:
+ # the model_output is always re-derived from the clipped x_0 in Glide
+ model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
+
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
+
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
+
+ if eta > 0:
+ device = model_output.device if torch.is_tensor(model_output) else "cpu"
+ noise = torch.randn(model_output.shape, generator=generator).to(device)
+ variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
+
+ if not torch.is_tensor(model_output):
+ variance = variance.numpy()
+
+ prev_sample = prev_sample + variance
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def add_noise(
+ self,
+ original_samples: Union[torch.FloatTensor, np.ndarray],
+ noise: Union[torch.FloatTensor, np.ndarray],
+ timesteps: Union[torch.IntTensor, np.ndarray],
+ ) -> Union[torch.FloatTensor, np.ndarray]:
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/my_diffusers/schedulers/scheduling_ddpm.py b/my_diffusers/schedulers/scheduling_ddpm.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fbfb90383361ece4e82aa10a499c8dc58113794
--- /dev/null
+++ b/my_diffusers/schedulers/scheduling_ddpm.py
@@ -0,0 +1,264 @@
+# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
+
+import math
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils import SchedulerMixin, SchedulerOutput
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+
+ def alpha_bar(time_step):
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas, dtype=np.float32)
+
+
+class DDPMScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
+ Langevin dynamics sampling.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
+ [`~ConfigMixin.from_config`] functios.
+
+ For more details, see the original paper: https://arxiv.org/abs/2006.11239
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, optional): TODO
+ variance_type (`str`):
+ options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
+ `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
+ clip_sample (`bool`, default `True`):
+ option to clip predicted sample between -1 and 1 for numerical stability.
+ tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
+
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[np.ndarray] = None,
+ variance_type: str = "fixed_small",
+ clip_sample: bool = True,
+ tensor_format: str = "pt",
+ ):
+
+ if trained_betas is not None:
+ self.betas = np.asarray(trained_betas)
+ elif beta_schedule == "linear":
+ self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
+ self.one = np.array(1.0)
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
+
+ self.tensor_format = tensor_format
+ self.set_format(tensor_format=tensor_format)
+
+ self.variance_type = variance_type
+
+ def set_timesteps(self, num_inference_steps: int):
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ """
+ num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
+ self.num_inference_steps = num_inference_steps
+ self.timesteps = np.arange(
+ 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
+ )[::-1].copy()
+ self.set_format(tensor_format=self.tensor_format)
+
+ def _get_variance(self, t, predicted_variance=None, variance_type=None):
+ alpha_prod_t = self.alphas_cumprod[t]
+ alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
+
+ # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
+ # and sample from it to get previous sample
+ # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
+ variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
+
+ if variance_type is None:
+ variance_type = self.config.variance_type
+
+ # hacks - were probs added for training stability
+ if variance_type == "fixed_small":
+ variance = self.clip(variance, min_value=1e-20)
+ # for rl-diffuser https://arxiv.org/abs/2205.09991
+ elif variance_type == "fixed_small_log":
+ variance = self.log(self.clip(variance, min_value=1e-20))
+ elif variance_type == "fixed_large":
+ variance = self.betas[t]
+ elif variance_type == "fixed_large_log":
+ # Glide max_log
+ variance = self.log(self.betas[t])
+ elif variance_type == "learned":
+ return predicted_variance
+ elif variance_type == "learned_range":
+ min_log = variance
+ max_log = self.betas[t]
+ frac = (predicted_variance + 1) / 2
+ variance = frac * max_log + (1 - frac) * min_log
+
+ return variance
+
+ def step(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: int,
+ sample: Union[torch.FloatTensor, np.ndarray],
+ predict_epsilon=True,
+ generator=None,
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ eta (`float`): weight of noise for added noise in diffusion step.
+ predict_epsilon (`bool`):
+ optional flag to use when model predicts the samples directly instead of the noise, epsilon.
+ generator: random number generator.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+ t = timestep
+
+ if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
+ model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
+ else:
+ predicted_variance = None
+
+ # 1. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[t]
+ alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ # 2. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
+ if predict_epsilon:
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ else:
+ pred_original_sample = model_output
+
+ # 3. Clip "predicted x_0"
+ if self.config.clip_sample:
+ pred_original_sample = self.clip(pred_original_sample, -1, 1)
+
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t
+ current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
+
+ # 5. Compute predicted previous sample µ_t
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
+
+ # 6. Add noise
+ variance = 0
+ if t > 0:
+ noise = self.randn_like(model_output, generator=generator)
+ variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise
+
+ pred_prev_sample = pred_prev_sample + variance
+
+ if not return_dict:
+ return (pred_prev_sample,)
+
+ return SchedulerOutput(prev_sample=pred_prev_sample)
+
+ def add_noise(
+ self,
+ original_samples: Union[torch.FloatTensor, np.ndarray],
+ noise: Union[torch.FloatTensor, np.ndarray],
+ timesteps: Union[torch.IntTensor, np.ndarray],
+ ) -> Union[torch.FloatTensor, np.ndarray]:
+
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/my_diffusers/schedulers/scheduling_karras_ve.py b/my_diffusers/schedulers/scheduling_karras_ve.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a2370cfc3e0523dfba48703bcd0c3e9a42b2381
--- /dev/null
+++ b/my_diffusers/schedulers/scheduling_karras_ve.py
@@ -0,0 +1,208 @@
+# Copyright 2022 NVIDIA and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput
+from .scheduling_utils import SchedulerMixin
+
+
+@dataclass
+class KarrasVeOutput(BaseOutput):
+ """
+ Output class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Derivate of predicted original image sample (x_0).
+ """
+
+ prev_sample: torch.FloatTensor
+ derivative: torch.FloatTensor
+
+
+class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
+ the VE column of Table 1 from [1] for reference.
+
+ [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
+ https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
+ differential equations." https://arxiv.org/abs/2011.13456
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
+ [`~ConfigMixin.from_config`] functios.
+
+ For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
+ Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
+ optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
+
+ Args:
+ sigma_min (`float`): minimum noise magnitude
+ sigma_max (`float`): maximum noise magnitude
+ s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling.
+ A reasonable range is [1.000, 1.011].
+ s_churn (`float`): the parameter controlling the overall amount of stochasticity.
+ A reasonable range is [0, 100].
+ s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity).
+ A reasonable range is [0, 10].
+ s_max (`float`): the end value of the sigma range where we add noise.
+ A reasonable range is [0.2, 80].
+ tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
+
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ sigma_min: float = 0.02,
+ sigma_max: float = 100,
+ s_noise: float = 1.007,
+ s_churn: float = 80,
+ s_min: float = 0.05,
+ s_max: float = 50,
+ tensor_format: str = "pt",
+ ):
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = None
+ self.schedule = None # sigma(t_i)
+
+ self.tensor_format = tensor_format
+ self.set_format(tensor_format=tensor_format)
+
+ def set_timesteps(self, num_inference_steps: int):
+ """
+ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+
+ """
+ self.num_inference_steps = num_inference_steps
+ self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
+ self.schedule = [
+ (self.sigma_max * (self.sigma_min**2 / self.sigma_max**2) ** (i / (num_inference_steps - 1)))
+ for i in self.timesteps
+ ]
+ self.schedule = np.array(self.schedule, dtype=np.float32)
+
+ self.set_format(tensor_format=self.tensor_format)
+
+ def add_noise_to_input(
+ self, sample: Union[torch.FloatTensor, np.ndarray], sigma: float, generator: Optional[torch.Generator] = None
+ ) -> Tuple[Union[torch.FloatTensor, np.ndarray], float]:
+ """
+ Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
+ higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
+
+ TODO Args:
+ """
+ if self.s_min <= sigma <= self.s_max:
+ gamma = min(self.s_churn / self.num_inference_steps, 2**0.5 - 1)
+ else:
+ gamma = 0
+
+ # sample eps ~ N(0, S_noise^2 * I)
+ eps = self.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device)
+ sigma_hat = sigma + gamma * sigma
+ sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)
+
+ return sample_hat, sigma_hat
+
+ def step(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ sigma_hat: float,
+ sigma_prev: float,
+ sample_hat: Union[torch.FloatTensor, np.ndarray],
+ return_dict: bool = True,
+ ) -> Union[KarrasVeOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ sigma_hat (`float`): TODO
+ sigma_prev (`float`): TODO
+ sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check).
+ Returns:
+ [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] or `tuple`:
+ [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+
+ pred_original_sample = sample_hat + sigma_hat * model_output
+ derivative = (sample_hat - pred_original_sample) / sigma_hat
+ sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative
+
+ if not return_dict:
+ return (sample_prev, derivative)
+
+ return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative)
+
+ def step_correct(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ sigma_hat: float,
+ sigma_prev: float,
+ sample_hat: Union[torch.FloatTensor, np.ndarray],
+ sample_prev: Union[torch.FloatTensor, np.ndarray],
+ derivative: Union[torch.FloatTensor, np.ndarray],
+ return_dict: bool = True,
+ ) -> Union[KarrasVeOutput, Tuple]:
+ """
+ Correct the predicted sample based on the output model_output of the network. TODO complete description
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ sigma_hat (`float`): TODO
+ sigma_prev (`float`): TODO
+ sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
+ sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
+ derivative (`torch.FloatTensor` or `np.ndarray`): TODO
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
+
+ """
+ pred_original_sample = sample_prev + sigma_prev * model_output
+ derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
+ sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
+
+ if not return_dict:
+ return (sample_prev, derivative)
+
+ return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative)
+
+ def add_noise(self, original_samples, noise, timesteps):
+ raise NotImplementedError()
diff --git a/my_diffusers/schedulers/scheduling_lms_discrete.py b/my_diffusers/schedulers/scheduling_lms_discrete.py
new file mode 100644
index 0000000000000000000000000000000000000000..1381587febf16d9c774b5f2574653c962e031a46
--- /dev/null
+++ b/my_diffusers/schedulers/scheduling_lms_discrete.py
@@ -0,0 +1,193 @@
+# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from scipy import integrate
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils import SchedulerMixin, SchedulerOutput
+
+
+class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
+ Katherine Crowson:
+ https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
+ [`~ConfigMixin.from_config`] functios.
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear` or `scaled_linear`.
+ trained_betas (`np.ndarray`, optional): TODO
+ options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
+ `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
+ timestep_values (`np.ndarry`, optional): TODO
+ tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
+
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[np.ndarray] = None,
+ timestep_values: Optional[np.ndarray] = None,
+ tensor_format: str = "pt",
+ ):
+ if trained_betas is not None:
+ self.betas = np.asarray(trained_betas)
+ if beta_schedule == "linear":
+ self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
+
+ self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
+ self.derivatives = []
+
+ self.tensor_format = tensor_format
+ self.set_format(tensor_format=tensor_format)
+
+ def get_lms_coefficient(self, order, t, current_order):
+ """
+ Compute a linear multistep coefficient.
+
+ Args:
+ order (TODO):
+ t (TODO):
+ current_order (TODO):
+ """
+
+ def lms_derivative(tau):
+ prod = 1.0
+ for k in range(order):
+ if current_order == k:
+ continue
+ prod *= (tau - self.sigmas[t - k]) / (self.sigmas[t - current_order] - self.sigmas[t - k])
+ return prod
+
+ integrated_coeff = integrate.quad(lms_derivative, self.sigmas[t], self.sigmas[t + 1], epsrel=1e-4)[0]
+
+ return integrated_coeff
+
+ def set_timesteps(self, num_inference_steps: int):
+ """
+ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ """
+ self.num_inference_steps = num_inference_steps
+ self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
+
+ low_idx = np.floor(self.timesteps).astype(int)
+ high_idx = np.ceil(self.timesteps).astype(int)
+ frac = np.mod(self.timesteps, 1.0)
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
+ self.sigmas = np.concatenate([sigmas, [0.0]])
+
+ self.derivatives = []
+
+ self.set_format(tensor_format=self.tensor_format)
+
+ def step(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: int,
+ sample: Union[torch.FloatTensor, np.ndarray],
+ order: int = 4,
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ order: coefficient for multi-step inference.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+ sigma = self.sigmas[timestep]
+
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
+ pred_original_sample = sample - sigma * model_output
+
+ # 2. Convert to an ODE derivative
+ derivative = (sample - pred_original_sample) / sigma
+ self.derivatives.append(derivative)
+ if len(self.derivatives) > order:
+ self.derivatives.pop(0)
+
+ # 3. Compute linear multistep coefficients
+ order = min(timestep + 1, order)
+ lms_coeffs = [self.get_lms_coefficient(order, timestep, curr_order) for curr_order in range(order)]
+
+ # 4. Compute previous sample based on the derivatives path
+ prev_sample = sample + sum(
+ coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives))
+ )
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def add_noise(
+ self,
+ original_samples: Union[torch.FloatTensor, np.ndarray],
+ noise: Union[torch.FloatTensor, np.ndarray],
+ timesteps: Union[torch.IntTensor, np.ndarray],
+ ) -> Union[torch.FloatTensor, np.ndarray]:
+ sigmas = self.match_shape(self.sigmas[timesteps], noise)
+ noisy_samples = original_samples + noise * sigmas
+
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/my_diffusers/schedulers/scheduling_pndm.py b/my_diffusers/schedulers/scheduling_pndm.py
new file mode 100644
index 0000000000000000000000000000000000000000..b43d88bbab7745e3e8579cc66f2ee2ed246e52d7
--- /dev/null
+++ b/my_diffusers/schedulers/scheduling_pndm.py
@@ -0,0 +1,378 @@
+# Copyright 2022 Zhejiang University Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
+
+import math
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils import SchedulerMixin, SchedulerOutput
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+
+ def alpha_bar(time_step):
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas, dtype=np.float32)
+
+
+class PNDMScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques,
+ namely Runge-Kutta method and a linear multi-step method.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
+ [`~ConfigMixin.from_config`] functios.
+
+ For more details, see the original paper: https://arxiv.org/abs/2202.09778
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, optional): TODO
+ tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays
+ skip_prk_steps (`bool`):
+ allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
+ before plms steps; defaults to `False`.
+
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[np.ndarray] = None,
+ tensor_format: str = "pt",
+ skip_prk_steps: bool = False,
+ ):
+ if trained_betas is not None:
+ self.betas = np.asarray(trained_betas)
+ if beta_schedule == "linear":
+ self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
+
+ self.one = np.array(1.0)
+
+ # For now we only support F-PNDM, i.e. the runge-kutta method
+ # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
+ # mainly at formula (9), (12), (13) and the Algorithm 2.
+ self.pndm_order = 4
+
+ # running values
+ self.cur_model_output = 0
+ self.counter = 0
+ self.cur_sample = None
+ self.ets = []
+
+ # setable values
+ self.num_inference_steps = None
+ self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
+ self._offset = 0
+ self.prk_timesteps = None
+ self.plms_timesteps = None
+ self.timesteps = None
+
+ self.tensor_format = tensor_format
+ self.set_format(tensor_format=tensor_format)
+
+ def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.FloatTensor:
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ offset (`int`): TODO
+ """
+ self.num_inference_steps = num_inference_steps
+ self._timesteps = list(
+ range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
+ )
+ self._offset = offset
+ self._timesteps = np.array([t + self._offset for t in self._timesteps])
+
+ if self.config.skip_prk_steps:
+ # for some models like stable diffusion the prk steps can/should be skipped to
+ # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation
+ # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
+ self.prk_timesteps = np.array([])
+ self.plms_timesteps = np.concatenate([self._timesteps[:-1], self._timesteps[-2:-1], self._timesteps[-1:]])[
+ ::-1
+ ].copy()
+ else:
+ prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
+ np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
+ )
+ self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy()
+ self.plms_timesteps = self._timesteps[:-3][
+ ::-1
+ ].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy
+
+ self.timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64)
+
+ self.ets = []
+ self.counter = 0
+ self.set_format(tensor_format=self.tensor_format)
+
+ def step(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: int,
+ sample: Union[torch.FloatTensor, np.ndarray],
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`.
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps:
+ return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
+ else:
+ return self.step_plms(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
+
+ def step_prk(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: int,
+ sample: Union[torch.FloatTensor, np.ndarray],
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
+ solution to the differential equation.
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
+ True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2
+ prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1])
+ timestep = self.prk_timesteps[self.counter // 4 * 4]
+
+ if self.counter % 4 == 0:
+ self.cur_model_output += 1 / 6 * model_output
+ self.ets.append(model_output)
+ self.cur_sample = sample
+ elif (self.counter - 1) % 4 == 0:
+ self.cur_model_output += 1 / 3 * model_output
+ elif (self.counter - 2) % 4 == 0:
+ self.cur_model_output += 1 / 3 * model_output
+ elif (self.counter - 3) % 4 == 0:
+ model_output = self.cur_model_output + 1 / 6 * model_output
+ self.cur_model_output = 0
+
+ # cur_sample should not be `None`
+ cur_sample = self.cur_sample if self.cur_sample is not None else sample
+
+ prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
+ self.counter += 1
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def step_plms(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: int,
+ sample: Union[torch.FloatTensor, np.ndarray],
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
+ times to approximate the solution.
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
+ True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ if not self.config.skip_prk_steps and len(self.ets) < 3:
+ raise ValueError(
+ f"{self.__class__} can only be run AFTER scheduler has been run "
+ "in 'prk' mode for at least 12 iterations "
+ "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py "
+ "for more information."
+ )
+
+ prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0)
+
+ if self.counter != 1:
+ self.ets.append(model_output)
+ else:
+ prev_timestep = timestep
+ timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps
+
+ if len(self.ets) == 1 and self.counter == 0:
+ model_output = model_output
+ self.cur_sample = sample
+ elif len(self.ets) == 1 and self.counter == 1:
+ model_output = (model_output + self.ets[-1]) / 2
+ sample = self.cur_sample
+ self.cur_sample = None
+ elif len(self.ets) == 2:
+ model_output = (3 * self.ets[-1] - self.ets[-2]) / 2
+ elif len(self.ets) == 3:
+ model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12
+ else:
+ model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
+
+ prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
+ self.counter += 1
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def _get_prev_sample(self, sample, timestep, timestep_prev, model_output):
+ # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
+ # this function computes x_(t−δ) using the formula of (9)
+ # Note that x_t needs to be added to both sides of the equation
+
+ # Notation ( ->
+ # alpha_prod_t -> α_t
+ # alpha_prod_t_prev -> α_(t−δ)
+ # beta_prod_t -> (1 - α_t)
+ # beta_prod_t_prev -> (1 - α_(t−δ))
+ # sample -> x_t
+ # model_output -> e_θ(x_t, t)
+ # prev_sample -> x_(t−δ)
+ alpha_prod_t = self.alphas_cumprod[timestep + 1 - self._offset]
+ alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - self._offset]
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ # corresponds to (α_(t−δ) - α_t) divided by
+ # denominator of x_t in formula (9) and plus 1
+ # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =
+ # sqrt(α_(t−δ)) / sqrt(α_t))
+ sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5)
+
+ # corresponds to denominator of e_θ(x_t, t) in formula (9)
+ model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + (
+ alpha_prod_t * beta_prod_t * alpha_prod_t_prev
+ ) ** (0.5)
+
+ # full formula (9)
+ prev_sample = (
+ sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff
+ )
+
+ return prev_sample
+
+ def add_noise(
+ self,
+ original_samples: Union[torch.FloatTensor, np.ndarray],
+ noise: Union[torch.FloatTensor, np.ndarray],
+ timesteps: Union[torch.IntTensor, np.ndarray],
+ ) -> torch.Tensor:
+ # mps requires indices to be in the same device, so we use cpu as is the default with cuda
+ timesteps = timesteps.to(self.alphas_cumprod.device)
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/my_diffusers/schedulers/scheduling_sde_ve.py b/my_diffusers/schedulers/scheduling_sde_ve.py
new file mode 100644
index 0000000000000000000000000000000000000000..e187f079688723c991b4b80fa1fd4f358896bb4f
--- /dev/null
+++ b/my_diffusers/schedulers/scheduling_sde_ve.py
@@ -0,0 +1,283 @@
+# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
+
+import warnings
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput
+from .scheduling_utils import SchedulerMixin, SchedulerOutput
+
+
+@dataclass
+class SdeVeOutput(BaseOutput):
+ """
+ Output class for the ScoreSdeVeScheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps.
+ """
+
+ prev_sample: torch.FloatTensor
+ prev_sample_mean: torch.FloatTensor
+
+
+class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
+ """
+ The variance exploding stochastic differential equation (SDE) scheduler.
+
+ For more information, see the original paper: https://arxiv.org/abs/2011.13456
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
+ [`~ConfigMixin.from_config`] functios.
+
+ Args:
+ snr (`float`):
+ coefficient weighting the step from the model_output sample (from the network) to the random noise.
+ sigma_min (`float`):
+ initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the
+ distribution of the data.
+ sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model.
+ sampling_eps (`float`): the end value of sampling, where timesteps decrease progessively from 1 to
+ epsilon.
+ correct_steps (`int`): number of correction steps performed on a produced sample.
+ tensor_format (`str`): "np" or "pt" for the expected format of samples passed to the Scheduler.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 2000,
+ snr: float = 0.15,
+ sigma_min: float = 0.01,
+ sigma_max: float = 1348.0,
+ sampling_eps: float = 1e-5,
+ correct_steps: int = 1,
+ tensor_format: str = "pt",
+ ):
+ # setable values
+ self.timesteps = None
+
+ self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
+
+ self.tensor_format = tensor_format
+ self.set_format(tensor_format=tensor_format)
+
+ def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None):
+ """
+ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
+
+ """
+ sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
+ tensor_format = getattr(self, "tensor_format", "pt")
+ if tensor_format == "np":
+ self.timesteps = np.linspace(1, sampling_eps, num_inference_steps)
+ elif tensor_format == "pt":
+ self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps)
+ else:
+ raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ def set_sigmas(
+ self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None
+ ):
+ """
+ Sets the noise scales used for the diffusion chain. Supporting function to be run before inference.
+
+ The sigmas control the weight of the `drift` and `diffusion` components of sample update.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ sigma_min (`float`, optional):
+ initial noise scale value (overrides value given at Scheduler instantiation).
+ sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation).
+ sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
+
+ """
+ sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
+ sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max
+ sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
+ if self.timesteps is None:
+ self.set_timesteps(num_inference_steps, sampling_eps)
+
+ tensor_format = getattr(self, "tensor_format", "pt")
+ if tensor_format == "np":
+ self.discrete_sigmas = np.exp(np.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps))
+ self.sigmas = np.array([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
+ elif tensor_format == "pt":
+ self.discrete_sigmas = torch.exp(torch.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps))
+ self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
+ else:
+ raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ def get_adjacent_sigma(self, timesteps, t):
+ tensor_format = getattr(self, "tensor_format", "pt")
+ if tensor_format == "np":
+ return np.where(timesteps == 0, np.zeros_like(t), self.discrete_sigmas[timesteps - 1])
+ elif tensor_format == "pt":
+ return torch.where(
+ timesteps == 0,
+ torch.zeros_like(t.to(timesteps.device)),
+ self.discrete_sigmas[timesteps - 1].to(timesteps.device),
+ )
+
+ raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ def set_seed(self, seed):
+ warnings.warn(
+ "The method `set_seed` is deprecated and will be removed in version `0.4.0`. Please consider passing a"
+ " generator instead.",
+ DeprecationWarning,
+ )
+ tensor_format = getattr(self, "tensor_format", "pt")
+ if tensor_format == "np":
+ np.random.seed(seed)
+ elif tensor_format == "pt":
+ torch.manual_seed(seed)
+ else:
+ raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ def step_pred(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: int,
+ sample: Union[torch.FloatTensor, np.ndarray],
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[SdeVeOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ generator: random number generator.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if "seed" in kwargs and kwargs["seed"] is not None:
+ self.set_seed(kwargs["seed"])
+
+ if self.timesteps is None:
+ raise ValueError(
+ "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ timestep = timestep * torch.ones(
+ sample.shape[0], device=sample.device
+ ) # torch.repeat_interleave(timestep, sample.shape[0])
+ timesteps = (timestep * (len(self.timesteps) - 1)).long()
+
+ # mps requires indices to be in the same device, so we use cpu as is the default with cuda
+ timesteps = timesteps.to(self.discrete_sigmas.device)
+
+ sigma = self.discrete_sigmas[timesteps].to(sample.device)
+ adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device)
+ drift = self.zeros_like(sample)
+ diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5
+
+ # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
+ # also equation 47 shows the analog from SDE models to ancestral sampling methods
+ drift = drift - diffusion[:, None, None, None] ** 2 * model_output
+
+ # equation 6: sample noise for the diffusion term of
+ noise = self.randn_like(sample, generator=generator)
+ prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
+ # TODO is the variable diffusion the correct scaling term for the noise?
+ prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
+
+ if not return_dict:
+ return (prev_sample, prev_sample_mean)
+
+ return SdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean)
+
+ def step_correct(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ sample: Union[torch.FloatTensor, np.ndarray],
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
+ after making the prediction for the previous timestep.
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ generator: random number generator.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if "seed" in kwargs and kwargs["seed"] is not None:
+ self.set_seed(kwargs["seed"])
+
+ if self.timesteps is None:
+ raise ValueError(
+ "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
+ # sample noise for correction
+ noise = self.randn_like(sample, generator=generator)
+
+ # compute step size from the model_output, the noise, and the snr
+ grad_norm = self.norm(model_output)
+ noise_norm = self.norm(noise)
+ step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
+ step_size = step_size * torch.ones(sample.shape[0]).to(sample.device)
+ # self.repeat_scalar(step_size, sample.shape[0])
+
+ # compute corrected sample: model_output term and noise term
+ prev_sample_mean = sample + step_size[:, None, None, None] * model_output
+ prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/my_diffusers/schedulers/scheduling_sde_vp.py b/my_diffusers/schedulers/scheduling_sde_vp.py
new file mode 100644
index 0000000000000000000000000000000000000000..66e6ec6616ab01e5ae988b21e9599a0422a9714a
--- /dev/null
+++ b/my_diffusers/schedulers/scheduling_sde_vp.py
@@ -0,0 +1,81 @@
+# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
+
+# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils import SchedulerMixin
+
+
+class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
+ """
+ The variance preserving stochastic differential equation (SDE) scheduler.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
+ [`~ConfigMixin.from_config`] functios.
+
+ For more information, see the original paper: https://arxiv.org/abs/2011.13456
+
+ UNDER CONSTRUCTION
+
+ """
+
+ @register_to_config
+ def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"):
+
+ self.sigmas = None
+ self.discrete_sigmas = None
+ self.timesteps = None
+
+ def set_timesteps(self, num_inference_steps):
+ self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
+
+ def step_pred(self, score, x, t):
+ if self.timesteps is None:
+ raise ValueError(
+ "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # TODO(Patrick) better comments + non-PyTorch
+ # postprocess model score
+ log_mean_coeff = (
+ -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min
+ )
+ std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff))
+ score = -score / std[:, None, None, None]
+
+ # compute
+ dt = -1.0 / len(self.timesteps)
+
+ beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min)
+ drift = -0.5 * beta_t[:, None, None, None] * x
+ diffusion = torch.sqrt(beta_t)
+ drift = drift - diffusion[:, None, None, None] ** 2 * score
+ x_mean = x + drift * dt
+
+ # add noise
+ noise = torch.randn_like(x)
+ x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * noise
+
+ return x, x_mean
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/my_diffusers/schedulers/scheduling_utils.py b/my_diffusers/schedulers/scheduling_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2bcd73acf32c1e152a5d8708479731996731c6d
--- /dev/null
+++ b/my_diffusers/schedulers/scheduling_utils.py
@@ -0,0 +1,125 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Union
+
+import numpy as np
+import torch
+
+from ..utils import BaseOutput
+
+
+SCHEDULER_CONFIG_NAME = "scheduler_config.json"
+
+
+@dataclass
+class SchedulerOutput(BaseOutput):
+ """
+ Base class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ """
+
+ prev_sample: torch.FloatTensor
+
+
+class SchedulerMixin:
+ """
+ Mixin containing common functions for the schedulers.
+ """
+
+ config_name = SCHEDULER_CONFIG_NAME
+ ignore_for_config = ["tensor_format"]
+
+ def set_format(self, tensor_format="pt"):
+ self.tensor_format = tensor_format
+ if tensor_format == "pt":
+ for key, value in vars(self).items():
+ if isinstance(value, np.ndarray):
+ setattr(self, key, torch.from_numpy(value))
+
+ return self
+
+ def clip(self, tensor, min_value=None, max_value=None):
+ tensor_format = getattr(self, "tensor_format", "pt")
+
+ if tensor_format == "np":
+ return np.clip(tensor, min_value, max_value)
+ elif tensor_format == "pt":
+ return torch.clamp(tensor, min_value, max_value)
+
+ raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ def log(self, tensor):
+ tensor_format = getattr(self, "tensor_format", "pt")
+
+ if tensor_format == "np":
+ return np.log(tensor)
+ elif tensor_format == "pt":
+ return torch.log(tensor)
+
+ raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ def match_shape(self, values: Union[np.ndarray, torch.Tensor], broadcast_array: Union[np.ndarray, torch.Tensor]):
+ """
+ Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims.
+
+ Args:
+ values: an array or tensor of values to extract.
+ broadcast_array: an array with a larger shape of K dimensions with the batch
+ dimension equal to the length of timesteps.
+ Returns:
+ a tensor of shape [batch_size, 1, ...] where the shape has K dims.
+ """
+
+ tensor_format = getattr(self, "tensor_format", "pt")
+ values = values.flatten()
+
+ while len(values.shape) < len(broadcast_array.shape):
+ values = values[..., None]
+ if tensor_format == "pt":
+ values = values.to(broadcast_array.device)
+
+ return values
+
+ def norm(self, tensor):
+ tensor_format = getattr(self, "tensor_format", "pt")
+ if tensor_format == "np":
+ return np.linalg.norm(tensor)
+ elif tensor_format == "pt":
+ return torch.norm(tensor.reshape(tensor.shape[0], -1), dim=-1).mean()
+
+ raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ def randn_like(self, tensor, generator=None):
+ tensor_format = getattr(self, "tensor_format", "pt")
+ if tensor_format == "np":
+ return np.random.randn(*np.shape(tensor))
+ elif tensor_format == "pt":
+ # return torch.randn_like(tensor)
+ return torch.randn(tensor.shape, layout=tensor.layout, generator=generator).to(tensor.device)
+
+ raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ def zeros_like(self, tensor):
+ tensor_format = getattr(self, "tensor_format", "pt")
+ if tensor_format == "np":
+ return np.zeros_like(tensor)
+ elif tensor_format == "pt":
+ return torch.zeros_like(tensor)
+
+ raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
diff --git a/my_diffusers/testing_utils.py b/my_diffusers/testing_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff8b6aa9b41c45b0ab77f343904bffc53fa9e9cb
--- /dev/null
+++ b/my_diffusers/testing_utils.py
@@ -0,0 +1,61 @@
+import os
+import random
+import unittest
+from distutils.util import strtobool
+
+import torch
+
+from packaging import version
+
+
+global_rng = random.Random()
+torch_device = "cuda" if torch.cuda.is_available() else "cpu"
+is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.12")
+
+if is_torch_higher_equal_than_1_12:
+ torch_device = "mps" if torch.backends.mps.is_available() else torch_device
+
+
+def parse_flag_from_env(key, default=False):
+ try:
+ value = os.environ[key]
+ except KeyError:
+ # KEY isn't set, default to `default`.
+ _value = default
+ else:
+ # KEY is set, convert it to True or False.
+ try:
+ _value = strtobool(value)
+ except ValueError:
+ # More values are supported, but let's keep the message simple.
+ raise ValueError(f"If set, {key} must be yes or no.")
+ return _value
+
+
+_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
+
+
+def floats_tensor(shape, scale=1.0, rng=None, name=None):
+ """Creates a random float32 tensor"""
+ if rng is None:
+ rng = global_rng
+
+ total_dims = 1
+ for dim in shape:
+ total_dims *= dim
+
+ values = []
+ for _ in range(total_dims):
+ values.append(rng.random() * scale)
+
+ return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous()
+
+
+def slow(test_case):
+ """
+ Decorator marking a test as slow.
+
+ Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
+
+ """
+ return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
diff --git a/my_diffusers/training_utils.py b/my_diffusers/training_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa1694161fc54c7fd097abf3bcbf44c498daad4b
--- /dev/null
+++ b/my_diffusers/training_utils.py
@@ -0,0 +1,125 @@
+import copy
+import os
+import random
+
+import numpy as np
+import torch
+
+
+def enable_full_determinism(seed: int):
+ """
+ Helper function for reproducible behavior during distributed training. See
+ - https://pytorch.org/docs/stable/notes/randomness.html for pytorch
+ """
+ # set seed first
+ set_seed(seed)
+
+ # Enable PyTorch deterministic mode. This potentially requires either the environment
+ # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
+ # depending on the CUDA version, so we set them both here
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
+ torch.use_deterministic_algorithms(True)
+
+ # Enable CUDNN deterministic mode
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+
+def set_seed(seed: int):
+ """
+ Args:
+ Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
+ seed (`int`): The seed to set.
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ # ^^ safe to call this function even if cuda is not available
+
+
+class EMAModel:
+ """
+ Exponential Moving Average of models weights
+ """
+
+ def __init__(
+ self,
+ model,
+ update_after_step=0,
+ inv_gamma=1.0,
+ power=2 / 3,
+ min_value=0.0,
+ max_value=0.9999,
+ device=None,
+ ):
+ """
+ @crowsonkb's notes on EMA Warmup:
+ If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
+ to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
+ gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
+ at 215.4k steps).
+ Args:
+ inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
+ power (float): Exponential factor of EMA warmup. Default: 2/3.
+ min_value (float): The minimum EMA decay rate. Default: 0.
+ """
+
+ self.averaged_model = copy.deepcopy(model).eval()
+ self.averaged_model.requires_grad_(False)
+
+ self.update_after_step = update_after_step
+ self.inv_gamma = inv_gamma
+ self.power = power
+ self.min_value = min_value
+ self.max_value = max_value
+
+ if device is not None:
+ self.averaged_model = self.averaged_model.to(device=device)
+
+ self.decay = 0.0
+ self.optimization_step = 0
+
+ def get_decay(self, optimization_step):
+ """
+ Compute the decay factor for the exponential moving average.
+ """
+ step = max(0, optimization_step - self.update_after_step - 1)
+ value = 1 - (1 + step / self.inv_gamma) ** -self.power
+
+ if step <= 0:
+ return 0.0
+
+ return max(self.min_value, min(value, self.max_value))
+
+ @torch.no_grad()
+ def step(self, new_model):
+ ema_state_dict = {}
+ ema_params = self.averaged_model.state_dict()
+
+ self.decay = self.get_decay(self.optimization_step)
+
+ for key, param in new_model.named_parameters():
+ if isinstance(param, dict):
+ continue
+ try:
+ ema_param = ema_params[key]
+ except KeyError:
+ ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
+ ema_params[key] = ema_param
+
+ if not param.requires_grad:
+ ema_params[key].copy_(param.to(dtype=ema_param.dtype).data)
+ ema_param = ema_params[key]
+ else:
+ ema_param.mul_(self.decay)
+ ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)
+
+ ema_state_dict[key] = ema_param
+
+ for key, param in new_model.named_buffers():
+ ema_state_dict[key] = param
+
+ self.averaged_model.load_state_dict(ema_state_dict, strict=False)
+ self.optimization_step += 1
diff --git a/my_diffusers/utils/__init__.py b/my_diffusers/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c00a28e1058fbd47451bfe48e23865876c08ed69
--- /dev/null
+++ b/my_diffusers/utils/__init__.py
@@ -0,0 +1,53 @@
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+
+from .import_utils import (
+ ENV_VARS_TRUE_AND_AUTO_VALUES,
+ ENV_VARS_TRUE_VALUES,
+ USE_JAX,
+ USE_TF,
+ USE_TORCH,
+ DummyObject,
+ is_flax_available,
+ is_inflect_available,
+ is_modelcards_available,
+ is_onnx_available,
+ is_scipy_available,
+ is_tf_available,
+ is_torch_available,
+ is_transformers_available,
+ is_unidecode_available,
+ requires_backends,
+)
+from .logging import get_logger
+from .outputs import BaseOutput
+
+
+logger = get_logger(__name__)
+
+
+hf_cache_home = os.path.expanduser(
+ os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
+)
+default_cache_path = os.path.join(hf_cache_home, "diffusers")
+
+
+CONFIG_NAME = "config.json"
+HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
+DIFFUSERS_CACHE = default_cache_path
+DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
+HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
diff --git a/my_diffusers/utils/__pycache__/__init__.cpython-38.pyc b/my_diffusers/utils/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9916fed2bc21ccf8fe62223b22636f14ecee08ac
Binary files /dev/null and b/my_diffusers/utils/__pycache__/__init__.cpython-38.pyc differ
diff --git a/my_diffusers/utils/__pycache__/import_utils.cpython-38.pyc b/my_diffusers/utils/__pycache__/import_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3065c05535492893920b6292ba0856ccbea11b8c
Binary files /dev/null and b/my_diffusers/utils/__pycache__/import_utils.cpython-38.pyc differ
diff --git a/my_diffusers/utils/__pycache__/logging.cpython-38.pyc b/my_diffusers/utils/__pycache__/logging.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8e748b571156f84f8b1086e678e06fa1e505a38b
Binary files /dev/null and b/my_diffusers/utils/__pycache__/logging.cpython-38.pyc differ
diff --git a/my_diffusers/utils/__pycache__/outputs.cpython-38.pyc b/my_diffusers/utils/__pycache__/outputs.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f9ad0fd0c4964787bb7d2f558a98b494dec6ae1c
Binary files /dev/null and b/my_diffusers/utils/__pycache__/outputs.cpython-38.pyc differ
diff --git a/my_diffusers/utils/dummy_scipy_objects.py b/my_diffusers/utils/dummy_scipy_objects.py
new file mode 100644
index 0000000000000000000000000000000000000000..3706c57541c1b7d9004957422b52cd1e2191ae68
--- /dev/null
+++ b/my_diffusers/utils/dummy_scipy_objects.py
@@ -0,0 +1,11 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+# flake8: noqa
+
+from ..utils import DummyObject, requires_backends
+
+
+class LMSDiscreteScheduler(metaclass=DummyObject):
+ _backends = ["scipy"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["scipy"])
diff --git a/my_diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py b/my_diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c2aec218c40190bd2d078bfb36fc34fd4ef16c2
--- /dev/null
+++ b/my_diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py
@@ -0,0 +1,10 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+# flake8: noqa
+from ..utils import DummyObject, requires_backends
+
+
+class GradTTSPipeline(metaclass=DummyObject):
+ _backends = ["transformers", "inflect", "unidecode"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["transformers", "inflect", "unidecode"])
diff --git a/my_diffusers/utils/dummy_transformers_and_onnx_objects.py b/my_diffusers/utils/dummy_transformers_and_onnx_objects.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e34b5ce0b69472df7e2c41de40476619d53dee9
--- /dev/null
+++ b/my_diffusers/utils/dummy_transformers_and_onnx_objects.py
@@ -0,0 +1,11 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+# flake8: noqa
+
+from ..utils import DummyObject, requires_backends
+
+
+class StableDiffusionOnnxPipeline(metaclass=DummyObject):
+ _backends = ["transformers", "onnx"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["transformers", "onnx"])
diff --git a/my_diffusers/utils/dummy_transformers_objects.py b/my_diffusers/utils/dummy_transformers_objects.py
new file mode 100644
index 0000000000000000000000000000000000000000..e05eb814d17b3a49eb550a89dfd13ee24fdda134
--- /dev/null
+++ b/my_diffusers/utils/dummy_transformers_objects.py
@@ -0,0 +1,32 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+# flake8: noqa
+
+from ..utils import DummyObject, requires_backends
+
+
+class LDMTextToImagePipeline(metaclass=DummyObject):
+ _backends = ["transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["transformers"])
+
+
+class StableDiffusionImg2ImgPipeline(metaclass=DummyObject):
+ _backends = ["transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["transformers"])
+
+
+class StableDiffusionInpaintPipeline(metaclass=DummyObject):
+ _backends = ["transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["transformers"])
+
+
+class StableDiffusionPipeline(metaclass=DummyObject):
+ _backends = ["transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["transformers"])
diff --git a/my_diffusers/utils/import_utils.py b/my_diffusers/utils/import_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f5e95ada51da97ac67e1dc62538b6eed8784bce
--- /dev/null
+++ b/my_diffusers/utils/import_utils.py
@@ -0,0 +1,274 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Import utilities: Utilities related to imports and our lazy inits.
+"""
+import importlib.util
+import os
+import sys
+from collections import OrderedDict
+
+from packaging import version
+
+from . import logging
+
+
+# The package importlib_metadata is in a different place, depending on the python version.
+if sys.version_info < (3, 8):
+ import importlib_metadata
+else:
+ import importlib.metadata as importlib_metadata
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
+ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
+
+USE_TF = os.environ.get("USE_TF", "AUTO").upper()
+USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
+USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
+
+_torch_version = "N/A"
+if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
+ _torch_available = importlib.util.find_spec("torch") is not None
+ if _torch_available:
+ try:
+ _torch_version = importlib_metadata.version("torch")
+ logger.info(f"PyTorch version {_torch_version} available.")
+ except importlib_metadata.PackageNotFoundError:
+ _torch_available = False
+else:
+ logger.info("Disabling PyTorch because USE_TF is set")
+ _torch_available = False
+
+
+_tf_version = "N/A"
+if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
+ _tf_available = importlib.util.find_spec("tensorflow") is not None
+ if _tf_available:
+ candidates = (
+ "tensorflow",
+ "tensorflow-cpu",
+ "tensorflow-gpu",
+ "tf-nightly",
+ "tf-nightly-cpu",
+ "tf-nightly-gpu",
+ "intel-tensorflow",
+ "intel-tensorflow-avx512",
+ "tensorflow-rocm",
+ "tensorflow-macos",
+ "tensorflow-aarch64",
+ )
+ _tf_version = None
+ # For the metadata, we have to look for both tensorflow and tensorflow-cpu
+ for pkg in candidates:
+ try:
+ _tf_version = importlib_metadata.version(pkg)
+ break
+ except importlib_metadata.PackageNotFoundError:
+ pass
+ _tf_available = _tf_version is not None
+ if _tf_available:
+ if version.parse(_tf_version) < version.parse("2"):
+ logger.info(f"TensorFlow found but with version {_tf_version}. Diffusers requires version 2 minimum.")
+ _tf_available = False
+ else:
+ logger.info(f"TensorFlow version {_tf_version} available.")
+else:
+ logger.info("Disabling Tensorflow because USE_TORCH is set")
+ _tf_available = False
+
+
+if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
+ _flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
+ if _flax_available:
+ try:
+ _jax_version = importlib_metadata.version("jax")
+ _flax_version = importlib_metadata.version("flax")
+ logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
+ except importlib_metadata.PackageNotFoundError:
+ _flax_available = False
+else:
+ _flax_available = False
+
+
+_transformers_available = importlib.util.find_spec("transformers") is not None
+try:
+ _transformers_version = importlib_metadata.version("transformers")
+ logger.debug(f"Successfully imported transformers version {_transformers_version}")
+except importlib_metadata.PackageNotFoundError:
+ _transformers_available = False
+
+
+_inflect_available = importlib.util.find_spec("inflect") is not None
+try:
+ _inflect_version = importlib_metadata.version("inflect")
+ logger.debug(f"Successfully imported inflect version {_inflect_version}")
+except importlib_metadata.PackageNotFoundError:
+ _inflect_available = False
+
+
+_unidecode_available = importlib.util.find_spec("unidecode") is not None
+try:
+ _unidecode_version = importlib_metadata.version("unidecode")
+ logger.debug(f"Successfully imported unidecode version {_unidecode_version}")
+except importlib_metadata.PackageNotFoundError:
+ _unidecode_available = False
+
+
+_modelcards_available = importlib.util.find_spec("modelcards") is not None
+try:
+ _modelcards_version = importlib_metadata.version("modelcards")
+ logger.debug(f"Successfully imported modelcards version {_modelcards_version}")
+except importlib_metadata.PackageNotFoundError:
+ _modelcards_available = False
+
+
+_onnx_available = importlib.util.find_spec("onnxruntime") is not None
+try:
+ _onnxruntime_version = importlib_metadata.version("onnxruntime")
+ logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}")
+except importlib_metadata.PackageNotFoundError:
+ _onnx_available = False
+
+
+_scipy_available = importlib.util.find_spec("scipy") is not None
+try:
+ _scipy_version = importlib_metadata.version("scipy")
+ logger.debug(f"Successfully imported transformers version {_scipy_version}")
+except importlib_metadata.PackageNotFoundError:
+ _scipy_available = False
+
+
+def is_torch_available():
+ return _torch_available
+
+
+def is_tf_available():
+ return _tf_available
+
+
+def is_flax_available():
+ return _flax_available
+
+
+def is_transformers_available():
+ return _transformers_available
+
+
+def is_inflect_available():
+ return _inflect_available
+
+
+def is_unidecode_available():
+ return _unidecode_available
+
+
+def is_modelcards_available():
+ return _modelcards_available
+
+
+def is_onnx_available():
+ return _onnx_available
+
+
+def is_scipy_available():
+ return _scipy_available
+
+
+# docstyle-ignore
+FLAX_IMPORT_ERROR = """
+{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
+installation page: https://github.com/google/flax and follow the ones that match your environment.
+"""
+
+# docstyle-ignore
+INFLECT_IMPORT_ERROR = """
+{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install
+inflect`
+"""
+
+# docstyle-ignore
+PYTORCH_IMPORT_ERROR = """
+{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the
+installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
+"""
+
+# docstyle-ignore
+ONNX_IMPORT_ERROR = """
+{0} requires the onnxruntime library but it was not found in your environment. You can install it with pip: `pip
+install onnxruntime`
+"""
+
+# docstyle-ignore
+SCIPY_IMPORT_ERROR = """
+{0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install
+scipy`
+"""
+
+# docstyle-ignore
+TENSORFLOW_IMPORT_ERROR = """
+{0} requires the TensorFlow library but it was not found in your environment. Checkout the instructions on the
+installation page: https://www.tensorflow.org/install and follow the ones that match your environment.
+"""
+
+# docstyle-ignore
+TRANSFORMERS_IMPORT_ERROR = """
+{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip
+install transformers`
+"""
+
+# docstyle-ignore
+UNIDECODE_IMPORT_ERROR = """
+{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install
+Unidecode`
+"""
+
+
+BACKENDS_MAPPING = OrderedDict(
+ [
+ ("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
+ ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)),
+ ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)),
+ ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
+ ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
+ ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
+ ("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)),
+ ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)),
+ ]
+)
+
+
+def requires_backends(obj, backends):
+ if not isinstance(backends, (list, tuple)):
+ backends = [backends]
+
+ name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
+ checks = (BACKENDS_MAPPING[backend] for backend in backends)
+ failed = [msg.format(name) for available, msg in checks if not available()]
+ if failed:
+ raise ImportError("".join(failed))
+
+
+class DummyObject(type):
+ """
+ Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
+ `requires_backend` each time a user tries to access any method of that class.
+ """
+
+ def __getattr__(cls, key):
+ if key.startswith("_"):
+ return super().__getattr__(cls, key)
+ requires_backends(cls, cls._backends)
diff --git a/my_diffusers/utils/logging.py b/my_diffusers/utils/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f2d0227b87c66205ceb3391a8e98f5f33285dc4
--- /dev/null
+++ b/my_diffusers/utils/logging.py
@@ -0,0 +1,344 @@
+# coding=utf-8
+# Copyright 2020 Optuna, Hugging Face
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Logging utilities."""
+
+import logging
+import os
+import sys
+import threading
+from logging import CRITICAL # NOQA
+from logging import DEBUG # NOQA
+from logging import ERROR # NOQA
+from logging import FATAL # NOQA
+from logging import INFO # NOQA
+from logging import NOTSET # NOQA
+from logging import WARN # NOQA
+from logging import WARNING # NOQA
+from typing import Optional
+
+from tqdm import auto as tqdm_lib
+
+
+_lock = threading.Lock()
+_default_handler: Optional[logging.Handler] = None
+
+log_levels = {
+ "debug": logging.DEBUG,
+ "info": logging.INFO,
+ "warning": logging.WARNING,
+ "error": logging.ERROR,
+ "critical": logging.CRITICAL,
+}
+
+_default_log_level = logging.WARNING
+
+_tqdm_active = True
+
+
+def _get_default_logging_level():
+ """
+ If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
+ not - fall back to `_default_log_level`
+ """
+ env_level_str = os.getenv("DIFFUSERS_VERBOSITY", None)
+ if env_level_str:
+ if env_level_str in log_levels:
+ return log_levels[env_level_str]
+ else:
+ logging.getLogger().warning(
+ f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, "
+ f"has to be one of: { ', '.join(log_levels.keys()) }"
+ )
+ return _default_log_level
+
+
+def _get_library_name() -> str:
+
+ return __name__.split(".")[0]
+
+
+def _get_library_root_logger() -> logging.Logger:
+
+ return logging.getLogger(_get_library_name())
+
+
+def _configure_library_root_logger() -> None:
+
+ global _default_handler
+
+ with _lock:
+ if _default_handler:
+ # This library has already configured the library root logger.
+ return
+ _default_handler = logging.StreamHandler() # Set sys.stderr as stream.
+ _default_handler.flush = sys.stderr.flush
+
+ # Apply our default configuration to the library root logger.
+ library_root_logger = _get_library_root_logger()
+ library_root_logger.addHandler(_default_handler)
+ library_root_logger.setLevel(_get_default_logging_level())
+ library_root_logger.propagate = False
+
+
+def _reset_library_root_logger() -> None:
+
+ global _default_handler
+
+ with _lock:
+ if not _default_handler:
+ return
+
+ library_root_logger = _get_library_root_logger()
+ library_root_logger.removeHandler(_default_handler)
+ library_root_logger.setLevel(logging.NOTSET)
+ _default_handler = None
+
+
+def get_log_levels_dict():
+ return log_levels
+
+
+def get_logger(name: Optional[str] = None) -> logging.Logger:
+ """
+ Return a logger with the specified name.
+
+ This function is not supposed to be directly accessed unless you are writing a custom diffusers module.
+ """
+
+ if name is None:
+ name = _get_library_name()
+
+ _configure_library_root_logger()
+ return logging.getLogger(name)
+
+
+def get_verbosity() -> int:
+ """
+ Return the current level for the 🤗 Diffusers' root logger as an int.
+
+ Returns:
+ `int`: The logging level.
+
+
+
+ 🤗 Diffusers has following logging levels:
+
+ - 50: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
+ - 40: `diffusers.logging.ERROR`
+ - 30: `diffusers.logging.WARNING` or `diffusers.logging.WARN`
+ - 20: `diffusers.logging.INFO`
+ - 10: `diffusers.logging.DEBUG`
+
+ """
+
+ _configure_library_root_logger()
+ return _get_library_root_logger().getEffectiveLevel()
+
+
+def set_verbosity(verbosity: int) -> None:
+ """
+ Set the verbosity level for the 🤗 Diffusers' root logger.
+
+ Args:
+ verbosity (`int`):
+ Logging level, e.g., one of:
+
+ - `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
+ - `diffusers.logging.ERROR`
+ - `diffusers.logging.WARNING` or `diffusers.logging.WARN`
+ - `diffusers.logging.INFO`
+ - `diffusers.logging.DEBUG`
+ """
+
+ _configure_library_root_logger()
+ _get_library_root_logger().setLevel(verbosity)
+
+
+def set_verbosity_info():
+ """Set the verbosity to the `INFO` level."""
+ return set_verbosity(INFO)
+
+
+def set_verbosity_warning():
+ """Set the verbosity to the `WARNING` level."""
+ return set_verbosity(WARNING)
+
+
+def set_verbosity_debug():
+ """Set the verbosity to the `DEBUG` level."""
+ return set_verbosity(DEBUG)
+
+
+def set_verbosity_error():
+ """Set the verbosity to the `ERROR` level."""
+ return set_verbosity(ERROR)
+
+
+def disable_default_handler() -> None:
+ """Disable the default handler of the HuggingFace Diffusers' root logger."""
+
+ _configure_library_root_logger()
+
+ assert _default_handler is not None
+ _get_library_root_logger().removeHandler(_default_handler)
+
+
+def enable_default_handler() -> None:
+ """Enable the default handler of the HuggingFace Diffusers' root logger."""
+
+ _configure_library_root_logger()
+
+ assert _default_handler is not None
+ _get_library_root_logger().addHandler(_default_handler)
+
+
+def add_handler(handler: logging.Handler) -> None:
+ """adds a handler to the HuggingFace Diffusers' root logger."""
+
+ _configure_library_root_logger()
+
+ assert handler is not None
+ _get_library_root_logger().addHandler(handler)
+
+
+def remove_handler(handler: logging.Handler) -> None:
+ """removes given handler from the HuggingFace Diffusers' root logger."""
+
+ _configure_library_root_logger()
+
+ assert handler is not None and handler not in _get_library_root_logger().handlers
+ _get_library_root_logger().removeHandler(handler)
+
+
+def disable_propagation() -> None:
+ """
+ Disable propagation of the library log outputs. Note that log propagation is disabled by default.
+ """
+
+ _configure_library_root_logger()
+ _get_library_root_logger().propagate = False
+
+
+def enable_propagation() -> None:
+ """
+ Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent
+ double logging if the root logger has been configured.
+ """
+
+ _configure_library_root_logger()
+ _get_library_root_logger().propagate = True
+
+
+def enable_explicit_format() -> None:
+ """
+ Enable explicit formatting for every HuggingFace Diffusers' logger. The explicit formatter is as follows:
+ ```
+ [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
+ ```
+ All handlers currently bound to the root logger are affected by this method.
+ """
+ handlers = _get_library_root_logger().handlers
+
+ for handler in handlers:
+ formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
+ handler.setFormatter(formatter)
+
+
+def reset_format() -> None:
+ """
+ Resets the formatting for HuggingFace Diffusers' loggers.
+
+ All handlers currently bound to the root logger are affected by this method.
+ """
+ handlers = _get_library_root_logger().handlers
+
+ for handler in handlers:
+ handler.setFormatter(None)
+
+
+def warning_advice(self, *args, **kwargs):
+ """
+ This method is identical to `logger.warninging()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this
+ warning will not be printed
+ """
+ no_advisory_warnings = os.getenv("DIFFUSERS_NO_ADVISORY_WARNINGS", False)
+ if no_advisory_warnings:
+ return
+ self.warning(*args, **kwargs)
+
+
+logging.Logger.warning_advice = warning_advice
+
+
+class EmptyTqdm:
+ """Dummy tqdm which doesn't do anything."""
+
+ def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
+ self._iterator = args[0] if args else None
+
+ def __iter__(self):
+ return iter(self._iterator)
+
+ def __getattr__(self, _):
+ """Return empty function."""
+
+ def empty_fn(*args, **kwargs): # pylint: disable=unused-argument
+ return
+
+ return empty_fn
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type_, value, traceback):
+ return
+
+
+class _tqdm_cls:
+ def __call__(self, *args, **kwargs):
+ if _tqdm_active:
+ return tqdm_lib.tqdm(*args, **kwargs)
+ else:
+ return EmptyTqdm(*args, **kwargs)
+
+ def set_lock(self, *args, **kwargs):
+ self._lock = None
+ if _tqdm_active:
+ return tqdm_lib.tqdm.set_lock(*args, **kwargs)
+
+ def get_lock(self):
+ if _tqdm_active:
+ return tqdm_lib.tqdm.get_lock()
+
+
+tqdm = _tqdm_cls()
+
+
+def is_progress_bar_enabled() -> bool:
+ """Return a boolean indicating whether tqdm progress bars are enabled."""
+ global _tqdm_active
+ return bool(_tqdm_active)
+
+
+def enable_progress_bar():
+ """Enable tqdm progress bar."""
+ global _tqdm_active
+ _tqdm_active = True
+
+
+def disable_progress_bar():
+ """Disable tqdm progress bar."""
+ global _tqdm_active
+ _tqdm_active = False
diff --git a/my_diffusers/utils/model_card_template.md b/my_diffusers/utils/model_card_template.md
new file mode 100644
index 0000000000000000000000000000000000000000..f19c85b0fcf2f7b07e9c3f950a9657b3f2053f21
--- /dev/null
+++ b/my_diffusers/utils/model_card_template.md
@@ -0,0 +1,50 @@
+---
+{{ card_data }}
+---
+
+
+
+# {{ model_name | default("Diffusion Model") }}
+
+## Model description
+
+This diffusion model is trained with the [🤗 Diffusers](https://github.com/huggingface/diffusers) library
+on the `{{ dataset_name }}` dataset.
+
+## Intended uses & limitations
+
+#### How to use
+
+```python
+# TODO: add an example code snippet for running this diffusion pipeline
+```
+
+#### Limitations and bias
+
+[TODO: provide examples of latent issues and potential remediations]
+
+## Training data
+
+[TODO: describe the data used to train the model]
+
+### Training hyperparameters
+
+The following hyperparameters were used during training:
+- learning_rate: {{ learning_rate }}
+- train_batch_size: {{ train_batch_size }}
+- eval_batch_size: {{ eval_batch_size }}
+- gradient_accumulation_steps: {{ gradient_accumulation_steps }}
+- optimizer: AdamW with betas=({{ adam_beta1 }}, {{ adam_beta2 }}), weight_decay={{ adam_weight_decay }} and epsilon={{ adam_epsilon }}
+- lr_scheduler: {{ lr_scheduler }}
+- lr_warmup_steps: {{ lr_warmup_steps }}
+- ema_inv_gamma: {{ ema_inv_gamma }}
+- ema_inv_gamma: {{ ema_power }}
+- ema_inv_gamma: {{ ema_max_decay }}
+- mixed_precision: {{ mixed_precision }}
+
+### Training results
+
+📈 [TensorBoard logs](https://huggingface.co/{{ repo_name }}/tensorboard?#scalars)
+
+
diff --git a/my_diffusers/utils/outputs.py b/my_diffusers/utils/outputs.py
new file mode 100644
index 0000000000000000000000000000000000000000..b02f62d02d0322401fd9926aca9f792a4696cc1e
--- /dev/null
+++ b/my_diffusers/utils/outputs.py
@@ -0,0 +1,109 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Generic utilities
+"""
+
+import warnings
+from collections import OrderedDict
+from dataclasses import fields
+from typing import Any, Tuple
+
+import numpy as np
+
+from .import_utils import is_torch_available
+
+
+def is_tensor(x):
+ """
+ Tests if `x` is a `torch.Tensor` or `np.ndarray`.
+ """
+ if is_torch_available():
+ import torch
+
+ if isinstance(x, torch.Tensor):
+ return True
+
+ return isinstance(x, np.ndarray)
+
+
+class BaseOutput(OrderedDict):
+ """
+ Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a
+ tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
+ python dictionary.
+
+
+
+ You can't unpack a `BaseOutput` directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple
+ before.
+
+
+ """
+
+ def __post_init__(self):
+ class_fields = fields(self)
+
+ # Safety and consistency checks
+ if not len(class_fields):
+ raise ValueError(f"{self.__class__.__name__} has no fields.")
+
+ for field in class_fields:
+ v = getattr(self, field.name)
+ if v is not None:
+ self[field.name] = v
+
+ def __delitem__(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
+
+ def setdefault(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
+
+ def pop(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
+
+ def update(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
+
+ def __getitem__(self, k):
+ if isinstance(k, str):
+ inner_dict = {k: v for (k, v) in self.items()}
+ if self.__class__.__name__ in ["StableDiffusionPipelineOutput", "ImagePipelineOutput"] and k == "sample":
+ warnings.warn(
+ "The keyword 'samples' is deprecated and will be removed in version 0.4.0. Please use `.images` or"
+ " `'images'` instead.",
+ DeprecationWarning,
+ )
+ return inner_dict["images"]
+ return inner_dict[k]
+ else:
+ return self.to_tuple()[k]
+
+ def __setattr__(self, name, value):
+ if name in self.keys() and value is not None:
+ # Don't call self.__setitem__ to avoid recursion errors
+ super().__setitem__(name, value)
+ super().__setattr__(name, value)
+
+ def __setitem__(self, key, value):
+ # Will raise a KeyException if needed
+ super().__setitem__(key, value)
+ # Don't call self.__setattr__ to avoid recursion errors
+ super().__setattr__(key, value)
+
+ def to_tuple(self) -> Tuple[Any]:
+ """
+ Convert self to a tuple containing all the attributes/keys that are not `None`.
+ """
+ return tuple(self[k] for k in self.keys())
diff --git a/my_half_diffusers/__init__.py b/my_half_diffusers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf2f183c9b5dc45a3cb40a3b2408833f6966ac96
--- /dev/null
+++ b/my_half_diffusers/__init__.py
@@ -0,0 +1,60 @@
+from .utils import (
+ is_inflect_available,
+ is_onnx_available,
+ is_scipy_available,
+ is_transformers_available,
+ is_unidecode_available,
+)
+
+
+__version__ = "0.3.0"
+
+from .configuration_utils import ConfigMixin
+from .modeling_utils import ModelMixin
+from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
+from .onnx_utils import OnnxRuntimeModel
+from .optimization import (
+ get_constant_schedule,
+ get_constant_schedule_with_warmup,
+ get_cosine_schedule_with_warmup,
+ get_cosine_with_hard_restarts_schedule_with_warmup,
+ get_linear_schedule_with_warmup,
+ get_polynomial_decay_schedule_with_warmup,
+ get_scheduler,
+)
+from .pipeline_utils import DiffusionPipeline
+from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
+from .schedulers import (
+ DDIMScheduler,
+ DDPMScheduler,
+ KarrasVeScheduler,
+ PNDMScheduler,
+ SchedulerMixin,
+ ScoreSdeVeScheduler,
+)
+from .utils import logging
+
+
+if is_scipy_available():
+ from .schedulers import LMSDiscreteScheduler
+else:
+ from .utils.dummy_scipy_objects import * # noqa F403
+
+from .training_utils import EMAModel
+
+
+if is_transformers_available():
+ from .pipelines import (
+ LDMTextToImagePipeline,
+ StableDiffusionImg2ImgPipeline,
+ StableDiffusionInpaintPipeline,
+ StableDiffusionPipeline,
+ )
+else:
+ from .utils.dummy_transformers_objects import * # noqa F403
+
+
+if is_transformers_available() and is_onnx_available():
+ from .pipelines import StableDiffusionOnnxPipeline
+else:
+ from .utils.dummy_transformers_and_onnx_objects import * # noqa F403
diff --git a/my_half_diffusers/__pycache__/__init__.cpython-38.pyc b/my_half_diffusers/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f4a340c7ac2f492270860c712c1cc844de062995
Binary files /dev/null and b/my_half_diffusers/__pycache__/__init__.cpython-38.pyc differ
diff --git a/my_half_diffusers/__pycache__/configuration_utils.cpython-38.pyc b/my_half_diffusers/__pycache__/configuration_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cddea59df5762c516faef7acbc1001dddba061cd
Binary files /dev/null and b/my_half_diffusers/__pycache__/configuration_utils.cpython-38.pyc differ
diff --git a/my_half_diffusers/__pycache__/modeling_utils.cpython-38.pyc b/my_half_diffusers/__pycache__/modeling_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d6ea4d9b06efa90c24a6c007d98c0109a820417c
Binary files /dev/null and b/my_half_diffusers/__pycache__/modeling_utils.cpython-38.pyc differ
diff --git a/my_half_diffusers/__pycache__/onnx_utils.cpython-38.pyc b/my_half_diffusers/__pycache__/onnx_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..09e521b4c33b617753ee815f7b11956e6bb63e3e
Binary files /dev/null and b/my_half_diffusers/__pycache__/onnx_utils.cpython-38.pyc differ
diff --git a/my_half_diffusers/__pycache__/optimization.cpython-38.pyc b/my_half_diffusers/__pycache__/optimization.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d09eb1e7cf1f82d60c3df53baf549881338c2292
Binary files /dev/null and b/my_half_diffusers/__pycache__/optimization.cpython-38.pyc differ
diff --git a/my_half_diffusers/__pycache__/pipeline_utils.cpython-38.pyc b/my_half_diffusers/__pycache__/pipeline_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5a881377671b089e8343f2c719b3ad96ae926acf
Binary files /dev/null and b/my_half_diffusers/__pycache__/pipeline_utils.cpython-38.pyc differ
diff --git a/my_half_diffusers/__pycache__/training_utils.cpython-38.pyc b/my_half_diffusers/__pycache__/training_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dc3cc500b84874c9270f0f7df38158028bb5bf27
Binary files /dev/null and b/my_half_diffusers/__pycache__/training_utils.cpython-38.pyc differ
diff --git a/my_half_diffusers/commands/__init__.py b/my_half_diffusers/commands/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..902bd46cedc6f2df785c1dc5d2e6bd8ef7c69ca6
--- /dev/null
+++ b/my_half_diffusers/commands/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABC, abstractmethod
+from argparse import ArgumentParser
+
+
+class BaseDiffusersCLICommand(ABC):
+ @staticmethod
+ @abstractmethod
+ def register_subcommand(parser: ArgumentParser):
+ raise NotImplementedError()
+
+ @abstractmethod
+ def run(self):
+ raise NotImplementedError()
diff --git a/my_half_diffusers/commands/diffusers_cli.py b/my_half_diffusers/commands/diffusers_cli.py
new file mode 100644
index 0000000000000000000000000000000000000000..30084e55ba4eeec79c87a99eae3e60a6233dc556
--- /dev/null
+++ b/my_half_diffusers/commands/diffusers_cli.py
@@ -0,0 +1,41 @@
+#!/usr/bin/env python
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from argparse import ArgumentParser
+
+from .env import EnvironmentCommand
+
+
+def main():
+ parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli []")
+ commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
+
+ # Register commands
+ EnvironmentCommand.register_subcommand(commands_parser)
+
+ # Let's go
+ args = parser.parse_args()
+
+ if not hasattr(args, "func"):
+ parser.print_help()
+ exit(1)
+
+ # Run
+ service = args.func(args)
+ service.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/my_half_diffusers/commands/env.py b/my_half_diffusers/commands/env.py
new file mode 100644
index 0000000000000000000000000000000000000000..81a878bff6688d3c510b53c60ac9d0e51e4aebcc
--- /dev/null
+++ b/my_half_diffusers/commands/env.py
@@ -0,0 +1,70 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import platform
+from argparse import ArgumentParser
+
+import huggingface_hub
+
+from .. import __version__ as version
+from ..utils import is_torch_available, is_transformers_available
+from . import BaseDiffusersCLICommand
+
+
+def info_command_factory(_):
+ return EnvironmentCommand()
+
+
+class EnvironmentCommand(BaseDiffusersCLICommand):
+ @staticmethod
+ def register_subcommand(parser: ArgumentParser):
+ download_parser = parser.add_parser("env")
+ download_parser.set_defaults(func=info_command_factory)
+
+ def run(self):
+ hub_version = huggingface_hub.__version__
+
+ pt_version = "not installed"
+ pt_cuda_available = "NA"
+ if is_torch_available():
+ import torch
+
+ pt_version = torch.__version__
+ pt_cuda_available = torch.cuda.is_available()
+
+ transformers_version = "not installed"
+ if is_transformers_available:
+ import transformers
+
+ transformers_version = transformers.__version__
+
+ info = {
+ "`diffusers` version": version,
+ "Platform": platform.platform(),
+ "Python version": platform.python_version(),
+ "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
+ "Huggingface_hub version": hub_version,
+ "Transformers version": transformers_version,
+ "Using GPU in script?": "",
+ "Using distributed or parallel set-up in script?": "",
+ }
+
+ print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
+ print(self.format_dict(info))
+
+ return info
+
+ @staticmethod
+ def format_dict(d):
+ return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
diff --git a/my_half_diffusers/configuration_utils.py b/my_half_diffusers/configuration_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbe75f3f1441d3df5e2fe1a88aa758c51040c05c
--- /dev/null
+++ b/my_half_diffusers/configuration_utils.py
@@ -0,0 +1,403 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" ConfigMixinuration base class and utilities."""
+import functools
+import inspect
+import json
+import os
+import re
+from collections import OrderedDict
+from typing import Any, Dict, Tuple, Union
+
+from huggingface_hub import hf_hub_download
+from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
+from requests import HTTPError
+
+from . import __version__
+from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
+
+
+logger = logging.get_logger(__name__)
+
+_re_configuration_file = re.compile(r"config\.(.*)\.json")
+
+
+class ConfigMixin:
+ r"""
+ Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
+ methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with
+ - [`~ConfigMixin.from_config`]
+ - [`~ConfigMixin.save_config`]
+
+ Class attributes:
+ - **config_name** (`str`) -- A filename under which the config should stored when calling
+ [`~ConfigMixin.save_config`] (should be overriden by parent class).
+ - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
+ overriden by parent class).
+ """
+ config_name = None
+ ignore_for_config = []
+
+ def register_to_config(self, **kwargs):
+ if self.config_name is None:
+ raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
+ kwargs["_class_name"] = self.__class__.__name__
+ kwargs["_diffusers_version"] = __version__
+
+ for key, value in kwargs.items():
+ try:
+ setattr(self, key, value)
+ except AttributeError as err:
+ logger.error(f"Can't set {key} with value {value} for {self}")
+ raise err
+
+ if not hasattr(self, "_internal_dict"):
+ internal_dict = kwargs
+ else:
+ previous_dict = dict(self._internal_dict)
+ internal_dict = {**self._internal_dict, **kwargs}
+ logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
+
+ self._internal_dict = FrozenDict(internal_dict)
+
+ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+ """
+ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
+ [`~ConfigMixin.from_config`] class method.
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
+ """
+ if os.path.isfile(save_directory):
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ # If we save using the predefined names, we can load using `from_config`
+ output_config_file = os.path.join(save_directory, self.config_name)
+
+ self.to_json_file(output_config_file)
+ logger.info(f"ConfigMixinuration saved in {output_config_file}")
+
+ @classmethod
+ def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
+ r"""
+ Instantiate a Python class from a pre-defined JSON-file.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
+ organization name, like `google/ddpm-celebahq-256`.
+ - A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g.,
+ `./my_model_directory/`.
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
+ Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
+ as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
+ checkpoint with 3 labels).
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `transformers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ mirror (`str`, *optional*):
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
+ Please refer to the mirror site for more information.
+
+
+
+ Passing `use_auth_token=True`` is required when you want to use a private model.
+
+
+
+
+
+ Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
+ use this method in a firewalled environment.
+
+
+
+ """
+ config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
+
+ init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
+
+ model = cls(**init_dict)
+
+ if return_unused_kwargs:
+ return model, unused_kwargs
+ else:
+ return model
+
+ @classmethod
+ def get_config_dict(
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+
+ user_agent = {"file_type": "config"}
+
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+
+ if cls.config_name is None:
+ raise ValueError(
+ "`self.config_name` is not defined. Note that one should not load a config from "
+ "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
+ )
+
+ if os.path.isfile(pretrained_model_name_or_path):
+ config_file = pretrained_model_name_or_path
+ elif os.path.isdir(pretrained_model_name_or_path):
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
+ # Load from a PyTorch checkpoint
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
+ elif subfolder is not None and os.path.isfile(
+ os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
+ ):
+ config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
+ else:
+ raise EnvironmentError(
+ f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
+ )
+ else:
+ try:
+ # Load from URL or cache if already cached
+ config_file = hf_hub_download(
+ pretrained_model_name_or_path,
+ filename=cls.config_name,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ user_agent=user_agent,
+ subfolder=subfolder,
+ revision=revision,
+ )
+
+ except RepositoryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
+ " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
+ " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
+ " login` and pass `use_auth_token=True`."
+ )
+ except RevisionNotFoundError:
+ raise EnvironmentError(
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
+ " this model name. Check the model page at"
+ f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
+ )
+ except EntryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
+ )
+ except HTTPError as err:
+ raise EnvironmentError(
+ "There was a specific connection error when trying to load"
+ f" {pretrained_model_name_or_path}:\n{err}"
+ )
+ except ValueError:
+ raise EnvironmentError(
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
+ f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
+ " run the library in offline mode at"
+ " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
+ )
+ except EnvironmentError:
+ raise EnvironmentError(
+ f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
+ f"containing a {cls.config_name} file"
+ )
+
+ try:
+ # Load config dict
+ config_dict = cls._dict_from_json_file(config_file)
+ except (json.JSONDecodeError, UnicodeDecodeError):
+ raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
+
+ return config_dict
+
+ @classmethod
+ def extract_init_dict(cls, config_dict, **kwargs):
+ expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
+ expected_keys.remove("self")
+ # remove general kwargs if present in dict
+ if "kwargs" in expected_keys:
+ expected_keys.remove("kwargs")
+ # remove keys to be ignored
+ if len(cls.ignore_for_config) > 0:
+ expected_keys = expected_keys - set(cls.ignore_for_config)
+ init_dict = {}
+ for key in expected_keys:
+ if key in kwargs:
+ # overwrite key
+ init_dict[key] = kwargs.pop(key)
+ elif key in config_dict:
+ # use value from config dict
+ init_dict[key] = config_dict.pop(key)
+
+ unused_kwargs = config_dict.update(kwargs)
+
+ passed_keys = set(init_dict.keys())
+ if len(expected_keys - passed_keys) > 0:
+ logger.warning(
+ f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
+ )
+
+ return init_dict, unused_kwargs
+
+ @classmethod
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
+ with open(json_file, "r", encoding="utf-8") as reader:
+ text = reader.read()
+ return json.loads(text)
+
+ def __repr__(self):
+ return f"{self.__class__.__name__} {self.to_json_string()}"
+
+ @property
+ def config(self) -> Dict[str, Any]:
+ return self._internal_dict
+
+ def to_json_string(self) -> str:
+ """
+ Serializes this instance to a JSON string.
+
+ Returns:
+ `str`: String containing all the attributes that make up this configuration instance in JSON format.
+ """
+ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
+
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
+ """
+ Save this instance to a JSON file.
+
+ Args:
+ json_file_path (`str` or `os.PathLike`):
+ Path to the JSON file in which this configuration instance's parameters will be saved.
+ """
+ with open(json_file_path, "w", encoding="utf-8") as writer:
+ writer.write(self.to_json_string())
+
+
+class FrozenDict(OrderedDict):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ for key, value in self.items():
+ setattr(self, key, value)
+
+ self.__frozen = True
+
+ def __delitem__(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
+
+ def setdefault(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
+
+ def pop(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
+
+ def update(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
+
+ def __setattr__(self, name, value):
+ if hasattr(self, "__frozen") and self.__frozen:
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
+ super().__setattr__(name, value)
+
+ def __setitem__(self, name, value):
+ if hasattr(self, "__frozen") and self.__frozen:
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
+ super().__setitem__(name, value)
+
+
+def register_to_config(init):
+ r"""
+ Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
+ automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
+ shouldn't be registered in the config, use the `ignore_for_config` class variable
+
+ Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
+ """
+
+ @functools.wraps(init)
+ def inner_init(self, *args, **kwargs):
+ # Ignore private kwargs in the init.
+ init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
+ init(self, *args, **init_kwargs)
+ if not isinstance(self, ConfigMixin):
+ raise RuntimeError(
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
+ "not inherit from `ConfigMixin`."
+ )
+
+ ignore = getattr(self, "ignore_for_config", [])
+ # Get positional arguments aligned with kwargs
+ new_kwargs = {}
+ signature = inspect.signature(init)
+ parameters = {
+ name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
+ }
+ for arg, name in zip(args, parameters.keys()):
+ new_kwargs[name] = arg
+
+ # Then add all kwargs
+ new_kwargs.update(
+ {
+ k: init_kwargs.get(k, default)
+ for k, default in parameters.items()
+ if k not in ignore and k not in new_kwargs
+ }
+ )
+ getattr(self, "register_to_config")(**new_kwargs)
+
+ return inner_init
diff --git a/my_half_diffusers/dependency_versions_check.py b/my_half_diffusers/dependency_versions_check.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbf863222a52fd60a15a95be0fbd6391acd3ba6d
--- /dev/null
+++ b/my_half_diffusers/dependency_versions_check.py
@@ -0,0 +1,47 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+
+from .dependency_versions_table import deps
+from .utils.versions import require_version, require_version_core
+
+
+# define which module versions we always want to check at run time
+# (usually the ones defined in `install_requires` in setup.py)
+#
+# order specific notes:
+# - tqdm must be checked before tokenizers
+
+pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split()
+if sys.version_info < (3, 7):
+ pkgs_to_check_at_runtime.append("dataclasses")
+if sys.version_info < (3, 8):
+ pkgs_to_check_at_runtime.append("importlib_metadata")
+
+for pkg in pkgs_to_check_at_runtime:
+ if pkg in deps:
+ if pkg == "tokenizers":
+ # must be loaded here, or else tqdm check may fail
+ from .utils import is_tokenizers_available
+
+ if not is_tokenizers_available():
+ continue # not required, check version only if installed
+
+ require_version_core(deps[pkg])
+ else:
+ raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
+
+
+def dep_version_check(pkg, hint=None):
+ require_version(deps[pkg], hint)
diff --git a/my_half_diffusers/dependency_versions_table.py b/my_half_diffusers/dependency_versions_table.py
new file mode 100644
index 0000000000000000000000000000000000000000..74c5331e5af63fbab6e583da377c811e00791391
--- /dev/null
+++ b/my_half_diffusers/dependency_versions_table.py
@@ -0,0 +1,26 @@
+# THIS FILE HAS BEEN AUTOGENERATED. To update:
+# 1. modify the `_deps` dict in setup.py
+# 2. run `make deps_table_update``
+deps = {
+ "Pillow": "Pillow",
+ "accelerate": "accelerate>=0.11.0",
+ "black": "black==22.3",
+ "datasets": "datasets",
+ "filelock": "filelock",
+ "flake8": "flake8>=3.8.3",
+ "hf-doc-builder": "hf-doc-builder>=0.3.0",
+ "huggingface-hub": "huggingface-hub>=0.8.1",
+ "importlib_metadata": "importlib_metadata",
+ "isort": "isort>=5.5.4",
+ "modelcards": "modelcards==0.1.4",
+ "numpy": "numpy",
+ "pytest": "pytest",
+ "pytest-timeout": "pytest-timeout",
+ "pytest-xdist": "pytest-xdist",
+ "scipy": "scipy",
+ "regex": "regex!=2019.12.17",
+ "requests": "requests",
+ "tensorboard": "tensorboard",
+ "torch": "torch>=1.4",
+ "transformers": "transformers>=4.21.0",
+}
diff --git a/my_half_diffusers/dynamic_modules_utils.py b/my_half_diffusers/dynamic_modules_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ebf916e7af5768be3d3dc9984e5c2a066c5b4a2
--- /dev/null
+++ b/my_half_diffusers/dynamic_modules_utils.py
@@ -0,0 +1,335 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utilities to dynamically load objects from the Hub."""
+
+import importlib
+import os
+import re
+import shutil
+import sys
+from pathlib import Path
+from typing import Dict, Optional, Union
+
+from huggingface_hub import cached_download
+
+from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def init_hf_modules():
+ """
+ Creates the cache directory for modules with an init, and adds it to the Python path.
+ """
+ # This function has already been executed if HF_MODULES_CACHE already is in the Python path.
+ if HF_MODULES_CACHE in sys.path:
+ return
+
+ sys.path.append(HF_MODULES_CACHE)
+ os.makedirs(HF_MODULES_CACHE, exist_ok=True)
+ init_path = Path(HF_MODULES_CACHE) / "__init__.py"
+ if not init_path.exists():
+ init_path.touch()
+
+
+def create_dynamic_module(name: Union[str, os.PathLike]):
+ """
+ Creates a dynamic module in the cache directory for modules.
+ """
+ init_hf_modules()
+ dynamic_module_path = Path(HF_MODULES_CACHE) / name
+ # If the parent module does not exist yet, recursively create it.
+ if not dynamic_module_path.parent.exists():
+ create_dynamic_module(dynamic_module_path.parent)
+ os.makedirs(dynamic_module_path, exist_ok=True)
+ init_path = dynamic_module_path / "__init__.py"
+ if not init_path.exists():
+ init_path.touch()
+
+
+def get_relative_imports(module_file):
+ """
+ Get the list of modules that are relatively imported in a module file.
+
+ Args:
+ module_file (`str` or `os.PathLike`): The module file to inspect.
+ """
+ with open(module_file, "r", encoding="utf-8") as f:
+ content = f.read()
+
+ # Imports of the form `import .xxx`
+ relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
+ # Imports of the form `from .xxx import yyy`
+ relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
+ # Unique-ify
+ return list(set(relative_imports))
+
+
+def get_relative_import_files(module_file):
+ """
+ Get the list of all files that are needed for a given module. Note that this function recurses through the relative
+ imports (if a imports b and b imports c, it will return module files for b and c).
+
+ Args:
+ module_file (`str` or `os.PathLike`): The module file to inspect.
+ """
+ no_change = False
+ files_to_check = [module_file]
+ all_relative_imports = []
+
+ # Let's recurse through all relative imports
+ while not no_change:
+ new_imports = []
+ for f in files_to_check:
+ new_imports.extend(get_relative_imports(f))
+
+ module_path = Path(module_file).parent
+ new_import_files = [str(module_path / m) for m in new_imports]
+ new_import_files = [f for f in new_import_files if f not in all_relative_imports]
+ files_to_check = [f"{f}.py" for f in new_import_files]
+
+ no_change = len(new_import_files) == 0
+ all_relative_imports.extend(files_to_check)
+
+ return all_relative_imports
+
+
+def check_imports(filename):
+ """
+ Check if the current Python environment contains all the libraries that are imported in a file.
+ """
+ with open(filename, "r", encoding="utf-8") as f:
+ content = f.read()
+
+ # Imports of the form `import xxx`
+ imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
+ # Imports of the form `from xxx import yyy`
+ imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
+ # Only keep the top-level module
+ imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
+
+ # Unique-ify and test we got them all
+ imports = list(set(imports))
+ missing_packages = []
+ for imp in imports:
+ try:
+ importlib.import_module(imp)
+ except ImportError:
+ missing_packages.append(imp)
+
+ if len(missing_packages) > 0:
+ raise ImportError(
+ "This modeling file requires the following packages that were not found in your environment: "
+ f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
+ )
+
+ return get_relative_imports(filename)
+
+
+def get_class_in_module(class_name, module_path):
+ """
+ Import a module on the cache directory for modules and extract a class from it.
+ """
+ module_path = module_path.replace(os.path.sep, ".")
+ module = importlib.import_module(module_path)
+ return getattr(module, class_name)
+
+
+def get_cached_module_file(
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ module_file: str,
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ resume_download: bool = False,
+ proxies: Optional[Dict[str, str]] = None,
+ use_auth_token: Optional[Union[bool, str]] = None,
+ revision: Optional[str] = None,
+ local_files_only: bool = False,
+):
+ """
+ Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
+ Transformers module.
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
+ huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
+ under a user or organization name, like `dbmdz/bert-base-german-cased`.
+ - a path to a *directory* containing a configuration file saved using the
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
+
+ module_file (`str`):
+ The name of the module file containing the class to look for.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
+ cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
+ exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `transformers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ If `True`, will only try to load the tokenizer configuration from local files.
+
+
+
+ Passing `use_auth_token=True` is required when you want to use a private model.
+
+
+
+ Returns:
+ `str`: The path to the module inside the cache.
+ """
+ # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+ module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
+ submodule = "local"
+
+ if os.path.isfile(module_file_or_url):
+ resolved_module_file = module_file_or_url
+ else:
+ try:
+ # Load from URL or cache if already cached
+ resolved_module_file = cached_download(
+ module_file_or_url,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ )
+
+ except EnvironmentError:
+ logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
+ raise
+
+ # Check we have all the requirements in our environment
+ modules_needed = check_imports(resolved_module_file)
+
+ # Now we move the module inside our cached dynamic modules.
+ full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
+ create_dynamic_module(full_submodule)
+ submodule_path = Path(HF_MODULES_CACHE) / full_submodule
+ # We always copy local files (we could hash the file to see if there was a change, and give them the name of
+ # that hash, to only copy when there is a modification but it seems overkill for now).
+ # The only reason we do the copy is to avoid putting too many folders in sys.path.
+ shutil.copy(resolved_module_file, submodule_path / module_file)
+ for module_needed in modules_needed:
+ module_needed = f"{module_needed}.py"
+ shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
+ return os.path.join(full_submodule, module_file)
+
+
+def get_class_from_dynamic_module(
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ module_file: str,
+ class_name: str,
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ resume_download: bool = False,
+ proxies: Optional[Dict[str, str]] = None,
+ use_auth_token: Optional[Union[bool, str]] = None,
+ revision: Optional[str] = None,
+ local_files_only: bool = False,
+ **kwargs,
+):
+ """
+ Extracts a class from a module file, present in the local folder or repository of a model.
+
+
+
+ Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should
+ therefore only be called on trusted repos.
+
+
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
+ huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
+ under a user or organization name, like `dbmdz/bert-base-german-cased`.
+ - a path to a *directory* containing a configuration file saved using the
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
+
+ module_file (`str`):
+ The name of the module file containing the class to look for.
+ class_name (`str`):
+ The name of the class to import in the module.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
+ cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
+ exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ use_auth_token (`str` or `bool`, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `transformers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ If `True`, will only try to load the tokenizer configuration from local files.
+
+
+
+ Passing `use_auth_token=True` is required when you want to use a private model.
+
+
+
+ Returns:
+ `type`: The class, dynamically imported from the module.
+
+ Examples:
+
+ ```python
+ # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
+ # module.
+ cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel")
+ ```"""
+ # And lastly we get the class inside our newly created module
+ final_module = get_cached_module_file(
+ pretrained_model_name_or_path,
+ module_file,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ local_files_only=local_files_only,
+ )
+ return get_class_in_module(class_name, final_module.replace(".py", ""))
diff --git a/my_half_diffusers/hub_utils.py b/my_half_diffusers/hub_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c07329e36fe7a8826b0f1fb22396819b220e1b58
--- /dev/null
+++ b/my_half_diffusers/hub_utils.py
@@ -0,0 +1,197 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import shutil
+from pathlib import Path
+from typing import Optional
+
+from huggingface_hub import HfFolder, Repository, whoami
+
+from .pipeline_utils import DiffusionPipeline
+from .utils import is_modelcards_available, logging
+
+
+if is_modelcards_available():
+ from modelcards import CardData, ModelCard
+
+
+logger = logging.get_logger(__name__)
+
+
+MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md"
+
+
+def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
+ if token is None:
+ token = HfFolder.get_token()
+ if organization is None:
+ username = whoami(token)["name"]
+ return f"{username}/{model_id}"
+ else:
+ return f"{organization}/{model_id}"
+
+
+def init_git_repo(args, at_init: bool = False):
+ """
+ Args:
+ Initializes a git repo in `args.hub_model_id`.
+ at_init (`bool`, *optional*, defaults to `False`):
+ Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True`
+ and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out.
+ """
+ if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
+ return
+ hub_token = args.hub_token if hasattr(args, "hub_token") else None
+ use_auth_token = True if hub_token is None else hub_token
+ if not hasattr(args, "hub_model_id") or args.hub_model_id is None:
+ repo_name = Path(args.output_dir).absolute().name
+ else:
+ repo_name = args.hub_model_id
+ if "/" not in repo_name:
+ repo_name = get_full_repo_name(repo_name, token=hub_token)
+
+ try:
+ repo = Repository(
+ args.output_dir,
+ clone_from=repo_name,
+ use_auth_token=use_auth_token,
+ private=args.hub_private_repo,
+ )
+ except EnvironmentError:
+ if args.overwrite_output_dir and at_init:
+ # Try again after wiping output_dir
+ shutil.rmtree(args.output_dir)
+ repo = Repository(
+ args.output_dir,
+ clone_from=repo_name,
+ use_auth_token=use_auth_token,
+ )
+ else:
+ raise
+
+ repo.git_pull()
+
+ # By default, ignore the checkpoint folders
+ if not os.path.exists(os.path.join(args.output_dir, ".gitignore")):
+ with open(os.path.join(args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
+ writer.writelines(["checkpoint-*/"])
+
+ return repo
+
+
+def push_to_hub(
+ args,
+ pipeline: DiffusionPipeline,
+ repo: Repository,
+ commit_message: Optional[str] = "End of training",
+ blocking: bool = True,
+ **kwargs,
+) -> str:
+ """
+ Parameters:
+ Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
+ commit_message (`str`, *optional*, defaults to `"End of training"`):
+ Message to commit while pushing.
+ blocking (`bool`, *optional*, defaults to `True`):
+ Whether the function should return only when the `git push` has finished.
+ kwargs:
+ Additional keyword arguments passed along to [`create_model_card`].
+ Returns:
+ The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of the
+ commit and an object to track the progress of the commit if `blocking=True`
+ """
+
+ if not hasattr(args, "hub_model_id") or args.hub_model_id is None:
+ model_name = Path(args.output_dir).name
+ else:
+ model_name = args.hub_model_id.split("/")[-1]
+
+ output_dir = args.output_dir
+ os.makedirs(output_dir, exist_ok=True)
+ logger.info(f"Saving pipeline checkpoint to {output_dir}")
+ pipeline.save_pretrained(output_dir)
+
+ # Only push from one node.
+ if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
+ return
+
+ # Cancel any async push in progress if blocking=True. The commits will all be pushed together.
+ if (
+ blocking
+ and len(repo.command_queue) > 0
+ and repo.command_queue[-1] is not None
+ and not repo.command_queue[-1].is_done
+ ):
+ repo.command_queue[-1]._process.kill()
+
+ git_head_commit_url = repo.push_to_hub(commit_message=commit_message, blocking=blocking, auto_lfs_prune=True)
+ # push separately the model card to be independent from the rest of the model
+ create_model_card(args, model_name=model_name)
+ try:
+ repo.push_to_hub(commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True)
+ except EnvironmentError as exc:
+ logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")
+
+ return git_head_commit_url
+
+
+def create_model_card(args, model_name):
+ if not is_modelcards_available:
+ raise ValueError(
+ "Please make sure to have `modelcards` installed when using the `create_model_card` function. You can"
+ " install the package with `pip install modelcards`."
+ )
+
+ if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
+ return
+
+ hub_token = args.hub_token if hasattr(args, "hub_token") else None
+ repo_name = get_full_repo_name(model_name, token=hub_token)
+
+ model_card = ModelCard.from_template(
+ card_data=CardData( # Card metadata object that will be converted to YAML block
+ language="en",
+ license="apache-2.0",
+ library_name="diffusers",
+ tags=[],
+ datasets=args.dataset_name,
+ metrics=[],
+ ),
+ template_path=MODEL_CARD_TEMPLATE_PATH,
+ model_name=model_name,
+ repo_name=repo_name,
+ dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None,
+ learning_rate=args.learning_rate,
+ train_batch_size=args.train_batch_size,
+ eval_batch_size=args.eval_batch_size,
+ gradient_accumulation_steps=args.gradient_accumulation_steps
+ if hasattr(args, "gradient_accumulation_steps")
+ else None,
+ adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None,
+ adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None,
+ adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None,
+ adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None,
+ lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None,
+ lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None,
+ ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None,
+ ema_power=args.ema_power if hasattr(args, "ema_power") else None,
+ ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None,
+ mixed_precision=args.mixed_precision,
+ )
+
+ card_path = os.path.join(args.output_dir, "README.md")
+ model_card.save(card_path)
diff --git a/my_half_diffusers/modeling_utils.py b/my_half_diffusers/modeling_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb613614a8782bf2eba2a2e7c2dc2af987088d6f
--- /dev/null
+++ b/my_half_diffusers/modeling_utils.py
@@ -0,0 +1,542 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from typing import Callable, List, Optional, Tuple, Union
+
+import torch
+from torch import Tensor, device
+
+from huggingface_hub import hf_hub_download
+from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
+from requests import HTTPError
+
+from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
+
+
+WEIGHTS_NAME = "diffusion_pytorch_model.bin"
+
+
+logger = logging.get_logger(__name__)
+
+
+def get_parameter_device(parameter: torch.nn.Module):
+ try:
+ return next(parameter.parameters()).device
+ except StopIteration:
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
+
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
+ return tuples
+
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
+ first_tuple = next(gen)
+ return first_tuple[1].device
+
+
+def get_parameter_dtype(parameter: torch.nn.Module):
+ try:
+ return next(parameter.parameters()).dtype
+ except StopIteration:
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
+
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
+ return tuples
+
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
+ first_tuple = next(gen)
+ return first_tuple[1].dtype
+
+
+def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
+ """
+ Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
+ """
+ try:
+ return torch.load(checkpoint_file, map_location="cpu")
+ except Exception as e:
+ try:
+ with open(checkpoint_file) as f:
+ if f.read().startswith("version"):
+ raise OSError(
+ "You seem to have cloned a repository without having git-lfs installed. Please install "
+ "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
+ "you cloned."
+ )
+ else:
+ raise ValueError(
+ f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
+ "model. Make sure you have saved the model properly."
+ ) from e
+ except (UnicodeDecodeError, ValueError):
+ raise OSError(
+ f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' "
+ f"at '{checkpoint_file}'. "
+ "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
+ )
+
+
+def _load_state_dict_into_model(model_to_load, state_dict):
+ # Convert old format to new format if needed from a PyTorch state_dict
+ # copy state_dict so _load_from_state_dict can modify it
+ state_dict = state_dict.copy()
+ error_msgs = []
+
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
+ # so we need to apply the function recursively.
+ def load(module: torch.nn.Module, prefix=""):
+ args = (state_dict, prefix, {}, True, [], [], error_msgs)
+ module._load_from_state_dict(*args)
+
+ for name, child in module._modules.items():
+ if child is not None:
+ load(child, prefix + name + ".")
+
+ load(model_to_load)
+
+ return error_msgs
+
+
+class ModelMixin(torch.nn.Module):
+ r"""
+ Base class for all models.
+
+ [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
+ and saving models.
+
+ - **config_name** ([`str`]) -- A filename under which the model should be stored when calling
+ [`~modeling_utils.ModelMixin.save_pretrained`].
+ """
+ config_name = CONFIG_NAME
+ _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
+
+ def __init__(self):
+ super().__init__()
+
+ def save_pretrained(
+ self,
+ save_directory: Union[str, os.PathLike],
+ is_main_process: bool = True,
+ save_function: Callable = torch.save,
+ ):
+ """
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
+ `[`~modeling_utils.ModelMixin.from_pretrained`]` class method.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to which to save. Will be created if it doesn't exist.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful when in distributed training like
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
+ the main process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
+ need to replace `torch.save` by another method.
+ """
+ if os.path.isfile(save_directory):
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+ return
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ model_to_save = self
+
+ # Attach architecture to the config
+ # Save the config
+ if is_main_process:
+ model_to_save.save_config(save_directory)
+
+ # Save the model
+ state_dict = model_to_save.state_dict()
+
+ # Clean the folder from a previous save
+ for filename in os.listdir(save_directory):
+ full_filename = os.path.join(save_directory, filename)
+ # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
+ # in distributed settings to avoid race conditions.
+ if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename) and is_main_process:
+ os.remove(full_filename)
+
+ # Save the model
+ save_function(state_dict, os.path.join(save_directory, WEIGHTS_NAME))
+
+ logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}")
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
+ r"""
+ Instantiate a pretrained pytorch model from a pre-trained model configuration.
+
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
+ the model, you should first set it back in training mode with `model.train()`.
+
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
+ task.
+
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
+ weights are discarded.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
+ - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
+ `./my_model_directory/`.
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
+ will be automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `diffusers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ mirror (`str`, *optional*):
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
+ Please refer to the mirror site for more information.
+
+
+
+ Passing `use_auth_token=True`` is required when you want to use a private model.
+
+
+
+
+
+ Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
+ this method in a firewalled environment.
+
+
+
+ """
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ output_loading_info = kwargs.pop("output_loading_info", False)
+ local_files_only = kwargs.pop("local_files_only", False)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ from_auto_class = kwargs.pop("_from_auto", False)
+ torch_dtype = kwargs.pop("torch_dtype", None)
+ subfolder = kwargs.pop("subfolder", None)
+
+ user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
+
+ # Load config if we don't provide a configuration
+ config_path = pretrained_model_name_or_path
+ model, unused_kwargs = cls.from_config(
+ config_path,
+ cache_dir=cache_dir,
+ return_unused_kwargs=True,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ **kwargs,
+ )
+
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
+ raise ValueError(
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
+ )
+ elif torch_dtype is not None:
+ model = model.to(torch_dtype)
+
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
+ # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
+ # Load model
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+ if os.path.isdir(pretrained_model_name_or_path):
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
+ # Load from a PyTorch checkpoint
+ model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
+ elif subfolder is not None and os.path.isfile(
+ os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
+ ):
+ model_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
+ else:
+ raise EnvironmentError(
+ f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
+ )
+ else:
+ try:
+ # Load from URL or cache if already cached
+ model_file = hf_hub_download(
+ pretrained_model_name_or_path,
+ filename=WEIGHTS_NAME,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ user_agent=user_agent,
+ subfolder=subfolder,
+ revision=revision,
+ )
+
+ except RepositoryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
+ "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
+ "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
+ "login` and pass `use_auth_token=True`."
+ )
+ except RevisionNotFoundError:
+ raise EnvironmentError(
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
+ "this model name. Check the model page at "
+ f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
+ )
+ except EntryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}."
+ )
+ except HTTPError as err:
+ raise EnvironmentError(
+ "There was a specific connection error when trying to load"
+ f" {pretrained_model_name_or_path}:\n{err}"
+ )
+ except ValueError:
+ raise EnvironmentError(
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
+ f" directory containing a file named {WEIGHTS_NAME} or"
+ " \nCheckout your internet connection or see how to run the library in"
+ " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
+ )
+ except EnvironmentError:
+ raise EnvironmentError(
+ f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
+ f"containing a file named {WEIGHTS_NAME}"
+ )
+
+ # restore default dtype
+ state_dict = load_state_dict(model_file)
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
+ model,
+ state_dict,
+ model_file,
+ pretrained_model_name_or_path,
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
+ )
+
+ # Set model in evaluation mode to deactivate DropOut modules by default
+ model.eval()
+
+ if output_loading_info:
+ loading_info = {
+ "missing_keys": missing_keys,
+ "unexpected_keys": unexpected_keys,
+ "mismatched_keys": mismatched_keys,
+ "error_msgs": error_msgs,
+ }
+ return model, loading_info
+
+ return model
+
+ @classmethod
+ def _load_pretrained_model(
+ cls,
+ model,
+ state_dict,
+ resolved_archive_file,
+ pretrained_model_name_or_path,
+ ignore_mismatched_sizes=False,
+ ):
+ # Retrieve missing & unexpected_keys
+ model_state_dict = model.state_dict()
+ loaded_keys = [k for k in state_dict.keys()]
+
+ expected_keys = list(model_state_dict.keys())
+
+ original_loaded_keys = loaded_keys
+
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
+
+ # Make sure we are able to load base models as well as derived models (with heads)
+ model_to_load = model
+
+ def _find_mismatched_keys(
+ state_dict,
+ model_state_dict,
+ loaded_keys,
+ ignore_mismatched_sizes,
+ ):
+ mismatched_keys = []
+ if ignore_mismatched_sizes:
+ for checkpoint_key in loaded_keys:
+ model_key = checkpoint_key
+
+ if (
+ model_key in model_state_dict
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
+ ):
+ mismatched_keys.append(
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
+ )
+ del state_dict[checkpoint_key]
+ return mismatched_keys
+
+ if state_dict is not None:
+ # Whole checkpoint
+ mismatched_keys = _find_mismatched_keys(
+ state_dict,
+ model_state_dict,
+ original_loaded_keys,
+ ignore_mismatched_sizes,
+ )
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
+
+ if len(error_msgs) > 0:
+ error_msg = "\n\t".join(error_msgs)
+ if "size mismatch" in error_msg:
+ error_msg += (
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
+ )
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
+
+ if len(unexpected_keys) > 0:
+ logger.warning(
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
+ " identical (initializing a BertForSequenceClassification model from a"
+ " BertForSequenceClassification model)."
+ )
+ else:
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
+ if len(missing_keys) > 0:
+ logger.warning(
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
+ )
+ elif len(mismatched_keys) == 0:
+ logger.info(
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
+ " without further training."
+ )
+ if len(mismatched_keys) > 0:
+ mismatched_warning = "\n".join(
+ [
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
+ for key, shape1, shape2 in mismatched_keys
+ ]
+ )
+ logger.warning(
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
+ " able to use it for predictions and inference."
+ )
+
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
+
+ @property
+ def device(self) -> device:
+ """
+ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
+ device).
+ """
+ return get_parameter_device(self)
+
+ @property
+ def dtype(self) -> torch.dtype:
+ """
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
+ """
+ return get_parameter_dtype(self)
+
+ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
+ """
+ Get number of (optionally, trainable or non-embeddings) parameters in the module.
+
+ Args:
+ only_trainable (`bool`, *optional*, defaults to `False`):
+ Whether or not to return only the number of trainable parameters
+
+ exclude_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether or not to return only the number of non-embeddings parameters
+
+ Returns:
+ `int`: The number of parameters.
+ """
+
+ if exclude_embeddings:
+ embedding_param_names = [
+ f"{name}.weight"
+ for name, module_type in self.named_modules()
+ if isinstance(module_type, torch.nn.Embedding)
+ ]
+ non_embedding_parameters = [
+ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
+ ]
+ return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
+ else:
+ return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
+
+
+def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
+ """
+ Recursively unwraps a model from potential containers (as used in distributed training).
+
+ Args:
+ model (`torch.nn.Module`): The model to unwrap.
+ """
+ # since there could be multiple levels of wrapping, unwrap recursively
+ if hasattr(model, "module"):
+ return unwrap_model(model.module)
+ else:
+ return model
diff --git a/my_half_diffusers/models/__init__.py b/my_half_diffusers/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0ac5c8d548b4ec2f7b9c84d5c6d884fd470385b
--- /dev/null
+++ b/my_half_diffusers/models/__init__.py
@@ -0,0 +1,17 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .unet_2d import UNet2DModel
+from .unet_2d_condition import UNet2DConditionModel
+from .vae import AutoencoderKL, VQModel
diff --git a/my_half_diffusers/models/__pycache__/__init__.cpython-38.pyc b/my_half_diffusers/models/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e24f6db27930d2f62b0104a819fef4d3a9028e09
Binary files /dev/null and b/my_half_diffusers/models/__pycache__/__init__.cpython-38.pyc differ
diff --git a/my_half_diffusers/models/__pycache__/attention.cpython-38.pyc b/my_half_diffusers/models/__pycache__/attention.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8e2b3fab6bd749e991f43db5b1d7ac4e47f2e1da
Binary files /dev/null and b/my_half_diffusers/models/__pycache__/attention.cpython-38.pyc differ
diff --git a/my_half_diffusers/models/__pycache__/embeddings.cpython-38.pyc b/my_half_diffusers/models/__pycache__/embeddings.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..039ba75033874a7289834f893f1d444986054c67
Binary files /dev/null and b/my_half_diffusers/models/__pycache__/embeddings.cpython-38.pyc differ
diff --git a/my_half_diffusers/models/__pycache__/resnet.cpython-38.pyc b/my_half_diffusers/models/__pycache__/resnet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b219aca96c808cc4b2ea7aa2842dfa7e527f354c
Binary files /dev/null and b/my_half_diffusers/models/__pycache__/resnet.cpython-38.pyc differ
diff --git a/my_half_diffusers/models/__pycache__/unet_2d.cpython-38.pyc b/my_half_diffusers/models/__pycache__/unet_2d.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3d254dfc791d59e6203d57c2e3874bff25aef46d
Binary files /dev/null and b/my_half_diffusers/models/__pycache__/unet_2d.cpython-38.pyc differ
diff --git a/my_half_diffusers/models/__pycache__/unet_2d_condition.cpython-38.pyc b/my_half_diffusers/models/__pycache__/unet_2d_condition.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1cb9ff7817ef7d7abf4a9d4b3142c730bd1d62ae
Binary files /dev/null and b/my_half_diffusers/models/__pycache__/unet_2d_condition.cpython-38.pyc differ
diff --git a/my_half_diffusers/models/__pycache__/unet_blocks.cpython-38.pyc b/my_half_diffusers/models/__pycache__/unet_blocks.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..73bff9b2daddc0d084318bd10e87673adc0b1eda
Binary files /dev/null and b/my_half_diffusers/models/__pycache__/unet_blocks.cpython-38.pyc differ
diff --git a/my_half_diffusers/models/__pycache__/vae.cpython-38.pyc b/my_half_diffusers/models/__pycache__/vae.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..78876d41e186d3dac991c583ad0f1d6f902df10f
Binary files /dev/null and b/my_half_diffusers/models/__pycache__/vae.cpython-38.pyc differ
diff --git a/my_half_diffusers/models/attention.py b/my_half_diffusers/models/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..3db2c9e97fae16a941704c3155cc89d8269679f3
--- /dev/null
+++ b/my_half_diffusers/models/attention.py
@@ -0,0 +1,333 @@
+import math
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
+ to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ Uses three q, k, v linear layers to compute attention.
+
+ Parameters:
+ channels (:obj:`int`): The number of channels in the input and output.
+ num_head_channels (:obj:`int`, *optional*):
+ The number of channels in each head. If None, then `num_heads` = 1.
+ num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm.
+ rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
+ eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ num_head_channels: Optional[int] = None,
+ num_groups: int = 32,
+ rescale_output_factor = 1.0,
+ eps = 1e-5,
+ ):
+ super().__init__()
+ self.channels = channels
+
+ self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
+ self.num_head_size = num_head_channels
+ self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
+
+ # define q,k,v as linear layers
+ self.query = nn.Linear(channels, channels)
+ self.key = nn.Linear(channels, channels)
+ self.value = nn.Linear(channels, channels)
+
+ self.rescale_output_factor = rescale_output_factor
+ self.proj_attn = nn.Linear(channels, channels, 1)
+
+ def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
+ new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
+ # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
+ new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
+ return new_projection
+
+ def forward(self, hidden_states):
+ residual = hidden_states
+ batch, channel, height, width = hidden_states.shape
+
+ # norm
+ hidden_states = self.group_norm(hidden_states)
+
+ hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
+
+ # proj to q, k, v
+ query_proj = self.query(hidden_states)
+ key_proj = self.key(hidden_states)
+ value_proj = self.value(hidden_states)
+
+ # transpose
+ query_states = self.transpose_for_scores(query_proj)
+ key_states = self.transpose_for_scores(key_proj)
+ value_states = self.transpose_for_scores(value_proj)
+
+ # get scores
+ scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
+
+ attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
+ attention_probs = torch.softmax(attention_scores, dim=-1).type(attention_scores.dtype)
+
+ # compute attention output
+ hidden_states = torch.matmul(attention_probs, value_states)
+
+ hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
+ new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
+ hidden_states = hidden_states.view(new_hidden_states_shape)
+
+ # compute next hidden_states
+ hidden_states = self.proj_attn(hidden_states)
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
+
+ # res connect and rescale
+ hidden_states = (hidden_states + residual) / self.rescale_output_factor
+ return hidden_states
+
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
+ standard transformer action. Finally, reshape to image.
+
+ Parameters:
+ in_channels (:obj:`int`): The number of channels in the input and output.
+ n_heads (:obj:`int`): The number of heads to use for multi-head attention.
+ d_head (:obj:`int`): The number of channels in each head.
+ depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use.
+ context_dim (:obj:`int`, *optional*): The number of context dimensions to use.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ n_heads: int,
+ d_head: int,
+ depth: int = 1,
+ dropout = 0.0,
+ context_dim: Optional[int] = None,
+ ):
+ super().__init__()
+ self.n_heads = n_heads
+ self.d_head = d_head
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
+ for d in range(depth)
+ ]
+ )
+
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+
+ def _set_attention_slice(self, slice_size):
+ for block in self.transformer_blocks:
+ block._set_attention_slice(slice_size)
+
+ def forward(self, x, context=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ x = self.proj_in(x)
+ x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
+ for block in self.transformer_blocks:
+ x = block(x, context=context)
+ x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
+ x = self.proj_out(x)
+ return x + x_in
+
+
+class BasicTransformerBlock(nn.Module):
+ r"""
+ A basic Transformer block.
+
+ Parameters:
+ dim (:obj:`int`): The number of channels in the input and output.
+ n_heads (:obj:`int`): The number of heads to use for multi-head attention.
+ d_head (:obj:`int`): The number of channels in each head.
+ dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention.
+ gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network.
+ checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ n_heads: int,
+ d_head: int,
+ dropout=0.0,
+ context_dim: Optional[int] = None,
+ gated_ff: bool = True,
+ checkpoint: bool = True,
+ ):
+ super().__init__()
+ self.attn1 = CrossAttention(
+ query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
+ ) # is a self-attention
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = CrossAttention(
+ query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
+ ) # is self-attn if context is none
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def _set_attention_slice(self, slice_size):
+ self.attn1._slice_size = slice_size
+ self.attn2._slice_size = slice_size
+
+ def forward(self, x, context=None):
+ x = x.contiguous() if x.device.type == "mps" else x
+ x = self.attn1(self.norm1(x)) + x
+ x = self.attn2(self.norm2(x), context=context) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class CrossAttention(nn.Module):
+ r"""
+ A cross attention layer.
+
+ Parameters:
+ query_dim (:obj:`int`): The number of channels in the query.
+ context_dim (:obj:`int`, *optional*):
+ The number of channels in the context. If not given, defaults to `query_dim`.
+ heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
+ dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head.
+ dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ """
+
+ def __init__(
+ self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0
+ ):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = context_dim if context_dim is not None else query_dim
+
+ self.scale = dim_head**-0.5
+ self.heads = heads
+ # for slice_size > 0 the attention score computation
+ # is split across the batch axis to save memory
+ # You can set slice_size with `set_attention_slice`
+ self._slice_size = None
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
+
+ def reshape_heads_to_batch_dim(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
+ return tensor
+
+ def reshape_batch_dim_to_heads(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def forward(self, x, context=None, mask=None):
+ batch_size, sequence_length, dim = x.shape
+
+ q = self.to_q(x)
+ context = context if context is not None else x
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ q = self.reshape_heads_to_batch_dim(q)
+ k = self.reshape_heads_to_batch_dim(k)
+ v = self.reshape_heads_to_batch_dim(v)
+
+ # TODO(PVP) - mask is currently never used. Remember to re-implement when used
+
+ # attention, what we cannot get enough of
+ hidden_states = self._attention(q, k, v, sequence_length, dim)
+
+ return self.to_out(hidden_states)
+
+ def _attention(self, query, key, value, sequence_length, dim):
+ batch_size_attention = query.shape[0]
+ hidden_states = torch.zeros(
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
+ )
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
+ for i in range(hidden_states.shape[0] // slice_size):
+ start_idx = i * slice_size
+ end_idx = (i + 1) * slice_size
+ attn_slice = (
+ torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale
+ )
+ attn_slice = attn_slice.softmax(dim=-1)
+ attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
+
+ hidden_states[start_idx:end_idx] = attn_slice
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+
+class FeedForward(nn.Module):
+ r"""
+ A feed-forward layer.
+
+ Parameters:
+ dim (:obj:`int`): The number of channels in the input.
+ dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
+ mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
+ glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation.
+ dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ """
+
+ def __init__(
+ self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout = 0.0
+ ):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = dim_out if dim_out is not None else dim
+ project_in = GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
+
+ def forward(self, x):
+ return self.net(x)
+
+
+# feedforward
+class GEGLU(nn.Module):
+ r"""
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
+
+ Parameters:
+ dim_in (:obj:`int`): The number of channels in the input.
+ dim_out (:obj:`int`): The number of channels in the output.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
diff --git a/my_half_diffusers/models/embeddings.py b/my_half_diffusers/models/embeddings.py
new file mode 100644
index 0000000000000000000000000000000000000000..57a6d14e0d226abd5e4c3f3f506d028bffdf3b22
--- /dev/null
+++ b/my_half_diffusers/models/embeddings.py
@@ -0,0 +1,116 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+
+import numpy as np
+import torch
+from torch import nn
+
+
+def get_timestep_embedding(
+ timesteps: torch.Tensor,
+ embedding_dim: int,
+ flip_sin_to_cos: bool = False,
+ downscale_freq_shift: float = 1,
+ scale: float = 1,
+ max_period: int = 10000,
+):
+ # print(timesteps)
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
+
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
+ """
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
+
+ half_dim = embedding_dim // 2
+ exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32)
+ exponent = exponent / (half_dim - downscale_freq_shift)
+
+ emb = torch.exp(exponent).to(device=timesteps.device)
+ emb = timesteps[:, None] * emb[None, :]
+
+ # scale embeddings
+ emb = scale * emb
+
+ # concat sine and cosine embeddings
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
+
+ # flip sine and cosine embeddings
+ if flip_sin_to_cos:
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
+
+ # zero pad
+ if embedding_dim % 2 == 1:
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb.to(torch.float16)
+
+
+class TimestepEmbedding(nn.Module):
+ def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"):
+ super().__init__()
+
+ self.linear_1 = nn.Linear(channel, time_embed_dim)
+ self.act = None
+ if act_fn == "silu":
+ self.act = nn.SiLU()
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
+
+ def forward(self, sample):
+ sample = self.linear_1(sample)
+
+ if self.act is not None:
+ sample = self.act(sample)
+
+ sample = self.linear_2(sample)
+ return sample
+
+
+class Timesteps(nn.Module):
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
+ super().__init__()
+ self.num_channels = num_channels
+ self.flip_sin_to_cos = flip_sin_to_cos
+ self.downscale_freq_shift = downscale_freq_shift
+
+ def forward(self, timesteps):
+ t_emb = get_timestep_embedding(
+ timesteps,
+ self.num_channels,
+ flip_sin_to_cos=self.flip_sin_to_cos,
+ downscale_freq_shift=self.downscale_freq_shift,
+ )
+ return t_emb
+
+
+class GaussianFourierProjection(nn.Module):
+ """Gaussian Fourier embeddings for noise levels."""
+
+ def __init__(self, embedding_size: int = 256, scale: float = 1.0):
+ super().__init__()
+ self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
+
+ # to delete later
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
+
+ self.weight = self.W
+
+ def forward(self, x):
+ x = torch.log(x)
+ x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
+ out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
+ return out
diff --git a/my_half_diffusers/models/resnet.py b/my_half_diffusers/models/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..0439aff823242b9e9f9e504db6fbd69702f190cc
--- /dev/null
+++ b/my_half_diffusers/models/resnet.py
@@ -0,0 +1,483 @@
+from functools import partial
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class Upsample2D(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+
+ :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
+ applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_conv_transpose = use_conv_transpose
+ self.name = name
+
+ conv = None
+ if use_conv_transpose:
+ conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
+ elif use_conv:
+ conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
+
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
+ if name == "conv":
+ self.conv = conv
+ else:
+ self.Conv2d_0 = conv
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.use_conv_transpose:
+ return self.conv(x)
+
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
+
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
+ if self.use_conv:
+ if self.name == "conv":
+ x = self.conv(x)
+ else:
+ x = self.Conv2d_0(x)
+
+ return x
+
+
+class Downsample2D(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+
+ :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
+ applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.padding = padding
+ stride = 2
+ self.name = name
+
+ if use_conv:
+ conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
+ else:
+ assert self.channels == self.out_channels
+ conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
+
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
+ if name == "conv":
+ self.Conv2d_0 = conv
+ self.conv = conv
+ elif name == "Conv2d_0":
+ self.conv = conv
+ else:
+ self.conv = conv
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.use_conv and self.padding == 0:
+ pad = (0, 1, 0, 1)
+ x = F.pad(x, pad, mode="constant", value=0)
+
+ assert x.shape[1] == self.channels
+ x = self.conv(x)
+
+ return x
+
+
+class FirUpsample2D(nn.Module):
+ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
+ super().__init__()
+ out_channels = out_channels if out_channels else channels
+ if use_conv:
+ self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.use_conv = use_conv
+ self.fir_kernel = fir_kernel
+ self.out_channels = out_channels
+
+ def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
+ """Fused `upsample_2d()` followed by `Conv2d()`.
+
+ Args:
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
+ order.
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
+ C]`.
+ weight: Weight tensor of the shape `[filterH, filterW, inChannels,
+ outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
+ factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
+ `x`.
+ """
+
+ assert isinstance(factor, int) and factor >= 1
+
+ # Setup filter kernel.
+ if kernel is None:
+ kernel = [1] * factor
+
+ # setup kernel
+ kernel = np.asarray(kernel, dtype=np.float16)
+ if kernel.ndim == 1:
+ kernel = np.outer(kernel, kernel)
+ kernel /= np.sum(kernel)
+
+ kernel = kernel * (gain * (factor**2))
+
+ if self.use_conv:
+ convH = weight.shape[2]
+ convW = weight.shape[3]
+ inC = weight.shape[1]
+
+ p = (kernel.shape[0] - factor) - (convW - 1)
+
+ stride = (factor, factor)
+ # Determine data dimensions.
+ stride = [1, 1, factor, factor]
+ output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
+ output_padding = (
+ output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
+ output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
+ )
+ assert output_padding[0] >= 0 and output_padding[1] >= 0
+ inC = weight.shape[1]
+ num_groups = x.shape[1] // inC
+
+ # Transpose weights.
+ weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
+ weight = weight[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
+ weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
+
+ x = F.conv_transpose2d(x, weight, stride=stride, output_padding=output_padding, padding=0)
+
+ x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
+ else:
+ p = kernel.shape[0] - factor
+ x = upfirdn2d_native(
+ x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
+ )
+
+ return x
+
+ def forward(self, x):
+ if self.use_conv:
+ height = self._upsample_2d(x, self.Conv2d_0.weight, kernel=self.fir_kernel)
+ height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
+ else:
+ height = self._upsample_2d(x, kernel=self.fir_kernel, factor=2)
+
+ return height
+
+
+class FirDownsample2D(nn.Module):
+ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
+ super().__init__()
+ out_channels = out_channels if out_channels else channels
+ if use_conv:
+ self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.fir_kernel = fir_kernel
+ self.use_conv = use_conv
+ self.out_channels = out_channels
+
+ def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
+ """Fused `Conv2d()` followed by `downsample_2d()`.
+
+ Args:
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
+ order.
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH,
+ filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] //
+ numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
+ factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain:
+ Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
+ datatype as `x`.
+ """
+
+ assert isinstance(factor, int) and factor >= 1
+ if kernel is None:
+ kernel = [1] * factor
+
+ # setup kernel
+ kernel = np.asarray(kernel, dtype=np.float16)
+ if kernel.ndim == 1:
+ kernel = np.outer(kernel, kernel)
+ kernel /= np.sum(kernel)
+
+ kernel = kernel * gain
+
+ if self.use_conv:
+ _, _, convH, convW = weight.shape
+ p = (kernel.shape[0] - factor) + (convW - 1)
+ s = [factor, factor]
+ x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2, p // 2))
+ x = F.conv2d(x, weight, stride=s, padding=0)
+ else:
+ p = kernel.shape[0] - factor
+ x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
+
+ return x
+
+ def forward(self, x):
+ if self.use_conv:
+ x = self._downsample_2d(x, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
+ x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
+ else:
+ x = self._downsample_2d(x, kernel=self.fir_kernel, factor=2)
+
+ return x
+
+
+class ResnetBlock2D(nn.Module):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout=0.0,
+ temb_channels=512,
+ groups=32,
+ groups_out=None,
+ pre_norm=True,
+ eps=1e-6,
+ non_linearity="swish",
+ time_embedding_norm="default",
+ kernel=None,
+ output_scale_factor=1.0,
+ use_nin_shortcut=None,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.pre_norm = pre_norm
+ self.pre_norm = True
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+ self.time_embedding_norm = time_embedding_norm
+ self.up = up
+ self.down = down
+ self.output_scale_factor = output_scale_factor
+
+ if groups_out is None:
+ groups_out = groups
+
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
+
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ if temb_channels is not None:
+ self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
+ else:
+ self.time_emb_proj = None
+
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ if non_linearity == "swish":
+ self.nonlinearity = lambda x: F.silu(x)
+ elif non_linearity == "mish":
+ self.nonlinearity = Mish()
+ elif non_linearity == "silu":
+ self.nonlinearity = nn.SiLU()
+
+ self.upsample = self.downsample = None
+ if self.up:
+ if kernel == "fir":
+ fir_kernel = (1, 3, 3, 1)
+ self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
+ elif kernel == "sde_vp":
+ self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
+ else:
+ self.upsample = Upsample2D(in_channels, use_conv=False)
+ elif self.down:
+ if kernel == "fir":
+ fir_kernel = (1, 3, 3, 1)
+ self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
+ elif kernel == "sde_vp":
+ self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
+ else:
+ self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
+
+ self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
+
+ self.conv_shortcut = None
+ if self.use_nin_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x, temb):
+ hidden_states = x
+
+ # make sure hidden states is in float32
+ # when running in half-precision
+ hidden_states = self.norm1(hidden_states).type(hidden_states.dtype)
+ hidden_states = self.nonlinearity(hidden_states)
+
+ if self.upsample is not None:
+ x = self.upsample(x)
+ hidden_states = self.upsample(hidden_states)
+ elif self.downsample is not None:
+ x = self.downsample(x)
+ hidden_states = self.downsample(hidden_states)
+
+ hidden_states = self.conv1(hidden_states)
+
+ if temb is not None:
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
+ hidden_states = hidden_states + temb
+
+ # make sure hidden states is in float32
+ # when running in half-precision
+ hidden_states = self.norm2(hidden_states).type(hidden_states.dtype)
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.conv_shortcut is not None:
+ x = self.conv_shortcut(x)
+
+ out = (x + hidden_states) / self.output_scale_factor
+
+ return out
+
+
+class Mish(torch.nn.Module):
+ def forward(self, x):
+ return x * torch.tanh(torch.nn.functional.softplus(x))
+
+
+def upsample_2d(x, kernel=None, factor=2, gain=1):
+ r"""Upsample2D a batch of 2D images with the given filter.
+
+ Args:
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
+ filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
+ `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
+ multiple of the upsampling factor.
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
+ C]`.
+ k: FIR filter of the shape `[firH, firW]` or `[firN]`
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
+ factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ Tensor of the shape `[N, C, H * factor, W * factor]`
+ """
+ assert isinstance(factor, int) and factor >= 1
+ if kernel is None:
+ kernel = [1] * factor
+
+ kernel = np.asarray(kernel, dtype=np.float16)
+ if kernel.ndim == 1:
+ kernel = np.outer(kernel, kernel)
+ kernel /= np.sum(kernel)
+
+ kernel = kernel * (gain * (factor**2))
+ p = kernel.shape[0] - factor
+ return upfirdn2d_native(
+ x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
+ )
+
+
+def downsample_2d(x, kernel=None, factor=2, gain=1):
+ r"""Downsample2D a batch of 2D images with the given filter.
+
+ Args:
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
+ given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
+ specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
+ shape is a multiple of the downsampling factor.
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
+ C]`.
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
+ (separable). The default is `[1] * factor`, which corresponds to average pooling.
+ factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ Tensor of the shape `[N, C, H // factor, W // factor]`
+ """
+
+ assert isinstance(factor, int) and factor >= 1
+ if kernel is None:
+ kernel = [1] * factor
+
+ kernel = np.asarray(kernel, dtype=np.float16)
+ if kernel.ndim == 1:
+ kernel = np.outer(kernel, kernel)
+ kernel /= np.sum(kernel)
+
+ kernel = kernel * gain
+ p = kernel.shape[0] - factor
+ return upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
+
+
+def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
+ up_x = up_y = up
+ down_x = down_y = down
+ pad_x0 = pad_y0 = pad[0]
+ pad_x1 = pad_y1 = pad[1]
+
+ _, channel, in_h, in_w = input.shape
+ input = input.reshape(-1, in_h, in_w, 1)
+
+ _, in_h, in_w, minor = input.shape
+ kernel_h, kernel_w = kernel.shape
+
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
+
+ # Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535
+ if input.device.type == "mps":
+ out = out.to("cpu")
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
+
+ out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
+ out = out.to(input.device) # Move back to mps if necessary
+ out = out[
+ :,
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
+ :,
+ ]
+
+ out = out.permute(0, 3, 1, 2)
+ out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
+ out = F.conv2d(out, w)
+ out = out.reshape(
+ -1,
+ minor,
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
+ )
+ out = out.permute(0, 2, 3, 1)
+ out = out[:, ::down_y, ::down_x, :]
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+
+ return out.view(-1, channel, out_h, out_w)
diff --git a/my_half_diffusers/models/unet_2d.py b/my_half_diffusers/models/unet_2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca8931b2ed6db1e5b4561b510785e5a69c20fa59
--- /dev/null
+++ b/my_half_diffusers/models/unet_2d.py
@@ -0,0 +1,246 @@
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..modeling_utils import ModelMixin
+from ..utils import BaseOutput
+from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
+from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
+
+
+@dataclass
+class UNet2DOutput(BaseOutput):
+ """
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Hidden states output. Output of last layer of model.
+ """
+
+ sample: torch.DoubleTensor
+
+
+class UNet2DModel(ModelMixin, ConfigMixin):
+ r"""
+ UNet2DModel is a 2D UNet model that takes in a noisy sample and a timestep and returns sample shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
+ implements for all the model (such as downloading or saving, etc.)
+
+ Parameters:
+ sample_size (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
+ Input sample size.
+ in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
+ out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
+ time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
+ freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
+ flip_sin_to_cos (`bool`, *optional*, defaults to :
+ obj:`False`): Whether to flip sin to cos for fourier time embedding.
+ down_block_types (`Tuple[str]`, *optional*, defaults to :
+ obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block
+ types.
+ up_block_types (`Tuple[str]`, *optional*, defaults to :
+ obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to :
+ obj:`(224, 448, 672, 896)`): Tuple of block output channels.
+ layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
+ mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
+ downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
+ norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization.
+ norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ center_input_sample: bool = False,
+ time_embedding_type: str = "positional",
+ freq_shift: int = 0,
+ flip_sin_to_cos: bool = True,
+ down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
+ up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
+ block_out_channels: Tuple[int] = (224, 448, 672, 896),
+ layers_per_block: int = 2,
+ mid_block_scale_factor = 1,
+ downsample_padding: int = 1,
+ act_fn: str = "silu",
+ attention_head_dim: int = 8,
+ norm_num_groups: int = 32,
+ norm_eps = 1e-5,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+ time_embed_dim = block_out_channels[0] * 4
+
+ # input
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
+
+ # time
+ if time_embedding_type == "fourier":
+ self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
+ timestep_input_dim = 2 * block_out_channels[0]
+ elif time_embedding_type == "positional":
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+
+ self.down_blocks = nn.ModuleList([])
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ attn_num_head_channels=attention_head_dim,
+ downsample_padding=downsample_padding,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift="default",
+ attn_num_head_channels=attention_head_dim,
+ resnet_groups=norm_num_groups,
+ )
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ is_final_block = i == len(block_out_channels) - 1
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=time_embed_dim,
+ add_upsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ attn_num_head_channels=attention_head_dim,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
+ self.conv_act = nn.SiLU()
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
+
+ def forward(
+ self,
+ sample: torch.DoubleTensor,
+ timestep: Union[torch.Tensor, float, int],
+ return_dict: bool = True,
+ ) -> Union[UNet2DOutput, Tuple]:
+ """r
+ Args:
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
+ timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True,
+ otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+ """
+ # 0. center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
+
+ t_emb = self.time_proj(timesteps)
+ emb = self.time_embedding(t_emb)
+
+ # 2. pre-process
+ skip_sample = sample
+ sample = self.conv_in(sample)
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "skip_conv"):
+ sample, res_samples, skip_sample = downsample_block(
+ hidden_states=sample, temb=emb, skip_sample=skip_sample
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ sample = self.mid_block(sample, emb)
+
+ # 5. up
+ skip_sample = None
+ for upsample_block in self.up_blocks:
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ if hasattr(upsample_block, "skip_conv"):
+ sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
+ else:
+ sample = upsample_block(sample, res_samples, emb)
+
+ # 6. post-process
+ # make sure hidden states is in float32
+ # when running in half-precision
+ sample = self.conv_norm_out(sample).type(sample.dtype)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ if skip_sample is not None:
+ sample += skip_sample
+
+ if self.config.time_embedding_type == "fourier":
+ timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
+ sample = sample / timesteps
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet2DOutput(sample=sample)
diff --git a/my_half_diffusers/models/unet_2d_condition.py b/my_half_diffusers/models/unet_2d_condition.py
new file mode 100644
index 0000000000000000000000000000000000000000..8546ea4c475ead158f9ae16a0c391c1267d6a4ec
--- /dev/null
+++ b/my_half_diffusers/models/unet_2d_condition.py
@@ -0,0 +1,273 @@
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..modeling_utils import ModelMixin
+from ..utils import BaseOutput
+from .embeddings import TimestepEmbedding, Timesteps
+from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
+
+
+@dataclass
+class UNet2DConditionOutput(BaseOutput):
+ """
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
+ """
+
+ sample: torch.FloatTensor
+
+
+class UNet2DConditionModel(ModelMixin, ConfigMixin):
+ r"""
+ UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
+ and returns sample shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
+ implements for all the model (such as downloading or saving, etc.)
+
+ Parameters:
+ sample_size (`int`, *optional*): The size of the input sample.
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
+ The tuple of upsample blocks to use.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
+ cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 1280,
+ attention_head_dim: int = 8,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+ time_embed_dim = block_out_channels[0] * 4
+
+ # input
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
+
+ # time
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+
+ self.down_blocks = nn.ModuleList([])
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim,
+ downsample_padding=downsample_padding,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = UNetMidBlock2DCrossAttn(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift="default",
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim,
+ resnet_groups=norm_num_groups,
+ )
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ is_final_block = i == len(block_out_channels) - 1
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=time_embed_dim,
+ add_upsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
+ self.conv_act = nn.SiLU()
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
+
+ def set_attention_slice(self, slice_size):
+ if slice_size is not None and self.config.attention_head_dim % slice_size != 0:
+ raise ValueError(
+ f"Make sure slice_size {slice_size} is a divisor of "
+ f"the number of heads used in cross_attention {self.config.attention_head_dim}"
+ )
+ if slice_size is not None and slice_size > self.config.attention_head_dim:
+ raise ValueError(
+ f"Chunk_size {slice_size} has to be smaller or equal to "
+ f"the number of heads used in cross_attention {self.config.attention_head_dim}"
+ )
+
+ for block in self.down_blocks:
+ if hasattr(block, "attentions") and block.attentions is not None:
+ block.set_attention_slice(slice_size)
+
+ self.mid_block.set_attention_slice(slice_size)
+
+ for block in self.up_blocks:
+ if hasattr(block, "attentions") and block.attentions is not None:
+ block.set_attention_slice(slice_size)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ return_dict: bool = True,
+ ) -> Union[UNet2DConditionOutput, Tuple]:
+ """r
+ Args:
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
+ timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
+ encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+ """
+ # 0. center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
+ timesteps = timesteps.to(dtype=torch.float16)
+ timesteps = timesteps[None].to(device=sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+ # print(t_emb.dtype)
+ t_emb = t_emb.to(sample.dtype).to(sample.device)
+ emb = self.time_embedding(t_emb)
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
+ # print(sample.dtype, emb.dtype, encoder_hidden_states.dtype)
+ sample, res_samples = downsample_block(
+ hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
+
+ # 5. up
+ for upsample_block in self.up_blocks:
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ )
+ else:
+ sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples)
+
+ # 6. post-process
+ # make sure hidden states is in float32
+ # when running in half-precision
+ sample = self.conv_norm_out(sample).type(sample.dtype)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet2DConditionOutput(sample=sample)
diff --git a/my_half_diffusers/models/unet_blocks.py b/my_half_diffusers/models/unet_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e062165357c33d9b2f0bec13a66204c2e7e7833
--- /dev/null
+++ b/my_half_diffusers/models/unet_blocks.py
@@ -0,0 +1,1481 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import numpy as np
+
+# limitations under the License.
+import torch
+from torch import nn
+
+from .attention import AttentionBlock, SpatialTransformer
+from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
+
+
+def get_down_block(
+ down_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ temb_channels,
+ add_downsample,
+ resnet_eps,
+ resnet_act_fn,
+ attn_num_head_channels,
+ cross_attention_dim=None,
+ downsample_padding=None,
+):
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
+ if down_block_type == "DownBlock2D":
+ return DownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ )
+ elif down_block_type == "AttnDownBlock2D":
+ return AttnDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif down_block_type == "CrossAttnDownBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
+ return CrossAttnDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif down_block_type == "SkipDownBlock2D":
+ return SkipDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ )
+ elif down_block_type == "AttnSkipDownBlock2D":
+ return AttnSkipDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif down_block_type == "DownEncoderBlock2D":
+ return DownEncoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ )
+
+
+def get_up_block(
+ up_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ prev_output_channel,
+ temb_channels,
+ add_upsample,
+ resnet_eps,
+ resnet_act_fn,
+ attn_num_head_channels,
+ cross_attention_dim=None,
+):
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
+ if up_block_type == "UpBlock2D":
+ return UpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ )
+ elif up_block_type == "CrossAttnUpBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
+ return CrossAttnUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif up_block_type == "AttnUpBlock2D":
+ return AttnUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif up_block_type == "SkipUpBlock2D":
+ return SkipUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ )
+ elif up_block_type == "AttnSkipUpBlock2D":
+ return AttnSkipUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif up_block_type == "UpDecoderBlock2D":
+ return UpDecoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ )
+ raise ValueError(f"{up_block_type} does not exist.")
+
+
+class UNetMidBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ attention_type="default",
+ output_scale_factor=1.0,
+ **kwargs,
+ ):
+ super().__init__()
+
+ self.attention_type = attention_type
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ attentions.append(
+ AttentionBlock(
+ in_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ num_groups=resnet_groups,
+ )
+ )
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(self, hidden_states, temb=None, encoder_states=None):
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ if self.attention_type == "default":
+ hidden_states = attn(hidden_states)
+ else:
+ hidden_states = attn(hidden_states, encoder_states)
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class UNetMidBlock2DCrossAttn(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ attention_type="default",
+ output_scale_factor=1.0,
+ cross_attention_dim=1280,
+ **kwargs,
+ ):
+ super().__init__()
+
+ self.attention_type = attention_type
+ self.attn_num_head_channels = attn_num_head_channels
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ attentions.append(
+ SpatialTransformer(
+ in_channels,
+ attn_num_head_channels,
+ in_channels // attn_num_head_channels,
+ depth=1,
+ context_dim=cross_attention_dim,
+ )
+ )
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def set_attention_slice(self, slice_size):
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
+ raise ValueError(
+ f"Make sure slice_size {slice_size} is a divisor of "
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
+ )
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
+ raise ValueError(
+ f"Chunk_size {slice_size} has to be smaller or equal to "
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
+ )
+
+ for attn in self.attentions:
+ attn._set_attention_slice(slice_size)
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ hidden_states = attn(hidden_states, encoder_hidden_states)
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class AttnDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ attention_type="default",
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_downsample=True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.attention_type = attention_type
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states, temb=None):
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states)
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class CrossAttnDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ cross_attention_dim=1280,
+ attention_type="default",
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_downsample=True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.attention_type = attention_type
+ self.attn_num_head_channels = attn_num_head_channels
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ SpatialTransformer(
+ out_channels,
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ depth=1,
+ context_dim=cross_attention_dim,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def set_attention_slice(self, slice_size):
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
+ raise ValueError(
+ f"Make sure slice_size {slice_size} is a divisor of "
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
+ )
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
+ raise ValueError(
+ f"Chunk_size {slice_size} has to be smaller or equal to "
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
+ )
+
+ for attn in self.attentions:
+ attn._set_attention_slice(slice_size)
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, context=encoder_hidden_states)
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class DownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states, temb=None):
+ output_states = ()
+
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb)
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class DownEncoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states):
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb=None)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnDownEncoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ num_groups=resnet_groups,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb=None)
+ hidden_states = attn(hidden_states)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnSkipDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ attention_type="default",
+ output_scale_factor=np.sqrt(2.0),
+ downsample_padding=1,
+ add_downsample=True,
+ ):
+ super().__init__()
+ self.attentions = nn.ModuleList([])
+ self.resnets = nn.ModuleList([])
+
+ self.attention_type = attention_type
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(in_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ self.attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ )
+ )
+
+ if add_downsample:
+ self.resnet_down = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_nin_shortcut=True,
+ down=True,
+ kernel="fir",
+ )
+ self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
+ else:
+ self.resnet_down = None
+ self.downsamplers = None
+ self.skip_conv = None
+
+ def forward(self, hidden_states, temb=None, skip_sample=None):
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states)
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ hidden_states = self.resnet_down(hidden_states, temb)
+ for downsampler in self.downsamplers:
+ skip_sample = downsampler(skip_sample)
+
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states, skip_sample
+
+
+class SkipDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ output_scale_factor=np.sqrt(2.0),
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ self.resnets = nn.ModuleList([])
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(in_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ if add_downsample:
+ self.resnet_down = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_nin_shortcut=True,
+ down=True,
+ kernel="fir",
+ )
+ self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
+ else:
+ self.resnet_down = None
+ self.downsamplers = None
+ self.skip_conv = None
+
+ def forward(self, hidden_states, temb=None, skip_sample=None):
+ output_states = ()
+
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb)
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ hidden_states = self.resnet_down(hidden_states, temb)
+ for downsampler in self.downsamplers:
+ skip_sample = downsampler(skip_sample)
+
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states, skip_sample
+
+
+class AttnUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_type="default",
+ attn_num_head_channels=1,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.attention_type = attention_type
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
+ for resnet, attn in zip(self.resnets, self.attentions):
+
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class CrossAttnUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ cross_attention_dim=1280,
+ attention_type="default",
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.attention_type = attention_type
+ self.attn_num_head_channels = attn_num_head_channels
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ SpatialTransformer(
+ out_channels,
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ depth=1,
+ context_dim=cross_attention_dim,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ def set_attention_slice(self, slice_size):
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
+ raise ValueError(
+ f"Make sure slice_size {slice_size} is a divisor of "
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
+ )
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
+ raise ValueError(
+ f"Chunk_size {slice_size} has to be smaller or equal to "
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
+ )
+
+ for attn in self.attentions:
+ attn._set_attention_slice(slice_size)
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None):
+ for resnet, attn in zip(self.resnets, self.attentions):
+
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, context=encoder_hidden_states)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class UpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
+ for resnet in self.resnets:
+
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class UpDecoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ input_channels = in_channels if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ def forward(self, hidden_states):
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb=None)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnUpDecoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ for i in range(num_layers):
+ input_channels = in_channels if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ num_groups=resnet_groups,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ def forward(self, hidden_states):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb=None)
+ hidden_states = attn(hidden_states)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnSkipUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ attention_type="default",
+ output_scale_factor=np.sqrt(2.0),
+ upsample_padding=1,
+ add_upsample=True,
+ ):
+ super().__init__()
+ self.attentions = nn.ModuleList([])
+ self.resnets = nn.ModuleList([])
+
+ self.attention_type = attention_type
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(resnet_in_channels + res_skip_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ )
+ )
+
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
+ if add_upsample:
+ self.resnet_up = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_nin_shortcut=True,
+ up=True,
+ kernel="fir",
+ )
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
+ self.skip_norm = torch.nn.GroupNorm(
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
+ )
+ self.act = nn.SiLU()
+ else:
+ self.resnet_up = None
+ self.skip_conv = None
+ self.skip_norm = None
+ self.act = None
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+
+ hidden_states = self.attentions[0](hidden_states)
+
+ if skip_sample is not None:
+ skip_sample = self.upsampler(skip_sample)
+ else:
+ skip_sample = 0
+
+ if self.resnet_up is not None:
+ skip_sample_states = self.skip_norm(hidden_states)
+ skip_sample_states = self.act(skip_sample_states)
+ skip_sample_states = self.skip_conv(skip_sample_states)
+
+ skip_sample = skip_sample + skip_sample_states
+
+ hidden_states = self.resnet_up(hidden_states, temb)
+
+ return hidden_states, skip_sample
+
+
+class SkipUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ output_scale_factor=np.sqrt(2.0),
+ add_upsample=True,
+ upsample_padding=1,
+ ):
+ super().__init__()
+ self.resnets = nn.ModuleList([])
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min((resnet_in_channels + res_skip_channels) // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
+ if add_upsample:
+ self.resnet_up = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_nin_shortcut=True,
+ up=True,
+ kernel="fir",
+ )
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
+ self.skip_norm = torch.nn.GroupNorm(
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
+ )
+ self.act = nn.SiLU()
+ else:
+ self.resnet_up = None
+ self.skip_conv = None
+ self.skip_norm = None
+ self.act = None
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+
+ if skip_sample is not None:
+ skip_sample = self.upsampler(skip_sample)
+ else:
+ skip_sample = 0
+
+ if self.resnet_up is not None:
+ skip_sample_states = self.skip_norm(hidden_states)
+ skip_sample_states = self.act(skip_sample_states)
+ skip_sample_states = self.skip_conv(skip_sample_states)
+
+ skip_sample = skip_sample + skip_sample_states
+
+ hidden_states = self.resnet_up(hidden_states, temb)
+
+ return hidden_states, skip_sample
diff --git a/my_half_diffusers/models/vae.py b/my_half_diffusers/models/vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..82748cb5b60c0241cc3ca96f9016f07650e44a54
--- /dev/null
+++ b/my_half_diffusers/models/vae.py
@@ -0,0 +1,581 @@
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..modeling_utils import ModelMixin
+from ..utils import BaseOutput
+from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
+
+
+@dataclass
+class DecoderOutput(BaseOutput):
+ """
+ Output of decoding method.
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Decoded output sample of the model. Output of the last layer of the model.
+ """
+
+ sample: torch.FloatTensor
+
+
+@dataclass
+class VQEncoderOutput(BaseOutput):
+ """
+ Output of VQModel encoding method.
+
+ Args:
+ latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Encoded output sample of the model. Output of the last layer of the model.
+ """
+
+ latents: torch.FloatTensor
+
+
+@dataclass
+class AutoencoderKLOutput(BaseOutput):
+ """
+ Output of AutoencoderKL encoding method.
+
+ Args:
+ latent_dist (`DiagonalGaussianDistribution`):
+ Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
+ `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
+ """
+
+ latent_dist: "DiagonalGaussianDistribution"
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ in_channels=3,
+ out_channels=3,
+ down_block_types=("DownEncoderBlock2D",),
+ block_out_channels=(64,),
+ layers_per_block=2,
+ act_fn="silu",
+ double_z=True,
+ ):
+ super().__init__()
+ self.layers_per_block = layers_per_block
+
+ self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
+
+ self.mid_block = None
+ self.down_blocks = nn.ModuleList([])
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=self.layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ add_downsample=not is_final_block,
+ resnet_eps=1e-6,
+ downsample_padding=0,
+ resnet_act_fn=act_fn,
+ attn_num_head_channels=None,
+ temb_channels=None,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ output_scale_factor=1,
+ resnet_time_scale_shift="default",
+ attn_num_head_channels=None,
+ resnet_groups=32,
+ temb_channels=None,
+ )
+
+ # out
+ num_groups_out = 32
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups_out, eps=1e-6)
+ self.conv_act = nn.SiLU()
+
+ conv_out_channels = 2 * out_channels if double_z else out_channels
+ self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
+
+ def forward(self, x):
+ sample = x
+ sample = self.conv_in(sample)
+
+ # down
+ for down_block in self.down_blocks:
+ sample = down_block(sample)
+
+ # middle
+ sample = self.mid_block(sample)
+
+ # post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ return sample
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ in_channels=3,
+ out_channels=3,
+ up_block_types=("UpDecoderBlock2D",),
+ block_out_channels=(64,),
+ layers_per_block=2,
+ act_fn="silu",
+ ):
+ super().__init__()
+ self.layers_per_block = layers_per_block
+
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
+
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+
+ # mid
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ output_scale_factor=1,
+ resnet_time_scale_shift="default",
+ attn_num_head_channels=None,
+ resnet_groups=32,
+ temb_channels=None,
+ )
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+
+ is_final_block = i == len(block_out_channels) - 1
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=self.layers_per_block + 1,
+ in_channels=prev_output_channel,
+ out_channels=output_channel,
+ prev_output_channel=None,
+ add_upsample=not is_final_block,
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ attn_num_head_channels=None,
+ temb_channels=None,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ num_groups_out = 32
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=1e-6)
+ self.conv_act = nn.SiLU()
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
+
+ def forward(self, z):
+ sample = z
+ sample = self.conv_in(sample)
+
+ # middle
+ sample = self.mid_block(sample)
+
+ # up
+ for up_block in self.up_blocks:
+ sample = up_block(sample)
+
+ # post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ return sample
+
+
+class VectorQuantizer(nn.Module):
+ """
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
+ multiplications and allows for post-hoc remapping of indices.
+ """
+
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
+ # backwards compatibility we use the buggy version by default, but you can
+ # specify legacy=False to fix it.
+ def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True):
+ super().__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+ self.legacy = legacy
+
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed + 1
+ print(
+ f"Remapping {self.n_e} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices."
+ )
+ else:
+ self.re_embed = n_e
+
+ self.sane_index_shape = sane_index_shape
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ match = (inds[:, :, None] == used[None, None, ...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2) < 1
+ if self.unknown_index == "random":
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
+ return back.reshape(ishape)
+
+ def forward(self, z):
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = z.permute(0, 2, 3, 1).contiguous()
+ z_flattened = z.view(-1, self.e_dim)
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+ d = (
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
+ + torch.sum(self.embedding.weight**2, dim=1)
+ - 2 * torch.einsum("bd,dn->bn", z_flattened, self.embedding.weight.t())
+ )
+
+ min_encoding_indices = torch.argmin(d, dim=1)
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
+ perplexity = None
+ min_encodings = None
+
+ # compute loss for embedding
+ if not self.legacy:
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
+ else:
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ if self.remap is not None:
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
+
+ if self.sane_index_shape:
+ min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
+
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
+
+ def get_codebook_entry(self, indices, shape):
+ # shape specifying (batch, height, width, channel)
+ if self.remap is not None:
+ indices = indices.reshape(shape[0], -1) # add batch axis
+ indices = self.unmap_to_all(indices)
+ indices = indices.reshape(-1) # flatten again
+
+ # get quantized latent vectors
+ z_q = self.embedding(indices)
+
+ if shape is not None:
+ z_q = z_q.view(shape)
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
+ device = self.parameters.device
+ sample_device = "cpu" if device.type == "mps" else device
+ sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device)
+ x = self.mean + self.std * sample
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var
+ - 1.0
+ - self.logvar
+ + other.logvar,
+ dim=[1, 2, 3],
+ )
+
+ def nll(self, sample, dims=[1, 2, 3]):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
+
+ def mode(self):
+ return self.mean
+
+
+class VQModel(ModelMixin, ConfigMixin):
+ r"""VQ-VAE model from the paper Neural Discrete Representation Learning by Aaron van den Oord, Oriol Vinyals and Koray
+ Kavukcuoglu.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
+ implements for all the model (such as downloading or saving, etc.)
+
+ Parameters:
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to :
+ obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
+ up_block_types (`Tuple[str]`, *optional*, defaults to :
+ obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to :
+ obj:`(64,)`): Tuple of block output channels.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
+ sample_size (`int`, *optional*, defaults to `32`): TODO
+ num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
+ block_out_channels: Tuple[int] = (64,),
+ layers_per_block: int = 1,
+ act_fn: str = "silu",
+ latent_channels: int = 3,
+ sample_size: int = 32,
+ num_vq_embeddings: int = 256,
+ ):
+ super().__init__()
+
+ # pass init params to Encoder
+ self.encoder = Encoder(
+ in_channels=in_channels,
+ out_channels=latent_channels,
+ down_block_types=down_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ double_z=False,
+ )
+
+ self.quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
+ self.quantize = VectorQuantizer(
+ num_vq_embeddings, latent_channels, beta=0.25, remap=None, sane_index_shape=False
+ )
+ self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
+
+ # pass init params to Decoder
+ self.decoder = Decoder(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ up_block_types=up_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ )
+
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+
+ if not return_dict:
+ return (h,)
+
+ return VQEncoderOutput(latents=h)
+
+ def decode(
+ self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
+ # also go through quantization layer
+ if not force_not_quantize:
+ quant, emb_loss, info = self.quantize(h)
+ else:
+ quant = h
+ quant = self.post_quant_conv(quant)
+ dec = self.decoder(quant)
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): Input sample.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ x = sample
+ h = self.encode(x).latents
+ dec = self.decode(h).sample
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+
+class AutoencoderKL(ModelMixin, ConfigMixin):
+ r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
+ and Max Welling.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
+ implements for all the model (such as downloading or saving, etc.)
+
+ Parameters:
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to :
+ obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
+ up_block_types (`Tuple[str]`, *optional*, defaults to :
+ obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to :
+ obj:`(64,)`): Tuple of block output channels.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space.
+ sample_size (`int`, *optional*, defaults to `32`): TODO
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
+ block_out_channels: Tuple[int] = (64,),
+ layers_per_block: int = 1,
+ act_fn: str = "silu",
+ latent_channels: int = 4,
+ sample_size: int = 32,
+ ):
+ super().__init__()
+
+ # pass init params to Encoder
+ self.encoder = Encoder(
+ in_channels=in_channels,
+ out_channels=latent_channels,
+ down_block_types=down_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ double_z=True,
+ )
+
+ # pass init params to Decoder
+ self.decoder = Decoder(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ up_block_types=up_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ )
+
+ self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
+ self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
+
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+
+ if not return_dict:
+ return (posterior,)
+
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ def forward(
+ self, sample: torch.FloatTensor, sample_posterior: bool = False, return_dict: bool = True
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): Input sample.
+ sample_posterior (`bool`, *optional*, defaults to `False`):
+ Whether to sample from the posterior.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z).sample
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
diff --git a/my_half_diffusers/onnx_utils.py b/my_half_diffusers/onnx_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e840565dd5c1b9bd17422aba5af6dc0d045c4682
--- /dev/null
+++ b/my_half_diffusers/onnx_utils.py
@@ -0,0 +1,189 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import shutil
+from pathlib import Path
+from typing import Optional, Union
+
+import numpy as np
+
+from huggingface_hub import hf_hub_download
+
+from .utils import is_onnx_available, logging
+
+
+if is_onnx_available():
+ import onnxruntime as ort
+
+
+ONNX_WEIGHTS_NAME = "model.onnx"
+
+
+logger = logging.get_logger(__name__)
+
+
+class OnnxRuntimeModel:
+ base_model_prefix = "onnx_model"
+
+ def __init__(self, model=None, **kwargs):
+ logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.")
+ self.model = model
+ self.model_save_dir = kwargs.get("model_save_dir", None)
+ self.latest_model_name = kwargs.get("latest_model_name", "model.onnx")
+
+ def __call__(self, **kwargs):
+ inputs = {k: np.array(v) for k, v in kwargs.items()}
+ return self.model.run(None, inputs)
+
+ @staticmethod
+ def load_model(path: Union[str, Path], provider=None):
+ """
+ Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`
+
+ Arguments:
+ path (`str` or `Path`):
+ Directory from which to load
+ provider(`str`, *optional*):
+ Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider`
+ """
+ if provider is None:
+ logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
+ provider = "CPUExecutionProvider"
+
+ return ort.InferenceSession(path, providers=[provider])
+
+ def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs):
+ """
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
+ [`~optimum.onnxruntime.modeling_ort.ORTModel.from_pretrained`] class method. It will always save the
+ latest_model_name.
+
+ Arguments:
+ save_directory (`str` or `Path`):
+ Directory where to save the model file.
+ file_name(`str`, *optional*):
+ Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to save the
+ model with a different name.
+ """
+ model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
+
+ src_path = self.model_save_dir.joinpath(self.latest_model_name)
+ dst_path = Path(save_directory).joinpath(model_file_name)
+ if not src_path.samefile(dst_path):
+ shutil.copyfile(src_path, dst_path)
+
+ def save_pretrained(
+ self,
+ save_directory: Union[str, os.PathLike],
+ **kwargs,
+ ):
+ """
+ Save a model to a directory, so that it can be re-loaded using the [`~OnnxModel.from_pretrained`] class
+ method.:
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to which to save. Will be created if it doesn't exist.
+ """
+ if os.path.isfile(save_directory):
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+ return
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ # saving model weights/files
+ self._save_pretrained(save_directory, **kwargs)
+
+ @classmethod
+ def _from_pretrained(
+ cls,
+ model_id: Union[str, Path],
+ use_auth_token: Optional[Union[bool, str, None]] = None,
+ revision: Optional[Union[str, None]] = None,
+ force_download: bool = False,
+ cache_dir: Optional[str] = None,
+ file_name: Optional[str] = None,
+ provider: Optional[str] = None,
+ **kwargs,
+ ):
+ """
+ Load a model from a directory or the HF Hub.
+
+ Arguments:
+ model_id (`str` or `Path`):
+ Directory from which to load
+ use_auth_token (`str` or `bool`):
+ Is needed to load models from a private or gated repository
+ revision (`str`):
+ Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id
+ cache_dir (`Union[str, Path]`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ file_name(`str`):
+ Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to load
+ different model files from the same repository or directory.
+ provider(`str`):
+ The ONNX runtime provider, e.g. `CPUExecutionProvider` or `CUDAExecutionProvider`.
+ kwargs (`Dict`, *optional*):
+ kwargs will be passed to the model during initialization
+ """
+ model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
+ # load model from local directory
+ if os.path.isdir(model_id):
+ model = OnnxRuntimeModel.load_model(os.path.join(model_id, model_file_name), provider=provider)
+ kwargs["model_save_dir"] = Path(model_id)
+ # load model from hub
+ else:
+ # download model
+ model_cache_path = hf_hub_download(
+ repo_id=model_id,
+ filename=model_file_name,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ )
+ kwargs["model_save_dir"] = Path(model_cache_path).parent
+ kwargs["latest_model_name"] = Path(model_cache_path).name
+ model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider)
+ return cls(model=model, **kwargs)
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ model_id: Union[str, Path],
+ force_download: bool = True,
+ use_auth_token: Optional[str] = None,
+ cache_dir: Optional[str] = None,
+ **model_kwargs,
+ ):
+ revision = None
+ if len(str(model_id).split("@")) == 2:
+ model_id, revision = model_id.split("@")
+
+ return cls._from_pretrained(
+ model_id=model_id,
+ revision=revision,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ use_auth_token=use_auth_token,
+ **model_kwargs,
+ )
diff --git a/my_half_diffusers/optimization.py b/my_half_diffusers/optimization.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7b836b4a69bffb61c15967ef9b1736201721f1b
--- /dev/null
+++ b/my_half_diffusers/optimization.py
@@ -0,0 +1,275 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch optimization for diffusion models."""
+
+import math
+from enum import Enum
+from typing import Optional, Union
+
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import LambdaLR
+
+from .utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class SchedulerType(Enum):
+ LINEAR = "linear"
+ COSINE = "cosine"
+ COSINE_WITH_RESTARTS = "cosine_with_restarts"
+ POLYNOMIAL = "polynomial"
+ CONSTANT = "constant"
+ CONSTANT_WITH_WARMUP = "constant_with_warmup"
+
+
+def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
+ """
+ Create a schedule with a constant learning rate, using the learning rate set in optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+ return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
+
+
+def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
+ """
+ Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
+ increases linearly between 0 and the initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ def lr_lambda(current_step: int):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1.0, num_warmup_steps))
+ return 1.0
+
+ return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
+
+
+def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
+ """
+ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
+ a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ def lr_lambda(current_step: int):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ return max(
+ 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
+ )
+
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def get_cosine_schedule_with_warmup(
+ optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
+):
+ """
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
+ initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
+ initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ num_cycles (`float`, *optional*, defaults to 0.5):
+ The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
+ following a half-cosine).
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ def lr_lambda(current_step):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
+
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def get_cosine_with_hard_restarts_schedule_with_warmup(
+ optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
+):
+ """
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
+ initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
+ linearly between 0 and the initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ num_cycles (`int`, *optional*, defaults to 1):
+ The number of hard restarts to use.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ def lr_lambda(current_step):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
+ if progress >= 1.0:
+ return 0.0
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
+
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def get_polynomial_decay_schedule_with_warmup(
+ optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
+):
+ """
+ Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
+ optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
+ initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ lr_end (`float`, *optional*, defaults to 1e-7):
+ The end LR.
+ power (`float`, *optional*, defaults to 1.0):
+ Power factor.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
+ implementation at
+ https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+
+ """
+
+ lr_init = optimizer.defaults["lr"]
+ if not (lr_init > lr_end):
+ raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})")
+
+ def lr_lambda(current_step: int):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ elif current_step > num_training_steps:
+ return lr_end / lr_init # as LambdaLR multiplies by lr_init
+ else:
+ lr_range = lr_init - lr_end
+ decay_steps = num_training_steps - num_warmup_steps
+ pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
+ decay = lr_range * pct_remaining**power + lr_end
+ return decay / lr_init # as LambdaLR multiplies by lr_init
+
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+TYPE_TO_SCHEDULER_FUNCTION = {
+ SchedulerType.LINEAR: get_linear_schedule_with_warmup,
+ SchedulerType.COSINE: get_cosine_schedule_with_warmup,
+ SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
+ SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
+ SchedulerType.CONSTANT: get_constant_schedule,
+ SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
+}
+
+
+def get_scheduler(
+ name: Union[str, SchedulerType],
+ optimizer: Optimizer,
+ num_warmup_steps: Optional[int] = None,
+ num_training_steps: Optional[int] = None,
+):
+ """
+ Unified API to get any scheduler from its name.
+
+ Args:
+ name (`str` or `SchedulerType`):
+ The name of the scheduler to use.
+ optimizer (`torch.optim.Optimizer`):
+ The optimizer that will be used during training.
+ num_warmup_steps (`int`, *optional*):
+ The number of warmup steps to do. This is not required by all schedulers (hence the argument being
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
+ num_training_steps (`int``, *optional*):
+ The number of training steps to do. This is not required by all schedulers (hence the argument being
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
+ """
+ name = SchedulerType(name)
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
+ if name == SchedulerType.CONSTANT:
+ return schedule_func(optimizer)
+
+ # All other schedulers require `num_warmup_steps`
+ if num_warmup_steps is None:
+ raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
+
+ if name == SchedulerType.CONSTANT_WITH_WARMUP:
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
+
+ # All other schedulers require `num_training_steps`
+ if num_training_steps is None:
+ raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
+
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
diff --git a/my_half_diffusers/pipeline_utils.py b/my_half_diffusers/pipeline_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..84ee9e20f1107a54dcdaf2799d805cf9e4f3b0a7
--- /dev/null
+++ b/my_half_diffusers/pipeline_utils.py
@@ -0,0 +1,417 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+import inspect
+import os
+from dataclasses import dataclass
+from typing import List, Optional, Union
+
+import numpy as np
+import torch
+
+import diffusers
+import PIL
+from huggingface_hub import snapshot_download
+from PIL import Image
+from tqdm.auto import tqdm
+
+from .configuration_utils import ConfigMixin
+from .utils import DIFFUSERS_CACHE, BaseOutput, logging
+
+
+INDEX_FILE = "diffusion_pytorch_model.bin"
+
+
+logger = logging.get_logger(__name__)
+
+
+LOADABLE_CLASSES = {
+ "diffusers": {
+ "ModelMixin": ["save_pretrained", "from_pretrained"],
+ "SchedulerMixin": ["save_config", "from_config"],
+ "DiffusionPipeline": ["save_pretrained", "from_pretrained"],
+ "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
+ },
+ "transformers": {
+ "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
+ "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
+ "PreTrainedModel": ["save_pretrained", "from_pretrained"],
+ "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
+ },
+}
+
+ALL_IMPORTABLE_CLASSES = {}
+for library in LOADABLE_CLASSES:
+ ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
+
+
+@dataclass
+class ImagePipelineOutput(BaseOutput):
+ """
+ Output class for image pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
+
+
+class DiffusionPipeline(ConfigMixin):
+ r"""
+ Base class for all models.
+
+ [`DiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion pipelines
+ and handles methods for loading, downloading and saving models as well as a few methods common to all pipelines to:
+
+ - move all PyTorch modules to the device of your choice
+ - enabling/disabling the progress bar for the denoising iteration
+
+ Class attributes:
+
+ - **config_name** ([`str`]) -- name of the config file that will store the class and module names of all
+ compenents of the diffusion pipeline.
+ """
+ config_name = "model_index.json"
+
+ def register_modules(self, **kwargs):
+ # import it here to avoid circular import
+ from diffusers import pipelines
+
+ for name, module in kwargs.items():
+ # retrive library
+ library = module.__module__.split(".")[0]
+
+ # check if the module is a pipeline module
+ pipeline_dir = module.__module__.split(".")[-2]
+ path = module.__module__.split(".")
+ is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
+
+ # if library is not in LOADABLE_CLASSES, then it is a custom module.
+ # Or if it's a pipeline module, then the module is inside the pipeline
+ # folder so we set the library to module name.
+ if library not in LOADABLE_CLASSES or is_pipeline_module:
+ library = pipeline_dir
+
+ # retrive class_name
+ class_name = module.__class__.__name__
+
+ register_dict = {name: (library, class_name)}
+
+ # save model index config
+ self.register_to_config(**register_dict)
+
+ # set models
+ setattr(self, name, module)
+
+ def save_pretrained(self, save_directory: Union[str, os.PathLike]):
+ """
+ Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to
+ a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading
+ method. The pipeline can easily be re-loaded using the `[`~DiffusionPipeline.from_pretrained`]` class method.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to which to save. Will be created if it doesn't exist.
+ """
+ self.save_config(save_directory)
+
+ model_index_dict = dict(self.config)
+ model_index_dict.pop("_class_name")
+ model_index_dict.pop("_diffusers_version")
+ model_index_dict.pop("_module", None)
+
+ for pipeline_component_name in model_index_dict.keys():
+ sub_model = getattr(self, pipeline_component_name)
+ model_cls = sub_model.__class__
+
+ save_method_name = None
+ # search for the model's base class in LOADABLE_CLASSES
+ for library_name, library_classes in LOADABLE_CLASSES.items():
+ library = importlib.import_module(library_name)
+ for base_class, save_load_methods in library_classes.items():
+ class_candidate = getattr(library, base_class)
+ if issubclass(model_cls, class_candidate):
+ # if we found a suitable base class in LOADABLE_CLASSES then grab its save method
+ save_method_name = save_load_methods[0]
+ break
+ if save_method_name is not None:
+ break
+
+ save_method = getattr(sub_model, save_method_name)
+ save_method(os.path.join(save_directory, pipeline_component_name))
+
+ def to(self, torch_device: Optional[Union[str, torch.device]] = None):
+ if torch_device is None:
+ return self
+
+ module_names, _ = self.extract_init_dict(dict(self.config))
+ for name in module_names.keys():
+ module = getattr(self, name)
+ if isinstance(module, torch.nn.Module):
+ module.to(torch_device)
+ return self
+
+ @property
+ def device(self) -> torch.device:
+ r"""
+ Returns:
+ `torch.device`: The torch device on which the pipeline is located.
+ """
+ module_names, _ = self.extract_init_dict(dict(self.config))
+ for name in module_names.keys():
+ module = getattr(self, name)
+ if isinstance(module, torch.nn.Module):
+ return module.device
+ return torch.device("cpu")
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
+ r"""
+ Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights.
+
+ The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
+
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
+ task.
+
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
+ weights are discarded.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on
+ https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like
+ `CompVis/ldm-text2im-large-256`.
+ - A path to a *directory* containing pipeline weights saved using
+ [`~DiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
+ will be automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ mirror (`str`, *optional*):
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
+ Please refer to the mirror site for more information. specify the folder name here.
+
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
+ speficic pipeline class. The overritten components are then directly passed to the pipelines `__init__`
+ method. See example below for more information.
+
+
+
+ Passing `use_auth_token=True`` is required when you want to use a private model, *e.g.*
+ `"CompVis/stable-diffusion-v1-4"`
+
+
+
+
+
+ Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
+ this method in a firewalled environment.
+
+
+
+ Examples:
+
+ ```py
+ >>> from diffusers import DiffusionPipeline
+
+ >>> # Download pipeline from huggingface.co and cache.
+ >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
+
+ >>> # Download pipeline that requires an authorization token
+ >>> # For more information on access tokens, please refer to this section
+ >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
+ >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
+
+ >>> # Download pipeline, but overwrite scheduler
+ >>> from diffusers import LMSDiscreteScheduler
+
+ >>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
+ >>> pipeline = DiffusionPipeline.from_pretrained(
+ ... "CompVis/stable-diffusion-v1-4", scheduler=scheduler, use_auth_token=True
+ ... )
+ ```
+ """
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ torch_dtype = kwargs.pop("torch_dtype", None)
+ provider = kwargs.pop("provider", None)
+
+ # 1. Download the checkpoints and configs
+ # use snapshot download here to get it working from from_pretrained
+ if not os.path.isdir(pretrained_model_name_or_path):
+ cached_folder = snapshot_download(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ )
+ else:
+ cached_folder = pretrained_model_name_or_path
+
+ config_dict = cls.get_config_dict(cached_folder)
+
+ # 2. Load the pipeline class, if using custom module then load it from the hub
+ # if we load from explicit class, let's use it
+ if cls != DiffusionPipeline:
+ pipeline_class = cls
+ else:
+ diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
+ pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
+
+ # some modules can be passed directly to the init
+ # in this case they are already instantiated in `kwargs`
+ # extract them here
+ expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys())
+ passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
+
+ init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
+
+ init_kwargs = {}
+
+ # import it here to avoid circular import
+ from diffusers import pipelines
+
+ # 3. Load each module in the pipeline
+ for name, (library_name, class_name) in init_dict.items():
+ is_pipeline_module = hasattr(pipelines, library_name)
+ loaded_sub_model = None
+
+ # if the model is in a pipeline module, then we load it from the pipeline
+ if name in passed_class_obj:
+ # 1. check that passed_class_obj has correct parent class
+ if not is_pipeline_module:
+ library = importlib.import_module(library_name)
+ class_obj = getattr(library, class_name)
+ importable_classes = LOADABLE_CLASSES[library_name]
+ class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
+
+ expected_class_obj = None
+ for class_name, class_candidate in class_candidates.items():
+ if issubclass(class_obj, class_candidate):
+ expected_class_obj = class_candidate
+
+ if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
+ raise ValueError(
+ f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
+ f" {expected_class_obj}"
+ )
+ else:
+ logger.warn(
+ f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
+ " has the correct type"
+ )
+
+ # set passed class object
+ loaded_sub_model = passed_class_obj[name]
+ elif is_pipeline_module:
+ pipeline_module = getattr(pipelines, library_name)
+ class_obj = getattr(pipeline_module, class_name)
+ importable_classes = ALL_IMPORTABLE_CLASSES
+ class_candidates = {c: class_obj for c in importable_classes.keys()}
+ else:
+ # else we just import it from the library.
+ library = importlib.import_module(library_name)
+ class_obj = getattr(library, class_name)
+ importable_classes = LOADABLE_CLASSES[library_name]
+ class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
+
+ if loaded_sub_model is None:
+ load_method_name = None
+ for class_name, class_candidate in class_candidates.items():
+ if issubclass(class_obj, class_candidate):
+ load_method_name = importable_classes[class_name][1]
+
+ load_method = getattr(class_obj, load_method_name)
+
+ loading_kwargs = {}
+ if issubclass(class_obj, torch.nn.Module):
+ loading_kwargs["torch_dtype"] = torch_dtype
+ if issubclass(class_obj, diffusers.OnnxRuntimeModel):
+ loading_kwargs["provider"] = provider
+
+ # check if the module is in a subdirectory
+ if os.path.isdir(os.path.join(cached_folder, name)):
+ loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
+ else:
+ # else load from the root directory
+ loaded_sub_model = load_method(cached_folder, **loading_kwargs)
+
+ init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
+
+ # 4. Instantiate the pipeline
+ model = pipeline_class(**init_kwargs)
+ return model
+
+ @staticmethod
+ def numpy_to_pil(images):
+ """
+ Convert a numpy image or a batch of images to a PIL image.
+ """
+ if images.ndim == 3:
+ images = images[None, ...]
+ images = (images * 255).round().astype("uint8")
+ pil_images = [Image.fromarray(image) for image in images]
+
+ return pil_images
+
+ def progress_bar(self, iterable):
+ if not hasattr(self, "_progress_bar_config"):
+ self._progress_bar_config = {}
+ elif not isinstance(self._progress_bar_config, dict):
+ raise ValueError(
+ f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
+ )
+
+ return tqdm(iterable, **self._progress_bar_config)
+
+ def set_progress_bar_config(self, **kwargs):
+ self._progress_bar_config = kwargs
diff --git a/my_half_diffusers/pipelines/__init__.py b/my_half_diffusers/pipelines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e2aeb4fb2b7f1315adb3a2ddea6aec42e806779
--- /dev/null
+++ b/my_half_diffusers/pipelines/__init__.py
@@ -0,0 +1,19 @@
+from ..utils import is_onnx_available, is_transformers_available
+from .ddim import DDIMPipeline
+from .ddpm import DDPMPipeline
+from .latent_diffusion_uncond import LDMPipeline
+from .pndm import PNDMPipeline
+from .score_sde_ve import ScoreSdeVePipeline
+from .stochastic_karras_ve import KarrasVePipeline
+
+
+if is_transformers_available():
+ from .latent_diffusion import LDMTextToImagePipeline
+ from .stable_diffusion import (
+ StableDiffusionImg2ImgPipeline,
+ StableDiffusionInpaintPipeline,
+ StableDiffusionPipeline,
+ )
+
+if is_transformers_available() and is_onnx_available():
+ from .stable_diffusion import StableDiffusionOnnxPipeline
diff --git a/my_half_diffusers/pipelines/__pycache__/__init__.cpython-38.pyc b/my_half_diffusers/pipelines/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..69bc4821a0ad0ceae1c9a5717931562c9228ffce
Binary files /dev/null and b/my_half_diffusers/pipelines/__pycache__/__init__.cpython-38.pyc differ
diff --git a/my_half_diffusers/pipelines/ddim/__init__.py b/my_half_diffusers/pipelines/ddim/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fd31868a88ac0d9ec7118574f21a9d8a1d4069b
--- /dev/null
+++ b/my_half_diffusers/pipelines/ddim/__init__.py
@@ -0,0 +1,2 @@
+# flake8: noqa
+from .pipeline_ddim import DDIMPipeline
diff --git a/my_half_diffusers/pipelines/ddim/__pycache__/__init__.cpython-38.pyc b/my_half_diffusers/pipelines/ddim/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f5c19d3b02565807ddcba81a7d5238041f9bb786
Binary files /dev/null and b/my_half_diffusers/pipelines/ddim/__pycache__/__init__.cpython-38.pyc differ
diff --git a/my_half_diffusers/pipelines/ddim/__pycache__/pipeline_ddim.cpython-38.pyc b/my_half_diffusers/pipelines/ddim/__pycache__/pipeline_ddim.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9467b2d0621e5f1387b06b175e1a93795529982d
Binary files /dev/null and b/my_half_diffusers/pipelines/ddim/__pycache__/pipeline_ddim.cpython-38.pyc differ
diff --git a/my_half_diffusers/pipelines/ddim/pipeline_ddim.py b/my_half_diffusers/pipelines/ddim/pipeline_ddim.py
new file mode 100644
index 0000000000000000000000000000000000000000..33f6064dbba347dc82a941edac42e178a3e7df8a
--- /dev/null
+++ b/my_half_diffusers/pipelines/ddim/pipeline_ddim.py
@@ -0,0 +1,117 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+# limitations under the License.
+
+
+import warnings
+from typing import Optional, Tuple, Union
+
+import torch
+
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+
+
+class DDIMPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
+ [`DDPMScheduler`], or [`DDIMScheduler`].
+ """
+
+ def __init__(self, unet, scheduler):
+ super().__init__()
+ scheduler = scheduler.set_format("pt")
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ generator: Optional[torch.Generator] = None,
+ eta: float = 0.0,
+ num_inference_steps: int = 50,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of images to generate.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ eta (`float`, *optional*, defaults to 0.0):
+ The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM).
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+
+ if "torch_device" in kwargs:
+ device = kwargs.pop("torch_device")
+ warnings.warn(
+ "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
+ " Consider using `pipe.to(torch_device)` instead."
+ )
+
+ # Set device as before (to be removed in 0.3.0)
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.to(device)
+
+ # eta corresponds to η in paper and should be between [0, 1]
+
+ # Sample gaussian noise to begin loop
+ image = torch.randn(
+ (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
+ generator=generator,
+ )
+ image = image.to(self.device)
+
+ # set step values
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ # 1. predict noise model_output
+ model_output = self.unet(image, t).sample
+
+ # 2. predict previous mean of image x_t-1 and add variance depending on eta
+ # do x_t -> x_t-1
+ image = self.scheduler.step(model_output, t, image, eta).prev_sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/my_half_diffusers/pipelines/ddpm/__init__.py b/my_half_diffusers/pipelines/ddpm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8889bdae1224e91916e0f8454bafba0ee566f3b9
--- /dev/null
+++ b/my_half_diffusers/pipelines/ddpm/__init__.py
@@ -0,0 +1,2 @@
+# flake8: noqa
+from .pipeline_ddpm import DDPMPipeline
diff --git a/my_half_diffusers/pipelines/ddpm/__pycache__/__init__.cpython-38.pyc b/my_half_diffusers/pipelines/ddpm/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7d64ba8e7662684f8821e022803f9c8fbead68d3
Binary files /dev/null and b/my_half_diffusers/pipelines/ddpm/__pycache__/__init__.cpython-38.pyc differ
diff --git a/my_half_diffusers/pipelines/ddpm/__pycache__/pipeline_ddpm.cpython-38.pyc b/my_half_diffusers/pipelines/ddpm/__pycache__/pipeline_ddpm.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..40b8340d335f89dd704feac24e339992a7f78aef
Binary files /dev/null and b/my_half_diffusers/pipelines/ddpm/__pycache__/pipeline_ddpm.cpython-38.pyc differ
diff --git a/my_half_diffusers/pipelines/ddpm/pipeline_ddpm.py b/my_half_diffusers/pipelines/ddpm/pipeline_ddpm.py
new file mode 100644
index 0000000000000000000000000000000000000000..71103bbe4d051e94f3fca9122460464fb8b1a4f7
--- /dev/null
+++ b/my_half_diffusers/pipelines/ddpm/pipeline_ddpm.py
@@ -0,0 +1,106 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+# limitations under the License.
+
+
+import warnings
+from typing import Optional, Tuple, Union
+
+import torch
+
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+
+
+class DDPMPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
+ [`DDPMScheduler`], or [`DDIMScheduler`].
+ """
+
+ def __init__(self, unet, scheduler):
+ super().__init__()
+ scheduler = scheduler.set_format("pt")
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of images to generate.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+ if "torch_device" in kwargs:
+ device = kwargs.pop("torch_device")
+ warnings.warn(
+ "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
+ " Consider using `pipe.to(torch_device)` instead."
+ )
+
+ # Set device as before (to be removed in 0.3.0)
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.to(device)
+
+ # Sample gaussian noise to begin loop
+ image = torch.randn(
+ (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
+ generator=generator,
+ )
+ image = image.to(self.device)
+
+ # set step values
+ self.scheduler.set_timesteps(1000)
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ # 1. predict noise model_output
+ model_output = self.unet(image, t).sample
+
+ # 2. compute previous image: x_t -> t_t-1
+ image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/my_half_diffusers/pipelines/latent_diffusion/__init__.py b/my_half_diffusers/pipelines/latent_diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c481b38cf5e0a1c4e24f7e0edf944efb68e1f979
--- /dev/null
+++ b/my_half_diffusers/pipelines/latent_diffusion/__init__.py
@@ -0,0 +1,6 @@
+# flake8: noqa
+from ...utils import is_transformers_available
+
+
+if is_transformers_available():
+ from .pipeline_latent_diffusion import LDMBertModel, LDMTextToImagePipeline
diff --git a/my_half_diffusers/pipelines/latent_diffusion/__pycache__/__init__.cpython-38.pyc b/my_half_diffusers/pipelines/latent_diffusion/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c793f1f9814e31666df86a6d57c8ee418567f939
Binary files /dev/null and b/my_half_diffusers/pipelines/latent_diffusion/__pycache__/__init__.cpython-38.pyc differ
diff --git a/my_half_diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion.cpython-38.pyc b/my_half_diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..602fafb898655830afdf014ae51c97b776d0b8da
Binary files /dev/null and b/my_half_diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion.cpython-38.pyc differ
diff --git a/my_half_diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/my_half_diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..b39840f2436b1deda0443fe0883eb4d1f6b73957
--- /dev/null
+++ b/my_half_diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
@@ -0,0 +1,705 @@
+import inspect
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+
+from transformers.activations import ACT2FN
+from transformers.configuration_utils import PretrainedConfig
+from transformers.modeling_outputs import BaseModelOutput
+from transformers.modeling_utils import PreTrainedModel
+from transformers.tokenization_utils import PreTrainedTokenizer
+from transformers.utils import logging
+
+from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+
+
+class LDMTextToImagePipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ vqvae ([`VQModel`]):
+ Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
+ bert ([`LDMBertModel`]):
+ Text-encoder model based on [BERT](ttps://huggingface.co/docs/transformers/model_doc/bert) architecture.
+ tokenizer (`transformers.BertTokenizer`):
+ Tokenizer of class
+ [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+
+ def __init__(
+ self,
+ vqvae: Union[VQModel, AutoencoderKL],
+ bert: PreTrainedModel,
+ tokenizer: PreTrainedTokenizer,
+ unet: Union[UNet2DModel, UNet2DConditionModel],
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ ):
+ super().__init__()
+ scheduler = scheduler.set_format("pt")
+ self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int] = 256,
+ width: Optional[int] = 256,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 1.0,
+ eta: Optional[float] = 0.0,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[Tuple, ImagePipelineOutput]:
+ r"""
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to 256):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 256):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 1.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt` at
+ the, usually at the expense of lower image quality.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+ if "torch_device" in kwargs:
+ device = kwargs.pop("torch_device")
+ warnings.warn(
+ "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
+ " Consider using `pipe.to(torch_device)` instead."
+ )
+
+ # Set device as before (to be removed in 0.3.0)
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.to(device)
+
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ # get unconditional embeddings for classifier free guidance
+ if guidance_scale != 1.0:
+ uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
+ uncond_embeddings = self.bert(uncond_input.input_ids.to(self.device))[0]
+
+ # get prompt text embeddings
+ text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
+ text_embeddings = self.bert(text_input.input_ids.to(self.device))[0]
+
+ latents = torch.randn(
+ (batch_size, self.unet.in_channels, height // 8, width // 8),
+ generator=generator,
+ )
+ latents = latents.to(self.device)
+
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+
+ extra_kwargs = {}
+ if accepts_eta:
+ extra_kwargs["eta"] = eta
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ if guidance_scale == 1.0:
+ # guidance_scale of 1 means no guidance
+ latents_input = latents
+ context = text_embeddings
+ else:
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ latents_input = torch.cat([latents] * 2)
+ context = torch.cat([uncond_embeddings, text_embeddings])
+
+ # predict the noise residual
+ noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample
+ # perform guidance
+ if guidance_scale != 1.0:
+ noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
+
+ # scale and decode the image latents with vae
+ latents = 1 / 0.18215 * latents
+ image = self.vqvae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
+
+
+################################################################################
+# Code for the text transformer model
+################################################################################
+""" PyTorch LDMBERT model."""
+
+
+logger = logging.get_logger(__name__)
+
+LDMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "ldm-bert",
+ # See all LDMBert models at https://huggingface.co/models?filter=ldmbert
+]
+
+
+LDMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "ldm-bert": "https://huggingface.co/ldm-bert/resolve/main/config.json",
+}
+
+
+""" LDMBERT model configuration"""
+
+
+class LDMBertConfig(PretrainedConfig):
+ model_type = "ldmbert"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
+
+ def __init__(
+ self,
+ vocab_size=30522,
+ max_position_embeddings=77,
+ encoder_layers=32,
+ encoder_ffn_dim=5120,
+ encoder_attention_heads=8,
+ head_dim=64,
+ encoder_layerdrop=0.0,
+ activation_function="gelu",
+ d_model=1280,
+ dropout=0.1,
+ attention_dropout=0.0,
+ activation_dropout=0.0,
+ init_std=0.02,
+ classifier_dropout=0.0,
+ scale_embedding=False,
+ use_cache=True,
+ pad_token_id=0,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.d_model = d_model
+ self.encoder_ffn_dim = encoder_ffn_dim
+ self.encoder_layers = encoder_layers
+ self.encoder_attention_heads = encoder_attention_heads
+ self.head_dim = head_dim
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.activation_dropout = activation_dropout
+ self.activation_function = activation_function
+ self.init_std = init_std
+ self.encoder_layerdrop = encoder_layerdrop
+ self.classifier_dropout = classifier_dropout
+ self.use_cache = use_cache
+ self.num_hidden_layers = encoder_layers
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
+
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
+
+
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+
+# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->LDMBert
+class LDMBertAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ head_dim: int,
+ dropout: float = 0.0,
+ is_decoder: bool = False,
+ bias: bool = False,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = head_dim
+ self.inner_dim = head_dim * num_heads
+
+ self.scaling = self.head_dim**-0.5
+ self.is_decoder = is_decoder
+
+ self.k_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias)
+ self.v_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias)
+ self.q_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias)
+ self.out_proj = nn.Linear(self.inner_dim, embed_dim)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scaling
+ # get key, value proj
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+ else:
+ # self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.view(*proj_shape)
+ value_states = value_states.view(*proj_shape)
+
+ src_len = key_states.size(1)
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if layer_head_mask is not None:
+ if layer_head_mask.size() != (self.num_heads,):
+ raise ValueError(
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
+ )
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if output_attentions:
+ # this operation is a bit awkward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to be reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ attn_output = torch.bmm(attn_probs, value_states)
+
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+ # partitioned aross GPUs when using tensor-parallelism.
+ attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped, past_key_value
+
+
+class LDMBertEncoderLayer(nn.Module):
+ def __init__(self, config: LDMBertConfig):
+ super().__init__()
+ self.embed_dim = config.d_model
+ self.self_attn = LDMBertAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.encoder_attention_heads,
+ head_dim=config.head_dim,
+ dropout=config.attention_dropout,
+ )
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: torch.FloatTensor,
+ layer_head_mask: torch.FloatTensor,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+ hidden_states, attn_weights, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ if hidden_states.dtype == torch.float16 and (
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
+ ):
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+# Copied from transformers.models.bart.modeling_bart.BartPretrainedModel with Bart->LDMBert
+class LDMBertPreTrainedModel(PreTrainedModel):
+ config_class = LDMBertConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"]
+
+ def _init_weights(self, module):
+ std = self.config.init_std
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (LDMBertEncoder,)):
+ module.gradient_checkpointing = value
+
+ @property
+ def dummy_inputs(self):
+ pad_token = self.config.pad_token_id
+ input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
+ dummy_inputs = {
+ "attention_mask": input_ids.ne(pad_token),
+ "input_ids": input_ids,
+ }
+ return dummy_inputs
+
+
+class LDMBertEncoder(LDMBertPreTrainedModel):
+ """
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
+ [`LDMBertEncoderLayer`].
+
+ Args:
+ config: LDMBertConfig
+ embed_tokens (nn.Embedding): output embedding
+ """
+
+ def __init__(self, config: LDMBertConfig):
+ super().__init__(config)
+
+ self.dropout = config.dropout
+
+ embed_dim = config.d_model
+ self.padding_idx = config.pad_token_id
+ self.max_source_positions = config.max_position_embeddings
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim)
+ self.embed_positions = nn.Embedding(config.max_position_embeddings, embed_dim)
+ self.layers = nn.ModuleList([LDMBertEncoderLayer(config) for _ in range(config.encoder_layers)])
+ self.layer_norm = nn.LayerNorm(embed_dim)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.BaseModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ seq_len = input_shape[1]
+ if position_ids is None:
+ position_ids = torch.arange(seq_len, dtype=torch.long, device=inputs_embeds.device).expand((1, -1))
+ embed_pos = self.embed_positions(position_ids)
+
+ hidden_states = inputs_embeds + embed_pos
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ # expand attention_mask
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ # check if head_mask has a correct number of layers specified if desired
+ if head_mask is not None:
+ if head_mask.size()[0] != (len(self.layers)):
+ raise ValueError(
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
+
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(encoder_layer),
+ hidden_states,
+ attention_mask,
+ (head_mask[idx] if head_mask is not None else None),
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+
+class LDMBertModel(LDMBertPreTrainedModel):
+ def __init__(self, config: LDMBertConfig):
+ super().__init__(config)
+ self.model = LDMBertEncoder(config)
+ self.to_logits = nn.Linear(config.hidden_size, config.vocab_size)
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ return outputs
diff --git a/my_half_diffusers/pipelines/latent_diffusion_uncond/__init__.py b/my_half_diffusers/pipelines/latent_diffusion_uncond/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0826ca7536c706f9bc1f310c157068efbca7f0b3
--- /dev/null
+++ b/my_half_diffusers/pipelines/latent_diffusion_uncond/__init__.py
@@ -0,0 +1,2 @@
+# flake8: noqa
+from .pipeline_latent_diffusion_uncond import LDMPipeline
diff --git a/my_half_diffusers/pipelines/latent_diffusion_uncond/__pycache__/__init__.cpython-38.pyc b/my_half_diffusers/pipelines/latent_diffusion_uncond/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6fe71d34ed267abd314458a60e42cfca4e6fa84f
Binary files /dev/null and b/my_half_diffusers/pipelines/latent_diffusion_uncond/__pycache__/__init__.cpython-38.pyc differ
diff --git a/my_half_diffusers/pipelines/latent_diffusion_uncond/__pycache__/pipeline_latent_diffusion_uncond.cpython-38.pyc b/my_half_diffusers/pipelines/latent_diffusion_uncond/__pycache__/pipeline_latent_diffusion_uncond.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..91ab565556d010be9b1f3d95134c72cc5ba99752
Binary files /dev/null and b/my_half_diffusers/pipelines/latent_diffusion_uncond/__pycache__/pipeline_latent_diffusion_uncond.cpython-38.pyc differ
diff --git a/my_half_diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/my_half_diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
new file mode 100644
index 0000000000000000000000000000000000000000..4979d88feee933483ac49c5cf71eef590d8fb34c
--- /dev/null
+++ b/my_half_diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
@@ -0,0 +1,108 @@
+import inspect
+import warnings
+from typing import Optional, Tuple, Union
+
+import torch
+
+from ...models import UNet2DModel, VQModel
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ...schedulers import DDIMScheduler
+
+
+class LDMPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ vqvae ([`VQModel`]):
+ Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ [`DDIMScheduler`] is to be used in combination with `unet` to denoise the encoded image latens.
+ """
+
+ def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler):
+ super().__init__()
+ scheduler = scheduler.set_format("pt")
+ self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ generator: Optional[torch.Generator] = None,
+ eta: float = 0.0,
+ num_inference_steps: int = 50,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[Tuple, ImagePipelineOutput]:
+
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ Number of images to generate.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+
+ if "torch_device" in kwargs:
+ device = kwargs.pop("torch_device")
+ warnings.warn(
+ "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
+ " Consider using `pipe.to(torch_device)` instead."
+ )
+
+ # Set device as before (to be removed in 0.3.0)
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.to(device)
+
+ latents = torch.randn(
+ (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
+ generator=generator,
+ )
+ latents = latents.to(self.device)
+
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+
+ extra_kwargs = {}
+ if accepts_eta:
+ extra_kwargs["eta"] = eta
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ # predict the noise residual
+ noise_prediction = self.unet(latents, t).sample
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample
+
+ # decode the image latents with the VAE
+ image = self.vqvae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/my_half_diffusers/pipelines/pndm/__init__.py b/my_half_diffusers/pipelines/pndm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fc46aaab9fa26e83b49c26843d854e217742664
--- /dev/null
+++ b/my_half_diffusers/pipelines/pndm/__init__.py
@@ -0,0 +1,2 @@
+# flake8: noqa
+from .pipeline_pndm import PNDMPipeline
diff --git a/my_half_diffusers/pipelines/pndm/__pycache__/__init__.cpython-38.pyc b/my_half_diffusers/pipelines/pndm/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..febaf4e3ac1c7889ff8d4aa2e1040f85b8bec69b
Binary files /dev/null and b/my_half_diffusers/pipelines/pndm/__pycache__/__init__.cpython-38.pyc differ
diff --git a/my_half_diffusers/pipelines/pndm/__pycache__/pipeline_pndm.cpython-38.pyc b/my_half_diffusers/pipelines/pndm/__pycache__/pipeline_pndm.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..365a852e5b5809ace74e0ae22b1f47b9affcc394
Binary files /dev/null and b/my_half_diffusers/pipelines/pndm/__pycache__/pipeline_pndm.cpython-38.pyc differ
diff --git a/my_half_diffusers/pipelines/pndm/pipeline_pndm.py b/my_half_diffusers/pipelines/pndm/pipeline_pndm.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3dff1a9a9416ef7592200c7dbb2ee092bd524d5
--- /dev/null
+++ b/my_half_diffusers/pipelines/pndm/pipeline_pndm.py
@@ -0,0 +1,111 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+# limitations under the License.
+
+
+import warnings
+from typing import Optional, Tuple, Union
+
+import torch
+
+from ...models import UNet2DModel
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ...schedulers import PNDMScheduler
+
+
+class PNDMPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ unet (`UNet2DModel`): U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ The `PNDMScheduler` to be used in combination with `unet` to denoise the encoded image.
+ """
+
+ unet: UNet2DModel
+ scheduler: PNDMScheduler
+
+ def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler):
+ super().__init__()
+ scheduler = scheduler.set_format("pt")
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ num_inference_steps: int = 50,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Args:
+ batch_size (`int`, `optional`, defaults to 1): The number of images to generate.
+ num_inference_steps (`int`, `optional`, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ generator (`torch.Generator`, `optional`): A [torch
+ generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ output_type (`str`, `optional`, defaults to `"pil"`): The output format of the generate image. Choose
+ between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, `optional`, defaults to `True`): Whether or not to return a
+ [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+ # For more information on the sampling method you can take a look at Algorithm 2 of
+ # the official paper: https://arxiv.org/pdf/2202.09778.pdf
+
+ if "torch_device" in kwargs:
+ device = kwargs.pop("torch_device")
+ warnings.warn(
+ "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
+ " Consider using `pipe.to(torch_device)` instead."
+ )
+
+ # Set device as before (to be removed in 0.3.0)
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.to(device)
+
+ # Sample gaussian noise to begin loop
+ image = torch.randn(
+ (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
+ generator=generator,
+ )
+ image = image.to(self.device)
+
+ self.scheduler.set_timesteps(num_inference_steps)
+ for t in self.progress_bar(self.scheduler.timesteps):
+ model_output = self.unet(image, t).sample
+
+ image = self.scheduler.step(model_output, t, image).prev_sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/my_half_diffusers/pipelines/score_sde_ve/__init__.py b/my_half_diffusers/pipelines/score_sde_ve/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..000d61f6e9b183728cb6fc137e7180cac3a616df
--- /dev/null
+++ b/my_half_diffusers/pipelines/score_sde_ve/__init__.py
@@ -0,0 +1,2 @@
+# flake8: noqa
+from .pipeline_score_sde_ve import ScoreSdeVePipeline
diff --git a/my_half_diffusers/pipelines/score_sde_ve/__pycache__/__init__.cpython-38.pyc b/my_half_diffusers/pipelines/score_sde_ve/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5bcb2835c90eaefb41ed577f47d34e3e8a1b785f
Binary files /dev/null and b/my_half_diffusers/pipelines/score_sde_ve/__pycache__/__init__.cpython-38.pyc differ
diff --git a/my_half_diffusers/pipelines/score_sde_ve/__pycache__/pipeline_score_sde_ve.cpython-38.pyc b/my_half_diffusers/pipelines/score_sde_ve/__pycache__/pipeline_score_sde_ve.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f3d66b31cc1a120ca65aef46e49f44909aee72f5
Binary files /dev/null and b/my_half_diffusers/pipelines/score_sde_ve/__pycache__/pipeline_score_sde_ve.cpython-38.pyc differ
diff --git a/my_half_diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/my_half_diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
new file mode 100644
index 0000000000000000000000000000000000000000..604e2b54cc1766ff446a23235ae4b40f790eadc5
--- /dev/null
+++ b/my_half_diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
@@ -0,0 +1,101 @@
+#!/usr/bin/env python3
+import warnings
+from typing import Optional, Tuple, Union
+
+import torch
+
+from ...models import UNet2DModel
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ...schedulers import ScoreSdeVeScheduler
+
+
+class ScoreSdeVePipeline(DiffusionPipeline):
+ r"""
+ Parameters:
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. scheduler ([`SchedulerMixin`]):
+ The [`ScoreSdeVeScheduler`] scheduler to be used in combination with `unet` to denoise the encoded image.
+ """
+ unet: UNet2DModel
+ scheduler: ScoreSdeVeScheduler
+
+ def __init__(self, unet: UNet2DModel, scheduler: DiffusionPipeline):
+ super().__init__()
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ num_inference_steps: int = 2000,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of images to generate.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+
+ if "torch_device" in kwargs:
+ device = kwargs.pop("torch_device")
+ warnings.warn(
+ "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
+ " Consider using `pipe.to(torch_device)` instead."
+ )
+
+ # Set device as before (to be removed in 0.3.0)
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.to(device)
+
+ img_size = self.unet.config.sample_size
+ shape = (batch_size, 3, img_size, img_size)
+
+ model = self.unet
+
+ sample = torch.randn(*shape, generator=generator) * self.scheduler.config.sigma_max
+ sample = sample.to(self.device)
+
+ self.scheduler.set_timesteps(num_inference_steps)
+ self.scheduler.set_sigmas(num_inference_steps)
+
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
+ sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device)
+
+ # correction step
+ for _ in range(self.scheduler.correct_steps):
+ model_output = self.unet(sample, sigma_t).sample
+ sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample
+
+ # prediction step
+ model_output = model(sample, sigma_t).sample
+ output = self.scheduler.step_pred(model_output, t, sample, generator=generator)
+
+ sample, sample_mean = output.prev_sample, output.prev_sample_mean
+
+ sample = sample_mean.clamp(0, 1)
+ sample = sample.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ sample = self.numpy_to_pil(sample)
+
+ if not return_dict:
+ return (sample,)
+
+ return ImagePipelineOutput(images=sample)
diff --git a/my_half_diffusers/pipelines/stable_diffusion/__init__.py b/my_half_diffusers/pipelines/stable_diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ffda93f172142c03298972177b9a74b85867be6
--- /dev/null
+++ b/my_half_diffusers/pipelines/stable_diffusion/__init__.py
@@ -0,0 +1,37 @@
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+
+import PIL
+from PIL import Image
+
+from ...utils import BaseOutput, is_onnx_available, is_transformers_available
+
+
+@dataclass
+class StableDiffusionPipelineOutput(BaseOutput):
+ """
+ Output class for Stable Diffusion pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ nsfw_content_detected (`List[bool]`)
+ List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
+ nsfw_content_detected: List[bool]
+
+
+if is_transformers_available():
+ from .pipeline_stable_diffusion import StableDiffusionPipeline
+ from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
+ from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
+ from .safety_checker import StableDiffusionSafetyChecker
+
+if is_transformers_available() and is_onnx_available():
+ from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline
diff --git a/my_half_diffusers/pipelines/stable_diffusion/__pycache__/__init__.cpython-38.pyc b/my_half_diffusers/pipelines/stable_diffusion/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1a2a2879830b2f7c7879a9b3bc2d9d18ebe7ffd4
Binary files /dev/null and b/my_half_diffusers/pipelines/stable_diffusion/__pycache__/__init__.cpython-38.pyc differ
diff --git a/my_half_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-38.pyc b/my_half_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0d20a38a56afb3a3bceae914285e2cbbd71350f2
Binary files /dev/null and b/my_half_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-38.pyc differ
diff --git a/my_half_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_img2img.cpython-38.pyc b/my_half_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_img2img.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f58cee7c1cd862bdb1e4909f0dd2a65d6a91b285
Binary files /dev/null and b/my_half_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_img2img.cpython-38.pyc differ
diff --git a/my_half_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint.cpython-38.pyc b/my_half_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0ef7456c5c312cf39113491b863e7a986e7596a6
Binary files /dev/null and b/my_half_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint.cpython-38.pyc differ
diff --git a/my_half_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_onnx.cpython-38.pyc b/my_half_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_onnx.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..45e2f788ae6219b0bbe73d4664ea80d7bc0ec70d
Binary files /dev/null and b/my_half_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_onnx.cpython-38.pyc differ
diff --git a/my_half_diffusers/pipelines/stable_diffusion/__pycache__/safety_checker.cpython-38.pyc b/my_half_diffusers/pipelines/stable_diffusion/__pycache__/safety_checker.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aef909e3882608dcd1304dd8d3b9cd3b90def13b
Binary files /dev/null and b/my_half_diffusers/pipelines/stable_diffusion/__pycache__/safety_checker.cpython-38.pyc differ
diff --git a/my_half_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/my_half_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..f02fa114a8e1607136fd1c8247e3cabb763b4415
--- /dev/null
+++ b/my_half_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
@@ -0,0 +1,279 @@
+import inspect
+import warnings
+from typing import List, Optional, Union
+
+import torch
+
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...pipeline_utils import DiffusionPipeline
+from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from . import StableDiffusionPipelineOutput
+from .safety_checker import StableDiffusionSafetyChecker
+
+
+class StableDiffusionPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offsensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ ):
+ super().__init__()
+ scheduler = scheduler.set_format("pt")
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ self.unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ eta: Optional[float] = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ if "torch_device" in kwargs:
+ device = kwargs.pop("torch_device")
+ warnings.warn(
+ "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
+ " Consider using `pipe.to(torch_device)` instead."
+ )
+
+ # Set device as before (to be removed in 0.3.0)
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.to(device)
+
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ # get prompt text embeddings
+ text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ max_length = text_input.input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # get the initial random noise unless the user supplied it
+
+ # Unlike in other pipelines, latents need to be generated in the target device
+ # for 1-to-1 results reproducibility with the CompVis implementation.
+ # However this currently doesn't work in `mps`.
+ latents_device = "cpu" if self.device.type == "mps" else self.device
+ latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
+ if latents is None:
+ latents = torch.randn(
+ latents_shape,
+ generator=generator,
+ device=latents_device,
+ )
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+ latents = latents.to(self.device)
+
+ # set timesteps
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
+ extra_set_kwargs = {}
+ if accepts_offset:
+ extra_set_kwargs["offset"] = 1
+
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
+
+ # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = latents * self.scheduler.sigmas[0]
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ sigma = self.scheduler.sigmas[i]
+ # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
+ latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
+ else:
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # scale and decode the image latents with vae
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+
+ # run safety checker
+ safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
+ image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/my_half_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/my_half_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..475ceef4f002f80842c4b25352a504f6b957db55
--- /dev/null
+++ b/my_half_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
@@ -0,0 +1,291 @@
+import inspect
+from typing import List, Optional, Union
+
+import numpy as np
+import torch
+
+import PIL
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...pipeline_utils import DiffusionPipeline
+from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from . import StableDiffusionPipelineOutput
+from .safety_checker import StableDiffusionSafetyChecker
+
+
+def preprocess(image):
+ w, h = image.size
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h), resample=PIL.Image.LANCZOS)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ return 2.0 * image - 1.0
+
+
+class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-guided image to image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offsensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ ):
+ super().__init__()
+ scheduler = scheduler.set_format("pt")
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ self.unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `set_attention_slice`
+ self.enable_attention_slice(None)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ init_image: Union[torch.FloatTensor, PIL.Image.Image],
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ eta: Optional[float] = 0.0,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ init_image (`torch.FloatTensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1.
+ `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter will be modulated by `strength`.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ # set timesteps
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
+ extra_set_kwargs = {}
+ offset = 0
+ if accepts_offset:
+ offset = 1
+ extra_set_kwargs["offset"] = 1
+
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
+
+ if not isinstance(init_image, torch.FloatTensor):
+ init_image = preprocess(init_image)
+
+ # encode the init image into latents and scale the latents
+ init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
+ init_latents = init_latent_dist.sample(generator=generator)
+ init_latents = 0.18215 * init_latents
+
+ # expand init_latents for batch_size
+ init_latents = torch.cat([init_latents] * batch_size)
+
+ # get the original timestep using init_timestep
+ init_timestep = int(num_inference_steps * strength) + offset
+ init_timestep = min(init_timestep, num_inference_steps)
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ timesteps = torch.tensor(
+ [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
+ )
+ else:
+ timesteps = self.scheduler.timesteps[-init_timestep]
+ timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
+
+ # add noise to latents using the timesteps
+ noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
+ init_latents = self.scheduler.add_noise(init_latents, noise, timesteps).to(self.device)
+
+ # get prompt text embeddings
+ text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ max_length = text_input.input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ latents = init_latents
+
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps[t_start:])):
+ t_index = t_start + i
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ sigma = self.scheduler.sigmas[t_index]
+ # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
+ latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
+ latent_model_input = latent_model_input.to(self.unet.dtype)
+ t = t.to(self.unet.dtype)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
+ else:
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # scale and decode the image latents with vae
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents.to(self.vae.dtype)).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+
+ # run safety checker
+ safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
+ image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/my_half_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/my_half_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
new file mode 100644
index 0000000000000000000000000000000000000000..05ea84ae0326231fa2ffbd4ad936f8747a9fed2c
--- /dev/null
+++ b/my_half_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
@@ -0,0 +1,309 @@
+import inspect
+from typing import List, Optional, Union
+
+import numpy as np
+import torch
+
+import PIL
+from tqdm.auto import tqdm
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...pipeline_utils import DiffusionPipeline
+from ...schedulers import DDIMScheduler, PNDMScheduler
+from ...utils import logging
+from . import StableDiffusionPipelineOutput
+from .safety_checker import StableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__)
+
+
+def preprocess_image(image):
+ w, h = image.size
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h), resample=PIL.Image.LANCZOS)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ return 2.0 * image - 1.0
+
+
+def preprocess_mask(mask):
+ mask = mask.convert("L")
+ w, h = mask.size
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
+ mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
+ mask = np.array(mask).astype(np.float32) / 255.0
+ mask = np.tile(mask, (4, 1, 1))
+ mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
+ mask = 1 - mask # repaint white, keep black
+ mask = torch.from_numpy(mask)
+ return mask
+
+
+class StableDiffusionInpaintPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offsensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ ):
+ super().__init__()
+ scheduler = scheduler.set_format("pt")
+ logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ self.unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `set_attention_slice`
+ self.enable_attention_slice(None)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ init_image: Union[torch.FloatTensor, PIL.Image.Image],
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image],
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ eta: Optional[float] = 0.0,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ init_image (`torch.FloatTensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process. This is the image whose masked region will be inpainted.
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
+ replaced by noise and therefore repainted, while black pixels will be preserved. The mask image will be
+ converted to a single channel (luminance) before use.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
+ is 1, the denoising process will be run on the masked area for the full number of iterations specified
+ in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more
+ noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
+ the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ # set timesteps
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
+ extra_set_kwargs = {}
+ offset = 0
+ if accepts_offset:
+ offset = 1
+ extra_set_kwargs["offset"] = 1
+
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
+
+ # preprocess image
+ init_image = preprocess_image(init_image).to(self.device)
+
+ # encode the init image into latents and scale the latents
+ init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
+ init_latents = init_latent_dist.sample(generator=generator)
+
+ init_latents = 0.18215 * init_latents
+
+ # Expand init_latents for batch_size
+ init_latents = torch.cat([init_latents] * batch_size)
+ init_latents_orig = init_latents
+
+ # preprocess mask
+ mask = preprocess_mask(mask_image).to(self.device)
+ mask = torch.cat([mask] * batch_size)
+
+ # check sizes
+ if not mask.shape == init_latents.shape:
+ raise ValueError("The mask and init_image should be the same size!")
+
+ # get the original timestep using init_timestep
+ init_timestep = int(num_inference_steps * strength) + offset
+ init_timestep = min(init_timestep, num_inference_steps)
+ timesteps = self.scheduler.timesteps[-init_timestep]
+ timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
+
+ # add noise to latents using the timesteps
+ noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
+ init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
+
+ # get prompt text embeddings
+ text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ max_length = text_input.input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ latents = init_latents
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
+ for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # masking
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
+
+ # scale and decode the image latents with vae
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+
+ # run safety checker
+ safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
+ image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/my_half_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/my_half_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ff3ff22fc21014fa7b6c12fba96a2ca36fc9cc4
--- /dev/null
+++ b/my_half_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py
@@ -0,0 +1,165 @@
+import inspect
+from typing import List, Optional, Union
+
+import numpy as np
+
+from transformers import CLIPFeatureExtractor, CLIPTokenizer
+
+from ...onnx_utils import OnnxRuntimeModel
+from ...pipeline_utils import DiffusionPipeline
+from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from . import StableDiffusionPipelineOutput
+
+
+class StableDiffusionOnnxPipeline(DiffusionPipeline):
+ vae_decoder: OnnxRuntimeModel
+ text_encoder: OnnxRuntimeModel
+ tokenizer: CLIPTokenizer
+ unet: OnnxRuntimeModel
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
+ safety_checker: OnnxRuntimeModel
+ feature_extractor: CLIPFeatureExtractor
+
+ def __init__(
+ self,
+ vae_decoder: OnnxRuntimeModel,
+ text_encoder: OnnxRuntimeModel,
+ tokenizer: CLIPTokenizer,
+ unet: OnnxRuntimeModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: OnnxRuntimeModel,
+ feature_extractor: CLIPFeatureExtractor,
+ ):
+ super().__init__()
+ scheduler = scheduler.set_format("np")
+ self.register_modules(
+ vae_decoder=vae_decoder,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ eta: Optional[float] = 0.0,
+ latents: Optional[np.ndarray] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ):
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ # get prompt text embeddings
+ text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ text_embeddings = self.text_encoder(input_ids=text_input.input_ids.astype(np.int32))[0]
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ max_length = text_input.input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
+ )
+ uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
+
+ # get the initial random noise unless the user supplied it
+ latents_shape = (batch_size, 4, height // 8, width // 8)
+ if latents is None:
+ latents = np.random.randn(*latents_shape).astype(np.float32)
+ elif latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+
+ # set timesteps
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
+ extra_set_kwargs = {}
+ if accepts_offset:
+ extra_set_kwargs["offset"] = 1
+
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
+
+ # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = latents * self.scheduler.sigmas[0]
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ sigma = self.scheduler.sigmas[i]
+ # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
+ latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings
+ )
+ noise_pred = noise_pred[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
+ else:
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # scale and decode the image latents with vae
+ latents = 1 / 0.18215 * latents
+ image = self.vae_decoder(latent_sample=latents)[0]
+
+ image = np.clip(image / 2 + 0.5, 0, 1)
+ image = image.transpose((0, 2, 3, 1))
+
+ # run safety checker
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
+ image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image)
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/my_half_diffusers/pipelines/stable_diffusion/safety_checker.py b/my_half_diffusers/pipelines/stable_diffusion/safety_checker.py
new file mode 100644
index 0000000000000000000000000000000000000000..09de92eeb1ec7e64863839012b1eddba444ad80a
--- /dev/null
+++ b/my_half_diffusers/pipelines/stable_diffusion/safety_checker.py
@@ -0,0 +1,106 @@
+import numpy as np
+import torch
+import torch.nn as nn
+
+from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
+
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+def cosine_distance(image_embeds, text_embeds):
+ normalized_image_embeds = nn.functional.normalize(image_embeds)
+ normalized_text_embeds = nn.functional.normalize(text_embeds)
+ return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
+
+
+class StableDiffusionSafetyChecker(PreTrainedModel):
+ config_class = CLIPConfig
+
+ def __init__(self, config: CLIPConfig):
+ super().__init__(config)
+
+ self.vision_model = CLIPVisionModel(config.vision_config)
+ self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False)
+
+ self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
+ self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)
+
+ self.register_buffer("concept_embeds_weights", torch.ones(17))
+ self.register_buffer("special_care_embeds_weights", torch.ones(3))
+
+ @torch.no_grad()
+ def forward(self, clip_input, images):
+ pooled_output = self.vision_model(clip_input)[1] # pooled_output
+ image_embeds = self.visual_projection(pooled_output)
+
+ special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().numpy()
+ cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().numpy()
+
+ result = []
+ batch_size = image_embeds.shape[0]
+ for i in range(batch_size):
+ result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}
+
+ # increase this value to create a stronger `nfsw` filter
+ # at the cost of increasing the possibility of filtering benign images
+ adjustment = 0.0
+
+ for concet_idx in range(len(special_cos_dist[0])):
+ concept_cos = special_cos_dist[i][concet_idx]
+ concept_threshold = self.special_care_embeds_weights[concet_idx].item()
+ result_img["special_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3)
+ if result_img["special_scores"][concet_idx] > 0:
+ result_img["special_care"].append({concet_idx, result_img["special_scores"][concet_idx]})
+ adjustment = 0.01
+
+ for concet_idx in range(len(cos_dist[0])):
+ concept_cos = cos_dist[i][concet_idx]
+ concept_threshold = self.concept_embeds_weights[concet_idx].item()
+ result_img["concept_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3)
+ if result_img["concept_scores"][concet_idx] > 0:
+ result_img["bad_concepts"].append(concet_idx)
+
+ result.append(result_img)
+
+ has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
+
+ for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
+ if has_nsfw_concept:
+ images[idx] = np.zeros(images[idx].shape) # black image
+
+ if any(has_nsfw_concepts):
+ logger.warning(
+ "Potential NSFW content was detected in one or more images. A black image will be returned instead."
+ " Try again with a different prompt and/or seed."
+ )
+
+ return images, has_nsfw_concepts
+
+ @torch.inference_mode()
+ def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor):
+ pooled_output = self.vision_model(clip_input)[1] # pooled_output
+ image_embeds = self.visual_projection(pooled_output)
+
+ special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds)
+ cos_dist = cosine_distance(image_embeds, self.concept_embeds)
+
+ # increase this value to create a stronger `nsfw` filter
+ # at the cost of increasing the possibility of filtering benign images
+ adjustment = 0.0
+
+ special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment
+ # special_scores = special_scores.round(decimals=3)
+ special_care = torch.any(special_scores > 0, dim=1)
+ special_adjustment = special_care * 0.01
+ special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1])
+
+ concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
+ # concept_scores = concept_scores.round(decimals=3)
+ has_nsfw_concepts = torch.any(concept_scores > 0, dim=1)
+
+ images[has_nsfw_concepts] = 0.0 # black image
+
+ return images, has_nsfw_concepts
diff --git a/my_half_diffusers/pipelines/stochastic_karras_ve/__init__.py b/my_half_diffusers/pipelines/stochastic_karras_ve/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..db2582043781130794e01b96b3e6beecbfe9f369
--- /dev/null
+++ b/my_half_diffusers/pipelines/stochastic_karras_ve/__init__.py
@@ -0,0 +1,2 @@
+# flake8: noqa
+from .pipeline_stochastic_karras_ve import KarrasVePipeline
diff --git a/my_half_diffusers/pipelines/stochastic_karras_ve/__pycache__/__init__.cpython-38.pyc b/my_half_diffusers/pipelines/stochastic_karras_ve/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fa14fb74353a117bf78b8d019e298df3139dcbb1
Binary files /dev/null and b/my_half_diffusers/pipelines/stochastic_karras_ve/__pycache__/__init__.cpython-38.pyc differ
diff --git a/my_half_diffusers/pipelines/stochastic_karras_ve/__pycache__/pipeline_stochastic_karras_ve.cpython-38.pyc b/my_half_diffusers/pipelines/stochastic_karras_ve/__pycache__/pipeline_stochastic_karras_ve.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0c3e4dfedc7edcd5ae3008a305ef494eec0eb0ef
Binary files /dev/null and b/my_half_diffusers/pipelines/stochastic_karras_ve/__pycache__/pipeline_stochastic_karras_ve.cpython-38.pyc differ
diff --git a/my_half_diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py b/my_half_diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py
new file mode 100644
index 0000000000000000000000000000000000000000..15266544db7c8bc7448405955d74396eef7fe950
--- /dev/null
+++ b/my_half_diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py
@@ -0,0 +1,129 @@
+#!/usr/bin/env python3
+import warnings
+from typing import Optional, Tuple, Union
+
+import torch
+
+from ...models import UNet2DModel
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ...schedulers import KarrasVeScheduler
+
+
+class KarrasVePipeline(DiffusionPipeline):
+ r"""
+ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
+ the VE column of Table 1 from [1] for reference.
+
+ [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
+ https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
+ differential equations." https://arxiv.org/abs/2011.13456
+
+ Parameters:
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
+ scheduler ([`KarrasVeScheduler`]):
+ Scheduler for the diffusion process to be used in combination with `unet` to denoise the encoded image.
+ """
+
+ # add type hints for linting
+ unet: UNet2DModel
+ scheduler: KarrasVeScheduler
+
+ def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler):
+ super().__init__()
+ scheduler = scheduler.set_format("pt")
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ num_inference_steps: int = 50,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[Tuple, ImagePipelineOutput]:
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of images to generate.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+ if "torch_device" in kwargs:
+ device = kwargs.pop("torch_device")
+ warnings.warn(
+ "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
+ " Consider using `pipe.to(torch_device)` instead."
+ )
+
+ # Set device as before (to be removed in 0.3.0)
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.to(device)
+
+ img_size = self.unet.config.sample_size
+ shape = (batch_size, 3, img_size, img_size)
+
+ model = self.unet
+
+ # sample x_0 ~ N(0, sigma_0^2 * I)
+ sample = torch.randn(*shape) * self.scheduler.config.sigma_max
+ sample = sample.to(self.device)
+
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ # here sigma_t == t_i from the paper
+ sigma = self.scheduler.schedule[t]
+ sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0
+
+ # 1. Select temporarily increased noise level sigma_hat
+ # 2. Add new noise to move from sample_i to sample_hat
+ sample_hat, sigma_hat = self.scheduler.add_noise_to_input(sample, sigma, generator=generator)
+
+ # 3. Predict the noise residual given the noise magnitude `sigma_hat`
+ # The model inputs and output are adjusted by following eq. (213) in [1].
+ model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2).sample
+
+ # 4. Evaluate dx/dt at sigma_hat
+ # 5. Take Euler step from sigma to sigma_prev
+ step_output = self.scheduler.step(model_output, sigma_hat, sigma_prev, sample_hat)
+
+ if sigma_prev != 0:
+ # 6. Apply 2nd order correction
+ # The model inputs and output are adjusted by following eq. (213) in [1].
+ model_output = (sigma_prev / 2) * model((step_output.prev_sample + 1) / 2, sigma_prev / 2).sample
+ step_output = self.scheduler.step_correct(
+ model_output,
+ sigma_hat,
+ sigma_prev,
+ sample_hat,
+ step_output.prev_sample,
+ step_output["derivative"],
+ )
+ sample = step_output.prev_sample
+
+ sample = (sample / 2 + 0.5).clamp(0, 1)
+ image = sample.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(sample)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/my_half_diffusers/schedulers/__init__.py b/my_half_diffusers/schedulers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..20c25f35183faeeef2cd7b5095f80a70a9edac01
--- /dev/null
+++ b/my_half_diffusers/schedulers/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..utils import is_scipy_available
+from .scheduling_ddim import DDIMScheduler
+from .scheduling_ddpm import DDPMScheduler
+from .scheduling_karras_ve import KarrasVeScheduler
+from .scheduling_pndm import PNDMScheduler
+from .scheduling_sde_ve import ScoreSdeVeScheduler
+from .scheduling_sde_vp import ScoreSdeVpScheduler
+from .scheduling_utils import SchedulerMixin
+
+
+if is_scipy_available():
+ from .scheduling_lms_discrete import LMSDiscreteScheduler
+else:
+ from ..utils.dummy_scipy_objects import * # noqa F403
diff --git a/my_half_diffusers/schedulers/__pycache__/__init__.cpython-38.pyc b/my_half_diffusers/schedulers/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ef687b81333e055df491e729a325b8d6d8c032ef
Binary files /dev/null and b/my_half_diffusers/schedulers/__pycache__/__init__.cpython-38.pyc differ
diff --git a/my_half_diffusers/schedulers/__pycache__/scheduling_ddim.cpython-38.pyc b/my_half_diffusers/schedulers/__pycache__/scheduling_ddim.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..264ac7d50ab74136fed15308377ee028ecbdfaa1
Binary files /dev/null and b/my_half_diffusers/schedulers/__pycache__/scheduling_ddim.cpython-38.pyc differ
diff --git a/my_half_diffusers/schedulers/__pycache__/scheduling_ddpm.cpython-38.pyc b/my_half_diffusers/schedulers/__pycache__/scheduling_ddpm.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..855e2732a0910fff880e1a929c414fa65d529e68
Binary files /dev/null and b/my_half_diffusers/schedulers/__pycache__/scheduling_ddpm.cpython-38.pyc differ
diff --git a/my_half_diffusers/schedulers/__pycache__/scheduling_karras_ve.cpython-38.pyc b/my_half_diffusers/schedulers/__pycache__/scheduling_karras_ve.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c2c338c9a9372cd29c4afdf0a4fa2222245c1bc5
Binary files /dev/null and b/my_half_diffusers/schedulers/__pycache__/scheduling_karras_ve.cpython-38.pyc differ
diff --git a/my_half_diffusers/schedulers/__pycache__/scheduling_lms_discrete.cpython-38.pyc b/my_half_diffusers/schedulers/__pycache__/scheduling_lms_discrete.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3bc0bd3e8bb2a1433ae4af37d7d746a1c3fde14a
Binary files /dev/null and b/my_half_diffusers/schedulers/__pycache__/scheduling_lms_discrete.cpython-38.pyc differ
diff --git a/my_half_diffusers/schedulers/__pycache__/scheduling_pndm.cpython-38.pyc b/my_half_diffusers/schedulers/__pycache__/scheduling_pndm.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..95dd92d6ac4830ac48c78d4071899ca6a52bf285
Binary files /dev/null and b/my_half_diffusers/schedulers/__pycache__/scheduling_pndm.cpython-38.pyc differ
diff --git a/my_half_diffusers/schedulers/__pycache__/scheduling_sde_ve.cpython-38.pyc b/my_half_diffusers/schedulers/__pycache__/scheduling_sde_ve.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fa0ce4941d8a5a6334cbdc69e9f40749889e4acb
Binary files /dev/null and b/my_half_diffusers/schedulers/__pycache__/scheduling_sde_ve.cpython-38.pyc differ
diff --git a/my_half_diffusers/schedulers/__pycache__/scheduling_sde_vp.cpython-38.pyc b/my_half_diffusers/schedulers/__pycache__/scheduling_sde_vp.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1edac39159a9478e4828215a17bf07408e7051f6
Binary files /dev/null and b/my_half_diffusers/schedulers/__pycache__/scheduling_sde_vp.cpython-38.pyc differ
diff --git a/my_half_diffusers/schedulers/__pycache__/scheduling_utils.cpython-38.pyc b/my_half_diffusers/schedulers/__pycache__/scheduling_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..beb23ef10d32cfcf3c5b00a94c937836c76c64b1
Binary files /dev/null and b/my_half_diffusers/schedulers/__pycache__/scheduling_utils.cpython-38.pyc differ
diff --git a/my_half_diffusers/schedulers/scheduling_ddim.py b/my_half_diffusers/schedulers/scheduling_ddim.py
new file mode 100644
index 0000000000000000000000000000000000000000..ccfb0f7e648acc81750a98d317a03de715633588
--- /dev/null
+++ b/my_half_diffusers/schedulers/scheduling_ddim.py
@@ -0,0 +1,270 @@
+# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
+# and https://github.com/hojonathanho/diffusion
+
+import math
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils import SchedulerMixin, SchedulerOutput
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+
+ def alpha_bar(time_step):
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas, dtype=np.float64)
+
+
+class DDIMScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
+ diffusion probabilistic models (DDPMs) with non-Markovian guidance.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
+ [`~ConfigMixin.from_config`] functios.
+
+ For more details, see the original paper: https://arxiv.org/abs/2010.02502
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, optional): TODO
+ timestep_values (`np.ndarray`, optional): TODO
+ clip_sample (`bool`, default `True`):
+ option to clip predicted sample between -1 and 1 for numerical stability.
+ set_alpha_to_one (`bool`, default `True`):
+ if alpha for final step is 1 or the final alpha of the "non-previous" one.
+ tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
+
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[np.ndarray] = None,
+ timestep_values: Optional[np.ndarray] = None,
+ clip_sample: bool = True,
+ set_alpha_to_one: bool = True,
+ tensor_format: str = "pt",
+ ):
+ if trained_betas is not None:
+ self.betas = np.asarray(trained_betas)
+ if beta_schedule == "linear":
+ self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float64)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float64) ** 2
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
+
+ # At every step in ddim, we are looking into the previous alphas_cumprod
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
+ # `set_alpha_to_one` decides whether we set this paratemer simply to one or
+ # whether we use the final alpha of the "non-previous" one.
+ self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
+
+ self.tensor_format = tensor_format
+ self.set_format(tensor_format=tensor_format)
+
+ # print(self.alphas.shape)
+
+
+ def _get_variance(self, timestep, prev_timestep):
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+
+ return variance
+
+ def set_timesteps(self, num_inference_steps: int, offset: int = 0):
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ offset (`int`): TODO
+ """
+ self.num_inference_steps = num_inference_steps
+ if num_inference_steps <= 1000:
+ self.timesteps = np.arange(
+ 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
+ )[::-1].copy()
+ else:
+ print("Hitting new logic, allowing fractional timesteps")
+ self.timesteps = np.linspace(
+ 0, self.config.num_train_timesteps-1, self.num_inference_steps, endpoint=True
+ )[::-1].copy()
+ self.timesteps += offset
+ self.set_format(tensor_format=self.tensor_format)
+
+ def step(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: int,
+ sample: Union[torch.FloatTensor, np.ndarray],
+ eta: float = 0.0,
+ use_clipped_model_output: bool = False,
+ generator=None,
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ eta (`float`): weight of noise for added noise in diffusion step.
+ use_clipped_model_output (`bool`): TODO
+ generator: random number generator.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
+ # Ideally, read DDIM paper in-detail understanding
+
+ # Notation ( ->
+ # - pred_noise_t -> e_theta(x_t, t)
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
+ # - std_dev_t -> sigma_t
+ # - eta -> η
+ # - pred_sample_direction -> "direction pointingc to x_t"
+ # - pred_prev_sample -> "x_t-1"
+
+ # 1. get previous step value (=t-1)
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
+
+ # 2. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+
+ # 3. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+
+ # 4. Clip "predicted x_0"
+ if self.config.clip_sample:
+ pred_original_sample = self.clip(pred_original_sample, -1, 1)
+
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
+ variance = self._get_variance(timestep, prev_timestep)
+ std_dev_t = eta * variance ** (0.5)
+
+ if use_clipped_model_output:
+ # the model_output is always re-derived from the clipped x_0 in Glide
+ model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
+
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
+
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
+
+ if eta > 0:
+ device = model_output.device if torch.is_tensor(model_output) else "cpu"
+ noise = torch.randn(model_output.shape, generator=generator).to(device)
+ variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
+
+ if not torch.is_tensor(model_output):
+ variance = variance.numpy()
+
+ prev_sample = prev_sample + variance
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def add_noise(
+ self,
+ original_samples: Union[torch.FloatTensor, np.ndarray],
+ noise: Union[torch.FloatTensor, np.ndarray],
+ timesteps: Union[torch.IntTensor, np.ndarray],
+ ) -> Union[torch.FloatTensor, np.ndarray]:
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/my_half_diffusers/schedulers/scheduling_ddpm.py b/my_half_diffusers/schedulers/scheduling_ddpm.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fbfb90383361ece4e82aa10a499c8dc58113794
--- /dev/null
+++ b/my_half_diffusers/schedulers/scheduling_ddpm.py
@@ -0,0 +1,264 @@
+# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
+
+import math
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils import SchedulerMixin, SchedulerOutput
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+
+ def alpha_bar(time_step):
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas, dtype=np.float32)
+
+
+class DDPMScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
+ Langevin dynamics sampling.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
+ [`~ConfigMixin.from_config`] functios.
+
+ For more details, see the original paper: https://arxiv.org/abs/2006.11239
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, optional): TODO
+ variance_type (`str`):
+ options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
+ `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
+ clip_sample (`bool`, default `True`):
+ option to clip predicted sample between -1 and 1 for numerical stability.
+ tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
+
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[np.ndarray] = None,
+ variance_type: str = "fixed_small",
+ clip_sample: bool = True,
+ tensor_format: str = "pt",
+ ):
+
+ if trained_betas is not None:
+ self.betas = np.asarray(trained_betas)
+ elif beta_schedule == "linear":
+ self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
+ self.one = np.array(1.0)
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
+
+ self.tensor_format = tensor_format
+ self.set_format(tensor_format=tensor_format)
+
+ self.variance_type = variance_type
+
+ def set_timesteps(self, num_inference_steps: int):
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ """
+ num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
+ self.num_inference_steps = num_inference_steps
+ self.timesteps = np.arange(
+ 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
+ )[::-1].copy()
+ self.set_format(tensor_format=self.tensor_format)
+
+ def _get_variance(self, t, predicted_variance=None, variance_type=None):
+ alpha_prod_t = self.alphas_cumprod[t]
+ alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
+
+ # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
+ # and sample from it to get previous sample
+ # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
+ variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
+
+ if variance_type is None:
+ variance_type = self.config.variance_type
+
+ # hacks - were probs added for training stability
+ if variance_type == "fixed_small":
+ variance = self.clip(variance, min_value=1e-20)
+ # for rl-diffuser https://arxiv.org/abs/2205.09991
+ elif variance_type == "fixed_small_log":
+ variance = self.log(self.clip(variance, min_value=1e-20))
+ elif variance_type == "fixed_large":
+ variance = self.betas[t]
+ elif variance_type == "fixed_large_log":
+ # Glide max_log
+ variance = self.log(self.betas[t])
+ elif variance_type == "learned":
+ return predicted_variance
+ elif variance_type == "learned_range":
+ min_log = variance
+ max_log = self.betas[t]
+ frac = (predicted_variance + 1) / 2
+ variance = frac * max_log + (1 - frac) * min_log
+
+ return variance
+
+ def step(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: int,
+ sample: Union[torch.FloatTensor, np.ndarray],
+ predict_epsilon=True,
+ generator=None,
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ eta (`float`): weight of noise for added noise in diffusion step.
+ predict_epsilon (`bool`):
+ optional flag to use when model predicts the samples directly instead of the noise, epsilon.
+ generator: random number generator.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+ t = timestep
+
+ if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
+ model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
+ else:
+ predicted_variance = None
+
+ # 1. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[t]
+ alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ # 2. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
+ if predict_epsilon:
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ else:
+ pred_original_sample = model_output
+
+ # 3. Clip "predicted x_0"
+ if self.config.clip_sample:
+ pred_original_sample = self.clip(pred_original_sample, -1, 1)
+
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t
+ current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
+
+ # 5. Compute predicted previous sample µ_t
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
+
+ # 6. Add noise
+ variance = 0
+ if t > 0:
+ noise = self.randn_like(model_output, generator=generator)
+ variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise
+
+ pred_prev_sample = pred_prev_sample + variance
+
+ if not return_dict:
+ return (pred_prev_sample,)
+
+ return SchedulerOutput(prev_sample=pred_prev_sample)
+
+ def add_noise(
+ self,
+ original_samples: Union[torch.FloatTensor, np.ndarray],
+ noise: Union[torch.FloatTensor, np.ndarray],
+ timesteps: Union[torch.IntTensor, np.ndarray],
+ ) -> Union[torch.FloatTensor, np.ndarray]:
+
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/my_half_diffusers/schedulers/scheduling_karras_ve.py b/my_half_diffusers/schedulers/scheduling_karras_ve.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a2370cfc3e0523dfba48703bcd0c3e9a42b2381
--- /dev/null
+++ b/my_half_diffusers/schedulers/scheduling_karras_ve.py
@@ -0,0 +1,208 @@
+# Copyright 2022 NVIDIA and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput
+from .scheduling_utils import SchedulerMixin
+
+
+@dataclass
+class KarrasVeOutput(BaseOutput):
+ """
+ Output class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Derivate of predicted original image sample (x_0).
+ """
+
+ prev_sample: torch.FloatTensor
+ derivative: torch.FloatTensor
+
+
+class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
+ the VE column of Table 1 from [1] for reference.
+
+ [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
+ https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
+ differential equations." https://arxiv.org/abs/2011.13456
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
+ [`~ConfigMixin.from_config`] functios.
+
+ For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
+ Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
+ optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
+
+ Args:
+ sigma_min (`float`): minimum noise magnitude
+ sigma_max (`float`): maximum noise magnitude
+ s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling.
+ A reasonable range is [1.000, 1.011].
+ s_churn (`float`): the parameter controlling the overall amount of stochasticity.
+ A reasonable range is [0, 100].
+ s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity).
+ A reasonable range is [0, 10].
+ s_max (`float`): the end value of the sigma range where we add noise.
+ A reasonable range is [0.2, 80].
+ tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
+
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ sigma_min: float = 0.02,
+ sigma_max: float = 100,
+ s_noise: float = 1.007,
+ s_churn: float = 80,
+ s_min: float = 0.05,
+ s_max: float = 50,
+ tensor_format: str = "pt",
+ ):
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = None
+ self.schedule = None # sigma(t_i)
+
+ self.tensor_format = tensor_format
+ self.set_format(tensor_format=tensor_format)
+
+ def set_timesteps(self, num_inference_steps: int):
+ """
+ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+
+ """
+ self.num_inference_steps = num_inference_steps
+ self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
+ self.schedule = [
+ (self.sigma_max * (self.sigma_min**2 / self.sigma_max**2) ** (i / (num_inference_steps - 1)))
+ for i in self.timesteps
+ ]
+ self.schedule = np.array(self.schedule, dtype=np.float32)
+
+ self.set_format(tensor_format=self.tensor_format)
+
+ def add_noise_to_input(
+ self, sample: Union[torch.FloatTensor, np.ndarray], sigma: float, generator: Optional[torch.Generator] = None
+ ) -> Tuple[Union[torch.FloatTensor, np.ndarray], float]:
+ """
+ Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
+ higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
+
+ TODO Args:
+ """
+ if self.s_min <= sigma <= self.s_max:
+ gamma = min(self.s_churn / self.num_inference_steps, 2**0.5 - 1)
+ else:
+ gamma = 0
+
+ # sample eps ~ N(0, S_noise^2 * I)
+ eps = self.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device)
+ sigma_hat = sigma + gamma * sigma
+ sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)
+
+ return sample_hat, sigma_hat
+
+ def step(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ sigma_hat: float,
+ sigma_prev: float,
+ sample_hat: Union[torch.FloatTensor, np.ndarray],
+ return_dict: bool = True,
+ ) -> Union[KarrasVeOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ sigma_hat (`float`): TODO
+ sigma_prev (`float`): TODO
+ sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check).
+ Returns:
+ [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] or `tuple`:
+ [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+
+ pred_original_sample = sample_hat + sigma_hat * model_output
+ derivative = (sample_hat - pred_original_sample) / sigma_hat
+ sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative
+
+ if not return_dict:
+ return (sample_prev, derivative)
+
+ return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative)
+
+ def step_correct(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ sigma_hat: float,
+ sigma_prev: float,
+ sample_hat: Union[torch.FloatTensor, np.ndarray],
+ sample_prev: Union[torch.FloatTensor, np.ndarray],
+ derivative: Union[torch.FloatTensor, np.ndarray],
+ return_dict: bool = True,
+ ) -> Union[KarrasVeOutput, Tuple]:
+ """
+ Correct the predicted sample based on the output model_output of the network. TODO complete description
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ sigma_hat (`float`): TODO
+ sigma_prev (`float`): TODO
+ sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
+ sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
+ derivative (`torch.FloatTensor` or `np.ndarray`): TODO
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
+
+ """
+ pred_original_sample = sample_prev + sigma_prev * model_output
+ derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
+ sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
+
+ if not return_dict:
+ return (sample_prev, derivative)
+
+ return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative)
+
+ def add_noise(self, original_samples, noise, timesteps):
+ raise NotImplementedError()
diff --git a/my_half_diffusers/schedulers/scheduling_lms_discrete.py b/my_half_diffusers/schedulers/scheduling_lms_discrete.py
new file mode 100644
index 0000000000000000000000000000000000000000..1381587febf16d9c774b5f2574653c962e031a46
--- /dev/null
+++ b/my_half_diffusers/schedulers/scheduling_lms_discrete.py
@@ -0,0 +1,193 @@
+# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from scipy import integrate
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils import SchedulerMixin, SchedulerOutput
+
+
+class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
+ Katherine Crowson:
+ https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
+ [`~ConfigMixin.from_config`] functios.
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear` or `scaled_linear`.
+ trained_betas (`np.ndarray`, optional): TODO
+ options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
+ `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
+ timestep_values (`np.ndarry`, optional): TODO
+ tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
+
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[np.ndarray] = None,
+ timestep_values: Optional[np.ndarray] = None,
+ tensor_format: str = "pt",
+ ):
+ if trained_betas is not None:
+ self.betas = np.asarray(trained_betas)
+ if beta_schedule == "linear":
+ self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
+
+ self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
+ self.derivatives = []
+
+ self.tensor_format = tensor_format
+ self.set_format(tensor_format=tensor_format)
+
+ def get_lms_coefficient(self, order, t, current_order):
+ """
+ Compute a linear multistep coefficient.
+
+ Args:
+ order (TODO):
+ t (TODO):
+ current_order (TODO):
+ """
+
+ def lms_derivative(tau):
+ prod = 1.0
+ for k in range(order):
+ if current_order == k:
+ continue
+ prod *= (tau - self.sigmas[t - k]) / (self.sigmas[t - current_order] - self.sigmas[t - k])
+ return prod
+
+ integrated_coeff = integrate.quad(lms_derivative, self.sigmas[t], self.sigmas[t + 1], epsrel=1e-4)[0]
+
+ return integrated_coeff
+
+ def set_timesteps(self, num_inference_steps: int):
+ """
+ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ """
+ self.num_inference_steps = num_inference_steps
+ self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
+
+ low_idx = np.floor(self.timesteps).astype(int)
+ high_idx = np.ceil(self.timesteps).astype(int)
+ frac = np.mod(self.timesteps, 1.0)
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
+ self.sigmas = np.concatenate([sigmas, [0.0]])
+
+ self.derivatives = []
+
+ self.set_format(tensor_format=self.tensor_format)
+
+ def step(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: int,
+ sample: Union[torch.FloatTensor, np.ndarray],
+ order: int = 4,
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ order: coefficient for multi-step inference.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+ sigma = self.sigmas[timestep]
+
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
+ pred_original_sample = sample - sigma * model_output
+
+ # 2. Convert to an ODE derivative
+ derivative = (sample - pred_original_sample) / sigma
+ self.derivatives.append(derivative)
+ if len(self.derivatives) > order:
+ self.derivatives.pop(0)
+
+ # 3. Compute linear multistep coefficients
+ order = min(timestep + 1, order)
+ lms_coeffs = [self.get_lms_coefficient(order, timestep, curr_order) for curr_order in range(order)]
+
+ # 4. Compute previous sample based on the derivatives path
+ prev_sample = sample + sum(
+ coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives))
+ )
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def add_noise(
+ self,
+ original_samples: Union[torch.FloatTensor, np.ndarray],
+ noise: Union[torch.FloatTensor, np.ndarray],
+ timesteps: Union[torch.IntTensor, np.ndarray],
+ ) -> Union[torch.FloatTensor, np.ndarray]:
+ sigmas = self.match_shape(self.sigmas[timesteps], noise)
+ noisy_samples = original_samples + noise * sigmas
+
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/my_half_diffusers/schedulers/scheduling_pndm.py b/my_half_diffusers/schedulers/scheduling_pndm.py
new file mode 100644
index 0000000000000000000000000000000000000000..b43d88bbab7745e3e8579cc66f2ee2ed246e52d7
--- /dev/null
+++ b/my_half_diffusers/schedulers/scheduling_pndm.py
@@ -0,0 +1,378 @@
+# Copyright 2022 Zhejiang University Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
+
+import math
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils import SchedulerMixin, SchedulerOutput
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+
+ def alpha_bar(time_step):
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas, dtype=np.float32)
+
+
+class PNDMScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques,
+ namely Runge-Kutta method and a linear multi-step method.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
+ [`~ConfigMixin.from_config`] functios.
+
+ For more details, see the original paper: https://arxiv.org/abs/2202.09778
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, optional): TODO
+ tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays
+ skip_prk_steps (`bool`):
+ allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
+ before plms steps; defaults to `False`.
+
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[np.ndarray] = None,
+ tensor_format: str = "pt",
+ skip_prk_steps: bool = False,
+ ):
+ if trained_betas is not None:
+ self.betas = np.asarray(trained_betas)
+ if beta_schedule == "linear":
+ self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
+
+ self.one = np.array(1.0)
+
+ # For now we only support F-PNDM, i.e. the runge-kutta method
+ # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
+ # mainly at formula (9), (12), (13) and the Algorithm 2.
+ self.pndm_order = 4
+
+ # running values
+ self.cur_model_output = 0
+ self.counter = 0
+ self.cur_sample = None
+ self.ets = []
+
+ # setable values
+ self.num_inference_steps = None
+ self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
+ self._offset = 0
+ self.prk_timesteps = None
+ self.plms_timesteps = None
+ self.timesteps = None
+
+ self.tensor_format = tensor_format
+ self.set_format(tensor_format=tensor_format)
+
+ def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.FloatTensor:
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ offset (`int`): TODO
+ """
+ self.num_inference_steps = num_inference_steps
+ self._timesteps = list(
+ range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
+ )
+ self._offset = offset
+ self._timesteps = np.array([t + self._offset for t in self._timesteps])
+
+ if self.config.skip_prk_steps:
+ # for some models like stable diffusion the prk steps can/should be skipped to
+ # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation
+ # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
+ self.prk_timesteps = np.array([])
+ self.plms_timesteps = np.concatenate([self._timesteps[:-1], self._timesteps[-2:-1], self._timesteps[-1:]])[
+ ::-1
+ ].copy()
+ else:
+ prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
+ np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
+ )
+ self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy()
+ self.plms_timesteps = self._timesteps[:-3][
+ ::-1
+ ].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy
+
+ self.timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64)
+
+ self.ets = []
+ self.counter = 0
+ self.set_format(tensor_format=self.tensor_format)
+
+ def step(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: int,
+ sample: Union[torch.FloatTensor, np.ndarray],
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`.
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps:
+ return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
+ else:
+ return self.step_plms(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
+
+ def step_prk(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: int,
+ sample: Union[torch.FloatTensor, np.ndarray],
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
+ solution to the differential equation.
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
+ True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2
+ prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1])
+ timestep = self.prk_timesteps[self.counter // 4 * 4]
+
+ if self.counter % 4 == 0:
+ self.cur_model_output += 1 / 6 * model_output
+ self.ets.append(model_output)
+ self.cur_sample = sample
+ elif (self.counter - 1) % 4 == 0:
+ self.cur_model_output += 1 / 3 * model_output
+ elif (self.counter - 2) % 4 == 0:
+ self.cur_model_output += 1 / 3 * model_output
+ elif (self.counter - 3) % 4 == 0:
+ model_output = self.cur_model_output + 1 / 6 * model_output
+ self.cur_model_output = 0
+
+ # cur_sample should not be `None`
+ cur_sample = self.cur_sample if self.cur_sample is not None else sample
+
+ prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
+ self.counter += 1
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def step_plms(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: int,
+ sample: Union[torch.FloatTensor, np.ndarray],
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
+ times to approximate the solution.
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
+ True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ if not self.config.skip_prk_steps and len(self.ets) < 3:
+ raise ValueError(
+ f"{self.__class__} can only be run AFTER scheduler has been run "
+ "in 'prk' mode for at least 12 iterations "
+ "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py "
+ "for more information."
+ )
+
+ prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0)
+
+ if self.counter != 1:
+ self.ets.append(model_output)
+ else:
+ prev_timestep = timestep
+ timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps
+
+ if len(self.ets) == 1 and self.counter == 0:
+ model_output = model_output
+ self.cur_sample = sample
+ elif len(self.ets) == 1 and self.counter == 1:
+ model_output = (model_output + self.ets[-1]) / 2
+ sample = self.cur_sample
+ self.cur_sample = None
+ elif len(self.ets) == 2:
+ model_output = (3 * self.ets[-1] - self.ets[-2]) / 2
+ elif len(self.ets) == 3:
+ model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12
+ else:
+ model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
+
+ prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
+ self.counter += 1
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def _get_prev_sample(self, sample, timestep, timestep_prev, model_output):
+ # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
+ # this function computes x_(t−δ) using the formula of (9)
+ # Note that x_t needs to be added to both sides of the equation
+
+ # Notation ( ->
+ # alpha_prod_t -> α_t
+ # alpha_prod_t_prev -> α_(t−δ)
+ # beta_prod_t -> (1 - α_t)
+ # beta_prod_t_prev -> (1 - α_(t−δ))
+ # sample -> x_t
+ # model_output -> e_θ(x_t, t)
+ # prev_sample -> x_(t−δ)
+ alpha_prod_t = self.alphas_cumprod[timestep + 1 - self._offset]
+ alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - self._offset]
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ # corresponds to (α_(t−δ) - α_t) divided by
+ # denominator of x_t in formula (9) and plus 1
+ # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =
+ # sqrt(α_(t−δ)) / sqrt(α_t))
+ sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5)
+
+ # corresponds to denominator of e_θ(x_t, t) in formula (9)
+ model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + (
+ alpha_prod_t * beta_prod_t * alpha_prod_t_prev
+ ) ** (0.5)
+
+ # full formula (9)
+ prev_sample = (
+ sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff
+ )
+
+ return prev_sample
+
+ def add_noise(
+ self,
+ original_samples: Union[torch.FloatTensor, np.ndarray],
+ noise: Union[torch.FloatTensor, np.ndarray],
+ timesteps: Union[torch.IntTensor, np.ndarray],
+ ) -> torch.Tensor:
+ # mps requires indices to be in the same device, so we use cpu as is the default with cuda
+ timesteps = timesteps.to(self.alphas_cumprod.device)
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/my_half_diffusers/schedulers/scheduling_sde_ve.py b/my_half_diffusers/schedulers/scheduling_sde_ve.py
new file mode 100644
index 0000000000000000000000000000000000000000..e187f079688723c991b4b80fa1fd4f358896bb4f
--- /dev/null
+++ b/my_half_diffusers/schedulers/scheduling_sde_ve.py
@@ -0,0 +1,283 @@
+# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
+
+import warnings
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput
+from .scheduling_utils import SchedulerMixin, SchedulerOutput
+
+
+@dataclass
+class SdeVeOutput(BaseOutput):
+ """
+ Output class for the ScoreSdeVeScheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps.
+ """
+
+ prev_sample: torch.FloatTensor
+ prev_sample_mean: torch.FloatTensor
+
+
+class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
+ """
+ The variance exploding stochastic differential equation (SDE) scheduler.
+
+ For more information, see the original paper: https://arxiv.org/abs/2011.13456
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
+ [`~ConfigMixin.from_config`] functios.
+
+ Args:
+ snr (`float`):
+ coefficient weighting the step from the model_output sample (from the network) to the random noise.
+ sigma_min (`float`):
+ initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the
+ distribution of the data.
+ sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model.
+ sampling_eps (`float`): the end value of sampling, where timesteps decrease progessively from 1 to
+ epsilon.
+ correct_steps (`int`): number of correction steps performed on a produced sample.
+ tensor_format (`str`): "np" or "pt" for the expected format of samples passed to the Scheduler.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 2000,
+ snr: float = 0.15,
+ sigma_min: float = 0.01,
+ sigma_max: float = 1348.0,
+ sampling_eps: float = 1e-5,
+ correct_steps: int = 1,
+ tensor_format: str = "pt",
+ ):
+ # setable values
+ self.timesteps = None
+
+ self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
+
+ self.tensor_format = tensor_format
+ self.set_format(tensor_format=tensor_format)
+
+ def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None):
+ """
+ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
+
+ """
+ sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
+ tensor_format = getattr(self, "tensor_format", "pt")
+ if tensor_format == "np":
+ self.timesteps = np.linspace(1, sampling_eps, num_inference_steps)
+ elif tensor_format == "pt":
+ self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps)
+ else:
+ raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ def set_sigmas(
+ self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None
+ ):
+ """
+ Sets the noise scales used for the diffusion chain. Supporting function to be run before inference.
+
+ The sigmas control the weight of the `drift` and `diffusion` components of sample update.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ sigma_min (`float`, optional):
+ initial noise scale value (overrides value given at Scheduler instantiation).
+ sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation).
+ sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
+
+ """
+ sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
+ sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max
+ sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
+ if self.timesteps is None:
+ self.set_timesteps(num_inference_steps, sampling_eps)
+
+ tensor_format = getattr(self, "tensor_format", "pt")
+ if tensor_format == "np":
+ self.discrete_sigmas = np.exp(np.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps))
+ self.sigmas = np.array([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
+ elif tensor_format == "pt":
+ self.discrete_sigmas = torch.exp(torch.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps))
+ self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
+ else:
+ raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ def get_adjacent_sigma(self, timesteps, t):
+ tensor_format = getattr(self, "tensor_format", "pt")
+ if tensor_format == "np":
+ return np.where(timesteps == 0, np.zeros_like(t), self.discrete_sigmas[timesteps - 1])
+ elif tensor_format == "pt":
+ return torch.where(
+ timesteps == 0,
+ torch.zeros_like(t.to(timesteps.device)),
+ self.discrete_sigmas[timesteps - 1].to(timesteps.device),
+ )
+
+ raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ def set_seed(self, seed):
+ warnings.warn(
+ "The method `set_seed` is deprecated and will be removed in version `0.4.0`. Please consider passing a"
+ " generator instead.",
+ DeprecationWarning,
+ )
+ tensor_format = getattr(self, "tensor_format", "pt")
+ if tensor_format == "np":
+ np.random.seed(seed)
+ elif tensor_format == "pt":
+ torch.manual_seed(seed)
+ else:
+ raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ def step_pred(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: int,
+ sample: Union[torch.FloatTensor, np.ndarray],
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[SdeVeOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ generator: random number generator.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if "seed" in kwargs and kwargs["seed"] is not None:
+ self.set_seed(kwargs["seed"])
+
+ if self.timesteps is None:
+ raise ValueError(
+ "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ timestep = timestep * torch.ones(
+ sample.shape[0], device=sample.device
+ ) # torch.repeat_interleave(timestep, sample.shape[0])
+ timesteps = (timestep * (len(self.timesteps) - 1)).long()
+
+ # mps requires indices to be in the same device, so we use cpu as is the default with cuda
+ timesteps = timesteps.to(self.discrete_sigmas.device)
+
+ sigma = self.discrete_sigmas[timesteps].to(sample.device)
+ adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device)
+ drift = self.zeros_like(sample)
+ diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5
+
+ # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
+ # also equation 47 shows the analog from SDE models to ancestral sampling methods
+ drift = drift - diffusion[:, None, None, None] ** 2 * model_output
+
+ # equation 6: sample noise for the diffusion term of
+ noise = self.randn_like(sample, generator=generator)
+ prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
+ # TODO is the variable diffusion the correct scaling term for the noise?
+ prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
+
+ if not return_dict:
+ return (prev_sample, prev_sample_mean)
+
+ return SdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean)
+
+ def step_correct(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ sample: Union[torch.FloatTensor, np.ndarray],
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
+ after making the prediction for the previous timestep.
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ generator: random number generator.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if "seed" in kwargs and kwargs["seed"] is not None:
+ self.set_seed(kwargs["seed"])
+
+ if self.timesteps is None:
+ raise ValueError(
+ "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
+ # sample noise for correction
+ noise = self.randn_like(sample, generator=generator)
+
+ # compute step size from the model_output, the noise, and the snr
+ grad_norm = self.norm(model_output)
+ noise_norm = self.norm(noise)
+ step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
+ step_size = step_size * torch.ones(sample.shape[0]).to(sample.device)
+ # self.repeat_scalar(step_size, sample.shape[0])
+
+ # compute corrected sample: model_output term and noise term
+ prev_sample_mean = sample + step_size[:, None, None, None] * model_output
+ prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/my_half_diffusers/schedulers/scheduling_sde_vp.py b/my_half_diffusers/schedulers/scheduling_sde_vp.py
new file mode 100644
index 0000000000000000000000000000000000000000..66e6ec6616ab01e5ae988b21e9599a0422a9714a
--- /dev/null
+++ b/my_half_diffusers/schedulers/scheduling_sde_vp.py
@@ -0,0 +1,81 @@
+# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
+
+# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils import SchedulerMixin
+
+
+class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
+ """
+ The variance preserving stochastic differential equation (SDE) scheduler.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
+ [`~ConfigMixin.from_config`] functios.
+
+ For more information, see the original paper: https://arxiv.org/abs/2011.13456
+
+ UNDER CONSTRUCTION
+
+ """
+
+ @register_to_config
+ def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"):
+
+ self.sigmas = None
+ self.discrete_sigmas = None
+ self.timesteps = None
+
+ def set_timesteps(self, num_inference_steps):
+ self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
+
+ def step_pred(self, score, x, t):
+ if self.timesteps is None:
+ raise ValueError(
+ "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # TODO(Patrick) better comments + non-PyTorch
+ # postprocess model score
+ log_mean_coeff = (
+ -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min
+ )
+ std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff))
+ score = -score / std[:, None, None, None]
+
+ # compute
+ dt = -1.0 / len(self.timesteps)
+
+ beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min)
+ drift = -0.5 * beta_t[:, None, None, None] * x
+ diffusion = torch.sqrt(beta_t)
+ drift = drift - diffusion[:, None, None, None] ** 2 * score
+ x_mean = x + drift * dt
+
+ # add noise
+ noise = torch.randn_like(x)
+ x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * noise
+
+ return x, x_mean
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/my_half_diffusers/schedulers/scheduling_utils.py b/my_half_diffusers/schedulers/scheduling_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2bcd73acf32c1e152a5d8708479731996731c6d
--- /dev/null
+++ b/my_half_diffusers/schedulers/scheduling_utils.py
@@ -0,0 +1,125 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Union
+
+import numpy as np
+import torch
+
+from ..utils import BaseOutput
+
+
+SCHEDULER_CONFIG_NAME = "scheduler_config.json"
+
+
+@dataclass
+class SchedulerOutput(BaseOutput):
+ """
+ Base class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ """
+
+ prev_sample: torch.FloatTensor
+
+
+class SchedulerMixin:
+ """
+ Mixin containing common functions for the schedulers.
+ """
+
+ config_name = SCHEDULER_CONFIG_NAME
+ ignore_for_config = ["tensor_format"]
+
+ def set_format(self, tensor_format="pt"):
+ self.tensor_format = tensor_format
+ if tensor_format == "pt":
+ for key, value in vars(self).items():
+ if isinstance(value, np.ndarray):
+ setattr(self, key, torch.from_numpy(value))
+
+ return self
+
+ def clip(self, tensor, min_value=None, max_value=None):
+ tensor_format = getattr(self, "tensor_format", "pt")
+
+ if tensor_format == "np":
+ return np.clip(tensor, min_value, max_value)
+ elif tensor_format == "pt":
+ return torch.clamp(tensor, min_value, max_value)
+
+ raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ def log(self, tensor):
+ tensor_format = getattr(self, "tensor_format", "pt")
+
+ if tensor_format == "np":
+ return np.log(tensor)
+ elif tensor_format == "pt":
+ return torch.log(tensor)
+
+ raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ def match_shape(self, values: Union[np.ndarray, torch.Tensor], broadcast_array: Union[np.ndarray, torch.Tensor]):
+ """
+ Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims.
+
+ Args:
+ values: an array or tensor of values to extract.
+ broadcast_array: an array with a larger shape of K dimensions with the batch
+ dimension equal to the length of timesteps.
+ Returns:
+ a tensor of shape [batch_size, 1, ...] where the shape has K dims.
+ """
+
+ tensor_format = getattr(self, "tensor_format", "pt")
+ values = values.flatten()
+
+ while len(values.shape) < len(broadcast_array.shape):
+ values = values[..., None]
+ if tensor_format == "pt":
+ values = values.to(broadcast_array.device)
+
+ return values
+
+ def norm(self, tensor):
+ tensor_format = getattr(self, "tensor_format", "pt")
+ if tensor_format == "np":
+ return np.linalg.norm(tensor)
+ elif tensor_format == "pt":
+ return torch.norm(tensor.reshape(tensor.shape[0], -1), dim=-1).mean()
+
+ raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ def randn_like(self, tensor, generator=None):
+ tensor_format = getattr(self, "tensor_format", "pt")
+ if tensor_format == "np":
+ return np.random.randn(*np.shape(tensor))
+ elif tensor_format == "pt":
+ # return torch.randn_like(tensor)
+ return torch.randn(tensor.shape, layout=tensor.layout, generator=generator).to(tensor.device)
+
+ raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ def zeros_like(self, tensor):
+ tensor_format = getattr(self, "tensor_format", "pt")
+ if tensor_format == "np":
+ return np.zeros_like(tensor)
+ elif tensor_format == "pt":
+ return torch.zeros_like(tensor)
+
+ raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
diff --git a/my_half_diffusers/testing_utils.py b/my_half_diffusers/testing_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff8b6aa9b41c45b0ab77f343904bffc53fa9e9cb
--- /dev/null
+++ b/my_half_diffusers/testing_utils.py
@@ -0,0 +1,61 @@
+import os
+import random
+import unittest
+from distutils.util import strtobool
+
+import torch
+
+from packaging import version
+
+
+global_rng = random.Random()
+torch_device = "cuda" if torch.cuda.is_available() else "cpu"
+is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.12")
+
+if is_torch_higher_equal_than_1_12:
+ torch_device = "mps" if torch.backends.mps.is_available() else torch_device
+
+
+def parse_flag_from_env(key, default=False):
+ try:
+ value = os.environ[key]
+ except KeyError:
+ # KEY isn't set, default to `default`.
+ _value = default
+ else:
+ # KEY is set, convert it to True or False.
+ try:
+ _value = strtobool(value)
+ except ValueError:
+ # More values are supported, but let's keep the message simple.
+ raise ValueError(f"If set, {key} must be yes or no.")
+ return _value
+
+
+_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
+
+
+def floats_tensor(shape, scale=1.0, rng=None, name=None):
+ """Creates a random float32 tensor"""
+ if rng is None:
+ rng = global_rng
+
+ total_dims = 1
+ for dim in shape:
+ total_dims *= dim
+
+ values = []
+ for _ in range(total_dims):
+ values.append(rng.random() * scale)
+
+ return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous()
+
+
+def slow(test_case):
+ """
+ Decorator marking a test as slow.
+
+ Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
+
+ """
+ return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
diff --git a/my_half_diffusers/training_utils.py b/my_half_diffusers/training_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa1694161fc54c7fd097abf3bcbf44c498daad4b
--- /dev/null
+++ b/my_half_diffusers/training_utils.py
@@ -0,0 +1,125 @@
+import copy
+import os
+import random
+
+import numpy as np
+import torch
+
+
+def enable_full_determinism(seed: int):
+ """
+ Helper function for reproducible behavior during distributed training. See
+ - https://pytorch.org/docs/stable/notes/randomness.html for pytorch
+ """
+ # set seed first
+ set_seed(seed)
+
+ # Enable PyTorch deterministic mode. This potentially requires either the environment
+ # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
+ # depending on the CUDA version, so we set them both here
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
+ torch.use_deterministic_algorithms(True)
+
+ # Enable CUDNN deterministic mode
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+
+def set_seed(seed: int):
+ """
+ Args:
+ Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
+ seed (`int`): The seed to set.
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ # ^^ safe to call this function even if cuda is not available
+
+
+class EMAModel:
+ """
+ Exponential Moving Average of models weights
+ """
+
+ def __init__(
+ self,
+ model,
+ update_after_step=0,
+ inv_gamma=1.0,
+ power=2 / 3,
+ min_value=0.0,
+ max_value=0.9999,
+ device=None,
+ ):
+ """
+ @crowsonkb's notes on EMA Warmup:
+ If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
+ to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
+ gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
+ at 215.4k steps).
+ Args:
+ inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
+ power (float): Exponential factor of EMA warmup. Default: 2/3.
+ min_value (float): The minimum EMA decay rate. Default: 0.
+ """
+
+ self.averaged_model = copy.deepcopy(model).eval()
+ self.averaged_model.requires_grad_(False)
+
+ self.update_after_step = update_after_step
+ self.inv_gamma = inv_gamma
+ self.power = power
+ self.min_value = min_value
+ self.max_value = max_value
+
+ if device is not None:
+ self.averaged_model = self.averaged_model.to(device=device)
+
+ self.decay = 0.0
+ self.optimization_step = 0
+
+ def get_decay(self, optimization_step):
+ """
+ Compute the decay factor for the exponential moving average.
+ """
+ step = max(0, optimization_step - self.update_after_step - 1)
+ value = 1 - (1 + step / self.inv_gamma) ** -self.power
+
+ if step <= 0:
+ return 0.0
+
+ return max(self.min_value, min(value, self.max_value))
+
+ @torch.no_grad()
+ def step(self, new_model):
+ ema_state_dict = {}
+ ema_params = self.averaged_model.state_dict()
+
+ self.decay = self.get_decay(self.optimization_step)
+
+ for key, param in new_model.named_parameters():
+ if isinstance(param, dict):
+ continue
+ try:
+ ema_param = ema_params[key]
+ except KeyError:
+ ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
+ ema_params[key] = ema_param
+
+ if not param.requires_grad:
+ ema_params[key].copy_(param.to(dtype=ema_param.dtype).data)
+ ema_param = ema_params[key]
+ else:
+ ema_param.mul_(self.decay)
+ ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)
+
+ ema_state_dict[key] = ema_param
+
+ for key, param in new_model.named_buffers():
+ ema_state_dict[key] = param
+
+ self.averaged_model.load_state_dict(ema_state_dict, strict=False)
+ self.optimization_step += 1
diff --git a/my_half_diffusers/utils/__init__.py b/my_half_diffusers/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c00a28e1058fbd47451bfe48e23865876c08ed69
--- /dev/null
+++ b/my_half_diffusers/utils/__init__.py
@@ -0,0 +1,53 @@
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+
+from .import_utils import (
+ ENV_VARS_TRUE_AND_AUTO_VALUES,
+ ENV_VARS_TRUE_VALUES,
+ USE_JAX,
+ USE_TF,
+ USE_TORCH,
+ DummyObject,
+ is_flax_available,
+ is_inflect_available,
+ is_modelcards_available,
+ is_onnx_available,
+ is_scipy_available,
+ is_tf_available,
+ is_torch_available,
+ is_transformers_available,
+ is_unidecode_available,
+ requires_backends,
+)
+from .logging import get_logger
+from .outputs import BaseOutput
+
+
+logger = get_logger(__name__)
+
+
+hf_cache_home = os.path.expanduser(
+ os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
+)
+default_cache_path = os.path.join(hf_cache_home, "diffusers")
+
+
+CONFIG_NAME = "config.json"
+HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
+DIFFUSERS_CACHE = default_cache_path
+DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
+HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
diff --git a/my_half_diffusers/utils/__pycache__/__init__.cpython-38.pyc b/my_half_diffusers/utils/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3e7d6f1f1231bf974295a7d453bd8f541f11f002
Binary files /dev/null and b/my_half_diffusers/utils/__pycache__/__init__.cpython-38.pyc differ
diff --git a/my_half_diffusers/utils/__pycache__/import_utils.cpython-38.pyc b/my_half_diffusers/utils/__pycache__/import_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2c86d9ddce0b25b141284538b60f408cf9e8ab1b
Binary files /dev/null and b/my_half_diffusers/utils/__pycache__/import_utils.cpython-38.pyc differ
diff --git a/my_half_diffusers/utils/__pycache__/logging.cpython-38.pyc b/my_half_diffusers/utils/__pycache__/logging.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c2632089b9f2e828831d39e16889228b0de1929a
Binary files /dev/null and b/my_half_diffusers/utils/__pycache__/logging.cpython-38.pyc differ
diff --git a/my_half_diffusers/utils/__pycache__/outputs.cpython-38.pyc b/my_half_diffusers/utils/__pycache__/outputs.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c754eed2fb379698e0190c8973069429ceb12a82
Binary files /dev/null and b/my_half_diffusers/utils/__pycache__/outputs.cpython-38.pyc differ
diff --git a/my_half_diffusers/utils/dummy_scipy_objects.py b/my_half_diffusers/utils/dummy_scipy_objects.py
new file mode 100644
index 0000000000000000000000000000000000000000..3706c57541c1b7d9004957422b52cd1e2191ae68
--- /dev/null
+++ b/my_half_diffusers/utils/dummy_scipy_objects.py
@@ -0,0 +1,11 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+# flake8: noqa
+
+from ..utils import DummyObject, requires_backends
+
+
+class LMSDiscreteScheduler(metaclass=DummyObject):
+ _backends = ["scipy"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["scipy"])
diff --git a/my_half_diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py b/my_half_diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c2aec218c40190bd2d078bfb36fc34fd4ef16c2
--- /dev/null
+++ b/my_half_diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py
@@ -0,0 +1,10 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+# flake8: noqa
+from ..utils import DummyObject, requires_backends
+
+
+class GradTTSPipeline(metaclass=DummyObject):
+ _backends = ["transformers", "inflect", "unidecode"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["transformers", "inflect", "unidecode"])
diff --git a/my_half_diffusers/utils/dummy_transformers_and_onnx_objects.py b/my_half_diffusers/utils/dummy_transformers_and_onnx_objects.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e34b5ce0b69472df7e2c41de40476619d53dee9
--- /dev/null
+++ b/my_half_diffusers/utils/dummy_transformers_and_onnx_objects.py
@@ -0,0 +1,11 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+# flake8: noqa
+
+from ..utils import DummyObject, requires_backends
+
+
+class StableDiffusionOnnxPipeline(metaclass=DummyObject):
+ _backends = ["transformers", "onnx"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["transformers", "onnx"])
diff --git a/my_half_diffusers/utils/dummy_transformers_objects.py b/my_half_diffusers/utils/dummy_transformers_objects.py
new file mode 100644
index 0000000000000000000000000000000000000000..e05eb814d17b3a49eb550a89dfd13ee24fdda134
--- /dev/null
+++ b/my_half_diffusers/utils/dummy_transformers_objects.py
@@ -0,0 +1,32 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+# flake8: noqa
+
+from ..utils import DummyObject, requires_backends
+
+
+class LDMTextToImagePipeline(metaclass=DummyObject):
+ _backends = ["transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["transformers"])
+
+
+class StableDiffusionImg2ImgPipeline(metaclass=DummyObject):
+ _backends = ["transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["transformers"])
+
+
+class StableDiffusionInpaintPipeline(metaclass=DummyObject):
+ _backends = ["transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["transformers"])
+
+
+class StableDiffusionPipeline(metaclass=DummyObject):
+ _backends = ["transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["transformers"])
diff --git a/my_half_diffusers/utils/import_utils.py b/my_half_diffusers/utils/import_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f5e95ada51da97ac67e1dc62538b6eed8784bce
--- /dev/null
+++ b/my_half_diffusers/utils/import_utils.py
@@ -0,0 +1,274 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Import utilities: Utilities related to imports and our lazy inits.
+"""
+import importlib.util
+import os
+import sys
+from collections import OrderedDict
+
+from packaging import version
+
+from . import logging
+
+
+# The package importlib_metadata is in a different place, depending on the python version.
+if sys.version_info < (3, 8):
+ import importlib_metadata
+else:
+ import importlib.metadata as importlib_metadata
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
+ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
+
+USE_TF = os.environ.get("USE_TF", "AUTO").upper()
+USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
+USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
+
+_torch_version = "N/A"
+if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
+ _torch_available = importlib.util.find_spec("torch") is not None
+ if _torch_available:
+ try:
+ _torch_version = importlib_metadata.version("torch")
+ logger.info(f"PyTorch version {_torch_version} available.")
+ except importlib_metadata.PackageNotFoundError:
+ _torch_available = False
+else:
+ logger.info("Disabling PyTorch because USE_TF is set")
+ _torch_available = False
+
+
+_tf_version = "N/A"
+if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
+ _tf_available = importlib.util.find_spec("tensorflow") is not None
+ if _tf_available:
+ candidates = (
+ "tensorflow",
+ "tensorflow-cpu",
+ "tensorflow-gpu",
+ "tf-nightly",
+ "tf-nightly-cpu",
+ "tf-nightly-gpu",
+ "intel-tensorflow",
+ "intel-tensorflow-avx512",
+ "tensorflow-rocm",
+ "tensorflow-macos",
+ "tensorflow-aarch64",
+ )
+ _tf_version = None
+ # For the metadata, we have to look for both tensorflow and tensorflow-cpu
+ for pkg in candidates:
+ try:
+ _tf_version = importlib_metadata.version(pkg)
+ break
+ except importlib_metadata.PackageNotFoundError:
+ pass
+ _tf_available = _tf_version is not None
+ if _tf_available:
+ if version.parse(_tf_version) < version.parse("2"):
+ logger.info(f"TensorFlow found but with version {_tf_version}. Diffusers requires version 2 minimum.")
+ _tf_available = False
+ else:
+ logger.info(f"TensorFlow version {_tf_version} available.")
+else:
+ logger.info("Disabling Tensorflow because USE_TORCH is set")
+ _tf_available = False
+
+
+if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
+ _flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
+ if _flax_available:
+ try:
+ _jax_version = importlib_metadata.version("jax")
+ _flax_version = importlib_metadata.version("flax")
+ logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
+ except importlib_metadata.PackageNotFoundError:
+ _flax_available = False
+else:
+ _flax_available = False
+
+
+_transformers_available = importlib.util.find_spec("transformers") is not None
+try:
+ _transformers_version = importlib_metadata.version("transformers")
+ logger.debug(f"Successfully imported transformers version {_transformers_version}")
+except importlib_metadata.PackageNotFoundError:
+ _transformers_available = False
+
+
+_inflect_available = importlib.util.find_spec("inflect") is not None
+try:
+ _inflect_version = importlib_metadata.version("inflect")
+ logger.debug(f"Successfully imported inflect version {_inflect_version}")
+except importlib_metadata.PackageNotFoundError:
+ _inflect_available = False
+
+
+_unidecode_available = importlib.util.find_spec("unidecode") is not None
+try:
+ _unidecode_version = importlib_metadata.version("unidecode")
+ logger.debug(f"Successfully imported unidecode version {_unidecode_version}")
+except importlib_metadata.PackageNotFoundError:
+ _unidecode_available = False
+
+
+_modelcards_available = importlib.util.find_spec("modelcards") is not None
+try:
+ _modelcards_version = importlib_metadata.version("modelcards")
+ logger.debug(f"Successfully imported modelcards version {_modelcards_version}")
+except importlib_metadata.PackageNotFoundError:
+ _modelcards_available = False
+
+
+_onnx_available = importlib.util.find_spec("onnxruntime") is not None
+try:
+ _onnxruntime_version = importlib_metadata.version("onnxruntime")
+ logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}")
+except importlib_metadata.PackageNotFoundError:
+ _onnx_available = False
+
+
+_scipy_available = importlib.util.find_spec("scipy") is not None
+try:
+ _scipy_version = importlib_metadata.version("scipy")
+ logger.debug(f"Successfully imported transformers version {_scipy_version}")
+except importlib_metadata.PackageNotFoundError:
+ _scipy_available = False
+
+
+def is_torch_available():
+ return _torch_available
+
+
+def is_tf_available():
+ return _tf_available
+
+
+def is_flax_available():
+ return _flax_available
+
+
+def is_transformers_available():
+ return _transformers_available
+
+
+def is_inflect_available():
+ return _inflect_available
+
+
+def is_unidecode_available():
+ return _unidecode_available
+
+
+def is_modelcards_available():
+ return _modelcards_available
+
+
+def is_onnx_available():
+ return _onnx_available
+
+
+def is_scipy_available():
+ return _scipy_available
+
+
+# docstyle-ignore
+FLAX_IMPORT_ERROR = """
+{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
+installation page: https://github.com/google/flax and follow the ones that match your environment.
+"""
+
+# docstyle-ignore
+INFLECT_IMPORT_ERROR = """
+{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install
+inflect`
+"""
+
+# docstyle-ignore
+PYTORCH_IMPORT_ERROR = """
+{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the
+installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
+"""
+
+# docstyle-ignore
+ONNX_IMPORT_ERROR = """
+{0} requires the onnxruntime library but it was not found in your environment. You can install it with pip: `pip
+install onnxruntime`
+"""
+
+# docstyle-ignore
+SCIPY_IMPORT_ERROR = """
+{0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install
+scipy`
+"""
+
+# docstyle-ignore
+TENSORFLOW_IMPORT_ERROR = """
+{0} requires the TensorFlow library but it was not found in your environment. Checkout the instructions on the
+installation page: https://www.tensorflow.org/install and follow the ones that match your environment.
+"""
+
+# docstyle-ignore
+TRANSFORMERS_IMPORT_ERROR = """
+{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip
+install transformers`
+"""
+
+# docstyle-ignore
+UNIDECODE_IMPORT_ERROR = """
+{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install
+Unidecode`
+"""
+
+
+BACKENDS_MAPPING = OrderedDict(
+ [
+ ("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
+ ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)),
+ ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)),
+ ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
+ ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
+ ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
+ ("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)),
+ ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)),
+ ]
+)
+
+
+def requires_backends(obj, backends):
+ if not isinstance(backends, (list, tuple)):
+ backends = [backends]
+
+ name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
+ checks = (BACKENDS_MAPPING[backend] for backend in backends)
+ failed = [msg.format(name) for available, msg in checks if not available()]
+ if failed:
+ raise ImportError("".join(failed))
+
+
+class DummyObject(type):
+ """
+ Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
+ `requires_backend` each time a user tries to access any method of that class.
+ """
+
+ def __getattr__(cls, key):
+ if key.startswith("_"):
+ return super().__getattr__(cls, key)
+ requires_backends(cls, cls._backends)
diff --git a/my_half_diffusers/utils/logging.py b/my_half_diffusers/utils/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f2d0227b87c66205ceb3391a8e98f5f33285dc4
--- /dev/null
+++ b/my_half_diffusers/utils/logging.py
@@ -0,0 +1,344 @@
+# coding=utf-8
+# Copyright 2020 Optuna, Hugging Face
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Logging utilities."""
+
+import logging
+import os
+import sys
+import threading
+from logging import CRITICAL # NOQA
+from logging import DEBUG # NOQA
+from logging import ERROR # NOQA
+from logging import FATAL # NOQA
+from logging import INFO # NOQA
+from logging import NOTSET # NOQA
+from logging import WARN # NOQA
+from logging import WARNING # NOQA
+from typing import Optional
+
+from tqdm import auto as tqdm_lib
+
+
+_lock = threading.Lock()
+_default_handler: Optional[logging.Handler] = None
+
+log_levels = {
+ "debug": logging.DEBUG,
+ "info": logging.INFO,
+ "warning": logging.WARNING,
+ "error": logging.ERROR,
+ "critical": logging.CRITICAL,
+}
+
+_default_log_level = logging.WARNING
+
+_tqdm_active = True
+
+
+def _get_default_logging_level():
+ """
+ If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
+ not - fall back to `_default_log_level`
+ """
+ env_level_str = os.getenv("DIFFUSERS_VERBOSITY", None)
+ if env_level_str:
+ if env_level_str in log_levels:
+ return log_levels[env_level_str]
+ else:
+ logging.getLogger().warning(
+ f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, "
+ f"has to be one of: { ', '.join(log_levels.keys()) }"
+ )
+ return _default_log_level
+
+
+def _get_library_name() -> str:
+
+ return __name__.split(".")[0]
+
+
+def _get_library_root_logger() -> logging.Logger:
+
+ return logging.getLogger(_get_library_name())
+
+
+def _configure_library_root_logger() -> None:
+
+ global _default_handler
+
+ with _lock:
+ if _default_handler:
+ # This library has already configured the library root logger.
+ return
+ _default_handler = logging.StreamHandler() # Set sys.stderr as stream.
+ _default_handler.flush = sys.stderr.flush
+
+ # Apply our default configuration to the library root logger.
+ library_root_logger = _get_library_root_logger()
+ library_root_logger.addHandler(_default_handler)
+ library_root_logger.setLevel(_get_default_logging_level())
+ library_root_logger.propagate = False
+
+
+def _reset_library_root_logger() -> None:
+
+ global _default_handler
+
+ with _lock:
+ if not _default_handler:
+ return
+
+ library_root_logger = _get_library_root_logger()
+ library_root_logger.removeHandler(_default_handler)
+ library_root_logger.setLevel(logging.NOTSET)
+ _default_handler = None
+
+
+def get_log_levels_dict():
+ return log_levels
+
+
+def get_logger(name: Optional[str] = None) -> logging.Logger:
+ """
+ Return a logger with the specified name.
+
+ This function is not supposed to be directly accessed unless you are writing a custom diffusers module.
+ """
+
+ if name is None:
+ name = _get_library_name()
+
+ _configure_library_root_logger()
+ return logging.getLogger(name)
+
+
+def get_verbosity() -> int:
+ """
+ Return the current level for the 🤗 Diffusers' root logger as an int.
+
+ Returns:
+ `int`: The logging level.
+
+
+
+ 🤗 Diffusers has following logging levels:
+
+ - 50: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
+ - 40: `diffusers.logging.ERROR`
+ - 30: `diffusers.logging.WARNING` or `diffusers.logging.WARN`
+ - 20: `diffusers.logging.INFO`
+ - 10: `diffusers.logging.DEBUG`
+
+ """
+
+ _configure_library_root_logger()
+ return _get_library_root_logger().getEffectiveLevel()
+
+
+def set_verbosity(verbosity: int) -> None:
+ """
+ Set the verbosity level for the 🤗 Diffusers' root logger.
+
+ Args:
+ verbosity (`int`):
+ Logging level, e.g., one of:
+
+ - `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
+ - `diffusers.logging.ERROR`
+ - `diffusers.logging.WARNING` or `diffusers.logging.WARN`
+ - `diffusers.logging.INFO`
+ - `diffusers.logging.DEBUG`
+ """
+
+ _configure_library_root_logger()
+ _get_library_root_logger().setLevel(verbosity)
+
+
+def set_verbosity_info():
+ """Set the verbosity to the `INFO` level."""
+ return set_verbosity(INFO)
+
+
+def set_verbosity_warning():
+ """Set the verbosity to the `WARNING` level."""
+ return set_verbosity(WARNING)
+
+
+def set_verbosity_debug():
+ """Set the verbosity to the `DEBUG` level."""
+ return set_verbosity(DEBUG)
+
+
+def set_verbosity_error():
+ """Set the verbosity to the `ERROR` level."""
+ return set_verbosity(ERROR)
+
+
+def disable_default_handler() -> None:
+ """Disable the default handler of the HuggingFace Diffusers' root logger."""
+
+ _configure_library_root_logger()
+
+ assert _default_handler is not None
+ _get_library_root_logger().removeHandler(_default_handler)
+
+
+def enable_default_handler() -> None:
+ """Enable the default handler of the HuggingFace Diffusers' root logger."""
+
+ _configure_library_root_logger()
+
+ assert _default_handler is not None
+ _get_library_root_logger().addHandler(_default_handler)
+
+
+def add_handler(handler: logging.Handler) -> None:
+ """adds a handler to the HuggingFace Diffusers' root logger."""
+
+ _configure_library_root_logger()
+
+ assert handler is not None
+ _get_library_root_logger().addHandler(handler)
+
+
+def remove_handler(handler: logging.Handler) -> None:
+ """removes given handler from the HuggingFace Diffusers' root logger."""
+
+ _configure_library_root_logger()
+
+ assert handler is not None and handler not in _get_library_root_logger().handlers
+ _get_library_root_logger().removeHandler(handler)
+
+
+def disable_propagation() -> None:
+ """
+ Disable propagation of the library log outputs. Note that log propagation is disabled by default.
+ """
+
+ _configure_library_root_logger()
+ _get_library_root_logger().propagate = False
+
+
+def enable_propagation() -> None:
+ """
+ Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent
+ double logging if the root logger has been configured.
+ """
+
+ _configure_library_root_logger()
+ _get_library_root_logger().propagate = True
+
+
+def enable_explicit_format() -> None:
+ """
+ Enable explicit formatting for every HuggingFace Diffusers' logger. The explicit formatter is as follows:
+ ```
+ [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
+ ```
+ All handlers currently bound to the root logger are affected by this method.
+ """
+ handlers = _get_library_root_logger().handlers
+
+ for handler in handlers:
+ formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
+ handler.setFormatter(formatter)
+
+
+def reset_format() -> None:
+ """
+ Resets the formatting for HuggingFace Diffusers' loggers.
+
+ All handlers currently bound to the root logger are affected by this method.
+ """
+ handlers = _get_library_root_logger().handlers
+
+ for handler in handlers:
+ handler.setFormatter(None)
+
+
+def warning_advice(self, *args, **kwargs):
+ """
+ This method is identical to `logger.warninging()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this
+ warning will not be printed
+ """
+ no_advisory_warnings = os.getenv("DIFFUSERS_NO_ADVISORY_WARNINGS", False)
+ if no_advisory_warnings:
+ return
+ self.warning(*args, **kwargs)
+
+
+logging.Logger.warning_advice = warning_advice
+
+
+class EmptyTqdm:
+ """Dummy tqdm which doesn't do anything."""
+
+ def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
+ self._iterator = args[0] if args else None
+
+ def __iter__(self):
+ return iter(self._iterator)
+
+ def __getattr__(self, _):
+ """Return empty function."""
+
+ def empty_fn(*args, **kwargs): # pylint: disable=unused-argument
+ return
+
+ return empty_fn
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type_, value, traceback):
+ return
+
+
+class _tqdm_cls:
+ def __call__(self, *args, **kwargs):
+ if _tqdm_active:
+ return tqdm_lib.tqdm(*args, **kwargs)
+ else:
+ return EmptyTqdm(*args, **kwargs)
+
+ def set_lock(self, *args, **kwargs):
+ self._lock = None
+ if _tqdm_active:
+ return tqdm_lib.tqdm.set_lock(*args, **kwargs)
+
+ def get_lock(self):
+ if _tqdm_active:
+ return tqdm_lib.tqdm.get_lock()
+
+
+tqdm = _tqdm_cls()
+
+
+def is_progress_bar_enabled() -> bool:
+ """Return a boolean indicating whether tqdm progress bars are enabled."""
+ global _tqdm_active
+ return bool(_tqdm_active)
+
+
+def enable_progress_bar():
+ """Enable tqdm progress bar."""
+ global _tqdm_active
+ _tqdm_active = True
+
+
+def disable_progress_bar():
+ """Disable tqdm progress bar."""
+ global _tqdm_active
+ _tqdm_active = False
diff --git a/my_half_diffusers/utils/model_card_template.md b/my_half_diffusers/utils/model_card_template.md
new file mode 100644
index 0000000000000000000000000000000000000000..f19c85b0fcf2f7b07e9c3f950a9657b3f2053f21
--- /dev/null
+++ b/my_half_diffusers/utils/model_card_template.md
@@ -0,0 +1,50 @@
+---
+{{ card_data }}
+---
+
+
+
+# {{ model_name | default("Diffusion Model") }}
+
+## Model description
+
+This diffusion model is trained with the [🤗 Diffusers](https://github.com/huggingface/diffusers) library
+on the `{{ dataset_name }}` dataset.
+
+## Intended uses & limitations
+
+#### How to use
+
+```python
+# TODO: add an example code snippet for running this diffusion pipeline
+```
+
+#### Limitations and bias
+
+[TODO: provide examples of latent issues and potential remediations]
+
+## Training data
+
+[TODO: describe the data used to train the model]
+
+### Training hyperparameters
+
+The following hyperparameters were used during training:
+- learning_rate: {{ learning_rate }}
+- train_batch_size: {{ train_batch_size }}
+- eval_batch_size: {{ eval_batch_size }}
+- gradient_accumulation_steps: {{ gradient_accumulation_steps }}
+- optimizer: AdamW with betas=({{ adam_beta1 }}, {{ adam_beta2 }}), weight_decay={{ adam_weight_decay }} and epsilon={{ adam_epsilon }}
+- lr_scheduler: {{ lr_scheduler }}
+- lr_warmup_steps: {{ lr_warmup_steps }}
+- ema_inv_gamma: {{ ema_inv_gamma }}
+- ema_inv_gamma: {{ ema_power }}
+- ema_inv_gamma: {{ ema_max_decay }}
+- mixed_precision: {{ mixed_precision }}
+
+### Training results
+
+📈 [TensorBoard logs](https://huggingface.co/{{ repo_name }}/tensorboard?#scalars)
+
+
diff --git a/my_half_diffusers/utils/outputs.py b/my_half_diffusers/utils/outputs.py
new file mode 100644
index 0000000000000000000000000000000000000000..b02f62d02d0322401fd9926aca9f792a4696cc1e
--- /dev/null
+++ b/my_half_diffusers/utils/outputs.py
@@ -0,0 +1,109 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Generic utilities
+"""
+
+import warnings
+from collections import OrderedDict
+from dataclasses import fields
+from typing import Any, Tuple
+
+import numpy as np
+
+from .import_utils import is_torch_available
+
+
+def is_tensor(x):
+ """
+ Tests if `x` is a `torch.Tensor` or `np.ndarray`.
+ """
+ if is_torch_available():
+ import torch
+
+ if isinstance(x, torch.Tensor):
+ return True
+
+ return isinstance(x, np.ndarray)
+
+
+class BaseOutput(OrderedDict):
+ """
+ Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a
+ tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
+ python dictionary.
+
+
+
+ You can't unpack a `BaseOutput` directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple
+ before.
+
+
+ """
+
+ def __post_init__(self):
+ class_fields = fields(self)
+
+ # Safety and consistency checks
+ if not len(class_fields):
+ raise ValueError(f"{self.__class__.__name__} has no fields.")
+
+ for field in class_fields:
+ v = getattr(self, field.name)
+ if v is not None:
+ self[field.name] = v
+
+ def __delitem__(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
+
+ def setdefault(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
+
+ def pop(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
+
+ def update(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
+
+ def __getitem__(self, k):
+ if isinstance(k, str):
+ inner_dict = {k: v for (k, v) in self.items()}
+ if self.__class__.__name__ in ["StableDiffusionPipelineOutput", "ImagePipelineOutput"] and k == "sample":
+ warnings.warn(
+ "The keyword 'samples' is deprecated and will be removed in version 0.4.0. Please use `.images` or"
+ " `'images'` instead.",
+ DeprecationWarning,
+ )
+ return inner_dict["images"]
+ return inner_dict[k]
+ else:
+ return self.to_tuple()[k]
+
+ def __setattr__(self, name, value):
+ if name in self.keys() and value is not None:
+ # Don't call self.__setitem__ to avoid recursion errors
+ super().__setitem__(name, value)
+ super().__setattr__(name, value)
+
+ def __setitem__(self, key, value):
+ # Will raise a KeyException if needed
+ super().__setitem__(key, value)
+ # Don't call self.__setattr__ to avoid recursion errors
+ super().__setattr__(key, value)
+
+ def to_tuple(self) -> Tuple[Any]:
+ """
+ Convert self to a tuple containing all the attributes/keys that are not `None`.
+ """
+ return tuple(self[k] for k in self.keys())
diff --git a/requirements_local.txt b/requirements_local.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9e6d6fc3a4b61932b3fb4e96b7b41635fbeefc26
--- /dev/null
+++ b/requirements_local.txt
@@ -0,0 +1,13 @@
+--find-links https://download.pytorch.org/whl/torch_stable.html
+torch==1.11.0+cu113
+torchvision
+numpy
+transformers==4.19.2
+diffusers==0.6.0
+omegaconf==2.1.1
+ftfy==6.1.1
+regex==2022.9.13
+timm==0.6.12
+imageio
+gradio
+# cudatoolkit==11.3.1
diff --git a/square_ims/american_gothic.jpg b/square_ims/american_gothic.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..581b305c4891699db630cee77103a2d56eff0a45
Binary files /dev/null and b/square_ims/american_gothic.jpg differ
diff --git a/square_ims/colloseum.jpg b/square_ims/colloseum.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6a6b3669d8f4f7d28b47068f7c2cfc5ba05b563a
Binary files /dev/null and b/square_ims/colloseum.jpg differ
diff --git a/square_ims/einstein.jpg b/square_ims/einstein.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d87ff7da27f41153c6f1a943eb5cd59e1d2c60e2
Binary files /dev/null and b/square_ims/einstein.jpg differ
diff --git a/square_ims/hf.png b/square_ims/hf.png
new file mode 100644
index 0000000000000000000000000000000000000000..7f8121bac8f2316d3d73f53decd6217df79f4798
Binary files /dev/null and b/square_ims/hf.png differ
diff --git a/square_ims/imagenet_cake_2.jpg b/square_ims/imagenet_cake_2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..44b1c7f217cc603c5060fa2413a21b903789684c
Binary files /dev/null and b/square_ims/imagenet_cake_2.jpg differ
diff --git a/square_ims/imagenet_dog_1.jpg b/square_ims/imagenet_dog_1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..54f6b0231d30369c17a1bd03153397587e5dc332
Binary files /dev/null and b/square_ims/imagenet_dog_1.jpg differ
diff --git a/square_ims/imagenet_dog_2.jpg b/square_ims/imagenet_dog_2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b4da94ef784dfce89b9aae029802943229df63cf
Binary files /dev/null and b/square_ims/imagenet_dog_2.jpg differ
diff --git a/square_ims/scream.jpg b/square_ims/scream.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0b27677986bd63259454fa68d5ca8116275a99c7
Binary files /dev/null and b/square_ims/scream.jpg differ
diff --git a/square_ims/ucsb.jpg b/square_ims/ucsb.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..cea8da28c0ae74afc4c02cb7f57db6616f43b482
Binary files /dev/null and b/square_ims/ucsb.jpg differ
diff --git a/square_ims/yosemite.jpg b/square_ims/yosemite.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c67c49ea3778448b52f8d91df65c5cc10d228e97
Binary files /dev/null and b/square_ims/yosemite.jpg differ
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..77d58387c13924bbc847ddcec7a8cec2ae7397a5
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,27 @@
+import os
+
+
+class Endpoint:
+ def __init__(self):
+ self._url = None
+
+ @property
+ def url(self):
+ if self._url is None:
+ self._url = self.get_url()
+
+ return self._url
+
+ def get_url(self):
+ endpoint = os.environ.get("endpoint")
+
+ return endpoint
+
+
+def get_token():
+ token = os.environ.get("auth_token")
+
+ if token is None:
+ raise ValueError("auth-token not found in environment variables")
+
+ return token
\ No newline at end of file
diff --git a/working_app.py b/working_app.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2c3dce7c5dc6a76564870e54996b97023cc780a
--- /dev/null
+++ b/working_app.py
@@ -0,0 +1,139 @@
+import gradio as gr
+import numpy as np
+# from edict_functions import EDICT_editing
+from PIL import Image
+from utils import Endpoint, get_token
+from io import BytesIO
+import requests
+
+
+endpoint = Endpoint()
+
+def local_edict(x, source_text, edit_text,
+ edit_strength, guidance_scale,
+ steps=50, mix_weight=0.93, ):
+ x = Image.fromarray(x)
+ return_im = EDICT_editing(x,
+ source_text,
+ edit_text,
+ steps=steps,
+ mix_weight=mix_weight,
+ init_image_strength=edit_strength,
+ guidance_scale=guidance_scale
+ )[0]
+ return np.array(return_im)
+
+def encode_image(image):
+ buffered = BytesIO()
+ image.save(buffered, format="JPEG", quality=95)
+ buffered.seek(0)
+
+ return buffered
+
+
+
+def decode_image(img_obj):
+ img = Image.open(img_obj).convert("RGB")
+ return img
+
+def edict(x, source_text, edit_text,
+ edit_strength, guidance_scale,
+ steps=50, mix_weight=0.93, ):
+
+ url = endpoint.url
+ url = url + "/api/edit"
+ headers = {### Misc.
+
+ "User-Agent": "EDICT HuggingFace Space",
+ "Auth-Token": get_token(),
+ }
+
+ data = {
+ "source_text": source_text,
+ "edit_text": edit_text,
+ "edit_strength": edit_strength,
+ "guidance_scale": guidance_scale,
+ }
+
+ image = encode_image(Image.fromarray(x))
+ files = {"image": image}
+
+ response = requests.post(url, data=data, files=files, headers=headers)
+
+ if response.status_code == 200:
+ return np.array(decode_image(BytesIO(response.content)))
+ else:
+ return "Error: " + response.text
+ # x = decode_image(response)
+ # return np.array(x)
+
+examples = [
+ ['square_ims/american_gothic.jpg', 'A painting of two people frowning', 'A painting of two people smiling', 0.5, 3],
+ ['square_ims/colloseum.jpg', 'An old ruined building', 'A new modern office building', 0.8, 3],
+ ]
+
+
+examples.append(['square_ims/scream.jpg', 'A painting of someone screaming', 'A painting of an alien', 0.5, 3])
+examples.append(['square_ims/yosemite.jpg', 'Granite forest valley', 'Granite desert valley', 0.8, 3])
+examples.append(['square_ims/einstein.jpg', 'Mouth open', 'Mouth closed', 0.8, 3])
+examples.append(['square_ims/einstein.jpg', 'A man', 'A man in K.I.S.S. facepaint', 0.8, 3])
+"""
+examples.extend([
+ ['square_ims/imagenet_cake_2.jpg', 'A cupcake', 'A Chinese New Year cupcake', 0.8, 3],
+ ['square_ims/imagenet_cake_2.jpg', 'A cupcake', 'A Union Jack cupcake', 0.8, 3],
+ ['square_ims/imagenet_cake_2.jpg', 'A cupcake', 'A Nigerian flag cupcake', 0.8, 3],
+ ['square_ims/imagenet_cake_2.jpg', 'A cupcake', 'A Santa Claus cupcake', 0.8, 3],
+ ['square_ims/imagenet_cake_2.jpg', 'A cupcake', 'An Easter cupcake', 0.8, 3],
+ ['square_ims/imagenet_cake_2.jpg', 'A cupcake', 'A hedgehog cupcake', 0.8, 3],
+ ['square_ims/imagenet_cake_2.jpg', 'A cupcake', 'A rose cupcake', 0.8, 3],
+ ])
+"""
+
+for dog_i in [1, 2]:
+ for breed in ['Golden Retriever', 'Chihuahua', 'Dalmatian']:
+ examples.append([f'square_ims/imagenet_dog_{dog_i}.jpg', 'A dog', f'A {breed}', 0.8, 3])
+
+
+description = 'A gradio demo for [EDICT](https://arxiv.org/abs/2211.12446) (CVPR23)'
+# description = gr.Markdown(description)
+
+article = """
+
+### Prompting Style
+
+As with many text-to-image methods, the prompting style of EDICT can make a big difference. When in doubt, experiment! Some guidance:
+* Parallel *Original Description* and *Edit Description* construction as much as possible. Inserting/editing single words often is enough to affect a change while maintaining a lot of the original structure
+* Words that will affect the entire setting (e.g. "A photo of " vs. "A painting of") can make a big difference. Playing around with them can help a lot
+
+### Parameters
+Both `edit_strength` and `guidance_scale` have similar properties qualitatively: the higher the value the more the image will change. We suggest
+* Increasing/decreasing `edit_strength` first, particularly to alter/preserve more of the original structure/content
+* Then changing `guidance_scale` to make the change in the edited region more or less pronounced.
+
+Usually we find changing `edit_strength` to be enough, but feel free to play around (and report any interesting results)!
+
+### Misc.
+
+Having difficulty coming up with a caption? Try [BLIP](https://huggingface.co/spaces/Salesforce/BLIP2) to automatically generate one!
+
+As with most StableDiffusion approaches, faces/text are often problematic to render, especially if they're small. Having these in the foreground will help keep them cleaner.
+
+A returned black image means that the [Safety Checker](https://huggingface.co/CompVis/stable-diffusion-safety-checker) triggered on the photo. This happens in odd cases sometimes (it often rejects
+the huggingface logo or variations), but we need to keep it in for obvious reasons.
+"""
+# article = gr.Markdown(description)
+
+iface = gr.Interface(fn=edict, inputs=["image",
+ gr.Textbox(label="Original Description"),
+ gr.Textbox(label="Edit Description"),
+ # 50, # gr.Slider(5, 50, value=20, step=1),
+ # 0.93, # gr.Slider(0.5, 1, value=0.7, step=0.05),
+ gr.Slider(0.0, 1, value=0.8, step=0.05),
+ gr.Slider(0, 10, value=3, step=0.5),
+ ],
+ examples = examples,
+ outputs="image",
+ description=description,
+ article=article,
+ cache_examples=True)
+iface.launch()