Dhenenjay commited on
Commit
91659d8
·
verified ·
1 Parent(s): ca82763

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +36 -463
app.py CHANGED
@@ -1,27 +1,18 @@
1
- """
2
- E3Diff: High-Resolution SAR-to-Optical Translation
3
- HuggingFace Spaces Deployment
4
-
5
- Features:
6
- - Full resolution processing with seamless tiling
7
- - Proper diffusion sampling (matching local inference)
8
- - TIFF output support
9
- """
10
 
11
  import os
12
- import sys
13
  import torch
14
- import torch.nn as nn
15
- import torch.nn.functional as F
16
  import numpy as np
17
  from PIL import Image, ImageEnhance
18
  import gradio as gr
19
- from pathlib import Path
20
  import tempfile
21
  import time
22
- from functools import partial
23
  from huggingface_hub import hf_hub_download
24
 
 
 
 
 
25
  # ZeroGPU support
26
  try:
27
  import spaces
@@ -30,442 +21,10 @@ except ImportError:
30
  GPU_AVAILABLE = False
31
  spaces = None
32
 
33
- # ============================================================================
34
- # SoftPool Implementation (Pure PyTorch)
35
- # ============================================================================
36
-
37
- def soft_pool2d(x, kernel_size=(2, 2), stride=None, force_inplace=False):
38
- if stride is None:
39
- stride = kernel_size
40
- if isinstance(kernel_size, int):
41
- kernel_size = (kernel_size, kernel_size)
42
- if isinstance(stride, int):
43
- stride = (stride, stride)
44
-
45
- batch, channels, height, width = x.shape
46
- kh, kw = kernel_size
47
- sh, sw = stride
48
- out_h = (height - kh) // sh + 1
49
- out_w = (width - kw) // sw + 1
50
-
51
- x_unfold = F.unfold(x, kernel_size=kernel_size, stride=stride)
52
- x_unfold = x_unfold.view(batch, channels, kh * kw, out_h * out_w)
53
- x_max = x_unfold.max(dim=2, keepdim=True)[0]
54
- exp_x = torch.exp(x_unfold - x_max)
55
- softpool = (x_unfold * exp_x).sum(dim=2) / (exp_x.sum(dim=2) + 1e-8)
56
- return softpool.view(batch, channels, out_h, out_w)
57
-
58
-
59
- class SoftPool2d(nn.Module):
60
- def __init__(self, kernel_size=(2, 2), stride=None, force_inplace=False):
61
- super(SoftPool2d, self).__init__()
62
- self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
63
- self.stride = stride if stride is not None else self.kernel_size
64
-
65
- def forward(self, x):
66
- return soft_pool2d(x, self.kernel_size, self.stride)
67
-
68
-
69
- # Monkey-patch SoftPool
70
- class SoftPoolModule:
71
- soft_pool2d = staticmethod(soft_pool2d)
72
- SoftPool2d = SoftPool2d
73
- sys.modules['SoftPool'] = SoftPoolModule()
74
-
75
- # ============================================================================
76
- # Model Architecture
77
- # ============================================================================
78
-
79
- import math
80
- from inspect import isfunction
81
-
82
- def exists(x):
83
- return x is not None
84
-
85
- def default(val, d):
86
- if exists(val):
87
- return val
88
- return d() if isfunction(d) else d
89
-
90
-
91
- class PositionalEncoding(nn.Module):
92
- def __init__(self, dim):
93
- super().__init__()
94
- self.dim = dim
95
-
96
- def forward(self, noise_level):
97
- count = self.dim // 2
98
- step = torch.arange(count, dtype=noise_level.dtype, device=noise_level.device) / count
99
- encoding = noise_level.unsqueeze(1) * torch.exp(-math.log(1e4) * step.unsqueeze(0))
100
- encoding = torch.cat([torch.sin(encoding), torch.cos(encoding)], dim=-1)
101
- return encoding
102
-
103
-
104
- class Swish(nn.Module):
105
- def forward(self, x):
106
- return x * torch.sigmoid(x)
107
-
108
-
109
- class FeatureWiseAffine(nn.Module):
110
- def __init__(self, in_channels, out_channels, use_affine_level=False):
111
- super(FeatureWiseAffine, self).__init__()
112
- self.use_affine_level = use_affine_level
113
- self.noise_func = nn.Sequential(nn.Linear(in_channels, out_channels*(1+self.use_affine_level)))
114
-
115
- def forward(self, x, noise_embed):
116
- batch = x.shape[0]
117
- if self.use_affine_level:
118
- gamma, beta = self.noise_func(noise_embed).view(batch, -1, 1, 1).chunk(2, dim=1)
119
- x = (1 + gamma) * x + beta
120
- else:
121
- x = x + self.noise_func(noise_embed).view(batch, -1, 1, 1)
122
- return x
123
-
124
-
125
- class Upsample(nn.Module):
126
- def __init__(self, dim):
127
- super().__init__()
128
- self.up = nn.Upsample(scale_factor=2, mode="nearest")
129
- self.conv = nn.Conv2d(dim, dim, 3, padding=1)
130
-
131
- def forward(self, x):
132
- return self.conv(self.up(x))
133
-
134
-
135
- class Downsample(nn.Module):
136
- def __init__(self, dim):
137
- super().__init__()
138
- self.conv = nn.Conv2d(dim, dim, 3, 2, 1)
139
-
140
- def forward(self, x):
141
- return self.conv(x)
142
-
143
-
144
- class Block(nn.Module):
145
- def __init__(self, dim, dim_out, groups=32, dropout=0, stride=1):
146
- super().__init__()
147
- self.block = nn.Sequential(
148
- nn.GroupNorm(groups, dim),
149
- Swish(),
150
- nn.Dropout(dropout) if dropout != 0 else nn.Identity(),
151
- nn.Conv2d(dim, dim_out, 3, stride=stride, padding=1)
152
- )
153
-
154
- def forward(self, x):
155
- return self.block(x)
156
-
157
-
158
- class ResnetBlock(nn.Module):
159
- def __init__(self, dim, dim_out, noise_level_emb_dim=None, dropout=0, use_affine_level=False, norm_groups=32):
160
- super().__init__()
161
- self.noise_func = FeatureWiseAffine(noise_level_emb_dim, dim_out, use_affine_level)
162
- self.c_func = nn.Conv2d(dim_out, dim_out, 1)
163
- self.block1 = Block(dim, dim_out, groups=norm_groups)
164
- self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout)
165
- self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
166
-
167
- def forward(self, x, time_emb, c):
168
- h = self.block1(x)
169
- h = self.noise_func(h, time_emb)
170
- h = self.block2(h)
171
- # Resize condition features to match spatial size
172
- if c.shape[2:] != h.shape[2:]:
173
- c = F.interpolate(c, size=h.shape[2:], mode='bilinear', align_corners=False)
174
- h = self.c_func(c) + h
175
- return h + self.res_conv(x)
176
-
177
-
178
- class SelfAttention(nn.Module):
179
- def __init__(self, in_channel, n_head=1, norm_groups=32):
180
- super().__init__()
181
- self.n_head = n_head
182
- self.norm = nn.GroupNorm(norm_groups, in_channel)
183
- self.qkv = nn.Conv2d(in_channel, in_channel * 3, 1, bias=False)
184
- self.out = nn.Conv2d(in_channel, in_channel, 1)
185
-
186
- def forward(self, input, t=None, save_flag=None, file_num=None):
187
- batch, channel, height, width = input.shape
188
- n_head = self.n_head
189
- head_dim = channel // n_head
190
- norm = self.norm(input)
191
- qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, height, width)
192
- query, key, value = qkv.chunk(3, dim=2)
193
- attn = torch.einsum("bnchw, bncyx -> bnhwyx", query, key).contiguous() / math.sqrt(channel)
194
- attn = attn.view(batch, n_head, height, width, -1)
195
- attn = torch.softmax(attn, -1)
196
- attn = attn.view(batch, n_head, height, width, height, width)
197
- out = torch.einsum("bnhwyx, bncyx -> bnchw", attn, value).contiguous()
198
- out = self.out(out.view(batch, channel, height, width))
199
- return out + input
200
-
201
-
202
- class ResnetBlocWithAttn(nn.Module):
203
- def __init__(self, dim, dim_out, *, noise_level_emb_dim=None, norm_groups=32, dropout=0, with_attn=False, size=256):
204
- super().__init__()
205
- self.with_attn = with_attn
206
- self.res_block = ResnetBlock(dim, dim_out, noise_level_emb_dim, norm_groups=norm_groups, dropout=dropout)
207
- if with_attn:
208
- self.attn = SelfAttention(dim_out, norm_groups=norm_groups)
209
-
210
- def forward(self, x, time_emb, c):
211
- x = self.res_block(x, time_emb, c)
212
- if self.with_attn:
213
- x = self.attn(x, time_emb)
214
- return x
215
-
216
-
217
- # CPEN Condition Encoder
218
- class CPEN(nn.Module):
219
- def __init__(self, inchannel=3):
220
- super(CPEN, self).__init__()
221
- from SoftPool import SoftPool2d
222
-
223
- self.conv1 = nn.Conv2d(inchannel, 64, 3, 1, 1)
224
- self.pool1 = SoftPool2d(kernel_size=(2, 2), stride=(2, 2))
225
- self.conv2 = nn.Conv2d(64, 128, 3, 1, 1)
226
- self.pool2 = SoftPool2d(kernel_size=(2, 2), stride=(2, 2))
227
- self.conv3 = nn.Conv2d(128, 256, 3, 1, 1)
228
- self.pool3 = SoftPool2d(kernel_size=(2, 2), stride=(2, 2))
229
- self.conv4 = nn.Conv2d(256, 512, 3, 1, 1)
230
- self.pool4 = SoftPool2d(kernel_size=(2, 2), stride=(2, 2))
231
- self.conv5 = nn.Conv2d(512, 1024, 3, 1, 1)
232
-
233
- def forward(self, x):
234
- c1 = self.pool1(F.leaky_relu(self.conv1(x)))
235
- c2 = self.pool2(F.leaky_relu(self.conv2(c1)))
236
- c3 = self.pool3(F.leaky_relu(self.conv3(c2)))
237
- c4 = self.pool4(F.leaky_relu(self.conv4(c3)))
238
- c5 = F.leaky_relu(self.conv5(c4))
239
- return c1, c2, c3, c4, c5
240
-
241
-
242
- class UNet(nn.Module):
243
- def __init__(self, in_channel=6, out_channel=3, inner_channel=32, norm_groups=32,
244
- channel_mults=(1, 2, 4, 8, 8), attn_res=(8,), res_blocks=3, dropout=0,
245
- with_noise_level_emb=True, image_size=128, condition_ch=3):
246
- super().__init__()
247
-
248
- self.res_blocks = res_blocks
249
- noise_level_channel = inner_channel
250
- self.noise_level_mlp = nn.Sequential(
251
- PositionalEncoding(inner_channel),
252
- nn.Linear(inner_channel, inner_channel * 4),
253
- Swish(),
254
- nn.Linear(inner_channel * 4, inner_channel)
255
- ) if with_noise_level_emb else None
256
-
257
- num_mults = len(channel_mults)
258
- pre_channel = inner_channel
259
- feat_channels = [pre_channel]
260
- now_res = image_size
261
-
262
- downs = [nn.Conv2d(in_channel, inner_channel, kernel_size=3, padding=1)]
263
- for ind in range(num_mults):
264
- is_last = (ind == num_mults - 1)
265
- use_attn = (now_res in attn_res)
266
- channel_mult = inner_channel * channel_mults[ind]
267
- for _ in range(0, res_blocks):
268
- downs.append(ResnetBlocWithAttn(pre_channel, channel_mult, noise_level_emb_dim=noise_level_channel,
269
- norm_groups=norm_groups, dropout=dropout, with_attn=use_attn, size=now_res))
270
- feat_channels.append(channel_mult)
271
- pre_channel = channel_mult
272
- if not is_last:
273
- downs.append(Downsample(pre_channel))
274
- feat_channels.append(pre_channel)
275
- now_res = now_res // 2
276
- self.downs = nn.ModuleList(downs)
277
-
278
- self.mid = nn.ModuleList([
279
- ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel,
280
- norm_groups=norm_groups, dropout=dropout, with_attn=True, size=now_res),
281
- ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel,
282
- norm_groups=norm_groups, dropout=dropout, with_attn=False, size=now_res)
283
- ])
284
-
285
- ups = []
286
- for ind in reversed(range(num_mults)):
287
- is_last = (ind < 1)
288
- use_attn = (now_res in attn_res)
289
- channel_mult = inner_channel * channel_mults[ind]
290
- for _ in range(0, res_blocks + 1):
291
- ups.append(ResnetBlocWithAttn(pre_channel + feat_channels.pop(), channel_mult,
292
- noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
293
- dropout=dropout, with_attn=use_attn, size=now_res))
294
- pre_channel = channel_mult
295
- if not is_last:
296
- ups.append(Upsample(pre_channel))
297
- now_res = now_res * 2
298
- self.ups = nn.ModuleList(ups)
299
-
300
- self.final_conv = Block(pre_channel, default(out_channel, in_channel), groups=norm_groups)
301
- self.condition = CPEN(inchannel=condition_ch)
302
- self.condition_ch = condition_ch
303
-
304
- def forward(self, x, time, img_s1=None, class_label=None, return_condition=False, t_ori=0):
305
- condition = x[:, :self.condition_ch, ...].clone()
306
- x = x[:, self.condition_ch:, ...]
307
-
308
- c1, c2, c3, c4, c5 = self.condition(condition)
309
- c_base = [c1, c2, c3, c4, c5]
310
-
311
- c = []
312
- for i in range(len(c_base)):
313
- for _ in range(self.res_blocks):
314
- c.append(c_base[i])
315
-
316
- t = self.noise_level_mlp(time) if exists(self.noise_level_mlp) else None
317
-
318
- feats = []
319
- i = 0
320
- for layer in self.downs:
321
- if isinstance(layer, ResnetBlocWithAttn):
322
- x = layer(x, t, c[i])
323
- i += 1
324
- else:
325
- x = layer(x)
326
- feats.append(x)
327
-
328
- for layer in self.mid:
329
- if isinstance(layer, ResnetBlocWithAttn):
330
- x = layer(x, t, c5)
331
- else:
332
- x = layer(x)
333
-
334
- c_base = [c5, c4, c3, c2, c1]
335
- c = []
336
- for i in range(len(c_base)):
337
- for _ in range(self.res_blocks + 1):
338
- c.append(c_base[i])
339
-
340
- i = 0
341
- for layer in self.ups:
342
- if isinstance(layer, ResnetBlocWithAttn):
343
- x = layer(torch.cat((x, feats.pop()), dim=1), t, c[i])
344
- i += 1
345
- else:
346
- x = layer(x)
347
-
348
- if not return_condition:
349
- return self.final_conv(x)
350
- else:
351
- return self.final_conv(x), [c1, c2, c3, c4, c5]
352
-
353
-
354
- # ============================================================================
355
- # GaussianDiffusion - Proper DDIM Sampling
356
- # ============================================================================
357
-
358
- def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2):
359
- if schedule == 'linear':
360
- betas = np.linspace(linear_start, linear_end, n_timestep, dtype=np.float64)
361
- else:
362
- raise NotImplementedError(schedule)
363
- return betas
364
-
365
-
366
- class GaussianDiffusion(nn.Module):
367
- def __init__(self, denoise_fn, image_size, channels=3, schedule_opt=None, opt=None):
368
- super().__init__()
369
- self.channels = channels
370
- self.image_size = image_size
371
- self.denoise_fn = denoise_fn
372
- self.opt = opt
373
- self.ddim = schedule_opt.get('ddim', 1) if schedule_opt else 1
374
-
375
- def set_new_noise_schedule(self, schedule_opt, device, num_train_timesteps=1000):
376
- self.ddim = schedule_opt['ddim']
377
- self.num_train_timesteps = num_train_timesteps
378
- to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
379
-
380
- betas = make_beta_schedule(
381
- schedule=schedule_opt['schedule'],
382
- n_timestep=num_train_timesteps,
383
- linear_start=schedule_opt['linear_start'],
384
- linear_end=schedule_opt['linear_end']
385
- )
386
-
387
- alphas = 1. - betas
388
- alphas_cumprod = np.cumprod(alphas, axis=0)
389
- self.sqrt_alphas_cumprod_prev = np.sqrt(np.append(1., alphas_cumprod))
390
-
391
- self.num_timesteps = int(betas.shape[0])
392
- self.register_buffer('betas', to_torch(betas))
393
- self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
394
-
395
- self.ddim_num_steps = schedule_opt['n_timestep']
396
- print(f'DDIM sampling steps: {self.ddim_num_steps}')
397
-
398
- def ddim_sample(self, condition_x, img_or_shape, device, seed=1):
399
- """DDIM sampling - matches the original E3Diff implementation."""
400
- eta = 0.8 # ddim_sampling_eta for linear schedule
401
-
402
- batch = img_or_shape[0]
403
- total_timesteps = self.num_train_timesteps
404
- sampling_timesteps = self.ddim_num_steps
405
-
406
- ts = torch.linspace(total_timesteps, 0, sampling_timesteps + 1).to(device).long()
407
- x = torch.randn(img_or_shape, device=device)
408
- batch_size = x.shape[0]
409
-
410
- imgs = [x]
411
- img_onestep = [condition_x[:, :self.channels, ...]]
412
-
413
- for i in range(1, sampling_timesteps + 1):
414
- cur_t = ts[i - 1] - 1
415
- prev_t = ts[i] - 1
416
-
417
- noise_level = torch.FloatTensor(
418
- [self.sqrt_alphas_cumprod_prev[cur_t.item()]]
419
- ).repeat(batch_size, 1).to(device)
420
-
421
- alpha_prod_t = self.alphas_cumprod[cur_t]
422
- alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else torch.tensor(1.0, device=device)
423
- beta_prod_t = 1 - alpha_prod_t
424
-
425
- # Model prediction
426
- model_output = self.denoise_fn(torch.cat([condition_x, x], dim=1), noise_level)
427
-
428
- # Compute sigma
429
- sigma_2 = eta * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
430
- noise = torch.randn_like(x)
431
-
432
- # Predict original sample
433
- pred_original_sample = (x - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
434
- pred_original_sample = pred_original_sample.clamp(-1, 1)
435
-
436
- pred_sample_direction = (1 - alpha_prod_t_prev - sigma_2) ** 0.5 * model_output
437
-
438
- x = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction + sigma_2 ** 0.5 * noise
439
-
440
- imgs.append(x)
441
- img_onestep.append(pred_original_sample)
442
-
443
- imgs = torch.cat(imgs, dim=0)
444
- img_onestep = torch.cat(img_onestep, dim=0)
445
-
446
- return imgs, img_onestep
447
-
448
- @torch.no_grad()
449
- def super_resolution(self, x_in, continous=False, seed=1, img_s1=None):
450
- """Main inference method."""
451
- device = self.betas.device
452
- x = x_in
453
- shape = (x.shape[0], self.channels, x.shape[-2], x.shape[-1])
454
-
455
- self.ddim_num_steps = self.opt['ddim_steps']
456
- ret_img, img_onestep = self.ddim_sample(condition_x=x, img_or_shape=shape, device=device, seed=seed)
457
-
458
- if continous:
459
- return ret_img, img_onestep
460
- else:
461
- return ret_img[-x_in.shape[0]:], img_onestep
462
-
463
-
464
- # ============================================================================
465
- # E3Diff Inference Class
466
- # ============================================================================
467
 
468
  class E3DiffInference:
 
 
469
  def __init__(self, weights_path=None, device="cuda", num_inference_steps=1):
470
  self.device = torch.device(device if torch.cuda.is_available() else "cpu")
471
  self.image_size = 256
@@ -480,6 +39,7 @@ class E3DiffInference:
480
  print("[E3Diff] Model ready!")
481
 
482
  def _build_model(self):
 
483
  unet = UNet(
484
  in_channel=3,
485
  out_channel=3,
@@ -505,19 +65,30 @@ class E3DiffInference:
505
  opt = {
506
  'stage': 2,
507
  'ddim_steps': self.num_inference_steps,
 
 
 
 
 
 
508
  }
509
 
510
  model = GaussianDiffusion(
511
  denoise_fn=unet,
512
  image_size=self.image_size,
513
  channels=3,
 
 
514
  schedule_opt=schedule_opt,
 
 
515
  opt=opt
516
  )
517
 
518
  return model.to(self.device)
519
 
520
  def _load_weights(self, weights_path):
 
521
  if weights_path is None:
522
  weights_path = hf_hub_download(
523
  repo_id="Dhenenjay/E3Diff-SAR2Optical",
@@ -530,6 +101,7 @@ class E3DiffInference:
530
  print("[E3Diff] Weights loaded!")
531
 
532
  def preprocess(self, image):
 
533
  if image.mode != 'RGB':
534
  image = image.convert('RGB')
535
  if image.size != (self.image_size, self.image_size):
@@ -541,6 +113,7 @@ class E3DiffInference:
541
  return img_tensor.unsqueeze(0).to(self.device)
542
 
543
  def postprocess(self, tensor):
 
544
  tensor = tensor.squeeze(0).cpu()
545
  tensor = torch.clamp(tensor, -1, 1)
546
  tensor = (tensor + 1.0) / 2.0
@@ -549,12 +122,14 @@ class E3DiffInference:
549
 
550
  @torch.no_grad()
551
  def translate(self, sar_image, seed=42):
 
552
  if seed is not None:
553
  torch.manual_seed(seed)
554
  np.random.seed(seed)
555
 
556
  sar_tensor = self.preprocess(sar_image)
557
 
 
558
  self.model.set_new_noise_schedule(
559
  {
560
  'schedule': 'linear',
@@ -568,22 +143,22 @@ class E3DiffInference:
568
  num_train_timesteps=1000
569
  )
570
 
 
571
  output, _ = self.model.super_resolution(sar_tensor, continous=False, seed=seed, img_s1=sar_tensor)
572
  return self.postprocess(output)
573
 
574
 
575
- # ============================================================================
576
- # High-Resolution Processor
577
- # ============================================================================
578
-
579
  class HighResProcessor:
 
 
580
  def __init__(self, device="cuda"):
581
  self.device = device
582
  self.model = None
583
  self.tile_size = 256
 
584
 
585
  def load_model(self, num_steps=1):
586
- print("Loading E3Diff model...")
587
  self.model = E3DiffInference(device=self.device, num_inference_steps=num_steps)
588
  self.num_steps = num_steps
589
 
@@ -640,7 +215,8 @@ class HighResProcessor:
640
  weights[y:y+tile_size, x:x+tile_size] += blend_weight
641
 
642
  tile_idx += 1
643
- print(f" Tile {tile_idx}/{total_tiles}")
 
644
 
645
  output = output / (weights + 1e-8)
646
  output = output[:h, :w]
@@ -656,12 +232,10 @@ class HighResProcessor:
656
  return image
657
 
658
 
659
- # ============================================================================
660
- # Gradio Interface
661
- # ============================================================================
662
-
663
  processor = None
664
 
 
665
  def load_sar_image(filepath):
666
  """Load SAR image from various formats."""
667
  try:
@@ -686,7 +260,7 @@ def load_sar_image(filepath):
686
 
687
 
688
  def _translate_sar_impl(file, num_steps, overlap, enhance_output):
689
- """Main translation function implementation."""
690
  global processor
691
 
692
  if file is None:
@@ -729,7 +303,7 @@ else:
729
  translate_sar = _translate_sar_impl
730
 
731
 
732
- # Create interface
733
  with gr.Blocks(title="E3Diff: SAR-to-Optical Translation") as demo:
734
  gr.Markdown("""
735
  # 🛰️ E3Diff: High-Resolution SAR-to-Optical Translation
@@ -743,14 +317,13 @@ with gr.Blocks(title="E3Diff: SAR-to-Optical Translation") as demo:
743
 
744
  with gr.Row():
745
  with gr.Column():
746
- input_file = gr.File(label="SAR Input (TIFF, PNG, JPG supported)", file_types=[".tif", ".tiff", ".png", ".jpg", ".jpeg"])
747
 
748
  with gr.Row():
749
  num_steps = gr.Slider(1, 8, value=1, step=1, label="Quality Steps (1=fast, 8=best)")
750
  overlap = gr.Slider(16, 128, value=64, step=16, label="Tile Overlap")
751
 
752
  enhance = gr.Checkbox(value=True, label="Apply enhancement")
753
-
754
  submit_btn = gr.Button("🚀 Translate to Optical", variant="primary")
755
 
756
  with gr.Column():
@@ -766,7 +339,7 @@ with gr.Blocks(title="E3Diff: SAR-to-Optical Translation") as demo:
766
 
767
  gr.Markdown("""
768
  ---
769
- **Tips:** The model works best with Sentinel-1 style SAR imagery. Use steps=1 for speed, steps=4-8 for quality.
770
  """)
771
 
772
 
 
1
+ """E3Diff: SAR-to-Optical Translation - HuggingFace Space."""
 
 
 
 
 
 
 
 
2
 
3
  import os
 
4
  import torch
 
 
5
  import numpy as np
6
  from PIL import Image, ImageEnhance
7
  import gradio as gr
 
8
  import tempfile
9
  import time
 
10
  from huggingface_hub import hf_hub_download
11
 
12
+ # Import model components
13
+ from unet import UNet
14
+ from diffusion import GaussianDiffusion
15
+
16
  # ZeroGPU support
17
  try:
18
  import spaces
 
21
  GPU_AVAILABLE = False
22
  spaces = None
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  class E3DiffInference:
26
+ """E3Diff Inference Pipeline - matches local implementation exactly."""
27
+
28
  def __init__(self, weights_path=None, device="cuda", num_inference_steps=1):
29
  self.device = torch.device(device if torch.cuda.is_available() else "cpu")
30
  self.image_size = 256
 
39
  print("[E3Diff] Model ready!")
40
 
41
  def _build_model(self):
42
+ """Build model - exact same config as local inference.py"""
43
  unet = UNet(
44
  in_channel=3,
45
  out_channel=3,
 
65
  opt = {
66
  'stage': 2,
67
  'ddim_steps': self.num_inference_steps,
68
+ 'model': {
69
+ 'beta_schedule': {
70
+ 'train': {'n_timestep': 1000},
71
+ 'val': schedule_opt
72
+ }
73
+ }
74
  }
75
 
76
  model = GaussianDiffusion(
77
  denoise_fn=unet,
78
  image_size=self.image_size,
79
  channels=3,
80
+ loss_type='l1',
81
+ conditional=True,
82
  schedule_opt=schedule_opt,
83
+ xT_noise_r=0,
84
+ seed=1,
85
  opt=opt
86
  )
87
 
88
  return model.to(self.device)
89
 
90
  def _load_weights(self, weights_path):
91
+ """Load weights - same as local inference.py"""
92
  if weights_path is None:
93
  weights_path = hf_hub_download(
94
  repo_id="Dhenenjay/E3Diff-SAR2Optical",
 
101
  print("[E3Diff] Weights loaded!")
102
 
103
  def preprocess(self, image):
104
+ """Preprocess input image."""
105
  if image.mode != 'RGB':
106
  image = image.convert('RGB')
107
  if image.size != (self.image_size, self.image_size):
 
113
  return img_tensor.unsqueeze(0).to(self.device)
114
 
115
  def postprocess(self, tensor):
116
+ """Postprocess output tensor."""
117
  tensor = tensor.squeeze(0).cpu()
118
  tensor = torch.clamp(tensor, -1, 1)
119
  tensor = (tensor + 1.0) / 2.0
 
122
 
123
  @torch.no_grad()
124
  def translate(self, sar_image, seed=42):
125
+ """Translate SAR to optical - same as local inference.py"""
126
  if seed is not None:
127
  torch.manual_seed(seed)
128
  np.random.seed(seed)
129
 
130
  sar_tensor = self.preprocess(sar_image)
131
 
132
+ # Set noise schedule
133
  self.model.set_new_noise_schedule(
134
  {
135
  'schedule': 'linear',
 
143
  num_train_timesteps=1000
144
  )
145
 
146
+ # Run inference
147
  output, _ = self.model.super_resolution(sar_tensor, continous=False, seed=seed, img_s1=sar_tensor)
148
  return self.postprocess(output)
149
 
150
 
 
 
 
 
151
  class HighResProcessor:
152
+ """High resolution tiled processing."""
153
+
154
  def __init__(self, device="cuda"):
155
  self.device = device
156
  self.model = None
157
  self.tile_size = 256
158
+ self.num_steps = None
159
 
160
  def load_model(self, num_steps=1):
161
+ print(f"Loading E3Diff model with {num_steps} steps...")
162
  self.model = E3DiffInference(device=self.device, num_inference_steps=num_steps)
163
  self.num_steps = num_steps
164
 
 
215
  weights[y:y+tile_size, x:x+tile_size] += blend_weight
216
 
217
  tile_idx += 1
218
+ if tile_idx % 4 == 0 or tile_idx == total_tiles:
219
+ print(f" Tile {tile_idx}/{total_tiles}")
220
 
221
  output = output / (weights + 1e-8)
222
  output = output[:h, :w]
 
232
  return image
233
 
234
 
235
+ # Global processor
 
 
 
236
  processor = None
237
 
238
+
239
  def load_sar_image(filepath):
240
  """Load SAR image from various formats."""
241
  try:
 
260
 
261
 
262
  def _translate_sar_impl(file, num_steps, overlap, enhance_output):
263
+ """Main translation function."""
264
  global processor
265
 
266
  if file is None:
 
303
  translate_sar = _translate_sar_impl
304
 
305
 
306
+ # Create Gradio interface
307
  with gr.Blocks(title="E3Diff: SAR-to-Optical Translation") as demo:
308
  gr.Markdown("""
309
  # 🛰️ E3Diff: High-Resolution SAR-to-Optical Translation
 
317
 
318
  with gr.Row():
319
  with gr.Column():
320
+ input_file = gr.File(label="SAR Input (TIFF, PNG, JPG)", file_types=[".tif", ".tiff", ".png", ".jpg", ".jpeg"])
321
 
322
  with gr.Row():
323
  num_steps = gr.Slider(1, 8, value=1, step=1, label="Quality Steps (1=fast, 8=best)")
324
  overlap = gr.Slider(16, 128, value=64, step=16, label="Tile Overlap")
325
 
326
  enhance = gr.Checkbox(value=True, label="Apply enhancement")
 
327
  submit_btn = gr.Button("🚀 Translate to Optical", variant="primary")
328
 
329
  with gr.Column():
 
339
 
340
  gr.Markdown("""
341
  ---
342
+ **Tips:** Use steps=1 for speed, steps=4-8 for quality. Works best with Sentinel-1 style SAR.
343
  """)
344
 
345