|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
from cosyvoice.transformer.activation import Swish |
|
from cosyvoice.transformer.subsampling import ( |
|
LinearNoSubsampling, |
|
EmbedinigNoSubsampling, |
|
Conv1dSubsampling2, |
|
Conv2dSubsampling4, |
|
Conv2dSubsampling6, |
|
Conv2dSubsampling8, |
|
) |
|
from cosyvoice.transformer.embedding import (PositionalEncoding, |
|
RelPositionalEncoding, |
|
WhisperPositionalEncoding, |
|
LearnablePositionalEncoding, |
|
NoPositionalEncoding) |
|
from cosyvoice.transformer.attention import (MultiHeadedAttention, |
|
RelPositionMultiHeadedAttention) |
|
from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding |
|
from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling |
|
from cosyvoice.llm.llm import TransformerLM, Qwen2LM |
|
from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec |
|
from cosyvoice.hifigan.generator import HiFTGenerator |
|
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model |
|
|
|
|
|
COSYVOICE_ACTIVATION_CLASSES = { |
|
"hardtanh": torch.nn.Hardtanh, |
|
"tanh": torch.nn.Tanh, |
|
"relu": torch.nn.ReLU, |
|
"selu": torch.nn.SELU, |
|
"swish": getattr(torch.nn, "SiLU", Swish), |
|
"gelu": torch.nn.GELU, |
|
} |
|
|
|
COSYVOICE_SUBSAMPLE_CLASSES = { |
|
"linear": LinearNoSubsampling, |
|
"linear_legacy": LegacyLinearNoSubsampling, |
|
"embed": EmbedinigNoSubsampling, |
|
"conv1d2": Conv1dSubsampling2, |
|
"conv2d": Conv2dSubsampling4, |
|
"conv2d6": Conv2dSubsampling6, |
|
"conv2d8": Conv2dSubsampling8, |
|
'paraformer_dummy': torch.nn.Identity |
|
} |
|
|
|
COSYVOICE_EMB_CLASSES = { |
|
"embed": PositionalEncoding, |
|
"abs_pos": PositionalEncoding, |
|
"rel_pos": RelPositionalEncoding, |
|
"rel_pos_espnet": EspnetRelPositionalEncoding, |
|
"no_pos": NoPositionalEncoding, |
|
"abs_pos_whisper": WhisperPositionalEncoding, |
|
"embed_learnable_pe": LearnablePositionalEncoding, |
|
} |
|
|
|
COSYVOICE_ATTENTION_CLASSES = { |
|
"selfattn": MultiHeadedAttention, |
|
"rel_selfattn": RelPositionMultiHeadedAttention, |
|
} |
|
|
|
|
|
def get_model_type(configs): |
|
|
|
if isinstance(configs['llm'], TransformerLM) and isinstance(configs['flow'], MaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator): |
|
return CosyVoiceModel |
|
if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator): |
|
return CosyVoice2Model |
|
raise TypeError('No valid model type found!') |
|
|