File size: 3,303 Bytes
8b7a3d1 c0a7c3c 8b7a3d1 c0a7c3c 8b7a3d1 60ae079 c0a7c3c 8b7a3d1 ecfdc8b c0a7c3c 8b7a3d1 ecfdc8b c0a7c3c 8b7a3d1 c0a7c3c 8b7a3d1 c0a7c3c 8b7a3d1 ecfdc8b 8b7a3d1 ecfdc8b 8b7a3d1 c0a7c3c 8b7a3d1 c0a7c3c 8b7a3d1 c0a7c3c 765072b ecfdc8b c0a7c3c 60ae079 c0a7c3c 8b7a3d1 c0a7c3c 8b7a3d1 c0a7c3c 8b7a3d1 ecfdc8b 8b7a3d1 c0a7c3c 8b7a3d1 c0a7c3c 8b7a3d1 c0a7c3c ecfdc8b c0a7c3c ecfdc8b c0a7c3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
from __future__ import annotations
import gc
import pathlib
import sys
import tempfile
import gradio as gr
import imageio
import PIL.Image
import torch
from diffusers.utils.import_utils import is_xformers_available
from einops import rearrange
from huggingface_hub import ModelCard
sys.path.append("Tune-A-Video")
from tuneavideo.models.unet import UNet3DConditionModel
from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline
class InferencePipeline:
def __init__(self, hf_token: str | None = None):
self.hf_token = hf_token
self.pipe = None
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.model_id = None
def clear(self) -> None:
self.model_id = None
del self.pipe
self.pipe = None
torch.cuda.empty_cache()
gc.collect()
@staticmethod
def check_if_model_is_local(model_id: str) -> bool:
return pathlib.Path(model_id).exists()
@staticmethod
def get_model_card(model_id: str, hf_token: str | None = None) -> ModelCard:
if InferencePipeline.check_if_model_is_local(model_id):
card_path = (pathlib.Path(model_id) / "README.md").as_posix()
else:
card_path = model_id
return ModelCard.load(card_path, token=hf_token)
@staticmethod
def get_base_model_info(model_id: str, hf_token: str | None = None) -> str:
card = InferencePipeline.get_model_card(model_id, hf_token)
return card.data.base_model
def load_pipe(self, model_id: str) -> None:
if model_id == self.model_id:
return
base_model_id = self.get_base_model_info(model_id, self.hf_token)
unet = UNet3DConditionModel.from_pretrained(
model_id, subfolder="unet", torch_dtype=torch.float16, use_auth_token=self.hf_token
)
pipe = TuneAVideoPipeline.from_pretrained(
base_model_id, unet=unet, torch_dtype=torch.float16, use_auth_token=self.hf_token
)
pipe = pipe.to(self.device)
if is_xformers_available():
pipe.unet.enable_xformers_memory_efficient_attention()
self.pipe = pipe
self.model_id = model_id # type: ignore
def run(
self,
model_id: str,
prompt: str,
video_length: int,
fps: int,
seed: int,
n_steps: int,
guidance_scale: float,
) -> PIL.Image.Image:
if not torch.cuda.is_available():
raise gr.Error("CUDA is not available.")
self.load_pipe(model_id)
generator = torch.Generator(device=self.device).manual_seed(seed)
out = self.pipe(
prompt,
video_length=video_length,
width=512,
height=512,
num_inference_steps=n_steps,
guidance_scale=guidance_scale,
generator=generator,
) # type: ignore
frames = rearrange(out.videos[0], "c t h w -> t h w c")
frames = (frames * 255).to(torch.uint8).numpy()
out_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
writer = imageio.get_writer(out_file.name, fps=fps)
for frame in frames:
writer.append_data(frame)
writer.close()
return out_file.name
|