getting unpackb error - "msgpack/_unpacker.pyx", line 201, in msgpack._cmsgpack.unpackb
python3 src/convert_flax_to_pytorch.py
/usr/local/lib/python3.10/site-packages/flax/core/frozen_dict.py:169: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use register_pytree_with_keys()
instead.
jax.tree_util.register_keypaths(
Traceback (most recent call last):
File "/Users/johnpatrick/JohnData/Projects/gpt-2-tamil/src/convert_flax_to_pytorch.py", line 4, in
model = GPT2LMHeadModel.from_pretrained("../gpt-2-tamil", from_flax=True)
File "/usr/local/lib/python3.10/site-packages/transformers/modeling_utils.py", line 2457, in from_pretrained
model = load_flax_checkpoint_in_pytorch_model(model, resolved_archive_file)
File "/usr/local/lib/python3.10/site-packages/transformers/modeling_flax_pytorch_utils.py", line 235, in load_flax_checkpoint_in_pytorch_model
flax_state_dict = from_bytes(flax_cls, state_f.read())
File "/usr/local/lib/python3.10/site-packages/flax/serialization.py", line 425, in from_bytes
state_dict = msgpack_restore(encoded_bytes)
File "/usr/local/lib/python3.10/site-packages/flax/serialization.py", line 407, in msgpack_restore
state_dict = msgpack.unpackb(
File "msgpack/_unpacker.pyx", line 201, in msgpack._cmsgpack.unpackb
msgpack.exceptions.ExtraData: unpack(b) received extra data.