jbetker commited on
Commit
5a958b4
·
1 Parent(s): 051f500

Initial commit

Browse files
.gitignore CHANGED
@@ -127,3 +127,6 @@ dmypy.json
127
 
128
  # Pyre type checker
129
  .pyre/
 
 
 
 
127
 
128
  # Pyre type checker
129
  .pyre/
130
+
131
+ .idea/*
132
+ .models/*
README.md CHANGED
@@ -1,2 +1,41 @@
1
- # tortoise-tts
2
- A multi-voice TTS system trained with an emphasis on quality
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Tortoise-TTS
2
+
3
+ Tortoise TTS is an experimental text-to-speech program that uses recent machine learning techniques to generate
4
+ high-quality speech samples.
5
+
6
+ This repo contains all the code needed to run Tortoise TTS in inference mode.
7
+
8
+ ## What's in a name?
9
+
10
+ I'm naming my speech-related repos after Mojave desert flora and fauna. Tortoise is a bit tongue in cheek: this model
11
+ is insanely slow. It leverages both an autoregressive speech alignment model and a diffusion model, both of which
12
+ are known for their slow inference. It also performs CLIP sampling, which slows things down even further. You can
13
+ expect ~5 seconds of speech to take ~30 seconds to produce on the latest hardware. Still, the results are pretty cool.
14
+
15
+ ## What the heck is this?
16
+
17
+ Tortoise TTS is inspired by OpenAI's DALLE, applied to speech data. It is made up of 4 separate models that work together:
18
+
19
+ First, an autoregressive transformer stack predicts discrete speech "tokens" given a text prompt. This model is very
20
+ similar to the GPT model used by DALLE, except it operates on speech data.
21
+
22
+ Next, a CLIP model judges a batch of outputs from the autoregressive transformer against the provided text and stack
23
+ ranks the outputs according to most probable. You could use greedy or beam-search decoding but in my experience CLIP
24
+ decoding creates considerably better results.
25
+
26
+ Next, the speech "tokens" are decoded into a low-quality MEL spectrogram using a VQVAE.
27
+
28
+ Finally, the output of the VQVAE is further decoded by a UNet diffusion model into raw audio, which can be placed in
29
+ a wav file.
30
+
31
+ ## How do I use this?
32
+
33
+ <incoming>
34
+
35
+ ## How do I train this?
36
+
37
+ Frankly - you don't. Building this model has been a labor of love for me, consuming most of my 6 RTX3090s worth of
38
+ resources for the better part of 6 months. It uses a dataset I've gathered, refined and transcribed that consists of
39
+ a lot of audio data which I cannot distribute because of copywrite or no open licenses.
40
+
41
+ With that said, I'm willing to help you out if you really want to give it a shot. DM me.
data/tokenizer.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"version":"1.0","truncation":null,"padding":null,"added_tokens":[{"id":0,"special":true,"content":"[STOP]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":1,"special":true,"content":"[UNK]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":2,"special":true,"content":"[SPACE]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false}],"normalizer":null,"pre_tokenizer":{"type":"Whitespace"},"post_processor":null,"decoder":null,"model":{"type":"BPE","dropout":null,"unk_token":"[UNK]","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"vocab":{"[STOP]":0,"[UNK]":1,"[SPACE]":2,"!":3,"'":4,"(":5,")":6,",":7,"-":8,".":9,"/":10,":":11,";":12,"?":13,"a":14,"b":15,"c":16,"d":17,"e":18,"f":19,"g":20,"h":21,"i":22,"j":23,"k":24,"l":25,"m":26,"n":27,"o":28,"p":29,"q":30,"r":31,"s":32,"t":33,"u":34,"v":35,"w":36,"x":37,"y":38,"z":39,"th":40,"in":41,"the":42,"an":43,"er":44,"ou":45,"re":46,"on":47,"at":48,"ed":49,"en":50,"to":51,"ing":52,"and":53,"is":54,"as":55,"al":56,"or":57,"of":58,"ar":59,"it":60,"es":61,"he":62,"st":63,"le":64,"om":65,"se":66,"be":67,"ad":68,"ow":69,"ly":70,"ch":71,"wh":72,"that":73,"you":74,"li":75,"ve":76,"ac":77,"ti":78,"ld":79,"me":80,"was":81,"gh":82,"id":83,"ll":84,"wi":85,"ent":86,"for":87,"ay":88,"ro":89,"ver":90,"ic":91,"her":92,"ke":93,"his":94,"no":95,"ut":96,"un":97,"ir":98,"lo":99,"we":100,"ri":101,"ha":102,"with":103,"ght":104,"out":105,"im":106,"ion":107,"all":108,"ab":109,"one":110,"ne":111,"ge":112,"ould":113,"ter":114,"mo":115,"had":116,"ce":117,"she":118,"go":119,"sh":120,"ur":121,"am":122,"so":123,"pe":124,"my":125,"de":126,"are":127,"but":128,"ome":129,"fr":130,"ther":131,"fe":132,"su":133,"do":134,"con":135,"te":136,"ain":137,"ere":138,"po":139,"if":140,"they":141,"us":142,"ag":143,"tr":144,"now":145,"oun":146,"this":147,"have":148,"not":149,"sa":150,"il":151,"up":152,"thing":153,"from":154,"ap":155,"him":156,"ack":157,"ation":158,"ant":159,"our":160,"op":161,"like":162,"ust":163,"ess":164,"bo":165,"ok":166,"ul":167,"ind":168,"ex":169,"com":170,"some":171,"there":172,"ers":173,"co":174,"res":175,"man":176,"ard":177,"pl":178,"wor":179,"way":180,"tion":181,"fo":182,"ca":183,"were":184,"by":185,"ate":186,"pro":187,"ted":188,"ound":189,"own":190,"would":191,"ts":192,"what":193,"qu":194,"ally":195,"ight":196,"ck":197,"gr":198,"when":199,"ven":200,"can":201,"ough":202,"ine":203,"end":204,"per":205,"ous":206,"od":207,"ide":208,"know":209,"ty":210,"very":211,"si":212,"ak":213,"who":214,"about":215,"ill":216,"them":217,"est":218,"red":219,"ye":220,"could":221,"ong":222,"your":223,"their":224,"em":225,"just":226,"other":227,"into":228,"any":229,"whi":230,"um":231,"tw":232,"ast":233,"der":234,"did":235,"ie":236,"been":237,"ace":238,"ink":239,"ity":240,"back":241,"ting":242,"br":243,"more":244,"ake":245,"pp":246,"then":247,"sp":248,"el":249,"use":250,"bl":251,"said":252,"over":253,"get":254},"merges":["t h","i n","th e","a n","e r","o u","r e","o n","a t","e d","e n","t o","in g","an d","i s","a s","a l","o r","o f","a r","i t","e s","h e","s t","l e","o m","s e","b e","a d","o w","l y","c h","w h","th at","y ou","l i","v e","a c","t i","l d","m e","w as","g h","i d","l l","w i","en t","f or","a y","r o","v er","i c","h er","k e","h is","n o","u t","u n","i r","l o","w e","r i","h a","wi th","gh t","ou t","i m","i on","al l","a b","on e","n e","g e","ou ld","t er","m o","h ad","c e","s he","g o","s h","u r","a m","s o","p e","m y","d e","a re","b ut","om e","f r","the r","f e","s u","d o","c on","t e","a in","er e","p o","i f","the y","u s","a g","t r","n ow","ou n","th is","ha ve","no t","s a","i l","u p","th ing","fr om","a p","h im","ac k","at ion","an t","ou r","o p","li ke","u st","es s","b o","o k","u l","in d","e x","c om","s ome","the re","er s","c o","re s","m an","ar d","p l","w or","w ay","ti on","f o","c a","w ere","b y","at e","p ro","t ed","oun d","ow n","w ould","t s","wh at","q u","al ly","i ght","c k","g r","wh en","v en","c an","ou gh","in e","en d","p er","ou s","o d","id e","k now","t y","ver y","s i","a k","wh o","ab out","i ll","the m","es t","re d","y e","c ould","on g","you r","the ir","e m","j ust","o ther","in to","an y","wh i","u m","t w","as t","d er","d id","i e","be en","ac e","in k","it y","b ack","t ing","b r","mo re","a ke","p p","the n","s p","e l","u se","b l","sa id","o ver","ge t"]}}
do_tts.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torchaudio
8
+ import yaml
9
+ from tqdm import tqdm
10
+
11
+ from models.arch_util import TorchMelSpectrogram
12
+ from models.discrete_diffusion_vocoder import DiscreteDiffusionVocoder
13
+ from models.lucidrains_dvae import DiscreteVAE
14
+ from models.text_voice_clip import VoiceCLIP
15
+ from models.unified_voice import UnifiedVoice
16
+ from utils.audio import load_audio
17
+ from utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
18
+ from utils.tokenizer import VoiceBpeTokenizer
19
+
20
+
21
+ def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200):
22
+ """
23
+ Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
24
+ """
25
+ return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
26
+ model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps))
27
+
28
+
29
+ def do_spectrogram_diffusion(diffusion_model, dvae_model, diffuser, mel_codes, conditioning_input, spectrogram_compression_factor=128):
30
+ """
31
+ Uses the specified diffusion model and DVAE model to convert the provided MEL & conditioning inputs into an audio clip.
32
+ """
33
+ with torch.no_grad():
34
+ mel = dvae_model.decode(mel_codes)[0]
35
+
36
+ # Pad MEL to multiples of 2048//spectrogram_compression_factor
37
+ msl = mel.shape[-1]
38
+ dsl = 2048 // spectrogram_compression_factor
39
+ gap = dsl - (msl % dsl)
40
+ if gap > 0:
41
+ mel = torch.nn.functional.pad(mel, (0, gap))
42
+
43
+ output_shape = (mel.shape[0], 1, mel.shape[-1] * spectrogram_compression_factor)
44
+ return diffuser.p_sample_loop(diffusion_model, output_shape, model_kwargs={'spectrogram': mel, 'conditioning_input': conditioning_input})
45
+
46
+
47
+ def load_conditioning(path, sample_rate=22050, cond_length=44100):
48
+ rel_clip = load_audio(path, sample_rate)
49
+ gap = rel_clip.shape[-1] - cond_length
50
+ if gap < 0:
51
+ rel_clip = F.pad(rel_clip, pad=(0, abs(gap)))
52
+ elif gap > 0:
53
+ rand_start = random.randint(0, gap)
54
+ rel_clip = rel_clip[:, rand_start:rand_start + cond_length]
55
+ mel_clip = TorchMelSpectrogram()(rel_clip.unsqueeze(0)).squeeze(0)
56
+ return mel_clip.unsqueeze(0).cuda(), rel_clip.unsqueeze(0).cuda()
57
+
58
+
59
+ def fix_autoregressive_output(codes, stop_token):
60
+ """
61
+ This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was
62
+ trained on and what the autoregressive code generator creates (which has no padding or end).
63
+ This is highly specific to the DVAE being used, so this particular coding will not necessarily work if used with
64
+ a different DVAE. This can be inferred by feeding a audio clip padded with lots of zeros on the end through the DVAE
65
+ and copying out the last few codes.
66
+
67
+ Failing to do this padding will produce speech with a harsh end that sounds like "BLAH" or similar.
68
+ """
69
+ # Strip off the autoregressive stop token and add padding.
70
+ stop_token_indices = (codes == stop_token).nonzero()
71
+ if len(stop_token_indices) == 0:
72
+ print("No stop tokens found, enjoy that output of yours!")
73
+ return
74
+ else:
75
+ codes[stop_token_indices] = 83
76
+ stm = stop_token_indices.min().item()
77
+ codes[stm:] = 83
78
+ if stm - 3 < codes.shape[0]:
79
+ codes[-3] = 45
80
+ codes[-2] = 45
81
+ codes[-1] = 248
82
+
83
+ return codes
84
+
85
+
86
+ if __name__ == '__main__':
87
+ preselected_cond_voices = {
88
+ 'simmons': ['Y:\\clips\\books1\\754_Dan Simmons - The Rise Of Endymion 356 of 450\\00026.wav'],
89
+ 'news_girl': ['Y:\\clips\\podcasts-0\\8288_20210113-Is More Violence Coming_\\00022.wav', 'Y:\\clips\\podcasts-0\\8288_20210113-Is More Violence Coming_\\00016.wav'],
90
+ 'dan_carlin': ['Y:\\clips\\books1\\5_dchha06 Shield of the West\\00476.wav', 'Y:\\clips\\books1\\15_dchha16 Nazi Tidbits\\00036.wav'],
91
+ 'libri_test': ['Y:\\libritts\\test-clean\\672\\122797\\672_122797_000057_000002.wav'],
92
+ }
93
+
94
+ parser = argparse.ArgumentParser()
95
+ parser.add_argument('-autoregressive_model_path', type=str, help='Autoregressive model checkpoint to load.', default='.models/unified_voice.pth')
96
+ parser.add_argument('-clip_model_path', type=str, help='CLIP model checkpoint to load.', default='.models/clip.pth')
97
+ parser.add_argument('-diffusion_model_path', type=str, help='Diffusion model checkpoint to load.', default='./models/diffusion_vocoder.pth')
98
+ parser.add_argument('-dvae_model_path', type=str, help='DVAE model checkpoint to load.', default='./models/dvae.pth')
99
+ parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
100
+ parser.add_argument('-cond_preset', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='dan_carlin')
101
+ parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=32)
102
+ parser.add_argument('-num_batches', type=int, help='How many batches those samples should be produced over.', default=2)
103
+ parser.add_argument('-num_outputs', type=int, help='Number of outputs to produce.', default=2)
104
+ parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/')
105
+ args = parser.parse_args()
106
+ os.makedirs(args.output_path, exist_ok=True)
107
+
108
+ print("Loading GPT TTS..")
109
+ autoregressive = UnifiedVoice(max_mel_tokens=300, max_text_tokens=200, max_conditioning_inputs=2, layers=30, model_dim=1024, heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False, train_solo_embeddings=False).eval()
110
+ autoregressive.load_state_dict(torch.load(args.autoregressive_model_path))
111
+ stop_mel_token = autoregressive.stop_mel_token
112
+
113
+ print("Loading data..")
114
+ tokenizer = VoiceBpeTokenizer()
115
+ text = torch.IntTensor(tokenizer.encode(args.text)).unsqueeze(0).cuda()
116
+ text = F.pad(text, (0,1)) # This may not be necessary.
117
+ cond_paths = preselected_cond_voices[args.cond_preset]
118
+ conds = []
119
+ for cond_path in cond_paths:
120
+ c, cond_wav = load_conditioning(cond_path, cond_length=132300)
121
+ conds.append(c)
122
+ conds = torch.stack(conds, dim=1) # And just use the last cond_wav for the diffusion model.
123
+
124
+ with torch.no_grad():
125
+ print("Performing GPT inference..")
126
+ samples = []
127
+ for b in tqdm(range(args.num_batches)):
128
+ codes = autoregressive.inference_speech(conds, text, num_beams=1, repetition_penalty=1.0, do_sample=True, top_k=50, top_p=.95,
129
+ temperature=.9, num_return_sequences=args.num_samples//args.num_batches, length_penalty=1)
130
+ padding_needed = 250 - codes.shape[1]
131
+ codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
132
+ samples.append(codes)
133
+ samples = torch.cat(samples, dim=0)
134
+ del autoregressive
135
+
136
+ print("Loading CLIP..")
137
+ clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=8, text_seq_len=120, text_heads=8,
138
+ num_speech_tokens=8192, speech_enc_depth=10, speech_heads=8, speech_seq_len=250).eval()
139
+ clip.load_state_dict(torch.load(args.clip_model_path))
140
+ print("Performing CLIP filtering..")
141
+ for i in range(samples.shape[0]):
142
+ samples[i] = fix_autoregressive_output(samples[i], stop_mel_token)
143
+ clip_results = clip(text.repeat(samples.shape[0], 1),
144
+ torch.full((samples.shape[0],), fill_value=text.shape[1]-1, dtype=torch.long, device='cuda'),
145
+ samples, torch.full((samples.shape[0],), fill_value=samples.shape[1]*1024, dtype=torch.long, device='cuda'),
146
+ return_loss=False)
147
+ best_results = samples[torch.topk(clip_results, k=args.num_outputs).indices]
148
+
149
+ # Delete the autoregressive and clip models to free up GPU memory
150
+ del samples, clip
151
+
152
+ print("Loading DVAE..")
153
+ dvae = DiscreteVAE(positional_dims=1, channels=80, hidden_dim=512, num_resnet_blocks=3, codebook_dim=512, num_tokens=8192, num_layers=2,
154
+ record_codes=True, kernel_size=3, use_transposed_convs=False).eval()
155
+ dvae.load_state_dict(torch.load(args.dvae_model_path))
156
+ print("Loading Diffusion Model..")
157
+ diffusion = DiscreteDiffusionVocoder(model_channels=128, dvae_dim=80, channel_mult=[1, 1, 1.5, 2, 3, 4, 6, 8, 8, 8, 8], num_res_blocks=[1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1],
158
+ spectrogram_conditioning_resolutions=[2,512], attention_resolutions=[512,1024], num_heads=4, kernel_size=3, scale_factor=2,
159
+ conditioning_inputs_provided=True, time_embed_dim_multiplier=4).eval()
160
+ diffusion.load_state_dict(torch.load(args.diffusion_model_path))
161
+ diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=100)
162
+
163
+ print("Performing vocoding..")
164
+ # Perform vocoding on each batch element separately: Vocoding is very memory (and compute!) intensive.
165
+ for b in range(best_results.shape[0]):
166
+ code = best_results[b].unsqueeze(0)
167
+ wav = do_spectrogram_diffusion(diffusion, dvae, diffuser, code, cond_wav, spectrogram_compression_factor=256)
168
+ torchaudio.save(os.path.join(args.output_path, f'gpt_tts_output_{b}.wav'), wav.squeeze(0).cpu(), 22050)
models/arch_util.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torchaudio
7
+
8
+
9
+ def zero_module(module):
10
+ """
11
+ Zero out the parameters of a module and return it.
12
+ """
13
+ for p in module.parameters():
14
+ p.detach().zero_()
15
+ return module
16
+
17
+
18
+ class GroupNorm32(nn.GroupNorm):
19
+ def forward(self, x):
20
+ return super().forward(x.float()).type(x.dtype)
21
+
22
+
23
+ def normalization(channels):
24
+ """
25
+ Make a standard normalization layer.
26
+
27
+ :param channels: number of input channels.
28
+ :return: an nn.Module for normalization.
29
+ """
30
+ groups = 32
31
+ if channels <= 16:
32
+ groups = 8
33
+ elif channels <= 64:
34
+ groups = 16
35
+ while channels % groups != 0:
36
+ groups = int(groups / 2)
37
+ assert groups > 2
38
+ return GroupNorm32(groups, channels)
39
+
40
+
41
+ class QKVAttentionLegacy(nn.Module):
42
+ """
43
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
44
+ """
45
+
46
+ def __init__(self, n_heads):
47
+ super().__init__()
48
+ self.n_heads = n_heads
49
+
50
+ def forward(self, qkv, mask=None):
51
+ """
52
+ Apply QKV attention.
53
+
54
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
55
+ :return: an [N x (H * C) x T] tensor after attention.
56
+ """
57
+ bs, width, length = qkv.shape
58
+ assert width % (3 * self.n_heads) == 0
59
+ ch = width // (3 * self.n_heads)
60
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
61
+ scale = 1 / math.sqrt(math.sqrt(ch))
62
+ weight = torch.einsum(
63
+ "bct,bcs->bts", q * scale, k * scale
64
+ ) # More stable with f16 than dividing afterwards
65
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
66
+ if mask is not None:
67
+ # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
68
+ mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
69
+ weight = weight * mask
70
+ a = torch.einsum("bts,bcs->bct", weight, v)
71
+
72
+ return a.reshape(bs, -1, length)
73
+
74
+
75
+ class AttentionBlock(nn.Module):
76
+ """
77
+ An attention block that allows spatial positions to attend to each other.
78
+
79
+ Originally ported from here, but adapted to the N-d case.
80
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
81
+ """
82
+
83
+ def __init__(
84
+ self,
85
+ channels,
86
+ num_heads=1,
87
+ num_head_channels=-1,
88
+ ):
89
+ super().__init__()
90
+ self.channels = channels
91
+ if num_head_channels == -1:
92
+ self.num_heads = num_heads
93
+ else:
94
+ assert (
95
+ channels % num_head_channels == 0
96
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
97
+ self.num_heads = channels // num_head_channels
98
+ self.norm = normalization(channels)
99
+ self.qkv = nn.Conv1d(channels, channels * 3, 1)
100
+ self.attention = QKVAttentionLegacy(self.num_heads)
101
+
102
+ self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
103
+
104
+ def forward(self, x, mask=None):
105
+ if mask is not None:
106
+ return self._forward(x, mask)
107
+ else:
108
+ return self._forward(x)
109
+
110
+ def _forward(self, x, mask=None):
111
+ b, c, *spatial = x.shape
112
+ x = x.reshape(b, c, -1)
113
+ qkv = self.qkv(self.norm(x))
114
+ h = self.attention(qkv, mask)
115
+ h = self.proj_out(h)
116
+ return (x + h).reshape(b, c, *spatial)
117
+
118
+
119
+ class Upsample(nn.Module):
120
+ """
121
+ An upsampling layer with an optional convolution.
122
+
123
+ :param channels: channels in the inputs and outputs.
124
+ :param use_conv: a bool determining if a convolution is applied.
125
+ """
126
+
127
+ def __init__(self, channels, use_conv, out_channels=None, factor=4):
128
+ super().__init__()
129
+ self.channels = channels
130
+ self.out_channels = out_channels or channels
131
+ self.use_conv = use_conv
132
+ self.factor = factor
133
+ if use_conv:
134
+ ksize = 5
135
+ pad = 2
136
+ self.conv = nn.Conv1d(self.channels, self.out_channels, ksize, padding=pad)
137
+
138
+ def forward(self, x):
139
+ assert x.shape[1] == self.channels
140
+ x = F.interpolate(x, scale_factor=self.factor, mode="nearest")
141
+ if self.use_conv:
142
+ x = self.conv(x)
143
+ return x
144
+
145
+
146
+ class Downsample(nn.Module):
147
+ """
148
+ A downsampling layer with an optional convolution.
149
+
150
+ :param channels: channels in the inputs and outputs.
151
+ :param use_conv: a bool determining if a convolution is applied.
152
+ """
153
+
154
+ def __init__(self, channels, use_conv, out_channels=None, factor=4, ksize=5, pad=2):
155
+ super().__init__()
156
+ self.channels = channels
157
+ self.out_channels = out_channels or channels
158
+ self.use_conv = use_conv
159
+
160
+ stride = factor
161
+ if use_conv:
162
+ self.op = nn.Conv1d(
163
+ self.channels, self.out_channels, ksize, stride=stride, padding=pad
164
+ )
165
+ else:
166
+ assert self.channels == self.out_channels
167
+ self.op = nn.AvgPool1d(kernel_size=stride, stride=stride)
168
+
169
+ def forward(self, x):
170
+ assert x.shape[1] == self.channels
171
+ return self.op(x)
172
+
173
+
174
+ class ResBlock(nn.Module):
175
+ def __init__(
176
+ self,
177
+ channels,
178
+ dropout,
179
+ out_channels=None,
180
+ use_conv=False,
181
+ use_scale_shift_norm=False,
182
+ up=False,
183
+ down=False,
184
+ kernel_size=3,
185
+ ):
186
+ super().__init__()
187
+ self.channels = channels
188
+ self.dropout = dropout
189
+ self.out_channels = out_channels or channels
190
+ self.use_conv = use_conv
191
+ self.use_scale_shift_norm = use_scale_shift_norm
192
+ padding = 1 if kernel_size == 3 else 2
193
+
194
+ self.in_layers = nn.Sequential(
195
+ normalization(channels),
196
+ nn.SiLU(),
197
+ nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding),
198
+ )
199
+
200
+ self.updown = up or down
201
+
202
+ if up:
203
+ self.h_upd = Upsample(channels, False)
204
+ self.x_upd = Upsample(channels, False)
205
+ elif down:
206
+ self.h_upd = Downsample(channels, False)
207
+ self.x_upd = Downsample(channels, False)
208
+ else:
209
+ self.h_upd = self.x_upd = nn.Identity()
210
+
211
+ self.out_layers = nn.Sequential(
212
+ normalization(self.out_channels),
213
+ nn.SiLU(),
214
+ nn.Dropout(p=dropout),
215
+ zero_module(
216
+ nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)
217
+ ),
218
+ )
219
+
220
+ if self.out_channels == channels:
221
+ self.skip_connection = nn.Identity()
222
+ elif use_conv:
223
+ self.skip_connection = nn.Conv1d(
224
+ channels, self.out_channels, kernel_size, padding=padding
225
+ )
226
+ else:
227
+ self.skip_connection = nn.Conv1d(channels, self.out_channels, 1)
228
+
229
+ def forward(self, x):
230
+ if self.updown:
231
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
232
+ h = in_rest(x)
233
+ h = self.h_upd(h)
234
+ x = self.x_upd(x)
235
+ h = in_conv(h)
236
+ else:
237
+ h = self.in_layers(x)
238
+ h = self.out_layers(h)
239
+ return self.skip_connection(x) + h
240
+
241
+
242
+ class AudioMiniEncoder(nn.Module):
243
+ def __init__(self,
244
+ spec_dim,
245
+ embedding_dim,
246
+ base_channels=128,
247
+ depth=2,
248
+ resnet_blocks=2,
249
+ attn_blocks=4,
250
+ num_attn_heads=4,
251
+ dropout=0,
252
+ downsample_factor=2,
253
+ kernel_size=3):
254
+ super().__init__()
255
+ self.init = nn.Sequential(
256
+ nn.Conv1d(spec_dim, base_channels, 3, padding=1)
257
+ )
258
+ ch = base_channels
259
+ res = []
260
+ for l in range(depth):
261
+ for r in range(resnet_blocks):
262
+ res.append(ResBlock(ch, dropout, kernel_size=kernel_size))
263
+ res.append(Downsample(ch, use_conv=True, out_channels=ch*2, factor=downsample_factor))
264
+ ch *= 2
265
+ self.res = nn.Sequential(*res)
266
+ self.final = nn.Sequential(
267
+ normalization(ch),
268
+ nn.SiLU(),
269
+ nn.Conv1d(ch, embedding_dim, 1)
270
+ )
271
+ attn = []
272
+ for a in range(attn_blocks):
273
+ attn.append(AttentionBlock(embedding_dim, num_attn_heads,))
274
+ self.attn = nn.Sequential(*attn)
275
+ self.dim = embedding_dim
276
+
277
+ def forward(self, x):
278
+ h = self.init(x)
279
+ h = self.res(h)
280
+ h = self.final(h)
281
+ h = self.attn(h)
282
+ return h[:, :, 0]
283
+
284
+
285
+ class TorchMelSpectrogram(nn.Module):
286
+ def __init__(self, filter_length=1024, hop_length=256, win_length=1024, n_mel_channels=80, mel_fmin=0, mel_fmax=8000,
287
+ sampling_rate=22050, normalize=False, mel_norm_file='data/mel_norms.pth'):
288
+ super().__init__()
289
+ # These are the default tacotron values for the MEL spectrogram.
290
+ self.filter_length = filter_length
291
+ self.hop_length = hop_length
292
+ self.win_length = win_length
293
+ self.n_mel_channels = n_mel_channels
294
+ self.mel_fmin = mel_fmin
295
+ self.mel_fmax = mel_fmax
296
+ self.sampling_rate = sampling_rate
297
+ self.mel_stft = torchaudio.transforms.MelSpectrogram(n_fft=self.filter_length, hop_length=self.hop_length,
298
+ win_length=self.win_length, power=2, normalized=normalize,
299
+ sample_rate=self.sampling_rate, f_min=self.mel_fmin,
300
+ f_max=self.mel_fmax, n_mels=self.n_mel_channels,
301
+ norm="slaney")
302
+ self.mel_norm_file = mel_norm_file
303
+ if self.mel_norm_file is not None:
304
+ self.mel_norms = torch.load(self.mel_norm_file)
305
+ else:
306
+ self.mel_norms = None
307
+
308
+ def forward(self, inp):
309
+ if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
310
+ inp = inp.squeeze(1)
311
+ assert len(inp.shape) == 2
312
+ self.mel_stft = self.mel_stft.to(inp.device)
313
+ mel = self.mel_stft(inp)
314
+ # Perform dynamic range compression
315
+ mel = torch.log(torch.clamp(mel, min=1e-5))
316
+ if self.mel_norms is not None:
317
+ self.mel_norms = self.mel_norms.to(mel.device)
318
+ mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1)
319
+ return mel
models/discrete_diffusion_vocoder.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This model is based on OpenAI's UNet from improved diffusion, with modifications to support a MEL conditioning signal
3
+ and an audio conditioning input. It has also been simplified somewhat.
4
+ Credit: https://github.com/openai/improved-diffusion
5
+ """
6
+
7
+
8
+ import math
9
+ from abc import abstractmethod
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from models.arch_util import normalization, zero_module, Downsample, Upsample, AudioMiniEncoder, AttentionBlock
15
+
16
+
17
+ def timestep_embedding(timesteps, dim, max_period=10000):
18
+ """
19
+ Create sinusoidal timestep embeddings.
20
+
21
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
22
+ These may be fractional.
23
+ :param dim: the dimension of the output.
24
+ :param max_period: controls the minimum frequency of the embeddings.
25
+ :return: an [N x dim] Tensor of positional embeddings.
26
+ """
27
+ half = dim // 2
28
+ freqs = torch.exp(
29
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
30
+ ).to(device=timesteps.device)
31
+ args = timesteps[:, None].float() * freqs[None]
32
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
33
+ if dim % 2:
34
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
35
+ return embedding
36
+
37
+
38
+ class TimestepBlock(nn.Module):
39
+ """
40
+ Any module where forward() takes timestep embeddings as a second argument.
41
+ """
42
+
43
+ @abstractmethod
44
+ def forward(self, x, emb):
45
+ """
46
+ Apply the module to `x` given `emb` timestep embeddings.
47
+ """
48
+
49
+
50
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
51
+ """
52
+ A sequential module that passes timestep embeddings to the children that
53
+ support it as an extra input.
54
+ """
55
+
56
+ def forward(self, x, emb):
57
+ for layer in self:
58
+ if isinstance(layer, TimestepBlock):
59
+ x = layer(x, emb)
60
+ else:
61
+ x = layer(x)
62
+ return x
63
+
64
+
65
+ class TimestepResBlock(TimestepBlock):
66
+ """
67
+ A residual block that can optionally change the number of channels.
68
+
69
+ :param channels: the number of input channels.
70
+ :param emb_channels: the number of timestep embedding channels.
71
+ :param dropout: the rate of dropout.
72
+ :param out_channels: if specified, the number of out channels.
73
+ :param use_conv: if True and out_channels is specified, use a spatial
74
+ convolution instead of a smaller 1x1 convolution to change the
75
+ channels in the skip connection.
76
+ :param dims: determines if the signal is 1D, 2D, or 3D.
77
+ :param up: if True, use this block for upsampling.
78
+ :param down: if True, use this block for downsampling.
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ channels,
84
+ emb_channels,
85
+ dropout,
86
+ out_channels=None,
87
+ use_conv=False,
88
+ use_scale_shift_norm=False,
89
+ up=False,
90
+ down=False,
91
+ kernel_size=3,
92
+ ):
93
+ super().__init__()
94
+ self.channels = channels
95
+ self.emb_channels = emb_channels
96
+ self.dropout = dropout
97
+ self.out_channels = out_channels or channels
98
+ self.use_conv = use_conv
99
+ self.use_scale_shift_norm = use_scale_shift_norm
100
+ padding = 1 if kernel_size == 3 else (2 if kernel_size == 5 else 0)
101
+
102
+ self.in_layers = nn.Sequential(
103
+ normalization(channels),
104
+ nn.SiLU(),
105
+ nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding),
106
+ )
107
+
108
+ self.updown = up or down
109
+
110
+ if up:
111
+ self.h_upd = Upsample(channels, False, dims)
112
+ self.x_upd = Upsample(channels, False, dims)
113
+ elif down:
114
+ self.h_upd = Downsample(channels, False, dims)
115
+ self.x_upd = Downsample(channels, False, dims)
116
+ else:
117
+ self.h_upd = self.x_upd = nn.Identity()
118
+
119
+ self.emb_layers = nn.Sequential(
120
+ nn.SiLU(),
121
+ nn.Linear(
122
+ emb_channels,
123
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
124
+ ),
125
+ )
126
+ self.out_layers = nn.Sequential(
127
+ normalization(self.out_channels),
128
+ nn.SiLU(),
129
+ nn.Dropout(p=dropout),
130
+ zero_module(
131
+ nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)
132
+ ),
133
+ )
134
+
135
+ if self.out_channels == channels:
136
+ self.skip_connection = nn.Identity()
137
+ elif use_conv:
138
+ self.skip_connection = nn.Conv1d(
139
+ channels, self.out_channels, kernel_size, padding=padding
140
+ )
141
+ else:
142
+ self.skip_connection = nn.Conv1d(channels, self.out_channels, 1)
143
+
144
+ def forward(self, x, emb):
145
+ if self.updown:
146
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
147
+ h = in_rest(x)
148
+ h = self.h_upd(h)
149
+ x = self.x_upd(x)
150
+ h = in_conv(h)
151
+ else:
152
+ h = self.in_layers(x)
153
+ emb_out = self.emb_layers(emb).type(h.dtype)
154
+ while len(emb_out.shape) < len(h.shape):
155
+ emb_out = emb_out[..., None]
156
+ if self.use_scale_shift_norm:
157
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
158
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
159
+ h = out_norm(h) * (1 + scale) + shift
160
+ h = out_rest(h)
161
+ else:
162
+ h = h + emb_out
163
+ h = self.out_layers(h)
164
+ return self.skip_connection(x) + h
165
+
166
+
167
+ class DiscreteSpectrogramConditioningBlock(nn.Module):
168
+ def __init__(self, dvae_channels, channels, level):
169
+ super().__init__()
170
+ self.intg = nn.Sequential(nn.Conv1d(dvae_channels, channels, kernel_size=1),
171
+ normalization(channels),
172
+ nn.SiLU(),
173
+ nn.Conv1d(channels, channels, kernel_size=3))
174
+ self.level = level
175
+
176
+ """
177
+ Embeds the given codes and concatenates them onto x. Return shape is the same as x.shape.
178
+
179
+ :param x: bxcxS waveform latent
180
+ :param codes: bxN discrete codes, N <= S
181
+ """
182
+ def forward(self, x, dvae_in):
183
+ b, c, S = x.shape
184
+ _, q, N = dvae_in.shape
185
+ emb = self.intg(dvae_in)
186
+ emb = nn.functional.interpolate(emb, size=(S,), mode='nearest')
187
+ return torch.cat([x, emb], dim=1)
188
+
189
+
190
+ class DiscreteDiffusionVocoder(nn.Module):
191
+ """
192
+ The full UNet model with attention and timestep embedding.
193
+
194
+ Customized to be conditioned on a spectrogram prior.
195
+
196
+ :param in_channels: channels in the input Tensor.
197
+ :param spectrogram_channels: channels in the conditioning spectrogram.
198
+ :param model_channels: base channel count for the model.
199
+ :param out_channels: channels in the output Tensor.
200
+ :param num_res_blocks: number of residual blocks per downsample.
201
+ :param attention_resolutions: a collection of downsample rates at which
202
+ attention will take place. May be a set, list, or tuple.
203
+ For example, if this contains 4, then at 4x downsampling, attention
204
+ will be used.
205
+ :param dropout: the dropout probability.
206
+ :param channel_mult: channel multiplier for each level of the UNet.
207
+ :param conv_resample: if True, use learned convolutions for upsampling and
208
+ downsampling.
209
+ :param dims: determines if the signal is 1D, 2D, or 3D.
210
+ :param num_heads: the number of attention heads in each attention layer.
211
+ :param num_heads_channels: if specified, ignore num_heads and instead use
212
+ a fixed channel width per attention head.
213
+ :param num_heads_upsample: works with num_heads to set a different number
214
+ of heads for upsampling. Deprecated.
215
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
216
+ :param resblock_updown: use residual blocks for up/downsampling.
217
+ :param use_new_attention_order: use a different attention pattern for potentially
218
+ increased efficiency.
219
+ """
220
+
221
+ def __init__(
222
+ self,
223
+ model_channels,
224
+ in_channels=1,
225
+ out_channels=2, # mean and variance
226
+ dvae_dim=512,
227
+ dropout=0,
228
+ # res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
229
+ channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
230
+ num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2),
231
+ # spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0)
232
+ # attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1
233
+ spectrogram_conditioning_resolutions=(512,),
234
+ attention_resolutions=(512,1024,2048),
235
+ conv_resample=True,
236
+ dims=1,
237
+ use_fp16=False,
238
+ num_heads=1,
239
+ num_head_channels=-1,
240
+ num_heads_upsample=-1,
241
+ use_scale_shift_norm=False,
242
+ resblock_updown=False,
243
+ kernel_size=3,
244
+ scale_factor=2,
245
+ conditioning_inputs_provided=True,
246
+ time_embed_dim_multiplier=4,
247
+ ):
248
+ super().__init__()
249
+
250
+ if num_heads_upsample == -1:
251
+ num_heads_upsample = num_heads
252
+
253
+ self.in_channels = in_channels
254
+ self.model_channels = model_channels
255
+ self.out_channels = out_channels
256
+ self.attention_resolutions = attention_resolutions
257
+ self.dropout = dropout
258
+ self.channel_mult = channel_mult
259
+ self.conv_resample = conv_resample
260
+ self.dtype = torch.float16 if use_fp16 else torch.float32
261
+ self.num_heads = num_heads
262
+ self.num_head_channels = num_head_channels
263
+ self.num_heads_upsample = num_heads_upsample
264
+ self.dims = dims
265
+
266
+ padding = 1 if kernel_size == 3 else 2
267
+
268
+ time_embed_dim = model_channels * time_embed_dim_multiplier
269
+ self.time_embed = nn.Sequential(
270
+ nn.Linear(model_channels, time_embed_dim),
271
+ nn.SiLU(),
272
+ nn.Linear(time_embed_dim, time_embed_dim),
273
+ )
274
+
275
+ self.conditioning_enabled = conditioning_inputs_provided
276
+ if conditioning_inputs_provided:
277
+ self.contextual_embedder = AudioMiniEncoder(in_channels, time_embed_dim, base_channels=32, depth=6, resnet_blocks=1,
278
+ attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
279
+
280
+ seqlyr = TimestepEmbedSequential(
281
+ nn.Conv1d(in_channels, model_channels, kernel_size, padding=padding)
282
+ )
283
+ seqlyr.level = 0
284
+ self.input_blocks = nn.ModuleList([seqlyr])
285
+ spectrogram_blocks = []
286
+ self._feature_size = model_channels
287
+ input_block_chans = [model_channels]
288
+ ch = model_channels
289
+ ds = 1
290
+
291
+ for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)):
292
+ if ds in spectrogram_conditioning_resolutions:
293
+ spec_cond_block = DiscreteSpectrogramConditioningBlock(dvae_dim, ch, 2 ** level)
294
+ self.input_blocks.append(spec_cond_block)
295
+ spectrogram_blocks.append(spec_cond_block)
296
+ ch *= 2
297
+
298
+ for _ in range(num_blocks):
299
+ layers = [
300
+ TimestepResBlock(
301
+ ch,
302
+ time_embed_dim,
303
+ dropout,
304
+ out_channels=int(mult * model_channels),
305
+ use_scale_shift_norm=use_scale_shift_norm,
306
+ kernel_size=kernel_size,
307
+ )
308
+ ]
309
+ ch = int(mult * model_channels)
310
+ if ds in attention_resolutions:
311
+ layers.append(
312
+ AttentionBlock(
313
+ ch,
314
+ num_heads=num_heads,
315
+ num_head_channels=num_head_channels,
316
+ )
317
+ )
318
+ layer = TimestepEmbedSequential(*layers)
319
+ layer.level = 2 ** level
320
+ self.input_blocks.append(layer)
321
+ self._feature_size += ch
322
+ input_block_chans.append(ch)
323
+ if level != len(channel_mult) - 1:
324
+ out_ch = ch
325
+ upblk = TimestepEmbedSequential(
326
+ TimestepResBlock(
327
+ ch,
328
+ time_embed_dim,
329
+ dropout,
330
+ out_channels=out_ch,
331
+ use_scale_shift_norm=use_scale_shift_norm,
332
+ down=True,
333
+ kernel_size=kernel_size,
334
+ )
335
+ if resblock_updown
336
+ else Downsample(
337
+ ch, conv_resample, out_channels=out_ch, factor=scale_factor
338
+ )
339
+ )
340
+ upblk.level = 2 ** level
341
+ self.input_blocks.append(upblk)
342
+ ch = out_ch
343
+ input_block_chans.append(ch)
344
+ ds *= 2
345
+ self._feature_size += ch
346
+
347
+ self.middle_block = TimestepEmbedSequential(
348
+ TimestepResBlock(
349
+ ch,
350
+ time_embed_dim,
351
+ dropout,
352
+ use_scale_shift_norm=use_scale_shift_norm,
353
+ kernel_size=kernel_size,
354
+ ),
355
+ AttentionBlock(
356
+ ch,
357
+ num_heads=num_heads,
358
+ num_head_channels=num_head_channels,
359
+ ),
360
+ TimestepResBlock(
361
+ ch,
362
+ time_embed_dim,
363
+ dropout,
364
+ use_scale_shift_norm=use_scale_shift_norm,
365
+ kernel_size=kernel_size,
366
+ ),
367
+ )
368
+ self._feature_size += ch
369
+
370
+ self.output_blocks = nn.ModuleList([])
371
+ for level, (mult, num_blocks) in list(enumerate(zip(channel_mult, num_res_blocks)))[::-1]:
372
+ for i in range(num_blocks + 1):
373
+ ich = input_block_chans.pop()
374
+ layers = [
375
+ TimestepResBlock(
376
+ ch + ich,
377
+ time_embed_dim,
378
+ dropout,
379
+ out_channels=int(model_channels * mult),
380
+ use_scale_shift_norm=use_scale_shift_norm,
381
+ kernel_size=kernel_size,
382
+ )
383
+ ]
384
+ ch = int(model_channels * mult)
385
+ if ds in attention_resolutions:
386
+ layers.append(
387
+ AttentionBlock(
388
+ ch,
389
+ num_heads=num_heads_upsample,
390
+ num_head_channels=num_head_channels,
391
+ )
392
+ )
393
+ if level and i == num_blocks:
394
+ out_ch = ch
395
+ layers.append(
396
+ TimestepResBlock(
397
+ ch,
398
+ time_embed_dim,
399
+ dropout,
400
+ out_channels=out_ch,
401
+ use_scale_shift_norm=use_scale_shift_norm,
402
+ up=True,
403
+ kernel_size=kernel_size,
404
+ )
405
+ if resblock_updown
406
+ else Upsample(ch, conv_resample, out_channels=out_ch, factor=scale_factor)
407
+ )
408
+ ds //= 2
409
+ layer = TimestepEmbedSequential(*layers)
410
+ layer.level = 2 ** level
411
+ self.output_blocks.append(layer)
412
+ self._feature_size += ch
413
+
414
+ self.out = nn.Sequential(
415
+ normalization(ch),
416
+ nn.SiLU(),
417
+ zero_module(nn.Conv1d(model_channels, out_channels, kernel_size, padding=padding)),
418
+ )
419
+
420
+ def forward(self, x, timesteps, spectrogram, conditioning_input=None):
421
+ """
422
+ Apply the model to an input batch.
423
+
424
+ :param x: an [N x C x ...] Tensor of inputs.
425
+ :param timesteps: a 1-D batch of timesteps.
426
+ :param y: an [N] Tensor of labels, if class-conditional.
427
+ :return: an [N x C x ...] Tensor of outputs.
428
+ """
429
+ assert x.shape[-1] % 2048 == 0 # This model operates at base//2048 at it's bottom levels, thus this requirement.
430
+ if self.conditioning_enabled:
431
+ assert conditioning_input is not None
432
+
433
+ hs = []
434
+ emb1 = self.time_embed(timestep_embedding(timesteps, self.model_channels))
435
+ if self.conditioning_enabled:
436
+ emb2 = self.contextual_embedder(conditioning_input)
437
+ emb = emb1 + emb2
438
+ else:
439
+ emb = emb1
440
+
441
+ h = x.type(self.dtype)
442
+ for k, module in enumerate(self.input_blocks):
443
+ if isinstance(module, DiscreteSpectrogramConditioningBlock):
444
+ h = module(h, spectrogram)
445
+ else:
446
+ h = module(h, emb)
447
+ hs.append(h)
448
+ h = self.middle_block(h, emb)
449
+ for module in self.output_blocks:
450
+ h = torch.cat([h, hs.pop()], dim=1)
451
+ h = module(h, emb)
452
+ h = h.type(x.dtype)
453
+ return self.out(h)
454
+
455
+
456
+ # Test for ~4 second audio clip at 22050Hz
457
+ if __name__ == '__main__':
458
+ clip = torch.randn(2, 1, 40960)
459
+ spec = torch.randn(2,80,160)
460
+ cond = torch.randn(2, 1, 40960)
461
+ ts = torch.LongTensor([555, 556])
462
+ model = DiscreteDiffusionVocoder(model_channels=128, channel_mult=[1, 1, 1.5, 2, 3, 4, 6, 8, 8, 8, 8],
463
+ num_res_blocks=[1,2, 2, 2, 2, 2, 2, 2, 2, 1, 1 ], spectrogram_conditioning_resolutions=[2,512],
464
+ dropout=.05, attention_resolutions=[512,1024], num_heads=4, kernel_size=3, scale_factor=2,
465
+ conditioning_inputs_provided=True, conditioning_input_dim=80, time_embed_dim_multiplier=4,
466
+ dvae_dim=80)
467
+
468
+ print(model(clip, ts, spec, cond).shape)
models/lucidrains_dvae.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from math import sqrt
3
+
4
+ import torch
5
+ import torch.distributed as distributed
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+
10
+
11
+ def default(val, d):
12
+ return val if val is not None else d
13
+
14
+
15
+ def eval_decorator(fn):
16
+ def inner(model, *args, **kwargs):
17
+ was_training = model.training
18
+ model.eval()
19
+ out = fn(model, *args, **kwargs)
20
+ model.train(was_training)
21
+ return out
22
+ return inner
23
+
24
+
25
+ # Quantizer implemented by the rosinality vqvae repo.
26
+ # Credit: https://github.com/rosinality/vq-vae-2-pytorch
27
+ class Quantize(nn.Module):
28
+ def __init__(self, dim, n_embed, decay=0.99, eps=1e-5, balancing_heuristic=False, new_return_order=False):
29
+ super().__init__()
30
+
31
+ self.dim = dim
32
+ self.n_embed = n_embed
33
+ self.decay = decay
34
+ self.eps = eps
35
+
36
+ self.balancing_heuristic = balancing_heuristic
37
+ self.codes = None
38
+ self.max_codes = 64000
39
+ self.codes_full = False
40
+ self.new_return_order = new_return_order
41
+
42
+ embed = torch.randn(dim, n_embed)
43
+ self.register_buffer("embed", embed)
44
+ self.register_buffer("cluster_size", torch.zeros(n_embed))
45
+ self.register_buffer("embed_avg", embed.clone())
46
+
47
+ def forward(self, input, return_soft_codes=False):
48
+ if self.balancing_heuristic and self.codes_full:
49
+ h = torch.histc(self.codes, bins=self.n_embed, min=0, max=self.n_embed) / len(self.codes)
50
+ mask = torch.logical_or(h > .9, h < .01).unsqueeze(1)
51
+ ep = self.embed.permute(1,0)
52
+ ea = self.embed_avg.permute(1,0)
53
+ rand_embed = torch.randn_like(ep) * mask
54
+ self.embed = (ep * ~mask + rand_embed).permute(1,0)
55
+ self.embed_avg = (ea * ~mask + rand_embed).permute(1,0)
56
+ self.cluster_size = self.cluster_size * ~mask.squeeze()
57
+ if torch.any(mask):
58
+ print(f"Reset {torch.sum(mask)} embedding codes.")
59
+ self.codes = None
60
+ self.codes_full = False
61
+
62
+ flatten = input.reshape(-1, self.dim)
63
+ dist = (
64
+ flatten.pow(2).sum(1, keepdim=True)
65
+ - 2 * flatten @ self.embed
66
+ + self.embed.pow(2).sum(0, keepdim=True)
67
+ )
68
+ soft_codes = -dist
69
+ _, embed_ind = soft_codes.max(1)
70
+ embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
71
+ embed_ind = embed_ind.view(*input.shape[:-1])
72
+ quantize = self.embed_code(embed_ind)
73
+
74
+ if self.balancing_heuristic:
75
+ if self.codes is None:
76
+ self.codes = embed_ind.flatten()
77
+ else:
78
+ self.codes = torch.cat([self.codes, embed_ind.flatten()])
79
+ if len(self.codes) > self.max_codes:
80
+ self.codes = self.codes[-self.max_codes:]
81
+ self.codes_full = True
82
+
83
+ if self.training:
84
+ embed_onehot_sum = embed_onehot.sum(0)
85
+ embed_sum = flatten.transpose(0, 1) @ embed_onehot
86
+
87
+ if distributed.is_initialized() and distributed.get_world_size() > 1:
88
+ distributed.all_reduce(embed_onehot_sum)
89
+ distributed.all_reduce(embed_sum)
90
+
91
+ self.cluster_size.data.mul_(self.decay).add_(
92
+ embed_onehot_sum, alpha=1 - self.decay
93
+ )
94
+ self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
95
+ n = self.cluster_size.sum()
96
+ cluster_size = (
97
+ (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
98
+ )
99
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
100
+ self.embed.data.copy_(embed_normalized)
101
+
102
+ diff = (quantize.detach() - input).pow(2).mean()
103
+ quantize = input + (quantize - input).detach()
104
+
105
+ if return_soft_codes:
106
+ return quantize, diff, embed_ind, soft_codes.view(input.shape[:-1] + (-1,))
107
+ elif self.new_return_order:
108
+ return quantize, embed_ind, diff
109
+ else:
110
+ return quantize, diff, embed_ind
111
+
112
+ def embed_code(self, embed_id):
113
+ return F.embedding(embed_id, self.embed.transpose(0, 1))
114
+
115
+
116
+ # Fits a soft-discretized input to a normal-PDF across the specified dimension.
117
+ # In other words, attempts to force the discretization function to have a mean equal utilization across all discrete
118
+ # values with the specified expected variance.
119
+ class DiscretizationLoss(nn.Module):
120
+ def __init__(self, discrete_bins, dim, expected_variance, store_past=0):
121
+ super().__init__()
122
+ self.discrete_bins = discrete_bins
123
+ self.dim = dim
124
+ self.dist = torch.distributions.Normal(0, scale=expected_variance)
125
+ if store_past > 0:
126
+ self.record_past = True
127
+ self.register_buffer("accumulator_index", torch.zeros(1, dtype=torch.long, device='cpu'))
128
+ self.register_buffer("accumulator_filled", torch.zeros(1, dtype=torch.long, device='cpu'))
129
+ self.register_buffer("accumulator", torch.zeros(store_past, discrete_bins))
130
+ else:
131
+ self.record_past = False
132
+
133
+ def forward(self, x):
134
+ other_dims = set(range(len(x.shape)))-set([self.dim])
135
+ averaged = x.sum(dim=tuple(other_dims)) / x.sum()
136
+ averaged = averaged - averaged.mean()
137
+
138
+ if self.record_past:
139
+ acc_count = self.accumulator.shape[0]
140
+ avg = averaged.detach().clone()
141
+ if self.accumulator_filled > 0:
142
+ averaged = torch.mean(self.accumulator, dim=0) * (acc_count-1) / acc_count + \
143
+ averaged / acc_count
144
+
145
+ # Also push averaged into the accumulator.
146
+ self.accumulator[self.accumulator_index] = avg
147
+ self.accumulator_index += 1
148
+ if self.accumulator_index >= acc_count:
149
+ self.accumulator_index *= 0
150
+ if self.accumulator_filled <= 0:
151
+ self.accumulator_filled += 1
152
+
153
+ return torch.sum(-self.dist.log_prob(averaged))
154
+
155
+
156
+ class ResBlock(nn.Module):
157
+ def __init__(self, chan, conv, activation):
158
+ super().__init__()
159
+ self.net = nn.Sequential(
160
+ conv(chan, chan, 3, padding = 1),
161
+ activation(),
162
+ conv(chan, chan, 3, padding = 1),
163
+ activation(),
164
+ conv(chan, chan, 1)
165
+ )
166
+
167
+ def forward(self, x):
168
+ return self.net(x) + x
169
+
170
+
171
+ class UpsampledConv(nn.Module):
172
+ def __init__(self, conv, *args, **kwargs):
173
+ super().__init__()
174
+ assert 'stride' in kwargs.keys()
175
+ self.stride = kwargs['stride']
176
+ del kwargs['stride']
177
+ self.conv = conv(*args, **kwargs)
178
+
179
+ def forward(self, x):
180
+ up = nn.functional.interpolate(x, scale_factor=self.stride, mode='nearest')
181
+ return self.conv(up)
182
+
183
+
184
+ # DiscreteVAE partially derived from lucidrains DALLE implementation
185
+ # Credit: https://github.com/lucidrains/DALLE-pytorch
186
+ class DiscreteVAE(nn.Module):
187
+ def __init__(
188
+ self,
189
+ positional_dims=2,
190
+ num_tokens = 512,
191
+ codebook_dim = 512,
192
+ num_layers = 3,
193
+ num_resnet_blocks = 0,
194
+ hidden_dim = 64,
195
+ channels = 3,
196
+ stride = 2,
197
+ kernel_size = 4,
198
+ use_transposed_convs = True,
199
+ encoder_norm = False,
200
+ activation = 'relu',
201
+ smooth_l1_loss = False,
202
+ straight_through = False,
203
+ normalization = None, # ((0.5,) * 3, (0.5,) * 3),
204
+ record_codes = False,
205
+ discretization_loss_averaging_steps = 100,
206
+ lr_quantizer_args = {},
207
+ ):
208
+ super().__init__()
209
+ has_resblocks = num_resnet_blocks > 0
210
+
211
+ self.num_tokens = num_tokens
212
+ self.num_layers = num_layers
213
+ self.straight_through = straight_through
214
+ self.positional_dims = positional_dims
215
+ self.discrete_loss = DiscretizationLoss(num_tokens, 2, 1 / (num_tokens*2), discretization_loss_averaging_steps)
216
+
217
+ assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now.
218
+ if positional_dims == 2:
219
+ conv = nn.Conv2d
220
+ conv_transpose = nn.ConvTranspose2d
221
+ else:
222
+ conv = nn.Conv1d
223
+ conv_transpose = nn.ConvTranspose1d
224
+ if not use_transposed_convs:
225
+ conv_transpose = functools.partial(UpsampledConv, conv)
226
+
227
+ if activation == 'relu':
228
+ act = nn.ReLU
229
+ elif activation == 'silu':
230
+ act = nn.SiLU
231
+ else:
232
+ assert NotImplementedError()
233
+
234
+
235
+ enc_layers = []
236
+ dec_layers = []
237
+
238
+ if num_layers > 0:
239
+ enc_chans = [hidden_dim * 2 ** i for i in range(num_layers)]
240
+ dec_chans = list(reversed(enc_chans))
241
+
242
+ enc_chans = [channels, *enc_chans]
243
+
244
+ dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0]
245
+ dec_chans = [dec_init_chan, *dec_chans]
246
+
247
+ enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans))
248
+
249
+ pad = (kernel_size - 1) // 2
250
+ for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
251
+ enc_layers.append(nn.Sequential(conv(enc_in, enc_out, kernel_size, stride = stride, padding = pad), act()))
252
+ if encoder_norm:
253
+ enc_layers.append(nn.GroupNorm(8, enc_out))
254
+ dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, kernel_size, stride = stride, padding = pad), act()))
255
+ dec_out_chans = dec_chans[-1]
256
+ innermost_dim = dec_chans[0]
257
+ else:
258
+ enc_layers.append(nn.Sequential(conv(channels, hidden_dim, 1), act()))
259
+ dec_out_chans = hidden_dim
260
+ innermost_dim = hidden_dim
261
+
262
+ for _ in range(num_resnet_blocks):
263
+ dec_layers.insert(0, ResBlock(innermost_dim, conv, act))
264
+ enc_layers.append(ResBlock(innermost_dim, conv, act))
265
+
266
+ if num_resnet_blocks > 0:
267
+ dec_layers.insert(0, conv(codebook_dim, innermost_dim, 1))
268
+
269
+
270
+ enc_layers.append(conv(innermost_dim, codebook_dim, 1))
271
+ dec_layers.append(conv(dec_out_chans, channels, 1))
272
+
273
+ self.encoder = nn.Sequential(*enc_layers)
274
+ self.decoder = nn.Sequential(*dec_layers)
275
+
276
+ self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
277
+ self.codebook = Quantize(codebook_dim, num_tokens, new_return_order=True)
278
+
279
+ # take care of normalization within class
280
+ self.normalization = normalization
281
+ self.record_codes = record_codes
282
+ if record_codes:
283
+ self.codes = torch.zeros((1228800,), dtype=torch.long)
284
+ self.code_ind = 0
285
+ self.total_codes = 0
286
+ self.internal_step = 0
287
+
288
+ def norm(self, images):
289
+ if not self.normalization is not None:
290
+ return images
291
+
292
+ means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization)
293
+ arrange = 'c -> () c () ()' if self.positional_dims == 2 else 'c -> () c ()'
294
+ means, stds = map(lambda t: rearrange(t, arrange), (means, stds))
295
+ images = images.clone()
296
+ images.sub_(means).div_(stds)
297
+ return images
298
+
299
+ def get_debug_values(self, step, __):
300
+ if self.record_codes and self.total_codes > 0:
301
+ # Report annealing schedule
302
+ return {'histogram_codes': self.codes[:self.total_codes]}
303
+ else:
304
+ return {}
305
+
306
+ @torch.no_grad()
307
+ @eval_decorator
308
+ def get_codebook_indices(self, images):
309
+ img = self.norm(images)
310
+ logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
311
+ sampled, codes, _ = self.codebook(logits)
312
+ self.log_codes(codes)
313
+ return codes
314
+
315
+ def decode(
316
+ self,
317
+ img_seq
318
+ ):
319
+ self.log_codes(img_seq)
320
+ if hasattr(self.codebook, 'embed_code'):
321
+ image_embeds = self.codebook.embed_code(img_seq)
322
+ else:
323
+ image_embeds = F.embedding(img_seq, self.codebook.codebook)
324
+ b, n, d = image_embeds.shape
325
+
326
+ kwargs = {}
327
+ if self.positional_dims == 1:
328
+ arrange = 'b n d -> b d n'
329
+ else:
330
+ h = w = int(sqrt(n))
331
+ arrange = 'b (h w) d -> b d h w'
332
+ kwargs = {'h': h, 'w': w}
333
+ image_embeds = rearrange(image_embeds, arrange, **kwargs)
334
+ images = [image_embeds]
335
+ for layer in self.decoder:
336
+ images.append(layer(images[-1]))
337
+ return images[-1], images[-2]
338
+
339
+ def infer(self, img):
340
+ img = self.norm(img)
341
+ logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
342
+ sampled, codes, commitment_loss = self.codebook(logits)
343
+ return self.decode(codes)
344
+
345
+ # Note: This module is not meant to be run in forward() except while training. It has special logic which performs
346
+ # evaluation using quantized values when it detects that it is being run in eval() mode, which will be substantially
347
+ # more lossy (but useful for determining network performance).
348
+ def forward(
349
+ self,
350
+ img
351
+ ):
352
+ img = self.norm(img)
353
+ logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
354
+ sampled, codes, commitment_loss = self.codebook(logits)
355
+ sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1))
356
+
357
+ if self.training:
358
+ out = sampled
359
+ for d in self.decoder:
360
+ out = d(out)
361
+ self.log_codes(codes)
362
+ else:
363
+ # This is non-differentiable, but gives a better idea of how the network is actually performing.
364
+ out, _ = self.decode(codes)
365
+
366
+ # reconstruction loss
367
+ recon_loss = self.loss_fn(img, out, reduction='none')
368
+
369
+ return recon_loss, commitment_loss, out
370
+
371
+ def log_codes(self, codes):
372
+ # This is so we can debug the distribution of codes being learned.
373
+ if self.record_codes and self.internal_step % 10 == 0:
374
+ codes = codes.flatten()
375
+ l = codes.shape[0]
376
+ i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
377
+ self.codes[i:i+l] = codes.cpu()
378
+ self.code_ind = self.code_ind + l
379
+ if self.code_ind >= self.codes.shape[0]:
380
+ self.code_ind = 0
381
+ self.total_codes += 1
382
+ self.internal_step += 1
383
+
384
+
385
+ if __name__ == '__main__':
386
+ v = DiscreteVAE(channels=80, normalization=None, positional_dims=1, num_tokens=8192, codebook_dim=2048,
387
+ hidden_dim=512, num_resnet_blocks=3, kernel_size=3, num_layers=1, use_transposed_convs=False)
388
+ r,l,o=v(torch.randn(1,80,256))
389
+ v.decode(torch.randint(0,8192,(1,256)))
390
+ print(o.shape, l.shape)
models/text_voice_clip.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch import einsum
5
+ from models.transformer import Transformer
6
+
7
+
8
+ def exists(val):
9
+ return val is not None
10
+
11
+
12
+ def masked_mean(t, mask, dim = 1):
13
+ t = t.masked_fill(~mask[:, :, None], 0.)
14
+ return t.sum(dim = 1) / mask.sum(dim = 1)[..., None]
15
+
16
+
17
+ class VoiceCLIP(nn.Module):
18
+ """
19
+ CLIP model retrofitted for performing contrastive evaluation between tokenized audio data and the corresponding
20
+ transcribed text.
21
+
22
+ Originally from https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/dalle_pytorch.py
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ *,
28
+ dim_text=512,
29
+ dim_speech=512,
30
+ dim_latent=512,
31
+ num_text_tokens=256,
32
+ text_enc_depth=6,
33
+ text_seq_len=120,
34
+ text_heads=8,
35
+ num_speech_tokens=8192,
36
+ speech_enc_depth=6,
37
+ speech_heads=8,
38
+ speech_seq_len=250,
39
+ text_mask_percentage=0,
40
+ voice_mask_percentage=0,
41
+ wav_token_compression=1024,
42
+ ):
43
+ super().__init__()
44
+ self.text_emb = nn.Embedding(num_text_tokens, dim_text)
45
+ self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
46
+ self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth,
47
+ heads=text_heads)
48
+ self.to_text_latent = nn.Linear(dim_text, dim_latent, bias=False)
49
+
50
+ self.speech_emb = nn.Embedding(num_speech_tokens, dim_speech)
51
+ self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
52
+ self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech,
53
+ depth=speech_enc_depth, heads=speech_heads)
54
+ self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False)
55
+
56
+ self.temperature = nn.Parameter(torch.tensor(1.))
57
+ self.text_mask_percentage = text_mask_percentage
58
+ self.voice_mask_percentage = voice_mask_percentage
59
+ self.wav_token_compression = wav_token_compression
60
+
61
+ def forward(
62
+ self,
63
+ text,
64
+ text_lengths,
65
+ speech_tokens,
66
+ wav_lengths,
67
+ return_loss=False
68
+ ):
69
+ # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
70
+ # chopping the inputs by the maximum actual length.
71
+ max_text_len = text_lengths.max()
72
+ text = text[:, :max_text_len]
73
+ max_mel_len = wav_lengths.max() // self.wav_token_compression
74
+ speech_tokens = speech_tokens[:, :max_mel_len]
75
+
76
+ b, device = text.shape[0], text.device
77
+ if self.training:
78
+ text_mask = torch.rand_like(text.float()) > self.text_mask_percentage
79
+ voice_mask = torch.rand_like(speech_tokens.float()) > self.voice_mask_percentage
80
+ else:
81
+ text_mask = torch.ones_like(text.float()).bool()
82
+ voice_mask = torch.ones_like(speech_tokens.float()).bool()
83
+
84
+ text_emb = self.text_emb(text)
85
+ text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device))
86
+
87
+ speech_emb = self.speech_emb(speech_tokens)
88
+ speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device))
89
+
90
+ enc_text = self.text_transformer(text_emb, mask=text_mask)
91
+ enc_speech = self.speech_transformer(speech_emb, mask=voice_mask)
92
+
93
+ text_latents = masked_mean(enc_text, text_mask, dim=1)
94
+ speech_latents = masked_mean(enc_speech, voice_mask, dim=1)
95
+
96
+ text_latents = self.to_text_latent(text_latents)
97
+ speech_latents = self.to_speech_latent(speech_latents)
98
+
99
+ text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
100
+
101
+ temp = self.temperature.exp()
102
+
103
+ if not return_loss:
104
+ sim = einsum('n d, n d -> n', text_latents, speech_latents) * temp
105
+ return sim
106
+
107
+ sim = einsum('i d, j d -> i j', text_latents, speech_latents) * temp
108
+ labels = torch.arange(b, device=device)
109
+ loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
110
+ return loss
111
+
112
+
113
+ if __name__ == '__main__':
114
+ clip = VoiceCLIP(text_mask_percentage=.2, voice_mask_percentage=.2)
115
+ clip(torch.randint(0,256,(2,120)),
116
+ torch.tensor([50,100]),
117
+ torch.randint(0,8192,(2,250)),
118
+ torch.tensor([101,102]),
119
+ return_loss=True)
120
+ nonloss = clip(torch.randint(0,256,(2,120)),
121
+ torch.tensor([50,100]),
122
+ torch.randint(0,8192,(2,250)),
123
+ torch.tensor([101,102]),
124
+ return_loss=False)
125
+ print(nonloss.shape)
models/transformer.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from rotary_embedding_torch import RotaryEmbedding, broadcat
7
+ from torch import nn
8
+
9
+
10
+ # helpers
11
+
12
+
13
+ def exists(val):
14
+ return val is not None
15
+
16
+
17
+ def default(val, d):
18
+ return val if exists(val) else d
19
+
20
+
21
+ def cast_tuple(val, depth = 1):
22
+ if isinstance(val, list):
23
+ val = tuple(val)
24
+ return val if isinstance(val, tuple) else (val,) * depth
25
+
26
+
27
+ def max_neg_value(t):
28
+ return -torch.finfo(t.dtype).max
29
+
30
+
31
+ def stable_softmax(t, dim = -1, alpha = 32 ** 2):
32
+ t = t / alpha
33
+ t = t - torch.amax(t, dim = dim, keepdim = True).detach()
34
+ return (t * alpha).softmax(dim = dim)
35
+
36
+
37
+ def route_args(router, args, depth):
38
+ routed_args = [(dict(), dict()) for _ in range(depth)]
39
+ matched_keys = [key for key in args.keys() if key in router]
40
+
41
+ for key in matched_keys:
42
+ val = args[key]
43
+ for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
44
+ new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
45
+ routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
46
+ return routed_args
47
+
48
+
49
+ # classes
50
+ class SequentialSequence(nn.Module):
51
+ def __init__(self, layers, args_route = {}, layer_dropout = 0.):
52
+ super().__init__()
53
+ assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers'
54
+ self.layers = layers
55
+ self.args_route = args_route
56
+ self.layer_dropout = layer_dropout
57
+
58
+ def forward(self, x, **kwargs):
59
+ args = route_args(self.args_route, kwargs, len(self.layers))
60
+ layers_and_args = list(zip(self.layers, args))
61
+
62
+ for (f, g), (f_args, g_args) in layers_and_args:
63
+ x = x + f(x, **f_args)
64
+ x = x + g(x, **g_args)
65
+ return x
66
+
67
+
68
+ class DivideMax(nn.Module):
69
+ def __init__(self, dim):
70
+ super().__init__()
71
+ self.dim = dim
72
+
73
+ def forward(self, x):
74
+ maxes = x.amax(dim = self.dim, keepdim = True).detach()
75
+ return x / maxes
76
+
77
+
78
+ # https://arxiv.org/abs/2103.17239
79
+ class LayerScale(nn.Module):
80
+ def __init__(self, dim, depth, fn):
81
+ super().__init__()
82
+ if depth <= 18:
83
+ init_eps = 0.1
84
+ elif depth > 18 and depth <= 24:
85
+ init_eps = 1e-5
86
+ else:
87
+ init_eps = 1e-6
88
+
89
+ scale = torch.zeros(1, 1, dim).fill_(init_eps)
90
+ self.scale = nn.Parameter(scale)
91
+ self.fn = fn
92
+ def forward(self, x, **kwargs):
93
+ return self.fn(x, **kwargs) * self.scale
94
+
95
+ # layer norm
96
+
97
+
98
+ class PreNorm(nn.Module):
99
+ def __init__(self, dim, fn, sandwich = False):
100
+ super().__init__()
101
+ self.norm = nn.LayerNorm(dim)
102
+ self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity()
103
+ self.fn = fn
104
+
105
+ def forward(self, x, **kwargs):
106
+ x = self.norm(x)
107
+ x = self.fn(x, **kwargs)
108
+ return self.norm_out(x)
109
+
110
+ # feed forward
111
+
112
+
113
+ class GEGLU(nn.Module):
114
+ def forward(self, x):
115
+ x, gates = x.chunk(2, dim = -1)
116
+ return x * F.gelu(gates)
117
+
118
+
119
+ class FeedForward(nn.Module):
120
+ def __init__(self, dim, dropout = 0., mult = 4.):
121
+ super().__init__()
122
+ self.net = nn.Sequential(
123
+ nn.Linear(dim, dim * mult * 2),
124
+ GEGLU(),
125
+ nn.Dropout(dropout),
126
+ nn.Linear(dim * mult, dim)
127
+ )
128
+
129
+ def forward(self, x):
130
+ return self.net(x)
131
+
132
+ # Attention
133
+
134
+
135
+ class Attention(nn.Module):
136
+ def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0.):
137
+ super().__init__()
138
+ inner_dim = dim_head * heads
139
+ self.heads = heads
140
+ self.seq_len = seq_len
141
+ self.scale = dim_head ** -0.5
142
+
143
+ self.causal = causal
144
+
145
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
146
+ self.to_out = nn.Sequential(
147
+ nn.Linear(inner_dim, dim),
148
+ nn.Dropout(dropout)
149
+ )
150
+
151
+ def forward(self, x, mask = None):
152
+ b, n, _, h, device = *x.shape, self.heads, x.device
153
+ softmax = torch.softmax
154
+
155
+ qkv = self.to_qkv(x).chunk(3, dim = -1)
156
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
157
+
158
+ q = q * self.scale
159
+
160
+ dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
161
+ mask_value = max_neg_value(dots)
162
+
163
+ if exists(mask):
164
+ mask = rearrange(mask, 'b j -> b () () j')
165
+ dots.masked_fill_(~mask, mask_value)
166
+ del mask
167
+
168
+ if self.causal:
169
+ i, j = dots.shape[-2:]
170
+ mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
171
+ dots.masked_fill_(mask, mask_value)
172
+
173
+ attn = softmax(dots, dim=-1)
174
+
175
+ out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
176
+ out = rearrange(out, 'b h n d -> b n (h d)')
177
+ out = self.to_out(out)
178
+ return out
179
+
180
+
181
+ # main transformer class
182
+ class Transformer(nn.Module):
183
+ def __init__(
184
+ self,
185
+ *,
186
+ dim,
187
+ depth,
188
+ seq_len,
189
+ causal = True,
190
+ heads = 8,
191
+ dim_head = 64,
192
+ ff_mult = 4,
193
+ attn_dropout = 0.,
194
+ ff_dropout = 0.,
195
+ sparse_attn = False,
196
+ sandwich_norm = False,
197
+ ):
198
+ super().__init__()
199
+ layers = nn.ModuleList([])
200
+ sparse_layer = cast_tuple(sparse_attn, depth)
201
+
202
+ for ind, sparse_attn in zip(range(depth), sparse_layer):
203
+ attn = Attention(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)
204
+
205
+ ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout)
206
+
207
+ layers.append(nn.ModuleList([
208
+ LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich = sandwich_norm)),
209
+ LayerScale(dim, ind + 1, PreNorm(dim, ff, sandwich = sandwich_norm))
210
+ ]))
211
+
212
+ execute_type = SequentialSequence
213
+ route_attn = ((True, False),) * depth
214
+ attn_route_map = {'mask': route_attn}
215
+
216
+ self.layers = execute_type(layers, args_route = attn_route_map)
217
+
218
+ def forward(self, x, **kwargs):
219
+ return self.layers(x, **kwargs)
models/unified_voice.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers import GPT2Config, GPT2PreTrainedModel
7
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
8
+ from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
9
+ from models.arch_util import AttentionBlock
10
+
11
+
12
+
13
+ def null_position_embeddings(range, dim):
14
+ return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
15
+
16
+
17
+ class ResBlock(nn.Module):
18
+ """
19
+ Basic residual convolutional block that uses GroupNorm.
20
+ """
21
+ def __init__(self, chan):
22
+ super().__init__()
23
+ self.net = nn.Sequential(
24
+ nn.Conv1d(chan, chan, kernel_size=3, padding=1),
25
+ nn.GroupNorm(chan//8, chan),
26
+ nn.ReLU(),
27
+ nn.Conv1d(chan, chan, kernel_size=3, padding=1),
28
+ nn.GroupNorm(chan//8, chan)
29
+ )
30
+
31
+ def forward(self, x):
32
+ return F.relu(self.net(x) + x)
33
+
34
+
35
+ class GPT2InferenceModel(GPT2PreTrainedModel):
36
+ def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear):
37
+ super().__init__(config)
38
+ self.transformer = gpt
39
+ self.text_pos_embedding = text_pos_emb
40
+ self.embeddings = embeddings
41
+ self.lm_head = nn.Sequential(norm, linear)
42
+
43
+ # Model parallel
44
+ self.model_parallel = False
45
+ self.device_map = None
46
+ self.cached_mel_emb = None
47
+
48
+ def parallelize(self, device_map=None):
49
+ self.device_map = (
50
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
51
+ if device_map is None
52
+ else device_map
53
+ )
54
+ assert_device_map(self.device_map, len(self.transformer.h))
55
+ self.transformer.parallelize(self.device_map)
56
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
57
+ self.model_parallel = True
58
+
59
+ def deparallelize(self):
60
+ self.transformer.deparallelize()
61
+ self.transformer = self.transformer.to("cpu")
62
+ self.lm_head = self.lm_head.to("cpu")
63
+ self.model_parallel = False
64
+ torch.cuda.empty_cache()
65
+
66
+ def get_output_embeddings(self):
67
+ return self.lm_head
68
+
69
+ def set_output_embeddings(self, new_embeddings):
70
+ self.lm_head = new_embeddings
71
+
72
+ def store_mel_emb(self, mel_emb):
73
+ self.cached_mel_emb = mel_emb
74
+
75
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
76
+
77
+ token_type_ids = kwargs.get("token_type_ids", None)
78
+ # only last token for inputs_ids if past is defined in kwargs
79
+ if past:
80
+ input_ids = input_ids[:, -1].unsqueeze(-1)
81
+ if token_type_ids is not None:
82
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
83
+
84
+ attention_mask = kwargs.get("attention_mask", None)
85
+ position_ids = kwargs.get("position_ids", None)
86
+
87
+ if attention_mask is not None and position_ids is None:
88
+ # create position_ids on the fly for batch generation
89
+ position_ids = attention_mask.long().cumsum(-1) - 1
90
+ position_ids.masked_fill_(attention_mask == 0, 1)
91
+ if past:
92
+ position_ids = position_ids[:, -1].unsqueeze(-1)
93
+ else:
94
+ position_ids = None
95
+ return {
96
+ "input_ids": input_ids,
97
+ "past_key_values": past,
98
+ "use_cache": kwargs.get("use_cache"),
99
+ "position_ids": position_ids,
100
+ "attention_mask": attention_mask,
101
+ "token_type_ids": token_type_ids,
102
+ }
103
+
104
+ def forward(
105
+ self,
106
+ input_ids=None,
107
+ past_key_values=None,
108
+ attention_mask=None,
109
+ token_type_ids=None,
110
+ position_ids=None,
111
+ head_mask=None,
112
+ inputs_embeds=None,
113
+ encoder_hidden_states=None,
114
+ encoder_attention_mask=None,
115
+ labels=None,
116
+ use_cache=None,
117
+ output_attentions=None,
118
+ output_hidden_states=None,
119
+ return_dict=None,
120
+ ):
121
+ assert self.cached_mel_emb is not None
122
+ assert inputs_embeds is None # Not supported by this inference model.
123
+ assert labels is None # Training not supported by this inference model.
124
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
125
+
126
+ # Create embedding
127
+ mel_len = self.cached_mel_emb.shape[1]
128
+ if input_ids.shape[1] != 1:
129
+ text_inputs = input_ids[:, mel_len:]
130
+ text_emb = self.embeddings(text_inputs)
131
+ text_emb = text_emb + self.text_pos_embedding(text_emb)
132
+ if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
133
+ mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0]//self.cached_mel_emb.shape[0], 0)
134
+ else:
135
+ mel_emb = self.cached_mel_emb
136
+ emb = torch.cat([mel_emb, text_emb], dim=1)
137
+ else:
138
+ emb = self.embeddings(input_ids)
139
+ emb = emb + self.text_pos_embedding.get_fixed_embedding(attention_mask.shape[1]-mel_len, attention_mask.device)
140
+
141
+ transformer_outputs = self.transformer(
142
+ inputs_embeds=emb,
143
+ past_key_values=past_key_values,
144
+ attention_mask=attention_mask,
145
+ token_type_ids=token_type_ids,
146
+ position_ids=position_ids,
147
+ head_mask=head_mask,
148
+ encoder_hidden_states=encoder_hidden_states,
149
+ encoder_attention_mask=encoder_attention_mask,
150
+ use_cache=use_cache,
151
+ output_attentions=output_attentions,
152
+ output_hidden_states=output_hidden_states,
153
+ return_dict=return_dict,
154
+ )
155
+ hidden_states = transformer_outputs[0]
156
+
157
+ # Set device for model parallelism
158
+ if self.model_parallel:
159
+ torch.cuda.set_device(self.transformer.first_device)
160
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
161
+
162
+ lm_logits = self.lm_head(hidden_states)
163
+
164
+ if not return_dict:
165
+ return (lm_logits,) + transformer_outputs[1:]
166
+
167
+ return CausalLMOutputWithCrossAttentions(
168
+ loss=None,
169
+ logits=lm_logits,
170
+ past_key_values=transformer_outputs.past_key_values,
171
+ hidden_states=transformer_outputs.hidden_states,
172
+ attentions=transformer_outputs.attentions,
173
+ cross_attentions=transformer_outputs.cross_attentions,
174
+ )
175
+
176
+ @staticmethod
177
+ def _reorder_cache(past, beam_idx):
178
+ """
179
+ This function is used to re-order the :obj:`past_key_values` cache if
180
+ :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
181
+ called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
182
+ """
183
+ return tuple(
184
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
185
+ for layer_past in past
186
+ )
187
+
188
+
189
+ class ConditioningEncoder(nn.Module):
190
+ def __init__(self,
191
+ spec_dim,
192
+ embedding_dim,
193
+ attn_blocks=6,
194
+ num_attn_heads=4,
195
+ do_checkpointing=False):
196
+ super().__init__()
197
+ attn = []
198
+ self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
199
+ for a in range(attn_blocks):
200
+ attn.append(AttentionBlock(embedding_dim, num_attn_heads))
201
+ self.attn = nn.Sequential(*attn)
202
+ self.dim = embedding_dim
203
+ self.do_checkpointing = do_checkpointing
204
+
205
+ def forward(self, x):
206
+ h = self.init(x)
207
+ h = self.attn(h)
208
+ return h[:, :, 0]
209
+
210
+
211
+ class LearnedPositionEmbeddings(nn.Module):
212
+ def __init__(self, seq_len, model_dim, init=.02):
213
+ super().__init__()
214
+ self.emb = nn.Embedding(seq_len, model_dim)
215
+ # Initializing this way is standard for GPT-2
216
+ self.emb.weight.data.normal_(mean=0.0, std=init)
217
+
218
+ def forward(self, x):
219
+ sl = x.shape[1]
220
+ return self.emb(torch.arange(0, sl, device=x.device))
221
+
222
+ def get_fixed_embedding(self, ind, dev):
223
+ return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
224
+
225
+
226
+ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
227
+ """
228
+ GPT-2 implemented by the HuggingFace library.
229
+ """
230
+ from transformers import GPT2Config, GPT2Model
231
+ gpt_config = GPT2Config(vocab_size=256, # Unused.
232
+ n_positions=max_mel_seq_len+max_text_seq_len,
233
+ n_ctx=max_mel_seq_len+max_text_seq_len,
234
+ n_embd=model_dim,
235
+ n_layer=layers,
236
+ n_head=heads,
237
+ gradient_checkpointing=checkpointing,
238
+ use_cache=not checkpointing)
239
+ gpt = GPT2Model(gpt_config)
240
+ # Override the built in positional embeddings
241
+ del gpt.wpe
242
+ gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
243
+ # Built-in token embeddings are unused.
244
+ del gpt.wte
245
+ return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim),\
246
+ None, None
247
+
248
+
249
+ class MelEncoder(nn.Module):
250
+ def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2):
251
+ super().__init__()
252
+ self.channels = channels
253
+ self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=3, padding=1),
254
+ nn.Sequential(*[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]),
255
+ nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1),
256
+ nn.GroupNorm(channels//16, channels//2),
257
+ nn.ReLU(),
258
+ nn.Sequential(*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]),
259
+ nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1),
260
+ nn.GroupNorm(channels//8, channels),
261
+ nn.ReLU(),
262
+ nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
263
+ )
264
+ self.reduction = 4
265
+
266
+
267
+ def forward(self, x):
268
+ for e in self.encoder:
269
+ x = e(x)
270
+ return x.permute(0,2,1)
271
+
272
+
273
+ class UnifiedVoice(nn.Module):
274
+ def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
275
+ mel_length_compression=1024, number_text_tokens=256,
276
+ start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
277
+ stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
278
+ checkpointing=True):
279
+ """
280
+ Args:
281
+ layers: Number of layers in transformer stack.
282
+ model_dim: Operating dimensions of the transformer
283
+ heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
284
+ max_text_tokens: Maximum number of text tokens that will be encountered by model.
285
+ max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
286
+ max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
287
+ mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
288
+ number_text_tokens:
289
+ start_text_token:
290
+ stop_text_token:
291
+ number_mel_codes:
292
+ start_mel_token:
293
+ stop_mel_token:
294
+ train_solo_embeddings:
295
+ use_mel_codes_as_input:
296
+ checkpointing:
297
+ """
298
+ super().__init__()
299
+
300
+ self.number_text_tokens = number_text_tokens
301
+ self.start_text_token = start_text_token
302
+ self.stop_text_token = stop_text_token
303
+ self.number_mel_codes = number_mel_codes
304
+ self.start_mel_token = start_mel_token
305
+ self.stop_mel_token = stop_mel_token
306
+ self.layers = layers
307
+ self.heads = heads
308
+ self.max_mel_tokens = max_mel_tokens
309
+ self.max_text_tokens = max_text_tokens
310
+ self.model_dim = model_dim
311
+ self.max_conditioning_inputs = max_conditioning_inputs
312
+ self.mel_length_compression = mel_length_compression
313
+ self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
314
+ self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
315
+ if use_mel_codes_as_input:
316
+ self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
317
+ else:
318
+ self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
319
+ self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
320
+ build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens+2+self.max_conditioning_inputs, self.max_text_tokens+2, checkpointing)
321
+ if train_solo_embeddings:
322
+ self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
323
+ self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
324
+ else:
325
+ self.mel_solo_embedding = 0
326
+ self.text_solo_embedding = 0
327
+
328
+ self.final_norm = nn.LayerNorm(model_dim)
329
+ self.text_head = nn.Linear(model_dim, self.number_text_tokens)
330
+ self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
331
+
332
+ # Initialize the embeddings per the GPT-2 scheme
333
+ embeddings = [self.text_embedding]
334
+ if use_mel_codes_as_input:
335
+ embeddings.append(self.mel_embedding)
336
+ for module in embeddings:
337
+ module.weight.data.normal_(mean=0.0, std=.02)
338
+
339
+ def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
340
+ inp = F.pad(input, (1,0), value=start_token)
341
+ tar = F.pad(input, (0,1), value=stop_token)
342
+ return inp, tar
343
+
344
+ def set_mel_padding(self, mel_input_tokens, wav_lengths):
345
+ """
346
+ Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
347
+ that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
348
+ preformatting to create a working TTS model.
349
+ """
350
+ # Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
351
+ mel_lengths = wav_lengths // self.mel_length_compression
352
+ for b in range(len(mel_lengths)):
353
+ actual_end = mel_lengths[b] + 1 # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token.
354
+ if actual_end < mel_input_tokens.shape[-1]:
355
+ mel_input_tokens[b, actual_end:] = self.stop_mel_token
356
+ return mel_input_tokens
357
+
358
+ def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False):
359
+ if second_inputs is not None:
360
+ emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
361
+ else:
362
+ emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
363
+
364
+ gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
365
+ if get_attns:
366
+ return gpt_out.attentions
367
+
368
+ enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input
369
+ enc = self.final_norm(enc)
370
+ first_logits = enc[:, :first_inputs.shape[1]]
371
+ first_logits = first_head(first_logits)
372
+ first_logits = first_logits.permute(0,2,1)
373
+ if second_inputs is not None:
374
+ second_logits = enc[:, -second_inputs.shape[1]:]
375
+ second_logits = second_head(second_logits)
376
+ second_logits = second_logits.permute(0,2,1)
377
+ return first_logits, second_logits
378
+ else:
379
+ return first_logits
380
+
381
+ def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False):
382
+ """
383
+ Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
384
+ (actuated by `text_first`).
385
+
386
+ speech_conditioning_input: MEL float tensor, (b,80,s)
387
+ text_inputs: long tensor, (b,t)
388
+ text_lengths: long tensor, (b,)
389
+ mel_inputs: long tensor, (b,m)
390
+ wav_lengths: long tensor, (b,)
391
+ raw_mels: MEL float tensor (b,80,s)
392
+ """
393
+ assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}'
394
+ assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
395
+
396
+ # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
397
+ # chopping the inputs by the maximum actual length.
398
+ max_text_len = text_lengths.max()
399
+ text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token)
400
+ max_mel_len = wav_lengths.max() // self.mel_length_compression
401
+ mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token)
402
+ if raw_mels is not None:
403
+ raw_mels = raw_mels[:, :, :max_mel_len*4]
404
+ mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
405
+
406
+ speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
407
+ conds = []
408
+ for j in range(speech_conditioning_input.shape[1]):
409
+ conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
410
+ conds = torch.stack(conds, dim=1)
411
+
412
+ text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
413
+ text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
414
+ mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
415
+ if raw_mels is not None:
416
+ mel_inp = F.pad(raw_mels, (0, 8))
417
+ else:
418
+ mel_inp = mel_codes
419
+ mel_emb = self.mel_embedding(mel_inp)
420
+ mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
421
+ if text_first:
422
+ text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions)
423
+ else:
424
+ mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions)
425
+
426
+ if return_attentions:
427
+ return mel_logits
428
+ loss_text = F.cross_entropy(text_logits, text_targets.long())
429
+ loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
430
+ return loss_text.mean(), loss_mel.mean(), mel_logits
431
+
432
+ def text_forward(self, speech_conditioning_input, text_inputs, text_lengths):
433
+ """
434
+ Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the
435
+ model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided).
436
+ """
437
+ assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
438
+
439
+ # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
440
+ # chopping the inputs by the maximum actual length.
441
+ max_text_len = text_lengths.max()
442
+ text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token)
443
+
444
+ speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
445
+ conds = []
446
+ for j in range(speech_conditioning_input.shape[1]):
447
+ conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
448
+ conds = torch.stack(conds, dim=1)
449
+
450
+ text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
451
+ text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + self.text_solo_embedding
452
+ text_logits = self.get_logits(conds, text_emb, self.text_head)
453
+ loss_text = F.cross_entropy(text_logits, text_targets.long())
454
+ return loss_text.mean()
455
+
456
+ def speech_forward(self, speech_conditioning_input, mel_codes, wav_lengths, raw_mels=None):
457
+ """
458
+ Performs autoregressive modeling on only speech data.
459
+ """
460
+ assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}'
461
+
462
+ # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
463
+ # chopping the inputs by the maximum actual length.
464
+ max_mel_len = wav_lengths.max() // self.mel_length_compression
465
+ mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token)
466
+ mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
467
+ if raw_mels is not None:
468
+ raw_mels = raw_mels[:, :, :max_mel_len*4]
469
+
470
+ speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
471
+ conds = []
472
+ for j in range(speech_conditioning_input.shape[1]):
473
+ conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
474
+ conds = torch.stack(conds, dim=1)
475
+
476
+ mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
477
+ if raw_mels is not None:
478
+ mel_inp = F.pad(raw_mels, (0, 4))
479
+ else:
480
+ mel_inp = mel_codes
481
+ mel_emb = self.mel_embedding(mel_inp)
482
+ mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) + self.mel_solo_embedding
483
+ mel_logits = self.get_logits(conds, mel_emb, self.mel_head)
484
+ loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
485
+ return loss_mel.mean()
486
+
487
+ def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs):
488
+ seq_length = self.max_mel_tokens + self.max_text_tokens + 2
489
+ if not hasattr(self, 'inference_model'):
490
+ # TODO: Decouple gpt_config from this inference model.
491
+ gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
492
+ n_positions=seq_length,
493
+ n_ctx=seq_length,
494
+ n_embd=self.model_dim,
495
+ n_layer=self.layers,
496
+ n_head=self.heads,
497
+ gradient_checkpointing=False,
498
+ use_cache=True)
499
+ self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
500
+ self.gpt.wte = self.mel_embedding
501
+
502
+ text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
503
+ text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
504
+ text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
505
+
506
+ speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
507
+ conds = []
508
+ for j in range(speech_conditioning_input.shape[1]):
509
+ conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
510
+ conds = torch.stack(conds, dim=1)
511
+
512
+ emb = torch.cat([conds, text_emb], dim=1)
513
+ self.inference_model.store_mel_emb(emb)
514
+
515
+ fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],), fill_value=1, dtype=torch.long, device=text_inputs.device)
516
+ fake_inputs[:,-1] = self.start_mel_token
517
+
518
+ gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
519
+ max_length=seq_length, **hf_generate_kwargs)
520
+ return gen[:, fake_inputs.shape[1]:]
521
+
522
+
523
+ if __name__ == '__main__':
524
+ gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4)
525
+ l = gpt(torch.randn(2, 3, 80, 800),
526
+ torch.randint(high=120, size=(2,120)),
527
+ torch.tensor([32, 120]),
528
+ torch.randint(high=8192, size=(2,250)),
529
+ torch.tensor([250*256,195*256]))
530
+ gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchaudio
3
+ rotary_embedding_torch
4
+ transformers
5
+ tokenizers
6
+ pyfastmp3decoder
7
+ inflect
utils/audio.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+
4
+
5
+ def load_wav_to_torch(full_path):
6
+ sampling_rate, data = read(full_path)
7
+ if data.dtype == np.int32:
8
+ norm_fix = 2 ** 31
9
+ elif data.dtype == np.int16:
10
+ norm_fix = 2 ** 15
11
+ elif data.dtype == np.float16 or data.dtype == np.float32:
12
+ norm_fix = 1.
13
+ else:
14
+ raise NotImplemented(f"Provided data dtype not supported: {data.dtype}")
15
+ return (torch.FloatTensor(data.astype(np.float32)) / norm_fix, sampling_rate)
16
+
17
+
18
+ def load_audio(audiopath, sampling_rate):
19
+ if audiopath[-4:] == '.wav':
20
+ audio, lsr = load_wav_to_torch(audiopath)
21
+ elif audiopath[-4:] == '.mp3':
22
+ # https://github.com/neonbjb/pyfastmp3decoder - Definitely worth it.
23
+ from pyfastmp3decoder.mp3decoder import load_mp3
24
+ audio, lsr = load_mp3(audiopath, sampling_rate)
25
+ audio = torch.FloatTensor(audio)
26
+
27
+ # Remove any channel data.
28
+ if len(audio.shape) > 1:
29
+ if audio.shape[0] < 5:
30
+ audio = audio[0]
31
+ else:
32
+ assert audio.shape[1] < 5
33
+ audio = audio[:, 0]
34
+
35
+ if lsr != sampling_rate:
36
+ audio = torchaudio.functional.resample(audio, lsr, sampling_rate)
37
+
38
+ # Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
39
+ # '2' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
40
+ if torch.any(audio > 2) or not torch.any(audio < 0):
41
+ print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
42
+ audio.clip_(-1, 1)
43
+
44
+ return audio.unsqueeze(0)
utils/diffusion.py ADDED
@@ -0,0 +1,1232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is an almost carbon copy of gaussian_diffusion.py from OpenAI's ImprovedDiffusion repo, which itself:
3
+
4
+ This code started out as a PyTorch port of Ho et al's diffusion models:
5
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py
6
+
7
+ Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
8
+ """
9
+
10
+ import enum
11
+ import math
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torch as th
16
+ from tqdm import tqdm
17
+
18
+
19
+ def normal_kl(mean1, logvar1, mean2, logvar2):
20
+ """
21
+ Compute the KL divergence between two gaussians.
22
+
23
+ Shapes are automatically broadcasted, so batches can be compared to
24
+ scalars, among other use cases.
25
+ """
26
+ tensor = None
27
+ for obj in (mean1, logvar1, mean2, logvar2):
28
+ if isinstance(obj, th.Tensor):
29
+ tensor = obj
30
+ break
31
+ assert tensor is not None, "at least one argument must be a Tensor"
32
+
33
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
34
+ # Tensors, but it does not work for th.exp().
35
+ logvar1, logvar2 = [
36
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
37
+ for x in (logvar1, logvar2)
38
+ ]
39
+
40
+ return 0.5 * (
41
+ -1.0
42
+ + logvar2
43
+ - logvar1
44
+ + th.exp(logvar1 - logvar2)
45
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
46
+ )
47
+
48
+
49
+ def approx_standard_normal_cdf(x):
50
+ """
51
+ A fast approximation of the cumulative distribution function of the
52
+ standard normal.
53
+ """
54
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
55
+
56
+
57
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
58
+ """
59
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
60
+ given image.
61
+
62
+ :param x: the target images. It is assumed that this was uint8 values,
63
+ rescaled to the range [-1, 1].
64
+ :param means: the Gaussian mean Tensor.
65
+ :param log_scales: the Gaussian log stddev Tensor.
66
+ :return: a tensor like x of log probabilities (in nats).
67
+ """
68
+ assert x.shape == means.shape == log_scales.shape
69
+ centered_x = x - means
70
+ inv_stdv = th.exp(-log_scales)
71
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
72
+ cdf_plus = approx_standard_normal_cdf(plus_in)
73
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
74
+ cdf_min = approx_standard_normal_cdf(min_in)
75
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
76
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
77
+ cdf_delta = cdf_plus - cdf_min
78
+ log_probs = th.where(
79
+ x < -0.999,
80
+ log_cdf_plus,
81
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
82
+ )
83
+ assert log_probs.shape == x.shape
84
+ return log_probs
85
+
86
+
87
+ def mean_flat(tensor):
88
+ """
89
+ Take the mean over all non-batch dimensions.
90
+ """
91
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
92
+
93
+
94
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
95
+ """
96
+ Get a pre-defined beta schedule for the given name.
97
+
98
+ The beta schedule library consists of beta schedules which remain similar
99
+ in the limit of num_diffusion_timesteps.
100
+ Beta schedules may be added, but should not be removed or changed once
101
+ they are committed to maintain backwards compatibility.
102
+ """
103
+ if schedule_name == "linear":
104
+ # Linear schedule from Ho et al, extended to work for any number of
105
+ # diffusion steps.
106
+ scale = 1000 / num_diffusion_timesteps
107
+ beta_start = scale * 0.0001
108
+ beta_end = scale * 0.02
109
+ return np.linspace(
110
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
111
+ )
112
+ elif schedule_name == "cosine":
113
+ return betas_for_alpha_bar(
114
+ num_diffusion_timesteps,
115
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
116
+ )
117
+ else:
118
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
119
+
120
+
121
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
122
+ """
123
+ Create a beta schedule that discretizes the given alpha_t_bar function,
124
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
125
+
126
+ :param num_diffusion_timesteps: the number of betas to produce.
127
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
128
+ produces the cumulative product of (1-beta) up to that
129
+ part of the diffusion process.
130
+ :param max_beta: the maximum beta to use; use values lower than 1 to
131
+ prevent singularities.
132
+ """
133
+ betas = []
134
+ for i in range(num_diffusion_timesteps):
135
+ t1 = i / num_diffusion_timesteps
136
+ t2 = (i + 1) / num_diffusion_timesteps
137
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
138
+ return np.array(betas)
139
+
140
+
141
+ class ModelMeanType(enum.Enum):
142
+ """
143
+ Which type of output the model predicts.
144
+ """
145
+
146
+ PREVIOUS_X = 'previous_x' # the model predicts x_{t-1}
147
+ START_X = 'start_x' # the model predicts x_0
148
+ EPSILON = 'epsilon' # the model predicts epsilon
149
+
150
+
151
+ class ModelVarType(enum.Enum):
152
+ """
153
+ What is used as the model's output variance.
154
+
155
+ The LEARNED_RANGE option has been added to allow the model to predict
156
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
157
+ """
158
+
159
+ LEARNED = 'learned'
160
+ FIXED_SMALL = 'fixed_small'
161
+ FIXED_LARGE = 'fixed_large'
162
+ LEARNED_RANGE = 'learned_range'
163
+
164
+
165
+ class LossType(enum.Enum):
166
+ MSE = 'mse' # use raw MSE loss (and KL when learning variances)
167
+ RESCALED_MSE = 'rescaled_mse' # use raw MSE loss (with RESCALED_KL when learning variances)
168
+ KL = 'kl' # use the variational lower-bound
169
+ RESCALED_KL = 'rescaled_kl' # like KL, but rescale to estimate the full VLB
170
+
171
+ def is_vb(self):
172
+ return self == LossType.KL or self == LossType.RESCALED_KL
173
+
174
+
175
+ class GaussianDiffusion:
176
+ """
177
+ Utilities for training and sampling diffusion models.
178
+
179
+ Ported directly from here, and then adapted over time to further experimentation.
180
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
181
+
182
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
183
+ starting at T and going to 1.
184
+ :param model_mean_type: a ModelMeanType determining what the model outputs.
185
+ :param model_var_type: a ModelVarType determining how variance is output.
186
+ :param loss_type: a LossType determining the loss function to use.
187
+ :param rescale_timesteps: if True, pass floating point timesteps into the
188
+ model so that they are always scaled like in the
189
+ original paper (0 to 1000).
190
+ """
191
+
192
+ def __init__(
193
+ self,
194
+ *,
195
+ betas,
196
+ model_mean_type,
197
+ model_var_type,
198
+ loss_type,
199
+ rescale_timesteps=False,
200
+ ):
201
+ self.model_mean_type = ModelMeanType(model_mean_type)
202
+ self.model_var_type = ModelVarType(model_var_type)
203
+ self.loss_type = LossType(loss_type)
204
+ self.rescale_timesteps = rescale_timesteps
205
+
206
+ # Use float64 for accuracy.
207
+ betas = np.array(betas, dtype=np.float64)
208
+ self.betas = betas
209
+ assert len(betas.shape) == 1, "betas must be 1-D"
210
+ assert (betas > 0).all() and (betas <= 1).all()
211
+
212
+ self.num_timesteps = int(betas.shape[0])
213
+
214
+ alphas = 1.0 - betas
215
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
216
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
217
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
218
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
219
+
220
+ # calculations for diffusion q(x_t | x_{t-1}) and others
221
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
222
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
223
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
224
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
225
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
226
+
227
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
228
+ self.posterior_variance = (
229
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
230
+ )
231
+ # log calculation clipped because the posterior variance is 0 at the
232
+ # beginning of the diffusion chain.
233
+ self.posterior_log_variance_clipped = np.log(
234
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
235
+ )
236
+ self.posterior_mean_coef1 = (
237
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
238
+ )
239
+ self.posterior_mean_coef2 = (
240
+ (1.0 - self.alphas_cumprod_prev)
241
+ * np.sqrt(alphas)
242
+ / (1.0 - self.alphas_cumprod)
243
+ )
244
+
245
+ def q_mean_variance(self, x_start, t):
246
+ """
247
+ Get the distribution q(x_t | x_0).
248
+
249
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
250
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
251
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
252
+ """
253
+ mean = (
254
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
255
+ )
256
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
257
+ log_variance = _extract_into_tensor(
258
+ self.log_one_minus_alphas_cumprod, t, x_start.shape
259
+ )
260
+ return mean, variance, log_variance
261
+
262
+ def q_sample(self, x_start, t, noise=None):
263
+ """
264
+ Diffuse the data for a given number of diffusion steps.
265
+
266
+ In other words, sample from q(x_t | x_0).
267
+
268
+ :param x_start: the initial data batch.
269
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
270
+ :param noise: if specified, the split-out normal noise.
271
+ :return: A noisy version of x_start.
272
+ """
273
+ if noise is None:
274
+ noise = th.randn_like(x_start)
275
+ assert noise.shape == x_start.shape
276
+ return (
277
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
278
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
279
+ * noise
280
+ )
281
+
282
+ def q_posterior_mean_variance(self, x_start, x_t, t):
283
+ """
284
+ Compute the mean and variance of the diffusion posterior:
285
+
286
+ q(x_{t-1} | x_t, x_0)
287
+
288
+ """
289
+ assert x_start.shape == x_t.shape
290
+ posterior_mean = (
291
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
292
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
293
+ )
294
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
295
+ posterior_log_variance_clipped = _extract_into_tensor(
296
+ self.posterior_log_variance_clipped, t, x_t.shape
297
+ )
298
+ assert (
299
+ posterior_mean.shape[0]
300
+ == posterior_variance.shape[0]
301
+ == posterior_log_variance_clipped.shape[0]
302
+ == x_start.shape[0]
303
+ )
304
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
305
+
306
+ def p_mean_variance(
307
+ self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
308
+ ):
309
+ """
310
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
311
+ the initial x, x_0.
312
+
313
+ :param model: the model, which takes a signal and a batch of timesteps
314
+ as input.
315
+ :param x: the [N x C x ...] tensor at time t.
316
+ :param t: a 1-D Tensor of timesteps.
317
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
318
+ :param denoised_fn: if not None, a function which applies to the
319
+ x_start prediction before it is used to sample. Applies before
320
+ clip_denoised.
321
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
322
+ pass to the model. This can be used for conditioning.
323
+ :return: a dict with the following keys:
324
+ - 'mean': the model mean output.
325
+ - 'variance': the model variance output.
326
+ - 'log_variance': the log of 'variance'.
327
+ - 'pred_xstart': the prediction for x_0.
328
+ """
329
+ if model_kwargs is None:
330
+ model_kwargs = {}
331
+
332
+ B, C = x.shape[:2]
333
+ assert t.shape == (B,)
334
+ model_output = model(x, self._scale_timesteps(t), **model_kwargs)
335
+
336
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
337
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
338
+ model_output, model_var_values = th.split(model_output, C, dim=1)
339
+ if self.model_var_type == ModelVarType.LEARNED:
340
+ model_log_variance = model_var_values
341
+ model_variance = th.exp(model_log_variance)
342
+ else:
343
+ min_log = _extract_into_tensor(
344
+ self.posterior_log_variance_clipped, t, x.shape
345
+ )
346
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
347
+ # The model_var_values is [-1, 1] for [min_var, max_var].
348
+ frac = (model_var_values + 1) / 2
349
+ model_log_variance = frac * max_log + (1 - frac) * min_log
350
+ model_variance = th.exp(model_log_variance)
351
+ else:
352
+ model_variance, model_log_variance = {
353
+ # for fixedlarge, we set the initial (log-)variance like so
354
+ # to get a better decoder log likelihood.
355
+ ModelVarType.FIXED_LARGE: (
356
+ np.append(self.posterior_variance[1], self.betas[1:]),
357
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
358
+ ),
359
+ ModelVarType.FIXED_SMALL: (
360
+ self.posterior_variance,
361
+ self.posterior_log_variance_clipped,
362
+ ),
363
+ }[self.model_var_type]
364
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
365
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
366
+
367
+ def process_xstart(x):
368
+ if denoised_fn is not None:
369
+ x = denoised_fn(x)
370
+ if clip_denoised:
371
+ return x.clamp(-1, 1)
372
+ return x
373
+
374
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
375
+ pred_xstart = process_xstart(
376
+ self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
377
+ )
378
+ model_mean = model_output
379
+ elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
380
+ if self.model_mean_type == ModelMeanType.START_X:
381
+ pred_xstart = process_xstart(model_output)
382
+ else:
383
+ pred_xstart = process_xstart(
384
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
385
+ )
386
+ model_mean, _, _ = self.q_posterior_mean_variance(
387
+ x_start=pred_xstart, x_t=x, t=t
388
+ )
389
+ else:
390
+ raise NotImplementedError(self.model_mean_type)
391
+
392
+ assert (
393
+ model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
394
+ )
395
+ return {
396
+ "mean": model_mean,
397
+ "variance": model_variance,
398
+ "log_variance": model_log_variance,
399
+ "pred_xstart": pred_xstart,
400
+ }
401
+
402
+ def _predict_xstart_from_eps(self, x_t, t, eps):
403
+ assert x_t.shape == eps.shape
404
+ return (
405
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
406
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
407
+ )
408
+
409
+ def _predict_xstart_from_xprev(self, x_t, t, xprev):
410
+ assert x_t.shape == xprev.shape
411
+ return ( # (xprev - coef2*x_t) / coef1
412
+ _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
413
+ - _extract_into_tensor(
414
+ self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
415
+ )
416
+ * x_t
417
+ )
418
+
419
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
420
+ return (
421
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
422
+ - pred_xstart
423
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
424
+
425
+ def _scale_timesteps(self, t):
426
+ if self.rescale_timesteps:
427
+ return t.float() * (1000.0 / self.num_timesteps)
428
+ return t
429
+
430
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
431
+ """
432
+ Compute the mean for the previous step, given a function cond_fn that
433
+ computes the gradient of a conditional log probability with respect to
434
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
435
+ condition on y.
436
+
437
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
438
+ """
439
+ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
440
+ new_mean = (
441
+ p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
442
+ )
443
+ return new_mean
444
+
445
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
446
+ """
447
+ Compute what the p_mean_variance output would have been, should the
448
+ model's score function be conditioned by cond_fn.
449
+
450
+ See condition_mean() for details on cond_fn.
451
+
452
+ Unlike condition_mean(), this instead uses the conditioning strategy
453
+ from Song et al (2020).
454
+ """
455
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
456
+
457
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
458
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
459
+ x, self._scale_timesteps(t), **model_kwargs
460
+ )
461
+
462
+ out = p_mean_var.copy()
463
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
464
+ out["mean"], _, _ = self.q_posterior_mean_variance(
465
+ x_start=out["pred_xstart"], x_t=x, t=t
466
+ )
467
+ return out
468
+
469
+ def p_sample(
470
+ self,
471
+ model,
472
+ x,
473
+ t,
474
+ clip_denoised=True,
475
+ denoised_fn=None,
476
+ cond_fn=None,
477
+ model_kwargs=None,
478
+ ):
479
+ """
480
+ Sample x_{t-1} from the model at the given timestep.
481
+
482
+ :param model: the model to sample from.
483
+ :param x: the current tensor at x_{t-1}.
484
+ :param t: the value of t, starting at 0 for the first diffusion step.
485
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
486
+ :param denoised_fn: if not None, a function which applies to the
487
+ x_start prediction before it is used to sample.
488
+ :param cond_fn: if not None, this is a gradient function that acts
489
+ similarly to the model.
490
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
491
+ pass to the model. This can be used for conditioning.
492
+ :return: a dict containing the following keys:
493
+ - 'sample': a random sample from the model.
494
+ - 'pred_xstart': a prediction of x_0.
495
+ """
496
+ out = self.p_mean_variance(
497
+ model,
498
+ x,
499
+ t,
500
+ clip_denoised=clip_denoised,
501
+ denoised_fn=denoised_fn,
502
+ model_kwargs=model_kwargs,
503
+ )
504
+ noise = th.randn_like(x)
505
+ nonzero_mask = (
506
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
507
+ ) # no noise when t == 0
508
+ if cond_fn is not None:
509
+ out["mean"] = self.condition_mean(
510
+ cond_fn, out, x, t, model_kwargs=model_kwargs
511
+ )
512
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
513
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
514
+
515
+ def p_sample_loop(
516
+ self,
517
+ model,
518
+ shape,
519
+ noise=None,
520
+ clip_denoised=True,
521
+ denoised_fn=None,
522
+ cond_fn=None,
523
+ model_kwargs=None,
524
+ device=None,
525
+ progress=False,
526
+ ):
527
+ """
528
+ Generate samples from the model.
529
+
530
+ :param model: the model module.
531
+ :param shape: the shape of the samples, (N, C, H, W).
532
+ :param noise: if specified, the noise from the encoder to sample.
533
+ Should be of the same shape as `shape`.
534
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
535
+ :param denoised_fn: if not None, a function which applies to the
536
+ x_start prediction before it is used to sample.
537
+ :param cond_fn: if not None, this is a gradient function that acts
538
+ similarly to the model.
539
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
540
+ pass to the model. This can be used for conditioning.
541
+ :param device: if specified, the device to create the samples on.
542
+ If not specified, use a model parameter's device.
543
+ :param progress: if True, show a tqdm progress bar.
544
+ :return: a non-differentiable batch of samples.
545
+ """
546
+ final = None
547
+ for sample in self.p_sample_loop_progressive(
548
+ model,
549
+ shape,
550
+ noise=noise,
551
+ clip_denoised=clip_denoised,
552
+ denoised_fn=denoised_fn,
553
+ cond_fn=cond_fn,
554
+ model_kwargs=model_kwargs,
555
+ device=device,
556
+ progress=progress,
557
+ ):
558
+ final = sample
559
+ return final["sample"]
560
+
561
+ def p_sample_loop_progressive(
562
+ self,
563
+ model,
564
+ shape,
565
+ noise=None,
566
+ clip_denoised=True,
567
+ denoised_fn=None,
568
+ cond_fn=None,
569
+ model_kwargs=None,
570
+ device=None,
571
+ progress=False,
572
+ ):
573
+ """
574
+ Generate samples from the model and yield intermediate samples from
575
+ each timestep of diffusion.
576
+
577
+ Arguments are the same as p_sample_loop().
578
+ Returns a generator over dicts, where each dict is the return value of
579
+ p_sample().
580
+ """
581
+ if device is None:
582
+ device = next(model.parameters()).device
583
+ assert isinstance(shape, (tuple, list))
584
+ if noise is not None:
585
+ img = noise
586
+ else:
587
+ img = th.randn(*shape, device=device)
588
+ indices = list(range(self.num_timesteps))[::-1]
589
+
590
+ for i in tqdm(indices):
591
+ t = th.tensor([i] * shape[0], device=device)
592
+ with th.no_grad():
593
+ out = self.p_sample(
594
+ model,
595
+ img,
596
+ t,
597
+ clip_denoised=clip_denoised,
598
+ denoised_fn=denoised_fn,
599
+ cond_fn=cond_fn,
600
+ model_kwargs=model_kwargs,
601
+ )
602
+ yield out
603
+ img = out["sample"]
604
+
605
+ def ddim_sample(
606
+ self,
607
+ model,
608
+ x,
609
+ t,
610
+ clip_denoised=True,
611
+ denoised_fn=None,
612
+ cond_fn=None,
613
+ model_kwargs=None,
614
+ eta=0.0,
615
+ ):
616
+ """
617
+ Sample x_{t-1} from the model using DDIM.
618
+
619
+ Same usage as p_sample().
620
+ """
621
+ out = self.p_mean_variance(
622
+ model,
623
+ x,
624
+ t,
625
+ clip_denoised=clip_denoised,
626
+ denoised_fn=denoised_fn,
627
+ model_kwargs=model_kwargs,
628
+ )
629
+ if cond_fn is not None:
630
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
631
+
632
+ # Usually our model outputs epsilon, but we re-derive it
633
+ # in case we used x_start or x_prev prediction.
634
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
635
+
636
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
637
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
638
+ sigma = (
639
+ eta
640
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
641
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
642
+ )
643
+ # Equation 12.
644
+ noise = th.randn_like(x)
645
+ mean_pred = (
646
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
647
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
648
+ )
649
+ nonzero_mask = (
650
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
651
+ ) # no noise when t == 0
652
+ sample = mean_pred + nonzero_mask * sigma * noise
653
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
654
+
655
+ def ddim_reverse_sample(
656
+ self,
657
+ model,
658
+ x,
659
+ t,
660
+ clip_denoised=True,
661
+ denoised_fn=None,
662
+ model_kwargs=None,
663
+ eta=0.0,
664
+ ):
665
+ """
666
+ Sample x_{t+1} from the model using DDIM reverse ODE.
667
+ """
668
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
669
+ out = self.p_mean_variance(
670
+ model,
671
+ x,
672
+ t,
673
+ clip_denoised=clip_denoised,
674
+ denoised_fn=denoised_fn,
675
+ model_kwargs=model_kwargs,
676
+ )
677
+ # Usually our model outputs epsilon, but we re-derive it
678
+ # in case we used x_start or x_prev prediction.
679
+ eps = (
680
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
681
+ - out["pred_xstart"]
682
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
683
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
684
+
685
+ # Equation 12. reversed
686
+ mean_pred = (
687
+ out["pred_xstart"] * th.sqrt(alpha_bar_next)
688
+ + th.sqrt(1 - alpha_bar_next) * eps
689
+ )
690
+
691
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
692
+
693
+ def ddim_sample_loop(
694
+ self,
695
+ model,
696
+ shape,
697
+ noise=None,
698
+ clip_denoised=True,
699
+ denoised_fn=None,
700
+ cond_fn=None,
701
+ model_kwargs=None,
702
+ device=None,
703
+ progress=False,
704
+ eta=0.0,
705
+ ):
706
+ """
707
+ Generate samples from the model using DDIM.
708
+
709
+ Same usage as p_sample_loop().
710
+ """
711
+ final = None
712
+ for sample in self.ddim_sample_loop_progressive(
713
+ model,
714
+ shape,
715
+ noise=noise,
716
+ clip_denoised=clip_denoised,
717
+ denoised_fn=denoised_fn,
718
+ cond_fn=cond_fn,
719
+ model_kwargs=model_kwargs,
720
+ device=device,
721
+ progress=progress,
722
+ eta=eta,
723
+ ):
724
+ final = sample
725
+ return final["sample"]
726
+
727
+ def ddim_sample_loop_progressive(
728
+ self,
729
+ model,
730
+ shape,
731
+ noise=None,
732
+ clip_denoised=True,
733
+ denoised_fn=None,
734
+ cond_fn=None,
735
+ model_kwargs=None,
736
+ device=None,
737
+ progress=False,
738
+ eta=0.0,
739
+ ):
740
+ """
741
+ Use DDIM to sample from the model and yield intermediate samples from
742
+ each timestep of DDIM.
743
+
744
+ Same usage as p_sample_loop_progressive().
745
+ """
746
+ if device is None:
747
+ device = next(model.parameters()).device
748
+ assert isinstance(shape, (tuple, list))
749
+ if noise is not None:
750
+ img = noise
751
+ else:
752
+ img = th.randn(*shape, device=device)
753
+ indices = list(range(self.num_timesteps))[::-1]
754
+
755
+ if progress:
756
+ # Lazy import so that we don't depend on tqdm.
757
+ from tqdm.auto import tqdm
758
+
759
+ indices = tqdm(indices)
760
+
761
+ for i in indices:
762
+ t = th.tensor([i] * shape[0], device=device)
763
+ with th.no_grad():
764
+ out = self.ddim_sample(
765
+ model,
766
+ img,
767
+ t,
768
+ clip_denoised=clip_denoised,
769
+ denoised_fn=denoised_fn,
770
+ cond_fn=cond_fn,
771
+ model_kwargs=model_kwargs,
772
+ eta=eta,
773
+ )
774
+ yield out
775
+ img = out["sample"]
776
+
777
+ def _vb_terms_bpd(
778
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
779
+ ):
780
+ """
781
+ Get a term for the variational lower-bound.
782
+
783
+ The resulting units are bits (rather than nats, as one might expect).
784
+ This allows for comparison to other papers.
785
+
786
+ :return: a dict with the following keys:
787
+ - 'output': a shape [N] tensor of NLLs or KLs.
788
+ - 'pred_xstart': the x_0 predictions.
789
+ """
790
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
791
+ x_start=x_start, x_t=x_t, t=t
792
+ )
793
+ out = self.p_mean_variance(
794
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
795
+ )
796
+ kl = normal_kl(
797
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
798
+ )
799
+ kl = mean_flat(kl) / np.log(2.0)
800
+
801
+ decoder_nll = -discretized_gaussian_log_likelihood(
802
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
803
+ )
804
+ assert decoder_nll.shape == x_start.shape
805
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
806
+
807
+ # At the first timestep return the decoder NLL,
808
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
809
+ output = th.where((t == 0), decoder_nll, kl)
810
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
811
+
812
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
813
+ """
814
+ Compute training losses for a single timestep.
815
+
816
+ :param model: the model to evaluate loss on.
817
+ :param x_start: the [N x C x ...] tensor of inputs.
818
+ :param t: a batch of timestep indices.
819
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
820
+ pass to the model. This can be used for conditioning.
821
+ :param noise: if specified, the specific Gaussian noise to try to remove.
822
+ :return: a dict with the key "loss" containing a tensor of shape [N].
823
+ Some mean or variance settings may also have other keys.
824
+ """
825
+ if model_kwargs is None:
826
+ model_kwargs = {}
827
+ if noise is None:
828
+ noise = th.randn_like(x_start)
829
+ x_t = self.q_sample(x_start, t, noise=noise)
830
+
831
+ terms = {}
832
+
833
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
834
+ # TODO: support multiple model outputs for this mode.
835
+ terms["loss"] = self._vb_terms_bpd(
836
+ model=model,
837
+ x_start=x_start,
838
+ x_t=x_t,
839
+ t=t,
840
+ clip_denoised=False,
841
+ model_kwargs=model_kwargs,
842
+ )["output"]
843
+ if self.loss_type == LossType.RESCALED_KL:
844
+ terms["loss"] *= self.num_timesteps
845
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
846
+ model_outputs = model(x_t, self._scale_timesteps(t), **model_kwargs)
847
+ if isinstance(model_outputs, tuple):
848
+ model_output = model_outputs[0]
849
+ terms['extra_outputs'] = model_outputs[1:]
850
+ else:
851
+ model_output = model_outputs
852
+
853
+ if self.model_var_type in [
854
+ ModelVarType.LEARNED,
855
+ ModelVarType.LEARNED_RANGE,
856
+ ]:
857
+ B, C = x_t.shape[:2]
858
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:])
859
+ model_output, model_var_values = th.split(model_output, C, dim=1)
860
+ # Learn the variance using the variational bound, but don't let
861
+ # it affect our mean prediction.
862
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
863
+ terms["vb"] = self._vb_terms_bpd(
864
+ model=lambda *args, r=frozen_out: r,
865
+ x_start=x_start,
866
+ x_t=x_t,
867
+ t=t,
868
+ clip_denoised=False,
869
+ )["output"]
870
+ if self.loss_type == LossType.RESCALED_MSE:
871
+ # Divide by 1000 for equivalence with initial implementation.
872
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
873
+ terms["vb"] *= self.num_timesteps / 1000.0
874
+
875
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
876
+ target = self.q_posterior_mean_variance(
877
+ x_start=x_start, x_t=x_t, t=t
878
+ )[0]
879
+ x_start_pred = torch.zeros(x_start) # Not supported.
880
+ elif self.model_mean_type == ModelMeanType.START_X:
881
+ target = x_start
882
+ x_start_pred = model_output
883
+ elif self.model_mean_type == ModelMeanType.EPSILON:
884
+ target = noise
885
+ x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output)
886
+ else:
887
+ raise NotImplementedError(self.model_mean_type)
888
+ assert model_output.shape == target.shape == x_start.shape
889
+ terms["mse"] = mean_flat((target - model_output) ** 2)
890
+ terms["x_start_predicted"] = x_start_pred
891
+ if "vb" in terms:
892
+ terms["loss"] = terms["mse"] + terms["vb"]
893
+ else:
894
+ terms["loss"] = terms["mse"]
895
+ else:
896
+ raise NotImplementedError(self.loss_type)
897
+
898
+ return terms
899
+
900
+ def autoregressive_training_losses(self, model, x_start, t, model_output_keys, gd_out_key, model_kwargs=None, noise=None):
901
+ """
902
+ Compute training losses for a single timestep.
903
+
904
+ :param model: the model to evaluate loss on.
905
+ :param x_start: the [N x C x ...] tensor of inputs.
906
+ :param t: a batch of timestep indices.
907
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
908
+ pass to the model. This can be used for conditioning.
909
+ :param noise: if specified, the specific Gaussian noise to try to remove.
910
+ :return: a dict with the key "loss" containing a tensor of shape [N].
911
+ Some mean or variance settings may also have other keys.
912
+ """
913
+ if model_kwargs is None:
914
+ model_kwargs = {}
915
+ if noise is None:
916
+ noise = th.randn_like(x_start)
917
+ x_t = self.q_sample(x_start, t, noise=noise)
918
+ terms = {}
919
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
920
+ assert False # not currently supported for this type of diffusion.
921
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
922
+ model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs)
923
+ terms.update({k: o for k, o in zip(model_output_keys, model_outputs)})
924
+ model_output = terms[gd_out_key]
925
+ if self.model_var_type in [
926
+ ModelVarType.LEARNED,
927
+ ModelVarType.LEARNED_RANGE,
928
+ ]:
929
+ B, C = x_t.shape[:2]
930
+ assert model_output.shape == (B, C, 2, *x_t.shape[2:])
931
+ model_output, model_var_values = model_output[:, :, 0], model_output[:, :, 1]
932
+ # Learn the variance using the variational bound, but don't let
933
+ # it affect our mean prediction.
934
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
935
+ terms["vb"] = self._vb_terms_bpd(
936
+ model=lambda *args, r=frozen_out: r,
937
+ x_start=x_start,
938
+ x_t=x_t,
939
+ t=t,
940
+ clip_denoised=False,
941
+ )["output"]
942
+ if self.loss_type == LossType.RESCALED_MSE:
943
+ # Divide by 1000 for equivalence with initial implementation.
944
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
945
+ terms["vb"] *= self.num_timesteps / 1000.0
946
+
947
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
948
+ target = self.q_posterior_mean_variance(
949
+ x_start=x_start, x_t=x_t, t=t
950
+ )[0]
951
+ x_start_pred = torch.zeros(x_start) # Not supported.
952
+ elif self.model_mean_type == ModelMeanType.START_X:
953
+ target = x_start
954
+ x_start_pred = model_output
955
+ elif self.model_mean_type == ModelMeanType.EPSILON:
956
+ target = noise
957
+ x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output)
958
+ else:
959
+ raise NotImplementedError(self.model_mean_type)
960
+ assert model_output.shape == target.shape == x_start.shape
961
+ terms["mse"] = mean_flat((target - model_output) ** 2)
962
+ terms["x_start_predicted"] = x_start_pred
963
+ if "vb" in terms:
964
+ terms["loss"] = terms["mse"] + terms["vb"]
965
+ else:
966
+ terms["loss"] = terms["mse"]
967
+ else:
968
+ raise NotImplementedError(self.loss_type)
969
+
970
+ return terms
971
+
972
+ def _prior_bpd(self, x_start):
973
+ """
974
+ Get the prior KL term for the variational lower-bound, measured in
975
+ bits-per-dim.
976
+
977
+ This term can't be optimized, as it only depends on the encoder.
978
+
979
+ :param x_start: the [N x C x ...] tensor of inputs.
980
+ :return: a batch of [N] KL values (in bits), one per batch element.
981
+ """
982
+ batch_size = x_start.shape[0]
983
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
984
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
985
+ kl_prior = normal_kl(
986
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
987
+ )
988
+ return mean_flat(kl_prior) / np.log(2.0)
989
+
990
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
991
+ """
992
+ Compute the entire variational lower-bound, measured in bits-per-dim,
993
+ as well as other related quantities.
994
+
995
+ :param model: the model to evaluate loss on.
996
+ :param x_start: the [N x C x ...] tensor of inputs.
997
+ :param clip_denoised: if True, clip denoised samples.
998
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
999
+ pass to the model. This can be used for conditioning.
1000
+
1001
+ :return: a dict containing the following keys:
1002
+ - total_bpd: the total variational lower-bound, per batch element.
1003
+ - prior_bpd: the prior term in the lower-bound.
1004
+ - vb: an [N x T] tensor of terms in the lower-bound.
1005
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
1006
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
1007
+ """
1008
+ device = x_start.device
1009
+ batch_size = x_start.shape[0]
1010
+
1011
+ vb = []
1012
+ xstart_mse = []
1013
+ mse = []
1014
+ for t in list(range(self.num_timesteps))[::-1]:
1015
+ t_batch = th.tensor([t] * batch_size, device=device)
1016
+ noise = th.randn_like(x_start)
1017
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
1018
+ # Calculate VLB term at the current timestep
1019
+ with th.no_grad():
1020
+ out = self._vb_terms_bpd(
1021
+ model,
1022
+ x_start=x_start,
1023
+ x_t=x_t,
1024
+ t=t_batch,
1025
+ clip_denoised=clip_denoised,
1026
+ model_kwargs=model_kwargs,
1027
+ )
1028
+ vb.append(out["output"])
1029
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
1030
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
1031
+ mse.append(mean_flat((eps - noise) ** 2))
1032
+
1033
+ vb = th.stack(vb, dim=1)
1034
+ xstart_mse = th.stack(xstart_mse, dim=1)
1035
+ mse = th.stack(mse, dim=1)
1036
+
1037
+ prior_bpd = self._prior_bpd(x_start)
1038
+ total_bpd = vb.sum(dim=1) + prior_bpd
1039
+ return {
1040
+ "total_bpd": total_bpd,
1041
+ "prior_bpd": prior_bpd,
1042
+ "vb": vb,
1043
+ "xstart_mse": xstart_mse,
1044
+ "mse": mse,
1045
+ }
1046
+
1047
+
1048
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
1049
+ """
1050
+ Get a pre-defined beta schedule for the given name.
1051
+
1052
+ The beta schedule library consists of beta schedules which remain similar
1053
+ in the limit of num_diffusion_timesteps.
1054
+ Beta schedules may be added, but should not be removed or changed once
1055
+ they are committed to maintain backwards compatibility.
1056
+ """
1057
+ if schedule_name == "linear":
1058
+ # Linear schedule from Ho et al, extended to work for any number of
1059
+ # diffusion steps.
1060
+ scale = 1000 / num_diffusion_timesteps
1061
+ beta_start = scale * 0.0001
1062
+ beta_end = scale * 0.02
1063
+ return np.linspace(
1064
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
1065
+ )
1066
+ elif schedule_name == "cosine":
1067
+ return betas_for_alpha_bar(
1068
+ num_diffusion_timesteps,
1069
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
1070
+ )
1071
+ else:
1072
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
1073
+
1074
+
1075
+ class SpacedDiffusion(GaussianDiffusion):
1076
+ """
1077
+ A diffusion process which can skip steps in a base diffusion process.
1078
+
1079
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
1080
+ original diffusion process to retain.
1081
+ :param kwargs: the kwargs to create the base diffusion process.
1082
+ """
1083
+
1084
+ def __init__(self, use_timesteps, **kwargs):
1085
+ self.use_timesteps = set(use_timesteps)
1086
+ self.timestep_map = []
1087
+ self.original_num_steps = len(kwargs["betas"])
1088
+
1089
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
1090
+ last_alpha_cumprod = 1.0
1091
+ new_betas = []
1092
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
1093
+ if i in self.use_timesteps:
1094
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
1095
+ last_alpha_cumprod = alpha_cumprod
1096
+ self.timestep_map.append(i)
1097
+ kwargs["betas"] = np.array(new_betas)
1098
+ super().__init__(**kwargs)
1099
+
1100
+ def p_mean_variance(
1101
+ self, model, *args, **kwargs
1102
+ ): # pylint: disable=signature-differs
1103
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
1104
+
1105
+ def training_losses(
1106
+ self, model, *args, **kwargs
1107
+ ): # pylint: disable=signature-differs
1108
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
1109
+
1110
+ def autoregressive_training_losses(
1111
+ self, model, *args, **kwargs
1112
+ ): # pylint: disable=signature-differs
1113
+ return super().autoregressive_training_losses(self._wrap_model(model, True), *args, **kwargs)
1114
+
1115
+ def condition_mean(self, cond_fn, *args, **kwargs):
1116
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
1117
+
1118
+ def condition_score(self, cond_fn, *args, **kwargs):
1119
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
1120
+
1121
+ def _wrap_model(self, model, autoregressive=False):
1122
+ if isinstance(model, _WrappedModel) or isinstance(model, _WrappedAutoregressiveModel):
1123
+ return model
1124
+ mod = _WrappedAutoregressiveModel if autoregressive else _WrappedModel
1125
+ return mod(
1126
+ model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
1127
+ )
1128
+
1129
+ def _scale_timesteps(self, t):
1130
+ # Scaling is done by the wrapped model.
1131
+ return t
1132
+
1133
+
1134
+ def space_timesteps(num_timesteps, section_counts):
1135
+ """
1136
+ Create a list of timesteps to use from an original diffusion process,
1137
+ given the number of timesteps we want to take from equally-sized portions
1138
+ of the original process.
1139
+
1140
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
1141
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
1142
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
1143
+
1144
+ If the stride is a string starting with "ddim", then the fixed striding
1145
+ from the DDIM paper is used, and only one section is allowed.
1146
+
1147
+ :param num_timesteps: the number of diffusion steps in the original
1148
+ process to divide up.
1149
+ :param section_counts: either a list of numbers, or a string containing
1150
+ comma-separated numbers, indicating the step count
1151
+ per section. As a special case, use "ddimN" where N
1152
+ is a number of steps to use the striding from the
1153
+ DDIM paper.
1154
+ :return: a set of diffusion steps from the original process to use.
1155
+ """
1156
+ if isinstance(section_counts, str):
1157
+ if section_counts.startswith("ddim"):
1158
+ desired_count = int(section_counts[len("ddim") :])
1159
+ for i in range(1, num_timesteps):
1160
+ if len(range(0, num_timesteps, i)) == desired_count:
1161
+ return set(range(0, num_timesteps, i))
1162
+ raise ValueError(
1163
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
1164
+ )
1165
+ section_counts = [int(x) for x in section_counts.split(",")]
1166
+ size_per = num_timesteps // len(section_counts)
1167
+ extra = num_timesteps % len(section_counts)
1168
+ start_idx = 0
1169
+ all_steps = []
1170
+ for i, section_count in enumerate(section_counts):
1171
+ size = size_per + (1 if i < extra else 0)
1172
+ if size < section_count:
1173
+ raise ValueError(
1174
+ f"cannot divide section of {size} steps into {section_count}"
1175
+ )
1176
+ if section_count <= 1:
1177
+ frac_stride = 1
1178
+ else:
1179
+ frac_stride = (size - 1) / (section_count - 1)
1180
+ cur_idx = 0.0
1181
+ taken_steps = []
1182
+ for _ in range(section_count):
1183
+ taken_steps.append(start_idx + round(cur_idx))
1184
+ cur_idx += frac_stride
1185
+ all_steps += taken_steps
1186
+ start_idx += size
1187
+ return set(all_steps)
1188
+
1189
+
1190
+ class _WrappedModel:
1191
+ def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
1192
+ self.model = model
1193
+ self.timestep_map = timestep_map
1194
+ self.rescale_timesteps = rescale_timesteps
1195
+ self.original_num_steps = original_num_steps
1196
+
1197
+ def __call__(self, x, ts, **kwargs):
1198
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
1199
+ new_ts = map_tensor[ts]
1200
+ if self.rescale_timesteps:
1201
+ new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
1202
+ return self.model(x, new_ts, **kwargs)
1203
+
1204
+
1205
+ class _WrappedAutoregressiveModel:
1206
+ def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
1207
+ self.model = model
1208
+ self.timestep_map = timestep_map
1209
+ self.rescale_timesteps = rescale_timesteps
1210
+ self.original_num_steps = original_num_steps
1211
+
1212
+ def __call__(self, x, x0, ts, **kwargs):
1213
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
1214
+ new_ts = map_tensor[ts]
1215
+ if self.rescale_timesteps:
1216
+ new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
1217
+ return self.model(x, x0, new_ts, **kwargs)
1218
+
1219
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
1220
+ """
1221
+ Extract values from a 1-D numpy array for a batch of indices.
1222
+
1223
+ :param arr: the 1-D numpy array.
1224
+ :param timesteps: a tensor of indices into the array to extract.
1225
+ :param broadcast_shape: a larger shape of K dimensions with the batch
1226
+ dimension equal to the length of timesteps.
1227
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
1228
+ """
1229
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
1230
+ while len(res.shape) < len(broadcast_shape):
1231
+ res = res[..., None]
1232
+ return res.expand(broadcast_shape)
utils/tokenizer.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import inflect
4
+ import torch
5
+ from tokenizers import Tokenizer
6
+
7
+
8
+ # Regular expression matching whitespace:
9
+ from unidecode import unidecode
10
+
11
+ _whitespace_re = re.compile(r'\s+')
12
+
13
+
14
+ # List of (regular expression, replacement) pairs for abbreviations:
15
+ _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
16
+ ('mrs', 'misess'),
17
+ ('mr', 'mister'),
18
+ ('dr', 'doctor'),
19
+ ('st', 'saint'),
20
+ ('co', 'company'),
21
+ ('jr', 'junior'),
22
+ ('maj', 'major'),
23
+ ('gen', 'general'),
24
+ ('drs', 'doctors'),
25
+ ('rev', 'reverend'),
26
+ ('lt', 'lieutenant'),
27
+ ('hon', 'honorable'),
28
+ ('sgt', 'sergeant'),
29
+ ('capt', 'captain'),
30
+ ('esq', 'esquire'),
31
+ ('ltd', 'limited'),
32
+ ('col', 'colonel'),
33
+ ('ft', 'fort'),
34
+ ]]
35
+
36
+
37
+ def expand_abbreviations(text):
38
+ for regex, replacement in _abbreviations:
39
+ text = re.sub(regex, replacement, text)
40
+ return text
41
+
42
+
43
+ _inflect = inflect.engine()
44
+ _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
45
+ _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
46
+ _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
47
+ _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
48
+ _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
49
+ _number_re = re.compile(r'[0-9]+')
50
+
51
+
52
+ def _remove_commas(m):
53
+ return m.group(1).replace(',', '')
54
+
55
+
56
+ def _expand_decimal_point(m):
57
+ return m.group(1).replace('.', ' point ')
58
+
59
+
60
+ def _expand_dollars(m):
61
+ match = m.group(1)
62
+ parts = match.split('.')
63
+ if len(parts) > 2:
64
+ return match + ' dollars' # Unexpected format
65
+ dollars = int(parts[0]) if parts[0] else 0
66
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
67
+ if dollars and cents:
68
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
69
+ cent_unit = 'cent' if cents == 1 else 'cents'
70
+ return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
71
+ elif dollars:
72
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
73
+ return '%s %s' % (dollars, dollar_unit)
74
+ elif cents:
75
+ cent_unit = 'cent' if cents == 1 else 'cents'
76
+ return '%s %s' % (cents, cent_unit)
77
+ else:
78
+ return 'zero dollars'
79
+
80
+
81
+ def _expand_ordinal(m):
82
+ return _inflect.number_to_words(m.group(0))
83
+
84
+
85
+ def _expand_number(m):
86
+ num = int(m.group(0))
87
+ if num > 1000 and num < 3000:
88
+ if num == 2000:
89
+ return 'two thousand'
90
+ elif num > 2000 and num < 2010:
91
+ return 'two thousand ' + _inflect.number_to_words(num % 100)
92
+ elif num % 100 == 0:
93
+ return _inflect.number_to_words(num // 100) + ' hundred'
94
+ else:
95
+ return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
96
+ else:
97
+ return _inflect.number_to_words(num, andword='')
98
+
99
+
100
+ def normalize_numbers(text):
101
+ text = re.sub(_comma_number_re, _remove_commas, text)
102
+ text = re.sub(_pounds_re, r'\1 pounds', text)
103
+ text = re.sub(_dollars_re, _expand_dollars, text)
104
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
105
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
106
+ text = re.sub(_number_re, _expand_number, text)
107
+ return text
108
+
109
+
110
+ def expand_numbers(text):
111
+ return normalize_numbers(text)
112
+
113
+
114
+ def lowercase(text):
115
+ return text.lower()
116
+
117
+
118
+ def collapse_whitespace(text):
119
+ return re.sub(_whitespace_re, ' ', text)
120
+
121
+
122
+ def convert_to_ascii(text):
123
+ return unidecode(text)
124
+
125
+
126
+ def basic_cleaners(text):
127
+ '''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
128
+ text = lowercase(text)
129
+ text = collapse_whitespace(text)
130
+ return text
131
+
132
+
133
+ def transliteration_cleaners(text):
134
+ '''Pipeline for non-English text that transliterates to ASCII.'''
135
+ text = convert_to_ascii(text)
136
+ text = lowercase(text)
137
+ text = collapse_whitespace(text)
138
+ return text
139
+
140
+
141
+ def english_cleaners(text):
142
+ '''Pipeline for English text, including number and abbreviation expansion.'''
143
+ text = convert_to_ascii(text)
144
+ text = lowercase(text)
145
+ text = expand_numbers(text)
146
+ text = expand_abbreviations(text)
147
+ text = collapse_whitespace(text)
148
+ text = text.replace('"', '')
149
+ return text
150
+
151
+
152
+ class VoiceBpeTokenizer:
153
+ def __init__(self, vocab_file='data/tokenizer.json'):
154
+ if vocab_file is not None:
155
+ self.tokenizer = Tokenizer.from_file(vocab_file)
156
+
157
+ def preprocess_text(self, txt):
158
+ txt = english_cleaners(txt)
159
+ return txt
160
+
161
+ def encode(self, txt):
162
+ txt = self.preprocess_text(txt)
163
+ txt = txt.replace(' ', '[SPACE]')
164
+ return self.tokenizer.encode(txt).ids
165
+
166
+ def decode(self, seq):
167
+ if isinstance(seq, torch.Tensor):
168
+ seq = seq.cpu().numpy()
169
+ txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(' ', '')
170
+ txt = txt.replace('[SPACE]', ' ')
171
+ txt = txt.replace('[STOP]', '')
172
+ txt = txt.replace('[UNK]', '')
173
+ return txt