JS6969 commited on
Commit
356195b
Β·
verified Β·
1 Parent(s): 3ec4287

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +236 -1
app.py CHANGED
@@ -31,6 +31,10 @@ import numpy
31
  import os
32
  import random
33
  import inspect
 
 
 
 
34
 
35
  from basicsr.archs.rrdbnet_arch import RRDBNet as _RRDBNet
36
  from basicsr.utils.download_util import load_file_from_url
@@ -160,6 +164,69 @@ def build_rrdb(scale: int, num_block: int):
160
  except TypeError as e:
161
  raise TypeError(f"RRDBNet signature not recognized: {e}")
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
  # ────────────────────────────────────────────────────────
165
  # Core upscaling
@@ -285,6 +352,141 @@ def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
285
 
286
  return display_img
287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
  # ────────────────────────────────────────────────────────
290
  # UI
@@ -339,11 +541,44 @@ def main():
339
  )
340
  reset_btn.click(fn=reset, inputs=[], outputs=[output_image, input_image])
341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
  gr.Markdown("") # spacer
343
 
344
  # Disable SSR (ZeroGPU + Gradio logs suggested turning this off)
345
  demo.launch(ssr_mode=False) # set share=True for a public link
346
 
347
 
 
 
 
 
 
348
  if __name__ == "__main__":
349
- main()
 
 
31
  import os
32
  import random
33
  import inspect
34
+ from pathlib import Path
35
+ import zipfile
36
+ import tempfile
37
+
38
 
39
  from basicsr.archs.rrdbnet_arch import RRDBNet as _RRDBNet
40
  from basicsr.utils.download_util import load_file_from_url
 
164
  except TypeError as e:
165
  raise TypeError(f"RRDBNet signature not recognized: {e}")
166
 
167
+ #Factor an upsampler builder
168
+ def get_upsampler(model_name: str, outscale: int, tile: int = 256):
169
+ # Build the same backbone/weights as in realesrgan(), but return a ready RealESRGANer
170
+ if model_name == 'RealESRGAN_x4plus':
171
+ model = build_rrdb(scale=4, num_block=23); netscale = 4
172
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
173
+ elif model_name == 'RealESRNet_x4plus':
174
+ model = build_rrdb(scale=4, num_block=23); netscale = 4
175
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
176
+ elif model_name == 'RealESRGAN_x4plus_anime_6B':
177
+ model = build_rrdb(scale=4, num_block=6); netscale = 4
178
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
179
+ elif model_name == 'RealESRGAN_x2plus':
180
+ model = build_rrdb(scale=2, num_block=23); netscale = 2
181
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
182
+ elif model_name == 'realesr-general-x4v3':
183
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu'); netscale = 4
184
+ file_url = [
185
+ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth',
186
+ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth'
187
+ ]
188
+ else:
189
+ raise ValueError(f"Unknown model: {model_name}")
190
+
191
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
192
+ weights_dir = os.path.join(ROOT_DIR, 'weights')
193
+ os.makedirs(weights_dir, exist_ok=True)
194
+ for url in file_url:
195
+ fname = os.path.basename(url)
196
+ local_path = os.path.join(weights_dir, fname)
197
+ if not os.path.isfile(local_path):
198
+ load_file_from_url(url=url, model_dir=weights_dir, progress=True)
199
+
200
+ if model_name == 'realesr-general-x4v3':
201
+ model_path = [
202
+ os.path.join(weights_dir, 'realesr-general-x4v3.pth'),
203
+ os.path.join(weights_dir, 'realesr-general-wdn-x4v3.pth'),
204
+ ]
205
+ dni_weight = None # supplied at call site if using denoise blend
206
+ else:
207
+ model_path = os.path.join(weights_dir, f"{model_name}.pth")
208
+ dni_weight = None
209
+
210
+ use_cuda = False
211
+ try:
212
+ use_cuda = hasattr(cv2, "cuda") and cv2.cuda.getCudaEnabledDeviceCount() > 0
213
+ except Exception:
214
+ use_cuda = False
215
+ gpu_id = 0 if use_cuda else None
216
+
217
+ upsampler = RealESRGANer(
218
+ scale=netscale,
219
+ model_path=model_path,
220
+ dni_weight=dni_weight,
221
+ model=model,
222
+ tile=tile or 256,
223
+ tile_pad=10,
224
+ pre_pad=10,
225
+ half=bool(use_cuda),
226
+ gpu_id=gpu_id
227
+ )
228
+ return upsampler, netscale, use_cuda, model_path
229
+
230
 
