File size: 3,183 Bytes
b944fa1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
from text_to_video.tuneavideo.pipelines.pipeline_text_to_video import TuneAVideoPipeline
from text_to_video.tuneavideo.models.unet import UNet3DConditionModel
import torch
from diffusers import AutoencoderKL, DDIMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
class TextToVideo():
def __init__(self,sd_path = None,motion_field_strength = 12, video_length = 8,t0 = 881, t1=941,use_cf_attn=True,use_motion_field=True) -> None:
g = torch.Generator(device='cuda')
g.manual_seed(22)
self.g = g
assert sd_path is not None
print(f"Loading model SD-Net model file from {sd_path}")
self.dtype = torch.float16
noise_scheduler = DDIMScheduler.from_pretrained(
sd_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(
sd_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(
sd_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae")
unet = UNet3DConditionModel.from_pretrained_2d(
sd_path, subfolder="unet", use_cf_attn=use_cf_attn)
self.pipe = TuneAVideoPipeline(
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
scheduler=DDIMScheduler.from_pretrained(
sd_path, subfolder="scheduler")
).to('cuda').to(self.dtype)
noise_scheduler.set_timesteps(50, device='cuda')
# t0 parameter (DDIM backward from noise until t0)
self.t0 = t0
# from t0 apply DDPM forward until t1
self.t1 = t1
self.use_foreground_motion_field = False # apply motion field on forground object (not used)
# strength of motion field (delta_x = delta_y in Sect 3.3.1)
self.motion_field_strength = motion_field_strength
self.use_motion_field = use_motion_field # apply general motion field
self.smooth_bg = False # temporally smooth background
self.smooth_bg_strength = 0.4 # alpha = (1-self.smooth_bg_strength) in Eq (9)
self.video_length = video_length
def inference(self, prompt):
prompt_compute = [prompt]
xT = torch.randn((1, 4, 1, 64, 64), dtype=self.dtype, device="cuda")
result = self.pipe(prompt_compute,
video_length=self.video_length,
height=512,
width=512,
num_inference_steps=50,
guidance_scale=7.5,
guidance_stop_step=1.0,
t0=self.t0,
t1=self.t1,
xT=xT,
use_foreground_motion_field=self.use_foreground_motion_field,
motion_field_strength=self.motion_field_strength,
use_motion_field=self.use_motion_field,
smooth_bg=self.smooth_bg,
smooth_bg_strength=self.smooth_bg_strength,
generator=self.g)
return result.videos[0]
|