|
from text_to_video.tuneavideo.pipelines.pipeline_text_to_video import TuneAVideoPipeline |
|
from text_to_video.tuneavideo.models.unet import UNet3DConditionModel |
|
import torch |
|
from diffusers import AutoencoderKL, DDIMScheduler |
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
|
|
|
|
class TextToVideo(): |
|
|
|
|
|
def __init__(self,sd_path = None,motion_field_strength = 12, video_length = 8,t0 = 881, t1=941,use_cf_attn=True,use_motion_field=True) -> None: |
|
g = torch.Generator(device='cuda') |
|
g.manual_seed(22) |
|
self.g = g |
|
|
|
assert sd_path is not None |
|
|
|
print(f"Loading model SD-Net model file from {sd_path}") |
|
|
|
self.dtype = torch.float16 |
|
noise_scheduler = DDIMScheduler.from_pretrained( |
|
sd_path, subfolder="scheduler") |
|
tokenizer = CLIPTokenizer.from_pretrained( |
|
sd_path, subfolder="tokenizer") |
|
text_encoder = CLIPTextModel.from_pretrained( |
|
sd_path, subfolder="text_encoder") |
|
vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae") |
|
|
|
|
|
unet = UNet3DConditionModel.from_pretrained_2d( |
|
sd_path, subfolder="unet", use_cf_attn=use_cf_attn) |
|
self.pipe = TuneAVideoPipeline( |
|
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, |
|
scheduler=DDIMScheduler.from_pretrained( |
|
sd_path, subfolder="scheduler") |
|
).to('cuda').to(self.dtype) |
|
|
|
noise_scheduler.set_timesteps(50, device='cuda') |
|
|
|
|
|
self.t0 = t0 |
|
|
|
|
|
|
|
self.t1 = t1 |
|
|
|
self.use_foreground_motion_field = False |
|
|
|
|
|
self.motion_field_strength = motion_field_strength |
|
self.use_motion_field = use_motion_field |
|
self.smooth_bg = False |
|
self.smooth_bg_strength = 0.4 |
|
|
|
|
|
self.video_length = video_length |
|
|
|
def inference(self, prompt): |
|
|
|
prompt_compute = [prompt] |
|
xT = torch.randn((1, 4, 1, 64, 64), dtype=self.dtype, device="cuda") |
|
result = self.pipe(prompt_compute, |
|
video_length=self.video_length, |
|
height=512, |
|
width=512, |
|
num_inference_steps=50, |
|
guidance_scale=7.5, |
|
guidance_stop_step=1.0, |
|
t0=self.t0, |
|
t1=self.t1, |
|
xT=xT, |
|
use_foreground_motion_field=self.use_foreground_motion_field, |
|
motion_field_strength=self.motion_field_strength, |
|
use_motion_field=self.use_motion_field, |
|
smooth_bg=self.smooth_bg, |
|
smooth_bg_strength=self.smooth_bg_strength, |
|
generator=self.g) |
|
|
|
return result.videos[0] |
|
|