ChenyangSi commited on
Commit
bf85ddb
1 Parent(s): 01e731e

Update free_lunch_utils.py

Browse files
Files changed (1) hide show
  1. free_lunch_utils.py +24 -4
free_lunch_utils.py CHANGED
@@ -20,6 +20,26 @@ def isinstance_str(x: object, cls_name: str):
20
  return False
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  def register_upblock2d(model):
@@ -77,10 +97,10 @@ def register_free_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
77
  # Only operate on the first two stages
78
  if hidden_states.shape[1] == 1280:
79
  hidden_states[:,:640] = hidden_states[:,:640] * self.b1
80
- # # res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
81
  if hidden_states.shape[1] == 640:
82
  hidden_states[:,:320] = hidden_states[:,:320] * self.b2
83
- # res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
84
  # ---------------------------------------------------------
85
 
86
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
@@ -215,10 +235,10 @@ def register_free_crossattn_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
215
  # Only operate on the first two stages
216
  if hidden_states.shape[1] == 1280:
217
  hidden_states[:,:640] = hidden_states[:,:640] * self.b1
218
- # res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
219
  if hidden_states.shape[1] == 640:
220
  hidden_states[:,:320] = hidden_states[:,:320] * self.b2
221
- # res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
222
  # ---------------------------------------------------------
223
 
224
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
 
20
  return False
21
 
22
 
23
+ def Fourier_filter(x, threshold, scale):
24
+ dtype = x.dtype
25
+ x = x.type(torch.float32)
26
+ # FFT
27
+ x_freq = fft.fftn(x, dim=(-2, -1))
28
+ x_freq = fft.fftshift(x_freq, dim=(-2, -1))
29
+
30
+ B, C, H, W = x_freq.shape
31
+ mask = torch.ones((B, C, H, W)).cuda()
32
+
33
+ crow, ccol = H // 2, W //2
34
+ mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
35
+ x_freq = x_freq * mask
36
+
37
+ # IFFT
38
+ x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
39
+ x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
40
+
41
+ x_filtered = x_filtered.type(dtype)
42
+ return x_filtered
43
 
44
 
45
  def register_upblock2d(model):
 
97
  # Only operate on the first two stages
98
  if hidden_states.shape[1] == 1280:
99
  hidden_states[:,:640] = hidden_states[:,:640] * self.b1
100
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
101
  if hidden_states.shape[1] == 640:
102
  hidden_states[:,:320] = hidden_states[:,:320] * self.b2
103
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
104
  # ---------------------------------------------------------
105
 
106
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
 
235
  # Only operate on the first two stages
236
  if hidden_states.shape[1] == 1280:
237
  hidden_states[:,:640] = hidden_states[:,:640] * self.b1
238
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
239
  if hidden_states.shape[1] == 640:
240
  hidden_states[:,:320] = hidden_states[:,:320] * self.b2
241
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
242
  # ---------------------------------------------------------
243
 
244
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)