Pie31415 commited on
Commit
43b5157
1 Parent(s): e2f5469
Files changed (1) hide show
  1. text_to_animation/model_flax.py +191 -0
text_to_animation/model_flax.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from enum import Enum
3
+ import gc
4
+ import numpy as np
5
+ import jax.numpy as jnp
6
+ import jax
7
+
8
+ from PIL import Image
9
+ from typing import List
10
+
11
+ from flax.training.common_utils import shard
12
+ from flax.jax_utils import replicate
13
+ from flax import jax_utils
14
+ import einops
15
+
16
+ from transformers import CLIPTokenizer, CLIPFeatureExtractor, FlaxCLIPTextModel
17
+ from diffusers import (
18
+ FlaxDDIMScheduler,
19
+ FlaxAutoencoderKL,
20
+ FlaxUNet2DConditionModel as VanillaFlaxUNet2DConditionModel,
21
+ )
22
+ from text_to_animation.models.unet_2d_condition_flax import FlaxUNet2DConditionModel
23
+ from diffusers import FlaxControlNetModel
24
+
25
+ from text_to_animation.pipelines.text_to_video_pipeline_flax import (
26
+ FlaxTextToVideoPipeline,
27
+ )
28
+
29
+ import utils.utils as utils
30
+ import utils.gradio_utils as gradio_utils
31
+ import os
32
+
33
+ on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR"
34
+
35
+ unshard = lambda x: einops.rearrange(x, "d b ... -> (d b) ...")
36
+
37
+
38
+ class ModelType(Enum):
39
+ Text2Video = 1
40
+ ControlNetPose = 2
41
+ StableDiffusion = 3
42
+
43
+
44
+ def replicate_devices(array):
45
+ return jnp.expand_dims(array, 0).repeat(jax.device_count(), 0)
46
+
47
+
48
+ class ControlAnimationModel:
49
+ def __init__(self, dtype, **kwargs):
50
+ self.dtype = dtype
51
+ self.rng = jax.random.PRNGKey(0)
52
+ self.pipe = None
53
+ self.model_type = None
54
+
55
+ self.states = {}
56
+ self.model_name = ""
57
+
58
+ def set_model(
59
+ self,
60
+ model_id: str,
61
+ **kwargs,
62
+ ):
63
+ if hasattr(self, "pipe") and self.pipe is not None:
64
+ del self.pipe
65
+ self.pipe = None
66
+ gc.collect()
67
+
68
+ controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
69
+ "fusing/stable-diffusion-v1-5-controlnet-openpose",
70
+ from_pt=True,
71
+ dtype=jnp.float16,
72
+ )
73
+
74
+ scheduler, scheduler_state = FlaxDDIMScheduler.from_pretrained(
75
+ model_id, subfolder="scheduler", from_pt=True
76
+ )
77
+ tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
78
+ feature_extractor = CLIPFeatureExtractor.from_pretrained(
79
+ model_id, subfolder="feature_extractor"
80
+ )
81
+ unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
82
+ model_id, subfolder="unet", from_pt=True, dtype=self.dtype
83
+ )
84
+ unet_vanilla = VanillaFlaxUNet2DConditionModel.from_config(
85
+ model_id, subfolder="unet", from_pt=True, dtype=self.dtype
86
+ )
87
+ vae, vae_params = FlaxAutoencoderKL.from_pretrained(
88
+ model_id, subfolder="vae", from_pt=True, dtype=self.dtype
89
+ )
90
+ text_encoder = FlaxCLIPTextModel.from_pretrained(
91
+ model_id, subfolder="text_encoder", from_pt=True, dtype=self.dtype
92
+ )
93
+ self.pipe = FlaxTextToVideoPipeline(
94
+ vae=vae,
95
+ text_encoder=text_encoder,
96
+ tokenizer=tokenizer,
97
+ unet=unet,
98
+ unet_vanilla=unet_vanilla,
99
+ controlnet=controlnet,
100
+ scheduler=scheduler,
101
+ safety_checker=None,
102
+ feature_extractor=feature_extractor,
103
+ )
104
+ self.params = {
105
+ "unet": unet_params,
106
+ "vae": vae_params,
107
+ "scheduler": scheduler_state,
108
+ "controlnet": controlnet_params,
109
+ "text_encoder": text_encoder.params,
110
+ }
111
+ self.p_params = jax_utils.replicate(self.params)
112
+ self.model_name = model_id
113
+
114
+ def generate_initial_frames(
115
+ self,
116
+ prompt: str,
117
+ video_path: str,
118
+ n_prompt: str = "",
119
+ num_imgs: int = 4,
120
+ resolution: int = 512,
121
+ model_id: str = "runwayml/stable-diffusion-v1-5",
122
+ ) -> List[Image.Image]:
123
+ self.set_model(model_id=model_id)
124
+
125
+ video_path = gradio_utils.motion_to_video_path(video_path)
126
+
127
+ added_prompt = "high quality, best quality, HD, clay stop-motion, claymation, HQ, masterpiece, art, smooth"
128
+ prompts = added_prompt + ", " + prompt
129
+
130
+ added_n_prompt = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly"
131
+ negative_prompts = added_n_prompt + ", " + n_prompt
132
+
133
+ video, fps = utils.prepare_video(
134
+ video_path, resolution, None, self.dtype, False, output_fps=4
135
+ )
136
+ control = utils.pre_process_pose(video, apply_pose_detect=False)
137
+
138
+ seeds = [seed for seed in jax.random.randint(self.rng, [num_imgs], 0, 65536)]
139
+ prngs = [jax.random.PRNGKey(seed) for seed in seeds]
140
+ print(seeds)
141
+ images = self.pipe.generate_starting_frames(
142
+ params=self.p_params,
143
+ prngs=prngs,
144
+ controlnet_image=control,
145
+ prompt=prompts,
146
+ neg_prompt=negative_prompts,
147
+ )
148
+
149
+ images = [np.array(images[i]) for i in range(images.shape[0])]
150
+
151
+ return images
152
+
153
+ def generate_video_from_frame(self, controlnet_video, prompt, seed, neg_prompt=""):
154
+ # generate a video using the seed provided
155
+ prng_seed = jax.random.PRNGKey(seed)
156
+ len_vid = controlnet_video.shape[0]
157
+ # print(f"Generating video from prompt {'<aardman> style '+ prompt}, with {controlnet_video.shape[0]} frames and prng seed {seed}")
158
+ added_prompt = "high quality, best quality, HD, clay stop-motion, claymation, HQ, masterpiece, art, smooth"
159
+ prompts = added_prompt + ", " + prompt
160
+
161
+ added_n_prompt = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly"
162
+ negative_prompts = added_n_prompt + ", " + neg_prompt
163
+
164
+ # prompt_ids = self.pipe.prepare_text_inputs(["aardman style "+ prompt]*len_vid)
165
+ # n_prompt_ids = self.pipe.prepare_text_inputs([neg_prompt]*len_vid)
166
+
167
+ prompt_ids = self.pipe.prepare_text_inputs([prompts] * len_vid)
168
+ n_prompt_ids = self.pipe.prepare_text_inputs([negative_prompts] * len_vid)
169
+ prng = replicate_devices(
170
+ prng_seed
171
+ ) # jax.random.split(prng, jax.device_count())
172
+ image = replicate_devices(controlnet_video)
173
+ prompt_ids = replicate_devices(prompt_ids)
174
+ n_prompt_ids = replicate_devices(n_prompt_ids)
175
+ motion_field_strength_x = replicate_devices(jnp.array(3))
176
+ motion_field_strength_y = replicate_devices(jnp.array(4))
177
+ smooth_bg_strength = replicate_devices(jnp.array(0.8))
178
+ vid = (
179
+ self.pipe(
180
+ image=image,
181
+ prompt_ids=prompt_ids,
182
+ neg_prompt_ids=n_prompt_ids,
183
+ params=self.p_params,
184
+ prng_seed=prng,
185
+ jit=True,
186
+ smooth_bg_strength=smooth_bg_strength,
187
+ motion_field_strength_x=motion_field_strength_x,
188
+ motion_field_strength_y=motion_field_strength_y,
189
+ ).images
190
+ )[0]
191
+ return utils.create_gif(np.array(vid), 4, path=None, watermark=None)