Spaces:
Running
on
Zero
Running
on
Zero
import tempfile | |
from io import BytesIO | |
import requests | |
import torch | |
from huggingface_hub import hf_hub_download, snapshot_download | |
from diffusers.models.attention_processor import AttnProcessor | |
from diffusers.utils.testing_utils import ( | |
numpy_cosine_similarity_distance, | |
torch_device, | |
) | |
def download_single_file_checkpoint(repo_id, filename, tmpdir): | |
path = hf_hub_download(repo_id, filename=filename, local_dir=tmpdir) | |
return path | |
def download_original_config(config_url, tmpdir): | |
original_config_file = BytesIO(requests.get(config_url).content) | |
path = f"{tmpdir}/config.yaml" | |
with open(path, "wb") as f: | |
f.write(original_config_file.read()) | |
return path | |
def download_diffusers_config(repo_id, tmpdir): | |
path = snapshot_download( | |
repo_id, | |
ignore_patterns=[ | |
"**/*.ckpt", | |
"*.ckpt", | |
"**/*.bin", | |
"*.bin", | |
"**/*.pt", | |
"*.pt", | |
"**/*.safetensors", | |
"*.safetensors", | |
], | |
allow_patterns=["**/*.json", "*.json", "*.txt", "**/*.txt"], | |
local_dir=tmpdir, | |
) | |
return path | |
class SDSingleFileTesterMixin: | |
def _compare_component_configs(self, pipe, single_file_pipe): | |
for param_name, param_value in single_file_pipe.text_encoder.config.to_dict().items(): | |
if param_name in ["torch_dtype", "architectures", "_name_or_path"]: | |
continue | |
assert pipe.text_encoder.config.to_dict()[param_name] == param_value | |
PARAMS_TO_IGNORE = [ | |
"torch_dtype", | |
"_name_or_path", | |
"architectures", | |
"_use_default_values", | |
"_diffusers_version", | |
] | |
for component_name, component in single_file_pipe.components.items(): | |
if component_name in single_file_pipe._optional_components: | |
continue | |
# skip testing transformer based components here | |
# skip text encoders / safety checkers since they have already been tested | |
if component_name in ["text_encoder", "tokenizer", "safety_checker", "feature_extractor"]: | |
continue | |
assert component_name in pipe.components, f"single file {component_name} not found in pretrained pipeline" | |
assert isinstance( | |
component, pipe.components[component_name].__class__ | |
), f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same" | |
for param_name, param_value in component.config.items(): | |
if param_name in PARAMS_TO_IGNORE: | |
continue | |
# Some pretrained configs will set upcast attention to None | |
# In single file loading it defaults to the value in the class __init__ which is False | |
if param_name == "upcast_attention" and pipe.components[component_name].config[param_name] is None: | |
pipe.components[component_name].config[param_name] = param_value | |
assert ( | |
pipe.components[component_name].config[param_name] == param_value | |
), f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}" | |
def test_single_file_components(self, pipe=None, single_file_pipe=None): | |
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( | |
self.ckpt_path, safety_checker=None | |
) | |
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) | |
self._compare_component_configs(pipe, single_file_pipe) | |
def test_single_file_components_local_files_only(self, pipe=None, single_file_pipe=None): | |
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) | |
with tempfile.TemporaryDirectory() as tmpdir: | |
ckpt_filename = self.ckpt_path.split("/")[-1] | |
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) | |
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( | |
local_ckpt_path, safety_checker=None, local_files_only=True | |
) | |
self._compare_component_configs(pipe, single_file_pipe) | |
def test_single_file_components_with_original_config( | |
self, | |
pipe=None, | |
single_file_pipe=None, | |
): | |
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) | |
# Not possible to infer this value when original config is provided | |
# we just pass it in here otherwise this test will fail | |
upcast_attention = pipe.unet.config.upcast_attention | |
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( | |
self.ckpt_path, | |
original_config=self.original_config, | |
safety_checker=None, | |
upcast_attention=upcast_attention, | |
) | |
self._compare_component_configs(pipe, single_file_pipe) | |
def test_single_file_components_with_original_config_local_files_only( | |
self, | |
pipe=None, | |
single_file_pipe=None, | |
): | |
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) | |
# Not possible to infer this value when original config is provided | |
# we just pass it in here otherwise this test will fail | |
upcast_attention = pipe.unet.config.upcast_attention | |
with tempfile.TemporaryDirectory() as tmpdir: | |
ckpt_filename = self.ckpt_path.split("/")[-1] | |
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) | |
local_original_config = download_original_config(self.original_config, tmpdir) | |
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( | |
local_ckpt_path, | |
original_config=local_original_config, | |
safety_checker=None, | |
upcast_attention=upcast_attention, | |
local_files_only=True, | |
) | |
self._compare_component_configs(pipe, single_file_pipe) | |
def test_single_file_format_inference_is_same_as_pretrained(self, expected_max_diff=1e-4): | |
sf_pipe = self.pipeline_class.from_single_file(self.ckpt_path, safety_checker=None) | |
sf_pipe.unet.set_attn_processor(AttnProcessor()) | |
sf_pipe.enable_model_cpu_offload() | |
inputs = self.get_inputs(torch_device) | |
image_single_file = sf_pipe(**inputs).images[0] | |
pipe = self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) | |
pipe.unet.set_attn_processor(AttnProcessor()) | |
pipe.enable_model_cpu_offload() | |
inputs = self.get_inputs(torch_device) | |
image = pipe(**inputs).images[0] | |
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten()) | |
assert max_diff < expected_max_diff | |
def test_single_file_components_with_diffusers_config( | |
self, | |
pipe=None, | |
single_file_pipe=None, | |
): | |
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( | |
self.ckpt_path, config=self.repo_id, safety_checker=None | |
) | |
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) | |
self._compare_component_configs(pipe, single_file_pipe) | |
def test_single_file_components_with_diffusers_config_local_files_only( | |
self, | |
pipe=None, | |
single_file_pipe=None, | |
): | |
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) | |
with tempfile.TemporaryDirectory() as tmpdir: | |
ckpt_filename = self.ckpt_path.split("/")[-1] | |
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) | |
local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir) | |
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( | |
local_ckpt_path, config=local_diffusers_config, safety_checker=None, local_files_only=True | |
) | |
self._compare_component_configs(pipe, single_file_pipe) | |
class SDXLSingleFileTesterMixin: | |
def _compare_component_configs(self, pipe, single_file_pipe): | |
# Skip testing the text_encoder for Refiner Pipelines | |
if pipe.text_encoder: | |
for param_name, param_value in single_file_pipe.text_encoder.config.to_dict().items(): | |
if param_name in ["torch_dtype", "architectures", "_name_or_path"]: | |
continue | |
assert pipe.text_encoder.config.to_dict()[param_name] == param_value | |
for param_name, param_value in single_file_pipe.text_encoder_2.config.to_dict().items(): | |
if param_name in ["torch_dtype", "architectures", "_name_or_path"]: | |
continue | |
assert pipe.text_encoder_2.config.to_dict()[param_name] == param_value | |
PARAMS_TO_IGNORE = [ | |
"torch_dtype", | |
"_name_or_path", | |
"architectures", | |
"_use_default_values", | |
"_diffusers_version", | |
] | |
for component_name, component in single_file_pipe.components.items(): | |
if component_name in single_file_pipe._optional_components: | |
continue | |
# skip text encoders since they have already been tested | |
if component_name in ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]: | |
continue | |
# skip safety checker if it is not present in the pipeline | |
if component_name in ["safety_checker", "feature_extractor"]: | |
continue | |
assert component_name in pipe.components, f"single file {component_name} not found in pretrained pipeline" | |
assert isinstance( | |
component, pipe.components[component_name].__class__ | |
), f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same" | |
for param_name, param_value in component.config.items(): | |
if param_name in PARAMS_TO_IGNORE: | |
continue | |
# Some pretrained configs will set upcast attention to None | |
# In single file loading it defaults to the value in the class __init__ which is False | |
if param_name == "upcast_attention" and pipe.components[component_name].config[param_name] is None: | |
pipe.components[component_name].config[param_name] = param_value | |
assert ( | |
pipe.components[component_name].config[param_name] == param_value | |
), f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}" | |
def test_single_file_components(self, pipe=None, single_file_pipe=None): | |
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( | |
self.ckpt_path, safety_checker=None | |
) | |
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) | |
self._compare_component_configs( | |
pipe, | |
single_file_pipe, | |
) | |
def test_single_file_components_local_files_only( | |
self, | |
pipe=None, | |
single_file_pipe=None, | |
): | |
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) | |
with tempfile.TemporaryDirectory() as tmpdir: | |
ckpt_filename = self.ckpt_path.split("/")[-1] | |
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) | |
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( | |
local_ckpt_path, safety_checker=None, local_files_only=True | |
) | |
self._compare_component_configs(pipe, single_file_pipe) | |
def test_single_file_components_with_original_config( | |
self, | |
pipe=None, | |
single_file_pipe=None, | |
): | |
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) | |
# Not possible to infer this value when original config is provided | |
# we just pass it in here otherwise this test will fail | |
upcast_attention = pipe.unet.config.upcast_attention | |
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( | |
self.ckpt_path, | |
original_config=self.original_config, | |
safety_checker=None, | |
upcast_attention=upcast_attention, | |
) | |
self._compare_component_configs( | |
pipe, | |
single_file_pipe, | |
) | |
def test_single_file_components_with_original_config_local_files_only( | |
self, | |
pipe=None, | |
single_file_pipe=None, | |
): | |
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) | |
# Not possible to infer this value when original config is provided | |
# we just pass it in here otherwise this test will fail | |
upcast_attention = pipe.unet.config.upcast_attention | |
with tempfile.TemporaryDirectory() as tmpdir: | |
ckpt_filename = self.ckpt_path.split("/")[-1] | |
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) | |
local_original_config = download_original_config(self.original_config, tmpdir) | |
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( | |
local_ckpt_path, | |
original_config=local_original_config, | |
upcast_attention=upcast_attention, | |
safety_checker=None, | |
local_files_only=True, | |
) | |
self._compare_component_configs( | |
pipe, | |
single_file_pipe, | |
) | |
def test_single_file_components_with_diffusers_config( | |
self, | |
pipe=None, | |
single_file_pipe=None, | |
): | |
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( | |
self.ckpt_path, config=self.repo_id, safety_checker=None | |
) | |
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) | |
self._compare_component_configs(pipe, single_file_pipe) | |
def test_single_file_components_with_diffusers_config_local_files_only( | |
self, | |
pipe=None, | |
single_file_pipe=None, | |
): | |
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) | |
with tempfile.TemporaryDirectory() as tmpdir: | |
ckpt_filename = self.ckpt_path.split("/")[-1] | |
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) | |
local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir) | |
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( | |
local_ckpt_path, config=local_diffusers_config, safety_checker=None, local_files_only=True | |
) | |
self._compare_component_configs(pipe, single_file_pipe) | |
def test_single_file_format_inference_is_same_as_pretrained(self, expected_max_diff=1e-4): | |
sf_pipe = self.pipeline_class.from_single_file(self.ckpt_path, torch_dtype=torch.float16, safety_checker=None) | |
sf_pipe.unet.set_default_attn_processor() | |
sf_pipe.enable_model_cpu_offload() | |
inputs = self.get_inputs(torch_device) | |
image_single_file = sf_pipe(**inputs).images[0] | |
pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.float16, safety_checker=None) | |
pipe.unet.set_default_attn_processor() | |
pipe.enable_model_cpu_offload() | |
inputs = self.get_inputs(torch_device) | |
image = pipe(**inputs).images[0] | |
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten()) | |
assert max_diff < expected_max_diff | |