|
import logging |
|
from pathlib import Path |
|
from typing import Optional |
|
|
|
from pydantic import BaseModel, Field |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class RifeNCNNOptions(BaseModel): |
|
model_path: Path = Field(..., description="Path to RIFE model directory") |
|
input_path: Path = Field(..., description="Path to source frames directory") |
|
output_path: Optional[Path] = Field(None, description="Path to output frames directory") |
|
num_frame: Optional[int] = Field(None, description="Number of frames to generate (default N*2)") |
|
time_step: float = Field(0.5, description="Time step for interpolation (default 0.5)", gt=0.0, le=1.0) |
|
gpu_id: Optional[int | list[int]] = Field( |
|
None, description="GPU ID(s) to use (default: auto, -1 for CPU)" |
|
) |
|
load_threads: int = Field(1, description="Number of threads for frame loading", gt=0) |
|
process_threads: int = Field(2, description="Number of threads used for frame processing", gt=0) |
|
save_threads: int = Field(2, description="Number of threads for frame saving", gt=0) |
|
spatial_tta: bool = Field(False, description="Enable spatial TTA mode") |
|
temporal_tta: bool = Field(False, description="Enable temporal TTA mode") |
|
uhd: bool = Field(False, description="Enable UHD mode") |
|
verbose: bool = Field(False, description="Enable verbose logging") |
|
|
|
def get_args(self, frame_multiplier: int = 7) -> list[str]: |
|
"""Generate arguments to pass to rife-ncnn-vulkan. |
|
|
|
Frame multiplier is used to calculate the number of frames to generate, if num_frame is not set. |
|
""" |
|
if self.output_path is None: |
|
self.output_path = self.input_path.joinpath("out") |
|
|
|
|
|
if self.num_frame is None: |
|
num_src_frames = len([x for x in self.input_path.glob("*.png") if x.is_file()]) |
|
logger.info(f"Found {num_src_frames} source frames, using multiplier {frame_multiplier}") |
|
num_frame = num_src_frames * frame_multiplier |
|
logger.info(f"We will generate {num_frame} frames") |
|
else: |
|
num_frame = self.num_frame |
|
|
|
|
|
if self.gpu_id is None: |
|
gpu_id = "auto" |
|
process_threads = self.process_threads |
|
elif isinstance(self.gpu_id, list): |
|
gpu_id = ",".join([str(x) for x in self.gpu_id]) |
|
process_threads = ",".join([str(self.process_threads) for _ in self.gpu_id]) |
|
else: |
|
gpu_id = str(self.gpu_id) |
|
process_threads = str(self.process_threads) |
|
|
|
|
|
args_list = [ |
|
"-i", |
|
f"{self.input_path.resolve()}/", |
|
"-o", |
|
f"{self.output_path.resolve()}/", |
|
"-m", |
|
f"{self.model_path.resolve()}/", |
|
"-n", |
|
num_frame, |
|
"-s", |
|
f"{self.time_step:.5f}", |
|
"-g", |
|
gpu_id, |
|
"-j", |
|
f"{self.load_threads}:{process_threads}:{self.save_threads}", |
|
] |
|
|
|
|
|
if self.spatial_tta: |
|
args_list.append("-x") |
|
if self.temporal_tta: |
|
args_list.append("-z") |
|
if self.uhd: |
|
args_list.append("-u") |
|
if self.verbose: |
|
args_list.append("-v") |
|
|
|
|
|
return [str(x) for x in args_list] |
|
|