# app.py — Gradio front-end that calls test.py IN-PROCESS (ZeroGPU-safe) # Folder layout per run (under TEMP_ROOT): # input_video//00000.png ... # ref//ref.png # output//*.png # Final mp4: TEMP_ROOT/.mp4 import os import sys import shutil import urllib.request from os import path import io from contextlib import redirect_stdout, redirect_stderr import subprocess import tempfile import importlib import gradio as gr import spaces from PIL import Image import cv2 import torch # used for cuda sync & empty_cache # ----------------- BASIC INFO ----------------- CHECKPOINT_URL = "https://github.com/yyang181/colormnet/releases/download/v0.1/DINOv2FeatureV6_LocalAtten_s2_154000.pth" CHECKPOINT_LOCAL = "DINOv2FeatureV6_LocalAtten_s2_154000.pth" TITLE = "ColorMNet — 视频着色 / Video Colorization (ZeroGPU, CUDA-only)" DESC = """ **中文** 上传**黑白视频**与**参考图像**,点击「开始着色 / Start Coloring」。 此版本在 **app.py 中调度 ZeroGPU**,并**在同一进程**调用 `test.py` 的入口函数。 临时工作目录结构: - 抽帧:`_colormnet_tmp/input_video/<视频名>/00000.png ...` - 参考:`_colormnet_tmp/ref/<视频名>/ref.png` - 输出:`_colormnet_tmp/output/<视频名>/*.png` - 合成视频:`_colormnet_tmp/<视频名>.mp4` **English** Upload a **B&W video** and a **reference image**, then click “Start Coloring”. This app runs **ZeroGPU scheduling in `app.py`** and calls `test.py` **in-process**. Temp workspace layout: - Frames: `_colormnet_tmp/input_video//00000.png ...` - Reference: `_colormnet_tmp/ref//ref.png` - Output frames: `_colormnet_tmp/output//*.png` - Final video: `_colormnet_tmp/.mp4` """ PAPER = """ ### 论文 / Paper **ECCV 2024 — ColorMNet: A Memory-based Deep Spatial-Temporal Feature Propagation Network for Video Colorization** 如果你喜欢这个项目,欢迎到 GitHub 点个 ⭐ Star: **GitHub**: https://github.com/yyang181/colormnet **BibTeX 引用 / BibTeX Citation** ```bibtex @inproceedings{yang2024colormnet, author = {Yixin Yang and Jiangxin Dong and Jinhui Tang and Jinshan Pan}, title = {ColorMNet: A Memory-based Deep Spatial-Temporal Feature Propagation Network for Video Colorization}, booktitle = ECCV, year = {2024} } """ BADGES_HTML = """ """ # ----------------- REFERENCE FRAME GUIDE (NO CROPPING) ----------------- REF_GUIDE_MD = r""" ## 参考帧制作指南 / Reference Frame Guide **目的 / Goal** 为模型提供一张与你的视频关键帧在**姿态、光照、构图**尽量接近的**彩色参考图**,用来指导整段视频的着色风格与主体颜色。 --- ### 中文步骤 1. **挑帧**:从视频里挑一帧(或相近角度的照片),尽量与要着色的镜头在**姿态 / 光照 / 场景**一致。 2. **上色方式**:若你只有黑白参考图、但需要彩色参考,可用 **通义千问·图像编辑(Qwen-Image)**: - 打开: → 选择**图像编辑** - 上传你的黑白参考图 - 在提示词里输入: **「帮我给这张照片上色,只修改颜色,不要修改内容」** - 可按需多次编辑(如补充「衣服为复古蓝、肤色自然、不要锐化」) 3. **保存格式**:PNG/JPG 均可;推荐分辨率 ≥ **480px**(短边)。 4. **文件放置**:本应用会自动放置为 `ref/<视频名>/ref.png`。 **注意事项(Do/Don’t)** - ✅ 主体清晰、颜色干净,不要过曝或强滤镜。 - ✅ 关键区域(衣服、皮肤、头发、天空等)颜色与目标风格一致。 - ❌ 不要更改几何结构(如人脸形状/姿态),**只修改颜色**。 - ❌ 避免文字、贴纸、重度风格化滤镜。 --- ### English Steps 1. **Pick a frame** (or a similar photo) that matches the target shot in **pose / lighting / composition**. 2. **Colorizing if your reference is B&W** — use **Qwen-Image (Image Editing)**: - Open → **Image Editing** - Upload your B&W reference - Prompt: **“Help me colorize this photo; only change colors, do not alter the content.”** - Iterate if needed (e.g., “vintage blue jacket, natural skin tone; avoid sharpening”). 3. **Format**: PNG/JPG; recommended short side ≥ **480px**. 4. **File placement**: The app will place it as `ref//ref.png`. **Do / Don’t** - ✅ Clean subject and palette; avoid overexposure/harsh filters. - ✅ Ensure key regions (clothes/skin/hair/sky) match the intended colors. - ❌ Do not change geometry/structure — **colors only**. - ❌ Avoid text/stickers/heavy stylization filters. """ # ----------------- TEMP WORKDIR ----------------- TEMP_ROOT = path.join(os.getcwd(), "_colormnet_tmp") INPUT_DIR = "input_video" REF_DIR = "ref" OUTPUT_DIR = "output" def reset_temp_root(): """每次运行前清空并重建临时工作目录。""" if path.isdir(TEMP_ROOT): shutil.rmtree(TEMP_ROOT, ignore_errors=True) os.makedirs(TEMP_ROOT, exist_ok=True) for sub in (INPUT_DIR, REF_DIR, OUTPUT_DIR): os.makedirs(path.join(TEMP_ROOT, sub), exist_ok=True) def ensure_dir(d: str): os.makedirs(d, exist_ok=True) # ----------------- CHECKPOINT (可选) ----------------- def ensure_checkpoint(): """若 test.py 会在当前目录加载权重,可提前预下载,避免首次拉取超时。""" try: if not path.exists(CHECKPOINT_LOCAL): print(f"[INFO] Downloading checkpoint from: {CHECKPOINT_URL}") urllib.request.urlretrieve(CHECKPOINT_URL, CHECKPOINT_LOCAL) print("[INFO] Checkpoint downloaded:", CHECKPOINT_LOCAL) except Exception as e: print(f"[WARN] 预下载权重失败(首次推理会再试): {e}") # ----------------- VIDEO UTILS ----------------- def video_to_frames_dir(video_path: str, frames_dir: str): """ 抽帧到 frames_dir/00000.png ... 返回: (w, h, fps, n_frames) """ ensure_dir(frames_dir) cap = cv2.VideoCapture(video_path) assert cap.isOpened(), f"Cannot open video: {video_path}" fps = cap.get(cv2.CAP_PROP_FPS) or 25.0 idx = 0 w = h = None while True: ret, frame = cap.read() if not ret: break if frame is None: continue h, w = frame.shape[:2] out_path = path.join(frames_dir, f"{idx:05d}.png") ok = cv2.imwrite(out_path, frame) if not ok: raise RuntimeError(f"写入抽帧失败 / Failed to write: {out_path}") idx += 1 cap.release() if idx == 0: raise RuntimeError("视频无可读帧 / Input video has no readable frames.") return w, h, fps, idx def encode_frames_to_video(frames_dir: str, out_path: str, fps: float): frames = sorted([f for f in os.listdir(frames_dir) if f.lower().endswith(".png")]) if not frames: raise RuntimeError(f"No frames found in {frames_dir}") first = cv2.imread(path.join(frames_dir, frames[0])) if first is None: raise RuntimeError(f"Failed to read first frame {frames[0]}") h, w = first.shape[:2] fourcc = cv2.VideoWriter_fourcc(*"mp4v") vw = cv2.VideoWriter(out_path, fourcc, fps, (w, h)) for f in frames: img = cv2.imread(path.join(frames_dir, f)) if img is None: continue vw.write(img) vw.release() # ----------------- CLI MAPPING ----------------- CONFIG_TO_CLI = { "FirstFrameIsNotExemplar": "--FirstFrameIsNotExemplar", # bool "dataset": "--dataset", "split": "--split", "save_all": "--save_all", # bool "benchmark": "--benchmark", # bool "disable_long_term": "--disable_long_term", # bool "max_mid_term_frames": "--max_mid_term_frames", "min_mid_term_frames": "--min_mid_term_frames", "max_long_term_elements": "--max_long_term_elements", "num_prototypes": "--num_prototypes", "top_k": "--top_k", "mem_every": "--mem_every", "deep_update_every": "--deep_update_every", "save_scores": "--save_scores", # bool "flip": "--flip", # bool "size": "--size", "reverse": "--reverse", # bool } def build_args_list_for_test(d16_batch_path: str, out_path: str, ref_root: str, cfg: dict): """ 构造传给 test.run_cli(args_list) 的参数列表。 - 必传:--d16_batch_path 、--ref_path 、--output """ args = [ "--d16_batch_path", d16_batch_path, "--ref_path", ref_root, "--output", out_path, ] for k, v in cfg.items(): if k not in CONFIG_TO_CLI: continue flag = CONFIG_TO_CLI[k] if isinstance(v, bool): if v: args.append(flag) # store_true elif v is None: continue else: args.extend([flag, str(v)]) return args # ===== 新增:ZeroGPU 后按需安装 Pytorch-Correlation-extension ===== _CORR_OK_STAMP = path.join(os.getcwd(), ".corr_ext_installed") def ensure_correlation_extension_installed(): """ 在 ZeroGPU 分配后调用(位于 @spaces.GPU 函数体内): - 若已能 import 或存在本地 stamp,则直接返回 - 否则执行: git clone https://github.com/ClementPinard/Pytorch-Correlation-extension.git cd Pytorch-Correlation-extension && python setup.py install && cd .. """ # 1) 尝试直接导入 try: import spatial_correlation_sampler # noqa: F401 return except Exception: pass # 2) 之前成功过(打过 stamp) if path.exists(_CORR_OK_STAMP): return repo_url = "https://github.com/ClementPinard/Pytorch-Correlation-extension.git" workdir = tempfile.mkdtemp(prefix="corr_ext_") repo_dir = path.join(workdir, "Pytorch-Correlation-extension") try: print("[INFO] Installing Pytorch-Correlation-extension ...") # clone subprocess.run( ["git", "clone", "--depth", "1", repo_url], cwd=workdir, check=True ) # build & install subprocess.run( [sys.executable, "setup.py", "install"], cwd=repo_dir, check=True ) # 验证 importlib.invalidate_caches() import spatial_correlation_sampler # noqa: F401 # 打 stamp,避免下次重复 with open(_CORR_OK_STAMP, "w") as f: f.write("ok") print("[INFO] Pytorch-Correlation-extension installed successfully.") except subprocess.CalledProcessError as e: print(f"[WARN] Failed to build/install correlation extension: {e}") print("You can still proceed if your pipeline doesn't use it.") except Exception as e: print(f"[WARN] Correlation extension install check failed: {e}") finally: # 清理临时目录 try: shutil.rmtree(workdir, ignore_errors=True) except Exception: pass # ----------------- GRADIO HANDLER ----------------- @spaces.GPU(duration=600) # 确保 CUDA 初始化在此函数体内 def gradio_infer( debug_shapes, bw_video, ref_image, first_not_exemplar, dataset, split, save_all, benchmark, disable_long_term, max_mid, min_mid, max_long, num_proto, top_k, mem_every, deep_update, save_scores, flip, size, reverse ): # <<< ZeroGPU 分配后:按需安装 Pytorch-Correlation-extension >>> ensure_correlation_extension_installed() # -------------------------------------------------------------- # 1) 基本校验与临时目录 if bw_video is None: return None, "请上传黑白视频 / Please upload a B&W video." if ref_image is None: return None, "请上传参考图像 / Please upload a reference image." reset_temp_root() # 2) 解析视频源路径 & 目标 if isinstance(bw_video, dict) and "name" in bw_video: src_video_path = bw_video["name"] elif isinstance(bw_video, str): src_video_path = bw_video else: return None, "无法读取视频输入 / Failed to read video input." video_stem = path.splitext(path.basename(src_video_path))[0] # 3) 生成临时路径 input_root = path.join(TEMP_ROOT, INPUT_DIR) # _colormnet_tmp/input_video ref_root = path.join(TEMP_ROOT, REF_DIR) # _colormnet_tmp/ref output_root= path.join(TEMP_ROOT, OUTPUT_DIR) # _colormnet_tmp/output input_frames_dir = path.join(input_root, video_stem) ref_dir = path.join(ref_root, video_stem) out_frames_dir = path.join(output_root, video_stem) for d in (input_root, ref_root, output_root, input_frames_dir, ref_dir, out_frames_dir): ensure_dir(d) # 4) 抽帧 -> input_video// try: _w, _h, fps, _n = video_to_frames_dir(src_video_path, input_frames_dir) except Exception as e: return None, f"抽帧失败 / Frame extraction failed:\n{e}" # 5) 参考帧 -> ref//ref.png ref_png_path = path.join(ref_dir, "ref.png") if isinstance(ref_image, Image.Image): try: ref_image.save(ref_png_path) except Exception as e: return None, f"保存参考图像失败 / Failed to save reference image:\n{e}" elif isinstance(ref_image, str): try: shutil.copy2(ref_image, ref_png_path) except Exception as e: return None, f"复制参考图像失败 / Failed to copy reference image:\n{e}" else: return None, "无法读取参考图像输入 / Failed to read reference image." # 6) 收集 UI 配置 default_config = { "FirstFrameIsNotExemplar": True, "dataset": "D16_batch", "split": "val", "save_all": True, "benchmark": False, "disable_long_term": False, "max_mid_term_frames": 10, "min_mid_term_frames": 5, "max_long_term_elements": 10000, "num_prototypes": 128, "top_k": 30, "mem_every": 5, "deep_update_every": -1, "save_scores": False, "flip": False, "size": -1, "reverse": False, } user_config = { "FirstFrameIsNotExemplar": bool(first_not_exemplar) if first_not_exemplar is not None else default_config["FirstFrameIsNotExemplar"], "dataset": str(dataset) if dataset else default_config["dataset"], "split": str(split) if split else default_config["split"], "save_all": bool(save_all) if save_all is not None else default_config["save_all"], "benchmark": bool(benchmark) if benchmark is not None else default_config["benchmark"], "disable_long_term": bool(disable_long_term) if disable_long_term is not None else default_config["disable_long_term"], "max_mid_term_frames": int(max_mid) if max_mid is not None else default_config["max_mid_term_frames"], "min_mid_term_frames": int(min_mid) if min_mid is not None else default_config["min_mid_term_frames"], "max_long_term_elements": int(max_long) if max_long is not None else default_config["max_long_term_elements"], "num_prototypes": int(num_proto) if num_proto is not None else default_config["num_prototypes"], "top_k": int(top_k) if top_k is not None else default_config["top_k"], "mem_every": int(mem_every) if mem_every is not None else default_config["mem_every"], "deep_update_every": int(deep_update) if deep_update is not None else default_config["deep_update_every"], "save_scores": bool(save_scores) if save_scores is not None else default_config["save_scores"], "flip": bool(flip) if flip is not None else default_config["flip"], "size": int(size) if size is not None else default_config["size"], "reverse": bool(reverse) if reverse is not None else default_config["reverse"], } # 7) 预下载权重(可选) ensure_checkpoint() # 8) 同进程调用 test.py try: import test # 确保 test.py 同目录且提供 run_cli(args_list) except Exception as e: return None, f"导入 test.py 失败 / Failed to import test.py:\n{e}" args_list = build_args_list_for_test( d16_batch_path=input_root, # 指向 input_video 根 out_path=output_root, # 指向 output 根(test.py 写 output//*.png) ref_root=ref_root, # 指向 ref 根(test.py 读 ref//ref.png) cfg=user_config ) buf = io.StringIO() try: with redirect_stdout(buf), redirect_stderr(buf): entry = getattr(test, "run_cli", None) if entry is None or not callable(entry): raise RuntimeError("test.py 未提供可调用的 run_cli(args_list) 接口。") entry(args_list) log = f"Args: {' '.join(args_list)}\n\n{buf.getvalue()}" except Exception as e: log = f"Args: {' '.join(args_list)}\n\n{buf.getvalue()}\n\nERROR: {e}" return None, log # 在合成 mp4 之前:清空 CUDA(防止显存占用) try: torch.cuda.synchronize() except Exception: pass try: torch.cuda.empty_cache() except Exception: pass # 9) 合成 mp4:从 output// 帧合成 -> TEMP_ROOT/.mp4 out_frames = path.join(output_root, video_stem) if not path.isdir(out_frames): return None, f"未找到输出帧目录 / Output frame dir not found:{out_frames}\n\n{log}" final_mp4 = path.abspath(path.join(TEMP_ROOT, f"{video_stem}.mp4")) try: encode_frames_to_video(out_frames, final_mp4, fps=fps) except Exception as e: return None, f"合成视频失败 / Video mux failed:\n{e}\n\n{log}" return final_mp4, f"完成 ✅ / Done ✅\n\n{log}" # ----------------- UI ----------------- with gr.Blocks() as demo: gr.Markdown(f"# {TITLE}") gr.HTML(BADGES_HTML) gr.Markdown(PAPER) gr.Markdown(DESC) # 参考帧制作指南(中英双语,无裁剪步骤) with gr.Accordion("参考帧制作指南 / Reference Frame Guide", open=False): gr.Markdown(REF_GUIDE_MD) debug_shapes = gr.Checkbox(label="调试日志 / Debug Logs(仅用于显示更完整日志 / show verbose logs)", value=False) with gr.Row(): inp_video = gr.Video(label="黑白视频(mp4/webm/avi) / B&W Video", interactive=True) inp_ref = gr.Image(label="参考图像(RGB) / Reference Image (RGB)", type="pil") gr.Examples( label="示例 / Examples", examples=[["./example/4.mp4", "./example/4.png"]], inputs=[inp_video, inp_ref], cache_examples=False, ) with gr.Accordion("高级参数设置 / Advanced Settings(传给 test.py / passed to test.py)", open=False): with gr.Row(): first_not_exemplar = gr.Checkbox(label="FirstFrameIsNotExemplar (--FirstFrameIsNotExemplar)", value=True) reverse = gr.Checkbox(label="reverse (--reverse)", value=False) dataset = gr.Textbox(label="dataset (--dataset)", value="D16_batch") split = gr.Textbox(label="split (--split)", value="val") save_all = gr.Checkbox(label="save_all (--save_all)", value=True) benchmark = gr.Checkbox(label="benchmark (--benchmark)", value=False) with gr.Row(): disable_long_term = gr.Checkbox(label="disable_long_term (--disable_long_term)", value=False) max_mid = gr.Number(label="max_mid_term_frames (--max_mid_term_frames)", value=10, precision=0) min_mid = gr.Number(label="min_mid_term_frames (--min_mid_term_frames)", value=5, precision=0) max_long = gr.Number(label="max_long_term_elements (--max_long_term_elements)", value=10000, precision=0) num_proto = gr.Number(label="num_prototypes (--num_prototypes)", value=128, precision=0) with gr.Row(): top_k = gr.Number(label="top_k (--top_k)", value=30, precision=0) mem_every = gr.Number(label="mem_every (--mem_every)", value=5, precision=0) deep_update = gr.Number(label="deep_update_every (--deep_update_every)", value=-1, precision=0) save_scores = gr.Checkbox(label="save_scores (--save_scores)", value=False) flip = gr.Checkbox(label="flip (--flip)", value=False) size = gr.Number(label="size (--size)", value=-1, precision=0) run_btn = gr.Button("开始着色 / Start Coloring") with gr.Row(): out_video = gr.Video(label="输出视频(着色结果) / Output (Colorized)", autoplay=True) status = gr.Textbox(label="状态 / 日志输出 / Status & Logs", interactive=False, lines=16) run_btn.click( fn=gradio_infer, inputs=[ debug_shapes, inp_video, inp_ref, first_not_exemplar, dataset, split, save_all, benchmark, disable_long_term, max_mid, min_mid, max_long, num_proto, top_k, mem_every, deep_update, save_scores, flip, size, reverse ], outputs=[out_video, status] ) gr.HTML("
") gr.HTML(BADGES_HTML) if __name__ == "__main__": try: ensure_checkpoint() except Exception as e: print(f"[WARN] 预下载权重失败(首次推理会再试): {e}") demo.queue(max_size=32).launch(server_name="0.0.0.0", server_port=7860)