Spaces:
Running
Running
| """ | |
| export_models.py | |
| ---------------- | |
| Downloads publicly available pretrained weights for SRCNN and EDSR (HResNet-style) | |
| and exports them as ONNX files into the ./model/ directory. | |
| Run once before starting app.py: | |
| pip install torch torchvision huggingface_hub basicsr | |
| python export_models.py | |
| After this script finishes you should have: | |
| model/SRCNN_x4.onnx | |
| model/HResNet_x4.onnx | |
| Then upload both files to Google Drive, copy the file IDs into DRIVE_IDS in app.py, | |
| OR set LOCAL_ONLY = True below to skip Drive entirely and load straight from disk. | |
| """ | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.onnx | |
| from pathlib import Path | |
| MODEL_DIR = Path("model") | |
| MODEL_DIR.mkdir(exist_ok=True) | |
| # --------------------------------------------------------------------------- | |
| # Set to True to skip Drive and have app.py load the ONNX files from disk | |
| # directly. In app.py, remove the download_from_drive call for these keys | |
| # (or just leave the placeholder Drive ID β the script already guards against | |
| # missing files gracefully). | |
| # --------------------------------------------------------------------------- | |
| LOCAL_ONLY = True # flip to False once you have Drive IDs | |
| # =========================================================================== | |
| # 1. SRCNN Γ4 | |
| # Architecture: Dong et al. 2014 β 3 conv layers, no upsampling inside | |
| # the network. Input is bicubic-upscaled LR; output is the refined HR. | |
| # We bicubic-upsample inside a wrapper so the ONNX takes a raw LR image. | |
| # =========================================================================== | |
| class SRCNN(nn.Module): | |
| """Original SRCNN (Dong et al., 2014).""" | |
| def __init__(self, num_channels: int = 3): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2) | |
| self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2) | |
| self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2) | |
| self.relu = nn.ReLU(inplace=True) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.relu(self.conv1(x)) | |
| x = self.relu(self.conv2(x)) | |
| return self.conv3(x) | |
| class SRCNNx4Wrapper(nn.Module): | |
| """ | |
| Wraps SRCNN so the ONNX input is a LOW-resolution image. | |
| Internally bicubic-upsamples by Γ4 before feeding SRCNN, | |
| matching the interface expected by app.py's tile_upscale_model. | |
| """ | |
| def __init__(self, srcnn: SRCNN, scale: int = 4): | |
| super().__init__() | |
| self.srcnn = srcnn | |
| self.scale = scale | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # x: (1, 3, H, W) β low-res, float32 in [0, 1] | |
| up = torch.nn.functional.interpolate( | |
| x, scale_factor=self.scale, mode="bicubic", align_corners=False | |
| ) | |
| return self.srcnn(up) | |
| def build_srcnn_x4() -> nn.Module: | |
| """ | |
| Loads pretrained SRCNN weights from the basicsr model zoo. | |
| Falls back to random init with a warning if download fails. | |
| """ | |
| srcnn = SRCNN(num_channels=3) | |
| wrapper = SRCNNx4Wrapper(srcnn, scale=4) | |
| # Pretrained weights from the basicsr / mmedit community | |
| # (original Caffe weights re-converted to PyTorch by https://github.com/yjn870/SRCNN-pytorch) | |
| SRCNN_WEIGHTS_URL = ( | |
| "https://github.com/yjn870/SRCNN-pytorch/raw/master/models/" | |
| "srcnn_x4.pth" | |
| ) | |
| weights_path = MODEL_DIR / "srcnn_x4.pth" | |
| if not weights_path.exists(): | |
| print(" Downloading SRCNN Γ4 weights β¦") | |
| try: | |
| import urllib.request | |
| urllib.request.urlretrieve(SRCNN_WEIGHTS_URL, weights_path) | |
| print(f" Saved β {weights_path}") | |
| except Exception as e: | |
| print(f" [WARN] Could not download SRCNN weights: {e}") | |
| print(" Continuing with random init (quality will be poor).") | |
| return wrapper | |
| state = torch.load(weights_path, map_location="cpu") | |
| # The yjn870 checkpoint uses keys conv1/conv2/conv3 matching our module | |
| try: | |
| srcnn.load_state_dict(state, strict=True) | |
| print(" SRCNN weights loaded β") | |
| except RuntimeError as e: | |
| print(f" [WARN] Weight mismatch: {e}\n Proceeding with partial load.") | |
| srcnn.load_state_dict(state, strict=False) | |
| return wrapper | |
| # =========================================================================== | |
| # 2. EDSR (HResNet-style) Γ4 | |
| # EDSR-baseline (Lim et al., 2017) is the canonical "deep residual" SR | |
| # network. Pretrained weights from eugenesiow/torch-sr (HuggingFace). | |
| # =========================================================================== | |
| class ResBlock(nn.Module): | |
| def __init__(self, n_feats: int, res_scale: float = 1.0): | |
| super().__init__() | |
| self.body = nn.Sequential( | |
| nn.Conv2d(n_feats, n_feats, 3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(n_feats, n_feats, 3, padding=1), | |
| ) | |
| self.res_scale = res_scale | |
| def forward(self, x): | |
| return x + self.body(x) * self.res_scale | |
| class Upsampler(nn.Sequential): | |
| def __init__(self, scale: int, n_feats: int): | |
| layers = [] | |
| if scale in (2, 4): | |
| steps = {2: 1, 4: 2}[scale] | |
| for _ in range(steps): | |
| layers += [ | |
| nn.Conv2d(n_feats, 4 * n_feats, 3, padding=1), | |
| nn.PixelShuffle(2), | |
| ] | |
| elif scale == 3: | |
| layers += [ | |
| nn.Conv2d(n_feats, 9 * n_feats, 3, padding=1), | |
| nn.PixelShuffle(3), | |
| ] | |
| super().__init__(*layers) | |
| class EDSR(nn.Module): | |
| """ | |
| EDSR-baseline: 16 residual blocks, 64 feature channels. | |
| Matches the publicly released weights from eugenesiow/torch-sr. | |
| """ | |
| def __init__(self, n_resblocks: int = 16, n_feats: int = 64, | |
| scale: int = 4, num_channels: int = 3): | |
| super().__init__() | |
| self.head = nn.Conv2d(num_channels, n_feats, 3, padding=1) | |
| self.body = nn.Sequential(*[ResBlock(n_feats) for _ in range(n_resblocks)]) | |
| self.body_tail = nn.Conv2d(n_feats, n_feats, 3, padding=1) | |
| self.tail = nn.Sequential( | |
| Upsampler(scale, n_feats), | |
| nn.Conv2d(n_feats, num_channels, 3, padding=1), | |
| ) | |
| def forward(self, x): | |
| x = self.head(x) | |
| res = self.body(x) | |
| res = self.body_tail(res) | |
| x = x + res | |
| return self.tail(x) | |
| def build_edsr_x4() -> nn.Module: | |
| """ | |
| Downloads EDSR-baseline Γ4 weights and loads them. | |
| Source: eugenesiow/torch-sr (Apache-2.0 licensed). | |
| """ | |
| model = EDSR(n_resblocks=16, n_feats=64, scale=4) | |
| # Direct link to the EDSR-baseline Γ4 checkpoint | |
| EDSR_WEIGHTS_URL = ( | |
| "https://huggingface.co/eugenesiow/edsr-base/resolve/main/" | |
| "pytorch_model_4x.pt" | |
| ) | |
| weights_path = MODEL_DIR / "edsr_x4.pt" | |
| if not weights_path.exists(): | |
| print(" Downloading EDSR Γ4 weights from HuggingFace β¦") | |
| try: | |
| import urllib.request | |
| urllib.request.urlretrieve(EDSR_WEIGHTS_URL, weights_path) | |
| print(f" Saved β {weights_path}") | |
| except Exception as e: | |
| print(f" [WARN] Could not download EDSR weights: {e}") | |
| print(" Continuing with random init (quality will be poor).") | |
| return model | |
| state = torch.load(weights_path, map_location="cpu") | |
| # eugenesiow checkpoints may wrap state_dict under a 'model' key | |
| if "model" in state: | |
| state = state["model"] | |
| if "state_dict" in state: | |
| state = state["state_dict"] | |
| # Strip any 'module.' prefix from DataParallel wrapping | |
| state = {k.replace("module.", ""): v for k, v in state.items()} | |
| try: | |
| model.load_state_dict(state, strict=True) | |
| print(" EDSR weights loaded β") | |
| except RuntimeError as e: | |
| print(f" [WARN] Weight mismatch ({e}). Trying strict=False β¦") | |
| model.load_state_dict(state, strict=False) | |
| print(" EDSR weights loaded (partial) β") | |
| return model | |
| # =========================================================================== | |
| # ONNX export helper | |
| # =========================================================================== | |
| def export_onnx(model: nn.Module, out_path: Path, tile_h: int = 128, tile_w: int = 128): | |
| """Export *model* to ONNX with dynamic H/W axes.""" | |
| model.eval() | |
| dummy = torch.zeros(1, 3, tile_h, tile_w) | |
| torch.onnx.export( | |
| model, | |
| dummy, | |
| str(out_path), | |
| opset_version=17, | |
| input_names=["input"], | |
| output_names=["output"], | |
| dynamic_axes={ | |
| "input": {0: "batch", 2: "H", 3: "W"}, | |
| "output": {0: "batch", 2: "H_out", 3: "W_out"}, | |
| }, | |
| ) | |
| size_mb = out_path.stat().st_size / 1_048_576 | |
| print(f" Exported β {out_path} ({size_mb:.1f} MB)") | |
| # =========================================================================== | |
| # Main | |
| # =========================================================================== | |
| if __name__ == "__main__": | |
| print("=" * 60) | |
| print("SpectraGAN β ONNX model exporter") | |
| print("=" * 60) | |
| # -- SRCNN Γ4 ------------------------------------------------------------ | |
| srcnn_out = MODEL_DIR / "SRCNN_x4.onnx" | |
| if srcnn_out.exists(): | |
| print(f"\n[SKIP] {srcnn_out} already exists.") | |
| else: | |
| print("\n[1/2] Building SRCNN Γ4 β¦") | |
| srcnn_model = build_srcnn_x4() | |
| print(" Exporting to ONNX β¦") | |
| export_onnx(srcnn_model, srcnn_out, tile_h=128, tile_w=128) | |
| # -- EDSR (HResNet) Γ4 --------------------------------------------------- | |
| edsr_out = MODEL_DIR / "HResNet_x4.onnx" | |
| if edsr_out.exists(): | |
| print(f"\n[SKIP] {edsr_out} already exists.") | |
| else: | |
| print("\n[2/2] Building EDSR (HResNet) Γ4 β¦") | |
| edsr_model = build_edsr_x4() | |
| print(" Exporting to ONNX β¦") | |
| export_onnx(edsr_model, edsr_out, tile_h=128, tile_w=128) | |
| print("\n" + "=" * 60) | |
| print("Done! Files created:") | |
| for p in [srcnn_out, edsr_out]: | |
| status = "β" if p.exists() else "β MISSING" | |
| print(f" {status} {p}") | |
| print() | |
| if LOCAL_ONLY: | |
| print("LOCAL_ONLY = True:") | |
| print(" app.py will load these files directly from disk.") | |
| print(" No Google Drive upload needed.") | |
| else: | |
| print("Next step:") | |
| print(" Upload the .onnx files to Google Drive and paste") | |
| print(" the file IDs into DRIVE_IDS in app.py.") | |
| print("=" * 60) | |