BorisEm Claude commited on
Commit
34e77ba
·
1 Parent(s): 841d16c

Fix HAT model architecture to match checkpoint structure

Browse files

- Replaced simplified implementation with complete HAT architecture
- Added proper RHAG (Residual Hybrid Attention Groups) structure
- Included AttenBlocks, HAB, OCAB components with residual_group attribute
- Added correct relative position calculation and attention mask handling
- Fixed parameter structure to match trained model checkpoint

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (2) hide show
  1. app.py +309 -249
  2. app_old.py +700 -0
app.py CHANGED
@@ -3,11 +3,33 @@ import torch
3
  import torch.nn as nn
4
  import numpy as np
5
  from PIL import Image
6
- import cv2
7
  import math
8
  from einops import rearrange
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def drop_path(x, drop_prob: float = 0., training: bool = False):
12
  if drop_prob == 0. or not training:
13
  return x
@@ -77,16 +99,16 @@ class Mlp(nn.Module):
77
 
78
 
79
  def window_partition(x, window_size):
80
- B, H, W, C = x.shape
81
- x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
82
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
83
  return windows
84
 
85
 
86
- def window_reverse(windows, window_size, H, W):
87
- B = int(windows.shape[0] / (H * W / window_size / window_size))
88
- x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
89
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
90
  return x
91
 
92
 
@@ -97,55 +119,43 @@ class WindowAttention(nn.Module):
97
  self.window_size = window_size
98
  self.num_heads = num_heads
99
  head_dim = dim // num_heads
100
- self.scale = qk_scale or head_dim ** -0.5
101
 
102
  self.relative_position_bias_table = nn.Parameter(
103
  torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
104
 
105
- coords_h = torch.arange(self.window_size[0])
106
- coords_w = torch.arange(self.window_size[1])
107
- coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
108
- coords_flatten = torch.flatten(coords, 1)
109
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
110
- relative_coords = relative_coords.permute(1, 2, 0).contiguous()
111
- relative_coords[:, :, 0] += self.window_size[0] - 1
112
- relative_coords[:, :, 1] += self.window_size[1] - 1
113
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
114
- relative_position_index = relative_coords.sum(-1)
115
- self.register_buffer("relative_position_index", relative_position_index)
116
-
117
  self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
118
  self.attn_drop = nn.Dropout(attn_drop)
119
  self.proj = nn.Linear(dim, dim)
120
  self.proj_drop = nn.Dropout(proj_drop)
121
 
122
- nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
123
  self.softmax = nn.Softmax(dim=-1)
124
 
125
- def forward(self, x, mask=None):
126
- B_, N, C = x.shape
127
- qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
128
  q, k, v = qkv[0], qkv[1], qkv[2]
129
 
130
  q = q * self.scale
131
  attn = (q @ k.transpose(-2, -1))
132
 
133
- relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
134
  self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
135
  relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
136
  attn = attn + relative_position_bias.unsqueeze(0)
137
 
138
  if mask is not None:
139
- nW = mask.shape[0]
140
- attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
141
- attn = attn.view(-1, self.num_heads, N, N)
142
  attn = self.softmax(attn)
143
  else:
144
  attn = self.softmax(attn)
145
 
146
  attn = self.attn_drop(attn)
147
 
148
- x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
149
  x = self.proj(x)
150
  x = self.proj_drop(x)
151
  return x
@@ -153,8 +163,9 @@ class WindowAttention(nn.Module):
153
 
154
  class HAB(nn.Module):
155
  def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
156
- mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
157
- act_layer=nn.GELU, norm_layer=nn.LayerNorm, compress_ratio=3, squeeze_factor=30):
 
158
  super().__init__()
159
  self.dim = dim
160
  self.input_resolution = input_resolution
@@ -165,177 +176,225 @@ class HAB(nn.Module):
165
  if min(self.input_resolution) <= self.window_size:
166
  self.shift_size = 0
167
  self.window_size = min(self.input_resolution)
168
- assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
169
 
170
  self.norm1 = norm_layer(dim)
