ChenyangSi commited on
Commit
ed16a12
1 Parent(s): 861e5b3

Update free_lunch_utils.py

Browse files
Files changed (1) hide show
  1. free_lunch_utils.py +15 -2
free_lunch_utils.py CHANGED
@@ -234,10 +234,23 @@ def register_free_crossattn_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
234
  # --------------- FreeU code -----------------------
235
  # Only operate on the first two stages
236
  if hidden_states.shape[1] == 1280:
237
- hidden_states[:,:640] = hidden_states[:,:640] * self.b1
 
 
 
 
 
 
 
238
  res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
239
  if hidden_states.shape[1] == 640:
240
- hidden_states[:,:320] = hidden_states[:,:320] * self.b2
 
 
 
 
 
 
241
  res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
242
  # ---------------------------------------------------------
243
 
 
234
  # --------------- FreeU code -----------------------
235
  # Only operate on the first two stages
236
  if hidden_states.shape[1] == 1280:
237
+ hidden_mean = hidden_states.mean(1).unsqueeze(1)
238
+ B = hidden_mean.shape[0]
239
+ hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
240
+ hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
241
+
242
+ hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
243
+
244
+ hidden_states[:,:640] = hidden_states[:,:640] * ((self.b1 - 1 ) * hidden_mean + 1)
245
  res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
246
  if hidden_states.shape[1] == 640:
247
+ hidden_mean = hidden_states.mean(1).unsqueeze(1)
248
+ B = hidden_mean.shape[0]
249
+ hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
250
+ hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
251
+ hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
252
+
253
+ hidden_states[:,:320] = hidden_states[:,:320] * ((self.b2 - 1 ) * hidden_mean + 1)
254
  res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
255
  # ---------------------------------------------------------
256