Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -31,6 +31,10 @@ import numpy
|
|
| 31 |
import os
|
| 32 |
import random
|
| 33 |
import inspect
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
from basicsr.archs.rrdbnet_arch import RRDBNet as _RRDBNet
|
| 36 |
from basicsr.utils.download_util import load_file_from_url
|
|
@@ -160,6 +164,69 @@ def build_rrdb(scale: int, num_block: int):
|
|
| 160 |
except TypeError as e:
|
| 161 |
raise TypeError(f"RRDBNet signature not recognized: {e}")
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 165 |
# Core upscaling
|
|
@@ -285,6 +352,141 @@ def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
|
|
| 285 |
|
| 286 |
return display_img
|
| 287 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
|
| 289 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 290 |
# UI
|
|
@@ -339,11 +541,44 @@ def main():
|
|
| 339 |
)
|
| 340 |
reset_btn.click(fn=reset, inputs=[], outputs=[output_image, input_image])
|
| 341 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
gr.Markdown("") # spacer
|
| 343 |
|
| 344 |
# Disable SSR (ZeroGPU + Gradio logs suggested turning this off)
|
| 345 |
demo.launch(ssr_mode=False) # set share=True for a public link
|
| 346 |
|
| 347 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
if __name__ == "__main__":
|
| 349 |
-
main()
|
|
|
|
|
|
| 31 |
import os
|
| 32 |
import random
|
| 33 |
import inspect
|
| 34 |
+
from pathlib import Path
|
| 35 |
+
import zipfile
|
| 36 |
+
import tempfile
|
| 37 |
+
|
| 38 |
|
| 39 |
from basicsr.archs.rrdbnet_arch import RRDBNet as _RRDBNet
|
| 40 |
from basicsr.utils.download_util import load_file_from_url
|
|
|
|
| 164 |
except TypeError as e:
|
| 165 |
raise TypeError(f"RRDBNet signature not recognized: {e}")
|
| 166 |
|
| 167 |
+
#Factor an upsampler builder
|
| 168 |
+
def get_upsampler(model_name: str, outscale: int, tile: int = 256):
|
| 169 |
+
# Build the same backbone/weights as in realesrgan(), but return a ready RealESRGANer
|
| 170 |
+
if model_name == 'RealESRGAN_x4plus':
|
| 171 |
+
model = build_rrdb(scale=4, num_block=23); netscale = 4
|
| 172 |
+
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
|
| 173 |
+
elif model_name == 'RealESRNet_x4plus':
|
| 174 |
+
model = build_rrdb(scale=4, num_block=23); netscale = 4
|
| 175 |
+
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
|
| 176 |
+
elif model_name == 'RealESRGAN_x4plus_anime_6B':
|
| 177 |
+
model = build_rrdb(scale=4, num_block=6); netscale = 4
|
| 178 |
+
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
|
| 179 |
+
elif model_name == 'RealESRGAN_x2plus':
|
| 180 |
+
model = build_rrdb(scale=2, num_block=23); netscale = 2
|
| 181 |
+
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
|
| 182 |
+
elif model_name == 'realesr-general-x4v3':
|
| 183 |
+
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu'); netscale = 4
|
| 184 |
+
file_url = [
|
| 185 |
+
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth',
|
| 186 |
+
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth'
|
| 187 |
+
]
|
| 188 |
+
else:
|
| 189 |
+
raise ValueError(f"Unknown model: {model_name}")
|
| 190 |
+
|
| 191 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 192 |
+
weights_dir = os.path.join(ROOT_DIR, 'weights')
|
| 193 |
+
os.makedirs(weights_dir, exist_ok=True)
|
| 194 |
+
for url in file_url:
|
| 195 |
+
fname = os.path.basename(url)
|
| 196 |
+
local_path = os.path.join(weights_dir, fname)
|
| 197 |
+
if not os.path.isfile(local_path):
|
| 198 |
+
load_file_from_url(url=url, model_dir=weights_dir, progress=True)
|
| 199 |
+
|
| 200 |
+
if model_name == 'realesr-general-x4v3':
|
| 201 |
+
model_path = [
|
| 202 |
+
os.path.join(weights_dir, 'realesr-general-x4v3.pth'),
|
| 203 |
+
os.path.join(weights_dir, 'realesr-general-wdn-x4v3.pth'),
|
| 204 |
+
]
|
| 205 |
+
dni_weight = None # supplied at call site if using denoise blend
|
| 206 |
+
else:
|
| 207 |
+
model_path = os.path.join(weights_dir, f"{model_name}.pth")
|
| 208 |
+
dni_weight = None
|
| 209 |
+
|
| 210 |
+
use_cuda = False
|
| 211 |
+
try:
|
| 212 |
+
use_cuda = hasattr(cv2, "cuda") and cv2.cuda.getCudaEnabledDeviceCount() > 0
|
| 213 |
+
except Exception:
|
| 214 |
+
use_cuda = False
|
| 215 |
+
gpu_id = 0 if use_cuda else None
|
| 216 |
+
|
| 217 |
+
upsampler = RealESRGANer(
|
| 218 |
+
scale=netscale,
|
| 219 |
+
model_path=model_path,
|
| 220 |
+
dni_weight=dni_weight,
|
| 221 |
+
model=model,
|
| 222 |
+
tile=tile or 256,
|
| 223 |
+
tile_pad=10,
|
| 224 |
+
pre_pad=10,
|
| 225 |
+
half=bool(use_cuda),
|
| 226 |
+
gpu_id=gpu_id
|
| 227 |
+
)
|
| 228 |
+
return upsampler, netscale, use_cuda, model_path
|
| 229 |
+
|
| 230 |
|
| 231 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 232 |
# Core upscaling
|
|
|
|
| 352 |
|
| 353 |
return display_img
|
| 354 |
|
| 355 |
+
#Add a batch upscaler that preserves filenames
|
| 356 |
+
def render_progress(pct: float, text: str = "") -> str:
|
| 357 |
+
pct = max(0.0, min(100.0, float(pct)))
|
| 358 |
+
bar = f"<div style='width:100%;border:1px solid #ddd;border-radius:6px;overflow:hidden;height:12px;'><div style='height:100%;width:{pct:.1f}%;background:#3b82f6;'></div></div>"
|
| 359 |
+
label = f"<div style='font-size:12px;opacity:.8;margin-top:4px;'>{text} {pct:.1f}%</div>"
|
| 360 |
+
return bar + label
|
| 361 |
+
|
| 362 |
+
def batch_realesrgan(
|
| 363 |
+
files: list, # from gr.Files (type='filepath')
|
| 364 |
+
model_name: str,
|
| 365 |
+
denoise_strength: float,
|
| 366 |
+
face_enhance: bool,
|
| 367 |
+
outscale: int,
|
| 368 |
+
tile: int,
|
| 369 |
+
batch_size: int = 16,
|
| 370 |
+
):
|
| 371 |
+
"""
|
| 372 |
+
Processes multiple images in batches, preserves original file names for outputs,
|
| 373 |
+
and returns (gallery, zip_file, details, progress_html) with streamed progress.
|
| 374 |
+
"""
|
| 375 |
+
# Validate
|
| 376 |
+
if not files or len(files) == 0:
|
| 377 |
+
yield None, None, "No files uploaded.", render_progress(0, "Idle")
|
| 378 |
+
return
|
| 379 |
+
|
| 380 |
+
# Build upsampler once (much faster than per-image)
|
| 381 |
+
upsampler, netscale, use_cuda, model_path = get_upsampler(model_name, outscale, tile=tile)
|
| 382 |
+
|
| 383 |
+
# Optional: face enhancer (same as your single-image path)
|
| 384 |
+
face_enhancer = None
|
| 385 |
+
if face_enhance:
|
| 386 |
+
from gfpgan import GFPGANer
|
| 387 |
+
face_enhancer = GFPGANer(
|
| 388 |
+
model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
|
| 389 |
+
upscale=outscale,
|
| 390 |
+
arch='clean',
|
| 391 |
+
channel_multiplier=2,
|
| 392 |
+
bg_upsampler=upsampler
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
# Prepare work/output dirs
|
| 396 |
+
work = Path(tempfile.mkdtemp(prefix="batch_up_"))
|
| 397 |
+
out_dir = work / "upscaled"
|
| 398 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 399 |
+
|
| 400 |
+
# Normalize list of input paths
|
| 401 |
+
src_paths = [Path(f.name if hasattr(f, "name") else f) for f in files]
|
| 402 |
+
|
| 403 |
+
total = len(src_paths)
|
| 404 |
+
done = 0
|
| 405 |
+
out_paths = []
|
| 406 |
+
|
| 407 |
+
# If realesr-general-x4v3: support blending base + WDN via dni (optional)
|
| 408 |
+
dni_weight = None
|
| 409 |
+
if model_name == "realesr-general-x4v3":
|
| 410 |
+
# Blend [base, WDN] with user's slider
|
| 411 |
+
denoise_strength = float(denoise_strength)
|
| 412 |
+
dni_weight = [1.0 - denoise_strength, denoise_strength]
|
| 413 |
+
# RealESRGANer.enhance accepts dni_weight override via attribute on the instance
|
| 414 |
+
try:
|
| 415 |
+
upsampler.dni_weight = dni_weight
|
| 416 |
+
except Exception:
|
| 417 |
+
pass
|
| 418 |
+
|
| 419 |
+
# Process in batches (I/O and PIL open are still per-file)
|
| 420 |
+
for i in range(0, total, int(max(1, batch_size))):
|
| 421 |
+
batch = src_paths[i:i + int(max(1, batch_size))]
|
| 422 |
+
for src in batch:
|
| 423 |
+
try:
|
| 424 |
+
# Load as RGB consistently
|
| 425 |
+
from PIL import Image
|
| 426 |
+
with Image.open(src) as im:
|
| 427 |
+
img = im.convert("RGB")
|
| 428 |
+
arr = numpy.array(img)
|
| 429 |
+
arr = cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
|
| 430 |
+
|
| 431 |
+
if face_enhancer:
|
| 432 |
+
_, _, output = face_enhancer.enhance(arr, has_aligned=False, only_center_face=False, paste_back=True)
|
| 433 |
+
else:
|
| 434 |
+
output, _ = upsampler.enhance(arr, outscale=int(outscale))
|
| 435 |
+
|
| 436 |
+
# Preserve original file name & (reasonable) extension
|
| 437 |
+
orig_ext = src.suffix.lower()
|
| 438 |
+
ext = orig_ext if orig_ext in (".png", ".jpg", ".jpeg") else ".png"
|
| 439 |
+
out_path = out_dir / (src.stem + ext)
|
| 440 |
+
|
| 441 |
+
# Save (keep alpha if produced, else RGB)
|
| 442 |
+
if output.ndim == 3 and output.shape[2] == 4:
|
| 443 |
+
cv2.imwrite(str(out_path.with_suffix(".png")), output) # 4ch β PNG
|
| 444 |
+
out_path = out_path.with_suffix(".png")
|
| 445 |
+
else:
|
| 446 |
+
if ext in (".jpg", ".jpeg"):
|
| 447 |
+
cv2.imwrite(str(out_path), output, [int(cv2.IMWRITE_JPEG_QUALITY), 95])
|
| 448 |
+
else:
|
| 449 |
+
cv2.imwrite(str(out_path), output) # PNG default
|
| 450 |
+
out_paths.append(out_path)
|
| 451 |
+
except Exception as e:
|
| 452 |
+
# Continue on errors
|
| 453 |
+
print(f"[batch] Error on {src}: {e}")
|
| 454 |
+
finally:
|
| 455 |
+
done += 1
|
| 456 |
+
|
| 457 |
+
pct = (done / total) * 100.0 if total else 0.0
|
| 458 |
+
remaining = max(0, total - done)
|
| 459 |
+
msg = f"Upscaling⦠{done}/{total} done · {remaining} remaining (batch {(i//batch_size)+1}/{(total+batch_size-1)//batch_size})"
|
| 460 |
+
yield None, None, msg, render_progress(pct, msg)
|
| 461 |
+
|
| 462 |
+
if not out_paths:
|
| 463 |
+
yield None, None, "No outputs produced.", render_progress(100, "Finished")
|
| 464 |
+
return
|
| 465 |
+
|
| 466 |
+
# Small even-sampled gallery for preview
|
| 467 |
+
def _sample_even(seq, n=30):
|
| 468 |
+
if not seq: return []
|
| 469 |
+
if len(seq) <= n: return [str(p) for p in seq]
|
| 470 |
+
step = (len(seq)-1) / (n-1)
|
| 471 |
+
idxs = [round(i*step) for i in range(n)]
|
| 472 |
+
seen, out = set(), []
|
| 473 |
+
for i in idxs:
|
| 474 |
+
if i not in seen:
|
| 475 |
+
out.append(str(seq[int(i)])); seen.add(int(i))
|
| 476 |
+
return out
|
| 477 |
+
|
| 478 |
+
out_paths = sorted(out_paths) # stable
|
| 479 |
+
gallery = _sample_even(out_paths, 30)
|
| 480 |
+
|
| 481 |
+
# Zip with same file names
|
| 482 |
+
zip_path = work / "upscaled.zip"
|
| 483 |
+
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
| 484 |
+
for p in out_paths:
|
| 485 |
+
zf.write(p, arcname=p.name)
|
| 486 |
+
|
| 487 |
+
details = f"Upscaled {len(out_paths)} images β {out_dir}"
|
| 488 |
+
yield gallery, str(zip_path), details, render_progress(100, "Complete")
|
| 489 |
+
|
| 490 |
|
| 491 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 492 |
# UI
|
|
|
|
| 541 |
)
|
| 542 |
reset_btn.click(fn=reset, inputs=[], outputs=[output_image, input_image])
|
| 543 |
|
| 544 |
+
# --- Batch Upscale (multi-image) ---
|
| 545 |
+
gr.Markdown("### Batch Upscale")
|
| 546 |
+
with gr.Accordion("Batch options", open=True):
|
| 547 |
+
with gr.Row():
|
| 548 |
+
batch_files = gr.Files(
|
| 549 |
+
label="Upload multiple images (PNG/JPG/JPEG)",
|
| 550 |
+
type="filepath",
|
| 551 |
+
file_types=[".png", ".jpg", ".jpeg"],
|
| 552 |
+
)
|
| 553 |
+
with gr.Row():
|
| 554 |
+
batch_tile = gr.Number(label="Tile size (0/auto β 256)", value=256, precision=0)
|
| 555 |
+
batch_size = gr.Number(label="Batch size (images per batch)", value=16, precision=0)
|
| 556 |
+
|
| 557 |
+
with gr.Row():
|
| 558 |
+
batch_btn = gr.Button("Upscale Batch", variant="primary")
|
| 559 |
+
|
| 560 |
+
batch_prog = gr.HTML(render_progress(0.0, "Idle"))
|
| 561 |
+
batch_gallery = gr.Gallery(label="Preview (sampled 30)", columns=6, height=420)
|
| 562 |
+
batch_zip = gr.File(label="Download upscaled.zip")
|
| 563 |
+
batch_details = gr.Markdown("")
|
| 564 |
+
|
| 565 |
+
# Wire it up (generator β streaming)
|
| 566 |
+
batch_btn.click(
|
| 567 |
+
fn=batch_realesrgan,
|
| 568 |
+
inputs=[batch_files, model_name, denoise_strength, face_enhance, outscale, batch_tile, batch_size],
|
| 569 |
+
outputs=[batch_gallery, batch_zip, batch_details, batch_prog],
|
| 570 |
+
)
|
| 571 |
gr.Markdown("") # spacer
|
| 572 |
|
| 573 |
# Disable SSR (ZeroGPU + Gradio logs suggested turning this off)
|
| 574 |
demo.launch(ssr_mode=False) # set share=True for a public link
|
| 575 |
|
| 576 |
|
| 577 |
+
def main():
|
| 578 |
+
with gr.Blocks(title="Real-ESRGAN Gradio Demo", theme="ParityError/Interstellar") as demo:
|
| 579 |
+
# ... your current UI (plus batch section) ...
|
| 580 |
+
return demo
|
| 581 |
+
|
| 582 |
if __name__ == "__main__":
|
| 583 |
+
demo = main()
|
| 584 |
+
demo.queue().launch(ssr_mode=False)
|