amildravid4292 commited on
Commit
e789a6b
1 Parent(s): 4783cef

Update lora_w2w.py

Browse files
Files changed (1) hide show
  1. lora_w2w.py +3 -5
lora_w2w.py CHANGED
@@ -82,8 +82,8 @@ class LoRAModule(nn.Module):
82
  if type(alpha) == torch.Tensor:
83
  alpha = alpha.detach().numpy()
84
  alpha = lora_dim if alpha is None or alpha == 0 else alpha
85
- # self.scale = alpha / self.lora_dim
86
- # self.scale = self.scale.bfloat16()
87
 
88
 
89
  self.multiplier = multiplier
@@ -95,10 +95,8 @@ class LoRAModule(nn.Module):
95
  del self.org_module
96
 
97
  def forward(self, x):
98
-
99
-
100
  return self.org_forward(x) +\
101
- (x@((self.proj@self.v1.T)*self.std1+self.mean1).T)@(((self.proj@self.v2.T)*self.std2+self.mean2))#*self.multiplier*self.scale
102
 
103
 
104
 
 
82
  if type(alpha) == torch.Tensor:
83
  alpha = alpha.detach().numpy()
84
  alpha = lora_dim if alpha is None or alpha == 0 else alpha
85
+ self.scale = alpha / self.lora_dim
86
+
87
 
88
 
89
  self.multiplier = multiplier
 
95
  del self.org_module
96
 
97
  def forward(self, x):
 
 
98
  return self.org_forward(x) +\
99
+ (x@((self.proj@self.v1.T)*self.std1+self.mean1).T)@(((self.proj@self.v2.T)*self.std2+self.mean2))*self.multiplier*self.scale
100
 
101
 
102