| import json, logging, time, sys, os, tempfile, threading |
| from io import BytesIO |
| from pathlib import Path |
| from contextlib import asynccontextmanager |
| from typing import Literal |
|
|
| import numpy as np |
| import torch |
| from PIL import Image, ImageDraw, ImageFont |
| from scipy import ndimage |
|
|
| from fastapi import FastAPI, File, UploadFile, Query, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import StreamingResponse, JSONResponse |
|
|
| from mewzoom.model import MewZoom |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") |
| logger = logging.getLogger(__name__) |
|
|
| MEWZOOM_MODELS = {"2x": "andrewdalpino/MewZoom-V1-2X-Unet", "4x": "andrewdalpino/MewZoom-V1-4X-Unet"} |
| MAX_DIM = {"2x": 2048, "4x": 1024, "invsr": 256} |
| CACHE_DIR = Path("models") |
| _DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| logger.info("Device: %s", _DEVICE) |
|
|
| |
| _mz_models: dict[str, MewZoom] = {} |
|
|
| def _load_mewzoom(scale: str) -> MewZoom: |
| if scale in _mz_models: |
| return _mz_models[scale] |
| mid = MEWZOOM_MODELS[scale] |
| logger.info("Loading MewZoom %s ...", scale) |
| CACHE_DIR.mkdir(exist_ok=True) |
| m = MewZoom.from_pretrained(mid, cache_dir=str(CACHE_DIR)) |
| m.to(_DEVICE).eval() |
| _mz_models[scale] = m |
| logger.info("MewZoom %s ready (%s params)", scale, f"{sum(p.numel() for p in m.parameters()):,}") |
| return m |
|
|
| def _pil_to_tensor(img: Image.Image) -> torch.Tensor: |
| arr = np.array(img, dtype=np.float32) / 255.0 |
| return torch.from_numpy(arr).permute(2, 0, 1) |
|
|
| def _resize_if_needed(img: Image.Image, scale: str) -> tuple[Image.Image, bool]: |
| md = MAX_DIM.get(scale, 1024) |
| w, h = img.size |
| if max(w, h) <= md: |
| return img, False |
| r = md / max(w, h) |
| return img.resize((int(w * r), int(h * r)), Image.LANCZOS), True |
|
|
| def upscale_mewzoom(image_bytes: bytes, scale: str) -> tuple[bytes, dict]: |
| model = _load_mewzoom(scale) |
| factor = int(scale[0]) |
| pil = Image.open(BytesIO(image_bytes)).convert("RGB") |
| orig = (pil.width, pil.height) |
| pil, resized = _resize_if_needed(pil, scale) |
| out_mp = pil.width * factor * pil.height * factor / 1e6 |
| if out_mp > 64: |
| raise HTTPException(400, f"Output too large ({out_mp:.0f}MP)") |
| x = _pil_to_tensor(pil).unsqueeze(0).to(_DEVICE) |
| with torch.inference_mode(): |
| y = model.upscale(x) |
| result_np = (y.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).clip(0, 255).astype(np.uint8) |
| result = Image.fromarray(result_np) |
| buf = BytesIO(); result.save(buf, format="PNG"); buf.seek(0) |
| return buf.getvalue(), {"scale": scale, "input": f"{orig[0]}x{orig[1]}", "output": f"{result.width}x{result.height}", "resized": resized} |
|
|
| |
| _INVSR_PATH = Path("/app/InvSR") |
| _sampler_invsr = None |
| _invsr_status = "not_loaded" |
| _invsr_error = None |
| _invsr_jobs: dict[str, dict] = {} |
| _job_counter = 0 |
|
|
| def _patch_invsr_source(): |
| p = _INVSR_PATH / "sampler_invsr.py" |
| code = p.read_text() |
| code = code.replace("from datapipe.datasets import create_dataset", "") |
| code = code.replace( |
| "class BaseSampler:\n def __init__(self, configs):\n '''\n Input:\n configs: config", |
| "class BaseSampler:\n def __init__(self, configs, device='auto'):\n '''\n Input:\n configs: config" |
| ) |
| code = code.replace( |
| "self.configs = configs\n\n self.setup_seed()\n\n self.build_model()", |
| "self.configs = configs\n if device == 'auto':\n device = 'cuda' if torch.cuda.is_available() else 'cpu'\n self.device = torch.device(device)\n self.dtype = torch.float16 if self.device.type == 'cuda' else torch.float32\n self.setup_seed()\n self.build_model()" |
| ) |
| code = code.replace( |
| "torch.cuda.manual_seed_all(seed)", |
| "if torch.cuda.is_available():\n torch.cuda.manual_seed_all(seed)" |
| ) |
| code = code.replace('sd_pipe.to(f"cuda")', "sd_pipe.to(self.device)") |
| code = code.replace("model_start.cuda()", "model_start.to(self.device)") |
| code = code.replace('map_location=f"cuda"', "map_location=self.device") |
| code = code.replace("im_cond.type(torch.float16)", "im_cond.type(self.dtype)") |
| code = code.replace(".type(torch.float16)", ".type(self.dtype)") |
| code = code.replace("data['lq'].cuda()", "data['lq'].to(self.device)") |
| code = code.replace("util_image.img2tensor(im_cond).cuda()", "util_image.img2tensor(im_cond).to(self.device)") |
| code = code.replace( |
| "if in_path.is_dir():\n data_config", |
| "if in_path.is_dir():\n from datapipe.datasets import create_dataset\n data_config" |
| ) |
| p.write_text(code) |
| logger.info("InvSR source patched for CPU") |
|
|
| def _load_invsr_sync(): |
| global _sampler_invsr, _invsr_status, _invsr_error |
| try: |
| _invsr_status = "patching" |
| _patch_invsr_source() |
| sys.path.insert(0, str(_INVSR_PATH)) |
| sys.path.insert(0, str(_INVSR_PATH / "src")) |
|
|
| from omegaconf import OmegaConf |
| from huggingface_hub import snapshot_download, hf_hub_download |
| from sampler_invsr import InvSamplerSR |
|
|
| invsr_cache = str(CACHE_DIR / "invsr") |
| CACHE_DIR.mkdir(exist_ok=True) |
|
|
| _invsr_status = "downloading_sd_turbo" |
| logger.info("Downloading SD-Turbo (~5GB, one-time, 10-20 min)...") |
| snapshot_download("stabilityai/sd-turbo", cache_dir=invsr_cache, resume_download=True) |
| logger.info("SD-Turbo downloaded") |
|
|
| _invsr_status = "downloading_noise_pred" |
| logger.info("Downloading noise predictor...") |
| hf_hub_download("OAOA/InvSR", "noise_predictor_sd_turbo_v5.pth", cache_dir=invsr_cache) |
| ckpt = None |
| for f in Path(invsr_cache).rglob("noise_predictor_sd_turbo_v5.pth"): |
| ckpt = str(f); break |
| if not ckpt: |
| raise FileNotFoundError("Noise predictor not found") |
|
|
| _invsr_status = "loading" |
| cfg = OmegaConf.load(str(_INVSR_PATH / "configs" / "sample-sd-turbo.yaml")) |
| cfg.sd_pipe.params.torch_dtype = "torch.float32" |
| cfg.sd_pipe.params.cache_dir = invsr_cache |
| cfg.sd_pipe.params.local_files_only = True |
| cfg.model_start.ckpt_path = ckpt |
| cfg.timesteps = [200]; cfg.bs = 1; cfg.tiled_vae = True |
| cfg.color_fix = "wavelet"; cfg.basesr.chopping.pch_size = 128 |
| cfg.basesr.chopping.extra_bs = 8 |
|
|
| logger.info("Loading InvSR into memory...") |
| _sampler_invsr = InvSamplerSR(cfg, device="auto") |
| if _DEVICE == "cpu": |
| _sampler_invsr.sd_pipe = _sampler_invsr.sd_pipe.to(dtype=torch.float32) |
| _invsr_status = "ready" |
| logger.info("InvSR ready on %s", _DEVICE) |
| except Exception as e: |
| _invsr_status = "error" |
| _invsr_error = str(e) |
| logger.error("InvSR load failed: %s", e) |
|
|
| def upscale_invsr(image_bytes: bytes, num_steps: int = 1) -> bytes: |
| if _invsr_status == "error": |
| raise HTTPException(500, f"InvSR failed to load: {_invsr_error}") |
| if _sampler_invsr is None: |
| raise HTTPException(503, f"InvSR is still {_invsr_status}. Check /health for progress.") |
| sampler = _sampler_invsr |
| sys.path.insert(0, str(_INVSR_PATH)) |
| from utils import util_image |
| tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False) |
| try: |
| tmp.write(image_bytes); tmp.close() |
| im = util_image.imread(tmp.name, chn="rgb", dtype="float32") |
| finally: |
| os.unlink(tmp.name) |
| im_cond = util_image.img2tensor(im).to(sampler.device) |
| steps = {1: [200], 2: [200, 100], 3: [200, 100, 50], 4: [200, 150, 100, 50], 5: [250, 200, 150, 100, 50]} |
| sampler.configs.timesteps = steps.get(num_steps, [200]) |
| sampler.configs.basesr.chopping.pch_size = 128 |
| result = sampler.sample_func(im_cond).squeeze(0) |
| result = (result * 255).clip(0, 255).astype(np.uint8) |
| img = Image.fromarray(result) |
| buf = BytesIO(); img.save(buf, format="PNG"); buf.seek(0) |
| return buf.getvalue() |
|
|
| |
| FG_ESRGAN_PATH = CACHE_DIR / "esrgan" |
| _fg_esrgan_model = None |
| _fg_esrgan_loading = False |
|
|
| def _conv_block(in_nc, out_nc): |
| return torch.nn.Sequential( |
| torch.nn.Conv2d(in_nc, out_nc, kernel_size=3, padding=1), |
| torch.nn.LeakyReLU(negative_slope=0.2, inplace=True), |
| ) |
|
|
| class _ResidualDenseBlock5C(torch.nn.Module): |
| def __init__(self, nf=64, gc=32): |
| super().__init__() |
| self.conv1 = _conv_block(nf, gc) |
| self.conv2 = _conv_block(nf + gc, gc) |
| self.conv3 = _conv_block(nf + 2 * gc, gc) |
| self.conv4 = _conv_block(nf + 3 * gc, gc) |
| self.conv5 = torch.nn.Sequential(torch.nn.Conv2d(nf + 4 * gc, nf, kernel_size=3, padding=1)) |
|
|
| def forward(self, x): |
| x1 = self.conv1(x) |
| x2 = self.conv2(torch.cat((x, x1), 1)) |
| x3 = self.conv3(torch.cat((x, x1, x2), 1)) |
| x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) |
| x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) |
| return x5 * 0.2 + x |
|
|
| class _RRDB(torch.nn.Module): |
| def __init__(self, nf): |
| super().__init__() |
| self.RDB1 = _ResidualDenseBlock5C(nf) |
| self.RDB2 = _ResidualDenseBlock5C(nf) |
| self.RDB3 = _ResidualDenseBlock5C(nf) |
|
|
| def forward(self, x): |
| out = self.RDB1(x) |
| out = self.RDB2(out) |
| out = self.RDB3(out) |
| return out * 0.2 + x |
|
|
| class _SkipBlock(torch.nn.Module): |
| def __init__(self, sub): |
| super().__init__() |
| self.sub = sub |
| def forward(self, x): return x + self.sub(x) |
|
|
| class _RRDBNet(torch.nn.Module): |
| def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23): |
| super().__init__() |
| self.model = torch.nn.Sequential( |
| torch.nn.Conv2d(in_nc, nf, kernel_size=3, padding=1), |
| _SkipBlock(torch.nn.Sequential( |
| *(_RRDB(nf) for _ in range(nb)), |
| torch.nn.Conv2d(nf, nf, kernel_size=3, padding=1), |
| )), |
| torch.nn.Upsample(scale_factor=2), |
| torch.nn.Conv2d(nf, nf, kernel_size=3, padding=1), |
| torch.nn.LeakyReLU(negative_slope=0.2, inplace=True), |
| torch.nn.Upsample(scale_factor=2), |
| torch.nn.Conv2d(nf, nf, kernel_size=3, padding=1), |
| torch.nn.LeakyReLU(negative_slope=0.2, inplace=True), |
| torch.nn.Conv2d(nf, nf, kernel_size=3, padding=1), |
| torch.nn.LeakyReLU(negative_slope=0.2, inplace=True), |
| torch.nn.Conv2d(nf, out_nc, kernel_size=3, padding=1), |
| ) |
|
|
| def forward(self, x): |
| return self.model(x) |
|
|
| def _load_fg_esrgan(): |
| global _fg_esrgan_model, _fg_esrgan_loading |
| if _fg_esrgan_model is not None: return _fg_esrgan_model |
| if _fg_esrgan_loading: return None |
| _fg_esrgan_loading = True |
| try: |
| from huggingface_hub import hf_hub_download |
| logger.info("Downloading ESRGAN 4x-UltraSharp model...") |
| ckpt = hf_hub_download("philz1337x/upscaler", "4x-UltraSharp.pth", cache_dir=str(FG_ESRGAN_PATH)) |
| logger.info("Loading ESRGAN...") |
| state = torch.load(ckpt, map_location="cpu", weights_only=True) |
| model = _RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23) |
| model.load_state_dict(state, strict=False) |
| model.eval() |
| _fg_esrgan_model = model |
| logger.info("ESRGAN 4X ready (CPU)") |
| except Exception as e: |
| logger.error("Failed to load ESRGAN: %s", e) |
| _fg_esrgan_model = None |
| _fg_esrgan_loading = False |
| return _fg_esrgan_model |
|
|
| def upscale_finegrain(image_bytes: bytes, use_sd_refinement: bool = False) -> tuple[bytes, dict]: |
| model = _load_fg_esrgan() |
| if model is None: |
| raise HTTPException(503, "ESRGAN model not loaded. Check /health.") |
|
|
| img = Image.open(BytesIO(image_bytes)).convert("RGB") |
| in_w, in_h = img.size |
|
|
| |
| tile_size = 512 |
| overlap = 64 |
| w, h = img.size |
| out = Image.new("RGB", (w * 4, h * 4)) |
|
|
| if w <= tile_size and h <= tile_size: |
| img_np = np.array(img)[:, :, ::-1] |
| img_np = np.transpose(img_np, (2, 0, 1))[np.newaxis, :].astype(np.float32) / 255.0 |
| with torch.no_grad(): |
| result = model(torch.from_numpy(img_np)) |
| result = result.squeeze().clamp(0, 1).numpy() |
| result = np.transpose(result, (1, 2, 0))[:, :, ::-1] |
| out = Image.fromarray((result * 255).astype(np.uint8)) |
| else: |
| |
| stride = tile_size - overlap |
| cols = -(-max(0, w - overlap) // stride) if w > tile_size else 1 |
| rows = -(-max(0, h - overlap) // stride) if h > tile_size else 1 |
| out_arr = np.zeros((h * 4, w * 4, 3), dtype=np.float32) |
| weight = np.zeros((h * 4, w * 4, 1), dtype=np.float32) |
| for row in range(rows): |
| y1 = min(row * stride, h - tile_size) if h > tile_size else 0 |
| y2 = min(y1 + tile_size, h) |
| for col in range(cols): |
| x1 = min(col * stride, w - tile_size) if w > tile_size else 0 |
| x2 = min(x1 + tile_size, w) |
| tile = img.crop((x1, y1, x2, y2)) |
| tile_np = np.array(tile)[:, :, ::-1] |
| tile_np = np.transpose(tile_np, (2, 0, 1))[np.newaxis, :].astype(np.float32) / 255.0 |
| with torch.no_grad(): |
| res_tile = model(torch.from_numpy(tile_np)) |
| res_tile = res_tile.squeeze().clamp(0, 1).numpy() |
| res_tile = np.transpose(res_tile, (1, 2, 0)) |
| ys, ye = y1 * 4, y2 * 4 |
| xs, xe = x1 * 4, x2 * 4 |
| out_arr[ys:ye, xs:xe] += res_tile |
| weight[ys:ye, xs:xe] += 1.0 |
| out_arr = out_arr / np.maximum(weight, 1e-8) |
| out = Image.fromarray((out_arr[:, :, ::-1] * 255).astype(np.uint8)) |
|
|
| if use_sd_refinement and torch.cuda.is_available(): |
| out = out |
|
|
| buf = BytesIO(); out.save(buf, format="PNG"); buf.seek(0) |
| info = {"model": "esrgan_4x", "input": f"{in_w}x{in_h}", "output": f"{out.width}x{out.height}"} |
| return buf.getvalue(), info |
|
|
| |
| def compute_metrics(img: Image.Image) -> dict: |
| arr = np.array(img.convert("L"), dtype=np.float64) |
| lap = ndimage.laplace(arr) |
| hist = np.histogram(arr, bins=256, range=(0, 256))[0] |
| hist = hist[hist > 0] / hist.sum() |
| mag = np.hypot(ndimage.sobel(arr, axis=0), ndimage.sobel(arr, axis=1)) |
| return {"size": f"{img.width}x{img.height}", "sharpness": round(float(lap.var()), 4), "entropy": round(float(-np.sum(hist * np.log2(hist))), 4), "edge_density": round(float(np.mean(mag > mag.mean() + mag.std())), 4), "contrast_std": round(float(np.array(img).std()), 2)} |
|
|
| def generate_comparison(image_bytes: bytes) -> tuple[bytes, dict]: |
| original = Image.open(BytesIO(image_bytes)).convert("RGB") |
| metrics = {"original": compute_metrics(original)} |
| upscaled = {} |
| for scale in MEWZOOM_MODELS: |
| t0 = time.perf_counter() |
| rb, info = upscale_mewzoom(image_bytes, scale) |
| t = time.perf_counter() - t0 |
| img = Image.open(BytesIO(rb)).convert("RGB") |
| upscaled[scale] = img |
| metrics[scale] = {**compute_metrics(img), "time_s": round(t, 3), **info} |
| orig_r = original.resize(upscaled["2x"].size, Image.LANCZOS) |
| images = [orig_r, upscaled["2x"], upscaled["4x"]] |
| labels = ["Original", "MewZoom 2X", "MewZoom 4X"] |
| lh, gap = 30, 8 |
| mh = max(i.height for i in images) |
| tw = sum(i.width for i in images) + gap * (len(images) - 1) |
| canvas = Image.new("RGB", (tw, mh + lh), (30, 30, 30)) |
| draw = ImageDraw.Draw(canvas) |
| try: |
| font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14) |
| except Exception: |
| font = ImageFont.load_default() |
| x = 0 |
| for img, lbl in zip(images, labels): |
| canvas.paste(img, (x, lh)) |
| bb = draw.textbbox((0, 0), lbl, font=font) |
| draw.text((x + (img.width - (bb[2] - bb[0])) // 2, 6), lbl, fill=(255, 255, 255), font=font) |
| x += img.width + gap |
| buf = BytesIO(); canvas.save(buf, format="PNG"); buf.seek(0) |
| return buf.getvalue(), metrics |
|
|
| |
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| logger.info("Loading MewZoom models...") |
| for s in MEWZOOM_MODELS: |
| _load_mewzoom(s) |
| threading.Thread(target=_load_invsr_sync, daemon=True).start() |
| threading.Thread(target=_load_fg_esrgan, daemon=True).start() |
| yield |
|
|
| app = FastAPI(title="Super-Resolution API", version="2.0.0", lifespan=lifespan, |
| description="MewZoom 2X/4X + InvSR diffusion 4X + comparison + quality metrics") |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) |
|
|
| @app.get("/") |
| @app.get("/health") |
| async def health(): |
| return JSONResponse({ |
| "status": "healthy", "device": _DEVICE, |
| "models": ["2x","4x","invsr","finegrain"], |
| "gpu": torch.cuda.is_available(), |
| "invsr_status": _invsr_status, "invsr_error": _invsr_error, |
| "finegrain_loaded": _fg_esrgan_model is not None, |
| }) |
|
|
| @app.post("/upscale/2x") |
| async def route_2x(file: UploadFile = File(...)): |
| r, i = upscale_mewzoom(await file.read(), "2x") |
| return StreamingResponse(BytesIO(r), media_type="image/png", headers={"X-Info": json.dumps(i)}) |
|
|
| @app.post("/upscale/4x") |
| async def route_4x(file: UploadFile = File(...)): |
| r, i = upscale_mewzoom(await file.read(), "4x") |
| return StreamingResponse(BytesIO(r), media_type="image/png", headers={"X-Info": json.dumps(i)}) |
|
|
| @app.post("/upscale/compare") |
| async def route_compare(file: UploadFile = File(...), format: Literal["image","json","both"] = Query("both")): |
| img, m = generate_comparison(await file.read()) |
| if format == "json": return JSONResponse(m) |
| if format == "image": return StreamingResponse(BytesIO(img), media_type="image/png") |
| return StreamingResponse(BytesIO(img), media_type="image/png", headers={"X-Metrics": json.dumps(m)}) |
|
|
| @app.post("/upscale/metrics") |
| async def route_metrics(file: UploadFile = File(...)): |
| _, m = generate_comparison(await file.read()) |
| return JSONResponse(m) |
|
|
| @app.post("/upscale/finegrain") |
| async def route_finegrain( |
| file: UploadFile = File(...), |
| sd_refinement: bool = Query(False, description="Use SD1.5 refinement (GPU only)"), |
| ): |
| try: |
| r, i = upscale_finegrain(await file.read(), use_sd_refinement=sd_refinement) |
| except HTTPException: |
| raise |
| except Exception as e: |
| raise HTTPException(500, detail=f"Finegrain failed: {e}") |
| return StreamingResponse(BytesIO(r), media_type="image/png", headers={"X-Info": json.dumps(i)}) |
|
|
| @app.post("/upscale/invsr") |
| async def route_invsr(file: UploadFile = File(...), num_steps: int = Query(1, ge=1, le=5)): |
| if _invsr_status == "error": |
| raise HTTPException(500, f"InvSR not loaded: {_invsr_error}") |
| if _sampler_invsr is None: |
| raise HTTPException(503, f"InvSR is {_invsr_status}. Check /health for status.") |
| global _job_counter |
| _job_counter += 1 |
| job_id = str(_job_counter) |
| _invsr_jobs[job_id] = {"status": "queued", "image_bytes": await file.read(), "num_steps": num_steps} |
| threading.Thread(target=_run_invsr_job, args=(job_id,), daemon=True).start() |
| return JSONResponse({"job_id": job_id, "status": "queued", "check": f"/upscale/invsr/{job_id}"}) |
|
|
| def _run_invsr_job(job_id: str): |
| job = _invsr_jobs.get(job_id) |
| if not job: return |
| try: |
| job["status"] = "processing" |
| job["result"] = upscale_invsr(job["image_bytes"], job["num_steps"]) |
| job["status"] = "done" |
| except Exception as e: |
| job["status"] = "error" |
| job["error"] = str(e) |
|
|
| @app.get("/upscale/invsr/{job_id}") |
| async def route_invsr_status(job_id: str): |
| job = _invsr_jobs.get(job_id) |
| if not job: |
| raise HTTPException(404, "Job not found") |
| if job["status"] == "done": |
| return StreamingResponse(BytesIO(job["result"]), media_type="image/png") |
| return JSONResponse({"job_id": job_id, "status": job["status"], "error": job.get("error")}) |
|
|