Spaces:
Runtime error
Runtime error
Add application file
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Dockerfile +9 -2
- customs/customsf +0 -0
- data/__init__.py +3 -0
- data/__pycache__/__init__.cpython-311.pyc +0 -0
- data/__pycache__/collation.cpython-311.pyc +0 -0
- data/__pycache__/input_strategies.cpython-311.pyc +0 -0
- data/__pycache__/tokenizer.cpython-311.pyc +0 -0
- data/collation.py +120 -0
- data/datamodule.py +419 -0
- data/dataset.py +242 -0
- data/fbank.py +212 -0
- data/input_strategies.py +159 -0
- data/tokenizer.py +126 -0
- macros.py +44 -0
- main.py +2 -2
- models/__init__.py +136 -0
- models/__pycache__/__init__.cpython-311.pyc +0 -0
- models/__pycache__/macros.cpython-311.pyc +0 -0
- models/__pycache__/transformer.cpython-311.pyc +0 -0
- models/__pycache__/vallex.cpython-311.pyc +0 -0
- models/__pycache__/visualizer.cpython-311.pyc +0 -0
- models/macros.py +11 -0
- models/transformer.py +394 -0
- models/vallex.py +853 -0
- models/visualizer.py +106 -0
- modules/__init__.py +0 -0
- modules/__pycache__/__init__.cpython-311.pyc +0 -0
- modules/__pycache__/activation.cpython-311.pyc +0 -0
- modules/__pycache__/embedding.cpython-311.pyc +0 -0
- modules/__pycache__/scaling.cpython-311.pyc +0 -0
- modules/__pycache__/transformer.cpython-311.pyc +0 -0
- modules/activation.py +612 -0
- modules/embedding.py +97 -0
- modules/optim.py +1105 -0
- modules/scaling.py +1401 -0
- modules/scheduler.py +78 -0
- modules/transformer.py +683 -0
- prompts/promptsf +0 -0
- requirements.txt +1 -1
- utils/__init__.py +15 -0
- utils/__pycache__/__init__.cpython-311.pyc +0 -0
- utils/__pycache__/generation.cpython-311.pyc +0 -0
- utils/__pycache__/prompt_making.cpython-311.pyc +0 -0
- utils/__pycache__/sentence_cutter.cpython-311.pyc +0 -0
- utils/__pycache__/symbol_table.cpython-311.pyc +0 -0
- utils/download.py +49 -0
- utils/g2p/__init__.py +72 -0
- utils/g2p/__pycache__/__init__.cpython-311.pyc +0 -0
- utils/g2p/__pycache__/cleaners.cpython-311.pyc +0 -0
- utils/g2p/__pycache__/english.cpython-311.pyc +0 -0
Dockerfile
CHANGED
@@ -3,6 +3,13 @@ WORKDIR /code
|
|
3 |
COPY ./requirements.txt /code/requirements.txt
|
4 |
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
5 |
COPY ./s2smodels.py /code/
|
6 |
-
COPY ./
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80"]
|
|
|
3 |
COPY ./requirements.txt /code/requirements.txt
|
4 |
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
5 |
COPY ./s2smodels.py /code/
|
6 |
+
COPY ./macros.py /code/
|
7 |
+
COPY ./utils/ . /code/
|
8 |
+
COPY ./modules/ . /code/
|
9 |
+
COPY ./models/ . /code/
|
10 |
+
COPY ./data/ . /code/
|
11 |
+
COPY ./prompts/ . /code/
|
12 |
+
COPY ./customs/ . /code/
|
13 |
+
COPY ./main.py /code/
|
14 |
|
15 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80"]
|
customs/customsf
ADDED
File without changes
|
data/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# from .datamodule import *
|
2 |
+
# from .tokenizer import *
|
3 |
+
from .collation import *
|
data/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (208 Bytes). View file
|
|
data/__pycache__/collation.cpython-311.pyc
ADDED
Binary file (7.2 kB). View file
|
|
data/__pycache__/input_strategies.cpython-311.pyc
ADDED
Binary file (1.8 kB). View file
|
|
data/__pycache__/tokenizer.cpython-311.pyc
ADDED
Binary file (6.77 kB). View file
|
|
data/collation.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import List, Tuple
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from utils import SymbolTable
|
8 |
+
|
9 |
+
|
10 |
+
class TextTokenCollater:
|
11 |
+
"""Collate list of text tokens
|
12 |
+
|
13 |
+
Map sentences to integers. Sentences are padded to equal length.
|
14 |
+
Beginning and end-of-sequence symbols can be added.
|
15 |
+
|
16 |
+
Example:
|
17 |
+
>>> token_collater = TextTokenCollater(text_tokens)
|
18 |
+
>>> tokens_batch, tokens_lens = token_collater(text)
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
tokens_batch: IntTensor of shape (B, L)
|
22 |
+
B: batch dimension, number of input sentences
|
23 |
+
L: length of the longest sentence
|
24 |
+
tokens_lens: IntTensor of shape (B,)
|
25 |
+
Length of each sentence after adding <eos> and <bos>
|
26 |
+
but before padding.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
text_tokens: List[str],
|
32 |
+
add_eos: bool = True,
|
33 |
+
add_bos: bool = True,
|
34 |
+
pad_symbol: str = "<pad>",
|
35 |
+
bos_symbol: str = "<bos>",
|
36 |
+
eos_symbol: str = "<eos>",
|
37 |
+
):
|
38 |
+
self.pad_symbol = pad_symbol
|
39 |
+
|
40 |
+
self.add_eos = add_eos
|
41 |
+
self.add_bos = add_bos
|
42 |
+
|
43 |
+
self.bos_symbol = bos_symbol
|
44 |
+
self.eos_symbol = eos_symbol
|
45 |
+
|
46 |
+
unique_tokens = (
|
47 |
+
[pad_symbol]
|
48 |
+
+ ([bos_symbol] if add_bos else [])
|
49 |
+
+ ([eos_symbol] if add_eos else [])
|
50 |
+
+ sorted(text_tokens)
|
51 |
+
)
|
52 |
+
|
53 |
+
self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)}
|
54 |
+
self.idx2token = [token for token in unique_tokens]
|
55 |
+
|
56 |
+
def index(
|
57 |
+
self, tokens_list: List[str]
|
58 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
59 |
+
seqs, seq_lens = [], []
|
60 |
+
for tokens in tokens_list:
|
61 |
+
assert (
|
62 |
+
all([True if s in self.token2idx else False for s in tokens])
|
63 |
+
is True
|
64 |
+
)
|
65 |
+
seq = (
|
66 |
+
([self.bos_symbol] if self.add_bos else [])
|
67 |
+
+ list(tokens)
|
68 |
+
+ ([self.eos_symbol] if self.add_eos else [])
|
69 |
+
)
|
70 |
+
seqs.append(seq)
|
71 |
+
seq_lens.append(len(seq))
|
72 |
+
|
73 |
+
max_len = max(seq_lens)
|
74 |
+
for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)):
|
75 |
+
seq.extend([self.pad_symbol] * (max_len - seq_len))
|
76 |
+
|
77 |
+
tokens = torch.from_numpy(
|
78 |
+
np.array(
|
79 |
+
[[self.token2idx[token] for token in seq] for seq in seqs],
|
80 |
+
dtype=np.int64,
|
81 |
+
)
|
82 |
+
)
|
83 |
+
tokens_lens = torch.IntTensor(seq_lens)
|
84 |
+
|
85 |
+
return tokens, tokens_lens
|
86 |
+
|
87 |
+
def __call__(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
|
88 |
+
tokens_seqs = [[p for p in text] for text in texts]
|
89 |
+
max_len = len(max(tokens_seqs, key=len))
|
90 |
+
|
91 |
+
seqs = [
|
92 |
+
([self.bos_symbol] if self.add_bos else [])
|
93 |
+
+ list(seq)
|
94 |
+
+ ([self.eos_symbol] if self.add_eos else [])
|
95 |
+
+ [self.pad_symbol] * (max_len - len(seq))
|
96 |
+
for seq in tokens_seqs
|
97 |
+
]
|
98 |
+
|
99 |
+
tokens_batch = torch.from_numpy(
|
100 |
+
np.array(
|
101 |
+
[seq for seq in seqs],
|
102 |
+
dtype=np.int64,
|
103 |
+
)
|
104 |
+
)
|
105 |
+
|
106 |
+
tokens_lens = torch.IntTensor(
|
107 |
+
[
|
108 |
+
len(seq) + int(self.add_eos) + int(self.add_bos)
|
109 |
+
for seq in tokens_seqs
|
110 |
+
]
|
111 |
+
)
|
112 |
+
|
113 |
+
return tokens_batch, tokens_lens
|
114 |
+
|
115 |
+
|
116 |
+
def get_text_token_collater() -> TextTokenCollater:
|
117 |
+
collater = TextTokenCollater(
|
118 |
+
['0'], add_bos=False, add_eos=False
|
119 |
+
)
|
120 |
+
return collater
|
data/datamodule.py
ADDED
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 (authors: Feiteng Li)
|
2 |
+
#
|
3 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
|
18 |
+
import argparse
|
19 |
+
import inspect
|
20 |
+
import logging
|
21 |
+
from functools import lru_cache
|
22 |
+
from pathlib import Path
|
23 |
+
from typing import Any, Dict, Optional
|
24 |
+
|
25 |
+
import torch
|
26 |
+
# from icefall.utils import str2bool
|
27 |
+
# from lhotse import CutSet, load_manifest_lazy
|
28 |
+
# from lhotse.dataset import (
|
29 |
+
# CutConcatenate,
|
30 |
+
# DynamicBucketingSampler,
|
31 |
+
# PrecomputedFeatures,
|
32 |
+
# SingleCutSampler,
|
33 |
+
# SpecAugment,
|
34 |
+
# )
|
35 |
+
# from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
36 |
+
# from lhotse.utils import fix_random_seed
|
37 |
+
from torch.utils.data import DataLoader
|
38 |
+
|
39 |
+
from data.collation import get_text_token_collater
|
40 |
+
# from data.dataset import SpeechSynthesisDataset
|
41 |
+
from data.fbank import get_fbank_extractor
|
42 |
+
from data.input_strategies import PromptedPrecomputedFeatures
|
43 |
+
|
44 |
+
# PrecomputedFeatures = PrecomputedFeatures
|
45 |
+
|
46 |
+
|
47 |
+
class _SeedWorkers:
|
48 |
+
def __init__(self, seed: int):
|
49 |
+
self.seed = seed
|
50 |
+
|
51 |
+
def __call__(self, worker_id: int):
|
52 |
+
fix_random_seed(self.seed + worker_id)
|
53 |
+
|
54 |
+
|
55 |
+
def _get_input_strategy(input_strategy, dataset, cuts):
|
56 |
+
if input_strategy == "PromptedPrecomputedFeatures":
|
57 |
+
return PromptedPrecomputedFeatures(dataset, cuts)
|
58 |
+
|
59 |
+
return eval(input_strategy)()
|
60 |
+
|
61 |
+
|
62 |
+
class TtsDataModule:
|
63 |
+
"""
|
64 |
+
DataModule for VALL-E TTS experiments.
|
65 |
+
It assumes there is always one train and valid dataloader.
|
66 |
+
|
67 |
+
It contains all the common data pipeline modules used in TTS
|
68 |
+
experiments, e.g.:
|
69 |
+
- dynamic batch size,
|
70 |
+
- bucketing samplers,
|
71 |
+
- cut concatenation[not used & tested yet],
|
72 |
+
- augmentation[not used & tested yet],
|
73 |
+
- on-the-fly feature extraction[not used & tested yet]
|
74 |
+
|
75 |
+
This class should be derived for specific corpora used in TTS tasks.
|
76 |
+
"""
|
77 |
+
|
78 |
+
def __init__(self, args: argparse.Namespace):
|
79 |
+
self.args = args
|
80 |
+
|
81 |
+
@classmethod
|
82 |
+
def add_arguments(cls, parser: argparse.ArgumentParser):
|
83 |
+
group = parser.add_argument_group(
|
84 |
+
title="TTS data related options",
|
85 |
+
description="These options are used for the preparation of "
|
86 |
+
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
87 |
+
"effective batch sizes, sampling strategies, applied data "
|
88 |
+
"augmentations, etc.",
|
89 |
+
)
|
90 |
+
group.add_argument(
|
91 |
+
"--manifest-dir",
|
92 |
+
type=Path,
|
93 |
+
default=Path("data/tokenized"),
|
94 |
+
help="Path to directory with train/valid/test cuts.",
|
95 |
+
)
|
96 |
+
group.add_argument(
|
97 |
+
"--max-duration",
|
98 |
+
type=int,
|
99 |
+
default=40.0,
|
100 |
+
help="Maximum pooled recordings duration (seconds) in a "
|
101 |
+
"single batch. You can reduce it if it causes CUDA OOM.",
|
102 |
+
)
|
103 |
+
group.add_argument(
|
104 |
+
"--bucketing-sampler",
|
105 |
+
type=str2bool,
|
106 |
+
default=True,
|
107 |
+
help="When enabled, the batches will come from buckets of "
|
108 |
+
"similar duration (saves padding frames).",
|
109 |
+
)
|
110 |
+
group.add_argument(
|
111 |
+
"--num-buckets",
|
112 |
+
type=int,
|
113 |
+
default=10,
|
114 |
+
help="The number of buckets for the DynamicBucketingSampler"
|
115 |
+
"(you might want to increase it for larger datasets).",
|
116 |
+
)
|
117 |
+
group.add_argument(
|
118 |
+
"--concatenate-cuts",
|
119 |
+
type=str2bool,
|
120 |
+
default=False,
|
121 |
+
help="When enabled, utterances (cuts) will be concatenated "
|
122 |
+
"to minimize the amount of padding.",
|
123 |
+
)
|
124 |
+
group.add_argument(
|
125 |
+
"--duration-factor",
|
126 |
+
type=float,
|
127 |
+
default=1.0,
|
128 |
+
help="Determines the maximum duration of a concatenated cut "
|
129 |
+
"relative to the duration of the longest cut in a batch.",
|
130 |
+
)
|
131 |
+
group.add_argument(
|
132 |
+
"--gap",
|
133 |
+
type=float,
|
134 |
+
default=0.1,
|
135 |
+
help="The amount of padding (in seconds) inserted between "
|
136 |
+
"concatenated cuts. This padding is filled with noise when "
|
137 |
+
"noise augmentation is used.",
|
138 |
+
)
|
139 |
+
group.add_argument(
|
140 |
+
"--on-the-fly-feats",
|
141 |
+
type=str2bool,
|
142 |
+
default=False,
|
143 |
+
help="When enabled, use on-the-fly cut mixing and feature "
|
144 |
+
"extraction. Will drop existing precomputed feature manifests "
|
145 |
+
"if available.",
|
146 |
+
)
|
147 |
+
group.add_argument(
|
148 |
+
"--shuffle",
|
149 |
+
type=str2bool,
|
150 |
+
default=True,
|
151 |
+
help="When enabled (=default), the examples will be "
|
152 |
+
"shuffled for each epoch.",
|
153 |
+
)
|
154 |
+
group.add_argument(
|
155 |
+
"--drop-last",
|
156 |
+
type=str2bool,
|
157 |
+
default=False,
|
158 |
+
help="Whether to drop last batch. Used by sampler.",
|
159 |
+
)
|
160 |
+
group.add_argument(
|
161 |
+
"--return-cuts",
|
162 |
+
type=str2bool,
|
163 |
+
default=True,
|
164 |
+
help="When enabled, each batch will have the "
|
165 |
+
"field: batch['supervisions']['cut'] with the cuts that "
|
166 |
+
"were used to construct it.",
|
167 |
+
)
|
168 |
+
|
169 |
+
group.add_argument(
|
170 |
+
"--num-workers",
|
171 |
+
type=int,
|
172 |
+
default=8,
|
173 |
+
help="The number of training dataloader workers that "
|
174 |
+
"collect the batches.",
|
175 |
+
)
|
176 |
+
|
177 |
+
group.add_argument(
|
178 |
+
"--enable-spec-aug",
|
179 |
+
type=str2bool,
|
180 |
+
default=False,
|
181 |
+
help="When enabled, use SpecAugment for training dataset.",
|
182 |
+
)
|
183 |
+
|
184 |
+
group.add_argument(
|
185 |
+
"--spec-aug-time-warp-factor",
|
186 |
+
type=int,
|
187 |
+
default=80,
|
188 |
+
help="Used only when --enable-spec-aug is True. "
|
189 |
+
"It specifies the factor for time warping in SpecAugment. "
|
190 |
+
"Larger values mean more warping. "
|
191 |
+
"A value less than 1 means to disable time warp.",
|
192 |
+
)
|
193 |
+
|
194 |
+
group.add_argument(
|
195 |
+
"--input-strategy",
|
196 |
+
type=str,
|
197 |
+
default="PrecomputedFeatures",
|
198 |
+
help="AudioSamples or PrecomputedFeatures or PromptedPrecomputedFeatures",
|
199 |
+
)
|
200 |
+
|
201 |
+
group.add_argument(
|
202 |
+
"--dataset",
|
203 |
+
type=str,
|
204 |
+
default="ljspeech",
|
205 |
+
help="--input-strategy PromptedPrecomputedFeatures needs dataset name to prepare prompts.",
|
206 |
+
)
|
207 |
+
|
208 |
+
parser.add_argument(
|
209 |
+
"--text-tokens",
|
210 |
+
type=str,
|
211 |
+
default="data/tokenized/unique_text_tokens.k2symbols",
|
212 |
+
help="Path to the unique text tokens file",
|
213 |
+
)
|
214 |
+
|
215 |
+
parser.add_argument(
|
216 |
+
"--sampling-rate",
|
217 |
+
type=int,
|
218 |
+
default=24000,
|
219 |
+
help="""Audio sampling rate.""",
|
220 |
+
)
|
221 |
+
|
222 |
+
def train_dataloaders(
|
223 |
+
self,
|
224 |
+
cuts_train: CutSet,
|
225 |
+
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
226 |
+
) -> DataLoader:
|
227 |
+
"""
|
228 |
+
Args:
|
229 |
+
cuts_train:
|
230 |
+
CutSet for training.
|
231 |
+
sampler_state_dict:
|
232 |
+
The state dict for the training sampler.
|
233 |
+
"""
|
234 |
+
transforms = []
|
235 |
+
|
236 |
+
if self.args.concatenate_cuts:
|
237 |
+
logging.info(
|
238 |
+
f"Using cut concatenation with duration factor "
|
239 |
+
f"{self.args.duration_factor} and gap {self.args.gap}."
|
240 |
+
)
|
241 |
+
# Cut concatenation should be the first transform in the list,
|
242 |
+
# so that if we e.g. mix noise in, it will fill the gaps between
|
243 |
+
# different utterances.
|
244 |
+
transforms = [
|
245 |
+
CutConcatenate(
|
246 |
+
duration_factor=self.args.duration_factor, gap=self.args.gap
|
247 |
+
)
|
248 |
+
] + transforms
|
249 |
+
|
250 |
+
input_transforms = []
|
251 |
+
if self.args.enable_spec_aug:
|
252 |
+
logging.info("Enable SpecAugment")
|
253 |
+
logging.info(
|
254 |
+
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
|
255 |
+
)
|
256 |
+
# Set the value of num_frame_masks according to Lhotse's version.
|
257 |
+
# In different Lhotse's versions, the default of num_frame_masks is
|
258 |
+
# different.
|
259 |
+
num_frame_masks = 10
|
260 |
+
num_frame_masks_parameter = inspect.signature(
|
261 |
+
SpecAugment.__init__
|
262 |
+
).parameters["num_frame_masks"]
|
263 |
+
if num_frame_masks_parameter.default == 1:
|
264 |
+
num_frame_masks = 2
|
265 |
+
logging.info(f"Num frame mask: {num_frame_masks}")
|
266 |
+
input_transforms.append(
|
267 |
+
SpecAugment(
|
268 |
+
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
269 |
+
num_frame_masks=num_frame_masks,
|
270 |
+
features_mask_size=27,
|
271 |
+
num_feature_masks=2,
|
272 |
+
frames_mask_size=100,
|
273 |
+
)
|
274 |
+
)
|
275 |
+
else:
|
276 |
+
logging.info("Disable SpecAugment")
|
277 |
+
|
278 |
+
logging.info("About to create train dataset")
|
279 |
+
if self.args.on_the_fly_feats:
|
280 |
+
# NOTE: the PerturbSpeed transform should be added only if we
|
281 |
+
# remove it from data prep stage.
|
282 |
+
# Add on-the-fly speed perturbation; since originally it would
|
283 |
+
# have increased epoch size by 3, we will apply prob 2/3 and use
|
284 |
+
# 3x more epochs.
|
285 |
+
# Speed perturbation probably should come first before
|
286 |
+
# concatenation, but in principle the transforms order doesn't have
|
287 |
+
# to be strict (e.g. could be randomized)
|
288 |
+
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
289 |
+
# Drop feats to be on the safe side.
|
290 |
+
train = SpeechSynthesisDataset(
|
291 |
+
get_text_token_collater(self.args.text_tokens),
|
292 |
+
cut_transforms=transforms,
|
293 |
+
feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor()),
|
294 |
+
feature_transforms=input_transforms,
|
295 |
+
)
|
296 |
+
else:
|
297 |
+
train = SpeechSynthesisDataset(
|
298 |
+
get_text_token_collater(self.args.text_tokens),
|
299 |
+
feature_input_strategy=_get_input_strategy(
|
300 |
+
self.args.input_strategy, self.args.dataset, cuts_train
|
301 |
+
),
|
302 |
+
cut_transforms=transforms,
|
303 |
+
feature_transforms=input_transforms,
|
304 |
+
)
|
305 |
+
|
306 |
+
if self.args.bucketing_sampler:
|
307 |
+
logging.info("Using DynamicBucketingSampler")
|
308 |
+
train_sampler = DynamicBucketingSampler(
|
309 |
+
cuts_train,
|
310 |
+
max_duration=self.args.max_duration,
|
311 |
+
shuffle=self.args.shuffle,
|
312 |
+
num_buckets=self.args.num_buckets,
|
313 |
+
drop_last=self.args.drop_last,
|
314 |
+
)
|
315 |
+
else:
|
316 |
+
logging.info(
|
317 |
+
"Using SingleCutSampler and sort by duraton(ascending=True)."
|
318 |
+
)
|
319 |
+
cuts_train = cuts_train.to_eager().sort_by_duration(ascending=True)
|
320 |
+
train_sampler = SingleCutSampler(
|
321 |
+
cuts_train,
|
322 |
+
max_duration=self.args.max_duration,
|
323 |
+
shuffle=self.args.shuffle,
|
324 |
+
)
|
325 |
+
logging.info("About to create train dataloader")
|
326 |
+
|
327 |
+
if sampler_state_dict is not None:
|
328 |
+
logging.info("Loading sampler state dict")
|
329 |
+
train_sampler.load_state_dict(sampler_state_dict)
|
330 |
+
|
331 |
+
# 'seed' is derived from the current random state, which will have
|
332 |
+
# previously been set in the main process.
|
333 |
+
seed = torch.randint(0, 100000, ()).item()
|
334 |
+
worker_init_fn = _SeedWorkers(seed)
|
335 |
+
|
336 |
+
train_dl = DataLoader(
|
337 |
+
train,
|
338 |
+
sampler=train_sampler,
|
339 |
+
batch_size=None,
|
340 |
+
num_workers=self.args.num_workers,
|
341 |
+
persistent_workers=False,
|
342 |
+
worker_init_fn=worker_init_fn,
|
343 |
+
)
|
344 |
+
|
345 |
+
return train_dl
|
346 |
+
|
347 |
+
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
348 |
+
logging.info("About to create dev dataset")
|
349 |
+
if self.args.on_the_fly_feats:
|
350 |
+
validate = SpeechSynthesisDataset(
|
351 |
+
get_text_token_collater(self.args.text_tokens),
|
352 |
+
feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor()),
|
353 |
+
cut_transforms=[],
|
354 |
+
)
|
355 |
+
else:
|
356 |
+
validate = SpeechSynthesisDataset(
|
357 |
+
get_text_token_collater(self.args.text_tokens),
|
358 |
+
feature_input_strategy=_get_input_strategy(
|
359 |
+
self.args.input_strategy, self.args.dataset, cuts_valid
|
360 |
+
),
|
361 |
+
cut_transforms=[],
|
362 |
+
)
|
363 |
+
valid_sampler = DynamicBucketingSampler(
|
364 |
+
cuts_valid,
|
365 |
+
max_duration=self.args.max_duration,
|
366 |
+
shuffle=False,
|
367 |
+
)
|
368 |
+
logging.info("About to create dev dataloader")
|
369 |
+
valid_dl = DataLoader(
|
370 |
+
validate,
|
371 |
+
sampler=valid_sampler,
|
372 |
+
batch_size=None,
|
373 |
+
num_workers=4,
|
374 |
+
persistent_workers=False,
|
375 |
+
)
|
376 |
+
|
377 |
+
return valid_dl
|
378 |
+
|
379 |
+
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
380 |
+
logging.debug("About to create test dataset")
|
381 |
+
test = SpeechSynthesisDataset(
|
382 |
+
get_text_token_collater(self.args.text_tokens),
|
383 |
+
feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor())
|
384 |
+
if self.args.on_the_fly_feats
|
385 |
+
else _get_input_strategy(
|
386 |
+
self.args.input_strategy, self.args.dataset, cuts
|
387 |
+
),
|
388 |
+
cut_transforms=[],
|
389 |
+
)
|
390 |
+
sampler = DynamicBucketingSampler(
|
391 |
+
cuts,
|
392 |
+
max_duration=self.args.max_duration,
|
393 |
+
shuffle=False,
|
394 |
+
)
|
395 |
+
logging.debug("About to create test dataloader")
|
396 |
+
test_dl = DataLoader(
|
397 |
+
test,
|
398 |
+
batch_size=None,
|
399 |
+
sampler=sampler,
|
400 |
+
num_workers=self.args.num_workers,
|
401 |
+
)
|
402 |
+
return test_dl
|
403 |
+
|
404 |
+
@lru_cache()
|
405 |
+
def train_cuts(self) -> CutSet:
|
406 |
+
logging.info("About to get train cuts")
|
407 |
+
return load_manifest_lazy(
|
408 |
+
self.args.manifest_dir / "cuts_train.jsonl.gz"
|
409 |
+
)
|
410 |
+
|
411 |
+
@lru_cache()
|
412 |
+
def dev_cuts(self) -> CutSet:
|
413 |
+
logging.info("About to get dev cuts")
|
414 |
+
return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz")
|
415 |
+
|
416 |
+
@lru_cache()
|
417 |
+
def test_cuts(self) -> CutSet:
|
418 |
+
logging.info("About to get test cuts")
|
419 |
+
return load_manifest_lazy(self.args.manifest_dir / "cuts_test.jsonl.gz")
|
data/dataset.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 (authors: Feiteng Li)
|
2 |
+
#
|
3 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
"""
|
18 |
+
modified from lhoste.dataset.speech_synthesis.py
|
19 |
+
"""
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import math
|
23 |
+
import h5py
|
24 |
+
from tokenizers import Tokenizer
|
25 |
+
from typing import Union, List
|
26 |
+
import numpy as np
|
27 |
+
from tqdm import tqdm
|
28 |
+
|
29 |
+
_pad = '_'
|
30 |
+
_punctuation = ',.!?-~…'
|
31 |
+
_letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ '
|
32 |
+
symbols = [_pad] + list(_punctuation) + list(_letters)
|
33 |
+
|
34 |
+
language_dict = {
|
35 |
+
'en': 0,
|
36 |
+
'zh': 1,
|
37 |
+
'ja': 2,
|
38 |
+
}
|
39 |
+
def seq2phone(tokens: Union[List, np.ndarray]):
|
40 |
+
"""
|
41 |
+
Convert tokenized phoneme ID sequence back to phoneme string
|
42 |
+
:param tokens: phoneme tokens
|
43 |
+
:return: recovered phoneme sequence
|
44 |
+
"""
|
45 |
+
phones = "".join([symbols[i] for i in tokens])
|
46 |
+
return phones
|
47 |
+
|
48 |
+
class DynamicBatchSampler(torch.utils.data.Sampler):
|
49 |
+
def __init__(self, sampler, num_tokens_fn, num_buckets=100, min_size=0, max_size=1000,
|
50 |
+
max_tokens=None, max_sentences=None, drop_last=False):
|
51 |
+
"""
|
52 |
+
|
53 |
+
:param sampler:
|
54 |
+
:param num_tokens_fn: 根据idx返回样本的长度的函数
|
55 |
+
:param num_buckets: 利用桶原理将相似长度的样本放在一个batchsize中,桶的数量
|
56 |
+
:param min_size: 最小长度的样本, 小于这个值的样本会被过滤掉。 依据这个值来创建样桶
|
57 |
+
:param max_size: 最大长度的样本
|
58 |
+
:param max_sentences: batch_size, 但是这里可以通过max_sentences 和 max_tokens 共同控制最终的大小
|
59 |
+
"""
|
60 |
+
super(DynamicBatchSampler, self).__init__(sampler)
|
61 |
+
self.sampler = sampler
|
62 |
+
self.num_tokens_fn = num_tokens_fn
|
63 |
+
self.num_buckets = num_buckets
|
64 |
+
|
65 |
+
self.min_size = min_size
|
66 |
+
self.max_size = max_size
|
67 |
+
|
68 |
+
assert max_size <= max_tokens, "max_size should be smaller than max tokens"
|
69 |
+
assert max_tokens is not None or max_sentences is not None, \
|
70 |
+
"max_tokens and max_sentences should not be null at the same time, please specify one parameter at least"
|
71 |
+
self.max_tokens = max_tokens if max_tokens is not None else float('Inf')
|
72 |
+
self.max_sentences = max_sentences if max_sentences is not None else float('Inf')
|
73 |
+
self.drop_last = drop_last
|
74 |
+
|
75 |
+
def set_epoch(self, epoch):
|
76 |
+
self.sampler.set_epoch(epoch)
|
77 |
+
def is_batch_full(self, num_tokens, batch):
|
78 |
+
if len(batch) == 0:
|
79 |
+
return False
|
80 |
+
if len(batch) == self.max_sentences:
|
81 |
+
return True
|
82 |
+
if num_tokens > self.max_tokens:
|
83 |
+
return True
|
84 |
+
return False
|
85 |
+
|
86 |
+
def __iter__(self):
|
87 |
+
buckets = [[] for _ in range(self.num_buckets)]
|
88 |
+
sample_len = [0] * self.num_buckets
|
89 |
+
|
90 |
+
for idx in self.sampler:
|
91 |
+
idx_length = self.num_tokens_fn(idx)
|
92 |
+
if not (self.min_size <= idx_length <= self.max_size):
|
93 |
+
print("sentence at index {} of size {} exceeds max_tokens, the sentence is ignored".format(idx, idx_length))
|
94 |
+
continue
|
95 |
+
|
96 |
+
index_buckets = math.floor((idx_length - self.min_size) / (self.max_size - self.min_size + 1)
|
97 |
+
* self.num_buckets)
|
98 |
+
sample_len[index_buckets] = max(sample_len[index_buckets], idx_length)
|
99 |
+
|
100 |
+
num_tokens = (len(buckets[index_buckets]) + 1) * sample_len[index_buckets]
|
101 |
+
if self.is_batch_full(num_tokens, buckets[index_buckets]):
|
102 |
+
# yield this batch
|
103 |
+
yield buckets[index_buckets]
|
104 |
+
buckets[index_buckets] = []
|
105 |
+
sample_len[index_buckets] = 0
|
106 |
+
|
107 |
+
buckets[index_buckets].append(idx)
|
108 |
+
|
109 |
+
# process left-over
|
110 |
+
leftover_batch = []
|
111 |
+
leftover_sample_len = 0
|
112 |
+
leftover = [idx for bucket in buckets for idx in bucket]
|
113 |
+
for idx in leftover:
|
114 |
+
idx_length = self.num_tokens_fn(idx)
|
115 |
+
leftover_sample_len = max(leftover_sample_len, idx_length)
|
116 |
+
num_tokens = (len(leftover_batch) + 1) * leftover_sample_len
|
117 |
+
if self.is_batch_full(num_tokens, leftover_batch):
|
118 |
+
yield leftover_batch
|
119 |
+
leftover_batch = []
|
120 |
+
leftover_sample_len = 0
|
121 |
+
leftover_batch.append(idx)
|
122 |
+
|
123 |
+
if len(leftover_batch) > 0 and not self.drop_last:
|
124 |
+
yield leftover_batch
|
125 |
+
|
126 |
+
def __len__(self):
|
127 |
+
# we do not know the exactly batch size, so do not call len(dataloader)
|
128 |
+
pass
|
129 |
+
|
130 |
+
|
131 |
+
class AudioDataset(torch.utils.data.Dataset):
|
132 |
+
def __init__(self, h5_path, ann_path, tokenizer_path):
|
133 |
+
self.h5_path = h5_path
|
134 |
+
with open(ann_path, 'r', encoding='utf-8') as f:
|
135 |
+
lines = f.readlines()
|
136 |
+
ls = [l.split("|") for l in lines]
|
137 |
+
ls_T = list(zip(*ls))
|
138 |
+
del ls_T[-1]
|
139 |
+
self.h5_paths, self.durations, self.langs, self.texts = \
|
140 |
+
list(ls_T[0]), list(ls_T[1]), list(ls_T[2]), list(ls_T[3])
|
141 |
+
self.durations = [float(dur) for dur in self.durations]
|
142 |
+
self.tokenizer = Tokenizer.from_file(tokenizer_path)
|
143 |
+
|
144 |
+
self._archive = None
|
145 |
+
|
146 |
+
def __len__(self):
|
147 |
+
return len(self.h5_paths)
|
148 |
+
|
149 |
+
def get_dur(self, idx):
|
150 |
+
return self.durations[idx]
|
151 |
+
|
152 |
+
@property
|
153 |
+
def archive(self):
|
154 |
+
if self._archive is None: # lazy loading here!
|
155 |
+
self._archive = h5py.File(self.h5_path, "r")
|
156 |
+
return self._archive
|
157 |
+
def __getitem__(self, idx):
|
158 |
+
archive = self.archive
|
159 |
+
h5_path = self.h5_paths[idx]
|
160 |
+
sub = archive[h5_path]
|
161 |
+
audio_tokens = sub['audio'][()]
|
162 |
+
phone_tokens = sub['text'][()]
|
163 |
+
dur = self.durations[idx]
|
164 |
+
lang = self.langs[idx]
|
165 |
+
text = self.texts[idx]
|
166 |
+
# tokenization should be done within dataloader
|
167 |
+
phones = seq2phone(phone_tokens)
|
168 |
+
phones = phones.replace(" ", "_")
|
169 |
+
if not len(phones):
|
170 |
+
cptpho_tokens = self.tokenizer.encode(text).ids
|
171 |
+
else:
|
172 |
+
cptpho_tokens = self.tokenizer.encode(phones).ids
|
173 |
+
assert len(cptpho_tokens)
|
174 |
+
return {
|
175 |
+
'utt_id': h5_path,
|
176 |
+
'text': text,
|
177 |
+
'audio': None,
|
178 |
+
'audio_lens': None,
|
179 |
+
'audio_features': audio_tokens,
|
180 |
+
'audio_features_lens': len(audio_tokens.T),
|
181 |
+
'text_tokens': np.array(cptpho_tokens),
|
182 |
+
'text_tokens_lens': len(cptpho_tokens),
|
183 |
+
'language': language_dict[lang],
|
184 |
+
}
|
185 |
+
|
186 |
+
def collate(batch):
|
187 |
+
utt_id_s = [b['utt_id'] for b in batch]
|
188 |
+
text_s = [b['text'] for b in batch]
|
189 |
+
|
190 |
+
audio_s = [b['audio'] for b in batch]
|
191 |
+
audio_lens_s = [b['audio_lens'] for b in batch]
|
192 |
+
|
193 |
+
audio_features_lens_s = [b['audio_features_lens'] for b in batch]
|
194 |
+
# create an empty tensor with maximum audio feature length
|
195 |
+
audio_features_s = torch.zeros([len(batch), max(audio_features_lens_s), 8], dtype=torch.int64) - 1 # audio pad with -1
|
196 |
+
|
197 |
+
text_tokens_lens_s = [b['text_tokens_lens'] for b in batch]
|
198 |
+
# create an empty tensor with maximum text tokens length
|
199 |
+
text_tokens_s = torch.zeros([len(batch), max(text_tokens_lens_s)], dtype=torch.int64) + 3 # [PAD] token id 3
|
200 |
+
|
201 |
+
language_s = [b['language'] for b in batch]
|
202 |
+
|
203 |
+
for i, b in enumerate(batch):
|
204 |
+
audio_features = b['audio_features']
|
205 |
+
audio_features_lens = b['audio_features_lens']
|
206 |
+
audio_features_s[i, :audio_features_lens, :] = torch.LongTensor(audio_features.T)
|
207 |
+
|
208 |
+
text_tokens = b['text_tokens']
|
209 |
+
text_tokens_lens = b['text_tokens_lens']
|
210 |
+
text_tokens_s[i, :text_tokens_lens] = torch.LongTensor(text_tokens)
|
211 |
+
|
212 |
+
batch = {
|
213 |
+
'utt_id': utt_id_s,
|
214 |
+
'text': text_s,
|
215 |
+
'audio': audio_s,
|
216 |
+
'audio_lens': audio_lens_s,
|
217 |
+
'audio_features': audio_features_s,
|
218 |
+
'audio_features_lens': torch.LongTensor(np.array(audio_features_lens_s)),
|
219 |
+
'text_tokens': text_tokens_s,
|
220 |
+
'text_tokens_lens': torch.LongTensor(np.array(text_tokens_lens_s)),
|
221 |
+
'languages': torch.LongTensor(np.array(language_s)),
|
222 |
+
}
|
223 |
+
return batch
|
224 |
+
|
225 |
+
def create_dataloader(data_dir="/root/valle/egs/mix", n_gpus=1, rank=0, num_workers=0, num_buckets=10, max_duration=120):
|
226 |
+
train_dataset = AudioDataset(h5_path=f"{data_dir}/audio_sum.hdf5",
|
227 |
+
ann_path=f"{data_dir}/audio_ann_sum.txt",
|
228 |
+
tokenizer_path=f"{data_dir}/bpe_69.json")
|
229 |
+
ran_sampler = torch.utils.data.distributed.DistributedSampler(
|
230 |
+
train_dataset,
|
231 |
+
num_replicas=n_gpus,
|
232 |
+
rank=rank,
|
233 |
+
shuffle=True,
|
234 |
+
)
|
235 |
+
dynamic_sampler = DynamicBatchSampler(ran_sampler, train_dataset.get_dur, num_buckets=num_buckets, max_size=20,
|
236 |
+
max_tokens=max_duration)
|
237 |
+
|
238 |
+
|
239 |
+
train_loader = torch.utils.data.DataLoader(train_dataset, num_workers=num_workers, collate_fn=collate,
|
240 |
+
batch_sampler=dynamic_sampler)
|
241 |
+
|
242 |
+
return train_loader
|
data/fbank.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 (authors: Feiteng Li)
|
2 |
+
#
|
3 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
|
18 |
+
from dataclasses import asdict, dataclass
|
19 |
+
from typing import Any, Dict, Optional, Union
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
# from lhotse.features.base import FeatureExtractor
|
24 |
+
# from lhotse.utils import EPSILON, Seconds, compute_num_frames
|
25 |
+
from librosa.filters import mel as librosa_mel_fn
|
26 |
+
|
27 |
+
|
28 |
+
@dataclass
|
29 |
+
class BigVGANFbankConfig:
|
30 |
+
# Spectogram-related part
|
31 |
+
# Note that frame_length and frame_shift will be converted to milliseconds before torchaudio/Kaldi sees them
|
32 |
+
frame_length: Seconds = 1024 / 24000.0
|
33 |
+
frame_shift: Seconds = 256 / 24000.0
|
34 |
+
remove_dc_offset: bool = True
|
35 |
+
round_to_power_of_two: bool = True
|
36 |
+
|
37 |
+
# Fbank-related part
|
38 |
+
low_freq: float = 0.0
|
39 |
+
high_freq: float = 12000.0
|
40 |
+
num_mel_bins: int = 100
|
41 |
+
use_energy: bool = False
|
42 |
+
|
43 |
+
def to_dict(self) -> Dict[str, Any]:
|
44 |
+
return asdict(self)
|
45 |
+
|
46 |
+
@staticmethod
|
47 |
+
def from_dict(data: Dict[str, Any]) -> "BigVGANFbankConfig":
|
48 |
+
return BigVGANFbankConfig(**data)
|
49 |
+
|
50 |
+
|
51 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
52 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
53 |
+
|
54 |
+
|
55 |
+
def spectral_normalize_torch(magnitudes):
|
56 |
+
output = dynamic_range_compression_torch(magnitudes)
|
57 |
+
return output
|
58 |
+
|
59 |
+
|
60 |
+
# https://github.com/NVIDIA/BigVGAN
|
61 |
+
# bigvgan_24khz_100band https://drive.google.com/drive/folders/1EpxX6AsxjCbbk0mmAhE0td6eYiABr8Oz
|
62 |
+
class BigVGANFbank(FeatureExtractor):
|
63 |
+
name = "fbank"
|
64 |
+
config_type = BigVGANFbankConfig
|
65 |
+
|
66 |
+
def __init__(self, config: Optional[Any] = None):
|
67 |
+
super(BigVGANFbank, self).__init__(config)
|
68 |
+
sampling_rate = 24000
|
69 |
+
self.mel_basis = torch.from_numpy(
|
70 |
+
librosa_mel_fn(
|
71 |
+
sampling_rate,
|
72 |
+
1024,
|
73 |
+
self.config.num_mel_bins,
|
74 |
+
self.config.low_freq,
|
75 |
+
self.config.high_freq,
|
76 |
+
).astype(np.float32)
|
77 |
+
)
|
78 |
+
self.hann_window = torch.hann_window(1024)
|
79 |
+
|
80 |
+
def _feature_fn(self, samples, **kwargs):
|
81 |
+
win_length, n_fft = 1024, 1024
|
82 |
+
hop_size = 256
|
83 |
+
if True:
|
84 |
+
sampling_rate = 24000
|
85 |
+
duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
|
86 |
+
expected_num_frames = compute_num_frames(
|
87 |
+
duration=duration,
|
88 |
+
frame_shift=self.frame_shift,
|
89 |
+
sampling_rate=sampling_rate,
|
90 |
+
)
|
91 |
+
pad_size = (
|
92 |
+
(expected_num_frames - 1) * hop_size
|
93 |
+
+ win_length
|
94 |
+
- samples.shape[-1]
|
95 |
+
)
|
96 |
+
assert pad_size >= 0
|
97 |
+
|
98 |
+
y = torch.nn.functional.pad(
|
99 |
+
samples,
|
100 |
+
(0, pad_size),
|
101 |
+
mode="constant",
|
102 |
+
)
|
103 |
+
else:
|
104 |
+
y = torch.nn.functional.pad(
|
105 |
+
samples,
|
106 |
+
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
107 |
+
mode="reflect",
|
108 |
+
)
|
109 |
+
|
110 |
+
y = y.squeeze(1)
|
111 |
+
|
112 |
+
# complex tensor as default, then use view_as_real for future pytorch compatibility
|
113 |
+
spec = torch.stft(
|
114 |
+
y,
|
115 |
+
n_fft,
|
116 |
+
hop_length=hop_size,
|
117 |
+
win_length=win_length,
|
118 |
+
window=self.hann_window,
|
119 |
+
center=False,
|
120 |
+
pad_mode="reflect",
|
121 |
+
normalized=False,
|
122 |
+
onesided=True,
|
123 |
+
return_complex=True,
|
124 |
+
)
|
125 |
+
spec = torch.view_as_real(spec)
|
126 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
127 |
+
|
128 |
+
spec = torch.matmul(self.mel_basis, spec)
|
129 |
+
spec = spectral_normalize_torch(spec)
|
130 |
+
|
131 |
+
return spec.transpose(2, 1).squeeze(0)
|
132 |
+
|
133 |
+
def extract(
|
134 |
+
self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
|
135 |
+
) -> np.ndarray:
|
136 |
+
assert sampling_rate == 24000
|
137 |
+
params = asdict(self.config)
|
138 |
+
params.update({"sample_frequency": sampling_rate, "snip_edges": False})
|
139 |
+
params["frame_shift"] *= 1000.0
|
140 |
+
params["frame_length"] *= 1000.0
|
141 |
+
if not isinstance(samples, torch.Tensor):
|
142 |
+
samples = torch.from_numpy(samples)
|
143 |
+
# Torchaudio Kaldi feature extractors expect the channel dimension to be first.
|
144 |
+
if len(samples.shape) == 1:
|
145 |
+
samples = samples.unsqueeze(0)
|
146 |
+
features = self._feature_fn(samples, **params).to(torch.float32)
|
147 |
+
return features.numpy()
|
148 |
+
|
149 |
+
@property
|
150 |
+
def frame_shift(self) -> Seconds:
|
151 |
+
return self.config.frame_shift
|
152 |
+
|
153 |
+
def feature_dim(self, sampling_rate: int) -> int:
|
154 |
+
return self.config.num_mel_bins
|
155 |
+
|
156 |
+
@staticmethod
|
157 |
+
def mix(
|
158 |
+
features_a: np.ndarray,
|
159 |
+
features_b: np.ndarray,
|
160 |
+
energy_scaling_factor_b: float,
|
161 |
+
) -> np.ndarray:
|
162 |
+
return np.log(
|
163 |
+
np.maximum(
|
164 |
+
# protection against log(0); max with EPSILON is adequate since these are energies (always >= 0)
|
165 |
+
EPSILON,
|
166 |
+
np.exp(features_a)
|
167 |
+
+ energy_scaling_factor_b * np.exp(features_b),
|
168 |
+
)
|
169 |
+
)
|
170 |
+
|
171 |
+
@staticmethod
|
172 |
+
def compute_energy(features: np.ndarray) -> float:
|
173 |
+
return float(np.sum(np.exp(features)))
|
174 |
+
|
175 |
+
|
176 |
+
def get_fbank_extractor() -> BigVGANFbank:
|
177 |
+
return BigVGANFbank(BigVGANFbankConfig())
|
178 |
+
|
179 |
+
|
180 |
+
if __name__ == "__main__":
|
181 |
+
extractor = BigVGANFbank(BigVGANFbankConfig())
|
182 |
+
|
183 |
+
samples = torch.from_numpy(np.random.random([1000]).astype(np.float32))
|
184 |
+
samples = torch.clip(samples, -1.0, 1.0)
|
185 |
+
fbank = extractor.extract(samples, 24000.0)
|
186 |
+
print(f"fbank {fbank.shape}")
|
187 |
+
|
188 |
+
from scipy.io.wavfile import read
|
189 |
+
|
190 |
+
MAX_WAV_VALUE = 32768.0
|
191 |
+
|
192 |
+
sampling_rate, samples = read(
|
193 |
+
"egs/libritts/prompts/5639_40744_000000_000002.wav"
|
194 |
+
)
|
195 |
+
print(f"samples: [{samples.min()}, {samples.max()}]")
|
196 |
+
fbank = extractor.extract(samples.astype(np.float32) / MAX_WAV_VALUE, 24000)
|
197 |
+
print(f"fbank {fbank.shape}")
|
198 |
+
|
199 |
+
import matplotlib.pyplot as plt
|
200 |
+
|
201 |
+
_ = plt.figure(figsize=(18, 10))
|
202 |
+
plt.imshow(
|
203 |
+
X=fbank.transpose(1, 0),
|
204 |
+
cmap=plt.get_cmap("jet"),
|
205 |
+
aspect="auto",
|
206 |
+
interpolation="nearest",
|
207 |
+
)
|
208 |
+
plt.gca().invert_yaxis()
|
209 |
+
plt.savefig("egs/libritts/prompts/5639_40744_000000_000002.png")
|
210 |
+
plt.close()
|
211 |
+
|
212 |
+
print("fbank test PASS!")
|
data/input_strategies.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from collections import defaultdict
|
3 |
+
from concurrent.futures import ThreadPoolExecutor
|
4 |
+
from typing import Tuple, Type
|
5 |
+
|
6 |
+
# from lhotse import CutSet
|
7 |
+
# from lhotse.dataset.collation import collate_features
|
8 |
+
# from lhotse.dataset.input_strategies import (
|
9 |
+
# ExecutorType,
|
10 |
+
# PrecomputedFeatures,
|
11 |
+
# _get_executor,
|
12 |
+
# )
|
13 |
+
# from lhotse.utils import fastcopy
|
14 |
+
|
15 |
+
|
16 |
+
class PromptedFeatures:
|
17 |
+
def __init__(self, prompts, features):
|
18 |
+
self.prompts = prompts
|
19 |
+
self.features = features
|
20 |
+
|
21 |
+
def to(self, device):
|
22 |
+
return PromptedFeatures(
|
23 |
+
self.prompts.to(device), self.features.to(device)
|
24 |
+
)
|
25 |
+
|
26 |
+
def sum(self):
|
27 |
+
return self.features.sum()
|
28 |
+
|
29 |
+
@property
|
30 |
+
def ndim(self):
|
31 |
+
return self.features.ndim
|
32 |
+
|
33 |
+
@property
|
34 |
+
def data(self):
|
35 |
+
return (self.prompts, self.features)
|
36 |
+
|
37 |
+
|
38 |
+
# class PromptedPrecomputedFeatures(PrecomputedFeatures):
|
39 |
+
# """
|
40 |
+
# :class:`InputStrategy` that reads pre-computed features, whose manifests
|
41 |
+
# are attached to cuts, from disk.
|
42 |
+
#
|
43 |
+
# It automatically pads the feature matrices with pre or post feature.
|
44 |
+
#
|
45 |
+
# .. automethod:: __call__
|
46 |
+
# """
|
47 |
+
#
|
48 |
+
# def __init__(
|
49 |
+
# self,
|
50 |
+
# dataset: str,
|
51 |
+
# cuts: CutSet,
|
52 |
+
# num_workers: int = 0,
|
53 |
+
# executor_type: Type[ExecutorType] = ThreadPoolExecutor,
|
54 |
+
# ) -> None:
|
55 |
+
# super(PromptedPrecomputedFeatures, self).__init__(
|
56 |
+
# num_workers, executor_type
|
57 |
+
# )
|
58 |
+
#
|
59 |
+
# self.utt2neighbors = defaultdict(lambda: [])
|
60 |
+
#
|
61 |
+
# if dataset.lower() == "libritts":
|
62 |
+
# # 909_131041_000013_000002
|
63 |
+
# # 909_131041_000013_000003
|
64 |
+
# speaker2utts = defaultdict(lambda: [])
|
65 |
+
#
|
66 |
+
# utt2cut = {}
|
67 |
+
# for cut in cuts:
|
68 |
+
# speaker = cut.supervisions[0].speaker
|
69 |
+
# speaker2utts[speaker].append(cut.id)
|
70 |
+
# utt2cut[cut.id] = cut
|
71 |
+
#
|
72 |
+
# for spk in speaker2utts:
|
73 |
+
# uttids = sorted(speaker2utts[spk])
|
74 |
+
# # Using the property of sorted keys to find previous utterance
|
75 |
+
# # The keys has structure speaker_book_x_y e.g. 1089_134691_000004_000001
|
76 |
+
# if len(uttids) == 1:
|
77 |
+
# self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]])
|
78 |
+
# continue
|
79 |
+
#
|
80 |
+
# utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1]))
|
81 |
+
# utt2postutt = dict(zip(uttids[:-1], uttids[1:]))
|
82 |
+
#
|
83 |
+
# for utt in utt2prevutt:
|
84 |
+
# self.utt2neighbors[utt].append(utt2cut[utt2prevutt[utt]])
|
85 |
+
#
|
86 |
+
# for utt in utt2postutt:
|
87 |
+
# self.utt2neighbors[utt].append(utt2cut[utt2postutt[utt]])
|
88 |
+
# elif dataset.lower() == "ljspeech":
|
89 |
+
# utt2cut = {}
|
90 |
+
# uttids = []
|
91 |
+
# for cut in cuts:
|
92 |
+
# uttids.append(cut.id)
|
93 |
+
# utt2cut[cut.id] = cut
|
94 |
+
#
|
95 |
+
# if len(uttids) == 1:
|
96 |
+
# self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]])
|
97 |
+
# else:
|
98 |
+
# # Using the property of sorted keys to find previous utterance
|
99 |
+
# # The keys has structure: LJ001-0010
|
100 |
+
# utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1]))
|
101 |
+
# utt2postutt = dict(zip(uttids[:-1], uttids[1:]))
|
102 |
+
#
|
103 |
+
# for utt in utt2postutt:
|
104 |
+
# postutt = utt2postutt[utt]
|
105 |
+
# if utt[:5] == postutt[:5]:
|
106 |
+
# self.utt2neighbors[utt].append(utt2cut[postutt])
|
107 |
+
#
|
108 |
+
# for utt in utt2prevutt:
|
109 |
+
# prevutt = utt2prevutt[utt]
|
110 |
+
# if utt[:5] == prevutt[:5] or not self.utt2neighbors[utt]:
|
111 |
+
# self.utt2neighbors[utt].append(utt2cut[prevutt])
|
112 |
+
# else:
|
113 |
+
# raise ValueError
|
114 |
+
#
|
115 |
+
# def __call__(
|
116 |
+
# self, cuts: CutSet
|
117 |
+
# ) -> Tuple[PromptedFeatures, PromptedFeatures]:
|
118 |
+
# """
|
119 |
+
# Reads the pre-computed features from disk/other storage.
|
120 |
+
# The returned shape is``(B, T, F) => (batch_size, num_frames, num_features)``.
|
121 |
+
#
|
122 |
+
# :return: a tensor with collated features, and a tensor of ``num_frames`` of each cut before padding.
|
123 |
+
# """
|
124 |
+
# features, features_lens = collate_features(
|
125 |
+
# cuts,
|
126 |
+
# executor=_get_executor(
|
127 |
+
# self.num_workers, executor_type=self._executor_type
|
128 |
+
# ),
|
129 |
+
# )
|
130 |
+
#
|
131 |
+
# prompts_cuts = []
|
132 |
+
# for k, cut in enumerate(cuts):
|
133 |
+
# prompts_cut = random.choice(self.utt2neighbors[cut.id])
|
134 |
+
# prompts_cuts.append(fastcopy(prompts_cut, id=f"{cut.id}-{str(k)}"))
|
135 |
+
#
|
136 |
+
# mini_duration = min([cut.duration for cut in prompts_cuts] + [3.0])
|
137 |
+
# # prompts_cuts = CutSet.from_cuts(prompts_cuts).truncate(
|
138 |
+
# # max_duration=mini_duration,
|
139 |
+
# # offset_type="random",
|
140 |
+
# # preserve_id=True,
|
141 |
+
# # )
|
142 |
+
# prompts_cuts = CutSet(
|
143 |
+
# cuts={k: cut for k, cut in enumerate(prompts_cuts)}
|
144 |
+
# ).truncate(
|
145 |
+
# max_duration=mini_duration,
|
146 |
+
# offset_type="random",
|
147 |
+
# preserve_id=False,
|
148 |
+
# )
|
149 |
+
#
|
150 |
+
# prompts, prompts_lens = collate_features(
|
151 |
+
# prompts_cuts,
|
152 |
+
# executor=_get_executor(
|
153 |
+
# self.num_workers, executor_type=self._executor_type
|
154 |
+
# ),
|
155 |
+
# )
|
156 |
+
#
|
157 |
+
# return PromptedFeatures(prompts, features), PromptedFeatures(
|
158 |
+
# prompts_lens, features_lens
|
159 |
+
# )
|
data/tokenizer.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2023 (authors: Feiteng Li)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import re
|
17 |
+
from dataclasses import asdict, dataclass
|
18 |
+
from typing import Any, Dict, List, Optional, Pattern, Union
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
import torchaudio
|
23 |
+
from encodec import EncodecModel
|
24 |
+
from encodec.utils import convert_audio
|
25 |
+
|
26 |
+
try:
|
27 |
+
from pypinyin import Style, pinyin
|
28 |
+
from pypinyin.style._utils import get_finals, get_initials
|
29 |
+
except Exception:
|
30 |
+
pass
|
31 |
+
|
32 |
+
|
33 |
+
def remove_encodec_weight_norm(model):
|
34 |
+
from encodec.modules import SConv1d
|
35 |
+
from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock
|
36 |
+
from torch.nn.utils import remove_weight_norm
|
37 |
+
|
38 |
+
encoder = model.encoder.model
|
39 |
+
for key in encoder._modules:
|
40 |
+
if isinstance(encoder._modules[key], SEANetResnetBlock):
|
41 |
+
remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
|
42 |
+
block_modules = encoder._modules[key].block._modules
|
43 |
+
for skey in block_modules:
|
44 |
+
if isinstance(block_modules[skey], SConv1d):
|
45 |
+
remove_weight_norm(block_modules[skey].conv.conv)
|
46 |
+
elif isinstance(encoder._modules[key], SConv1d):
|
47 |
+
remove_weight_norm(encoder._modules[key].conv.conv)
|
48 |
+
|
49 |
+
decoder = model.decoder.model
|
50 |
+
for key in decoder._modules:
|
51 |
+
if isinstance(decoder._modules[key], SEANetResnetBlock):
|
52 |
+
remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
|
53 |
+
block_modules = decoder._modules[key].block._modules
|
54 |
+
for skey in block_modules:
|
55 |
+
if isinstance(block_modules[skey], SConv1d):
|
56 |
+
remove_weight_norm(block_modules[skey].conv.conv)
|
57 |
+
elif isinstance(decoder._modules[key], SConvTranspose1d):
|
58 |
+
remove_weight_norm(decoder._modules[key].convtr.convtr)
|
59 |
+
elif isinstance(decoder._modules[key], SConv1d):
|
60 |
+
remove_weight_norm(decoder._modules[key].conv.conv)
|
61 |
+
|
62 |
+
|
63 |
+
class AudioTokenizer:
|
64 |
+
"""EnCodec audio."""
|
65 |
+
|
66 |
+
def __init__(
|
67 |
+
self,
|
68 |
+
device: Any = None,
|
69 |
+
) -> None:
|
70 |
+
# Instantiate a pretrained EnCodec model
|
71 |
+
model = EncodecModel.encodec_model_24khz()
|
72 |
+
model.set_target_bandwidth(6.0)
|
73 |
+
remove_encodec_weight_norm(model)
|
74 |
+
|
75 |
+
if not device:
|
76 |
+
device = torch.device("cpu")
|
77 |
+
if torch.cuda.is_available():
|
78 |
+
device = torch.device("cuda:0")
|
79 |
+
if torch.backends.mps.is_available():
|
80 |
+
device = torch.device("mps")
|
81 |
+
|
82 |
+
self._device = device
|
83 |
+
|
84 |
+
self.codec = model.to(device)
|
85 |
+
self.sample_rate = model.sample_rate
|
86 |
+
self.channels = model.channels
|
87 |
+
|
88 |
+
@property
|
89 |
+
def device(self):
|
90 |
+
return self._device
|
91 |
+
|
92 |
+
def encode(self, wav: torch.Tensor) -> torch.Tensor:
|
93 |
+
return self.codec.encode(wav.to(self.device))
|
94 |
+
|
95 |
+
def decode(self, frames: torch.Tensor) -> torch.Tensor:
|
96 |
+
return self.codec.decode(frames)
|
97 |
+
|
98 |
+
|
99 |
+
def tokenize_audio(tokenizer: AudioTokenizer, audio):
|
100 |
+
# Load and pre-process the audio waveform
|
101 |
+
if isinstance(audio, str):
|
102 |
+
wav, sr = torchaudio.load(audio)
|
103 |
+
else:
|
104 |
+
wav, sr = audio
|
105 |
+
wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
|
106 |
+
wav = wav.unsqueeze(0)
|
107 |
+
|
108 |
+
# Extract discrete codes from EnCodec
|
109 |
+
with torch.no_grad():
|
110 |
+
encoded_frames = tokenizer.encode(wav)
|
111 |
+
return encoded_frames
|
112 |
+
|
113 |
+
|
114 |
+
if __name__ == "__main__":
|
115 |
+
model = EncodecModel.encodec_model_24khz()
|
116 |
+
model.set_target_bandwidth(6.0)
|
117 |
+
|
118 |
+
samples = torch.from_numpy(np.random.random([4, 1, 1600])).type(
|
119 |
+
torch.float32
|
120 |
+
)
|
121 |
+
codes_raw = model.encode(samples)
|
122 |
+
|
123 |
+
remove_encodec_weight_norm(model)
|
124 |
+
codes_norm = model.encode(samples)
|
125 |
+
|
126 |
+
assert torch.allclose(codes_raw[0][0], codes_norm[0][0])
|
macros.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
NUM_LAYERS = 12
|
2 |
+
NUM_HEAD = 16
|
3 |
+
N_DIM = 1024
|
4 |
+
PREFIX_MODE = 1
|
5 |
+
NUM_QUANTIZERS = 8
|
6 |
+
SAMPLE_RATE = 24000
|
7 |
+
|
8 |
+
lang2token = {
|
9 |
+
'zh': "[ZH]",
|
10 |
+
'ja': "[JA]",
|
11 |
+
"en": "[EN]",
|
12 |
+
"AR": "[AR]",
|
13 |
+
'mix': "",
|
14 |
+
}
|
15 |
+
|
16 |
+
lang2code = {
|
17 |
+
'zh': 0,
|
18 |
+
'ja': 1,
|
19 |
+
"en": 2,
|
20 |
+
"ar": 3,
|
21 |
+
}
|
22 |
+
|
23 |
+
token2lang = {
|
24 |
+
'[ZH]': "zh",
|
25 |
+
'[JA]': "ja",
|
26 |
+
"[EN]": "en",
|
27 |
+
"[AR]": "ar",
|
28 |
+
"": "mix"
|
29 |
+
}
|
30 |
+
|
31 |
+
code2lang = {
|
32 |
+
0: 'zh',
|
33 |
+
1: 'ja',
|
34 |
+
2: "en",
|
35 |
+
3: "ar",
|
36 |
+
}
|
37 |
+
|
38 |
+
langdropdown2token = {
|
39 |
+
'English': "[EN]",
|
40 |
+
'中文': "[ZH]",
|
41 |
+
'日本語': "[JA]",
|
42 |
+
'عربي':"[AR]",
|
43 |
+
'Mix': "",
|
44 |
+
}
|
main.py
CHANGED
@@ -7,9 +7,9 @@ import os
|
|
7 |
from fastapi import FastAPI, Response
|
8 |
import torch
|
9 |
from fastapi.responses import JSONResponse
|
10 |
-
from
|
11 |
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
|
12 |
-
from
|
13 |
from io import BytesIO
|
14 |
from pyannote.audio import Pipeline
|
15 |
import soundfile as sf
|
|
|
7 |
from fastapi import FastAPI, Response
|
8 |
import torch
|
9 |
from fastapi.responses import JSONResponse
|
10 |
+
from utils.prompt_making import make_prompt
|
11 |
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
|
12 |
+
from utils.generation import SAMPLE_RATE, generate_audio, preload_models
|
13 |
from io import BytesIO
|
14 |
from pyannote.audio import Pipeline
|
15 |
import soundfile as sf
|
models/__init__.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
# from icefall.utils import AttributeDict, str2bool
|
5 |
+
|
6 |
+
from .macros import (
|
7 |
+
NUM_AUDIO_TOKENS,
|
8 |
+
NUM_MEL_BINS,
|
9 |
+
NUM_SPEAKER_CLASSES,
|
10 |
+
NUM_TEXT_TOKENS,
|
11 |
+
SPEAKER_EMBEDDING_DIM,
|
12 |
+
)
|
13 |
+
from .transformer import Transformer
|
14 |
+
from .vallex import VALLE, VALLF
|
15 |
+
from .visualizer import visualize
|
16 |
+
|
17 |
+
|
18 |
+
def add_model_arguments(parser: argparse.ArgumentParser):
|
19 |
+
parser.add_argument(
|
20 |
+
"--model-name",
|
21 |
+
type=str,
|
22 |
+
default="VALL-E",
|
23 |
+
help="VALL-E, VALL-F, Transformer.",
|
24 |
+
)
|
25 |
+
parser.add_argument(
|
26 |
+
"--decoder-dim",
|
27 |
+
type=int,
|
28 |
+
default=1024,
|
29 |
+
help="Embedding dimension in the decoder model.",
|
30 |
+
)
|
31 |
+
parser.add_argument(
|
32 |
+
"--nhead",
|
33 |
+
type=int,
|
34 |
+
default=16,
|
35 |
+
help="Number of attention heads in the Decoder layers.",
|
36 |
+
)
|
37 |
+
parser.add_argument(
|
38 |
+
"--num-decoder-layers",
|
39 |
+
type=int,
|
40 |
+
default=12,
|
41 |
+
help="Number of Decoder layers.",
|
42 |
+
)
|
43 |
+
parser.add_argument(
|
44 |
+
"--scale-factor",
|
45 |
+
type=float,
|
46 |
+
default=1.0,
|
47 |
+
help="Model scale factor which will be assigned different meanings in different models.",
|
48 |
+
)
|
49 |
+
parser.add_argument(
|
50 |
+
"--norm-first",
|
51 |
+
type=bool,
|
52 |
+
default=True,
|
53 |
+
help="Pre or Post Normalization.",
|
54 |
+
)
|
55 |
+
parser.add_argument(
|
56 |
+
"--add-prenet",
|
57 |
+
type=bool,
|
58 |
+
default=False,
|
59 |
+
help="Whether add PreNet after Inputs.",
|
60 |
+
)
|
61 |
+
|
62 |
+
# VALL-E & F
|
63 |
+
parser.add_argument(
|
64 |
+
"--prefix-mode",
|
65 |
+
type=int,
|
66 |
+
default=1,
|
67 |
+
help="The mode for how to prefix VALL-E NAR Decoder, "
|
68 |
+
"0: no prefix, 1: 0 to random, 2: random to random, 4: chunk of pre or post utterance.",
|
69 |
+
)
|
70 |
+
parser.add_argument(
|
71 |
+
"--share-embedding",
|
72 |
+
type=bool,
|
73 |
+
default=True,
|
74 |
+
help="Share the parameters of the output projection layer with the parameters of the acoustic embedding.",
|
75 |
+
)
|
76 |
+
parser.add_argument(
|
77 |
+
"--prepend-bos",
|
78 |
+
type=bool,
|
79 |
+
default=False,
|
80 |
+
help="Whether prepend <BOS> to the acoustic tokens -> AR Decoder inputs.",
|
81 |
+
)
|
82 |
+
parser.add_argument(
|
83 |
+
"--num-quantizers",
|
84 |
+
type=int,
|
85 |
+
default=8,
|
86 |
+
help="Number of Audio/Semantic quantization layers.",
|
87 |
+
)
|
88 |
+
|
89 |
+
# Transformer
|
90 |
+
parser.add_argument(
|
91 |
+
"--scaling-xformers",
|
92 |
+
type=bool,
|
93 |
+
default=False,
|
94 |
+
help="Apply Reworked Conformer scaling on Transformers.",
|
95 |
+
)
|
96 |
+
|
97 |
+
|
98 |
+
def get_model(params) -> nn.Module:
|
99 |
+
if params.model_name.lower() in ["vall-f", "vallf"]:
|
100 |
+
model = VALLF(
|
101 |
+
params.decoder_dim,
|
102 |
+
params.nhead,
|
103 |
+
params.num_decoder_layers,
|
104 |
+
norm_first=params.norm_first,
|
105 |
+
add_prenet=params.add_prenet,
|
106 |
+
prefix_mode=params.prefix_mode,
|
107 |
+
share_embedding=params.share_embedding,
|
108 |
+
nar_scale_factor=params.scale_factor,
|
109 |
+
prepend_bos=params.prepend_bos,
|
110 |
+
num_quantizers=params.num_quantizers,
|
111 |
+
)
|
112 |
+
elif params.model_name.lower() in ["vall-e", "valle"]:
|
113 |
+
model = VALLE(
|
114 |
+
params.decoder_dim,
|
115 |
+
params.nhead,
|
116 |
+
params.num_decoder_layers,
|
117 |
+
norm_first=params.norm_first,
|
118 |
+
add_prenet=params.add_prenet,
|
119 |
+
prefix_mode=params.prefix_mode,
|
120 |
+
share_embedding=params.share_embedding,
|
121 |
+
nar_scale_factor=params.scale_factor,
|
122 |
+
prepend_bos=params.prepend_bos,
|
123 |
+
num_quantizers=params.num_quantizers,
|
124 |
+
)
|
125 |
+
else:
|
126 |
+
assert params.model_name in ["Transformer"]
|
127 |
+
model = Transformer(
|
128 |
+
params.decoder_dim,
|
129 |
+
params.nhead,
|
130 |
+
params.num_decoder_layers,
|
131 |
+
norm_first=params.norm_first,
|
132 |
+
add_prenet=params.add_prenet,
|
133 |
+
scaling_xformers=params.scaling_xformers,
|
134 |
+
)
|
135 |
+
|
136 |
+
return model
|
models/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (4.4 kB). View file
|
|
models/__pycache__/macros.cpython-311.pyc
ADDED
Binary file (335 Bytes). View file
|
|
models/__pycache__/transformer.cpython-311.pyc
ADDED
Binary file (15.1 kB). View file
|
|
models/__pycache__/vallex.cpython-311.pyc
ADDED
Binary file (37.6 kB). View file
|
|
models/__pycache__/visualizer.cpython-311.pyc
ADDED
Binary file (5.17 kB). View file
|
|
models/macros.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Text
|
2 |
+
NUM_TEXT_TOKENS = 2048
|
3 |
+
|
4 |
+
# Audio
|
5 |
+
NUM_AUDIO_TOKENS = 1024 # EnCodec RVQ bins
|
6 |
+
NUM_MEL_BINS = 100 # BigVGAN bigvgan_24khz_100band
|
7 |
+
|
8 |
+
|
9 |
+
# Speaker
|
10 |
+
NUM_SPEAKER_CLASSES = 4096
|
11 |
+
SPEAKER_EMBEDDING_DIM = 64
|
models/transformer.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 (authors: Feiteng Li)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from functools import partial
|
16 |
+
from typing import Any, Dict, List, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
# from icefall.utils import make_pad_mask
|
22 |
+
# from torchmetrics.classification import BinaryAccuracy
|
23 |
+
|
24 |
+
from models.vallex import Transpose
|
25 |
+
from modules.embedding import SinePositionalEmbedding, TokenEmbedding
|
26 |
+
from modules.scaling import BalancedDoubleSwish, ScaledLinear
|
27 |
+
from modules.transformer import (
|
28 |
+
BalancedBasicNorm,
|
29 |
+
IdentityNorm,
|
30 |
+
TransformerDecoderLayer,
|
31 |
+
TransformerEncoder,
|
32 |
+
TransformerEncoderLayer,
|
33 |
+
)
|
34 |
+
|
35 |
+
from .macros import NUM_MEL_BINS, NUM_TEXT_TOKENS
|
36 |
+
from .visualizer import visualize
|
37 |
+
|
38 |
+
IdentityNorm = IdentityNorm
|
39 |
+
|
40 |
+
|
41 |
+
class Transformer(nn.Module):
|
42 |
+
"""It implements seq2seq Transformer TTS for debug(No StopPredictor and SpeakerEmbeding)
|
43 |
+
Neural Speech Synthesis with Transformer Network
|
44 |
+
https://arxiv.org/abs/1809.08895
|
45 |
+
"""
|
46 |
+
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
d_model: int,
|
50 |
+
nhead: int,
|
51 |
+
num_layers: int,
|
52 |
+
norm_first: bool = True,
|
53 |
+
add_prenet: bool = False,
|
54 |
+
scaling_xformers: bool = False,
|
55 |
+
):
|
56 |
+
"""
|
57 |
+
Args:
|
58 |
+
d_model:
|
59 |
+
The number of expected features in the input (required).
|
60 |
+
nhead:
|
61 |
+
The number of heads in the multiheadattention models (required).
|
62 |
+
num_layers:
|
63 |
+
The number of sub-decoder-layers in the decoder (required).
|
64 |
+
"""
|
65 |
+
super().__init__()
|
66 |
+
self.text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x
|
67 |
+
|
68 |
+
if add_prenet:
|
69 |
+
self.encoder_prenet = nn.Sequential(
|
70 |
+
Transpose(),
|
71 |
+
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
|
72 |
+
nn.BatchNorm1d(d_model),
|
73 |
+
nn.ReLU(),
|
74 |
+
nn.Dropout(0.5),
|
75 |
+
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
|
76 |
+
nn.BatchNorm1d(d_model),
|
77 |
+
nn.ReLU(),
|
78 |
+
nn.Dropout(0.5),
|
79 |
+
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
|
80 |
+
nn.BatchNorm1d(d_model),
|
81 |
+
nn.ReLU(),
|
82 |
+
nn.Dropout(0.5),
|
83 |
+
Transpose(),
|
84 |
+
nn.Linear(d_model, d_model),
|
85 |
+
)
|
86 |
+
|
87 |
+
self.decoder_prenet = nn.Sequential(
|
88 |
+
nn.Linear(NUM_MEL_BINS, 256),
|
89 |
+
nn.ReLU(),
|
90 |
+
nn.Dropout(0.5),
|
91 |
+
nn.Linear(256, 256),
|
92 |
+
nn.ReLU(),
|
93 |
+
nn.Dropout(0.5),
|
94 |
+
nn.Linear(256, d_model),
|
95 |
+
)
|
96 |
+
|
97 |
+
assert scaling_xformers is False # TODO: update this block
|
98 |
+
else:
|
99 |
+
self.encoder_prenet = nn.Identity()
|
100 |
+
if scaling_xformers:
|
101 |
+
self.decoder_prenet = ScaledLinear(NUM_MEL_BINS, d_model)
|
102 |
+
else:
|
103 |
+
self.decoder_prenet = nn.Linear(NUM_MEL_BINS, d_model)
|
104 |
+
|
105 |
+
self.encoder_position = SinePositionalEmbedding(
|
106 |
+
d_model,
|
107 |
+
dropout=0.1,
|
108 |
+
scale=False,
|
109 |
+
)
|
110 |
+
self.decoder_position = SinePositionalEmbedding(
|
111 |
+
d_model, dropout=0.1, scale=False
|
112 |
+
)
|
113 |
+
|
114 |
+
if scaling_xformers:
|
115 |
+
self.encoder = TransformerEncoder(
|
116 |
+
TransformerEncoderLayer(
|
117 |
+
d_model,
|
118 |
+
nhead,
|
119 |
+
dim_feedforward=d_model * 4,
|
120 |
+
dropout=0.1,
|
121 |
+
batch_first=True,
|
122 |
+
norm_first=norm_first,
|
123 |
+
linear1_self_attention_cls=ScaledLinear,
|
124 |
+
linear2_self_attention_cls=partial(
|
125 |
+
ScaledLinear, initial_scale=0.01
|
126 |
+
),
|
127 |
+
linear1_feedforward_cls=ScaledLinear,
|
128 |
+
linear2_feedforward_cls=partial(
|
129 |
+
ScaledLinear, initial_scale=0.01
|
130 |
+
),
|
131 |
+
activation=partial(
|
132 |
+
BalancedDoubleSwish,
|
133 |
+
channel_dim=-1,
|
134 |
+
max_abs=10.0,
|
135 |
+
min_prob=0.25,
|
136 |
+
),
|
137 |
+
layer_norm_cls=IdentityNorm,
|
138 |
+
),
|
139 |
+
num_layers=num_layers,
|
140 |
+
norm=BalancedBasicNorm(d_model) if norm_first else None,
|
141 |
+
)
|
142 |
+
|
143 |
+
self.decoder = nn.TransformerDecoder(
|
144 |
+
TransformerDecoderLayer(
|
145 |
+
d_model,
|
146 |
+
nhead,
|
147 |
+
dim_feedforward=d_model * 4,
|
148 |
+
dropout=0.1,
|
149 |
+
batch_first=True,
|
150 |
+
norm_first=norm_first,
|
151 |
+
linear1_self_attention_cls=ScaledLinear,
|
152 |
+
linear2_self_attention_cls=partial(
|
153 |
+
ScaledLinear, initial_scale=0.01
|
154 |
+
),
|
155 |
+
linear1_feedforward_cls=ScaledLinear,
|
156 |
+
linear2_feedforward_cls=partial(
|
157 |
+
ScaledLinear, initial_scale=0.01
|
158 |
+
),
|
159 |
+
activation=partial(
|
160 |
+
BalancedDoubleSwish,
|
161 |
+
channel_dim=-1,
|
162 |
+
max_abs=10.0,
|
163 |
+
min_prob=0.25,
|
164 |
+
),
|
165 |
+
layer_norm_cls=IdentityNorm,
|
166 |
+
),
|
167 |
+
num_layers=num_layers,
|
168 |
+
norm=BalancedBasicNorm(d_model) if norm_first else None,
|
169 |
+
)
|
170 |
+
|
171 |
+
self.predict_layer = ScaledLinear(d_model, NUM_MEL_BINS)
|
172 |
+
self.stop_layer = nn.Linear(d_model, 1)
|
173 |
+
else:
|
174 |
+
self.encoder = nn.TransformerEncoder(
|
175 |
+
nn.TransformerEncoderLayer(
|
176 |
+
d_model,
|
177 |
+
nhead,
|
178 |
+
dim_feedforward=d_model * 4,
|
179 |
+
activation=F.relu,
|
180 |
+
dropout=0.1,
|
181 |
+
batch_first=True,
|
182 |
+
norm_first=norm_first,
|
183 |
+
),
|
184 |
+
num_layers=num_layers,
|
185 |
+
norm=nn.LayerNorm(d_model) if norm_first else None,
|
186 |
+
)
|
187 |
+
|
188 |
+
self.decoder = nn.TransformerDecoder(
|
189 |
+
nn.TransformerDecoderLayer(
|
190 |
+
d_model,
|
191 |
+
nhead,
|
192 |
+
dim_feedforward=d_model * 4,
|
193 |
+
activation=F.relu,
|
194 |
+
dropout=0.1,
|
195 |
+
batch_first=True,
|
196 |
+
norm_first=norm_first,
|
197 |
+
),
|
198 |
+
num_layers=num_layers,
|
199 |
+
norm=nn.LayerNorm(d_model) if norm_first else None,
|
200 |
+
)
|
201 |
+
|
202 |
+
self.predict_layer = nn.Linear(d_model, NUM_MEL_BINS)
|
203 |
+
self.stop_layer = nn.Linear(d_model, 1)
|
204 |
+
|
205 |
+
self.stop_accuracy_metric = BinaryAccuracy(
|
206 |
+
threshold=0.5, multidim_average="global"
|
207 |
+
)
|
208 |
+
|
209 |
+
# self.apply(self._init_weights)
|
210 |
+
|
211 |
+
# def _init_weights(self, module):
|
212 |
+
# if isinstance(module, (nn.Linear)):
|
213 |
+
# module.weight.data.normal_(mean=0.0, std=0.02)
|
214 |
+
# if isinstance(module, nn.Linear) and module.bias is not None:
|
215 |
+
# module.bias.data.zero_()
|
216 |
+
# elif isinstance(module, nn.LayerNorm):
|
217 |
+
# module.bias.data.zero_()
|
218 |
+
# module.weight.data.fill_(1.0)
|
219 |
+
# elif isinstance(module, nn.Embedding):
|
220 |
+
# module.weight.data.normal_(mean=0.0, std=0.02)
|
221 |
+
|
222 |
+
def forward(
|
223 |
+
self,
|
224 |
+
x: torch.Tensor,
|
225 |
+
x_lens: torch.Tensor,
|
226 |
+
y: torch.Tensor,
|
227 |
+
y_lens: torch.Tensor,
|
228 |
+
reduction: str = "sum",
|
229 |
+
train_stage: int = 0,
|
230 |
+
**kwargs,
|
231 |
+
) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
|
232 |
+
"""
|
233 |
+
Args:
|
234 |
+
x:
|
235 |
+
A 2-D tensor of shape (N, S).
|
236 |
+
x_lens:
|
237 |
+
A 1-D tensor of shape (N,). It contains the number of tokens in `x`
|
238 |
+
before padding.
|
239 |
+
y:
|
240 |
+
A 3-D tensor of shape (N, T, 8).
|
241 |
+
y_lens:
|
242 |
+
A 1-D tensor of shape (N,). It contains the number of tokens in `x`
|
243 |
+
before padding.
|
244 |
+
train_stage:
|
245 |
+
Not used in this model.
|
246 |
+
Returns:
|
247 |
+
Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy.
|
248 |
+
"""
|
249 |
+
del train_stage
|
250 |
+
|
251 |
+
assert x.ndim == 2, x.shape
|
252 |
+
assert x_lens.ndim == 1, x_lens.shape
|
253 |
+
assert y.ndim == 3, y.shape
|
254 |
+
assert y_lens.ndim == 1, y_lens.shape
|
255 |
+
|
256 |
+
assert torch.all(x_lens > 0)
|
257 |
+
|
258 |
+
# NOTE: x has been padded in TextTokenCollater
|
259 |
+
x_mask = make_pad_mask(x_lens).to(x.device)
|
260 |
+
|
261 |
+
x = self.text_embedding(x)
|
262 |
+
x = self.encoder_prenet(x)
|
263 |
+
x = self.encoder_position(x)
|
264 |
+
x = self.encoder(x, src_key_padding_mask=x_mask)
|
265 |
+
|
266 |
+
total_loss, metrics = 0.0, {}
|
267 |
+
|
268 |
+
y_mask = make_pad_mask(y_lens).to(y.device)
|
269 |
+
y_mask_float = y_mask.type(torch.float32)
|
270 |
+
data_mask = 1.0 - y_mask_float.unsqueeze(-1)
|
271 |
+
|
272 |
+
# Training
|
273 |
+
# AR Decoder
|
274 |
+
def pad_y(y):
|
275 |
+
y = F.pad(y, (0, 0, 1, 0, 0, 0), value=0).detach()
|
276 |
+
# inputs, targets
|
277 |
+
return y[:, :-1], y[:, 1:]
|
278 |
+
|
279 |
+
y, targets = pad_y(y * data_mask) # mask padding as zeros
|
280 |
+
|
281 |
+
y_emb = self.decoder_prenet(y)
|
282 |
+
y_pos = self.decoder_position(y_emb)
|
283 |
+
|
284 |
+
y_len = y_lens.max()
|
285 |
+
tgt_mask = torch.triu(
|
286 |
+
torch.ones(y_len, y_len, device=y.device, dtype=torch.bool),
|
287 |
+
diagonal=1,
|
288 |
+
)
|
289 |
+
y_dec = self.decoder(
|
290 |
+
y_pos,
|
291 |
+
x,
|
292 |
+
tgt_mask=tgt_mask,
|
293 |
+
memory_key_padding_mask=x_mask,
|
294 |
+
)
|
295 |
+
|
296 |
+
predict = self.predict_layer(y_dec)
|
297 |
+
# loss
|
298 |
+
total_loss = F.mse_loss(predict, targets, reduction=reduction)
|
299 |
+
|
300 |
+
logits = self.stop_layer(y_dec).squeeze(-1)
|
301 |
+
stop_loss = F.binary_cross_entropy_with_logits(
|
302 |
+
logits,
|
303 |
+
y_mask_float.detach(),
|
304 |
+
weight=1.0 + y_mask_float.detach() * 4.0,
|
305 |
+
reduction=reduction,
|
306 |
+
)
|
307 |
+
metrics["stop_loss"] = stop_loss.detach()
|
308 |
+
|
309 |
+
stop_accuracy = self.stop_accuracy_metric(
|
310 |
+
(torch.sigmoid(logits) >= 0.5).type(torch.int64),
|
311 |
+
y_mask.type(torch.int64),
|
312 |
+
)
|
313 |
+
# icefall MetricsTracker.norm_items()
|
314 |
+
metrics["stop_accuracy"] = stop_accuracy.item() * y_lens.sum().type(
|
315 |
+
torch.float32
|
316 |
+
)
|
317 |
+
|
318 |
+
return ((x, predict), total_loss + 100.0 * stop_loss, metrics)
|
319 |
+
|
320 |
+
def inference(
|
321 |
+
self,
|
322 |
+
x: torch.Tensor,
|
323 |
+
x_lens: torch.Tensor,
|
324 |
+
y: Any = None,
|
325 |
+
**kwargs,
|
326 |
+
) -> torch.Tensor:
|
327 |
+
"""
|
328 |
+
Args:
|
329 |
+
x:
|
330 |
+
A 2-D tensor of shape (1, S).
|
331 |
+
x_lens:
|
332 |
+
A 1-D tensor of shape (1,). It contains the number of tokens in `x`
|
333 |
+
before padding.
|
334 |
+
Returns:
|
335 |
+
Return the predicted audio code matrix and cross-entropy loss.
|
336 |
+
"""
|
337 |
+
assert x.ndim == 2, x.shape
|
338 |
+
assert x_lens.ndim == 1, x_lens.shape
|
339 |
+
|
340 |
+
assert torch.all(x_lens > 0)
|
341 |
+
|
342 |
+
x_mask = make_pad_mask(x_lens).to(x.device)
|
343 |
+
|
344 |
+
x = self.text_embedding(x)
|
345 |
+
x = self.encoder_prenet(x)
|
346 |
+
x = self.encoder_position(x)
|
347 |
+
x = self.encoder(x, src_key_padding_mask=x_mask)
|
348 |
+
|
349 |
+
x_mask = make_pad_mask(x_lens).to(x.device)
|
350 |
+
|
351 |
+
# AR Decoder
|
352 |
+
# TODO: Managing decoder steps avoid repetitive computation
|
353 |
+
y = torch.zeros(
|
354 |
+
[x.shape[0], 1, NUM_MEL_BINS], dtype=torch.float32, device=x.device
|
355 |
+
)
|
356 |
+
while True:
|
357 |
+
y_emb = self.decoder_prenet(y)
|
358 |
+
y_pos = self.decoder_position(y_emb)
|
359 |
+
|
360 |
+
tgt_mask = torch.triu(
|
361 |
+
torch.ones(
|
362 |
+
y.shape[1], y.shape[1], device=y.device, dtype=torch.bool
|
363 |
+
),
|
364 |
+
diagonal=1,
|
365 |
+
)
|
366 |
+
|
367 |
+
y_dec = self.decoder(
|
368 |
+
y_pos,
|
369 |
+
x,
|
370 |
+
tgt_mask=tgt_mask,
|
371 |
+
memory_mask=None,
|
372 |
+
memory_key_padding_mask=x_mask,
|
373 |
+
)
|
374 |
+
predict = self.predict_layer(y_dec[:, -1:])
|
375 |
+
|
376 |
+
logits = self.stop_layer(y_dec[:, -1:]) > 0 # sigmoid(0.0) = 0.5
|
377 |
+
if y.shape[1] > x_lens.max() * 10 or all(logits.cpu().numpy()):
|
378 |
+
print(
|
379 |
+
f"TransformerTTS EOS [Text {x_lens[0]} -> Audio {y.shape[1]}]"
|
380 |
+
)
|
381 |
+
break
|
382 |
+
|
383 |
+
y = torch.concat([y, predict], dim=1)
|
384 |
+
|
385 |
+
return y[:, 1:]
|
386 |
+
|
387 |
+
def visualize(
|
388 |
+
self,
|
389 |
+
predicts: Tuple[torch.Tensor],
|
390 |
+
batch: Dict[str, Union[List, torch.Tensor]],
|
391 |
+
output_dir: str,
|
392 |
+
limit: int = 4,
|
393 |
+
) -> None:
|
394 |
+
visualize(predicts, batch, output_dir, limit=limit)
|
models/vallex.py
ADDED
@@ -0,0 +1,853 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 (authors: Feiteng Li)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import random
|
16 |
+
from typing import Dict, Iterator, List, Tuple, Union
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
import torch.nn.functional as F
|
22 |
+
# from icefall.utils import make_pad_mask
|
23 |
+
# from torchmetrics.classification import MulticlassAccuracy
|
24 |
+
|
25 |
+
from data.input_strategies import PromptedFeatures
|
26 |
+
from modules.embedding import SinePositionalEmbedding, TokenEmbedding
|
27 |
+
from modules.transformer import (
|
28 |
+
AdaptiveLayerNorm,
|
29 |
+
LayerNorm,
|
30 |
+
TransformerDecoderLayer,
|
31 |
+
TransformerEncoder,
|
32 |
+
TransformerEncoderLayer,
|
33 |
+
)
|
34 |
+
|
35 |
+
from .macros import NUM_AUDIO_TOKENS, NUM_TEXT_TOKENS
|
36 |
+
from .visualizer import visualize
|
37 |
+
|
38 |
+
|
39 |
+
class Transpose(nn.Identity):
|
40 |
+
"""(N, T, D) -> (N, D, T)"""
|
41 |
+
|
42 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
43 |
+
return input.transpose(1, 2)
|
44 |
+
|
45 |
+
|
46 |
+
# NOTE: There are two ways to implement the model
|
47 |
+
# 1) [VALL-F] standard TransformerDecoder, use x as memory
|
48 |
+
# 2) [VALL-E] modified TransformerDecoder like GPT-x(e.g. causal TransformerEncoder),
|
49 |
+
# use x as the prefix of decoder inputs
|
50 |
+
class VALLF(nn.Module):
|
51 |
+
"""It implements https://arxiv.org/abs/2301.02111
|
52 |
+
"Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers"
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
d_model: int,
|
58 |
+
nhead: int,
|
59 |
+
num_layers: int,
|
60 |
+
norm_first: bool = True,
|
61 |
+
add_prenet: bool = False,
|
62 |
+
decoder_cls: Union[
|
63 |
+
nn.TransformerDecoder, nn.TransformerEncoder
|
64 |
+
] = nn.TransformerDecoder,
|
65 |
+
decoder_layer_cls: Union[
|
66 |
+
TransformerDecoderLayer, TransformerEncoderLayer
|
67 |
+
] = TransformerDecoderLayer,
|
68 |
+
prefix_mode: int = 0,
|
69 |
+
share_embedding: bool = True,
|
70 |
+
nar_scale_factor: float = 1.0,
|
71 |
+
prepend_bos: bool = True,
|
72 |
+
num_quantizers: int = 8,
|
73 |
+
):
|
74 |
+
"""
|
75 |
+
Args:
|
76 |
+
d_model:
|
77 |
+
The number of expected features in the input (required).
|
78 |
+
nhead:
|
79 |
+
The number of heads in the multiheadattention models (required).
|
80 |
+
num_layers:
|
81 |
+
The number of sub-decoder-layers in the decoder (required).
|
82 |
+
"""
|
83 |
+
super().__init__()
|
84 |
+
nar_d_model = int(d_model * nar_scale_factor)
|
85 |
+
|
86 |
+
self.ar_text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x
|
87 |
+
self.nar_text_embedding = TokenEmbedding(nar_d_model, NUM_TEXT_TOKENS)
|
88 |
+
|
89 |
+
# ID NUM_AUDIO_TOKENS -> PAD
|
90 |
+
# ID NUM_AUDIO_TOKENS + 1 -> BOS
|
91 |
+
self.ar_audio_prepend_bos = prepend_bos
|
92 |
+
self.ar_audio_embedding = TokenEmbedding(
|
93 |
+
d_model, NUM_AUDIO_TOKENS + 1 + int(prepend_bos)
|
94 |
+
)
|
95 |
+
|
96 |
+
# PreNet
|
97 |
+
if add_prenet:
|
98 |
+
self.ar_text_prenet = nn.Sequential(
|
99 |
+
Transpose(),
|
100 |
+
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
|
101 |
+
nn.BatchNorm1d(d_model),
|
102 |
+
nn.ReLU(),
|
103 |
+
nn.Dropout(0.5),
|
104 |
+
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
|
105 |
+
nn.BatchNorm1d(d_model),
|
106 |
+
nn.ReLU(),
|
107 |
+
nn.Dropout(0.5),
|
108 |
+
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
|
109 |
+
nn.BatchNorm1d(d_model),
|
110 |
+
nn.ReLU(),
|
111 |
+
nn.Dropout(0.5),
|
112 |
+
Transpose(),
|
113 |
+
nn.Linear(d_model, d_model),
|
114 |
+
)
|
115 |
+
|
116 |
+
self.ar_audio_prenet = nn.Sequential(
|
117 |
+
nn.Linear(d_model, 256),
|
118 |
+
nn.ReLU(),
|
119 |
+
nn.Dropout(0.25),
|
120 |
+
nn.Linear(256, 256),
|
121 |
+
nn.ReLU(),
|
122 |
+
nn.Dropout(0.25),
|
123 |
+
nn.Linear(256, d_model),
|
124 |
+
)
|
125 |
+
else:
|
126 |
+
self.ar_text_prenet = nn.Identity()
|
127 |
+
self.ar_audio_prenet = nn.Identity()
|
128 |
+
|
129 |
+
self.ar_text_position = SinePositionalEmbedding(
|
130 |
+
d_model,
|
131 |
+
dropout=0.1,
|
132 |
+
scale=False,
|
133 |
+
alpha=True,
|
134 |
+
)
|
135 |
+
self.ar_audio_position = SinePositionalEmbedding(
|
136 |
+
d_model,
|
137 |
+
dropout=0.1,
|
138 |
+
scale=False,
|
139 |
+
alpha=True,
|
140 |
+
)
|
141 |
+
|
142 |
+
self.ar_decoder = decoder_cls(
|
143 |
+
decoder_layer_cls(
|
144 |
+
d_model,
|
145 |
+
nhead,
|
146 |
+
dim_feedforward=d_model * 4,
|
147 |
+
dropout=0.1,
|
148 |
+
batch_first=True,
|
149 |
+
norm_first=norm_first,
|
150 |
+
),
|
151 |
+
num_layers=num_layers,
|
152 |
+
norm=LayerNorm(d_model) if norm_first else None,
|
153 |
+
)
|
154 |
+
self.ar_predict_layer = nn.Linear(
|
155 |
+
d_model, NUM_AUDIO_TOKENS + 1, bias=False
|
156 |
+
)
|
157 |
+
|
158 |
+
self.rng = random.Random(0)
|
159 |
+
self.num_heads = nhead
|
160 |
+
self.prefix_mode = prefix_mode
|
161 |
+
self.num_quantizers = num_quantizers
|
162 |
+
|
163 |
+
assert num_quantizers >= 1
|
164 |
+
if num_quantizers > 1:
|
165 |
+
self.nar_audio_embeddings = nn.ModuleList(
|
166 |
+
[TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS + 1)]
|
167 |
+
+ [
|
168 |
+
TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS)
|
169 |
+
for i in range(num_quantizers - 1)
|
170 |
+
]
|
171 |
+
) # W_a
|
172 |
+
|
173 |
+
# PreNet
|
174 |
+
if add_prenet:
|
175 |
+
self.nar_text_prenet = nn.Sequential(
|
176 |
+
Transpose(),
|
177 |
+
nn.Conv1d(
|
178 |
+
nar_d_model, nar_d_model, kernel_size=5, padding="same"
|
179 |
+
),
|
180 |
+
nn.BatchNorm1d(nar_d_model),
|
181 |
+
nn.ReLU(),
|
182 |
+
nn.Dropout(0.5),
|
183 |
+
nn.Conv1d(
|
184 |
+
nar_d_model, nar_d_model, kernel_size=5, padding="same"
|
185 |
+
),
|
186 |
+
nn.BatchNorm1d(nar_d_model),
|
187 |
+
nn.ReLU(),
|
188 |
+
nn.Dropout(0.5),
|
189 |
+
nn.Conv1d(
|
190 |
+
nar_d_model, nar_d_model, kernel_size=5, padding="same"
|
191 |
+
),
|
192 |
+
nn.BatchNorm1d(nar_d_model),
|
193 |
+
nn.ReLU(),
|
194 |
+
nn.Dropout(0.5),
|
195 |
+
Transpose(),
|
196 |
+
nn.Linear(nar_d_model, nar_d_model),
|
197 |
+
)
|
198 |
+
self.nar_audio_prenet = nn.Sequential(
|
199 |
+
nn.Linear(nar_d_model, 256),
|
200 |
+
nn.ReLU(),
|
201 |
+
nn.Dropout(0.25),
|
202 |
+
nn.Linear(256, 256),
|
203 |
+
nn.ReLU(),
|
204 |
+
nn.Dropout(0.25),
|
205 |
+
nn.Linear(256, nar_d_model),
|
206 |
+
)
|
207 |
+
else:
|
208 |
+
self.nar_text_prenet = nn.Identity()
|
209 |
+
self.nar_audio_prenet = nn.Identity()
|
210 |
+
|
211 |
+
self.nar_text_position = SinePositionalEmbedding(
|
212 |
+
nar_d_model,
|
213 |
+
dropout=0.0,
|
214 |
+
scale=False,
|
215 |
+
alpha=False,
|
216 |
+
)
|
217 |
+
self.nar_audio_position = SinePositionalEmbedding(
|
218 |
+
nar_d_model,
|
219 |
+
dropout=0.1,
|
220 |
+
scale=False,
|
221 |
+
alpha=False,
|
222 |
+
)
|
223 |
+
|
224 |
+
self.nar_decoder = decoder_cls(
|
225 |
+
decoder_layer_cls(
|
226 |
+
nar_d_model,
|
227 |
+
int(nhead * nar_scale_factor),
|
228 |
+
dim_feedforward=nar_d_model * 4,
|
229 |
+
dropout=0.1,
|
230 |
+
batch_first=True,
|
231 |
+
norm_first=norm_first,
|
232 |
+
adaptive_layer_norm=True,
|
233 |
+
),
|
234 |
+
num_layers=int(num_layers * nar_scale_factor),
|
235 |
+
norm=AdaptiveLayerNorm(
|
236 |
+
nar_d_model, norm=nn.LayerNorm(nar_d_model)
|
237 |
+
)
|
238 |
+
if norm_first
|
239 |
+
else None,
|
240 |
+
)
|
241 |
+
self.nar_predict_layers = nn.ModuleList(
|
242 |
+
[
|
243 |
+
nn.Linear(nar_d_model, NUM_AUDIO_TOKENS, bias=False)
|
244 |
+
for i in range(num_quantizers - 1)
|
245 |
+
]
|
246 |
+
)
|
247 |
+
self.nar_stage_embeddings = nn.ModuleList(
|
248 |
+
[
|
249 |
+
TokenEmbedding(nar_d_model, 1)
|
250 |
+
for i in range(num_quantizers - 1)
|
251 |
+
]
|
252 |
+
)
|
253 |
+
|
254 |
+
if share_embedding:
|
255 |
+
# We share the parameters of the output projection layer with the parameters of the acoustic embedding Wa
|
256 |
+
# NOTE(Feiteng): In the experiment, this undermines accuracy
|
257 |
+
# self.ar_predict_layer.weight = self.ar_audio_embedding.weight
|
258 |
+
|
259 |
+
# We also share the parameters of the acoustic embedding layer and the output prediction layer,
|
260 |
+
# which means the weights of the j-th prediction layer are the same as the (j + 1)-th acoustic embedding layer.
|
261 |
+
for j in range(0, num_quantizers - 2):
|
262 |
+
self.nar_predict_layers[
|
263 |
+
j
|
264 |
+
].weight = self.nar_audio_embeddings[j + 2].weight
|
265 |
+
|
266 |
+
def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]:
|
267 |
+
assert stage > 0
|
268 |
+
if stage == 1:
|
269 |
+
for name, param in self.named_parameters():
|
270 |
+
if name.startswith("ar_"):
|
271 |
+
print(f" AR parameter: {name}")
|
272 |
+
yield param
|
273 |
+
|
274 |
+
if stage == 2:
|
275 |
+
for name, param in self.named_parameters():
|
276 |
+
if name.startswith("nar_"):
|
277 |
+
print(f"NAR parameter: {name}")
|
278 |
+
yield param
|
279 |
+
|
280 |
+
def stage_named_parameters(
|
281 |
+
self, stage: int = 1
|
282 |
+
) -> Iterator[Tuple[str, nn.Parameter]]:
|
283 |
+
assert stage > 0
|
284 |
+
if stage == 1:
|
285 |
+
for pair in self.named_parameters():
|
286 |
+
if pair[0].startswith("ar_"):
|
287 |
+
yield pair
|
288 |
+
|
289 |
+
if stage == 2:
|
290 |
+
for pair in self.named_parameters():
|
291 |
+
if pair[0].startswith("nar_"):
|
292 |
+
yield pair
|
293 |
+
|
294 |
+
def pad_y_eos(self, y, y_mask_int, eos_id):
|
295 |
+
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
|
296 |
+
y_mask_int, (0, 1), value=1
|
297 |
+
)
|
298 |
+
# inputs, targets
|
299 |
+
if self.ar_audio_prepend_bos:
|
300 |
+
return (
|
301 |
+
F.pad(targets[:, :-1], (1, 0), value=NUM_AUDIO_TOKENS + 1),
|
302 |
+
targets,
|
303 |
+
)
|
304 |
+
|
305 |
+
return targets[:, :-1], targets[:, 1:]
|
306 |
+
|
307 |
+
def _prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_codes, prefix_mode):
|
308 |
+
# 5.1 For the NAR acoustic prompt tokens, we select a random segment waveform of 3 seconds
|
309 |
+
# from the same utterance.
|
310 |
+
# We implement this differently.
|
311 |
+
if prefix_mode == 0:
|
312 |
+
# no prefix
|
313 |
+
prefix_len = 0
|
314 |
+
y_emb = self.nar_audio_embeddings[0](y)
|
315 |
+
for j in range(1, nar_stage):
|
316 |
+
# Formula (4) (5)
|
317 |
+
y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j])
|
318 |
+
elif prefix_mode == 1:
|
319 |
+
# prefix at begining
|
320 |
+
int_low = (0.25 * y_lens.min()).type(torch.int64).item()
|
321 |
+
prefix_len = torch.randint(0, int_low * 2, size=()).item()
|
322 |
+
prefix_len = min(prefix_len, 225) # 24000/320 * 3s = 225 frames
|
323 |
+
|
324 |
+
y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len])
|
325 |
+
y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:])
|
326 |
+
for j in range(1, self.num_quantizers):
|
327 |
+
y_prompts += self.nar_audio_embeddings[j](
|
328 |
+
codes[:, :prefix_len, j]
|
329 |
+
)
|
330 |
+
if j < nar_stage:
|
331 |
+
y_emb += self.nar_audio_embeddings[j](
|
332 |
+
codes[:, prefix_len:, j]
|
333 |
+
)
|
334 |
+
y_emb = torch.concat([y_prompts, y_emb], axis=1)
|
335 |
+
elif prefix_mode in [2, 4]:
|
336 |
+
if prefix_mode == 2:
|
337 |
+
# random prefix
|
338 |
+
prefix_len = min(225, int(0.25 * y_lens.min().item()))
|
339 |
+
|
340 |
+
y_prompts_codes = []
|
341 |
+
for b in range(codes.shape[0]):
|
342 |
+
start = self.rng.randint(0, y_lens[b].item() - prefix_len)
|
343 |
+
y_prompts_codes.append(
|
344 |
+
torch.clone(codes[b, start : start + prefix_len])
|
345 |
+
)
|
346 |
+
codes[
|
347 |
+
b, start : start + prefix_len, nar_stage
|
348 |
+
] = NUM_AUDIO_TOKENS
|
349 |
+
y_prompts_codes = torch.stack(y_prompts_codes, dim=0)
|
350 |
+
else:
|
351 |
+
prefix_len = y_prompts_codes.shape[1]
|
352 |
+
|
353 |
+
y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0])
|
354 |
+
y_emb = self.nar_audio_embeddings[0](y)
|
355 |
+
for j in range(1, self.num_quantizers):
|
356 |
+
y_prompts += self.nar_audio_embeddings[j](
|
357 |
+
y_prompts_codes[..., j]
|
358 |
+
)
|
359 |
+
if j < nar_stage:
|
360 |
+
y_emb += self.nar_audio_embeddings[j](codes[..., j])
|
361 |
+
y_emb = torch.concat([y_prompts, y_emb], axis=1)
|
362 |
+
else:
|
363 |
+
raise ValueError
|
364 |
+
|
365 |
+
return y_emb, prefix_len
|
366 |
+
|
367 |
+
def forward(
|
368 |
+
self,
|
369 |
+
x: torch.Tensor,
|
370 |
+
x_lens: torch.Tensor,
|
371 |
+
y: Union[torch.Tensor, PromptedFeatures],
|
372 |
+
y_lens: Union[torch.Tensor, PromptedFeatures],
|
373 |
+
reduction: str = "sum",
|
374 |
+
train_stage: int = 0,
|
375 |
+
**kwargs,
|
376 |
+
) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
|
377 |
+
raise NotImplementedError
|
378 |
+
|
379 |
+
def inference(
|
380 |
+
self,
|
381 |
+
x: torch.Tensor,
|
382 |
+
x_lens: torch.Tensor,
|
383 |
+
y: torch.Tensor,
|
384 |
+
enroll_x_lens: Union[torch.Tensor, None] = None,
|
385 |
+
top_k: int = -100,
|
386 |
+
temperature: float = 1.0,
|
387 |
+
) -> torch.Tensor:
|
388 |
+
raise NotImplementedError
|
389 |
+
|
390 |
+
def visualize(
|
391 |
+
self,
|
392 |
+
predicts: Tuple[torch.Tensor],
|
393 |
+
batch: Dict[str, Union[List, torch.Tensor]],
|
394 |
+
output_dir: str,
|
395 |
+
limit: int = 4,
|
396 |
+
) -> None:
|
397 |
+
raise NotImplementedError
|
398 |
+
|
399 |
+
|
400 |
+
class VALLE(VALLF):
|
401 |
+
"""It implements https://arxiv.org/abs/2301.02111
|
402 |
+
"Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers"
|
403 |
+
"""
|
404 |
+
|
405 |
+
def __init__(
|
406 |
+
self,
|
407 |
+
d_model: int,
|
408 |
+
nhead: int,
|
409 |
+
num_layers: int,
|
410 |
+
norm_first: bool = True,
|
411 |
+
add_prenet: bool = False,
|
412 |
+
prefix_mode: int = 0,
|
413 |
+
share_embedding: bool = True,
|
414 |
+
nar_scale_factor: float = 1.0,
|
415 |
+
**kwargs,
|
416 |
+
):
|
417 |
+
"""
|
418 |
+
Args:
|
419 |
+
d_model:
|
420 |
+
The number of expected features in the input (required).
|
421 |
+
nhead:
|
422 |
+
The number of heads in the multiheadattention models (required).
|
423 |
+
num_layers:
|
424 |
+
The number of sub-decoder-layers in the decoder (required).
|
425 |
+
"""
|
426 |
+
super(VALLE, self).__init__(
|
427 |
+
d_model,
|
428 |
+
nhead,
|
429 |
+
num_layers,
|
430 |
+
norm_first=norm_first,
|
431 |
+
add_prenet=add_prenet,
|
432 |
+
decoder_cls=TransformerEncoder,
|
433 |
+
decoder_layer_cls=TransformerEncoderLayer,
|
434 |
+
prefix_mode=prefix_mode,
|
435 |
+
share_embedding=share_embedding,
|
436 |
+
nar_scale_factor=nar_scale_factor,
|
437 |
+
**kwargs,
|
438 |
+
)
|
439 |
+
self.language_ID = {
|
440 |
+
'en': 0,
|
441 |
+
'zh': 1,
|
442 |
+
'ja': 2,
|
443 |
+
}
|
444 |
+
self.ar_language_embedding = TokenEmbedding(d_model, len(self.language_ID))
|
445 |
+
self.nar_language_embedding = TokenEmbedding(d_model, len(self.language_ID))
|
446 |
+
|
447 |
+
def forward(
|
448 |
+
self,
|
449 |
+
x: torch.Tensor,
|
450 |
+
x_lens: torch.Tensor,
|
451 |
+
y: Union[torch.Tensor, PromptedFeatures],
|
452 |
+
y_lens: Union[torch.Tensor, PromptedFeatures],
|
453 |
+
reduction: str = "sum",
|
454 |
+
train_stage: int = 0,
|
455 |
+
**kwargs,
|
456 |
+
):
|
457 |
+
raise NotImplementedError
|
458 |
+
def inference(
|
459 |
+
self,
|
460 |
+
x: torch.Tensor,
|
461 |
+
x_lens: torch.Tensor,
|
462 |
+
y: torch.Tensor,
|
463 |
+
enroll_x_lens: torch.Tensor,
|
464 |
+
top_k: int = -100,
|
465 |
+
temperature: float = 1.0,
|
466 |
+
prompt_language: str = None,
|
467 |
+
text_language: str = None,
|
468 |
+
best_of: int = 1,
|
469 |
+
length_penalty: float = 1.0,
|
470 |
+
return_worst: bool = False,
|
471 |
+
) -> torch.Tensor:
|
472 |
+
"""
|
473 |
+
Args:
|
474 |
+
x:
|
475 |
+
A 2-D tensor of shape (1, S).
|
476 |
+
x_lens:
|
477 |
+
A 1-D tensor of shape (1,). It contains the number of tokens in `x`
|
478 |
+
before padding.
|
479 |
+
y:
|
480 |
+
A 3-D tensor of shape (1, T, 8).
|
481 |
+
top_k: (`optional`) int
|
482 |
+
The number of highest probability tokens to keep for top-k-filtering. Default to -100.
|
483 |
+
temperature: (`optional`) float
|
484 |
+
The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
|
485 |
+
Returns:
|
486 |
+
Return the predicted audio code matrix.
|
487 |
+
"""
|
488 |
+
assert x.ndim == 2, x.shape
|
489 |
+
assert x_lens.ndim == 1, x_lens.shape
|
490 |
+
assert y.ndim == 3, y.shape
|
491 |
+
assert y.shape[0] == 1, y.shape
|
492 |
+
|
493 |
+
assert torch.all(x_lens > 0)
|
494 |
+
|
495 |
+
# NOTE: x has been padded in TextTokenCollater
|
496 |
+
text = x
|
497 |
+
x = self.ar_text_embedding(text)
|
498 |
+
# Add language embedding
|
499 |
+
prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device)
|
500 |
+
if isinstance(text_language, str):
|
501 |
+
text_language_id = torch.LongTensor(np.array([self.language_ID[text_language]])).to(x.device)
|
502 |
+
elif isinstance(text_language, List):
|
503 |
+
text_language_id = torch.LongTensor(np.array([self.language_ID[tl] for tl in text_language])).to(x.device)
|
504 |
+
x[:, :enroll_x_lens, :] += self.ar_language_embedding(prompt_language_id)
|
505 |
+
x[:, enroll_x_lens:, :] += self.ar_language_embedding(text_language_id)
|
506 |
+
x = self.ar_text_prenet(x)
|
507 |
+
x = self.ar_text_position(x)
|
508 |
+
|
509 |
+
text_len = x_lens.max()
|
510 |
+
prompts = y
|
511 |
+
prefix_len = y.shape[1]
|
512 |
+
|
513 |
+
# AR Decoder
|
514 |
+
# TODO: Managing decoder steps avoid repetitive computation
|
515 |
+
y = prompts[..., 0]
|
516 |
+
if self.ar_audio_prepend_bos:
|
517 |
+
y = F.pad(y, (1, 0), value=NUM_AUDIO_TOKENS + 1)
|
518 |
+
|
519 |
+
x_len = x_lens.max()
|
520 |
+
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
|
521 |
+
|
522 |
+
kv_cache = None
|
523 |
+
use_kv_caching = True
|
524 |
+
|
525 |
+
sum_logprobs = torch.zeros(best_of, device=y.device) # implement batch decoding here
|
526 |
+
x = x.repeat(best_of, 1, 1)
|
527 |
+
y = y.repeat(best_of, 1)
|
528 |
+
while True:
|
529 |
+
y_emb = self.ar_audio_embedding(y)
|
530 |
+
y_emb = self.ar_audio_prenet(y_emb)
|
531 |
+
y_pos = self.ar_audio_position(y_emb)
|
532 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
533 |
+
|
534 |
+
y_len = y.shape[1]
|
535 |
+
x_attn_mask_pad = F.pad(
|
536 |
+
x_attn_mask,
|
537 |
+
(0, y_len),
|
538 |
+
value=True,
|
539 |
+
)
|
540 |
+
y_attn_mask = F.pad(
|
541 |
+
torch.triu(
|
542 |
+
torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1
|
543 |
+
),
|
544 |
+
(x_len, 0),
|
545 |
+
value=False,
|
546 |
+
)
|
547 |
+
xy_attn_mask = torch.concat(
|
548 |
+
[x_attn_mask_pad, y_attn_mask], dim=0
|
549 |
+
).to(y.device)
|
550 |
+
|
551 |
+
|
552 |
+
if use_kv_caching and kv_cache is not None:
|
553 |
+
xy_pos = xy_pos[:, [-1]]
|
554 |
+
else:
|
555 |
+
pass
|
556 |
+
|
557 |
+
xy_dec, kv_cache = self.ar_decoder.infer(
|
558 |
+
xy_pos,
|
559 |
+
mask=xy_attn_mask,
|
560 |
+
past_kv=kv_cache,
|
561 |
+
use_cache=use_kv_caching,
|
562 |
+
)
|
563 |
+
# xy_dec, _ = self.ar_decoder(
|
564 |
+
# (xy_pos, None),
|
565 |
+
# mask=xy_attn_mask,
|
566 |
+
# )
|
567 |
+
|
568 |
+
logits = self.ar_predict_layer(xy_dec[:, -1])
|
569 |
+
samples, current_logprobs = topk_sampling(
|
570 |
+
logits, top_k=top_k, top_p=1, temperature=temperature
|
571 |
+
)
|
572 |
+
sum_logprobs += current_logprobs * (y[:, -1] != NUM_AUDIO_TOKENS)
|
573 |
+
samples[y[:, -1] == NUM_AUDIO_TOKENS] = NUM_AUDIO_TOKENS
|
574 |
+
completed = (samples[:, -1] == NUM_AUDIO_TOKENS).all()
|
575 |
+
if (
|
576 |
+
completed
|
577 |
+
or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16
|
578 |
+
):
|
579 |
+
if prompts.shape[1] == y.shape[1]:
|
580 |
+
raise SyntaxError(
|
581 |
+
"well trained model shouldn't reach here."
|
582 |
+
)
|
583 |
+
lengths = torch.sum(y != NUM_AUDIO_TOKENS, dim=1)
|
584 |
+
avg_logprobs = sum_logprobs / lengths ** length_penalty
|
585 |
+
# choose the best beam according to sum_logprobs
|
586 |
+
best_beam = y[torch.argmax(avg_logprobs), :]
|
587 |
+
worst_beam = y[torch.argmin(avg_logprobs), :]
|
588 |
+
# strip all eos tokens
|
589 |
+
best_beam = best_beam[best_beam != NUM_AUDIO_TOKENS]
|
590 |
+
worst_beam = worst_beam[worst_beam != NUM_AUDIO_TOKENS]
|
591 |
+
if return_worst:
|
592 |
+
y = worst_beam.unsqueeze(0)
|
593 |
+
else:
|
594 |
+
y = best_beam.unsqueeze(0)
|
595 |
+
print(f"VALL-E EOS [{prompts.shape[1]} -> {y.shape[1]}]")
|
596 |
+
break
|
597 |
+
|
598 |
+
y = torch.concat([y, samples], dim=1)
|
599 |
+
|
600 |
+
codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]]
|
601 |
+
if self.num_quantizers == 1:
|
602 |
+
return torch.stack(codes, dim=-1)
|
603 |
+
|
604 |
+
# Non-AR Decoders
|
605 |
+
y_emb = self.nar_audio_embeddings[0](
|
606 |
+
y[:, int(self.ar_audio_prepend_bos) :]
|
607 |
+
)
|
608 |
+
|
609 |
+
if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes
|
610 |
+
enrolled_len = enroll_x_lens.max().item()
|
611 |
+
# SOS + Synthesis Text + EOS
|
612 |
+
text = torch.concat(
|
613 |
+
[
|
614 |
+
text[:, :1],
|
615 |
+
text[:, enrolled_len - 1 :],
|
616 |
+
],
|
617 |
+
dim=1,
|
618 |
+
)
|
619 |
+
text_len = text_len - (enrolled_len - 2)
|
620 |
+
assert text.shape[0] == 1
|
621 |
+
|
622 |
+
x = self.nar_text_embedding(text)
|
623 |
+
# Add language embedding
|
624 |
+
prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device)
|
625 |
+
if isinstance(text_language, str):
|
626 |
+
text_language_id = torch.LongTensor(np.array([self.language_ID[text_language]])).to(x.device)
|
627 |
+
elif isinstance(text_language, List):
|
628 |
+
text_language_id = torch.LongTensor(np.array([self.language_ID[tl] for tl in text_language])).to(x.device)
|
629 |
+
x[:, :enroll_x_lens, :] += self.nar_language_embedding(prompt_language_id)
|
630 |
+
x[:, enroll_x_lens:, :] += self.nar_language_embedding(text_language_id)
|
631 |
+
x = self.nar_text_prenet(x)
|
632 |
+
x = self.nar_text_position(x)
|
633 |
+
|
634 |
+
if self.prefix_mode == 0:
|
635 |
+
for i, (predict_layer, embedding_layer) in enumerate(
|
636 |
+
zip(
|
637 |
+
self.nar_predict_layers,
|
638 |
+
self.nar_audio_embeddings[1:],
|
639 |
+
)
|
640 |
+
):
|
641 |
+
y_pos = self.nar_audio_prenet(y_emb)
|
642 |
+
y_pos = self.nar_audio_position(y_pos)
|
643 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
644 |
+
|
645 |
+
xy_dec, _ = self.nar_decoder(
|
646 |
+
(xy_pos, self.nar_stage_embeddings[i].weight)
|
647 |
+
)
|
648 |
+
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
|
649 |
+
|
650 |
+
samples = torch.argmax(logits, dim=-1)
|
651 |
+
codes.append(samples)
|
652 |
+
|
653 |
+
if i < self.num_quantizers - 2:
|
654 |
+
y_emb[:, :prefix_len] += embedding_layer(
|
655 |
+
prompts[..., i + 1]
|
656 |
+
)
|
657 |
+
y_emb[:, prefix_len:] += embedding_layer(samples)
|
658 |
+
else:
|
659 |
+
for j in range(1, self.num_quantizers):
|
660 |
+
y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](
|
661 |
+
prompts[..., j]
|
662 |
+
)
|
663 |
+
|
664 |
+
for i, (predict_layer, embedding_layer) in enumerate(
|
665 |
+
zip(
|
666 |
+
self.nar_predict_layers,
|
667 |
+
self.nar_audio_embeddings[1:],
|
668 |
+
)
|
669 |
+
):
|
670 |
+
y_pos = self.nar_audio_prenet(y_emb)
|
671 |
+
y_pos = self.nar_audio_position(y_pos)
|
672 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
673 |
+
|
674 |
+
xy_dec, _ = self.nar_decoder(
|
675 |
+
(xy_pos, self.nar_stage_embeddings[i].weight)
|
676 |
+
)
|
677 |
+
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
|
678 |
+
|
679 |
+
samples = torch.argmax(logits, dim=-1)
|
680 |
+
codes.append(samples)
|
681 |
+
|
682 |
+
if i < self.num_quantizers - 2:
|
683 |
+
y_emb[:, prefix_len:] += embedding_layer(samples)
|
684 |
+
|
685 |
+
assert len(codes) == self.num_quantizers
|
686 |
+
return torch.stack(codes, dim=-1)
|
687 |
+
|
688 |
+
def continual(
|
689 |
+
self,
|
690 |
+
x: torch.Tensor,
|
691 |
+
x_lens: torch.Tensor,
|
692 |
+
y: torch.Tensor,
|
693 |
+
) -> torch.Tensor:
|
694 |
+
"""
|
695 |
+
Args:
|
696 |
+
x:
|
697 |
+
A 2-D tensor of shape (1, S).
|
698 |
+
x_lens:
|
699 |
+
A 1-D tensor of shape (1,). It contains the number of tokens in `x`
|
700 |
+
before padding.
|
701 |
+
y:
|
702 |
+
A 3-D tensor of shape (1, T, 8).
|
703 |
+
Returns:
|
704 |
+
Return the predicted audio code matrix.
|
705 |
+
"""
|
706 |
+
assert x.ndim == 2, x.shape
|
707 |
+
assert x_lens.ndim == 1, x_lens.shape
|
708 |
+
assert y.ndim == 3, y.shape
|
709 |
+
assert y.shape[0] == 1, y.shape
|
710 |
+
|
711 |
+
assert torch.all(x_lens > 0)
|
712 |
+
assert self.num_quantizers == 8
|
713 |
+
|
714 |
+
# NOTE: x has been padded in TextTokenCollater
|
715 |
+
text = x
|
716 |
+
x = self.ar_text_embedding(text)
|
717 |
+
x = self.ar_text_prenet(x)
|
718 |
+
x = self.ar_text_position(x)
|
719 |
+
|
720 |
+
text_len = x_lens.max()
|
721 |
+
|
722 |
+
prefix_len = min(int(y.shape[1] * 0.5), 3 * 75)
|
723 |
+
|
724 |
+
# AR Decoder
|
725 |
+
prompts = y[:, :prefix_len]
|
726 |
+
|
727 |
+
codes = [y[:, prefix_len:, 0]]
|
728 |
+
# Non-AR Decoders
|
729 |
+
x = self.nar_text_embedding(text)
|
730 |
+
x = self.nar_text_prenet(x)
|
731 |
+
x = self.nar_text_position(x)
|
732 |
+
|
733 |
+
y_emb = self.nar_audio_embeddings[0](y[..., 0])
|
734 |
+
|
735 |
+
if self.prefix_mode == 0:
|
736 |
+
for i, (predict_layer, embedding_layer) in enumerate(
|
737 |
+
zip(
|
738 |
+
self.nar_predict_layers,
|
739 |
+
self.nar_audio_embeddings[1:],
|
740 |
+
)
|
741 |
+
):
|
742 |
+
y_pos = self.nar_audio_position(y_emb)
|
743 |
+
y_pos = self.nar_audio_prenet(y_pos)
|
744 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
745 |
+
|
746 |
+
xy_dec, _ = self.nar_decoder(
|
747 |
+
(xy_pos, self.nar_stage_embeddings[i].weight)
|
748 |
+
)
|
749 |
+
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
|
750 |
+
|
751 |
+
samples = torch.argmax(logits, dim=-1)
|
752 |
+
codes.append(samples)
|
753 |
+
|
754 |
+
if i < 6:
|
755 |
+
y_emb[:, :prefix_len] += embedding_layer(
|
756 |
+
prompts[..., i + 1]
|
757 |
+
)
|
758 |
+
y_emb[:, prefix_len:] += embedding_layer(samples)
|
759 |
+
else:
|
760 |
+
for j in range(1, 8):
|
761 |
+
y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](
|
762 |
+
prompts[..., j]
|
763 |
+
)
|
764 |
+
|
765 |
+
for i, (predict_layer, embedding_layer) in enumerate(
|
766 |
+
zip(
|
767 |
+
self.nar_predict_layers,
|
768 |
+
self.nar_audio_embeddings[1:],
|
769 |
+
)
|
770 |
+
):
|
771 |
+
y_pos = self.nar_audio_prenet(y_emb)
|
772 |
+
y_pos = self.nar_audio_position(y_pos)
|
773 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
774 |
+
|
775 |
+
xy_dec, _ = self.nar_decoder(
|
776 |
+
(xy_pos, self.nar_stage_embeddings[i].weight)
|
777 |
+
)
|
778 |
+
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
|
779 |
+
|
780 |
+
samples = torch.argmax(logits, dim=-1)
|
781 |
+
codes.append(samples)
|
782 |
+
|
783 |
+
if i < 6:
|
784 |
+
y_emb[:, prefix_len:] += embedding_layer(samples)
|
785 |
+
|
786 |
+
assert len(codes) == 8
|
787 |
+
return torch.stack(codes, dim=-1)
|
788 |
+
|
789 |
+
|
790 |
+
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
|
791 |
+
def top_k_top_p_filtering(
|
792 |
+
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
|
793 |
+
):
|
794 |
+
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
795 |
+
Args:
|
796 |
+
logits: logits distribution shape (batch size, vocabulary size)
|
797 |
+
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
798 |
+
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
799 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
800 |
+
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
801 |
+
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
802 |
+
"""
|
803 |
+
if top_k > 0:
|
804 |
+
top_k = min(
|
805 |
+
max(top_k, min_tokens_to_keep), logits.size(-1)
|
806 |
+
) # Safety check
|
807 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
808 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
809 |
+
logits[indices_to_remove] = filter_value
|
810 |
+
|
811 |
+
if top_p < 1.0:
|
812 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
813 |
+
cumulative_probs = torch.cumsum(
|
814 |
+
F.softmax(sorted_logits, dim=-1), dim=-1
|
815 |
+
)
|
816 |
+
|
817 |
+
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
818 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
819 |
+
if min_tokens_to_keep > 1:
|
820 |
+
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
821 |
+
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
822 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
823 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
|
824 |
+
..., :-1
|
825 |
+
].clone()
|
826 |
+
sorted_indices_to_remove[..., 0] = 0
|
827 |
+
|
828 |
+
# scatter sorted tensors to original indexing
|
829 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
830 |
+
1, sorted_indices, sorted_indices_to_remove
|
831 |
+
)
|
832 |
+
logits[indices_to_remove] = filter_value
|
833 |
+
return logits
|
834 |
+
|
835 |
+
|
836 |
+
def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
|
837 |
+
# temperature: (`optional`) float
|
838 |
+
# The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
|
839 |
+
# top_k: (`optional`) int
|
840 |
+
# The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
|
841 |
+
# top_p: (`optional`) float
|
842 |
+
# The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
|
843 |
+
|
844 |
+
# Temperature (higher temperature => more likely to sample low probability tokens)
|
845 |
+
if temperature != 1.0:
|
846 |
+
logits = logits / temperature
|
847 |
+
# Top-p/top-k filtering
|
848 |
+
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
849 |
+
# Sample
|
850 |
+
token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
851 |
+
logprobs = F.log_softmax(logits.float(), dim=-1)
|
852 |
+
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), token.squeeze(1)]
|
853 |
+
return token, current_logprobs
|
models/visualizer.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2023 (authors: Feiteng Li)
|
3 |
+
#
|
4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
|
19 |
+
from typing import Dict, List, Tuple, Union
|
20 |
+
|
21 |
+
import matplotlib.pyplot as plt
|
22 |
+
import numpy as np
|
23 |
+
import torch
|
24 |
+
|
25 |
+
|
26 |
+
def visualize(
|
27 |
+
predicts: Tuple[torch.Tensor],
|
28 |
+
batch: Dict[str, Union[List, torch.Tensor]],
|
29 |
+
output_dir: str,
|
30 |
+
limit: int = 4,
|
31 |
+
) -> None:
|
32 |
+
text_tokens = batch["text_tokens"].to("cpu").detach().numpy()
|
33 |
+
text_tokens_lens = batch["text_tokens_lens"].to("cpu").detach().numpy()
|
34 |
+
audio_features = batch["audio_features"].to("cpu").detach().numpy()
|
35 |
+
audio_features_lens = (
|
36 |
+
batch["audio_features_lens"].to("cpu").detach().numpy()
|
37 |
+
)
|
38 |
+
assert text_tokens.ndim == 2
|
39 |
+
|
40 |
+
utt_ids, texts = batch["utt_id"], batch["text"]
|
41 |
+
|
42 |
+
encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy()
|
43 |
+
decoder_outputs = predicts[1]
|
44 |
+
if isinstance(decoder_outputs, list):
|
45 |
+
decoder_outputs = decoder_outputs[-1]
|
46 |
+
decoder_outputs = (
|
47 |
+
decoder_outputs.to("cpu").type(torch.float32).detach().numpy()
|
48 |
+
)
|
49 |
+
|
50 |
+
vmin, vmax = 0, 1024 # Encodec
|
51 |
+
if decoder_outputs.dtype == np.float32:
|
52 |
+
vmin, vmax = -6, 0 # Fbank
|
53 |
+
|
54 |
+
num_figures = 3
|
55 |
+
for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])):
|
56 |
+
_ = plt.figure(figsize=(14, 8 * num_figures))
|
57 |
+
|
58 |
+
S = text_tokens_lens[b]
|
59 |
+
T = audio_features_lens[b]
|
60 |
+
|
61 |
+
# encoder
|
62 |
+
plt.subplot(num_figures, 1, 1)
|
63 |
+
plt.title(f"Text: {text}")
|
64 |
+
plt.imshow(
|
65 |
+
X=np.transpose(encoder_outputs[b]),
|
66 |
+
cmap=plt.get_cmap("jet"),
|
67 |
+
aspect="auto",
|
68 |
+
interpolation="nearest",
|
69 |
+
)
|
70 |
+
plt.gca().invert_yaxis()
|
71 |
+
plt.axvline(x=S - 0.4, linewidth=2, color="r")
|
72 |
+
plt.xlabel("Encoder Output")
|
73 |
+
plt.colorbar()
|
74 |
+
|
75 |
+
# decoder
|
76 |
+
plt.subplot(num_figures, 1, 2)
|
77 |
+
plt.imshow(
|
78 |
+
X=np.transpose(decoder_outputs[b]),
|
79 |
+
cmap=plt.get_cmap("jet"),
|
80 |
+
aspect="auto",
|
81 |
+
interpolation="nearest",
|
82 |
+
vmin=vmin,
|
83 |
+
vmax=vmax,
|
84 |
+
)
|
85 |
+
plt.gca().invert_yaxis()
|
86 |
+
plt.axvline(x=T - 0.4, linewidth=2, color="r")
|
87 |
+
plt.xlabel("Decoder Output")
|
88 |
+
plt.colorbar()
|
89 |
+
|
90 |
+
# target
|
91 |
+
plt.subplot(num_figures, 1, 3)
|
92 |
+
plt.imshow(
|
93 |
+
X=np.transpose(audio_features[b]),
|
94 |
+
cmap=plt.get_cmap("jet"),
|
95 |
+
aspect="auto",
|
96 |
+
interpolation="nearest",
|
97 |
+
vmin=vmin,
|
98 |
+
vmax=vmax,
|
99 |
+
)
|
100 |
+
plt.gca().invert_yaxis()
|
101 |
+
plt.axvline(x=T - 0.4, linewidth=2, color="r")
|
102 |
+
plt.xlabel("Decoder Target")
|
103 |
+
plt.colorbar()
|
104 |
+
|
105 |
+
plt.savefig(f"{output_dir}/{utt_id}.png")
|
106 |
+
plt.close()
|
modules/__init__.py
ADDED
File without changes
|
modules/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (175 Bytes). View file
|
|
modules/__pycache__/activation.cpython-311.pyc
ADDED
Binary file (27.5 kB). View file
|
|
modules/__pycache__/embedding.cpython-311.pyc
ADDED
Binary file (6.15 kB). View file
|
|
modules/__pycache__/scaling.cpython-311.pyc
ADDED
Binary file (69 kB). View file
|
|
modules/__pycache__/transformer.cpython-311.pyc
ADDED
Binary file (28.2 kB). View file
|
|
modules/activation.py
ADDED
@@ -0,0 +1,612 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple, List
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import Tensor
|
6 |
+
from torch.nn import Linear, Module
|
7 |
+
from torch.nn import functional as F
|
8 |
+
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
9 |
+
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
10 |
+
from torch.nn.parameter import Parameter
|
11 |
+
|
12 |
+
def _in_projection_packed(
|
13 |
+
q: Tensor,
|
14 |
+
k: Tensor,
|
15 |
+
v: Tensor,
|
16 |
+
w: Tensor,
|
17 |
+
b: Optional[Tensor] = None,
|
18 |
+
) -> List[Tensor]:
|
19 |
+
r"""
|
20 |
+
Performs the in-projection step of the attention operation, using packed weights.
|
21 |
+
Output is a triple containing projection tensors for query, key and value.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
q, k, v: query, key and value tensors to be projected. For self-attention,
|
25 |
+
these are typically the same tensor; for encoder-decoder attention,
|
26 |
+
k and v are typically the same tensor. (We take advantage of these
|
27 |
+
identities for performance if they are present.) Regardless, q, k and v
|
28 |
+
must share a common embedding dimension; otherwise their shapes may vary.
|
29 |
+
w: projection weights for q, k and v, packed into a single tensor. Weights
|
30 |
+
are packed along dimension 0, in q, k, v order.
|
31 |
+
b: optional projection biases for q, k and v, packed into a single tensor
|
32 |
+
in q, k, v order.
|
33 |
+
|
34 |
+
Shape:
|
35 |
+
Inputs:
|
36 |
+
- q: :math:`(..., E)` where E is the embedding dimension
|
37 |
+
- k: :math:`(..., E)` where E is the embedding dimension
|
38 |
+
- v: :math:`(..., E)` where E is the embedding dimension
|
39 |
+
- w: :math:`(E * 3, E)` where E is the embedding dimension
|
40 |
+
- b: :math:`E * 3` where E is the embedding dimension
|
41 |
+
|
42 |
+
Output:
|
43 |
+
- in output list :math:`[q', k', v']`, each output tensor will have the
|
44 |
+
same shape as the corresponding input tensor.
|
45 |
+
"""
|
46 |
+
E = q.size(-1)
|
47 |
+
if k is v:
|
48 |
+
if q is k:
|
49 |
+
# self-attention
|
50 |
+
return F.linear(q, w, b).chunk(3, dim=-1)
|
51 |
+
else:
|
52 |
+
# encoder-decoder attention
|
53 |
+
w_q, w_kv = w.split([E, E * 2])
|
54 |
+
if b is None:
|
55 |
+
b_q = b_kv = None
|
56 |
+
else:
|
57 |
+
b_q, b_kv = b.split([E, E * 2])
|
58 |
+
return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1)
|
59 |
+
else:
|
60 |
+
w_q, w_k, w_v = w.chunk(3)
|
61 |
+
if b is None:
|
62 |
+
b_q = b_k = b_v = None
|
63 |
+
else:
|
64 |
+
b_q, b_k, b_v = b.chunk(3)
|
65 |
+
return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
|
66 |
+
|
67 |
+
def _scaled_dot_product_attention(
|
68 |
+
q: Tensor,
|
69 |
+
k: Tensor,
|
70 |
+
v: Tensor,
|
71 |
+
attn_mask: Optional[Tensor] = None,
|
72 |
+
dropout_p: float = 0.0,
|
73 |
+
) -> Tuple[Tensor, Tensor]:
|
74 |
+
r"""
|
75 |
+
Computes scaled dot product attention on query, key and value tensors, using
|
76 |
+
an optional attention mask if passed, and applying dropout if a probability
|
77 |
+
greater than 0.0 is specified.
|
78 |
+
Returns a tensor pair containing attended values and attention weights.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
q, k, v: query, key and value tensors. See Shape section for shape details.
|
82 |
+
attn_mask: optional tensor containing mask values to be added to calculated
|
83 |
+
attention. May be 2D or 3D; see Shape section for details.
|
84 |
+
dropout_p: dropout probability. If greater than 0.0, dropout is applied.
|
85 |
+
|
86 |
+
Shape:
|
87 |
+
- q: :math:`(B, Nt, E)` where B is batch size, Nt is the target sequence length,
|
88 |
+
and E is embedding dimension.
|
89 |
+
- key: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
|
90 |
+
and E is embedding dimension.
|
91 |
+
- value: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
|
92 |
+
and E is embedding dimension.
|
93 |
+
- attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of
|
94 |
+
shape :math:`(Nt, Ns)`.
|
95 |
+
|
96 |
+
- Output: attention values have shape :math:`(B, Nt, E)`; attention weights
|
97 |
+
have shape :math:`(B, Nt, Ns)`
|
98 |
+
"""
|
99 |
+
B, Nt, E = q.shape
|
100 |
+
q = q / math.sqrt(E)
|
101 |
+
# (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
|
102 |
+
if attn_mask is not None:
|
103 |
+
attn = torch.baddbmm(attn_mask, q, k.transpose(-2, -1))
|
104 |
+
else:
|
105 |
+
attn = torch.bmm(q, k.transpose(-2, -1))
|
106 |
+
|
107 |
+
attn = F.softmax(attn, dim=-1)
|
108 |
+
if dropout_p > 0.0:
|
109 |
+
attn = F.dropout(attn, p=dropout_p)
|
110 |
+
# (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
|
111 |
+
output = torch.bmm(attn, v)
|
112 |
+
return output, attn
|
113 |
+
|
114 |
+
def multi_head_attention_forward(
|
115 |
+
x,
|
116 |
+
ipw,
|
117 |
+
ipb,
|
118 |
+
opw,
|
119 |
+
opb,
|
120 |
+
n_head,
|
121 |
+
attn_mask,
|
122 |
+
past_kv=None,
|
123 |
+
use_cache=False,
|
124 |
+
):
|
125 |
+
# x = x.transpose(1, 0)
|
126 |
+
# tgt_len, bsz, embed_dim = x.shape
|
127 |
+
# head_dim = embed_dim // n_head
|
128 |
+
# q, k, v = _in_projection_packed(x, x, x, ipw, ipb)
|
129 |
+
# q = q.contiguous().view(tgt_len, bsz * n_head, head_dim).transpose(0, 1)
|
130 |
+
# k = k.contiguous().view(k.shape[0], bsz * n_head, head_dim).transpose(0, 1)
|
131 |
+
# v = v.contiguous().view(v.shape[0], bsz * n_head, head_dim).transpose(0, 1)
|
132 |
+
|
133 |
+
# new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
|
134 |
+
# new_attn_mask.masked_fill_(attn_mask, float("-inf"))
|
135 |
+
# attn_mask = new_attn_mask
|
136 |
+
#
|
137 |
+
# attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, 0.0)
|
138 |
+
# attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
|
139 |
+
# attn_output = torch._C._nn.linear(attn_output, opw, opb)
|
140 |
+
# attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
141 |
+
|
142 |
+
B, T, C = x.size()
|
143 |
+
|
144 |
+
q, k, v = torch._C._nn.linear(x, ipw, ipb).chunk(3, dim=-1)
|
145 |
+
k = k.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs)
|
146 |
+
q = q.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs)
|
147 |
+
v = v.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs)
|
148 |
+
if past_kv is not None:
|
149 |
+
past_key = past_kv[0]
|
150 |
+
past_value = past_kv[1]
|
151 |
+
k = torch.cat((past_key, k), dim=-2)
|
152 |
+
v = torch.cat((past_value, v), dim=-2)
|
153 |
+
|
154 |
+
FULL_T = k.shape[-2]
|
155 |
+
|
156 |
+
if use_cache is True:
|
157 |
+
present = (k, v)
|
158 |
+
else:
|
159 |
+
present = None
|
160 |
+
|
161 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
162 |
+
att = att.masked_fill(attn_mask[FULL_T - T:FULL_T, :FULL_T], float('-inf'))
|
163 |
+
att = F.softmax(att, dim=-1)
|
164 |
+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
165 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
166 |
+
y = torch._C._nn.linear(y, opw, opb)
|
167 |
+
return (y, present)
|
168 |
+
|
169 |
+
|
170 |
+
class MultiheadAttention(Module):
|
171 |
+
r"""Allows the model to jointly attend to information
|
172 |
+
from different representation subspaces as described in the paper:
|
173 |
+
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
|
174 |
+
|
175 |
+
Multi-Head Attention is defined as:
|
176 |
+
|
177 |
+
.. math::
|
178 |
+
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
179 |
+
|
180 |
+
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
|
181 |
+
|
182 |
+
``forward()`` will use a special optimized implementation if all of the following
|
183 |
+
conditions are met:
|
184 |
+
|
185 |
+
- self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
|
186 |
+
restriction will be loosened in the future.)
|
187 |
+
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
|
188 |
+
- training is disabled (using ``.eval()``)
|
189 |
+
- dropout is 0
|
190 |
+
- ``add_bias_kv`` is ``False``
|
191 |
+
- ``add_zero_attn`` is ``False``
|
192 |
+
- ``batch_first`` is ``True`` and the input is batched
|
193 |
+
- ``kdim`` and ``vdim`` are equal to ``embed_dim``
|
194 |
+
- at most one of ``key_padding_mask`` or ``attn_mask`` is passed
|
195 |
+
- if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
|
196 |
+
nor ``attn_mask`` is passed
|
197 |
+
|
198 |
+
If the optimized implementation is in use, a
|
199 |
+
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
|
200 |
+
``query``/``key``/``value`` to represent padding more efficiently than using a
|
201 |
+
padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
|
202 |
+
will be returned, and an additional speedup proportional to the fraction of the input
|
203 |
+
that is padding can be expected.
|
204 |
+
|
205 |
+
Args:
|
206 |
+
embed_dim: Total dimension of the model.
|
207 |
+
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
|
208 |
+
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
|
209 |
+
dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
|
210 |
+
bias: If specified, adds bias to input / output projection layers. Default: ``True``.
|
211 |
+
add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
|
212 |
+
add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
|
213 |
+
Default: ``False``.
|
214 |
+
kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
|
215 |
+
vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
|
216 |
+
batch_first: If ``True``, then the input and output tensors are provided
|
217 |
+
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
218 |
+
|
219 |
+
Examples::
|
220 |
+
|
221 |
+
>>> # xdoctest: +SKIP
|
222 |
+
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
223 |
+
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
224 |
+
|
225 |
+
"""
|
226 |
+
__constants__ = ["batch_first"]
|
227 |
+
bias_k: Optional[torch.Tensor]
|
228 |
+
bias_v: Optional[torch.Tensor]
|
229 |
+
|
230 |
+
def __init__(
|
231 |
+
self,
|
232 |
+
embed_dim,
|
233 |
+
num_heads,
|
234 |
+
dropout=0.0,
|
235 |
+
bias=True,
|
236 |
+
add_bias_kv=False,
|
237 |
+
add_zero_attn=False,
|
238 |
+
kdim=None,
|
239 |
+
vdim=None,
|
240 |
+
batch_first=False,
|
241 |
+
linear1_cls=Linear,
|
242 |
+
linear2_cls=Linear,
|
243 |
+
device=None,
|
244 |
+
dtype=None,
|
245 |
+
) -> None:
|
246 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
247 |
+
super(MultiheadAttention, self).__init__()
|
248 |
+
self.embed_dim = embed_dim
|
249 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
250 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
251 |
+
self._qkv_same_embed_dim = (
|
252 |
+
self.kdim == embed_dim and self.vdim == embed_dim
|
253 |
+
)
|
254 |
+
|
255 |
+
self.num_heads = num_heads
|
256 |
+
self.dropout = dropout
|
257 |
+
self.batch_first = batch_first
|
258 |
+
self.head_dim = embed_dim // num_heads
|
259 |
+
assert (
|
260 |
+
self.head_dim * num_heads == self.embed_dim
|
261 |
+
), "embed_dim must be divisible by num_heads"
|
262 |
+
|
263 |
+
if add_bias_kv:
|
264 |
+
self.bias_k = Parameter(
|
265 |
+
torch.empty((1, 1, embed_dim), **factory_kwargs)
|
266 |
+
)
|
267 |
+
self.bias_v = Parameter(
|
268 |
+
torch.empty((1, 1, embed_dim), **factory_kwargs)
|
269 |
+
)
|
270 |
+
else:
|
271 |
+
self.bias_k = self.bias_v = None
|
272 |
+
|
273 |
+
if linear1_cls == Linear:
|
274 |
+
if not self._qkv_same_embed_dim:
|
275 |
+
self.q_proj_weight = Parameter(
|
276 |
+
torch.empty((embed_dim, embed_dim), **factory_kwargs)
|
277 |
+
)
|
278 |
+
self.k_proj_weight = Parameter(
|
279 |
+
torch.empty((embed_dim, self.kdim), **factory_kwargs)
|
280 |
+
)
|
281 |
+
self.v_proj_weight = Parameter(
|
282 |
+
torch.empty((embed_dim, self.vdim), **factory_kwargs)
|
283 |
+
)
|
284 |
+
self.register_parameter("in_proj_weight", None)
|
285 |
+
else:
|
286 |
+
self.in_proj_weight = Parameter(
|
287 |
+
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
|
288 |
+
)
|
289 |
+
self.register_parameter("q_proj_weight", None)
|
290 |
+
self.register_parameter("k_proj_weight", None)
|
291 |
+
self.register_parameter("v_proj_weight", None)
|
292 |
+
|
293 |
+
if bias:
|
294 |
+
self.in_proj_bias = Parameter(
|
295 |
+
torch.empty(3 * embed_dim, **factory_kwargs)
|
296 |
+
)
|
297 |
+
else:
|
298 |
+
self.register_parameter("in_proj_bias", None)
|
299 |
+
self.out_proj = NonDynamicallyQuantizableLinear(
|
300 |
+
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
301 |
+
)
|
302 |
+
|
303 |
+
self._reset_parameters()
|
304 |
+
else:
|
305 |
+
if not self._qkv_same_embed_dim:
|
306 |
+
raise NotImplementedError
|
307 |
+
else:
|
308 |
+
self.in_proj_linear = linear1_cls(
|
309 |
+
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
|
310 |
+
)
|
311 |
+
self.in_proj_weight = self.in_proj_linear.weight
|
312 |
+
|
313 |
+
self.register_parameter("q_proj_weight", None)
|
314 |
+
self.register_parameter("k_proj_weight", None)
|
315 |
+
self.register_parameter("v_proj_weight", None)
|
316 |
+
|
317 |
+
if bias:
|
318 |
+
self.in_proj_bias = self.in_proj_linear.bias
|
319 |
+
else:
|
320 |
+
self.register_parameter("in_proj_bias", None)
|
321 |
+
|
322 |
+
self.out_proj = linear2_cls(
|
323 |
+
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
324 |
+
)
|
325 |
+
|
326 |
+
if self.bias_k is not None:
|
327 |
+
xavier_normal_(self.bias_k)
|
328 |
+
if self.bias_v is not None:
|
329 |
+
xavier_normal_(self.bias_v)
|
330 |
+
|
331 |
+
self.add_zero_attn = add_zero_attn
|
332 |
+
|
333 |
+
def _reset_parameters(self):
|
334 |
+
if self._qkv_same_embed_dim:
|
335 |
+
xavier_uniform_(self.in_proj_weight)
|
336 |
+
else:
|
337 |
+
xavier_uniform_(self.q_proj_weight)
|
338 |
+
xavier_uniform_(self.k_proj_weight)
|
339 |
+
xavier_uniform_(self.v_proj_weight)
|
340 |
+
|
341 |
+
if self.in_proj_bias is not None:
|
342 |
+
constant_(self.in_proj_bias, 0.0)
|
343 |
+
constant_(self.out_proj.bias, 0.0)
|
344 |
+
|
345 |
+
if self.bias_k is not None:
|
346 |
+
xavier_normal_(self.bias_k)
|
347 |
+
if self.bias_v is not None:
|
348 |
+
xavier_normal_(self.bias_v)
|
349 |
+
|
350 |
+
def __setstate__(self, state):
|
351 |
+
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
352 |
+
if "_qkv_same_embed_dim" not in state:
|
353 |
+
state["_qkv_same_embed_dim"] = True
|
354 |
+
|
355 |
+
super(MultiheadAttention, self).__setstate__(state)
|
356 |
+
|
357 |
+
def forward(
|
358 |
+
self,
|
359 |
+
query: Tensor,
|
360 |
+
key: Tensor,
|
361 |
+
value: Tensor,
|
362 |
+
key_padding_mask: Optional[Tensor] = None,
|
363 |
+
need_weights: bool = True,
|
364 |
+
attn_mask: Optional[Tensor] = None,
|
365 |
+
average_attn_weights: bool = True,
|
366 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
367 |
+
r"""
|
368 |
+
Args:
|
369 |
+
query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
|
370 |
+
or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
|
371 |
+
:math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
|
372 |
+
Queries are compared against key-value pairs to produce the output.
|
373 |
+
See "Attention Is All You Need" for more details.
|
374 |
+
key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
|
375 |
+
or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
|
376 |
+
:math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
|
377 |
+
See "Attention Is All You Need" for more details.
|
378 |
+
value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
|
379 |
+
``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
|
380 |
+
sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
|
381 |
+
See "Attention Is All You Need" for more details.
|
382 |
+
key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
|
383 |
+
to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
|
384 |
+
Binary and byte masks are supported.
|
385 |
+
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
|
386 |
+
the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
|
387 |
+
need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
|
388 |
+
Default: ``True``.
|
389 |
+
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
|
390 |
+
:math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
|
391 |
+
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
|
392 |
+
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
|
393 |
+
Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
|
394 |
+
corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
|
395 |
+
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
|
396 |
+
the attention weight.
|
397 |
+
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
|
398 |
+
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
|
399 |
+
effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
|
400 |
+
|
401 |
+
Outputs:
|
402 |
+
- **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
|
403 |
+
:math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
|
404 |
+
where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
|
405 |
+
embedding dimension ``embed_dim``.
|
406 |
+
- **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
|
407 |
+
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
|
408 |
+
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
|
409 |
+
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
|
410 |
+
head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
|
411 |
+
|
412 |
+
.. note::
|
413 |
+
`batch_first` argument is ignored for unbatched inputs.
|
414 |
+
"""
|
415 |
+
is_batched = query.dim() == 3
|
416 |
+
if key_padding_mask is not None:
|
417 |
+
_kpm_dtype = key_padding_mask.dtype
|
418 |
+
if _kpm_dtype != torch.bool and not torch.is_floating_point(
|
419 |
+
key_padding_mask
|
420 |
+
):
|
421 |
+
raise AssertionError(
|
422 |
+
"only bool and floating types of key_padding_mask are supported"
|
423 |
+
)
|
424 |
+
why_not_fast_path = ""
|
425 |
+
if not is_batched:
|
426 |
+
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
427 |
+
elif query is not key or key is not value:
|
428 |
+
# When lifting this restriction, don't forget to either
|
429 |
+
# enforce that the dtypes all match or test cases where
|
430 |
+
# they don't!
|
431 |
+
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
|
432 |
+
elif (
|
433 |
+
self.in_proj_bias is not None
|
434 |
+
and query.dtype != self.in_proj_bias.dtype
|
435 |
+
):
|
436 |
+
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
|
437 |
+
elif (
|
438 |
+
self.in_proj_weight is not None
|
439 |
+
and query.dtype != self.in_proj_weight.dtype
|
440 |
+
):
|
441 |
+
# this case will fail anyway, but at least they'll get a useful error message.
|
442 |
+
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
|
443 |
+
elif self.training:
|
444 |
+
why_not_fast_path = "training is enabled"
|
445 |
+
elif not self.batch_first:
|
446 |
+
why_not_fast_path = "batch_first was not True"
|
447 |
+
elif self.bias_k is not None:
|
448 |
+
why_not_fast_path = "self.bias_k was not None"
|
449 |
+
elif self.bias_v is not None:
|
450 |
+
why_not_fast_path = "self.bias_v was not None"
|
451 |
+
elif self.dropout:
|
452 |
+
why_not_fast_path = f"dropout was {self.dropout}, required zero"
|
453 |
+
elif self.add_zero_attn:
|
454 |
+
why_not_fast_path = "add_zero_attn was enabled"
|
455 |
+
elif not self._qkv_same_embed_dim:
|
456 |
+
why_not_fast_path = "_qkv_same_embed_dim was not True"
|
457 |
+
elif attn_mask is not None:
|
458 |
+
why_not_fast_path = "attn_mask was not None"
|
459 |
+
elif query.is_nested and key_padding_mask is not None:
|
460 |
+
why_not_fast_path = (
|
461 |
+
"key_padding_mask is not supported with NestedTensor input"
|
462 |
+
)
|
463 |
+
elif self.num_heads % 2 == 1:
|
464 |
+
why_not_fast_path = "num_heads is odd"
|
465 |
+
elif torch.is_autocast_enabled():
|
466 |
+
why_not_fast_path = "autocast is enabled"
|
467 |
+
|
468 |
+
if not why_not_fast_path:
|
469 |
+
tensor_args = (
|
470 |
+
query,
|
471 |
+
key,
|
472 |
+
value,
|
473 |
+
self.in_proj_weight,
|
474 |
+
self.in_proj_bias,
|
475 |
+
self.out_proj.weight,
|
476 |
+
self.out_proj.bias,
|
477 |
+
)
|
478 |
+
# We have to use list comprehensions below because TorchScript does not support
|
479 |
+
# generator expressions.
|
480 |
+
if torch.overrides.has_torch_function(tensor_args):
|
481 |
+
why_not_fast_path = "some Tensor argument has_torch_function"
|
482 |
+
elif not all(
|
483 |
+
[
|
484 |
+
(x is None or x.is_cuda or "cpu" in str(x.device))
|
485 |
+
for x in tensor_args
|
486 |
+
]
|
487 |
+
):
|
488 |
+
why_not_fast_path = (
|
489 |
+
"some Tensor argument is neither CUDA nor CPU"
|
490 |
+
)
|
491 |
+
elif torch.is_grad_enabled() and any(
|
492 |
+
[x is not None and x.requires_grad for x in tensor_args]
|
493 |
+
):
|
494 |
+
why_not_fast_path = (
|
495 |
+
"grad is enabled and at least one of query or the "
|
496 |
+
"input/output projection weights or biases requires_grad"
|
497 |
+
)
|
498 |
+
if not why_not_fast_path:
|
499 |
+
return torch._native_multi_head_attention(
|
500 |
+
query,
|
501 |
+
key,
|
502 |
+
value,
|
503 |
+
self.embed_dim,
|
504 |
+
self.num_heads,
|
505 |
+
self.in_proj_weight,
|
506 |
+
self.in_proj_bias,
|
507 |
+
self.out_proj.weight,
|
508 |
+
self.out_proj.bias,
|
509 |
+
key_padding_mask
|
510 |
+
if key_padding_mask is not None
|
511 |
+
else attn_mask,
|
512 |
+
need_weights,
|
513 |
+
average_attn_weights,
|
514 |
+
1
|
515 |
+
if key_padding_mask is not None
|
516 |
+
else 0
|
517 |
+
if attn_mask is not None
|
518 |
+
else None,
|
519 |
+
)
|
520 |
+
|
521 |
+
any_nested = query.is_nested or key.is_nested or value.is_nested
|
522 |
+
assert not any_nested, (
|
523 |
+
"MultiheadAttention does not support NestedTensor outside of its fast path. "
|
524 |
+
+ f"The fast path was not hit because {why_not_fast_path}"
|
525 |
+
)
|
526 |
+
|
527 |
+
if self.batch_first and is_batched:
|
528 |
+
# make sure that the transpose op does not affect the "is" property
|
529 |
+
if key is value:
|
530 |
+
if query is key:
|
531 |
+
query = key = value = query.transpose(1, 0)
|
532 |
+
else:
|
533 |
+
query, key = [x.transpose(1, 0) for x in (query, key)]
|
534 |
+
value = key
|
535 |
+
else:
|
536 |
+
query, key, value = [
|
537 |
+
x.transpose(1, 0) for x in (query, key, value)
|
538 |
+
]
|
539 |
+
|
540 |
+
if not self._qkv_same_embed_dim:
|
541 |
+
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
542 |
+
query,
|
543 |
+
key,
|
544 |
+
value,
|
545 |
+
self.embed_dim,
|
546 |
+
self.num_heads,
|
547 |
+
self.in_proj_weight,
|
548 |
+
self.in_proj_bias,
|
549 |
+
self.bias_k,
|
550 |
+
self.bias_v,
|
551 |
+
self.add_zero_attn,
|
552 |
+
self.dropout,
|
553 |
+
self.out_proj.weight,
|
554 |
+
self.out_proj.bias,
|
555 |
+
training=self.training,
|
556 |
+
key_padding_mask=key_padding_mask,
|
557 |
+
need_weights=need_weights,
|
558 |
+
attn_mask=attn_mask,
|
559 |
+
use_separate_proj_weight=True,
|
560 |
+
q_proj_weight=self.q_proj_weight,
|
561 |
+
k_proj_weight=self.k_proj_weight,
|
562 |
+
v_proj_weight=self.v_proj_weight,
|
563 |
+
average_attn_weights=average_attn_weights,
|
564 |
+
)
|
565 |
+
else:
|
566 |
+
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
567 |
+
query,
|
568 |
+
key,
|
569 |
+
value,
|
570 |
+
self.embed_dim,
|
571 |
+
self.num_heads,
|
572 |
+
self.in_proj_weight,
|
573 |
+
self.in_proj_bias,
|
574 |
+
self.bias_k,
|
575 |
+
self.bias_v,
|
576 |
+
self.add_zero_attn,
|
577 |
+
self.dropout,
|
578 |
+
self.out_proj.weight,
|
579 |
+
self.out_proj.bias,
|
580 |
+
training=self.training,
|
581 |
+
key_padding_mask=key_padding_mask,
|
582 |
+
need_weights=need_weights,
|
583 |
+
attn_mask=attn_mask,
|
584 |
+
average_attn_weights=average_attn_weights,
|
585 |
+
)
|
586 |
+
if self.batch_first and is_batched:
|
587 |
+
return attn_output.transpose(1, 0), attn_output_weights
|
588 |
+
else:
|
589 |
+
return attn_output, attn_output_weights
|
590 |
+
|
591 |
+
def infer(self,
|
592 |
+
x: Tensor,
|
593 |
+
key_padding_mask: Optional[Tensor] = None,
|
594 |
+
need_weights: bool = True,
|
595 |
+
attn_mask: Optional[Tensor] = None,
|
596 |
+
average_attn_weights: bool = True,
|
597 |
+
past_kv = None,
|
598 |
+
use_cache = False
|
599 |
+
):
|
600 |
+
# x = x.transpose(1, 0)
|
601 |
+
y, kv = multi_head_attention_forward(
|
602 |
+
x=x,
|
603 |
+
ipw=self.in_proj_weight,
|
604 |
+
ipb=self.in_proj_bias,
|
605 |
+
opw=self.out_proj.weight,
|
606 |
+
opb=self.out_proj.bias,
|
607 |
+
n_head=self.num_heads,
|
608 |
+
attn_mask=attn_mask,
|
609 |
+
past_kv=past_kv,
|
610 |
+
use_cache=use_cache,
|
611 |
+
)
|
612 |
+
return (y, kv)
|
modules/embedding.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 (authors: Feiteng Li)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
|
21 |
+
class TokenEmbedding(nn.Module):
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
dim_model: int,
|
25 |
+
vocab_size: int,
|
26 |
+
dropout: float = 0.0,
|
27 |
+
):
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
self.vocab_size = vocab_size
|
31 |
+
self.dim_model = dim_model
|
32 |
+
|
33 |
+
self.dropout = torch.nn.Dropout(p=dropout)
|
34 |
+
self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)
|
35 |
+
|
36 |
+
@property
|
37 |
+
def weight(self) -> torch.Tensor:
|
38 |
+
return self.word_embeddings.weight
|
39 |
+
|
40 |
+
def embedding(self, index: int) -> torch.Tensor:
|
41 |
+
return self.word_embeddings.weight[index : index + 1]
|
42 |
+
|
43 |
+
def forward(self, x: torch.Tensor):
|
44 |
+
X = self.word_embeddings(x)
|
45 |
+
X = self.dropout(X)
|
46 |
+
|
47 |
+
return X
|
48 |
+
|
49 |
+
|
50 |
+
class SinePositionalEmbedding(nn.Module):
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
dim_model: int,
|
54 |
+
dropout: float = 0.0,
|
55 |
+
scale: bool = False,
|
56 |
+
alpha: bool = False,
|
57 |
+
):
|
58 |
+
super().__init__()
|
59 |
+
self.dim_model = dim_model
|
60 |
+
self.x_scale = math.sqrt(dim_model) if scale else 1.0
|
61 |
+
self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
|
62 |
+
self.dropout = torch.nn.Dropout(p=dropout)
|
63 |
+
|
64 |
+
self.reverse = False
|
65 |
+
self.pe = None
|
66 |
+
self.extend_pe(torch.tensor(0.0).expand(1, 4000))
|
67 |
+
|
68 |
+
def extend_pe(self, x):
|
69 |
+
"""Reset the positional encodings."""
|
70 |
+
if self.pe is not None:
|
71 |
+
if self.pe.size(1) >= x.size(1):
|
72 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
73 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
74 |
+
return
|
75 |
+
pe = torch.zeros(x.size(1), self.dim_model)
|
76 |
+
if self.reverse:
|
77 |
+
position = torch.arange(
|
78 |
+
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
79 |
+
).unsqueeze(1)
|
80 |
+
else:
|
81 |
+
position = torch.arange(
|
82 |
+
0, x.size(1), dtype=torch.float32
|
83 |
+
).unsqueeze(1)
|
84 |
+
div_term = torch.exp(
|
85 |
+
torch.arange(0, self.dim_model, 2, dtype=torch.float32)
|
86 |
+
* -(math.log(10000.0) / self.dim_model)
|
87 |
+
)
|
88 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
89 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
90 |
+
pe = pe.unsqueeze(0)
|
91 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
|
92 |
+
|
93 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
94 |
+
self.extend_pe(x)
|
95 |
+
output = x.unsqueeze(-1) if x.ndim == 2 else x
|
96 |
+
output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
|
97 |
+
return self.dropout(output)
|
modules/optim.py
ADDED
@@ -0,0 +1,1105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
|
2 |
+
#
|
3 |
+
# See ../LICENSE for clarification regarding multiple authors
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import contextlib
|
18 |
+
import logging
|
19 |
+
import random
|
20 |
+
from collections import defaultdict
|
21 |
+
from typing import List, Optional, Tuple, Union
|
22 |
+
|
23 |
+
import torch
|
24 |
+
from lhotse.utils import fix_random_seed
|
25 |
+
from torch import Tensor
|
26 |
+
from torch.optim import Optimizer
|
27 |
+
|
28 |
+
|
29 |
+
class BatchedOptimizer(Optimizer):
|
30 |
+
"""
|
31 |
+
This class adds to class Optimizer the capability to optimize parameters in batches:
|
32 |
+
it will stack the parameters and their grads for you so the optimizer can work
|
33 |
+
on tensors with an extra leading dimension. This is intended for speed with GPUs,
|
34 |
+
as it reduces the number of kernels launched in the optimizer.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
params:
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, params, defaults):
|
41 |
+
super(BatchedOptimizer, self).__init__(params, defaults)
|
42 |
+
|
43 |
+
@contextlib.contextmanager
|
44 |
+
def batched_params(self, param_group, group_params_names):
|
45 |
+
"""
|
46 |
+
This function returns (technically, yields) a list of
|
47 |
+
of tuples (p, state), where
|
48 |
+
p is a `fake` parameter that is stacked (over axis 0) from real parameters
|
49 |
+
that share the same shape, and its gradient is also stacked;
|
50 |
+
`state` is the state corresponding to this batch of parameters
|
51 |
+
(it will be physically located in the "state" for one of the real
|
52 |
+
parameters, the last one that has any particular shape and dtype).
|
53 |
+
|
54 |
+
This function is decorated as a context manager so that it can
|
55 |
+
write parameters back to their "real" locations.
|
56 |
+
|
57 |
+
The idea is, instead of doing:
|
58 |
+
<code>
|
59 |
+
for p in group["params"]:
|
60 |
+
state = self.state[p]
|
61 |
+
...
|
62 |
+
</code>
|
63 |
+
you can do:
|
64 |
+
<code>
|
65 |
+
with self.batched_params(group["params"]) as batches:
|
66 |
+
for p, state, p_names in batches:
|
67 |
+
...
|
68 |
+
</code>
|
69 |
+
|
70 |
+
Args:
|
71 |
+
group: a parameter group, which is a list of parameters; should be
|
72 |
+
one of self.param_groups.
|
73 |
+
group_params_names: name for each parameter in group,
|
74 |
+
which is List[str].
|
75 |
+
"""
|
76 |
+
batches = defaultdict(
|
77 |
+
list
|
78 |
+
) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
|
79 |
+
batches_names = defaultdict(
|
80 |
+
list
|
81 |
+
) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
|
82 |
+
|
83 |
+
assert len(param_group) == len(group_params_names)
|
84 |
+
for p, named_p in zip(param_group, group_params_names):
|
85 |
+
key = (str(p.dtype), *p.shape)
|
86 |
+
batches[key].append(p)
|
87 |
+
batches_names[key].append(named_p)
|
88 |
+
|
89 |
+
batches_names_keys = list(batches_names.keys())
|
90 |
+
sorted_idx = sorted(
|
91 |
+
range(len(batches_names)), key=lambda i: batches_names_keys[i]
|
92 |
+
)
|
93 |
+
batches_names = [
|
94 |
+
batches_names[batches_names_keys[idx]] for idx in sorted_idx
|
95 |
+
]
|
96 |
+
batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
|
97 |
+
|
98 |
+
stacked_params_dict = dict()
|
99 |
+
|
100 |
+
# turn batches into a list, in deterministic order.
|
101 |
+
# tuples will contain tuples of (stacked_param, state, stacked_params_names),
|
102 |
+
# one for each batch in `batches`.
|
103 |
+
tuples = []
|
104 |
+
|
105 |
+
for batch, batch_names in zip(batches, batches_names):
|
106 |
+
p = batch[0]
|
107 |
+
# we arbitrarily store the state in the
|
108 |
+
# state corresponding to the 1st parameter in the
|
109 |
+
# group. class Optimizer will take care of saving/loading state.
|
110 |
+
state = self.state[p]
|
111 |
+
p_stacked = torch.stack(batch)
|
112 |
+
grad = torch.stack(
|
113 |
+
[
|
114 |
+
torch.zeros_like(p) if p.grad is None else p.grad
|
115 |
+
for p in batch
|
116 |
+
]
|
117 |
+
)
|
118 |
+
p_stacked.grad = grad
|
119 |
+
stacked_params_dict[key] = p_stacked
|
120 |
+
tuples.append((p_stacked, state, batch_names))
|
121 |
+
|
122 |
+
yield tuples # <-- calling code will do the actual optimization here!
|
123 |
+
|
124 |
+
for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
|
125 |
+
for i, p in enumerate(batch): # batch is list of Parameter
|
126 |
+
p.copy_(stacked_params[i])
|
127 |
+
|
128 |
+
|
129 |
+
class ScaledAdam(BatchedOptimizer):
|
130 |
+
"""
|
131 |
+
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
|
132 |
+
proportional to the norm of that parameter; and also learn the scale of the parameter,
|
133 |
+
in log space, subject to upper and lower limits (as if we had factored each parameter as
|
134 |
+
param = underlying_param * log_scale.exp())
|
135 |
+
|
136 |
+
|
137 |
+
Args:
|
138 |
+
params: The parameters or param_groups to optimize (like other Optimizer subclasses)
|
139 |
+
lr: The learning rate. We will typically use a learning rate schedule that starts
|
140 |
+
at 0.03 and decreases over time, i.e. much higher than other common
|
141 |
+
optimizers.
|
142 |
+
clipping_scale: (e.g. 2.0)
|
143 |
+
A scale for gradient-clipping: if specified, the normalized gradients
|
144 |
+
over the whole model will be clipped to have 2-norm equal to
|
145 |
+
`clipping_scale` times the median 2-norm over the most recent period
|
146 |
+
of `clipping_update_period` minibatches. By "normalized gradients",
|
147 |
+
we mean after multiplying by the rms parameter value for this tensor
|
148 |
+
[for non-scalars]; this is appropriate because our update is scaled
|
149 |
+
by this quantity.
|
150 |
+
betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
|
151 |
+
Must satisfy 0 < beta <= beta2 < 1.
|
152 |
+
scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
|
153 |
+
scale of each parameter tensor and scalar parameters of the mode..
|
154 |
+
If each parameter were decomposed
|
155 |
+
as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
|
156 |
+
would be a the scaling factor on the learning rate of p_scale.
|
157 |
+
eps: A general-purpose epsilon to prevent division by zero
|
158 |
+
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
|
159 |
+
learning the scale on the parameters (we'll constrain the rms of each non-scalar
|
160 |
+
parameter tensor to be >= this value)
|
161 |
+
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
|
162 |
+
learning the scale on the parameters (we'll constrain the rms of each non-scalar
|
163 |
+
parameter tensor to be <= this value)
|
164 |
+
scalar_max: Maximum absolute value for scalar parameters (applicable if your
|
165 |
+
model has any parameters with numel() == 1).
|
166 |
+
size_update_period: The periodicity, in steps, with which we update the size (scale)
|
167 |
+
of the parameter tensor. This is provided to save a little time
|
168 |
+
in the update.
|
169 |
+
clipping_update_period: if clipping_scale is specified, this is the period
|
170 |
+
"""
|
171 |
+
|
172 |
+
def __init__(
|
173 |
+
self,
|
174 |
+
params,
|
175 |
+
lr=3e-02,
|
176 |
+
clipping_scale=None,
|
177 |
+
betas=(0.9, 0.98),
|
178 |
+
scalar_lr_scale=0.1,
|
179 |
+
eps=1.0e-08,
|
180 |
+
param_min_rms=1.0e-05,
|
181 |
+
param_max_rms=3.0,
|
182 |
+
scalar_max=10.0,
|
183 |
+
size_update_period=4,
|
184 |
+
clipping_update_period=100,
|
185 |
+
parameters_names=None,
|
186 |
+
show_dominant_parameters=True,
|
187 |
+
):
|
188 |
+
|
189 |
+
assert parameters_names is not None, (
|
190 |
+
"Please prepare parameters_names,"
|
191 |
+
"which is a List[List[str]]. Each List[str] is for a group"
|
192 |
+
"and each str is for a parameter"
|
193 |
+
)
|
194 |
+
defaults = dict(
|
195 |
+
lr=lr,
|
196 |
+
clipping_scale=clipping_scale,
|
197 |
+
betas=betas,
|
198 |
+
scalar_lr_scale=scalar_lr_scale,
|
199 |
+
eps=eps,
|
200 |
+
param_min_rms=param_min_rms,
|
201 |
+
param_max_rms=param_max_rms,
|
202 |
+
scalar_max=scalar_max,
|
203 |
+
size_update_period=size_update_period,
|
204 |
+
clipping_update_period=clipping_update_period,
|
205 |
+
)
|
206 |
+
|
207 |
+
super(ScaledAdam, self).__init__(params, defaults)
|
208 |
+
assert len(self.param_groups) == len(parameters_names)
|
209 |
+
self.parameters_names = parameters_names
|
210 |
+
self.show_dominant_parameters = show_dominant_parameters
|
211 |
+
|
212 |
+
def __setstate__(self, state):
|
213 |
+
super(ScaledAdam, self).__setstate__(state)
|
214 |
+
|
215 |
+
@torch.no_grad()
|
216 |
+
def step(self, closure=None):
|
217 |
+
"""Performs a single optimization step.
|
218 |
+
|
219 |
+
Arguments:
|
220 |
+
closure (callable, optional): A closure that reevaluates the model
|
221 |
+
and returns the loss.
|
222 |
+
"""
|
223 |
+
loss = None
|
224 |
+
if closure is not None:
|
225 |
+
with torch.enable_grad():
|
226 |
+
loss = closure()
|
227 |
+
|
228 |
+
batch = True
|
229 |
+
|
230 |
+
for group, group_params_names in zip(
|
231 |
+
self.param_groups, self.parameters_names
|
232 |
+
):
|
233 |
+
|
234 |
+
with self.batched_params(
|
235 |
+
group["params"], group_params_names
|
236 |
+
) as batches:
|
237 |
+
|
238 |
+
# batches is list of pairs (stacked_param, state). stacked_param is like
|
239 |
+
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
|
240 |
+
# a stacking dim, it is not a real dim.
|
241 |
+
|
242 |
+
if (
|
243 |
+
len(batches[0][1]) == 0
|
244 |
+
): # if len(first state) == 0: not yet initialized
|
245 |
+
clipping_scale = 1
|
246 |
+
else:
|
247 |
+
clipping_scale = self._get_clipping_scale(group, batches)
|
248 |
+
|
249 |
+
for p, state, _ in batches:
|
250 |
+
# Perform optimization step.
|
251 |
+
# grad is not going to be None, we handled that when creating the batches.
|
252 |
+
grad = p.grad
|
253 |
+
if grad.is_sparse:
|
254 |
+
raise RuntimeError(
|
255 |
+
"ScaledAdam optimizer does not support sparse gradients"
|
256 |
+
)
|
257 |
+
# State initialization
|
258 |
+
if len(state) == 0:
|
259 |
+
self._init_state(group, p, state)
|
260 |
+
|
261 |
+
self._step_one_batch(group, p, state, clipping_scale)
|
262 |
+
|
263 |
+
return loss
|
264 |
+
|
265 |
+
def _init_state(self, group: dict, p: Tensor, state: dict):
|
266 |
+
"""
|
267 |
+
Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
|
268 |
+
is actually the batch dimension, corresponding to batched-together
|
269 |
+
parameters of a given shape.
|
270 |
+
|
271 |
+
|
272 |
+
Args:
|
273 |
+
group: Dict to look up configuration values.
|
274 |
+
p: The parameter that we are initializing the state for
|
275 |
+
state: Dict from string to whatever state we are initializing
|
276 |
+
"""
|
277 |
+
size_update_period = group["size_update_period"]
|
278 |
+
|
279 |
+
state["step"] = 0
|
280 |
+
|
281 |
+
kwargs = {"device": p.device, "dtype": p.dtype}
|
282 |
+
|
283 |
+
# 'delta' implements conventional momentum. There are
|
284 |
+
# several different kinds of update going on, so rather than
|
285 |
+
# compute "exp_avg" like in Adam, we store and decay a
|
286 |
+
# parameter-change "delta", which combines all forms of
|
287 |
+
# update. this is equivalent to how it's done in Adam,
|
288 |
+
# except for the first few steps.
|
289 |
+
state["delta"] = torch.zeros_like(
|
290 |
+
p, memory_format=torch.preserve_format
|
291 |
+
)
|
292 |
+
|
293 |
+
batch_size = p.shape[0]
|
294 |
+
numel = p.numel() // batch_size
|
295 |
+
numel = p.numel()
|
296 |
+
|
297 |
+
if numel > 1:
|
298 |
+
# "param_rms" just periodically records the scalar root-mean-square value of
|
299 |
+
# the parameter tensor.
|
300 |
+
# it has a shape like (batch_size, 1, 1, 1, 1)
|
301 |
+
param_rms = (
|
302 |
+
(p ** 2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
|
303 |
+
)
|
304 |
+
state["param_rms"] = param_rms
|
305 |
+
|
306 |
+
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
|
307 |
+
state["scale_grads"] = torch.zeros(
|
308 |
+
size_update_period, *param_rms.shape, **kwargs
|
309 |
+
)
|
310 |
+
|
311 |
+
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
|
312 |
+
state["exp_avg_sq"] = torch.zeros_like(
|
313 |
+
p, memory_format=torch.preserve_format
|
314 |
+
)
|
315 |
+
|
316 |
+
def _get_clipping_scale(
|
317 |
+
self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]
|
318 |
+
) -> float:
|
319 |
+
"""
|
320 |
+
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
|
321 |
+
by this amount before applying the rest of the update.
|
322 |
+
|
323 |
+
Args:
|
324 |
+
group: the parameter group, an item in self.param_groups
|
325 |
+
tuples: a list of tuples of (param, state, param_names)
|
326 |
+
where param is a batched set of parameters,
|
327 |
+
with a .grad (1st dim is batch dim)
|
328 |
+
and state is the state-dict where optimization parameters are kept.
|
329 |
+
param_names is a List[str] while each str is name for a parameter
|
330 |
+
in batched set of parameters "param".
|
331 |
+
"""
|
332 |
+
assert len(tuples) >= 1
|
333 |
+
clipping_scale = group["clipping_scale"]
|
334 |
+
(first_p, first_state, _) = tuples[0]
|
335 |
+
step = first_state["step"]
|
336 |
+
if clipping_scale is None or step == 0:
|
337 |
+
# no clipping. return early on step == 0 because the other
|
338 |
+
# parameters' state won't have been initialized yet.
|
339 |
+
return 1.0
|
340 |
+
clipping_update_period = group["clipping_update_period"]
|
341 |
+
|
342 |
+
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
343 |
+
for (p, state, param_names) in tuples:
|
344 |
+
grad = p.grad
|
345 |
+
if grad.is_sparse:
|
346 |
+
raise RuntimeError(
|
347 |
+
"ScaledAdam optimizer does not support sparse gradients"
|
348 |
+
)
|
349 |
+
if p.numel() == p.shape[0]: # a batch of scalars
|
350 |
+
tot_sumsq += (
|
351 |
+
grad ** 2
|
352 |
+
).sum() # sum() to change shape [1] to []
|
353 |
+
else:
|
354 |
+
tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
|
355 |
+
|
356 |
+
tot_norm = tot_sumsq.sqrt()
|
357 |
+
if "model_norms" not in first_state:
|
358 |
+
first_state["model_norms"] = torch.zeros(
|
359 |
+
clipping_update_period, device=p.device
|
360 |
+
)
|
361 |
+
first_state["model_norms"][step % clipping_update_period] = tot_norm
|
362 |
+
|
363 |
+
if step % clipping_update_period == 0:
|
364 |
+
# Print some stats.
|
365 |
+
# We don't reach here if step == 0 because we would have returned
|
366 |
+
# above.
|
367 |
+
sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
|
368 |
+
quartiles = []
|
369 |
+
for n in range(0, 5):
|
370 |
+
index = min(
|
371 |
+
clipping_update_period - 1,
|
372 |
+
(clipping_update_period // 4) * n,
|
373 |
+
)
|
374 |
+
quartiles.append(sorted_norms[index].item())
|
375 |
+
|
376 |
+
median = quartiles[2]
|
377 |
+
threshold = clipping_scale * median
|
378 |
+
first_state["model_norm_threshold"] = threshold
|
379 |
+
percent_clipped = (
|
380 |
+
first_state["num_clipped"] * 100.0 / clipping_update_period
|
381 |
+
if "num_clipped" in first_state
|
382 |
+
else 0.0
|
383 |
+
)
|
384 |
+
first_state["num_clipped"] = 0
|
385 |
+
quartiles = " ".join(["%.3e" % x for x in quartiles])
|
386 |
+
logging.info(
|
387 |
+
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
|
388 |
+
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
|
389 |
+
)
|
390 |
+
|
391 |
+
if step < clipping_update_period:
|
392 |
+
return 1.0 # We have not yet estimated a norm to clip to.
|
393 |
+
else:
|
394 |
+
try:
|
395 |
+
model_norm_threshold = first_state["model_norm_threshold"]
|
396 |
+
except KeyError:
|
397 |
+
logging.info(
|
398 |
+
"Warning: model_norm_threshold not in state: possibly "
|
399 |
+
"you changed config when restarting, adding clipping_scale option?"
|
400 |
+
)
|
401 |
+
return 1.0
|
402 |
+
ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
|
403 |
+
if ans < 1.0:
|
404 |
+
first_state["num_clipped"] += 1
|
405 |
+
if ans < 0.1:
|
406 |
+
logging.warn(
|
407 |
+
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
|
408 |
+
)
|
409 |
+
if self.show_dominant_parameters:
|
410 |
+
assert p.shape[0] == len(param_names)
|
411 |
+
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
|
412 |
+
return ans
|
413 |
+
|
414 |
+
def _show_gradient_dominating_parameter(
|
415 |
+
self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor
|
416 |
+
):
|
417 |
+
"""
|
418 |
+
Show information of parameter wihch dominanting tot_sumsq.
|
419 |
+
|
420 |
+
Args:
|
421 |
+
tuples: a list of tuples of (param, state, param_names)
|
422 |
+
where param is a batched set of parameters,
|
423 |
+
with a .grad (1st dim is batch dim)
|
424 |
+
and state is the state-dict where optimization parameters are kept.
|
425 |
+
param_names is a List[str] while each str is name for a parameter
|
426 |
+
in batched set of parameters "param".
|
427 |
+
tot_sumsq: sumsq of all parameters. Though it's could be calculated
|
428 |
+
from tuples, we still pass it to save some time.
|
429 |
+
"""
|
430 |
+
all_sumsq_orig = {}
|
431 |
+
for (p, state, batch_param_names) in tuples:
|
432 |
+
# p is a stacked batch parameters.
|
433 |
+
batch_grad = p.grad
|
434 |
+
if p.numel() == p.shape[0]: # a batch of scalars
|
435 |
+
batch_sumsq_orig = batch_grad ** 2
|
436 |
+
# Dummpy values used by following `zip` statement.
|
437 |
+
batch_rms_orig = torch.ones(p.shape[0])
|
438 |
+
else:
|
439 |
+
batch_rms_orig = state["param_rms"]
|
440 |
+
batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(
|
441 |
+
dim=list(range(1, batch_grad.ndim))
|
442 |
+
)
|
443 |
+
|
444 |
+
for name, sumsq_orig, rms, grad in zip(
|
445 |
+
batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
|
446 |
+
):
|
447 |
+
|
448 |
+
proportion_orig = sumsq_orig / tot_sumsq
|
449 |
+
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
|
450 |
+
|
451 |
+
assert torch.isclose(
|
452 |
+
sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
|
453 |
+
torch.tensor(1.0),
|
454 |
+
)
|
455 |
+
sorted_by_proportion = {
|
456 |
+
k: v
|
457 |
+
for k, v in sorted(
|
458 |
+
all_sumsq_orig.items(),
|
459 |
+
key=lambda item: item[1][0],
|
460 |
+
reverse=True,
|
461 |
+
)
|
462 |
+
}
|
463 |
+
dominant_param_name = next(iter(sorted_by_proportion))
|
464 |
+
(
|
465 |
+
dominant_proportion,
|
466 |
+
dominant_sumsq,
|
467 |
+
dominant_rms,
|
468 |
+
dominant_grad,
|
469 |
+
) = sorted_by_proportion[dominant_param_name]
|
470 |
+
logging.info(
|
471 |
+
f"Parameter Dominanting tot_sumsq {dominant_param_name}"
|
472 |
+
f" with proportion {dominant_proportion:.2f},"
|
473 |
+
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
|
474 |
+
f"={dominant_sumsq:.3e},"
|
475 |
+
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
|
476 |
+
f" orig_rms_sq={(dominant_rms**2).item():.3e}"
|
477 |
+
)
|
478 |
+
|
479 |
+
def _step_one_batch(
|
480 |
+
self, group: dict, p: Tensor, state: dict, clipping_scale: float
|
481 |
+
):
|
482 |
+
"""
|
483 |
+
Do the step for one parameter, which is actually going to be a batch of
|
484 |
+
`real` parameters, with dim 0 as the batch dim.
|
485 |
+
Args:
|
486 |
+
group: dict to look up configuration values
|
487 |
+
p: parameter to update (actually multiple parameters stacked together
|
488 |
+
as a batch)
|
489 |
+
state: state-dict for p, to look up the optimizer state
|
490 |
+
"""
|
491 |
+
lr = group["lr"]
|
492 |
+
size_update_period = group["size_update_period"]
|
493 |
+
beta1 = group["betas"][0]
|
494 |
+
|
495 |
+
grad = p.grad
|
496 |
+
if clipping_scale != 1.0:
|
497 |
+
grad = grad * clipping_scale
|
498 |
+
step = state["step"]
|
499 |
+
delta = state["delta"]
|
500 |
+
|
501 |
+
delta.mul_(beta1)
|
502 |
+
batch_size = p.shape[0]
|
503 |
+
numel = p.numel() // batch_size
|
504 |
+
if numel > 1:
|
505 |
+
# Update the size/scale of p, and set param_rms
|
506 |
+
scale_grads = state["scale_grads"]
|
507 |
+
scale_grads[step % size_update_period] = (p * grad).sum(
|
508 |
+
dim=list(range(1, p.ndim)), keepdim=True
|
509 |
+
)
|
510 |
+
if step % size_update_period == size_update_period - 1:
|
511 |
+
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
|
512 |
+
param_rms.copy_(
|
513 |
+
(p ** 2)
|
514 |
+
.mean(dim=list(range(1, p.ndim)), keepdim=True)
|
515 |
+
.sqrt()
|
516 |
+
)
|
517 |
+
if step > 0:
|
518 |
+
# self._size_update() learns the overall scale on the
|
519 |
+
# parameter, by shrinking or expanding it.
|
520 |
+
self._size_update(group, scale_grads, p, state)
|
521 |
+
|
522 |
+
if numel == 1:
|
523 |
+
# For parameters with 1 element we just use regular Adam.
|
524 |
+
# Updates delta.
|
525 |
+
self._step_scalar(group, p, state)
|
526 |
+
else:
|
527 |
+
self._step(group, p, state)
|
528 |
+
|
529 |
+
state["step"] = step + 1
|
530 |
+
|
531 |
+
def _size_update(
|
532 |
+
self, group: dict, scale_grads: Tensor, p: Tensor, state: dict
|
533 |
+
) -> None:
|
534 |
+
"""
|
535 |
+
Called only where p.numel() > 1, this updates the scale of the parameter.
|
536 |
+
If we imagine: p = underlying_param * scale.exp(), and we are doing
|
537 |
+
gradient descent on underlying param and on scale, this function does the update
|
538 |
+
on `scale`.
|
539 |
+
|
540 |
+
Args:
|
541 |
+
group: dict to look up configuration values
|
542 |
+
scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
|
543 |
+
grads w.r.t. the scales.
|
544 |
+
p: The parameter to update
|
545 |
+
state: The state-dict of p
|
546 |
+
"""
|
547 |
+
|
548 |
+
param_rms = state["param_rms"]
|
549 |
+
beta1, beta2 = group["betas"]
|
550 |
+
size_lr = group["lr"] * group["scalar_lr_scale"]
|
551 |
+
param_min_rms = group["param_min_rms"]
|
552 |
+
param_max_rms = group["param_max_rms"]
|
553 |
+
eps = group["eps"]
|
554 |
+
step = state["step"]
|
555 |
+
batch_size = p.shape[0]
|
556 |
+
|
557 |
+
size_update_period = scale_grads.shape[0]
|
558 |
+
# correct beta2 for the size update period: we will have
|
559 |
+
# faster decay at this level.
|
560 |
+
beta2_corr = beta2 ** size_update_period
|
561 |
+
|
562 |
+
scale_exp_avg_sq = state[
|
563 |
+
"scale_exp_avg_sq"
|
564 |
+
] # shape: (batch_size, 1, 1, ..)
|
565 |
+
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
566 |
+
(scale_grads ** 2).mean(
|
567 |
+
dim=0
|
568 |
+
), # mean over dim `size_update_period`
|
569 |
+
alpha=1 - beta2_corr,
|
570 |
+
) # shape is (batch_size, 1, 1, ...)
|
571 |
+
|
572 |
+
# The 1st time we reach here is when size_step == 1.
|
573 |
+
size_step = (step + 1) // size_update_period
|
574 |
+
bias_correction2 = 1 - beta2_corr ** size_step
|
575 |
+
# we don't bother with bias_correction1; this will help prevent divergence
|
576 |
+
# at the start of training.
|
577 |
+
|
578 |
+
denom = scale_exp_avg_sq.sqrt() + eps
|
579 |
+
|
580 |
+
scale_step = (
|
581 |
+
-size_lr
|
582 |
+
* (bias_correction2 ** 0.5)
|
583 |
+
* scale_grads.sum(dim=0)
|
584 |
+
/ denom
|
585 |
+
)
|
586 |
+
|
587 |
+
is_too_small = param_rms < param_min_rms
|
588 |
+
is_too_large = param_rms > param_max_rms
|
589 |
+
|
590 |
+
# when the param gets too small, just don't shrink it any further.
|
591 |
+
scale_step.masked_fill_(is_too_small, 0.0)
|
592 |
+
# when it gets too large, stop it from getting any larger.
|
593 |
+
scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
|
594 |
+
delta = state["delta"]
|
595 |
+
# the factor of (1-beta1) relates to momentum.
|
596 |
+
delta.add_(p * scale_step, alpha=(1 - beta1))
|
597 |
+
|
598 |
+
def _step(self, group: dict, p: Tensor, state: dict):
|
599 |
+
"""
|
600 |
+
This function does the core update of self.step(), in the case where the members of
|
601 |
+
the batch have more than 1 element.
|
602 |
+
|
603 |
+
Args:
|
604 |
+
group: A dict which will be used to look up configuration values
|
605 |
+
p: The parameter to be updated
|
606 |
+
grad: The grad of p
|
607 |
+
state: The state-dict corresponding to parameter p
|
608 |
+
|
609 |
+
This function modifies p.
|
610 |
+
"""
|
611 |
+
grad = p.grad
|
612 |
+
lr = group["lr"]
|
613 |
+
beta1, beta2 = group["betas"]
|
614 |
+
eps = group["eps"]
|
615 |
+
param_min_rms = group["param_min_rms"]
|
616 |
+
step = state["step"]
|
617 |
+
|
618 |
+
exp_avg_sq = state["exp_avg_sq"]
|
619 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
|
620 |
+
|
621 |
+
this_step = state["step"] - (
|
622 |
+
state["zero_step"] if "zero_step" in state else 0
|
623 |
+
)
|
624 |
+
bias_correction2 = 1 - beta2 ** (this_step + 1)
|
625 |
+
if bias_correction2 < 0.99:
|
626 |
+
# note: not in-place.
|
627 |
+
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
|
628 |
+
|
629 |
+
denom = exp_avg_sq.sqrt()
|
630 |
+
denom += eps
|
631 |
+
grad = grad / denom
|
632 |
+
|
633 |
+
alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
|
634 |
+
|
635 |
+
delta = state["delta"]
|
636 |
+
delta.add_(grad * alpha)
|
637 |
+
p.add_(delta)
|
638 |
+
|
639 |
+
def _step_scalar(self, group: dict, p: Tensor, state: dict):
|
640 |
+
"""
|
641 |
+
A simplified form of the core update for scalar tensors, where we cannot get a good
|
642 |
+
estimate of the parameter rms.
|
643 |
+
"""
|
644 |
+
beta1, beta2 = group["betas"]
|
645 |
+
scalar_max = group["scalar_max"]
|
646 |
+
eps = group["eps"]
|
647 |
+
lr = group["lr"] * group["scalar_lr_scale"]
|
648 |
+
grad = p.grad
|
649 |
+
|
650 |
+
exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
|
651 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
652 |
+
|
653 |
+
# bias_correction2 is like in Adam. Don't bother with bias_correction1;
|
654 |
+
# slower update at the start will help stability anyway.
|
655 |
+
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
|
656 |
+
denom = (exp_avg_sq / bias_correction2).sqrt() + eps
|
657 |
+
|
658 |
+
delta = state["delta"]
|
659 |
+
delta.add_(grad / denom, alpha=-lr * (1 - beta1))
|
660 |
+
p.clamp_(min=-scalar_max, max=scalar_max)
|
661 |
+
p.add_(delta)
|
662 |
+
|
663 |
+
|
664 |
+
class LRScheduler(object):
|
665 |
+
"""
|
666 |
+
Base-class for learning rate schedulers where the learning-rate depends on both the
|
667 |
+
batch and the epoch.
|
668 |
+
"""
|
669 |
+
|
670 |
+
def __init__(self, optimizer: Optimizer, verbose: bool = False):
|
671 |
+
# Attach optimizer
|
672 |
+
if not isinstance(optimizer, Optimizer):
|
673 |
+
raise TypeError(
|
674 |
+
"{} is not an Optimizer".format(type(optimizer).__name__)
|
675 |
+
)
|
676 |
+
self.optimizer = optimizer
|
677 |
+
self.verbose = verbose
|
678 |
+
|
679 |
+
for group in optimizer.param_groups:
|
680 |
+
group.setdefault("base_lr", group["lr"])
|
681 |
+
|
682 |
+
self.base_lrs = [group["base_lr"] for group in optimizer.param_groups]
|
683 |
+
|
684 |
+
self.epoch = 0
|
685 |
+
self.batch = 0
|
686 |
+
|
687 |
+
def state_dict(self):
|
688 |
+
"""Returns the state of the scheduler as a :class:`dict`.
|
689 |
+
|
690 |
+
It contains an entry for every variable in self.__dict__ which
|
691 |
+
is not the optimizer.
|
692 |
+
"""
|
693 |
+
return {
|
694 |
+
"base_lrs": self.base_lrs,
|
695 |
+
"epoch": self.epoch,
|
696 |
+
"batch": self.batch,
|
697 |
+
}
|
698 |
+
|
699 |
+
def load_state_dict(self, state_dict):
|
700 |
+
"""Loads the schedulers state.
|
701 |
+
|
702 |
+
Args:
|
703 |
+
state_dict (dict): scheduler state. Should be an object returned
|
704 |
+
from a call to :meth:`state_dict`.
|
705 |
+
"""
|
706 |
+
self.__dict__.update(state_dict)
|
707 |
+
|
708 |
+
def get_last_lr(self) -> List[float]:
|
709 |
+
"""Return last computed learning rate by current scheduler. Will be a list of float."""
|
710 |
+
return self._last_lr
|
711 |
+
|
712 |
+
def get_lr(self):
|
713 |
+
# Compute list of learning rates from self.epoch and self.batch and
|
714 |
+
# self.base_lrs; this must be overloaded by the user.
|
715 |
+
# e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
|
716 |
+
raise NotImplementedError
|
717 |
+
|
718 |
+
def step_batch(self, batch: Optional[int] = None) -> None:
|
719 |
+
# Step the batch index, or just set it. If `batch` is specified, it
|
720 |
+
# must be the batch index from the start of training, i.e. summed over
|
721 |
+
# all epochs.
|
722 |
+
# You can call this in any order; if you don't provide 'batch', it should
|
723 |
+
# of course be called once per batch.
|
724 |
+
if batch is not None:
|
725 |
+
self.batch = batch
|
726 |
+
else:
|
727 |
+
self.batch = self.batch + 1
|
728 |
+
self._set_lrs()
|
729 |
+
|
730 |
+
def step_epoch(self, epoch: Optional[int] = None):
|
731 |
+
# Step the epoch index, or just set it. If you provide the 'epoch' arg,
|
732 |
+
# you should call this at the start of the epoch; if you don't provide the 'epoch'
|
733 |
+
# arg, you should call it at the end of the epoch.
|
734 |
+
if epoch is not None:
|
735 |
+
self.epoch = epoch
|
736 |
+
else:
|
737 |
+
self.epoch = self.epoch + 1
|
738 |
+
self._set_lrs()
|
739 |
+
|
740 |
+
def _set_lrs(self):
|
741 |
+
values = self.get_lr()
|
742 |
+
assert len(values) == len(self.optimizer.param_groups)
|
743 |
+
|
744 |
+
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
|
745 |
+
param_group, lr = data
|
746 |
+
param_group["lr"] = lr
|
747 |
+
self.print_lr(self.verbose, i, lr)
|
748 |
+
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
|
749 |
+
|
750 |
+
def print_lr(self, is_verbose, group, lr):
|
751 |
+
"""Display the current learning rate."""
|
752 |
+
if is_verbose:
|
753 |
+
logging.info(
|
754 |
+
f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
|
755 |
+
f" of group {group} to {lr:.4e}."
|
756 |
+
)
|
757 |
+
|
758 |
+
|
759 |
+
class Eden(LRScheduler):
|
760 |
+
"""
|
761 |
+
Eden scheduler.
|
762 |
+
The basic formula (before warmup) is:
|
763 |
+
lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 *
|
764 |
+
(((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup
|
765 |
+
where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
|
766 |
+
and then stays constant at 1.
|
767 |
+
|
768 |
+
|
769 |
+
E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
|
770 |
+
|
771 |
+
Args:
|
772 |
+
optimizer: the optimizer to change the learning rates on
|
773 |
+
lr_batches: the number of batches after which we start significantly
|
774 |
+
decreasing the learning rate, suggest 5000.
|
775 |
+
lr_epochs: the number of epochs after which we start significantly
|
776 |
+
decreasing the learning rate, suggest 6 if you plan to do e.g.
|
777 |
+
20 to 40 epochs, but may need smaller number if dataset is huge
|
778 |
+
and you will do few epochs.
|
779 |
+
"""
|
780 |
+
|
781 |
+
def __init__(
|
782 |
+
self,
|
783 |
+
optimizer: Optimizer,
|
784 |
+
lr_batches: Union[int, float],
|
785 |
+
lr_epochs: Union[int, float],
|
786 |
+
warmup_batches: Union[int, float] = 500.0,
|
787 |
+
verbose: bool = False,
|
788 |
+
):
|
789 |
+
super(Eden, self).__init__(optimizer, verbose)
|
790 |
+
self.lr_batches = lr_batches
|
791 |
+
self.lr_epochs = lr_epochs
|
792 |
+
self.warmup_batches = warmup_batches
|
793 |
+
|
794 |
+
def get_lr(self):
|
795 |
+
factor = (
|
796 |
+
(self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2
|
797 |
+
) ** -0.25 * (
|
798 |
+
((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2)
|
799 |
+
** -0.25
|
800 |
+
)
|
801 |
+
warmup_factor = (
|
802 |
+
1.0
|
803 |
+
if self.batch >= self.warmup_batches
|
804 |
+
else 0.5 + 0.5 * (self.batch / self.warmup_batches)
|
805 |
+
)
|
806 |
+
|
807 |
+
return [x * factor * warmup_factor for x in self.base_lrs]
|
808 |
+
|
809 |
+
|
810 |
+
def _test_eden():
|
811 |
+
m = torch.nn.Linear(100, 100)
|
812 |
+
optim = ScaledAdam(m.parameters(), lr=0.03)
|
813 |
+
|
814 |
+
scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True)
|
815 |
+
|
816 |
+
for epoch in range(10):
|
817 |
+
scheduler.step_epoch(epoch) # sets epoch to `epoch`
|
818 |
+
|
819 |
+
for step in range(20):
|
820 |
+
x = torch.randn(200, 100).detach()
|
821 |
+
x.requires_grad = True
|
822 |
+
y = m(x)
|
823 |
+
dy = torch.randn(200, 100).detach()
|
824 |
+
f = (y * dy).sum()
|
825 |
+
f.backward()
|
826 |
+
|
827 |
+
optim.step()
|
828 |
+
scheduler.step_batch()
|
829 |
+
optim.zero_grad()
|
830 |
+
|
831 |
+
logging.info(f"last lr = {scheduler.get_last_lr()}")
|
832 |
+
logging.info(f"state dict = {scheduler.state_dict()}")
|
833 |
+
|
834 |
+
|
835 |
+
# This is included mostly as a baseline for ScaledAdam.
|
836 |
+
class Eve(Optimizer):
|
837 |
+
"""
|
838 |
+
Implements Eve algorithm. This is a modified version of AdamW with a special
|
839 |
+
way of setting the weight-decay / shrinkage-factor, which is designed to make the
|
840 |
+
rms of the parameters approach a particular target_rms (default: 0.1). This is
|
841 |
+
for use with networks with 'scaled' versions of modules (see scaling.py), which
|
842 |
+
will be close to invariant to the absolute scale on the parameter matrix.
|
843 |
+
|
844 |
+
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
|
845 |
+
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
|
846 |
+
Eve is unpublished so far.
|
847 |
+
|
848 |
+
Arguments:
|
849 |
+
params (iterable): iterable of parameters to optimize or dicts defining
|
850 |
+
parameter groups
|
851 |
+
lr (float, optional): learning rate (default: 1e-3)
|
852 |
+
betas (Tuple[float, float], optional): coefficients used for computing
|
853 |
+
running averages of gradient and its square (default: (0.9, 0.999))
|
854 |
+
eps (float, optional): term added to the denominator to improve
|
855 |
+
numerical stability (default: 1e-8)
|
856 |
+
weight_decay (float, optional): weight decay coefficient (default: 3e-4;
|
857 |
+
this value means that the weight would decay significantly after
|
858 |
+
about 3k minibatches. Is not multiplied by learning rate, but
|
859 |
+
is conditional on RMS-value of parameter being > target_rms.
|
860 |
+
target_rms (float, optional): target root-mean-square value of
|
861 |
+
parameters, if they fall below this we will stop applying weight decay.
|
862 |
+
|
863 |
+
|
864 |
+
.. _Adam: A Method for Stochastic Optimization:
|
865 |
+
https://arxiv.org/abs/1412.6980
|
866 |
+
.. _Decoupled Weight Decay Regularization:
|
867 |
+
https://arxiv.org/abs/1711.05101
|
868 |
+
.. _On the Convergence of Adam and Beyond:
|
869 |
+
https://openreview.net/forum?id=ryQu7f-RZ
|
870 |
+
"""
|
871 |
+
|
872 |
+
def __init__(
|
873 |
+
self,
|
874 |
+
params,
|
875 |
+
lr=1e-3,
|
876 |
+
betas=(0.9, 0.98),
|
877 |
+
eps=1e-8,
|
878 |
+
weight_decay=1e-3,
|
879 |
+
target_rms=0.1,
|
880 |
+
):
|
881 |
+
if not 0.0 <= lr:
|
882 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
883 |
+
if not 0.0 <= eps:
|
884 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
885 |
+
if not 0.0 <= betas[0] < 1.0:
|
886 |
+
raise ValueError(
|
887 |
+
"Invalid beta parameter at index 0: {}".format(betas[0])
|
888 |
+
)
|
889 |
+
if not 0.0 <= betas[1] < 1.0:
|
890 |
+
raise ValueError(
|
891 |
+
"Invalid beta parameter at index 1: {}".format(betas[1])
|
892 |
+
)
|
893 |
+
if not 0 <= weight_decay <= 0.1:
|
894 |
+
raise ValueError(
|
895 |
+
"Invalid weight_decay value: {}".format(weight_decay)
|
896 |
+
)
|
897 |
+
if not 0 < target_rms <= 10.0:
|
898 |
+
raise ValueError("Invalid target_rms value: {}".format(target_rms))
|
899 |
+
defaults = dict(
|
900 |
+
lr=lr,
|
901 |
+
betas=betas,
|
902 |
+
eps=eps,
|
903 |
+
weight_decay=weight_decay,
|
904 |
+
target_rms=target_rms,
|
905 |
+
)
|
906 |
+
super(Eve, self).__init__(params, defaults)
|
907 |
+
|
908 |
+
def __setstate__(self, state):
|
909 |
+
super(Eve, self).__setstate__(state)
|
910 |
+
|
911 |
+
@torch.no_grad()
|
912 |
+
def step(self, closure=None):
|
913 |
+
"""Performs a single optimization step.
|
914 |
+
|
915 |
+
Arguments:
|
916 |
+
closure (callable, optional): A closure that reevaluates the model
|
917 |
+
and returns the loss.
|
918 |
+
"""
|
919 |
+
loss = None
|
920 |
+
if closure is not None:
|
921 |
+
with torch.enable_grad():
|
922 |
+
loss = closure()
|
923 |
+
|
924 |
+
for group in self.param_groups:
|
925 |
+
for p in group["params"]:
|
926 |
+
if p.grad is None:
|
927 |
+
continue
|
928 |
+
|
929 |
+
# Perform optimization step
|
930 |
+
grad = p.grad
|
931 |
+
if grad.is_sparse:
|
932 |
+
raise RuntimeError(
|
933 |
+
"AdamW does not support sparse gradients"
|
934 |
+
)
|
935 |
+
|
936 |
+
state = self.state[p]
|
937 |
+
|
938 |
+
# State initialization
|
939 |
+
if len(state) == 0:
|
940 |
+
state["step"] = 0
|
941 |
+
# Exponential moving average of gradient values
|
942 |
+
state["exp_avg"] = torch.zeros_like(
|
943 |
+
p, memory_format=torch.preserve_format
|
944 |
+
)
|
945 |
+
# Exponential moving average of squared gradient values
|
946 |
+
state["exp_avg_sq"] = torch.zeros_like(
|
947 |
+
p, memory_format=torch.preserve_format
|
948 |
+
)
|
949 |
+
|
950 |
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
951 |
+
|
952 |
+
beta1, beta2 = group["betas"]
|
953 |
+
|
954 |
+
state["step"] += 1
|
955 |
+
bias_correction1 = 1 - beta1 ** state["step"]
|
956 |
+
bias_correction2 = 1 - beta2 ** state["step"]
|
957 |
+
|
958 |
+
# Decay the first and second moment running average coefficient
|
959 |
+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
960 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
961 |
+
denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(
|
962 |
+
group["eps"]
|
963 |
+
)
|
964 |
+
|
965 |
+
step_size = group["lr"] / bias_correction1
|
966 |
+
target_rms = group["target_rms"]
|
967 |
+
weight_decay = group["weight_decay"]
|
968 |
+
|
969 |
+
if p.numel() > 1:
|
970 |
+
# avoid applying this weight-decay on "scaling factors"
|
971 |
+
# (which are scalar).
|
972 |
+
is_above_target_rms = p.norm() > (
|
973 |
+
target_rms * (p.numel() ** 0.5)
|
974 |
+
)
|
975 |
+
p.mul_(1 - (weight_decay * is_above_target_rms))
|
976 |
+
|
977 |
+
p.addcdiv_(exp_avg, denom, value=-step_size)
|
978 |
+
|
979 |
+
# if random.random() < 0.0005:
|
980 |
+
# step = (exp_avg / denom) * step_size
|
981 |
+
# logging.info(
|
982 |
+
# f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}"
|
983 |
+
# )
|
984 |
+
|
985 |
+
return loss
|
986 |
+
|
987 |
+
|
988 |
+
def _test_scaled_adam(hidden_dim: int):
|
989 |
+
import timeit
|
990 |
+
|
991 |
+
from scaling import ScaledLinear
|
992 |
+
|
993 |
+
E = 100
|
994 |
+
B = 4
|
995 |
+
T = 2
|
996 |
+
logging.info("in test_eve_cain")
|
997 |
+
# device = torch.device('cuda')
|
998 |
+
device = torch.device("cpu")
|
999 |
+
dtype = torch.float32
|
1000 |
+
|
1001 |
+
fix_random_seed(42)
|
1002 |
+
# these input_magnitudes and output_magnitudes are to test that
|
1003 |
+
# Abel is working as we expect and is able to adjust scales of
|
1004 |
+
# different dims differently.
|
1005 |
+
input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
1006 |
+
output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
1007 |
+
|
1008 |
+
for iter in [1, 0]:
|
1009 |
+
fix_random_seed(42)
|
1010 |
+
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
|
1011 |
+
|
1012 |
+
m = torch.nn.Sequential(
|
1013 |
+
Linear(E, hidden_dim),
|
1014 |
+
torch.nn.PReLU(),
|
1015 |
+
Linear(hidden_dim, hidden_dim),
|
1016 |
+
torch.nn.PReLU(),
|
1017 |
+
Linear(hidden_dim, E),
|
1018 |
+
).to(device)
|
1019 |
+
|
1020 |
+
train_pairs = [
|
1021 |
+
(
|
1022 |
+
100.0
|
1023 |
+
* torch.randn(B, T, E, device=device, dtype=dtype)
|
1024 |
+
* input_magnitudes,
|
1025 |
+
torch.randn(B, T, E, device=device, dtype=dtype)
|
1026 |
+
* output_magnitudes,
|
1027 |
+
)
|
1028 |
+
for _ in range(20)
|
1029 |
+
]
|
1030 |
+
|
1031 |
+
if iter == 0:
|
1032 |
+
optim = Eve(m.parameters(), lr=0.003)
|
1033 |
+
elif iter == 1:
|
1034 |
+
optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0)
|
1035 |
+
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
|
1036 |
+
|
1037 |
+
start = timeit.default_timer()
|
1038 |
+
avg_loss = 0.0
|
1039 |
+
for epoch in range(180):
|
1040 |
+
scheduler.step_epoch()
|
1041 |
+
# if epoch == 100 and iter in [2,3]:
|
1042 |
+
# optim.reset_speedup() # check it doesn't crash.
|
1043 |
+
|
1044 |
+
# if epoch == 130:
|
1045 |
+
# opts = diagnostics.TensorDiagnosticOptions(
|
1046 |
+
# 2 ** 22
|
1047 |
+
# ) # allow 4 megabytes per sub-module
|
1048 |
+
# diagnostic = diagnostics.attach_diagnostics(m, opts)
|
1049 |
+
|
1050 |
+
for n, (x, y) in enumerate(train_pairs):
|
1051 |
+
y_out = m(x)
|
1052 |
+
loss = ((y_out - y) ** 2).mean() * 100.0
|
1053 |
+
if epoch == 0 and n == 0:
|
1054 |
+
avg_loss = loss.item()
|
1055 |
+
else:
|
1056 |
+
avg_loss = 0.98 * avg_loss + 0.02 * loss.item()
|
1057 |
+
if n == 0 and epoch % 5 == 0:
|
1058 |
+
# norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
|
1059 |
+
# norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
|
1060 |
+
# norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
|
1061 |
+
# norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item()
|
1062 |
+
# scale1 = '%.2e' % (m[0].weight_scale.exp().item())
|
1063 |
+
# scale1b = '%.2e' % (m[0].bias_scale.exp().item())
|
1064 |
+
# scale2 = '%.2e' % (m[2].weight_scale.exp().item())
|
1065 |
+
# scale2b = '%.2e' % (m[2].bias_scale.exp().item())
|
1066 |
+
lr = scheduler.get_last_lr()[0]
|
1067 |
+
logging.info(
|
1068 |
+
f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}"
|
1069 |
+
) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
|
1070 |
+
loss.log().backward()
|
1071 |
+
optim.step()
|
1072 |
+
optim.zero_grad()
|
1073 |
+
scheduler.step_batch()
|
1074 |
+
|
1075 |
+
# diagnostic.print_diagnostics()
|
1076 |
+
|
1077 |
+
stop = timeit.default_timer()
|
1078 |
+
logging.info(f"Iter={iter}, Time taken: {stop - start}")
|
1079 |
+
|
1080 |
+
logging.info(f"last lr = {scheduler.get_last_lr()}")
|
1081 |
+
# logging.info("state dict = ", scheduler.state_dict())
|
1082 |
+
# logging.info("optim state_dict = ", optim.state_dict())
|
1083 |
+
logging.info(f"input_magnitudes = {input_magnitudes}")
|
1084 |
+
logging.info(f"output_magnitudes = {output_magnitudes}")
|
1085 |
+
|
1086 |
+
|
1087 |
+
if __name__ == "__main__":
|
1088 |
+
torch.set_num_threads(1)
|
1089 |
+
torch.set_num_interop_threads(1)
|
1090 |
+
logging.getLogger().setLevel(logging.INFO)
|
1091 |
+
import subprocess
|
1092 |
+
|
1093 |
+
s = subprocess.check_output(
|
1094 |
+
"git status -uno .; git log -1; git diff HEAD .", shell=True
|
1095 |
+
)
|
1096 |
+
logging.info(s)
|
1097 |
+
import sys
|
1098 |
+
|
1099 |
+
if len(sys.argv) > 1:
|
1100 |
+
hidden_dim = int(sys.argv[1])
|
1101 |
+
else:
|
1102 |
+
hidden_dim = 200
|
1103 |
+
|
1104 |
+
_test_scaled_adam(hidden_dim)
|
1105 |
+
_test_eden()
|
modules/scaling.py
ADDED
@@ -0,0 +1,1401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
|
2 |
+
#
|
3 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
|
18 |
+
import collections
|
19 |
+
import logging
|
20 |
+
import random
|
21 |
+
import math
|
22 |
+
from functools import reduce
|
23 |
+
from itertools import repeat
|
24 |
+
from typing import Optional, Tuple, Union
|
25 |
+
|
26 |
+
import torch
|
27 |
+
import torch.nn as nn
|
28 |
+
import torch.nn.functional as F
|
29 |
+
from torch import Tensor
|
30 |
+
from torch.nn import Embedding as ScaledEmbedding
|
31 |
+
|
32 |
+
from utils import Transpose
|
33 |
+
|
34 |
+
|
35 |
+
class ActivationBalancerFunction(torch.autograd.Function):
|
36 |
+
@staticmethod
|
37 |
+
def forward(
|
38 |
+
ctx,
|
39 |
+
x: Tensor,
|
40 |
+
scale_factor: Tensor,
|
41 |
+
sign_factor: Optional[Tensor],
|
42 |
+
channel_dim: int,
|
43 |
+
) -> Tensor:
|
44 |
+
if channel_dim < 0:
|
45 |
+
channel_dim += x.ndim
|
46 |
+
ctx.channel_dim = channel_dim
|
47 |
+
xgt0 = x > 0
|
48 |
+
if sign_factor is None:
|
49 |
+
ctx.save_for_backward(xgt0, scale_factor)
|
50 |
+
else:
|
51 |
+
ctx.save_for_backward(xgt0, scale_factor, sign_factor)
|
52 |
+
return x
|
53 |
+
|
54 |
+
@staticmethod
|
55 |
+
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
|
56 |
+
if len(ctx.saved_tensors) == 3:
|
57 |
+
xgt0, scale_factor, sign_factor = ctx.saved_tensors
|
58 |
+
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
59 |
+
scale_factor = scale_factor.unsqueeze(-1)
|
60 |
+
sign_factor = sign_factor.unsqueeze(-1)
|
61 |
+
factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
62 |
+
else:
|
63 |
+
xgt0, scale_factor = ctx.saved_tensors
|
64 |
+
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
65 |
+
scale_factor = scale_factor.unsqueeze(-1)
|
66 |
+
factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
67 |
+
neg_delta_grad = x_grad.abs() * factor
|
68 |
+
return (
|
69 |
+
x_grad - neg_delta_grad,
|
70 |
+
None,
|
71 |
+
None,
|
72 |
+
None,
|
73 |
+
)
|
74 |
+
|
75 |
+
|
76 |
+
def _compute_scale_factor(
|
77 |
+
x: Tensor,
|
78 |
+
channel_dim: int,
|
79 |
+
min_abs: float,
|
80 |
+
max_abs: float,
|
81 |
+
gain_factor: float,
|
82 |
+
max_factor: float,
|
83 |
+
) -> Tensor:
|
84 |
+
if channel_dim < 0:
|
85 |
+
channel_dim += x.ndim
|
86 |
+
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
87 |
+
x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
|
88 |
+
|
89 |
+
if min_abs == 0.0:
|
90 |
+
below_threshold = 0.0
|
91 |
+
else:
|
92 |
+
# below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
|
93 |
+
# x_abs)_mean , min_abs.
|
94 |
+
below_threshold = (
|
95 |
+
(min_abs - x_abs_mean) * (gain_factor / min_abs)
|
96 |
+
).clamp(min=0, max=max_factor)
|
97 |
+
|
98 |
+
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
|
99 |
+
min=0, max=max_factor
|
100 |
+
)
|
101 |
+
|
102 |
+
return below_threshold - above_threshold
|
103 |
+
|
104 |
+
|
105 |
+
def _compute_sign_factor(
|
106 |
+
x: Tensor,
|
107 |
+
channel_dim: int,
|
108 |
+
min_positive: float,
|
109 |
+
max_positive: float,
|
110 |
+
gain_factor: float,
|
111 |
+
max_factor: float,
|
112 |
+
) -> Tensor:
|
113 |
+
if channel_dim < 0:
|
114 |
+
channel_dim += x.ndim
|
115 |
+
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
116 |
+
proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims)
|
117 |
+
if min_positive == 0.0:
|
118 |
+
factor1 = 0.0
|
119 |
+
else:
|
120 |
+
# 0 if proportion_positive >= min_positive, else can be
|
121 |
+
# as large as max_factor.
|
122 |
+
factor1 = (
|
123 |
+
(min_positive - proportion_positive) * (gain_factor / min_positive)
|
124 |
+
).clamp_(min=0, max=max_factor)
|
125 |
+
|
126 |
+
if max_positive == 1.0:
|
127 |
+
factor2 = 0.0
|
128 |
+
else:
|
129 |
+
# 0 if self.proportion_positive <= max_positive, else can be
|
130 |
+
# as large as -max_factor.
|
131 |
+
factor2 = (
|
132 |
+
(proportion_positive - max_positive)
|
133 |
+
* (gain_factor / (1.0 - max_positive))
|
134 |
+
).clamp_(min=0, max=max_factor)
|
135 |
+
sign_factor = factor1 - factor2
|
136 |
+
# require min_positive != 0 or max_positive != 1:
|
137 |
+
assert not isinstance(sign_factor, float)
|
138 |
+
return sign_factor
|
139 |
+
|
140 |
+
|
141 |
+
class ActivationScaleBalancerFunction(torch.autograd.Function):
|
142 |
+
"""
|
143 |
+
This object is used in class ActivationBalancer when the user specified
|
144 |
+
min_positive=0, max_positive=1, so there are no constraints on the signs
|
145 |
+
of the activations and only the absolute value has a constraint.
|
146 |
+
"""
|
147 |
+
|
148 |
+
@staticmethod
|
149 |
+
def forward(
|
150 |
+
ctx,
|
151 |
+
x: Tensor,
|
152 |
+
sign_factor: Tensor,
|
153 |
+
scale_factor: Tensor,
|
154 |
+
channel_dim: int,
|
155 |
+
) -> Tensor:
|
156 |
+
if channel_dim < 0:
|
157 |
+
channel_dim += x.ndim
|
158 |
+
ctx.channel_dim = channel_dim
|
159 |
+
xgt0 = x > 0
|
160 |
+
ctx.save_for_backward(xgt0, sign_factor, scale_factor)
|
161 |
+
return x
|
162 |
+
|
163 |
+
@staticmethod
|
164 |
+
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
|
165 |
+
xgt0, sign_factor, scale_factor = ctx.saved_tensors
|
166 |
+
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
167 |
+
sign_factor = sign_factor.unsqueeze(-1)
|
168 |
+
scale_factor = scale_factor.unsqueeze(-1)
|
169 |
+
|
170 |
+
factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
171 |
+
neg_delta_grad = x_grad.abs() * factor
|
172 |
+
return (
|
173 |
+
x_grad - neg_delta_grad,
|
174 |
+
None,
|
175 |
+
None,
|
176 |
+
None,
|
177 |
+
)
|
178 |
+
|
179 |
+
|
180 |
+
class RandomClampFunction(torch.autograd.Function):
|
181 |
+
@staticmethod
|
182 |
+
def forward(
|
183 |
+
ctx,
|
184 |
+
x: Tensor,
|
185 |
+
min: Optional[float],
|
186 |
+
max: Optional[float],
|
187 |
+
prob: float,
|
188 |
+
reflect: float,
|
189 |
+
) -> Tensor:
|
190 |
+
x_clamped = torch.clamp(x, min=min, max=max)
|
191 |
+
mask = torch.rand_like(x) < prob
|
192 |
+
ans = torch.where(mask, x_clamped, x)
|
193 |
+
if x.requires_grad:
|
194 |
+
ctx.save_for_backward(ans == x)
|
195 |
+
ctx.reflect = reflect
|
196 |
+
if reflect != 0.0:
|
197 |
+
ans = ans * (1.0 + reflect) - (x * reflect)
|
198 |
+
return ans
|
199 |
+
|
200 |
+
@staticmethod
|
201 |
+
def backward(
|
202 |
+
ctx, ans_grad: Tensor
|
203 |
+
) -> Tuple[Tensor, None, None, None, None]:
|
204 |
+
(is_same,) = ctx.saved_tensors
|
205 |
+
x_grad = ans_grad * is_same.to(ans_grad.dtype)
|
206 |
+
reflect = ctx.reflect
|
207 |
+
if reflect != 0.0:
|
208 |
+
x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect)
|
209 |
+
return x_grad, None, None, None, None
|
210 |
+
|
211 |
+
|
212 |
+
def random_clamp(
|
213 |
+
x: Tensor,
|
214 |
+
min: Optional[float] = None,
|
215 |
+
max: Optional[float] = None,
|
216 |
+
prob: float = 0.5,
|
217 |
+
reflect: float = 0.0,
|
218 |
+
):
|
219 |
+
return RandomClampFunction.apply(x, min, max, prob, reflect)
|
220 |
+
|
221 |
+
|
222 |
+
def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor:
|
223 |
+
"""
|
224 |
+
A randomized way of casting a floating point value to half precision.
|
225 |
+
"""
|
226 |
+
if x.dtype == torch.float16:
|
227 |
+
return x
|
228 |
+
x_abs = x.abs()
|
229 |
+
is_too_small = x_abs < min_abs
|
230 |
+
# for elements where is_too_small is true, random_val will contain +-min_abs with
|
231 |
+
# probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations,
|
232 |
+
# for those elements].
|
233 |
+
random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
|
234 |
+
return torch.where(is_too_small, random_val, x).to(torch.float16)
|
235 |
+
|
236 |
+
|
237 |
+
class RandomGradFunction(torch.autograd.Function):
|
238 |
+
"""
|
239 |
+
Does nothing in forward pass; in backward pass, gets rid of very small grads using
|
240 |
+
randomized approach that preserves expectations (intended to reduce roundoff).
|
241 |
+
"""
|
242 |
+
|
243 |
+
@staticmethod
|
244 |
+
def forward(ctx, x: Tensor, min_abs: float) -> Tensor:
|
245 |
+
ctx.min_abs = min_abs
|
246 |
+
return x
|
247 |
+
|
248 |
+
@staticmethod
|
249 |
+
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]:
|
250 |
+
if ans_grad.dtype == torch.float16:
|
251 |
+
return (
|
252 |
+
random_cast_to_half(
|
253 |
+
ans_grad.to(torch.float32), min_abs=ctx.min_abs
|
254 |
+
),
|
255 |
+
None,
|
256 |
+
)
|
257 |
+
else:
|
258 |
+
return ans_grad, None
|
259 |
+
|
260 |
+
|
261 |
+
class RandomGrad(torch.nn.Module):
|
262 |
+
"""
|
263 |
+
Gets rid of very small gradients using an expectation-preserving method, intended to increase
|
264 |
+
accuracy of training when using amp (automatic mixed precision)
|
265 |
+
"""
|
266 |
+
|
267 |
+
def __init__(self, min_abs: float = 5.0e-06):
|
268 |
+
super(RandomGrad, self).__init__()
|
269 |
+
self.min_abs = min_abs
|
270 |
+
|
271 |
+
def forward(self, x: Tensor):
|
272 |
+
if (
|
273 |
+
torch.jit.is_scripting()
|
274 |
+
or not self.training
|
275 |
+
or torch.jit.is_tracing()
|
276 |
+
):
|
277 |
+
return x
|
278 |
+
else:
|
279 |
+
return RandomGradFunction.apply(x, self.min_abs)
|
280 |
+
|
281 |
+
|
282 |
+
class SoftmaxFunction(torch.autograd.Function):
|
283 |
+
"""
|
284 |
+
Tries to handle half-precision derivatives in a randomized way that should
|
285 |
+
be more accurate for training than the default behavior.
|
286 |
+
"""
|
287 |
+
|
288 |
+
@staticmethod
|
289 |
+
def forward(ctx, x: Tensor, dim: int):
|
290 |
+
ans = x.softmax(dim=dim)
|
291 |
+
# if x dtype is float16, x.softmax() returns a float32 because
|
292 |
+
# (presumably) that op does not support float16, and autocast
|
293 |
+
# is enabled.
|
294 |
+
if torch.is_autocast_enabled():
|
295 |
+
ans = ans.to(torch.float16)
|
296 |
+
ctx.save_for_backward(ans)
|
297 |
+
ctx.x_dtype = x.dtype
|
298 |
+
ctx.dim = dim
|
299 |
+
return ans
|
300 |
+
|
301 |
+
@staticmethod
|
302 |
+
def backward(ctx, ans_grad: Tensor):
|
303 |
+
(ans,) = ctx.saved_tensors
|
304 |
+
with torch.cuda.amp.autocast(enabled=False):
|
305 |
+
ans_grad = ans_grad.to(torch.float32)
|
306 |
+
ans = ans.to(torch.float32)
|
307 |
+
x_grad = ans_grad * ans
|
308 |
+
x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
|
309 |
+
return x_grad, None
|
310 |
+
|
311 |
+
|
312 |
+
def softmax(x: Tensor, dim: int):
|
313 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
314 |
+
return x.softmax(dim)
|
315 |
+
|
316 |
+
return SoftmaxFunction.apply(x, dim)
|
317 |
+
|
318 |
+
|
319 |
+
class MaxEigLimiterFunction(torch.autograd.Function):
|
320 |
+
@staticmethod
|
321 |
+
def forward(
|
322 |
+
ctx,
|
323 |
+
x: Tensor,
|
324 |
+
coeffs: Tensor,
|
325 |
+
direction: Tensor,
|
326 |
+
channel_dim: int,
|
327 |
+
grad_scale: float,
|
328 |
+
) -> Tensor:
|
329 |
+
ctx.channel_dim = channel_dim
|
330 |
+
ctx.grad_scale = grad_scale
|
331 |
+
ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach())
|
332 |
+
return x
|
333 |
+
|
334 |
+
@staticmethod
|
335 |
+
def backward(ctx, x_grad, *args):
|
336 |
+
with torch.enable_grad():
|
337 |
+
(x_orig, coeffs, new_direction) = ctx.saved_tensors
|
338 |
+
x_orig.requires_grad = True
|
339 |
+
num_channels = x_orig.shape[ctx.channel_dim]
|
340 |
+
x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
|
341 |
+
new_direction.requires_grad = False
|
342 |
+
x = x - x.mean(dim=0)
|
343 |
+
x_var = (x ** 2).mean()
|
344 |
+
x_residual = x - coeffs * new_direction
|
345 |
+
x_residual_var = (x_residual ** 2).mean()
|
346 |
+
# `variance_proportion` is the proportion of the variance accounted for
|
347 |
+
# by the top eigen-direction. This is to be minimized.
|
348 |
+
variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
|
349 |
+
variance_proportion.backward()
|
350 |
+
x_orig_grad = x_orig.grad
|
351 |
+
x_extra_grad = (
|
352 |
+
x_orig.grad
|
353 |
+
* ctx.grad_scale
|
354 |
+
* x_grad.norm()
|
355 |
+
/ (x_orig_grad.norm() + 1.0e-20)
|
356 |
+
)
|
357 |
+
return x_grad + x_extra_grad.detach(), None, None, None, None
|
358 |
+
|
359 |
+
|
360 |
+
class BasicNorm(torch.nn.Module):
|
361 |
+
"""
|
362 |
+
This is intended to be a simpler, and hopefully cheaper, replacement for
|
363 |
+
LayerNorm. The observation this is based on, is that Transformer-type
|
364 |
+
networks, especially with pre-norm, sometimes seem to set one of the
|
365 |
+
feature dimensions to a large constant value (e.g. 50), which "defeats"
|
366 |
+
the LayerNorm because the output magnitude is then not strongly dependent
|
367 |
+
on the other (useful) features. Presumably the weight and bias of the
|
368 |
+
LayerNorm are required to allow it to do this.
|
369 |
+
|
370 |
+
So the idea is to introduce this large constant value as an explicit
|
371 |
+
parameter, that takes the role of the "eps" in LayerNorm, so the network
|
372 |
+
doesn't have to do this trick. We make the "eps" learnable.
|
373 |
+
|
374 |
+
Args:
|
375 |
+
num_channels: the number of channels, e.g. 512.
|
376 |
+
channel_dim: the axis/dimension corresponding to the channel,
|
377 |
+
interprted as an offset from the input's ndim if negative.
|
378 |
+
shis is NOT the num_channels; it should typically be one of
|
379 |
+
{-2, -1, 0, 1, 2, 3}.
|
380 |
+
eps: the initial "epsilon" that we add as ballast in:
|
381 |
+
scale = ((input_vec**2).mean() + epsilon)**-0.5
|
382 |
+
Note: our epsilon is actually large, but we keep the name
|
383 |
+
to indicate the connection with conventional LayerNorm.
|
384 |
+
learn_eps: if true, we learn epsilon; if false, we keep it
|
385 |
+
at the initial value.
|
386 |
+
eps_min: float
|
387 |
+
eps_max: float
|
388 |
+
"""
|
389 |
+
|
390 |
+
def __init__(
|
391 |
+
self,
|
392 |
+
num_channels: int,
|
393 |
+
channel_dim: int = -1, # CAUTION: see documentation.
|
394 |
+
eps: float = 0.25,
|
395 |
+
learn_eps: bool = True,
|
396 |
+
eps_min: float = -3.0,
|
397 |
+
eps_max: float = 3.0,
|
398 |
+
) -> None:
|
399 |
+
super(BasicNorm, self).__init__()
|
400 |
+
self.num_channels = num_channels
|
401 |
+
self.channel_dim = channel_dim
|
402 |
+
if learn_eps:
|
403 |
+
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
|
404 |
+
else:
|
405 |
+
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
406 |
+
self.eps_min = eps_min
|
407 |
+
self.eps_max = eps_max
|
408 |
+
|
409 |
+
def forward(self, x: Tensor) -> Tensor:
|
410 |
+
assert x.shape[self.channel_dim] == self.num_channels
|
411 |
+
eps = self.eps
|
412 |
+
if self.training and random.random() < 0.25:
|
413 |
+
# with probability 0.25, in training mode, clamp eps between the min
|
414 |
+
# and max; this will encourage it to learn parameters within the
|
415 |
+
# allowed range by making parameters that are outside the allowed
|
416 |
+
# range noisy.
|
417 |
+
|
418 |
+
# gradients to allow the parameter to get back into the allowed
|
419 |
+
# region if it happens to exit it.
|
420 |
+
eps = eps.clamp(min=self.eps_min, max=self.eps_max)
|
421 |
+
scales = (
|
422 |
+
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp()
|
423 |
+
) ** -0.5
|
424 |
+
return x * scales
|
425 |
+
|
426 |
+
|
427 |
+
def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
|
428 |
+
"""
|
429 |
+
Behaves like a constructor of a modified version of nn.Linear
|
430 |
+
that gives an easy way to set the default initial parameter scale.
|
431 |
+
|
432 |
+
Args:
|
433 |
+
Accepts the standard args and kwargs that nn.Linear accepts
|
434 |
+
e.g. in_features, out_features, bias=False.
|
435 |
+
|
436 |
+
initial_scale: you can override this if you want to increase
|
437 |
+
or decrease the initial magnitude of the module's output
|
438 |
+
(affects the initialization of weight_scale and bias_scale).
|
439 |
+
Another option, if you want to do something like this, is
|
440 |
+
to re-initialize the parameters.
|
441 |
+
"""
|
442 |
+
ans = nn.Linear(*args, **kwargs)
|
443 |
+
with torch.no_grad():
|
444 |
+
ans.weight[:] *= initial_scale
|
445 |
+
if ans.bias is not None:
|
446 |
+
torch.nn.init.uniform_(
|
447 |
+
ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
|
448 |
+
)
|
449 |
+
return ans
|
450 |
+
|
451 |
+
|
452 |
+
def ScaledConv1d(
|
453 |
+
*args,
|
454 |
+
initial_scale: float = 1.0,
|
455 |
+
kernel_size: int = 3,
|
456 |
+
padding: str = "same",
|
457 |
+
**kwargs,
|
458 |
+
) -> nn.Conv1d:
|
459 |
+
"""
|
460 |
+
Behaves like a constructor of a modified version of nn.Conv1d
|
461 |
+
that gives an easy way to set the default initial parameter scale.
|
462 |
+
|
463 |
+
Args:
|
464 |
+
Accepts the standard args and kwargs that nn.Linear accepts
|
465 |
+
e.g. in_features, out_features, bias=False.
|
466 |
+
|
467 |
+
initial_scale: you can override this if you want to increase
|
468 |
+
or decrease the initial magnitude of the module's output
|
469 |
+
(affects the initialization of weight_scale and bias_scale).
|
470 |
+
Another option, if you want to do something like this, is
|
471 |
+
to re-initialize the parameters.
|
472 |
+
"""
|
473 |
+
ans = nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs)
|
474 |
+
with torch.no_grad():
|
475 |
+
ans.weight[:] *= initial_scale
|
476 |
+
if ans.bias is not None:
|
477 |
+
torch.nn.init.uniform_(
|
478 |
+
ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
|
479 |
+
)
|
480 |
+
return ans
|
481 |
+
|
482 |
+
|
483 |
+
def TransposeScaledConv1d(
|
484 |
+
*args,
|
485 |
+
initial_scale: float = 1.0,
|
486 |
+
kernel_size: int = 3,
|
487 |
+
padding: str = "same",
|
488 |
+
**kwargs,
|
489 |
+
) -> nn.Sequential:
|
490 |
+
"""
|
491 |
+
Transpose -> ScaledConv1d
|
492 |
+
"""
|
493 |
+
return nn.Sequential(
|
494 |
+
Transpose(),
|
495 |
+
ScaledConv1d(
|
496 |
+
*args,
|
497 |
+
initial_scale=initial_scale,
|
498 |
+
kernel_size=kernel_size,
|
499 |
+
padding=padding,
|
500 |
+
**kwargs,
|
501 |
+
),
|
502 |
+
)
|
503 |
+
|
504 |
+
|
505 |
+
def ScaledConv1dTranspose(
|
506 |
+
*args,
|
507 |
+
initial_scale: float = 1.0,
|
508 |
+
kernel_size: int = 3,
|
509 |
+
padding: str = "same",
|
510 |
+
**kwargs,
|
511 |
+
) -> nn.Sequential:
|
512 |
+
"""
|
513 |
+
Transpose -> ScaledConv1d
|
514 |
+
"""
|
515 |
+
return nn.Sequential(
|
516 |
+
ScaledConv1d(
|
517 |
+
*args,
|
518 |
+
initial_scale=initial_scale,
|
519 |
+
kernel_size=kernel_size,
|
520 |
+
padding=padding,
|
521 |
+
**kwargs,
|
522 |
+
),
|
523 |
+
Transpose(),
|
524 |
+
)
|
525 |
+
|
526 |
+
|
527 |
+
def TransposeConv1d(
|
528 |
+
*args, kernel_size: int = 3, padding: str = "same", **kwargs
|
529 |
+
) -> nn.Sequential:
|
530 |
+
"""
|
531 |
+
Transpose -> Conv1d
|
532 |
+
"""
|
533 |
+
return nn.Sequential(
|
534 |
+
Transpose(),
|
535 |
+
nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
|
536 |
+
)
|
537 |
+
|
538 |
+
|
539 |
+
def Conv1dTranspose(
|
540 |
+
*args, kernel_size: int = 3, padding: str = "same", **kwargs
|
541 |
+
) -> nn.Sequential:
|
542 |
+
"""
|
543 |
+
ScaledConv1d -> Transpose
|
544 |
+
"""
|
545 |
+
return nn.Sequential(
|
546 |
+
nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
|
547 |
+
Transpose(),
|
548 |
+
)
|
549 |
+
|
550 |
+
|
551 |
+
class SRLinear(nn.Linear):
|
552 |
+
"""https://arxiv.org/abs/2303.06296
|
553 |
+
Stabilizing Transformer Training by Preventing Attention Entropy Collapse
|
554 |
+
"""
|
555 |
+
|
556 |
+
def __init__(self, in_features, out_features, bias=True, **kwargs):
|
557 |
+
super().__init__(in_features, out_features, bias=bias, **kwargs)
|
558 |
+
self.register_buffer(
|
559 |
+
"u", nn.functional.normalize(torch.randn(in_features), dim=0)
|
560 |
+
)
|
561 |
+
with torch.no_grad():
|
562 |
+
sigma = self.get_sigma()
|
563 |
+
self.register_buffer("spectral_norm", sigma)
|
564 |
+
self.sigma = nn.Parameter(torch.ones(1))
|
565 |
+
|
566 |
+
def get_sigma(self):
|
567 |
+
with torch.no_grad():
|
568 |
+
u = self.u
|
569 |
+
v = self.weight.mv(u)
|
570 |
+
v = nn.functional.normalize(v, dim=0)
|
571 |
+
u = self.weight.T.mv(v)
|
572 |
+
u = nn.functional.normalize(u, dim=0)
|
573 |
+
self.u.data.copy_(u)
|
574 |
+
return torch.einsum("c,cd,d->", v, self.weight, u)
|
575 |
+
|
576 |
+
def get_weight(self):
|
577 |
+
sigma = self.get_sigma()
|
578 |
+
if self.training:
|
579 |
+
self.spectral_norm.data.copy_(sigma)
|
580 |
+
weight = (self.sigma / sigma) * self.weight
|
581 |
+
return weight
|
582 |
+
|
583 |
+
def forward(self, x):
|
584 |
+
return nn.functional.linear(x, self.get_weight(), self.bias)
|
585 |
+
|
586 |
+
|
587 |
+
class SRConv1d(SRLinear):
|
588 |
+
def __init__(
|
589 |
+
self,
|
590 |
+
in_features,
|
591 |
+
out_features,
|
592 |
+
kernel_size,
|
593 |
+
stride: int = 1,
|
594 |
+
padding: str = "same",
|
595 |
+
bias: bool = True,
|
596 |
+
**kwargs,
|
597 |
+
):
|
598 |
+
in_features = in_features * kernel_size
|
599 |
+
super().__init__(in_features, out_features, bias=bias, **kwargs)
|
600 |
+
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
601 |
+
self.kernel_size = kernel_size
|
602 |
+
self.stride = stride
|
603 |
+
self.padding = padding
|
604 |
+
|
605 |
+
def forward(self, x):
|
606 |
+
in_features = self.in_features // self.kernel_size
|
607 |
+
weight = self.get_weight().view(
|
608 |
+
self.out_features, in_features, self.kernel_size
|
609 |
+
)
|
610 |
+
return nn.functional.conv1d(
|
611 |
+
x, weight, bias=self.bias, stride=self.stride, padding=self.padding
|
612 |
+
)
|
613 |
+
|
614 |
+
|
615 |
+
def TransposeSRConv1d(
|
616 |
+
*args, kernel_size: int = 3, padding: str = "same", **kwargs
|
617 |
+
) -> nn.Sequential:
|
618 |
+
"""
|
619 |
+
Transpose -> SRConv1d
|
620 |
+
"""
|
621 |
+
return nn.Sequential(
|
622 |
+
Transpose(),
|
623 |
+
SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
|
624 |
+
)
|
625 |
+
|
626 |
+
|
627 |
+
def SRConv1dTranspose(
|
628 |
+
*args, kernel_size: int = 3, padding: str = "same", **kwargs
|
629 |
+
) -> nn.Sequential:
|
630 |
+
"""
|
631 |
+
SRConv1d -> Transpose
|
632 |
+
"""
|
633 |
+
return nn.Sequential(
|
634 |
+
SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
|
635 |
+
Transpose(),
|
636 |
+
)
|
637 |
+
|
638 |
+
|
639 |
+
class ActivationBalancer(torch.nn.Module):
|
640 |
+
"""
|
641 |
+
Modifies the backpropped derivatives of a function to try to encourage, for
|
642 |
+
each channel, that it is positive at least a proportion `threshold` of the
|
643 |
+
time. It does this by multiplying negative derivative values by up to
|
644 |
+
(1+max_factor), and positive derivative values by up to (1-max_factor),
|
645 |
+
interpolated from 1 at the threshold to those extremal values when none
|
646 |
+
of the inputs are positive.
|
647 |
+
|
648 |
+
Args:
|
649 |
+
num_channels: the number of channels
|
650 |
+
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
651 |
+
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
652 |
+
min_positive: the minimum, per channel, of the proportion of the time
|
653 |
+
that (x > 0), below which we start to modify the derivatives.
|
654 |
+
max_positive: the maximum, per channel, of the proportion of the time
|
655 |
+
that (x > 0), above which we start to modify the derivatives.
|
656 |
+
max_factor: the maximum factor by which we modify the derivatives for
|
657 |
+
either the sign constraint or the magnitude constraint;
|
658 |
+
e.g. with max_factor=0.02, the the derivatives would be multiplied by
|
659 |
+
values in the range [0.98..1.02].
|
660 |
+
sign_gain_factor: determines the 'gain' with which we increase the
|
661 |
+
change in gradient once the constraints on min_positive and max_positive
|
662 |
+
are violated.
|
663 |
+
scale_gain_factor: determines the 'gain' with which we increase the
|
664 |
+
change in gradient once the constraints on min_abs and max_abs
|
665 |
+
are violated.
|
666 |
+
min_abs: the minimum average-absolute-value difference from the mean
|
667 |
+
value per channel, which we allow, before we start to modify
|
668 |
+
the derivatives to prevent this.
|
669 |
+
max_abs: the maximum average-absolute-value difference from the mean
|
670 |
+
value per channel, which we allow, before we start to modify
|
671 |
+
the derivatives to prevent this.
|
672 |
+
min_prob: determines the minimum probability with which we modify the
|
673 |
+
gradients for the {min,max}_positive and {min,max}_abs constraints,
|
674 |
+
on each forward(). This is done randomly to prevent all layers
|
675 |
+
from doing it at the same time. Early in training we may use
|
676 |
+
higher probabilities than this; it will decay to this value.
|
677 |
+
"""
|
678 |
+
|
679 |
+
def __init__(
|
680 |
+
self,
|
681 |
+
num_channels: int,
|
682 |
+
channel_dim: int,
|
683 |
+
min_positive: float = 0.05,
|
684 |
+
max_positive: float = 0.95,
|
685 |
+
max_factor: float = 0.04,
|
686 |
+
sign_gain_factor: float = 0.01,
|
687 |
+
scale_gain_factor: float = 0.02,
|
688 |
+
min_abs: float = 0.2,
|
689 |
+
max_abs: float = 100.0,
|
690 |
+
min_prob: float = 0.1,
|
691 |
+
):
|
692 |
+
super(ActivationBalancer, self).__init__()
|
693 |
+
self.num_channels = num_channels
|
694 |
+
self.channel_dim = channel_dim
|
695 |
+
self.min_positive = min_positive
|
696 |
+
self.max_positive = max_positive
|
697 |
+
self.max_factor = max_factor
|
698 |
+
self.min_abs = min_abs
|
699 |
+
self.max_abs = max_abs
|
700 |
+
self.min_prob = min_prob
|
701 |
+
self.sign_gain_factor = sign_gain_factor
|
702 |
+
self.scale_gain_factor = scale_gain_factor
|
703 |
+
|
704 |
+
# count measures how many times the forward() function has been called.
|
705 |
+
# We occasionally sync this to a tensor called `count`, that exists to
|
706 |
+
# make sure it is synced to disk when we load and save the model.
|
707 |
+
self.cpu_count = 0
|
708 |
+
self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
|
709 |
+
|
710 |
+
def forward(self, x: Tensor) -> Tensor:
|
711 |
+
if (
|
712 |
+
torch.jit.is_scripting()
|
713 |
+
or not x.requires_grad
|
714 |
+
or torch.jit.is_tracing()
|
715 |
+
):
|
716 |
+
return _no_op(x)
|
717 |
+
|
718 |
+
count = self.cpu_count
|
719 |
+
self.cpu_count += 1
|
720 |
+
|
721 |
+
if random.random() < 0.01:
|
722 |
+
# Occasionally sync self.cpu_count with self.count.
|
723 |
+
# count affects the decay of 'prob'. don't do this on every iter,
|
724 |
+
# because syncing with the GPU is slow.
|
725 |
+
self.cpu_count = max(self.cpu_count, self.count.item())
|
726 |
+
self.count.fill_(self.cpu_count)
|
727 |
+
|
728 |
+
# the prob of doing some work exponentially decreases from 0.5 till it hits
|
729 |
+
# a floor at min_prob (==0.1, by default)
|
730 |
+
prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
|
731 |
+
|
732 |
+
if random.random() < prob:
|
733 |
+
sign_gain_factor = 0.5
|
734 |
+
if self.min_positive != 0.0 or self.max_positive != 1.0:
|
735 |
+
sign_factor = _compute_sign_factor(
|
736 |
+
x,
|
737 |
+
self.channel_dim,
|
738 |
+
self.min_positive,
|
739 |
+
self.max_positive,
|
740 |
+
gain_factor=self.sign_gain_factor / prob,
|
741 |
+
max_factor=self.max_factor,
|
742 |
+
)
|
743 |
+
else:
|
744 |
+
sign_factor = None
|
745 |
+
|
746 |
+
scale_factor = _compute_scale_factor(
|
747 |
+
x.detach(),
|
748 |
+
self.channel_dim,
|
749 |
+
min_abs=self.min_abs,
|
750 |
+
max_abs=self.max_abs,
|
751 |
+
gain_factor=self.scale_gain_factor / prob,
|
752 |
+
max_factor=self.max_factor,
|
753 |
+
)
|
754 |
+
return ActivationBalancerFunction.apply(
|
755 |
+
x,
|
756 |
+
scale_factor,
|
757 |
+
sign_factor,
|
758 |
+
self.channel_dim,
|
759 |
+
)
|
760 |
+
else:
|
761 |
+
return _no_op(x)
|
762 |
+
|
763 |
+
|
764 |
+
def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
|
765 |
+
"""
|
766 |
+
Returns x unmodified, but in backprop will put a penalty for the excess of
|
767 |
+
the absolute values of elements of x over the limit "limit". E.g. if
|
768 |
+
limit == 10.0, then if x has any values over 10 it will get a penalty.
|
769 |
+
|
770 |
+
Caution: the value of this penalty will be affected by grad scaling used
|
771 |
+
in automatic mixed precision training. For this reasons we use this,
|
772 |
+
it shouldn't really matter, or may even be helpful; we just use this
|
773 |
+
to disallow really implausible values of scores to be given to softmax.
|
774 |
+
"""
|
775 |
+
x_sign = x.sign()
|
776 |
+
over_limit = (x.abs() - limit) > 0
|
777 |
+
# The following is a memory efficient way to penalize the absolute values of
|
778 |
+
# x that's over the limit. (The memory efficiency comes when you think
|
779 |
+
# about which items torch needs to cache for the autograd, and which ones it
|
780 |
+
# can throw away). The numerical value of aux_loss as computed here will
|
781 |
+
# actually be larger than it should be, by limit * over_limit.sum(), but it
|
782 |
+
# has the same derivative as the real aux_loss which is penalty * (x.abs() -
|
783 |
+
# limit).relu().
|
784 |
+
aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
|
785 |
+
# note: we don't do sum() here on aux)_loss, but it's as if we had done
|
786 |
+
# sum() due to how with_loss() works.
|
787 |
+
x = with_loss(x, aux_loss)
|
788 |
+
# you must use x for something, or this will be ineffective.
|
789 |
+
return x
|
790 |
+
|
791 |
+
|
792 |
+
def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
|
793 |
+
if x.ndim == 2:
|
794 |
+
return x.diag()
|
795 |
+
else:
|
796 |
+
(batch, dim, dim) = x.shape
|
797 |
+
x = x.reshape(batch, dim * dim)
|
798 |
+
x = x[:, :: dim + 1]
|
799 |
+
assert x.shape == (batch, dim)
|
800 |
+
return x
|
801 |
+
|
802 |
+
|
803 |
+
def _whitening_metric(x: Tensor, num_groups: int):
|
804 |
+
"""
|
805 |
+
Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
|
806 |
+
of the centered feature covariance are the same within each group's covariance matrix
|
807 |
+
and also between groups.
|
808 |
+
Args:
|
809 |
+
x: a Tensor of shape (*, num_channels)
|
810 |
+
num_groups: the number of groups of channels, a number >=1 that divides num_channels
|
811 |
+
Returns:
|
812 |
+
Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
|
813 |
+
greater than 1.0 otherwise.
|
814 |
+
"""
|
815 |
+
assert x.dtype != torch.float16
|
816 |
+
x = x.reshape(-1, x.shape[-1])
|
817 |
+
(num_frames, num_channels) = x.shape
|
818 |
+
assert num_channels % num_groups == 0
|
819 |
+
channels_per_group = num_channels // num_groups
|
820 |
+
x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
|
821 |
+
# x now has shape (num_groups, num_frames, channels_per_group)
|
822 |
+
# subtract the mean so we use the centered, not uncentered, covariance.
|
823 |
+
# My experience has been that when we "mess with the gradients" like this,
|
824 |
+
# it's better not do anything that tries to move the mean around, because
|
825 |
+
# that can easily cause instability.
|
826 |
+
x = x - x.mean(dim=1, keepdim=True)
|
827 |
+
# x_covar: (num_groups, channels_per_group, channels_per_group)
|
828 |
+
x_covar = torch.matmul(x.transpose(1, 2), x)
|
829 |
+
x_covar_mean_diag = _diag(x_covar).mean()
|
830 |
+
# the following expression is what we'd get if we took the matrix product
|
831 |
+
# of each covariance and measured the mean of its trace, i.e.
|
832 |
+
# the same as _diag(torch.matmul(x_covar, x_covar)).mean().
|
833 |
+
x_covarsq_mean_diag = (x_covar ** 2).sum() / (
|
834 |
+
num_groups * channels_per_group
|
835 |
+
)
|
836 |
+
# this metric will be >= 1.0; the larger it is, the less 'white' the data was.
|
837 |
+
metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20)
|
838 |
+
return metric
|
839 |
+
|
840 |
+
|
841 |
+
class WhiteningPenaltyFunction(torch.autograd.Function):
|
842 |
+
@staticmethod
|
843 |
+
def forward(
|
844 |
+
ctx,
|
845 |
+
x: Tensor,
|
846 |
+
num_groups: int,
|
847 |
+
whitening_limit: float,
|
848 |
+
grad_scale: float,
|
849 |
+
) -> Tensor:
|
850 |
+
ctx.save_for_backward(x)
|
851 |
+
ctx.num_groups = num_groups
|
852 |
+
ctx.whitening_limit = whitening_limit
|
853 |
+
ctx.grad_scale = grad_scale
|
854 |
+
return x
|
855 |
+
|
856 |
+
@staticmethod
|
857 |
+
def backward(ctx, x_grad: Tensor):
|
858 |
+
(x_orig,) = ctx.saved_tensors
|
859 |
+
with torch.enable_grad():
|
860 |
+
with torch.cuda.amp.autocast(enabled=False):
|
861 |
+
x_detached = x_orig.to(torch.float32).detach()
|
862 |
+
x_detached.requires_grad = True
|
863 |
+
|
864 |
+
metric = _whitening_metric(x_detached, ctx.num_groups)
|
865 |
+
|
866 |
+
if random.random() < 0.005 or __name__ == "__main__":
|
867 |
+
logging.info(
|
868 |
+
f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, "
|
869 |
+
f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}"
|
870 |
+
)
|
871 |
+
|
872 |
+
(metric - ctx.whitening_limit).relu().backward()
|
873 |
+
penalty_grad = x_detached.grad
|
874 |
+
scale = ctx.grad_scale * (
|
875 |
+
x_grad.to(torch.float32).norm()
|
876 |
+
/ (penalty_grad.norm() + 1.0e-20)
|
877 |
+
)
|
878 |
+
penalty_grad = penalty_grad * scale
|
879 |
+
return x_grad + penalty_grad.to(x_grad.dtype), None, None, None
|
880 |
+
|
881 |
+
|
882 |
+
class Whiten(nn.Module):
|
883 |
+
def __init__(
|
884 |
+
self,
|
885 |
+
num_groups: int,
|
886 |
+
whitening_limit: float,
|
887 |
+
prob: Union[float, Tuple[float, float]],
|
888 |
+
grad_scale: float,
|
889 |
+
):
|
890 |
+
"""
|
891 |
+
Args:
|
892 |
+
num_groups: the number of groups to divide the channel dim into before
|
893 |
+
whitening. We will attempt to make the feature covariance
|
894 |
+
within each group, after mean subtraction, as "white" as possible,
|
895 |
+
while having the same trace across all groups.
|
896 |
+
whitening_limit: a value greater than 1.0, that dictates how much
|
897 |
+
freedom we have to violate the constraints. 1.0 would mean perfectly
|
898 |
+
white, with exactly the same trace across groups; larger values
|
899 |
+
give more freedom. E.g. 2.0.
|
900 |
+
prob: the probability with which we apply the gradient modification
|
901 |
+
(also affects the grad scale). May be supplied as a float,
|
902 |
+
or as a pair (min_prob, max_prob)
|
903 |
+
|
904 |
+
grad_scale: determines the scale on the gradient term from this object,
|
905 |
+
relative to the rest of the gradient on the attention weights.
|
906 |
+
E.g. 0.02 (you may want to use smaller values than this if prob is large)
|
907 |
+
"""
|
908 |
+
super(Whiten, self).__init__()
|
909 |
+
assert num_groups >= 1
|
910 |
+
assert whitening_limit >= 1
|
911 |
+
assert grad_scale >= 0
|
912 |
+
self.num_groups = num_groups
|
913 |
+
self.whitening_limit = whitening_limit
|
914 |
+
if isinstance(prob, float):
|
915 |
+
assert 0 < prob <= 1
|
916 |
+
self.prob = prob
|
917 |
+
else:
|
918 |
+
(self.min_prob, self.max_prob) = prob
|
919 |
+
assert 0 < self.min_prob < self.max_prob <= 1
|
920 |
+
self.prob = self.max_prob
|
921 |
+
|
922 |
+
self.grad_scale = grad_scale
|
923 |
+
|
924 |
+
def forward(self, x: Tensor) -> Tensor:
|
925 |
+
"""
|
926 |
+
In the forward pass, this function just returns the input unmodified.
|
927 |
+
In the backward pass, it will modify the gradients to ensure that the
|
928 |
+
distribution in each group has close to (lambda times I) as the covariance
|
929 |
+
after mean subtraction, with the same lambda across groups.
|
930 |
+
For whitening_limit > 1, there will be more freedom to violate this
|
931 |
+
constraint.
|
932 |
+
|
933 |
+
Args:
|
934 |
+
x: the input of shape (*, num_channels)
|
935 |
+
|
936 |
+
Returns:
|
937 |
+
x, unmodified. You should make sure
|
938 |
+
you use the returned value, or the graph will be freed
|
939 |
+
and nothing will happen in backprop.
|
940 |
+
"""
|
941 |
+
if (
|
942 |
+
not x.requires_grad
|
943 |
+
or random.random() > self.prob
|
944 |
+
or self.grad_scale == 0
|
945 |
+
):
|
946 |
+
return _no_op(x)
|
947 |
+
else:
|
948 |
+
if hasattr(self, "min_prob") and random.random() < 0.25:
|
949 |
+
# occasionally switch between min_prob and max_prob, based on whether
|
950 |
+
# we are above or below the threshold.
|
951 |
+
if (
|
952 |
+
_whitening_metric(x.to(torch.float32), self.num_groups)
|
953 |
+
> self.whitening_limit
|
954 |
+
):
|
955 |
+
# there would be a change to the grad.
|
956 |
+
self.prob = self.max_prob
|
957 |
+
else:
|
958 |
+
self.prob = self.min_prob
|
959 |
+
|
960 |
+
return WhiteningPenaltyFunction.apply(
|
961 |
+
x, self.num_groups, self.whitening_limit, self.grad_scale
|
962 |
+
)
|
963 |
+
|
964 |
+
|
965 |
+
class WithLoss(torch.autograd.Function):
|
966 |
+
@staticmethod
|
967 |
+
def forward(ctx, x: Tensor, y: Tensor):
|
968 |
+
ctx.y_shape = y.shape
|
969 |
+
return x
|
970 |
+
|
971 |
+
@staticmethod
|
972 |
+
def backward(ctx, ans_grad: Tensor):
|
973 |
+
return ans_grad, torch.ones(
|
974 |
+
ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device
|
975 |
+
)
|
976 |
+
|
977 |
+
|
978 |
+
def with_loss(x, y):
|
979 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
980 |
+
return x
|
981 |
+
# returns x but adds y.sum() to the loss function.
|
982 |
+
return WithLoss.apply(x, y)
|
983 |
+
|
984 |
+
|
985 |
+
def _no_op(x: Tensor) -> Tensor:
|
986 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
987 |
+
return x
|
988 |
+
else:
|
989 |
+
# a no-op function that will have a node in the autograd graph,
|
990 |
+
# to avoid certain bugs relating to backward hooks
|
991 |
+
return x.chunk(1, dim=-1)[0]
|
992 |
+
|
993 |
+
|
994 |
+
class Identity(torch.nn.Module):
|
995 |
+
def __init__(self):
|
996 |
+
super(Identity, self).__init__()
|
997 |
+
|
998 |
+
def forward(self, x):
|
999 |
+
return _no_op(x)
|
1000 |
+
|
1001 |
+
|
1002 |
+
class MaxEig(torch.nn.Module):
|
1003 |
+
"""
|
1004 |
+
Modifies the backpropped derivatives of a function to try to discourage
|
1005 |
+
that any given direction in activation space accounts for more than
|
1006 |
+
a specified proportion of the covariance (e.g. 0.2).
|
1007 |
+
|
1008 |
+
|
1009 |
+
Args:
|
1010 |
+
num_channels: the number of channels
|
1011 |
+
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
1012 |
+
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
1013 |
+
max_var_per_eig: the maximum proportion of the variance of the
|
1014 |
+
features/channels, after mean subtraction, that can come from
|
1015 |
+
any given eigenvalue.
|
1016 |
+
min_prob: the minimum probability with which we apply this during any invocation
|
1017 |
+
of forward(), assuming last time we applied the constraint it was
|
1018 |
+
not active; supplied for speed.
|
1019 |
+
scale: determines the scale with which we modify the gradients, relative
|
1020 |
+
to the existing / unmodified gradients
|
1021 |
+
"""
|
1022 |
+
|
1023 |
+
def __init__(
|
1024 |
+
self,
|
1025 |
+
num_channels: int,
|
1026 |
+
channel_dim: int,
|
1027 |
+
max_var_per_eig: float = 0.2,
|
1028 |
+
min_prob: float = 0.01,
|
1029 |
+
scale: float = 0.01,
|
1030 |
+
):
|
1031 |
+
super(MaxEig, self).__init__()
|
1032 |
+
self.num_channels = num_channels
|
1033 |
+
self.channel_dim = channel_dim
|
1034 |
+
self.scale = scale
|
1035 |
+
assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels
|
1036 |
+
self.max_var_per_eig = max_var_per_eig
|
1037 |
+
|
1038 |
+
# we figure out the dominant direction using the power method: starting with
|
1039 |
+
# a random vector, keep multiplying by the covariance and renormalizing.
|
1040 |
+
with torch.no_grad():
|
1041 |
+
# arbitrary.. would use randn() but want to leave the rest of the model's
|
1042 |
+
# random parameters unchanged for comparison
|
1043 |
+
direction = torch.arange(num_channels).to(torch.float)
|
1044 |
+
direction = direction / direction.norm()
|
1045 |
+
self.register_buffer("max_eig_direction", direction)
|
1046 |
+
|
1047 |
+
self.min_prob = min_prob
|
1048 |
+
# cur_prob is the current probability we'll use to apply the ActivationBalancer.
|
1049 |
+
# We'll regress this towards prob, each time we try to apply it and it is not
|
1050 |
+
# active.
|
1051 |
+
self.cur_prob = 1.0
|
1052 |
+
|
1053 |
+
def forward(self, x: Tensor) -> Tensor:
|
1054 |
+
if (
|
1055 |
+
torch.jit.is_scripting()
|
1056 |
+
or self.max_var_per_eig <= 0
|
1057 |
+
or random.random() > self.cur_prob
|
1058 |
+
or torch.jit.is_tracing()
|
1059 |
+
):
|
1060 |
+
return _no_op(x)
|
1061 |
+
|
1062 |
+
with torch.cuda.amp.autocast(enabled=False):
|
1063 |
+
eps = 1.0e-20
|
1064 |
+
orig_x = x
|
1065 |
+
x = x.to(torch.float32)
|
1066 |
+
with torch.no_grad():
|
1067 |
+
x = x.transpose(self.channel_dim, -1).reshape(
|
1068 |
+
-1, self.num_channels
|
1069 |
+
)
|
1070 |
+
x = x - x.mean(dim=0)
|
1071 |
+
new_direction, coeffs = self._find_direction_coeffs(
|
1072 |
+
x, self.max_eig_direction
|
1073 |
+
)
|
1074 |
+
x_var = (x ** 2).mean()
|
1075 |
+
x_residual = x - coeffs * new_direction
|
1076 |
+
x_residual_var = (x_residual ** 2).mean()
|
1077 |
+
|
1078 |
+
# `variance_proportion` is the proportion of the variance accounted for
|
1079 |
+
# by the top eigen-direction.
|
1080 |
+
variance_proportion = (x_var - x_residual_var) / (
|
1081 |
+
x_var + 1.0e-20
|
1082 |
+
)
|
1083 |
+
|
1084 |
+
# ensure new direction is nonzero even if x == 0, by including `direction`.
|
1085 |
+
self._set_direction(
|
1086 |
+
0.1 * self.max_eig_direction + new_direction
|
1087 |
+
)
|
1088 |
+
|
1089 |
+
if random.random() < 0.01 or __name__ == "__main__":
|
1090 |
+
logging.info(
|
1091 |
+
f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}"
|
1092 |
+
)
|
1093 |
+
|
1094 |
+
if variance_proportion >= self.max_var_per_eig:
|
1095 |
+
# The constraint is active. Note, we should quite rarely
|
1096 |
+
# reach here, only near the beginning of training if we are
|
1097 |
+
# starting to diverge, should this constraint be active.
|
1098 |
+
cur_prob = self.cur_prob
|
1099 |
+
self.cur_prob = (
|
1100 |
+
1.0 # next time, do the update with probability 1.0.
|
1101 |
+
)
|
1102 |
+
return MaxEigLimiterFunction.apply(
|
1103 |
+
orig_x, coeffs, new_direction, self.channel_dim, self.scale
|
1104 |
+
)
|
1105 |
+
else:
|
1106 |
+
# let self.cur_prob exponentially approach self.min_prob, as
|
1107 |
+
# long as the constraint is inactive.
|
1108 |
+
self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob
|
1109 |
+
return orig_x
|
1110 |
+
|
1111 |
+
def _set_direction(self, direction: Tensor):
|
1112 |
+
"""
|
1113 |
+
Sets self.max_eig_direction to a normalized version of `direction`
|
1114 |
+
"""
|
1115 |
+
direction = direction.detach()
|
1116 |
+
direction = direction / direction.norm()
|
1117 |
+
direction_sum = direction.sum().item()
|
1118 |
+
if direction_sum - direction_sum == 0: # no inf/nan
|
1119 |
+
self.max_eig_direction[:] = direction
|
1120 |
+
else:
|
1121 |
+
logging.info(
|
1122 |
+
f"Warning: sum of direction in MaxEig is {direction_sum}, "
|
1123 |
+
"num_channels={self.num_channels}, channel_dim={self.channel_dim}"
|
1124 |
+
)
|
1125 |
+
|
1126 |
+
def _find_direction_coeffs(
|
1127 |
+
self, x: Tensor, prev_direction: Tensor
|
1128 |
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
1129 |
+
"""
|
1130 |
+
Figure out (an approximation to) the proportion of the variance of a set of
|
1131 |
+
feature vectors that can be attributed to the top eigen-direction.
|
1132 |
+
Args:
|
1133 |
+
x: a Tensor of shape (num_frames, num_channels), with num_frames > 1.
|
1134 |
+
prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
|
1135 |
+
of the top eigen-direction, or a random direction if this is the first
|
1136 |
+
iteration. Does not have to be normalized, but should be nonzero.
|
1137 |
+
|
1138 |
+
Returns: (cur_direction, coeffs), where:
|
1139 |
+
cur_direction: a Tensor of shape (num_channels,) that is the current
|
1140 |
+
estimate of the top eigen-direction.
|
1141 |
+
coeffs: a Tensor of shape (num_frames, 1) that minimizes, or
|
1142 |
+
approximately minimizes, (x - coeffs * cur_direction).norm()
|
1143 |
+
"""
|
1144 |
+
(num_frames, num_channels) = x.shape
|
1145 |
+
assert num_channels > 1 and num_frames > 1
|
1146 |
+
assert prev_direction.shape == (num_channels,)
|
1147 |
+
# `coeffs` are the coefficients of `prev_direction` in x.
|
1148 |
+
# actually represent the coeffs up to a constant positive factor.
|
1149 |
+
coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10
|
1150 |
+
cur_direction = (x * coeffs).sum(dim=0) / (
|
1151 |
+
(coeffs ** 2).sum() + 1.0e-20
|
1152 |
+
)
|
1153 |
+
return cur_direction, coeffs
|
1154 |
+
|
1155 |
+
|
1156 |
+
class DoubleSwishFunction(torch.autograd.Function):
|
1157 |
+
"""
|
1158 |
+
double_swish(x) = x * torch.sigmoid(x-1)
|
1159 |
+
This is a definition, originally motivated by its close numerical
|
1160 |
+
similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
|
1161 |
+
|
1162 |
+
Memory-efficient derivative computation:
|
1163 |
+
double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
|
1164 |
+
double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
|
1165 |
+
Now, s'(x) = s(x) * (1-s(x)).
|
1166 |
+
double_swish'(x) = x * s'(x) + s(x).
|
1167 |
+
= x * s(x) * (1-s(x)) + s(x).
|
1168 |
+
= double_swish(x) * (1-s(x)) + s(x)
|
1169 |
+
... so we just need to remember s(x) but not x itself.
|
1170 |
+
"""
|
1171 |
+
|
1172 |
+
@staticmethod
|
1173 |
+
def forward(ctx, x: Tensor) -> Tensor:
|
1174 |
+
requires_grad = x.requires_grad
|
1175 |
+
x_dtype = x.dtype
|
1176 |
+
if x.dtype == torch.float16:
|
1177 |
+
x = x.to(torch.float32)
|
1178 |
+
|
1179 |
+
s = torch.sigmoid(x - 1.0)
|
1180 |
+
y = x * s
|
1181 |
+
|
1182 |
+
if requires_grad:
|
1183 |
+
deriv = y * (1 - s) + s
|
1184 |
+
# notes on derivative of x * sigmoid(x - 1):
|
1185 |
+
# https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
|
1186 |
+
# min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund
|
1187 |
+
# max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
|
1188 |
+
# the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
|
1189 |
+
# floors), should be expectation-preserving.
|
1190 |
+
floor = -0.043637
|
1191 |
+
ceil = 1.2
|
1192 |
+
d_scaled = (deriv - floor) * (
|
1193 |
+
255.0 / (ceil - floor)
|
1194 |
+
) + torch.rand_like(deriv)
|
1195 |
+
if __name__ == "__main__":
|
1196 |
+
# for self-testing only.
|
1197 |
+
assert d_scaled.min() >= 0.0
|
1198 |
+
assert d_scaled.max() < 256.0
|
1199 |
+
d_int = d_scaled.to(torch.uint8)
|
1200 |
+
ctx.save_for_backward(d_int)
|
1201 |
+
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
1202 |
+
y = y.to(torch.float16)
|
1203 |
+
return y
|
1204 |
+
|
1205 |
+
@staticmethod
|
1206 |
+
def backward(ctx, y_grad: Tensor) -> Tensor:
|
1207 |
+
(d,) = ctx.saved_tensors
|
1208 |
+
# the same constants as used in forward pass.
|
1209 |
+
floor = -0.043637
|
1210 |
+
ceil = 1.2
|
1211 |
+
d = d * ((ceil - floor) / 255.0) + floor
|
1212 |
+
return y_grad * d
|
1213 |
+
|
1214 |
+
|
1215 |
+
class DoubleSwish(torch.nn.Module):
|
1216 |
+
def forward(self, x: Tensor) -> Tensor:
|
1217 |
+
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
1218 |
+
that we approximate closely with x * sigmoid(x-1).
|
1219 |
+
"""
|
1220 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
1221 |
+
return x * torch.sigmoid(x - 1.0)
|
1222 |
+
return DoubleSwishFunction.apply(x)
|
1223 |
+
|
1224 |
+
|
1225 |
+
def BalancedDoubleSwish(
|
1226 |
+
d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
|
1227 |
+
) -> nn.Sequential:
|
1228 |
+
"""
|
1229 |
+
ActivationBalancer -> DoubleSwish
|
1230 |
+
"""
|
1231 |
+
balancer = ActivationBalancer(
|
1232 |
+
d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
|
1233 |
+
)
|
1234 |
+
return nn.Sequential(
|
1235 |
+
balancer,
|
1236 |
+
DoubleSwish(),
|
1237 |
+
)
|
1238 |
+
|
1239 |
+
|
1240 |
+
def _test_max_eig():
|
1241 |
+
for proportion in [0.1, 0.5, 10.0]:
|
1242 |
+
logging.info(f"proportion = {proportion}")
|
1243 |
+
x = torch.randn(100, 128)
|
1244 |
+
direction = torch.randn(128)
|
1245 |
+
coeffs = torch.randn(100, 1)
|
1246 |
+
x += proportion * direction * coeffs
|
1247 |
+
|
1248 |
+
x.requires_grad = True
|
1249 |
+
|
1250 |
+
num_channels = 128
|
1251 |
+
m = MaxEig(
|
1252 |
+
num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig
|
1253 |
+
) # grad_scale
|
1254 |
+
|
1255 |
+
for _ in range(4):
|
1256 |
+
y = m(x)
|
1257 |
+
|
1258 |
+
y_grad = torch.randn_like(x)
|
1259 |
+
y.backward(gradient=y_grad)
|
1260 |
+
|
1261 |
+
if proportion < 0.2:
|
1262 |
+
assert torch.allclose(x.grad, y_grad, atol=1.0e-02)
|
1263 |
+
elif proportion > 1.0:
|
1264 |
+
assert not torch.allclose(x.grad, y_grad)
|
1265 |
+
|
1266 |
+
|
1267 |
+
def _test_whiten():
|
1268 |
+
for proportion in [0.1, 0.5, 10.0]:
|
1269 |
+
logging.info(f"_test_whiten(): proportion = {proportion}")
|
1270 |
+
x = torch.randn(100, 128)
|
1271 |
+
direction = torch.randn(128)
|
1272 |
+
coeffs = torch.randn(100, 1)
|
1273 |
+
x += proportion * direction * coeffs
|
1274 |
+
|
1275 |
+
x.requires_grad = True
|
1276 |
+
|
1277 |
+
num_channels = 128
|
1278 |
+
m = Whiten(
|
1279 |
+
1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit,
|
1280 |
+
) # grad_scale
|
1281 |
+
|
1282 |
+
for _ in range(4):
|
1283 |
+
y = m(x)
|
1284 |
+
|
1285 |
+
y_grad = torch.randn_like(x)
|
1286 |
+
y.backward(gradient=y_grad)
|
1287 |
+
|
1288 |
+
if proportion < 0.2:
|
1289 |
+
assert torch.allclose(x.grad, y_grad)
|
1290 |
+
elif proportion > 1.0:
|
1291 |
+
assert not torch.allclose(x.grad, y_grad)
|
1292 |
+
|
1293 |
+
|
1294 |
+
def _test_activation_balancer_sign():
|
1295 |
+
probs = torch.arange(0, 1, 0.01)
|
1296 |
+
N = 1000
|
1297 |
+
x = 1.0 * (
|
1298 |
+
(2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0
|
1299 |
+
)
|
1300 |
+
x = x.detach()
|
1301 |
+
x.requires_grad = True
|
1302 |
+
m = ActivationBalancer(
|
1303 |
+
probs.numel(),
|
1304 |
+
channel_dim=0,
|
1305 |
+
min_positive=0.05,
|
1306 |
+
max_positive=0.95,
|
1307 |
+
max_factor=0.2,
|
1308 |
+
min_abs=0.0,
|
1309 |
+
)
|
1310 |
+
|
1311 |
+
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
1312 |
+
|
1313 |
+
y = m(x)
|
1314 |
+
y.backward(gradient=y_grad)
|
1315 |
+
print("_test_activation_balancer_sign: x = ", x)
|
1316 |
+
print("_test_activation_balancer_sign: y grad = ", y_grad)
|
1317 |
+
print("_test_activation_balancer_sign: x grad = ", x.grad)
|
1318 |
+
|
1319 |
+
|
1320 |
+
def _test_activation_balancer_magnitude():
|
1321 |
+
magnitudes = torch.arange(0, 1, 0.01)
|
1322 |
+
N = 1000
|
1323 |
+
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(
|
1324 |
+
-1
|
1325 |
+
)
|
1326 |
+
x = x.detach()
|
1327 |
+
x.requires_grad = True
|
1328 |
+
m = ActivationBalancer(
|
1329 |
+
magnitudes.numel(),
|
1330 |
+
channel_dim=0,
|
1331 |
+
min_positive=0.0,
|
1332 |
+
max_positive=1.0,
|
1333 |
+
max_factor=0.2,
|
1334 |
+
min_abs=0.2,
|
1335 |
+
max_abs=0.8,
|
1336 |
+
min_prob=1.0,
|
1337 |
+
)
|
1338 |
+
|
1339 |
+
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
1340 |
+
|
1341 |
+
y = m(x)
|
1342 |
+
y.backward(gradient=y_grad)
|
1343 |
+
print("_test_activation_balancer_magnitude: x = ", x)
|
1344 |
+
print("_test_activation_balancer_magnitude: y grad = ", y_grad)
|
1345 |
+
print("_test_activation_balancer_magnitude: x grad = ", x.grad)
|
1346 |
+
|
1347 |
+
|
1348 |
+
def _test_basic_norm():
|
1349 |
+
num_channels = 128
|
1350 |
+
m = BasicNorm(num_channels=num_channels, channel_dim=1)
|
1351 |
+
|
1352 |
+
x = torch.randn(500, num_channels)
|
1353 |
+
|
1354 |
+
y = m(x)
|
1355 |
+
|
1356 |
+
assert y.shape == x.shape
|
1357 |
+
x_rms = (x ** 2).mean().sqrt()
|
1358 |
+
y_rms = (y ** 2).mean().sqrt()
|
1359 |
+
print("x rms = ", x_rms)
|
1360 |
+
print("y rms = ", y_rms)
|
1361 |
+
assert y_rms < x_rms
|
1362 |
+
assert y_rms > 0.5 * x_rms
|
1363 |
+
|
1364 |
+
|
1365 |
+
def _test_double_swish_deriv():
|
1366 |
+
x = torch.randn(10, 12, dtype=torch.double) * 3.0
|
1367 |
+
x.requires_grad = True
|
1368 |
+
m = DoubleSwish()
|
1369 |
+
|
1370 |
+
tol = (1.2 - (-0.043637)) / 255.0
|
1371 |
+
torch.autograd.gradcheck(m, x, atol=tol)
|
1372 |
+
|
1373 |
+
# for self-test.
|
1374 |
+
x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
|
1375 |
+
x.requires_grad = True
|
1376 |
+
y = m(x)
|
1377 |
+
|
1378 |
+
|
1379 |
+
def _test_softmax():
|
1380 |
+
a = torch.randn(2, 10, dtype=torch.float64)
|
1381 |
+
b = a.clone()
|
1382 |
+
a.requires_grad = True
|
1383 |
+
b.requires_grad = True
|
1384 |
+
a.softmax(dim=1)[:, 0].sum().backward()
|
1385 |
+
print("a grad = ", a.grad)
|
1386 |
+
softmax(b, dim=1)[:, 0].sum().backward()
|
1387 |
+
print("b grad = ", b.grad)
|
1388 |
+
assert torch.allclose(a.grad, b.grad)
|
1389 |
+
|
1390 |
+
|
1391 |
+
if __name__ == "__main__":
|
1392 |
+
logging.getLogger().setLevel(logging.INFO)
|
1393 |
+
torch.set_num_threads(1)
|
1394 |
+
torch.set_num_interop_threads(1)
|
1395 |
+
_test_softmax()
|
1396 |
+
_test_whiten()
|
1397 |
+
_test_max_eig()
|
1398 |
+
_test_activation_balancer_sign()
|
1399 |
+
_test_activation_balancer_magnitude()
|
1400 |
+
_test_basic_norm()
|
1401 |
+
_test_double_swish_deriv()
|
modules/scheduler.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2023 (authors: Feiteng Li)
|
3 |
+
#
|
4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
|
19 |
+
import torch
|
20 |
+
|
21 |
+
from modules.optim import Eden
|
22 |
+
|
23 |
+
|
24 |
+
def calc_lr(step, dim_embed, warmup_steps):
|
25 |
+
return dim_embed ** (-0.5) * min(
|
26 |
+
step ** (-0.5), step * warmup_steps ** (-1.5)
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
class NoamScheduler(torch.optim.lr_scheduler._LRScheduler):
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
base_lr: float,
|
34 |
+
optimizer: torch.optim.Optimizer,
|
35 |
+
dim_embed: int,
|
36 |
+
warmup_steps: int,
|
37 |
+
last_epoch: int = -1,
|
38 |
+
verbose: bool = False,
|
39 |
+
) -> None:
|
40 |
+
|
41 |
+
self.dim_embed = dim_embed
|
42 |
+
self.base_lr = base_lr
|
43 |
+
self.warmup_steps = warmup_steps
|
44 |
+
self.num_param_groups = len(optimizer.param_groups)
|
45 |
+
|
46 |
+
super().__init__(optimizer, last_epoch, verbose)
|
47 |
+
|
48 |
+
def get_lr(self) -> float:
|
49 |
+
lr = self.base_lr * calc_lr(
|
50 |
+
self._step_count, self.dim_embed, self.warmup_steps
|
51 |
+
)
|
52 |
+
return [lr] * self.num_param_groups
|
53 |
+
|
54 |
+
def set_step(self, step: int):
|
55 |
+
self._step_count = step
|
56 |
+
|
57 |
+
|
58 |
+
def get_scheduler(params, optimizer):
|
59 |
+
if params.scheduler_name.lower() == "eden":
|
60 |
+
scheduler = Eden(optimizer, 5000, 4, warmup_batches=params.warmup_steps)
|
61 |
+
elif params.scheduler_name.lower() == "noam":
|
62 |
+
scheduler = NoamScheduler(
|
63 |
+
params.base_lr,
|
64 |
+
optimizer,
|
65 |
+
params.decoder_dim,
|
66 |
+
warmup_steps=params.warmup_steps,
|
67 |
+
)
|
68 |
+
# scheduler.set_step(params.start_batch or params.batch_idx_train)
|
69 |
+
elif params.scheduler_name.lower() == "cosine":
|
70 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
71 |
+
params.warmup_steps,
|
72 |
+
optimizer,
|
73 |
+
eta_min=params.base_lr,
|
74 |
+
)
|
75 |
+
else:
|
76 |
+
raise NotImplementedError(f"{params.scheduler_name}")
|
77 |
+
|
78 |
+
return scheduler
|
modules/transformer.py
ADDED
@@ -0,0 +1,683 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import numbers
|
3 |
+
from functools import partial
|
4 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import Tensor, nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
|
10 |
+
from .activation import MultiheadAttention
|
11 |
+
from .scaling import ActivationBalancer, BalancedDoubleSwish
|
12 |
+
from .scaling import BasicNorm as _BasicNorm
|
13 |
+
|
14 |
+
_shape_t = Union[int, List[int], torch.Size]
|
15 |
+
|
16 |
+
|
17 |
+
class LayerNorm(nn.Module):
|
18 |
+
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
|
19 |
+
normalized_shape: Tuple[int, ...]
|
20 |
+
eps: float
|
21 |
+
elementwise_affine: bool
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
normalized_shape: _shape_t,
|
26 |
+
eps: float = 1e-5,
|
27 |
+
elementwise_affine: bool = True,
|
28 |
+
device=None,
|
29 |
+
dtype=None,
|
30 |
+
) -> None:
|
31 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
32 |
+
super(LayerNorm, self).__init__()
|
33 |
+
if isinstance(normalized_shape, numbers.Integral):
|
34 |
+
# mypy error: incompatible types in assignment
|
35 |
+
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
36 |
+
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
37 |
+
self.eps = eps
|
38 |
+
self.elementwise_affine = elementwise_affine
|
39 |
+
if self.elementwise_affine:
|
40 |
+
self.weight = nn.Parameter(
|
41 |
+
torch.empty(self.normalized_shape, **factory_kwargs)
|
42 |
+
)
|
43 |
+
self.bias = nn.Parameter(
|
44 |
+
torch.empty(self.normalized_shape, **factory_kwargs)
|
45 |
+
)
|
46 |
+
else:
|
47 |
+
self.register_parameter("weight", None)
|
48 |
+
self.register_parameter("bias", None)
|
49 |
+
|
50 |
+
self.reset_parameters()
|
51 |
+
|
52 |
+
def reset_parameters(self) -> None:
|
53 |
+
if self.elementwise_affine:
|
54 |
+
nn.init.ones_(self.weight)
|
55 |
+
nn.init.zeros_(self.bias)
|
56 |
+
|
57 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
58 |
+
if isinstance(input, tuple):
|
59 |
+
input, embedding = input
|
60 |
+
return (
|
61 |
+
F.layer_norm(
|
62 |
+
input,
|
63 |
+
self.normalized_shape,
|
64 |
+
self.weight,
|
65 |
+
self.bias,
|
66 |
+
self.eps,
|
67 |
+
),
|
68 |
+
embedding,
|
69 |
+
)
|
70 |
+
|
71 |
+
assert embedding is None
|
72 |
+
return F.layer_norm(
|
73 |
+
input, self.normalized_shape, self.weight, self.bias, self.eps
|
74 |
+
)
|
75 |
+
|
76 |
+
def extra_repr(self) -> str:
|
77 |
+
return (
|
78 |
+
"{normalized_shape}, eps={eps}, "
|
79 |
+
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
80 |
+
)
|
81 |
+
|
82 |
+
|
83 |
+
class AdaptiveLayerNorm(nn.Module):
|
84 |
+
r"""Adaptive Layer Normalization"""
|
85 |
+
|
86 |
+
def __init__(self, d_model, norm) -> None:
|
87 |
+
super(AdaptiveLayerNorm, self).__init__()
|
88 |
+
self.project_layer = nn.Linear(d_model, 2 * d_model)
|
89 |
+
self.norm = norm
|
90 |
+
self.d_model = d_model
|
91 |
+
self.eps = self.norm.eps
|
92 |
+
|
93 |
+
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
|
94 |
+
if isinstance(input, tuple):
|
95 |
+
input, embedding = input
|
96 |
+
weight, bias = torch.split(
|
97 |
+
self.project_layer(embedding),
|
98 |
+
split_size_or_sections=self.d_model,
|
99 |
+
dim=-1,
|
100 |
+
)
|
101 |
+
return (weight * self.norm(input) + bias, embedding)
|
102 |
+
|
103 |
+
weight, bias = torch.split(
|
104 |
+
self.project_layer(embedding),
|
105 |
+
split_size_or_sections=self.d_model,
|
106 |
+
dim=-1,
|
107 |
+
)
|
108 |
+
return weight * self.norm(input) + bias
|
109 |
+
|
110 |
+
|
111 |
+
class BasicNorm(_BasicNorm):
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
d_model: int,
|
115 |
+
eps: float = 1e-5,
|
116 |
+
device=None,
|
117 |
+
dtype=None,
|
118 |
+
):
|
119 |
+
super(BasicNorm, self).__init__(d_model, eps=eps)
|
120 |
+
|
121 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
122 |
+
if isinstance(input, tuple):
|
123 |
+
input, embedding = input
|
124 |
+
return (
|
125 |
+
super(BasicNorm, self).forward(input),
|
126 |
+
embedding,
|
127 |
+
)
|
128 |
+
|
129 |
+
assert embedding is None
|
130 |
+
return super(BasicNorm, self).forward(input)
|
131 |
+
|
132 |
+
|
133 |
+
class BalancedBasicNorm(nn.Module):
|
134 |
+
def __init__(
|
135 |
+
self,
|
136 |
+
d_model: int,
|
137 |
+
eps: float = 1e-5,
|
138 |
+
device=None,
|
139 |
+
dtype=None,
|
140 |
+
):
|
141 |
+
super(BalancedBasicNorm, self).__init__()
|
142 |
+
self.balancer = ActivationBalancer(
|
143 |
+
d_model,
|
144 |
+
channel_dim=-1,
|
145 |
+
min_positive=0.45,
|
146 |
+
max_positive=0.55,
|
147 |
+
max_abs=6.0,
|
148 |
+
)
|
149 |
+
self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype)
|
150 |
+
|
151 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
152 |
+
if isinstance(input, tuple):
|
153 |
+
input, embedding = input
|
154 |
+
return self.norm((self.balancer(input), embedding))
|
155 |
+
|
156 |
+
assert embedding is None
|
157 |
+
return self.norm(self.balancer(input))
|
158 |
+
|
159 |
+
|
160 |
+
class IdentityNorm(nn.Module):
|
161 |
+
def __init__(
|
162 |
+
self,
|
163 |
+
d_model: int,
|
164 |
+
eps: float = 1e-5,
|
165 |
+
device=None,
|
166 |
+
dtype=None,
|
167 |
+
) -> None:
|
168 |
+
super(IdentityNorm, self).__init__()
|
169 |
+
|
170 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
171 |
+
if isinstance(input, tuple):
|
172 |
+
return input
|
173 |
+
|
174 |
+
assert embedding is None
|
175 |
+
return input
|
176 |
+
|
177 |
+
|
178 |
+
class TransformerEncoderLayer(nn.Module):
|
179 |
+
__constants__ = ["batch_first", "norm_first"]
|
180 |
+
|
181 |
+
def __init__(
|
182 |
+
self,
|
183 |
+
d_model: int,
|
184 |
+
nhead: int,
|
185 |
+
dim_feedforward: int = 2048,
|
186 |
+
dropout: float = 0.1,
|
187 |
+
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
188 |
+
batch_first: bool = False,
|
189 |
+
norm_first: bool = False,
|
190 |
+
device=None,
|
191 |
+
dtype=None,
|
192 |
+
linear1_self_attention_cls: nn.Module = nn.Linear,
|
193 |
+
linear2_self_attention_cls: nn.Module = nn.Linear,
|
194 |
+
linear1_feedforward_cls: nn.Module = nn.Linear,
|
195 |
+
linear2_feedforward_cls: nn.Module = nn.Linear,
|
196 |
+
layer_norm_cls: nn.Module = LayerNorm,
|
197 |
+
layer_norm_eps: float = 1e-5,
|
198 |
+
adaptive_layer_norm=False,
|
199 |
+
) -> None:
|
200 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
201 |
+
super(TransformerEncoderLayer, self).__init__()
|
202 |
+
self.self_attn = MultiheadAttention(
|
203 |
+
d_model,
|
204 |
+
nhead,
|
205 |
+
dropout=dropout,
|
206 |
+
batch_first=batch_first,
|
207 |
+
linear1_cls=linear1_self_attention_cls,
|
208 |
+
linear2_cls=linear2_self_attention_cls,
|
209 |
+
**factory_kwargs,
|
210 |
+
)
|
211 |
+
|
212 |
+
# Implementation of Feedforward model
|
213 |
+
self.linear1 = linear1_feedforward_cls(
|
214 |
+
d_model, dim_feedforward, **factory_kwargs
|
215 |
+
)
|
216 |
+
self.dropout = nn.Dropout(dropout)
|
217 |
+
self.linear2 = linear2_feedforward_cls(
|
218 |
+
dim_feedforward, d_model, **factory_kwargs
|
219 |
+
)
|
220 |
+
|
221 |
+
self.norm_first = norm_first
|
222 |
+
self.dropout1 = nn.Dropout(dropout)
|
223 |
+
self.dropout2 = nn.Dropout(dropout)
|
224 |
+
|
225 |
+
# Legacy string support for activation function.
|
226 |
+
if isinstance(activation, str):
|
227 |
+
activation = _get_activation_fn(activation)
|
228 |
+
elif isinstance(activation, partial):
|
229 |
+
activation = activation(d_model)
|
230 |
+
elif activation == BalancedDoubleSwish:
|
231 |
+
activation = BalancedDoubleSwish(d_model)
|
232 |
+
|
233 |
+
# # We can't test self.activation in forward() in TorchScript,
|
234 |
+
# # so stash some information about it instead.
|
235 |
+
# if activation is F.relu or isinstance(activation, torch.nn.ReLU):
|
236 |
+
# self.activation_relu_or_gelu = 1
|
237 |
+
# elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
|
238 |
+
# self.activation_relu_or_gelu = 2
|
239 |
+
# else:
|
240 |
+
# self.activation_relu_or_gelu = 0
|
241 |
+
self.activation = activation
|
242 |
+
|
243 |
+
norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
|
244 |
+
if layer_norm_cls == IdentityNorm:
|
245 |
+
norm2 = BalancedBasicNorm(
|
246 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
247 |
+
)
|
248 |
+
else:
|
249 |
+
norm2 = layer_norm_cls(
|
250 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
251 |
+
)
|
252 |
+
|
253 |
+
if adaptive_layer_norm:
|
254 |
+
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
|
255 |
+
self.norm2 = AdaptiveLayerNorm(d_model, norm2)
|
256 |
+
else:
|
257 |
+
self.norm1 = norm1
|
258 |
+
self.norm2 = norm2
|
259 |
+
|
260 |
+
def __setstate__(self, state):
|
261 |
+
super(TransformerEncoderLayer, self).__setstate__(state)
|
262 |
+
if not hasattr(self, "activation"):
|
263 |
+
self.activation = F.relu
|
264 |
+
|
265 |
+
def forward(
|
266 |
+
self,
|
267 |
+
src: Tensor,
|
268 |
+
src_mask: Optional[Tensor] = None,
|
269 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
270 |
+
) -> Tensor:
|
271 |
+
r"""Pass the input through the encoder layer.
|
272 |
+
|
273 |
+
Args:
|
274 |
+
src: the sequence to the encoder layer (required).
|
275 |
+
src_mask: the mask for the src sequence (optional).
|
276 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
277 |
+
|
278 |
+
Shape:
|
279 |
+
see the docs in Transformer class.
|
280 |
+
"""
|
281 |
+
x, stage_embedding = src, None
|
282 |
+
is_src_tuple = False
|
283 |
+
if isinstance(src, tuple):
|
284 |
+
x, stage_embedding = src
|
285 |
+
is_src_tuple = True
|
286 |
+
|
287 |
+
if src_key_padding_mask is not None:
|
288 |
+
_skpm_dtype = src_key_padding_mask.dtype
|
289 |
+
if _skpm_dtype != torch.bool and not torch.is_floating_point(
|
290 |
+
src_key_padding_mask
|
291 |
+
):
|
292 |
+
raise AssertionError(
|
293 |
+
"only bool and floating types of key_padding_mask are supported"
|
294 |
+
)
|
295 |
+
|
296 |
+
if self.norm_first:
|
297 |
+
x = x + self._sa_block(
|
298 |
+
self.norm1(x, stage_embedding),
|
299 |
+
src_mask,
|
300 |
+
src_key_padding_mask,
|
301 |
+
)
|
302 |
+
x = x + self._ff_block(self.norm2(x, stage_embedding))
|
303 |
+
else:
|
304 |
+
x = self.norm1(
|
305 |
+
x + self._sa_block(x, src_mask, src_key_padding_mask),
|
306 |
+
stage_embedding,
|
307 |
+
)
|
308 |
+
x = self.norm2(x + self._ff_block(x), stage_embedding)
|
309 |
+
|
310 |
+
if is_src_tuple:
|
311 |
+
return (x, stage_embedding)
|
312 |
+
return x
|
313 |
+
|
314 |
+
def infer(
|
315 |
+
self,
|
316 |
+
src: Tensor,
|
317 |
+
src_mask: Optional[Tensor] = None,
|
318 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
319 |
+
past_kv: Optional[Tensor] = None,
|
320 |
+
use_cache: bool = False,
|
321 |
+
):
|
322 |
+
x, stage_embedding = src, None
|
323 |
+
is_src_tuple = False
|
324 |
+
if isinstance(src, tuple):
|
325 |
+
x, stage_embedding = src
|
326 |
+
is_src_tuple = True
|
327 |
+
|
328 |
+
if src_key_padding_mask is not None:
|
329 |
+
_skpm_dtype = src_key_padding_mask.dtype
|
330 |
+
if _skpm_dtype != torch.bool and not torch.is_floating_point(
|
331 |
+
src_key_padding_mask
|
332 |
+
):
|
333 |
+
raise AssertionError(
|
334 |
+
"only bool and floating types of key_padding_mask are supported"
|
335 |
+
)
|
336 |
+
|
337 |
+
if self.norm_first:
|
338 |
+
x_attn_out, kv = self.self_attn.infer(
|
339 |
+
self.norm1(x, stage_embedding),
|
340 |
+
attn_mask=src_mask,
|
341 |
+
key_padding_mask=src_key_padding_mask,
|
342 |
+
need_weights=False,
|
343 |
+
past_kv=past_kv,
|
344 |
+
use_cache=use_cache,
|
345 |
+
)
|
346 |
+
x = x + x_attn_out
|
347 |
+
x = x + self._ff_block(self.norm2(x, stage_embedding))
|
348 |
+
|
349 |
+
if is_src_tuple:
|
350 |
+
return (x, stage_embedding)
|
351 |
+
return (x, kv)
|
352 |
+
|
353 |
+
# self-attention block
|
354 |
+
def _sa_block(
|
355 |
+
self,
|
356 |
+
x: Tensor,
|
357 |
+
attn_mask: Optional[Tensor],
|
358 |
+
key_padding_mask: Optional[Tensor],
|
359 |
+
) -> Tensor:
|
360 |
+
x = self.self_attn(
|
361 |
+
x,
|
362 |
+
x,
|
363 |
+
x,
|
364 |
+
attn_mask=attn_mask,
|
365 |
+
key_padding_mask=key_padding_mask,
|
366 |
+
need_weights=False,
|
367 |
+
)[0]
|
368 |
+
return self.dropout1(x)
|
369 |
+
|
370 |
+
# feed forward block
|
371 |
+
def _ff_block(self, x: Tensor) -> Tensor:
|
372 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
373 |
+
return self.dropout2(x)
|
374 |
+
|
375 |
+
|
376 |
+
class TransformerEncoder(nn.Module):
|
377 |
+
r"""TransformerEncoder is a stack of N encoder layers. Users can build the
|
378 |
+
BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
|
379 |
+
|
380 |
+
Args:
|
381 |
+
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
|
382 |
+
num_layers: the number of sub-encoder-layers in the encoder (required).
|
383 |
+
norm: the layer normalization component (optional).
|
384 |
+
enable_nested_tensor: if True, input will automatically convert to nested tensor
|
385 |
+
(and convert back on output). This will improve the overall performance of
|
386 |
+
TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
|
387 |
+
|
388 |
+
Examples::
|
389 |
+
>>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
|
390 |
+
>>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
|
391 |
+
>>> src = torch.rand(10, 32, 512)
|
392 |
+
>>> out = transformer_encoder(src)
|
393 |
+
"""
|
394 |
+
__constants__ = ["norm"]
|
395 |
+
|
396 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
397 |
+
super(TransformerEncoder, self).__init__()
|
398 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
399 |
+
self.num_layers = num_layers
|
400 |
+
self.norm = norm
|
401 |
+
|
402 |
+
def forward(
|
403 |
+
self,
|
404 |
+
src: Tensor,
|
405 |
+
mask: Optional[Tensor] = None,
|
406 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
407 |
+
return_layer_states: bool = False,
|
408 |
+
) -> Tensor:
|
409 |
+
r"""Pass the input through the encoder layers in turn.
|
410 |
+
|
411 |
+
Args:
|
412 |
+
src: the sequence to the encoder (required).
|
413 |
+
mask: the mask for the src sequence (optional).
|
414 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
415 |
+
return_layer_states: return layers' state (optional).
|
416 |
+
|
417 |
+
Shape:
|
418 |
+
see the docs in Transformer class.
|
419 |
+
"""
|
420 |
+
if return_layer_states:
|
421 |
+
layer_states = [] # layers' output
|
422 |
+
output = src
|
423 |
+
for mod in self.layers:
|
424 |
+
output = mod(
|
425 |
+
output,
|
426 |
+
src_mask=mask,
|
427 |
+
src_key_padding_mask=src_key_padding_mask,
|
428 |
+
)
|
429 |
+
layer_states.append(output[0])
|
430 |
+
|
431 |
+
if self.norm is not None:
|
432 |
+
output = self.norm(output)
|
433 |
+
|
434 |
+
return layer_states, output
|
435 |
+
|
436 |
+
output = src
|
437 |
+
for mod in self.layers:
|
438 |
+
output = mod(
|
439 |
+
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
|
440 |
+
)
|
441 |
+
|
442 |
+
if self.norm is not None:
|
443 |
+
output = self.norm(output)
|
444 |
+
|
445 |
+
return output
|
446 |
+
|
447 |
+
def infer(
|
448 |
+
self,
|
449 |
+
src: Tensor,
|
450 |
+
mask: Optional[Tensor] = None,
|
451 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
452 |
+
return_layer_states: bool = False,
|
453 |
+
past_kv: Optional[Tensor] = None,
|
454 |
+
use_cache: bool = False,
|
455 |
+
):
|
456 |
+
if past_kv is None:
|
457 |
+
past_length = 0
|
458 |
+
past_kv = tuple([None] * self.num_layers)
|
459 |
+
else:
|
460 |
+
past_length = past_kv[0][0].size(-2)
|
461 |
+
new_kv = () if use_cache else None
|
462 |
+
output = src
|
463 |
+
for mod, past_layer_kv in zip(self.layers, past_kv):
|
464 |
+
output, kv = mod.infer(
|
465 |
+
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, past_kv=past_layer_kv, use_cache=use_cache
|
466 |
+
)
|
467 |
+
if use_cache:
|
468 |
+
new_kv = new_kv + (kv,)
|
469 |
+
|
470 |
+
if self.norm is not None:
|
471 |
+
output = self.norm(output)
|
472 |
+
|
473 |
+
return output, new_kv
|
474 |
+
|
475 |
+
|
476 |
+
class TransformerDecoderLayer(nn.Module):
|
477 |
+
__constants__ = ["batch_first", "norm_first"]
|
478 |
+
|
479 |
+
def __init__(
|
480 |
+
self,
|
481 |
+
d_model: int,
|
482 |
+
nhead: int,
|
483 |
+
dim_feedforward: int = 2048,
|
484 |
+
dropout: float = 0.1,
|
485 |
+
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
486 |
+
linear1_self_attention_cls: nn.Module = nn.Linear,
|
487 |
+
linear2_self_attention_cls: nn.Module = nn.Linear,
|
488 |
+
linear1_feedforward_cls: nn.Module = nn.Linear,
|
489 |
+
linear2_feedforward_cls: nn.Module = nn.Linear,
|
490 |
+
batch_first: bool = False,
|
491 |
+
norm_first: bool = False,
|
492 |
+
device=None,
|
493 |
+
dtype=None,
|
494 |
+
layer_norm_cls: nn.Module = LayerNorm,
|
495 |
+
layer_norm_eps: float = 1e-5,
|
496 |
+
adaptive_layer_norm=False,
|
497 |
+
) -> None:
|
498 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
499 |
+
super(TransformerDecoderLayer, self).__init__()
|
500 |
+
self.self_attn = MultiheadAttention(
|
501 |
+
d_model,
|
502 |
+
nhead,
|
503 |
+
dropout=dropout,
|
504 |
+
batch_first=batch_first,
|
505 |
+
linear1_cls=linear1_self_attention_cls,
|
506 |
+
linear2_cls=linear2_self_attention_cls,
|
507 |
+
**factory_kwargs,
|
508 |
+
)
|
509 |
+
self.multihead_attn = MultiheadAttention(
|
510 |
+
d_model,
|
511 |
+
nhead,
|
512 |
+
dropout=dropout,
|
513 |
+
batch_first=batch_first,
|
514 |
+
linear1_cls=linear1_self_attention_cls,
|
515 |
+
linear2_cls=linear2_self_attention_cls,
|
516 |
+
**factory_kwargs,
|
517 |
+
)
|
518 |
+
# Implementation of Feedforward model
|
519 |
+
self.linear1 = linear1_feedforward_cls(
|
520 |
+
d_model, dim_feedforward, **factory_kwargs
|
521 |
+
)
|
522 |
+
self.dropout = nn.Dropout(dropout)
|
523 |
+
self.linear2 = linear2_feedforward_cls(
|
524 |
+
dim_feedforward, d_model, **factory_kwargs
|
525 |
+
)
|
526 |
+
|
527 |
+
self.norm_first = norm_first
|
528 |
+
self.dropout1 = nn.Dropout(dropout)
|
529 |
+
self.dropout2 = nn.Dropout(dropout)
|
530 |
+
self.dropout3 = nn.Dropout(dropout)
|
531 |
+
|
532 |
+
# Legacy string support for activation function.
|
533 |
+
if isinstance(activation, str):
|
534 |
+
self.activation = _get_activation_fn(activation)
|
535 |
+
elif isinstance(activation, partial):
|
536 |
+
self.activation = activation(d_model)
|
537 |
+
elif activation == BalancedDoubleSwish:
|
538 |
+
self.activation = BalancedDoubleSwish(d_model)
|
539 |
+
else:
|
540 |
+
self.activation = activation
|
541 |
+
|
542 |
+
if adaptive_layer_norm:
|
543 |
+
norm1 = layer_norm_cls(
|
544 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
545 |
+
)
|
546 |
+
norm2 = layer_norm_cls(
|
547 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
548 |
+
)
|
549 |
+
norm3 = layer_norm_cls(
|
550 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
551 |
+
)
|
552 |
+
|
553 |
+
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
|
554 |
+
self.norm2 = AdaptiveLayerNorm(d_model, norm2)
|
555 |
+
self.norm3 = AdaptiveLayerNorm(d_model, norm3)
|
556 |
+
else:
|
557 |
+
self.norm1 = layer_norm_cls(
|
558 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
559 |
+
)
|
560 |
+
self.norm2 = layer_norm_cls(
|
561 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
562 |
+
)
|
563 |
+
if layer_norm_cls == IdentityNorm:
|
564 |
+
self.norm3 = BalancedBasicNorm(
|
565 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
566 |
+
)
|
567 |
+
else:
|
568 |
+
self.norm3 = layer_norm_cls(
|
569 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
570 |
+
)
|
571 |
+
|
572 |
+
def forward(
|
573 |
+
self,
|
574 |
+
tgt: Tensor,
|
575 |
+
memory: Tensor,
|
576 |
+
tgt_mask: Optional[Tensor] = None,
|
577 |
+
memory_mask: Optional[Tensor] = None,
|
578 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
579 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
580 |
+
) -> Tensor:
|
581 |
+
r"""Pass the inputs (and mask) through the decoder layer.
|
582 |
+
|
583 |
+
Args:
|
584 |
+
tgt: the sequence to the decoder layer (required).
|
585 |
+
memory: the sequence from the last layer of the encoder (required).
|
586 |
+
tgt_mask: the mask for the tgt sequence (optional).
|
587 |
+
memory_mask: the mask for the memory sequence (optional).
|
588 |
+
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
589 |
+
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
590 |
+
|
591 |
+
Shape:
|
592 |
+
see the docs in Transformer class.
|
593 |
+
"""
|
594 |
+
tgt_is_tuple = False
|
595 |
+
if isinstance(tgt, tuple):
|
596 |
+
x, stage_embedding = tgt
|
597 |
+
tgt_is_tuple = True
|
598 |
+
else:
|
599 |
+
x, stage_embedding = tgt, None
|
600 |
+
|
601 |
+
if self.norm_first:
|
602 |
+
x = x + self._sa_block(
|
603 |
+
self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask
|
604 |
+
)
|
605 |
+
x = x + self._mha_block(
|
606 |
+
self.norm2(x, stage_embedding),
|
607 |
+
memory,
|
608 |
+
memory_mask,
|
609 |
+
memory_key_padding_mask,
|
610 |
+
)
|
611 |
+
x = x + self._ff_block(self.norm3(x, stage_embedding))
|
612 |
+
else:
|
613 |
+
x = self.norm1(
|
614 |
+
x + self._sa_block(x, tgt_mask, tgt_key_padding_mask),
|
615 |
+
stage_embedding,
|
616 |
+
)
|
617 |
+
x = self.norm2(
|
618 |
+
x
|
619 |
+
+ self._mha_block(
|
620 |
+
x, memory, memory_mask, memory_key_padding_mask
|
621 |
+
),
|
622 |
+
stage_embedding,
|
623 |
+
)
|
624 |
+
x = self.norm3(x + self._ff_block(x), stage_embedding)
|
625 |
+
|
626 |
+
if tgt_is_tuple:
|
627 |
+
return (x, stage_embedding)
|
628 |
+
return x
|
629 |
+
|
630 |
+
# self-attention block
|
631 |
+
def _sa_block(
|
632 |
+
self,
|
633 |
+
x: Tensor,
|
634 |
+
attn_mask: Optional[Tensor],
|
635 |
+
key_padding_mask: Optional[Tensor],
|
636 |
+
) -> Tensor:
|
637 |
+
x = self.self_attn(
|
638 |
+
x,
|
639 |
+
x,
|
640 |
+
x,
|
641 |
+
attn_mask=attn_mask,
|
642 |
+
key_padding_mask=key_padding_mask,
|
643 |
+
need_weights=False,
|
644 |
+
)[0]
|
645 |
+
return self.dropout1(x)
|
646 |
+
|
647 |
+
# multihead attention block
|
648 |
+
def _mha_block(
|
649 |
+
self,
|
650 |
+
x: Tensor,
|
651 |
+
mem: Tensor,
|
652 |
+
attn_mask: Optional[Tensor],
|
653 |
+
key_padding_mask: Optional[Tensor],
|
654 |
+
) -> Tensor:
|
655 |
+
x = self.multihead_attn(
|
656 |
+
x,
|
657 |
+
mem,
|
658 |
+
mem,
|
659 |
+
attn_mask=attn_mask,
|
660 |
+
key_padding_mask=key_padding_mask,
|
661 |
+
need_weights=False,
|
662 |
+
)[0]
|
663 |
+
return self.dropout2(x)
|
664 |
+
|
665 |
+
# feed forward block
|
666 |
+
def _ff_block(self, x: Tensor) -> Tensor:
|
667 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
668 |
+
return self.dropout3(x)
|
669 |
+
|
670 |
+
|
671 |
+
def _get_clones(module, N):
|
672 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
673 |
+
|
674 |
+
|
675 |
+
def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
|
676 |
+
if activation == "relu":
|
677 |
+
return F.relu
|
678 |
+
elif activation == "gelu":
|
679 |
+
return F.gelu
|
680 |
+
|
681 |
+
raise RuntimeError(
|
682 |
+
"activation should be relu/gelu, not {}".format(activation)
|
683 |
+
)
|
prompts/promptsf
ADDED
File without changes
|
requirements.txt
CHANGED
@@ -33,5 +33,5 @@ pytest
|
|
33 |
fastapi-cors
|
34 |
sqlalchemy
|
35 |
sqlalchemy.orm
|
36 |
-
|
37 |
git+https://github.com/Plachtaa/VALL-E-X.git
|
|
|
|
33 |
fastapi-cors
|
34 |
sqlalchemy
|
35 |
sqlalchemy.orm
|
|
|
36 |
git+https://github.com/Plachtaa/VALL-E-X.git
|
37 |
+
|
utils/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
# from icefall.utils import make_pad_mask
|
4 |
+
|
5 |
+
from .symbol_table import SymbolTable
|
6 |
+
|
7 |
+
# make_pad_mask = make_pad_mask
|
8 |
+
SymbolTable = SymbolTable
|
9 |
+
|
10 |
+
|
11 |
+
class Transpose(nn.Identity):
|
12 |
+
"""(N, T, D) -> (N, D, T)"""
|
13 |
+
|
14 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
15 |
+
return input.transpose(1, 2)
|
utils/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (915 Bytes). View file
|
|
utils/__pycache__/generation.cpython-311.pyc
ADDED
Binary file (15.1 kB). View file
|
|
utils/__pycache__/prompt_making.cpython-311.pyc
ADDED
Binary file (7 kB). View file
|
|
utils/__pycache__/sentence_cutter.cpython-311.pyc
ADDED
Binary file (3.5 kB). View file
|
|
utils/__pycache__/symbol_table.cpython-311.pyc
ADDED
Binary file (12.8 kB). View file
|
|
utils/download.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import requests
|
3 |
+
|
4 |
+
|
5 |
+
def download_file_from_google_drive(id, destination):
|
6 |
+
URL = "https://docs.google.com/uc?export=download&confirm=1"
|
7 |
+
|
8 |
+
session = requests.Session()
|
9 |
+
|
10 |
+
response = session.get(URL, params={"id": id}, stream=True)
|
11 |
+
token = get_confirm_token(response)
|
12 |
+
|
13 |
+
if token:
|
14 |
+
params = {"id": id, "confirm": token}
|
15 |
+
response = session.get(URL, params=params, stream=True)
|
16 |
+
|
17 |
+
save_response_content(response, destination)
|
18 |
+
|
19 |
+
|
20 |
+
def get_confirm_token(response):
|
21 |
+
for key, value in response.cookies.items():
|
22 |
+
if key.startswith("download_warning"):
|
23 |
+
return value
|
24 |
+
|
25 |
+
return None
|
26 |
+
|
27 |
+
|
28 |
+
def save_response_content(response, destination):
|
29 |
+
CHUNK_SIZE = 32768
|
30 |
+
|
31 |
+
with open(destination, "wb", encoding='utf-8') as f:
|
32 |
+
for chunk in response.iter_content(CHUNK_SIZE):
|
33 |
+
if chunk: # filter out keep-alive new chunks
|
34 |
+
f.write(chunk)
|
35 |
+
|
36 |
+
|
37 |
+
def main():
|
38 |
+
if len(sys.argv) >= 3:
|
39 |
+
file_id = sys.argv[1]
|
40 |
+
destination = sys.argv[2]
|
41 |
+
else:
|
42 |
+
file_id = "TAKE_ID_FROM_SHAREABLE_LINK"
|
43 |
+
destination = "DESTINATION_FILE_ON_YOUR_DISK"
|
44 |
+
print(f"dowload {file_id} to {destination}")
|
45 |
+
download_file_from_google_drive(file_id, destination)
|
46 |
+
|
47 |
+
|
48 |
+
if __name__ == "__main__":
|
49 |
+
main()
|
utils/g2p/__init__.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/keithito/tacotron """
|
2 |
+
import utils.g2p.cleaners
|
3 |
+
from utils.g2p.symbols import symbols
|
4 |
+
from tokenizers import Tokenizer
|
5 |
+
|
6 |
+
# Mappings from symbol to numeric ID and vice versa:
|
7 |
+
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
8 |
+
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
9 |
+
|
10 |
+
|
11 |
+
class PhonemeBpeTokenizer:
|
12 |
+
def __init__(self, tokenizer_path = "./utils/g2p/bpe_1024.json"):
|
13 |
+
self.tokenizer = Tokenizer.from_file(tokenizer_path)
|
14 |
+
|
15 |
+
def tokenize(self, text):
|
16 |
+
# 1. convert text to phoneme
|
17 |
+
phonemes, langs = _clean_text(text, ['cje_cleaners'])
|
18 |
+
# 2. replace blank space " " with "_"
|
19 |
+
phonemes = phonemes.replace(" ", "_")
|
20 |
+
# 3. tokenize phonemes
|
21 |
+
phoneme_tokens = self.tokenizer.encode(phonemes).ids
|
22 |
+
assert(len(phoneme_tokens) == len(langs))
|
23 |
+
if not len(phoneme_tokens):
|
24 |
+
raise ValueError("Empty text is given")
|
25 |
+
return phoneme_tokens, langs
|
26 |
+
|
27 |
+
def text_to_sequence(text, cleaner_names):
|
28 |
+
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
29 |
+
Args:
|
30 |
+
text: string to convert to a sequence
|
31 |
+
cleaner_names: names of the cleaner functions to run the text through
|
32 |
+
Returns:
|
33 |
+
List of integers corresponding to the symbols in the text
|
34 |
+
'''
|
35 |
+
sequence = []
|
36 |
+
symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
37 |
+
clean_text = _clean_text(text, cleaner_names)
|
38 |
+
for symbol in clean_text:
|
39 |
+
if symbol not in symbol_to_id.keys():
|
40 |
+
continue
|
41 |
+
symbol_id = symbol_to_id[symbol]
|
42 |
+
sequence += [symbol_id]
|
43 |
+
return sequence
|
44 |
+
|
45 |
+
|
46 |
+
def cleaned_text_to_sequence(cleaned_text):
|
47 |
+
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
48 |
+
Args:
|
49 |
+
text: string to convert to a sequence
|
50 |
+
Returns:
|
51 |
+
List of integers corresponding to the symbols in the text
|
52 |
+
'''
|
53 |
+
sequence = [_symbol_to_id[symbol] for symbol in cleaned_text if symbol in _symbol_to_id.keys()]
|
54 |
+
return sequence
|
55 |
+
|
56 |
+
|
57 |
+
def sequence_to_text(sequence):
|
58 |
+
'''Converts a sequence of IDs back to a string'''
|
59 |
+
result = ''
|
60 |
+
for symbol_id in sequence:
|
61 |
+
s = _id_to_symbol[symbol_id]
|
62 |
+
result += s
|
63 |
+
return result
|
64 |
+
|
65 |
+
|
66 |
+
def _clean_text(text, cleaner_names):
|
67 |
+
for name in cleaner_names:
|
68 |
+
cleaner = getattr(utils.g2p.cleaners, name)
|
69 |
+
if not cleaner:
|
70 |
+
raise Exception('Unknown cleaner: %s' % name)
|
71 |
+
text, langs = cleaner(text)
|
72 |
+
return text, langs
|
utils/g2p/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (4.49 kB). View file
|
|
utils/g2p/__pycache__/cleaners.cpython-311.pyc
ADDED
Binary file (4.66 kB). View file
|
|
utils/g2p/__pycache__/english.cpython-311.pyc
ADDED
Binary file (8.53 kB). View file
|
|