File size: 15,358 Bytes
2b739a9
 
 
f0fb314
646f05b
f0fb314
796c86d
0c4f028
 
 
db6b323
ad71a7a
0c4f028
ad71a7a
0c4f028
 
 
 
 
ad71a7a
0c4f028
ad71a7a
 
796c86d
6d44a08
0c4f028
 
 
6d44a08
796c86d
0c4f028
796c86d
0c4f028
 
 
6d44a08
 
 
 
 
c692566
 
 
db6b323
0c4f028
c692566
6d44a08
 
 
ad71a7a
 
6d44a08
ad71a7a
db6b323
5f3cd68
 
 
 
 
db6b323
6d44a08
5e1ff4a
 
 
 
 
 
 
 
 
6d44a08
 
 
089db71
6d44a08
796c86d
6d44a08
 
796c86d
 
089db71
 
 
796c86d
 
 
 
 
c230973
 
 
 
 
 
 
 
 
 
e302cd8
 
c230973
5f3cd68
d82ee7c
 
 
c230973
db6b323
 
b7c10b1
e626434
 
 
 
 
fd4aef5
 
e626434
fd4aef5
e626434
 
 
 
fd4aef5
ad71a7a
b7c10b1
 
 
2d624f1
 
 
 
 
 
 
 
6d44a08
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
---
license: agpl-3.0
---

