| | import unittest |
| |
|
| | from diffusers.pipelines.pipeline_utils import is_safetensors_compatible |
| |
|
| |
|
| | class IsSafetensorsCompatibleTests(unittest.TestCase): |
| | def test_all_is_compatible(self): |
| | filenames = [ |
| | "safety_checker/pytorch_model.bin", |
| | "safety_checker/model.safetensors", |
| | "vae/diffusion_pytorch_model.bin", |
| | "vae/diffusion_pytorch_model.safetensors", |
| | "text_encoder/pytorch_model.bin", |
| | "text_encoder/model.safetensors", |
| | "unet/diffusion_pytorch_model.bin", |
| | "unet/diffusion_pytorch_model.safetensors", |
| | ] |
| | self.assertTrue(is_safetensors_compatible(filenames)) |
| |
|
| | def test_diffusers_model_is_compatible(self): |
| | filenames = [ |
| | "unet/diffusion_pytorch_model.bin", |
| | "unet/diffusion_pytorch_model.safetensors", |
| | ] |
| | self.assertTrue(is_safetensors_compatible(filenames)) |
| |
|
| | def test_diffusers_model_is_not_compatible(self): |
| | filenames = [ |
| | "safety_checker/pytorch_model.bin", |
| | "safety_checker/model.safetensors", |
| | "vae/diffusion_pytorch_model.bin", |
| | "vae/diffusion_pytorch_model.safetensors", |
| | "text_encoder/pytorch_model.bin", |
| | "text_encoder/model.safetensors", |
| | "unet/diffusion_pytorch_model.bin", |
| | |
| | ] |
| | self.assertFalse(is_safetensors_compatible(filenames)) |
| |
|
| | def test_transformer_model_is_compatible(self): |
| | filenames = [ |
| | "text_encoder/pytorch_model.bin", |
| | "text_encoder/model.safetensors", |
| | ] |
| | self.assertTrue(is_safetensors_compatible(filenames)) |
| |
|
| | def test_transformer_model_is_not_compatible(self): |
| | filenames = [ |
| | "safety_checker/pytorch_model.bin", |
| | "safety_checker/model.safetensors", |
| | "vae/diffusion_pytorch_model.bin", |
| | "vae/diffusion_pytorch_model.safetensors", |
| | "text_encoder/pytorch_model.bin", |
| | |
| | "unet/diffusion_pytorch_model.bin", |
| | "unet/diffusion_pytorch_model.safetensors", |
| | ] |
| | self.assertFalse(is_safetensors_compatible(filenames)) |
| |
|
| | def test_all_is_compatible_variant(self): |
| | filenames = [ |
| | "safety_checker/pytorch_model.fp16.bin", |
| | "safety_checker/model.fp16.safetensors", |
| | "vae/diffusion_pytorch_model.fp16.bin", |
| | "vae/diffusion_pytorch_model.fp16.safetensors", |
| | "text_encoder/pytorch_model.fp16.bin", |
| | "text_encoder/model.fp16.safetensors", |
| | "unet/diffusion_pytorch_model.fp16.bin", |
| | "unet/diffusion_pytorch_model.fp16.safetensors", |
| | ] |
| | variant = "fp16" |
| | self.assertTrue(is_safetensors_compatible(filenames, variant=variant)) |
| |
|
| | def test_diffusers_model_is_compatible_variant(self): |
| | filenames = [ |
| | "unet/diffusion_pytorch_model.fp16.bin", |
| | "unet/diffusion_pytorch_model.fp16.safetensors", |
| | ] |
| | variant = "fp16" |
| | self.assertTrue(is_safetensors_compatible(filenames, variant=variant)) |
| |
|
| | def test_diffusers_model_is_compatible_variant_partial(self): |
| | |
| | filenames = [ |
| | "unet/diffusion_pytorch_model.bin", |
| | "unet/diffusion_pytorch_model.safetensors", |
| | ] |
| | variant = "fp16" |
| | self.assertTrue(is_safetensors_compatible(filenames, variant=variant)) |
| |
|
| | def test_diffusers_model_is_not_compatible_variant(self): |
| | filenames = [ |
| | "safety_checker/pytorch_model.fp16.bin", |
| | "safety_checker/model.fp16.safetensors", |
| | "vae/diffusion_pytorch_model.fp16.bin", |
| | "vae/diffusion_pytorch_model.fp16.safetensors", |
| | "text_encoder/pytorch_model.fp16.bin", |
| | "text_encoder/model.fp16.safetensors", |
| | "unet/diffusion_pytorch_model.fp16.bin", |
| | |
| | ] |
| | variant = "fp16" |
| | self.assertFalse(is_safetensors_compatible(filenames, variant=variant)) |
| |
|
| | def test_transformer_model_is_compatible_variant(self): |
| | filenames = [ |
| | "text_encoder/pytorch_model.fp16.bin", |
| | "text_encoder/model.fp16.safetensors", |
| | ] |
| | variant = "fp16" |
| | self.assertTrue(is_safetensors_compatible(filenames, variant=variant)) |
| |
|
| | def test_transformer_model_is_compatible_variant_partial(self): |
| | |
| | filenames = [ |
| | "text_encoder/pytorch_model.bin", |
| | "text_encoder/model.safetensors", |
| | ] |
| | variant = "fp16" |
| | self.assertTrue(is_safetensors_compatible(filenames, variant=variant)) |
| |
|
| | def test_transformer_model_is_not_compatible_variant(self): |
| | filenames = [ |
| | "safety_checker/pytorch_model.fp16.bin", |
| | "safety_checker/model.fp16.safetensors", |
| | "vae/diffusion_pytorch_model.fp16.bin", |
| | "vae/diffusion_pytorch_model.fp16.safetensors", |
| | "text_encoder/pytorch_model.fp16.bin", |
| | |
| | "unet/diffusion_pytorch_model.fp16.bin", |
| | "unet/diffusion_pytorch_model.fp16.safetensors", |
| | ] |
| | variant = "fp16" |
| | self.assertFalse(is_safetensors_compatible(filenames, variant=variant)) |
| |
|