ZIT-Controlnet / videox_fun /models /wan_audio_encoder.py
Alexander Bagus
22
be751d2
# Modified from https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/s2v/audio_encoder.py
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import math
import librosa
import numpy as np
import torch
import torch.nn.functional as F
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from diffusers.configuration_utils import ConfigMixin
from diffusers.loaders.single_file_model import FromOriginalModelMixin
from diffusers.models.modeling_utils import ModelMixin
def get_sample_indices(original_fps,
total_frames,
target_fps,
num_sample,
fixed_start=None):
required_duration = num_sample / target_fps
required_origin_frames = int(np.ceil(required_duration * original_fps))
if required_duration > total_frames / original_fps:
raise ValueError("required_duration must be less than video length")
if not fixed_start is None and fixed_start >= 0:
start_frame = fixed_start
else:
max_start = total_frames - required_origin_frames
if max_start < 0:
raise ValueError("video length is too short")
start_frame = np.random.randint(0, max_start + 1)
start_time = start_frame / original_fps
end_time = start_time + required_duration
time_points = np.linspace(start_time, end_time, num_sample, endpoint=False)
frame_indices = np.round(np.array(time_points) * original_fps).astype(int)
frame_indices = np.clip(frame_indices, 0, total_frames - 1)
return frame_indices
def linear_interpolation(features, input_fps, output_fps, output_len=None):
"""
features: shape=[1, T, 512]
input_fps: fps for audio, f_a
output_fps: fps for video, f_m
output_len: video length
"""
features = features.transpose(1, 2) # [1, 512, T]
seq_len = features.shape[2] / float(input_fps) # T/f_a
if output_len is None:
output_len = int(seq_len * output_fps) # f_m*T/f_a
output_features = F.interpolate(
features, size=output_len, align_corners=True,
mode='linear') # [1, 512, output_len]
return output_features.transpose(1, 2) # [1, output_len, 512]
class WanAudioEncoder(ModelMixin, ConfigMixin, FromOriginalModelMixin):
def __init__(self, pretrained_model_path="facebook/wav2vec2-base-960h", device='cpu'):
super(WanAudioEncoder, self).__init__()
# load pretrained model
self.processor = Wav2Vec2Processor.from_pretrained(pretrained_model_path)
self.model = Wav2Vec2ForCTC.from_pretrained(pretrained_model_path)
self.model = self.model.to(device)
self.video_rate = 30
def extract_audio_feat(self,
audio_path,
return_all_layers=False,
dtype=torch.float32):
audio_input, sample_rate = librosa.load(audio_path, sr=16000)
input_values = self.processor(
audio_input, sampling_rate=sample_rate, return_tensors="pt"
).input_values
# INFERENCE
# retrieve logits & take argmax
res = self.model(
input_values.to(self.model.device), output_hidden_states=True)
if return_all_layers:
feat = torch.cat(res.hidden_states)
else:
feat = res.hidden_states[-1]
feat = linear_interpolation(
feat, input_fps=50, output_fps=self.video_rate)
z = feat.to(dtype) # Encoding for the motion
return z
def extract_audio_feat_without_file_load(self, audio_input, sample_rate, return_all_layers=False, dtype=torch.float32):
input_values = self.processor(
audio_input, sampling_rate=sample_rate, return_tensors="pt"
).input_values
# INFERENCE
# retrieve logits & take argmax
res = self.model(
input_values.to(self.model.device), output_hidden_states=True)
if return_all_layers:
feat = torch.cat(res.hidden_states)
else:
feat = res.hidden_states[-1]
feat = linear_interpolation(
feat, input_fps=50, output_fps=self.video_rate)
z = feat.to(dtype) # Encoding for the motion
return z
def get_audio_embed_bucket(self,
audio_embed,
stride=2,
batch_frames=12,
m=2):
num_layers, audio_frame_num, audio_dim = audio_embed.shape
if num_layers > 1:
return_all_layers = True
else:
return_all_layers = False
min_batch_num = int(audio_frame_num / (batch_frames * stride)) + 1
bucket_num = min_batch_num * batch_frames
batch_idx = [stride * i for i in range(bucket_num)]
batch_audio_eb = []
for bi in batch_idx:
if bi < audio_frame_num:
audio_sample_stride = 2
chosen_idx = list(
range(bi - m * audio_sample_stride,
bi + (m + 1) * audio_sample_stride,
audio_sample_stride))
chosen_idx = [0 if c < 0 else c for c in chosen_idx]
chosen_idx = [
audio_frame_num - 1 if c >= audio_frame_num else c
for c in chosen_idx
]
if return_all_layers:
frame_audio_embed = audio_embed[:, chosen_idx].flatten(
start_dim=-2, end_dim=-1)
else:
frame_audio_embed = audio_embed[0][chosen_idx].flatten()
else:
frame_audio_embed = \
torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \
else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)
batch_audio_eb.append(frame_audio_embed)
batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb],
dim=0)
return batch_audio_eb, min_batch_num
def get_audio_embed_bucket_fps(self,
audio_embed,
fps=16,
batch_frames=81,
m=0):
num_layers, audio_frame_num, audio_dim = audio_embed.shape
if num_layers > 1:
return_all_layers = True
else:
return_all_layers = False
scale = self.video_rate / fps
min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1
bucket_num = min_batch_num * batch_frames
padd_audio_num = math.ceil(min_batch_num * batch_frames / fps *
self.video_rate) - audio_frame_num
batch_idx = get_sample_indices(
original_fps=self.video_rate,
total_frames=audio_frame_num + padd_audio_num,
target_fps=fps,
num_sample=bucket_num,
fixed_start=0)
batch_audio_eb = []
audio_sample_stride = int(self.video_rate / fps)
for bi in batch_idx:
if bi < audio_frame_num:
chosen_idx = list(
range(bi - m * audio_sample_stride,
bi + (m + 1) * audio_sample_stride,
audio_sample_stride))
chosen_idx = [0 if c < 0 else c for c in chosen_idx]
chosen_idx = [
audio_frame_num - 1 if c >= audio_frame_num else c
for c in chosen_idx
]
if return_all_layers:
frame_audio_embed = audio_embed[:, chosen_idx].flatten(
start_dim=-2, end_dim=-1)
else:
frame_audio_embed = audio_embed[0][chosen_idx].flatten()
else:
frame_audio_embed = \
torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \
else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)
batch_audio_eb.append(frame_audio_embed)
batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb],
dim=0)
return batch_audio_eb, min_batch_num