This repo catalogs my weights for use with my [VALL-E](https://github.com/e-c-k-e-r/vall-e) implementation as I try and iron out the kinks.

The model currently is in a *usable* state under `ar+nar-llama-8` (the default model thats downloaded).

## Models

This repo contains the following configurations under `./models/`:

* `config.retnet.yaml` / `ar+nar-retnet-8`: The previously released weights.
	+ This configuration utilizes a RetNet (retention based "transformer") as the underlying architecture due to a number of misleading interpretations with comparisons, for better or for worse.
		+ Prompt and response embeddings are summed (further RVQ levels gets the previous RVQ levels' embeddings factored in).
		+ Tokenizer is a homebrewed "naive" implementation.
	+ This model received the most training time between my 4070Ti, 7900XTX, and a few rental rigs to training further progress, entirely at `bfloat16` with `prodigyopt` (and a few optimizer restarts).
	+ The later part of training aimed to shuffle between speakers rather than the global pool of utterances to better focus on zero-shot performance. Due to this, I feel it achieved *decent* zero-shot performance.
	+ However, due to the dataset being aggressively trimmed under 12 seconds for memory savings during training, it suffers trying to inference non-short utterances. Additional training may fix this, the following models seemed to adapt well to longer utterances.
        + From the `ar+nar-llama-8` experiment, I believe this can be "fixed" with additional training on the currently processed dataset.
	+ Prior testing showed that longer prompt durations results in better utterances.
    + *Can* benefit from additional training, but I recall the average loss being around `1.9` to `2.1`.
        + However, due to regressions (or bias from working under `llama`), I don't think I can optimially train with a RetNet again (both in terms of VRAM consumption and throughput).
        + I would love to revisit this with my more-better-er training paradigms.
    + Currently does not seem to work anymore due to regressions in the code.

* `config.llama.yaml` / `ar+nar-llama-8`: The most recent-ishly trained weights after learning from my mistakes.
	+ This configuration utilizes Llama's attention-based transformer as the underlying architecture, making use of creature comforts like RoPE, GQA, and memory-efficient attention (trained under `xformers`, shouldn't really affect things).
		+ Prompt and response embeddings ARE summed (half the model was trained without summing, but enabling it seemed to make the most sense, and it didn't affect anything to do so).
          + However, the opposite is not true: a model trained with summed embeddings does not function after disabling this.
		+ Utilizes a HF tokenizer for "optimal" vocab.
          + Optimal in the sense it uses the remaining portion of the 256 indices for merged phonemes (although I imagine it would be better NOT to merge, as the model's focus isn't in phoneme output).
		+ The current RVQ level is included as a token as well to help guide NAR tasks better.
	+ This model received a few days of training on my 4xV100s, stepping up the duration window to *try* and better make the model inference for longer utterances.
		+ Some sessions end up training the current duration window for a few epochs, but I don't know how much it affected things.
    + This model *actually* received additional post-training for a variety of issues needed to be addressed:
        + Training on shuffled batches of durations to have it better generalize on a variety of durations.
        + Non-naive prompt sampling for similar utterances to try and give better prompt adherance.
        + Additional languages (Japanese, French, and German) and an additional task: Speech-to-Text (phonemes)
        + etc.
	+ ~~However, it seems to *only* do well with long utterances. Short utterances fumble. I believe further training with a variety of durations should allow the AR to handle a variety of durations.~~
		- ~~I believe the "slowly stepping up the context length" only works for text, and not audio.~~
        - Addendum: Additional brief training for a variety of duration lengths seemed to have mostly fixed this issue.
        - Addendum addendum: Properly creating the position IDs per-segment rather than the whole sequence, also helps a lot.
	+ Zero-shot performance leaves a bit to be desired, as it did not receive the special training prioritizing shuffling between speakers rather than the global pool of utterances.
        - Addendum: Additional brief training for sampling based on speaker per "epoch" (per dataloader, not dataset) seemed to slightly improve it.
        - Addendum addendum: non-naive prompt sampling with a similar utterance to the output helps a non-negligible amount.
	+ Testing showed that~~, despite also stepping up the prompt duration, it *really* likes three second prompts.~~ longer input prompts does actually help.
        + Giving a wide coverage of phonemes to directly reference goes a long way.
	+ Definitely needs additional training, but the next way to go is unknown.
        + Naturally, training it on a "next RVQ level is half as likely" distribution introduces some crust as the later RVQ levels are less accurate, introducing noise and artifacts.
        + Additional training on the AR will ~~see huge diminishing returns, so I don't know if it's worth doing so.~~ see slight improvements over additional epochs with differen training/sampling paradigms.
    + Seems to be a decent foundation for "distillation", at the very least for LoRA training.
    	- Addendum: it seems to serve fine for patch-training a few extra tweaks, to non-unified position IDs, split classifier heads, and para-parallel decoding for the AR.
    + Addendum: This received a lot of additional training (~60k more steps).
      + This post-training *was* intended to teach the model a pure NAR-RVQ-level-0 task for parallel decoding, but an error proved it to actually make it into decent AR training.
      + Classifier-free-guidance-aware-training was also performed, really helping the prompt adherence even at ar-temperature=1.0.
      + Regression tests are needed just in case I did botch something, but it seems really nice so far.
        + The old weights are saved as `ar+nar-old-llama-8` in the event of a nasty regression, but I doubt it's necessary.

* ~~`config.llama-tts+stt.yaml` / `ar+nar-tts+stt-llama-8`~~: The above, but with partially trained for STT.
    + These weights use the above weights but with additional training for the default `tts` task and a new `stt` task (at a 3:1 ratio).
    + Initially was trained with `duration_range: [3.0, 60.0]` and `sample_shuffle: True` for a few hours, but then pivoted to my standard  `duration_range: [3.0, 12.0]` and `sample_shuffle: False`
      + Will need the former training to "undo" any issues with durations, as it usually came up before.
    + `stt` task simply takes a piece of audio and outputs a transcription using IPA phonemes (that the model already is trained against for its text inputs).
      + Can be done with `--task=stt` and an empty (`""`) text input through the CLI interface or the `Speech-to-Text` tab in the web UI.
    + This mainly serves as a stepping stone before pivoting towards SpeechX tasks.
      + I first need a good mechanism to make sure I *can* extend existing weights with additional tasks, but with a simple enough task.
      + This also *maybe* seems to help bolster the initial TTS task by helping the model have a better internal state (or something to that tune).
    + STT is not perfect against voices that aren't close to a normal speaking voice (as per the dataset), unlike TTS where you can easily have "sounds close enough" and room for errors.
    + Addendum: this replaced the `ar+nar-llama-8` as the defacto model (taking its name), so the above does apply.

* `config.llama[layerskip].yaml` / `ar+nar-layerskip-llama-8`: The above, but with very brief training for LayerSkip:
    + Post-trained on a small English subset of Emilia and a small private corpus, and Japanese+French+German from Emilia.
    + Using shuffled batches (where each batch has the same durations) and a modified `rvq_levels_p` to help the NAR.
    + Initially trained with LaterSkip hyperparamenters `R=4` and `e_scale=0.2`, but midway through swapped to `R=2` and `e_scale=0.1` to maintain stability.
    + This model received LayerSkip-aware training, with layer dropout and early-exit loss to help try and bolster the model and enable self-speculation sampling.
    + Goal is to utilize self-speculation sampling to enable speedups when possible.
      + Current implementation will early-exit if the entropy/varentropy of the logits are low enough (<0.1).
      + Speedups seem to shave off a second of inference time.
    + Training is a pain.
      + LayerSkip-aware training does *not* like to train under ROCm.
      + Training under float16+AMP with loss scaling will fry the model with a large enough de facto batch size (>512 samples/update step) and/or too low of a loss scale (<=8K).
    + LayerSkip-aware training seems to degrade the model enough to where it harms the models ability to sound similar to the reference prompt the more it trains.
      + I imagine this techique only really works for "large" enough models (be it wide and/or deep enough) that may cause it to second-guess in the later levels.
      + The current size of VALL-E doesn't seem to necessitate LayerSkip, as it seems to instead dumb the model down to ~9 layers instead of 12 (as it typically exits early at layer 9, and the remaining layers offer little additional benefits).
        + This *does* seem to prove a nice way to shrink models, and perhaps even grow them? I remember finding trying to grow a model causes the extra layers to be useless.
    * Unless I get a revelation, this experiment is bunk unless it can magically live through a LoRA.
    * Experiments shown that this actively harms the model for a very negligible speed gain, as LayerSkip-aware training shifts most of the intelligence down a few layers, and keeps the last couple of layers to further-upon the confidence of the outputs, or something.
      * Despite being a failure, this does pave a nice way to shrink models from an existing model. However, this does not seem to be useful as even dropping two/three layers really does harm how well the prompt is followed.

* `config.llama[nar-len].yaml` / `nar-len-llama-8`: A fully non-autoregressive model.
  * These weights are a work in progress, but currently are a good proof-of-concept so far until training is on-par with the base `ar+nar-llama-8` model.
  * A ***lot*** of pain was put into trying to get something working, through implementation issues to dumb mistakes, until the best option of just training from scratch was picked.
    * Technically, the `ar+nar-llama-8` can be modified to be a pure non-autoregressive model, but I needed to start from scratch before dumping more time again trying to adapt it.
  * Speedups are immense compared to the `ar+nar-llama-8`, as the entire audio output is decoded in parallel rather than causally.
    * Throughput and memory usage should be constant between inferencing steps.
    * The model only needs to be invoked about 5+25+7 (duration inferencing + RVQ level 0 inferencing + remaining RVQ levels) instead.
  * Seems to absolutely require classifier-free-guidance to keep the output stable.
  * The "confidence" issue on voices it hasn't seen / hasn't seen much of is much more noticeable as RVQ level 0 is much more susceptable to it.
  * Unlike the base model, this is trained with the current dataset without iteratively dripfeeding additional sources (like tacking on Emilia afterwards).
    * ...except STT, this received no STT training out of fear of botching the model.
  * Weights will be added as the model is trained.
    * This *was* expected to be a dud, but one very, very small oversight in the sampling code proved to be the culrpit......
    * In other words, the model *does* work.

Some additional configurations have been explored with, but experiments have not been fruitful:
* Exotic wrappers like `BitNet` seemed to yield little gains in inferencing, somehow. The memory savings is pretty much unneccessary as the models are already manageable at ~200M parameters.
* Mamba / Mamba2-based models have shown that it's ***really*** hard to have an AR+NAR model. I really do not want to bother throwing the compute at another ~~meme~~ arch I can't easily make use of all the other tech to throw at.
* a model using [Descript-Audio-Codec](https://github.com/descriptinc/descript-audio-codec/):
  + the 24KHz model will *not* converge no matter what. However, naively using just the first 8 RVQ levels might not be good enough, as there's too many codebooks for viable use.
  + the 44KHz model was erroneously assumed to be an even 44KHz, when in reality it's 44.1KHz. *All* of my audio has to be requantized, as there's some stuttering in it.
  	+ Because of this, training losses are high and it's having a hard time trying to converge.
  + It has *sub-servicable* output for the first 4 RVQ levels, but it's massive cope to try and use it as a model.
  + ~~I believe there's hope to use it when I requantize my audio properly.~~
    + Addendum: even after properly processing my audio, the loss is actually *worse* than before. I imagine DAC just cannot be used as an intermediary for an LM.
* a model with a causal size >1 (sampling more than one token for the AR):
  + re-using an existing model or training from scratch does not have fruitful results.
  + there's an inherent periodic stutter that doesn't seem to be able to be trained out, but it might require exotic sampling methods.
  + unfortunately it requires:
    + either something similar to Medusa heads, where there's additional parameters to perform speculative sampling,
    + a solution similar to what VALL-E 2 uses with group token embeddings or whatever, which *will* harm the NAR tasks in an AR+NAR model.
  + I just don't understand where the issue lies, since parallel decoding does work, as evidence with the NAR.

Some current "achitectural features" are in-use, but their effects need to be experimented with further:
* `split_classifier_heads` is still a mystery whether it's truly helpful or not (each RVQ level gets its own output head).
* `audio_embeddings_sum` is also a mystery whether it matters if each later RVQ level should "see" the past levels through summing embeddings, or if not doing it is preferable.
* Disabling `unified_position_ids` seems to help quality more often than not, but I'm still unsure if it's beneficial in practice.

## LoRAs

This repo also contains some LoRAs to serve as a reference under `./loras/`.

Using a LoRA is the same as a base model, except you're required to have the base model already (obviously). Just use the LoRA's config YAML to load from instead to use it.

The only caveat is that my original dataset *does* contain (most of) these samples already, but given the sheer size of it, they're probably underutilized.
* However, the base model already has *almost adequate* output from these speakers, but not enough to be satisfactory.