File size: 3,303 Bytes
8b7a3d1
 
 
 
c0a7c3c
 
8b7a3d1
 
c0a7c3c
8b7a3d1
 
60ae079
c0a7c3c
8b7a3d1
 
ecfdc8b
c0a7c3c
 
 
 
8b7a3d1
 
 
 
 
ecfdc8b
c0a7c3c
8b7a3d1
 
c0a7c3c
8b7a3d1
 
 
 
 
 
c0a7c3c
 
8b7a3d1
 
ecfdc8b
8b7a3d1
ecfdc8b
8b7a3d1
 
 
 
 
c0a7c3c
 
8b7a3d1
 
c0a7c3c
 
8b7a3d1
c0a7c3c
765072b
ecfdc8b
 
 
 
 
c0a7c3c
60ae079
 
c0a7c3c
 
8b7a3d1
 
 
c0a7c3c
8b7a3d1
c0a7c3c
 
8b7a3d1
 
 
 
 
ecfdc8b
8b7a3d1
c0a7c3c
8b7a3d1
 
 
 
c0a7c3c
 
 
8b7a3d1
 
 
 
c0a7c3c
ecfdc8b
c0a7c3c
 
ecfdc8b
c0a7c3c
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from __future__ import annotations

import gc
import pathlib
import sys
import tempfile

import gradio as gr
import imageio
import PIL.Image
import torch
from diffusers.utils.import_utils import is_xformers_available
from einops import rearrange
from huggingface_hub import ModelCard

sys.path.append("Tune-A-Video")

from tuneavideo.models.unet import UNet3DConditionModel
from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline


class InferencePipeline:
    def __init__(self, hf_token: str | None = None):
        self.hf_token = hf_token
        self.pipe = None
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.model_id = None

    def clear(self) -> None:
        self.model_id = None
        del self.pipe
        self.pipe = None
        torch.cuda.empty_cache()
        gc.collect()

    @staticmethod
    def check_if_model_is_local(model_id: str) -> bool:
        return pathlib.Path(model_id).exists()

    @staticmethod
    def get_model_card(model_id: str, hf_token: str | None = None) -> ModelCard:
        if InferencePipeline.check_if_model_is_local(model_id):
            card_path = (pathlib.Path(model_id) / "README.md").as_posix()
        else:
            card_path = model_id
        return ModelCard.load(card_path, token=hf_token)

    @staticmethod
    def get_base_model_info(model_id: str, hf_token: str | None = None) -> str:
        card = InferencePipeline.get_model_card(model_id, hf_token)
        return card.data.base_model

    def load_pipe(self, model_id: str) -> None:
        if model_id == self.model_id:
            return
        base_model_id = self.get_base_model_info(model_id, self.hf_token)
        unet = UNet3DConditionModel.from_pretrained(
            model_id, subfolder="unet", torch_dtype=torch.float16, use_auth_token=self.hf_token
        )
        pipe = TuneAVideoPipeline.from_pretrained(
            base_model_id, unet=unet, torch_dtype=torch.float16, use_auth_token=self.hf_token
        )
        pipe = pipe.to(self.device)
        if is_xformers_available():
            pipe.unet.enable_xformers_memory_efficient_attention()
        self.pipe = pipe
        self.model_id = model_id  # type: ignore

    def run(
        self,
        model_id: str,
        prompt: str,
        video_length: int,
        fps: int,
        seed: int,
        n_steps: int,
        guidance_scale: float,
    ) -> PIL.Image.Image:
        if not torch.cuda.is_available():
            raise gr.Error("CUDA is not available.")

        self.load_pipe(model_id)

        generator = torch.Generator(device=self.device).manual_seed(seed)
        out = self.pipe(
            prompt,
            video_length=video_length,
            width=512,
            height=512,
            num_inference_steps=n_steps,
            guidance_scale=guidance_scale,
            generator=generator,
        )  # type: ignore

        frames = rearrange(out.videos[0], "c t h w -> t h w c")
        frames = (frames * 255).to(torch.uint8).numpy()

        out_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
        writer = imageio.get_writer(out_file.name, fps=fps)
        for frame in frames:
            writer.append_data(frame)
        writer.close()

        return out_file.name