|
|
import os |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import torch |
|
|
from torchvision import transforms |
|
|
|
|
|
import config |
|
|
from config import logger |
|
|
|
|
|
|
|
|
class RVMProcessor: |
|
|
"""RVM (Robust Video Matting) 抠图处理器""" |
|
|
|
|
|
def __init__(self): |
|
|
self.model = None |
|
|
self.available = False |
|
|
self.device = "cpu" |
|
|
|
|
|
try: |
|
|
|
|
|
local_repo = getattr(config, 'RVM_LOCAL_REPO', '') |
|
|
weights_path = getattr(config, 'RVM_WEIGHTS_PATH', '') |
|
|
|
|
|
if not local_repo or not os.path.isdir(local_repo): |
|
|
raise RuntimeError("RVM_LOCAL_REPO not set or invalid. Please set env RVM_LOCAL_REPO to local RobustVideoMatting repo path (with hubconf.py)") |
|
|
|
|
|
if not weights_path or not os.path.isfile(weights_path): |
|
|
raise RuntimeError("RVM_WEIGHTS_PATH not set or file not found. Please set env RVM_WEIGHTS_PATH to local RVM weights file path") |
|
|
|
|
|
logger.info(f"Loading RVM model {config.RVM_MODEL} from local repo: {local_repo}") |
|
|
|
|
|
self.model = torch.hub.load(local_repo, config.RVM_MODEL, source='local', pretrained=False) |
|
|
|
|
|
|
|
|
state = torch.load(weights_path, map_location=self.device) |
|
|
if isinstance(state, dict) and 'state_dict' in state: |
|
|
state = state['state_dict'] |
|
|
missing, unexpected = self.model.load_state_dict(state, strict=False) |
|
|
|
|
|
|
|
|
self.model = self.model.to(self.device).eval() |
|
|
self.available = True |
|
|
logger.info("RVM background removal processor initialized successfully (local mode)") |
|
|
if missing: |
|
|
logger.warning(f"RVM weights missing keys: {list(missing)[:5]}... total={len(missing)}") |
|
|
if unexpected: |
|
|
logger.warning(f"RVM weights unexpected keys: {list(unexpected)[:5]}... total={len(unexpected)}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"RVM background removal processor initialization failed: {e}") |
|
|
self.available = False |
|
|
|
|
|
def is_available(self) -> bool: |
|
|
"""检查RVM处理器是否可用""" |
|
|
return self.available and self.model is not None |
|
|
|
|
|
def remove_background(self, image: np.ndarray, background_color: tuple = None) -> np.ndarray: |
|
|
""" |
|
|
使用RVM移除图片背景 |
|
|
:param image: 输入的OpenCV图像(BGR格式) |
|
|
:param background_color: 替换的背景颜色(BGR格式),如果为None则保持透明背景 |
|
|
:return: 处理后的图像 |
|
|
""" |
|
|
if not self.is_available(): |
|
|
raise Exception("RVM抠图处理器不可用") |
|
|
|
|
|
try: |
|
|
logger.info("Starting to remove background using RVM...") |
|
|
|
|
|
|
|
|
original_height, original_width = image.shape[:2] |
|
|
|
|
|
|
|
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
src = transforms.ToTensor()(image_rgb).unsqueeze(0).to(self.device) |
|
|
|
|
|
|
|
|
rec = [None] * 4 |
|
|
with torch.no_grad(): |
|
|
fgr, pha, *rec = self.model(src, *rec, downsample_ratio=0.25) |
|
|
|
|
|
|
|
|
fgr = (fgr[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) |
|
|
pha = (pha[0, 0].cpu().numpy() * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
if fgr.shape[:2] != (original_height, original_width): |
|
|
fgr = cv2.resize(fgr, (original_width, original_height)) |
|
|
pha = cv2.resize(pha, (original_width, original_height)) |
|
|
|
|
|
if background_color is not None: |
|
|
|
|
|
|
|
|
fgr_bgr = cv2.cvtColor(fgr, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
|
|
|
background = np.full((original_height, original_width, 3), background_color, dtype=np.uint8) |
|
|
|
|
|
|
|
|
alpha = pha.astype(np.float32) / 255.0 |
|
|
alpha = np.stack([alpha] * 3, axis=-1) |
|
|
|
|
|
result = (fgr_bgr * alpha + background * (1 - alpha)).astype(np.uint8) |
|
|
else: |
|
|
|
|
|
fgr_bgr = cv2.cvtColor(fgr, cv2.COLOR_RGB2BGR) |
|
|
rgba = np.dstack((fgr_bgr, pha)) |
|
|
result = rgba |
|
|
|
|
|
logger.info("RVM background removal completed") |
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"RVM background removal failed: {e}") |
|
|
raise Exception(f"背景移除失败: {str(e)}") |
|
|
|
|
|
def create_id_photo(self, image: np.ndarray, background_color: tuple = (255, 255, 255)) -> np.ndarray: |
|
|
""" |
|
|
创建证件照(移除背景并添加纯色背景) |
|
|
:param image: 输入的OpenCV图像 |
|
|
:param background_color: 背景颜色,默认白色(BGR格式) |
|
|
:return: 处理后的证件照 |
|
|
""" |
|
|
logger.info(f"Starting to create ID photo, background color: {background_color}") |
|
|
|
|
|
|
|
|
id_photo = self.remove_background(image, background_color) |
|
|
|
|
|
logger.info("ID photo creation completed") |
|
|
return id_photo |
|
|
|