kaushalya commited on
Commit
a97b84b
1 Parent(s): b0632fa

Fix loading issue

Browse files

Adds `_do_init` class variable to `FlaxHybridCLIPModule` as it is passed from the newer version of `FlaxPreTrainedModel`.
```
Traceback (most recent call last):
File "/home/kaushalya/miniconda3/envs/flax_p38/lib/python3.8/site-packages/streamlit/scriptrunner/script_runner.py", line 475, in _run_script
exec(code, module.__dict__)
File "/home/kaushalya/coding/medclip/app.py", line 67, in <module>
model, processor = load_model()
File "/home/kaushalya/miniconda3/envs/flax_p38/lib/python3.8/site-packages/streamlit/legacy_caching/caching.py", line 573, in wrapped_func
return get_or_create_cached_value()
File "/home/kaushalya/miniconda3/envs/flax_p38/lib/python3.8/site-packages/streamlit/legacy_caching/caching.py", line 557, in get_or_create_cached_value
return_value = func(*args, **kwargs)
File "/home/kaushalya/coding/medclip/app.py", line 14, in load_model
model = FlaxHybridCLIP.from_pretrained("flax-community/medclip-roco")
File "/home/kaushalya/miniconda3/envs/flax_p38/lib/python3.8/site-packages/transformers/modeling_flax_utils.py", line 596, in from_pretrained
model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)
File "/home/kaushalya/coding/medclip/medclip/modeling_hybrid_clip.py", line 140, in __init__
module = self.module_class(config=config, dtype=dtype, **kwargs)
TypeError: __init__() got an unexpected keyword argument '_do_init'
```

Files changed (1) hide show
  1. medclip/modeling_hybrid_clip.py +1 -0
medclip/modeling_hybrid_clip.py CHANGED
@@ -32,6 +32,7 @@ logger = logging.get_logger(__name__)
32
  class FlaxHybridCLIPModule(nn.Module):
33
  config: HybridCLIPConfig
34
  dtype: jnp.dtype = jnp.float32
 
35
 
36
  def setup(self):
37
  text_config = self.config.text_config
32
  class FlaxHybridCLIPModule(nn.Module):
33
  config: HybridCLIPConfig
34
  dtype: jnp.dtype = jnp.float32
35
+ _do_init: bool = False
36
 
37
  def setup(self):
38
  text_config = self.config.text_config