File size: 3,433 Bytes
fd5f698
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
from typing import List
from cog import BasePredictor, Input, Path
import subprocess
import shutil

MODEL_CACHE = "model-cache"

class Predictor(BasePredictor):
    def setup(self):
        pass

    def predict(
        self,
        prompt: str = Input(
            description="Input prompt", default="An astronaut riding a horse"
        ),
        negative_prompt: str = Input(
            description="Negative prompt", default=None
        ),
        init_video: Path = Input(
            description="URL of the initial video (optional)", default=None
        ),
        init_weight: float = Input(
            description="Strength of init_video", default=0.5
        ),
        num_frames: int = Input(
            description="Number of frames for the output video", default=24
        ),
        num_inference_steps: int = Input(
            description="Number of denoising steps", ge=1, le=500, default=50
        ),
        width: int = Input(
            description="Width of the output video", ge=256, default=576
        ),
        height: int = Input(
            description="Height of the output video", ge=256, default=320
        ),
        guidance_scale: float = Input(
            description="Guidance scale", ge=1.0, le=100.0, default=7.5
        ),
        fps: int = Input(description="fps for the output video", default=8),
        model: str = Input(
            description="Model to use", default="xl", choices=["xl", "576w", "potat1", "animov-512x"]
        ),
        batch_size: int = Input(description="Batch size", default=1, ge=1),
        remove_watermark: bool = Input(
            description="Remove watermark", default=False
        ),
        seed: int = Input(
            description="Random seed. Leave blank to randomize the seed", default=None
        ),
    ) -> List[Path]:
        if seed is None:
            seed = int.from_bytes(os.urandom(2), "big")
        print(f"Using seed: {seed}")

        shutil.rmtree("output", ignore_errors=True)
        os.makedirs("output", exist_ok=True)

        args = {
            "prompt": prompt,
            "negative_prompt": negative_prompt,
            "batch_size": batch_size,
            "num_frames": num_frames,
            "num_steps": num_inference_steps,
            "seed": seed,
            "guidance-scale": guidance_scale,
            "width": width,
            "height": height,
            "fps": fps,
            "device": "cuda",
            "output_dir": "output",
            "remove-watermark": remove_watermark,
        }

        args['model'] = MODEL_CACHE + "/" + model

        if init_video is not None:
            # for some reason I need to copy the file to make it work
            if os.path.exists("input.mp4"):
                os.unlink("input.mp4")
            shutil.copy(init_video, "input.mp4")

            args["init-video"] = "input.mp4"
            args["init-weight"] = init_weight
            print("init video", os.stat("input.mp4").st_size)

        cmd = ["python", "inference.py"]
        for k, v in args.items():
            if not v is None:
                cmd.append(f"--{k}")
                cmd.append(str(v))
        subprocess.check_call(cmd)
        # outputs = inference.run(**args)

        outputs = []
        for f in os.listdir("output"):
            if f.endswith(".mp4"):
                outputs.append(Path(os.path.join("output", f)))
        return outputs