File size: 3,462 Bytes
d0ffe9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import logging
from pathlib import Path
from typing import Optional

from pydantic import BaseModel, Field

logger = logging.getLogger(__name__)


class RifeNCNNOptions(BaseModel):
    model_path: Path = Field(..., description="Path to RIFE model directory")
    input_path: Path = Field(..., description="Path to source frames directory")
    output_path: Optional[Path] = Field(None, description="Path to output frames directory")
    num_frame: Optional[int] = Field(None, description="Number of frames to generate (default N*2)")
    time_step: float = Field(0.5, description="Time step for interpolation (default 0.5)", gt=0.0, le=1.0)
    gpu_id: Optional[int | list[int]] = Field(
        None, description="GPU ID(s) to use (default: auto, -1 for CPU)"
    )
    load_threads: int = Field(1, description="Number of threads for frame loading", gt=0)
    process_threads: int = Field(2, description="Number of threads used for frame processing", gt=0)
    save_threads: int = Field(2, description="Number of threads for frame saving", gt=0)
    spatial_tta: bool = Field(False, description="Enable spatial TTA mode")
    temporal_tta: bool = Field(False, description="Enable temporal TTA mode")
    uhd: bool = Field(False, description="Enable UHD mode")
    verbose: bool = Field(False, description="Enable verbose logging")

    def get_args(self, frame_multiplier: int = 7) -> list[str]:
        """Generate arguments to pass to rife-ncnn-vulkan.

        Frame multiplier is used to calculate the number of frames to generate, if num_frame is not set.
        """
        if self.output_path is None:
            self.output_path = self.input_path.joinpath("out")

        # calc num frames
        if self.num_frame is None:
            num_src_frames = len([x for x in self.input_path.glob("*.png") if x.is_file()])
            logger.info(f"Found {num_src_frames} source frames, using multiplier {frame_multiplier}")
            num_frame = num_src_frames * frame_multiplier
            logger.info(f"We will generate {num_frame} frames")
        else:
            num_frame = self.num_frame

        # GPU ID and process threads are comma-separated lists, so we need to convert them to strings
        if self.gpu_id is None:
            gpu_id = "auto"
            process_threads = self.process_threads
        elif isinstance(self.gpu_id, list):
            gpu_id = ",".join([str(x) for x in self.gpu_id])
            process_threads = ",".join([str(self.process_threads) for _ in self.gpu_id])
        else:
            gpu_id = str(self.gpu_id)
            process_threads = str(self.process_threads)

        # Build args list
        args_list = [
            "-i",
            f"{self.input_path.resolve()}/",
            "-o",
            f"{self.output_path.resolve()}/",
            "-m",
            f"{self.model_path.resolve()}/",
            "-n",
            num_frame,
            "-s",
            f"{self.time_step:.5f}",
            "-g",
            gpu_id,
            "-j",
            f"{self.load_threads}:{process_threads}:{self.save_threads}",
        ]

        # Add flags if set
        if self.spatial_tta:
            args_list.append("-x")
        if self.temporal_tta:
            args_list.append("-z")
        if self.uhd:
            args_list.append("-u")
        if self.verbose:
            args_list.append("-v")

        # Convert all args to strings and return
        return [str(x) for x in args_list]