File size: 3,846 Bytes
f1069cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import unittest

from diffusers import FlaxDPMSolverMultistepScheduler, FlaxStableDiffusionPipeline
from diffusers.utils import is_flax_available, slow
from diffusers.utils.testing_utils import require_flax


if is_flax_available():
    import jax
    import jax.numpy as jnp
    from flax.jax_utils import replicate
    from flax.training.common_utils import shard


@slow
@require_flax
class FlaxStableDiffusion2PipelineIntegrationTests(unittest.TestCase):
    def tearDown(self):
        # clean up the VRAM after each test
        super().tearDown()
        gc.collect()

    def test_stable_diffusion_flax(self):
        sd_pipe, params = FlaxStableDiffusionPipeline.from_pretrained(
            "stabilityai/stable-diffusion-2",
            revision="bf16",
            dtype=jnp.bfloat16,
        )

        prompt = "A painting of a squirrel eating a burger"
        num_samples = jax.device_count()
        prompt = num_samples * [prompt]
        prompt_ids = sd_pipe.prepare_inputs(prompt)

        params = replicate(params)
        prompt_ids = shard(prompt_ids)

        prng_seed = jax.random.PRNGKey(0)
        prng_seed = jax.random.split(prng_seed, jax.device_count())

        images = sd_pipe(prompt_ids, params, prng_seed, num_inference_steps=25, jit=True)[0]
        assert images.shape == (jax.device_count(), 1, 768, 768, 3)

        images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
        image_slice = images[0, 253:256, 253:256, -1]

        output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
        expected_slice = jnp.array([0.4238, 0.4414, 0.4395, 0.4453, 0.4629, 0.4590, 0.4531, 0.45508, 0.4512])
        print(f"output_slice: {output_slice}")
        assert jnp.abs(output_slice - expected_slice).max() < 1e-2

    def test_stable_diffusion_dpm_flax(self):
        model_id = "stabilityai/stable-diffusion-2"
        scheduler, scheduler_params = FlaxDPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler")
        sd_pipe, params = FlaxStableDiffusionPipeline.from_pretrained(
            model_id,
            scheduler=scheduler,
            revision="bf16",
            dtype=jnp.bfloat16,
        )
        params["scheduler"] = scheduler_params

        prompt = "A painting of a squirrel eating a burger"
        num_samples = jax.device_count()
        prompt = num_samples * [prompt]
        prompt_ids = sd_pipe.prepare_inputs(prompt)

        params = replicate(params)
        prompt_ids = shard(prompt_ids)

        prng_seed = jax.random.PRNGKey(0)
        prng_seed = jax.random.split(prng_seed, jax.device_count())

        images = sd_pipe(prompt_ids, params, prng_seed, num_inference_steps=25, jit=True)[0]
        assert images.shape == (jax.device_count(), 1, 768, 768, 3)

        images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
        image_slice = images[0, 253:256, 253:256, -1]

        output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
        expected_slice = jnp.array([0.4336, 0.42969, 0.4453, 0.4199, 0.4297, 0.4531, 0.4434, 0.4434, 0.4297])
        print(f"output_slice: {output_slice}")
        assert jnp.abs(output_slice - expected_slice).max() < 1e-2