Spaces:
Running
Fix loading issue
Browse filesAdds `_do_init` class variable to `FlaxHybridCLIPModule` as it is passed from the newer version of `FlaxPreTrainedModel`.
```
Traceback (most recent call last):
File "/home/kaushalya/miniconda3/envs/flax_p38/lib/python3.8/site-packages/streamlit/scriptrunner/script_runner.py", line 475, in _run_script
exec(code, module.__dict__)
File "/home/kaushalya/coding/medclip/app.py", line 67, in <module>
model, processor = load_model()
File "/home/kaushalya/miniconda3/envs/flax_p38/lib/python3.8/site-packages/streamlit/legacy_caching/caching.py", line 573, in wrapped_func
return get_or_create_cached_value()
File "/home/kaushalya/miniconda3/envs/flax_p38/lib/python3.8/site-packages/streamlit/legacy_caching/caching.py", line 557, in get_or_create_cached_value
return_value = func(*args, **kwargs)
File "/home/kaushalya/coding/medclip/app.py", line 14, in load_model
model = FlaxHybridCLIP.from_pretrained("flax-community/medclip-roco")
File "/home/kaushalya/miniconda3/envs/flax_p38/lib/python3.8/site-packages/transformers/modeling_flax_utils.py", line 596, in from_pretrained
model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)
File "/home/kaushalya/coding/medclip/medclip/modeling_hybrid_clip.py", line 140, in __init__
module = self.module_class(config=config, dtype=dtype, **kwargs)
TypeError: __init__() got an unexpected keyword argument '_do_init'
```
@@ -32,6 +32,7 @@ logger = logging.get_logger(__name__)
|
|
32 |
class FlaxHybridCLIPModule(nn.Module):
|
33 |
config: HybridCLIPConfig
|
34 |
dtype: jnp.dtype = jnp.float32
|
|
|
35 |
|
36 |
def setup(self):
|
37 |
text_config = self.config.text_config
|
|
|
32 |
class FlaxHybridCLIPModule(nn.Module):
|
33 |
config: HybridCLIPConfig
|
34 |
dtype: jnp.dtype = jnp.float32
|
35 |
+
_do_init: bool = False
|
36 |
|
37 |
def setup(self):
|
38 |
text_config = self.config.text_config
|