Plachta commited on
Commit
893262b
β€’
1 Parent(s): 1c7cc12

Update modules/diffusion_transformer.py

Browse files
Files changed (1) hide show
  1. modules/diffusion_transformer.py +240 -237
modules/diffusion_transformer.py CHANGED
@@ -1,237 +1,240 @@
1
- import torch
2
- from torch import nn
3
- import math
4
-
5
- from modules.gpt_fast.model import ModelArgs, Transformer
6
- from modules.wavenet import WN
7
- from modules.commons import sequence_mask
8
-
9
- from torch.nn.utils import weight_norm
10
-
11
- def modulate(x, shift, scale):
12
- return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
13
-
14
-
15
- #################################################################################
16
- # Embedding Layers for Timesteps and Class Labels #
17
- #################################################################################
18
-
19
- class TimestepEmbedder(nn.Module):
20
- """
21
- Embeds scalar timesteps into vector representations.
22
- """
23
- def __init__(self, hidden_size, frequency_embedding_size=256):
24
- super().__init__()
25
- self.mlp = nn.Sequential(
26
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
27
- nn.SiLU(),
28
- nn.Linear(hidden_size, hidden_size, bias=True),
29
- )
30
- self.frequency_embedding_size = frequency_embedding_size
31
-
32
- @staticmethod
33
- def timestep_embedding(t, dim, max_period=10000, scale=1000):
34
- """
35
- Create sinusoidal timestep embeddings.
36
- :param t: a 1-D Tensor of N indices, one per batch element.
37
- These may be fractional.
38
- :param dim: the dimension of the output.
39
- :param max_period: controls the minimum frequency of the embeddings.
40
- :return: an (N, D) Tensor of positional embeddings.
41
- """
42
- # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
43
- half = dim // 2
44
- freqs = torch.exp(
45
- -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
46
- ).to(device=t.device)
47
- args = scale * t[:, None].float() * freqs[None]
48
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
49
- if dim % 2:
50
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
51
- return embedding
52
-
53
- def forward(self, t):
54
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
55
- t_emb = self.mlp(t_freq)
56
- return t_emb
57
-
58
-
59
- class StyleEmbedder(nn.Module):
60
- """
61
- Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
62
- """
63
- def __init__(self, input_size, hidden_size, dropout_prob):
64
- super().__init__()
65
- use_cfg_embedding = dropout_prob > 0
66
- self.embedding_table = nn.Embedding(int(use_cfg_embedding), hidden_size)
67
- self.style_in = weight_norm(nn.Linear(input_size, hidden_size, bias=True))
68
- self.input_size = input_size
69
- self.dropout_prob = dropout_prob
70
-
71
- def forward(self, labels, train, force_drop_ids=None):
72
- use_dropout = self.dropout_prob > 0
73
- if (train and use_dropout) or (force_drop_ids is not None):
74
- labels = self.token_drop(labels, force_drop_ids)
75
- else:
76
- labels = self.style_in(labels)
77
- embeddings = labels
78
- return embeddings
79
-
80
- class FinalLayer(nn.Module):
81
- """
82
- The final layer of DiT.
83
- """
84
- def __init__(self, hidden_size, patch_size, out_channels):
85
- super().__init__()
86
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
87
- self.linear = weight_norm(nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True))
88
- self.adaLN_modulation = nn.Sequential(
89
- nn.SiLU(),
90
- nn.Linear(hidden_size, 2 * hidden_size, bias=True)
91
- )
92
-
93
- def forward(self, x, c):
94
- shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
95
- x = modulate(self.norm_final(x), shift, scale)
96
- x = self.linear(x)
97
- return x
98
-
99
- class DiT(torch.nn.Module):
100
- def __init__(
101
- self,
102
- args
103
- ):
104
- super(DiT, self).__init__()
105
- self.time_as_token = args.DiT.time_as_token if hasattr(args.DiT, 'time_as_token') else False
106
- self.style_as_token = args.DiT.style_as_token if hasattr(args.DiT, 'style_as_token') else False
107
- self.uvit_skip_connection = args.DiT.uvit_skip_connection if hasattr(args.DiT, 'uvit_skip_connection') else False
108
- model_args = ModelArgs(
109
- block_size=16384,#args.DiT.block_size,
110
- n_layer=args.DiT.depth,
111
- n_head=args.DiT.num_heads,
112
- dim=args.DiT.hidden_dim,
113
- head_dim=args.DiT.hidden_dim // args.DiT.num_heads,
114
- vocab_size=1024,
115
- uvit_skip_connection=self.uvit_skip_connection,
116
- )
117
- self.transformer = Transformer(model_args)
118
- self.in_channels = args.DiT.in_channels
119
- self.out_channels = args.DiT.in_channels
120
- self.num_heads = args.DiT.num_heads
121
-
122
- self.x_embedder = weight_norm(nn.Linear(args.DiT.in_channels, args.DiT.hidden_dim, bias=True))
123
-
124
- self.content_type = args.DiT.content_type # 'discrete' or 'continuous'
125
- self.content_codebook_size = args.DiT.content_codebook_size # for discrete content
126
- self.content_dim = args.DiT.content_dim # for continuous content
127
- self.cond_embedder = nn.Embedding(args.DiT.content_codebook_size, args.DiT.hidden_dim) # discrete content
128
- self.cond_projection = nn.Linear(args.DiT.content_dim, args.DiT.hidden_dim, bias=True) # continuous content
129
-
130
- self.is_causal = args.DiT.is_causal
131
-
132
- self.n_f0_bins = args.DiT.n_f0_bins
133
- self.f0_bins = torch.arange(2, 1024, 1024 // args.DiT.n_f0_bins)
134
- self.f0_embedder = nn.Embedding(args.DiT.n_f0_bins, args.DiT.hidden_dim)
135
- self.f0_condition = args.DiT.f0_condition
136
-
137
- self.t_embedder = TimestepEmbedder(args.DiT.hidden_dim)
138
- self.t_embedder2 = TimestepEmbedder(args.wavenet.hidden_dim)
139
- # self.style_embedder1 = weight_norm(nn.Linear(1024, args.DiT.hidden_dim, bias=True))
140
- # self.style_embedder2 = weight_norm(nn.Linear(1024, args.style_encoder.dim, bias=True))
141
-
142
- input_pos = torch.arange(16384)
143
- self.register_buffer("input_pos", input_pos)
144
-
145
- self.conv1 = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim)
146
- self.conv2 = nn.Conv1d(args.wavenet.hidden_dim, args.DiT.in_channels, 1)
147
- self.final_layer_type = args.DiT.final_layer_type # mlp or wavenet
148
- if self.final_layer_type == 'wavenet':
149
- self.wavenet = WN(hidden_channels=args.wavenet.hidden_dim,
150
- kernel_size=args.wavenet.kernel_size,
151
- dilation_rate=args.wavenet.dilation_rate,
152
- n_layers=args.wavenet.num_layers,
153
- gin_channels=args.wavenet.hidden_dim,
154
- p_dropout=args.wavenet.p_dropout,
155
- causal=False)
156
- self.final_layer = FinalLayer(args.wavenet.hidden_dim, 1, args.wavenet.hidden_dim)
157
- else:
158
- self.final_mlp = nn.Sequential(
159
- nn.Linear(args.DiT.hidden_dim, args.DiT.hidden_dim),
160
- nn.SiLU(),
161
- nn.Linear(args.DiT.hidden_dim, args.DiT.in_channels),
162
- )
163
- self.final_conv = nn.Conv1d(args.DiT.in_channels, args.DiT.in_channels, kernel_size=3, padding=1)
164
- self.transformer_style_condition = args.DiT.style_condition
165
- self.wavenet_style_condition = args.wavenet.style_condition
166
- assert args.DiT.style_condition == args.wavenet.style_condition
167
-
168
- self.class_dropout_prob = args.DiT.class_dropout_prob
169
- self.content_mask_embedder = nn.Embedding(1, args.DiT.hidden_dim)
170
- self.res_projection = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim) # residual connection from tranformer output to final output
171
- self.long_skip_connection = args.DiT.long_skip_connection
172
- self.skip_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels, args.DiT.hidden_dim)
173
-
174
- self.cond_x_merge_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels * 2 +
175
- args.style_encoder.dim * self.transformer_style_condition * (not self.style_as_token),
176
- args.DiT.hidden_dim)
177
- if self.style_as_token:
178
- self.style_in = nn.Linear(args.style_encoder.dim, args.DiT.hidden_dim)
179
-
180
- def setup_caches(self, max_batch_size, max_seq_length):
181
- self.transformer.setup_caches(max_batch_size, max_seq_length, use_kv_cache=False)
182
- def forward(self, x, prompt_x, x_lens, t, style, cond, f0=None, mask_content=False):
183
- class_dropout = False
184
- if self.training and torch.rand(1) < self.class_dropout_prob:
185
- class_dropout = True
186
- if not self.training and mask_content:
187
- class_dropout = True
188
- # cond_in_module = self.cond_embedder if self.content_type == 'discrete' else self.cond_projection
189
- cond_in_module = self.cond_projection
190
-
191
- B, _, T = x.size()
192
-
193
-
194
- t1 = self.t_embedder(t) # (N, D)
195
-
196
- cond = cond_in_module(cond)
197
- if self.f0_condition and f0 is not None:
198
- quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device)) # (N, T)
199
- cond = cond + self.f0_embedder(quantized_f0)
200
-
201
- x = x.transpose(1, 2)
202
- prompt_x = prompt_x.transpose(1, 2)
203
-
204
- x_in = torch.cat([x, prompt_x, cond], dim=-1)
205
- if self.transformer_style_condition and not self.style_as_token:
206
- x_in = torch.cat([x_in, style[:, None, :].repeat(1, T, 1)], dim=-1)
207
- if class_dropout:
208
- x_in[..., self.in_channels:] = x_in[..., self.in_channels:] * 0
209
- x_in = self.cond_x_merge_linear(x_in) # (N, T, D)
210
-
211
- if self.style_as_token:
212
- style = self.style_in(style)
213
- style = torch.zeros_like(style) if class_dropout else style
214
- x_in = torch.cat([style.unsqueeze(1), x_in], dim=1)
215
- if self.time_as_token:
216
- x_in = torch.cat([t1.unsqueeze(1), x_in], dim=1)
217
- x_mask = sequence_mask(x_lens + self.style_as_token + self.time_as_token).to(x.device).unsqueeze(1)
218
- input_pos = self.input_pos[:x_in.size(1)] # (T,)
219
- x_mask_expanded = x_mask[:, None, :].repeat(1, 1, x_in.size(1), 1) if not self.is_causal else None
220
- x_res = self.transformer(x_in, None if self.time_as_token else t1.unsqueeze(1), input_pos, x_mask_expanded)
221
- x_res = x_res[:, 1:] if self.time_as_token else x_res
222
- x_res = x_res[:, 1:] if self.style_as_token else x_res
223
- if self.long_skip_connection:
224
- x_res = self.skip_linear(torch.cat([x_res, x], dim=-1))
225
- if self.final_layer_type == 'wavenet':
226
- x = self.conv1(x_res)
227
- x = x.transpose(1, 2)
228
- t2 = self.t_embedder2(t)
229
- x = self.wavenet(x, x_mask, g=t2.unsqueeze(2)).transpose(1, 2) + self.res_projection(
230
- x_res) # long residual connection
231
- x = self.final_layer(x, t1).transpose(1, 2)
232
- x = self.conv2(x)
233
- else:
234
- x = self.final_mlp(x_res)
235
- x = x.transpose(1, 2)
236
- x = self.final_conv(x)
237
- return x
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import math
4
+
5
+ from modules.gpt_fast.model import ModelArgs, Transformer
6
+ # from modules.torchscript_modules.gpt_fast_model import ModelArgs, Transformer
7
+ from modules.wavenet import WN
8
+ from modules.commons import sequence_mask
9
+
10
+ from torch.nn.utils import weight_norm
11
+
12
+ def modulate(x, shift, scale):
13
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
14
+
15
+
16
+ #################################################################################
17
+ # Embedding Layers for Timesteps and Class Labels #
18
+ #################################################################################
19
+
20
+ class TimestepEmbedder(nn.Module):
21
+ """
22
+ Embeds scalar timesteps into vector representations.
23
+ """
24
+ def __init__(self, hidden_size, frequency_embedding_size=256):
25
+ super().__init__()
26
+ self.mlp = nn.Sequential(
27
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
28
+ nn.SiLU(),
29
+ nn.Linear(hidden_size, hidden_size, bias=True),
30
+ )
31
+ self.frequency_embedding_size = frequency_embedding_size
32
+ self.max_period = 10000
33
+ self.scale = 1000
34
+
35
+ half = frequency_embedding_size // 2
36
+ freqs = torch.exp(
37
+ -math.log(self.max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
38
+ )
39
+ self.register_buffer("freqs", freqs)
40
+
41
+ def timestep_embedding(self, t):
42
+ """
43
+ Create sinusoidal timestep embeddings.
44
+ :param t: a 1-D Tensor of N indices, one per batch element.
45
+ These may be fractional.
46
+ :param dim: the dimension of the output.
47
+ :param max_period: controls the minimum frequency of the embeddings.
48
+ :return: an (N, D) Tensor of positional embeddings.
49
+ """
50
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
51
+
52
+ args = self.scale * t[:, None].float() * self.freqs[None]
53
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
54
+ if self.frequency_embedding_size % 2:
55
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
56
+ return embedding
57
+
58
+ def forward(self, t):
59
+ t_freq = self.timestep_embedding(t)
60
+ t_emb = self.mlp(t_freq)
61
+ return t_emb
62
+
63
+
64
+ class StyleEmbedder(nn.Module):
65
+ """
66
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
67
+ """
68
+ def __init__(self, input_size, hidden_size, dropout_prob):
69
+ super().__init__()
70
+ use_cfg_embedding = dropout_prob > 0
71
+ self.embedding_table = nn.Embedding(int(use_cfg_embedding), hidden_size)
72
+ self.style_in = weight_norm(nn.Linear(input_size, hidden_size, bias=True))
73
+ self.input_size = input_size
74
+ self.dropout_prob = dropout_prob
75
+
76
+ def forward(self, labels, train, force_drop_ids=None):
77
+ use_dropout = self.dropout_prob > 0
78
+ if (train and use_dropout) or (force_drop_ids is not None):
79
+ labels = self.token_drop(labels, force_drop_ids)
80
+ else:
81
+ labels = self.style_in(labels)
82
+ embeddings = labels
83
+ return embeddings
84
+
85
+ class FinalLayer(nn.Module):
86
+ """
87
+ The final layer of DiT.
88
+ """
89
+ def __init__(self, hidden_size, patch_size, out_channels):
90
+ super().__init__()
91
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
92
+ self.linear = weight_norm(nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True))
93
+ self.adaLN_modulation = nn.Sequential(
94
+ nn.SiLU(),
95
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
96
+ )
97
+
98
+ def forward(self, x, c):
99
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
100
+ x = modulate(self.norm_final(x), shift, scale)
101
+ x = self.linear(x)
102
+ return x
103
+
104
+ class DiT(torch.nn.Module):
105
+ def __init__(
106
+ self,
107
+ args
108
+ ):
109
+ super(DiT, self).__init__()
110
+ self.time_as_token = args.DiT.time_as_token if hasattr(args.DiT, 'time_as_token') else False
111
+ self.style_as_token = args.DiT.style_as_token if hasattr(args.DiT, 'style_as_token') else False
112
+ self.uvit_skip_connection = args.DiT.uvit_skip_connection if hasattr(args.DiT, 'uvit_skip_connection') else False
113
+ model_args = ModelArgs(
114
+ block_size=16384,#args.DiT.block_size,
115
+ n_layer=args.DiT.depth,
116
+ n_head=args.DiT.num_heads,
117
+ dim=args.DiT.hidden_dim,
118
+ head_dim=args.DiT.hidden_dim // args.DiT.num_heads,
119
+ vocab_size=1024,
120
+ uvit_skip_connection=self.uvit_skip_connection,
121
+ )
122
+ self.transformer = Transformer(model_args)
123
+ self.in_channels = args.DiT.in_channels
124
+ self.out_channels = args.DiT.in_channels
125
+ self.num_heads = args.DiT.num_heads
126
+
127
+ self.x_embedder = weight_norm(nn.Linear(args.DiT.in_channels, args.DiT.hidden_dim, bias=True))
128
+
129
+ self.content_type = args.DiT.content_type # 'discrete' or 'continuous'
130
+ self.content_codebook_size = args.DiT.content_codebook_size # for discrete content
131
+ self.content_dim = args.DiT.content_dim # for continuous content
132
+ self.cond_embedder = nn.Embedding(args.DiT.content_codebook_size, args.DiT.hidden_dim) # discrete content
133
+ self.cond_projection = nn.Linear(args.DiT.content_dim, args.DiT.hidden_dim, bias=True) # continuous content
134
+
135
+ self.is_causal = args.DiT.is_causal
136
+
137
+ self.n_f0_bins = args.DiT.n_f0_bins
138
+ self.f0_bins = torch.arange(2, 1024, 1024 // args.DiT.n_f0_bins)
139
+ self.f0_embedder = nn.Embedding(args.DiT.n_f0_bins, args.DiT.hidden_dim)
140
+ self.f0_condition = args.DiT.f0_condition
141
+
142
+ self.t_embedder = TimestepEmbedder(args.DiT.hidden_dim)
143
+ self.t_embedder2 = TimestepEmbedder(args.wavenet.hidden_dim)
144
+ # self.style_embedder1 = weight_norm(nn.Linear(1024, args.DiT.hidden_dim, bias=True))
145
+ # self.style_embedder2 = weight_norm(nn.Linear(1024, args.style_encoder.dim, bias=True))
146
+
147
+ input_pos = torch.arange(16384)
148
+ self.register_buffer("input_pos", input_pos)
149
+
150
+ self.conv1 = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim)
151
+ self.conv2 = nn.Conv1d(args.wavenet.hidden_dim, args.DiT.in_channels, 1)
152
+ self.final_layer_type = args.DiT.final_layer_type # mlp or wavenet
153
+ if self.final_layer_type == 'wavenet':
154
+ self.wavenet = WN(hidden_channels=args.wavenet.hidden_dim,
155
+ kernel_size=args.wavenet.kernel_size,
156
+ dilation_rate=args.wavenet.dilation_rate,
157
+ n_layers=args.wavenet.num_layers,
158
+ gin_channels=args.wavenet.hidden_dim,
159
+ p_dropout=args.wavenet.p_dropout,
160
+ causal=False)
161
+ self.final_layer = FinalLayer(args.wavenet.hidden_dim, 1, args.wavenet.hidden_dim)
162
+ else:
163
+ self.final_mlp = nn.Sequential(
164
+ nn.Linear(args.DiT.hidden_dim, args.DiT.hidden_dim),
165
+ nn.SiLU(),
166
+ nn.Linear(args.DiT.hidden_dim, args.DiT.in_channels),
167
+ )
168
+ self.transformer_style_condition = args.DiT.style_condition
169
+ self.wavenet_style_condition = args.wavenet.style_condition
170
+ assert args.DiT.style_condition == args.wavenet.style_condition
171
+
172
+ self.class_dropout_prob = args.DiT.class_dropout_prob
173
+ self.content_mask_embedder = nn.Embedding(1, args.DiT.hidden_dim)
174
+ self.res_projection = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim) # residual connection from tranformer output to final output
175
+ self.long_skip_connection = args.DiT.long_skip_connection
176
+ self.skip_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels, args.DiT.hidden_dim)
177
+
178
+ self.cond_x_merge_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels * 2 +
179
+ args.style_encoder.dim * self.transformer_style_condition * (not self.style_as_token),
180
+ args.DiT.hidden_dim)
181
+ if self.style_as_token:
182
+ self.style_in = nn.Linear(args.style_encoder.dim, args.DiT.hidden_dim)
183
+
184
+ def setup_caches(self, max_batch_size, max_seq_length):
185
+ self.transformer.setup_caches(max_batch_size, max_seq_length, use_kv_cache=False)
186
+ def forward(self, x, prompt_x, x_lens, t, style, cond, f0=None, mask_content=False):
187
+ class_dropout = False
188
+ if self.training and torch.rand(1) < self.class_dropout_prob:
189
+ class_dropout = True
190
+ if not self.training and mask_content:
191
+ class_dropout = True
192
+ # cond_in_module = self.cond_embedder if self.content_type == 'discrete' else self.cond_projection
193
+ cond_in_module = self.cond_projection
194
+
195
+ B, _, T = x.size()
196
+
197
+
198
+ t1 = self.t_embedder(t) # (N, D)
199
+
200
+ cond = cond_in_module(cond)
201
+ if self.f0_condition and f0 is not None:
202
+ quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device)) # (N, T)
203
+ cond = cond + self.f0_embedder(quantized_f0)
204
+
205
+ x = x.transpose(1, 2)
206
+ prompt_x = prompt_x.transpose(1, 2)
207
+
208
+ x_in = torch.cat([x, prompt_x, cond], dim=-1)
209
+ if self.transformer_style_condition and not self.style_as_token:
210
+ x_in = torch.cat([x_in, style[:, None, :].repeat(1, T, 1)], dim=-1)
211
+ if class_dropout:
212
+ x_in[..., self.in_channels:] = x_in[..., self.in_channels:] * 0
213
+ x_in = self.cond_x_merge_linear(x_in) # (N, T, D)
214
+
215
+ if self.style_as_token:
216
+ style = self.style_in(style)
217
+ style = torch.zeros_like(style) if class_dropout else style
218
+ x_in = torch.cat([style.unsqueeze(1), x_in], dim=1)
219
+ if self.time_as_token:
220
+ x_in = torch.cat([t1.unsqueeze(1), x_in], dim=1)
221
+ x_mask = sequence_mask(x_lens + self.style_as_token + self.time_as_token).to(x.device).unsqueeze(1)
222
+ input_pos = self.input_pos[:x_in.size(1)] # (T,)
223
+ x_mask_expanded = x_mask[:, None, :].repeat(1, 1, x_in.size(1), 1) if not self.is_causal else None
224
+ x_res = self.transformer(x_in, None if self.time_as_token else t1.unsqueeze(1), input_pos, x_mask_expanded)
225
+ x_res = x_res[:, 1:] if self.time_as_token else x_res
226
+ x_res = x_res[:, 1:] if self.style_as_token else x_res
227
+ if self.long_skip_connection:
228
+ x_res = self.skip_linear(torch.cat([x_res, x], dim=-1))
229
+ if self.final_layer_type == 'wavenet':
230
+ x = self.conv1(x_res)
231
+ x = x.transpose(1, 2)
232
+ t2 = self.t_embedder2(t)
233
+ x = self.wavenet(x, x_mask, g=t2.unsqueeze(2)).transpose(1, 2) + self.res_projection(
234
+ x_res) # long residual connection
235
+ x = self.final_layer(x, t1).transpose(1, 2)
236
+ x = self.conv2(x)
237
+ else:
238
+ x = self.final_mlp(x_res)
239
+ x = x.transpose(1, 2)
240
+ return x