Patrick Ramos commited on
Commit
b093863
1 Parent(s): d81f7ec

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +42 -6
README.md CHANGED
@@ -3083,7 +3083,7 @@ widget:
3083
 
3084
  # Model description
3085
 
3086
- This is a LogisticRegressionCV model trained on averages of patch embeddings from the Imagenette dataset..
3087
 
3088
  ## Intended uses & limitations
3089
 
@@ -3145,9 +3145,40 @@ Use the code below to get started with the model.
3145
  <summary> Click to expand </summary>
3146
 
3147
  ```python
3148
- import pickle
3149
- with open('model.pkl', 'rb') as file:
3150
- clf = pickle.load(file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3151
  ```
3152
 
3153
  </details>
@@ -3168,11 +3199,16 @@ You can contact the model card authors through following channels:
3168
 
3169
  # Citation
3170
 
3171
- Below you can find information related to citation.
3172
 
3173
  **BibTeX:**
3174
  ```
3175
- [More Information Needed]
 
 
 
 
 
3176
  ```
3177
 
3178
 
 
3083
 
3084
  # Model description
3085
 
3086
+ This is a LogisticRegressionCV model trained on averages of patch embeddings from the Imagenette dataset. This forms the GAM of an [Emb-GAM](https://arxiv.org/abs/2209.11799) extended to images. Patch embeddings are meant to be extracted with the [`facebook/dino-vitb16` DINO checkpoint](https://huggingface.co/facebook/dino-vitb16).
3087
 
3088
  ## Intended uses & limitations
3089
 
 
3145
  <summary> Click to expand </summary>
3146
 
3147
  ```python
3148
+ from PIL import Image
3149
+ from skops import hub_utils
3150
+ import torch
3151
+ from transformers import ViTFeatureExtractor, ViTModel
3152
+ import pickle
3153
+ import os
3154
+
3155
+ # load DINO
3156
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
3157
+ feature_extractor = ViTFeatureExtractor.from_pretrained('facebook/dino-vitb16')
3158
+ model = ViTModel.from_pretrained('facebook/dino-vitb16').eval().to(device)
3159
+
3160
+ # load logistic regression
3161
+ os.mkdir('logistic regression')
3162
+ hub_utils.download(repo_id='Ramos-Ramos/emb-gam-dino', dst='emb-gam-dino')
3163
+
3164
+ with open('emb-gam-dino/model.pkl', 'rb') as file:
3165
+ logistic_regression = pickle.load(file)
3166
+
3167
+ # load image
3168
+ img = Image.open('examples/english_springer.png')
3169
+
3170
+ # preprocess image
3171
+ inputs = {k: v.to(device) for k, v in feature_extractor(img, return_tensors='pt').items()}
3172
+
3173
+ # extract patch embeddings
3174
+ with torch.no_grad():
3175
+ patch_embeddings = model(**inputs).last_hidden_state[0, 1:].cpu()
3176
+
3177
+ # classify
3178
+ pred = logistic_regression.predict(patch_embeddings.mean(dim=0).view(1, -1))
3179
+
3180
+ # get patch contributions
3181
+ patch_contributions = logistic_regression.coef_ @ patch_embeddings.T.numpy()
3182
  ```
3183
 
3184
  </details>
 
3199
 
3200
  # Citation
3201
 
3202
+ Below you can find information related to citation. Note that this is **not our own paper**.
3203
 
3204
  **BibTeX:**
3205
  ```
3206
+ @article{singh2022emb,
3207
+ title={Emb-GAM: an Interpretable and Efficient Predictor using Pre-trained Language Models},
3208
+ author={Singh, Chandan and Gao, Jianfeng},
3209
+ journal={arXiv preprint arXiv:2209.11799},
3210
+ year={2022}
3211
+ }
3212
  ```
3213
 
3214