mpsk commited on
Commit
f5a521b
1 Parent(s): 7b11a8b

Update classifier.py

Browse files
Files changed (1) hide show
  1. classifier.py +5 -2
classifier.py CHANGED
@@ -95,8 +95,8 @@ class Classifier:
95
  grad = []
96
  # Normalize the weight before inference
97
  # This will constrain the gradient or you will have an explosion on query vector
98
- self.weight.data /= torch.norm(
99
- self.weight.data, p=2, dim=-1, keepdim=True
100
  )
101
  for n in range(self.num_class):
102
  # select all training sample and create labels
@@ -119,6 +119,9 @@ class Classifier:
119
  # update weights
120
  grad = torch.stack(grad, dim=0)
121
  self.weight -= 0.1 * grad
 
 
 
122
 
123
  def get_weights(self):
124
  xq = self.weight.detach().numpy()
 
95
  grad = []
96
  # Normalize the weight before inference
97
  # This will constrain the gradient or you will have an explosion on query vector
98
+ self.weight /= torch.norm(
99
+ self.weight, p=2, dim=-1, keepdim=True
100
  )
101
  for n in range(self.num_class):
102
  # select all training sample and create labels
 
119
  # update weights
120
  grad = torch.stack(grad, dim=0)
121
  self.weight -= 0.1 * grad
122
+ self.weight /= torch.norm(
123
+ self.weight, p=2, dim=-1, keepdim=True
124
+ )
125
 
126
  def get_weights(self):
127
  xq = self.weight.detach().numpy()