231
  # ────────────────────────────────────────────────────────
232
  # Core upscaling
 
352
 
353
  return display_img
354
 
355
+ #Add a batch upscaler that preserves filenames
356
+ def render_progress(pct: float, text: str = "") -> str:
357
+ pct = max(0.0, min(100.0, float(pct)))
358
+ bar = f"<div style='width:100%;border:1px solid #ddd;border-radius:6px;overflow:hidden;height:12px;'><div style='height:100%;width:{pct:.1f}%;background:#3b82f6;'></div></div>"
359
+ label = f"<div style='font-size:12px;opacity:.8;margin-top:4px;'>{text} {pct:.1f}%</div>"
360
+ return bar + label
361
+
362
+ def batch_realesrgan(
363
+ files: list, # from gr.Files (type='filepath')
364
+ model_name: str,
365
+ denoise_strength: float,
366
+ face_enhance: bool,
367
+ outscale: int,
368
+ tile: int,
369
+ batch_size: int = 16,
370
+ ):
371
+ """
372
+ Processes multiple images in batches, preserves original file names for outputs,
373
+ and returns (gallery, zip_file, details, progress_html) with streamed progress.
374
+ """
375
+ # Validate
376
+ if not files or len(files) == 0:
377
+ yield None, None, "No files uploaded.", render_progress(0, "Idle")
378
+ return
379
+
380
+ # Build upsampler once (much faster than per-image)
381
+ upsampler, netscale, use_cuda, model_path = get_upsampler(model_name, outscale, tile=tile)
382
+
383
+ # Optional: face enhancer (same as your single-image path)
384
+ face_enhancer = None
385
+ if face_enhance:
386
+ from gfpgan import GFPGANer
387
+ face_enhancer = GFPGANer(
388
+ model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
389
+ upscale=outscale,
390
+ arch='clean',
391
+ channel_multiplier=2,
392
+ bg_upsampler=upsampler
393
+ )
394
+
395
+ # Prepare work/output dirs
396
+ work = Path(tempfile.mkdtemp(prefix="batch_up_"))
397
+ out_dir = work / "upscaled"
398
+ out_dir.mkdir(parents=True, exist_ok=True)
399
+
400
+ # Normalize list of input paths
401
+ src_paths = [Path(f.name if hasattr(f, "name") else f) for f in files]
402
+
403
+ total = len(src_paths)
404
+ done = 0
405
+ out_paths = []
406
+
407
+ # If realesr-general-x4v3: support blending base + WDN via dni (optional)
408
+ dni_weight = None
409
+ if model_name == "realesr-general-x4v3":
410
+ # Blend [base, WDN] with user's slider
411
+ denoise_strength = float(denoise_strength)
412
+ dni_weight = [1.0 - denoise_strength, denoise_strength]
413
+ # RealESRGANer.enhance accepts dni_weight override via attribute on the instance
414
+ try:
415
+ upsampler.dni_weight = dni_weight
416
+ except Exception:
417
+ pass
418
+
419
+ # Process in batches (I/O and PIL open are still per-file)
420
+ for i in range(0, total, int(max(1, batch_size))):
421
+ batch = src_paths[i:i + int(max(1, batch_size))]
422
+ for src in batch:
423
+ try:
424
+ # Load as RGB consistently
425
+ from PIL import Image
426
+ with Image.open(src) as im:
427
+ img = im.convert("RGB")
428
+ arr = numpy.array(img)
429
+ arr = cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
430
+
431
+ if face_enhancer:
432
+ _, _, output = face_enhancer.enhance(arr, has_aligned=False, only_center_face=False, paste_back=True)
433
+ else:
434
+ output, _ = upsampler.enhance(arr, outscale=int(outscale))
435
+
436
+ # Preserve original file name & (reasonable) extension
437
+ orig_ext = src.suffix.lower()
438
+ ext = orig_ext if orig_ext in (".png", ".jpg", ".jpeg") else ".png"
439
+ out_path = out_dir / (src.stem + ext)
440
+
441
+ # Save (keep alpha if produced, else RGB)
442
+ if output.ndim == 3 and output.shape[2] == 4:
443
+ cv2.imwrite(str(out_path.with_suffix(".png")), output) # 4ch β†’ PNG
444
+ out_path = out_path.with_suffix(".png")
445
+ else:
446
+ if ext in (".jpg", ".jpeg"):
447
+ cv2.imwrite(str(out_path), output, [int(cv2.IMWRITE_JPEG_QUALITY), 95])
448
+ else:
449
+ cv2.imwrite(str(out_path), output) # PNG default
450
+ out_paths.append(out_path)
451
+ except Exception as e:
452
+ # Continue on errors
453
+ print(f"[batch] Error on {src}: {e}")
454
+ finally:
455
+ done += 1
456
+
457
+ pct = (done / total) * 100.0 if total else 0.0
458
+ remaining = max(0, total - done)
459
+ msg = f"Upscaling… {done}/{total} done Β· {remaining} remaining (batch {(i//batch_size)+1}/{(total+batch_size-1)//batch_size})"
460
+ yield None, None, msg, render_progress(pct, msg)
461
+
462
+ if not out_paths:
463
+ yield None, None, "No outputs produced.", render_progress(100, "Finished")
464
+ return
465
+
466
+ # Small even-sampled gallery for preview
467
+ def _sample_even(seq, n=30):
468
+ if not seq: return []
469
+ if len(seq) <= n: return [str(p) for p in seq]
470
+ step = (len(seq)-1) / (n-1)
471
+ idxs = [round(i*step) for i in range(n)]
472
+ seen, out = set(), []
473
+ for i in idxs:
474
+ if i not in seen:
475
+ out.append(str(seq[int(i)])); seen.add(int(i))
476
+ return out
477
+
478
+ out_paths = sorted(out_paths) # stable
479
+ gallery = _sample_even(out_paths, 30)
480
+
481
+ # Zip with same file names
482
+ zip_path = work / "upscaled.zip"
483
+ with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
484
+ for p in out_paths:
485
+ zf.write(p, arcname=p.name)
486
+
487
+ details = f"Upscaled {len(out_paths)} images β†’ {out_dir}"
488
+ yield gallery, str(zip_path), details, render_progress(100, "Complete")
489
+
490
 
