huzey commited on
Commit
c37e7c7
·
1 Parent(s): 7acde1f
Files changed (1) hide show
  1. alignedthreeattn_model.py +2 -2
alignedthreeattn_model.py CHANGED
@@ -8,11 +8,11 @@ import numpy as np
8
  import torch
9
  import torch.nn.functional as F
10
 
11
- align_weights = torch.load("align_weights.pth")
12
  from torch import nn
13
  from backbone import CLIPAttnNode, DiNOv2AttnNode, MAEAttnNode
14
  class ThreeAttnNodes(nn.Module):
15
- def __init__(self, align_weights=align_weights):
16
  super().__init__()
17
  self.backbone1 = CLIPAttnNode()
18
  self.backbone2 = DiNOv2AttnNode()
 
8
  import torch
9
  import torch.nn.functional as F
10
 
11
+ # align_weights = torch.load("align_weights.pth")
12
  from torch import nn
13
  from backbone import CLIPAttnNode, DiNOv2AttnNode, MAEAttnNode
14
  class ThreeAttnNodes(nn.Module):
15
+ def __init__(self, align_weights):
16
  super().__init__()
17
  self.backbone1 = CLIPAttnNode()
18
  self.backbone2 = DiNOv2AttnNode()