File size: 5,391 Bytes
ddb7519
f6f9150
ddb7519
 
 
5a57f66
d1f4ed8
f45086c
f6f9150
ddb7519
 
 
d1f4ed8
f6f9150
 
 
 
 
 
 
ddb7519
d1f4ed8
ddb7519
5a57f66
 
 
eea7935
0f38a31
 
 
eea7935
ddb7519
1a070e2
 
f45086c
 
5a57f66
 
f45086c
5a57f66
f45086c
5a57f66
ddb7519
 
 
 
 
 
f45086c
ddb7519
d1f4ed8
 
 
f45086c
 
d1f4ed8
ddb7519
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1f4ed8
ddb7519
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6f9150
 
 
 
 
 
 
a74fae4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# this is the huggingface handler file

from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from omegaconf import OmegaConf
from huggingface_hub import hf_hub_download, try_to_load_from_cache

import os

from diffusers.utils.import_utils import is_xformers_available
from typing import Any
import torch
from einops import rearrange

from animatediff.models.unet import UNet3DConditionModel
from animatediff.pipelines.pipeline_animation import AnimationPipeline
from animatediff.utils.util import save_videos_grid
from animatediff.utils.util import load_weights


class EndpointHandler():
    def __init__(self, model_path: str = "bluestarburst/AnimateDiff-SceneFusion"):
        
        # inference_config_path = "configs/inference/inference-v3.yaml"
        inference_config_path = hf_hub_download(repo_id="bluestarburst/AnimateDiff-SceneFusion", filename="configs/inference/inference-v3.yaml")
        print(inference_config_path)
        
        inference_config = OmegaConf.load(inference_config_path)
        
        # inference_config = {'unet_additional_kwargs': {'unet_use_cross_frame_attention': False, 'unet_use_temporal_attention': False, 'use_motion_module': True, 'motion_module_resolutions': [1, 2, 4, 8], 'motion_module_mid_block': False, 'motion_module_decoder_only': False, 'motion_module_type': 'Vanilla', 'motion_module_kwargs': {'num_attention_heads': 8, 'num_transformer_block': 1, 'attention_block_types': ['Temporal_Self', 'Temporal_Self'], 'temporal_position_encoding': True, 'temporal_position_encoding_max_len': 24, 'temporal_attention_dim_div': 1}}, 'noise_scheduler_kwargs': {'DDIMScheduler': {'num_train_timesteps': 1000, 'beta_start': 0.00085, 'beta_end': 0.012, 'beta_schedule': 'linear', 'steps_offset': 1, 'clip_sample': False}, 'EulerAncestralDiscreteScheduler': {'num_train_timesteps': 1000, 'beta_start': 0.00085, 'beta_end': 0.012, 'beta_schedule': 'linear'}, 'KDPM2AncestralDiscreteScheduler': {'num_train_timesteps': 1000, 'beta_start': 0.00085, 'beta_end': 0.012, 'beta_schedule': 'linear'}}}
        
        ### >>> create validation pipeline >>> ###
        tokenizer    = CLIPTokenizer.from_pretrained(model_path, subfolder="models/StableDiffusion/tokenizer")
        text_encoder = CLIPTextModel.from_pretrained(model_path, subfolder="models/StableDiffusion/text_encoder")
        vae          = AutoencoderKL.from_pretrained(model_path, subfolder="models/StableDiffusion/vae")         
        
        unet_model_path = hf_hub_download(repo_id="bluestarburst/AnimateDiff-SceneFusion", filename="models/StableDiffusion/unet/diffusion_pytorch_model.bin")
        unet_config_path = hf_hub_download(repo_id="bluestarburst/AnimateDiff-SceneFusion", filename="models/StableDiffusion/unet/config.json")        

        print(unet_model_path)

        unet         = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path=unet_model_path, unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs), config_path=unet_config_path)

        if is_xformers_available(): unet.enable_xformers_memory_efficient_attention()
        else: assert False

        self.pipeline = AnimationPipeline(
            vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
            scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs.DDIMScheduler))
        ).to("cuda")
        
        # huggingface download motion module from bluestarburst/AnimateDiff-SceneFusion/models/Motion_Module/mm_sd_v15.ckpt

        motion_module = "models/MotionModule/mm_sd_v15.ckpt"
        hf_hub_download(repo_id="bluestarburst/AnimateDiff-SceneFusion", filename="models/Motion_Module/mm_sd_v15.ckpt")
        

        self.pipeline = load_weights(
            self.pipeline,
            # motion module
            motion_module_path         = motion_module,
            motion_module_lora_configs = [],
            # image layers
            dreambooth_model_path      = "",
            lora_model_path            = "",
            lora_alpha                 = 0.8,
        ).to("cuda")
    
    def __call__(self, prompt, negative_prompt, steps, guidance_scale):
        """
        __call__ method will be called once per request. This can be used to
        run inference.
        """
        vids = self.pipeline(
            prompt=prompt, 
            negative_prompt=negative_prompt, 
            num_inference_steps=steps, 
            guidance_scale=guidance_scale,
            width= 256,
            height= 256,
            video_length= 5,
            ).videos
        
        videos = rearrange(vids, "b c t h w -> t b c h w")
        n_rows=6
        fps=1
        loop = True
        rescale=False
        outputs = []
        for x in videos:
            x = torchvision.utils.make_grid(x, nrow=n_rows)
            x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
            if rescale:
                x = (x + 1.0) / 2.0  # -1,1 -> 0,1
            x = (x * 255).numpy().astype(np.uint8)
            outputs.append(x)
            
        # imageio.mimsave(path, outputs, fps=fps)
        
        # return a gif file as bytes
        return outputs
    

# This is the entry point for the serverless function.
# This function will be called during inference time.


# new_handler = EndpointHandler()