Edit model card

This repo contains the SMAT meta-tuned vit-sup21k-small model checkpoint for PyTorch.

How to use

With our implementation here on github, you can load the pre-trained weights by calling

model.load_state_dict(torch.load(/path/to/checkpoint.pt))

For inference with ProtoNet on a few-shot learning task:

# outputs is a dictionary
outputs = model(x_s=x_s, # support inputs
               y_s=y_s,  # support labels
               x_q=x_q,  # query inputs
               y_q=None, # predict for query labels
               finetune_model=None # None for direct inference with a ProtoNet classifier 
              )  

y_q_pred = outputs['y_q_pred']

For inference with task-specific full fine-tuning then inference:

# outputs is a dictionary
model.args.meta_learner.inner_lr.lr = lr  # set the learning rate for fine-tuning
model.args.meta_learner.num_finetune_steps = num_finetune_steps # set the number of fine-tuning steps
outputs = model(x_s=x_s, # support inputs
               y_s=y_s,  # support labels
               x_q=x_q,  # query inputs
               y_q=None, # predict for query labels
               finetune_model="full" # {'full','lora'} 
              )  

y_q_pred = outputs['y_q_pred']

You can visit our github repo for more details on training and inference!

Downloads last month

-

Downloads are not tracked for this model. How to track
Unable to determine this model's library. Check the docs .

Collection including szcjerry/smat-vit-sup21k-small