491
  # ────────────────────────────────────────────────────────
492
  # UI
 
541
  )
542
  reset_btn.click(fn=reset, inputs=[], outputs=[output_image, input_image])
543
 
544
+ # --- Batch Upscale (multi-image) ---
545
+ gr.Markdown("### Batch Upscale")
546
+ with gr.Accordion("Batch options", open=True):
547
+ with gr.Row():
548
+ batch_files = gr.Files(
549
+ label="Upload multiple images (PNG/JPG/JPEG)",
550
+ type="filepath",
551
+ file_types=[".png", ".jpg", ".jpeg"],
552
+ )
553
+ with gr.Row():
554
+ batch_tile = gr.Number(label="Tile size (0/auto β†’ 256)", value=256, precision=0)
555
+ batch_size = gr.Number(label="Batch size (images per batch)", value=16, precision=0)
556
+
557
+ with gr.Row():
558
+ batch_btn = gr.Button("Upscale Batch", variant="primary")
559
+
560
+ batch_prog = gr.HTML(render_progress(0.0, "Idle"))
561
+ batch_gallery = gr.Gallery(label="Preview (sampled 30)", columns=6, height=420)
562
+ batch_zip = gr.File(label="Download upscaled.zip")
563
+ batch_details = gr.Markdown("")
564
+
565
+ # Wire it up (generator β†’ streaming)
566
+ batch_btn.click(
567
+ fn=batch_realesrgan,
568
+ inputs=[batch_files, model_name, denoise_strength, face_enhance, outscale, batch_tile, batch_size],
569
+ outputs=[batch_gallery, batch_zip, batch_details, batch_prog],
570
+ )
571
  gr.Markdown("") # spacer
572
 
573
  # Disable SSR (ZeroGPU + Gradio logs suggested turning this off)
574
  demo.launch(ssr_mode=False) # set share=True for a public link
575
 
576
 
577
+ def main():
578
+ with gr.Blocks(title="Real-ESRGAN Gradio Demo", theme="ParityError/Interstellar") as demo:
579
+ # ... your current UI (plus batch section) ...
580
+ return demo
581
+
582
  if __name__ == "__main__":
583
+ demo = main()
584
+ demo.queue().launch(ssr_mode=False)