Meta-tuned SMAT ViTs
Collection
A collection of meta-tuned SMAT model checkpoints for pre-trained Vision Transformer backbones of various scales.
•
5 items
•
Updated
This repo contains the SMAT meta-tuned vit-sup21k-small model checkpoint for PyTorch.
With our implementation here on github, you can load the pre-trained weights by calling
model.load_state_dict(torch.load(/path/to/checkpoint.pt))
For inference with ProtoNet on a few-shot learning task:
# outputs is a dictionary
outputs = model(x_s=x_s, # support inputs
y_s=y_s, # support labels
x_q=x_q, # query inputs
y_q=None, # predict for query labels
finetune_model=None # None for direct inference with a ProtoNet classifier
)
y_q_pred = outputs['y_q_pred']
For inference with task-specific full fine-tuning then inference:
# outputs is a dictionary
model.args.meta_learner.inner_lr.lr = lr # set the learning rate for fine-tuning
model.args.meta_learner.num_finetune_steps = num_finetune_steps # set the number of fine-tuning steps
outputs = model(x_s=x_s, # support inputs
y_s=y_s, # support labels
x_q=x_q, # query inputs
y_q=None, # predict for query labels
finetune_model="full" # {'full','lora'}
)
y_q_pred = outputs['y_q_pred']
You can visit our github repo for more details on training and inference!