Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,60 @@
|
|
1 |
-
---
|
2 |
-
license: apache-2.0
|
3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
---
|
4 |
+
|
5 |
+
To use Sum-of-Parts(SOP), you would need to install exlib. Currently SOP is only available on the dev branch https://github.com/BrachioLab/exlib/tree/dev
|
6 |
+
|
7 |
+
To use SOP trained for `google/vit-base-patch16-224`, follow the following code.
|
8 |
+
|
9 |
+
### Load the model
|
10 |
+
```
|
11 |
+
import torch
|
12 |
+
import os
|
13 |
+
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
14 |
+
|
15 |
+
import sys
|
16 |
+
from exlib.modules.sop import WrappedModel, SOPConfig, SOPImageCls, get_chained_attr
|
17 |
+
|
18 |
+
|
19 |
+
# init backbone model
|
20 |
+
backbone_model = AutoModelForImageClassification.from_pretrained('google/vit-base-patch16-224')
|
21 |
+
processor = AutoImageProcessor.from_pretrained('google/vit-base-patch16-224')
|
22 |
+
|
23 |
+
# get needed wrapped models
|
24 |
+
original_model = WrappedModel(backbone_model, output_type='logits')
|
25 |
+
wrapped_backbone_model = WrappedModel(backbone_model, output_type='tuple')
|
26 |
+
projection_layer = WrappedModel(wrapped_backbone_model, output_type='hidden_states')
|
27 |
+
|
28 |
+
# load trained sop model
|
29 |
+
model = SOPImageCls.from_pretrained('BrachioLab/sop-vit-base-patch16-224',
|
30 |
+
blackbox_model=wrapped_backbone_model,
|
31 |
+
projection_layer=projection_layer)
|
32 |
+
model.eval();
|
33 |
+
```
|
34 |
+
|
35 |
+
### Open an image
|
36 |
+
```
|
37 |
+
from PIL import Image
|
38 |
+
|
39 |
+
# Open an example image
|
40 |
+
# image_path = '../../examples/ILSVRC2012_val_00000873.JPEG'
|
41 |
+
image_path = '../../examples/ILSVRC2012_val_00000247.JPEG'
|
42 |
+
image = Image.open(image_path)
|
43 |
+
image.show()
|
44 |
+
image_rgb = image.convert("RGB")
|
45 |
+
inputs = torch.tensor(processor(image_rgb)['pixel_values'])
|
46 |
+
inputs.shape # (1, 3, 224, 224)
|
47 |
+
```
|
48 |
+
|
49 |
+
### Get the output from SOP
|
50 |
+
```
|
51 |
+
# Get the outputs from the model
|
52 |
+
outputs = model(inputs, return_tuple=True)
|
53 |
+
```
|
54 |
+
|
55 |
+
### Show the groups
|
56 |
+
```
|
57 |
+
from exlib.modules.sop import show_masks_weights
|
58 |
+
|
59 |
+
show_masks_weights(inputs, outputs, i=0) # This allows you to see the group masks with group attribution scores.
|
60 |
+
```
|