import pathlib from dataclasses import dataclass import cv2 import numpy as np import redis from logger import logger from config import settings try: logger.info( f"Connecting to Redis: {settings.get('redis_url', 'redis://localhost:6379')}" ) redis_client = redis.from_url(settings.get("redis_url", "redis://localhost:6379")) logger.info(f"Redis ping: {redis_client.ping()}") except Exception as e: logger.error(f"Failed to connect to Redis: {e}") exit(1) @dataclass class ModelInfo: name: str = "" path: str = "" scale: int = 4 algo: str = "" BASE_STREAM_NAME = ( "super_resolution_api_queue" if not settings.get("worker_id") else f"super_resolution_api_queue_{settings.get('worker_id')}" ) WORKER_KEY_PREFIX = "super_resolution_api_worker_" DISTRIBUTED_STREAM_NAME = "super_resolution_api_distributed_queue" RESULT_KEY_PREFIX = ( "super_resolution_api_result_" if not settings.get("worker_id") else f"super_resolution_api_result_{settings.get('worker_id')}_" ) PROGRESS_TIMEOUT = settings.get("timeout", 30) MAX_ALLOWED_TIMEOUT = settings.get("max_timeout", 300) MAX_THREAD = settings.get("max_thread", 8) MODEL_NAME_DEFAULT = "x4_Anime_6B-Official" MODEL_NAME_X4_JP_ILLUSTRATION_FIX1 = "x4_JP_Illustration-fix1" MODEL_NAME_X4_JP_ILLUSTRATION_FIX2 = "x4_JP_Illustration-fix2" MODEL_NAME_X4_JP_ILLUSTRATION_FIX1_D = "x4_JP_Illustration-fix1-d" MODEL_NAME_X4_ANIME_6B_OFFICIAL = "x4_Anime_6B-Official" model_Anime_Official = ModelInfo( MODEL_NAME_X4_ANIME_6B_OFFICIAL, "models/x4_Anime_6B-Official.onnx", 4, "real-esrgan", ) model_JP_Illustration_fix1 = ModelInfo( MODEL_NAME_X4_JP_ILLUSTRATION_FIX1, "models/x4_jp_Illustration-fix1.onnx", 4, "real-hatgan", ) model_JP_Illustration_fix2 = ModelInfo( MODEL_NAME_X4_JP_ILLUSTRATION_FIX2, "models/x4_jp_Illustration-fix2.onnx", 4, "real-esrgan", ) model_JP_Illustration_fix1_d = ModelInfo( MODEL_NAME_X4_JP_ILLUSTRATION_FIX1_D, "models/x4_jp_Illustration-fix1-d.onnx", 4, "real-esrgan", ) models = { MODEL_NAME_X4_ANIME_6B_OFFICIAL: model_Anime_Official, MODEL_NAME_X4_JP_ILLUSTRATION_FIX1: model_JP_Illustration_fix1, MODEL_NAME_X4_JP_ILLUSTRATION_FIX2: model_JP_Illustration_fix2, MODEL_NAME_X4_JP_ILLUSTRATION_FIX1_D: model_JP_Illustration_fix1_d, } def get_image_size(image_path: pathlib.Path) -> tuple[int, int]: """ return: (width, height) """ img = cv2.imread(str(image_path)) if img is None: raise Exception(f"Failed to load image: {image_path}") return img.shape[1], img.shape[0] @dataclass class TileInfo: x: int y: int filpath: pathlib.Path def split_image( img_path: pathlib.Path, save_dir: pathlib.Path, grid_size: tuple[int, int], overlap: int = 16, ) -> list[TileInfo]: save_path = pathlib.Path(save_dir) save_path.mkdir(parents=True, exist_ok=True) img = cv2.imread(str(img_path)) if img is None: raise Exception(f"Failed to load image: {img_path}") height, width = img.shape[:2] rows, cols = grid_size base_h = height // rows base_w = width // cols tiles_info = [] for row in range(rows): for col in range(cols): x1 = max(0, col * base_w - overlap) y1 = max(0, row * base_h - overlap) x2 = min(width, (col + 1) * base_w + overlap) y2 = min(height, (row + 1) * base_h + overlap) tile = img[y1:y2, x1:x2] tile_name = f"{img_path.stem}_tile_{row}_{col}.png" tile_path = save_path / tile_name cv2.imwrite(str(tile_path), tile) tiles_info.append(TileInfo(col, row, tile_path)) return tiles_info def merge_sr_tiles( tiles: list[TileInfo], output: pathlib.Path, original_size: tuple[int, int], scale: int, overlap: int = 16, ): """ 合并超分辨率后的图块 tiles: 超分辨率后的图块信息列表, 需要根据 filepath 读取图块, 根据 x, y 位置信息进行拼接 output: 合并后的图片保存路径 original_size: 原始图片的尺寸 overlap 为原始图片切割时设定的重叠像素数 scale 为超分辨率倍数 """ # Calculate output dimensions logger.debug( f"正在合并 {len(tiles)} 张超分辨率图块, 原尺寸: {original_size}, 缩放倍数: {scale}" ) width, height = original_size out_width = width * scale out_height = height * scale output_img = np.zeros((out_height, out_width, 3), dtype=np.uint8) # Calculate base tile sizes rows = max([t.y for t in tiles]) + 1 cols = max([t.x for t in tiles]) + 1 base_h = height // rows base_w = width // cols # Scale dimensions scaled_base_h = base_h * scale scaled_base_w = base_w * scale scaled_overlap = overlap * scale for tile_info in tiles: # Read tile tile = cv2.imread(str(tile_info.filpath)) if tile is None: raise Exception(f"Failed to load tile: {tile_info.filpath}") # Calculate positions x1 = max(0, tile_info.x * scaled_base_w - scaled_overlap) y1 = max(0, tile_info.y * scaled_base_h - scaled_overlap) x2 = min(out_width, (tile_info.x + 1) * scaled_base_w + scaled_overlap) y2 = min(out_height, (tile_info.y + 1) * scaled_base_h + scaled_overlap) # Calculate blend mask for overlapping regions h, w = y2 - y1, x2 - x1 blend_mask = np.ones((h, w, 1), dtype=np.float32) # Apply feathering at edges if tile_info.x > 0: # Left edge blend_mask[:, :scaled_overlap] = np.linspace(0, 1, scaled_overlap).reshape( 1, -1, 1 ) if tile_info.x < cols - 1: # Right edge blend_mask[:, -scaled_overlap:] = np.linspace(1, 0, scaled_overlap).reshape( 1, -1, 1 ) if tile_info.y > 0: # Top edge blend_mask[:scaled_overlap, :] *= np.linspace(0, 1, scaled_overlap).reshape( -1, 1, 1 ) if tile_info.y < rows - 1: # Bottom edge blend_mask[-scaled_overlap:, :] *= np.linspace( 1, 0, scaled_overlap ).reshape(-1, 1, 1) # Blend tiles output_img[y1:y2, x1:x2] = ( output_img[y1:y2, x1:x2] * (1 - blend_mask) + tile[: y2 - y1, : x2 - x1] * blend_mask ).astype(np.uint8) cv2.imwrite(output.as_posix(), output_img) def calculate_grid(image_width, image_height, workers): if workers <= 0: raise ValueError("Worker count must be positive") best_rows, best_cols = 1, workers min_aspect_diff = float("inf") for rows in range(1, workers + 1): if workers % rows == 0: cols = workers // rows tile_width = image_width / cols tile_height = image_height / rows aspect_ratio = max(tile_width, tile_height) / min(tile_width, tile_height) aspect_diff = aspect_ratio - 1 if aspect_diff < min_aspect_diff: best_rows, best_cols = rows, cols min_aspect_diff = aspect_diff logger.debug(f"calculate_grid: {best_rows}x{best_cols}") return best_rows, best_cols