Loading the model into FlaxStableDiffusionPipeline

#7
by Omorfiamorphism - opened

Hey,

I love the model, but when I want to load the model into a FlaxStableDiffusionPipeline I with this code:

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained("prompthero/midjourney-v4-diffusion", dtype=jnp.bfloat16)

I get the following error:

OSError: Error no file named flax_model.msgpack or pytorch_model.bin found in directory /root/.cache/huggingface/diffusers/models--prompthero--midjourney-v4-diffusion/snapshots/f8a4391de59f2ba8bdc73020843fe92fe9cfe1a1/text_encoder.

The same occurs if I use the model_id: "nitrosocke/mo-di-diffusion"

The file "pytorch_model.bin" mentioned in the error message exists which confuses me. The link in the readme in the passage "You can also export the model to ONNX, MPS and/or FLAX/JAX." for FLAX/JAX only leads back to the "https://huggingface.co/prompthero/midjourney-v4-diffusion" and not to an explanation.

Do you guys know how to load the model into an FlaxStableDiffusionPipeline?

Thanks a lot!

@akhaliq could you help us here?

@Omorfiamorphism if you're still having issues, use from_pt=True as a parameter when loading pytorch models with Flax pipelines

Hey,

I love the model, but when I want to load the model into a FlaxStableDiffusionPipeline I with this code:

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained("prompthero/midjourney-v4-diffusion", dtype=jnp.bfloat16)

I get the following error:

OSError: Error no file named flax_model.msgpack or pytorch_model.bin found in directory /root/.cache/huggingface/diffusers/models--prompthero--midjourney-v4-diffusion/snapshots/f8a4391de59f2ba8bdc73020843fe92fe9cfe1a1/text_encoder.

The same occurs if I use the model_id: "nitrosocke/mo-di-diffusion"

The file "pytorch_model.bin" mentioned in the error message exists which confuses me. The link in the readme in the passage "You can also export the model to ONNX, MPS and/or FLAX/JAX." for FLAX/JAX only leads back to the "https://huggingface.co/prompthero/midjourney-v4-diffusion" and not to an explanation.

Do you guys know how to load the model into an FlaxStableDiffusionPipeline?

Thanks a lot!

Yes, I have the same issue. Did you fix it? Can you show me how to fix this problem? Thanks a lot

Sign up or log in to comment