Spaces:
Runtime error
Runtime error
ChenyangSi
commited on
Commit
•
ed16a12
1
Parent(s):
861e5b3
Update free_lunch_utils.py
Browse files- free_lunch_utils.py +15 -2
free_lunch_utils.py
CHANGED
@@ -234,10 +234,23 @@ def register_free_crossattn_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
|
|
234 |
# --------------- FreeU code -----------------------
|
235 |
# Only operate on the first two stages
|
236 |
if hidden_states.shape[1] == 1280:
|
237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
|
239 |
if hidden_states.shape[1] == 640:
|
240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
|
242 |
# ---------------------------------------------------------
|
243 |
|
|
|
234 |
# --------------- FreeU code -----------------------
|
235 |
# Only operate on the first two stages
|
236 |
if hidden_states.shape[1] == 1280:
|
237 |
+
hidden_mean = hidden_states.mean(1).unsqueeze(1)
|
238 |
+
B = hidden_mean.shape[0]
|
239 |
+
hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
|
240 |
+
hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
|
241 |
+
|
242 |
+
hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
|
243 |
+
|
244 |
+
hidden_states[:,:640] = hidden_states[:,:640] * ((self.b1 - 1 ) * hidden_mean + 1)
|
245 |
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
|
246 |
if hidden_states.shape[1] == 640:
|
247 |
+
hidden_mean = hidden_states.mean(1).unsqueeze(1)
|
248 |
+
B = hidden_mean.shape[0]
|
249 |
+
hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
|
250 |
+
hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
|
251 |
+
hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
|
252 |
+
|
253 |
+
hidden_states[:,:320] = hidden_states[:,:320] * ((self.b2 - 1 ) * hidden_mean + 1)
|
254 |
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
|
255 |
# ---------------------------------------------------------
|
256 |
|