fallcat commited on
Commit
a085136
·
verified ·
1 Parent(s): 7a24093

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +60 -3
README.md CHANGED
@@ -1,3 +1,60 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+
5
+ To use Sum-of-Parts(SOP), you would need to install exlib. Currently SOP is only available on the dev branch https://github.com/BrachioLab/exlib/tree/dev
6
+
7
+ To use SOP trained for `google/vit-base-patch16-224`, follow the following code.
8
+
9
+ ### Load the model
10
+ ```
11
+ import torch
12
+ import os
13
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
14
+
15
+ import sys
16
+ from exlib.modules.sop import WrappedModel, SOPConfig, SOPImageCls, get_chained_attr
17
+
18
+
19
+ # init backbone model
20
+ backbone_model = AutoModelForImageClassification.from_pretrained('google/vit-base-patch16-224')
21
+ processor = AutoImageProcessor.from_pretrained('google/vit-base-patch16-224')
22
+
23
+ # get needed wrapped models
24
+ original_model = WrappedModel(backbone_model, output_type='logits')
25
+ wrapped_backbone_model = WrappedModel(backbone_model, output_type='tuple')
26
+ projection_layer = WrappedModel(wrapped_backbone_model, output_type='hidden_states')
27
+
28
+ # load trained sop model
29
+ model = SOPImageCls.from_pretrained('BrachioLab/sop-vit-base-patch16-224',
30
+ blackbox_model=wrapped_backbone_model,
31
+ projection_layer=projection_layer)
32
+ model.eval();
33
+ ```
34
+
35
+ ### Open an image
36
+ ```
37
+ from PIL import Image
38
+
39
+ # Open an example image
40
+ # image_path = '../../examples/ILSVRC2012_val_00000873.JPEG'
41
+ image_path = '../../examples/ILSVRC2012_val_00000247.JPEG'
42
+ image = Image.open(image_path)
43
+ image.show()
44
+ image_rgb = image.convert("RGB")
45
+ inputs = torch.tensor(processor(image_rgb)['pixel_values'])
46
+ inputs.shape # (1, 3, 224, 224)
47
+ ```
48
+
49
+ ### Get the output from SOP
50
+ ```
51
+ # Get the outputs from the model
52
+ outputs = model(inputs, return_tuple=True)
53
+ ```
54
+
55
+ ### Show the groups
56
+ ```
57
+ from exlib.modules.sop import show_masks_weights
58
+
59
+ show_masks_weights(inputs, outputs, i=0) # This allows you to see the group masks with group attribution scores.
60
+ ```