patrickramos commited on
Commit
617c3e2
1 Parent(s): fc3d6cc

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +35 -3
README.md CHANGED
@@ -5,6 +5,7 @@ tags:
5
  - sklearn
6
  - skops
7
  - tabular-classification
 
8
  ---
9
 
10
  # Model description
@@ -71,9 +72,40 @@ Use the code below to get started with the model.
71
  <summary> Click to expand </summary>
72
 
73
  ```python
74
- import pickle
75
- with open('model.pkl', 'rb') as file:
76
- clf = pickle.load(file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  ```
78
 
79
  </details>
 
5
  - sklearn
6
  - skops
7
  - tabular-classification
8
+ - visual emb-gam
9
  ---
10
 
11
  # Model description
 
72
  <summary> Click to expand </summary>
73
 
74
  ```python
75
+ from PIL import Image
76
+ from skops import hub_utils
77
+ import torch
78
+ from transformers import AutoFeatureExtractor, AutoModel
79
+ import pickle
80
+ import os
81
+
82
+ # load embedding model
83
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
84
+ feature_extractor = AutoFeatureExtractor.from_pretrained('microsoft/resnet-50')
85
+ model = AutoModel.from_pretrained('microsoft/resnet-50').eval().to(device)
86
+
87
+ # load logistic regression
88
+ os.mkdir('emb-gam-resnet')
89
+ hub_utils.download(repo_id='Ramos-Ramos/emb-gam-resnet', dst='emb-gam-resnet')
90
+
91
+ with open('emb-gam-resnet/model.pkl', 'rb') as file:
92
+ logistic_regression = pickle.load(file)
93
+
94
+ # load image
95
+ img = Image.open('examples/english_springer.png')
96
+
97
+ # preprocess image
98
+ inputs = {k: v.to(device) for k, v in feature_extractor(img, return_tensors='pt').items()}
99
+
100
+ # extract patch embeddings
101
+ with torch.no_grad():
102
+ patch_embeddings = rearrange(model(**inputs).last_hidden_state, 'b d h w -> b (h w) d').cpu()
103
+
104
+ # classify
105
+ pred = logistic_regression.predict(patch_embeddings.sum(dim=0, keepdim=True))
106
+
107
+ # get patch contributions
108
+ patch_contributions = logistic_regression.coef_ @ patch_embeddings.T.numpy()
109
  ```
110
 
111
  </details>