del diffusion [unused]
Browse files- audiocraft/audiogen.py +0 -129
- audiocraft/builders.py +2 -24
- audiocraft/diffusion_schedule.py +0 -272
- audiocraft/loaders.py +0 -24
- audiocraft/rope.py +0 -125
- audiocraft/unet.py +0 -214
audiocraft/audiogen.py
DELETED
@@ -1,129 +0,0 @@
|
|
1 |
-
import typing as tp
|
2 |
-
import torch
|
3 |
-
from audiocraft.loaders import load_compression_model, load_lm_model
|
4 |
-
import typing as tp
|
5 |
-
import omegaconf
|
6 |
-
import torch
|
7 |
-
import numpy as np
|
8 |
-
from .lm import LMModel
|
9 |
-
from .conditioners import ConditioningAttributes
|
10 |
-
from .utils.autocast import TorchAutocast
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
def _shift(x):
|
15 |
-
n = x.shape[2]
|
16 |
-
i = np.random.randint(.24 * n, max(1, .74 * n)) # high should be above >= 0 TBD do we have very short segments
|
17 |
-
x = torch.roll(x, i, dims=2)
|
18 |
-
return x
|
19 |
-
|
20 |
-
|
21 |
-
class AudioGen():
|
22 |
-
"""Base generative model with convenient generation API.
|
23 |
-
|
24 |
-
Args:
|
25 |
-
name (str)
|
26 |
-
compression_model (CompressionModel): Encodec with Seanet Decoder
|
27 |
-
lm
|
28 |
-
max_duration (float, optional): As is using top250 token draw() we can gen xN sequences
|
29 |
-
"""
|
30 |
-
def __init__(self,
|
31 |
-
name,
|
32 |
-
compression_model,
|
33 |
-
lm,
|
34 |
-
max_duration=None):
|
35 |
-
self.name = name
|
36 |
-
self.compression_model = compression_model
|
37 |
-
self.lm = lm
|
38 |
-
self.cfg: tp.Optional[omegaconf.DictConfig] = None
|
39 |
-
# Just to be safe, let's put everything in eval mode.
|
40 |
-
self.compression_model.eval()
|
41 |
-
self.lm.eval()
|
42 |
-
|
43 |
-
if hasattr(lm, 'cfg'):
|
44 |
-
cfg = lm.cfg
|
45 |
-
assert isinstance(cfg, omegaconf.DictConfig)
|
46 |
-
self.cfg = cfg
|
47 |
-
|
48 |
-
if max_duration is None:
|
49 |
-
if self.cfg is not None:
|
50 |
-
max_duration = lm.cfg.dataset.segment_duration # type: ignore
|
51 |
-
else:
|
52 |
-
raise ValueError("You must provide max_duration when building directly your GenModel")
|
53 |
-
assert max_duration is not None
|
54 |
-
|
55 |
-
self.max_duration: float = max_duration
|
56 |
-
self.duration = self.max_duration
|
57 |
-
self.device = next(iter(lm.parameters())).device
|
58 |
-
self.generation_params={}
|
59 |
-
|
60 |
-
if self.device.type == 'cpu':
|
61 |
-
self.autocast = TorchAutocast(enabled=False)
|
62 |
-
else:
|
63 |
-
self.autocast = TorchAutocast(
|
64 |
-
enabled=True,
|
65 |
-
device_type=self.device.type,
|
66 |
-
dtype=torch.float16)
|
67 |
-
|
68 |
-
@property
|
69 |
-
def frame_rate(self) -> float:
|
70 |
-
"""Roughly the number of AR steps per seconds."""
|
71 |
-
return self.compression_model.frame_rate
|
72 |
-
|
73 |
-
@property
|
74 |
-
def sample_rate(self) -> int:
|
75 |
-
"""Sample rate of the generated audio."""
|
76 |
-
return self.compression_model.sample_rate
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
def generate(self, descriptions):
|
83 |
-
attributes = [
|
84 |
-
ConditioningAttributes(text={'description': d}) for d in descriptions]
|
85 |
-
tokens = self._generate_tokens(attributes)
|
86 |
-
print(f'\n{tokens.shape=}\n{tokens=} FINAL 5 AUD')
|
87 |
-
return self.generate_audio(tokens)
|
88 |
-
|
89 |
-
def _generate_tokens(self, attributes):
|
90 |
-
|
91 |
-
total_gen_len = int(self.duration * self.frame_rate)
|
92 |
-
|
93 |
-
if self.duration <= self.max_duration:
|
94 |
-
# generate by sampling from LM, simple case.
|
95 |
-
|
96 |
-
with self.autocast:
|
97 |
-
gen_tokens = self.lm.generate(conditions=attributes, max_gen_len=total_gen_len)
|
98 |
-
else:
|
99 |
-
print('<>Long gen ?<>')
|
100 |
-
# print(f'{gen_tokens.shape=}') # [5,4,35]
|
101 |
-
# FLATTEN BATCH AS EXTRA SEQUENCE (BATCH IS VIRTUAL JUST MULTINOMIAL SAMPLING OF N_DRAW TOKENS)
|
102 |
-
gen_tokens = gen_tokens.transpose(0, 1).reshape(4, -1)[None, :, :]
|
103 |
-
for _ in range(3):
|
104 |
-
print(gen_tokens.shape)
|
105 |
-
gen_tokens = _shift(gen_tokens)
|
106 |
-
return gen_tokens
|
107 |
-
|
108 |
-
def generate_audio(self, gen_tokens: torch.Tensor) -> torch.Tensor:
|
109 |
-
"""Generate Audio from tokens."""
|
110 |
-
assert gen_tokens.dim() == 3
|
111 |
-
with torch.no_grad():
|
112 |
-
gen_audio = self.compression_model.decode(gen_tokens, None)
|
113 |
-
return gen_audio
|
114 |
-
|
115 |
-
|
116 |
-
def get_pretrained(name='facebook/audiogen-medium',
|
117 |
-
device=None):
|
118 |
-
"""Return pretrained model, we provide a single model for now:
|
119 |
-
- facebook/audiogen-medium (1.5B), text to sound,
|
120 |
-
# see: https://huggingface.co/facebook/audiogen-medium
|
121 |
-
"""
|
122 |
-
compression_model = load_compression_model(name, device=device)
|
123 |
-
lm = load_lm_model(name, device=device)
|
124 |
-
assert 'self_wav' not in lm.condition_provider.conditioners, \
|
125 |
-
"AudioGen do not support waveform conditioning for now"
|
126 |
-
return AudioGen(name, compression_model, lm)
|
127 |
-
|
128 |
-
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/builders.py
CHANGED
@@ -16,10 +16,10 @@ from .conditioners import (
|
|
16 |
ConditioningProvider,
|
17 |
T5Conditioner,
|
18 |
)
|
19 |
-
|
20 |
from .vq import ResidualVectorQuantizer
|
21 |
|
22 |
-
|
23 |
|
24 |
def dict_from_config(cfg):
|
25 |
dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
|
@@ -155,25 +155,3 @@ def get_codebooks_pattern_provider(n_q, cfg):
|
|
155 |
|
156 |
klass = pattern_providers[name]
|
157 |
return klass(n_q, **kwargs)
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
def get_diffusion_model(cfg: omegaconf.DictConfig):
|
164 |
-
# TODO Find a way to infer the channels from dset
|
165 |
-
channels = cfg.channels
|
166 |
-
num_steps = cfg.schedule.num_steps
|
167 |
-
return DiffusionUnet(
|
168 |
-
chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
|
169 |
-
|
170 |
-
|
171 |
-
def get_processor(cfg, sample_rate: int = 24000):
|
172 |
-
sample_processor = SampleProcessor()
|
173 |
-
if cfg.use:
|
174 |
-
kw = dict(cfg)
|
175 |
-
kw.pop('use')
|
176 |
-
kw.pop('name')
|
177 |
-
if cfg.name == "multi_band_processor":
|
178 |
-
sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw)
|
179 |
-
return sample_processor
|
|
|
16 |
ConditioningProvider,
|
17 |
T5Conditioner,
|
18 |
)
|
19 |
+
|
20 |
from .vq import ResidualVectorQuantizer
|
21 |
|
22 |
+
|
23 |
|
24 |
def dict_from_config(cfg):
|
25 |
dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
|
|
|
155 |
|
156 |
klass = pattern_providers[name]
|
157 |
return klass(n_q, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/diffusion_schedule.py
DELETED
@@ -1,272 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
"""
|
8 |
-
Functions for Noise Schedule, defines diffusion process, reverse process and data processor.
|
9 |
-
"""
|
10 |
-
|
11 |
-
from collections import namedtuple
|
12 |
-
import random
|
13 |
-
import typing as tp
|
14 |
-
import julius
|
15 |
-
import torch
|
16 |
-
|
17 |
-
TrainingItem = namedtuple("TrainingItem", "noisy noise step")
|
18 |
-
|
19 |
-
|
20 |
-
def betas_from_alpha_bar(alpha_bar):
|
21 |
-
alphas = torch.cat([torch.Tensor([alpha_bar[0]]), alpha_bar[1:]/alpha_bar[:-1]])
|
22 |
-
return 1 - alphas
|
23 |
-
|
24 |
-
|
25 |
-
class SampleProcessor(torch.nn.Module):
|
26 |
-
def project_sample(self, x: torch.Tensor):
|
27 |
-
"""Project the original sample to the 'space' where the diffusion will happen."""
|
28 |
-
return x
|
29 |
-
|
30 |
-
def return_sample(self, z: torch.Tensor):
|
31 |
-
"""Project back from diffusion space to the actual sample space."""
|
32 |
-
return z
|
33 |
-
|
34 |
-
|
35 |
-
class MultiBandProcessor(SampleProcessor):
|
36 |
-
"""
|
37 |
-
MultiBand sample processor. The input audio is splitted across
|
38 |
-
frequency bands evenly distributed in mel-scale.
|
39 |
-
|
40 |
-
Each band will be rescaled to match the power distribution
|
41 |
-
of Gaussian noise in that band, using online metrics
|
42 |
-
computed on the first few samples.
|
43 |
-
|
44 |
-
Args:
|
45 |
-
n_bands (int): Number of mel-bands to split the signal over.
|
46 |
-
sample_rate (int): Sample rate of the audio.
|
47 |
-
num_samples (int): Number of samples to use to fit the rescaling
|
48 |
-
for each band. The processor won't be stable
|
49 |
-
until it has seen that many samples.
|
50 |
-
power_std (float or list/tensor): The rescaling factor computed to match the
|
51 |
-
power of Gaussian noise in each band is taken to
|
52 |
-
that power, i.e. `1.` means full correction of the energy
|
53 |
-
in each band, and values less than `1` means only partial
|
54 |
-
correction. Can be used to balance the relative importance
|
55 |
-
of low vs. high freq in typical audio signals.
|
56 |
-
"""
|
57 |
-
def __init__(self, n_bands: int = 8, sample_rate: float = 24_000,
|
58 |
-
num_samples: int = 10_000, power_std: tp.Union[float, tp.List[float], torch.Tensor] = 1.):
|
59 |
-
super().__init__()
|
60 |
-
self.n_bands = n_bands
|
61 |
-
self.split_bands = julius.SplitBands(sample_rate, n_bands=n_bands)
|
62 |
-
self.num_samples = num_samples
|
63 |
-
self.power_std = power_std
|
64 |
-
if isinstance(power_std, list):
|
65 |
-
assert len(power_std) == n_bands
|
66 |
-
power_std = torch.tensor(power_std)
|
67 |
-
self.register_buffer('counts', torch.zeros(1))
|
68 |
-
self.register_buffer('sum_x', torch.zeros(n_bands))
|
69 |
-
self.register_buffer('sum_x2', torch.zeros(n_bands))
|
70 |
-
self.register_buffer('sum_target_x2', torch.zeros(n_bands))
|
71 |
-
self.counts: torch.Tensor
|
72 |
-
self.sum_x: torch.Tensor
|
73 |
-
self.sum_x2: torch.Tensor
|
74 |
-
self.sum_target_x2: torch.Tensor
|
75 |
-
|
76 |
-
@property
|
77 |
-
def mean(self):
|
78 |
-
mean = self.sum_x / self.counts
|
79 |
-
return mean
|
80 |
-
|
81 |
-
@property
|
82 |
-
def std(self):
|
83 |
-
std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
|
84 |
-
return std
|
85 |
-
|
86 |
-
@property
|
87 |
-
def target_std(self):
|
88 |
-
target_std = self.sum_target_x2 / self.counts
|
89 |
-
return target_std
|
90 |
-
|
91 |
-
def project_sample(self, x: torch.Tensor):
|
92 |
-
assert x.dim() == 3
|
93 |
-
bands = self.split_bands(x)
|
94 |
-
if self.counts.item() < self.num_samples:
|
95 |
-
ref_bands = self.split_bands(torch.randn_like(x))
|
96 |
-
self.counts += len(x)
|
97 |
-
self.sum_x += bands.mean(dim=(2, 3)).sum(dim=1)
|
98 |
-
self.sum_x2 += bands.pow(2).mean(dim=(2, 3)).sum(dim=1)
|
99 |
-
self.sum_target_x2 += ref_bands.pow(2).mean(dim=(2, 3)).sum(dim=1)
|
100 |
-
rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
|
101 |
-
bands = (bands - self.mean.view(-1, 1, 1, 1)) * rescale.view(-1, 1, 1, 1)
|
102 |
-
return bands.sum(dim=0)
|
103 |
-
|
104 |
-
def return_sample(self, x: torch.Tensor):
|
105 |
-
assert x.dim() == 3
|
106 |
-
bands = self.split_bands(x)
|
107 |
-
rescale = (self.std / self.target_std) ** self.power_std
|
108 |
-
bands = bands * rescale.view(-1, 1, 1, 1) + self.mean.view(-1, 1, 1, 1)
|
109 |
-
return bands.sum(dim=0)
|
110 |
-
|
111 |
-
|
112 |
-
class NoiseSchedule:
|
113 |
-
"""Noise schedule for diffusion.
|
114 |
-
|
115 |
-
Args:
|
116 |
-
beta_t0 (float): Variance of the first diffusion step.
|
117 |
-
beta_t1 (float): Variance of the last diffusion step.
|
118 |
-
beta_exp (float): Power schedule exponent
|
119 |
-
num_steps (int): Number of diffusion step.
|
120 |
-
variance (str): choice of the sigma value for the denoising eq. Choices: "beta" or "beta_tilde"
|
121 |
-
clip (float): clipping value for the denoising steps
|
122 |
-
rescale (float): rescaling value to avoid vanishing signals unused by default (i.e 1)
|
123 |
-
repartition (str): shape of the schedule only power schedule is supported
|
124 |
-
sample_processor (SampleProcessor): Module that normalize data to match better the gaussian distribution
|
125 |
-
noise_scale (float): Scaling factor for the noise
|
126 |
-
"""
|
127 |
-
def __init__(self, beta_t0: float = 1e-4, beta_t1: float = 0.02, num_steps: int = 1000, variance: str = 'beta',
|
128 |
-
clip: float = 5., rescale: float = 1., device='cuda', beta_exp: float = 1,
|
129 |
-
repartition: str = "power", alpha_sigmoid: dict = {}, n_bands: tp.Optional[int] = None,
|
130 |
-
sample_processor: SampleProcessor = SampleProcessor(), noise_scale: float = 1.0, **kwargs):
|
131 |
-
|
132 |
-
self.beta_t0 = beta_t0
|
133 |
-
self.beta_t1 = beta_t1
|
134 |
-
self.variance = variance
|
135 |
-
self.num_steps = num_steps
|
136 |
-
self.clip = clip
|
137 |
-
self.sample_processor = sample_processor
|
138 |
-
self.rescale = rescale
|
139 |
-
self.n_bands = n_bands
|
140 |
-
self.noise_scale = noise_scale
|
141 |
-
assert n_bands is None
|
142 |
-
if repartition == "power":
|
143 |
-
self.betas = torch.linspace(beta_t0 ** (1 / beta_exp), beta_t1 ** (1 / beta_exp), num_steps,
|
144 |
-
device=device, dtype=torch.float) ** beta_exp
|
145 |
-
else:
|
146 |
-
raise RuntimeError('Not implemented')
|
147 |
-
self.rng = random.Random(1234)
|
148 |
-
|
149 |
-
def get_beta(self, step: tp.Union[int, torch.Tensor]):
|
150 |
-
if self.n_bands is None:
|
151 |
-
return self.betas[step]
|
152 |
-
else:
|
153 |
-
return self.betas[:, step] # [n_bands, len(step)]
|
154 |
-
|
155 |
-
def get_initial_noise(self, x: torch.Tensor):
|
156 |
-
if self.n_bands is None:
|
157 |
-
return torch.randn_like(x)
|
158 |
-
return torch.randn((x.size(0), self.n_bands, x.size(2)))
|
159 |
-
|
160 |
-
def get_alpha_bar(self, step: tp.Optional[tp.Union[int, torch.Tensor]] = None) -> torch.Tensor:
|
161 |
-
"""Return 'alpha_bar', either for a given step, or as a tensor with its value for each step."""
|
162 |
-
if step is None:
|
163 |
-
return (1 - self.betas).cumprod(dim=-1) # works for simgle and multi bands
|
164 |
-
if type(step) is int:
|
165 |
-
return (1 - self.betas[:step + 1]).prod()
|
166 |
-
else:
|
167 |
-
return (1 - self.betas).cumprod(dim=0)[step].view(-1, 1, 1)
|
168 |
-
|
169 |
-
def get_training_item(self, x: torch.Tensor, tensor_step: bool = False) -> TrainingItem:
|
170 |
-
"""Create a noisy data item for diffusion model training:
|
171 |
-
|
172 |
-
Args:
|
173 |
-
x (torch.Tensor): clean audio data torch.tensor(bs, 1, T)
|
174 |
-
tensor_step (bool): If tensor_step = false, only one step t is sample,
|
175 |
-
the whole batch is diffused to the same step and t is int.
|
176 |
-
If tensor_step = true, t is a tensor of size (x.size(0),)
|
177 |
-
every element of the batch is diffused to a independently sampled.
|
178 |
-
"""
|
179 |
-
step: tp.Union[int, torch.Tensor]
|
180 |
-
if tensor_step:
|
181 |
-
bs = x.size(0)
|
182 |
-
step = torch.randint(0, self.num_steps, size=(bs,), device=x.device)
|
183 |
-
else:
|
184 |
-
step = self.rng.randrange(self.num_steps)
|
185 |
-
alpha_bar = self.get_alpha_bar(step) # [batch_size, n_bands, 1]
|
186 |
-
|
187 |
-
x = self.sample_processor.project_sample(x)
|
188 |
-
noise = torch.randn_like(x)
|
189 |
-
noisy = (alpha_bar.sqrt() / self.rescale) * x + (1 - alpha_bar).sqrt() * noise * self.noise_scale
|
190 |
-
return TrainingItem(noisy, noise, step)
|
191 |
-
|
192 |
-
def generate(self, model: torch.nn.Module, initial: tp.Optional[torch.Tensor] = None,
|
193 |
-
condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
|
194 |
-
"""Full ddpm reverse process.
|
195 |
-
|
196 |
-
Args:
|
197 |
-
model (nn.Module): Diffusion model.
|
198 |
-
initial (tensor): Initial Noise.
|
199 |
-
condition (tensor): Input conditionning Tensor (e.g. encodec compressed representation).
|
200 |
-
return_list (bool): Whether to return the whole process or only the sampled point.
|
201 |
-
"""
|
202 |
-
alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
|
203 |
-
current = initial
|
204 |
-
iterates = [initial]
|
205 |
-
for step in range(self.num_steps)[::-1]:
|
206 |
-
with torch.no_grad():
|
207 |
-
estimate = model(current, step, condition=condition).sample
|
208 |
-
alpha = 1 - self.betas[step]
|
209 |
-
previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
|
210 |
-
previous_alpha_bar = self.get_alpha_bar(step=step - 1)
|
211 |
-
if step == 0:
|
212 |
-
sigma2 = 0
|
213 |
-
elif self.variance == 'beta':
|
214 |
-
sigma2 = 1 - alpha
|
215 |
-
elif self.variance == 'beta_tilde':
|
216 |
-
sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
|
217 |
-
elif self.variance == 'none':
|
218 |
-
sigma2 = 0
|
219 |
-
else:
|
220 |
-
raise ValueError(f'Invalid variance type {self.variance}')
|
221 |
-
|
222 |
-
if sigma2 > 0:
|
223 |
-
previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
|
224 |
-
if self.clip:
|
225 |
-
previous = previous.clamp(-self.clip, self.clip)
|
226 |
-
current = previous
|
227 |
-
alpha_bar = previous_alpha_bar
|
228 |
-
if step == 0:
|
229 |
-
previous *= self.rescale
|
230 |
-
if return_list:
|
231 |
-
iterates.append(previous.cpu())
|
232 |
-
|
233 |
-
if return_list:
|
234 |
-
return iterates
|
235 |
-
else:
|
236 |
-
return self.sample_processor.return_sample(previous)
|
237 |
-
|
238 |
-
def generate_subsampled(self, model: torch.nn.Module, initial: torch.Tensor, step_list: tp.Optional[list] = None,
|
239 |
-
condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
|
240 |
-
"""Reverse process that only goes through Markov chain states in step_list."""
|
241 |
-
if step_list is None:
|
242 |
-
step_list = list(range(1000))[::-50] + [0]
|
243 |
-
alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
|
244 |
-
alpha_bars_subsampled = (1 - self.betas).cumprod(dim=0)[list(reversed(step_list))].cpu()
|
245 |
-
betas_subsampled = betas_from_alpha_bar(alpha_bars_subsampled)
|
246 |
-
current = initial * self.noise_scale
|
247 |
-
iterates = [current]
|
248 |
-
for idx, step in enumerate(step_list[:-1]):
|
249 |
-
with torch.no_grad():
|
250 |
-
estimate = model(current, step, condition=condition).sample * self.noise_scale
|
251 |
-
alpha = 1 - betas_subsampled[-1 - idx]
|
252 |
-
previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
|
253 |
-
previous_alpha_bar = self.get_alpha_bar(step_list[idx + 1])
|
254 |
-
if step == step_list[-2]:
|
255 |
-
sigma2 = 0
|
256 |
-
previous_alpha_bar = torch.tensor(1.0)
|
257 |
-
else:
|
258 |
-
sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
|
259 |
-
if sigma2 > 0:
|
260 |
-
previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
|
261 |
-
if self.clip:
|
262 |
-
previous = previous.clamp(-self.clip, self.clip)
|
263 |
-
current = previous
|
264 |
-
alpha_bar = previous_alpha_bar
|
265 |
-
if step == 0:
|
266 |
-
previous *= self.rescale
|
267 |
-
if return_list:
|
268 |
-
iterates.append(previous.cpu())
|
269 |
-
if return_list:
|
270 |
-
return iterates
|
271 |
-
else:
|
272 |
-
return self.sample_processor.return_sample(previous)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/loaders.py
CHANGED
@@ -1,33 +1,9 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
"""
|
8 |
-
Utility functions to load from the checkpoints.
|
9 |
-
Each checkpoint is a torch.saved dict with the following keys:
|
10 |
-
- 'xp.cfg': the hydra config as dumped during training. This should be used
|
11 |
-
to rebuild the object using the audiocraft.models.builders functions,
|
12 |
-
- 'model_best_state': a readily loadable best state for the model, including
|
13 |
-
the conditioner. The model obtained from `xp.cfg` should be compatible
|
14 |
-
with this state dict. In the case of a LM, the encodec model would not be
|
15 |
-
bundled along but instead provided separately.
|
16 |
-
|
17 |
-
Those functions also support loading from a remote location with the Torch Hub API.
|
18 |
-
They also support overriding some parameters, in particular the device and dtype
|
19 |
-
of the returned model.
|
20 |
-
"""
|
21 |
-
|
22 |
from pathlib import Path
|
23 |
from huggingface_hub import hf_hub_download
|
24 |
import typing as tp
|
25 |
import os
|
26 |
-
|
27 |
from omegaconf import OmegaConf, DictConfig
|
28 |
import torch
|
29 |
-
|
30 |
-
import audiocraft
|
31 |
from . import builders
|
32 |
from .encodec import EncodecModel
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from pathlib import Path
|
2 |
from huggingface_hub import hf_hub_download
|
3 |
import typing as tp
|
4 |
import os
|
|
|
5 |
from omegaconf import OmegaConf, DictConfig
|
6 |
import torch
|
|
|
|
|
7 |
from . import builders
|
8 |
from .encodec import EncodecModel
|
9 |
|
audiocraft/rope.py
DELETED
@@ -1,125 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
import typing as tp
|
8 |
-
|
9 |
-
from torch import nn
|
10 |
-
import torch
|
11 |
-
|
12 |
-
|
13 |
-
class XPos(nn.Module):
|
14 |
-
"""Length-extrapolatable positional embedding (xPos) from [Sun et al 2022](https://arxiv.org/abs/2212.10554v1).
|
15 |
-
This applies an exponential decay to the RoPE rotation matrix.
|
16 |
-
|
17 |
-
Args:
|
18 |
-
dim (int): Embedding dimension.
|
19 |
-
smoothing (float): Smoothing factor applied to the decay rates.
|
20 |
-
base_scale (int): Base decay rate, given in terms of scaling time.
|
21 |
-
device (torch.device, optional): Device on which to initialize the module.
|
22 |
-
dtype (torch.dtype): dtype to use to generate the embedding.
|
23 |
-
"""
|
24 |
-
def __init__(self, dim: int, smoothing: float = 0.4, base_scale: int = 512,
|
25 |
-
device=None, dtype: torch.dtype = torch.float32):
|
26 |
-
super().__init__()
|
27 |
-
assert dim % 2 == 0
|
28 |
-
assert dtype in [torch.float64, torch.float32]
|
29 |
-
self.dtype = dtype
|
30 |
-
self.base_scale = base_scale
|
31 |
-
|
32 |
-
half_dim = dim // 2
|
33 |
-
adim = torch.arange(half_dim, device=device, dtype=dtype)
|
34 |
-
decay_rates = (adim / half_dim + smoothing) / (1.0 + smoothing)
|
35 |
-
self.register_buffer("decay_rates", decay_rates)
|
36 |
-
self.decay: tp.Optional[torch.Tensor] = None
|
37 |
-
|
38 |
-
def get_decay(self, start: int, end: int):
|
39 |
-
"""Create complex decay tensor, cache values for fast computation."""
|
40 |
-
if self.decay is None or end > self.decay.shape[0]:
|
41 |
-
assert isinstance(self.decay_rates, torch.Tensor) # Satisfy type checker.
|
42 |
-
idx = torch.arange(end, device=self.decay_rates.device, dtype=self.dtype)
|
43 |
-
power = idx / self.base_scale
|
44 |
-
scale = self.decay_rates ** power.unsqueeze(-1)
|
45 |
-
self.decay = torch.polar(scale, torch.zeros_like(scale))
|
46 |
-
return self.decay[start:end] # [T, C/2]
|
47 |
-
|
48 |
-
|
49 |
-
class RotaryEmbedding(nn.Module):
|
50 |
-
"""Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864).
|
51 |
-
|
52 |
-
Args:
|
53 |
-
dim (int): Embedding dimension (twice the number of frequencies).
|
54 |
-
max_period (float): Maximum period of the rotation frequencies.
|
55 |
-
xpos (bool): Use xPos, applies an exponential decay to rotation matrix.
|
56 |
-
scale (float): Scale of positional embedding, set to 0 to deactivate.
|
57 |
-
device (torch.device, optional): Device on which to initialize the module.
|
58 |
-
dtype (torch.dtype): dtype to use to generate the embedding.
|
59 |
-
"""
|
60 |
-
def __init__(self, dim: int, max_period: float = 10000.0, xpos: bool = False,
|
61 |
-
scale: float = 1.0, device=None, dtype: torch.dtype = torch.float32):
|
62 |
-
super().__init__()
|
63 |
-
assert dim % 2 == 0
|
64 |
-
self.scale = scale
|
65 |
-
assert dtype in [torch.float64, torch.float32]
|
66 |
-
self.dtype = dtype
|
67 |
-
|
68 |
-
adim = torch.arange(0, dim, 2, device=device, dtype=dtype)[: (dim // 2)]
|
69 |
-
frequencies = 1.0 / (max_period ** (adim / dim))
|
70 |
-
self.register_buffer("frequencies", frequencies)
|
71 |
-
self.rotation: tp.Optional[torch.Tensor] = None
|
72 |
-
|
73 |
-
self.xpos = XPos(dim, device=device, dtype=dtype) if xpos else None
|
74 |
-
|
75 |
-
def get_rotation(self, start: int, end: int):
|
76 |
-
"""Create complex rotation tensor, cache values for fast computation."""
|
77 |
-
if self.rotation is None or end > self.rotation.shape[0]:
|
78 |
-
assert isinstance(self.frequencies, torch.Tensor) # Satisfy type checker.
|
79 |
-
idx = torch.arange(end, device=self.frequencies.device, dtype=self.dtype)
|
80 |
-
angles = torch.outer(idx, self.frequencies)
|
81 |
-
self.rotation = torch.polar(torch.ones_like(angles), angles)
|
82 |
-
return self.rotation[start:end]
|
83 |
-
|
84 |
-
def rotate(self, x: torch.Tensor, start: int = 0, time_dim: int = 1, invert_decay: bool = False):
|
85 |
-
"""Apply rope rotation to query or key tensor."""
|
86 |
-
T = x.shape[time_dim]
|
87 |
-
target_shape = [1] * x.dim()
|
88 |
-
target_shape[time_dim] = T
|
89 |
-
target_shape[-1] = -1
|
90 |
-
rotation = self.get_rotation(start, start + T).view(target_shape)
|
91 |
-
|
92 |
-
if self.xpos:
|
93 |
-
decay = self.xpos.get_decay(start, start + T).view(target_shape)
|
94 |
-
else:
|
95 |
-
decay = 1.0
|
96 |
-
|
97 |
-
if invert_decay:
|
98 |
-
decay = decay ** -1
|
99 |
-
|
100 |
-
x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2))
|
101 |
-
scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale)
|
102 |
-
x_out = torch.view_as_real(x_complex * scaled_rotation).view_as(x)
|
103 |
-
|
104 |
-
return x_out.type_as(x)
|
105 |
-
|
106 |
-
def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0, time_dim: int = 1):
|
107 |
-
""" Apply rope rotation to both query and key tensors.
|
108 |
-
Supports streaming mode, in which query and key are not expected to have the same shape.
|
109 |
-
In streaming mode, key will be of length [P + C] with P the cached past timesteps, but
|
110 |
-
query will be [C] (typically C == 1).
|
111 |
-
|
112 |
-
Args:
|
113 |
-
query (torch.Tensor): Query to rotate.
|
114 |
-
key (torch.Tensor): Key to rotate.
|
115 |
-
start (int): Start index of the sequence for time offset.
|
116 |
-
time_dim (int): which dimension represent the time steps.
|
117 |
-
"""
|
118 |
-
query_timesteps = query.shape[time_dim]
|
119 |
-
key_timesteps = key.shape[time_dim]
|
120 |
-
streaming_offset = key_timesteps - query_timesteps
|
121 |
-
|
122 |
-
query_out = self.rotate(query, start + streaming_offset, time_dim)
|
123 |
-
key_out = self.rotate(key, start, time_dim, invert_decay=True)
|
124 |
-
|
125 |
-
return query_out, key_out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/unet.py
DELETED
@@ -1,214 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
"""
|
8 |
-
Pytorch Unet Module used for diffusion.
|
9 |
-
"""
|
10 |
-
|
11 |
-
from dataclasses import dataclass
|
12 |
-
import typing as tp
|
13 |
-
|
14 |
-
import torch
|
15 |
-
from torch import nn
|
16 |
-
from torch.nn import functional as F
|
17 |
-
from .transformer import StreamingTransformer, create_sin_embedding
|
18 |
-
|
19 |
-
|
20 |
-
@dataclass
|
21 |
-
class Output:
|
22 |
-
sample: torch.Tensor
|
23 |
-
|
24 |
-
|
25 |
-
def get_model(cfg, channels: int, side: int, num_steps: int):
|
26 |
-
if cfg.model == 'unet':
|
27 |
-
return DiffusionUnet(
|
28 |
-
chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
|
29 |
-
else:
|
30 |
-
raise RuntimeError('Not Implemented')
|
31 |
-
|
32 |
-
|
33 |
-
class ResBlock(nn.Module):
|
34 |
-
def __init__(self, channels: int, kernel: int = 3, norm_groups: int = 4,
|
35 |
-
dilation: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
|
36 |
-
dropout: float = 0.):
|
37 |
-
super().__init__()
|
38 |
-
stride = 1
|
39 |
-
padding = dilation * (kernel - stride) // 2
|
40 |
-
Conv = nn.Conv1d
|
41 |
-
Drop = nn.Dropout1d
|
42 |
-
self.norm1 = nn.GroupNorm(norm_groups, channels)
|
43 |
-
self.conv1 = Conv(channels, channels, kernel, 1, padding, dilation=dilation)
|
44 |
-
self.activation1 = activation()
|
45 |
-
self.dropout1 = Drop(dropout)
|
46 |
-
|
47 |
-
self.norm2 = nn.GroupNorm(norm_groups, channels)
|
48 |
-
self.conv2 = Conv(channels, channels, kernel, 1, padding, dilation=dilation)
|
49 |
-
self.activation2 = activation()
|
50 |
-
self.dropout2 = Drop(dropout)
|
51 |
-
|
52 |
-
def forward(self, x):
|
53 |
-
h = self.dropout1(self.conv1(self.activation1(self.norm1(x))))
|
54 |
-
h = self.dropout2(self.conv2(self.activation2(self.norm2(h))))
|
55 |
-
return x + h
|
56 |
-
|
57 |
-
|
58 |
-
class DecoderLayer(nn.Module):
|
59 |
-
def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2,
|
60 |
-
norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
|
61 |
-
dropout: float = 0.):
|
62 |
-
super().__init__()
|
63 |
-
padding = (kernel - stride) // 2
|
64 |
-
self.res_blocks = nn.Sequential(
|
65 |
-
*[ResBlock(chin, norm_groups=norm_groups, dilation=2**idx, dropout=dropout)
|
66 |
-
for idx in range(res_blocks)])
|
67 |
-
self.norm = nn.GroupNorm(norm_groups, chin)
|
68 |
-
ConvTr = nn.ConvTranspose1d
|
69 |
-
self.convtr = ConvTr(chin, chout, kernel, stride, padding, bias=False)
|
70 |
-
self.activation = activation()
|
71 |
-
|
72 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
73 |
-
x = self.res_blocks(x)
|
74 |
-
x = self.norm(x)
|
75 |
-
x = self.activation(x)
|
76 |
-
x = self.convtr(x)
|
77 |
-
return x
|
78 |
-
|
79 |
-
|
80 |
-
class EncoderLayer(nn.Module):
|
81 |
-
def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2,
|
82 |
-
norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
|
83 |
-
dropout: float = 0.):
|
84 |
-
super().__init__()
|
85 |
-
padding = (kernel - stride) // 2
|
86 |
-
Conv = nn.Conv1d
|
87 |
-
self.conv = Conv(chin, chout, kernel, stride, padding, bias=False)
|
88 |
-
self.norm = nn.GroupNorm(norm_groups, chout)
|
89 |
-
self.activation = activation()
|
90 |
-
self.res_blocks = nn.Sequential(
|
91 |
-
*[ResBlock(chout, norm_groups=norm_groups, dilation=2**idx, dropout=dropout)
|
92 |
-
for idx in range(res_blocks)])
|
93 |
-
|
94 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
95 |
-
B, C, T = x.shape
|
96 |
-
stride, = self.conv.stride
|
97 |
-
pad = (stride - (T % stride)) % stride
|
98 |
-
x = F.pad(x, (0, pad))
|
99 |
-
|
100 |
-
x = self.conv(x)
|
101 |
-
x = self.norm(x)
|
102 |
-
x = self.activation(x)
|
103 |
-
x = self.res_blocks(x)
|
104 |
-
return x
|
105 |
-
|
106 |
-
|
107 |
-
class BLSTM(nn.Module):
|
108 |
-
"""BiLSTM with same hidden units as input dim.
|
109 |
-
"""
|
110 |
-
def __init__(self, dim, layers=2):
|
111 |
-
super().__init__()
|
112 |
-
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
|
113 |
-
self.linear = nn.Linear(2 * dim, dim)
|
114 |
-
|
115 |
-
def forward(self, x):
|
116 |
-
x = x.permute(2, 0, 1)
|
117 |
-
x = self.lstm(x)[0]
|
118 |
-
x = self.linear(x)
|
119 |
-
x = x.permute(1, 2, 0)
|
120 |
-
return x
|
121 |
-
|
122 |
-
|
123 |
-
class DiffusionUnet(nn.Module):
|
124 |
-
def __init__(self, chin: int = 3, hidden: int = 24, depth: int = 3, growth: float = 2.,
|
125 |
-
max_channels: int = 10_000, num_steps: int = 1000, emb_all_layers=False, cross_attention: bool = False,
|
126 |
-
bilstm: bool = False, transformer: bool = False,
|
127 |
-
codec_dim: tp.Optional[int] = None, **kwargs):
|
128 |
-
super().__init__()
|
129 |
-
self.encoders = nn.ModuleList()
|
130 |
-
self.decoders = nn.ModuleList()
|
131 |
-
self.embeddings: tp.Optional[nn.ModuleList] = None
|
132 |
-
self.embedding = nn.Embedding(num_steps, hidden)
|
133 |
-
if emb_all_layers:
|
134 |
-
self.embeddings = nn.ModuleList()
|
135 |
-
self.condition_embedding: tp.Optional[nn.Module] = None
|
136 |
-
for d in range(depth):
|
137 |
-
encoder = EncoderLayer(chin, hidden, **kwargs)
|
138 |
-
decoder = DecoderLayer(hidden, chin, **kwargs)
|
139 |
-
self.encoders.append(encoder)
|
140 |
-
self.decoders.insert(0, decoder)
|
141 |
-
if emb_all_layers and d > 0:
|
142 |
-
assert self.embeddings is not None
|
143 |
-
self.embeddings.append(nn.Embedding(num_steps, hidden))
|
144 |
-
chin = hidden
|
145 |
-
hidden = min(int(chin * growth), max_channels)
|
146 |
-
self.bilstm: tp.Optional[nn.Module]
|
147 |
-
if bilstm:
|
148 |
-
self.bilstm = BLSTM(chin)
|
149 |
-
else:
|
150 |
-
self.bilstm = None
|
151 |
-
self.use_transformer = transformer
|
152 |
-
self.cross_attention = False
|
153 |
-
if transformer:
|
154 |
-
self.cross_attention = cross_attention
|
155 |
-
self.transformer = StreamingTransformer(chin, 8, 6, bias_ff=False, bias_attn=False,
|
156 |
-
cross_attention=cross_attention)
|
157 |
-
|
158 |
-
self.use_codec = False
|
159 |
-
if codec_dim is not None:
|
160 |
-
self.conv_codec = nn.Conv1d(codec_dim, chin, 1)
|
161 |
-
self.use_codec = True
|
162 |
-
|
163 |
-
def forward(self, x: torch.Tensor, step: tp.Union[int, torch.Tensor], condition: tp.Optional[torch.Tensor] = None):
|
164 |
-
skips = []
|
165 |
-
bs = x.size(0)
|
166 |
-
z = x
|
167 |
-
view_args = [1]
|
168 |
-
if type(step) is torch.Tensor:
|
169 |
-
step_tensor = step
|
170 |
-
else:
|
171 |
-
step_tensor = torch.tensor([step], device=x.device, dtype=torch.long).expand(bs)
|
172 |
-
|
173 |
-
for idx, encoder in enumerate(self.encoders):
|
174 |
-
z = encoder(z)
|
175 |
-
if idx == 0:
|
176 |
-
z = z + self.embedding(step_tensor).view(bs, -1, *view_args).expand_as(z)
|
177 |
-
elif self.embeddings is not None:
|
178 |
-
z = z + self.embeddings[idx - 1](step_tensor).view(bs, -1, *view_args).expand_as(z)
|
179 |
-
|
180 |
-
skips.append(z)
|
181 |
-
|
182 |
-
if self.use_codec: # insert condition in the bottleneck
|
183 |
-
assert condition is not None, "Model defined for conditionnal generation"
|
184 |
-
condition_emb = self.conv_codec(condition) # reshape to the bottleneck dim
|
185 |
-
assert condition_emb.size(-1) <= 2 * z.size(-1), \
|
186 |
-
f"You are downsampling the conditionning with factor >=2 : {condition_emb.size(-1)=} and {z.size(-1)=}"
|
187 |
-
if not self.cross_attention:
|
188 |
-
|
189 |
-
condition_emb = torch.nn.functional.interpolate(condition_emb, z.size(-1))
|
190 |
-
assert z.size() == condition_emb.size()
|
191 |
-
z += condition_emb
|
192 |
-
cross_attention_src = None
|
193 |
-
else:
|
194 |
-
cross_attention_src = condition_emb.permute(0, 2, 1) # B, T, C
|
195 |
-
B, T, C = cross_attention_src.shape
|
196 |
-
positions = torch.arange(T, device=x.device).view(1, -1, 1)
|
197 |
-
pos_emb = create_sin_embedding(positions, C, max_period=10_000, dtype=cross_attention_src.dtype)
|
198 |
-
cross_attention_src = cross_attention_src + pos_emb
|
199 |
-
if self.use_transformer:
|
200 |
-
z = self.transformer(z.permute(0, 2, 1), cross_attention_src=cross_attention_src).permute(0, 2, 1)
|
201 |
-
else:
|
202 |
-
if self.bilstm is None:
|
203 |
-
z = torch.zeros_like(z)
|
204 |
-
else:
|
205 |
-
z = self.bilstm(z)
|
206 |
-
|
207 |
-
for decoder in self.decoders:
|
208 |
-
s = skips.pop(-1)
|
209 |
-
z = z[:, :, :s.shape[2]]
|
210 |
-
z = z + s
|
211 |
-
z = decoder(z)
|
212 |
-
|
213 |
-
z = z[:, :, :x.shape[2]]
|
214 |
-
return Output(z)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|