Patrick Ramos
commited on
Commit
•
9444735
1
Parent(s):
53308be
Update README.md
Browse files
README.md
CHANGED
@@ -74,17 +74,17 @@ Use the code below to get started with the model.
|
|
74 |
from PIL import Image
|
75 |
from skops import hub_utils
|
76 |
import torch
|
77 |
-
from transformers import
|
78 |
import pickle
|
79 |
import os
|
80 |
|
81 |
-
# load
|
82 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
83 |
-
feature_extractor =
|
84 |
-
model =
|
85 |
|
86 |
# load logistic regression
|
87 |
-
os.mkdir('
|
88 |
hub_utils.download(repo_id='Ramos-Ramos/emb-gam-dino', dst='emb-gam-dino')
|
89 |
|
90 |
with open('emb-gam-dino/model.pkl', 'rb') as file:
|
@@ -101,7 +101,7 @@ with torch.no_grad():
|
|
101 |
patch_embeddings = model(**inputs).last_hidden_state[0, 1:].cpu()
|
102 |
|
103 |
# classify
|
104 |
-
pred = logistic_regression.predict(patch_embeddings.
|
105 |
|
106 |
# get patch contributions
|
107 |
patch_contributions = logistic_regression.coef_ @ patch_embeddings.T.numpy()
|
|
|
74 |
from PIL import Image
|
75 |
from skops import hub_utils
|
76 |
import torch
|
77 |
+
from transformers import AutoFeatureExtractor, AutoModel
|
78 |
import pickle
|
79 |
import os
|
80 |
|
81 |
+
# load embedding model
|
82 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
83 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/dino-vitb16')
|
84 |
+
model = AutoModel.from_pretrained('facebook/dino-vitb16').eval().to(device)
|
85 |
|
86 |
# load logistic regression
|
87 |
+
os.mkdir('emb-gam-dino')
|
88 |
hub_utils.download(repo_id='Ramos-Ramos/emb-gam-dino', dst='emb-gam-dino')
|
89 |
|
90 |
with open('emb-gam-dino/model.pkl', 'rb') as file:
|
|
|
101 |
patch_embeddings = model(**inputs).last_hidden_state[0, 1:].cpu()
|
102 |
|
103 |
# classify
|
104 |
+
pred = logistic_regression.predict(patch_embeddings.sum(dim=0, keepdim=True))
|
105 |
|
106 |
# get patch contributions
|
107 |
patch_contributions = logistic_regression.coef_ @ patch_embeddings.T.numpy()
|