Patrick Ramos commited on
Commit
9444735
1 Parent(s): 53308be

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +6 -6
README.md CHANGED
@@ -74,17 +74,17 @@ Use the code below to get started with the model.
74
  from PIL import Image
75
  from skops import hub_utils
76
  import torch
77
- from transformers import ViTFeatureExtractor, ViTModel
78
  import pickle
79
  import os
80
 
81
- # load DINO
82
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
83
- feature_extractor = ViTFeatureExtractor.from_pretrained('facebook/dino-vitb16')
84
- model = ViTModel.from_pretrained('facebook/dino-vitb16').eval().to(device)
85
 
86
  # load logistic regression
87
- os.mkdir('logistic regression')
88
  hub_utils.download(repo_id='Ramos-Ramos/emb-gam-dino', dst='emb-gam-dino')
89
 
90
  with open('emb-gam-dino/model.pkl', 'rb') as file:
@@ -101,7 +101,7 @@ with torch.no_grad():
101
  patch_embeddings = model(**inputs).last_hidden_state[0, 1:].cpu()
102
 
103
  # classify
104
- pred = logistic_regression.predict(patch_embeddings.mean(dim=0).view(1, -1))
105
 
106
  # get patch contributions
107
  patch_contributions = logistic_regression.coef_ @ patch_embeddings.T.numpy()
 
74
  from PIL import Image
75
  from skops import hub_utils
76
  import torch
77
+ from transformers import AutoFeatureExtractor, AutoModel
78
  import pickle
79
  import os
80
 
81
+ # load embedding model
82
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
83
+ feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/dino-vitb16')
84
+ model = AutoModel.from_pretrained('facebook/dino-vitb16').eval().to(device)
85
 
86
  # load logistic regression
87
+ os.mkdir('emb-gam-dino')
88
  hub_utils.download(repo_id='Ramos-Ramos/emb-gam-dino', dst='emb-gam-dino')
89
 
90
  with open('emb-gam-dino/model.pkl', 'rb') as file:
 
101
  patch_embeddings = model(**inputs).last_hidden_state[0, 1:].cpu()
102
 
103
  # classify
104
+ pred = logistic_regression.predict(patch_embeddings.sum(dim=0, keepdim=True))
105
 
106
  # get patch contributions
107
  patch_contributions = logistic_regression.coef_ @ patch_embeddings.T.numpy()