Questions about finetuning the model

#2
by beline1231 - opened
This comment has been hidden

It's from my mistake. (Solved)

Solved how?

The JAX-CV codebase has got the code needed to reset the head weights in the wd_taggers_v3_finetune branch, the only missing part are a few lines to load msgpack serialized weights instead of an orbax checkpoint.

First of all, thank you for your reply.

I was having trouble loading the correct msgpack into orbax as you said, which I have now solved by pulling the code from wdv3-jax.

I don't know much about deep learning, but since I was just adding a few tags (resetting unused tags for me to be precise), I didn't reset the head weight.

It seems to work fine, although the MCC is a bit lower than expected. I'm going to lower the LR and try again.

I've left the code below for the next person to fine-tune it.

....<omitted>
print(run_name)
weights_path = "./model.msgpack"
with open(weights_path, "rb") as f:
    data = f.read()

restored = flax.serialization.msgpack_restore(data)["model"]
variables = {"params": restored["params"], **restored["constants"]}

state = state.replace(params=restored["params"])

if restore_params_ckpt or restore_simmim_ckpt: #actually not needed anyway
....<omitted>

Thank you.

I've committed official support for msgpack files in the wd_taggers_v3_finetune branch.

Sign up or log in to comment