ekwek commited on
Commit
63d4ab6
·
verified ·
1 Parent(s): 7497466

Upload 10 files

Browse files
soprano/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .tts import SopranoTTS
soprano/backends/base.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class BaseModel:
2
+ def infer(self,
3
+ prompts,
4
+ top_p=0.95,
5
+ temperature=0.3,
6
+ repetition_penalty=1.2):
7
+ '''
8
+ Takes a list of prompts and returns the output hidden states
9
+ '''
10
+ pass
11
+
12
+ def stream_infer(self,
13
+ prompt,
14
+ top_p=0.95,
15
+ temperature=0.3,
16
+ repetition_penalty=1.2):
17
+ '''
18
+ Takes a prompt and returns an iterator of the output hidden states
19
+ '''
20
+ pass
soprano/backends/lmdeploy.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from lmdeploy import pipeline, TurbomindEngineConfig, GenerationConfig
3
+ from .base import BaseModel
4
+
5
+
6
+ class LMDeployModel(BaseModel):
7
+ def __init__(self,
8
+ device='cuda',
9
+ cache_size_mb=100,
10
+ **kwargs):
11
+ assert device == 'cuda', "lmdeploy only supports cuda devices, consider changing device or using a different backend instead."
12
+ cache_size_ratio = cache_size_mb * 1024**2 / torch.cuda.get_device_properties('cuda').total_memory
13
+ backend_config = TurbomindEngineConfig(cache_max_entry_count=cache_size_ratio)
14
+ self.pipeline = pipeline('ekwek/Soprano-80M',
15
+ log_level='ERROR',
16
+ backend_config=backend_config)
17
+
18
+ def infer(self,
19
+ prompts,
20
+ top_p=0.95,
21
+ temperature=0.3,
22
+ repetition_penalty=1.2):
23
+ gen_config=GenerationConfig(output_last_hidden_state='generation',
24
+ do_sample=True,
25
+ top_p=top_p,
26
+ temperature=temperature,
27
+ repetition_penalty=repetition_penalty,
28
+ max_new_tokens=512)
29
+ responses = self.pipeline(prompts, gen_config=gen_config)
30
+ res = []
31
+ for response in responses:
32
+ res.append({
33
+ 'finish_reason': response.finish_reason,
34
+ 'hidden_state': response.last_hidden_state
35
+ })
36
+ return res
37
+
38
+ def stream_infer(self,
39
+ prompt,
40
+ top_p=0.95,
41
+ temperature=0.3,
42
+ repetition_penalty=1.2):
43
+ gen_config=GenerationConfig(output_last_hidden_state='generation',
44
+ do_sample=True,
45
+ top_p=top_p,
46
+ temperature=temperature,
47
+ repetition_penalty=repetition_penalty,
48
+ max_new_tokens=512)
49
+ responses = self.pipeline.stream_infer([prompt], gen_config=gen_config)
50
+ for response in responses:
51
+ yield {
52
+ 'finish_reason': response.finish_reason,
53
+ 'hidden_state': response.last_hidden_state
54
+ }
soprano/backends/transformers.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from .base import BaseModel
4
+
5
+
6
+ class TransformersModel(BaseModel):
7
+ def __init__(self,
8
+ device='cuda',
9
+ **kwargs):
10
+ self.device = device
11
+
12
+ self.model = AutoModelForCausalLM.from_pretrained(
13
+ 'ekwek/Soprano-80M',
14
+ torch_dtype=torch.bfloat16 if device == 'cuda' else torch.float32,
15
+ device_map=device
16
+ )
17
+ self.tokenizer = AutoTokenizer.from_pretrained('ekwek/Soprano-80M')
18
+ self.model.eval()
19
+
20
+ def infer(self,
21
+ prompts,
22
+ top_p=0.95,
23
+ temperature=0.3,
24
+ repetition_penalty=1.2):
25
+ inputs = self.tokenizer(
26
+ prompts,
27
+ return_tensors='pt',
28
+ padding=True,
29
+ truncation=True,
30
+ max_length=512,
31
+ ).to(self.device)
32
+
33
+ with torch.no_grad():
34
+ outputs = self.model.generate(
35
+ input_ids=inputs['input_ids'],
36
+ attention_mask=inputs['attention_mask'],
37
+ max_new_tokens=512,
38
+ do_sample=True,
39
+ top_p=top_p,
40
+ temperature=temperature,
41
+ repetition_penalty=repetition_penalty,
42
+ pad_token_id=self.tokenizer.pad_token_id,
43
+ return_dict_in_generate=True,
44
+ output_hidden_states=True,
45
+ )
46
+ res = []
47
+ eos_token_id = self.model.config.eos_token_id
48
+ for i in range(len(prompts)):
49
+ seq = outputs.sequences[i]
50
+ hidden_states = []
51
+ num_output_tokens = len(outputs.hidden_states)
52
+ for j in range(num_output_tokens):
53
+ token = seq[j + seq.size(0) - num_output_tokens]
54
+ if token != eos_token_id: hidden_states.append(outputs.hidden_states[j][-1][i, -1, :])
55
+ last_hidden_state = torch.stack(hidden_states).squeeze()
56
+ finish_reason = 'stop' if seq[-1].item() == eos_token_id else 'length'
57
+ res.append({
58
+ 'finish_reason': finish_reason,
59
+ 'hidden_state': last_hidden_state
60
+ })
61
+ return res
62
+
63
+ def stream_infer(self,
64
+ prompt,
65
+ top_p=0.95,
66
+ temperature=0.3,
67
+ repetition_penalty=1.2):
68
+ raise NotImplementedError("transformers backend does not currently support streaming, please consider using lmdeploy backend instead.")
soprano/tts.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .vocos.decoder import SopranoDecoder
2
+ import torch
3
+ import re
4
+ from unidecode import unidecode
5
+ from scipy.io import wavfile
6
+ from huggingface_hub import hf_hub_download
7
+ import os
8
+ import time
9
+
10
+
11
+ class SopranoTTS:
12
+ def __init__(self,
13
+ backend='auto',
14
+ device='cuda',
15
+ cache_size_mb=10,
16
+ decoder_batch_size=1):
17
+ RECOGNIZED_DEVICES = ['cuda']
18
+ RECOGNIZED_BACKENDS = ['auto', 'lmdeploy', 'transformers']
19
+ assert device in RECOGNIZED_DEVICES, f"unrecognized device {device}, device must be in {RECOGNIZED_DEVICES}"
20
+ if backend == 'auto':
21
+ if device == 'cpu':
22
+ backend = 'transformers'
23
+ else:
24
+ try:
25
+ import lmdeploy
26
+ backend = 'lmdeploy'
27
+ except ImportError:
28
+ backend='transformers'
29
+ print(f"Using backend {backend}.")
30
+ assert backend in RECOGNIZED_BACKENDS, f"unrecognized backend {backend}, backend must be in {RECOGNIZED_BACKENDS}"
31
+
32
+ if backend == 'lmdeploy':
33
+ from .backends.lmdeploy import LMDeployModel
34
+ self.pipeline = LMDeployModel(device=device, cache_size_mb=cache_size_mb)
35
+ elif backend == 'transformers':
36
+ from .backends.transformers import TransformersModel
37
+ self.pipeline = TransformersModel(device=device)
38
+
39
+ self.decoder = SopranoDecoder().cuda()
40
+ decoder_path = hf_hub_download(repo_id='ekwek/Soprano-80M', filename='decoder.pth')
41
+ self.decoder.load_state_dict(torch.load(decoder_path))
42
+ self.decoder_batch_size=decoder_batch_size
43
+ self.RECEPTIVE_FIELD = 4 # Decoder receptive field
44
+ self.TOKEN_SIZE = 2048 # Number of samples per audio token
45
+
46
+ self.infer("Hello world!") # warmup
47
+
48
+ def _preprocess_text(self, texts):
49
+ '''
50
+ adds prompt format and sentence/part index
51
+ '''
52
+ res = []
53
+ for text_idx, text in enumerate(texts):
54
+ text = text.strip()
55
+ sentences = re.split(r"(?<=[.!?])\s+", text)
56
+ processed_sentences = []
57
+ for sentence_idx, sentence in enumerate(sentences):
58
+ old_len = len(sentence)
59
+ new_sentence = re.sub(r"[^A-Za-z !\$%&'*+,-./0123456789<>?_]", "", sentence)
60
+ new_sentence = re.sub(r"[<>/_+]", "", new_sentence)
61
+ new_sentence = re.sub(r"\.\.[^\.]", ".", new_sentence)
62
+ new_len = len(new_sentence)
63
+ if old_len != new_len:
64
+ print(f"Warning: unsupported characters found in sentence: {sentence}\n\tThese characters have been removed.")
65
+ new_sentence = unidecode(new_sentence.strip())
66
+ processed_sentences.append((f'[STOP][TEXT]{new_sentence}[START]', text_idx, sentence_idx))
67
+ res.extend(processed_sentences)
68
+ return res
69
+
70
+ def infer(self,
71
+ text,
72
+ out_path=None,
73
+ top_p=0.95,
74
+ temperature=0.3,
75
+ repetition_penalty=1.2):
76
+ results = self.infer_batch([text],
77
+ top_p=top_p,
78
+ temperature=temperature,
79
+ repetition_penalty=repetition_penalty,
80
+ out_dir=None)[0]
81
+ if out_path:
82
+ wavfile.write(out_path, 32000, results.cpu().numpy())
83
+ return results
84
+
85
+ def infer_batch(self,
86
+ texts,
87
+ out_dir=None,
88
+ top_p=0.95,
89
+ temperature=0.3,
90
+ repetition_penalty=1.2):
91
+ sentence_data = self._preprocess_text(texts)
92
+ prompts = list(map(lambda x: x[0], sentence_data))
93
+ responses = self.pipeline.infer(prompts,
94
+ top_p=top_p,
95
+ temperature=temperature,
96
+ repetition_penalty=repetition_penalty)
97
+ hidden_states = []
98
+ for i, response in enumerate(responses):
99
+ if response['finish_reason'] != 'stop':
100
+ print(f"Warning: some sentences did not complete generation, likely due to hallucination.")
101
+ hidden_state = response['hidden_state']
102
+ hidden_states.append(hidden_state)
103
+ combined = list(zip(hidden_states, sentence_data))
104
+ combined.sort(key=lambda x: -x[0].size(0))
105
+ hidden_states, sentence_data = zip(*combined)
106
+
107
+ num_texts = len(texts)
108
+ audio_concat = [[] for _ in range(num_texts)]
109
+ for sentence in sentence_data:
110
+ audio_concat[sentence[1]].append(None)
111
+ for idx in range(0, len(hidden_states), self.decoder_batch_size):
112
+ batch_hidden_states = []
113
+ lengths = list(map(lambda x: x.size(0), hidden_states[idx:idx+self.decoder_batch_size]))
114
+ N = len(lengths)
115
+ for i in range(N):
116
+ batch_hidden_states.append(torch.cat([
117
+ torch.zeros((1, 512, lengths[0]-lengths[i]), device='cuda'),
118
+ hidden_states[idx+i].unsqueeze(0).transpose(1,2).cuda().to(torch.float32),
119
+ ], dim=2))
120
+ batch_hidden_states = torch.cat(batch_hidden_states)
121
+ with torch.no_grad():
122
+ audio = self.decoder(batch_hidden_states)
123
+
124
+ for i in range(N):
125
+ text_id = sentence_data[idx+i][1]
126
+ sentence_id = sentence_data[idx+i][2]
127
+ audio_concat[text_id][sentence_id] = audio[i].squeeze()[-(lengths[i]*self.TOKEN_SIZE-self.TOKEN_SIZE):]
128
+ audio_concat = [torch.cat(x).cpu() for x in audio_concat]
129
+
130
+ if out_dir:
131
+ os.makedirs(out_dir, exist_ok=True)
132
+ for i in range(len(audio_concat)):
133
+ wavfile.write(f"{out_dir}/{i}.wav", 32000, audio_concat[i].cpu().numpy())
134
+ return audio_concat
135
+
136
+ def infer_stream(self,
137
+ text,
138
+ chunk_size=1,
139
+ top_p=0.95,
140
+ temperature=0.3,
141
+ repetition_penalty=1.2):
142
+ start_time = time.time()
143
+ sentence_data = self._preprocess_text([text])
144
+
145
+ first_chunk = True
146
+ for sentence, _, _ in sentence_data:
147
+ responses = self.pipeline.stream_infer(sentence,
148
+ top_p=top_p,
149
+ temperature=temperature,
150
+ repetition_penalty=repetition_penalty)
151
+ hidden_states_buffer = []
152
+ chunk_counter = chunk_size
153
+ for token in responses:
154
+ finished = token['finish_reason'] is not None
155
+ if not finished: hidden_states_buffer.append(token['hidden_state'][-1])
156
+ hidden_states_buffer = hidden_states_buffer[-(2*self.RECEPTIVE_FIELD+chunk_size):]
157
+ if finished or len(hidden_states_buffer) >= self.RECEPTIVE_FIELD + chunk_size:
158
+ if finished or chunk_counter == chunk_size:
159
+ batch_hidden_states = torch.stack(hidden_states_buffer)
160
+ inp = batch_hidden_states.unsqueeze(0).transpose(1, 2).cuda().to(torch.float32)
161
+ with torch.no_grad():
162
+ audio = self.decoder(inp)[0]
163
+ if finished:
164
+ audio_chunk = audio[-((self.RECEPTIVE_FIELD+chunk_counter-1)*self.TOKEN_SIZE-self.TOKEN_SIZE):]
165
+ else:
166
+ audio_chunk = audio[-((self.RECEPTIVE_FIELD+chunk_size)*self.TOKEN_SIZE-self.TOKEN_SIZE):-(self.RECEPTIVE_FIELD*self.TOKEN_SIZE-self.TOKEN_SIZE)]
167
+ chunk_counter = 0
168
+ if first_chunk:
169
+ print(f"Streaming latency: {1000*(time.time()-start_time):.2f} ms")
170
+ first_chunk = False
171
+ yield audio_chunk.cpu()
172
+ chunk_counter += 1
soprano/vocos/decoder.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from .models import VocosBackbone
5
+ from .heads import ISTFTHead
6
+
7
+
8
+ class SopranoDecoder(nn.Module):
9
+ def __init__(self,
10
+ num_input_channels=512,
11
+ decoder_num_layers=8,
12
+ decoder_dim=512,
13
+ decoder_intermediate_dim=None,
14
+ hop_length=512,
15
+ n_fft=2048,
16
+ upscale=4,
17
+ dw_kernel=3,
18
+ ):
19
+ super().__init__()
20
+ self.decoder_initial_channels = num_input_channels
21
+ self.num_layers = decoder_num_layers
22
+ self.dim = decoder_dim
23
+ self.intermediate_dim = decoder_intermediate_dim if decoder_intermediate_dim else decoder_dim*3
24
+ self.hop_length = hop_length
25
+ self.n_fft = n_fft
26
+ self.upscale = upscale
27
+ self.dw_kernel = dw_kernel
28
+
29
+ self.decoder = VocosBackbone(input_channels=self.decoder_initial_channels,
30
+ dim=self.dim,
31
+ intermediate_dim=self.intermediate_dim,
32
+ num_layers=self.num_layers,
33
+ input_kernel_size=dw_kernel,
34
+ dw_kernel_size=dw_kernel,
35
+ )
36
+ self.head = ISTFTHead(dim=self.dim,
37
+ n_fft=self.n_fft,
38
+ hop_length=self.hop_length)
39
+
40
+ def forward(self, x):
41
+ T = x.size(2)
42
+ x = torch.nn.functional.interpolate(x, size=self.upscale*(T-1)+1, mode='linear', align_corners=True)
43
+ x = self.decoder(x)
44
+ reconstructed = self.head(x)
45
+ return reconstructed
soprano/vocos/heads.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from .spectral_ops import ISTFT
4
+
5
+
6
+ class ISTFTHead(nn.Module):
7
+ """
8
+ ISTFT Head module for predicting STFT complex coefficients.
9
+
10
+ Args:
11
+ dim (int): Hidden dimension of the model.
12
+ n_fft (int): Size of Fourier transform.
13
+ hop_length (int): The distance between neighboring sliding window frames, which should align with
14
+ the resolution of the input features.
15
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
16
+ """
17
+
18
+ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "center"):
19
+ super().__init__()
20
+ out_dim = n_fft + 2
21
+ self.out = torch.nn.Linear(dim, out_dim)
22
+ self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding)
23
+
24
+ @torch.compiler.disable
25
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
26
+ """
27
+ Forward pass of the ISTFTHead module.
28
+
29
+ Args:
30
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
31
+ L is the sequence length, and H denotes the model dimension.
32
+
33
+ Returns:
34
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
35
+ """
36
+ x = self.out(x.transpose(1,2)).transpose(1, 2)
37
+ mag, p = x.chunk(2, dim=1)
38
+ mag = torch.exp(mag)
39
+ mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes
40
+ # wrapping happens here. These two lines produce real and imaginary value
41
+ x = torch.cos(p)
42
+ y = torch.sin(p)
43
+ # recalculating phase here does not produce anything new
44
+ # only costs time
45
+ # phase = torch.atan2(y, x)
46
+ # S = mag * torch.exp(phase * 1j)
47
+ # better directly produce the complex value
48
+ S = mag * (x + 1j * y)
49
+ audio = self.istft(S)
50
+ return audio
soprano/vocos/models.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from .modules import ConvNeXtBlock
7
+
8
+ class VocosBackbone(nn.Module):
9
+ """
10
+ Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
11
+
12
+ Args:
13
+ input_channels (int): Number of input features channels.
14
+ dim (int): Hidden dimension of the model.
15
+ intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
16
+ num_layers (int): Number of ConvNeXtBlock layers.
17
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ input_channels: int,
23
+ dim: int,
24
+ intermediate_dim: int,
25
+ num_layers: int,
26
+ input_kernel_size: int = 9,
27
+ dw_kernel_size: int = 9,
28
+ layer_scale_init_value: Optional[float] = None,
29
+ pad: str = 'zeros',
30
+ ):
31
+ super().__init__()
32
+ self.embed = nn.Conv1d(input_channels, dim, kernel_size=input_kernel_size, padding=input_kernel_size//2, padding_mode=pad)
33
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
34
+ self.convnext = nn.ModuleList(
35
+ [
36
+ ConvNeXtBlock(
37
+ dim=dim,
38
+ intermediate_dim=intermediate_dim,
39
+ dw_kernel_size=dw_kernel_size,
40
+ layer_scale_init_value=layer_scale_init_value or 1 / num_layers**0.5,
41
+ )
42
+ for _ in range(num_layers)
43
+ ]
44
+ )
45
+ self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
46
+ self.apply(self._init_weights)
47
+
48
+ def _init_weights(self, m):
49
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
50
+ nn.init.trunc_normal_(m.weight, std=0.02)
51
+ if m.bias is not None: nn.init.constant_(m.bias, 0)
52
+
53
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
54
+ x = self.embed(x) # (B, C, L)
55
+ x = self.norm(x.transpose(1, 2))
56
+ x = x.transpose(1, 2)
57
+ for conv_block in self.convnext:
58
+ x = conv_block(x)
59
+ x = self.final_layer_norm(x.transpose(1, 2))
60
+ x = x.transpose(1, 2)
61
+ return x
soprano/vocos/modules.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class ConvNeXtBlock(nn.Module):
6
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
7
+
8
+ Args:
9
+ dim (int): Number of input channels.
10
+ intermediate_dim (int): Dimensionality of the intermediate layer.
11
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
12
+ Defaults to None.
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ dim: int,
18
+ intermediate_dim: int,
19
+ layer_scale_init_value: float,
20
+ dw_kernel_size: int = 9,
21
+ ):
22
+ super().__init__()
23
+ self.dwconv = nn.Conv1d(dim, dim, kernel_size=dw_kernel_size, padding=dw_kernel_size//2, groups=dim) # depthwise conv
24
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
25
+ self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
26
+ self.act = nn.GELU()
27
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
28
+ self.gamma = (
29
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
30
+ if layer_scale_init_value > 0
31
+ else None
32
+ )
33
+
34
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
35
+ residual = x
36
+ x = self.dwconv(x)
37
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
38
+ x = self.norm(x)
39
+ x = self.pwconv1(x)
40
+ x = self.act(x)
41
+ x = self.pwconv2(x)
42
+ if self.gamma is not None:
43
+ x = self.gamma * x
44
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
45
+
46
+ x = residual + x
47
+ return x
soprano/vocos/spectral_ops.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ class ISTFT(nn.Module):
5
+ """
6
+ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
7
+ windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
8
+ See issue: https://github.com/pytorch/pytorch/issues/62323
9
+ Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
10
+ The NOLA constraint is met as we trim padded samples anyway.
11
+
12
+ Args:
13
+ n_fft (int): Size of Fourier transform.
14
+ hop_length (int): The distance between neighboring sliding window frames.
15
+ win_length (int): The size of window frame and STFT filter.
16
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
17
+ """
18
+
19
+ def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"):
20
+ super().__init__()
21
+ if padding not in ["center", "same"]:
22
+ raise ValueError("Padding must be 'center' or 'same'.")
23
+ self.padding = padding
24
+ self.n_fft = n_fft
25
+ self.hop_length = hop_length
26
+ self.win_length = win_length
27
+ window = torch.hann_window(win_length).to('cuda')
28
+ self.register_buffer("window", window)
29
+
30
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
31
+ """
32
+ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
33
+
34
+ Args:
35
+ spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
36
+ N is the number of frequency bins, and T is the number of time frames.
37
+
38
+ Returns:
39
+ Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
40
+ """
41
+ if self.padding == "center":
42
+ spec[:,0] = 0 # fixes some strange bug where first/last freqs don't matter when bs<16 which causes exploding gradients
43
+ spec[:,-1] = 0
44
+ # Fallback to pytorch native implementation
45
+ return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True)
46
+ elif self.padding == "same":
47
+ pad = (self.win_length - self.hop_length) // 2
48
+ else:
49
+ raise ValueError("Padding must be 'center' or 'same'.")
50
+
51
+ assert spec.dim() == 3, "Expected a 3D tensor as input"
52
+ B, N, T = spec.shape
53
+
54
+ # Inverse FFT
55
+ ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
56
+ ifft = ifft * self.window[None, :, None]
57
+
58
+ # Overlap and Add
59
+ output_size = (T - 1) * self.hop_length + self.win_length
60
+ y = torch.nn.functional.fold(
61
+ ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
62
+ )[:, 0, 0, pad:-pad]
63
+
64
+ # Window envelope
65
+ window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
66
+ window_envelope = torch.nn.functional.fold(
67
+ window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
68
+ ).squeeze()[pad:-pad]
69
+
70
+ # Normalize
71
+ assert (window_envelope > 1e-11).all()
72
+ y = y / window_envelope
73
+
74
+ return y