Spaces:
Running
Running
| from .wav2vec2 import Wav2Vec2Model | |
| from .whisper import WhisperLargeV3 | |
| import comfy.model_management | |
| import comfy.ops | |
| import comfy.utils | |
| import logging | |
| import torchaudio | |
| class AudioEncoderModel(): | |
| def __init__(self, config): | |
| self.load_device = comfy.model_management.text_encoder_device() | |
| offload_device = comfy.model_management.text_encoder_offload_device() | |
| self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) | |
| model_type = config.pop("model_type") | |
| model_config = dict(config) | |
| model_config.update({ | |
| "dtype": self.dtype, | |
| "device": offload_device, | |
| "operations": comfy.ops.manual_cast | |
| }) | |
| if model_type == "wav2vec2": | |
| self.model = Wav2Vec2Model(**model_config) | |
| elif model_type == "whisper3": | |
| self.model = WhisperLargeV3(**model_config) | |
| self.model.eval() | |
| self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) | |
| self.model_sample_rate = 16000 | |
| def load_sd(self, sd): | |
| return self.model.load_state_dict(sd, strict=False) | |
| def get_sd(self): | |
| return self.model.state_dict() | |
| def encode_audio(self, audio, sample_rate): | |
| comfy.model_management.load_model_gpu(self.patcher) | |
| audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate) | |
| out, all_layers = self.model(audio.to(self.load_device)) | |
| outputs = {} | |
| outputs["encoded_audio"] = out | |
| outputs["encoded_audio_all_layers"] = all_layers | |
| outputs["audio_samples"] = audio.shape[2] | |
| return outputs | |
| def load_audio_encoder_from_sd(sd, prefix=""): | |
| sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""}) | |
| if "encoder.layer_norm.bias" in sd: #wav2vec2 | |
| embed_dim = sd["encoder.layer_norm.bias"].shape[0] | |
| if embed_dim == 1024:# large | |
| config = { | |
| "model_type": "wav2vec2", | |
| "embed_dim": 1024, | |
| "num_heads": 16, | |
| "num_layers": 24, | |
| "conv_norm": True, | |
| "conv_bias": True, | |
| "do_normalize": True, | |
| "do_stable_layer_norm": True | |
| } | |
| elif embed_dim == 768: # base | |
| config = { | |
| "model_type": "wav2vec2", | |
| "embed_dim": 768, | |
| "num_heads": 12, | |
| "num_layers": 12, | |
| "conv_norm": False, | |
| "conv_bias": False, | |
| "do_normalize": False, # chinese-wav2vec2-base has this False | |
| "do_stable_layer_norm": False | |
| } | |
| else: | |
| raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim)) | |
| elif "model.encoder.embed_positions.weight" in sd: | |
| sd = comfy.utils.state_dict_prefix_replace(sd, {"model.": ""}) | |
| config = { | |
| "model_type": "whisper3", | |
| } | |
| else: | |
| raise RuntimeError("ERROR: audio encoder not supported.") | |
| audio_encoder = AudioEncoderModel(config) | |
| m, u = audio_encoder.load_sd(sd) | |
| if len(m) > 0: | |
| logging.warning("missing audio encoder: {}".format(m)) | |
| if len(u) > 0: | |
| logging.warning("unexpected audio encoder: {}".format(u)) | |
| return audio_encoder | |