Spaces:
Running
on
Zero
Running
on
Zero
Ngaima Sandiman
commited on
Commit
•
685ecb2
1
Parent(s):
8fd0e3f
Initial commit.
Browse files- app.py +36 -0
- config.json +58 -0
- config.yaml +14 -0
- media/voicecraft/generated/empty.txt +0 -0
- media/voicecraft/voices/84_121550_000074_000000.wav +0 -0
- media/voicecraft/voices/mfa_alignments/84_121550_000074_000000.csv +109 -0
- models/pretrained/imagecraft/empty.txt +0 -0
- models/pretrained/voicecraft/empty.txt +0 -0
- packages.txt +14 -0
- requirements.txt +13 -0
- setup.py +16 -0
- src/__init__.py +0 -0
- src/model/modules/__init__.py +0 -0
- src/model/modules/activation.py +638 -0
- src/model/modules/codebooks_patterns.py +538 -0
- src/model/modules/embedding.py +98 -0
- src/model/modules/gemma.py +423 -0
- src/model/modules/imagecraft.py +490 -0
- src/model/modules/imagecraftconfig.py +47 -0
- src/model/modules/imagecraftprocessor.py +96 -0
- src/model/modules/kv_cache.py +38 -0
- src/model/modules/sampling.py +65 -0
- src/model/modules/scaling.py +1391 -0
- src/model/modules/siglip.py +258 -0
- src/model/modules/tokenizer.py +149 -0
- src/model/modules/transformer.py +690 -0
- src/model/modules/voicecraft.py +1999 -0
- src/model/modules/voicecraftconfig.py +37 -0
- src/utils/__init__.py +0 -0
- src/utils/image_utils.py +100 -0
- src/utils/model_utils.py +73 -0
- src/utils/tools.py +22 -0
- src/utils/util.py +305 -0
app.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
|
4 |
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
5 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
6 |
+
os.environ["USER"] = "imagecraft"
|
7 |
+
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
from src.model.modules.imagecraft import ImageCraft
|
11 |
+
|
12 |
+
model = ImageCraft.from_pretrained("nsandiman/imagecraft-ft-co-224")
|
13 |
+
|
14 |
+
|
15 |
+
def imagecraft_interface(image_path):
|
16 |
+
"""Process image inputs and generate audio response."""
|
17 |
+
transcript, audio_buffer = model.generate(image_path, output_type="buffer")
|
18 |
+
|
19 |
+
return audio_buffer, transcript
|
20 |
+
|
21 |
+
|
22 |
+
# Define Gradio interface
|
23 |
+
gradio_interface = gr.Interface(
|
24 |
+
fn=imagecraft_interface,
|
25 |
+
inputs=[
|
26 |
+
gr.Image(type="filepath", label="Upload an image"),
|
27 |
+
gr.Textbox(label="Reference Text (for evaluation)"),
|
28 |
+
],
|
29 |
+
outputs=[gr.Audio(label="Speech"), gr.Textbox(label="Transcript")],
|
30 |
+
title="ImageCraft",
|
31 |
+
description="Upload an image and get the speech responses.",
|
32 |
+
allow_flagging="never",
|
33 |
+
)
|
34 |
+
|
35 |
+
# Launch the Gradio app
|
36 |
+
gradio_interface.launch(share=True)
|
config.json
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "nsandiman/imagecraft-ft-co-224",
|
3 |
+
"_vocab_size": 257216,
|
4 |
+
"architectures": [
|
5 |
+
"PaliGemmaForConditionalGeneration"
|
6 |
+
],
|
7 |
+
"bos_token_id": 2,
|
8 |
+
"eos_token_id": 1,
|
9 |
+
"hidden_size": 2048,
|
10 |
+
"ignore_index": -100,
|
11 |
+
"image_token_index": 257152,
|
12 |
+
"model_type": "paligemma",
|
13 |
+
"pad_token_id": 0,
|
14 |
+
"projection_dim": 2048,
|
15 |
+
"text_config": {
|
16 |
+
"hidden_size": 2048,
|
17 |
+
"intermediate_size": 16384,
|
18 |
+
"model_type": "gemma",
|
19 |
+
"num_attention_heads": 8,
|
20 |
+
"num_hidden_layers": 18,
|
21 |
+
"num_image_tokens": 256,
|
22 |
+
"num_key_value_heads": 1,
|
23 |
+
"torch_dtype": "float32",
|
24 |
+
"vocab_size": 257216
|
25 |
+
},
|
26 |
+
"torch_dtype": "float32",
|
27 |
+
"transformers_version": "4.41.0.dev0",
|
28 |
+
"vision_config": {
|
29 |
+
"hidden_size": 1152,
|
30 |
+
"intermediate_size": 4304,
|
31 |
+
"model_type": "siglip_vision_model",
|
32 |
+
"num_attention_heads": 16,
|
33 |
+
"num_hidden_layers": 27,
|
34 |
+
"num_image_tokens": 256,
|
35 |
+
"patch_size": 14,
|
36 |
+
"projection_dim": 2048,
|
37 |
+
"projector_hidden_act": "gelu_fast",
|
38 |
+
"vision_use_head": false
|
39 |
+
},
|
40 |
+
"vocab_size": 257216,
|
41 |
+
"voicecraft_config": {
|
42 |
+
"model_name": "330M_TTSEnhanced.pth",
|
43 |
+
"encoded": "encodec_4cb2048_giga.th",
|
44 |
+
"voice_audio_path": "84_121550_000074_000000.wav",
|
45 |
+
"voice_audio_transcript": "But when I had approached so near to them The common object, which the sense deceives, Lost not by distance any of its marks",
|
46 |
+
"top_k": 0,
|
47 |
+
"top_p": 0.9,
|
48 |
+
"temperature": 1,
|
49 |
+
"kvcache": 1,
|
50 |
+
"codec_sr": 50,
|
51 |
+
"codec_audio_sr": 16000,
|
52 |
+
"silence_tokens": [1388, 1898, 131],
|
53 |
+
"stop_repetition": 3,
|
54 |
+
"sample_batch_size": 2,
|
55 |
+
"seed": 1,
|
56 |
+
"cut_off_sec": 7.87
|
57 |
+
}
|
58 |
+
}
|
config.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_flickr: nsandiman/imagecraft-pt-fk-224
|
2 |
+
model_coco: models/imagecraft-pt-co-224
|
3 |
+
model_tiny: models/imagecraft-pt-ty-224
|
4 |
+
checkpoint_dir: models/checkpoint
|
5 |
+
pretrained_dir: models/pretrained
|
6 |
+
model_dir: models
|
7 |
+
data:
|
8 |
+
raw_dir: data/raw
|
9 |
+
interim_dir: data/interim
|
10 |
+
processed_dir: data/processed
|
11 |
+
log_dir: data/logs
|
12 |
+
wandb_dir: data/wandb
|
13 |
+
tensorboard_log_dir: data/tensorboard/logs
|
14 |
+
|
media/voicecraft/generated/empty.txt
ADDED
File without changes
|
media/voicecraft/voices/84_121550_000074_000000.wav
ADDED
Binary file (508 kB). View file
|
|
media/voicecraft/voices/mfa_alignments/84_121550_000074_000000.csv
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Begin,End,Label,Type,Speaker
|
2 |
+
0.03,0.18,but,words,temp
|
3 |
+
0.18,0.32,when,words,temp
|
4 |
+
0.32,0.48,i,words,temp
|
5 |
+
0.48,0.64,had,words,temp
|
6 |
+
0.64,1.19,approached,words,temp
|
7 |
+
1.22,1.58,so,words,temp
|
8 |
+
1.58,1.91,near,words,temp
|
9 |
+
1.91,2.07,to,words,temp
|
10 |
+
2.07,2.42,them,words,temp
|
11 |
+
2.53,2.61,the,words,temp
|
12 |
+
2.61,3.01,common,words,temp
|
13 |
+
3.05,3.62,object,words,temp
|
14 |
+
3.68,3.93,which,words,temp
|
15 |
+
3.93,4.02,the,words,temp
|
16 |
+
4.02,4.34,sense,words,temp
|
17 |
+
4.34,4.97,deceives,words,temp
|
18 |
+
5.04,5.54,lost,words,temp
|
19 |
+
5.54,6.0,not,words,temp
|
20 |
+
6.0,6.14,by,words,temp
|
21 |
+
6.14,6.67,distance,words,temp
|
22 |
+
6.79,7.05,any,words,temp
|
23 |
+
7.05,7.18,of,words,temp
|
24 |
+
7.18,7.34,its,words,temp
|
25 |
+
7.34,7.87,marks,words,temp
|
26 |
+
0.03,0.06,B,phones,temp
|
27 |
+
0.06,0.09,AH1,phones,temp
|
28 |
+
0.09,0.18,T,phones,temp
|
29 |
+
0.18,0.23,W,phones,temp
|
30 |
+
0.23,0.27,EH1,phones,temp
|
31 |
+
0.27,0.32,N,phones,temp
|
32 |
+
0.32,0.48,AY1,phones,temp
|
33 |
+
0.48,0.49,HH,phones,temp
|
34 |
+
0.49,0.6,AE1,phones,temp
|
35 |
+
0.6,0.64,D,phones,temp
|
36 |
+
0.64,0.7,AH0,phones,temp
|
37 |
+
0.7,0.83,P,phones,temp
|
38 |
+
0.83,0.88,R,phones,temp
|
39 |
+
0.88,0.99,OW1,phones,temp
|
40 |
+
0.99,1.12,CH,phones,temp
|
41 |
+
1.12,1.19,T,phones,temp
|
42 |
+
1.22,1.4,S,phones,temp
|
43 |
+
1.4,1.58,OW1,phones,temp
|
44 |
+
1.58,1.7,N,phones,temp
|
45 |
+
1.7,1.84,IH1,phones,temp
|
46 |
+
1.84,1.91,R,phones,temp
|
47 |
+
1.91,2.01,T,phones,temp
|
48 |
+
2.01,2.07,AH0,phones,temp
|
49 |
+
2.07,2.13,DH,phones,temp
|
50 |
+
2.13,2.3,EH1,phones,temp
|
51 |
+
2.3,2.42,M,phones,temp
|
52 |
+
2.53,2.55,DH,phones,temp
|
53 |
+
2.55,2.61,AH0,phones,temp
|
54 |
+
2.61,2.73,K,phones,temp
|
55 |
+
2.73,2.85,AA1,phones,temp
|
56 |
+
2.85,2.9,M,phones,temp
|
57 |
+
2.9,2.95,AH0,phones,temp
|
58 |
+
2.95,3.01,N,phones,temp
|
59 |
+
3.05,3.22,AA1,phones,temp
|
60 |
+
3.22,3.27,B,phones,temp
|
61 |
+
3.27,3.34,JH,phones,temp
|
62 |
+
3.34,3.48,EH0,phones,temp
|
63 |
+
3.48,3.54,K,phones,temp
|
64 |
+
3.54,3.62,T,phones,temp
|
65 |
+
3.68,3.69,HH,phones,temp
|
66 |
+
3.69,3.76,W,phones,temp
|
67 |
+
3.76,3.8,IH1,phones,temp
|
68 |
+
3.8,3.93,CH,phones,temp
|
69 |
+
3.93,3.95,DH,phones,temp
|
70 |
+
3.95,4.02,AH0,phones,temp
|
71 |
+
4.02,4.12,S,phones,temp
|
72 |
+
4.12,4.21,EH1,phones,temp
|
73 |
+
4.21,4.27,N,phones,temp
|
74 |
+
4.27,4.34,S,phones,temp
|
75 |
+
4.34,4.42,D,phones,temp
|
76 |
+
4.42,4.45,IH0,phones,temp
|
77 |
+
4.45,4.59,S,phones,temp
|
78 |
+
4.59,4.79,IY1,phones,temp
|
79 |
+
4.79,4.87,V,phones,temp
|
80 |
+
4.87,4.97,Z,phones,temp
|
81 |
+
5.04,5.12,L,phones,temp
|
82 |
+
5.12,5.33,AO1,phones,temp
|
83 |
+
5.33,5.42,S,phones,temp
|
84 |
+
5.42,5.54,T,phones,temp
|
85 |
+
5.54,5.7,N,phones,temp
|
86 |
+
5.7,5.89,AA1,phones,temp
|
87 |
+
5.89,6.0,T,phones,temp
|
88 |
+
6.0,6.05,B,phones,temp
|
89 |
+
6.05,6.14,AY1,phones,temp
|
90 |
+
6.14,6.24,D,phones,temp
|
91 |
+
6.24,6.3,IH1,phones,temp
|
92 |
+
6.3,6.38,S,phones,temp
|
93 |
+
6.38,6.45,T,phones,temp
|
94 |
+
6.45,6.51,AH0,phones,temp
|
95 |
+
6.51,6.57,N,phones,temp
|
96 |
+
6.57,6.67,S,phones,temp
|
97 |
+
6.79,6.89,EH1,phones,temp
|
98 |
+
6.89,6.95,N,phones,temp
|
99 |
+
6.95,7.05,IY0,phones,temp
|
100 |
+
7.05,7.13,AH0,phones,temp
|
101 |
+
7.13,7.18,V,phones,temp
|
102 |
+
7.18,7.22,IH0,phones,temp
|
103 |
+
7.22,7.29,T,phones,temp
|
104 |
+
7.29,7.34,S,phones,temp
|
105 |
+
7.34,7.39,M,phones,temp
|
106 |
+
7.39,7.5,AA1,phones,temp
|
107 |
+
7.5,7.58,R,phones,temp
|
108 |
+
7.58,7.7,K,phones,temp
|
109 |
+
7.7,7.87,S,phones,temp
|
models/pretrained/imagecraft/empty.txt
ADDED
File without changes
|
models/pretrained/voicecraft/empty.txt
ADDED
File without changes
|
packages.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
espeak-ng
|
2 |
+
espeak
|
3 |
+
espeak-data
|
4 |
+
libespeak1
|
5 |
+
libespeak-dev
|
6 |
+
festival*
|
7 |
+
build-essential
|
8 |
+
flac
|
9 |
+
libasound2-dev
|
10 |
+
libsndfile1-dev
|
11 |
+
vorbis-tools
|
12 |
+
libxml2-dev
|
13 |
+
libxslt-dev
|
14 |
+
zlib1g-dev
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-e git+https://github.com/facebookresearch/audiocraft.git@f83babff6b5e97f75562127c4cc8122229c8f099#egg=audiocraft
|
2 |
+
git+https://github.com/huggingface/transformers.git
|
3 |
+
phonemizer
|
4 |
+
spaces
|
5 |
+
huggingface-hub
|
6 |
+
num2words
|
7 |
+
numpy
|
8 |
+
pillow
|
9 |
+
safetensors
|
10 |
+
tokenizers
|
11 |
+
torchaudio
|
12 |
+
torchvision
|
13 |
+
aeneas
|
setup.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import platform
|
2 |
+
from setuptools import setup, find_packages
|
3 |
+
|
4 |
+
if platform.python_version_tuple()[:2] != ("3", "11"):
|
5 |
+
raise RuntimeError("Python version 3.11 required")
|
6 |
+
|
7 |
+
setup(
|
8 |
+
name="distilvit",
|
9 |
+
version="0.1",
|
10 |
+
packages=find_packages(),
|
11 |
+
entry_points={
|
12 |
+
"console_scripts": [
|
13 |
+
"train=src.model.train:main", # "main" is a function in "train_model.py"
|
14 |
+
],
|
15 |
+
},
|
16 |
+
)
|
src/__init__.py
ADDED
File without changes
|
src/model/modules/__init__.py
ADDED
File without changes
|
src/model/modules/activation.py
ADDED
@@ -0,0 +1,638 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py, modified by Puyuan Peng, 2024
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import Tensor
|
6 |
+
from torch.nn import Linear, Module
|
7 |
+
from torch.nn import functional as F
|
8 |
+
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
9 |
+
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
10 |
+
from torch.nn.parameter import Parameter
|
11 |
+
import logging
|
12 |
+
from typing import Callable, List, Optional, Tuple, Union
|
13 |
+
from typing import TYPE_CHECKING
|
14 |
+
if TYPE_CHECKING:
|
15 |
+
from torch.types import _dtype as DType
|
16 |
+
else:
|
17 |
+
# The JIT doesn't understand Union, nor torch.dtype here
|
18 |
+
DType = int
|
19 |
+
|
20 |
+
def _canonical_mask(
|
21 |
+
mask: Optional[Tensor],
|
22 |
+
mask_name: str,
|
23 |
+
other_type: Optional[DType],
|
24 |
+
other_name: str,
|
25 |
+
target_type: DType,
|
26 |
+
check_other: bool = True,
|
27 |
+
) -> Optional[Tensor]:
|
28 |
+
|
29 |
+
if mask is not None:
|
30 |
+
_mask_dtype = mask.dtype
|
31 |
+
_mask_is_float = torch.is_floating_point(mask)
|
32 |
+
if _mask_dtype != torch.bool and not _mask_is_float:
|
33 |
+
raise AssertionError(
|
34 |
+
f"only bool and floating types of {mask_name} are supported")
|
35 |
+
if check_other and other_type is not None:
|
36 |
+
if _mask_dtype != other_type:
|
37 |
+
warnings.warn(
|
38 |
+
f"Support for mismatched {mask_name} and {other_name} "
|
39 |
+
"is deprecated. Use same type for both instead."
|
40 |
+
)
|
41 |
+
if not _mask_is_float:
|
42 |
+
mask = (
|
43 |
+
torch.zeros_like(mask, dtype=target_type)
|
44 |
+
.masked_fill_(mask, float("-inf"))
|
45 |
+
)
|
46 |
+
return mask
|
47 |
+
|
48 |
+
def _in_projection_packed(
|
49 |
+
q: Tensor,
|
50 |
+
k: Tensor,
|
51 |
+
v: Tensor,
|
52 |
+
w: Tensor,
|
53 |
+
b: Optional[Tensor] = None,
|
54 |
+
) -> List[Tensor]:
|
55 |
+
r"""
|
56 |
+
Performs the in-projection step of the attention operation, using packed weights.
|
57 |
+
Output is a triple containing projection tensors for query, key and value.
|
58 |
+
Args:
|
59 |
+
q, k, v: query, key and value tensors to be projected. For self-attention,
|
60 |
+
these are typically the same tensor; for encoder-decoder attention,
|
61 |
+
k and v are typically the same tensor. (We take advantage of these
|
62 |
+
identities for performance if they are present.) Regardless, q, k and v
|
63 |
+
must share a common embedding dimension; otherwise their shapes may vary.
|
64 |
+
w: projection weights for q, k and v, packed into a single tensor. Weights
|
65 |
+
are packed along dimension 0, in q, k, v order.
|
66 |
+
b: optional projection biases for q, k and v, packed into a single tensor
|
67 |
+
in q, k, v order.
|
68 |
+
Shape:
|
69 |
+
Inputs:
|
70 |
+
- q: :math:`(..., E)` where E is the embedding dimension
|
71 |
+
- k: :math:`(..., E)` where E is the embedding dimension
|
72 |
+
- v: :math:`(..., E)` where E is the embedding dimension
|
73 |
+
- w: :math:`(E * 3, E)` where E is the embedding dimension
|
74 |
+
- b: :math:`E * 3` where E is the embedding dimension
|
75 |
+
Output:
|
76 |
+
- in output list :math:`[q', k', v']`, each output tensor will have the
|
77 |
+
same shape as the corresponding input tensor.
|
78 |
+
"""
|
79 |
+
E = q.size(-1)
|
80 |
+
if k is v:
|
81 |
+
if q is k:
|
82 |
+
# self-attention
|
83 |
+
proj = F.linear(q, w, b)
|
84 |
+
# reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
|
85 |
+
proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
|
86 |
+
return proj[0], proj[1], proj[2]
|
87 |
+
else:
|
88 |
+
# encoder-decoder attention
|
89 |
+
w_q, w_kv = w.split([E, E * 2])
|
90 |
+
if b is None:
|
91 |
+
b_q = b_kv = None
|
92 |
+
else:
|
93 |
+
b_q, b_kv = b.split([E, E * 2])
|
94 |
+
q_proj = F.linear(q, w_q, b_q)
|
95 |
+
kv_proj = F.linear(k, w_kv, b_kv)
|
96 |
+
# reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk()
|
97 |
+
kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
|
98 |
+
return (q_proj, kv_proj[0], kv_proj[1])
|
99 |
+
else:
|
100 |
+
w_q, w_k, w_v = w.chunk(3)
|
101 |
+
if b is None:
|
102 |
+
b_q = b_k = b_v = None
|
103 |
+
else:
|
104 |
+
b_q, b_k, b_v = b.chunk(3)
|
105 |
+
return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
|
106 |
+
|
107 |
+
def _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]:
|
108 |
+
if input is None:
|
109 |
+
return None
|
110 |
+
elif isinstance(input, torch.Tensor):
|
111 |
+
return input.dtype
|
112 |
+
raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor")
|
113 |
+
class MultiheadAttention(Module):
|
114 |
+
r"""Allows the model to jointly attend to information
|
115 |
+
from different representation subspaces as described in the paper:
|
116 |
+
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
|
117 |
+
Multi-Head Attention is defined as:
|
118 |
+
.. math::
|
119 |
+
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
120 |
+
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
|
121 |
+
``forward()`` will use a special optimized implementation if all of the following
|
122 |
+
conditions are met:
|
123 |
+
- self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
|
124 |
+
restriction will be loosened in the future.)
|
125 |
+
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
|
126 |
+
- training is disabled (using ``.eval()``)
|
127 |
+
- dropout is 0
|
128 |
+
- ``add_bias_kv`` is ``False``
|
129 |
+
- ``add_zero_attn`` is ``False``
|
130 |
+
- ``batch_first`` is ``True`` and the input is batched
|
131 |
+
- ``kdim`` and ``vdim`` are equal to ``embed_dim``
|
132 |
+
- at most one of ``key_padding_mask`` or ``attn_mask`` is passed
|
133 |
+
- if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
|
134 |
+
nor ``attn_mask`` is passed
|
135 |
+
If the optimized implementation is in use, a
|
136 |
+
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
|
137 |
+
``query``/``key``/``value`` to represent padding more efficiently than using a
|
138 |
+
padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
|
139 |
+
will be returned, and an additional speedup proportional to the fraction of the input
|
140 |
+
that is padding can be expected.
|
141 |
+
Args:
|
142 |
+
embed_dim: Total dimension of the model.
|
143 |
+
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
|
144 |
+
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
|
145 |
+
dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
|
146 |
+
bias: If specified, adds bias to input / output projection layers. Default: ``True``.
|
147 |
+
add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
|
148 |
+
add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
|
149 |
+
Default: ``False``.
|
150 |
+
kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
|
151 |
+
vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
|
152 |
+
batch_first: If ``True``, then the input and output tensors are provided
|
153 |
+
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
154 |
+
Examples::
|
155 |
+
>>> # xdoctest: +SKIP
|
156 |
+
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
157 |
+
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
158 |
+
"""
|
159 |
+
__constants__ = ["batch_first"]
|
160 |
+
bias_k: Optional[torch.Tensor]
|
161 |
+
bias_v: Optional[torch.Tensor]
|
162 |
+
|
163 |
+
def __init__(
|
164 |
+
self,
|
165 |
+
embed_dim,
|
166 |
+
num_heads,
|
167 |
+
dropout=0.0,
|
168 |
+
bias=True,
|
169 |
+
add_bias_kv=False,
|
170 |
+
add_zero_attn=False,
|
171 |
+
kdim=None,
|
172 |
+
vdim=None,
|
173 |
+
batch_first=False,
|
174 |
+
linear1_cls=Linear,
|
175 |
+
linear2_cls=Linear,
|
176 |
+
device=None,
|
177 |
+
dtype=None,
|
178 |
+
) -> None:
|
179 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
180 |
+
super(MultiheadAttention, self).__init__()
|
181 |
+
self.embed_dim = embed_dim
|
182 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
183 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
184 |
+
self._qkv_same_embed_dim = (
|
185 |
+
self.kdim == embed_dim and self.vdim == embed_dim
|
186 |
+
)
|
187 |
+
|
188 |
+
self.num_heads = num_heads
|
189 |
+
self.dropout = dropout
|
190 |
+
self.batch_first = batch_first
|
191 |
+
self.head_dim = embed_dim // num_heads
|
192 |
+
assert (
|
193 |
+
self.head_dim * num_heads == self.embed_dim
|
194 |
+
), "embed_dim must be divisible by num_heads"
|
195 |
+
|
196 |
+
if add_bias_kv:
|
197 |
+
self.bias_k = Parameter(
|
198 |
+
torch.empty((1, 1, embed_dim), **factory_kwargs)
|
199 |
+
)
|
200 |
+
self.bias_v = Parameter(
|
201 |
+
torch.empty((1, 1, embed_dim), **factory_kwargs)
|
202 |
+
)
|
203 |
+
else:
|
204 |
+
self.bias_k = self.bias_v = None
|
205 |
+
|
206 |
+
if linear1_cls == Linear:
|
207 |
+
if not self._qkv_same_embed_dim:
|
208 |
+
self.q_proj_weight = Parameter(
|
209 |
+
torch.empty((embed_dim, embed_dim), **factory_kwargs)
|
210 |
+
)
|
211 |
+
self.k_proj_weight = Parameter(
|
212 |
+
torch.empty((embed_dim, self.kdim), **factory_kwargs)
|
213 |
+
)
|
214 |
+
self.v_proj_weight = Parameter(
|
215 |
+
torch.empty((embed_dim, self.vdim), **factory_kwargs)
|
216 |
+
)
|
217 |
+
self.register_parameter("in_proj_weight", None)
|
218 |
+
else:
|
219 |
+
# go down this route with voicecraft
|
220 |
+
self.in_proj_weight = Parameter(
|
221 |
+
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
|
222 |
+
)
|
223 |
+
self.register_parameter("q_proj_weight", None)
|
224 |
+
self.register_parameter("k_proj_weight", None)
|
225 |
+
self.register_parameter("v_proj_weight", None)
|
226 |
+
|
227 |
+
if bias: # True by default
|
228 |
+
self.in_proj_bias = Parameter(
|
229 |
+
torch.empty(3 * embed_dim, **factory_kwargs)
|
230 |
+
)
|
231 |
+
else:
|
232 |
+
self.register_parameter("in_proj_bias", None)
|
233 |
+
self.out_proj = NonDynamicallyQuantizableLinear(
|
234 |
+
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
235 |
+
)
|
236 |
+
|
237 |
+
self._reset_parameters()
|
238 |
+
else:
|
239 |
+
if not self._qkv_same_embed_dim:
|
240 |
+
raise NotImplementedError
|
241 |
+
else:
|
242 |
+
self.in_proj_linear = linear1_cls(
|
243 |
+
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
|
244 |
+
)
|
245 |
+
self.in_proj_weight = self.in_proj_linear.weight
|
246 |
+
|
247 |
+
self.register_parameter("q_proj_weight", None)
|
248 |
+
self.register_parameter("k_proj_weight", None)
|
249 |
+
self.register_parameter("v_proj_weight", None)
|
250 |
+
|
251 |
+
if bias:
|
252 |
+
self.in_proj_bias = self.in_proj_linear.bias
|
253 |
+
else:
|
254 |
+
self.register_parameter("in_proj_bias", None)
|
255 |
+
|
256 |
+
self.out_proj = linear2_cls(
|
257 |
+
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
258 |
+
)
|
259 |
+
|
260 |
+
if self.bias_k is not None:
|
261 |
+
xavier_normal_(self.bias_k)
|
262 |
+
if self.bias_v is not None:
|
263 |
+
xavier_normal_(self.bias_v)
|
264 |
+
|
265 |
+
self.add_zero_attn = add_zero_attn
|
266 |
+
|
267 |
+
def _reset_parameters(self):
|
268 |
+
if self._qkv_same_embed_dim:
|
269 |
+
xavier_uniform_(self.in_proj_weight)
|
270 |
+
else:
|
271 |
+
xavier_uniform_(self.q_proj_weight)
|
272 |
+
xavier_uniform_(self.k_proj_weight)
|
273 |
+
xavier_uniform_(self.v_proj_weight)
|
274 |
+
|
275 |
+
if self.in_proj_bias is not None:
|
276 |
+
constant_(self.in_proj_bias, 0.0)
|
277 |
+
constant_(self.out_proj.bias, 0.0)
|
278 |
+
|
279 |
+
if self.bias_k is not None:
|
280 |
+
xavier_normal_(self.bias_k)
|
281 |
+
if self.bias_v is not None:
|
282 |
+
xavier_normal_(self.bias_v)
|
283 |
+
|
284 |
+
def __setstate__(self, state):
|
285 |
+
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
286 |
+
if "_qkv_same_embed_dim" not in state:
|
287 |
+
state["_qkv_same_embed_dim"] = True
|
288 |
+
|
289 |
+
super(MultiheadAttention, self).__setstate__(state)
|
290 |
+
|
291 |
+
def forward(
|
292 |
+
self,
|
293 |
+
query: Tensor,
|
294 |
+
key: Tensor,
|
295 |
+
value: Tensor,
|
296 |
+
key_padding_mask: Optional[Tensor] = None,
|
297 |
+
need_weights: bool = True,
|
298 |
+
attn_mask: Optional[Tensor] = None,
|
299 |
+
average_attn_weights: bool = True,
|
300 |
+
past: Optional[Tensor] = None,
|
301 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
302 |
+
r"""
|
303 |
+
Args:
|
304 |
+
query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
|
305 |
+
or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
|
306 |
+
:math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
|
307 |
+
Queries are compared against key-value pairs to produce the output.
|
308 |
+
See "Attention Is All You Need" for more details.
|
309 |
+
key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
|
310 |
+
or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
|
311 |
+
:math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
|
312 |
+
See "Attention Is All You Need" for more details.
|
313 |
+
value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
|
314 |
+
``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
|
315 |
+
sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
|
316 |
+
See "Attention Is All You Need" for more details.
|
317 |
+
key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
|
318 |
+
to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
|
319 |
+
Binary and byte masks are supported.
|
320 |
+
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
|
321 |
+
the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
|
322 |
+
need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
|
323 |
+
Default: ``True``.
|
324 |
+
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
|
325 |
+
:math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
|
326 |
+
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
|
327 |
+
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
|
328 |
+
Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
|
329 |
+
corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
|
330 |
+
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
|
331 |
+
the attention weight.
|
332 |
+
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
|
333 |
+
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
|
334 |
+
effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
|
335 |
+
Outputs:
|
336 |
+
- **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
|
337 |
+
:math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
|
338 |
+
where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
|
339 |
+
embedding dimension ``embed_dim``.
|
340 |
+
- **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
|
341 |
+
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
|
342 |
+
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
|
343 |
+
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
|
344 |
+
head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
|
345 |
+
.. note::
|
346 |
+
`batch_first` argument is ignored for unbatched inputs.
|
347 |
+
"""
|
348 |
+
is_batched = query.dim() == 3
|
349 |
+
if key_padding_mask is not None:
|
350 |
+
_kpm_dtype = key_padding_mask.dtype
|
351 |
+
if _kpm_dtype != torch.bool and not torch.is_floating_point(
|
352 |
+
key_padding_mask
|
353 |
+
):
|
354 |
+
raise AssertionError(
|
355 |
+
"only bool and floating types of key_padding_mask are supported"
|
356 |
+
)
|
357 |
+
why_not_fast_path = ""
|
358 |
+
if not is_batched:
|
359 |
+
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
360 |
+
elif query is not key or key is not value:
|
361 |
+
# When lifting this restriction, don't forget to either
|
362 |
+
# enforce that the dtypes all match or test cases where
|
363 |
+
# they don't!
|
364 |
+
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
|
365 |
+
elif (
|
366 |
+
self.in_proj_bias is not None
|
367 |
+
and query.dtype != self.in_proj_bias.dtype
|
368 |
+
):
|
369 |
+
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
|
370 |
+
elif (
|
371 |
+
self.in_proj_weight is not None
|
372 |
+
and query.dtype != self.in_proj_weight.dtype
|
373 |
+
):
|
374 |
+
# this case will fail anyway, but at least they'll get a useful error message.
|
375 |
+
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
|
376 |
+
elif self.training:
|
377 |
+
why_not_fast_path = "training is enabled"
|
378 |
+
elif not self.batch_first:
|
379 |
+
why_not_fast_path = "batch_first was not True"
|
380 |
+
elif self.bias_k is not None:
|
381 |
+
why_not_fast_path = "self.bias_k was not None"
|
382 |
+
elif self.bias_v is not None:
|
383 |
+
why_not_fast_path = "self.bias_v was not None"
|
384 |
+
elif self.dropout:
|
385 |
+
why_not_fast_path = f"dropout was {self.dropout}, required zero"
|
386 |
+
elif self.add_zero_attn:
|
387 |
+
why_not_fast_path = "add_zero_attn was enabled"
|
388 |
+
elif not self._qkv_same_embed_dim:
|
389 |
+
why_not_fast_path = "_qkv_same_embed_dim was not True"
|
390 |
+
elif attn_mask is not None:
|
391 |
+
why_not_fast_path = "attn_mask was not None"
|
392 |
+
elif query.is_nested and key_padding_mask is not None:
|
393 |
+
why_not_fast_path = (
|
394 |
+
"key_padding_mask is not supported with NestedTensor input"
|
395 |
+
)
|
396 |
+
elif self.num_heads % 2 == 1:
|
397 |
+
why_not_fast_path = "num_heads is odd"
|
398 |
+
elif torch.is_autocast_enabled():
|
399 |
+
why_not_fast_path = "autocast is enabled"
|
400 |
+
|
401 |
+
if not why_not_fast_path:
|
402 |
+
tensor_args = (
|
403 |
+
query,
|
404 |
+
key,
|
405 |
+
value,
|
406 |
+
self.in_proj_weight,
|
407 |
+
self.in_proj_bias,
|
408 |
+
self.out_proj.weight,
|
409 |
+
self.out_proj.bias,
|
410 |
+
)
|
411 |
+
# We have to use list comprehensions below because TorchScript does not support
|
412 |
+
# generator expressions.
|
413 |
+
if torch.overrides.has_torch_function(tensor_args):
|
414 |
+
why_not_fast_path = "some Tensor argument has_torch_function"
|
415 |
+
elif not all(
|
416 |
+
[
|
417 |
+
(x is None or x.is_cuda or "cpu" in str(x.device))
|
418 |
+
for x in tensor_args
|
419 |
+
]
|
420 |
+
):
|
421 |
+
why_not_fast_path = (
|
422 |
+
"some Tensor argument is neither CUDA nor CPU"
|
423 |
+
)
|
424 |
+
elif torch.is_grad_enabled() and any(
|
425 |
+
[x is not None and x.requires_grad for x in tensor_args]
|
426 |
+
):
|
427 |
+
why_not_fast_path = (
|
428 |
+
"grad is enabled and at least one of query or the "
|
429 |
+
"input/output projection weights or biases requires_grad"
|
430 |
+
)
|
431 |
+
if not why_not_fast_path:
|
432 |
+
return torch._native_multi_head_attention(
|
433 |
+
query,
|
434 |
+
key,
|
435 |
+
value,
|
436 |
+
self.embed_dim,
|
437 |
+
self.num_heads,
|
438 |
+
self.in_proj_weight,
|
439 |
+
self.in_proj_bias,
|
440 |
+
self.out_proj.weight,
|
441 |
+
self.out_proj.bias,
|
442 |
+
key_padding_mask
|
443 |
+
if key_padding_mask is not None
|
444 |
+
else attn_mask,
|
445 |
+
need_weights,
|
446 |
+
average_attn_weights,
|
447 |
+
1
|
448 |
+
if key_padding_mask is not None
|
449 |
+
else 0
|
450 |
+
if attn_mask is not None
|
451 |
+
else None,
|
452 |
+
)
|
453 |
+
|
454 |
+
any_nested = query.is_nested or key.is_nested or value.is_nested
|
455 |
+
assert not any_nested, (
|
456 |
+
"MultiheadAttention does not support NestedTensor outside of its fast path. "
|
457 |
+
+ f"The fast path was not hit because {why_not_fast_path}"
|
458 |
+
)
|
459 |
+
|
460 |
+
if self.batch_first and is_batched:
|
461 |
+
# make sure that the transpose op does not affect the "is" property
|
462 |
+
if key is value:
|
463 |
+
if query is key:
|
464 |
+
query = key = value = query.transpose(1, 0)
|
465 |
+
else:
|
466 |
+
query, key = [x.transpose(1, 0) for x in (query, key)]
|
467 |
+
value = key
|
468 |
+
else:
|
469 |
+
query, key, value = [
|
470 |
+
x.transpose(1, 0) for x in (query, key, value)
|
471 |
+
]
|
472 |
+
|
473 |
+
if not self._qkv_same_embed_dim:
|
474 |
+
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
475 |
+
query,
|
476 |
+
key,
|
477 |
+
value,
|
478 |
+
self.embed_dim,
|
479 |
+
self.num_heads,
|
480 |
+
self.in_proj_weight,
|
481 |
+
self.in_proj_bias,
|
482 |
+
self.bias_k,
|
483 |
+
self.bias_v,
|
484 |
+
self.add_zero_attn,
|
485 |
+
self.dropout,
|
486 |
+
self.out_proj.weight,
|
487 |
+
self.out_proj.bias,
|
488 |
+
training=self.training,
|
489 |
+
key_padding_mask=key_padding_mask,
|
490 |
+
need_weights=need_weights,
|
491 |
+
attn_mask=attn_mask,
|
492 |
+
use_separate_proj_weight=True,
|
493 |
+
q_proj_weight=self.q_proj_weight,
|
494 |
+
k_proj_weight=self.k_proj_weight,
|
495 |
+
v_proj_weight=self.v_proj_weight,
|
496 |
+
average_attn_weights=average_attn_weights,
|
497 |
+
)
|
498 |
+
else:
|
499 |
+
# re-write the self.attention here, to get k, v cache
|
500 |
+
tgt_len, bsz, embed_dim = query.shape
|
501 |
+
src_len, _, _ = key.shape
|
502 |
+
num_heads = self.num_heads
|
503 |
+
key_padding_mask = _canonical_mask(
|
504 |
+
mask=key_padding_mask,
|
505 |
+
mask_name="key_padding_mask",
|
506 |
+
other_type=_none_or_dtype(attn_mask),
|
507 |
+
other_name="attn_mask",
|
508 |
+
target_type=query.dtype
|
509 |
+
)
|
510 |
+
attn_mask = _canonical_mask(
|
511 |
+
mask=attn_mask,
|
512 |
+
mask_name="attn_mask",
|
513 |
+
other_type=None,
|
514 |
+
other_name="",
|
515 |
+
target_type=query.dtype,
|
516 |
+
check_other=False,
|
517 |
+
)
|
518 |
+
head_dim = self.embed_dim // self.num_heads
|
519 |
+
assert head_dim * self.num_heads == self.embed_dim, f"embed_dim {self.embed_dim} not divisible by num_heads {self.num_heads}"
|
520 |
+
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
|
521 |
+
q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias)
|
522 |
+
# k_present, v_present = k, v
|
523 |
+
|
524 |
+
#
|
525 |
+
# reshape q, k, v for multihead attention and make em batch first
|
526 |
+
#
|
527 |
+
|
528 |
+
q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
529 |
+
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
530 |
+
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) # (bsz * num_heads, src_len, head_dim)
|
531 |
+
src_len = k.size(1)
|
532 |
+
if past is not None and past.ndim > 2:
|
533 |
+
expected_src_len = src_len + past[0].shape[-2]
|
534 |
+
else:
|
535 |
+
expected_src_len = src_len
|
536 |
+
|
537 |
+
|
538 |
+
# ensure attn_mask's dim is 3
|
539 |
+
if attn_mask.dim() == 2:
|
540 |
+
correct_2d_size = (tgt_len, expected_src_len)
|
541 |
+
if attn_mask.shape != correct_2d_size:
|
542 |
+
raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
|
543 |
+
attn_mask = attn_mask.unsqueeze(0)
|
544 |
+
elif attn_mask.dim() == 3:
|
545 |
+
correct_3d_size = (bsz * num_heads, tgt_len, expected_src_len)
|
546 |
+
if attn_mask.shape != correct_3d_size:
|
547 |
+
raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
|
548 |
+
else:
|
549 |
+
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
|
550 |
+
|
551 |
+
if key_padding_mask is not None:
|
552 |
+
assert key_padding_mask.shape == (bsz, expected_src_len), \
|
553 |
+
f"expecting key_padding_mask shape of {(bsz, expected_src_len)}, but got {key_padding_mask.shape}"
|
554 |
+
key_padding_mask = key_padding_mask.view(bsz, 1, 1, expected_src_len). \
|
555 |
+
expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, expected_src_len)
|
556 |
+
if attn_mask is None:
|
557 |
+
attn_mask = key_padding_mask
|
558 |
+
else:
|
559 |
+
attn_mask = attn_mask + key_padding_mask
|
560 |
+
|
561 |
+
if not self.training:
|
562 |
+
dropout_p = 0.0
|
563 |
+
else:
|
564 |
+
dropout_p = self.dropout
|
565 |
+
|
566 |
+
if need_weights:
|
567 |
+
raise NotImplementedError("need_weights not implemented for voicecraft")
|
568 |
+
# B, Nt, E = q.shape
|
569 |
+
# q_scaled = q / math.sqrt(E)
|
570 |
+
|
571 |
+
# assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
|
572 |
+
|
573 |
+
# if attn_mask is not None:
|
574 |
+
# attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
|
575 |
+
# else:
|
576 |
+
# attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
|
577 |
+
# attn_output_weights = softmax(attn_output_weights, dim=-1)
|
578 |
+
# if dropout_p > 0.0:
|
579 |
+
# attn_output_weights = dropout(attn_output_weights, p=dropout_p)
|
580 |
+
|
581 |
+
# attn_output = torch.bmm(attn_output_weights, v)
|
582 |
+
|
583 |
+
# attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
|
584 |
+
# attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
585 |
+
# attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
586 |
+
|
587 |
+
# # optionally average attention weights over heads
|
588 |
+
# attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
589 |
+
# if average_attn_weights:
|
590 |
+
# attn_output_weights = attn_output_weights.mean(dim=1)
|
591 |
+
|
592 |
+
# if not is_batched:
|
593 |
+
# # squeeze the output if input was unbatched
|
594 |
+
# attn_output = attn_output.squeeze(1)
|
595 |
+
# attn_output_weights = attn_output_weights.squeeze(0)
|
596 |
+
# return attn_output, attn_output_weights
|
597 |
+
else:
|
598 |
+
# attn_mask can be either (L,S) or (N*num_heads, L, S)
|
599 |
+
# if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
|
600 |
+
# in order to match the input for SDPA of (N, num_heads, L, S)
|
601 |
+
if attn_mask is not None:
|
602 |
+
if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
|
603 |
+
attn_mask = attn_mask.unsqueeze(0)
|
604 |
+
else:
|
605 |
+
attn_mask = attn_mask.view(bsz, num_heads, -1, expected_src_len)
|
606 |
+
|
607 |
+
q = q.view(bsz, num_heads, tgt_len, head_dim)
|
608 |
+
k = k.view(bsz, num_heads, src_len, head_dim)
|
609 |
+
v = v.view(bsz, num_heads, src_len, head_dim)
|
610 |
+
# logging.info(f"shape of past: {past.shape}")
|
611 |
+
if past is not None:
|
612 |
+
present = torch.stack([k, v], dim=0) # (2, bsz, num_heads, src_len, head_dim)
|
613 |
+
if past.ndim > 2: # this means we use kvcache, otherwise we just pass in a placeholder, but not actually using kvcache
|
614 |
+
pk, pv = past
|
615 |
+
k = torch.cat([pk, k], dim=-2)
|
616 |
+
v = torch.cat([pv, v], dim=-2)
|
617 |
+
else:
|
618 |
+
present = None
|
619 |
+
attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal=False)
|
620 |
+
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
|
621 |
+
|
622 |
+
attn_output = F.linear(attn_output, self.out_proj.weight, self.out_proj.bias)
|
623 |
+
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
624 |
+
if not is_batched:
|
625 |
+
# squeeze the output if input was unbatched
|
626 |
+
attn_output = attn_output.squeeze(1)
|
627 |
+
# if self.training:
|
628 |
+
# return attn_output, None
|
629 |
+
# else:
|
630 |
+
# return (attn_output, present), None
|
631 |
+
|
632 |
+
# harded coded, the code do not support returning attn weigths yet
|
633 |
+
attn_output_weights=None
|
634 |
+
if self.batch_first and is_batched:
|
635 |
+
return attn_output.transpose(1, 0), present
|
636 |
+
else:
|
637 |
+
return attn_output, present
|
638 |
+
|
src/model/modules/codebooks_patterns.py
ADDED
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from collections import namedtuple
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from functools import lru_cache
|
10 |
+
import logging
|
11 |
+
import typing as tp
|
12 |
+
|
13 |
+
from abc import ABC, abstractmethod
|
14 |
+
import torch
|
15 |
+
|
16 |
+
LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index)
|
17 |
+
PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates
|
18 |
+
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class Pattern:
|
22 |
+
"""Base implementation of a pattern over a sequence with multiple codebooks.
|
23 |
+
|
24 |
+
The codebook pattern consists in a layout, defining for each sequence step
|
25 |
+
the list of coordinates of each codebook timestep in the resulting interleaved sequence.
|
26 |
+
The first item of the pattern is always an empty list in order to properly insert a special token
|
27 |
+
to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern
|
28 |
+
and ``timesteps`` the number of timesteps corresponding to the original sequence.
|
29 |
+
|
30 |
+
The pattern provides convenient methods to build and revert interleaved sequences from it:
|
31 |
+
``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
|
32 |
+
to the interleaved sequence of shape [B, K, S] applying the pattern, with S being the batch size,
|
33 |
+
K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
|
34 |
+
for the output sequence. The unfilled positions are replaced with a special token and the built sequence
|
35 |
+
is returned along with a mask indicating valid tokens.
|
36 |
+
``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
|
37 |
+
of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
|
38 |
+
to fill and specify invalid positions if needed.
|
39 |
+
See the dedicated methods for more details.
|
40 |
+
"""
|
41 |
+
# Pattern layout, for each sequence step, we have a list of coordinates
|
42 |
+
# corresponding to the original codebook timestep and position.
|
43 |
+
# The first list is always an empty list in order to properly insert
|
44 |
+
# a special token to start with.
|
45 |
+
layout: PatternLayout
|
46 |
+
timesteps: int
|
47 |
+
n_q: int
|
48 |
+
|
49 |
+
def __post_init__(self):
|
50 |
+
assert len(self.layout) > 0
|
51 |
+
assert self.layout[0] == []
|
52 |
+
self._validate_layout()
|
53 |
+
self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
|
54 |
+
self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
|
55 |
+
# logging.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
|
56 |
+
|
57 |
+
def _validate_layout(self):
|
58 |
+
"""Runs checks on the layout to ensure a valid pattern is defined.
|
59 |
+
A pattern is considered invalid if:
|
60 |
+
- Multiple timesteps for a same codebook are defined in the same sequence step
|
61 |
+
- The timesteps for a given codebook are not in ascending order as we advance in the sequence
|
62 |
+
(this would mean that we have future timesteps before past timesteps).
|
63 |
+
"""
|
64 |
+
q_timesteps = {q: 0 for q in range(self.n_q)}
|
65 |
+
for s, seq_coords in enumerate(self.layout):
|
66 |
+
if len(seq_coords) > 0:
|
67 |
+
qs = set()
|
68 |
+
for coord in seq_coords:
|
69 |
+
qs.add(coord.q)
|
70 |
+
last_q_timestep = q_timesteps[coord.q]
|
71 |
+
assert coord.t >= last_q_timestep, \
|
72 |
+
f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
|
73 |
+
q_timesteps[coord.q] = coord.t
|
74 |
+
# each sequence step contains at max 1 coordinate per codebook
|
75 |
+
assert len(qs) == len(seq_coords), \
|
76 |
+
f"Multiple entries for a same codebook are found at step {s}"
|
77 |
+
|
78 |
+
@property
|
79 |
+
def num_sequence_steps(self):
|
80 |
+
return len(self.layout) - 1
|
81 |
+
|
82 |
+
@property
|
83 |
+
def max_delay(self):
|
84 |
+
max_t_in_seq_coords = 0
|
85 |
+
for seq_coords in self.layout[1:]:
|
86 |
+
for coords in seq_coords:
|
87 |
+
max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
|
88 |
+
return max_t_in_seq_coords - self.timesteps
|
89 |
+
|
90 |
+
@property
|
91 |
+
def valid_layout(self):
|
92 |
+
valid_step = len(self.layout) - self.max_delay
|
93 |
+
return self.layout[:valid_step]
|
94 |
+
|
95 |
+
def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
|
96 |
+
"""Get codebook coordinates in the layout that corresponds to the specified timestep t
|
97 |
+
and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
|
98 |
+
and the actual codebook coordinates.
|
99 |
+
"""
|
100 |
+
assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
|
101 |
+
if q is not None:
|
102 |
+
assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks"
|
103 |
+
coords = []
|
104 |
+
for s, seq_codes in enumerate(self.layout):
|
105 |
+
for code in seq_codes:
|
106 |
+
if code.t == t and (q is None or code.q == q):
|
107 |
+
coords.append((s, code))
|
108 |
+
return coords
|
109 |
+
|
110 |
+
def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
|
111 |
+
return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
|
112 |
+
|
113 |
+
def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
|
114 |
+
steps_with_timesteps = self.get_steps_with_timestep(t, q)
|
115 |
+
return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
|
116 |
+
|
117 |
+
def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool,
|
118 |
+
device: tp.Union[torch.device, str] = 'cpu'):
|
119 |
+
"""Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
timesteps (int): Maximum number of timesteps steps to consider.
|
123 |
+
keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
|
124 |
+
device (Union[torch.device, str]): Device for created tensors.
|
125 |
+
Returns:
|
126 |
+
indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
|
127 |
+
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
|
128 |
+
"""
|
129 |
+
assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
|
130 |
+
assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
|
131 |
+
# use the proper layout based on whether we limit ourselves to valid steps only or not,
|
132 |
+
# note that using the valid_layout will result in a truncated sequence up to the valid steps
|
133 |
+
ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
|
134 |
+
# single item indexing being super slow with pytorch vs. numpy, so we use numpy here
|
135 |
+
indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy()
|
136 |
+
mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy()
|
137 |
+
# fill indexes with last sequence step value that will correspond to our special token
|
138 |
+
# the last value is n_q * timesteps as we have flattened z and append special token as the last token
|
139 |
+
# which will correspond to the index: n_q * timesteps
|
140 |
+
indexes[:] = n_q * timesteps
|
141 |
+
# iterate over the pattern and fill scattered indexes and mask
|
142 |
+
for s, sequence_coords in enumerate(ref_layout):
|
143 |
+
for coords in sequence_coords:
|
144 |
+
if coords.t < timesteps:
|
145 |
+
indexes[coords.q, s] = coords.t + coords.q * timesteps
|
146 |
+
mask[coords.q, s] = 1
|
147 |
+
indexes = torch.from_numpy(indexes).to(device)
|
148 |
+
mask = torch.from_numpy(mask).to(device)
|
149 |
+
return indexes, mask
|
150 |
+
|
151 |
+
def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
|
152 |
+
"""Build sequence corresponding to the pattern from the input tensor z.
|
153 |
+
The sequence is built using up to sequence_steps if specified, and non-pattern
|
154 |
+
coordinates are filled with the special token.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
|
158 |
+
special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
|
159 |
+
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
|
160 |
+
Steps that are beyond valid steps will be replaced by the special_token in that case.
|
161 |
+
Returns:
|
162 |
+
values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
|
163 |
+
corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
|
164 |
+
indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
|
165 |
+
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
|
166 |
+
"""
|
167 |
+
B, K, T = z.shape
|
168 |
+
indexes, mask = self._build_pattern_sequence_scatter_indexes(
|
169 |
+
T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
|
170 |
+
)
|
171 |
+
z = z.view(B, -1)
|
172 |
+
# we append the special token as the last index of our flattened z tensor
|
173 |
+
z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
|
174 |
+
values = z[:, indexes.view(-1)]
|
175 |
+
values = values.view(B, K, indexes.shape[-1])
|
176 |
+
return values, indexes, mask
|
177 |
+
|
178 |
+
def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int,
|
179 |
+
keep_only_valid_steps: bool = False,
|
180 |
+
is_model_output: bool = False,
|
181 |
+
device: tp.Union[torch.device, str] = 'cpu'):
|
182 |
+
"""Builds scatter indexes required to retrieve the original multi-codebook sequence
|
183 |
+
from interleaving pattern.
|
184 |
+
|
185 |
+
Args:
|
186 |
+
sequence_steps (int): Sequence steps.
|
187 |
+
n_q (int): Number of codebooks.
|
188 |
+
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
|
189 |
+
Steps that are beyond valid steps will be replaced by the special_token in that case.
|
190 |
+
is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
|
191 |
+
device (Union[torch.device, str]): Device for created tensors.
|
192 |
+
Returns:
|
193 |
+
torch.Tensor: Indexes for reconstructing the output, of shape [K, T].
|
194 |
+
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
|
195 |
+
"""
|
196 |
+
ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
|
197 |
+
# TODO(jade): Do we want to further truncate to only valid timesteps here as well?
|
198 |
+
timesteps = self.timesteps
|
199 |
+
assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
|
200 |
+
assert sequence_steps <= len(ref_layout), \
|
201 |
+
f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
|
202 |
+
|
203 |
+
# ensure we take the appropriate indexes to keep the model output from the first special token as well
|
204 |
+
if is_model_output:
|
205 |
+
ref_layout = ref_layout[1:]
|
206 |
+
|
207 |
+
# single item indexing being super slow with pytorch vs. numpy, so we use numpy here
|
208 |
+
indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy()
|
209 |
+
mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy()
|
210 |
+
# fill indexes with last sequence step value that will correspond to our special token
|
211 |
+
indexes[:] = n_q * sequence_steps
|
212 |
+
for s, sequence_codes in enumerate(ref_layout):
|
213 |
+
if s < sequence_steps:
|
214 |
+
for code in sequence_codes:
|
215 |
+
if code.t < timesteps:
|
216 |
+
indexes[code.q, code.t] = s + code.q * sequence_steps
|
217 |
+
mask[code.q, code.t] = 1
|
218 |
+
indexes = torch.from_numpy(indexes).to(device)
|
219 |
+
mask = torch.from_numpy(mask).to(device)
|
220 |
+
return indexes, mask
|
221 |
+
|
222 |
+
def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
|
223 |
+
"""Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
|
224 |
+
The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
|
225 |
+
are filled with the special token.
|
226 |
+
|
227 |
+
Args:
|
228 |
+
s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
|
229 |
+
special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
|
230 |
+
Returns:
|
231 |
+
values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
|
232 |
+
corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
|
233 |
+
indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
|
234 |
+
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
|
235 |
+
"""
|
236 |
+
B, K, S = s.shape
|
237 |
+
indexes, mask = self._build_reverted_sequence_scatter_indexes(
|
238 |
+
S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
|
239 |
+
)
|
240 |
+
s = s.view(B, -1)
|
241 |
+
# we append the special token as the last index of our flattened z tensor
|
242 |
+
s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
|
243 |
+
values = s[:, indexes.view(-1)]
|
244 |
+
values = values.view(B, K, indexes.shape[-1])
|
245 |
+
return values, indexes, mask
|
246 |
+
|
247 |
+
def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
|
248 |
+
"""Revert model logits obtained on a sequence built from the pattern
|
249 |
+
back to a tensor matching the original sequence.
|
250 |
+
|
251 |
+
This method is similar to ``revert_pattern_sequence`` with the following specificities:
|
252 |
+
1. It is designed to work with the extra cardinality dimension
|
253 |
+
2. We return the logits for the first sequence item that matches the special_token and
|
254 |
+
which matching target in the original sequence is the first item of the sequence,
|
255 |
+
while we skip the last logits as there is no matching target
|
256 |
+
"""
|
257 |
+
B, card, K, S = logits.shape
|
258 |
+
indexes, mask = self._build_reverted_sequence_scatter_indexes(
|
259 |
+
S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
|
260 |
+
)
|
261 |
+
logits = logits.reshape(B, card, -1)
|
262 |
+
# we append the special token as the last index of our flattened z tensor
|
263 |
+
logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S]
|
264 |
+
values = logits[:, :, indexes.view(-1)]
|
265 |
+
values = values.view(B, card, K, indexes.shape[-1])
|
266 |
+
return values, indexes, mask
|
267 |
+
|
268 |
+
|
269 |
+
class CodebooksPatternProvider(ABC):
|
270 |
+
"""Abstraction around providing pattern for interleaving codebooks.
|
271 |
+
|
272 |
+
The CodebooksPatternProvider abstraction allows to implement various strategies to
|
273 |
+
define interleaving pattern of sequences composed of multiple codebooks. For a given
|
274 |
+
number of codebooks `n_q`, the pattern provider can generate a specified pattern
|
275 |
+
corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern
|
276 |
+
can be used to construct a new sequence from the original codes respecting the specified
|
277 |
+
pattern. The pattern is defined as a list of list of code coordinates, code coordinate
|
278 |
+
being a tuple with the original timestep and codebook to build the new sequence.
|
279 |
+
Note that all patterns must start with an empty list that is then used to insert a first
|
280 |
+
sequence step of special tokens in the newly generated sequence.
|
281 |
+
|
282 |
+
Args:
|
283 |
+
n_q (int): number of codebooks.
|
284 |
+
cached (bool): if True, patterns for a given length are cached. In general
|
285 |
+
that should be true for efficiency reason to avoid synchronization points.
|
286 |
+
"""
|
287 |
+
def __init__(self, n_q: int, cached: bool = True):
|
288 |
+
assert n_q > 0
|
289 |
+
self.n_q = n_q
|
290 |
+
self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore
|
291 |
+
|
292 |
+
@abstractmethod
|
293 |
+
def get_pattern(self, timesteps: int) -> Pattern:
|
294 |
+
"""Builds pattern with specific interleaving between codebooks.
|
295 |
+
|
296 |
+
Args:
|
297 |
+
timesteps (int): Total numer of timesteps.
|
298 |
+
"""
|
299 |
+
raise NotImplementedError()
|
300 |
+
|
301 |
+
|
302 |
+
class DelayedPatternProvider(CodebooksPatternProvider):
|
303 |
+
"""Provider for delayed pattern across delayed codebooks.
|
304 |
+
Codebooks are delayed in the sequence and sequence steps will contain codebooks
|
305 |
+
from different timesteps.
|
306 |
+
|
307 |
+
Example:
|
308 |
+
Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence:
|
309 |
+
[[1, 2, 3, 4],
|
310 |
+
[1, 2, 3, 4],
|
311 |
+
[1, 2, 3, 4]]
|
312 |
+
The resulting sequence obtained from the returned pattern is:
|
313 |
+
[[S, 1, 2, 3, 4],
|
314 |
+
[S, S, 1, 2, 3],
|
315 |
+
[S, S, S, 1, 2]]
|
316 |
+
(with S being a special token)
|
317 |
+
|
318 |
+
Args:
|
319 |
+
n_q (int): Number of codebooks.
|
320 |
+
delays (Optional[List[int]]): Delay for each of the codebooks.
|
321 |
+
If delays not defined, each codebook is delayed by 1 compared to the previous one.
|
322 |
+
flatten_first (int): Flatten the first N timesteps.
|
323 |
+
empty_initial (int): Prepend with N empty list of coordinates.
|
324 |
+
"""
|
325 |
+
def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None,
|
326 |
+
flatten_first: int = 0, empty_initial: int = 0):
|
327 |
+
super().__init__(n_q)
|
328 |
+
if delays is None:
|
329 |
+
delays = list(range(n_q))
|
330 |
+
self.delays = delays
|
331 |
+
self.flatten_first = flatten_first
|
332 |
+
self.empty_initial = empty_initial
|
333 |
+
assert len(self.delays) == self.n_q
|
334 |
+
assert sorted(self.delays) == self.delays
|
335 |
+
|
336 |
+
def get_pattern(self, timesteps: int) -> Pattern:
|
337 |
+
out: PatternLayout = [[]]
|
338 |
+
max_delay = max(self.delays)
|
339 |
+
if self.empty_initial:
|
340 |
+
out += [[] for _ in range(self.empty_initial)]
|
341 |
+
if self.flatten_first:
|
342 |
+
for t in range(min(timesteps, self.flatten_first)):
|
343 |
+
for q in range(self.n_q):
|
344 |
+
out.append([LayoutCoord(t, q)])
|
345 |
+
for t in range(self.flatten_first, timesteps + max_delay):
|
346 |
+
v = []
|
347 |
+
for q, delay in enumerate(self.delays):
|
348 |
+
t_for_q = t - delay
|
349 |
+
if t_for_q >= self.flatten_first:
|
350 |
+
v.append(LayoutCoord(t_for_q, q))
|
351 |
+
out.append(v)
|
352 |
+
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
353 |
+
|
354 |
+
|
355 |
+
class ParallelPatternProvider(DelayedPatternProvider):
|
356 |
+
"""Provider for parallel pattern across codebooks.
|
357 |
+
This pattern provider is a special case of the delayed pattern with actually no delay,
|
358 |
+
hence delays=repeat(0, n_q).
|
359 |
+
|
360 |
+
Args:
|
361 |
+
n_q (int): Number of codebooks.
|
362 |
+
"""
|
363 |
+
def __init__(self, n_q: int):
|
364 |
+
super().__init__(n_q, [0] * n_q)
|
365 |
+
|
366 |
+
|
367 |
+
class UnrolledPatternProvider(CodebooksPatternProvider):
|
368 |
+
"""Provider for unrolling codebooks pattern.
|
369 |
+
This pattern provider enables to represent the codebook flattened completely or only to some extend
|
370 |
+
while also specifying a given delay between the flattened codebooks representation, allowing to
|
371 |
+
unroll the codebooks in the sequence.
|
372 |
+
|
373 |
+
Example:
|
374 |
+
1. Flattening of the codebooks.
|
375 |
+
By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q),
|
376 |
+
taking n_q = 3 and timesteps = 4:
|
377 |
+
[[1, 2, 3, 4],
|
378 |
+
[1, 2, 3, 4],
|
379 |
+
[1, 2, 3, 4]]
|
380 |
+
will result into:
|
381 |
+
[[S, S, 1, S, S, 2, S, S, 3, S, S, 4],
|
382 |
+
[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
|
383 |
+
[1, S, S, 2, S, S, 3, S, S, 4, S, S]]
|
384 |
+
2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step
|
385 |
+
for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example
|
386 |
+
taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]:
|
387 |
+
[[1, 2, 3, 4],
|
388 |
+
[1, 2, 3, 4],
|
389 |
+
[1, 2, 3, 4]]
|
390 |
+
will result into:
|
391 |
+
[[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
|
392 |
+
[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
|
393 |
+
[1, S, S, 2, S, S, 3, S, S, 4, S, S]]
|
394 |
+
3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks
|
395 |
+
allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the
|
396 |
+
same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1]
|
397 |
+
and delays = [0, 3, 3]:
|
398 |
+
[[1, 2, 3, 4],
|
399 |
+
[1, 2, 3, 4],
|
400 |
+
[1, 2, 3, 4]]
|
401 |
+
will result into:
|
402 |
+
[[S, S, S, 1, S, 2, S, 3, S, 4],
|
403 |
+
[S, S, S, 1, S, 2, S, 3, S, 4],
|
404 |
+
[1, 2, 3, S, 4, S, 5, S, 6, S]]
|
405 |
+
|
406 |
+
Args:
|
407 |
+
n_q (int): Number of codebooks.
|
408 |
+
flattening (Optional[List[int]]): Flattening schema over the codebooks. If not defined,
|
409 |
+
the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
|
410 |
+
have n_q extra steps for each timestep.
|
411 |
+
delays (Optional[List[int]]): Delay for each of the codebooks. If not defined,
|
412 |
+
no delay is added and therefore will default to [0] * ``n_q``.
|
413 |
+
Note that two codebooks that will be flattened to the same inner step
|
414 |
+
should have the same delay, otherwise the pattern is considered as invalid.
|
415 |
+
"""
|
416 |
+
FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay'])
|
417 |
+
|
418 |
+
def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None,
|
419 |
+
delays: tp.Optional[tp.List[int]] = None):
|
420 |
+
super().__init__(n_q)
|
421 |
+
if flattening is None:
|
422 |
+
flattening = list(range(n_q))
|
423 |
+
if delays is None:
|
424 |
+
delays = [0] * n_q
|
425 |
+
assert len(flattening) == n_q
|
426 |
+
assert len(delays) == n_q
|
427 |
+
assert sorted(flattening) == flattening
|
428 |
+
assert sorted(delays) == delays
|
429 |
+
self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening)
|
430 |
+
self.max_delay = max(delays)
|
431 |
+
|
432 |
+
def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]):
|
433 |
+
"""Build a flattened codebooks representation as a dictionary of inner step
|
434 |
+
and the actual codebook indices corresponding to the flattened codebook. For convenience, we
|
435 |
+
also store the delay associated to the flattened codebook to avoid maintaining an extra mapping.
|
436 |
+
"""
|
437 |
+
flattened_codebooks: dict = {}
|
438 |
+
for q, (inner_step, delay) in enumerate(zip(flattening, delays)):
|
439 |
+
if inner_step not in flattened_codebooks:
|
440 |
+
flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay)
|
441 |
+
else:
|
442 |
+
flat_codebook = flattened_codebooks[inner_step]
|
443 |
+
assert flat_codebook.delay == delay, (
|
444 |
+
"Delay and flattening between codebooks is inconsistent: ",
|
445 |
+
"two codebooks flattened to the same position should have the same delay."
|
446 |
+
)
|
447 |
+
flat_codebook.codebooks.append(q)
|
448 |
+
flattened_codebooks[inner_step] = flat_codebook
|
449 |
+
return flattened_codebooks
|
450 |
+
|
451 |
+
@property
|
452 |
+
def _num_inner_steps(self):
|
453 |
+
"""Number of inner steps to unroll between timesteps in order to flatten the codebooks.
|
454 |
+
"""
|
455 |
+
return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1
|
456 |
+
|
457 |
+
def num_virtual_steps(self, timesteps: int) -> int:
|
458 |
+
return timesteps * self._num_inner_steps + 1
|
459 |
+
|
460 |
+
def get_pattern(self, timesteps: int) -> Pattern:
|
461 |
+
"""Builds pattern for delay across codebooks.
|
462 |
+
|
463 |
+
Args:
|
464 |
+
timesteps (int): Total numer of timesteps.
|
465 |
+
"""
|
466 |
+
# the PatternLayout is built as a tuple of sequence position and list of coordinates
|
467 |
+
# so that it can be reordered properly given the required delay between codebooks of given timesteps
|
468 |
+
indexed_out: list = [(-1, [])]
|
469 |
+
max_timesteps = timesteps + self.max_delay
|
470 |
+
for t in range(max_timesteps):
|
471 |
+
# for each timestep, we unroll the flattened codebooks,
|
472 |
+
# emitting the sequence step with the corresponding delay
|
473 |
+
for step in range(self._num_inner_steps):
|
474 |
+
if step in self._flattened_codebooks:
|
475 |
+
# we have codebooks at this virtual step to emit
|
476 |
+
step_codebooks = self._flattened_codebooks[step]
|
477 |
+
t_for_q = t + step_codebooks.delay
|
478 |
+
coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks]
|
479 |
+
if t_for_q < max_timesteps and t < max_timesteps:
|
480 |
+
indexed_out.append((t_for_q, coords))
|
481 |
+
else:
|
482 |
+
# there is no codebook in this virtual step so we emit an empty list
|
483 |
+
indexed_out.append((t, []))
|
484 |
+
out = [coords for _, coords in sorted(indexed_out)]
|
485 |
+
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
486 |
+
|
487 |
+
|
488 |
+
class VALLEPattern(CodebooksPatternProvider):
|
489 |
+
"""Almost VALL-E style pattern. We futher allow some delays for the
|
490 |
+
codebooks other than the first one.
|
491 |
+
|
492 |
+
Args:
|
493 |
+
n_q (int): Number of codebooks.
|
494 |
+
delays (Optional[List[int]]): Delay for each of the codebooks.
|
495 |
+
If delays not defined, each codebook is delayed by 1 compared to the previous one.
|
496 |
+
"""
|
497 |
+
def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
|
498 |
+
super().__init__(n_q)
|
499 |
+
if delays is None:
|
500 |
+
delays = [0] * (n_q - 1)
|
501 |
+
self.delays = delays
|
502 |
+
assert len(self.delays) == self.n_q - 1
|
503 |
+
assert sorted(self.delays) == self.delays
|
504 |
+
|
505 |
+
def get_pattern(self, timesteps: int) -> Pattern:
|
506 |
+
out: PatternLayout = [[]]
|
507 |
+
for t in range(timesteps):
|
508 |
+
out.append([LayoutCoord(t, 0)])
|
509 |
+
max_delay = max(self.delays)
|
510 |
+
for t in range(timesteps + max_delay):
|
511 |
+
v = []
|
512 |
+
for q, delay in enumerate(self.delays):
|
513 |
+
t_for_q = t - delay
|
514 |
+
if t_for_q >= 0:
|
515 |
+
v.append(LayoutCoord(t_for_q, q + 1))
|
516 |
+
out.append(v)
|
517 |
+
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
518 |
+
|
519 |
+
|
520 |
+
class MusicLMPattern(CodebooksPatternProvider):
|
521 |
+
"""Almost MusicLM style pattern. This is equivalent to full flattening
|
522 |
+
but in a different order.
|
523 |
+
|
524 |
+
Args:
|
525 |
+
n_q (int): Number of codebooks.
|
526 |
+
group_by (int): Number of codebooks to group together.
|
527 |
+
"""
|
528 |
+
def __init__(self, n_q: int, group_by: int = 2):
|
529 |
+
super().__init__(n_q)
|
530 |
+
self.group_by = group_by
|
531 |
+
|
532 |
+
def get_pattern(self, timesteps: int) -> Pattern:
|
533 |
+
out: PatternLayout = [[]]
|
534 |
+
for offset in range(0, self.n_q, self.group_by):
|
535 |
+
for t in range(timesteps):
|
536 |
+
for q in range(offset, offset + self.group_by):
|
537 |
+
out.append([LayoutCoord(t, q)])
|
538 |
+
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
src/model/modules/embedding.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
|
2 |
+
# Copyright 2023 (authors: Feiteng Li)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import math
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
|
21 |
+
|
22 |
+
class TokenEmbedding(nn.Module):
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
dim_model: int,
|
26 |
+
vocab_size: int,
|
27 |
+
dropout: float = 0.0,
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
|
31 |
+
self.vocab_size = vocab_size
|
32 |
+
self.dim_model = dim_model
|
33 |
+
|
34 |
+
self.dropout = torch.nn.Dropout(p=dropout)
|
35 |
+
self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)
|
36 |
+
|
37 |
+
@property
|
38 |
+
def weight(self) -> torch.Tensor:
|
39 |
+
return self.word_embeddings.weight
|
40 |
+
|
41 |
+
def embedding(self, index: int) -> torch.Tensor:
|
42 |
+
return self.word_embeddings.weight[index : index + 1]
|
43 |
+
|
44 |
+
def forward(self, x: torch.Tensor):
|
45 |
+
X = self.word_embeddings(x)
|
46 |
+
X = self.dropout(X)
|
47 |
+
|
48 |
+
return X
|
49 |
+
|
50 |
+
|
51 |
+
class SinePositionalEmbedding(nn.Module):
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
dim_model: int,
|
55 |
+
dropout: float = 0.0,
|
56 |
+
scale: bool = False,
|
57 |
+
alpha: bool = False,
|
58 |
+
):
|
59 |
+
super().__init__()
|
60 |
+
self.dim_model = dim_model
|
61 |
+
self.x_scale = math.sqrt(dim_model) if scale else 1.0
|
62 |
+
self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
|
63 |
+
self.dropout = torch.nn.Dropout(p=dropout)
|
64 |
+
|
65 |
+
self.reverse = False
|
66 |
+
self.pe = None
|
67 |
+
self.extend_pe(torch.tensor(0.0).expand(1, 4000))
|
68 |
+
|
69 |
+
def extend_pe(self, x):
|
70 |
+
"""Reset the positional encodings."""
|
71 |
+
if self.pe is not None:
|
72 |
+
if self.pe.size(1) >= x.size(1):
|
73 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
74 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
75 |
+
return
|
76 |
+
pe = torch.zeros(x.size(1), self.dim_model)
|
77 |
+
if self.reverse:
|
78 |
+
position = torch.arange(
|
79 |
+
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
80 |
+
).unsqueeze(1)
|
81 |
+
else:
|
82 |
+
position = torch.arange(
|
83 |
+
0, x.size(1), dtype=torch.float32
|
84 |
+
).unsqueeze(1)
|
85 |
+
div_term = torch.exp(
|
86 |
+
torch.arange(0, self.dim_model, 2, dtype=torch.float32)
|
87 |
+
* -(math.log(10000.0) / self.dim_model)
|
88 |
+
)
|
89 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
90 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
91 |
+
pe = pe.unsqueeze(0)
|
92 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
|
93 |
+
|
94 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
95 |
+
self.extend_pe(x)
|
96 |
+
output = x.unsqueeze(-1) if x.ndim == 2 else x
|
97 |
+
output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
|
98 |
+
return self.dropout(output)
|
src/model/modules/gemma.py
ADDED
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from typing import Optional, Tuple
|
4 |
+
import math
|
5 |
+
from src.model.modules.kv_cache import KVCache
|
6 |
+
|
7 |
+
|
8 |
+
class GemmaConfig:
|
9 |
+
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
vocab_size,
|
13 |
+
hidden_size,
|
14 |
+
intermediate_size,
|
15 |
+
num_hidden_layers,
|
16 |
+
num_attention_heads,
|
17 |
+
num_key_value_heads,
|
18 |
+
head_dim=256,
|
19 |
+
max_position_embeddings=8192,
|
20 |
+
rms_norm_eps=1e-6,
|
21 |
+
rope_theta=10000.0,
|
22 |
+
attention_bias=False,
|
23 |
+
attention_dropout=0.0,
|
24 |
+
pad_token_id=None,
|
25 |
+
**kwargs,
|
26 |
+
):
|
27 |
+
super().__init__()
|
28 |
+
self.vocab_size = vocab_size
|
29 |
+
self.max_position_embeddings = max_position_embeddings
|
30 |
+
self.hidden_size = hidden_size
|
31 |
+
self.intermediate_size = intermediate_size
|
32 |
+
self.num_hidden_layers = num_hidden_layers
|
33 |
+
self.num_attention_heads = num_attention_heads
|
34 |
+
self.head_dim = head_dim
|
35 |
+
self.num_key_value_heads = num_key_value_heads
|
36 |
+
self.rms_norm_eps = rms_norm_eps
|
37 |
+
self.rope_theta = rope_theta
|
38 |
+
self.attention_bias = attention_bias
|
39 |
+
self.attention_dropout = attention_dropout
|
40 |
+
self.pad_token_id = pad_token_id
|
41 |
+
|
42 |
+
|
43 |
+
class GemmaRMSNorm(nn.Module):
|
44 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
45 |
+
super().__init__()
|
46 |
+
self.eps = eps
|
47 |
+
self.weight = nn.Parameter(torch.zeros(dim))
|
48 |
+
|
49 |
+
def _norm(self, x):
|
50 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
output = self._norm(x.float())
|
54 |
+
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
|
55 |
+
# See https://github.com/huggingface/transformers/pull/29402
|
56 |
+
output = output * (1.0 + self.weight.float())
|
57 |
+
return output.type_as(x)
|
58 |
+
|
59 |
+
|
60 |
+
class GemmaRotaryEmbedding(nn.Module):
|
61 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
62 |
+
super().__init__()
|
63 |
+
|
64 |
+
self.dim = dim # it is set to the head_dim
|
65 |
+
self.max_position_embeddings = max_position_embeddings
|
66 |
+
self.base = base
|
67 |
+
|
68 |
+
# Calculate the theta according to the formula theta_i = base^(2i/dim) where i = 0, 1, 2, ..., dim // 2
|
69 |
+
inv_freq = 1.0 / (
|
70 |
+
self.base
|
71 |
+
** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)
|
72 |
+
)
|
73 |
+
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
|
74 |
+
|
75 |
+
@torch.no_grad()
|
76 |
+
def forward(self, x, position_ids, seq_len=None):
|
77 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
78 |
+
self.inv_freq.to(x.device)
|
79 |
+
# Copy the inv_freq tensor for batch in the sequence
|
80 |
+
# inv_freq_expanded: [Batch_Size, Head_Dim // 2, 1]
|
81 |
+
inv_freq_expanded = (
|
82 |
+
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
83 |
+
)
|
84 |
+
# position_ids_expanded: [Batch_Size, 1, Seq_Len]
|
85 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
86 |
+
device_type = x.device.type
|
87 |
+
device_type = (
|
88 |
+
device_type
|
89 |
+
if isinstance(device_type, str) and device_type != "mps"
|
90 |
+
else "cpu"
|
91 |
+
)
|
92 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
93 |
+
# Multiply each theta by the position (which is the argument of the sin and cos functions)
|
94 |
+
# freqs: [Batch_Size, Head_Dim // 2, 1] @ [Batch_Size, 1, Seq_Len] --> [Batch_Size, Seq_Len, Head_Dim // 2]
|
95 |
+
freqs = (
|
96 |
+
inv_freq_expanded.float() @ position_ids_expanded.float()
|
97 |
+
).transpose(1, 2)
|
98 |
+
# emb: [Batch_Size, Seq_Len, Head_Dim]
|
99 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
100 |
+
# cos, sin: [Batch_Size, Seq_Len, Head_Dim]
|
101 |
+
cos = emb.cos()
|
102 |
+
sin = emb.sin()
|
103 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
104 |
+
|
105 |
+
|
106 |
+
def rotate_half(x):
|
107 |
+
# Build the [-x2, x1, -x4, x3, ...] tensor for the sin part of the positional encoding.
|
108 |
+
x1 = x[..., : x.shape[-1] // 2] # Takes the first half of the last dimension
|
109 |
+
x2 = x[..., x.shape[-1] // 2 :] # Takes the second half of the last dimension
|
110 |
+
return torch.cat((-x2, x1), dim=-1)
|
111 |
+
|
112 |
+
|
113 |
+
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
114 |
+
cos = cos.unsqueeze(unsqueeze_dim) # Add the head dimension
|
115 |
+
sin = sin.unsqueeze(unsqueeze_dim) # Add the head dimension
|
116 |
+
# Apply the formula (34) of the Rotary Positional Encoding paper.
|
117 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
118 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
119 |
+
return q_embed, k_embed
|
120 |
+
|
121 |
+
|
122 |
+
class GemmaMLP(nn.Module):
|
123 |
+
def __init__(self, config):
|
124 |
+
super().__init__()
|
125 |
+
self.config = config
|
126 |
+
self.hidden_size = config.hidden_size
|
127 |
+
self.intermediate_size = config.intermediate_size
|
128 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
129 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
130 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
131 |
+
|
132 |
+
def forward(self, x):
|
133 |
+
# Equivalent to:
|
134 |
+
# y = self.gate_proj(x) # [Batch_Size, Seq_Len, Hidden_Size] -> [Batch_Size, Seq_Len, Intermediate_Size]
|
135 |
+
# y = torch.gelu(y, approximate="tanh") # [Batch_Size, Seq_Len, Intermediate_Size]
|
136 |
+
# j = self.up_proj(x) # [Batch_Size, Seq_Len, Hidden_Size] -> [Batch_Size, Seq_Len, Intermediate_Size]
|
137 |
+
# z = y * j # [Batch_Size, Seq_Len, Intermediate_Size]
|
138 |
+
# z = self.down_proj(z) # [Batch_Size, Seq_Len, Intermediate_Size] -> [Batch_Size, Seq_Len, Hidden_Size]
|
139 |
+
return self.down_proj(
|
140 |
+
nn.functional.gelu(self.gate_proj(x), approximate="tanh") * self.up_proj(x)
|
141 |
+
)
|
142 |
+
|
143 |
+
|
144 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
145 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
146 |
+
if n_rep == 1:
|
147 |
+
return hidden_states
|
148 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(
|
149 |
+
batch, num_key_value_heads, n_rep, slen, head_dim
|
150 |
+
)
|
151 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
152 |
+
|
153 |
+
|
154 |
+
class GemmaAttention(nn.Module):
|
155 |
+
|
156 |
+
def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None):
|
157 |
+
super().__init__()
|
158 |
+
self.config = config
|
159 |
+
self.layer_idx = layer_idx
|
160 |
+
|
161 |
+
self.attention_dropout = config.attention_dropout
|
162 |
+
self.hidden_size = config.hidden_size
|
163 |
+
self.num_heads = config.num_attention_heads
|
164 |
+
self.head_dim = config.head_dim
|
165 |
+
self.num_key_value_heads = config.num_key_value_heads
|
166 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
167 |
+
self.max_position_embeddings = config.max_position_embeddings
|
168 |
+
self.rope_theta = config.rope_theta
|
169 |
+
self.is_causal = True
|
170 |
+
|
171 |
+
assert self.hidden_size % self.num_heads == 0
|
172 |
+
|
173 |
+
self.q_proj = nn.Linear(
|
174 |
+
self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias
|
175 |
+
)
|
176 |
+
self.k_proj = nn.Linear(
|
177 |
+
self.hidden_size,
|
178 |
+
self.num_key_value_heads * self.head_dim,
|
179 |
+
bias=config.attention_bias,
|
180 |
+
)
|
181 |
+
self.v_proj = nn.Linear(
|
182 |
+
self.hidden_size,
|
183 |
+
self.num_key_value_heads * self.head_dim,
|
184 |
+
bias=config.attention_bias,
|
185 |
+
)
|
186 |
+
self.o_proj = nn.Linear(
|
187 |
+
self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias
|
188 |
+
)
|
189 |
+
self.rotary_emb = GemmaRotaryEmbedding(
|
190 |
+
self.head_dim,
|
191 |
+
max_position_embeddings=self.max_position_embeddings,
|
192 |
+
base=self.rope_theta,
|
193 |
+
)
|
194 |
+
|
195 |
+
def forward(
|
196 |
+
self,
|
197 |
+
hidden_states: torch.Tensor,
|
198 |
+
attention_mask: Optional[torch.Tensor] = None,
|
199 |
+
position_ids: Optional[torch.LongTensor] = None,
|
200 |
+
kv_cache: Optional[KVCache] = None,
|
201 |
+
**kwargs,
|
202 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
203 |
+
bsz, q_len, _ = hidden_states.size() # [Batch_Size, Seq_Len, Hidden_Size]
|
204 |
+
# [Batch_Size, Seq_Len, Num_Heads_Q * Head_Dim]
|
205 |
+
query_states = self.q_proj(hidden_states)
|
206 |
+
# [Batch_Size, Seq_Len, Num_Heads_KV * Head_Dim]
|
207 |
+
key_states = self.k_proj(hidden_states)
|
208 |
+
# [Batch_Size, Seq_Len, Num_Heads_KV * Head_Dim]
|
209 |
+
value_states = self.v_proj(hidden_states)
|
210 |
+
# [Batch_Size, Num_Heads_Q, Seq_Len, Head_Dim]
|
211 |
+
query_states = query_states.view(
|
212 |
+
bsz, q_len, self.num_heads, self.head_dim
|
213 |
+
).transpose(1, 2)
|
214 |
+
# [Batch_Size, Num_Heads_KV, Seq_Len, Head_Dim]
|
215 |
+
key_states = key_states.view(
|
216 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
217 |
+
).transpose(1, 2)
|
218 |
+
# [Batch_Size, Num_Heads_KV, Seq_Len, Head_Dim]
|
219 |
+
value_states = value_states.view(
|
220 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
221 |
+
).transpose(1, 2)
|
222 |
+
|
223 |
+
# [Batch_Size, Seq_Len, Head_Dim], [Batch_Size, Seq_Len, Head_Dim]
|
224 |
+
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
|
225 |
+
# [Batch_Size, Num_Heads_Q, Seq_Len, Head_Dim], [Batch_Size, Num_Heads_KV, Seq_Len, Head_Dim]
|
226 |
+
query_states, key_states = apply_rotary_pos_emb(
|
227 |
+
query_states, key_states, cos, sin
|
228 |
+
)
|
229 |
+
|
230 |
+
if kv_cache is not None:
|
231 |
+
key_states, value_states = kv_cache.update(
|
232 |
+
key_states, value_states, self.layer_idx
|
233 |
+
)
|
234 |
+
|
235 |
+
# Repeat the key and values to match the number of heads of the query
|
236 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
237 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
238 |
+
# Perform the calculation as usual, Q * K^T / sqrt(head_dim). Shape: [Batch_Size, Num_Heads_Q, Seq_Len_Q, Seq_Len_KV]
|
239 |
+
attn_weights = torch.matmul(
|
240 |
+
query_states, key_states.transpose(2, 3)
|
241 |
+
) / math.sqrt(self.head_dim)
|
242 |
+
|
243 |
+
assert attention_mask is not None
|
244 |
+
attn_weights = attn_weights + attention_mask
|
245 |
+
|
246 |
+
# Apply the softmax
|
247 |
+
# [Batch_Size, Num_Heads_Q, Seq_Len_Q, Seq_Len_KV]
|
248 |
+
attn_weights = nn.functional.softmax(
|
249 |
+
attn_weights, dim=-1, dtype=torch.float32
|
250 |
+
).to(query_states.dtype)
|
251 |
+
# Apply the dropout
|
252 |
+
attn_weights = nn.functional.dropout(
|
253 |
+
attn_weights, p=self.attention_dropout, training=self.training
|
254 |
+
)
|
255 |
+
# Multiply by the values. [Batch_Size, Num_Heads_Q, Seq_Len_Q, Seq_Len_KV] x [Batch_Size, Num_Heads_KV, Seq_Len_KV, Head_Dim] -> [Batch_Size, Num_Heads_Q, Seq_Len_Q, Head_Dim]
|
256 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
257 |
+
|
258 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
259 |
+
raise ValueError(
|
260 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
261 |
+
f" {attn_output.size()}"
|
262 |
+
)
|
263 |
+
# Make sure the sequence length is the second dimension. # [Batch_Size, Num_Heads_Q, Seq_Len_Q, Head_Dim] -> [Batch_Size, Seq_Len_Q, Num_Heads_Q, Head_Dim]
|
264 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
265 |
+
# Concatenate all the heads together. [Batch_Size, Seq_Len_Q, Num_Heads_Q, Head_Dim] -> [Batch_Size, Seq_Len_Q, Num_Heads_Q * Head_Dim]
|
266 |
+
attn_output = attn_output.view(bsz, q_len, -1)
|
267 |
+
# Multiply by W_o. [Batch_Size, Seq_Len_Q, Hidden_Size]
|
268 |
+
attn_output = self.o_proj(attn_output)
|
269 |
+
|
270 |
+
return attn_output, attn_weights
|
271 |
+
|
272 |
+
|
273 |
+
class GemmaDecoderLayer(nn.Module):
|
274 |
+
|
275 |
+
def __init__(self, config: GemmaConfig, layer_idx: int):
|
276 |
+
super().__init__()
|
277 |
+
self.hidden_size = config.hidden_size
|
278 |
+
|
279 |
+
self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx)
|
280 |
+
|
281 |
+
self.mlp = GemmaMLP(config)
|
282 |
+
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
283 |
+
self.post_attention_layernorm = GemmaRMSNorm(
|
284 |
+
config.hidden_size, eps=config.rms_norm_eps
|
285 |
+
)
|
286 |
+
|
287 |
+
def forward(
|
288 |
+
self,
|
289 |
+
hidden_states: torch.Tensor,
|
290 |
+
attention_mask: Optional[torch.Tensor] = None,
|
291 |
+
position_ids: Optional[torch.LongTensor] = None,
|
292 |
+
kv_cache: Optional[KVCache] = None,
|
293 |
+
) -> Tuple[
|
294 |
+
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
295 |
+
]:
|
296 |
+
residual = hidden_states
|
297 |
+
# [Batch_Size, Seq_Len, Hidden_Size]
|
298 |
+
hidden_states = self.input_layernorm(hidden_states)
|
299 |
+
|
300 |
+
# [Batch_Size, Seq_Len, Hidden_Size]
|
301 |
+
(
|
302 |
+
hidden_states,
|
303 |
+
_,
|
304 |
+
) = self.self_attn(
|
305 |
+
hidden_states=hidden_states,
|
306 |
+
attention_mask=attention_mask,
|
307 |
+
position_ids=position_ids,
|
308 |
+
kv_cache=kv_cache,
|
309 |
+
)
|
310 |
+
# [Batch_Size, Seq_Len, Hidden_Size]
|
311 |
+
hidden_states = residual + hidden_states
|
312 |
+
|
313 |
+
# [Batch_Size, Seq_Len, Hidden_Size]
|
314 |
+
residual = hidden_states
|
315 |
+
# [Batch_Size, Seq_Len, Hidden_Size]
|
316 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
317 |
+
# [Batch_Size, Seq_Len, Hidden_Size]
|
318 |
+
hidden_states = self.mlp(hidden_states)
|
319 |
+
# [Batch_Size, Seq_Len, Hidden_Size]
|
320 |
+
hidden_states = residual + hidden_states
|
321 |
+
|
322 |
+
return hidden_states
|
323 |
+
|
324 |
+
|
325 |
+
class GemmaModel(nn.Module):
|
326 |
+
|
327 |
+
def __init__(self, config: GemmaConfig):
|
328 |
+
super().__init__()
|
329 |
+
self.config = config
|
330 |
+
self.padding_idx = config.pad_token_id
|
331 |
+
self.vocab_size = config.vocab_size
|
332 |
+
|
333 |
+
self.embed_tokens = nn.Embedding(
|
334 |
+
config.vocab_size, config.hidden_size, self.padding_idx
|
335 |
+
)
|
336 |
+
self.layers = nn.ModuleList(
|
337 |
+
[
|
338 |
+
GemmaDecoderLayer(config, layer_idx)
|
339 |
+
for layer_idx in range(config.num_hidden_layers)
|
340 |
+
]
|
341 |
+
)
|
342 |
+
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
343 |
+
|
344 |
+
def get_input_embeddings(self):
|
345 |
+
return self.embed_tokens
|
346 |
+
|
347 |
+
# Ignore copy
|
348 |
+
def forward(
|
349 |
+
self,
|
350 |
+
attention_mask: Optional[torch.Tensor] = None,
|
351 |
+
position_ids: Optional[torch.LongTensor] = None,
|
352 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
353 |
+
kv_cache: Optional[KVCache] = None,
|
354 |
+
) -> torch.FloatTensor:
|
355 |
+
# [Batch_Size, Seq_Len, Hidden_Size]
|
356 |
+
hidden_states = inputs_embeds
|
357 |
+
# [Batch_Size, Seq_Len, Hidden_Size]
|
358 |
+
normalizer = torch.tensor(
|
359 |
+
self.config.hidden_size**0.5, dtype=hidden_states.dtype
|
360 |
+
)
|
361 |
+
hidden_states = hidden_states * normalizer
|
362 |
+
|
363 |
+
for decoder_layer in self.layers:
|
364 |
+
# [Batch_Size, Seq_Len, Hidden_Size]
|
365 |
+
hidden_states = decoder_layer(
|
366 |
+
hidden_states,
|
367 |
+
attention_mask=attention_mask,
|
368 |
+
position_ids=position_ids,
|
369 |
+
kv_cache=kv_cache,
|
370 |
+
)
|
371 |
+
|
372 |
+
# [Batch_Size, Seq_Len, Hidden_Size]
|
373 |
+
hidden_states = self.norm(hidden_states)
|
374 |
+
|
375 |
+
# [Batch_Size, Seq_Len, Hidden_Size]
|
376 |
+
return hidden_states
|
377 |
+
|
378 |
+
|
379 |
+
class GemmaForCausalLM(nn.Module):
|
380 |
+
|
381 |
+
def __init__(self, config):
|
382 |
+
super().__init__()
|
383 |
+
self.config = config
|
384 |
+
self.model = GemmaModel(config)
|
385 |
+
self.vocab_size = config.vocab_size
|
386 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
387 |
+
|
388 |
+
def get_input_embeddings(self):
|
389 |
+
return self.model.embed_tokens
|
390 |
+
|
391 |
+
def tie_weights(self):
|
392 |
+
self.lm_head.weight = self.model.embed_tokens.weight
|
393 |
+
|
394 |
+
def forward(
|
395 |
+
self,
|
396 |
+
attention_mask: Optional[torch.Tensor] = None,
|
397 |
+
position_ids: Optional[torch.LongTensor] = None,
|
398 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
399 |
+
kv_cache: Optional[KVCache] = None,
|
400 |
+
) -> Tuple:
|
401 |
+
|
402 |
+
# input_embeds: [Batch_Size, Seq_Len, Hidden_Size]
|
403 |
+
# outputs: [Batch_Size, Seq_Len, Hidden_Size]
|
404 |
+
outputs = self.model(
|
405 |
+
attention_mask=attention_mask,
|
406 |
+
position_ids=position_ids,
|
407 |
+
inputs_embeds=inputs_embeds,
|
408 |
+
kv_cache=kv_cache,
|
409 |
+
)
|
410 |
+
|
411 |
+
hidden_states = outputs
|
412 |
+
logits = self.lm_head(hidden_states)
|
413 |
+
logits = logits.float()
|
414 |
+
|
415 |
+
return_data = {
|
416 |
+
"logits": logits,
|
417 |
+
}
|
418 |
+
|
419 |
+
if kv_cache is not None:
|
420 |
+
# Return the updated cache
|
421 |
+
return_data["kv_cache"] = kv_cache
|
422 |
+
|
423 |
+
return return_data
|
src/model/modules/imagecraft.py
ADDED
@@ -0,0 +1,490 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import Namespace
|
2 |
+
import glob
|
3 |
+
import logging
|
4 |
+
from pathlib import Path
|
5 |
+
import os
|
6 |
+
import time
|
7 |
+
from typing import Optional, Tuple
|
8 |
+
from PIL import Image
|
9 |
+
from safetensors import safe_open
|
10 |
+
import torch
|
11 |
+
from torch import nn
|
12 |
+
import torchaudio
|
13 |
+
from src.model.modules import voicecraft
|
14 |
+
from src.model.modules.gemma import GemmaForCausalLM, KVCache
|
15 |
+
from src.model.modules.imagecraftconfig import ImageCraftConfig
|
16 |
+
from src.model.modules.imagecraftprocessor import (
|
17 |
+
ImageCraftProcessor,
|
18 |
+
)
|
19 |
+
from src.model.modules.siglip import SiglipVisionModel
|
20 |
+
|
21 |
+
from transformers import AutoTokenizer
|
22 |
+
|
23 |
+
from src.model.modules.tokenizer import (
|
24 |
+
AudioTokenizer,
|
25 |
+
TextTokenizer,
|
26 |
+
tokenize_audio,
|
27 |
+
tokenize_text,
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
from src.utils import tools
|
32 |
+
from src.utils.image_utils import is_valid_image
|
33 |
+
from src.utils.model_utils import get_config, get_model_inputs
|
34 |
+
from src.utils.util import (
|
35 |
+
replace_numbers_with_words,
|
36 |
+
sample_top_p,
|
37 |
+
save_to_buffer,
|
38 |
+
save_to_file,
|
39 |
+
split_line_to_sentences,
|
40 |
+
)
|
41 |
+
|
42 |
+
from huggingface_hub import HfApi
|
43 |
+
|
44 |
+
logger = logging.getLogger(__name__)
|
45 |
+
|
46 |
+
|
47 |
+
class ImageCraftMultiModalProjector(nn.Module):
|
48 |
+
def __init__(self, config: ImageCraftConfig):
|
49 |
+
super().__init__()
|
50 |
+
self.linear = nn.Linear(
|
51 |
+
config.vision_config.hidden_size,
|
52 |
+
config.vision_config.projection_dim,
|
53 |
+
bias=True,
|
54 |
+
)
|
55 |
+
|
56 |
+
def forward(self, image_features):
|
57 |
+
hidden_states = self.linear(image_features)
|
58 |
+
return hidden_states
|
59 |
+
|
60 |
+
|
61 |
+
class ImageCraft(nn.Module):
|
62 |
+
config_class = ImageCraftConfig
|
63 |
+
|
64 |
+
def __init__(self, config: ImageCraftConfig):
|
65 |
+
super(ImageCraft, self).__init__()
|
66 |
+
self.config = config
|
67 |
+
self.vision_tower = SiglipVisionModel(config.vision_config)
|
68 |
+
self.multi_modal_projector = ImageCraftMultiModalProjector(config)
|
69 |
+
self.vocab_size = config.text_config.vocab_size
|
70 |
+
|
71 |
+
self.language_model = GemmaForCausalLM(config.text_config)
|
72 |
+
|
73 |
+
self.pad_token_id = (
|
74 |
+
self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
75 |
+
)
|
76 |
+
|
77 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
78 |
+
"google/paligemma-3b-pt-224", padding_side="right"
|
79 |
+
)
|
80 |
+
assert tokenizer.padding_side == "right"
|
81 |
+
|
82 |
+
num_image_tokens = config.vision_config.num_image_tokens
|
83 |
+
image_size = config.vision_config.image_size
|
84 |
+
self.processor = ImageCraftProcessor(tokenizer, num_image_tokens, image_size)
|
85 |
+
|
86 |
+
self.text_tokenizer = None
|
87 |
+
|
88 |
+
self.voicecraft_model = None
|
89 |
+
self.audio_tokenizer = None
|
90 |
+
|
91 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
92 |
+
|
93 |
+
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.tie_weights with Llava->PaliGemma
|
94 |
+
def tie_weights(self):
|
95 |
+
return self.language_model.tie_weights()
|
96 |
+
|
97 |
+
def forward(
|
98 |
+
self,
|
99 |
+
input_ids: torch.LongTensor = None,
|
100 |
+
pixel_values: torch.FloatTensor = None,
|
101 |
+
attention_mask: Optional[torch.Tensor] = None,
|
102 |
+
labels: Optional[torch.LongTensor] = None,
|
103 |
+
kv_cache: Optional[KVCache] = None,
|
104 |
+
) -> Tuple:
|
105 |
+
# Make sure the input is right-padded
|
106 |
+
assert torch.all(attention_mask == 1), "The input cannot be padded"
|
107 |
+
|
108 |
+
# 1. Extra the input embeddings
|
109 |
+
# shape: (Batch_Size, Seq_Len, Hidden_Size)
|
110 |
+
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
111 |
+
|
112 |
+
# 2. Merge text and images
|
113 |
+
# [Batch_Size, Channels, Height, Width] -> [Batch_Size, Num_Patches, Embed_Dim]
|
114 |
+
selected_image_feature = self.vision_tower(pixel_values.to(inputs_embeds.dtype))
|
115 |
+
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Hidden_Size]
|
116 |
+
image_features = self.multi_modal_projector(selected_image_feature)
|
117 |
+
|
118 |
+
# Merge the embeddings of the text tokens and the image tokens
|
119 |
+
inputs_embeds, attention_mask, position_ids = (
|
120 |
+
self._merge_input_ids_with_image_features(
|
121 |
+
image_features, inputs_embeds, input_ids, attention_mask, kv_cache
|
122 |
+
)
|
123 |
+
)
|
124 |
+
|
125 |
+
outputs = self.language_model(
|
126 |
+
attention_mask=attention_mask,
|
127 |
+
position_ids=position_ids,
|
128 |
+
inputs_embeds=inputs_embeds,
|
129 |
+
kv_cache=kv_cache,
|
130 |
+
)
|
131 |
+
|
132 |
+
return outputs
|
133 |
+
|
134 |
+
def _merge_input_ids_with_image_features(
|
135 |
+
self,
|
136 |
+
image_features: torch.Tensor,
|
137 |
+
inputs_embeds: torch.Tensor,
|
138 |
+
input_ids: torch.Tensor,
|
139 |
+
attention_mask: torch.Tensor,
|
140 |
+
kv_cache: Optional[KVCache] = None,
|
141 |
+
):
|
142 |
+
_, _, embed_dim = image_features.shape
|
143 |
+
batch_size, sequence_length = input_ids.shape
|
144 |
+
dtype, device = inputs_embeds.dtype, inputs_embeds.device
|
145 |
+
# Shape: [Batch_Size, Seq_Len, Hidden_Size]
|
146 |
+
scaled_image_features = image_features / (self.config.hidden_size**0.5)
|
147 |
+
|
148 |
+
# Combine the embeddings of the image tokens, the text tokens and mask out all the padding tokens.
|
149 |
+
final_embedding = torch.zeros(
|
150 |
+
batch_size,
|
151 |
+
sequence_length,
|
152 |
+
embed_dim,
|
153 |
+
dtype=inputs_embeds.dtype,
|
154 |
+
device=inputs_embeds.device,
|
155 |
+
)
|
156 |
+
# Shape: [Batch_Size, Seq_Len]. True for text tokens
|
157 |
+
text_mask = (input_ids != self.config.image_token_index) & (
|
158 |
+
input_ids != self.pad_token_id
|
159 |
+
)
|
160 |
+
# Shape: [Batch_Size, Seq_Len]. True for image tokens
|
161 |
+
image_mask = input_ids == self.config.image_token_index
|
162 |
+
# Shape: [Batch_Size, Seq_Len]. True for padding tokens
|
163 |
+
pad_mask = input_ids == self.pad_token_id
|
164 |
+
|
165 |
+
# We need to expand the masks to the embedding dimension otherwise we can't use them in torch.where
|
166 |
+
text_mask_expanded = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
|
167 |
+
pad_mask_expanded = pad_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
|
168 |
+
image_mask_expanded = image_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
|
169 |
+
|
170 |
+
# Add the text embeddings
|
171 |
+
final_embedding = torch.where(
|
172 |
+
text_mask_expanded, inputs_embeds, final_embedding
|
173 |
+
)
|
174 |
+
# Insert image embeddings. We can't use torch.where because the sequence length of scaled_image_features is not equal to the sequence length of the final embedding
|
175 |
+
final_embedding = final_embedding.masked_scatter(
|
176 |
+
image_mask_expanded, scaled_image_features
|
177 |
+
)
|
178 |
+
# Zero out padding tokens
|
179 |
+
final_embedding = torch.where(
|
180 |
+
pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding
|
181 |
+
)
|
182 |
+
|
183 |
+
#### CREATE THE ATTENTION MASK ####
|
184 |
+
|
185 |
+
dtype, device = inputs_embeds.dtype, inputs_embeds.device
|
186 |
+
min_dtype = torch.finfo(dtype).min
|
187 |
+
q_len = inputs_embeds.shape[1]
|
188 |
+
|
189 |
+
if kv_cache is None or kv_cache.num_items() == 0:
|
190 |
+
# Do not mask any token, because we're in the prefill phase
|
191 |
+
# This only works when we have no padding
|
192 |
+
causal_mask = torch.full(
|
193 |
+
(batch_size, q_len, q_len), fill_value=0, dtype=dtype, device=device
|
194 |
+
)
|
195 |
+
else:
|
196 |
+
# Since we are generating tokens, the query must be one single token
|
197 |
+
assert q_len == 1
|
198 |
+
kv_len = kv_cache.num_items() + q_len
|
199 |
+
# Also in this case we don't need to mask anything, since each query should be able to attend all previous tokens.
|
200 |
+
# This only works when we have no padding
|
201 |
+
causal_mask = torch.full(
|
202 |
+
(batch_size, q_len, kv_len), fill_value=0, dtype=dtype, device=device
|
203 |
+
)
|
204 |
+
|
205 |
+
# Add the head dimension
|
206 |
+
# [Batch_Size, Q_Len, KV_Len] -> [Batch_Size, Num_Heads_Q, Q_Len, KV_Len]
|
207 |
+
causal_mask = causal_mask.unsqueeze(1)
|
208 |
+
|
209 |
+
if kv_cache is not None and kv_cache.num_items() > 0:
|
210 |
+
# The position of the query is just the last position
|
211 |
+
position_ids = attention_mask.cumsum(-1)[:, -1]
|
212 |
+
if position_ids.dim() == 1:
|
213 |
+
position_ids = position_ids.unsqueeze(0)
|
214 |
+
else:
|
215 |
+
# Create a position_ids based on the size of the attention_mask
|
216 |
+
# For masked tokens, use the number 1 as position.
|
217 |
+
position_ids = (
|
218 |
+
(attention_mask.cumsum(-1))
|
219 |
+
.masked_fill_((attention_mask == 0), 1)
|
220 |
+
.to(device)
|
221 |
+
)
|
222 |
+
|
223 |
+
return final_embedding, causal_mask, position_ids
|
224 |
+
|
225 |
+
def _generate_caption(self, image, max_tokens=100, do_sample=False):
|
226 |
+
prompt = "caption en"
|
227 |
+
image = (
|
228 |
+
image.convert("RGB")
|
229 |
+
if is_valid_image(image)
|
230 |
+
else Image.open(image).convert("RGB")
|
231 |
+
)
|
232 |
+
|
233 |
+
inputs = get_model_inputs(
|
234 |
+
processor=self.processor, prompt=prompt, image=image, device=self.device
|
235 |
+
)
|
236 |
+
|
237 |
+
image.close()
|
238 |
+
|
239 |
+
input_ids = inputs["input_ids"]
|
240 |
+
attention_mask = inputs["attention_mask"]
|
241 |
+
pixel_values = inputs["pixel_values"]
|
242 |
+
|
243 |
+
kv_cache = KVCache()
|
244 |
+
|
245 |
+
stop_token = self.processor.tokenizer.eos_token_id
|
246 |
+
generated_tokens = []
|
247 |
+
|
248 |
+
for _ in range(max_tokens):
|
249 |
+
outputs = self(
|
250 |
+
input_ids=input_ids,
|
251 |
+
pixel_values=pixel_values,
|
252 |
+
attention_mask=attention_mask,
|
253 |
+
kv_cache=kv_cache,
|
254 |
+
)
|
255 |
+
kv_cache = outputs["kv_cache"]
|
256 |
+
next_token_logits = outputs["logits"][:, -1, :]
|
257 |
+
if do_sample:
|
258 |
+
next_token_logits = torch.softmax(
|
259 |
+
next_token_logits / self.config.temperature, dim=-1
|
260 |
+
)
|
261 |
+
next_token = sample_top_p(next_token_logits, self.config.top_p)
|
262 |
+
else:
|
263 |
+
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
|
264 |
+
assert next_token.size() == (1, 1)
|
265 |
+
next_token = next_token.squeeze(0)
|
266 |
+
generated_tokens.append(next_token)
|
267 |
+
if next_token.item() == stop_token:
|
268 |
+
break
|
269 |
+
input_ids = next_token.unsqueeze(-1)
|
270 |
+
attention_mask = torch.cat(
|
271 |
+
[attention_mask, torch.ones((1, 1), device=input_ids.device)], dim=-1
|
272 |
+
)
|
273 |
+
|
274 |
+
generated_tokens = torch.cat(generated_tokens, dim=-1)
|
275 |
+
decoded_text = self.processor.tokenizer.decode(
|
276 |
+
generated_tokens, skip_special_tokens=True
|
277 |
+
)
|
278 |
+
decoded_text = (
|
279 |
+
parts[1] if len(parts := decoded_text.split("\n", 1)) > 1 else decoded_text
|
280 |
+
)
|
281 |
+
|
282 |
+
return decoded_text.rstrip(" .").strip().capitalize() + "."
|
283 |
+
|
284 |
+
def _generate_speech(self, text: str, output_type="file"):
|
285 |
+
|
286 |
+
sentences = split_line_to_sentences(text)
|
287 |
+
|
288 |
+
voice_audio = (
|
289 |
+
f"media/voicecraft/voices/{self.config.voicecraft_config.voice_audio_path}"
|
290 |
+
)
|
291 |
+
voice_transcript = self.config.voicecraft_config.voice_audio_transcript
|
292 |
+
cut_off_sec = self.config.voicecraft_config.cut_off_sec
|
293 |
+
|
294 |
+
decode_config = {
|
295 |
+
"top_k": self.config.voicecraft_config.top_k,
|
296 |
+
"top_p": self.config.voicecraft_config.top_p,
|
297 |
+
"temperature": self.config.voicecraft_config.temperature,
|
298 |
+
"stop_repetition": self.config.voicecraft_config.stop_repetition,
|
299 |
+
"kvcache": self.config.voicecraft_config.kvcache,
|
300 |
+
"codec_audio_sr": self.config.voicecraft_config.codec_audio_sr,
|
301 |
+
"codec_sr": self.config.voicecraft_config.codec_sr,
|
302 |
+
"silence_tokens": self.config.voicecraft_config.silence_tokens,
|
303 |
+
"sample_batch_size": self.config.voicecraft_config.sample_batch_size,
|
304 |
+
}
|
305 |
+
|
306 |
+
info = torchaudio.info(voice_audio)
|
307 |
+
audio_dur = info.num_frames / info.sample_rate
|
308 |
+
prompt_end_frame = int(min(audio_dur, cut_off_sec) * info.sample_rate)
|
309 |
+
|
310 |
+
audio_tensors = []
|
311 |
+
transcript = voice_transcript
|
312 |
+
|
313 |
+
for sentence in sentences:
|
314 |
+
|
315 |
+
transcript += sentence + "\n"
|
316 |
+
transcript = replace_numbers_with_words(transcript).replace(" ", " ")
|
317 |
+
|
318 |
+
# phonemize
|
319 |
+
phn2num = self.voicecraft_model.args.phn2num
|
320 |
+
text_tokens = [
|
321 |
+
phn2num[phn]
|
322 |
+
for phn in tokenize_text(self.text_tokenizer, text=transcript.strip())
|
323 |
+
if phn in phn2num
|
324 |
+
]
|
325 |
+
text_tokens = torch.LongTensor(text_tokens).unsqueeze(0)
|
326 |
+
text_tokens_lens = torch.LongTensor([text_tokens.shape[-1]])
|
327 |
+
|
328 |
+
# encode audio
|
329 |
+
encoded_frames = tokenize_audio(
|
330 |
+
self.audio_tokenizer,
|
331 |
+
voice_audio,
|
332 |
+
offset=0,
|
333 |
+
num_frames=prompt_end_frame,
|
334 |
+
)
|
335 |
+
original_audio = encoded_frames[0][0].transpose(2, 1) # [1,T,K]
|
336 |
+
model_args = vars(self.voicecraft_model.args)
|
337 |
+
model_args = Namespace(**model_args)
|
338 |
+
|
339 |
+
assert (
|
340 |
+
original_audio.ndim == 3
|
341 |
+
and original_audio.shape[0] == 1
|
342 |
+
and original_audio.shape[2] == model_args.n_codebooks
|
343 |
+
), original_audio.shape
|
344 |
+
|
345 |
+
# forward
|
346 |
+
stime = time.time()
|
347 |
+
if decode_config["sample_batch_size"] <= 1:
|
348 |
+
_, gen_frames = self.voicecraft_model.inference_tts(
|
349 |
+
text_tokens.to(self.device),
|
350 |
+
text_tokens_lens.to(self.device),
|
351 |
+
original_audio[..., : model_args.n_codebooks].to(
|
352 |
+
self.device
|
353 |
+
), # [1,T,8]
|
354 |
+
top_k=decode_config["top_k"],
|
355 |
+
top_p=decode_config["top_p"],
|
356 |
+
temperature=decode_config["temperature"],
|
357 |
+
stop_repetition=decode_config["stop_repetition"],
|
358 |
+
kvcache=decode_config["kvcache"],
|
359 |
+
silence_tokens=(
|
360 |
+
eval(decode_config["silence_tokens"])
|
361 |
+
if type(decode_config["silence_tokens"]) == str
|
362 |
+
else decode_config["silence_tokens"]
|
363 |
+
),
|
364 |
+
) # output is [1,K,T]
|
365 |
+
else:
|
366 |
+
_, gen_frames = self.voicecraft_model.inference_tts_batch(
|
367 |
+
text_tokens.to(self.device),
|
368 |
+
text_tokens_lens.to(self.device),
|
369 |
+
original_audio[..., : model_args.n_codebooks].to(
|
370 |
+
self.device
|
371 |
+
), # [1,T,8]
|
372 |
+
top_k=decode_config["top_k"],
|
373 |
+
top_p=decode_config["top_p"],
|
374 |
+
temperature=decode_config["temperature"],
|
375 |
+
stop_repetition=decode_config["stop_repetition"],
|
376 |
+
kvcache=decode_config["kvcache"],
|
377 |
+
batch_size=decode_config["sample_batch_size"],
|
378 |
+
silence_tokens=(
|
379 |
+
eval(decode_config["silence_tokens"])
|
380 |
+
if type(decode_config["silence_tokens"]) == str
|
381 |
+
else decode_config["silence_tokens"]
|
382 |
+
),
|
383 |
+
) # output is [1,K,T]
|
384 |
+
gen_sample = self.audio_tokenizer.decode([(gen_frames, None)])
|
385 |
+
gen_audio = gen_sample[0].cpu()
|
386 |
+
audio_tensors.append(gen_audio)
|
387 |
+
|
388 |
+
output = None
|
389 |
+
|
390 |
+
if output_type == "file":
|
391 |
+
output = save_to_file(audio_tensors, decode_config["codec_audio_sr"])
|
392 |
+
else:
|
393 |
+
output = save_to_buffer(audio_tensors, decode_config["codec_audio_sr"])
|
394 |
+
|
395 |
+
# Empty cuda cache between runs
|
396 |
+
if torch.cuda.is_available():
|
397 |
+
torch.cuda.empty_cache()
|
398 |
+
|
399 |
+
return output
|
400 |
+
|
401 |
+
@torch.inference_mode()
|
402 |
+
def generate(
|
403 |
+
self,
|
404 |
+
image,
|
405 |
+
max_tokens=30,
|
406 |
+
do_sample=False,
|
407 |
+
output_type="file",
|
408 |
+
return_output="speech",
|
409 |
+
):
|
410 |
+
if return_output == "speech" or return_output is None:
|
411 |
+
transcript = self._generate_caption(image, max_tokens, do_sample)
|
412 |
+
speech = self._generate_speech(transcript, output_type)
|
413 |
+
return transcript, speech
|
414 |
+
else:
|
415 |
+
transcript = self._generate_caption(image, max_tokens, do_sample)
|
416 |
+
return transcript
|
417 |
+
|
418 |
+
@classmethod
|
419 |
+
def from_pretrained(
|
420 |
+
cls,
|
421 |
+
model_path=None,
|
422 |
+
):
|
423 |
+
api = HfApi()
|
424 |
+
|
425 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
426 |
+
|
427 |
+
env_config = tools.load_config()
|
428 |
+
pretrained_dir = env_config["pretrained_dir"]
|
429 |
+
imagecraft_cache_dir = f"{pretrained_dir}/imagecraft"
|
430 |
+
voicecraft_cache_dir = f"{pretrained_dir}/voicecraft"
|
431 |
+
|
432 |
+
state_dict = {}
|
433 |
+
|
434 |
+
if Path(model_path).is_file():
|
435 |
+
checkpoint = torch.load(model_path, weights_only=False)
|
436 |
+
state_dict = checkpoint["state_dict"]
|
437 |
+
|
438 |
+
else:
|
439 |
+
|
440 |
+
model_path = api.snapshot_download(
|
441 |
+
repo_id=model_path,
|
442 |
+
repo_type="model",
|
443 |
+
cache_dir=imagecraft_cache_dir,
|
444 |
+
local_files_only=False,
|
445 |
+
)
|
446 |
+
|
447 |
+
safetensors_files = glob.glob(os.path.join(model_path, "*.safetensors"))
|
448 |
+
|
449 |
+
for safetensors_file in safetensors_files:
|
450 |
+
with safe_open(safetensors_file, framework="pt", device="cpu") as f:
|
451 |
+
for key in f.keys():
|
452 |
+
state_dict[key] = f.get_tensor(key)
|
453 |
+
|
454 |
+
imagecraft_config = get_config()
|
455 |
+
|
456 |
+
model = cls(imagecraft_config).to(device)
|
457 |
+
|
458 |
+
# Load the state dict of the model
|
459 |
+
model.load_state_dict(state_dict, strict=False)
|
460 |
+
|
461 |
+
# Tie weights
|
462 |
+
model.tie_weights()
|
463 |
+
|
464 |
+
model = model.eval()
|
465 |
+
|
466 |
+
# Load voicecraft module
|
467 |
+
|
468 |
+
model.voicecraft_model = voicecraft.VoiceCraft.from_pretrained(
|
469 |
+
f"pyp1/VoiceCraft_{model.config.voicecraft_config.model_name.replace('.pth', '')}",
|
470 |
+
cache_dir=voicecraft_cache_dir,
|
471 |
+
)
|
472 |
+
|
473 |
+
encodec_fn = f"{voicecraft_cache_dir}/{model.config.voicecraft_config.encodec}"
|
474 |
+
|
475 |
+
if not os.path.exists(encodec_fn):
|
476 |
+
os.system(
|
477 |
+
f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/{model.config.voicecraft_config.encodec}"
|
478 |
+
)
|
479 |
+
os.system(f"mv {model.config.voicecraft_config.encodec} {encodec_fn}")
|
480 |
+
|
481 |
+
model.audio_tokenizer = AudioTokenizer(
|
482 |
+
signature=encodec_fn,
|
483 |
+
device=device,
|
484 |
+
)
|
485 |
+
|
486 |
+
model.text_tokenizer = TextTokenizer(backend="espeak")
|
487 |
+
|
488 |
+
model.voicecraft_model.to(device)
|
489 |
+
|
490 |
+
return model
|
src/model/modules/imagecraftconfig.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from src.model.modules.gemma import GemmaConfig
|
2 |
+
# from src.model.modules.siglip import SiglipVisionConfig
|
3 |
+
from src.model.modules.voicecraftconfig import VoiceCraftConfig
|
4 |
+
|
5 |
+
from transformers import SiglipVisionConfig, GemmaConfig, PretrainedConfig
|
6 |
+
|
7 |
+
|
8 |
+
class ImageCraftConfig(PretrainedConfig):
|
9 |
+
|
10 |
+
model_type = "imagecraft"
|
11 |
+
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
vision_config=None,
|
15 |
+
text_config=None,
|
16 |
+
voicecraft_config=None,
|
17 |
+
ignore_index=-100,
|
18 |
+
image_token_index=256000,
|
19 |
+
vocab_size=257152,
|
20 |
+
projection_dim=2048,
|
21 |
+
hidden_size=2048,
|
22 |
+
pad_token_id=None,
|
23 |
+
**kwargs
|
24 |
+
):
|
25 |
+
super().__init__()
|
26 |
+
self.ignore_index = ignore_index
|
27 |
+
self.image_token_index = image_token_index
|
28 |
+
self.vocab_size = vocab_size
|
29 |
+
self.projection_dim = projection_dim
|
30 |
+
self.hidden_size = hidden_size
|
31 |
+
self.is_encoder_decoder = False
|
32 |
+
|
33 |
+
self.pad_token_id = pad_token_id if pad_token_id is not None else -1
|
34 |
+
|
35 |
+
self.vision_config = SiglipVisionConfig(**vision_config)
|
36 |
+
|
37 |
+
self.text_config = GemmaConfig(**text_config, pad_token_id=pad_token_id)
|
38 |
+
self.vocab_size = self.text_config.vocab_size
|
39 |
+
|
40 |
+
self.text_config.num_image_tokens = (
|
41 |
+
self.vision_config.image_size // self.vision_config.patch_size
|
42 |
+
) ** 2
|
43 |
+
self.vision_config.projection_dim = projection_dim
|
44 |
+
|
45 |
+
self.voicecraft_config = VoiceCraftConfig(**voicecraft_config)
|
46 |
+
|
47 |
+
super().__init__(**kwargs)
|
src/model/modules/imagecraftprocessor.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
from src.utils.util import (
|
8 |
+
IMAGENET_STANDARD_MEAN,
|
9 |
+
IMAGENET_STANDARD_STD,
|
10 |
+
add_image_tokens_to_prompt,
|
11 |
+
process_images,
|
12 |
+
)
|
13 |
+
|
14 |
+
from transformers import SiglipImageProcessor
|
15 |
+
|
16 |
+
|
17 |
+
class ImageCraftProcessor:
|
18 |
+
|
19 |
+
IMAGE_TOKEN = "<image>"
|
20 |
+
|
21 |
+
def __init__(self, tokenizer, num_image_tokens: int, image_size: int):
|
22 |
+
super().__init__()
|
23 |
+
|
24 |
+
self.image_seq_length = num_image_tokens
|
25 |
+
self.image_size = image_size
|
26 |
+
|
27 |
+
# Tokenizer described here: https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md#tokenizer
|
28 |
+
tokens_to_add = {"additional_special_tokens": [self.IMAGE_TOKEN]}
|
29 |
+
tokenizer.add_special_tokens(tokens_to_add)
|
30 |
+
EXTRA_TOKENS = [
|
31 |
+
f"<loc{i:04d}>" for i in range(1024)
|
32 |
+
] # These tokens are used for object detection (bounding boxes)
|
33 |
+
EXTRA_TOKENS += [
|
34 |
+
f"<seg{i:03d}>" for i in range(128)
|
35 |
+
] # These tokens are used for object segmentation
|
36 |
+
tokenizer.add_tokens(EXTRA_TOKENS)
|
37 |
+
self.image_token_id = tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
|
38 |
+
# We will add the BOS and EOS tokens ourselves
|
39 |
+
tokenizer.add_bos_token = False
|
40 |
+
tokenizer.add_eos_token = False
|
41 |
+
|
42 |
+
self.tokenizer = tokenizer
|
43 |
+
# self.image_processor = SiglipImageProcessor.from_pretrained(
|
44 |
+
# "google/siglip-so400m-patch14-384"
|
45 |
+
# )
|
46 |
+
|
47 |
+
def __call__(
|
48 |
+
self,
|
49 |
+
text: List[str],
|
50 |
+
images: List[Image.Image],
|
51 |
+
padding: str = "longest",
|
52 |
+
truncation: bool = True,
|
53 |
+
) -> dict:
|
54 |
+
assert (
|
55 |
+
len(images) == 1 and len(text) == 1
|
56 |
+
), f"Received {len(images)} images for {len(text)} prompts."
|
57 |
+
|
58 |
+
# pixel_values = self.image_processor(images=images, return_tensors="pt")[
|
59 |
+
# "pixel_values"
|
60 |
+
# ]
|
61 |
+
pixel_values = process_images(
|
62 |
+
images,
|
63 |
+
size=(self.image_size, self.image_size),
|
64 |
+
resample=Image.Resampling.BICUBIC,
|
65 |
+
rescale_factor=1 / 255.0,
|
66 |
+
image_mean=IMAGENET_STANDARD_MEAN,
|
67 |
+
image_std=IMAGENET_STANDARD_STD,
|
68 |
+
)
|
69 |
+
# Convert the list of numpy arrays to a single numpy array with shape [Batch_Size, Channel, Height, Width]
|
70 |
+
pixel_values = np.stack(pixel_values, axis=0)
|
71 |
+
# Convert the numpy array to a PyTorch tensor
|
72 |
+
pixel_values = torch.tensor(pixel_values)
|
73 |
+
|
74 |
+
input_strings = [
|
75 |
+
add_image_tokens_to_prompt(
|
76 |
+
prefix_prompt=prompt,
|
77 |
+
bos_token=self.tokenizer.bos_token,
|
78 |
+
image_seq_length=self.image_seq_length,
|
79 |
+
image_token=self.IMAGE_TOKEN,
|
80 |
+
)
|
81 |
+
for prompt in text
|
82 |
+
]
|
83 |
+
|
84 |
+
# max_length += self.image_seq_length
|
85 |
+
|
86 |
+
inputs = self.tokenizer(
|
87 |
+
input_strings,
|
88 |
+
return_tensors="pt",
|
89 |
+
padding=padding,
|
90 |
+
max_length=512,
|
91 |
+
truncation=truncation,
|
92 |
+
)
|
93 |
+
|
94 |
+
return_data = {"pixel_values": pixel_values, **inputs}
|
95 |
+
|
96 |
+
return return_data
|
src/model/modules/kv_cache.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
+
import torch
|
3 |
+
|
4 |
+
class KVCache:
|
5 |
+
|
6 |
+
def __init__(self) -> None:
|
7 |
+
self.key_cache: List[torch.Tensor] = []
|
8 |
+
self.value_cache: List[torch.Tensor] = []
|
9 |
+
|
10 |
+
def num_items(self) -> int:
|
11 |
+
if len(self.key_cache) == 0:
|
12 |
+
return 0
|
13 |
+
else:
|
14 |
+
# The shape of the key_cache is [Batch_Size, Num_Heads_KV, Seq_Len, Head_Dim]
|
15 |
+
return self.key_cache[0].shape[-2]
|
16 |
+
|
17 |
+
def update(
|
18 |
+
self,
|
19 |
+
key_states: torch.Tensor,
|
20 |
+
value_states: torch.Tensor,
|
21 |
+
layer_idx: int,
|
22 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
23 |
+
if len(self.key_cache) <= layer_idx:
|
24 |
+
# If we never added anything to the KV-Cache of this layer, let's create it.
|
25 |
+
self.key_cache.append(key_states)
|
26 |
+
self.value_cache.append(value_states)
|
27 |
+
else:
|
28 |
+
# ... otherwise we concatenate the new keys with the existing ones.
|
29 |
+
# each tensor has shape: [Batch_Size, Num_Heads_KV, Seq_Len, Head_Dim]
|
30 |
+
self.key_cache[layer_idx] = torch.cat(
|
31 |
+
[self.key_cache[layer_idx], key_states], dim=-2
|
32 |
+
)
|
33 |
+
self.value_cache[layer_idx] = torch.cat(
|
34 |
+
[self.value_cache[layer_idx], value_states], dim=-2
|
35 |
+
)
|
36 |
+
|
37 |
+
# ... and then we return all the existing keys + the new ones.
|
38 |
+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
src/model/modules/sampling.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# cp from https://github.com/jasonppy/VoiceCraft/blob/master/models/modules/sampling.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
def top_k_top_p_filtering(
|
7 |
+
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
|
8 |
+
):
|
9 |
+
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
10 |
+
Args:
|
11 |
+
logits: logits distribution shape (batch size, vocabulary size)
|
12 |
+
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
13 |
+
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
14 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
15 |
+
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
16 |
+
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
17 |
+
"""
|
18 |
+
if top_k > 0:
|
19 |
+
top_k = min(
|
20 |
+
max(top_k, min_tokens_to_keep), logits.size(-1)
|
21 |
+
) # Safety check
|
22 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
23 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
24 |
+
logits[indices_to_remove] = filter_value
|
25 |
+
|
26 |
+
if top_p < 1.0:
|
27 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
28 |
+
cumulative_probs = torch.cumsum(
|
29 |
+
F.softmax(sorted_logits, dim=-1), dim=-1
|
30 |
+
)
|
31 |
+
|
32 |
+
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
33 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
34 |
+
if min_tokens_to_keep > 1:
|
35 |
+
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
36 |
+
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
37 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
38 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
|
39 |
+
..., :-1
|
40 |
+
].clone()
|
41 |
+
sorted_indices_to_remove[..., 0] = 0
|
42 |
+
|
43 |
+
# scatter sorted tensors to original indexing
|
44 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
45 |
+
1, sorted_indices, sorted_indices_to_remove
|
46 |
+
)
|
47 |
+
logits[indices_to_remove] = filter_value
|
48 |
+
return logits
|
49 |
+
|
50 |
+
def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
|
51 |
+
# temperature: (`optional`) float
|
52 |
+
# The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
|
53 |
+
# top_k: (`optional`) int
|
54 |
+
# The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
|
55 |
+
# top_p: (`optional`) float
|
56 |
+
# The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
|
57 |
+
|
58 |
+
# Temperature (higher temperature => more likely to sample low probability tokens)
|
59 |
+
if temperature != 1.0:
|
60 |
+
logits = logits / temperature
|
61 |
+
# Top-p/top-k filtering
|
62 |
+
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
63 |
+
# Sample
|
64 |
+
token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
65 |
+
return token
|
src/model/modules/scaling.py
ADDED
@@ -0,0 +1,1391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/scaling.py
|
2 |
+
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
|
3 |
+
#
|
4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
|
19 |
+
import collections
|
20 |
+
import logging
|
21 |
+
import random
|
22 |
+
import math
|
23 |
+
from functools import reduce
|
24 |
+
from itertools import repeat
|
25 |
+
from typing import Optional, Tuple, Union
|
26 |
+
|
27 |
+
import torch
|
28 |
+
import torch.nn as nn
|
29 |
+
import torch.nn.functional as F
|
30 |
+
from torch import Tensor
|
31 |
+
from torch.nn import Embedding as ScaledEmbedding
|
32 |
+
|
33 |
+
# from valle.utils import Transpose
|
34 |
+
|
35 |
+
class Transpose(nn.Identity):
|
36 |
+
"""(N, T, D) -> (N, D, T)"""
|
37 |
+
|
38 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
39 |
+
return input.transpose(1, 2)
|
40 |
+
|
41 |
+
class ActivationBalancerFunction(torch.autograd.Function):
|
42 |
+
@staticmethod
|
43 |
+
def forward(
|
44 |
+
ctx,
|
45 |
+
x: Tensor,
|
46 |
+
scale_factor: Tensor,
|
47 |
+
sign_factor: Optional[Tensor],
|
48 |
+
channel_dim: int,
|
49 |
+
) -> Tensor:
|
50 |
+
if channel_dim < 0:
|
51 |
+
channel_dim += x.ndim
|
52 |
+
ctx.channel_dim = channel_dim
|
53 |
+
xgt0 = x > 0
|
54 |
+
if sign_factor is None:
|
55 |
+
ctx.save_for_backward(xgt0, scale_factor)
|
56 |
+
else:
|
57 |
+
ctx.save_for_backward(xgt0, scale_factor, sign_factor)
|
58 |
+
return x
|
59 |
+
|
60 |
+
@staticmethod
|
61 |
+
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
|
62 |
+
if len(ctx.saved_tensors) == 3:
|
63 |
+
xgt0, scale_factor, sign_factor = ctx.saved_tensors
|
64 |
+
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
65 |
+
scale_factor = scale_factor.unsqueeze(-1)
|
66 |
+
sign_factor = sign_factor.unsqueeze(-1)
|
67 |
+
factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
68 |
+
else:
|
69 |
+
xgt0, scale_factor = ctx.saved_tensors
|
70 |
+
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
71 |
+
scale_factor = scale_factor.unsqueeze(-1)
|
72 |
+
factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
73 |
+
neg_delta_grad = x_grad.abs() * factor
|
74 |
+
return (
|
75 |
+
x_grad - neg_delta_grad,
|
76 |
+
None,
|
77 |
+
None,
|
78 |
+
None,
|
79 |
+
)
|
80 |
+
|
81 |
+
|
82 |
+
def _compute_scale_factor(
|
83 |
+
x: Tensor,
|
84 |
+
channel_dim: int,
|
85 |
+
min_abs: float,
|
86 |
+
max_abs: float,
|
87 |
+
gain_factor: float,
|
88 |
+
max_factor: float,
|
89 |
+
) -> Tensor:
|
90 |
+
if channel_dim < 0:
|
91 |
+
channel_dim += x.ndim
|
92 |
+
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
93 |
+
x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
|
94 |
+
|
95 |
+
if min_abs == 0.0:
|
96 |
+
below_threshold = 0.0
|
97 |
+
else:
|
98 |
+
# below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
|
99 |
+
# x_abs)_mean , min_abs.
|
100 |
+
below_threshold = (
|
101 |
+
(min_abs - x_abs_mean) * (gain_factor / min_abs)
|
102 |
+
).clamp(min=0, max=max_factor)
|
103 |
+
|
104 |
+
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
|
105 |
+
min=0, max=max_factor
|
106 |
+
)
|
107 |
+
|
108 |
+
return below_threshold - above_threshold
|
109 |
+
|
110 |
+
|
111 |
+
def _compute_sign_factor(
|
112 |
+
x: Tensor,
|
113 |
+
channel_dim: int,
|
114 |
+
min_positive: float,
|
115 |
+
max_positive: float,
|
116 |
+
gain_factor: float,
|
117 |
+
max_factor: float,
|
118 |
+
) -> Tensor:
|
119 |
+
if channel_dim < 0:
|
120 |
+
channel_dim += x.ndim
|
121 |
+
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
122 |
+
proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims)
|
123 |
+
if min_positive == 0.0:
|
124 |
+
factor1 = 0.0
|
125 |
+
else:
|
126 |
+
# 0 if proportion_positive >= min_positive, else can be
|
127 |
+
# as large as max_factor.
|
128 |
+
factor1 = (
|
129 |
+
(min_positive - proportion_positive) * (gain_factor / min_positive)
|
130 |
+
).clamp_(min=0, max=max_factor)
|
131 |
+
|
132 |
+
if max_positive == 1.0:
|
133 |
+
factor2 = 0.0
|
134 |
+
else:
|
135 |
+
# 0 if self.proportion_positive <= max_positive, else can be
|
136 |
+
# as large as -max_factor.
|
137 |
+
factor2 = (
|
138 |
+
(proportion_positive - max_positive)
|
139 |
+
* (gain_factor / (1.0 - max_positive))
|
140 |
+
).clamp_(min=0, max=max_factor)
|
141 |
+
sign_factor = factor1 - factor2
|
142 |
+
# require min_positive != 0 or max_positive != 1:
|
143 |
+
assert not isinstance(sign_factor, float)
|
144 |
+
return sign_factor
|
145 |
+
|
146 |
+
|
147 |
+
class ActivationScaleBalancerFunction(torch.autograd.Function):
|
148 |
+
"""
|
149 |
+
This object is used in class ActivationBalancer when the user specified
|
150 |
+
min_positive=0, max_positive=1, so there are no constraints on the signs
|
151 |
+
of the activations and only the absolute value has a constraint.
|
152 |
+
"""
|
153 |
+
|
154 |
+
@staticmethod
|
155 |
+
def forward(
|
156 |
+
ctx,
|
157 |
+
x: Tensor,
|
158 |
+
sign_factor: Tensor,
|
159 |
+
scale_factor: Tensor,
|
160 |
+
channel_dim: int,
|
161 |
+
) -> Tensor:
|
162 |
+
if channel_dim < 0:
|
163 |
+
channel_dim += x.ndim
|
164 |
+
ctx.channel_dim = channel_dim
|
165 |
+
xgt0 = x > 0
|
166 |
+
ctx.save_for_backward(xgt0, sign_factor, scale_factor)
|
167 |
+
return x
|
168 |
+
|
169 |
+
@staticmethod
|
170 |
+
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
|
171 |
+
xgt0, sign_factor, scale_factor = ctx.saved_tensors
|
172 |
+
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
173 |
+
sign_factor = sign_factor.unsqueeze(-1)
|
174 |
+
scale_factor = scale_factor.unsqueeze(-1)
|
175 |
+
|
176 |
+
factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
177 |
+
neg_delta_grad = x_grad.abs() * factor
|
178 |
+
return (
|
179 |
+
x_grad - neg_delta_grad,
|
180 |
+
None,
|
181 |
+
None,
|
182 |
+
None,
|
183 |
+
)
|
184 |
+
|
185 |
+
|
186 |
+
class RandomClampFunction(torch.autograd.Function):
|
187 |
+
@staticmethod
|
188 |
+
def forward(
|
189 |
+
ctx,
|
190 |
+
x: Tensor,
|
191 |
+
min: Optional[float],
|
192 |
+
max: Optional[float],
|
193 |
+
prob: float,
|
194 |
+
reflect: float,
|
195 |
+
) -> Tensor:
|
196 |
+
x_clamped = torch.clamp(x, min=min, max=max)
|
197 |
+
mask = torch.rand_like(x) < prob
|
198 |
+
ans = torch.where(mask, x_clamped, x)
|
199 |
+
if x.requires_grad:
|
200 |
+
ctx.save_for_backward(ans == x)
|
201 |
+
ctx.reflect = reflect
|
202 |
+
if reflect != 0.0:
|
203 |
+
ans = ans * (1.0 + reflect) - (x * reflect)
|
204 |
+
return ans
|
205 |
+
|
206 |
+
@staticmethod
|
207 |
+
def backward(
|
208 |
+
ctx, ans_grad: Tensor
|
209 |
+
) -> Tuple[Tensor, None, None, None, None]:
|
210 |
+
(is_same,) = ctx.saved_tensors
|
211 |
+
x_grad = ans_grad * is_same.to(ans_grad.dtype)
|
212 |
+
reflect = ctx.reflect
|
213 |
+
if reflect != 0.0:
|
214 |
+
x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect)
|
215 |
+
return x_grad, None, None, None, None
|
216 |
+
|
217 |
+
|
218 |
+
def random_clamp(
|
219 |
+
x: Tensor,
|
220 |
+
min: Optional[float] = None,
|
221 |
+
max: Optional[float] = None,
|
222 |
+
prob: float = 0.5,
|
223 |
+
reflect: float = 0.0,
|
224 |
+
):
|
225 |
+
return RandomClampFunction.apply(x, min, max, prob, reflect)
|
226 |
+
|
227 |
+
|
228 |
+
def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor:
|
229 |
+
"""
|
230 |
+
A randomized way of casting a floating point value to half precision.
|
231 |
+
"""
|
232 |
+
if x.dtype == torch.float16:
|
233 |
+
return x
|
234 |
+
x_abs = x.abs()
|
235 |
+
is_too_small = x_abs < min_abs
|
236 |
+
# for elements where is_too_small is true, random_val will contain +-min_abs with
|
237 |
+
# probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations,
|
238 |
+
# for those elements].
|
239 |
+
random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
|
240 |
+
return torch.where(is_too_small, random_val, x).to(torch.float16)
|
241 |
+
|
242 |
+
|
243 |
+
class RandomGradFunction(torch.autograd.Function):
|
244 |
+
"""
|
245 |
+
Does nothing in forward pass; in backward pass, gets rid of very small grads using
|
246 |
+
randomized approach that preserves expectations (intended to reduce roundoff).
|
247 |
+
"""
|
248 |
+
|
249 |
+
@staticmethod
|
250 |
+
def forward(ctx, x: Tensor, min_abs: float) -> Tensor:
|
251 |
+
ctx.min_abs = min_abs
|
252 |
+
return x
|
253 |
+
|
254 |
+
@staticmethod
|
255 |
+
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]:
|
256 |
+
if ans_grad.dtype == torch.float16:
|
257 |
+
return (
|
258 |
+
random_cast_to_half(
|
259 |
+
ans_grad.to(torch.float32), min_abs=ctx.min_abs
|
260 |
+
),
|
261 |
+
None,
|
262 |
+
)
|
263 |
+
else:
|
264 |
+
return ans_grad, None
|
265 |
+
|
266 |
+
|
267 |
+
class RandomGrad(torch.nn.Module):
|
268 |
+
"""
|
269 |
+
Gets rid of very small gradients using an expectation-preserving method, intended to increase
|
270 |
+
accuracy of training when using amp (automatic mixed precision)
|
271 |
+
"""
|
272 |
+
|
273 |
+
def __init__(self, min_abs: float = 5.0e-06):
|
274 |
+
super(RandomGrad, self).__init__()
|
275 |
+
self.min_abs = min_abs
|
276 |
+
|
277 |
+
def forward(self, x: Tensor):
|
278 |
+
if (
|
279 |
+
torch.jit.is_scripting()
|
280 |
+
or not self.training
|
281 |
+
or torch.jit.is_tracing()
|
282 |
+
):
|
283 |
+
return x
|
284 |
+
else:
|
285 |
+
return RandomGradFunction.apply(x, self.min_abs)
|
286 |
+
|
287 |
+
|
288 |
+
class SoftmaxFunction(torch.autograd.Function):
|
289 |
+
"""
|
290 |
+
Tries to handle half-precision derivatives in a randomized way that should
|
291 |
+
be more accurate for training than the default behavior.
|
292 |
+
"""
|
293 |
+
|
294 |
+
@staticmethod
|
295 |
+
def forward(ctx, x: Tensor, dim: int):
|
296 |
+
ans = x.softmax(dim=dim)
|
297 |
+
# if x dtype is float16, x.softmax() returns a float32 because
|
298 |
+
# (presumably) that op does not support float16, and autocast
|
299 |
+
# is enabled.
|
300 |
+
if torch.is_autocast_enabled():
|
301 |
+
ans = ans.to(torch.float16)
|
302 |
+
ctx.save_for_backward(ans)
|
303 |
+
ctx.x_dtype = x.dtype
|
304 |
+
ctx.dim = dim
|
305 |
+
return ans
|
306 |
+
|
307 |
+
@staticmethod
|
308 |
+
def backward(ctx, ans_grad: Tensor):
|
309 |
+
(ans,) = ctx.saved_tensors
|
310 |
+
with torch.cuda.amp.autocast(enabled=False):
|
311 |
+
ans_grad = ans_grad.to(torch.float32)
|
312 |
+
ans = ans.to(torch.float32)
|
313 |
+
x_grad = ans_grad * ans
|
314 |
+
x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
|
315 |
+
return x_grad, None
|
316 |
+
|
317 |
+
|
318 |
+
def softmax(x: Tensor, dim: int):
|
319 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
320 |
+
return x.softmax(dim)
|
321 |
+
|
322 |
+
return SoftmaxFunction.apply(x, dim)
|
323 |
+
|
324 |
+
|
325 |
+
class MaxEigLimiterFunction(torch.autograd.Function):
|
326 |
+
@staticmethod
|
327 |
+
def forward(
|
328 |
+
ctx,
|
329 |
+
x: Tensor,
|
330 |
+
coeffs: Tensor,
|
331 |
+
direction: Tensor,
|
332 |
+
channel_dim: int,
|
333 |
+
grad_scale: float,
|
334 |
+
) -> Tensor:
|
335 |
+
ctx.channel_dim = channel_dim
|
336 |
+
ctx.grad_scale = grad_scale
|
337 |
+
ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach())
|
338 |
+
return x
|
339 |
+
|
340 |
+
@staticmethod
|
341 |
+
def backward(ctx, x_grad, *args):
|
342 |
+
with torch.enable_grad():
|
343 |
+
(x_orig, coeffs, new_direction) = ctx.saved_tensors
|
344 |
+
x_orig.requires_grad = True
|
345 |
+
num_channels = x_orig.shape[ctx.channel_dim]
|
346 |
+
x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
|
347 |
+
new_direction.requires_grad = False
|
348 |
+
x = x - x.mean(dim=0)
|
349 |
+
x_var = (x ** 2).mean()
|
350 |
+
x_residual = x - coeffs * new_direction
|
351 |
+
x_residual_var = (x_residual ** 2).mean()
|
352 |
+
# `variance_proportion` is the proportion of the variance accounted for
|
353 |
+
# by the top eigen-direction. This is to be minimized.
|
354 |
+
variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
|
355 |
+
variance_proportion.backward()
|
356 |
+
x_orig_grad = x_orig.grad
|
357 |
+
x_extra_grad = (
|
358 |
+
x_orig.grad
|
359 |
+
* ctx.grad_scale
|
360 |
+
* x_grad.norm()
|
361 |
+
/ (x_orig_grad.norm() + 1.0e-20)
|
362 |
+
)
|
363 |
+
return x_grad + x_extra_grad.detach(), None, None, None, None
|
364 |
+
|
365 |
+
|
366 |
+
class BasicNorm(torch.nn.Module):
|
367 |
+
"""
|
368 |
+
This is intended to be a simpler, and hopefully cheaper, replacement for
|
369 |
+
LayerNorm. The observation this is based on, is that Transformer-type
|
370 |
+
networks, especially with pre-norm, sometimes seem to set one of the
|
371 |
+
feature dimensions to a large constant value (e.g. 50), which "defeats"
|
372 |
+
the LayerNorm because the output magnitude is then not strongly dependent
|
373 |
+
on the other (useful) features. Presumably the weight and bias of the
|
374 |
+
LayerNorm are required to allow it to do this.
|
375 |
+
So the idea is to introduce this large constant value as an explicit
|
376 |
+
parameter, that takes the role of the "eps" in LayerNorm, so the network
|
377 |
+
doesn't have to do this trick. We make the "eps" learnable.
|
378 |
+
Args:
|
379 |
+
num_channels: the number of channels, e.g. 512.
|
380 |
+
channel_dim: the axis/dimension corresponding to the channel,
|
381 |
+
interprted as an offset from the input's ndim if negative.
|
382 |
+
shis is NOT the num_channels; it should typically be one of
|
383 |
+
{-2, -1, 0, 1, 2, 3}.
|
384 |
+
eps: the initial "epsilon" that we add as ballast in:
|
385 |
+
scale = ((input_vec**2).mean() + epsilon)**-0.5
|
386 |
+
Note: our epsilon is actually large, but we keep the name
|
387 |
+
to indicate the connection with conventional LayerNorm.
|
388 |
+
learn_eps: if true, we learn epsilon; if false, we keep it
|
389 |
+
at the initial value.
|
390 |
+
eps_min: float
|
391 |
+
eps_max: float
|
392 |
+
"""
|
393 |
+
|
394 |
+
def __init__(
|
395 |
+
self,
|
396 |
+
num_channels: int,
|
397 |
+
channel_dim: int = -1, # CAUTION: see documentation.
|
398 |
+
eps: float = 0.25,
|
399 |
+
learn_eps: bool = True,
|
400 |
+
eps_min: float = -3.0,
|
401 |
+
eps_max: float = 3.0,
|
402 |
+
) -> None:
|
403 |
+
super(BasicNorm, self).__init__()
|
404 |
+
self.num_channels = num_channels
|
405 |
+
self.channel_dim = channel_dim
|
406 |
+
if learn_eps:
|
407 |
+
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
|
408 |
+
else:
|
409 |
+
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
410 |
+
self.eps_min = eps_min
|
411 |
+
self.eps_max = eps_max
|
412 |
+
|
413 |
+
def forward(self, x: Tensor) -> Tensor:
|
414 |
+
assert x.shape[self.channel_dim] == self.num_channels
|
415 |
+
eps = self.eps
|
416 |
+
if self.training and random.random() < 0.25:
|
417 |
+
# with probability 0.25, in training mode, clamp eps between the min
|
418 |
+
# and max; this will encourage it to learn parameters within the
|
419 |
+
# allowed range by making parameters that are outside the allowed
|
420 |
+
# range noisy.
|
421 |
+
|
422 |
+
# gradients to allow the parameter to get back into the allowed region if it happens to exit it.
|
423 |
+
eps = eps.clamp(min=self.eps_min, max=self.eps_max)
|
424 |
+
scales = (
|
425 |
+
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp()
|
426 |
+
) ** -0.5
|
427 |
+
return x * scales
|
428 |
+
|
429 |
+
|
430 |
+
def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
|
431 |
+
"""
|
432 |
+
Behaves like a constructor of a modified version of nn.Linear
|
433 |
+
that gives an easy way to set the default initial parameter scale.
|
434 |
+
Args:
|
435 |
+
Accepts the standard args and kwargs that nn.Linear accepts
|
436 |
+
e.g. in_features, out_features, bias=False.
|
437 |
+
initial_scale: you can override this if you want to increase
|
438 |
+
or decrease the initial magnitude of the module's output
|
439 |
+
(affects the initialization of weight_scale and bias_scale).
|
440 |
+
Another option, if you want to do something like this, is
|
441 |
+
to re-initialize the parameters.
|
442 |
+
"""
|
443 |
+
ans = nn.Linear(*args, **kwargs)
|
444 |
+
with torch.no_grad():
|
445 |
+
ans.weight[:] *= initial_scale
|
446 |
+
if ans.bias is not None:
|
447 |
+
torch.nn.init.uniform_(
|
448 |
+
ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
|
449 |
+
)
|
450 |
+
return ans
|
451 |
+
|
452 |
+
|
453 |
+
def ScaledConv1d(
|
454 |
+
*args,
|
455 |
+
initial_scale: float = 1.0,
|
456 |
+
kernel_size: int = 3,
|
457 |
+
padding: str = "same",
|
458 |
+
**kwargs,
|
459 |
+
) -> nn.Conv1d:
|
460 |
+
"""
|
461 |
+
Behaves like a constructor of a modified version of nn.Conv1d
|
462 |
+
that gives an easy way to set the default initial parameter scale.
|
463 |
+
Args:
|
464 |
+
Accepts the standard args and kwargs that nn.Linear accepts
|
465 |
+
e.g. in_features, out_features, bias=False.
|
466 |
+
initial_scale: you can override this if you want to increase
|
467 |
+
or decrease the initial magnitude of the module's output
|
468 |
+
(affects the initialization of weight_scale and bias_scale).
|
469 |
+
Another option, if you want to do something like this, is
|
470 |
+
to re-initialize the parameters.
|
471 |
+
"""
|
472 |
+
ans = nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs)
|
473 |
+
with torch.no_grad():
|
474 |
+
ans.weight[:] *= initial_scale
|
475 |
+
if ans.bias is not None:
|
476 |
+
torch.nn.init.uniform_(
|
477 |
+
ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
|
478 |
+
)
|
479 |
+
return ans
|
480 |
+
|
481 |
+
|
482 |
+
def TransposeScaledConv1d(
|
483 |
+
*args,
|
484 |
+
initial_scale: float = 1.0,
|
485 |
+
kernel_size: int = 3,
|
486 |
+
padding: str = "same",
|
487 |
+
**kwargs,
|
488 |
+
) -> nn.Sequential:
|
489 |
+
"""
|
490 |
+
Transpose -> ScaledConv1d
|
491 |
+
"""
|
492 |
+
return nn.Sequential(
|
493 |
+
Transpose(),
|
494 |
+
ScaledConv1d(
|
495 |
+
*args,
|
496 |
+
initial_scale=initial_scale,
|
497 |
+
kernel_size=kernel_size,
|
498 |
+
padding=padding,
|
499 |
+
**kwargs,
|
500 |
+
),
|
501 |
+
)
|
502 |
+
|
503 |
+
|
504 |
+
def ScaledConv1dTranspose(
|
505 |
+
*args,
|
506 |
+
initial_scale: float = 1.0,
|
507 |
+
kernel_size: int = 3,
|
508 |
+
padding: str = "same",
|
509 |
+
**kwargs,
|
510 |
+
) -> nn.Sequential:
|
511 |
+
"""
|
512 |
+
Transpose -> ScaledConv1d
|
513 |
+
"""
|
514 |
+
return nn.Sequential(
|
515 |
+
ScaledConv1d(
|
516 |
+
*args,
|
517 |
+
initial_scale=initial_scale,
|
518 |
+
kernel_size=kernel_size,
|
519 |
+
padding=padding,
|
520 |
+
**kwargs,
|
521 |
+
),
|
522 |
+
Transpose(),
|
523 |
+
)
|
524 |
+
|
525 |
+
|
526 |
+
def TransposeConv1d(
|
527 |
+
*args, kernel_size: int = 3, padding: str = "same", **kwargs
|
528 |
+
) -> nn.Sequential:
|
529 |
+
"""
|
530 |
+
Transpose -> Conv1d
|
531 |
+
"""
|
532 |
+
return nn.Sequential(
|
533 |
+
Transpose(),
|
534 |
+
nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
|
535 |
+
)
|
536 |
+
|
537 |
+
|
538 |
+
def Conv1dTranspose(
|
539 |
+
*args, kernel_size: int = 3, padding: str = "same", **kwargs
|
540 |
+
) -> nn.Sequential:
|
541 |
+
"""
|
542 |
+
ScaledConv1d -> Transpose
|
543 |
+
"""
|
544 |
+
return nn.Sequential(
|
545 |
+
nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
|
546 |
+
Transpose(),
|
547 |
+
)
|
548 |
+
|
549 |
+
|
550 |
+
class SRLinear(nn.Linear):
|
551 |
+
"""https://arxiv.org/abs/2303.06296
|
552 |
+
Stabilizing Transformer Training by Preventing Attention Entropy Collapse
|
553 |
+
"""
|
554 |
+
|
555 |
+
def __init__(self, in_features, out_features, bias=True, **kwargs):
|
556 |
+
super().__init__(in_features, out_features, bias=bias, **kwargs)
|
557 |
+
self.register_buffer(
|
558 |
+
"u", nn.functional.normalize(torch.randn(in_features), dim=0)
|
559 |
+
)
|
560 |
+
with torch.no_grad():
|
561 |
+
sigma = self.get_sigma()
|
562 |
+
self.register_buffer("spectral_norm", sigma)
|
563 |
+
self.sigma = nn.Parameter(torch.ones(1))
|
564 |
+
|
565 |
+
def get_sigma(self):
|
566 |
+
with torch.no_grad():
|
567 |
+
u = self.u
|
568 |
+
v = self.weight.mv(u)
|
569 |
+
v = nn.functional.normalize(v, dim=0)
|
570 |
+
u = self.weight.T.mv(v)
|
571 |
+
u = nn.functional.normalize(u, dim=0)
|
572 |
+
self.u.data.copy_(u)
|
573 |
+
return torch.einsum("c,cd,d->", v, self.weight, u)
|
574 |
+
|
575 |
+
def get_weight(self):
|
576 |
+
sigma = self.get_sigma()
|
577 |
+
if self.training:
|
578 |
+
self.spectral_norm.data.copy_(sigma)
|
579 |
+
weight = (self.sigma / sigma) * self.weight
|
580 |
+
return weight
|
581 |
+
|
582 |
+
def forward(self, x):
|
583 |
+
return nn.functional.linear(x, self.get_weight(), self.bias)
|
584 |
+
|
585 |
+
|
586 |
+
class SRConv1d(SRLinear):
|
587 |
+
def __init__(
|
588 |
+
self,
|
589 |
+
in_features,
|
590 |
+
out_features,
|
591 |
+
kernel_size,
|
592 |
+
stride: int = 1,
|
593 |
+
padding: str = "same",
|
594 |
+
bias: bool = True,
|
595 |
+
**kwargs,
|
596 |
+
):
|
597 |
+
in_features = in_features * kernel_size
|
598 |
+
super().__init__(in_features, out_features, bias=bias, **kwargs)
|
599 |
+
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
600 |
+
self.kernel_size = kernel_size
|
601 |
+
self.stride = stride
|
602 |
+
self.padding = padding
|
603 |
+
|
604 |
+
def forward(self, x):
|
605 |
+
in_features = self.in_features // self.kernel_size
|
606 |
+
weight = self.get_weight().view(
|
607 |
+
self.out_features, in_features, self.kernel_size
|
608 |
+
)
|
609 |
+
return nn.functional.conv1d(
|
610 |
+
x, weight, bias=self.bias, stride=self.stride, padding=self.padding
|
611 |
+
)
|
612 |
+
|
613 |
+
|
614 |
+
def TransposeSRConv1d(
|
615 |
+
*args, kernel_size: int = 3, padding: str = "same", **kwargs
|
616 |
+
) -> nn.Sequential:
|
617 |
+
"""
|
618 |
+
Transpose -> SRConv1d
|
619 |
+
"""
|
620 |
+
return nn.Sequential(
|
621 |
+
Transpose(),
|
622 |
+
SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
|
623 |
+
)
|
624 |
+
|
625 |
+
|
626 |
+
def SRConv1dTranspose(
|
627 |
+
*args, kernel_size: int = 3, padding: str = "same", **kwargs
|
628 |
+
) -> nn.Sequential:
|
629 |
+
"""
|
630 |
+
SRConv1d -> Transpose
|
631 |
+
"""
|
632 |
+
return nn.Sequential(
|
633 |
+
SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
|
634 |
+
Transpose(),
|
635 |
+
)
|
636 |
+
|
637 |
+
|
638 |
+
class ActivationBalancer(torch.nn.Module):
|
639 |
+
"""
|
640 |
+
Modifies the backpropped derivatives of a function to try to encourage, for
|
641 |
+
each channel, that it is positive at least a proportion `threshold` of the
|
642 |
+
time. It does this by multiplying negative derivative values by up to
|
643 |
+
(1+max_factor), and positive derivative values by up to (1-max_factor),
|
644 |
+
interpolated from 1 at the threshold to those extremal values when none
|
645 |
+
of the inputs are positive.
|
646 |
+
Args:
|
647 |
+
num_channels: the number of channels
|
648 |
+
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
649 |
+
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
650 |
+
min_positive: the minimum, per channel, of the proportion of the time
|
651 |
+
that (x > 0), below which we start to modify the derivatives.
|
652 |
+
max_positive: the maximum, per channel, of the proportion of the time
|
653 |
+
that (x > 0), above which we start to modify the derivatives.
|
654 |
+
max_factor: the maximum factor by which we modify the derivatives for
|
655 |
+
either the sign constraint or the magnitude constraint;
|
656 |
+
e.g. with max_factor=0.02, the the derivatives would be multiplied by
|
657 |
+
values in the range [0.98..1.02].
|
658 |
+
sign_gain_factor: determines the 'gain' with which we increase the
|
659 |
+
change in gradient once the constraints on min_positive and max_positive
|
660 |
+
are violated.
|
661 |
+
scale_gain_factor: determines the 'gain' with which we increase the
|
662 |
+
change in gradient once the constraints on min_abs and max_abs
|
663 |
+
are violated.
|
664 |
+
min_abs: the minimum average-absolute-value difference from the mean
|
665 |
+
value per channel, which we allow, before we start to modify
|
666 |
+
the derivatives to prevent this.
|
667 |
+
max_abs: the maximum average-absolute-value difference from the mean
|
668 |
+
value per channel, which we allow, before we start to modify
|
669 |
+
the derivatives to prevent this.
|
670 |
+
min_prob: determines the minimum probability with which we modify the
|
671 |
+
gradients for the {min,max}_positive and {min,max}_abs constraints,
|
672 |
+
on each forward(). This is done randomly to prevent all layers
|
673 |
+
from doing it at the same time. Early in training we may use
|
674 |
+
higher probabilities than this; it will decay to this value.
|
675 |
+
"""
|
676 |
+
|
677 |
+
def __init__(
|
678 |
+
self,
|
679 |
+
num_channels: int,
|
680 |
+
channel_dim: int,
|
681 |
+
min_positive: float = 0.05,
|
682 |
+
max_positive: float = 0.95,
|
683 |
+
max_factor: float = 0.04,
|
684 |
+
sign_gain_factor: float = 0.01,
|
685 |
+
scale_gain_factor: float = 0.02,
|
686 |
+
min_abs: float = 0.2,
|
687 |
+
max_abs: float = 100.0,
|
688 |
+
min_prob: float = 0.1,
|
689 |
+
):
|
690 |
+
super(ActivationBalancer, self).__init__()
|
691 |
+
self.num_channels = num_channels
|
692 |
+
self.channel_dim = channel_dim
|
693 |
+
self.min_positive = min_positive
|
694 |
+
self.max_positive = max_positive
|
695 |
+
self.max_factor = max_factor
|
696 |
+
self.min_abs = min_abs
|
697 |
+
self.max_abs = max_abs
|
698 |
+
self.min_prob = min_prob
|
699 |
+
self.sign_gain_factor = sign_gain_factor
|
700 |
+
self.scale_gain_factor = scale_gain_factor
|
701 |
+
|
702 |
+
# count measures how many times the forward() function has been called.
|
703 |
+
# We occasionally sync this to a tensor called `count`, that exists to
|
704 |
+
# make sure it is synced to disk when we load and save the model.
|
705 |
+
self.cpu_count = 0
|
706 |
+
self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
|
707 |
+
|
708 |
+
def forward(self, x: Tensor) -> Tensor:
|
709 |
+
if (
|
710 |
+
torch.jit.is_scripting()
|
711 |
+
or not x.requires_grad
|
712 |
+
or torch.jit.is_tracing()
|
713 |
+
):
|
714 |
+
return _no_op(x)
|
715 |
+
|
716 |
+
count = self.cpu_count
|
717 |
+
self.cpu_count += 1
|
718 |
+
|
719 |
+
if random.random() < 0.01:
|
720 |
+
# Occasionally sync self.cpu_count with self.count.
|
721 |
+
# count affects the decay of 'prob'. don't do this on every iter,
|
722 |
+
# because syncing with the GPU is slow.
|
723 |
+
self.cpu_count = max(self.cpu_count, self.count.item())
|
724 |
+
self.count.fill_(self.cpu_count)
|
725 |
+
|
726 |
+
# the prob of doing some work exponentially decreases from 0.5 till it hits
|
727 |
+
# a floor at min_prob (==0.1, by default)
|
728 |
+
prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
|
729 |
+
|
730 |
+
if random.random() < prob:
|
731 |
+
sign_gain_factor = 0.5
|
732 |
+
if self.min_positive != 0.0 or self.max_positive != 1.0:
|
733 |
+
sign_factor = _compute_sign_factor(
|
734 |
+
x,
|
735 |
+
self.channel_dim,
|
736 |
+
self.min_positive,
|
737 |
+
self.max_positive,
|
738 |
+
gain_factor=self.sign_gain_factor / prob,
|
739 |
+
max_factor=self.max_factor,
|
740 |
+
)
|
741 |
+
else:
|
742 |
+
sign_factor = None
|
743 |
+
|
744 |
+
scale_factor = _compute_scale_factor(
|
745 |
+
x.detach(),
|
746 |
+
self.channel_dim,
|
747 |
+
min_abs=self.min_abs,
|
748 |
+
max_abs=self.max_abs,
|
749 |
+
gain_factor=self.scale_gain_factor / prob,
|
750 |
+
max_factor=self.max_factor,
|
751 |
+
)
|
752 |
+
return ActivationBalancerFunction.apply(
|
753 |
+
x,
|
754 |
+
scale_factor,
|
755 |
+
sign_factor,
|
756 |
+
self.channel_dim,
|
757 |
+
)
|
758 |
+
else:
|
759 |
+
return _no_op(x)
|
760 |
+
|
761 |
+
|
762 |
+
def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
|
763 |
+
"""
|
764 |
+
Returns x unmodified, but in backprop will put a penalty for the excess of
|
765 |
+
the absolute values of elements of x over the limit "limit". E.g. if
|
766 |
+
limit == 10.0, then if x has any values over 10 it will get a penalty.
|
767 |
+
Caution: the value of this penalty will be affected by grad scaling used
|
768 |
+
in automatic mixed precision training. For this reasons we use this,
|
769 |
+
it shouldn't really matter, or may even be helpful; we just use this
|
770 |
+
to disallow really implausible values of scores to be given to softmax.
|
771 |
+
"""
|
772 |
+
x_sign = x.sign()
|
773 |
+
over_limit = (x.abs() - limit) > 0
|
774 |
+
# The following is a memory efficient way to penalize the absolute values of
|
775 |
+
# x that's over the limit. (The memory efficiency comes when you think
|
776 |
+
# about which items torch needs to cache for the autograd, and which ones it
|
777 |
+
# can throw away). The numerical value of aux_loss as computed here will
|
778 |
+
# actually be larger than it should be, by limit * over_limit.sum(), but it
|
779 |
+
# has the same derivative as the real aux_loss which is penalty * (x.abs() -
|
780 |
+
# limit).relu().
|
781 |
+
aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
|
782 |
+
# note: we don't do sum() here on aux)_loss, but it's as if we had done
|
783 |
+
# sum() due to how with_loss() works.
|
784 |
+
x = with_loss(x, aux_loss)
|
785 |
+
# you must use x for something, or this will be ineffective.
|
786 |
+
return x
|
787 |
+
|
788 |
+
|
789 |
+
def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
|
790 |
+
if x.ndim == 2:
|
791 |
+
return x.diag()
|
792 |
+
else:
|
793 |
+
(batch, dim, dim) = x.shape
|
794 |
+
x = x.reshape(batch, dim * dim)
|
795 |
+
x = x[:, :: dim + 1]
|
796 |
+
assert x.shape == (batch, dim)
|
797 |
+
return x
|
798 |
+
|
799 |
+
|
800 |
+
def _whitening_metric(x: Tensor, num_groups: int):
|
801 |
+
"""
|
802 |
+
Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
|
803 |
+
of the centered feature covariance are the same within each group's covariance matrix
|
804 |
+
and also between groups.
|
805 |
+
Args:
|
806 |
+
x: a Tensor of shape (*, num_channels)
|
807 |
+
num_groups: the number of groups of channels, a number >=1 that divides num_channels
|
808 |
+
Returns:
|
809 |
+
Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
|
810 |
+
greater than 1.0 otherwise.
|
811 |
+
"""
|
812 |
+
assert x.dtype != torch.float16
|
813 |
+
x = x.reshape(-1, x.shape[-1])
|
814 |
+
(num_frames, num_channels) = x.shape
|
815 |
+
assert num_channels % num_groups == 0
|
816 |
+
channels_per_group = num_channels // num_groups
|
817 |
+
x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
|
818 |
+
# x now has shape (num_groups, num_frames, channels_per_group)
|
819 |
+
# subtract the mean so we use the centered, not uncentered, covariance.
|
820 |
+
# My experience has been that when we "mess with the gradients" like this,
|
821 |
+
# it's better not do anything that tries to move the mean around, because
|
822 |
+
# that can easily cause instability.
|
823 |
+
x = x - x.mean(dim=1, keepdim=True)
|
824 |
+
# x_covar: (num_groups, channels_per_group, channels_per_group)
|
825 |
+
x_covar = torch.matmul(x.transpose(1, 2), x)
|
826 |
+
x_covar_mean_diag = _diag(x_covar).mean()
|
827 |
+
# the following expression is what we'd get if we took the matrix product
|
828 |
+
# of each covariance and measured the mean of its trace, i.e.
|
829 |
+
# the same as _diag(torch.matmul(x_covar, x_covar)).mean().
|
830 |
+
x_covarsq_mean_diag = (x_covar ** 2).sum() / (
|
831 |
+
num_groups * channels_per_group
|
832 |
+
)
|
833 |
+
# this metric will be >= 1.0; the larger it is, the less 'white' the data was.
|
834 |
+
metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20)
|
835 |
+
return metric
|
836 |
+
|
837 |
+
|
838 |
+
class WhiteningPenaltyFunction(torch.autograd.Function):
|
839 |
+
@staticmethod
|
840 |
+
def forward(
|
841 |
+
ctx,
|
842 |
+
x: Tensor,
|
843 |
+
num_groups: int,
|
844 |
+
whitening_limit: float,
|
845 |
+
grad_scale: float,
|
846 |
+
) -> Tensor:
|
847 |
+
ctx.save_for_backward(x)
|
848 |
+
ctx.num_groups = num_groups
|
849 |
+
ctx.whitening_limit = whitening_limit
|
850 |
+
ctx.grad_scale = grad_scale
|
851 |
+
return x
|
852 |
+
|
853 |
+
@staticmethod
|
854 |
+
def backward(ctx, x_grad: Tensor):
|
855 |
+
(x_orig,) = ctx.saved_tensors
|
856 |
+
with torch.enable_grad():
|
857 |
+
with torch.cuda.amp.autocast(enabled=False):
|
858 |
+
x_detached = x_orig.to(torch.float32).detach()
|
859 |
+
x_detached.requires_grad = True
|
860 |
+
|
861 |
+
metric = _whitening_metric(x_detached, ctx.num_groups)
|
862 |
+
|
863 |
+
if random.random() < 0.005 or __name__ == "__main__":
|
864 |
+
logging.info(
|
865 |
+
f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, "
|
866 |
+
f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}"
|
867 |
+
)
|
868 |
+
|
869 |
+
(metric - ctx.whitening_limit).relu().backward()
|
870 |
+
penalty_grad = x_detached.grad
|
871 |
+
scale = ctx.grad_scale * (
|
872 |
+
x_grad.to(torch.float32).norm()
|
873 |
+
/ (penalty_grad.norm() + 1.0e-20)
|
874 |
+
)
|
875 |
+
penalty_grad = penalty_grad * scale
|
876 |
+
return x_grad + penalty_grad.to(x_grad.dtype), None, None, None
|
877 |
+
|
878 |
+
|
879 |
+
class Whiten(nn.Module):
|
880 |
+
def __init__(
|
881 |
+
self,
|
882 |
+
num_groups: int,
|
883 |
+
whitening_limit: float,
|
884 |
+
prob: Union[float, Tuple[float, float]],
|
885 |
+
grad_scale: float,
|
886 |
+
):
|
887 |
+
"""
|
888 |
+
Args:
|
889 |
+
num_groups: the number of groups to divide the channel dim into before
|
890 |
+
whitening. We will attempt to make the feature covariance
|
891 |
+
within each group, after mean subtraction, as "white" as possible,
|
892 |
+
while having the same trace across all groups.
|
893 |
+
whitening_limit: a value greater than 1.0, that dictates how much
|
894 |
+
freedom we have to violate the constraints. 1.0 would mean perfectly
|
895 |
+
white, with exactly the same trace across groups; larger values
|
896 |
+
give more freedom. E.g. 2.0.
|
897 |
+
prob: the probability with which we apply the gradient modification
|
898 |
+
(also affects the grad scale). May be supplied as a float,
|
899 |
+
or as a pair (min_prob, max_prob)
|
900 |
+
grad_scale: determines the scale on the gradient term from this object,
|
901 |
+
relative to the rest of the gradient on the attention weights.
|
902 |
+
E.g. 0.02 (you may want to use smaller values than this if prob is large)
|
903 |
+
"""
|
904 |
+
super(Whiten, self).__init__()
|
905 |
+
assert num_groups >= 1
|
906 |
+
assert whitening_limit >= 1
|
907 |
+
assert grad_scale >= 0
|
908 |
+
self.num_groups = num_groups
|
909 |
+
self.whitening_limit = whitening_limit
|
910 |
+
if isinstance(prob, float):
|
911 |
+
assert 0 < prob <= 1
|
912 |
+
self.prob = prob
|
913 |
+
else:
|
914 |
+
(self.min_prob, self.max_prob) = prob
|
915 |
+
assert 0 < self.min_prob < self.max_prob <= 1
|
916 |
+
self.prob = self.max_prob
|
917 |
+
|
918 |
+
self.grad_scale = grad_scale
|
919 |
+
|
920 |
+
def forward(self, x: Tensor) -> Tensor:
|
921 |
+
"""
|
922 |
+
In the forward pass, this function just returns the input unmodified.
|
923 |
+
In the backward pass, it will modify the gradients to ensure that the
|
924 |
+
distribution in each group has close to (lambda times I) as the covariance
|
925 |
+
after mean subtraction, with the same lambda across groups.
|
926 |
+
For whitening_limit > 1, there will be more freedom to violate this
|
927 |
+
constraint.
|
928 |
+
Args:
|
929 |
+
x: the input of shape (*, num_channels)
|
930 |
+
Returns:
|
931 |
+
x, unmodified. You should make sure
|
932 |
+
you use the returned value, or the graph will be freed
|
933 |
+
and nothing will happen in backprop.
|
934 |
+
"""
|
935 |
+
if (
|
936 |
+
not x.requires_grad
|
937 |
+
or random.random() > self.prob
|
938 |
+
or self.grad_scale == 0
|
939 |
+
):
|
940 |
+
return _no_op(x)
|
941 |
+
else:
|
942 |
+
if hasattr(self, "min_prob") and random.random() < 0.25:
|
943 |
+
# occasionally switch between min_prob and max_prob, based on whether
|
944 |
+
# we are above or below the threshold.
|
945 |
+
if (
|
946 |
+
_whitening_metric(x.to(torch.float32), self.num_groups)
|
947 |
+
> self.whitening_limit
|
948 |
+
):
|
949 |
+
# there would be a change to the grad.
|
950 |
+
self.prob = self.max_prob
|
951 |
+
else:
|
952 |
+
self.prob = self.min_prob
|
953 |
+
|
954 |
+
return WhiteningPenaltyFunction.apply(
|
955 |
+
x, self.num_groups, self.whitening_limit, self.grad_scale
|
956 |
+
)
|
957 |
+
|
958 |
+
|
959 |
+
class WithLoss(torch.autograd.Function):
|
960 |
+
@staticmethod
|
961 |
+
def forward(ctx, x: Tensor, y: Tensor):
|
962 |
+
ctx.y_shape = y.shape
|
963 |
+
return x
|
964 |
+
|
965 |
+
@staticmethod
|
966 |
+
def backward(ctx, ans_grad: Tensor):
|
967 |
+
return ans_grad, torch.ones(
|
968 |
+
ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device
|
969 |
+
)
|
970 |
+
|
971 |
+
|
972 |
+
def with_loss(x, y):
|
973 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
974 |
+
return x
|
975 |
+
# returns x but adds y.sum() to the loss function.
|
976 |
+
return WithLoss.apply(x, y)
|
977 |
+
|
978 |
+
|
979 |
+
def _no_op(x: Tensor) -> Tensor:
|
980 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
981 |
+
return x
|
982 |
+
else:
|
983 |
+
# a no-op function that will have a node in the autograd graph,
|
984 |
+
# to avoid certain bugs relating to backward hooks
|
985 |
+
return x.chunk(1, dim=-1)[0]
|
986 |
+
|
987 |
+
|
988 |
+
class Identity(torch.nn.Module):
|
989 |
+
def __init__(self):
|
990 |
+
super(Identity, self).__init__()
|
991 |
+
|
992 |
+
def forward(self, x):
|
993 |
+
return _no_op(x)
|
994 |
+
|
995 |
+
|
996 |
+
class MaxEig(torch.nn.Module):
|
997 |
+
"""
|
998 |
+
Modifies the backpropped derivatives of a function to try to discourage
|
999 |
+
that any given direction in activation space accounts for more than
|
1000 |
+
a specified proportion of the covariance (e.g. 0.2).
|
1001 |
+
Args:
|
1002 |
+
num_channels: the number of channels
|
1003 |
+
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
1004 |
+
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
1005 |
+
max_var_per_eig: the maximum proportion of the variance of the
|
1006 |
+
features/channels, after mean subtraction, that can come from
|
1007 |
+
any given eigenvalue.
|
1008 |
+
min_prob: the minimum probability with which we apply this during any invocation
|
1009 |
+
of forward(), assuming last time we applied the constraint it was
|
1010 |
+
not active; supplied for speed.
|
1011 |
+
scale: determines the scale with which we modify the gradients, relative
|
1012 |
+
to the existing / unmodified gradients
|
1013 |
+
"""
|
1014 |
+
|
1015 |
+
def __init__(
|
1016 |
+
self,
|
1017 |
+
num_channels: int,
|
1018 |
+
channel_dim: int,
|
1019 |
+
max_var_per_eig: float = 0.2,
|
1020 |
+
min_prob: float = 0.01,
|
1021 |
+
scale: float = 0.01,
|
1022 |
+
):
|
1023 |
+
super(MaxEig, self).__init__()
|
1024 |
+
self.num_channels = num_channels
|
1025 |
+
self.channel_dim = channel_dim
|
1026 |
+
self.scale = scale
|
1027 |
+
assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels
|
1028 |
+
self.max_var_per_eig = max_var_per_eig
|
1029 |
+
|
1030 |
+
# we figure out the dominant direction using the power method: starting with
|
1031 |
+
# a random vector, keep multiplying by the covariance and renormalizing.
|
1032 |
+
with torch.no_grad():
|
1033 |
+
# arbitrary.. would use randn() but want to leave the rest of the model's
|
1034 |
+
# random parameters unchanged for comparison
|
1035 |
+
direction = torch.arange(num_channels).to(torch.float)
|
1036 |
+
direction = direction / direction.norm()
|
1037 |
+
self.register_buffer("max_eig_direction", direction)
|
1038 |
+
|
1039 |
+
self.min_prob = min_prob
|
1040 |
+
# cur_prob is the current probability we'll use to apply the ActivationBalancer.
|
1041 |
+
# We'll regress this towards prob, each tiem we try to apply it and it is not
|
1042 |
+
# active.
|
1043 |
+
self.cur_prob = 1.0
|
1044 |
+
|
1045 |
+
def forward(self, x: Tensor) -> Tensor:
|
1046 |
+
if (
|
1047 |
+
torch.jit.is_scripting()
|
1048 |
+
or self.max_var_per_eig <= 0
|
1049 |
+
or random.random() > self.cur_prob
|
1050 |
+
or torch.jit.is_tracing()
|
1051 |
+
):
|
1052 |
+
return _no_op(x)
|
1053 |
+
|
1054 |
+
with torch.cuda.amp.autocast(enabled=False):
|
1055 |
+
eps = 1.0e-20
|
1056 |
+
orig_x = x
|
1057 |
+
x = x.to(torch.float32)
|
1058 |
+
with torch.no_grad():
|
1059 |
+
x = x.transpose(self.channel_dim, -1).reshape(
|
1060 |
+
-1, self.num_channels
|
1061 |
+
)
|
1062 |
+
x = x - x.mean(dim=0)
|
1063 |
+
new_direction, coeffs = self._find_direction_coeffs(
|
1064 |
+
x, self.max_eig_direction
|
1065 |
+
)
|
1066 |
+
x_var = (x ** 2).mean()
|
1067 |
+
x_residual = x - coeffs * new_direction
|
1068 |
+
x_residual_var = (x_residual ** 2).mean()
|
1069 |
+
|
1070 |
+
# `variance_proportion` is the proportion of the variance accounted for
|
1071 |
+
# by the top eigen-direction.
|
1072 |
+
variance_proportion = (x_var - x_residual_var) / (
|
1073 |
+
x_var + 1.0e-20
|
1074 |
+
)
|
1075 |
+
|
1076 |
+
# ensure new direction is nonzero even if x == 0, by including `direction`.
|
1077 |
+
self._set_direction(
|
1078 |
+
0.1 * self.max_eig_direction + new_direction
|
1079 |
+
)
|
1080 |
+
|
1081 |
+
if random.random() < 0.01 or __name__ == "__main__":
|
1082 |
+
logging.info(
|
1083 |
+
f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}"
|
1084 |
+
)
|
1085 |
+
|
1086 |
+
if variance_proportion >= self.max_var_per_eig:
|
1087 |
+
# The constraint is active. Note, we should quite rarely
|
1088 |
+
# reach here, only near the beginning of training if we are
|
1089 |
+
# starting to diverge, should this constraint be active.
|
1090 |
+
cur_prob = self.cur_prob
|
1091 |
+
self.cur_prob = (
|
1092 |
+
1.0 # next time, do the update with probability 1.0.
|
1093 |
+
)
|
1094 |
+
return MaxEigLimiterFunction.apply(
|
1095 |
+
orig_x, coeffs, new_direction, self.channel_dim, self.scale
|
1096 |
+
)
|
1097 |
+
else:
|
1098 |
+
# let self.cur_prob exponentially approach self.min_prob, as
|
1099 |
+
# long as the constraint is inactive.
|
1100 |
+
self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob
|
1101 |
+
return orig_x
|
1102 |
+
|
1103 |
+
def _set_direction(self, direction: Tensor):
|
1104 |
+
"""
|
1105 |
+
Sets self.max_eig_direction to a normalized version of `direction`
|
1106 |
+
"""
|
1107 |
+
direction = direction.detach()
|
1108 |
+
direction = direction / direction.norm()
|
1109 |
+
direction_sum = direction.sum().item()
|
1110 |
+
if direction_sum - direction_sum == 0: # no inf/nan
|
1111 |
+
self.max_eig_direction[:] = direction
|
1112 |
+
else:
|
1113 |
+
logging.info(
|
1114 |
+
f"Warning: sum of direction in MaxEig is {direction_sum}, "
|
1115 |
+
"num_channels={self.num_channels}, channel_dim={self.channel_dim}"
|
1116 |
+
)
|
1117 |
+
|
1118 |
+
def _find_direction_coeffs(
|
1119 |
+
self, x: Tensor, prev_direction: Tensor
|
1120 |
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
1121 |
+
"""
|
1122 |
+
Figure out (an approximation to) the proportion of the variance of a set of
|
1123 |
+
feature vectors that can be attributed to the top eigen-direction.
|
1124 |
+
Args:
|
1125 |
+
x: a Tensor of shape (num_frames, num_channels), with num_frames > 1.
|
1126 |
+
prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
|
1127 |
+
of the top eigen-direction, or a random direction if this is the first
|
1128 |
+
iteration. Does not have to be normalized, but should be nonzero.
|
1129 |
+
Returns: (cur_direction, coeffs), where:
|
1130 |
+
cur_direction: a Tensor of shape (num_channels,) that is the current
|
1131 |
+
estimate of the top eigen-direction.
|
1132 |
+
coeffs: a Tensor of shape (num_frames, 1) that minimizes, or
|
1133 |
+
approximately minimizes, (x - coeffs * cur_direction).norm()
|
1134 |
+
"""
|
1135 |
+
(num_frames, num_channels) = x.shape
|
1136 |
+
assert num_channels > 1 and num_frames > 1
|
1137 |
+
assert prev_direction.shape == (num_channels,)
|
1138 |
+
# `coeffs` are the coefficients of `prev_direction` in x.
|
1139 |
+
# actually represent the coeffs up to a constant positive factor.
|
1140 |
+
coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10
|
1141 |
+
cur_direction = (x * coeffs).sum(dim=0) / (
|
1142 |
+
(coeffs ** 2).sum() + 1.0e-20
|
1143 |
+
)
|
1144 |
+
return cur_direction, coeffs
|
1145 |
+
|
1146 |
+
|
1147 |
+
class DoubleSwishFunction(torch.autograd.Function):
|
1148 |
+
"""
|
1149 |
+
double_swish(x) = x * torch.sigmoid(x-1)
|
1150 |
+
This is a definition, originally motivated by its close numerical
|
1151 |
+
similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
|
1152 |
+
Memory-efficient derivative computation:
|
1153 |
+
double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
|
1154 |
+
double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
|
1155 |
+
Now, s'(x) = s(x) * (1-s(x)).
|
1156 |
+
double_swish'(x) = x * s'(x) + s(x).
|
1157 |
+
= x * s(x) * (1-s(x)) + s(x).
|
1158 |
+
= double_swish(x) * (1-s(x)) + s(x)
|
1159 |
+
... so we just need to remember s(x) but not x itself.
|
1160 |
+
"""
|
1161 |
+
|
1162 |
+
@staticmethod
|
1163 |
+
def forward(ctx, x: Tensor) -> Tensor:
|
1164 |
+
requires_grad = x.requires_grad
|
1165 |
+
x_dtype = x.dtype
|
1166 |
+
if x.dtype == torch.float16:
|
1167 |
+
x = x.to(torch.float32)
|
1168 |
+
|
1169 |
+
s = torch.sigmoid(x - 1.0)
|
1170 |
+
y = x * s
|
1171 |
+
|
1172 |
+
if requires_grad:
|
1173 |
+
deriv = y * (1 - s) + s
|
1174 |
+
# notes on derivative of x * sigmoid(x - 1):
|
1175 |
+
# https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
|
1176 |
+
# min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund
|
1177 |
+
# max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
|
1178 |
+
# the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
|
1179 |
+
# floors), should be expectation-preserving.
|
1180 |
+
floor = -0.043637
|
1181 |
+
ceil = 1.2
|
1182 |
+
d_scaled = (deriv - floor) * (
|
1183 |
+
255.0 / (ceil - floor)
|
1184 |
+
) + torch.rand_like(deriv)
|
1185 |
+
if __name__ == "__main__":
|
1186 |
+
# for self-testing only.
|
1187 |
+
assert d_scaled.min() >= 0.0
|
1188 |
+
assert d_scaled.max() < 256.0
|
1189 |
+
d_int = d_scaled.to(torch.uint8)
|
1190 |
+
ctx.save_for_backward(d_int)
|
1191 |
+
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
1192 |
+
y = y.to(torch.float16)
|
1193 |
+
return y
|
1194 |
+
|
1195 |
+
@staticmethod
|
1196 |
+
def backward(ctx, y_grad: Tensor) -> Tensor:
|
1197 |
+
(d,) = ctx.saved_tensors
|
1198 |
+
# the same constants as used in forward pass.
|
1199 |
+
floor = -0.043637
|
1200 |
+
ceil = 1.2
|
1201 |
+
d = d * ((ceil - floor) / 255.0) + floor
|
1202 |
+
return y_grad * d
|
1203 |
+
|
1204 |
+
|
1205 |
+
class DoubleSwish(torch.nn.Module):
|
1206 |
+
def forward(self, x: Tensor) -> Tensor:
|
1207 |
+
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
1208 |
+
that we approximate closely with x * sigmoid(x-1).
|
1209 |
+
"""
|
1210 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
1211 |
+
return x * torch.sigmoid(x - 1.0)
|
1212 |
+
return DoubleSwishFunction.apply(x)
|
1213 |
+
|
1214 |
+
|
1215 |
+
def BalancedDoubleSwish(
|
1216 |
+
d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
|
1217 |
+
) -> nn.Sequential:
|
1218 |
+
"""
|
1219 |
+
ActivationBalancer -> DoubleSwish
|
1220 |
+
"""
|
1221 |
+
balancer = ActivationBalancer(
|
1222 |
+
d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
|
1223 |
+
)
|
1224 |
+
return nn.Sequential(
|
1225 |
+
balancer,
|
1226 |
+
DoubleSwish(),
|
1227 |
+
)
|
1228 |
+
|
1229 |
+
|
1230 |
+
def _test_max_eig():
|
1231 |
+
for proportion in [0.1, 0.5, 10.0]:
|
1232 |
+
logging.info(f"proportion = {proportion}")
|
1233 |
+
x = torch.randn(100, 128)
|
1234 |
+
direction = torch.randn(128)
|
1235 |
+
coeffs = torch.randn(100, 1)
|
1236 |
+
x += proportion * direction * coeffs
|
1237 |
+
|
1238 |
+
x.requires_grad = True
|
1239 |
+
|
1240 |
+
num_channels = 128
|
1241 |
+
m = MaxEig(
|
1242 |
+
num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig
|
1243 |
+
) # grad_scale
|
1244 |
+
|
1245 |
+
for _ in range(4):
|
1246 |
+
y = m(x)
|
1247 |
+
|
1248 |
+
y_grad = torch.randn_like(x)
|
1249 |
+
y.backward(gradient=y_grad)
|
1250 |
+
|
1251 |
+
if proportion < 0.2:
|
1252 |
+
assert torch.allclose(x.grad, y_grad, atol=1.0e-02)
|
1253 |
+
elif proportion > 1.0:
|
1254 |
+
assert not torch.allclose(x.grad, y_grad)
|
1255 |
+
|
1256 |
+
|
1257 |
+
def _test_whiten():
|
1258 |
+
for proportion in [0.1, 0.5, 10.0]:
|
1259 |
+
logging.info(f"_test_whiten(): proportion = {proportion}")
|
1260 |
+
x = torch.randn(100, 128)
|
1261 |
+
direction = torch.randn(128)
|
1262 |
+
coeffs = torch.randn(100, 1)
|
1263 |
+
x += proportion * direction * coeffs
|
1264 |
+
|
1265 |
+
x.requires_grad = True
|
1266 |
+
|
1267 |
+
num_channels = 128
|
1268 |
+
m = Whiten(
|
1269 |
+
1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit,
|
1270 |
+
) # grad_scale
|
1271 |
+
|
1272 |
+
for _ in range(4):
|
1273 |
+
y = m(x)
|
1274 |
+
|
1275 |
+
y_grad = torch.randn_like(x)
|
1276 |
+
y.backward(gradient=y_grad)
|
1277 |
+
|
1278 |
+
if proportion < 0.2:
|
1279 |
+
assert torch.allclose(x.grad, y_grad)
|
1280 |
+
elif proportion > 1.0:
|
1281 |
+
assert not torch.allclose(x.grad, y_grad)
|
1282 |
+
|
1283 |
+
|
1284 |
+
def _test_activation_balancer_sign():
|
1285 |
+
probs = torch.arange(0, 1, 0.01)
|
1286 |
+
N = 1000
|
1287 |
+
x = 1.0 * (
|
1288 |
+
(2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0
|
1289 |
+
)
|
1290 |
+
x = x.detach()
|
1291 |
+
x.requires_grad = True
|
1292 |
+
m = ActivationBalancer(
|
1293 |
+
probs.numel(),
|
1294 |
+
channel_dim=0,
|
1295 |
+
min_positive=0.05,
|
1296 |
+
max_positive=0.95,
|
1297 |
+
max_factor=0.2,
|
1298 |
+
min_abs=0.0,
|
1299 |
+
)
|
1300 |
+
|
1301 |
+
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
1302 |
+
|
1303 |
+
y = m(x)
|
1304 |
+
y.backward(gradient=y_grad)
|
1305 |
+
print("_test_activation_balancer_sign: x = ", x)
|
1306 |
+
print("_test_activation_balancer_sign: y grad = ", y_grad)
|
1307 |
+
print("_test_activation_balancer_sign: x grad = ", x.grad)
|
1308 |
+
|
1309 |
+
|
1310 |
+
def _test_activation_balancer_magnitude():
|
1311 |
+
magnitudes = torch.arange(0, 1, 0.01)
|
1312 |
+
N = 1000
|
1313 |
+
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(
|
1314 |
+
-1
|
1315 |
+
)
|
1316 |
+
x = x.detach()
|
1317 |
+
x.requires_grad = True
|
1318 |
+
m = ActivationBalancer(
|
1319 |
+
magnitudes.numel(),
|
1320 |
+
channel_dim=0,
|
1321 |
+
min_positive=0.0,
|
1322 |
+
max_positive=1.0,
|
1323 |
+
max_factor=0.2,
|
1324 |
+
min_abs=0.2,
|
1325 |
+
max_abs=0.8,
|
1326 |
+
min_prob=1.0,
|
1327 |
+
)
|
1328 |
+
|
1329 |
+
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
1330 |
+
|
1331 |
+
y = m(x)
|
1332 |
+
y.backward(gradient=y_grad)
|
1333 |
+
print("_test_activation_balancer_magnitude: x = ", x)
|
1334 |
+
print("_test_activation_balancer_magnitude: y grad = ", y_grad)
|
1335 |
+
print("_test_activation_balancer_magnitude: x grad = ", x.grad)
|
1336 |
+
|
1337 |
+
|
1338 |
+
def _test_basic_norm():
|
1339 |
+
num_channels = 128
|
1340 |
+
m = BasicNorm(num_channels=num_channels, channel_dim=1)
|
1341 |
+
|
1342 |
+
x = torch.randn(500, num_channels)
|
1343 |
+
|
1344 |
+
y = m(x)
|
1345 |
+
|
1346 |
+
assert y.shape == x.shape
|
1347 |
+
x_rms = (x ** 2).mean().sqrt()
|
1348 |
+
y_rms = (y ** 2).mean().sqrt()
|
1349 |
+
print("x rms = ", x_rms)
|
1350 |
+
print("y rms = ", y_rms)
|
1351 |
+
assert y_rms < x_rms
|
1352 |
+
assert y_rms > 0.5 * x_rms
|
1353 |
+
|
1354 |
+
|
1355 |
+
def _test_double_swish_deriv():
|
1356 |
+
x = torch.randn(10, 12, dtype=torch.double) * 3.0
|
1357 |
+
x.requires_grad = True
|
1358 |
+
m = DoubleSwish()
|
1359 |
+
|
1360 |
+
tol = (1.2 - (-0.043637)) / 255.0
|
1361 |
+
torch.autograd.gradcheck(m, x, atol=tol)
|
1362 |
+
|
1363 |
+
# for self-test.
|
1364 |
+
x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
|
1365 |
+
x.requires_grad = True
|
1366 |
+
y = m(x)
|
1367 |
+
|
1368 |
+
|
1369 |
+
def _test_softmax():
|
1370 |
+
a = torch.randn(2, 10, dtype=torch.float64)
|
1371 |
+
b = a.clone()
|
1372 |
+
a.requires_grad = True
|
1373 |
+
b.requires_grad = True
|
1374 |
+
a.softmax(dim=1)[:, 0].sum().backward()
|
1375 |
+
print("a grad = ", a.grad)
|
1376 |
+
softmax(b, dim=1)[:, 0].sum().backward()
|
1377 |
+
print("b grad = ", b.grad)
|
1378 |
+
assert torch.allclose(a.grad, b.grad)
|
1379 |
+
|
1380 |
+
|
1381 |
+
if __name__ == "__main__":
|
1382 |
+
logging.getLogger().setLevel(logging.INFO)
|
1383 |
+
torch.set_num_threads(1)
|
1384 |
+
torch.set_num_interop_threads(1)
|
1385 |
+
_test_softmax()
|
1386 |
+
_test_whiten()
|
1387 |
+
_test_max_eig()
|
1388 |
+
_test_activation_balancer_sign()
|
1389 |
+
_test_activation_balancer_magnitude()
|
1390 |
+
_test_basic_norm()
|
1391 |
+
_test_double_swish_deriv()
|
src/model/modules/siglip.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
|
6 |
+
class SiglipVisionConfig:
|
7 |
+
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
hidden_size=768,
|
11 |
+
intermediate_size=3072,
|
12 |
+
num_hidden_layers=12,
|
13 |
+
num_attention_heads=12,
|
14 |
+
num_channels=3,
|
15 |
+
image_size=224,
|
16 |
+
patch_size=16,
|
17 |
+
layer_norm_eps=1e-6,
|
18 |
+
attention_dropout=0.0,
|
19 |
+
num_image_tokens: int = None,
|
20 |
+
**kwargs,
|
21 |
+
):
|
22 |
+
super().__init__()
|
23 |
+
|
24 |
+
self.hidden_size = hidden_size
|
25 |
+
self.intermediate_size = intermediate_size
|
26 |
+
self.num_hidden_layers = num_hidden_layers
|
27 |
+
self.num_attention_heads = num_attention_heads
|
28 |
+
self.num_channels = num_channels
|
29 |
+
self.patch_size = patch_size
|
30 |
+
self.image_size = image_size
|
31 |
+
self.attention_dropout = attention_dropout
|
32 |
+
self.layer_norm_eps = layer_norm_eps
|
33 |
+
self.num_image_tokens = num_image_tokens
|
34 |
+
|
35 |
+
|
36 |
+
class SiglipVisionEmbeddings(nn.Module):
|
37 |
+
def __init__(self, config: SiglipVisionConfig):
|
38 |
+
super().__init__()
|
39 |
+
self.config = config
|
40 |
+
self.embed_dim = config.hidden_size
|
41 |
+
self.image_size = config.image_size
|
42 |
+
self.patch_size = config.patch_size
|
43 |
+
|
44 |
+
self.patch_embedding = nn.Conv2d(
|
45 |
+
in_channels=config.num_channels,
|
46 |
+
out_channels=self.embed_dim,
|
47 |
+
kernel_size=self.patch_size,
|
48 |
+
stride=self.patch_size,
|
49 |
+
padding="valid", # This indicates no padding is added
|
50 |
+
)
|
51 |
+
|
52 |
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
53 |
+
self.num_positions = self.num_patches
|
54 |
+
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
55 |
+
self.register_buffer(
|
56 |
+
"position_ids",
|
57 |
+
torch.arange(self.num_positions).expand((1, -1)),
|
58 |
+
persistent=False,
|
59 |
+
)
|
60 |
+
|
61 |
+
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
62 |
+
_, _, height, width = (
|
63 |
+
pixel_values.shape
|
64 |
+
) # [Batch_Size, Channels, Height, Width]
|
65 |
+
# Convolve the `patch_size` kernel over the image, with no overlapping patches since the stride is equal to the kernel size
|
66 |
+
# The output of the convolution will have shape [Batch_Size, Embed_Dim, Num_Patches_H, Num_Patches_W]
|
67 |
+
# where Num_Patches_H = height // patch_size and Num_Patches_W = width // patch_size
|
68 |
+
patch_embeds = self.patch_embedding(pixel_values)
|
69 |
+
# [Batch_Size, Embed_Dim, Num_Patches_H, Num_Patches_W] -> [Batch_Size, Embed_Dim, Num_Patches]
|
70 |
+
# where Num_Patches = Num_Patches_H * Num_Patches_W
|
71 |
+
embeddings = patch_embeds.flatten(2)
|
72 |
+
# [Batch_Size, Embed_Dim, Num_Patches] -> [Batch_Size, Num_Patches, Embed_Dim]
|
73 |
+
embeddings = embeddings.transpose(1, 2)
|
74 |
+
# Add position embeddings to each patch. Each positional encoding is a vector of size [Embed_Dim]
|
75 |
+
embeddings = embeddings + self.position_embedding(self.position_ids)
|
76 |
+
# [Batch_Size, Num_Patches, Embed_Dim]
|
77 |
+
return embeddings
|
78 |
+
|
79 |
+
|
80 |
+
class SiglipAttention(nn.Module):
|
81 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
82 |
+
|
83 |
+
def __init__(self, config):
|
84 |
+
super().__init__()
|
85 |
+
self.config = config
|
86 |
+
self.embed_dim = config.hidden_size
|
87 |
+
self.num_heads = config.num_attention_heads
|
88 |
+
self.head_dim = self.embed_dim // self.num_heads
|
89 |
+
self.scale = self.head_dim**-0.5 # Equivalent to 1 / sqrt(self.head_dim)
|
90 |
+
self.dropout = config.attention_dropout
|
91 |
+
|
92 |
+
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
93 |
+
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
94 |
+
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
95 |
+
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
96 |
+
|
97 |
+
def forward(
|
98 |
+
self,
|
99 |
+
hidden_states: torch.Tensor,
|
100 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
101 |
+
|
102 |
+
# hidden_states: [Batch_Size, Num_Patches, Embed_Dim]
|
103 |
+
batch_size, seq_len, _ = hidden_states.size()
|
104 |
+
# query_states: [Batch_Size, Num_Patches, Embed_Dim]
|
105 |
+
query_states = self.q_proj(hidden_states)
|
106 |
+
# key_states: [Batch_Size, Num_Patches, Embed_Dim]
|
107 |
+
key_states = self.k_proj(hidden_states)
|
108 |
+
# value_states: [Batch_Size, Num_Patches, Embed_Dim]
|
109 |
+
value_states = self.v_proj(hidden_states)
|
110 |
+
# query_states: [Batch_Size, Num_Heads, Num_Patches, Head_Dim]
|
111 |
+
query_states = query_states.view(
|
112 |
+
batch_size, seq_len, self.num_heads, self.head_dim
|
113 |
+
).transpose(1, 2)
|
114 |
+
|
115 |
+
key_states = key_states.view(
|
116 |
+
batch_size, seq_len, self.num_heads, self.head_dim
|
117 |
+
).transpose(1, 2)
|
118 |
+
|
119 |
+
value_states = value_states.view(
|
120 |
+
batch_size, seq_len, self.num_heads, self.head_dim
|
121 |
+
).transpose(1, 2)
|
122 |
+
# Calculate the attention using the formula Q * K^T / sqrt(d_k). attn_weights: [Batch_Size, Num_Heads, Num_Patches, Num_Patches]
|
123 |
+
attn_weights = (
|
124 |
+
torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
|
125 |
+
)
|
126 |
+
|
127 |
+
if attn_weights.size() != (batch_size, self.num_heads, seq_len, seq_len):
|
128 |
+
raise ValueError(
|
129 |
+
f"Attention weights should be of size {(batch_size, self.num_heads, seq_len, seq_len)}, but is"
|
130 |
+
f" {attn_weights.size()}"
|
131 |
+
)
|
132 |
+
|
133 |
+
# Apply the softmax row-wise. attn_weights: [Batch_Size, Num_Heads, Num_Patches, Num_Patches]
|
134 |
+
attn_weights = nn.functional.softmax(
|
135 |
+
attn_weights, dim=-1, dtype=torch.float32
|
136 |
+
).to(query_states.dtype)
|
137 |
+
# Apply dropout only during training
|
138 |
+
attn_weights = nn.functional.dropout(
|
139 |
+
attn_weights, p=self.dropout, training=self.training
|
140 |
+
)
|
141 |
+
# Multiply the attention weights by the value states. attn_output: [Batch_Size, Num_Heads, Num_Patches, Head_Dim]
|
142 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
143 |
+
|
144 |
+
if attn_output.size() != (batch_size, self.num_heads, seq_len, self.head_dim):
|
145 |
+
raise ValueError(
|
146 |
+
f"`attn_output` should be of size {(batch_size, self.num_heads, seq_len, self.head_dim)}, but is"
|
147 |
+
f" {attn_output.size()}"
|
148 |
+
)
|
149 |
+
# [Batch_Size, Num_Heads, Num_Patches, Head_Dim] -> [Batch_Size, Num_Patches, Num_Heads, Head_Dim]
|
150 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
151 |
+
# [Batch_Size, Num_Patches, Num_Heads, Head_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
|
152 |
+
attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim)
|
153 |
+
# [Batch_Size, Num_Patches, Embed_Dim]
|
154 |
+
attn_output = self.out_proj(attn_output)
|
155 |
+
|
156 |
+
return attn_output, attn_weights
|
157 |
+
|
158 |
+
|
159 |
+
class SiglipMLP(nn.Module):
|
160 |
+
def __init__(self, config):
|
161 |
+
super().__init__()
|
162 |
+
self.config = config
|
163 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
164 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
165 |
+
|
166 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
167 |
+
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Intermediate_Size]
|
168 |
+
hidden_states = self.fc1(hidden_states)
|
169 |
+
# hidden_states: [Batch_Size, Num_Patches, Intermediate_Size]
|
170 |
+
hidden_states = nn.functional.gelu(hidden_states, approximate="tanh")
|
171 |
+
# [Batch_Size, Num_Patches, Intermediate_Size] -> [Batch_Size, Num_Patches, Embed_Dim]
|
172 |
+
hidden_states = self.fc2(hidden_states)
|
173 |
+
|
174 |
+
return hidden_states
|
175 |
+
|
176 |
+
|
177 |
+
class SiglipEncoderLayer(nn.Module):
|
178 |
+
def __init__(self, config: SiglipVisionConfig):
|
179 |
+
super().__init__()
|
180 |
+
self.embed_dim = config.hidden_size
|
181 |
+
self.self_attn = SiglipAttention(config)
|
182 |
+
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
183 |
+
self.mlp = SiglipMLP(config)
|
184 |
+
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
185 |
+
|
186 |
+
# Ignore copy
|
187 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
188 |
+
# residual: [Batch_Size, Num_Patches, Embed_Dim]
|
189 |
+
residual = hidden_states
|
190 |
+
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
|
191 |
+
hidden_states = self.layer_norm1(hidden_states)
|
192 |
+
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
|
193 |
+
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
|
194 |
+
# [Batch_Size, Num_Patches, Embed_Dim]
|
195 |
+
hidden_states = residual + hidden_states
|
196 |
+
# residual: [Batch_Size, Num_Patches, Embed_Dim]
|
197 |
+
residual = hidden_states
|
198 |
+
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
|
199 |
+
hidden_states = self.layer_norm2(hidden_states)
|
200 |
+
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
|
201 |
+
hidden_states = self.mlp(hidden_states)
|
202 |
+
# [Batch_Size, Num_Patches, Embed_Dim]
|
203 |
+
hidden_states = residual + hidden_states
|
204 |
+
|
205 |
+
return hidden_states
|
206 |
+
|
207 |
+
|
208 |
+
class SiglipEncoder(nn.Module):
|
209 |
+
def __init__(self, config: SiglipVisionConfig):
|
210 |
+
super().__init__()
|
211 |
+
self.config = config
|
212 |
+
self.layers = nn.ModuleList(
|
213 |
+
[SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]
|
214 |
+
)
|
215 |
+
|
216 |
+
# Ignore copy
|
217 |
+
def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor:
|
218 |
+
# inputs_embeds: [Batch_Size, Num_Patches, Embed_Dim]
|
219 |
+
hidden_states = inputs_embeds
|
220 |
+
|
221 |
+
for encoder_layer in self.layers:
|
222 |
+
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
|
223 |
+
hidden_states = encoder_layer(hidden_states)
|
224 |
+
|
225 |
+
return hidden_states
|
226 |
+
|
227 |
+
|
228 |
+
class SiglipVisionTransformer(nn.Module):
|
229 |
+
def __init__(self, config: SiglipVisionConfig):
|
230 |
+
super().__init__()
|
231 |
+
self.config = config
|
232 |
+
embed_dim = config.hidden_size
|
233 |
+
|
234 |
+
self.embeddings = SiglipVisionEmbeddings(config)
|
235 |
+
self.encoder = SiglipEncoder(config)
|
236 |
+
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
237 |
+
|
238 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
239 |
+
# pixel_values: [Batch_Size, Channels, Height, Width] -> [Batch_Size, Num_Patches, Embed_Dim]
|
240 |
+
hidden_states = self.embeddings(pixel_values)
|
241 |
+
|
242 |
+
last_hidden_state = self.encoder(inputs_embeds=hidden_states)
|
243 |
+
|
244 |
+
last_hidden_state = self.post_layernorm(last_hidden_state)
|
245 |
+
|
246 |
+
return last_hidden_state
|
247 |
+
|
248 |
+
|
249 |
+
class SiglipVisionModel(nn.Module):
|
250 |
+
|
251 |
+
def __init__(self, config: SiglipVisionConfig):
|
252 |
+
super().__init__()
|
253 |
+
self.config = config
|
254 |
+
self.vision_model = SiglipVisionTransformer(config)
|
255 |
+
|
256 |
+
def forward(self, pixel_values) -> Tuple:
|
257 |
+
# [Batch_Size, Channels, Height, Width] -> [Batch_Size, Num_Patches, Embed_Dim]
|
258 |
+
return self.vision_model(pixel_values=pixel_values)
|
src/model/modules/tokenizer.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/data/tokenizer.py
|
2 |
+
# Copyright 2023 (authors: Feiteng Li)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import re
|
17 |
+
from dataclasses import asdict, dataclass
|
18 |
+
from typing import Any, Dict, List, Optional, Pattern, Union
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
import torchaudio
|
23 |
+
# from lhotse.features import FeatureExtractor
|
24 |
+
# from lhotse.utils import Seconds, compute_num_frames
|
25 |
+
from phonemizer.backend import EspeakBackend
|
26 |
+
from phonemizer.backend.espeak.language_switch import LanguageSwitch
|
27 |
+
from phonemizer.backend.espeak.words_mismatch import WordMismatch
|
28 |
+
from phonemizer.punctuation import Punctuation
|
29 |
+
from phonemizer.separator import Separator
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
class TextTokenizer:
|
34 |
+
"""Phonemize Text."""
|
35 |
+
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
language="en-us",
|
39 |
+
backend="espeak",
|
40 |
+
separator=Separator(word="_", syllable="-", phone="|"),
|
41 |
+
preserve_punctuation=True,
|
42 |
+
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
|
43 |
+
with_stress: bool = False,
|
44 |
+
tie: Union[bool, str] = False,
|
45 |
+
language_switch: LanguageSwitch = "keep-flags",
|
46 |
+
words_mismatch: WordMismatch = "ignore",
|
47 |
+
) -> None:
|
48 |
+
phonemizer = EspeakBackend(
|
49 |
+
language,
|
50 |
+
punctuation_marks=punctuation_marks,
|
51 |
+
preserve_punctuation=preserve_punctuation,
|
52 |
+
with_stress=with_stress,
|
53 |
+
tie=tie,
|
54 |
+
language_switch=language_switch,
|
55 |
+
words_mismatch=words_mismatch,
|
56 |
+
)
|
57 |
+
|
58 |
+
self.backend = phonemizer
|
59 |
+
self.separator = separator
|
60 |
+
|
61 |
+
def to_list(self, phonemized: str) -> List[str]:
|
62 |
+
fields = []
|
63 |
+
for word in phonemized.split(self.separator.word):
|
64 |
+
# "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
|
65 |
+
pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
|
66 |
+
fields.extend(
|
67 |
+
[p for p in pp if p != self.separator.phone]
|
68 |
+
+ [self.separator.word]
|
69 |
+
)
|
70 |
+
assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
|
71 |
+
self.separator.phone
|
72 |
+
)
|
73 |
+
return fields[:-1]
|
74 |
+
|
75 |
+
def __call__(self, text, strip=True) -> List[List[str]]:
|
76 |
+
if isinstance(text, str):
|
77 |
+
text = [text]
|
78 |
+
|
79 |
+
phonemized = self.backend.phonemize(
|
80 |
+
text, separator=self.separator, strip=strip, njobs=1
|
81 |
+
)
|
82 |
+
return [self.to_list(p) for p in phonemized]
|
83 |
+
|
84 |
+
|
85 |
+
def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
|
86 |
+
phonemes = tokenizer([text.strip()])
|
87 |
+
return phonemes[0] # k2symbols
|
88 |
+
|
89 |
+
def convert_audio(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int):
|
90 |
+
assert wav.shape[0] in [1, 2], "Audio must be mono or stereo."
|
91 |
+
if target_channels == 1:
|
92 |
+
wav = wav.mean(0, keepdim=True)
|
93 |
+
elif target_channels == 2:
|
94 |
+
*shape, _, length = wav.shape
|
95 |
+
wav = wav.expand(*shape, target_channels, length)
|
96 |
+
elif wav.shape[0] == 1:
|
97 |
+
wav = wav.expand(target_channels, -1)
|
98 |
+
wav = torchaudio.transforms.Resample(sr, target_sr)(wav)
|
99 |
+
return wav
|
100 |
+
|
101 |
+
class AudioTokenizer:
|
102 |
+
"""EnCodec audio."""
|
103 |
+
|
104 |
+
def __init__(
|
105 |
+
self,
|
106 |
+
device: Any = None,
|
107 |
+
signature = None
|
108 |
+
) -> None:
|
109 |
+
from audiocraft.solvers import CompressionSolver
|
110 |
+
model = CompressionSolver.model_from_checkpoint(signature)
|
111 |
+
self.sample_rate = model.sample_rate
|
112 |
+
self.channels = model.channels
|
113 |
+
|
114 |
+
if not device:
|
115 |
+
device = torch.device("cpu")
|
116 |
+
if torch.cuda.is_available():
|
117 |
+
device = torch.device("cuda:0")
|
118 |
+
|
119 |
+
self._device = device
|
120 |
+
|
121 |
+
self.codec = model.to(device)
|
122 |
+
|
123 |
+
@property
|
124 |
+
def device(self):
|
125 |
+
return self._device
|
126 |
+
|
127 |
+
def encode(self, wav: torch.Tensor) -> torch.Tensor:
|
128 |
+
codes = self.codec.encode(wav.to(self.device))
|
129 |
+
return [(codes[0], None)]
|
130 |
+
|
131 |
+
def decode(self, frames: torch.Tensor) -> torch.Tensor:
|
132 |
+
frames = frames[0][0] # [1,4,T]
|
133 |
+
return self.codec.decode(frames)
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str, offset = -1, num_frames=-1):
|
138 |
+
# Load and pre-process the audio waveform
|
139 |
+
if offset != -1 and num_frames!=-1:
|
140 |
+
wav, sr = torchaudio.load(audio_path, frame_offset=offset, num_frames=num_frames)
|
141 |
+
else:
|
142 |
+
wav, sr = torchaudio.load(audio_path)
|
143 |
+
wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
|
144 |
+
wav = wav.unsqueeze(0)
|
145 |
+
|
146 |
+
# Extract discrete codes from EnCodec
|
147 |
+
with torch.no_grad():
|
148 |
+
encoded_frames = tokenizer.encode(wav)
|
149 |
+
return encoded_frames
|
src/model/modules/transformer.py
ADDED
@@ -0,0 +1,690 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py, modified by Puyuan Peng 2024
|
2 |
+
import copy
|
3 |
+
import numbers
|
4 |
+
from functools import partial
|
5 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import Tensor, nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
|
11 |
+
from .activation import MultiheadAttention
|
12 |
+
from .scaling import ActivationBalancer, BalancedDoubleSwish
|
13 |
+
from .scaling import BasicNorm as _BasicNorm
|
14 |
+
|
15 |
+
_shape_t = Union[int, List[int], torch.Size]
|
16 |
+
|
17 |
+
|
18 |
+
class LayerNorm(nn.Module):
|
19 |
+
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
|
20 |
+
normalized_shape: Tuple[int, ...]
|
21 |
+
eps: float
|
22 |
+
elementwise_affine: bool
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
normalized_shape: _shape_t,
|
27 |
+
eps: float = 1e-5,
|
28 |
+
elementwise_affine: bool = True,
|
29 |
+
device=None,
|
30 |
+
dtype=None,
|
31 |
+
) -> None:
|
32 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
33 |
+
super(LayerNorm, self).__init__()
|
34 |
+
if isinstance(normalized_shape, numbers.Integral):
|
35 |
+
# mypy error: incompatible types in assignment
|
36 |
+
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
37 |
+
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
38 |
+
self.eps = eps
|
39 |
+
self.elementwise_affine = elementwise_affine
|
40 |
+
if self.elementwise_affine:
|
41 |
+
self.weight = nn.Parameter(
|
42 |
+
torch.empty(self.normalized_shape, **factory_kwargs)
|
43 |
+
)
|
44 |
+
self.bias = nn.Parameter(
|
45 |
+
torch.empty(self.normalized_shape, **factory_kwargs)
|
46 |
+
)
|
47 |
+
else:
|
48 |
+
self.register_parameter("weight", None)
|
49 |
+
self.register_parameter("bias", None)
|
50 |
+
|
51 |
+
self.reset_parameters()
|
52 |
+
|
53 |
+
def reset_parameters(self) -> None:
|
54 |
+
if self.elementwise_affine:
|
55 |
+
nn.init.ones_(self.weight)
|
56 |
+
nn.init.zeros_(self.bias)
|
57 |
+
|
58 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
59 |
+
if isinstance(input, tuple):
|
60 |
+
input, embedding = input
|
61 |
+
return (
|
62 |
+
F.layer_norm(
|
63 |
+
input,
|
64 |
+
self.normalized_shape,
|
65 |
+
self.weight,
|
66 |
+
self.bias,
|
67 |
+
self.eps,
|
68 |
+
),
|
69 |
+
embedding,
|
70 |
+
)
|
71 |
+
|
72 |
+
assert embedding is None
|
73 |
+
return F.layer_norm(
|
74 |
+
input, self.normalized_shape, self.weight, self.bias, self.eps
|
75 |
+
)
|
76 |
+
|
77 |
+
def extra_repr(self) -> str:
|
78 |
+
return (
|
79 |
+
"{normalized_shape}, eps={eps}, "
|
80 |
+
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
81 |
+
)
|
82 |
+
|
83 |
+
|
84 |
+
class AdaptiveLayerNorm(nn.Module):
|
85 |
+
r"""Adaptive Layer Normalization"""
|
86 |
+
|
87 |
+
def __init__(self, d_model, norm) -> None:
|
88 |
+
super(AdaptiveLayerNorm, self).__init__()
|
89 |
+
self.project_layer = nn.Linear(d_model, 2 * d_model)
|
90 |
+
self.norm = norm
|
91 |
+
self.d_model = d_model
|
92 |
+
self.eps = self.norm.eps
|
93 |
+
|
94 |
+
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
|
95 |
+
if isinstance(input, tuple):
|
96 |
+
input, embedding = input
|
97 |
+
weight, bias = torch.split(
|
98 |
+
self.project_layer(embedding),
|
99 |
+
split_size_or_sections=self.d_model,
|
100 |
+
dim=-1,
|
101 |
+
)
|
102 |
+
return (weight * self.norm(input) + bias, embedding)
|
103 |
+
|
104 |
+
weight, bias = torch.split(
|
105 |
+
self.project_layer(embedding),
|
106 |
+
split_size_or_sections=self.d_model,
|
107 |
+
dim=-1,
|
108 |
+
)
|
109 |
+
return weight * self.norm(input) + bias
|
110 |
+
|
111 |
+
|
112 |
+
class BasicNorm(_BasicNorm):
|
113 |
+
def __init__(
|
114 |
+
self,
|
115 |
+
d_model: int,
|
116 |
+
eps: float = 1e-5,
|
117 |
+
device=None,
|
118 |
+
dtype=None,
|
119 |
+
):
|
120 |
+
super(BasicNorm, self).__init__(d_model, eps=eps)
|
121 |
+
|
122 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
123 |
+
if isinstance(input, tuple):
|
124 |
+
input, embedding = input
|
125 |
+
return (
|
126 |
+
super(BasicNorm, self).forward(input),
|
127 |
+
embedding,
|
128 |
+
)
|
129 |
+
|
130 |
+
assert embedding is None
|
131 |
+
return super(BasicNorm, self).forward(input)
|
132 |
+
|
133 |
+
|
134 |
+
class BalancedBasicNorm(nn.Module):
|
135 |
+
def __init__(
|
136 |
+
self,
|
137 |
+
d_model: int,
|
138 |
+
eps: float = 1e-5,
|
139 |
+
device=None,
|
140 |
+
dtype=None,
|
141 |
+
):
|
142 |
+
super(BalancedBasicNorm, self).__init__()
|
143 |
+
self.balancer = ActivationBalancer(
|
144 |
+
d_model,
|
145 |
+
channel_dim=-1,
|
146 |
+
min_positive=0.45,
|
147 |
+
max_positive=0.55,
|
148 |
+
max_abs=6.0,
|
149 |
+
)
|
150 |
+
self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype)
|
151 |
+
|
152 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
153 |
+
if isinstance(input, tuple):
|
154 |
+
input, embedding = input
|
155 |
+
return self.norm((self.balancer(input), embedding))
|
156 |
+
|
157 |
+
assert embedding is None
|
158 |
+
return self.norm(self.balancer(input))
|
159 |
+
|
160 |
+
|
161 |
+
class IdentityNorm(nn.Module):
|
162 |
+
def __init__(
|
163 |
+
self,
|
164 |
+
d_model: int,
|
165 |
+
eps: float = 1e-5,
|
166 |
+
device=None,
|
167 |
+
dtype=None,
|
168 |
+
) -> None:
|
169 |
+
super(IdentityNorm, self).__init__()
|
170 |
+
|
171 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
172 |
+
if isinstance(input, tuple):
|
173 |
+
return input
|
174 |
+
|
175 |
+
assert embedding is None
|
176 |
+
return input
|
177 |
+
|
178 |
+
|
179 |
+
class TransformerEncoderLayer(nn.Module):
|
180 |
+
__constants__ = ["batch_first", "norm_first"]
|
181 |
+
|
182 |
+
def __init__(
|
183 |
+
self,
|
184 |
+
d_model: int,
|
185 |
+
nhead: int,
|
186 |
+
dim_feedforward: int = 2048,
|
187 |
+
dropout: float = 0.1,
|
188 |
+
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
189 |
+
batch_first: bool = False,
|
190 |
+
norm_first: bool = False,
|
191 |
+
device=None,
|
192 |
+
dtype=None,
|
193 |
+
linear1_self_attention_cls: nn.Module = nn.Linear,
|
194 |
+
linear2_self_attention_cls: nn.Module = nn.Linear,
|
195 |
+
linear1_feedforward_cls: nn.Module = nn.Linear,
|
196 |
+
linear2_feedforward_cls: nn.Module = nn.Linear,
|
197 |
+
layer_norm_cls: nn.Module = LayerNorm,
|
198 |
+
layer_norm_eps: float = 1e-5,
|
199 |
+
adaptive_layer_norm=False,
|
200 |
+
) -> None:
|
201 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
202 |
+
super(TransformerEncoderLayer, self).__init__()
|
203 |
+
self.self_attn = MultiheadAttention(
|
204 |
+
d_model,
|
205 |
+
nhead,
|
206 |
+
dropout=dropout,
|
207 |
+
batch_first=batch_first,
|
208 |
+
linear1_cls=linear1_self_attention_cls,
|
209 |
+
linear2_cls=linear2_self_attention_cls,
|
210 |
+
**factory_kwargs,
|
211 |
+
)
|
212 |
+
|
213 |
+
# Implementation of Feedforward model
|
214 |
+
self.linear1 = linear1_feedforward_cls(
|
215 |
+
d_model, dim_feedforward, **factory_kwargs
|
216 |
+
)
|
217 |
+
self.dropout = nn.Dropout(dropout)
|
218 |
+
self.linear2 = linear2_feedforward_cls(
|
219 |
+
dim_feedforward, d_model, **factory_kwargs
|
220 |
+
)
|
221 |
+
|
222 |
+
self.norm_first = norm_first
|
223 |
+
self.dropout1 = nn.Dropout(dropout)
|
224 |
+
self.dropout2 = nn.Dropout(dropout)
|
225 |
+
|
226 |
+
# Legacy string support for activation function.
|
227 |
+
if isinstance(activation, str):
|
228 |
+
activation = _get_activation_fn(activation)
|
229 |
+
elif isinstance(activation, partial):
|
230 |
+
activation = activation(d_model)
|
231 |
+
elif activation == BalancedDoubleSwish:
|
232 |
+
activation = BalancedDoubleSwish(d_model)
|
233 |
+
|
234 |
+
# # We can't test self.activation in forward() in TorchScript,
|
235 |
+
# # so stash some information about it instead.
|
236 |
+
# if activation is F.relu or isinstance(activation, torch.nn.ReLU):
|
237 |
+
# self.activation_relu_or_gelu = 1
|
238 |
+
# elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
|
239 |
+
# self.activation_relu_or_gelu = 2
|
240 |
+
# else:
|
241 |
+
# self.activation_relu_or_gelu = 0
|
242 |
+
self.activation = activation
|
243 |
+
|
244 |
+
norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
|
245 |
+
if layer_norm_cls == IdentityNorm:
|
246 |
+
norm2 = BalancedBasicNorm(
|
247 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
248 |
+
)
|
249 |
+
else:
|
250 |
+
norm2 = layer_norm_cls(
|
251 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
252 |
+
)
|
253 |
+
|
254 |
+
if adaptive_layer_norm:
|
255 |
+
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
|
256 |
+
self.norm2 = AdaptiveLayerNorm(d_model, norm2)
|
257 |
+
else:
|
258 |
+
self.norm1 = norm1
|
259 |
+
self.norm2 = norm2
|
260 |
+
|
261 |
+
def __setstate__(self, state):
|
262 |
+
super(TransformerEncoderLayer, self).__setstate__(state)
|
263 |
+
if not hasattr(self, "activation"):
|
264 |
+
self.activation = F.relu
|
265 |
+
|
266 |
+
def forward(
|
267 |
+
self,
|
268 |
+
src: Tensor,
|
269 |
+
src_mask: Optional[Tensor] = None,
|
270 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
271 |
+
need_weights: Optional[bool] = False,
|
272 |
+
past: Optional[Tensor] = None,
|
273 |
+
) -> Tensor:
|
274 |
+
r"""Pass the input through the encoder layer.
|
275 |
+
Args:
|
276 |
+
src: the sequence to the encoder layer (required).
|
277 |
+
src_mask: the mask for the src sequence (optional).
|
278 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
279 |
+
Shape:
|
280 |
+
see the docs in Transformer class.
|
281 |
+
"""
|
282 |
+
x, stage_embedding = src, None
|
283 |
+
is_src_tuple = False
|
284 |
+
if isinstance(src, tuple):
|
285 |
+
x, stage_embedding = src
|
286 |
+
is_src_tuple = True
|
287 |
+
|
288 |
+
if src_key_padding_mask is not None:
|
289 |
+
_skpm_dtype = src_key_padding_mask.dtype
|
290 |
+
if _skpm_dtype != torch.bool and not torch.is_floating_point(
|
291 |
+
src_key_padding_mask
|
292 |
+
):
|
293 |
+
raise AssertionError(
|
294 |
+
"only bool and floating types of key_padding_mask are supported"
|
295 |
+
)
|
296 |
+
if need_weights:
|
297 |
+
if self.norm_first:
|
298 |
+
out, attn = self._sa_block_attn(
|
299 |
+
self.norm1(x, stage_embedding),
|
300 |
+
src_mask,
|
301 |
+
src_key_padding_mask,
|
302 |
+
past
|
303 |
+
)
|
304 |
+
out, present = out # present is the kvcache of the present timestep
|
305 |
+
x = x + out
|
306 |
+
x = x + self._ff_block(self.norm2(x, stage_embedding))
|
307 |
+
else:
|
308 |
+
out, attn = self._sa_block_attn(x, src_mask, src_key_padding_mask, past)
|
309 |
+
out, present = out # present is the kvcache of the present timestep
|
310 |
+
x = self.norm1(
|
311 |
+
x + out,
|
312 |
+
stage_embedding,
|
313 |
+
)
|
314 |
+
x = self.norm2(x + self._ff_block(x), stage_embedding)
|
315 |
+
assert not is_src_tuple
|
316 |
+
# return (x, stage_embedding)
|
317 |
+
return (x, attn)
|
318 |
+
else:
|
319 |
+
if self.norm_first:
|
320 |
+
out = self._sa_block(
|
321 |
+
self.norm1(x, stage_embedding),
|
322 |
+
src_mask,
|
323 |
+
src_key_padding_mask, past
|
324 |
+
)
|
325 |
+
out, present = out # present is the kvcache of the present timestep
|
326 |
+
x = x + out
|
327 |
+
x = x + self._ff_block(self.norm2(x, stage_embedding))
|
328 |
+
else:
|
329 |
+
out = self._sa_block(x, src_mask, src_key_padding_mask)
|
330 |
+
out, present = out # present is the kvcache of the present timestep
|
331 |
+
x = self.norm1(
|
332 |
+
x + out,
|
333 |
+
stage_embedding, past
|
334 |
+
)
|
335 |
+
x = self.norm2(x + self._ff_block(x), stage_embedding)
|
336 |
+
|
337 |
+
if is_src_tuple:
|
338 |
+
x = (x, stage_embedding)
|
339 |
+
if present != None:
|
340 |
+
x = [x, present]
|
341 |
+
return x
|
342 |
+
|
343 |
+
# self-attention block
|
344 |
+
def _sa_block(
|
345 |
+
self,
|
346 |
+
x: Tensor,
|
347 |
+
attn_mask: Optional[Tensor],
|
348 |
+
key_padding_mask: Optional[Tensor],
|
349 |
+
past: Optional[Tensor] = None,
|
350 |
+
) -> Tensor:
|
351 |
+
x = self.self_attn(
|
352 |
+
x,
|
353 |
+
x,
|
354 |
+
x,
|
355 |
+
attn_mask=attn_mask,
|
356 |
+
key_padding_mask=key_padding_mask,
|
357 |
+
need_weights=False,
|
358 |
+
past=past
|
359 |
+
)
|
360 |
+
x, present = x
|
361 |
+
return self.dropout1(x), present
|
362 |
+
|
363 |
+
# self-attention block, also return attention weights
|
364 |
+
def _sa_block_attn(
|
365 |
+
self,
|
366 |
+
x: Tensor,
|
367 |
+
attn_mask: Optional[Tensor],
|
368 |
+
key_padding_mask: Optional[Tensor],
|
369 |
+
past: Optional[Tensor] = None,
|
370 |
+
) -> Tensor:
|
371 |
+
x, attn = self.self_attn(
|
372 |
+
x,
|
373 |
+
x,
|
374 |
+
x,
|
375 |
+
attn_mask=attn_mask,
|
376 |
+
key_padding_mask=key_padding_mask,
|
377 |
+
need_weights=True,
|
378 |
+
past=past
|
379 |
+
)
|
380 |
+
x, present = x
|
381 |
+
return (self.dropout1(x), present), attn
|
382 |
+
|
383 |
+
# feed forward block
|
384 |
+
def _ff_block(self, x: Tensor) -> Tensor:
|
385 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
386 |
+
return self.dropout2(x)
|
387 |
+
|
388 |
+
|
389 |
+
class TransformerEncoder(nn.Module):
|
390 |
+
r"""TransformerEncoder is a stack of N encoder layers. Users can build the
|
391 |
+
BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
|
392 |
+
Args:
|
393 |
+
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
|
394 |
+
num_layers: the number of sub-encoder-layers in the encoder (required).
|
395 |
+
norm: the layer normalization component (optional).
|
396 |
+
enable_nested_tensor: if True, input will automatically convert to nested tensor
|
397 |
+
(and convert back on output). This will improve the overall performance of
|
398 |
+
TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
|
399 |
+
Examples::
|
400 |
+
>>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
|
401 |
+
>>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
|
402 |
+
>>> src = torch.rand(10, 32, 512)
|
403 |
+
>>> out = transformer_encoder(src)
|
404 |
+
"""
|
405 |
+
__constants__ = ["norm"]
|
406 |
+
|
407 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
408 |
+
super(TransformerEncoder, self).__init__()
|
409 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
410 |
+
self.num_layers = num_layers
|
411 |
+
self.norm = norm
|
412 |
+
|
413 |
+
def forward(
|
414 |
+
self,
|
415 |
+
src: Tensor,
|
416 |
+
mask: Optional[Tensor] = None,
|
417 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
418 |
+
return_layer_states: bool = False,
|
419 |
+
need_weights:Optional[bool] = False,
|
420 |
+
past: Optional[Tensor] = None,
|
421 |
+
) -> Tensor:
|
422 |
+
r"""Pass the input through the encoder layers in turn.
|
423 |
+
Args:
|
424 |
+
src: the sequence to the encoder (required).
|
425 |
+
mask: the mask for the src sequence (optional).
|
426 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
427 |
+
return_layer_states: return layers' state (optional).
|
428 |
+
Shape:
|
429 |
+
see the docs in Transformer class.
|
430 |
+
"""
|
431 |
+
if return_layer_states:
|
432 |
+
assert not need_weights
|
433 |
+
layer_states = [] # layers' output
|
434 |
+
output = src
|
435 |
+
for mod in self.layers:
|
436 |
+
output = mod(
|
437 |
+
output,
|
438 |
+
src_mask=mask,
|
439 |
+
src_key_padding_mask=src_key_padding_mask,
|
440 |
+
past=past
|
441 |
+
)
|
442 |
+
layer_states.append(output[0])
|
443 |
+
|
444 |
+
if self.norm is not None:
|
445 |
+
output = self.norm(output)
|
446 |
+
|
447 |
+
return layer_states, output
|
448 |
+
if need_weights:
|
449 |
+
assert not return_layer_states
|
450 |
+
layer_attn = [] # layers' output
|
451 |
+
output = src
|
452 |
+
for mod in self.layers:
|
453 |
+
output = mod(
|
454 |
+
output,
|
455 |
+
src_mask=mask,
|
456 |
+
src_key_padding_mask=src_key_padding_mask,
|
457 |
+
need_weights=True,
|
458 |
+
past=past
|
459 |
+
)
|
460 |
+
layer_attn.append(output[1])
|
461 |
+
|
462 |
+
if self.norm is not None:
|
463 |
+
output = self.norm(output)
|
464 |
+
|
465 |
+
return layer_attn, output
|
466 |
+
|
467 |
+
output = src
|
468 |
+
all_present = []
|
469 |
+
for n_layer, mod in enumerate(self.layers):
|
470 |
+
output = mod(
|
471 |
+
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, past=None if past is None else past[n_layer]
|
472 |
+
)
|
473 |
+
if isinstance(output, list):
|
474 |
+
output, present = output
|
475 |
+
all_present.append(present)
|
476 |
+
|
477 |
+
if self.norm is not None:
|
478 |
+
output = self.norm(output)
|
479 |
+
if all_present != []:
|
480 |
+
all_present = torch.stack(all_present, dim=0) # (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
|
481 |
+
output = [output, all_present]
|
482 |
+
return output
|
483 |
+
|
484 |
+
|
485 |
+
class TransformerDecoderLayer(nn.Module):
|
486 |
+
__constants__ = ["batch_first", "norm_first"]
|
487 |
+
|
488 |
+
def __init__(
|
489 |
+
self,
|
490 |
+
d_model: int,
|
491 |
+
nhead: int,
|
492 |
+
dim_feedforward: int = 2048,
|
493 |
+
dropout: float = 0.1,
|
494 |
+
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
495 |
+
linear1_self_attention_cls: nn.Module = nn.Linear,
|
496 |
+
linear2_self_attention_cls: nn.Module = nn.Linear,
|
497 |
+
linear1_feedforward_cls: nn.Module = nn.Linear,
|
498 |
+
linear2_feedforward_cls: nn.Module = nn.Linear,
|
499 |
+
batch_first: bool = False,
|
500 |
+
norm_first: bool = False,
|
501 |
+
device=None,
|
502 |
+
dtype=None,
|
503 |
+
layer_norm_cls: nn.Module = LayerNorm,
|
504 |
+
layer_norm_eps: float = 1e-5,
|
505 |
+
adaptive_layer_norm=False,
|
506 |
+
) -> None:
|
507 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
508 |
+
super(TransformerDecoderLayer, self).__init__()
|
509 |
+
self.self_attn = MultiheadAttention(
|
510 |
+
d_model,
|
511 |
+
nhead,
|
512 |
+
dropout=dropout,
|
513 |
+
batch_first=batch_first,
|
514 |
+
linear1_cls=linear1_self_attention_cls,
|
515 |
+
linear2_cls=linear2_self_attention_cls,
|
516 |
+
**factory_kwargs,
|
517 |
+
)
|
518 |
+
self.multihead_attn = MultiheadAttention(
|
519 |
+
d_model,
|
520 |
+
nhead,
|
521 |
+
dropout=dropout,
|
522 |
+
batch_first=batch_first,
|
523 |
+
linear1_cls=linear1_self_attention_cls,
|
524 |
+
linear2_cls=linear2_self_attention_cls,
|
525 |
+
**factory_kwargs,
|
526 |
+
)
|
527 |
+
# Implementation of Feedforward model
|
528 |
+
self.linear1 = linear1_feedforward_cls(
|
529 |
+
d_model, dim_feedforward, **factory_kwargs
|
530 |
+
)
|
531 |
+
self.dropout = nn.Dropout(dropout)
|
532 |
+
self.linear2 = linear2_feedforward_cls(
|
533 |
+
dim_feedforward, d_model, **factory_kwargs
|
534 |
+
)
|
535 |
+
|
536 |
+
self.norm_first = norm_first
|
537 |
+
self.dropout1 = nn.Dropout(dropout)
|
538 |
+
self.dropout2 = nn.Dropout(dropout)
|
539 |
+
self.dropout3 = nn.Dropout(dropout)
|
540 |
+
|
541 |
+
# Legacy string support for activation function.
|
542 |
+
if isinstance(activation, str):
|
543 |
+
self.activation = _get_activation_fn(activation)
|
544 |
+
elif isinstance(activation, partial):
|
545 |
+
self.activation = activation(d_model)
|
546 |
+
elif activation == BalancedDoubleSwish:
|
547 |
+
self.activation = BalancedDoubleSwish(d_model)
|
548 |
+
else:
|
549 |
+
self.activation = activation
|
550 |
+
|
551 |
+
if adaptive_layer_norm:
|
552 |
+
norm1 = layer_norm_cls(
|
553 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
554 |
+
)
|
555 |
+
norm2 = layer_norm_cls(
|
556 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
557 |
+
)
|
558 |
+
norm3 = layer_norm_cls(
|
559 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
560 |
+
)
|
561 |
+
|
562 |
+
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
|
563 |
+
self.norm2 = AdaptiveLayerNorm(d_model, norm2)
|
564 |
+
self.norm3 = AdaptiveLayerNorm(d_model, norm3)
|
565 |
+
else:
|
566 |
+
self.norm1 = layer_norm_cls(
|
567 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
568 |
+
)
|
569 |
+
self.norm2 = layer_norm_cls(
|
570 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
571 |
+
)
|
572 |
+
if layer_norm_cls == IdentityNorm:
|
573 |
+
self.norm3 = BalancedBasicNorm(
|
574 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
575 |
+
)
|
576 |
+
else:
|
577 |
+
self.norm3 = layer_norm_cls(
|
578 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
579 |
+
)
|
580 |
+
|
581 |
+
def forward(
|
582 |
+
self,
|
583 |
+
tgt: Tensor,
|
584 |
+
memory: Tensor,
|
585 |
+
tgt_mask: Optional[Tensor] = None,
|
586 |
+
memory_mask: Optional[Tensor] = None,
|
587 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
588 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
589 |
+
) -> Tensor:
|
590 |
+
r"""Pass the inputs (and mask) through the decoder layer.
|
591 |
+
Args:
|
592 |
+
tgt: the sequence to the decoder layer (required).
|
593 |
+
memory: the sequence from the last layer of the encoder (required).
|
594 |
+
tgt_mask: the mask for the tgt sequence (optional).
|
595 |
+
memory_mask: the mask for the memory sequence (optional).
|
596 |
+
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
597 |
+
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
598 |
+
Shape:
|
599 |
+
see the docs in Transformer class.
|
600 |
+
"""
|
601 |
+
tgt_is_tuple = False
|
602 |
+
if isinstance(tgt, tuple):
|
603 |
+
x, stage_embedding = tgt
|
604 |
+
tgt_is_tuple = True
|
605 |
+
else:
|
606 |
+
x, stage_embedding = tgt, None
|
607 |
+
|
608 |
+
if self.norm_first:
|
609 |
+
x = x + self._sa_block(
|
610 |
+
self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask
|
611 |
+
)
|
612 |
+
x = x + self._mha_block(
|
613 |
+
self.norm2(x, stage_embedding),
|
614 |
+
memory,
|
615 |
+
memory_mask,
|
616 |
+
memory_key_padding_mask,
|
617 |
+
)
|
618 |
+
x = x + self._ff_block(self.norm3(x, stage_embedding))
|
619 |
+
else:
|
620 |
+
x = self.norm1(
|
621 |
+
x + self._sa_block(x, tgt_mask, tgt_key_padding_mask),
|
622 |
+
stage_embedding,
|
623 |
+
)
|
624 |
+
x = self.norm2(
|
625 |
+
x
|
626 |
+
+ self._mha_block(
|
627 |
+
x, memory, memory_mask, memory_key_padding_mask
|
628 |
+
),
|
629 |
+
stage_embedding,
|
630 |
+
)
|
631 |
+
x = self.norm3(x + self._ff_block(x), stage_embedding)
|
632 |
+
|
633 |
+
if tgt_is_tuple:
|
634 |
+
return (x, stage_embedding)
|
635 |
+
return x
|
636 |
+
|
637 |
+
# self-attention block
|
638 |
+
def _sa_block(
|
639 |
+
self,
|
640 |
+
x: Tensor,
|
641 |
+
attn_mask: Optional[Tensor],
|
642 |
+
key_padding_mask: Optional[Tensor],
|
643 |
+
) -> Tensor:
|
644 |
+
x = self.self_attn(
|
645 |
+
x,
|
646 |
+
x,
|
647 |
+
x,
|
648 |
+
attn_mask=attn_mask,
|
649 |
+
key_padding_mask=key_padding_mask,
|
650 |
+
need_weights=False,
|
651 |
+
)[0]
|
652 |
+
return self.dropout1(x)
|
653 |
+
|
654 |
+
# multihead attention block
|
655 |
+
def _mha_block(
|
656 |
+
self,
|
657 |
+
x: Tensor,
|
658 |
+
mem: Tensor,
|
659 |
+
attn_mask: Optional[Tensor],
|
660 |
+
key_padding_mask: Optional[Tensor],
|
661 |
+
) -> Tensor:
|
662 |
+
x = self.multihead_attn(
|
663 |
+
x,
|
664 |
+
mem,
|
665 |
+
mem,
|
666 |
+
attn_mask=attn_mask,
|
667 |
+
key_padding_mask=key_padding_mask,
|
668 |
+
need_weights=False,
|
669 |
+
)[0]
|
670 |
+
return self.dropout2(x)
|
671 |
+
|
672 |
+
# feed forward block
|
673 |
+
def _ff_block(self, x: Tensor) -> Tensor:
|
674 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
675 |
+
return self.dropout3(x)
|
676 |
+
|
677 |
+
|
678 |
+
def _get_clones(module, N):
|
679 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
680 |
+
|
681 |
+
|
682 |
+
def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
|
683 |
+
if activation == "relu":
|
684 |
+
return F.relu
|
685 |
+
elif activation == "gelu":
|
686 |
+
return F.gelu
|
687 |
+
|
688 |
+
raise RuntimeError(
|
689 |
+
"activation should be relu/gelu, not {}".format(activation)
|
690 |
+
)
|
src/model/modules/voicecraft.py
ADDED
@@ -0,0 +1,1999 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# cp from https://github.com/jasonppy/VoiceCraft/blob/master/models/voicecraft.py
|
2 |
+
|
3 |
+
import random
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import logging
|
7 |
+
import argparse, copy
|
8 |
+
from typing import Dict, Optional
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from torchmetrics.classification import MulticlassAccuracy
|
13 |
+
|
14 |
+
from .codebooks_patterns import DelayedPatternProvider
|
15 |
+
|
16 |
+
from ...utils.util import make_pad_mask
|
17 |
+
|
18 |
+
from .embedding import SinePositionalEmbedding, TokenEmbedding
|
19 |
+
from .transformer import (
|
20 |
+
LayerNorm,
|
21 |
+
TransformerEncoder,
|
22 |
+
TransformerEncoderLayer,
|
23 |
+
)
|
24 |
+
|
25 |
+
from argparse import Namespace
|
26 |
+
from huggingface_hub import PyTorchModelHubMixin
|
27 |
+
|
28 |
+
|
29 |
+
def top_k_top_p_filtering(
|
30 |
+
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
|
31 |
+
):
|
32 |
+
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
33 |
+
Args:
|
34 |
+
logits: logits distribution shape (batch size, vocabulary size)
|
35 |
+
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
36 |
+
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
37 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
38 |
+
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
39 |
+
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
40 |
+
"""
|
41 |
+
if top_k > 0:
|
42 |
+
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
|
43 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
44 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
45 |
+
logits[indices_to_remove] = filter_value
|
46 |
+
|
47 |
+
if top_p < 1.0:
|
48 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
49 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
50 |
+
|
51 |
+
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
52 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
53 |
+
if min_tokens_to_keep > 1:
|
54 |
+
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
55 |
+
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
56 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
57 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
58 |
+
sorted_indices_to_remove[..., 0] = 0
|
59 |
+
|
60 |
+
# scatter sorted tensors to original indexing
|
61 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
62 |
+
1, sorted_indices, sorted_indices_to_remove
|
63 |
+
)
|
64 |
+
logits[indices_to_remove] = filter_value
|
65 |
+
return logits
|
66 |
+
|
67 |
+
|
68 |
+
def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
|
69 |
+
# temperature: (`optional`) float
|
70 |
+
# The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
|
71 |
+
# top_k: (`optional`) int
|
72 |
+
# The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
|
73 |
+
# top_p: (`optional`) float
|
74 |
+
# The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
|
75 |
+
|
76 |
+
# Temperature (higher temperature => more likely to sample low probability tokens)
|
77 |
+
if temperature != 1.0:
|
78 |
+
logits = logits / temperature
|
79 |
+
# Top-p/top-k filtering
|
80 |
+
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
81 |
+
# Sample
|
82 |
+
token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
83 |
+
return token
|
84 |
+
|
85 |
+
|
86 |
+
class VoiceCraft(
|
87 |
+
nn.Module,
|
88 |
+
PyTorchModelHubMixin,
|
89 |
+
library_name="voicecraft",
|
90 |
+
repo_url="https://github.com/jasonppy/VoiceCraft",
|
91 |
+
tags=["text-to-speech"],
|
92 |
+
):
|
93 |
+
def __new__(
|
94 |
+
cls, args: Optional[Namespace] = None, config: Optional[Dict] = None, **kwargs
|
95 |
+
) -> "VoiceCraft":
|
96 |
+
# If initialized from Namespace args => convert to dict config for 'PyTorchModelHubMixin' to serialize it as config.json
|
97 |
+
# Won't affect instance initialization
|
98 |
+
if args is not None:
|
99 |
+
if config is not None:
|
100 |
+
raise ValueError("Cannot provide both `args` and `config`.")
|
101 |
+
config = vars(args)
|
102 |
+
return super().__new__(cls, args=args, config=config, **kwargs)
|
103 |
+
|
104 |
+
def __init__(self, args: Optional[Namespace] = None, config: Optional[Dict] = None):
|
105 |
+
super().__init__()
|
106 |
+
|
107 |
+
# If loaded from HF Hub => convert config.json to Namespace args before initializing
|
108 |
+
if args is None:
|
109 |
+
if config is None:
|
110 |
+
raise ValueError("Either `args` or `config` must be provided.")
|
111 |
+
args = Namespace(**config)
|
112 |
+
|
113 |
+
self.args = copy.copy(args)
|
114 |
+
self.pattern = DelayedPatternProvider(n_q=self.args.n_codebooks)
|
115 |
+
if not getattr(self.args, "special_first", False):
|
116 |
+
self.args.special_first = 0
|
117 |
+
if not getattr(self.args, "n_special", False):
|
118 |
+
self.args.n_special = 3
|
119 |
+
self.args.eos = getattr(self.args, "eos", -1)
|
120 |
+
self.eog = nn.Parameter(
|
121 |
+
torch.full((self.args.n_codebooks, 1), self.args.eog, dtype=torch.long),
|
122 |
+
requires_grad=False,
|
123 |
+
) # [K 1]
|
124 |
+
if self.args.eos > 0:
|
125 |
+
assert (
|
126 |
+
self.args.eos != self.args.audio_pad_token
|
127 |
+
and self.args.eos != self.args.empty_token
|
128 |
+
), self.args.eos
|
129 |
+
self.eos = nn.Parameter(
|
130 |
+
torch.full((self.args.n_codebooks, 1), self.args.eos, dtype=torch.long),
|
131 |
+
requires_grad=False,
|
132 |
+
) # [K 1]
|
133 |
+
if isinstance(self.args.audio_vocab_size, str):
|
134 |
+
self.args.audio_vocab_size = eval(self.args.audio_vocab_size)
|
135 |
+
|
136 |
+
self.n_text_tokens = self.args.text_vocab_size + 1
|
137 |
+
assert (
|
138 |
+
self.args.text_pad_token == self.args.text_vocab_size
|
139 |
+
), f"self.args.text_vocab_size: {self.args.text_vocab_size}, self.args.text_pad_token: {self.args.text_pad_token}"
|
140 |
+
|
141 |
+
self.n_audio_tokens = [
|
142 |
+
self.args.audio_vocab_size + self.args.n_special
|
143 |
+
] * self.args.n_codebooks # special tokens: empty token, EOG token, audio pad token
|
144 |
+
assert (
|
145 |
+
self.args.audio_vocab_size == self.args.empty_token
|
146 |
+
), self.args.empty_token
|
147 |
+
assert self.args.eog == self.args.audio_vocab_size + 1, self.args.eog
|
148 |
+
assert (
|
149 |
+
self.args.audio_pad_token == self.args.audio_vocab_size + 2
|
150 |
+
), self.args.audio_pad_token
|
151 |
+
|
152 |
+
self.text_embedding = TokenEmbedding(
|
153 |
+
dim_model=self.args.d_model,
|
154 |
+
vocab_size=self.n_text_tokens,
|
155 |
+
dropout=self.args.text_embedding_dropout,
|
156 |
+
)
|
157 |
+
|
158 |
+
self.audio_embedding = nn.ModuleList(
|
159 |
+
[
|
160 |
+
TokenEmbedding(
|
161 |
+
dim_model=self.args.audio_embedding_dim,
|
162 |
+
vocab_size=self.n_audio_tokens[k],
|
163 |
+
dropout=self.args.audio_embedding_dropout,
|
164 |
+
)
|
165 |
+
for k in range(self.args.n_codebooks)
|
166 |
+
]
|
167 |
+
)
|
168 |
+
self.mask_embedding = nn.Parameter(
|
169 |
+
torch.randn(self.args.max_n_spans, self.args.d_model), requires_grad=True
|
170 |
+
)
|
171 |
+
self.text_positional_embedding = SinePositionalEmbedding(
|
172 |
+
self.args.d_model,
|
173 |
+
dropout=self.args.text_positional_embedding_dropout,
|
174 |
+
scale=False,
|
175 |
+
alpha=True, # learnable scaler, scale the volume of positional embedding
|
176 |
+
)
|
177 |
+
self.audio_positional_embedding = SinePositionalEmbedding(
|
178 |
+
self.args.d_model,
|
179 |
+
dropout=self.args.audio_positional_embedding_dropout,
|
180 |
+
scale=False,
|
181 |
+
alpha=True, # learnable scaler, scale the volume of positional embedding
|
182 |
+
)
|
183 |
+
|
184 |
+
dec_layer = TransformerEncoderLayer(
|
185 |
+
self.args.d_model,
|
186 |
+
self.args.nhead,
|
187 |
+
dim_feedforward=self.args.d_model * 4,
|
188 |
+
dropout=self.args.trm_dropout,
|
189 |
+
batch_first=True,
|
190 |
+
norm_first=True,
|
191 |
+
layer_norm_cls=LayerNorm,
|
192 |
+
)
|
193 |
+
self.decoder = TransformerEncoder(
|
194 |
+
dec_layer,
|
195 |
+
num_layers=self.args.num_decoder_layers,
|
196 |
+
norm=LayerNorm(self.args.d_model),
|
197 |
+
)
|
198 |
+
|
199 |
+
self.predict_layer = nn.ModuleList(
|
200 |
+
[
|
201 |
+
nn.Sequential(
|
202 |
+
nn.Linear(self.args.d_model, self.args.audio_vocab_size // 2),
|
203 |
+
nn.GELU(),
|
204 |
+
nn.Linear(self.args.audio_vocab_size // 2, self.n_audio_tokens[k]),
|
205 |
+
)
|
206 |
+
for k in range(self.args.n_codebooks)
|
207 |
+
]
|
208 |
+
)
|
209 |
+
|
210 |
+
self.accuracy_metrics = nn.ModuleList(
|
211 |
+
[
|
212 |
+
MulticlassAccuracy(
|
213 |
+
self.n_audio_tokens[k],
|
214 |
+
top_k=10,
|
215 |
+
average="micro",
|
216 |
+
multidim_average="global",
|
217 |
+
ignore_index=None,
|
218 |
+
)
|
219 |
+
for k in range(self.args.n_codebooks)
|
220 |
+
]
|
221 |
+
)
|
222 |
+
|
223 |
+
def prepare_mask_intervals(self, y_lens):
|
224 |
+
mask_intervals = []
|
225 |
+
non_mask_intervals = []
|
226 |
+
|
227 |
+
for i, y_len in enumerate(y_lens):
|
228 |
+
if self.args.mask_sample_dist == "uniform":
|
229 |
+
n_spans = random.choice(range(1, self.args.max_n_spans + 1))
|
230 |
+
elif "poisson" in self.args.mask_sample_dist.lower():
|
231 |
+
param = float(self.args.mask_sample_dist[len("poisson") :])
|
232 |
+
poisson_sample = torch.poisson(torch.tensor([param]))
|
233 |
+
n_spans = int(poisson_sample.clamp(1, self.args.max_n_spans).item())
|
234 |
+
|
235 |
+
starts = random.sample(
|
236 |
+
range(1, y_len - 1 - self.args.mask_len_min), n_spans
|
237 |
+
)
|
238 |
+
starts = sorted(starts)
|
239 |
+
|
240 |
+
for j in range(len(starts) - 1, 0, -1):
|
241 |
+
if starts[j] - starts[j - 1] < self.args.min_gap:
|
242 |
+
del starts[j] # If elements are too close, delete the later one
|
243 |
+
assert (
|
244 |
+
len(starts) > 0
|
245 |
+
), f"there is no masked span left, y_len: {y_len}, sampled n_spans: {n_spans}"
|
246 |
+
|
247 |
+
temp_starts = starts + [y_len]
|
248 |
+
gaps = [
|
249 |
+
temp_starts[j + 1] - temp_starts[j] for j in range(len(temp_starts) - 1)
|
250 |
+
]
|
251 |
+
|
252 |
+
ends = []
|
253 |
+
|
254 |
+
for j, (start, gap) in enumerate(zip(starts, gaps)):
|
255 |
+
mask_len = random.randint(
|
256 |
+
self.args.mask_len_min, self.args.mask_len_max
|
257 |
+
)
|
258 |
+
# if mask_len > gap * self.args.max_mask_portion: # make sure the masks are not overlapping with each other
|
259 |
+
if (
|
260 |
+
mask_len > gap - 1
|
261 |
+
): # make sure the masks are not overlapping with each other
|
262 |
+
# temp_mask_start = int(0.6*gap*self.args.max_mask_portion)
|
263 |
+
# temp_mask_end = int(gap*self.args.max_mask_portion)
|
264 |
+
temp_mask_start = 1
|
265 |
+
temp_mask_end = gap - 1
|
266 |
+
mask_len = random.randint(temp_mask_start, temp_mask_end)
|
267 |
+
ends.append(start + mask_len)
|
268 |
+
|
269 |
+
mask_intervals.append([(s, e) for s, e in zip(starts, ends)])
|
270 |
+
non_mask_intervals.append(
|
271 |
+
[(ns, ne) for ns, ne in zip([0] + ends, starts + [y_len])]
|
272 |
+
)
|
273 |
+
|
274 |
+
return mask_intervals, non_mask_intervals
|
275 |
+
|
276 |
+
def rearrange(self, y, non_mask_intervals, mask_intervals):
|
277 |
+
reduced_eog = getattr(self.args, "reduced_eog", 0)
|
278 |
+
rearranged_y = []
|
279 |
+
for i in range(len(y)):
|
280 |
+
if self.args.eos > 0:
|
281 |
+
assert reduced_eog
|
282 |
+
cur_y = (
|
283 |
+
[y[i, :, item[0] : item[1]] for item in non_mask_intervals[i][:-1]]
|
284 |
+
+ [
|
285 |
+
torch.cat(
|
286 |
+
[
|
287 |
+
y[
|
288 |
+
i,
|
289 |
+
:,
|
290 |
+
non_mask_intervals[i][-1][0] : non_mask_intervals[
|
291 |
+
i
|
292 |
+
][-1][1],
|
293 |
+
],
|
294 |
+
self.eos,
|
295 |
+
],
|
296 |
+
dim=-1,
|
297 |
+
)
|
298 |
+
]
|
299 |
+
+ [
|
300 |
+
torch.cat([y[i, :, item[0] : item[1]], self.eog], dim=-1)
|
301 |
+
for item in mask_intervals[i]
|
302 |
+
]
|
303 |
+
) # only insert eog to the last non-mask-interval, which is when the utterance actual ends
|
304 |
+
else:
|
305 |
+
if reduced_eog:
|
306 |
+
cur_y = (
|
307 |
+
[
|
308 |
+
y[i, :, item[0] : item[1]]
|
309 |
+
for item in non_mask_intervals[i][:-1]
|
310 |
+
]
|
311 |
+
+ [
|
312 |
+
torch.cat(
|
313 |
+
[
|
314 |
+
y[
|
315 |
+
i,
|
316 |
+
:,
|
317 |
+
non_mask_intervals[i][-1][
|
318 |
+
0
|
319 |
+
] : non_mask_intervals[i][-1][1],
|
320 |
+
],
|
321 |
+
self.eog,
|
322 |
+
],
|
323 |
+
dim=-1,
|
324 |
+
)
|
325 |
+
]
|
326 |
+
+ [
|
327 |
+
torch.cat([y[i, :, item[0] : item[1]], self.eog], dim=-1)
|
328 |
+
for item in mask_intervals[i]
|
329 |
+
]
|
330 |
+
) # only insert eog to the last non-mask-interval, which is when the utterance actual ends
|
331 |
+
else:
|
332 |
+
cur_y = [
|
333 |
+
torch.cat([y[i, :, item[0] : item[1]], self.eog], dim=-1)
|
334 |
+
for item in non_mask_intervals[i]
|
335 |
+
] + [
|
336 |
+
torch.cat([y[i, :, item[0] : item[1]], self.eog], dim=-1)
|
337 |
+
for item in mask_intervals[i]
|
338 |
+
] # eog is added to each section TODO this is not correct, I should add eog to non_mask_intervals if that segment is not the ending segment (as there is no way for the model to predict eog for those segments, and this will do harm to tts experiment, where the model randomly output eog for the first segment)
|
339 |
+
rearranged_y.append(cur_y)
|
340 |
+
return rearranged_y
|
341 |
+
|
342 |
+
def shift(self, rearranged_y):
|
343 |
+
shifted_y = []
|
344 |
+
patterns = []
|
345 |
+
for i in range(len(rearranged_y)):
|
346 |
+
cur_patterns = [
|
347 |
+
self.pattern.get_pattern(cur_y.shape[1]) for cur_y in rearranged_y[i]
|
348 |
+
]
|
349 |
+
out = [
|
350 |
+
cur_pattern.build_pattern_sequence(
|
351 |
+
z=cur_y.unsqueeze(0).contiguous(),
|
352 |
+
special_token=self.args.empty_token,
|
353 |
+
keep_only_valid_steps=False,
|
354 |
+
)
|
355 |
+
for cur_pattern, cur_y in zip(cur_patterns, rearranged_y[i])
|
356 |
+
]
|
357 |
+
shifted_y.append(
|
358 |
+
[item[0].squeeze(0) for item in out]
|
359 |
+
) # the first item is values, later two are indexes and mask
|
360 |
+
patterns.append(cur_patterns)
|
361 |
+
return shifted_y, patterns
|
362 |
+
|
363 |
+
def insert_mask(self, shifted_y):
|
364 |
+
inserted_y = []
|
365 |
+
mask_position = []
|
366 |
+
mask_value = []
|
367 |
+
for i in range(len(shifted_y)):
|
368 |
+
num_masks = (len(shifted_y[i]) - 1) // 2
|
369 |
+
assert num_masks == (len(shifted_y[i]) - 1) / 2, len(shifted_y[i])
|
370 |
+
emb_inds = list(range(self.args.max_n_spans))
|
371 |
+
if self.args.shuffle_mask_embedding:
|
372 |
+
random.shuffle(emb_inds)
|
373 |
+
emb_inds_use = emb_inds[:num_masks]
|
374 |
+
emb_inds_use = emb_inds_use + emb_inds_use
|
375 |
+
mask_value.append(emb_inds_use)
|
376 |
+
cur_inserted_y = []
|
377 |
+
cur_mask_position = []
|
378 |
+
for j in range(len(shifted_y[i]) - 1):
|
379 |
+
cur_inserted_y.append(shifted_y[i][j])
|
380 |
+
cur_mask_position.append(
|
381 |
+
sum([item.shape[1] for item in cur_inserted_y])
|
382 |
+
) # each item is of shape [K S], so take shape[1]
|
383 |
+
cur_inserted_y.append(
|
384 |
+
self.eog
|
385 |
+
) # insert mask token of shape [K, 1], BUT we are actually using the eog token as a place holder here, as the real mask will be inserted in embed_y function
|
386 |
+
|
387 |
+
cur_inserted_y.append(shifted_y[i][-1])
|
388 |
+
|
389 |
+
inserted_y.append(cur_inserted_y)
|
390 |
+
mask_position.append(cur_mask_position)
|
391 |
+
return inserted_y, mask_position, mask_value
|
392 |
+
|
393 |
+
def cat_y(self, inserted_y, mask_position, y_lens):
|
394 |
+
reduced_eog = getattr(self.args, "reduced_eog", 0)
|
395 |
+
cated_y = []
|
396 |
+
new_y_lens = []
|
397 |
+
for i in range(len(inserted_y)):
|
398 |
+
cur_cated_y = torch.cat(inserted_y[i], dim=1) # [K S]
|
399 |
+
cur_cated_y = cur_cated_y.transpose(1, 0) # [S K]
|
400 |
+
cur_cated_y_len = cur_cated_y.shape[0]
|
401 |
+
if reduced_eog:
|
402 |
+
assert cur_cated_y_len == y_lens[i] + len(mask_position[i]) + (
|
403 |
+
len(mask_position[i]) + 1
|
404 |
+
) * self.args.n_codebooks + (
|
405 |
+
len(mask_position[i]) / 2 + 1
|
406 |
+
), f"cur_cated_y_len == {cur_cated_y_len}, but it should be y_lens[i] ({y_lens[i]}) + len(mask_position[i]) ({len(mask_position[i])}) + (len(mask_position[i]) + 1) * self.args.n_codebooks ({(len(mask_position[i]) + 1) * self.args.n_codebooks}) + (len(mask_position[i])/2 + 1) ({len(mask_position[i])/2 + 1})={y_lens[i] + len(mask_position[i]) + (len(mask_position[i]) + 1) * self.args.n_codebooks + (len(mask_position[i])/2 + 1)}"
|
407 |
+
else:
|
408 |
+
assert cur_cated_y_len == y_lens[i] + len(mask_position[i]) + (
|
409 |
+
len(mask_position[i]) + 1
|
410 |
+
) * self.args.n_codebooks + (
|
411 |
+
len(mask_position[i]) + 1
|
412 |
+
), f"cur_cated_y_len == {cur_cated_y_len}, but it should be y_lens[i] ({y_lens[i]}) + len(mask_position[i]) ({len(mask_position[i])}) + (len(mask_position[i]) + 1) * self.args.n_codebooks ({(len(mask_position[i]) + 1) * self.args.n_codebooks}) + (len(mask_position[i]) + 1) ({len(mask_position[i]) + 1})" # the last term represent the inserted eog token, originally it's inserted at the end of every token, but this is wrong
|
413 |
+
new_y_lens.append(cur_cated_y_len)
|
414 |
+
cated_y.append(cur_cated_y)
|
415 |
+
|
416 |
+
cated_y = torch.nn.utils.rnn.pad_sequence(
|
417 |
+
cated_y, batch_first=False, padding_value=self.args.audio_pad_token
|
418 |
+
)
|
419 |
+
assert cated_y.shape == torch.Size(
|
420 |
+
[max(new_y_lens), len(inserted_y), self.args.n_codebooks]
|
421 |
+
), f"cated_y.shape: {cated_y.shape}, but it should be {torch.Size([max(new_y_lens,len(inserted_y), self.args.n_codebooks)])}"
|
422 |
+
cated_y = cated_y.permute(2, 0, 1) # [T,B,K]->[K,T,B]
|
423 |
+
assert cated_y.shape[0] == self.args.n_codebooks, cated_y.shape
|
424 |
+
return cated_y, torch.LongTensor(new_y_lens).to(cated_y.device)
|
425 |
+
|
426 |
+
def embed_y(self, cated_y, mask_position, mask_value):
|
427 |
+
embedded_y = torch.stack(
|
428 |
+
[self.audio_embedding[k](cated_y[k]) for k in range(self.args.n_codebooks)],
|
429 |
+
dim=0,
|
430 |
+
) # [K, T, B, D]
|
431 |
+
assert embedded_y.shape[0] == self.args.n_codebooks, embedded_y.shape
|
432 |
+
assert embedded_y.shape[-1] == self.args.d_model, embedded_y.shape
|
433 |
+
embedded_y = embedded_y.sum(dim=0) # [K,T,B,D]->[T,B,D]
|
434 |
+
embedded_y = embedded_y.transpose(1, 0) # [T,B,D]->[B,T,D]
|
435 |
+
for i in range(len(embedded_y)):
|
436 |
+
if len(mask_position[i]) > 0:
|
437 |
+
embedded_y[i, mask_position[i]] = self.mask_embedding[mask_value[i]]
|
438 |
+
return embedded_y
|
439 |
+
|
440 |
+
def prepare_input_target(self, y, y_lens):
|
441 |
+
# rearrange y
|
442 |
+
# assume y shape: [B T K], K is n_codebooks
|
443 |
+
assert y.shape[1] == self.args.n_codebooks, y.shape
|
444 |
+
# sample mask_intervals
|
445 |
+
mask_intervals, non_mask_intervals = self.prepare_mask_intervals(y_lens)
|
446 |
+
|
447 |
+
# need to have EOG in each section (SOG will be generated by the pattern class)
|
448 |
+
# but mask can be inserted later after we have shifted the input
|
449 |
+
# y could be rearranged in this way:
|
450 |
+
# [
|
451 |
+
# [tensor[4, 12], tensor[4, 45], tensor[4, 102], tensor[4, 32]], tensor[4, 22]],
|
452 |
+
# [tensor[4, 44], tensor[4, 56], tensor[4, 19]],
|
453 |
+
# ...
|
454 |
+
# ]
|
455 |
+
# for the first list of tensors (4 tensors), first 3 tensors are non_masked part, last 2 are masked part.
|
456 |
+
# NOTE #non_masked_part = #masked_part + 1
|
457 |
+
# NOTE *these are also the targets*
|
458 |
+
# added eog at the end of each segment (masked segment and unmasked segment)
|
459 |
+
rearranged_y = self.rearrange(y, non_mask_intervals, mask_intervals)
|
460 |
+
targets = rearranged_y # each element in each sample is of shape [K T]
|
461 |
+
assert targets[0][0].shape[0] == self.args.n_codebooks, targets[0][0].shape
|
462 |
+
|
463 |
+
# next we need to apply pattern shifting to each tensor, after which, we'll replace the starting tokens of each section with a token that's different from the special padding token
|
464 |
+
# [[5, 1, 2, 3, 4, 5, 5],
|
465 |
+
# [5, 5, 1, 2, 3, 4, 5],
|
466 |
+
# [5, 5, 5, 1, 2, 3, 4]]
|
467 |
+
shifted_y, patterns = self.shift(rearranged_y) # each element [K S]
|
468 |
+
assert shifted_y[0][0].shape[0] == self.args.n_codebooks, shifted_y[0][0].shape[
|
469 |
+
0
|
470 |
+
]
|
471 |
+
|
472 |
+
# then, insert mask token at the intersection of each tensor (we want to decide the arrangement of the mask (shuffle or not)), we better have a separate nn.embedding for it
|
473 |
+
# we also need to record the position of the inserted mask
|
474 |
+
inserted_y, mask_position, mask_value = self.insert_mask(shifted_y)
|
475 |
+
assert inserted_y[0][0].shape[0] == self.args.n_codebooks, inserted_y[0][
|
476 |
+
0
|
477 |
+
].shape[0]
|
478 |
+
assert inserted_y[0][1].shape == torch.Size(
|
479 |
+
(self.args.n_codebooks, 1)
|
480 |
+
), f"this should be a mask, so should have shape {(self.args.n_codebooks, 1)}, but it's {inserted_y[0][1].shape}"
|
481 |
+
|
482 |
+
# then concat tensors that belong to the same sample (in order) then get the length of each sample, and then stack them in batch dimension, pad them with pad_token
|
483 |
+
cated_y, new_y_lens = self.cat_y(inserted_y, mask_position, y_lens) # KTB
|
484 |
+
assert cated_y.shape == torch.Size(
|
485 |
+
(self.args.n_codebooks, cated_y.shape[1], len(inserted_y))
|
486 |
+
)
|
487 |
+
|
488 |
+
# embed remember to separately embed the mask tokens
|
489 |
+
embedded_y = self.embed_y(cated_y, mask_position, mask_value) # BTD
|
490 |
+
assert embedded_y.shape[1:] == torch.Size(
|
491 |
+
(max(new_y_lens), self.args.d_model)
|
492 |
+
), embedded_y.shape
|
493 |
+
|
494 |
+
# positional embedding
|
495 |
+
y_input = self.audio_positional_embedding(embedded_y)
|
496 |
+
|
497 |
+
# make attention mask and padding mask
|
498 |
+
y_padding_mask = make_pad_mask(new_y_lens).to(y.device)
|
499 |
+
y_attention_mask = (
|
500 |
+
torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1)
|
501 |
+
.bool()
|
502 |
+
.to(y_padding_mask.device)
|
503 |
+
)
|
504 |
+
return (
|
505 |
+
y_input,
|
506 |
+
new_y_lens,
|
507 |
+
targets,
|
508 |
+
y_padding_mask,
|
509 |
+
y_attention_mask,
|
510 |
+
mask_position,
|
511 |
+
patterns,
|
512 |
+
)
|
513 |
+
|
514 |
+
def remove_mask(self, logits, mask_position, new_y_lens):
|
515 |
+
# logits: [B K S card]
|
516 |
+
logits_use = []
|
517 |
+
for i in range(len(logits)):
|
518 |
+
non_mask_positions = [-1] + mask_position[i] + [new_y_lens[i]]
|
519 |
+
non_mask_intervals = [
|
520 |
+
[non_mask_positions[i] + 1, non_mask_positions[i + 1]]
|
521 |
+
for i in range(len(non_mask_positions) - 1)
|
522 |
+
]
|
523 |
+
cur_logits_use = [logits[i, :, l:r] for l, r in non_mask_intervals]
|
524 |
+
logits_use.append(cur_logits_use)
|
525 |
+
|
526 |
+
return logits_use
|
527 |
+
|
528 |
+
def revert_pattern(self, patterns, logits_use):
|
529 |
+
logits_final = []
|
530 |
+
logit_masks = []
|
531 |
+
for i in range(len(logits_use)):
|
532 |
+
cur_logits = [
|
533 |
+
item.unsqueeze(0).permute(0, 3, 1, 2).contiguous()
|
534 |
+
for item in logits_use[i]
|
535 |
+
] # each item is of shape [1 K S card] [1 card K S]
|
536 |
+
cur_logits_final = [
|
537 |
+
cur_pattern.revert_pattern_logits(item, 0, keep_only_valid_steps=False)
|
538 |
+
for cur_pattern, item in zip(patterns[i], cur_logits)
|
539 |
+
] # if input output order doesn't match, this step will give an error
|
540 |
+
cur_logits_final_ret = [
|
541 |
+
item[0].permute(0, 2, 3, 1).squeeze(0) for item in cur_logits_final
|
542 |
+
] # each element is of shape [K,T,card]
|
543 |
+
logits_final.append(cur_logits_final_ret)
|
544 |
+
logit_masks.append([item[2] for item in cur_logits_final])
|
545 |
+
|
546 |
+
return logits_final, logit_masks
|
547 |
+
|
548 |
+
def dec_forward(
|
549 |
+
self,
|
550 |
+
x_input,
|
551 |
+
x_lens,
|
552 |
+
x_attention_mask,
|
553 |
+
x_padding_mask,
|
554 |
+
y_input,
|
555 |
+
new_y_lens,
|
556 |
+
y_attention_mask,
|
557 |
+
y_padding_mask,
|
558 |
+
past=None,
|
559 |
+
last_3_tokens=False,
|
560 |
+
):
|
561 |
+
x_attn_mask = F.pad(
|
562 |
+
x_attention_mask,
|
563 |
+
(0, new_y_lens.max()),
|
564 |
+
value=True,
|
565 |
+
) # x attn to all x, doesn't attn to any y, this follow figure 3 of the valle paper
|
566 |
+
y_attn_mask = F.pad(
|
567 |
+
y_attention_mask,
|
568 |
+
(x_lens.max(), 0), # y is padded at the front
|
569 |
+
value=False,
|
570 |
+
) # y attn to all x, for y itself use lower triangle mask to ensure autoregressive
|
571 |
+
xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
|
572 |
+
|
573 |
+
# merge key padding and attention masks
|
574 |
+
bsz, src_len = x_input.shape[0], x_lens.max() + new_y_lens.max()
|
575 |
+
xy_padding_mask = torch.concat([x_padding_mask, y_padding_mask], dim=1)
|
576 |
+
_xy_padding_mask = (
|
577 |
+
xy_padding_mask.view(bsz, 1, 1, src_len)
|
578 |
+
.expand(-1, self.args.nhead, -1, -1)
|
579 |
+
.reshape(bsz * self.args.nhead, 1, src_len)
|
580 |
+
)
|
581 |
+
# Check shapes and resize+broadcast as necessary
|
582 |
+
if xy_attn_mask.shape != _xy_padding_mask.shape:
|
583 |
+
assert (
|
584 |
+
xy_attn_mask.ndim + 1 == _xy_padding_mask.ndim
|
585 |
+
), f"xy_attn_mask.shape: {xy_attn_mask.shape}, _xy_padding_mask: {_xy_padding_mask.shape}"
|
586 |
+
xy_attn_mask = xy_attn_mask.unsqueeze(0).repeat(
|
587 |
+
_xy_padding_mask.shape[0], 1, 1
|
588 |
+
) # Example approach
|
589 |
+
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
|
590 |
+
|
591 |
+
new_attn_mask = torch.zeros_like(xy_attn_mask)
|
592 |
+
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
|
593 |
+
xy_attn_mask = new_attn_mask
|
594 |
+
|
595 |
+
xy_input = torch.cat([x_input, y_input], dim=1)
|
596 |
+
|
597 |
+
if past == None: # do not use kvcache
|
598 |
+
out, _ = self.decoder((xy_input, None), mask=xy_attn_mask)
|
599 |
+
return out[:, x_lens.max() :], None
|
600 |
+
else: # use kvcache
|
601 |
+
if (
|
602 |
+
past.ndim > 3
|
603 |
+
): # uses kvcache, only need to pass the last tokens, this doesn't work with multi-span speech editing yet
|
604 |
+
if last_3_tokens:
|
605 |
+
xy_input = xy_input[:, -3:]
|
606 |
+
xy_attn_mask = xy_attn_mask[:, -3:]
|
607 |
+
else:
|
608 |
+
xy_input = xy_input[:, -1:]
|
609 |
+
xy_attn_mask = xy_attn_mask[:, -1:]
|
610 |
+
|
611 |
+
out, present = self.decoder((xy_input, None), mask=xy_attn_mask, past=past)
|
612 |
+
if isinstance(out, tuple): # get rid of stage_embedding
|
613 |
+
out = out[0]
|
614 |
+
|
615 |
+
if out.shape[1] > x_lens.max(): # the first pass, not kvcache yet
|
616 |
+
return out[:, x_lens.max() :], present
|
617 |
+
else: # used kvcache
|
618 |
+
return out, present
|
619 |
+
|
620 |
+
def forward(self, batch):
|
621 |
+
"""
|
622 |
+
Args:
|
623 |
+
x:
|
624 |
+
A 2-D tensor of shape (N, S).
|
625 |
+
x_lens:
|
626 |
+
A 1-D tensor of shape (N,). It contains the number of tokens in `x`
|
627 |
+
before padding.
|
628 |
+
y:
|
629 |
+
A 3-D tensor of shape (N, K, T).
|
630 |
+
where K is the number of codebooks
|
631 |
+
y_lens:
|
632 |
+
A 1-D tensor of shape (N,). It contains the number of tokens in `x`
|
633 |
+
before padding.
|
634 |
+
"""
|
635 |
+
x, x_lens, y, y_lens = batch["x"], batch["x_lens"], batch["y"], batch["y_lens"]
|
636 |
+
if len(x) == 0:
|
637 |
+
return None
|
638 |
+
x = x[
|
639 |
+
:, : x_lens.max()
|
640 |
+
] # this deal with gradient accumulation, where x_lens.max() might not be longer than the length of the current slice of x
|
641 |
+
y = y[:, :, : y_lens.max()]
|
642 |
+
assert x.ndim == 2, x.shape
|
643 |
+
assert x_lens.ndim == 1, x_lens.shape
|
644 |
+
assert y.ndim == 3 and y.shape[1] == self.args.n_codebooks, y.shape
|
645 |
+
assert y_lens.ndim == 1, y_lens.shape
|
646 |
+
# makes attention mask and padding mask for x
|
647 |
+
x_padding_mask = make_pad_mask(x_lens).to(x.device)
|
648 |
+
x_attention_mask = (
|
649 |
+
torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1)
|
650 |
+
.bool()
|
651 |
+
.to(x_padding_mask.device)
|
652 |
+
)
|
653 |
+
x_input = self.text_embedding(x)
|
654 |
+
x_input = self.text_positional_embedding(x_input)
|
655 |
+
(
|
656 |
+
y_input,
|
657 |
+
new_y_lens,
|
658 |
+
targets,
|
659 |
+
y_padding_mask,
|
660 |
+
y_attention_mask,
|
661 |
+
mask_position,
|
662 |
+
patterns,
|
663 |
+
) = self.prepare_input_target(y, y_lens)
|
664 |
+
y_out = self.dec_forward(
|
665 |
+
x_input,
|
666 |
+
x_lens,
|
667 |
+
x_attention_mask,
|
668 |
+
x_padding_mask,
|
669 |
+
y_input,
|
670 |
+
new_y_lens,
|
671 |
+
y_attention_mask,
|
672 |
+
y_padding_mask,
|
673 |
+
)
|
674 |
+
y_out = y_out[0] # no kv-caching during training
|
675 |
+
assert (
|
676 |
+
y_out.shape == y_input.shape
|
677 |
+
), f"y_out.shape: {y_out.shape}, y_input.shape: {y_input.shape}" # [B S D]
|
678 |
+
|
679 |
+
logits = torch.stack(
|
680 |
+
[self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1
|
681 |
+
) # [B K S card]
|
682 |
+
# take out the mask token (using mask_position and new_y_lens) and revert (using function provided by self.pattern)
|
683 |
+
assert (
|
684 |
+
logits.shape[1] == self.args.n_codebooks
|
685 |
+
and logits.shape[3] == self.n_audio_tokens[0]
|
686 |
+
), logits.shape
|
687 |
+
|
688 |
+
logits_use = self.remove_mask(logits, mask_position, new_y_lens)
|
689 |
+
|
690 |
+
# revert the pattern shift for each logits section in each sample
|
691 |
+
logits_final, logit_masks = self.revert_pattern(patterns, logits_use)
|
692 |
+
assert (
|
693 |
+
logits_final[0][0].shape[0] == self.args.n_codebooks
|
694 |
+
and logits_final[0][0].shape[2] == self.n_audio_tokens[0]
|
695 |
+
), f"it is: {logits_final[0][0].shape}, but should be [K, T, card]"
|
696 |
+
# testing
|
697 |
+
sample_to_test = 0
|
698 |
+
assert len(logits_final[sample_to_test]) == len(
|
699 |
+
targets[sample_to_test]
|
700 |
+
), f"{len(logits_final[sample_to_test])}, {len(targets[sample_to_test])}"
|
701 |
+
temp = sum(
|
702 |
+
[
|
703 |
+
logits_final[sample_to_test][i].shape[:-1]
|
704 |
+
!= targets[sample_to_test][i].shape
|
705 |
+
for i in range(len(targets[sample_to_test]))
|
706 |
+
]
|
707 |
+
)
|
708 |
+
assert (
|
709 |
+
temp == 0
|
710 |
+
), f"none equal positions: {temp}, total number of elements: {len(targets[sample_to_test])}"
|
711 |
+
|
712 |
+
logit_masked = sum(
|
713 |
+
[(item == False).any() for cur_mask in logit_masks for item in cur_mask]
|
714 |
+
)
|
715 |
+
assert logit_masked == 0, logit_masks
|
716 |
+
|
717 |
+
logits = torch.cat(
|
718 |
+
[torch.cat(item, dim=1) for item in logits_final], dim=1
|
719 |
+
) # [K, T1+T2+T3+..., card]
|
720 |
+
targets = torch.cat(
|
721 |
+
[torch.cat(item, dim=1) for item in targets], dim=1
|
722 |
+
) # [K, T1+T2+T3+...]
|
723 |
+
assert targets.shape[0] == logits.shape[0], f"{targets.shape}, {logits.shape}"
|
724 |
+
loss = []
|
725 |
+
ntokens = []
|
726 |
+
top10acc = []
|
727 |
+
for k, (logit, target) in enumerate(zip(logits, targets)):
|
728 |
+
loss.append(F.cross_entropy(logit, target, reduction="mean"))
|
729 |
+
top10acc.append(self.accuracy_metrics[k](logit.detach(), target))
|
730 |
+
ntokens.append(len(logit))
|
731 |
+
|
732 |
+
all_ntokens = sum(ntokens)
|
733 |
+
if self.args.codebook_weight != None:
|
734 |
+
codebook_weight = eval(self.args.codebook_weight)
|
735 |
+
else:
|
736 |
+
codebook_weight = [1.0] * self.args.n_codebooks
|
737 |
+
loss = sum([l * nt * cw for l, nt, cw in zip(loss, ntokens, codebook_weight)])
|
738 |
+
top10acc_by_codebook = [t10a * nt for t10a, nt in zip(top10acc, ntokens)]
|
739 |
+
top10acc = sum(top10acc_by_codebook)
|
740 |
+
ntokens = torch.tensor(all_ntokens).to(logits.device)
|
741 |
+
|
742 |
+
return {
|
743 |
+
"loss": loss,
|
744 |
+
"top10acc": top10acc,
|
745 |
+
"top10acc_by_codebook": top10acc_by_codebook,
|
746 |
+
"effective_ntoken": ntokens,
|
747 |
+
}
|
748 |
+
|
749 |
+
def inference(
|
750 |
+
self,
|
751 |
+
x: torch.Tensor,
|
752 |
+
x_lens: torch.Tensor,
|
753 |
+
y: torch.Tensor,
|
754 |
+
mask_interval: list[torch.Tensor],
|
755 |
+
top_k: int = -100,
|
756 |
+
top_p: float = 1.0,
|
757 |
+
temperature: float = 1.0,
|
758 |
+
stop_repetition: int = -1,
|
759 |
+
kvcache: int = 1,
|
760 |
+
silence_tokens: list[int] = [1388, 1898, 131],
|
761 |
+
) -> torch.Tensor:
|
762 |
+
"""
|
763 |
+
Args:
|
764 |
+
x:
|
765 |
+
A 2-D tensor of shape (1, L).
|
766 |
+
x_lens:
|
767 |
+
A 1-D tensor of shape (1,). It contains the number of tokens in `x`
|
768 |
+
before padding.
|
769 |
+
y:
|
770 |
+
A 3-D tensor of shape (1, T, K).
|
771 |
+
mask_interval:
|
772 |
+
a list of tensors of shape (M, 2). contains M mask_start and mask_end. list length is actually 1, because we only support single sample inference for now
|
773 |
+
top_k: (`optional`) int
|
774 |
+
The number of highest probability tokens to keep for top-k-filtering. Default to -100.
|
775 |
+
top_p: (`optional`) float
|
776 |
+
For Neucleus sampling
|
777 |
+
temperature: (`optional`) float
|
778 |
+
The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
|
779 |
+
eog_coef: (`optional`) float
|
780 |
+
if 0, no change to eog token logits, otherwise, will adjust eog token logit based on the difference between acoustic token and phn token length
|
781 |
+
stop_repetition (`optional`) int
|
782 |
+
if not -1, will set the logits of a token that repeated this many times to be -100000, to avoid generating it again. This only apply to tokens from the first codebook
|
783 |
+
allowed_repeat_tokens (`optional`) list of ints
|
784 |
+
by inspecting the validation set, get a few tokens that indeed repeat a significant amount of time, and exclude those tokens from prevent repetition
|
785 |
+
ultimate_stop_repetition (`optional`) int
|
786 |
+
no matter that token it is, stop repetition once after this number
|
787 |
+
"""
|
788 |
+
assert x.ndim == 2, x.shape
|
789 |
+
assert x_lens.ndim == 1, x_lens.shape
|
790 |
+
assert y.ndim == 3, y.shape
|
791 |
+
if self.args.special_first:
|
792 |
+
y = y + int(self.args.n_special)
|
793 |
+
y = y.transpose(2, 1) # [1,T,K] -> [1,K,T]
|
794 |
+
assert (
|
795 |
+
y.shape[0] == 1 and y.shape[1] == self.args.n_codebooks
|
796 |
+
), y.shape # there is no padding
|
797 |
+
assert mask_interval.shape == torch.Size(
|
798 |
+
(1, mask_interval.shape[1], 2)
|
799 |
+
), mask_interval
|
800 |
+
|
801 |
+
# make x attention mask and x_input
|
802 |
+
x_attention_mask = (
|
803 |
+
torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1)
|
804 |
+
.bool()
|
805 |
+
.to(x.device)
|
806 |
+
)
|
807 |
+
# x_attention_mask = torch.zeros(x.shape[1], x.shape[1]).bool().to(x.device)
|
808 |
+
x_input = self.text_embedding(x)
|
809 |
+
x_input = self.text_positional_embedding(x_input)
|
810 |
+
|
811 |
+
# make initial y_input
|
812 |
+
|
813 |
+
# make mask_interval and non_mask_interval
|
814 |
+
y_len = y.shape[2]
|
815 |
+
y_lens = torch.LongTensor([y_len]).to(y.device)
|
816 |
+
mask_interval = mask_interval[0]
|
817 |
+
starts = [item[0].item() for item in mask_interval] + [y_len]
|
818 |
+
ends = [0] + [item[1].item() for item in mask_interval]
|
819 |
+
mask_intervals = [
|
820 |
+
[(item[0].item(), item[1].item()) for item in mask_interval]
|
821 |
+
] # a werid name change, mask_interval is input, now is mask_intervals, with one more dimension
|
822 |
+
non_mask_intervals = [[(ns, ne) for ns, ne in zip(ends, starts)]]
|
823 |
+
|
824 |
+
# rearrange y
|
825 |
+
# will add have EOG in each section (SOG will be generated by the pattern class)
|
826 |
+
# but mask can be inserted later after we have shifted the input
|
827 |
+
# y could be rearranged in this way:
|
828 |
+
# [
|
829 |
+
# [tensor[4, 12], tensor[4, 45], tensor[4, 102], tensor[4, 32]], tensor[4, 22]],
|
830 |
+
# [tensor[4, 44], tensor[4, 56], tensor[4, 19]],
|
831 |
+
# ...
|
832 |
+
# ]
|
833 |
+
# for the first list of tensors (4 tensors), first 3 tensors are non_masked part, last 2 are masked part.
|
834 |
+
# NOTE #non_masked_part = #masked_part + 1
|
835 |
+
rearranged_y = self.rearrange(y, non_mask_intervals, mask_intervals)
|
836 |
+
assert rearranged_y[0][0].shape[0] == self.args.n_codebooks, rearranged_y[0][
|
837 |
+
0
|
838 |
+
].shape
|
839 |
+
|
840 |
+
# shift each element of y
|
841 |
+
# next we need to apply pattern shifting to each tensor, after which, we'll replace the starting tokens of each section with a token that's different from the special padding token
|
842 |
+
# [
|
843 |
+
# [empty, 1, 2, 3, eog, empty, empty, empty],
|
844 |
+
# [empty, empty, 1, 2, 3, eog, empty, empty],
|
845 |
+
# [empty, empty, empty, 1, 2, 3, eog, empty],
|
846 |
+
# [empty, empty, empty, empty, 1, 2, 3, eog]
|
847 |
+
# ]
|
848 |
+
shifted_y, patterns = self.shift(
|
849 |
+
rearranged_y
|
850 |
+
) # each element [K S], patterns is not used, as we directly use the original input y
|
851 |
+
assert shifted_y[0][0].shape[0] == self.args.n_codebooks, shifted_y[0][0].shape
|
852 |
+
|
853 |
+
# insert mask token at the intersction of each tensor, but *actually inserted eog as place holder*
|
854 |
+
# the position of inserted mask is also recorded
|
855 |
+
# and the mask_value, the index of the mask emb is recorded
|
856 |
+
inserted_y, mask_position, mask_value = self.insert_mask(shifted_y)
|
857 |
+
assert inserted_y[0][0].shape[0] == self.args.n_codebooks, inserted_y[0][
|
858 |
+
0
|
859 |
+
].shape[0]
|
860 |
+
assert inserted_y[0][1].shape == torch.Size(
|
861 |
+
(self.args.n_codebooks, 1)
|
862 |
+
), f"this should be a mask, so should have shape {(self.args.n_codebooks, 1)}, but it's {inserted_y[0][1].shape}"
|
863 |
+
|
864 |
+
# then concat tensors that belong to the same sample (in order) then get the length of each sample, and then stack them in batch dimension, pad them with pad_token
|
865 |
+
cated_y, new_y_lens = self.cat_y(inserted_y, mask_position, y_lens) # KTB
|
866 |
+
assert cated_y.shape == torch.Size(
|
867 |
+
(self.args.n_codebooks, cated_y.shape[1], len(inserted_y))
|
868 |
+
)
|
869 |
+
assert not (cated_y == self.args.audio_pad_token).any(), cated_y
|
870 |
+
|
871 |
+
### NOTE this is different from forward, as we will remove the masked tokens
|
872 |
+
### say there are two masked region
|
873 |
+
### the cated_y should be like
|
874 |
+
### [empty a a a a mask0 empty b b b mask1 empty c c mask0 empty]
|
875 |
+
### which means we need to take the part after the last empty out
|
876 |
+
num_mask = len(mask_position[0]) // 2
|
877 |
+
assert num_mask == len(mask_position[0]) / 2, mask_position
|
878 |
+
cated_y = cated_y[:, : mask_position[0][num_mask] + 2] # of shape [K,T,B]
|
879 |
+
# logging.info(f"mask_position[0][num_mask]+2: {mask_position[0][num_mask]+2}")
|
880 |
+
more_mask_value = mask_value[0][
|
881 |
+
num_mask + 1 :
|
882 |
+
] # NOTE this will be used in the generation loop for reference for inserting mask embedding
|
883 |
+
new_y_lens[0] = mask_position[0][num_mask] + 2
|
884 |
+
mask_position[0] = mask_position[0][: num_mask + 1]
|
885 |
+
assert (
|
886 |
+
mask_position[0][num_mask] + 2 == cated_y.shape[1]
|
887 |
+
), f"num_mask: {num_mask}, mask_position: {mask_position}, cated_y.shape: {cated_y.shape}"
|
888 |
+
|
889 |
+
# embed: remember to separately embed the mask tokens
|
890 |
+
embedded_y = self.embed_y(
|
891 |
+
cated_y, mask_position, [mask_value[0][: num_mask + 1]]
|
892 |
+
) # BTD
|
893 |
+
# assert embedded_y.shape == torch.Size((y.shape[0], max(new_y_lens), self.args.d_model)), embedded_y.shape
|
894 |
+
|
895 |
+
# positional embedding
|
896 |
+
y_input = self.audio_positional_embedding(embedded_y)
|
897 |
+
|
898 |
+
# make attention mask and padding mask
|
899 |
+
y_attention_mask = (
|
900 |
+
torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1)
|
901 |
+
.bool()
|
902 |
+
.to(y.device)
|
903 |
+
)
|
904 |
+
# y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device)
|
905 |
+
|
906 |
+
x_padding_mask = torch.full((1, x_lens[0]), False).to(x.device)
|
907 |
+
y_padding_mask = torch.full((1, new_y_lens[0]), False).to(y.device)
|
908 |
+
|
909 |
+
codebook_eog = [False] * self.args.n_codebooks
|
910 |
+
generated = [] # doesn't contain any empty_token, contains eog
|
911 |
+
cur_generated = []
|
912 |
+
# say 0 is empty, 4 is eog
|
913 |
+
# tensor([[ 1, 2, 3, 4, 0, 0],
|
914 |
+
# [ 0, 1, 2, 3, 4, 0],
|
915 |
+
# [ 0, 0, 1, 2, 3, 4]])
|
916 |
+
num_gen = []
|
917 |
+
cur_num_gen = 0
|
918 |
+
##################### silence repetition handling #####################
|
919 |
+
##################### silence repetition handling #####################
|
920 |
+
logging.info(
|
921 |
+
f"silence tokens: {silence_tokens}, note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default"
|
922 |
+
)
|
923 |
+
consec_silence_count = 0
|
924 |
+
prev_token = None
|
925 |
+
##################### silence repetition handling #####################
|
926 |
+
##################### silence repetition handling #####################
|
927 |
+
# prepare the cache placeholder
|
928 |
+
# n_layers, 2, bsz, num_heads, src_len, head_dim
|
929 |
+
past = (
|
930 |
+
torch.ones(
|
931 |
+
[self.args.num_decoder_layers, 2, x.shape[0]],
|
932 |
+
device=x.device,
|
933 |
+
dtype=torch.float32,
|
934 |
+
)
|
935 |
+
if kvcache
|
936 |
+
else None
|
937 |
+
)
|
938 |
+
# handle multi-span kv-cache
|
939 |
+
new_masked_span = False
|
940 |
+
|
941 |
+
def sample_helper(
|
942 |
+
n_eog,
|
943 |
+
logits,
|
944 |
+
codebook_eog,
|
945 |
+
top_k,
|
946 |
+
top_p,
|
947 |
+
temperature,
|
948 |
+
prev_token,
|
949 |
+
consec_silence_count,
|
950 |
+
stop_repetition,
|
951 |
+
silence_tokens,
|
952 |
+
cur_num_gen,
|
953 |
+
):
|
954 |
+
if n_eog == 0:
|
955 |
+
logits_adjust = logits
|
956 |
+
for jj in range(1, self.args.n_codebooks):
|
957 |
+
logits_adjust[jj][self.args.eog] = -10000
|
958 |
+
logits_adjust[jj][self.args.empty_token] = -10000
|
959 |
+
##################### silence repetition handling #####################
|
960 |
+
if (
|
961 |
+
stop_repetition > 0
|
962 |
+
and prev_token in silence_tokens
|
963 |
+
and consec_silence_count > stop_repetition
|
964 |
+
):
|
965 |
+
if logits_adjust[0, prev_token] < 0:
|
966 |
+
logits_adjust[0, prev_token] = logits_adjust[0, prev_token] * (
|
967 |
+
consec_silence_count - (stop_repetition - 1)
|
968 |
+
)
|
969 |
+
else:
|
970 |
+
logits_adjust[0, prev_token] = logits_adjust[0, prev_token] / (
|
971 |
+
consec_silence_count - (stop_repetition - 1)
|
972 |
+
)
|
973 |
+
##################### silence repetition handling #####################
|
974 |
+
if type(logits_adjust) == list:
|
975 |
+
samples_list = []
|
976 |
+
for logit in logits_adjust:
|
977 |
+
# print(logit)
|
978 |
+
# print(logit.shape)
|
979 |
+
cur_sample = topk_sampling(
|
980 |
+
logit.unsqueeze(0),
|
981 |
+
top_k=top_k,
|
982 |
+
top_p=top_p,
|
983 |
+
temperature=temperature,
|
984 |
+
) # [1, 1]
|
985 |
+
samples_list.append(cur_sample)
|
986 |
+
samples = torch.cat(samples_list, dim=0) # [K, 1]
|
987 |
+
else:
|
988 |
+
samples = topk_sampling(
|
989 |
+
logits_adjust, top_k=top_k, top_p=top_p, temperature=temperature
|
990 |
+
) # [K, 1]
|
991 |
+
assert samples.shape == torch.Size(
|
992 |
+
(self.args.n_codebooks, 1)
|
993 |
+
), f"samples.shape: {samples.shape}"
|
994 |
+
if cur_num_gen < self.args.n_codebooks - 1:
|
995 |
+
for jj in range(1, self.args.n_codebooks - cur_num_gen):
|
996 |
+
samples[-jj, 0] = self.args.empty_token
|
997 |
+
|
998 |
+
if (
|
999 |
+
samples[0, 0] == self.args.eog
|
1000 |
+
or torch.argmax(logits[0], dim=-1) == self.args.eog
|
1001 |
+
or y_input.shape[1] > x_lens[0] * 10
|
1002 |
+
): # last one means y is already too long, shouldn't happen, but put it here
|
1003 |
+
samples[0, 0] = self.args.eog
|
1004 |
+
codebook_eog[0] = True
|
1005 |
+
##################### silence repetition handling #####################
|
1006 |
+
##################### silence repetition handling #####################
|
1007 |
+
if samples[0, 0] in silence_tokens and samples[0, 0] == prev_token:
|
1008 |
+
consec_silence_count += 1
|
1009 |
+
else:
|
1010 |
+
consec_silence_count = 0
|
1011 |
+
prev_token = samples[0, 0]
|
1012 |
+
##################### silence repetition handling #####################
|
1013 |
+
##################### silence repetition handling #####################
|
1014 |
+
return samples, codebook_eog, prev_token, consec_silence_count
|
1015 |
+
else:
|
1016 |
+
assert (
|
1017 |
+
sum(codebook_eog[i] for i in range(n_eog)) == n_eog
|
1018 |
+
), f"codebook_eog: {codebook_eog}, but n_eog: {n_eog}"
|
1019 |
+
logits_adjust = logits
|
1020 |
+
for jj in range(n_eog + 1, self.args.n_codebooks):
|
1021 |
+
logits_adjust[jj][self.args.eog] = -10000
|
1022 |
+
logits_adjust[jj][self.args.empty_token] = -10000
|
1023 |
+
if type(logits_adjust) == list:
|
1024 |
+
samples_list = []
|
1025 |
+
for logit in logits_adjust:
|
1026 |
+
cur_sample = topk_sampling(
|
1027 |
+
logit.unsqueeze(0),
|
1028 |
+
top_k=top_k,
|
1029 |
+
top_p=top_p,
|
1030 |
+
temperature=temperature,
|
1031 |
+
) # [1, 1]
|
1032 |
+
samples_list.append(cur_sample)
|
1033 |
+
samples = torch.cat(samples_list, dim=0) # [K, 1]
|
1034 |
+
else:
|
1035 |
+
samples = topk_sampling(
|
1036 |
+
logits_adjust, top_k=top_k, top_p=top_p, temperature=temperature
|
1037 |
+
) # [K, 1]
|
1038 |
+
for jj in range(n_eog):
|
1039 |
+
samples[jj, 0] = self.args.empty_token
|
1040 |
+
samples[n_eog, 0] = self.args.eog
|
1041 |
+
codebook_eog[n_eog] = True
|
1042 |
+
return samples, codebook_eog, prev_token, consec_silence_count
|
1043 |
+
|
1044 |
+
while True:
|
1045 |
+
y_out, present = self.dec_forward(
|
1046 |
+
x_input,
|
1047 |
+
x_lens,
|
1048 |
+
x_attention_mask,
|
1049 |
+
x_padding_mask,
|
1050 |
+
y_input,
|
1051 |
+
new_y_lens,
|
1052 |
+
y_attention_mask,
|
1053 |
+
y_padding_mask,
|
1054 |
+
past=past,
|
1055 |
+
last_3_tokens=new_masked_span,
|
1056 |
+
)
|
1057 |
+
if new_masked_span:
|
1058 |
+
new_masked_span = False
|
1059 |
+
|
1060 |
+
if past != None:
|
1061 |
+
past = (
|
1062 |
+
torch.cat([past, present.to(past.dtype)], dim=-2)
|
1063 |
+
if past.ndim > 3
|
1064 |
+
else present.to(past.dtype)
|
1065 |
+
)
|
1066 |
+
|
1067 |
+
y_out = y_out[:, -1:] # only take the last one
|
1068 |
+
|
1069 |
+
logits = torch.stack(
|
1070 |
+
[self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)],
|
1071 |
+
dim=1,
|
1072 |
+
) # [B K S card], B==S==1, so [1 K 1 card]
|
1073 |
+
logits = logits.squeeze(0).squeeze(1) # [K card]
|
1074 |
+
assert logits.shape == torch.Size(
|
1075 |
+
(self.args.n_codebooks, self.n_audio_tokens[0])
|
1076 |
+
), f"{logits.shape}"
|
1077 |
+
|
1078 |
+
n_eog = sum(codebook_eog)
|
1079 |
+
assert n_eog < self.args.n_codebooks
|
1080 |
+
if (
|
1081 |
+
self.args.eos > 0
|
1082 |
+
): # eos stands for end-of-sentence, which shouldn't be used as we are doing speech editing
|
1083 |
+
for jj in range(self.args.n_codebooks):
|
1084 |
+
logits[jj][self.args.eos] = -10000.0
|
1085 |
+
# need to use a helper function to hand different n_eog cases
|
1086 |
+
samples, codebook_eog, prev_token, consec_silence_count = sample_helper(
|
1087 |
+
n_eog,
|
1088 |
+
logits,
|
1089 |
+
codebook_eog,
|
1090 |
+
top_k,
|
1091 |
+
top_p,
|
1092 |
+
temperature,
|
1093 |
+
prev_token,
|
1094 |
+
consec_silence_count,
|
1095 |
+
stop_repetition,
|
1096 |
+
silence_tokens,
|
1097 |
+
cur_num_gen,
|
1098 |
+
)
|
1099 |
+
cur_num_gen += 1
|
1100 |
+
cur_generated.append(samples.squeeze(-1)) # [K,1] -> [K]
|
1101 |
+
# get samples_emb
|
1102 |
+
samples_emb = torch.stack(
|
1103 |
+
[
|
1104 |
+
self.audio_embedding[k](samples[k])
|
1105 |
+
for k in range(self.args.n_codebooks)
|
1106 |
+
],
|
1107 |
+
dim=0,
|
1108 |
+
) # [K,1,D]
|
1109 |
+
samples_emb = samples_emb.sum(dim=0, keepdim=True) # [1,1,D]
|
1110 |
+
|
1111 |
+
if (
|
1112 |
+
sum(codebook_eog) == self.args.n_codebooks
|
1113 |
+
): # generation for the current span is done
|
1114 |
+
# re-init
|
1115 |
+
codebook_eog = [False] * self.args.n_codebooks
|
1116 |
+
num_gen.append(cur_num_gen)
|
1117 |
+
cur_num_gen = 0
|
1118 |
+
generated.append(cur_generated)
|
1119 |
+
cur_generated = []
|
1120 |
+
|
1121 |
+
# if the current mask span is the last span, then all done
|
1122 |
+
# else
|
1123 |
+
# append the next mask token and the four empty tokens to start the next generation
|
1124 |
+
if len(more_mask_value) > 0:
|
1125 |
+
next_mask_ind = more_mask_value.pop(0)
|
1126 |
+
mask_emb = (
|
1127 |
+
self.mask_embedding[next_mask_ind].unsqueeze(0).unsqueeze(0)
|
1128 |
+
) # [1,1,D]
|
1129 |
+
assert mask_emb.shape == torch.Size(
|
1130 |
+
(1, 1, self.args.d_model)
|
1131 |
+
), mask_emb.shape
|
1132 |
+
empty_token = torch.LongTensor([self.args.empty_token]).to(y.device)
|
1133 |
+
empty_emb = torch.stack(
|
1134 |
+
[
|
1135 |
+
self.audio_embedding[k](empty_token)
|
1136 |
+
for k in range(self.args.n_codebooks)
|
1137 |
+
],
|
1138 |
+
dim=0,
|
1139 |
+
).sum(
|
1140 |
+
dim=0, keepdim=True
|
1141 |
+
) # [1,1,D]
|
1142 |
+
assert empty_emb.shape == torch.Size(
|
1143 |
+
(1, 1, self.args.d_model)
|
1144 |
+
), empty_emb.shape
|
1145 |
+
extra_emb = torch.cat([mask_emb, empty_emb], dim=1) # [1,2,D]
|
1146 |
+
samples_emb = torch.cat(
|
1147 |
+
[samples_emb, extra_emb], dim=1
|
1148 |
+
) # [1,3,D] # prev_last_token, mask_token, empty token
|
1149 |
+
assert samples_emb.shape == torch.Size(
|
1150 |
+
(1, 3, self.args.d_model)
|
1151 |
+
), f"samples_emb.shape: {samples_emb.shape}"
|
1152 |
+
##################### silence repetition handling #####################
|
1153 |
+
##################### silence repetition handling #####################
|
1154 |
+
consec_silence_count = 0
|
1155 |
+
prev_token = None
|
1156 |
+
##################### silence repetition handling #####################
|
1157 |
+
##################### silence repetition handling #####################
|
1158 |
+
|
1159 |
+
# handling kv-caching for multi-span editing
|
1160 |
+
new_masked_span = True
|
1161 |
+
else:
|
1162 |
+
break
|
1163 |
+
else:
|
1164 |
+
assert samples_emb.shape == torch.Size(
|
1165 |
+
(1, 1, self.args.d_model)
|
1166 |
+
), f"samples_emb.shape: {samples_emb.shape}"
|
1167 |
+
|
1168 |
+
embedded_y = torch.cat([embedded_y, samples_emb], dim=1)
|
1169 |
+
# positional embedding
|
1170 |
+
y_input = self.audio_positional_embedding(embedded_y) # [B T D]
|
1171 |
+
# make attention mask and padding mask
|
1172 |
+
y_attention_mask = (
|
1173 |
+
torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1)
|
1174 |
+
.bool()
|
1175 |
+
.to(y.device)
|
1176 |
+
)
|
1177 |
+
new_y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device)
|
1178 |
+
y_padding_mask = torch.full((1, new_y_lens[0]), False).to(y.device)
|
1179 |
+
|
1180 |
+
assert (
|
1181 |
+
len(generated) == num_mask
|
1182 |
+
), f"len(generated): {len(generated)}, num_mask: {num_mask}"
|
1183 |
+
|
1184 |
+
# # combine non_masked_span with generated spans
|
1185 |
+
# first need to shift the generated part back
|
1186 |
+
flatten_gen = []
|
1187 |
+
for l, orig_span in enumerate(generated):
|
1188 |
+
span = torch.stack(orig_span, dim=0) # [T K]
|
1189 |
+
span = span.transpose(1, 0) # [K, T]
|
1190 |
+
assert span.shape[0] == self.args.n_codebooks, span.shape
|
1191 |
+
unshifted_span = []
|
1192 |
+
for j, s in enumerate(span):
|
1193 |
+
start_from = j
|
1194 |
+
end_at = -(self.args.n_codebooks - start_from)
|
1195 |
+
unshifted_span.append(s[start_from:end_at])
|
1196 |
+
unshifted_span = torch.stack(unshifted_span, dim=0)
|
1197 |
+
|
1198 |
+
assert (
|
1199 |
+
unshifted_span.shape[1] == num_gen[l] - self.args.n_codebooks
|
1200 |
+
), f"len(unshifted_spans[0]): {len(unshifted_span[0])}, num_gen[l]: {num_gen[l]}"
|
1201 |
+
flatten_gen.append(unshifted_span)
|
1202 |
+
# logging.info(f"unshfited_span: {unshifted_span.shape}")
|
1203 |
+
# raise
|
1204 |
+
assert len(non_mask_intervals[0]) - 1 == len(
|
1205 |
+
flatten_gen
|
1206 |
+
), f"len(non_mask_intervals[0]): {len(non_mask_intervals[0])}, len(flatten_gen): {len(flatten_gen)}"
|
1207 |
+
res = []
|
1208 |
+
for orig_interval, gen in zip(non_mask_intervals[0], flatten_gen):
|
1209 |
+
res.append(y[0, :, orig_interval[0] : orig_interval[1]])
|
1210 |
+
res.append(gen)
|
1211 |
+
res.append(y[0, :, non_mask_intervals[0][-1][0] : non_mask_intervals[0][-1][1]])
|
1212 |
+
res = torch.cat(res, dim=1).unsqueeze(0) # [K,new_T] -> [1, K, new_T]
|
1213 |
+
|
1214 |
+
expected_y_len = (
|
1215 |
+
y_len
|
1216 |
+
- sum([item[1] - item[0] for item in mask_intervals[0]])
|
1217 |
+
+ sum([item - self.args.n_codebooks for item in num_gen])
|
1218 |
+
)
|
1219 |
+
assert res.shape == torch.Size(
|
1220 |
+
(1, self.args.n_codebooks, expected_y_len)
|
1221 |
+
), f"res.shape: {res.shape}, expected_y_len: {expected_y_len}. y_len - sum([item[1] - item[0] for item in mask_interval]) + sum([item - self.args.n_codebooks for item in num_gen]): {y_len}-{sum([item[1] - item[0] for item in mask_interval])} + {sum([item - self.args.n_codebooks for item in num_gen])}"
|
1222 |
+
|
1223 |
+
if self.args.special_first:
|
1224 |
+
res = res - int(self.args.n_special)
|
1225 |
+
|
1226 |
+
return res
|
1227 |
+
|
1228 |
+
def inference_tts(
|
1229 |
+
self,
|
1230 |
+
x: torch.Tensor,
|
1231 |
+
x_lens: torch.Tensor,
|
1232 |
+
y: torch.Tensor,
|
1233 |
+
top_k: int = -100,
|
1234 |
+
top_p: float = 1.0,
|
1235 |
+
temperature: float = 1.0,
|
1236 |
+
stop_repetition: int = 3,
|
1237 |
+
kvcache: int = 1,
|
1238 |
+
silence_tokens: list[int] = [1388, 1898, 131],
|
1239 |
+
*kargs,
|
1240 |
+
) -> torch.Tensor:
|
1241 |
+
"""
|
1242 |
+
different from inference_tts, this implementation uses kvcache, which should have significant speed up
|
1243 |
+
Args:
|
1244 |
+
x:
|
1245 |
+
A 2-D tensor of shape (1, L).
|
1246 |
+
x_lens:
|
1247 |
+
A 1-D tensor of shape (1,). It contains the number of tokens in `x`
|
1248 |
+
before padding.
|
1249 |
+
y:
|
1250 |
+
A 3-D tensor of shape (1, T, K).
|
1251 |
+
top_k: (`optional`) int
|
1252 |
+
The number of highest probability tokens to keep for top-k-filtering. Default to -100.
|
1253 |
+
top_p: (`optional`) float
|
1254 |
+
For Neucleus sampling
|
1255 |
+
temperature: (`optional`) float
|
1256 |
+
The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
|
1257 |
+
"""
|
1258 |
+
eog_inference = self.args.eos if self.args.eos > 0 else self.args.eog
|
1259 |
+
assert x.ndim == 2, x.shape
|
1260 |
+
assert x_lens.ndim == 1, x_lens.shape
|
1261 |
+
assert y.ndim == 3, y.shape
|
1262 |
+
if self.args.special_first:
|
1263 |
+
y = y + int(self.args.n_special)
|
1264 |
+
y = y.transpose(2, 1) # [1,T,K] -> [1,K,T]
|
1265 |
+
assert (
|
1266 |
+
y.shape[0] == 1 and y.shape[1] == self.args.n_codebooks
|
1267 |
+
), y.shape # there is no padding
|
1268 |
+
|
1269 |
+
# make x attention mask and x_input
|
1270 |
+
x_attention_mask = (
|
1271 |
+
torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1)
|
1272 |
+
.bool()
|
1273 |
+
.to(x.device)
|
1274 |
+
)
|
1275 |
+
# x_attention_mask = torch.zeros(x.shape[1], x.shape[1]).bool().to(x.device)
|
1276 |
+
x_input = self.text_embedding(x)
|
1277 |
+
x_input = self.text_positional_embedding(x_input)
|
1278 |
+
|
1279 |
+
y_len = y.shape[2]
|
1280 |
+
y_lens = torch.LongTensor([y_len]).to(y.device)
|
1281 |
+
|
1282 |
+
# rearrange y, we don't add eog to the end, this doesn't actually do anything in the tts scenario
|
1283 |
+
rearranged_y = [[y[0]]]
|
1284 |
+
assert rearranged_y[0][0].shape[0] == self.args.n_codebooks, rearranged_y[0][
|
1285 |
+
0
|
1286 |
+
].shape
|
1287 |
+
|
1288 |
+
# shift y to create the delayed pattern
|
1289 |
+
shifted_y, patterns = self.shift(
|
1290 |
+
rearranged_y
|
1291 |
+
) # each element [K S], patterns is not used, as we directly use the original input y
|
1292 |
+
assert shifted_y[0][0].shape[0] == self.args.n_codebooks, shifted_y[0][0].shape
|
1293 |
+
assert len(shifted_y[0]) == 1, len(shifted_y[0])
|
1294 |
+
|
1295 |
+
# below is different from forward or inference
|
1296 |
+
# where we cut this shifted part
|
1297 |
+
shifted_y[0][0] = shifted_y[0][0][:, : -(self.args.n_codebooks - 1)]
|
1298 |
+
assert (
|
1299 |
+
not (
|
1300 |
+
shifted_y[0][0][self.args.n_codebooks :] == self.args.empty_token
|
1301 |
+
).any()
|
1302 |
+
and not (shifted_y[0][0][self.args.n_codebooks :] == self.args.eog).any()
|
1303 |
+
), shifted_y[0][0]
|
1304 |
+
|
1305 |
+
# next section in inference is insert mask at the intersection of each tensor in a sample, but we don't need to do that
|
1306 |
+
# next section is concate tensors of each sample to one tensor, which we also don't need
|
1307 |
+
cated_y = shifted_y[0][0].unsqueeze(-1) # [K,S]->[K,S,B]
|
1308 |
+
new_y_lens = torch.LongTensor([cated_y.shape[1]]).to(cated_y.device)
|
1309 |
+
assert cated_y.shape == torch.Size((self.args.n_codebooks, cated_y.shape[1], 1))
|
1310 |
+
assert not (cated_y == self.args.audio_pad_token).any(), cated_y
|
1311 |
+
|
1312 |
+
# replace tokens in y with the embeddings, add sum codebooks up
|
1313 |
+
embedded_y = torch.stack(
|
1314 |
+
[self.audio_embedding[k](cated_y[k]) for k in range(self.args.n_codebooks)],
|
1315 |
+
dim=0,
|
1316 |
+
) # [K, S, B, D]
|
1317 |
+
assert embedded_y.shape[0] == self.args.n_codebooks, embedded_y.shape
|
1318 |
+
assert embedded_y.shape[-1] == self.args.d_model, embedded_y.shape
|
1319 |
+
embedded_y = embedded_y.sum(dim=0) # [K,S,B,D]->[S,B,D]
|
1320 |
+
embedded_y = embedded_y.transpose(1, 0) # [S,B,D]->[B,S,D]
|
1321 |
+
|
1322 |
+
# positional embedding
|
1323 |
+
y_input = self.audio_positional_embedding(embedded_y)
|
1324 |
+
|
1325 |
+
# make attention mask and padding mask
|
1326 |
+
y_attention_mask = (
|
1327 |
+
torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1)
|
1328 |
+
.bool()
|
1329 |
+
.to(y.device)
|
1330 |
+
)
|
1331 |
+
|
1332 |
+
x_padding_mask = torch.full((1, x_lens[0]), False).to(x.device)
|
1333 |
+
y_padding_mask = torch.full((1, new_y_lens[0]), False).to(y.device)
|
1334 |
+
|
1335 |
+
# entering the generation stage
|
1336 |
+
# starting from line 708
|
1337 |
+
codebook_eog = [False] * self.args.n_codebooks
|
1338 |
+
generated = [] # doesn't contain any empty token, contain eog
|
1339 |
+
cur_generated = []
|
1340 |
+
# say 0 is empty, 4 is eog
|
1341 |
+
# tensor([[ 1, 2, 3, 4, 0, 0],
|
1342 |
+
# [ 0, 1, 2, 3, 4, 0],
|
1343 |
+
# [ 0, 0, 1, 2, 3, 4]])
|
1344 |
+
num_gen = []
|
1345 |
+
cur_num_gen = 0
|
1346 |
+
##################### silence repetition handling #####################
|
1347 |
+
##################### silence repetition handling #####################
|
1348 |
+
logging.info(
|
1349 |
+
f"silence tokens: {silence_tokens}, note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default"
|
1350 |
+
)
|
1351 |
+
consec_silence_count = 0
|
1352 |
+
prev_token = None
|
1353 |
+
##################### silence repetition handling #####################
|
1354 |
+
##################### silence repetition handling #####################
|
1355 |
+
|
1356 |
+
# prepare the cache placeholder
|
1357 |
+
# n_layers, 2, bsz, num_heads, src_len, head_dim
|
1358 |
+
past = (
|
1359 |
+
torch.ones(
|
1360 |
+
[self.args.num_decoder_layers, 2, x.shape[0]],
|
1361 |
+
device=x.device,
|
1362 |
+
dtype=torch.float32,
|
1363 |
+
)
|
1364 |
+
if kvcache
|
1365 |
+
else None
|
1366 |
+
)
|
1367 |
+
|
1368 |
+
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
|
1369 |
+
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
|
1370 |
+
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
|
1371 |
+
def sample_helper(
|
1372 |
+
n_eog,
|
1373 |
+
logits,
|
1374 |
+
codebook_eog,
|
1375 |
+
top_k,
|
1376 |
+
top_p,
|
1377 |
+
temperature,
|
1378 |
+
prev_token,
|
1379 |
+
consec_silence_count,
|
1380 |
+
stop_repetition,
|
1381 |
+
silence_tokens,
|
1382 |
+
cur_num_gen,
|
1383 |
+
):
|
1384 |
+
if n_eog == 0:
|
1385 |
+
logits_adjust = logits
|
1386 |
+
for jj in range(1, self.args.n_codebooks):
|
1387 |
+
logits_adjust[jj][eog_inference] = -10000
|
1388 |
+
logits_adjust[jj][self.args.empty_token] = -10000
|
1389 |
+
if (
|
1390 |
+
cur_num_gen <= self.args.encodec_sr // 5
|
1391 |
+
): # this shouldn't happen, but just in case the model stopped too early
|
1392 |
+
logits_adjust[0][eog_inference] = -10000
|
1393 |
+
##################### silence repetition handling #####################
|
1394 |
+
if (
|
1395 |
+
stop_repetition > 0
|
1396 |
+
and prev_token in silence_tokens
|
1397 |
+
and consec_silence_count > stop_repetition
|
1398 |
+
):
|
1399 |
+
if logits_adjust[0, prev_token] < 0:
|
1400 |
+
logits_adjust[0, prev_token] = logits_adjust[0, prev_token] * (
|
1401 |
+
consec_silence_count - (stop_repetition - 1)
|
1402 |
+
)
|
1403 |
+
else:
|
1404 |
+
logits_adjust[0, prev_token] = logits_adjust[0, prev_token] / (
|
1405 |
+
consec_silence_count - (stop_repetition - 1)
|
1406 |
+
)
|
1407 |
+
##################### silence repetition handling #####################
|
1408 |
+
samples = topk_sampling(
|
1409 |
+
logits_adjust, top_k=top_k, top_p=top_p, temperature=temperature
|
1410 |
+
) # [K, 1]
|
1411 |
+
assert samples.shape == torch.Size(
|
1412 |
+
(self.args.n_codebooks, 1)
|
1413 |
+
), f"samples.shape: {samples.shape}"
|
1414 |
+
if cur_num_gen < self.args.n_codebooks - 1:
|
1415 |
+
for jj in range(1, self.args.n_codebooks - cur_num_gen):
|
1416 |
+
samples[-jj, 0] = self.args.empty_token
|
1417 |
+
|
1418 |
+
if (
|
1419 |
+
samples[0, 0] == eog_inference
|
1420 |
+
or torch.argmax(logits[0], dim=-1) == eog_inference
|
1421 |
+
or y_input.shape[1] > x_lens[0] * (self.args.encodec_sr // 5)
|
1422 |
+
): # last one means y is already too long, shouldn't happen, but put it here
|
1423 |
+
samples[0, 0] = eog_inference
|
1424 |
+
codebook_eog[0] = True
|
1425 |
+
##################### silence repetition handling #####################
|
1426 |
+
if samples[0, 0] in silence_tokens and samples[0, 0] == prev_token:
|
1427 |
+
consec_silence_count += 1
|
1428 |
+
else:
|
1429 |
+
consec_silence_count = 0
|
1430 |
+
prev_token = samples[0, 0]
|
1431 |
+
##################### silence repetition handling #####################
|
1432 |
+
return samples, codebook_eog, prev_token, consec_silence_count
|
1433 |
+
else:
|
1434 |
+
assert (
|
1435 |
+
sum(codebook_eog[i] for i in range(n_eog)) == n_eog
|
1436 |
+
), f"codebook_eog: {codebook_eog}, but n_eog: {n_eog}"
|
1437 |
+
logits_adjust = logits
|
1438 |
+
for jj in range(n_eog + 1, self.args.n_codebooks):
|
1439 |
+
logits_adjust[jj][eog_inference] = -10000
|
1440 |
+
logits_adjust[jj][self.args.empty_token] = -10000
|
1441 |
+
samples = topk_sampling(
|
1442 |
+
logits_adjust, top_k=top_k, top_p=top_p, temperature=temperature
|
1443 |
+
) # [K, 1]
|
1444 |
+
for jj in range(n_eog):
|
1445 |
+
samples[jj, 0] = self.args.empty_token
|
1446 |
+
samples[n_eog, 0] = eog_inference
|
1447 |
+
codebook_eog[n_eog] = True
|
1448 |
+
return samples, codebook_eog, prev_token, consec_silence_count
|
1449 |
+
|
1450 |
+
while True:
|
1451 |
+
y_out, present = self.dec_forward(
|
1452 |
+
x_input,
|
1453 |
+
x_lens,
|
1454 |
+
x_attention_mask,
|
1455 |
+
x_padding_mask,
|
1456 |
+
y_input,
|
1457 |
+
new_y_lens,
|
1458 |
+
y_attention_mask,
|
1459 |
+
y_padding_mask,
|
1460 |
+
past=past,
|
1461 |
+
)
|
1462 |
+
if past != None:
|
1463 |
+
past = (
|
1464 |
+
torch.cat([past, present.to(past.dtype)], dim=-2)
|
1465 |
+
if past.ndim > 3
|
1466 |
+
else present.to(past.dtype)
|
1467 |
+
)
|
1468 |
+
|
1469 |
+
y_out = y_out[:, -1:] # only take the last token
|
1470 |
+
logits = torch.stack(
|
1471 |
+
[self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)],
|
1472 |
+
dim=1,
|
1473 |
+
) # [B K S card], B==S==1, so [1 K 1 card]
|
1474 |
+
logits = logits.squeeze(0).squeeze(1) # [K card]
|
1475 |
+
assert logits.shape == torch.Size(
|
1476 |
+
(self.args.n_codebooks, self.n_audio_tokens[0])
|
1477 |
+
), f"{logits.shape}"
|
1478 |
+
|
1479 |
+
n_eog = sum(codebook_eog)
|
1480 |
+
assert n_eog < self.args.n_codebooks
|
1481 |
+
if (
|
1482 |
+
self.args.eos > 0
|
1483 |
+
): # if we are using end-of-sentence token (which is used by default), eog shouldn't be used here, as there is no masked spans
|
1484 |
+
for jj in range(self.args.n_codebooks):
|
1485 |
+
logits[jj][self.args.eog] = -10000.0
|
1486 |
+
|
1487 |
+
samples, codebook_eog, prev_token, consec_silence_count = sample_helper(
|
1488 |
+
n_eog,
|
1489 |
+
logits,
|
1490 |
+
codebook_eog,
|
1491 |
+
top_k,
|
1492 |
+
top_p,
|
1493 |
+
temperature,
|
1494 |
+
prev_token,
|
1495 |
+
consec_silence_count,
|
1496 |
+
stop_repetition,
|
1497 |
+
silence_tokens,
|
1498 |
+
cur_num_gen,
|
1499 |
+
)
|
1500 |
+
|
1501 |
+
cur_num_gen += 1
|
1502 |
+
cur_generated.append(samples.squeeze(-1)) # [K,1] -> [K]
|
1503 |
+
|
1504 |
+
# samples.shape is [K,1]
|
1505 |
+
# ge samples_emb
|
1506 |
+
samples_emb = torch.stack(
|
1507 |
+
[
|
1508 |
+
self.audio_embedding[k](samples[k])
|
1509 |
+
for k in range(self.args.n_codebooks)
|
1510 |
+
],
|
1511 |
+
dim=0,
|
1512 |
+
) # [K,1,D]
|
1513 |
+
samples_emb = samples_emb.sum(dim=0, keepdim=True) # [1,1,D]
|
1514 |
+
|
1515 |
+
if (
|
1516 |
+
sum(codebook_eog) == self.args.n_codebooks
|
1517 |
+
): # generation for the current span is done
|
1518 |
+
codebook_eog = [False] * self.args.n_codebooks
|
1519 |
+
num_gen.append(cur_num_gen)
|
1520 |
+
cur_num_gen = 0
|
1521 |
+
generated.append(cur_generated)
|
1522 |
+
cur_generated = []
|
1523 |
+
break
|
1524 |
+
else:
|
1525 |
+
assert samples_emb.shape == torch.Size(
|
1526 |
+
(1, 1, self.args.d_model)
|
1527 |
+
), f"samples_emb.shape: {samples_emb.shape}"
|
1528 |
+
|
1529 |
+
embedded_y = torch.cat([embedded_y, samples_emb], dim=1)
|
1530 |
+
y_input = self.audio_positional_embedding(embedded_y) # [B T D]
|
1531 |
+
# make attention mask and padding mask
|
1532 |
+
y_attention_mask = (
|
1533 |
+
torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1)
|
1534 |
+
.bool()
|
1535 |
+
.to(y.device)
|
1536 |
+
)
|
1537 |
+
new_y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device)
|
1538 |
+
y_padding_mask = torch.full((1, new_y_lens[0]), False).to(y.device)
|
1539 |
+
|
1540 |
+
assert len(generated) == 1, f"len(generated): {len(generated)}"
|
1541 |
+
|
1542 |
+
# revert the pattern
|
1543 |
+
flatten_gen = []
|
1544 |
+
for l, orig_span in enumerate(generated):
|
1545 |
+
span = torch.stack(orig_span, dim=0) # [T, K]
|
1546 |
+
span = span.transpose(1, 0) # [K, T]
|
1547 |
+
assert span.shape[0] == self.args.n_codebooks, span.shape
|
1548 |
+
unshifted_span = []
|
1549 |
+
for j, s in enumerate(span):
|
1550 |
+
start_from = j
|
1551 |
+
end_at = -(self.args.n_codebooks - start_from)
|
1552 |
+
unshifted_span.append(s[start_from:end_at])
|
1553 |
+
unshifted_span = torch.stack(unshifted_span, dim=0)
|
1554 |
+
|
1555 |
+
assert (
|
1556 |
+
unshifted_span.shape[1] == num_gen[l] - self.args.n_codebooks
|
1557 |
+
), f"len(unshifted_spans[0]): {len(unshifted_span[0])}, num_gen[l]: {num_gen[l]}"
|
1558 |
+
|
1559 |
+
flatten_gen.append(unshifted_span)
|
1560 |
+
assert len(flatten_gen) == 1, len(flatten_gen)
|
1561 |
+
|
1562 |
+
# combine
|
1563 |
+
res = [y[0], flatten_gen[0]]
|
1564 |
+
res = torch.cat(res, dim=1).unsqueeze(0) # [K, new_t] -> [1, K, new_T]
|
1565 |
+
|
1566 |
+
expected_y_len = y_len + sum([item - self.args.n_codebooks for item in num_gen])
|
1567 |
+
assert res.shape == torch.Size(
|
1568 |
+
(1, self.args.n_codebooks, expected_y_len)
|
1569 |
+
), f"res.shape: {res.shape}, expected_y_len: {expected_y_len}. y_len + sum([item - self.args.n_codebooks for item in num_gen]): {y_len} + {sum([item - self.args.n_codebooks for item in num_gen])}"
|
1570 |
+
|
1571 |
+
if self.args.special_first:
|
1572 |
+
res = res - int(self.args.n_special)
|
1573 |
+
flatten_gen = flatten_gen - int(self.args.n_special)
|
1574 |
+
|
1575 |
+
return res, flatten_gen[0].unsqueeze(0)
|
1576 |
+
|
1577 |
+
def inference_tts_batch(
|
1578 |
+
self,
|
1579 |
+
x: torch.Tensor,
|
1580 |
+
x_lens: torch.Tensor,
|
1581 |
+
y: torch.Tensor,
|
1582 |
+
top_k: int = -100,
|
1583 |
+
top_p: float = 1.0,
|
1584 |
+
temperature: float = 1.0,
|
1585 |
+
stop_repetition: int = 3,
|
1586 |
+
kvcache: int = 1,
|
1587 |
+
batch_size: int = 5,
|
1588 |
+
silence_tokens: list[int] = [1388, 1898, 131],
|
1589 |
+
*kargs,
|
1590 |
+
) -> torch.Tensor:
|
1591 |
+
"""
|
1592 |
+
have a batch size when forward passing, but they are equivalant to same example but different random seed, therefore as long as one example generated eog, we can drop all other samlpes
|
1593 |
+
different from inference_tts, this implementation uses kvcache, which should have significant speed up
|
1594 |
+
Args:
|
1595 |
+
x:
|
1596 |
+
A 2-D tensor of shape (1, L).
|
1597 |
+
x_lens:
|
1598 |
+
A 1-D tensor of shape (1,). It contains the number of tokens in `x`
|
1599 |
+
before padding.
|
1600 |
+
y:
|
1601 |
+
A 3-D tensor of shape (1, T, K).
|
1602 |
+
top_k: (`optional`) int
|
1603 |
+
The number of highest probability tokens to keep for top-k-filtering. Default to -100.
|
1604 |
+
top_p: (`optional`) float
|
1605 |
+
For Neucleus sampling
|
1606 |
+
temperature: (`optional`) float
|
1607 |
+
The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
|
1608 |
+
"""
|
1609 |
+
eog_inference = self.args.eos if self.args.eos > 0 else self.args.eog
|
1610 |
+
assert x.ndim == 2, x.shape
|
1611 |
+
assert x_lens.ndim == 1, x_lens.shape
|
1612 |
+
assert y.ndim == 3, y.shape
|
1613 |
+
if self.args.special_first:
|
1614 |
+
y = y + int(self.args.n_special)
|
1615 |
+
y = y.transpose(2, 1) # [1,T,K] -> [1,K,T]
|
1616 |
+
assert (
|
1617 |
+
y.shape[0] == 1 and y.shape[1] == self.args.n_codebooks
|
1618 |
+
), y.shape # there is no padding
|
1619 |
+
|
1620 |
+
# make x attention mask and x_input
|
1621 |
+
x_attention_mask = (
|
1622 |
+
torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1)
|
1623 |
+
.bool()
|
1624 |
+
.to(x.device)
|
1625 |
+
)
|
1626 |
+
# x_attention_mask = torch.zeros(x.shape[1], x.shape[1]).bool().to(x.device)
|
1627 |
+
x_input = self.text_embedding(x)
|
1628 |
+
x_input = self.text_positional_embedding(x_input)
|
1629 |
+
|
1630 |
+
y_len = y.shape[2]
|
1631 |
+
y_lens = torch.LongTensor([y_len]).to(y.device)
|
1632 |
+
|
1633 |
+
# rearrange y, we don't add eog to the end, this doesn't actually do anything in the tts scenario
|
1634 |
+
rearranged_y = [[y[0]]]
|
1635 |
+
assert rearranged_y[0][0].shape[0] == self.args.n_codebooks, rearranged_y[0][
|
1636 |
+
0
|
1637 |
+
].shape
|
1638 |
+
|
1639 |
+
# shift y to create the delayed pattern
|
1640 |
+
shifted_y, patterns = self.shift(
|
1641 |
+
rearranged_y
|
1642 |
+
) # each element [K S], patterns is not used, as we directly use the original input y
|
1643 |
+
assert shifted_y[0][0].shape[0] == self.args.n_codebooks, shifted_y[0][0].shape
|
1644 |
+
assert len(shifted_y[0]) == 1, len(shifted_y[0])
|
1645 |
+
|
1646 |
+
# below is different from forward or inference
|
1647 |
+
# where we cut this shifted part
|
1648 |
+
shifted_y[0][0] = shifted_y[0][0][:, : -(self.args.n_codebooks - 1)]
|
1649 |
+
assert (
|
1650 |
+
not (
|
1651 |
+
shifted_y[0][0][self.args.n_codebooks :] == self.args.empty_token
|
1652 |
+
).any()
|
1653 |
+
and not (shifted_y[0][0][self.args.n_codebooks :] == self.args.eog).any()
|
1654 |
+
), shifted_y[0][0]
|
1655 |
+
|
1656 |
+
# next section in inference is insert mask at the intersection of each tensor in a sample, but we don't need to do that
|
1657 |
+
# next section is concate tensors of each sample to one tensor, which we also don't need
|
1658 |
+
cated_y = shifted_y[0][0].unsqueeze(-1) # [K,S]->[K,S,B]
|
1659 |
+
new_y_lens = torch.LongTensor([cated_y.shape[1]]).to(cated_y.device)
|
1660 |
+
assert cated_y.shape == torch.Size((self.args.n_codebooks, cated_y.shape[1], 1))
|
1661 |
+
assert not (cated_y == self.args.audio_pad_token).any(), cated_y
|
1662 |
+
|
1663 |
+
# replace tokens in y with the embeddings, add sum codebooks up
|
1664 |
+
embedded_y = torch.stack(
|
1665 |
+
[self.audio_embedding[k](cated_y[k]) for k in range(self.args.n_codebooks)],
|
1666 |
+
dim=0,
|
1667 |
+
) # [K, S, B, D]
|
1668 |
+
assert embedded_y.shape[0] == self.args.n_codebooks, embedded_y.shape
|
1669 |
+
assert embedded_y.shape[-1] == self.args.d_model, embedded_y.shape
|
1670 |
+
embedded_y = embedded_y.sum(dim=0) # [K,S,B,D]->[S,B,D]
|
1671 |
+
embedded_y = embedded_y.transpose(1, 0) # [S,B,D]->[B,S,D]
|
1672 |
+
|
1673 |
+
# positional embedding
|
1674 |
+
y_input = self.audio_positional_embedding(embedded_y)
|
1675 |
+
|
1676 |
+
# make attention mask and padding mask
|
1677 |
+
y_attention_mask = (
|
1678 |
+
torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1)
|
1679 |
+
.bool()
|
1680 |
+
.to(y.device)
|
1681 |
+
)
|
1682 |
+
|
1683 |
+
x_padding_mask = torch.full((1, x_lens[0]), False).to(x.device)
|
1684 |
+
y_padding_mask = torch.full((1, new_y_lens[0]), False).to(y.device)
|
1685 |
+
|
1686 |
+
# entering the generation stage
|
1687 |
+
# starting from line 708
|
1688 |
+
codebook_eog = [False] * self.args.n_codebooks
|
1689 |
+
generated = [] # doesn't contain any empty token, contain eog
|
1690 |
+
cur_generated = [[] for _ in range(batch_size)]
|
1691 |
+
# say 0 is empty, 4 is eog
|
1692 |
+
# tensor([[ 1, 2, 3, 4, 0, 0],
|
1693 |
+
# [ 0, 1, 2, 3, 4, 0],
|
1694 |
+
# [ 0, 0, 1, 2, 3, 4]])
|
1695 |
+
num_gen = []
|
1696 |
+
cur_num_gen = 0
|
1697 |
+
##################### silence repetition handling #####################
|
1698 |
+
##################### silence repetition handling #####################
|
1699 |
+
logging.info(
|
1700 |
+
f"silence tokens: {silence_tokens}, note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default"
|
1701 |
+
)
|
1702 |
+
consec_silence_counts = [0 for _ in range(batch_size)]
|
1703 |
+
prev_tokens = [None for _ in range(batch_size)]
|
1704 |
+
##################### silence repetition handling #####################
|
1705 |
+
##################### silence repetition handling #####################
|
1706 |
+
|
1707 |
+
# prepare the cache placeholder
|
1708 |
+
# n_layers, 2, bsz, num_heads, src_len, head_dim
|
1709 |
+
past = (
|
1710 |
+
torch.ones(
|
1711 |
+
[self.args.num_decoder_layers, 2, x.shape[0]],
|
1712 |
+
device=x.device,
|
1713 |
+
dtype=torch.float32,
|
1714 |
+
)
|
1715 |
+
if kvcache
|
1716 |
+
else None
|
1717 |
+
)
|
1718 |
+
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
|
1719 |
+
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
|
1720 |
+
# logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
|
1721 |
+
keep = None # NOTE: this very important, tells which sample to keep
|
1722 |
+
|
1723 |
+
def sample_helper(
|
1724 |
+
n_eog,
|
1725 |
+
logits,
|
1726 |
+
codebook_eog,
|
1727 |
+
top_k,
|
1728 |
+
top_p,
|
1729 |
+
temperature,
|
1730 |
+
prev_tokens,
|
1731 |
+
consec_silence_counts,
|
1732 |
+
stop_repetition,
|
1733 |
+
silence_tokens,
|
1734 |
+
cur_num_gen,
|
1735 |
+
keep,
|
1736 |
+
):
|
1737 |
+
if n_eog == 0:
|
1738 |
+
logits_adjust = logits
|
1739 |
+
for jj in range(1, self.args.n_codebooks):
|
1740 |
+
logits_adjust[:, jj, eog_inference] = -10000
|
1741 |
+
logits_adjust[:, jj, self.args.empty_token] = -10000
|
1742 |
+
if (
|
1743 |
+
cur_num_gen <= self.args.encodec_sr // 5
|
1744 |
+
): # this shouldn't happen, but just in case the model stopped too early
|
1745 |
+
logits_adjust[:, :, eog_inference] = -10000
|
1746 |
+
##################### silence repetition handling #####################
|
1747 |
+
for b in range(batch_size):
|
1748 |
+
prev_token = prev_tokens[b]
|
1749 |
+
consec_silence_count = consec_silence_counts[b]
|
1750 |
+
if (
|
1751 |
+
stop_repetition > 0
|
1752 |
+
and prev_token in silence_tokens
|
1753 |
+
and consec_silence_count > stop_repetition
|
1754 |
+
):
|
1755 |
+
if logits_adjust[b, 0, prev_token] < 0:
|
1756 |
+
logits_adjust[b, 0, prev_token] = logits_adjust[
|
1757 |
+
b, 0, prev_token
|
1758 |
+
] * (consec_silence_count - (stop_repetition - 1))
|
1759 |
+
else:
|
1760 |
+
logits_adjust[b, 0, prev_token] = logits_adjust[
|
1761 |
+
b, 0, prev_token
|
1762 |
+
] / (consec_silence_count - (stop_repetition - 1))
|
1763 |
+
##################### silence repetition handling #####################
|
1764 |
+
samples = topk_sampling(
|
1765 |
+
logits_adjust.reshape(
|
1766 |
+
batch_size * self.args.n_codebooks, logits_adjust.shape[-1]
|
1767 |
+
),
|
1768 |
+
top_k=top_k,
|
1769 |
+
top_p=top_p,
|
1770 |
+
temperature=temperature,
|
1771 |
+
) # [B*K, 1]
|
1772 |
+
samples = samples.reshape(batch_size, self.args.n_codebooks, 1)
|
1773 |
+
assert samples.shape == torch.Size(
|
1774 |
+
(batch_size, self.args.n_codebooks, 1)
|
1775 |
+
), f"samples.shape: {samples.shape}"
|
1776 |
+
for b in range(batch_size):
|
1777 |
+
if cur_num_gen < self.args.n_codebooks - 1:
|
1778 |
+
for jj in range(1, self.args.n_codebooks - cur_num_gen):
|
1779 |
+
samples[b, -jj, 0] = self.args.empty_token
|
1780 |
+
|
1781 |
+
if (
|
1782 |
+
samples[b, 0, 0] == eog_inference
|
1783 |
+
or torch.argmax(logits[b, 0], dim=-1) == eog_inference
|
1784 |
+
or y_input.shape[1] > x_lens[b] * (self.args.encodec_sr // 5)
|
1785 |
+
): # last one means y is already too long, shouldn't happen, but put it here
|
1786 |
+
samples[b, 0, 0] = eog_inference
|
1787 |
+
codebook_eog[0] = True
|
1788 |
+
keep = b # NOTE keep is a very important variable, we only return this one, note that if eog shows up in two samples, keep will be overwritten by the later one (or the last one)
|
1789 |
+
##################### silence repetition handling #####################
|
1790 |
+
if (
|
1791 |
+
samples[b, 0, 0] in silence_tokens
|
1792 |
+
and samples[b, 0, 0] == prev_tokens[b]
|
1793 |
+
):
|
1794 |
+
consec_silence_counts[b] += 1
|
1795 |
+
else:
|
1796 |
+
consec_silence_counts[b] = 0
|
1797 |
+
prev_tokens[b] = samples[b, 0, 0]
|
1798 |
+
##################### silence repetition handling #####################
|
1799 |
+
return samples, codebook_eog, prev_tokens, consec_silence_counts, keep
|
1800 |
+
else:
|
1801 |
+
assert (
|
1802 |
+
sum(codebook_eog[i] for i in range(n_eog)) == n_eog
|
1803 |
+
), f"codebook_eog: {codebook_eog}, but n_eog: {n_eog}"
|
1804 |
+
logits_adjust = logits
|
1805 |
+
for jj in range(n_eog + 1, self.args.n_codebooks):
|
1806 |
+
logits_adjust[:, jj, eog_inference] = -10000
|
1807 |
+
logits_adjust[:, jj, self.args.empty_token] = -10000
|
1808 |
+
samples = topk_sampling(
|
1809 |
+
logits_adjust.reshape(
|
1810 |
+
batch_size * self.args.n_codebooks, logits_adjust.shape[-1]
|
1811 |
+
),
|
1812 |
+
top_k=top_k,
|
1813 |
+
top_p=top_p,
|
1814 |
+
temperature=temperature,
|
1815 |
+
) # [B, K, 1]
|
1816 |
+
samples = samples.reshape(batch_size, self.args.n_codebooks, 1)
|
1817 |
+
for jj in range(n_eog):
|
1818 |
+
samples[keep, jj, 0] = self.args.empty_token
|
1819 |
+
samples[keep, n_eog, 0] = eog_inference
|
1820 |
+
codebook_eog[n_eog] = True
|
1821 |
+
return samples, codebook_eog, prev_tokens, consec_silence_counts, keep
|
1822 |
+
|
1823 |
+
while True:
|
1824 |
+
# if cur_num_gen > 0, should have everything in kvcache, so only pass in the last token
|
1825 |
+
# in the first generation step, we repeat each tensor to make their first dimension of length the batch size
|
1826 |
+
if cur_num_gen == 0:
|
1827 |
+
assert x_input.ndim == 3 and x_input.shape[0] == 1, x_input.shape
|
1828 |
+
assert (
|
1829 |
+
x_padding_mask.ndim == 2 and x_padding_mask.shape[0] == 1
|
1830 |
+
), x_padding_mask.shape
|
1831 |
+
assert (
|
1832 |
+
y_input.ndim == 3
|
1833 |
+
and y_input.shape[0] == 1
|
1834 |
+
and y_input.shape[1] == new_y_lens[0]
|
1835 |
+
), y_input.shape
|
1836 |
+
assert (
|
1837 |
+
embedded_y.ndim == 3
|
1838 |
+
and embedded_y.shape[0] == 1
|
1839 |
+
and embedded_y.shape[1] == new_y_lens[0]
|
1840 |
+
), embedded_y.shape
|
1841 |
+
x_input = x_input.repeat(batch_size, 1, 1)
|
1842 |
+
x_lens = x_lens.repeat(batch_size)
|
1843 |
+
# x_attention_mask = x_attention_mask.repeat(batch_size, 1, 1) # no need to work with attention mask, it doesn't contain batch dimension
|
1844 |
+
x_padding_mask = x_padding_mask.repeat(batch_size, 1)
|
1845 |
+
y_input = y_input.repeat(batch_size, 1, 1)
|
1846 |
+
new_y_lens = new_y_lens.repeat(batch_size)
|
1847 |
+
# y_attention_mask = y_attention_mask.repeat(batch_size, 1, 1) # no need to work with attention mask, it doesn't contain batch dimension
|
1848 |
+
y_padding_mask = y_padding_mask.repeat(batch_size, 1)
|
1849 |
+
embedded_y = embedded_y.repeat(
|
1850 |
+
batch_size, 1, 1
|
1851 |
+
) # will be used to concat with newly generated token embedding
|
1852 |
+
past = past.repeat(1, 1, batch_size) if past != None else None
|
1853 |
+
else:
|
1854 |
+
assert (
|
1855 |
+
x_input.shape[0] == batch_size
|
1856 |
+
and x_padding_mask.shape[0] == batch_size
|
1857 |
+
and y_input.shape[0] == batch_size
|
1858 |
+
and new_y_lens.shape[0] == batch_size
|
1859 |
+
), f"x_input.shape: {x_input.shape}, x_padding_mask.shape: {x_padding_mask.shape}, y_input.shape: {y_input.shape}, new_y_lens.shape: {new_y_lens.shape}"
|
1860 |
+
y_out, present = self.dec_forward(
|
1861 |
+
x_input,
|
1862 |
+
x_lens,
|
1863 |
+
x_attention_mask,
|
1864 |
+
x_padding_mask,
|
1865 |
+
y_input,
|
1866 |
+
new_y_lens,
|
1867 |
+
y_attention_mask,
|
1868 |
+
y_padding_mask,
|
1869 |
+
past=past,
|
1870 |
+
)
|
1871 |
+
if past != None:
|
1872 |
+
past = (
|
1873 |
+
torch.cat([past, present.to(past.dtype)], dim=-2)
|
1874 |
+
if past.ndim > 3
|
1875 |
+
else present.to(past.dtype)
|
1876 |
+
)
|
1877 |
+
|
1878 |
+
# if no eog emerges, y_out should have batch size of batch_size
|
1879 |
+
if sum(codebook_eog) == 0:
|
1880 |
+
assert y_out.shape[0] == batch_size and y_out.ndim == 3, y_out.shape
|
1881 |
+
y_out = y_out[:, -1:] # only take the last token
|
1882 |
+
logits = torch.stack(
|
1883 |
+
[self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)],
|
1884 |
+
dim=1,
|
1885 |
+
) # [B K S card], S==1, so [B K 1 card]
|
1886 |
+
logits = logits.squeeze(2) # [B K card]
|
1887 |
+
assert logits.shape == torch.Size(
|
1888 |
+
(batch_size, self.args.n_codebooks, self.n_audio_tokens[0])
|
1889 |
+
), f"{logits.shape}"
|
1890 |
+
|
1891 |
+
n_eog = sum(codebook_eog)
|
1892 |
+
if self.args.eos > 0:
|
1893 |
+
for jj in range(self.args.n_codebooks):
|
1894 |
+
logits[:, jj, self.args.eog] = -10000.0
|
1895 |
+
samples, codebook_eog, prev_tokens, consec_silence_counts, keep = (
|
1896 |
+
sample_helper(
|
1897 |
+
n_eog,
|
1898 |
+
logits,
|
1899 |
+
codebook_eog,
|
1900 |
+
top_k,
|
1901 |
+
top_p,
|
1902 |
+
temperature,
|
1903 |
+
prev_tokens,
|
1904 |
+
consec_silence_counts,
|
1905 |
+
stop_repetition,
|
1906 |
+
silence_tokens,
|
1907 |
+
cur_num_gen,
|
1908 |
+
keep,
|
1909 |
+
)
|
1910 |
+
)
|
1911 |
+
|
1912 |
+
cur_num_gen += 1
|
1913 |
+
if sum(codebook_eog) == 0: # no eog yet, keep batch_size of samples
|
1914 |
+
assert keep == None
|
1915 |
+
for b in range(batch_size):
|
1916 |
+
cur_generated[b].append(samples[b].squeeze(-1))
|
1917 |
+
elif sum(codebook_eog) == 1: # the first eog just showed up in this step
|
1918 |
+
assert keep != None
|
1919 |
+
cur_generated = cur_generated[keep]
|
1920 |
+
cur_generated.append(samples[keep].squeeze(-1))
|
1921 |
+
else: # we are generating the rest eogs for the 'keep' sample
|
1922 |
+
cur_generated.append(samples[keep].squeeze(-1))
|
1923 |
+
|
1924 |
+
# samples.shape is [K,1]
|
1925 |
+
# ge samples_emb
|
1926 |
+
samples_emb = torch.stack(
|
1927 |
+
[
|
1928 |
+
self.audio_embedding[k](samples[:, k])
|
1929 |
+
for k in range(self.args.n_codebooks)
|
1930 |
+
],
|
1931 |
+
dim=1,
|
1932 |
+
) # [B, K,1,D]
|
1933 |
+
assert samples_emb.shape == torch.Size(
|
1934 |
+
[batch_size, self.args.n_codebooks, 1, self.args.d_model]
|
1935 |
+
)
|
1936 |
+
samples_emb = samples_emb.sum(dim=1, keepdim=False) # [B,1,D]
|
1937 |
+
if (
|
1938 |
+
sum(codebook_eog) == self.args.n_codebooks
|
1939 |
+
): # generation for the current span is done
|
1940 |
+
codebook_eog = [False] * self.args.n_codebooks
|
1941 |
+
num_gen.append(cur_num_gen)
|
1942 |
+
cur_num_gen = 0
|
1943 |
+
generated.append(cur_generated)
|
1944 |
+
cur_generated = [[] for _ in range(batch_size)]
|
1945 |
+
break
|
1946 |
+
else:
|
1947 |
+
assert samples_emb.shape == torch.Size(
|
1948 |
+
(batch_size, 1, self.args.d_model)
|
1949 |
+
), f"samples_emb.shape: {samples_emb.shape}"
|
1950 |
+
|
1951 |
+
embedded_y = torch.cat([embedded_y, samples_emb], dim=1)
|
1952 |
+
y_input = self.audio_positional_embedding(embedded_y) # [B T D]
|
1953 |
+
# make attention mask and padding mask
|
1954 |
+
y_attention_mask = (
|
1955 |
+
torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1)
|
1956 |
+
.bool()
|
1957 |
+
.to(y.device)
|
1958 |
+
)
|
1959 |
+
new_y_lens = (
|
1960 |
+
torch.LongTensor([y_input.shape[1]]).to(y.device).repeat(batch_size)
|
1961 |
+
)
|
1962 |
+
y_padding_mask = torch.full((batch_size, new_y_lens[0]), False).to(y.device)
|
1963 |
+
|
1964 |
+
assert len(generated) == 1, f"len(generated): {len(generated)}"
|
1965 |
+
|
1966 |
+
# revert the pattern
|
1967 |
+
flatten_gen = []
|
1968 |
+
for l, orig_span in enumerate(generated):
|
1969 |
+
span = torch.stack(orig_span, dim=0) # [T, K]
|
1970 |
+
span = span.transpose(1, 0) # [K, T]
|
1971 |
+
assert span.shape[0] == self.args.n_codebooks, span.shape
|
1972 |
+
unshifted_span = []
|
1973 |
+
for j, s in enumerate(span):
|
1974 |
+
start_from = j
|
1975 |
+
end_at = -(self.args.n_codebooks - start_from)
|
1976 |
+
unshifted_span.append(s[start_from:end_at])
|
1977 |
+
unshifted_span = torch.stack(unshifted_span, dim=0)
|
1978 |
+
|
1979 |
+
assert (
|
1980 |
+
unshifted_span.shape[1] == num_gen[l] - self.args.n_codebooks
|
1981 |
+
), f"len(unshifted_spans[0]): {len(unshifted_span[0])}, num_gen[l]: {num_gen[l]}"
|
1982 |
+
|
1983 |
+
flatten_gen.append(unshifted_span)
|
1984 |
+
assert len(flatten_gen) == 1, len(flatten_gen)
|
1985 |
+
|
1986 |
+
# combine
|
1987 |
+
res = [y[0], flatten_gen[0]]
|
1988 |
+
res = torch.cat(res, dim=1).unsqueeze(0) # [K, new_t] -> [1, K, new_T]
|
1989 |
+
|
1990 |
+
expected_y_len = y_len + sum([item - self.args.n_codebooks for item in num_gen])
|
1991 |
+
assert res.shape == torch.Size(
|
1992 |
+
(1, self.args.n_codebooks, expected_y_len)
|
1993 |
+
), f"res.shape: {res.shape}, expected_y_len: {expected_y_len}. y_len + sum([item - self.args.n_codebooks for item in num_gen]): {y_len} + {sum([item - self.args.n_codebooks for item in num_gen])}"
|
1994 |
+
|
1995 |
+
if self.args.special_first:
|
1996 |
+
res = res - int(self.args.n_special)
|
1997 |
+
flatten_gen = flatten_gen - int(self.args.n_special)
|
1998 |
+
|
1999 |
+
return res, flatten_gen[0].unsqueeze(0)
|
src/model/modules/voicecraftconfig.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class VoiceCraftConfig:
|
2 |
+
|
3 |
+
def __init__(
|
4 |
+
self,
|
5 |
+
model_name="330M_TTSEnhanced.pth", # "gigaHalfLibri330M_TTSEnhanced_max16s.pth",
|
6 |
+
encodec="encodec_4cb2048_giga.th",
|
7 |
+
top_k=0,
|
8 |
+
top_p=0.9,
|
9 |
+
temperature=1,
|
10 |
+
kvcache=1,
|
11 |
+
codec_sr=50,
|
12 |
+
codec_audio_sr=16000,
|
13 |
+
silence_tokens=[1388, 1898, 131],
|
14 |
+
stop_repetition=3,
|
15 |
+
sample_batch_size=2,
|
16 |
+
seed=1,
|
17 |
+
cut_off_sec=7.87,
|
18 |
+
voice_audio_path="84_121550_000074_000000.wav",
|
19 |
+
voice_audio_transcript="But when I had approached so near to them The common object, which the sense deceives, Lost not by distance any of its marks",
|
20 |
+
**kwargs,
|
21 |
+
):
|
22 |
+
super().__init__()
|
23 |
+
self.model_name = model_name
|
24 |
+
self.encodec = encodec
|
25 |
+
self.top_k = top_k
|
26 |
+
self.top_p = top_p
|
27 |
+
self.temperature = temperature
|
28 |
+
self.kvcache = kvcache
|
29 |
+
self.codec_sr = codec_sr
|
30 |
+
self.codec_audio_sr = codec_audio_sr
|
31 |
+
self.silence_tokens = silence_tokens
|
32 |
+
self.stop_repetition = stop_repetition
|
33 |
+
self.sample_batch_size = sample_batch_size
|
34 |
+
self.seed = seed
|
35 |
+
self.cut_off_sec = cut_off_sec
|
36 |
+
self.voice_audio_path = voice_audio_path
|
37 |
+
self.voice_audio_transcript = voice_audio_transcript
|
src/utils/__init__.py
ADDED
File without changes
|
src/utils/image_utils.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import base64
|
17 |
+
import logging
|
18 |
+
import os
|
19 |
+
from io import BytesIO
|
20 |
+
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
|
21 |
+
|
22 |
+
import PIL
|
23 |
+
import numpy as np
|
24 |
+
import requests
|
25 |
+
from packaging import version
|
26 |
+
|
27 |
+
|
28 |
+
def _is_numpy(x):
|
29 |
+
return isinstance(x, np.ndarray)
|
30 |
+
|
31 |
+
|
32 |
+
def is_numpy_array(x):
|
33 |
+
"""
|
34 |
+
Tests if `x` is a numpy array or not.
|
35 |
+
"""
|
36 |
+
return _is_numpy(x)
|
37 |
+
|
38 |
+
|
39 |
+
def is_pil_image(img):
|
40 |
+
return isinstance(img, PIL.Image.Image)
|
41 |
+
|
42 |
+
|
43 |
+
def is_valid_image(img):
|
44 |
+
return is_pil_image(img) or is_numpy_array(img)
|
45 |
+
|
46 |
+
|
47 |
+
def valid_images(imgs):
|
48 |
+
# If we have an list of images, make sure every image is valid
|
49 |
+
if isinstance(imgs, (list, tuple)):
|
50 |
+
for img in imgs:
|
51 |
+
if not valid_images(img):
|
52 |
+
return False
|
53 |
+
# If not a list of tuple, we have been given a single image or batched tensor of images
|
54 |
+
elif not is_valid_image(imgs):
|
55 |
+
return False
|
56 |
+
return True
|
57 |
+
|
58 |
+
|
59 |
+
def is_batched(img):
|
60 |
+
if isinstance(img, (list, tuple)):
|
61 |
+
return is_valid_image(img[0])
|
62 |
+
return False
|
63 |
+
|
64 |
+
|
65 |
+
def is_scaled_image(image: np.ndarray) -> bool:
|
66 |
+
"""
|
67 |
+
Checks to see whether the pixel values have already been rescaled to [0, 1].
|
68 |
+
"""
|
69 |
+
if image.dtype == np.uint8:
|
70 |
+
return False
|
71 |
+
|
72 |
+
# It's possible the image has pixel values in [0, 255] but is of floating type
|
73 |
+
return np.min(image) >= 0 and np.max(image) <= 1
|
74 |
+
|
75 |
+
|
76 |
+
def make_batched_images(images):
|
77 |
+
"""
|
78 |
+
Accepts images in list or nested list format, and makes a list of images for preprocessing.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
|
82 |
+
The input image.
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
list: A list of images.
|
86 |
+
"""
|
87 |
+
if (
|
88 |
+
isinstance(images, (list, tuple))
|
89 |
+
and isinstance(images[0], (list, tuple))
|
90 |
+
and is_valid_image(images[0][0])
|
91 |
+
):
|
92 |
+
return [img for img_list in images for img in img_list]
|
93 |
+
|
94 |
+
elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
|
95 |
+
return images
|
96 |
+
|
97 |
+
elif is_valid_image(images):
|
98 |
+
return [images]
|
99 |
+
|
100 |
+
raise ValueError(f"Could not make batched video from {images}")
|
src/utils/model_utils.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from typing import Optional
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
|
7 |
+
from src.model.modules.imagecraftconfig import ImageCraftConfig
|
8 |
+
from src.model.modules.imagecraftprocessor import (
|
9 |
+
ImageCraftProcessor,
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
+
def move_inputs_to_device(model_inputs: dict, device: str):
|
14 |
+
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
|
15 |
+
return model_inputs
|
16 |
+
|
17 |
+
|
18 |
+
def get_model_inputs(
|
19 |
+
processor: ImageCraftProcessor,
|
20 |
+
prompt: str,
|
21 |
+
image: Image,
|
22 |
+
suffix: Optional[str] = None,
|
23 |
+
device: str = "cuda",
|
24 |
+
):
|
25 |
+
images = [image]
|
26 |
+
prompts = [prompt]
|
27 |
+
if suffix is not None:
|
28 |
+
suffix = [suffix]
|
29 |
+
model_inputs = processor(text=prompts, images=images)
|
30 |
+
model_inputs = move_inputs_to_device(model_inputs, device)
|
31 |
+
return model_inputs
|
32 |
+
|
33 |
+
|
34 |
+
def get_config(config_file="config.json"):
|
35 |
+
config = None
|
36 |
+
with open(config_file, "r") as f:
|
37 |
+
model_config_file = json.load(f)
|
38 |
+
config = ImageCraftConfig(**model_config_file)
|
39 |
+
|
40 |
+
return config
|
41 |
+
|
42 |
+
|
43 |
+
# def load_hf_model(model_path: str, device: str) -> Tuple[ImageCraft, AutoTokenizer]:
|
44 |
+
|
45 |
+
# # Load the tokenizer
|
46 |
+
# tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="right")
|
47 |
+
# assert tokenizer.padding_side == "right"
|
48 |
+
|
49 |
+
# # Find all the *.safetensors files
|
50 |
+
# safetensors_files = glob.glob(os.path.join(model_path, "*.safetensors"))
|
51 |
+
|
52 |
+
# # ... and load them one by one in the tensors dictionary
|
53 |
+
# tensors = {}
|
54 |
+
# for safetensors_file in safetensors_files:
|
55 |
+
# with safe_open(safetensors_file, framework="pt", device="cpu") as f:
|
56 |
+
# for key in f.keys():
|
57 |
+
# tensors[key] = f.get_tensor(key)
|
58 |
+
|
59 |
+
# # Load the model's config
|
60 |
+
# with open(os.path.join(model_path, "config.json"), "r") as f:
|
61 |
+
# model_config_file = json.load(f)
|
62 |
+
# config = ImageCraftConfig(**model_config_file)
|
63 |
+
|
64 |
+
# # Create the model using the configuration
|
65 |
+
# model = ImageCraft(config).to(device)
|
66 |
+
|
67 |
+
# # Load the state dict of the model
|
68 |
+
# model.load_state_dict(tensors, strict=False)
|
69 |
+
|
70 |
+
# # Tie weights
|
71 |
+
# model.tie_weights()
|
72 |
+
|
73 |
+
# return (model, tokenizer)
|
src/utils/tools.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import yaml
|
2 |
+
import pickle
|
3 |
+
|
4 |
+
|
5 |
+
def load_config():
|
6 |
+
# Read in the configuration file
|
7 |
+
with open("config.yaml") as p:
|
8 |
+
config = yaml.safe_load(p)
|
9 |
+
return config
|
10 |
+
|
11 |
+
|
12 |
+
def pickle_dump(path, variable):
|
13 |
+
# Serialize data from memory to file
|
14 |
+
with open(path, "wb") as handle:
|
15 |
+
pickle.dump(variable, handle)
|
16 |
+
|
17 |
+
|
18 |
+
def pickle_load(path):
|
19 |
+
# Read and load serialized data from file
|
20 |
+
with open(path, "rb") as handle:
|
21 |
+
loaded = pickle.load(handle)
|
22 |
+
return loaded
|
src/utils/util.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
from tempfile import TemporaryDirectory
|
5 |
+
import torch
|
6 |
+
import torchaudio
|
7 |
+
import random
|
8 |
+
import numpy as np
|
9 |
+
from PIL import Image
|
10 |
+
from urllib.parse import urlparse
|
11 |
+
from os.path import exists
|
12 |
+
import re
|
13 |
+
from num2words import num2words
|
14 |
+
import uuid
|
15 |
+
|
16 |
+
from typing import List, Optional, Dict, Union, Tuple, Iterable
|
17 |
+
|
18 |
+
from src.utils.image_utils import is_valid_image
|
19 |
+
|
20 |
+
|
21 |
+
IMAGENET_STANDARD_MEAN = [0.5, 0.5, 0.5]
|
22 |
+
IMAGENET_STANDARD_STD = [0.5, 0.5, 0.5]
|
23 |
+
|
24 |
+
|
25 |
+
def is_local(url):
|
26 |
+
url_parsed = urlparse(url)
|
27 |
+
if url_parsed.scheme in ("file", ""):
|
28 |
+
return exists(url_parsed.path)
|
29 |
+
return False
|
30 |
+
|
31 |
+
|
32 |
+
def replace_numbers_with_words(sentence):
|
33 |
+
sentence = re.sub(r"(\d+)", r" \1 ", sentence)
|
34 |
+
|
35 |
+
def replace_with_words(match):
|
36 |
+
num = match.group(0)
|
37 |
+
try:
|
38 |
+
return num2words(num)
|
39 |
+
except:
|
40 |
+
return num
|
41 |
+
|
42 |
+
return re.sub(r"\b\d+\b", replace_with_words, sentence)
|
43 |
+
|
44 |
+
|
45 |
+
def save_to_buffer(audio_tensors, codec_audio_sr):
|
46 |
+
|
47 |
+
result = torch.cat(audio_tensors, 1)
|
48 |
+
buffer = io.BytesIO()
|
49 |
+
torchaudio.save(buffer, result, int(codec_audio_sr), format="wav")
|
50 |
+
buffer.seek(0)
|
51 |
+
return buffer.read()
|
52 |
+
|
53 |
+
|
54 |
+
def save_to_file(audio_tensors, codec_audio_sr):
|
55 |
+
generated_audio_dir = f"media/voicecraft/generated"
|
56 |
+
Path(generated_audio_dir).mkdir(parents=True, exist_ok=True)
|
57 |
+
filename = f"{generated_audio_dir}/{str(uuid.uuid4())}.wav"
|
58 |
+
tensors = torch.cat(audio_tensors, 1)
|
59 |
+
torchaudio.save(filename, tensors, int(codec_audio_sr), format="wav")
|
60 |
+
return filename
|
61 |
+
|
62 |
+
|
63 |
+
def split_line_to_sentences(line):
|
64 |
+
line = line.strip().capitalize()
|
65 |
+
line = line + "." if line and line[-1] not in (".", "!", "?") else line
|
66 |
+
sentences = re.findall(r"\w+.*?[.?!]", line.replace("\n", " "), flags=re.S)
|
67 |
+
return sentences
|
68 |
+
|
69 |
+
|
70 |
+
def seed_everything(seed=1):
|
71 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
72 |
+
random.seed(seed)
|
73 |
+
np.random.seed(seed)
|
74 |
+
torch.manual_seed(seed)
|
75 |
+
torch.cuda.manual_seed(seed)
|
76 |
+
torch.backends.cudnn.benchmark = False
|
77 |
+
torch.backends.cudnn.deterministic = True
|
78 |
+
|
79 |
+
|
80 |
+
def add_image_tokens_to_prompt(prefix_prompt, bos_token, image_seq_length, image_token):
|
81 |
+
return f"{image_token * image_seq_length}{bos_token}{prefix_prompt}\n"
|
82 |
+
|
83 |
+
|
84 |
+
def rescale(
|
85 |
+
image: np.ndarray, scale: float, dtype: np.dtype = np.float32
|
86 |
+
) -> np.ndarray:
|
87 |
+
rescaled_image = image * scale
|
88 |
+
rescaled_image = rescaled_image.astype(dtype)
|
89 |
+
return rescaled_image
|
90 |
+
|
91 |
+
|
92 |
+
def resize(
|
93 |
+
image: Image,
|
94 |
+
size: Tuple[int, int],
|
95 |
+
resample: Image.Resampling = None,
|
96 |
+
reducing_gap: Optional[int] = None,
|
97 |
+
) -> np.ndarray:
|
98 |
+
height, width = size
|
99 |
+
resized_image = image.resize(
|
100 |
+
(width, height), resample=resample, reducing_gap=reducing_gap
|
101 |
+
)
|
102 |
+
return resized_image
|
103 |
+
|
104 |
+
|
105 |
+
def normalize(
|
106 |
+
image: np.ndarray,
|
107 |
+
mean: Union[float, Iterable[float]],
|
108 |
+
std: Union[float, Iterable[float]],
|
109 |
+
) -> np.ndarray:
|
110 |
+
mean = np.array(mean, dtype=image.dtype)
|
111 |
+
std = np.array(std, dtype=image.dtype)
|
112 |
+
image = (image - mean) / std
|
113 |
+
return image
|
114 |
+
|
115 |
+
|
116 |
+
def process_images(
|
117 |
+
images: List[Image.Image],
|
118 |
+
size: Dict[str, int] = None,
|
119 |
+
resample: Image.Resampling = None,
|
120 |
+
rescale_factor: float = None,
|
121 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
122 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
123 |
+
) -> List[np.ndarray]:
|
124 |
+
height, width = size[0], size[1]
|
125 |
+
images = [
|
126 |
+
resize(image=image, size=(height, width), resample=resample) for image in images
|
127 |
+
]
|
128 |
+
# Convert each image to a numpy array
|
129 |
+
images = [np.array(image) for image in images]
|
130 |
+
# Rescale the pixel values to be in the range [0, 1]
|
131 |
+
images = [rescale(image, scale=rescale_factor) for image in images]
|
132 |
+
# Normalize the images to have mean 0 and standard deviation 1
|
133 |
+
images = [normalize(image, mean=image_mean, std=image_std) for image in images]
|
134 |
+
# Move the channel dimension to the first dimension. The model expects images in the format [Channel, Height, Width]
|
135 |
+
images = [image.transpose(2, 0, 1) for image in images]
|
136 |
+
return images
|
137 |
+
|
138 |
+
|
139 |
+
def sample_top_p(probs: torch.Tensor, p: float):
|
140 |
+
# (B, vocab_size)
|
141 |
+
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
142 |
+
# (B, vocab_size)
|
143 |
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
144 |
+
# (B, vocab_size)
|
145 |
+
# (Substracting "probs_sort" shifts the cumulative sum by 1 position to the right before masking)
|
146 |
+
mask = probs_sum - probs_sort > p
|
147 |
+
# Zero out all the probabilities of tokens that are not selected by the Top P
|
148 |
+
probs_sort[mask] = 0.0
|
149 |
+
# Redistribute the probabilities so that they sum up to 1.
|
150 |
+
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
151 |
+
# Sample a token (its index) from the top p distribution
|
152 |
+
next_token = torch.multinomial(probs_sort, num_samples=1)
|
153 |
+
# Get the token position in the vocabulary corresponding to the sampled index
|
154 |
+
next_token = torch.gather(probs_idx, -1, next_token)
|
155 |
+
return next_token
|
156 |
+
|
157 |
+
|
158 |
+
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
159 |
+
"""
|
160 |
+
Args:
|
161 |
+
lengths:
|
162 |
+
A 1-D tensor containing sentence lengths.
|
163 |
+
max_len:
|
164 |
+
The length of masks.
|
165 |
+
Returns:
|
166 |
+
Return a 2-D bool tensor, where masked positions
|
167 |
+
are filled with `True` and non-masked positions are
|
168 |
+
filled with `False`.
|
169 |
+
>>> lengths = torch.tensor([1, 3, 2, 5])
|
170 |
+
>>> make_pad_mask(lengths)
|
171 |
+
tensor([[False, True, True, True, True],
|
172 |
+
[False, False, False, True, True],
|
173 |
+
[False, False, True, True, True],
|
174 |
+
[False, False, False, False, False]])
|
175 |
+
"""
|
176 |
+
assert lengths.ndim == 1, lengths.ndim
|
177 |
+
max_len = max(max_len, lengths.max())
|
178 |
+
n = lengths.size(0)
|
179 |
+
seq_range = torch.arange(0, max_len, device=lengths.device)
|
180 |
+
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
|
181 |
+
|
182 |
+
return expaned_lengths >= lengths.unsqueeze(-1)
|
183 |
+
|
184 |
+
|
185 |
+
def _prepare_4d_causal_attention_mask_with_cache_position(
|
186 |
+
attention_mask: torch.Tensor,
|
187 |
+
sequence_length: int,
|
188 |
+
target_length: int,
|
189 |
+
dtype: torch.dtype,
|
190 |
+
device: torch.device,
|
191 |
+
min_dtype: float,
|
192 |
+
cache_position: torch.Tensor,
|
193 |
+
batch_size: int,
|
194 |
+
is_training: bool = False,
|
195 |
+
token_type_ids: torch.Tensor = None,
|
196 |
+
):
|
197 |
+
"""
|
198 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
199 |
+
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
attention_mask (`torch.Tensor`):
|
203 |
+
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
204 |
+
sequence_length (`int`):
|
205 |
+
The sequence length being processed.
|
206 |
+
target_length (`int`):
|
207 |
+
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
208 |
+
dtype (`torch.dtype`):
|
209 |
+
The dtype to use for the 4D attention mask.
|
210 |
+
device (`torch.device`):
|
211 |
+
The device to plcae the 4D attention mask on.
|
212 |
+
min_dtype (`float`):
|
213 |
+
The minimum value representable with the dtype `dtype`.
|
214 |
+
cache_position (`torch.Tensor`):
|
215 |
+
Indices depicting the position of the input sequence tokens in the sequence.
|
216 |
+
batch_size (`torch.Tensor`):
|
217 |
+
Batch size.
|
218 |
+
is_training (`bool`):
|
219 |
+
Whether the model is in training mode or in inference. The condition is checked by presence/absence of `token_type_ids/labels`
|
220 |
+
"""
|
221 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
222 |
+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
223 |
+
causal_mask = attention_mask
|
224 |
+
else:
|
225 |
+
causal_mask = torch.full(
|
226 |
+
(sequence_length, target_length),
|
227 |
+
fill_value=min_dtype,
|
228 |
+
dtype=dtype,
|
229 |
+
device=device,
|
230 |
+
)
|
231 |
+
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
|
232 |
+
if sequence_length != 1:
|
233 |
+
if is_training:
|
234 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
235 |
+
else:
|
236 |
+
causal_mask[:, :sequence_length] = 0.0
|
237 |
+
|
238 |
+
causal_mask *= torch.arange(
|
239 |
+
target_length, device=cache_position.device
|
240 |
+
) > cache_position.reshape(-1, 1)
|
241 |
+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
242 |
+
if attention_mask is not None:
|
243 |
+
causal_mask = (
|
244 |
+
causal_mask.clone()
|
245 |
+
) # copy to contiguous memory for in-place edit
|
246 |
+
mask_length = attention_mask.shape[-1]
|
247 |
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
|
248 |
+
:, None, None, :
|
249 |
+
].to(causal_mask.device)
|
250 |
+
padding_mask = padding_mask == 0
|
251 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[
|
252 |
+
:, :, :, :mask_length
|
253 |
+
].masked_fill(padding_mask, min_dtype)
|
254 |
+
# we are training thus we need to create a full mask on the image + prefix but causal on suffix
|
255 |
+
if is_training:
|
256 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[
|
257 |
+
:, :, :, :mask_length
|
258 |
+
].masked_fill(
|
259 |
+
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
|
260 |
+
)
|
261 |
+
return causal_mask
|
262 |
+
|
263 |
+
|
264 |
+
# Copied from transformers.models.idefics2.processing_idefics2.is_url
|
265 |
+
def is_url(val) -> bool:
|
266 |
+
return isinstance(val, str) and val.startswith("http")
|
267 |
+
|
268 |
+
|
269 |
+
# Copied from transformers.models.idefics2.processing_idefics2.is_image_or_image_url
|
270 |
+
def is_image_or_image_url(elem):
|
271 |
+
return is_url(elem) or is_valid_image(elem)
|
272 |
+
|
273 |
+
|
274 |
+
def _is_str_or_image(elem):
|
275 |
+
return isinstance(elem, (str)) or is_image_or_image_url(elem)
|
276 |
+
|
277 |
+
|
278 |
+
def generate_partial_autoregressive_mask(sz, start, end):
|
279 |
+
mask = torch.zeros(sz, sz).bool()
|
280 |
+
mask[start:end, start:end] = torch.triu(
|
281 |
+
torch.ones(end - start, end - start, dtype=torch.bool), diagonal=1
|
282 |
+
)
|
283 |
+
mask[:start, start:end] = True
|
284 |
+
mask[end:, start:end] = True
|
285 |
+
return mask
|
286 |
+
|
287 |
+
|
288 |
+
def build_string_from_input(prompt, bos_token, image_seq_len, image_token, num_images):
|
289 |
+
|
290 |
+
return f"{image_token * image_seq_len * num_images}{bos_token}{prompt}\n"
|
291 |
+
|
292 |
+
|
293 |
+
def is_torchdynamo_compiling():
|
294 |
+
|
295 |
+
try:
|
296 |
+
import torch
|
297 |
+
|
298 |
+
return torch.compiler.is_compiling()
|
299 |
+
except Exception:
|
300 |
+
try:
|
301 |
+
import torch._dynamo as dynamo # noqa: F401
|
302 |
+
|
303 |
+
return dynamo.is_compiling()
|
304 |
+
except Exception:
|
305 |
+
return False
|