Annonymous commited on
Commit
6a21d52
1 Parent(s): c868bf2

Update methods.py

Browse files
Files changed (1) hide show
  1. methods.py +1 -1
methods.py CHANGED
@@ -402,7 +402,7 @@ def smooth_grad(guided, ssl_model, img1, img2, blur_output, steps = 50):
402
  def get_sample_dataset(img_path, num_augments, batch_size, no_shift_transforms, ssl_model, n_components):
403
 
404
  measure = nn.CosineSimilarity(dim=-1)
405
- img = Image.open(img_path).convert('RGB')
406
  no_shift_aug = transforms.Compose([no_shift_transforms['aug'],
407
  transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3))])
408
 
 
402
  def get_sample_dataset(img_path, num_augments, batch_size, no_shift_transforms, ssl_model, n_components):
403
 
404
  measure = nn.CosineSimilarity(dim=-1)
405
+ img = Image.open(img_path).convert('RGB') if isinstance(img_path, str) else img_path
406
  no_shift_aug = transforms.Compose([no_shift_transforms['aug'],
407
  transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3))])
408