Convert weights to jax
#18
by
jfacevedo
- opened
Hello. I tried converting the weights to jax, but running into an error.
Code:
from diffusers import FlaxStableDiffusionPipeline
model_name='riffusion/riffusion-model-v1'
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(model_name, from_pt=True)
pipeline.save_pretrained('riffusion_jax', params=params)
error:
File "riffusion.py", line 39, in <module>
pipeline.save_pretrained('riffusion_jax', params=params)
File "/python3.8/site-packages/diffusers/pipeline_flax_utils.py", line 189, in save_pretrained
save_method(
File "/python3.8/site-packages/diffusers/modeling_flax_utils.py", line 518, in save_pretrained
model_to_save.save_config(save_directory)
File "/python3.8/site-packages/diffusers/configuration_utils.py", line 137, in save_config
self.to_json_file(output_config_file)
File "/python3.8/site-packages/diffusers/configuration_utils.py", line 524, in to_json_file
writer.write(self.to_json_string())
File "/python3.8/site-packages/diffusers/configuration_utils.py", line 504, in to_json_string
config_dict["_class_name"] = self.__class__.__name__
File "/python3.8/site-packages/flax/core/frozen_dict.py", line 72, in __setitem__
raise ValueError('FrozenDict is immutable.')
ValueError: FrozenDict is immutable.
The same script works for CompVis/stable-diffusion-v1-4
and runwayml/stable-diffusion-v1-5
models without issues. Any idea what could be causing this? Thanks.