jusancp99 commited on
Commit
c82d603
1 Parent(s): acce508

Update similarity_utils.py

Browse files
Files changed (1) hide show
  1. similarity_utils.py +2 -2
similarity_utils.py CHANGED
@@ -18,14 +18,14 @@ np.random.seed(seed)
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
  # Load model for computing embeddings..
21
- model_ckpt = "nateraw/vit-base-beans"
22
  extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
23
 
24
  # Data transformation chain.
25
  transformation_chain = T.Compose(
26
  [
27
  # We first resize the input image to 256x256 and then we take center crop.
28
- T.Resize(int((256 / 224) * extractor.size["height"])),
29
  T.CenterCrop(extractor.size["height"]),
30
  T.ToTensor(),
31
  T.Normalize(mean=extractor.image_mean, std=extractor.image_std),
 
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
  # Load model for computing embeddings..
21
+ model_ckpt = "gjuggler/swin-tiny-patch4-window7-224-finetuned-birds"
22
  extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
23
 
24
  # Data transformation chain.
25
  transformation_chain = T.Compose(
26
  [
27
  # We first resize the input image to 256x256 and then we take center crop.
28
+ T.Resize(224),
29
  T.CenterCrop(extractor.size["height"]),
30
  T.ToTensor(),
31
  T.Normalize(mean=extractor.image_mean, std=extractor.image_std),