Spaces:
Running
Running
| import torch | |
| from ..models import SDUNet, SDMotionModel, SDXLUNet, SDXLMotionModel | |
| from ..models.sd_unet import PushBlock, PopBlock | |
| from ..controlnets import MultiControlNetManager | |
| def lets_dance( | |
| unet: SDUNet, | |
| motion_modules: SDMotionModel = None, | |
| controlnet: MultiControlNetManager = None, | |
| sample = None, | |
| timestep = None, | |
| encoder_hidden_states = None, | |
| ipadapter_kwargs_list = {}, | |
| controlnet_frames = None, | |
| unet_batch_size = 1, | |
| controlnet_batch_size = 1, | |
| cross_frame_attention = False, | |
| tiled=False, | |
| tile_size=64, | |
| tile_stride=32, | |
| device = "cuda", | |
| vram_limit_level = 0, | |
| ): | |
| # 1. ControlNet | |
| # This part will be repeated on overlapping frames if animatediff_batch_size > animatediff_stride. | |
| # I leave it here because I intend to do something interesting on the ControlNets. | |
| controlnet_insert_block_id = 30 | |
| if controlnet is not None and controlnet_frames is not None: | |
| res_stacks = [] | |
| # process controlnet frames with batch | |
| for batch_id in range(0, sample.shape[0], controlnet_batch_size): | |
| batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0]) | |
| res_stack = controlnet( | |
| sample[batch_id: batch_id_], | |
| timestep, | |
| encoder_hidden_states[batch_id: batch_id_], | |
| controlnet_frames[:, batch_id: batch_id_], | |
| tiled=tiled, tile_size=tile_size, tile_stride=tile_stride | |
| ) | |
| if vram_limit_level >= 1: | |
| res_stack = [res.cpu() for res in res_stack] | |
| res_stacks.append(res_stack) | |
| # concat the residual | |
| additional_res_stack = [] | |
| for i in range(len(res_stacks[0])): | |
| res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0) | |
| additional_res_stack.append(res) | |
| else: | |
| additional_res_stack = None | |
| # 2. time | |
| time_emb = unet.time_proj(timestep[None]).to(sample.dtype) | |
| time_emb = unet.time_embedding(time_emb) | |
| # 3. pre-process | |
| height, width = sample.shape[2], sample.shape[3] | |
| hidden_states = unet.conv_in(sample) | |
| text_emb = encoder_hidden_states | |
| res_stack = [hidden_states.cpu() if vram_limit_level>=1 else hidden_states] | |
| # 4. blocks | |
| for block_id, block in enumerate(unet.blocks): | |
| # 4.1 UNet | |
| if isinstance(block, PushBlock): | |
| hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) | |
| if vram_limit_level>=1: | |
| res_stack[-1] = res_stack[-1].cpu() | |
| elif isinstance(block, PopBlock): | |
| if vram_limit_level>=1: | |
| res_stack[-1] = res_stack[-1].to(device) | |
| hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) | |
| else: | |
| hidden_states_input = hidden_states | |
| hidden_states_output = [] | |
| for batch_id in range(0, sample.shape[0], unet_batch_size): | |
| batch_id_ = min(batch_id + unet_batch_size, sample.shape[0]) | |
| hidden_states, _, _, _ = block( | |
| hidden_states_input[batch_id: batch_id_], | |
| time_emb, | |
| text_emb[batch_id: batch_id_], | |
| res_stack, | |
| cross_frame_attention=cross_frame_attention, | |
| ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {}), | |
| tiled=tiled, tile_size=tile_size, tile_stride=tile_stride | |
| ) | |
| hidden_states_output.append(hidden_states) | |
| hidden_states = torch.concat(hidden_states_output, dim=0) | |
| # 4.2 AnimateDiff | |
| if motion_modules is not None: | |
| if block_id in motion_modules.call_block_id: | |
| motion_module_id = motion_modules.call_block_id[block_id] | |
| hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id]( | |
| hidden_states, time_emb, text_emb, res_stack, | |
| batch_size=1 | |
| ) | |
| # 4.3 ControlNet | |
| if block_id == controlnet_insert_block_id and additional_res_stack is not None: | |
| hidden_states += additional_res_stack.pop().to(device) | |
| if vram_limit_level>=1: | |
| res_stack = [(res.to(device) + additional_res.to(device)).cpu() for res, additional_res in zip(res_stack, additional_res_stack)] | |
| else: | |
| res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)] | |
| # 5. output | |
| hidden_states = unet.conv_norm_out(hidden_states) | |
| hidden_states = unet.conv_act(hidden_states) | |
| hidden_states = unet.conv_out(hidden_states) | |
| return hidden_states | |
| def lets_dance_xl( | |
| unet: SDXLUNet, | |
| motion_modules: SDXLMotionModel = None, | |
| controlnet: MultiControlNetManager = None, | |
| sample = None, | |
| add_time_id = None, | |
| add_text_embeds = None, | |
| timestep = None, | |
| encoder_hidden_states = None, | |
| ipadapter_kwargs_list = {}, | |
| controlnet_frames = None, | |
| unet_batch_size = 1, | |
| controlnet_batch_size = 1, | |
| cross_frame_attention = False, | |
| tiled=False, | |
| tile_size=64, | |
| tile_stride=32, | |
| device = "cuda", | |
| vram_limit_level = 0, | |
| ): | |
| # 2. time | |
| t_emb = unet.time_proj(timestep[None]).to(sample.dtype) | |
| t_emb = unet.time_embedding(t_emb) | |
| time_embeds = unet.add_time_proj(add_time_id) | |
| time_embeds = time_embeds.reshape((add_text_embeds.shape[0], -1)) | |
| add_embeds = torch.concat([add_text_embeds, time_embeds], dim=-1) | |
| add_embeds = add_embeds.to(sample.dtype) | |
| add_embeds = unet.add_time_embedding(add_embeds) | |
| time_emb = t_emb + add_embeds | |
| # 3. pre-process | |
| height, width = sample.shape[2], sample.shape[3] | |
| hidden_states = unet.conv_in(sample) | |
| text_emb = encoder_hidden_states | |
| res_stack = [hidden_states] | |
| # 4. blocks | |
| for block_id, block in enumerate(unet.blocks): | |
| hidden_states, time_emb, text_emb, res_stack = block( | |
| hidden_states, time_emb, text_emb, res_stack, | |
| tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, | |
| ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {}) | |
| ) | |
| # 4.2 AnimateDiff | |
| if motion_modules is not None: | |
| if block_id in motion_modules.call_block_id: | |
| motion_module_id = motion_modules.call_block_id[block_id] | |
| hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id]( | |
| hidden_states, time_emb, text_emb, res_stack, | |
| batch_size=1 | |
| ) | |
| # 5. output | |
| hidden_states = unet.conv_norm_out(hidden_states) | |
| hidden_states = unet.conv_act(hidden_states) | |
| hidden_states = unet.conv_out(hidden_states) | |
| return hidden_states |