serdaryildiz commited on
Commit
f81a237
·
verified ·
1 Parent(s): 2b8e195

Update Model/dino/dino.py

Browse files
Files changed (1) hide show
  1. Model/dino/dino.py +2 -2
Model/dino/dino.py CHANGED
@@ -15,7 +15,7 @@ class DinoV2(nn.Module):
15
  def __init__(self, model_name):
16
  super().__init__()
17
  self.vision_encoder = torch.hub.load('facebookresearch/dinov2', model_name)
18
- self.vision_encoder = self.vision_encoder.eval().cuda().half()
19
  return
20
 
21
  def forward(self, x):
@@ -24,6 +24,6 @@ class DinoV2(nn.Module):
24
  def get_output_dim(self):
25
  with torch.no_grad():
26
  dummpy_input_image = preprocess(Image.fromarray(numpy.zeros((512, 512, 3), dtype=numpy.uint8))).to(
27
- next(self.parameters()).device).half()
28
  encoder_output_size = self.vision_encoder(dummpy_input_image.unsqueeze(0)).shape[-1]
29
  return encoder_output_size
 
15
  def __init__(self, model_name):
16
  super().__init__()
17
  self.vision_encoder = torch.hub.load('facebookresearch/dinov2', model_name)
18
+ self.vision_encoder = self.vision_encoder.eval()
19
  return
20
 
21
  def forward(self, x):
 
24
  def get_output_dim(self):
25
  with torch.no_grad():
26
  dummpy_input_image = preprocess(Image.fromarray(numpy.zeros((512, 512, 3), dtype=numpy.uint8))).to(
27
+ next(self.parameters()).device)
28
  encoder_output_size = self.vision_encoder(dummpy_input_image.unsqueeze(0)).shape[-1]
29
  return encoder_output_size