Androidonnxfork's picture
Duplicate from Androidonnxfork/sd-to-diffuserscustom
e0097f3
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
from collections import OrderedDict
import torch
from diffusers import (
AutoPipelineForImage2Image,
AutoPipelineForInpainting,
AutoPipelineForText2Image,
ControlNetModel,
)
from diffusers.pipelines.auto_pipeline import (
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
AUTO_INPAINT_PIPELINES_MAPPING,
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
)
from diffusers.utils import slow
PRETRAINED_MODEL_REPO_MAPPING = OrderedDict(
[
("stable-diffusion", "runwayml/stable-diffusion-v1-5"),
("if", "DeepFloyd/IF-I-XL-v1.0"),
("kandinsky", "kandinsky-community/kandinsky-2-1"),
("kandinsky22", "kandinsky-community/kandinsky-2-2-decoder"),
]
)
class AutoPipelineFastTest(unittest.TestCase):
def test_from_pipe_consistent(self):
pipe = AutoPipelineForText2Image.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe", requires_safety_checker=False
)
original_config = dict(pipe.config)
pipe = AutoPipelineForImage2Image.from_pipe(pipe)
assert dict(pipe.config) == original_config
pipe = AutoPipelineForText2Image.from_pipe(pipe)
assert dict(pipe.config) == original_config
def test_from_pipe_override(self):
pipe = AutoPipelineForText2Image.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe", requires_safety_checker=False
)
pipe = AutoPipelineForImage2Image.from_pipe(pipe, requires_safety_checker=True)
assert pipe.config.requires_safety_checker is True
pipe = AutoPipelineForText2Image.from_pipe(pipe, requires_safety_checker=True)
assert pipe.config.requires_safety_checker is True
def test_from_pipe_consistent_sdxl(self):
pipe = AutoPipelineForImage2Image.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-xl-pipe",
requires_aesthetics_score=True,
force_zeros_for_empty_prompt=False,
)
original_config = dict(pipe.config)
pipe = AutoPipelineForText2Image.from_pipe(pipe)
pipe = AutoPipelineForImage2Image.from_pipe(pipe)
assert dict(pipe.config) == original_config
@slow
class AutoPipelineIntegrationTest(unittest.TestCase):
def test_pipe_auto(self):
for model_name, model_repo in PRETRAINED_MODEL_REPO_MAPPING.items():
# test txt2img
pipe_txt2img = AutoPipelineForText2Image.from_pretrained(
model_repo, variant="fp16", torch_dtype=torch.float16
)
self.assertIsInstance(pipe_txt2img, AUTO_TEXT2IMAGE_PIPELINES_MAPPING[model_name])
pipe_to = AutoPipelineForText2Image.from_pipe(pipe_txt2img)
self.assertIsInstance(pipe_to, AUTO_TEXT2IMAGE_PIPELINES_MAPPING[model_name])
pipe_to = AutoPipelineForImage2Image.from_pipe(pipe_txt2img)
self.assertIsInstance(pipe_to, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING[model_name])
if "kandinsky" not in model_name:
pipe_to = AutoPipelineForInpainting.from_pipe(pipe_txt2img)
self.assertIsInstance(pipe_to, AUTO_INPAINT_PIPELINES_MAPPING[model_name])
del pipe_txt2img, pipe_to
gc.collect()
# test img2img
pipe_img2img = AutoPipelineForImage2Image.from_pretrained(
model_repo, variant="fp16", torch_dtype=torch.float16
)
self.assertIsInstance(pipe_img2img, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING[model_name])
pipe_to = AutoPipelineForText2Image.from_pipe(pipe_img2img)
self.assertIsInstance(pipe_to, AUTO_TEXT2IMAGE_PIPELINES_MAPPING[model_name])
pipe_to = AutoPipelineForImage2Image.from_pipe(pipe_img2img)
self.assertIsInstance(pipe_to, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING[model_name])
if "kandinsky" not in model_name:
pipe_to = AutoPipelineForInpainting.from_pipe(pipe_img2img)
self.assertIsInstance(pipe_to, AUTO_INPAINT_PIPELINES_MAPPING[model_name])
del pipe_img2img, pipe_to
gc.collect()
# test inpaint
if "kandinsky" not in model_name:
pipe_inpaint = AutoPipelineForInpainting.from_pretrained(
model_repo, variant="fp16", torch_dtype=torch.float16
)
self.assertIsInstance(pipe_inpaint, AUTO_INPAINT_PIPELINES_MAPPING[model_name])
pipe_to = AutoPipelineForText2Image.from_pipe(pipe_inpaint)
self.assertIsInstance(pipe_to, AUTO_TEXT2IMAGE_PIPELINES_MAPPING[model_name])
pipe_to = AutoPipelineForImage2Image.from_pipe(pipe_inpaint)
self.assertIsInstance(pipe_to, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING[model_name])
pipe_to = AutoPipelineForInpainting.from_pipe(pipe_inpaint)
self.assertIsInstance(pipe_to, AUTO_INPAINT_PIPELINES_MAPPING[model_name])
del pipe_inpaint, pipe_to
gc.collect()
def test_from_pipe_consistent(self):
for model_name, model_repo in PRETRAINED_MODEL_REPO_MAPPING.items():
if model_name in ["kandinsky", "kandinsky22"]:
auto_pipes = [AutoPipelineForText2Image, AutoPipelineForImage2Image]
else:
auto_pipes = [AutoPipelineForText2Image, AutoPipelineForImage2Image, AutoPipelineForInpainting]
# test from_pretrained
for pipe_from_class in auto_pipes:
pipe_from = pipe_from_class.from_pretrained(model_repo, variant="fp16", torch_dtype=torch.float16)
pipe_from_config = dict(pipe_from.config)
for pipe_to_class in auto_pipes:
pipe_to = pipe_to_class.from_pipe(pipe_from)
self.assertEqual(dict(pipe_to.config), pipe_from_config)
del pipe_from, pipe_to
gc.collect()
def test_controlnet(self):
# test from_pretrained
model_repo = "runwayml/stable-diffusion-v1-5"
controlnet_repo = "lllyasviel/sd-controlnet-canny"
controlnet = ControlNetModel.from_pretrained(controlnet_repo, torch_dtype=torch.float16)
pipe_txt2img = AutoPipelineForText2Image.from_pretrained(
model_repo, controlnet=controlnet, torch_dtype=torch.float16
)
self.assertIsInstance(pipe_txt2img, AUTO_TEXT2IMAGE_PIPELINES_MAPPING["stable-diffusion-controlnet"])
pipe_img2img = AutoPipelineForImage2Image.from_pretrained(
model_repo, controlnet=controlnet, torch_dtype=torch.float16
)
self.assertIsInstance(pipe_img2img, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["stable-diffusion-controlnet"])
pipe_inpaint = AutoPipelineForInpainting.from_pretrained(
model_repo, controlnet=controlnet, torch_dtype=torch.float16
)
self.assertIsInstance(pipe_inpaint, AUTO_INPAINT_PIPELINES_MAPPING["stable-diffusion-controlnet"])
# test from_pipe
for pipe_from in [pipe_txt2img, pipe_img2img, pipe_inpaint]:
pipe_to = AutoPipelineForText2Image.from_pipe(pipe_from)
self.assertIsInstance(pipe_to, AUTO_TEXT2IMAGE_PIPELINES_MAPPING["stable-diffusion-controlnet"])
self.assertEqual(dict(pipe_to.config), dict(pipe_txt2img.config))
pipe_to = AutoPipelineForImage2Image.from_pipe(pipe_from)
self.assertIsInstance(pipe_to, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["stable-diffusion-controlnet"])
self.assertEqual(dict(pipe_to.config), dict(pipe_img2img.config))
pipe_to = AutoPipelineForInpainting.from_pipe(pipe_from)
self.assertIsInstance(pipe_to, AUTO_INPAINT_PIPELINES_MAPPING["stable-diffusion-controlnet"])
self.assertEqual(dict(pipe_to.config), dict(pipe_inpaint.config))