File size: 9,949 Bytes
ffe3a9a
 
 
 
 
 
 
617c3e2
ffe3a9a
 
 
 
85a2c5e
ffe3a9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85a2c5e
ffe3a9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
617c3e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82be63b
617c3e2
 
 
 
 
 
ffe3a9a
 
 
 
 
 
 
 
 
 
 
85a2c5e
ffe3a9a
 
 
 
 
 
 
 
85a2c5e
ffe3a9a
 
 
7052adb
 
 
 
 
 
ffe3a9a
 
85a2c5e
ffe3a9a
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
---
license: mit
library_name: sklearn
tags:
- sklearn
- skops
- tabular-classification
- visual emb-gam
---

# Model description

This is a LogisticRegressionCV model trained on averages of patch embeddings from the Imagenette dataset. This forms the GAM of an [Emb-GAM](https://arxiv.org/abs/2209.11799) extended to images. Patch embeddings are meant to be extracted with the [`microsoft/resnet-50` DINO checkpoint](https://huggingface.co/microsoft/resnet-50).

## Intended uses & limitations

This model is not intended to be used in production.

## Training Procedure

### Hyperparameters

The model is trained with below hyperparameters.

<details>
<summary> Click to expand </summary>

| Hyperparameter    | Value                                                     |
|-------------------|-----------------------------------------------------------|
| Cs                | 10                                                        |
| class_weight      |                                                           |
| cv                | StratifiedKFold(n_splits=5, random_state=1, shuffle=True) |
| dual              | False                                                     |
| fit_intercept     | True                                                      |
| intercept_scaling | 1.0                                                       |
| l1_ratios         |                                                           |
| max_iter          | 100                                                       |
| multi_class       | auto                                                      |
| n_jobs            |                                                           |
| penalty           | l2                                                        |
| random_state      | 1                                                         |
| refit             | False                                                     |
| scoring           |                                                           |
| solver            | lbfgs                                                     |
| tol               | 0.0001                                                    |
| verbose           | 0                                                         |

</details>

### Model Plot

The model plot is below.

<style>#sk-cdb8d112-cb44-4c1d-aaa0-00223c118110 {color: black;background-color: white;}#sk-cdb8d112-cb44-4c1d-aaa0-00223c118110 pre{padding: 0;}#sk-cdb8d112-cb44-4c1d-aaa0-00223c118110 div.sk-toggleable {background-color: white;}#sk-cdb8d112-cb44-4c1d-aaa0-00223c118110 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-cdb8d112-cb44-4c1d-aaa0-00223c118110 label.sk-toggleable__label-arrow:before {content: "▸";float: left;margin-right: 0.25em;color: #696969;}#sk-cdb8d112-cb44-4c1d-aaa0-00223c118110 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-cdb8d112-cb44-4c1d-aaa0-00223c118110 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-cdb8d112-cb44-4c1d-aaa0-00223c118110 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-cdb8d112-cb44-4c1d-aaa0-00223c118110 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-cdb8d112-cb44-4c1d-aaa0-00223c118110 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-cdb8d112-cb44-4c1d-aaa0-00223c118110 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: "▾";}#sk-cdb8d112-cb44-4c1d-aaa0-00223c118110 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-cdb8d112-cb44-4c1d-aaa0-00223c118110 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-cdb8d112-cb44-4c1d-aaa0-00223c118110 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-cdb8d112-cb44-4c1d-aaa0-00223c118110 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-cdb8d112-cb44-4c1d-aaa0-00223c118110 div.sk-estimator:hover {background-color: #d4ebff;}#sk-cdb8d112-cb44-4c1d-aaa0-00223c118110 div.sk-parallel-item::after {content: "";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-cdb8d112-cb44-4c1d-aaa0-00223c118110 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-cdb8d112-cb44-4c1d-aaa0-00223c118110 div.sk-serial::before {content: "";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 2em;bottom: 0;left: 50%;}#sk-cdb8d112-cb44-4c1d-aaa0-00223c118110 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;}#sk-cdb8d112-cb44-4c1d-aaa0-00223c118110 div.sk-item {z-index: 1;}#sk-cdb8d112-cb44-4c1d-aaa0-00223c118110 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;}#sk-cdb8d112-cb44-4c1d-aaa0-00223c118110 div.sk-parallel::before {content: "";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 2em;bottom: 0;left: 50%;}#sk-cdb8d112-cb44-4c1d-aaa0-00223c118110 div.sk-parallel-item {display: flex;flex-direction: column;position: relative;background-color: white;}#sk-cdb8d112-cb44-4c1d-aaa0-00223c118110 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-cdb8d112-cb44-4c1d-aaa0-00223c118110 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-cdb8d112-cb44-4c1d-aaa0-00223c118110 div.sk-parallel-item:only-child::after {width: 0;}#sk-cdb8d112-cb44-4c1d-aaa0-00223c118110 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;position: relative;}#sk-cdb8d112-cb44-4c1d-aaa0-00223c118110 div.sk-label label {font-family: monospace;font-weight: bold;background-color: white;display: inline-block;line-height: 1.2em;}#sk-cdb8d112-cb44-4c1d-aaa0-00223c118110 div.sk-label-container {position: relative;z-index: 2;text-align: center;}#sk-cdb8d112-cb44-4c1d-aaa0-00223c118110 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-cdb8d112-cb44-4c1d-aaa0-00223c118110 div.sk-text-repr-fallback {display: none;}</style><div id="sk-cdb8d112-cb44-4c1d-aaa0-00223c118110" class="sk-top-container"><div class="sk-text-repr-fallback"><pre>LogisticRegressionCV(cv=StratifiedKFold(n_splits=5, random_state=1, shuffle=True),random_state=1, refit=False)</pre><b>Please rerun this cell to show the HTML repr or trust the notebook.</b></div><div class="sk-container" hidden><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="8e021ebe-2e55-4711-8708-111032540cfe" type="checkbox" checked><label for="8e021ebe-2e55-4711-8708-111032540cfe" class="sk-toggleable__label sk-toggleable__label-arrow">LogisticRegressionCV</label><div class="sk-toggleable__content"><pre>LogisticRegressionCV(cv=StratifiedKFold(n_splits=5, random_state=1, shuffle=True),random_state=1, refit=False)</pre></div></div></div></div></div>

## Evaluation Results

You can find the details about evaluation process and the evaluation results.



| Metric   |    Value |
|----------|----------|
| accuracy | 0.996688 |
| f1 score | 0.996688 |

# How to Get Started with the Model

Use the code below to get started with the model.

<details>
<summary> Click to expand </summary>

```python
from PIL import Image
from skops import hub_utils
import torch
from transformers import AutoFeatureExtractor, AutoModel
import pickle
import os

# load embedding model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
feature_extractor = AutoFeatureExtractor.from_pretrained('microsoft/resnet-50')
model = AutoModel.from_pretrained('microsoft/resnet-50').eval().to(device)

# load logistic regression
os.mkdir('emb-gam-resnet')
hub_utils.download(repo_id='Ramos-Ramos/emb-gam-resnet', dst='emb-gam-resnet')

with open('emb-gam-resnet/model.pkl', 'rb') as file: 
  logistic_regression = pickle.load(file)
    
# load image
img = Image.open('examples/english_springer.png')

# preprocess image
inputs = {k: v.to(device) for k, v in feature_extractor(img, return_tensors='pt').items()}

# extract patch embeddings
with torch.no_grad():
  patch_embeddings = model(**inputs).last_hidden_state[0].permute(1, 2, 0).view(7*7, 2048).cpu()

# classify
pred = logistic_regression.predict(patch_embeddings.sum(dim=0, keepdim=True))

# get patch contributions
patch_contributions = logistic_regression.coef_ @ patch_embeddings.T.numpy()
```

</details>




# Model Card Authors

This model card is written by following authors:

Patrick Ramos and Ryan Ramos

# Model Card Contact

You can contact the model card authors through following channels:
[More Information Needed]

# Citation

Below you can find information related to citation.

**BibTeX:**
```
@article{singh2022emb,
  title={Emb-GAM: an Interpretable and Efficient Predictor using Pre-trained Language Models},
  author={Singh, Chandan and Gao, Jianfeng},
  journal={arXiv preprint arXiv:2209.11799},
  year={2022}
}
```


# Additional Content

## confusion_matrix

![confusion_matrix](confusion_matrix.png)