171
  self.attn = WindowAttention(
172
- dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
173
  qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
174
 
 
 
 
175
  self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
176
  self.norm2 = norm_layer(dim)
177
  mlp_hidden_dim = int(dim * mlp_ratio)
178
  self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
179
 
180
- self.conv_scale = nn.Parameter(torch.ones(1))
181
- self.conv_block = CAB(dim, compress_ratio, squeeze_factor)
182
-
183
- if self.shift_size > 0:
184
- H, W = self.input_resolution
185
- img_mask = torch.zeros((1, H, W, 1))
186
- h_slices = (slice(0, -self.window_size),
187
- slice(-self.window_size, -self.shift_size),
188
- slice(-self.shift_size, None))
189
- w_slices = (slice(0, -self.window_size),
190
- slice(-self.window_size, -self.shift_size),
191
- slice(-self.shift_size, None))
192
- cnt = 0
193
- for h in h_slices:
194
- for w in w_slices:
195
- img_mask[:, h, w, :] = cnt
196
- cnt += 1
197
-
198
- mask_windows = window_partition(img_mask, self.window_size)
199
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
200
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
201
- attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
202
- else:
203
- attn_mask = None
204
-
205
- self.register_buffer("attn_mask", attn_mask)
206
-
207
- def forward(self, x):
208
- H, W = self.input_resolution
209
- B, L, C = x.shape
210
- assert L == H * W, "input feature has wrong size"
211
 
212
  shortcut = x
213
  x = self.norm1(x)
214
- x = x.view(B, H, W, C)
215
 
 
 
 
 
 
216
  if self.shift_size > 0:
217
  shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
 
218
  else:
219
  shifted_x = x
 
220
 
 
221
  x_windows = window_partition(shifted_x, self.window_size)
222
- x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
223
 
224
- attn_windows = self.attn(x_windows, mask=self.attn_mask)
 
225
 
226
- attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
227
- shifted_x = window_reverse(attn_windows, self.window_size, H, W)
 
228
 
 
229
  if self.shift_size > 0:
230
- x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
231
  else:
232
- x = shifted_x
233
- x = x.view(B, H * W, C)
234
 
235
- x = shortcut + self.drop_path(x)
236
-
237
- y = x
238
- x = self.norm2(x)
239
- x = self.mlp(x)
240
- x = y + self.drop_path(x)
241
-
242
- conv_x = self.conv_block(x.view(B, H, W, C).permute(0, 3, 1, 2))
243
- conv_x = conv_x.permute(0, 2, 3, 1).view(B, H * W, C)
244
-
245
- x = x + self.conv_scale * conv_x
246
 
247
  return x
248
 
249
 
250
  class OCAB(nn.Module):
251
  def __init__(self, dim, input_resolution, window_size, overlap_ratio, num_heads,
252
- mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
253
- drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, compress_ratio=3,
254
- squeeze_factor=30):
255
  super().__init__()
256
  self.dim = dim
257
  self.input_resolution = input_resolution
258
  self.window_size = window_size
259
  self.num_heads = num_heads
260
- self.shift_size = round(overlap_ratio * window_size)
261
- self.mlp_ratio = mlp_ratio
 
262
 
263
- if min(self.input_resolution) <= self.window_size:
264
- self.shift_size = 0
265
- self.window_size = min(self.input_resolution)
 
266
 
267
- assert 0 <= self.shift_size, "shift_size >= 0 is required"
 
268
 
269
- self.norm1 = norm_layer(dim)
270
- self.attn = WindowAttention(
271
- dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
272
- qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
273
 
274
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
275
  self.norm2 = norm_layer(dim)
276
  mlp_hidden_dim = int(dim * mlp_ratio)
277
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
278
 
279
- self.conv_scale = nn.Parameter(torch.ones(1))
280
- self.conv_block = CAB(dim, compress_ratio, squeeze_factor)
281
-
282
- def forward(self, x):
283
- H, W = self.input_resolution
284
- B, L, C = x.shape
285
- assert L == H * W, "input feature has wrong size"
286
 
287
  shortcut = x
288
  x = self.norm1(x)
289
- x = x.view(B, H, W, C)
290
 
291
- pad_l = pad_t = 0
292
- pad_r = (self.window_size - W % self.window_size) % self.window_size
293
- pad_b = (self.window_size - H % self.window_size) % self.window_size
294
- x = torch.nn.functional.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
295
- _, Hp, Wp, _ = x.shape
296
 
297
- if self.shift_size > 0:
298
- shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
299
- else:
300
- shifted_x = x
301
 
302
- x_windows = window_partition(shifted_x, self.window_size)
303
- x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
 
 
304
 
305
- attn_windows = self.attn(x_windows, mask=None)
 
 
 
 
 
306
 
307
- attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
308
- shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)
309
 
310
- if self.shift_size > 0:
311
- x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
312
- else:
313
- x = shifted_x
 
 
 
 
 
 
 
 
 
 
 
 
314
 
315
- if pad_r > 0 or pad_b > 0:
316
- x = x[:, :H, :W, :].contiguous()
317
 
318
- x = x.view(B, H * W, C)
319
- x = shortcut + self.drop_path(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
321
- y = x
322
- x = self.norm2(x)
323
- x = self.mlp(x)
324
- x = y + self.drop_path(x)
 
325
 
326
- conv_x = self.conv_block(x.view(B, H, W, C).permute(0, 3, 1, 2))
327
- conv_x = conv_x.permute(0, 2, 3, 1).view(B, H * W, C)
 
328
 
329
- x = x + self.conv_scale * conv_x
330
 
 
 
331
  return x
332
 
333
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  class PatchEmbed(nn.Module):
335
  def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
336
  super().__init__()
337
- img_size = (img_size, img_size)
338
- patch_size = (patch_size, patch_size)
339
  patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
340
  self.img_size = img_size
341
  self.patch_size = patch_size
@@ -345,17 +404,13 @@ class PatchEmbed(nn.Module):
345
  self.in_chans = in_chans
346
  self.embed_dim = embed_dim
347
 
348
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
349
  if norm_layer is not None:
350
  self.norm = norm_layer(embed_dim)
351
  else:
352
  self.norm = None
353
 
354
  def forward(self, x):
355
- B, C, H, W = x.shape
356
- assert H == self.img_size[0] and W == self.img_size[1], \
357
- f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
358
- x = self.proj(x).flatten(2).transpose(1, 2)
359
  if self.norm is not None:
360
  x = self.norm(x)
361
  return x
@@ -364,8 +419,8 @@ class PatchEmbed(nn.Module):
364
  class PatchUnEmbed(nn.Module):
365
  def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
366
  super().__init__()
367
- img_size = (img_size, img_size)
368
- patch_size = (patch_size, patch_size)
369
  patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
370
  self.img_size = img_size
371
  self.patch_size = patch_size
@@ -376,73 +431,7 @@ class PatchUnEmbed(nn.Module):
376
  self.embed_dim = embed_dim
377
 
378
  def forward(self, x, x_size):
379
- H, W = x_size
380
- B, HW, C = x.shape
381
- x = x.transpose(1, 2).view(B, self.embed_dim, H, W)
382
- return x
383
-
384
-
385
- class RHAG(nn.Module):
386
- def __init__(self, dim, input_resolution, depth, num_heads, window_size, compress_ratio,
387
- squeeze_factor, conv_scale, overlap_ratio, mlp_ratio=4., qkv_bias=True, qk_scale=None,
388
- drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None,
389
- use_checkpoint=False):
390
- super().__init__()
391
- self.dim = dim
392
- self.input_resolution = input_resolution
393
- self.depth = depth
394
- self.use_checkpoint = use_checkpoint
395
-
396
- self.blocks_1 = nn.ModuleList([
397
- HAB(dim=dim, input_resolution=input_resolution,
398
- num_heads=num_heads, window_size=window_size,
399
- shift_size=0 if (i % 2 == 0) else window_size // 2,
400
- mlp_ratio=mlp_ratio,
401
- qkv_bias=qkv_bias, qk_scale=qk_scale,
402
- drop=drop, attn_drop=attn_drop,
403
- drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
404
- norm_layer=norm_layer, compress_ratio=compress_ratio,
405
- squeeze_factor=squeeze_factor)
406
- for i in range(depth // 2)])
407
-
408
- self.blocks_2 = nn.ModuleList([
409
- OCAB(dim=dim, input_resolution=input_resolution,
410
- window_size=window_size, overlap_ratio=overlap_ratio,
411
- num_heads=num_heads, mlp_ratio=mlp_ratio,
412
- qkv_bias=qkv_bias, qk_scale=qk_scale,
413
- drop=drop, attn_drop=attn_drop,
414
- drop_path=drop_path[i + depth//2] if isinstance(drop_path, list) else drop_path,
415
- norm_layer=norm_layer, compress_ratio=compress_ratio,
416
- squeeze_factor=squeeze_factor)
417
- for i in range(depth // 2)])
418
-
419
- self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
420
- self.conv_scale = conv_scale
421
-
422
- if downsample is not None:
423
- self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
424
- else:
425
- self.downsample = None
426
-
427
- def forward(self, x, x_size):
428
- H, W = x_size
429
- res = x
430
- for blk in self.blocks_1:
431
- if self.use_checkpoint:
432
- x = torch.utils.checkpoint.checkpoint(blk, x)
433
- else:
434
- x = blk(x)
435
- for blk in self.blocks_2:
436
- if self.use_checkpoint:
437
- x = torch.utils.checkpoint.checkpoint(blk, x)
438
- else:
439
- x = blk(x)
440
-
441
- conv_x = self.conv(x.transpose(1, 2).view(-1, self.dim, H, W)).view(-1, self.dim, H * W).transpose(1, 2)
442
- x = res + x + conv_x * self.conv_scale
443
-
444
- if self.downsample is not None:
445
- x = self.downsample(x)
446
  return x
447
 
448
 
@@ -462,8 +451,8 @@ class Upsample(nn.Sequential):
462
 
463
 
464
  class HAT(nn.Module):
465
- def __init__(self, img_size=64, patch_size=1, in_chans=3, embed_dim=180, depths=[6, 6, 6, 6, 6, 6],
466
- num_heads=[6, 6, 6, 6, 6, 6], window_size=16, compress_ratio=3, squeeze_factor=30,
467
  conv_scale=0.01, overlap_ratio=0.5, mlp_ratio=4., qkv_bias=True, qk_scale=None,
468
  drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm,
469
  ape=False, patch_norm=True, use_checkpoint=False, upscale=2, img_range=1.,
@@ -473,6 +462,7 @@ class HAT(nn.Module):
473
  self.window_size = window_size
474
  self.shift_size = window_size // 2
475
  self.overlap_ratio = overlap_ratio
 
476
  num_in_ch = in_chans
477
  num_out_ch = in_chans
478
  num_feat = 64
@@ -485,8 +475,16 @@ class HAT(nn.Module):
485
  self.upscale = upscale
486
  self.upsampler = upsampler
487
 
 
 
 
 
 
 
 
488
  self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
489
 
 
490
  self.num_layers = len(depths)
491
  self.embed_dim = embed_dim
492
  self.ape = ape
@@ -494,6 +492,7 @@ class HAT(nn.Module):
494
  self.num_features = embed_dim
495
  self.mlp_ratio = mlp_ratio
496
 
 
497
  self.patch_embed = PatchEmbed(
498
  img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
499
  norm_layer=norm_layer if self.patch_norm else None)
@@ -501,52 +500,59 @@ class HAT(nn.Module):
501
  patches_resolution = self.patch_embed.patches_resolution
502
  self.patches_resolution = patches_resolution
503
 
 
504
  self.patch_unembed = PatchUnEmbed(
505
  img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
506
  norm_layer=norm_layer if self.patch_norm else None)
507
 
 
508
  if self.ape:
509
  self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
510
- nn.init.trunc_normal_(self.absolute_pos_embed, std=.02)
511
 
512
  self.pos_drop = nn.Dropout(p=drop_rate)
513
 
 
514
  dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
515
 
 
516
  self.layers = nn.ModuleList()
517
  for i_layer in range(self.num_layers):
518
- layer = RHAG(dim=embed_dim,
519
- input_resolution=(patches_resolution[0],
520
- patches_resolution[1]),
521
- depth=depths[i_layer],
522
- num_heads=num_heads[i_layer],
523
- window_size=window_size,
524
- compress_ratio=compress_ratio,
525
- squeeze_factor=squeeze_factor,
526
- conv_scale=conv_scale,
527
- overlap_ratio=overlap_ratio,
528
- mlp_ratio=self.mlp_ratio,
529
- qkv_bias=qkv_bias, qk_scale=qk_scale,
530
- drop=drop_rate, attn_drop=attn_drop_rate,
531
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
532
- norm_layer=norm_layer,
533
- downsample=None,
534
- use_checkpoint=use_checkpoint)
 
 
 
 
 
535
  self.layers.append(layer)
536
  self.norm = norm_layer(self.num_features)
537
 
 
538
  if resi_connection == '1conv':
539
  self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
540
- elif resi_connection == '3conv':
541
- self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
542
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
543
- nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
544
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
545
- nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
546
-
547
- if upsampler == 'pixelshuffle':
548
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
549
- nn.LeakyReLU(inplace=True))
550
  self.upsample = Upsample(upscale, num_feat)
551
  self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
552
 
@@ -554,13 +560,65 @@ class HAT(nn.Module):
554
 
555
  def _init_weights(self, m):
556
  if isinstance(m, nn.Linear):
557
- nn.init.trunc_normal_(m.weight, std=.02)
558
  if isinstance(m, nn.Linear) and m.bias is not None:
559
  nn.init.constant_(m.bias, 0)
560
  elif isinstance(m, nn.LayerNorm):
561
  nn.init.constant_(m.bias, 0)
562
  nn.init.constant_(m.weight, 1.0)
563
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
564
  @torch.jit.ignore
565
  def no_weight_decay(self):
566
  return {'absolute_pos_embed'}
@@ -571,31 +629,33 @@ class HAT(nn.Module):
571
 
572
  def forward_features(self, x):
573
  x_size = (x.shape[2], x.shape[3])
 
 
 
 
574
  x = self.patch_embed(x)
575
  if self.ape:
576
  x = x + self.absolute_pos_embed
577
  x = self.pos_drop(x)
578
 
579
  for layer in self.layers:
580
- x = layer(x, x_size)
581
 
582
  x = self.norm(x)
583
  x = self.patch_unembed(x, x_size)
584
-
585
  return x
586
 
587
  def forward(self, x):
588
  self.mean = self.mean.type_as(x)
589
  x = (x - self.mean) * self.img_range
590
 
591
- x_first = self.conv_first(x)
592
- res = self.conv_after_body(self.forward_features(x_first)) + x_first
593
  if self.upsampler == 'pixelshuffle':
594
- x = self.conv_before_upsample(res)
 
 
595
  x = self.conv_last(self.upsample(x))
596
 
597
  x = x / self.img_range + self.mean
598
-
599
  return x
600
 
601
 
 
3
  import torch.nn as nn
4
  import numpy as np
5
  from PIL import Image
 
6
  import math
7
  from einops import rearrange
8
 
9
 
10
+ def to_2tuple(x):
11
+ """Convert input to tuple of length 2."""
12
+ if isinstance(x, (tuple, list)):
13
+ return tuple(x)
14
+ return (x, x)
15
+
16
+
17
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
18
+ """Truncated normal initialization."""
19
+ def norm_cdf(x):
20
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
21
+
22
+ with torch.no_grad():
23
+ l = norm_cdf((a - mean) / std)
24
+ u = norm_cdf((b - mean) / std)
25
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
26
+ tensor.erfinv_()
27
+ tensor.mul_(std * math.sqrt(2.))
28
+ tensor.add_(mean)
29
+ tensor.clamp_(min=a, max=b)
30
+ return tensor
31
+
32
+
33
  def drop_path(x, drop_prob: float = 0., training: bool = False):
34
  if drop_prob == 0. or not training:
35
  return x
 
99
 
100
 
101
  def window_partition(x, window_size):
102
+ b, h, w, c = x.shape
103
+ x = x.view(b, h // window_size, window_size, w // window_size, window_size, c)
104
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c)
105
  return windows
106
 
107
 
108
+ def window_reverse(windows, window_size, h, w):
109
+ b = int(windows.shape[0] / (h * w / window_size / window_size))
110
+ x = windows.view(b, h // window_size, w // window_size, window_size, window_size, -1)
111
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
112
  return x
113
 
114
 
 
119
  self.window_size = window_size
120
  self.num_heads = num_heads
121
  head_dim = dim // num_heads
122
+ self.scale = qk_scale or head_dim**-0.5
123
 
124
  self.relative_position_bias_table = nn.Parameter(
125
  torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
126
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
128
  self.attn_drop = nn.Dropout(attn_drop)
129
  self.proj = nn.Linear(dim, dim)
130
  self.proj_drop = nn.Dropout(proj_drop)
131
 
132
+ trunc_normal_(self.relative_position_bias_table, std=.02)
133
  self.softmax = nn.Softmax(dim=-1)
134
 
135
+ def forward(self, x, rpi, mask=None):
136
+ b_, n, c = x.shape
137
+ qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
138
  q, k, v = qkv[0], qkv[1], qkv[2]
139
 
140
  q = q * self.scale
141
  attn = (q @ k.transpose(-2, -1))
142
 
143
+ relative_position_bias = self.relative_position_bias_table[rpi.view(-1)].view(
144
  self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
145
  relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
146
  attn = attn + relative_position_bias.unsqueeze(0)
147
 
148
  if mask is not None:
149
+ nw = mask.shape[0]
150
+ attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
151
+ attn = attn.view(-1, self.num_heads, n, n)
152
  attn = self.softmax(attn)
153
  else:
154
  attn = self.softmax(attn)
155
 
156
  attn = self.attn_drop(attn)
157
 
158
+ x = (attn @ v).transpose(1, 2).reshape(b_, n, c)
159
  x = self.proj(x)
160
  x = self.proj_drop(x)
161
  return x
 
163
 
164
  class HAB(nn.Module):
165
  def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
166
+ compress_ratio=3, squeeze_factor=30, conv_scale=0.01, mlp_ratio=4.,
167
+ qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
168
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
169
  super().__init__()
170
  self.dim = dim
171
  self.input_resolution = input_resolution
 
176
  if min(self.input_resolution) <= self.window_size:
177
  self.shift_size = 0
178
  self.window_size = min(self.input_resolution)
179
+ assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
180
 
181
  self.norm1 = norm_layer(dim)
182
  self.attn = WindowAttention(
183
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
184
  qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
185
 
186
+ self.conv_scale = conv_scale
187
+ self.conv_block = CAB(num_feat=dim, compress_ratio=compress_ratio, squeeze_factor=squeeze_factor)
188
+
189
  self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
190
  self.norm2 = norm_layer(dim)
191
  mlp_hidden_dim = int(dim * mlp_ratio)
192
  self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
193
 
194
+ def forward(self, x, x_size, rpi_sa, attn_mask):
195
+ h, w = x_size
196
+ b, _, c = x.shape
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  shortcut = x
199
  x = self.norm1(x)
200
+ x = x.view(b, h, w, c)
201
 
202
+ # Conv_X
203
+ conv_x = self.conv_block(x.permute(0, 3, 1, 2))
204
+ conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(b, h * w, c)
205
+
206
+ # cyclic shift
207
  if self.shift_size > 0:
208
  shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
209
+ attn_mask = attn_mask
210
  else:
211
  shifted_x = x
212
+ attn_mask = None
213
 
214
+ # partition windows
215
  x_windows = window_partition(shifted_x, self.window_size)
216
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, c)
217
 
218
+ # W-MSA/SW-MSA
219
+ attn_windows = self.attn(x_windows, rpi=rpi_sa, mask=attn_mask)
220
 
221
+ # merge windows
222
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c)
223
+ shifted_x = window_reverse(attn_windows, self.window_size, h, w)
224
 
225
+ # reverse cyclic shift
226
  if self.shift_size > 0:
227
+ attn_x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
228
  else:
229
+ attn_x = shifted_x
230
+ attn_x = attn_x.view(b, h * w, c)
231
 
232
+ # FFN
233
+ x = shortcut + self.drop_path(attn_x) + conv_x * self.conv_scale
234
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
 
 
 
 
 
 
 
 
235
 
236
  return x
237
 
238
 
239
  class OCAB(nn.Module):
240
  def __init__(self, dim, input_resolution, window_size, overlap_ratio, num_heads,
241
+ qkv_bias=True, qk_scale=None, mlp_ratio=2, norm_layer=nn.LayerNorm):
 
 
242
  super().__init__()
243
  self.dim = dim
244
  self.input_resolution = input_resolution
245
  self.window_size = window_size
246
  self.num_heads = num_heads
247
+ head_dim = dim // num_heads
248
+ self.scale = qk_scale or head_dim**-0.5
249
+ self.overlap_win_size = int(window_size * overlap_ratio) + window_size
250
 
251
+ self.norm1 = norm_layer(dim)
252
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
253
+ self.unfold = nn.Unfold(kernel_size=(self.overlap_win_size, self.overlap_win_size),
254
+ stride=window_size, padding=(self.overlap_win_size-window_size)//2)
255
 
256
+ self.relative_position_bias_table = nn.Parameter(
257
+ torch.zeros((window_size + self.overlap_win_size - 1) * (window_size + self.overlap_win_size - 1), num_heads))
258
 
259
+ trunc_normal_(self.relative_position_bias_table, std=.02)
260
+ self.softmax = nn.Softmax(dim=-1)
261
+
262
+ self.proj = nn.Linear(dim,dim)
263
 
 
264
  self.norm2 = norm_layer(dim)
265
  mlp_hidden_dim = int(dim * mlp_ratio)
266
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=nn.GELU)
267
 
268
+ def forward(self, x, x_size, rpi):
269
+ h, w = x_size
270
+ b, _, c = x.shape
 
 
 
 
271
 
272
  shortcut = x
273
  x = self.norm1(x)
274
+ x = x.view(b, h, w, c)
275
 
276
+ qkv = self.qkv(x).reshape(b, h, w, 3, c).permute(3, 0, 4, 1, 2)
277
+ q = qkv[0].permute(0, 2, 3, 1)
278
+ kv = torch.cat((qkv[1], qkv[2]), dim=1)
 
 
279
 
280
+ # partition windows
281
+ q_windows = window_partition(q, self.window_size)
282
+ q_windows = q_windows.view(-1, self.window_size * self.window_size, c)
 
283
 
284
+ kv_windows = self.unfold(kv)
285
+ kv_windows = rearrange(kv_windows, 'b (nc ch owh oww) nw -> nc (b nw) (owh oww) ch',
286
+ nc=2, ch=c, owh=self.overlap_win_size, oww=self.overlap_win_size).contiguous()
287
+ k_windows, v_windows = kv_windows[0], kv_windows[1]
288
 
289
+ b_, nq, _ = q_windows.shape
290
+ _, n, _ = k_windows.shape
291
+ d = self.dim // self.num_heads
292
+ q = q_windows.reshape(b_, nq, self.num_heads, d).permute(0, 2, 1, 3)
293
+ k = k_windows.reshape(b_, n, self.num_heads, d).permute(0, 2, 1, 3)
294
+ v = v_windows.reshape(b_, n, self.num_heads, d).permute(0, 2, 1, 3)
295
 
296
+ q = q * self.scale
297
+ attn = (q @ k.transpose(-2, -1))
298
 
299
+ relative_position_bias = self.relative_position_bias_table[rpi.view(-1)].view(
300
+ self.window_size * self.window_size, self.overlap_win_size * self.overlap_win_size, -1)
301
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
302
+ attn = attn + relative_position_bias.unsqueeze(0)
303
+
304
+ attn = self.softmax(attn)
305
+ attn_windows = (attn @ v).transpose(1, 2).reshape(b_, nq, self.dim)
306
+
307
+ # merge windows
308
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.dim)
309
+ x = window_reverse(attn_windows, self.window_size, h, w)
310
+ x = x.view(b, h * w, self.dim)
311
+
312
+ x = self.proj(x) + shortcut
313
+ x = x + self.mlp(self.norm2(x))
314
+ return x
315
 
 
 
316
 
317
+ class AttenBlocks(nn.Module):
318
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size, compress_ratio,
319
+ squeeze_factor, conv_scale, overlap_ratio, mlp_ratio=4., qkv_bias=True, qk_scale=None,
320
+ drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None,
321
+ use_checkpoint=False):
322
+ super().__init__()
323
+ self.dim = dim
324
+ self.input_resolution = input_resolution
325
+ self.depth = depth
326
+ self.use_checkpoint = use_checkpoint
327
+
328
+ # build blocks
329
+ self.blocks = nn.ModuleList([
330
+ HAB(dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size,
331
+ shift_size=0 if (i % 2 == 0) else window_size // 2, compress_ratio=compress_ratio,
332
+ squeeze_factor=squeeze_factor, conv_scale=conv_scale, mlp_ratio=mlp_ratio,
333
+ qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop,
334
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
335
+ norm_layer=norm_layer) for i in range(depth)
336
+ ])
337
+
338
+ # OCAB
339
+ self.overlap_attn = OCAB(dim=dim, input_resolution=input_resolution, window_size=window_size,
340
+ overlap_ratio=overlap_ratio, num_heads=num_heads, qkv_bias=qkv_bias,
341
+ qk_scale=qk_scale, mlp_ratio=mlp_ratio, norm_layer=norm_layer)
342
 
343
+ # patch merging layer
344
+ if downsample is not None:
345
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
346
+ else:
347
+ self.downsample = None
348
 
349
+ def forward(self, x, x_size, params):
350
+ for blk in self.blocks:
351
+ x = blk(x, x_size, params['rpi_sa'], params['attn_mask'])
352
 
353
+ x = self.overlap_attn(x, x_size, params['rpi_oca'])
354
 
355
+ if self.downsample is not None:
356
+ x = self.downsample(x)
357
  return x
358
 
359
 
360
+ class RHAG(nn.Module):
361
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size, compress_ratio,
362
+ squeeze_factor, conv_scale, overlap_ratio, mlp_ratio=4., qkv_bias=True, qk_scale=None,
363
+ drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None,
364
+ use_checkpoint=False, img_size=224, patch_size=4, resi_connection='1conv'):
365
+ super(RHAG, self).__init__()
366
+
367
+ self.dim = dim
368
+ self.input_resolution = input_resolution
369
+
370
+ self.residual_group = AttenBlocks(
371
+ dim=dim, input_resolution=input_resolution, depth=depth, num_heads=num_heads,
372
+ window_size=window_size, compress_ratio=compress_ratio, squeeze_factor=squeeze_factor,
373
+ conv_scale=conv_scale, overlap_ratio=overlap_ratio, mlp_ratio=mlp_ratio,
374
+ qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop,
375
+ drop_path=drop_path, norm_layer=norm_layer, downsample=downsample,
376
+ use_checkpoint=use_checkpoint)
377
+
378
+ if resi_connection == '1conv':
379
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
380
+ elif resi_connection == 'identity':
381
+ self.conv = nn.Identity()
382
+
383
+ self.patch_embed = PatchEmbed(
384
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
385
+
386
+ self.patch_unembed = PatchUnEmbed(
387
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
388
+
389
+ def forward(self, x, x_size, params):
390
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size, params), x_size))) + x
391
+
392
+
393
  class PatchEmbed(nn.Module):
394
  def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
395
  super().__init__()
396
+ img_size = to_2tuple(img_size)
397
+ patch_size = to_2tuple(patch_size)
398
  patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
399
  self.img_size = img_size
400
  self.patch_size = patch_size
 
404
  self.in_chans = in_chans
405
  self.embed_dim = embed_dim
406
 
 
407
  if norm_layer is not None:
408
  self.norm = norm_layer(embed_dim)
409
  else:
410
  self.norm = None
411
 
412
  def forward(self, x):
413
+ x = x.flatten(2).transpose(1, 2)
 
 
 
414
  if self.norm is not None:
415
  x = self.norm(x)
416
  return x
 
419
  class PatchUnEmbed(nn.Module):
420
  def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
421
  super().__init__()
422
+ img_size = to_2tuple(img_size)
423
+ patch_size = to_2tuple(patch_size)
424
  patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
425
  self.img_size = img_size
426
  self.patch_size = patch_size
 
431
  self.embed_dim = embed_dim
432
 
433
  def forward(self, x, x_size):
434
+ x = x.transpose(1, 2).contiguous().view(x.shape[0], self.embed_dim, x_size[0], x_size[1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
  return x
436
 
437
 
 
451
 
452
 
453
  class HAT(nn.Module):
454
+ def __init__(self, img_size=64, patch_size=1, in_chans=3, embed_dim=96, depths=(6, 6, 6, 6),
455
+ num_heads=(6, 6, 6, 6), window_size=7, compress_ratio=3, squeeze_factor=30,
456
  conv_scale=0.01, overlap_ratio=0.5, mlp_ratio=4., qkv_bias=True, qk_scale=None,
457
  drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm,
458
  ape=False, patch_norm=True, use_checkpoint=False, upscale=2, img_range=1.,
 
462
  self.window_size = window_size
463
  self.shift_size = window_size // 2
464
  self.overlap_ratio = overlap_ratio
465
+
466
  num_in_ch = in_chans
467
  num_out_ch = in_chans
468
  num_feat = 64
 
475
  self.upscale = upscale
476
  self.upsampler = upsampler
477
 
478
+ # relative position index
479
+ relative_position_index_SA = self.calculate_rpi_sa()
480
+ relative_position_index_OCA = self.calculate_rpi_oca()
481
+ self.register_buffer('relative_position_index_SA', relative_position_index_SA)
482
+ self.register_buffer('relative_position_index_OCA', relative_position_index_OCA)
483
+
484
+ # shallow feature extraction
485
  self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
486
 
487
+ # deep feature extraction
488
  self.num_layers = len(depths)
489
  self.embed_dim = embed_dim
490
  self.ape = ape
 
492
  self.num_features = embed_dim
493
  self.mlp_ratio = mlp_ratio
494
 
495
+ # split image into non-overlapping patches
496
  self.patch_embed = PatchEmbed(
497
  img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
498
  norm_layer=norm_layer if self.patch_norm else None)
 
500
  patches_resolution = self.patch_embed.patches_resolution
501
  self.patches_resolution = patches_resolution
502
 
503
+ # merge non-overlapping patches into image
504
  self.patch_unembed = PatchUnEmbed(
505
  img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
506
  norm_layer=norm_layer if self.patch_norm else None)
507
 
508
+ # absolute position embedding
509
  if self.ape:
510
  self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
511
+ trunc_normal_(self.absolute_pos_embed, std=.02)
512
 
513
  self.pos_drop = nn.Dropout(p=drop_rate)
514
 
515
+ # stochastic depth
516
  dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
517
 
518
+ # build Residual Hybrid Attention Groups (RHAG)
519
  self.layers = nn.ModuleList()
520
  for i_layer in range(self.num_layers):
521
+ layer = RHAG(
522
+ dim=embed_dim,
523
+ input_resolution=(patches_resolution[0], patches_resolution[1]),
524
+ depth=depths[i_layer],
525
+ num_heads=num_heads[i_layer],
526
+ window_size=window_size,
527
+ compress_ratio=compress_ratio,
528
+ squeeze_factor=squeeze_factor,
529
+ conv_scale=conv_scale,
530
+ overlap_ratio=overlap_ratio,
531
+ mlp_ratio=self.mlp_ratio,
532
+ qkv_bias=qkv_bias,
533
+ qk_scale=qk_scale,
534
+ drop=drop_rate,
535
+ attn_drop=attn_drop_rate,
536
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
537
+ norm_layer=norm_layer,
538
+ downsample=None,
539
+ use_checkpoint=use_checkpoint,
540
+ img_size=img_size,
541
+ patch_size=patch_size,
542
+ resi_connection=resi_connection)
543
  self.layers.append(layer)
544
  self.norm = norm_layer(self.num_features)
545
 
546
+ # build the last conv layer in deep feature extraction
547
  if resi_connection == '1conv':
548
  self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
549
+ elif resi_connection == 'identity':
550
+ self.conv_after_body = nn.Identity()
551
+
552
+ # high quality image reconstruction
553
+ if self.upsampler == 'pixelshuffle':
554
+ self.conv_before_upsample = nn.Sequential(
555
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
 
 
 
556
  self.upsample = Upsample(upscale, num_feat)
557
  self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
558
 
 
560
 
561
  def _init_weights(self, m):
562
  if isinstance(m, nn.Linear):
563
+ trunc_normal_(m.weight, std=.02)
564
  if isinstance(m, nn.Linear) and m.bias is not None:
565
  nn.init.constant_(m.bias, 0)
566
  elif isinstance(m, nn.LayerNorm):
567
  nn.init.constant_(m.bias, 0)
568
  nn.init.constant_(m.weight, 1.0)
569
 
570
+ def calculate_rpi_sa(self):
571
+ coords_h = torch.arange(self.window_size)
572
+ coords_w = torch.arange(self.window_size)
573
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
574
+ coords_flatten = torch.flatten(coords, 1)
575
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
576
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
577
+ relative_coords[:, :, 0] += self.window_size - 1
578
+ relative_coords[:, :, 1] += self.window_size - 1
579
+ relative_coords[:, :, 0] *= 2 * self.window_size - 1
580
+ relative_position_index = relative_coords.sum(-1)
581
+ return relative_position_index
582
+
583
+ def calculate_rpi_oca(self):
584
+ window_size_ori = self.window_size
585
+ window_size_ext = self.window_size + int(self.overlap_ratio * self.window_size)
586
+
587
+ coords_h = torch.arange(window_size_ori)
588
+ coords_w = torch.arange(window_size_ori)
589
+ coords_ori = torch.stack(torch.meshgrid([coords_h, coords_w]))
590
+ coords_ori_flatten = torch.flatten(coords_ori, 1)
591
+
592
+ coords_h = torch.arange(window_size_ext)
593
+ coords_w = torch.arange(window_size_ext)
594
+ coords_ext = torch.stack(torch.meshgrid([coords_h, coords_w]))
595
+ coords_ext_flatten = torch.flatten(coords_ext, 1)
596
+
597
+ relative_coords = coords_ext_flatten[:, None, :] - coords_ori_flatten[:, :, None]
598
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
599
+ relative_coords[:, :, 0] += window_size_ori - window_size_ext + 1
600
+ relative_coords[:, :, 1] += window_size_ori - window_size_ext + 1
601
+ relative_coords[:, :, 0] *= window_size_ori + window_size_ext - 1
602
+ relative_position_index = relative_coords.sum(-1)
603
+ return relative_position_index
604
+
605
+ def calculate_mask(self, x_size):
606
+ h, w = x_size
607
+ img_mask = torch.zeros((1, h, w, 1))
608
+ h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None))
609
+ w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None))
610
+ cnt = 0
611
+ for h in h_slices:
612
+ for w in w_slices:
613
+ img_mask[:, h, w, :] = cnt
614
+ cnt += 1
615
+
616
+ mask_windows = window_partition(img_mask, self.window_size)
617
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
618
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
619
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
620
+ return attn_mask
621
+
622
  @torch.jit.ignore
623
  def no_weight_decay(self):
624
  return {'absolute_pos_embed'}
 
629
 
630
  def forward_features(self, x):
631
  x_size = (x.shape[2], x.shape[3])
632
+
633
+ attn_mask = self.calculate_mask(x_size).to(x.device)
634
+ params = {'attn_mask': attn_mask, 'rpi_sa': self.relative_position_index_SA, 'rpi_oca': self.relative_position_index_OCA}
635
+
636
  x = self.patch_embed(x)
637
  if self.ape:
638
  x = x + self.absolute_pos_embed
639
  x = self.pos_drop(x)
640
 
641
  for layer in self.layers:
642
+ x = layer(x, x_size, params)
643
 
644
  x = self.norm(x)
645
  x = self.patch_unembed(x, x_size)
 
646
  return x
647
 
648
  def forward(self, x):
649
  self.mean = self.mean.type_as(x)
650
  x = (x - self.mean) * self.img_range
651
 
 
 
652
  if self.upsampler == 'pixelshuffle':
653
+ x = self.conv_first(x)
654
+ x = self.conv_after_body(self.forward_features(x)) + x
655
+ x = self.conv_before_upsample(x)
656
  x = self.conv_last(self.upsample(x))
657
 
658
  x = x / self.img_range + self.mean
 
659
  return x
660
 
661
 
app_old.py ADDED
@@ -0,0 +1,700 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from PIL import Image
6
+ import cv2
7
+ import math
8
+ from einops import rearrange
9
+
10
+
11
+ def to_2tuple(x):
12
+ """Convert input to tuple of length 2."""
13
+ if isinstance(x, (tuple, list)):
14
+ return tuple(x)
15
+ return (x, x)
16
+
17
+
18
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
19
+ """Truncated normal initialization."""
20
+ def norm_cdf(x):
21
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
22
+
23
+ with torch.no_grad():
24
+ l = norm_cdf((a - mean) / std)
25
+ u = norm_cdf((b - mean) / std)
26
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
27
+ tensor.erfinv_()
28
+ tensor.mul_(std * math.sqrt(2.))
29
+ tensor.add_(mean)
30
+ tensor.clamp_(min=a, max=b)
31
+ return tensor
32
+
33
+
34
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
35
+ if drop_prob == 0. or not training:
36
+ return x
37
+ keep_prob = 1 - drop_prob
38
+ shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
39
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
40
+ random_tensor.floor_()
41
+ output = x.div(keep_prob) * random_tensor
42
+ return output
43
+
44
+
45
+ class DropPath(nn.Module):
46
+ def __init__(self, drop_prob=None):
47
+ super(DropPath, self).__init__()
48
+ self.drop_prob = drop_prob
49
+
50
+ def forward(self, x):
51
+ return drop_path(x, self.drop_prob, self.training)
52
+
53
+
54
+ class ChannelAttention(nn.Module):
55
+ def __init__(self, num_feat, squeeze_factor=16):
56
+ super(ChannelAttention, self).__init__()
57
+ self.attention = nn.Sequential(
58
+ nn.AdaptiveAvgPool2d(1),
59
+ nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
60
+ nn.ReLU(inplace=True),
61
+ nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0),
62
+ nn.Sigmoid())
63
+
64
+ def forward(self, x):
65
+ y = self.attention(x)
66
+ return x * y
67
+
68
+
69
+ class CAB(nn.Module):
70
+ def __init__(self, num_feat, compress_ratio=3, squeeze_factor=30):
71
+ super(CAB, self).__init__()
72
+ self.cab = nn.Sequential(
73
+ nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1),
74
+ nn.GELU(),
75
+ nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1),
76
+ ChannelAttention(num_feat, squeeze_factor)
77
+ )
78
+
79
+ def forward(self, x):
80
+ return self.cab(x)
81
+
82
+
83
+ class Mlp(nn.Module):
84
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
85
+ super().__init__()
86
+ out_features = out_features or in_features
87
+ hidden_features = hidden_features or in_features
88
+ self.fc1 = nn.Linear(in_features, hidden_features)
89
+ self.act = act_layer()
90
+ self.fc2 = nn.Linear(hidden_features, out_features)
91
+ self.drop = nn.Dropout(drop)
92
+
93
+ def forward(self, x):
94
+ x = self.fc1(x)
95
+ x = self.act(x)
96
+ x = self.drop(x)
97
+ x = self.fc2(x)
98
+ x = self.drop(x)
99
+ return x
100
+
101
+
102
+ def window_partition(x, window_size):
103
+ B, H, W, C = x.shape
104
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
105
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
106
+ return windows
107
+
108
+
109
+ def window_reverse(windows, window_size, H, W):
110
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
111
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
112
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
113
+ return x
114
+
115
+
116
+ class WindowAttention(nn.Module):
117
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
118
+ super().__init__()
119
+ self.dim = dim
120
+ self.window_size = window_size
121
+ self.num_heads = num_heads
122
+ head_dim = dim // num_heads
123
+ self.scale = qk_scale or head_dim ** -0.5
124
+
125
+ self.relative_position_bias_table = nn.Parameter(
126
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
127
+
128
+ coords_h = torch.arange(self.window_size[0])
129
+ coords_w = torch.arange(self.window_size[1])
130
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
131
+ coords_flatten = torch.flatten(coords, 1)
132
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
133
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
134
+ relative_coords[:, :, 0] += self.window_size[0] - 1
135
+ relative_coords[:, :, 1] += self.window_size[1] - 1
136
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
137
+ relative_position_index = relative_coords.sum(-1)
138
+ self.register_buffer("relative_position_index", relative_position_index)
139
+
140
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
141
+ self.attn_drop = nn.Dropout(attn_drop)
142
+ self.proj = nn.Linear(dim, dim)
143
+ self.proj_drop = nn.Dropout(proj_drop)
144
+
145
+ nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
146
+ self.softmax = nn.Softmax(dim=-1)
147
+
148
+ def forward(self, x, mask=None):
149
+ B_, N, C = x.shape
150
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
151
+ q, k, v = qkv[0], qkv[1], qkv[2]
152
+
153
+ q = q * self.scale
154
+ attn = (q @ k.transpose(-2, -1))
155
+
156
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
157
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
158
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
159
+ attn = attn + relative_position_bias.unsqueeze(0)
160
+
161
+ if mask is not None:
162
+ nW = mask.shape[0]
163
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
164
+ attn = attn.view(-1, self.num_heads, N, N)
165
+ attn = self.softmax(attn)
166
+ else:
167
+ attn = self.softmax(attn)
168
+
169
+ attn = self.attn_drop(attn)
170
+
171
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
172
+ x = self.proj(x)
173
+ x = self.proj_drop(x)
174
+ return x
175
+
176
+
177
+ class HAB(nn.Module):
178
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
179
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
180
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, compress_ratio=3, squeeze_factor=30):
181
+ super().__init__()
182
+ self.dim = dim
183
+ self.input_resolution = input_resolution
184
+ self.num_heads = num_heads
185
+ self.window_size = window_size
186
+ self.shift_size = shift_size
187
+ self.mlp_ratio = mlp_ratio
188
+ if min(self.input_resolution) <= self.window_size:
189
+ self.shift_size = 0
190
+ self.window_size = min(self.input_resolution)
191
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
192
+
193
+ self.norm1 = norm_layer(dim)
194
+ self.attn = WindowAttention(
195
+ dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
196
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
197
+
198
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
199
+ self.norm2 = norm_layer(dim)
200
+ mlp_hidden_dim = int(dim * mlp_ratio)
201
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
202
+
203
+ self.conv_scale = nn.Parameter(torch.ones(1))
204
+ self.conv_block = CAB(dim, compress_ratio, squeeze_factor)
205
+
206
+ if self.shift_size > 0:
207
+ H, W = self.input_resolution
208
+ img_mask = torch.zeros((1, H, W, 1))
209
+ h_slices = (slice(0, -self.window_size),
210
+ slice(-self.window_size, -self.shift_size),
211
+ slice(-self.shift_size, None))
212
+ w_slices = (slice(0, -self.window_size),
213
+ slice(-self.window_size, -self.shift_size),
214
+ slice(-self.shift_size, None))
215
+ cnt = 0
216
+ for h in h_slices:
217
+ for w in w_slices:
218
+ img_mask[:, h, w, :] = cnt
219
+ cnt += 1
220
+
221
+ mask_windows = window_partition(img_mask, self.window_size)
222
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
223
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
224
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
225
+ else:
226
+ attn_mask = None
227
+
228
+ self.register_buffer("attn_mask", attn_mask)
229
+
230
+ def forward(self, x):
231
+ H, W = self.input_resolution
232
+ B, L, C = x.shape
233
+ assert L == H * W, "input feature has wrong size"
234
+
235
+ shortcut = x
236
+ x = self.norm1(x)
237
+ x = x.view(B, H, W, C)
238
+
239
+ if self.shift_size > 0:
240
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
241
+ else:
242
+ shifted_x = x
243
+
244
+ x_windows = window_partition(shifted_x, self.window_size)
245
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
246
+
247
+ attn_windows = self.attn(x_windows, mask=self.attn_mask)
248
+
249
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
250
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W)
251
+
252
+ if self.shift_size > 0:
253
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
254
+ else:
255
+ x = shifted_x
256
+ x = x.view(B, H * W, C)
257
+
258
+ x = shortcut + self.drop_path(x)
259
+
260
+ y = x
261
+ x = self.norm2(x)
262
+ x = self.mlp(x)
263
+ x = y + self.drop_path(x)
264
+
265
+ conv_x = self.conv_block(x.view(B, H, W, C).permute(0, 3, 1, 2))
266
+ conv_x = conv_x.permute(0, 2, 3, 1).view(B, H * W, C)
267
+
268
+ x = x + self.conv_scale * conv_x
269
+
270
+ return x
271
+
272
+
273
+ class OCAB(nn.Module):
274
+ def __init__(self, dim, input_resolution, window_size, overlap_ratio, num_heads,
275
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
276
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, compress_ratio=3,
277
+ squeeze_factor=30):
278
+ super().__init__()
279
+ self.dim = dim
280
+ self.input_resolution = input_resolution
281
+ self.window_size = window_size
282
+ self.num_heads = num_heads
283
+ self.shift_size = round(overlap_ratio * window_size)
284
+ self.mlp_ratio = mlp_ratio
285
+
286
+ if min(self.input_resolution) <= self.window_size:
287
+ self.shift_size = 0
288
+ self.window_size = min(self.input_resolution)
289
+
290
+ assert 0 <= self.shift_size, "shift_size >= 0 is required"
291
+
292
+ self.norm1 = norm_layer(dim)
293
+ self.attn = WindowAttention(
294
+ dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
295
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
296
+
297
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
298
+ self.norm2 = norm_layer(dim)
299
+ mlp_hidden_dim = int(dim * mlp_ratio)
300
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
301
+
302
+ self.conv_scale = nn.Parameter(torch.ones(1))
303
+ self.conv_block = CAB(dim, compress_ratio, squeeze_factor)
304
+
305
+ def forward(self, x):
306
+ H, W = self.input_resolution
307
+ B, L, C = x.shape
308
+ assert L == H * W, "input feature has wrong size"
309
+
310
+ shortcut = x
311
+ x = self.norm1(x)
312
+ x = x.view(B, H, W, C)
313
+
314
+ pad_l = pad_t = 0
315
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
316
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
317
+ x = torch.nn.functional.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
318
+ _, Hp, Wp, _ = x.shape
319
+
320
+ if self.shift_size > 0:
321
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
322
+ else:
323
+ shifted_x = x
324
+
325
+ x_windows = window_partition(shifted_x, self.window_size)
326
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
327
+
328
+ attn_windows = self.attn(x_windows, mask=None)
329
+
330
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
331
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)
332
+
333
+ if self.shift_size > 0:
334
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
335
+ else:
336
+ x = shifted_x
337
+
338
+ if pad_r > 0 or pad_b > 0:
339
+ x = x[:, :H, :W, :].contiguous()
340
+
341
+ x = x.view(B, H * W, C)
342
+ x = shortcut + self.drop_path(x)
343
+
344
+ y = x
345
+ x = self.norm2(x)
346
+ x = self.mlp(x)
347
+ x = y + self.drop_path(x)
348
+
349
+ conv_x = self.conv_block(x.view(B, H, W, C).permute(0, 3, 1, 2))
350
+ conv_x = conv_x.permute(0, 2, 3, 1).view(B, H * W, C)
351
+
352
+ x = x + self.conv_scale * conv_x
353
+
354
+ return x
355
+
356
+
357
+ class PatchEmbed(nn.Module):
358
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
359
+ super().__init__()
360
+ img_size = (img_size, img_size)
361
+ patch_size = (patch_size, patch_size)
362
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
363
+ self.img_size = img_size
364
+ self.patch_size = patch_size
365
+ self.patches_resolution = patches_resolution
366
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
367
+
368
+ self.in_chans = in_chans
369
+ self.embed_dim = embed_dim
370
+
371
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
372
+ if norm_layer is not None:
373
+ self.norm = norm_layer(embed_dim)
374
+ else:
375
+ self.norm = None
376
+
377
+ def forward(self, x):
378
+ B, C, H, W = x.shape
379
+ assert H == self.img_size[0] and W == self.img_size[1], \
380
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
381
+ x = self.proj(x).flatten(2).transpose(1, 2)
382
+ if self.norm is not None:
383
+ x = self.norm(x)
384
+ return x
385
+
386
+
387
+ class PatchUnEmbed(nn.Module):
388
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
389
+ super().__init__()
390
+ img_size = (img_size, img_size)
391
+ patch_size = (patch_size, patch_size)
392
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
393
+ self.img_size = img_size
394
+ self.patch_size = patch_size
395
+ self.patches_resolution = patches_resolution
396
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
397
+
398
+ self.in_chans = in_chans
399
+ self.embed_dim = embed_dim
400
+
401
+ def forward(self, x, x_size):
402
+ H, W = x_size
403
+ B, HW, C = x.shape
404
+ x = x.transpose(1, 2).view(B, self.embed_dim, H, W)
405
+ return x
406
+
407
+
408
+ class RHAG(nn.Module):
409
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size, compress_ratio,
410
+ squeeze_factor, conv_scale, overlap_ratio, mlp_ratio=4., qkv_bias=True, qk_scale=None,
411
+ drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None,
412
+ use_checkpoint=False):
413
+ super().__init__()
414
+ self.dim = dim
415
+ self.input_resolution = input_resolution
416
+ self.depth = depth
417
+ self.use_checkpoint = use_checkpoint
418
+
419
+ self.blocks_1 = nn.ModuleList([
420
+ HAB(dim=dim, input_resolution=input_resolution,
421
+ num_heads=num_heads, window_size=window_size,
422
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
423
+ mlp_ratio=mlp_ratio,
424
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
425
+ drop=drop, attn_drop=attn_drop,
426
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
427
+ norm_layer=norm_layer, compress_ratio=compress_ratio,
428
+ squeeze_factor=squeeze_factor)
429
+ for i in range(depth // 2)])
430
+
431
+ self.blocks_2 = nn.ModuleList([
432
+ OCAB(dim=dim, input_resolution=input_resolution,
433
+ window_size=window_size, overlap_ratio=overlap_ratio,
434
+ num_heads=num_heads, mlp_ratio=mlp_ratio,
435
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
436
+ drop=drop, attn_drop=attn_drop,
437
+ drop_path=drop_path[i + depth//2] if isinstance(drop_path, list) else drop_path,
438
+ norm_layer=norm_layer, compress_ratio=compress_ratio,
439
+ squeeze_factor=squeeze_factor)
440
+ for i in range(depth // 2)])
441
+
442
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
443
+ self.conv_scale = conv_scale
444
+
445
+ if downsample is not None:
446
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
447
+ else:
448
+ self.downsample = None
449
+
450
+ def forward(self, x, x_size):
451
+ H, W = x_size
452
+ res = x
453
+ for blk in self.blocks_1:
454
+ if self.use_checkpoint:
455
+ x = torch.utils.checkpoint.checkpoint(blk, x)
456
+ else:
457
+ x = blk(x)
458
+ for blk in self.blocks_2:
459
+ if self.use_checkpoint:
460
+ x = torch.utils.checkpoint.checkpoint(blk, x)
461
+ else:
462
+ x = blk(x)
463
+
464
+ conv_x = self.conv(x.transpose(1, 2).view(-1, self.dim, H, W)).view(-1, self.dim, H * W).transpose(1, 2)
465
+ x = res + x + conv_x * self.conv_scale
466
+
467
+ if self.downsample is not None:
468
+ x = self.downsample(x)
469
+ return x
470
+
471
+
472
+ class Upsample(nn.Sequential):
473
+ def __init__(self, scale, num_feat):
474
+ m = []
475
+ if (scale & (scale - 1)) == 0:
476
+ for _ in range(int(math.log(scale, 2))):
477
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
478
+ m.append(nn.PixelShuffle(2))
479
+ elif scale == 3:
480
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
481
+ m.append(nn.PixelShuffle(3))
482
+ else:
483
+ raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
484
+ super(Upsample, self).__init__(*m)
485
+
486
+
487
+ class HAT(nn.Module):
488
+ def __init__(self, img_size=64, patch_size=1, in_chans=3, embed_dim=180, depths=[6, 6, 6, 6, 6, 6],
489
+ num_heads=[6, 6, 6, 6, 6, 6], window_size=16, compress_ratio=3, squeeze_factor=30,
490
+ conv_scale=0.01, overlap_ratio=0.5, mlp_ratio=4., qkv_bias=True, qk_scale=None,
491
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm,
492
+ ape=False, patch_norm=True, use_checkpoint=False, upscale=2, img_range=1.,
493
+ upsampler='', resi_connection='1conv', **kwargs):
494
+ super(HAT, self).__init__()
495
+
496
+ self.window_size = window_size
497
+ self.shift_size = window_size // 2
498
+ self.overlap_ratio = overlap_ratio
499
+ num_in_ch = in_chans
500
+ num_out_ch = in_chans
501
+ num_feat = 64
502
+ self.img_range = img_range
503
+ if in_chans == 3:
504
+ rgb_mean = (0.4488, 0.4371, 0.4040)
505
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
506
+ else:
507
+ self.mean = torch.zeros(1, 1, 1, 1)
508
+ self.upscale = upscale
509
+ self.upsampler = upsampler
510
+
511
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
512
+
513
+ self.num_layers = len(depths)
514
+ self.embed_dim = embed_dim
515
+ self.ape = ape
516
+ self.patch_norm = patch_norm
517
+ self.num_features = embed_dim
518
+ self.mlp_ratio = mlp_ratio
519
+
520
+ self.patch_embed = PatchEmbed(
521
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
522
+ norm_layer=norm_layer if self.patch_norm else None)
523
+ num_patches = self.patch_embed.num_patches
524
+ patches_resolution = self.patch_embed.patches_resolution
525
+ self.patches_resolution = patches_resolution
526
+
527
+ self.patch_unembed = PatchUnEmbed(
528
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
529
+ norm_layer=norm_layer if self.patch_norm else None)
530
+
531
+ if self.ape:
532
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
533
+ nn.init.trunc_normal_(self.absolute_pos_embed, std=.02)
534
+
535
+ self.pos_drop = nn.Dropout(p=drop_rate)
536
+
537
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
538
+
539
+ self.layers = nn.ModuleList()
540
+ for i_layer in range(self.num_layers):
541
+ layer = RHAG(dim=embed_dim,
542
+ input_resolution=(patches_resolution[0],
543
+ patches_resolution[1]),
544
+ depth=depths[i_layer],
545
+ num_heads=num_heads[i_layer],
546
+ window_size=window_size,
547
+ compress_ratio=compress_ratio,
548
+ squeeze_factor=squeeze_factor,
549
+ conv_scale=conv_scale,
550
+ overlap_ratio=overlap_ratio,
551
+ mlp_ratio=self.mlp_ratio,
552
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
553
+ drop=drop_rate, attn_drop=attn_drop_rate,
554
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
555
+ norm_layer=norm_layer,
556
+ downsample=None,
557
+ use_checkpoint=use_checkpoint)
558
+ self.layers.append(layer)
559
+ self.norm = norm_layer(self.num_features)
560
+
561
+ if resi_connection == '1conv':
562
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
563
+ elif resi_connection == '3conv':
564
+ self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
565
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
566
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
567
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
568
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
569
+
570
+ if upsampler == 'pixelshuffle':
571
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
572
+ nn.LeakyReLU(inplace=True))
573
+ self.upsample = Upsample(upscale, num_feat)
574
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
575
+
576
+ self.apply(self._init_weights)
577
+
578
+ def _init_weights(self, m):
579
+ if isinstance(m, nn.Linear):
580
+ nn.init.trunc_normal_(m.weight, std=.02)
581
+ if isinstance(m, nn.Linear) and m.bias is not None:
582
+ nn.init.constant_(m.bias, 0)
583
+ elif isinstance(m, nn.LayerNorm):
584
+ nn.init.constant_(m.bias, 0)
585
+ nn.init.constant_(m.weight, 1.0)
586
+
587
+ @torch.jit.ignore
588
+ def no_weight_decay(self):
589
+ return {'absolute_pos_embed'}
590
+
591
+ @torch.jit.ignore
592
+ def no_weight_decay_keywords(self):
593
+ return {'relative_position_bias_table'}
594
+
595
+ def forward_features(self, x):
596
+ x_size = (x.shape[2], x.shape[3])
597
+ x = self.patch_embed(x)
598
+ if self.ape:
599
+ x = x + self.absolute_pos_embed
600
+ x = self.pos_drop(x)
601
+
602
+ for layer in self.layers:
603
+ x = layer(x, x_size)
604
+
605
+ x = self.norm(x)
606
+ x = self.patch_unembed(x, x_size)
607
+
608
+ return x
609
+
610
+ def forward(self, x):
611
+ self.mean = self.mean.type_as(x)
612
+ x = (x - self.mean) * self.img_range
613
+
614
+ x_first = self.conv_first(x)
615
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
616
+ if self.upsampler == 'pixelshuffle':
617
+ x = self.conv_before_upsample(res)
618
+ x = self.conv_last(self.upsample(x))
619
+
620
+ x = x / self.img_range + self.mean
621
+
622
+ return x
623
+
624
+
625
+ # Load the model
626
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
627
+
628
+ model = HAT(
629
+ upscale=4,
630
+ in_chans=3,
631
+ img_size=128,
632
+ window_size=16,
633
+ compress_ratio=3,
634
+ squeeze_factor=30,
635
+ conv_scale=0.01,
636
+ overlap_ratio=0.5,
637
+ img_range=1.,
638
+ depths=[6, 6, 6, 6, 6, 6],
639
+ embed_dim=180,
640
+ num_heads=[6, 6, 6, 6, 6, 6],
641
+ mlp_ratio=2,
642
+ upsampler='pixelshuffle',
643
+ resi_connection='1conv'
644
+ )
645
+
646
+ # Load the fine-tuned weights
647
+ checkpoint = torch.load('net_g_20000.pth', map_location=device)
648
+ if 'params_ema' in checkpoint:
649
+ model.load_state_dict(checkpoint['params_ema'])
650
+ elif 'params' in checkpoint:
651
+ model.load_state_dict(checkpoint['params'])
652
+ else:
653
+ model.load_state_dict(checkpoint)
654
+
655
+ model.to(device)
656
+ model.eval()
657
+
658
+
659
+ def upscale_image(image):
660
+ # Convert PIL image to tensor
661
+ img_np = np.array(image).astype(np.float32) / 255.0
662
+ img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device)
663
+
664
+ # Ensure the image dimensions are multiples of window_size
665
+ h, w = img_tensor.shape[2], img_tensor.shape[3]
666
+
667
+ # Pad if necessary
668
+ pad_h = (16 - h % 16) % 16
669
+ pad_w = (16 - w % 16) % 16
670
+
671
+ if pad_h > 0 or pad_w > 0:
672
+ img_tensor = torch.nn.functional.pad(img_tensor, (0, pad_w, 0, pad_h), mode='reflect')
673
+
674
+ with torch.no_grad():
675
+ output = model(img_tensor)
676
+
677
+ # Remove padding if it was added
678
+ if pad_h > 0 or pad_w > 0:
679
+ output = output[:, :, :h*4, :w*4]
680
+
681
+ # Convert back to PIL image
682
+ output_np = output.squeeze(0).permute(1, 2, 0).cpu().numpy()
683
+ output_np = np.clip(output_np * 255.0, 0, 255).astype(np.uint8)
684
+
685
+ return Image.fromarray(output_np)
686
+
687
+
688
+ # Gradio interface
689
+ iface = gr.Interface(
690
+ fn=upscale_image,
691
+ inputs=gr.Image(type="pil", label="Input Satellite Image"),
692
+ outputs=gr.Image(type="pil", label="Super-Resolution Output (4x)"),
693
+ title="HAT Super-Resolution for Satellite Images",
694
+ description="Upload a satellite image to enhance its resolution by 4x using a fine-tuned HAT model. This model has been specifically trained on satellite imagery to provide high-quality super-resolution results.",
695
+ examples=None,
696
+ cache_examples=False
697
+ )
698
+
699
+ if __name__ == "__main__":
700
+ iface.launch()