To use Sum-of-Parts(SOP), you would need to install exlib. Currently SOP is only available on the dev branch https://github.com/BrachioLab/exlib/tree/dev
To use SOP trained for google/vit-base-patch16-224
, follow the following code.
Load the model
import torch
import os
from transformers import AutoImageProcessor, AutoModelForImageClassification
import sys
from exlib.modules.sop import WrappedModel, SOPConfig, SOPImageCls, get_chained_attr
# init backbone model
backbone_model = AutoModelForImageClassification.from_pretrained('google/vit-base-patch16-224')
processor = AutoImageProcessor.from_pretrained('google/vit-base-patch16-224')
# get needed wrapped models
original_model = WrappedModel(backbone_model, output_type='logits')
wrapped_backbone_model = WrappedModel(backbone_model, output_type='tuple')
projection_layer = WrappedModel(wrapped_backbone_model, output_type='hidden_states')
# load trained sop model
model = SOPImageCls.from_pretrained('BrachioLab/sop-vit-base-patch16-224',
blackbox_model=wrapped_backbone_model,
projection_layer=projection_layer)
model.eval();
Open an image
from PIL import Image
# Open an example image
# image_path = '../../examples/ILSVRC2012_val_00000873.JPEG'
image_path = '../../examples/ILSVRC2012_val_00000247.JPEG'
image = Image.open(image_path)
image.show()
image_rgb = image.convert("RGB")
inputs = torch.tensor(processor(image_rgb)['pixel_values'])
inputs.shape # (1, 3, 224, 224)
Get the output from SOP
# Get the outputs from the model
outputs = model(inputs, return_tuple=True)
Show the groups
from exlib.modules.sop import show_masks_weights
show_masks_weights(inputs, outputs, i=0) # This allows you to see the group masks with group attribution scores.
- Downloads last month
- 2