yuta0306 commited on
Commit
565faca
0 Parent(s):

first commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ /.venv/
2
+ /models/
3
+ /**/__pycache__/
4
+ /**/*.pyc
Dockerfile ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04
2
+
3
+ ENV DEBIAN_FRONTEND=noninteractive
4
+ RUN apt-get update && \
5
+ apt-get upgrade -y && \
6
+ apt-get install -y --no-install-recommends \
7
+ libgirepository1.0-dev \
8
+ git \
9
+ git-lfs \
10
+ wget \
11
+ curl \
12
+ # python build dependencies \
13
+ build-essential \
14
+ libssl-dev \
15
+ zlib1g-dev \
16
+ libbz2-dev \
17
+ libreadline-dev \
18
+ libsqlite3-dev \
19
+ libncursesw5-dev \
20
+ xz-utils \
21
+ tk-dev \
22
+ libxml2-dev \
23
+ libxmlsec1-dev \
24
+ libffi-dev \
25
+ liblzma-dev \
26
+ # nightly dependencies \
27
+ libdbus-glib-1-dev \
28
+ libpng-dev \
29
+ libjpeg-dev \
30
+ libcairo2-dev \
31
+ # gradio dependencies \
32
+ ffmpeg \
33
+ # fairseq2 dependencies \
34
+ libsndfile-dev && \
35
+ apt-get clean && \
36
+ rm -rf /var/lib/apt/lists/*
37
+
38
+ RUN useradd -m -u 1000 user
39
+ USER user
40
+ ENV HOME=/home/user \
41
+ PATH=/home/user/.local/bin:${PATH}
42
+ WORKDIR ${HOME}/app
43
+
44
+ COPY --chown=user . $HOME/app
45
+ RUN curl https://pyenv.run | bash
46
+ ENV PATH=${HOME}/.pyenv/shims:${HOME}/.pyenv/bin:${PATH}
47
+ ARG PYTHON_VERSION=3.10.13
48
+ RUN pyenv install ${PYTHON_VERSION} && \
49
+ pyenv global ${PYTHON_VERSION} && \
50
+ pyenv rehash && \
51
+ pip install --no-cache-dir -U pip setuptools wheel
52
+ RUN pip install packaging && \
53
+ pip install -r ${HOME}/app/requirements.txt && \
54
+ pip install -U flash-attn gradio spacy transformers fastapi tyro julius audiocraft tiktoken hf-transfer && \
55
+ pip install -U --pre torch torch torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
56
+ ENV PYTHONPATH=${HOME}/app \
57
+ PYTHONUNBUFFERED=1 \
58
+ HF_HUB_ENABLE_HF_TRANSFER=1 \
59
+ GRADIO_ALLOW_FLAGGING=never \
60
+ GRADIO_NUM_PORTS=1 \
61
+ GRADIO_SERVER_NAME=0.0.0.0 \
62
+ GRADIO_THEME=huggingface \
63
+ TQDM_POSITION=-1 \
64
+ TQDM_MININTERVAL=1 \
65
+ SYSTEM=spaces
66
+ CMD python fam/llm/serving.py --huggingface_repo_id kotoba-tech/kotoba-speech-v0.1 & python fam/ui/app.py
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Kotoba Voice Testing
3
+ emoji: 🚀
4
+ colorFrom: red
5
+ colorTo: green
6
+ sdk: docker
7
+ pinned: false
8
+ license: apache-2.0
9
+ ---
10
+
11
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
fam/__init__.py ADDED
File without changes
fam/llm/.fast_inference.py.swp ADDED
Binary file (16.4 kB). View file
 
fam/llm/__init__.py ADDED
File without changes
fam/llm/adapters/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from fam.llm.adapters.flattened_encodec import FlattenedInterleavedEncodec2Codebook
2
+ from fam.llm.adapters.tilted_encodec import TiltedEncodec
fam/llm/adapters/base.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from abc import ABC
2
+
3
+
4
+ class BaseDataAdapter(ABC):
5
+ pass
fam/llm/adapters/flattened_encodec.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fam.llm.adapters.base import BaseDataAdapter
2
+
3
+
4
+ class FlattenedInterleavedEncodec2Codebook(BaseDataAdapter):
5
+ def __init__(self, end_of_audio_token):
6
+ self._end_of_audio_token = end_of_audio_token
7
+
8
+ def decode(self, tokens: list[list[int]]) -> tuple[list[int], list[list[int]]]:
9
+ assert len(tokens) == 1
10
+ tokens = tokens[0]
11
+
12
+ text_ids = []
13
+ extracted_audio_ids = [[], []]
14
+
15
+ for t in tokens:
16
+ if t < self._end_of_audio_token:
17
+ extracted_audio_ids[0].append(t)
18
+ elif t >= self._end_of_audio_token and t < 2 * self._end_of_audio_token:
19
+ extracted_audio_ids[1].append(t - self._end_of_audio_token)
20
+ # We ignore t = 2 * self._end_of_audio_token, as it is the end of audio token
21
+ elif t > 2 * self._end_of_audio_token:
22
+ text_ids.append(t)
23
+
24
+ if len(set([len(x) for x in extracted_audio_ids])) != 1:
25
+ min_len = min([len(x) for x in extracted_audio_ids])
26
+ max_len = max([len(x) for x in extracted_audio_ids])
27
+ print("WARNING: Number of tokens at each hierarchy must be of the same length!")
28
+ print(f"Truncating to min length of {min_len} tokens from {max_len} max.")
29
+ print([len(x) for x in extracted_audio_ids])
30
+ extracted_audio_ids = [x[:min_len] for x in extracted_audio_ids]
31
+
32
+ return text_ids[:-1], extracted_audio_ids
33
+
34
+ def encode(self, text_tokens: list[int], audio_tokens: list[list[int]]):
35
+ """
36
+ Performs the required combination and padding as needed.
37
+ """
38
+ raise NotImplementedError
fam/llm/adapters/tilted_encodec.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fam.llm.adapters.base import BaseDataAdapter
2
+
3
+
4
+ class TiltedEncodec(BaseDataAdapter):
5
+ def __init__(self, end_of_audio_token):
6
+ self._end_of_audio_token = end_of_audio_token
7
+
8
+ def decode(self, tokens: list[list[int]]) -> tuple[list[int], list[list[int]]]:
9
+ assert len(tokens) > 1
10
+
11
+ text_ids = []
12
+ extracted_audio_ids = []
13
+
14
+ extracted_audio_ids.append([])
15
+ # Handle first hierarchy as special case as it contains text tokens as well
16
+ # TODO: maybe it doesn't need special case, and can be handled on it's own :)
17
+ for t in tokens[0]:
18
+ if t > self._end_of_audio_token:
19
+ text_ids.append(t)
20
+ elif t < self._end_of_audio_token:
21
+ extracted_audio_ids[0].append(t)
22
+
23
+ # Handle the rest of the hierarchies
24
+ for i in range(1, len(tokens)):
25
+ token_hierarchy_ids = tokens[i]
26
+ extracted_audio_ids.append([])
27
+ for t in token_hierarchy_ids:
28
+ if t < self._end_of_audio_token:
29
+ extracted_audio_ids[i].append(t)
30
+
31
+ if len(set([len(x) for x in extracted_audio_ids])) != 1:
32
+ min_len = min([len(x) for x in extracted_audio_ids])
33
+ max_len = max([len(x) for x in extracted_audio_ids])
34
+ print("WARNING: Number of tokens at each hierarchy must be of the same length!")
35
+ print(f"Truncating to min length of {min_len} tokens from {max_len} max.")
36
+ print([len(x) for x in extracted_audio_ids])
37
+ extracted_audio_ids = [x[:min_len] for x in extracted_audio_ids]
38
+
39
+ return text_ids[:-1], extracted_audio_ids
40
+
41
+ def encode(self, text_tokens: list[int], audio_tokens: list[list[int]]):
42
+ """
43
+ Performs the required combination and padding as needed.
44
+ """
45
+ raise NotImplementedError
fam/llm/decoders.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pathlib
3
+ import uuid
4
+ from abc import ABC, abstractmethod
5
+ from typing import Callable, Optional, Union
6
+
7
+ import julius
8
+ import torch
9
+ from audiocraft.data.audio import audio_read, audio_write
10
+ from audiocraft.models import MultiBandDiffusion # type: ignore
11
+
12
+ from IPython import embed
13
+
14
+ class Decoder(ABC):
15
+ @abstractmethod
16
+ def decode(self, tokens: list[int], ref_audio_path: Optional[str] = None, causal: Optional[bool] = None):
17
+ raise NotImplementedError
18
+
19
+
20
+ class EncodecDecoder(Decoder):
21
+ def __init__(
22
+ self,
23
+ tokeniser_decode_fn: Callable[[list[int]], str],
24
+ data_adapter_fn: Callable[[list[list[int]]], tuple[list[int], list[list[int]]]],
25
+ output_dir: str,
26
+ ):
27
+ self._mbd_bandwidth = 6 # 1.5
28
+ self._mbd_sample_rate = 24_000
29
+ self._end_of_audio_token = 1024
30
+ self._num_codebooks = 8
31
+ self.mbd = MultiBandDiffusion.get_mbd_24khz(bw=self._mbd_bandwidth)
32
+
33
+ self.tokeniser_decode_fn = tokeniser_decode_fn
34
+ self._data_adapter_fn = data_adapter_fn
35
+
36
+ self.output_dir = pathlib.Path(output_dir).resolve()
37
+ os.makedirs(self.output_dir, exist_ok=True)
38
+
39
+ def _save_audio(self, name: str, wav: torch.Tensor):
40
+ audio_write(
41
+ name,
42
+ wav.squeeze(0).cpu(),
43
+ self._mbd_sample_rate,
44
+ strategy="loudness",
45
+ loudness_compressor=True,
46
+ )
47
+
48
+ def get_tokens(self, audio_path: str) -> list[list[int]]:
49
+ """
50
+ Utility method to get tokens from audio. Useful when you want to test reconstruction in some form (e.g.
51
+ limited codebook reconstruction or sampling from second stage model only).
52
+ """
53
+ pass
54
+ wav, sr = audio_read(audio_path)
55
+ if sr != self._mbd_sample_rate:
56
+ wav = julius.resample_frac(wav, sr, self._mbd_sample_rate)
57
+ if wav.ndim == 2:
58
+ wav = wav.unsqueeze(1)
59
+ wav = wav.to("cuda")
60
+ tokens = self.mbd.codec_model.encode(wav)
61
+ tokens = tokens[0][0]
62
+ # embed()
63
+ return tokens.tolist()
64
+
65
+ def decode(
66
+ self, tokens: list[list[int]], causal: bool = True, ref_audio_path: Optional[str] = None
67
+ ) -> Union[str, torch.Tensor]:
68
+ # TODO: this has strange behaviour -- if causal is True, it returns tokens. if causal is False, it SAVES the audio file.
69
+ text_ids, extracted_audio_ids = self._data_adapter_fn(tokens)
70
+ text = self.tokeniser_decode_fn(text_ids)
71
+ print(f"Text: {text}")
72
+
73
+ tokens = torch.tensor(extracted_audio_ids, device="cuda").unsqueeze(0)
74
+
75
+ if tokens.shape[1] < self._num_codebooks:
76
+ tokens = torch.cat(
77
+ [tokens, *[torch.ones_like(tokens[0:1, 0:1]) * 0] * (self._num_codebooks - tokens.shape[1])], dim=1
78
+ )
79
+
80
+ if causal:
81
+ return tokens
82
+ else:
83
+ with torch.amp.autocast(device_type="cuda", dtype=torch.float32):
84
+ # embed()
85
+ wav = self.mbd.tokens_to_wav(tokens)
86
+ # NOTE: we couldn't just return wav here as it goes through loudness compression etc :)
87
+
88
+ if wav.shape[-1] < 9600:
89
+ # this causes problem for the code below, and is also odd :)
90
+ # first happened for tokens (1, 8, 28) -> wav (1, 1, 8960) (~320x factor in time dimension!)
91
+ raise Exception("wav predicted is shorter than 400ms!")
92
+
93
+ try:
94
+ wav_file_name = self.output_dir / f"synth_{text.replace(' ', '_')[:25]}_{uuid.uuid4()}"
95
+ self._save_audio(wav_file_name, wav)
96
+ print(f"\nSaved audio to {wav_file_name}.wav")
97
+ return wav_file_name
98
+ except Exception as e:
99
+ print(f"Failed to save audio! Reason: {e}")
100
+ wav_file_name = self.output_dir / f"synth_{uuid.uuid4()}"
101
+ self._save_audio(wav_file_name, wav)
102
+ print(f"\nSaved audio to {wav_file_name}.wav")
103
+ return wav_file_name
fam/llm/enhancers.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import ABC
3
+ from typing import Literal, Optional
4
+
5
+ from df.enhance import enhance, init_df, load_audio, save_audio
6
+ from pydub import AudioSegment
7
+
8
+
9
+ def convert_to_wav(input_file: str, output_file: str):
10
+ """Convert an audio file to WAV format
11
+
12
+ Args:
13
+ input_file (str): path to input audio file
14
+ output_file (str): path to output WAV file
15
+
16
+ """
17
+ # Detect the format of the input file
18
+ format = input_file.split(".")[-1].lower()
19
+
20
+ # Read the audio file
21
+ audio = AudioSegment.from_file(input_file, format=format)
22
+
23
+ # Export as WAV
24
+ audio.export(output_file, format="wav")
25
+
26
+
27
+ def make_output_file_path(audio_file: str, tag: str, ext: Optional[str] = None) -> str:
28
+ """Generate the output file path
29
+
30
+ Args:
31
+ audio_file (str): path to input audio file
32
+ tag (str): tag to append to the output file name
33
+ ext (str, optional): extension of the output file. Defaults to None.
34
+
35
+ Returns:
36
+ str: path to output file
37
+ """
38
+
39
+ directory = "./enhanced"
40
+ # Get the name of the input file
41
+ filename = os.path.basename(audio_file)
42
+
43
+ # Get the name of the input file without the extension
44
+ filename_without_extension = os.path.splitext(filename)[0]
45
+
46
+ # Get the extension of the input file
47
+ extension = ext or os.path.splitext(filename)[1]
48
+
49
+ # Generate the output file path
50
+ output_file = os.path.join(directory, filename_without_extension + tag + extension)
51
+
52
+ return output_file
53
+
54
+
55
+ class BaseEnhancer(ABC):
56
+ """Base class for audio enhancers"""
57
+
58
+ def __init__(self, *args, **kwargs):
59
+ raise NotImplementedError
60
+
61
+ def __call__(self, audio_file: str, output_file: Optional[str] = None) -> str:
62
+ raise NotImplementedError
63
+
64
+ def get_output_file(self, audio_file: str, tag: str, ext: Optional[str] = None) -> str:
65
+ output_file = make_output_file_path(audio_file, tag, ext=ext)
66
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
67
+ return output_file
68
+
69
+
70
+ class DFEnhancer(BaseEnhancer):
71
+ def __init__(self, *args, **kwargs):
72
+ self.model, self.df_state, _ = init_df()
73
+
74
+ def __call__(self, audio_file: str, output_file: Optional[str] = None) -> str:
75
+ output_file = output_file or self.get_output_file(audio_file, "_df")
76
+
77
+ audio, _ = load_audio(audio_file, sr=self.df_state.sr())
78
+
79
+ enhanced = enhance(self.model, self.df_state, audio)
80
+
81
+ save_audio(output_file, enhanced, self.df_state.sr())
82
+
83
+ return output_file
84
+
85
+
86
+ def get_enhancer(enhancer_name: Literal["df"]) -> BaseEnhancer:
87
+ """Get an audio enhancer
88
+
89
+ Args:
90
+ enhancer_name (Literal["df"]): name of the audio enhancer
91
+
92
+ Raises:
93
+ ValueError: if the enhancer name is not recognised
94
+
95
+ Returns:
96
+ BaseEnhancer: audio enhancer
97
+ """
98
+
99
+ if enhancer_name == "df":
100
+ return DFEnhancer()
101
+ else:
102
+ raise ValueError(f"Unknown enhancer name: {enhancer_name}")
fam/llm/fast_inference.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import tempfile
4
+ import time
5
+ from pathlib import Path
6
+
7
+ import librosa
8
+ import torch
9
+ from huggingface_hub import snapshot_download
10
+
11
+ from fam.llm.adapters import FlattenedInterleavedEncodec2Codebook
12
+ from fam.llm.decoders import EncodecDecoder
13
+ from fam.llm.fast_inference_utils import build_model, main
14
+ from fam.llm.inference import (
15
+ EncodecDecoder,
16
+ InferenceConfig,
17
+ Model,
18
+ TiltedEncodec,
19
+ TrainedBPETokeniser,
20
+ get_cached_embedding,
21
+ get_cached_file,
22
+ get_enhancer,
23
+ )
24
+ from fam.llm.utils import (
25
+ check_audio_file,
26
+ get_default_dtype,
27
+ get_device,
28
+ normalize_text,
29
+ )
30
+ import argparse
31
+
32
+
33
+ class TTS:
34
+ def __init__(
35
+ self, model_name: str = "kotoba-tech/kotoba-speech-v0.1", *, seed: int = 1337, output_dir: str = "outputs", first_model_path: str = None,
36
+ ):
37
+ """
38
+ model_name (str): refers to the model identifier from the Hugging Face Model Hub (https://huggingface.co/kotoba-tech/)
39
+ """
40
+
41
+ # NOTE: this needs to come first so that we don't change global state when we want to use
42
+ # the torch.compiled-model.
43
+ self._dtype = get_default_dtype()
44
+ self._device = get_device()
45
+ self._model_dir = snapshot_download(repo_id=model_name)
46
+ self.first_stage_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=1024)
47
+ self.output_dir = output_dir
48
+ os.makedirs(self.output_dir, exist_ok=True)
49
+
50
+ second_stage_ckpt_path = f"{self._model_dir}/second_stage.pt"
51
+ config_second_stage = InferenceConfig(
52
+ ckpt_path=second_stage_ckpt_path,
53
+ num_samples=1,
54
+ seed=seed,
55
+ device=self._device,
56
+ dtype=self._dtype,
57
+ compile=False,
58
+ init_from="resume",
59
+ output_dir=self.output_dir,
60
+ )
61
+ data_adapter_second_stage = TiltedEncodec(end_of_audio_token=1024)
62
+ self.llm_second_stage = Model(
63
+ config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode
64
+ )
65
+ self.enhancer = get_enhancer("df")
66
+
67
+ self.precision = {"float16": torch.float16, "bfloat16": torch.bfloat16}[self._dtype]
68
+ self.model, self.tokenizer, self.smodel, self.model_size = build_model(
69
+ precision=self.precision,
70
+ checkpoint_path=Path(f"{self._model_dir}/first_stage.pt"),
71
+ spk_emb_ckpt_path=Path(f"{self._model_dir}/speaker_encoder.pt"),
72
+ device=self._device,
73
+ compile=True,
74
+ compile_prefill=True,
75
+ first_model_path=first_model_path,
76
+ )
77
+
78
+
79
+ def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3.0, temperature=1.0) -> str:
80
+ """
81
+ text: Text to speak
82
+ spk_ref_path: Path to speaker reference file. Min. 30s of audio required. Supports both local paths & public URIs. Audio formats: wav, flac & mp3
83
+ top_p: Top p for sampling applied to first-stage model. Range [0.9, 1.0] are good. This is a measure of speech stability - improves text following for a challenging speaker
84
+ guidance_scale: Guidance scale [1.0, 3.0] for sampling. This is a measure of speaker similarity - how closely to match speaker identity and speech style.
85
+ temperature: Temperature for sampling applied to both LLMs (first & second stage)
86
+
87
+ returns: path to speech .wav file
88
+ """
89
+ text = normalize_text(text)
90
+ spk_ref_path = get_cached_file(spk_ref_path)
91
+ check_audio_file(spk_ref_path)
92
+ spk_emb = get_cached_embedding(
93
+ spk_ref_path,
94
+ self.smodel,
95
+ ).to(device=self._device, dtype=self.precision)
96
+
97
+ start = time.time()
98
+ # first stage LLM
99
+ tokens = main(
100
+ model=self.model,
101
+ tokenizer=self.tokenizer,
102
+ model_size=self.model_size,
103
+ prompt=text,
104
+ spk_emb=spk_emb,
105
+ top_p=torch.tensor(top_p, device=self._device, dtype=self.precision),
106
+ guidance_scale=torch.tensor(guidance_scale, device=self._device, dtype=self.precision),
107
+ temperature=torch.tensor(temperature, device=self._device, dtype=self.precision),
108
+ )
109
+ text_ids, extracted_audio_ids = self.first_stage_adapter.decode([tokens])
110
+
111
+ b_speaker_embs = spk_emb.unsqueeze(0)
112
+
113
+ # second stage LLM + multi-band diffusion model
114
+ wav_files = self.llm_second_stage(
115
+ texts=[text],
116
+ encodec_tokens=[torch.tensor(extracted_audio_ids, dtype=torch.int32, device=self._device).unsqueeze(0)],
117
+ speaker_embs=b_speaker_embs,
118
+ batch_size=1,
119
+ guidance_scale=None,
120
+ top_p=None,
121
+ top_k=200,
122
+ temperature=1.0,
123
+ max_new_tokens=None,
124
+ )
125
+
126
+ # enhance using deepfilternet
127
+ wav_file = wav_files[0]
128
+ with tempfile.NamedTemporaryFile(suffix=".wav") as enhanced_tmp:
129
+ self.enhancer(str(wav_file) + ".wav", enhanced_tmp.name)
130
+ shutil.copy2(enhanced_tmp.name, str(wav_file) + ".wav")
131
+ print(f"\nSaved audio to {wav_file}.wav")
132
+
133
+ # calculating real-time factor (RTF)
134
+ time_to_synth_s = time.time() - start
135
+ audio, sr = librosa.load(str(wav_file) + ".wav")
136
+ duration_s = librosa.get_duration(y=audio, sr=sr)
137
+ print(f"\nTotal time to synth (s): {time_to_synth_s}")
138
+ print(f"Real-time factor: {time_to_synth_s / duration_s:.2f}")
139
+
140
+ return str(wav_file) + ".wav"
141
+
142
+
143
+ if __name__ == "__main__":
144
+ parser = argparse.ArgumentParser(description="Split list into chunks")
145
+ parser.add_argument("--in_file", default="/home/data/reazon_large-v2_denoise/large.jsonl", help="Name of the file")
146
+ parser.add_argument("--text", type=str, default="This is a demo for text-to-speech.", help="input text")
147
+ parser.add_argument("--spk_ref_path", type=str, default="assets/bria.mp3", help="speaker path")
148
+ parser.add_argument("--first_model_path", type=str, default=None, help="path to the first model")
149
+ args = parser.parse_args()
150
+ tts = TTS(first_model_path=args.first_model_path)
151
+ tts.synthesise(text=args.text, spk_ref_path=args.spk_ref_path)
fam/llm/fast_inference_utils.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kotoba Technologies, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # Redistribution and use in source and binary forms, with or without modification, are permitted
5
+ # provided that the following conditions are met:
6
+ #
7
+ # 1. Redistributions of source code must retain the above copyright notice, this list of
8
+ # conditions and the following disclaimer.
9
+ #
10
+ # 2. Redistributions in binary form must reproduce the above copyright notice, this
11
+ # list of conditions and the following disclaimer in the documentation and/or other
12
+ # materials provided with the distribution.
13
+ #
14
+ # 3. Neither the name of the copyright holder nor the names of its contributors
15
+ # may be used to endorse or promote products derived from this software without
16
+ # specific prior written permission.
17
+ #
18
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR
19
+ # IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
20
+ # FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
21
+ # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
23
+ # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24
+ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
+ import itertools
27
+ import gc
28
+ import time
29
+ from pathlib import Path
30
+ from typing import Optional, Tuple
31
+
32
+ import torch
33
+ import torch._dynamo.config
34
+ import torch._inductor.config
35
+ import tqdm
36
+
37
+
38
+ def device_sync(device):
39
+ if "cuda" in device:
40
+ torch.cuda.synchronize()
41
+ elif "cpu" in device:
42
+ pass
43
+ else:
44
+ print(f"device={device} is not yet suppported")
45
+
46
+
47
+ torch._inductor.config.coordinate_descent_tuning = True
48
+ torch._inductor.config.triton.unique_kernel_names = True
49
+ torch._inductor.config.fx_graph_cache = (
50
+ True # Experimental feature to reduce compilation times, will be on by default in future
51
+ )
52
+
53
+ # imports need to happen after setting above flags
54
+ from fam.llm.fast_model import Transformer
55
+ from fam.quantiser.audio.speaker_encoder.model import SpeakerEncoder
56
+ from fam.quantiser.text.tokenise import TrainedBPETokeniser
57
+
58
+
59
+ def multinomial_sample_one_no_sync(
60
+ probs_sort,
61
+ ): # Does multinomial sampling without a cuda synchronization
62
+ q = torch.empty_like(probs_sort).exponential_(1)
63
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
64
+
65
+
66
+ def top_p_sample(logits: torch.Tensor, top_p: torch.Tensor):
67
+ # ref: huggingface/transformers
68
+
69
+ sorted_logits, sorted_indices = torch.sort(logits, descending=False)
70
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
71
+
72
+ # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
73
+ sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
74
+ # Keep at least min_tokens_to_keep
75
+ sorted_indices_to_remove[-1:] = 0
76
+
77
+ # scatter sorted tensors to original indexing
78
+ indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
79
+ scores = logits.masked_fill(indices_to_remove, -float("Inf"))
80
+ return scores
81
+
82
+
83
+ def logits_to_probs(
84
+ logits,
85
+ *,
86
+ temperature: torch.Tensor,
87
+ top_p: Optional[torch.Tensor] = None,
88
+ top_k: Optional[torch.Tensor] = None,
89
+ ):
90
+ logits = logits / torch.max(temperature, 1e-5 * torch.ones_like(temperature))
91
+
92
+ if top_k is not None:
93
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
94
+ pivot = v.select(-1, -1).unsqueeze(-1)
95
+ logits = torch.where(logits < pivot, -float("Inf"), logits)
96
+
97
+ if top_p is not None:
98
+ logits = top_p_sample(logits, top_p)
99
+
100
+ probs = torch.nn.functional.softmax(logits, dim=-1)
101
+
102
+ return probs
103
+
104
+
105
+ def sample(
106
+ logits,
107
+ guidance_scale: torch.Tensor,
108
+ temperature: torch.Tensor,
109
+ top_p: Optional[torch.Tensor] = None,
110
+ top_k: Optional[torch.Tensor] = None,
111
+ ):
112
+ # (b, t, vocab_size)
113
+ logits = logits[:, -1]
114
+ logits_cond, logits_uncond_spkemb = logits.split(logits.size(0) // 2, dim=0)
115
+ logits = guidance_scale * logits_cond + (1 - guidance_scale) * logits_uncond_spkemb
116
+ probs = logits_to_probs(logits[0], temperature=temperature, top_p=top_p, top_k=top_k)
117
+ idx_next = multinomial_sample_one_no_sync(probs)
118
+ return idx_next, probs
119
+
120
+
121
+ def prefill(
122
+ model: Transformer,
123
+ x: torch.Tensor,
124
+ spk_emb: torch.Tensor,
125
+ input_pos: torch.Tensor,
126
+ **sampling_kwargs,
127
+ ) -> torch.Tensor:
128
+ # input_pos: [B, S]
129
+ logits = model(x, spk_emb, input_pos)
130
+ return sample(logits, **sampling_kwargs)[0]
131
+
132
+
133
+ def decode_one_token(
134
+ model: Transformer,
135
+ x: torch.Tensor,
136
+ spk_emb: torch.Tensor,
137
+ input_pos: torch.Tensor,
138
+ **sampling_kwargs,
139
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
140
+ # input_pos: [B, 1]
141
+ assert input_pos.shape[-1] == 1
142
+ logits = model(x, spk_emb, input_pos)
143
+ return sample(logits, **sampling_kwargs)
144
+
145
+
146
+ def decode_n_tokens(
147
+ model: Transformer,
148
+ cur_token: torch.Tensor,
149
+ spk_emb: torch.Tensor,
150
+ input_pos: torch.Tensor,
151
+ num_new_tokens: int,
152
+ callback=lambda _: _,
153
+ return_probs: bool = False,
154
+ end_of_audio_token: int = 2048,
155
+ **sampling_kwargs,
156
+ ):
157
+ new_tokens, new_probs = [], []
158
+ for i in tqdm.tqdm(range(num_new_tokens)):
159
+ if (cur_token == end_of_audio_token).any():
160
+ break
161
+ with torch.backends.cuda.sdp_kernel(
162
+ enable_flash=False, enable_mem_efficient=False, enable_math=True
163
+ ): # Actually better for Inductor to codegen attention here
164
+ next_token, next_prob = decode_one_token(model, cur_token, spk_emb, input_pos, **sampling_kwargs)
165
+ input_pos += 1
166
+ new_tokens.append(next_token.clone())
167
+ callback(new_tokens[-1])
168
+ if return_probs:
169
+ new_probs.append(next_prob.clone())
170
+ cur_token = next_token.view(1, -1).repeat(2, 1)
171
+
172
+ return new_tokens, new_probs
173
+
174
+
175
+ def model_forward(model, x, spk_emb, input_pos):
176
+ return model(x, spk_emb, input_pos)
177
+
178
+
179
+ @torch.no_grad()
180
+ def generate(
181
+ model: Transformer,
182
+ prompt: torch.Tensor,
183
+ spk_emb: torch.Tensor,
184
+ *,
185
+ max_new_tokens: Optional[int] = None,
186
+ callback=lambda x: x,
187
+ end_of_audio_token: int = 2048,
188
+ **sampling_kwargs,
189
+ ) -> torch.Tensor:
190
+ """
191
+ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
192
+ """
193
+ # create an empty tensor of the expected final shape and fill in the current tokens
194
+ T = prompt.size(0)
195
+ if max_new_tokens is None:
196
+ max_seq_length = model.config.block_size
197
+ else:
198
+ max_seq_length = T + max_new_tokens
199
+ max_seq_length = min(max_seq_length, model.config.block_size)
200
+ max_new_tokens = max_seq_length - T
201
+ if max_new_tokens <= 0:
202
+ raise ValueError("Prompt is too long to generate more tokens")
203
+
204
+ device, dtype = prompt.device, prompt.dtype
205
+
206
+ seq = torch.clone(prompt)
207
+ input_pos = torch.arange(0, T, device=device)
208
+
209
+ next_token = prefill(model, prompt.view(1, -1).repeat(2, 1), spk_emb, input_pos, **sampling_kwargs)
210
+ seq = torch.cat([seq, next_token.view(1)])
211
+
212
+ input_pos = torch.tensor([T], device=device, dtype=torch.int)
213
+
214
+ generated_tokens, _ = decode_n_tokens(
215
+ model,
216
+ next_token.view(1, -1).repeat(2, 1),
217
+ spk_emb,
218
+ input_pos,
219
+ max_new_tokens - 1,
220
+ callback=callback,
221
+ end_of_audio_token=end_of_audio_token,
222
+ **sampling_kwargs,
223
+ )
224
+ seq = torch.cat([seq, torch.cat(generated_tokens)])
225
+
226
+ return seq
227
+
228
+
229
+ def encode_tokens(tokenizer, string, device="cuda"):
230
+ tokens = tokenizer.encode(string)
231
+ return torch.tensor(tokens, dtype=torch.int, device=device)
232
+
233
+
234
+ def _load_model(checkpoint_path, spk_emb_ckpt_path, device, precision, first_model_path=None, unwanted_prefix="_orig_mod."):
235
+ ##### MODEL
236
+ with torch.device("meta"):
237
+ model = Transformer.from_name("kotoba-speech-v0.1")
238
+
239
+ # TODO(quantization): enable
240
+ # if "int8" in str(checkpoint_path):
241
+ # print("Using int8 weight-only quantization!")
242
+ # from quantize import WeightOnlyInt8QuantHandler
243
+ # simple_quantizer = WeightOnlyInt8QuantHandler(model)
244
+ # model = simple_quantizer.convert_for_runtime()
245
+ # from quantize import WeightOnlyInt8QuantHandler
246
+
247
+ # if "int4" in str(checkpoint_path):
248
+ # print("Using int4 quantization!")
249
+ # path_comps = checkpoint_path.name.split(".")
250
+ # assert path_comps[-2].startswith("g")
251
+ # groupsize = int(path_comps[-2][1:])
252
+ # from quantize import WeightOnlyInt4QuantHandler
253
+ # simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
254
+ # model = simple_quantizer.convert_for_runtime()
255
+
256
+ checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=False)
257
+
258
+ ###### TOKENIZER
259
+ tokenizer_info = checkpoint.get("meta", {}).get("tokenizer", {})
260
+ tokenizer = TrainedBPETokeniser(**tokenizer_info)
261
+
262
+ if first_model_path is not None:
263
+ trained_ckpt = torch.load(str(first_model_path), mmap=True, weights_only=False)
264
+ state_dict = trained_ckpt["state_dict"]
265
+ del checkpoint
266
+ gc.collect()
267
+ torch.cuda.empty_cache()
268
+ else:
269
+ checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=False)
270
+ if "state_dict" in checkpoint.keys():
271
+ state_dict = checkpoint["state_dict"]
272
+ else:
273
+ state_dict = checkpoint["model"]
274
+ # convert Kotoba-Speech model weights naming to gptfast naming
275
+ for k, v in list(state_dict.items()):
276
+ if k.startswith(unwanted_prefix):
277
+ state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
278
+ state_dict["tok_embeddings.weight"] = state_dict.pop("transformer.wtes.0.weight")
279
+ state_dict["pos_embeddings.weight"] = state_dict.pop("transformer.wpe.weight")
280
+ state_dict["output.weight"] = state_dict.pop("lm_heads.0.weight")
281
+ state_dict["norm.weight"] = state_dict.pop("transformer.ln_f.weight")
282
+ for k, v in list(state_dict.items()):
283
+ if k.startswith("transformer.h."):
284
+ state_dict[k.replace("transformer.h.", "layers.")] = state_dict.pop(k)
285
+ k = k.replace("transformer.h.", "layers.")
286
+ if ".attn.c_attn." in k:
287
+ state_dict[k.replace(".attn.c_attn.", ".attention.wqkv.")] = state_dict.pop(k)
288
+ k = k.replace(".attn.c_attn.", ".attention.wqkv.")
289
+ if ".attn.c_proj." in k:
290
+ state_dict[k.replace(".attn.c_proj.", ".attention.wo.")] = state_dict.pop(k)
291
+ k = k.replace(".attn.c_proj.", ".attention.wo.")
292
+ if ".mlp.swiglu.w1." in k:
293
+ state_dict[k.replace(".mlp.swiglu.w1.", ".feed_forward.swiglu.w1.")] = state_dict.pop(k)
294
+ k = k.replace(".mlp.swiglu.w1.", ".feed_forward.swiglu.w1.")
295
+ if ".mlp.swiglu.w3." in k:
296
+ state_dict[k.replace(".mlp.swiglu.w3.", ".feed_forward.swiglu.w3.")] = state_dict.pop(k)
297
+ k = k.replace(".mlp.swiglu.w3.", ".feed_forward.swiglu.w3.")
298
+ if ".ln_1." in k:
299
+ state_dict[k.replace(".ln_1.", ".attention_norm.")] = state_dict.pop(k)
300
+ k = k.replace(".ln_1.", ".attention_norm.")
301
+ if ".ln_2." in k:
302
+ state_dict[k.replace(".ln_2.", ".ffn_norm.")] = state_dict.pop(k)
303
+ k = k.replace(".ln_2.", ".ffn_norm.")
304
+ if ".mlp.c_proj." in k:
305
+ state_dict[k.replace(".mlp.c_proj.", ".feed_forward.w2.")] = state_dict.pop(k)
306
+ k = k.replace(".mlp.c_proj.", ".feed_forward.w2.")
307
+
308
+ model.load_state_dict(state_dict, assign=True)
309
+ # simple_quantizer = WeightOnlyInt8QuantHandler(model)
310
+ # quantized_state_dict = simple_quantizer.create_quantized_state_dict()
311
+ # model = simple_quantizer.convert_for_runtime()
312
+ # model.load_state_dict(quantized_state_dict, assign=True)
313
+ model = model.to(device=device, dtype=precision)
314
+
315
+ ###### SPEAKER EMBEDDER
316
+ # TODO: fix!
317
+ smodel = SpeakerEncoder(
318
+ weights_fpath=spk_emb_ckpt_path,
319
+ device=device,
320
+ eval=True,
321
+ verbose=False,
322
+ )
323
+ return model.eval(), tokenizer, smodel
324
+
325
+
326
+ def build_model(
327
+ *,
328
+ precision: torch.dtype,
329
+ checkpoint_path: Path = Path(""),
330
+ spk_emb_ckpt_path: Path = Path(""),
331
+ compile_prefill: bool = False,
332
+ compile: bool = True,
333
+ device: str = "cuda",
334
+ first_model_path: str = None,
335
+ ):
336
+ assert checkpoint_path.is_file(), checkpoint_path
337
+
338
+ print(f"Using device={device}")
339
+
340
+ print("Loading model ...")
341
+ t0 = time.time()
342
+ if first_model_path is None:
343
+ # model, tokenizer, smodel = _load_model(checkpoint_path, spk_emb_ckpt_path, device, precision)
344
+ model, tokenizer, smodel = _load_model(
345
+ checkpoint_path, spk_emb_ckpt_path, device, precision, unwanted_prefix="first_stage_model_transformer."
346
+ )
347
+
348
+ else:
349
+ model, tokenizer, smodel = _load_model(checkpoint_path, spk_emb_ckpt_path, device, precision, first_model_path, unwanted_prefix="first_stage_model_transformer.")
350
+
351
+
352
+ device_sync(device=device) # MKG
353
+ print(f"Time to load model: {time.time() - t0:.02f} seconds")
354
+
355
+ torch.manual_seed(1234)
356
+ model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())])
357
+
358
+ with torch.device(device):
359
+ model.setup_spk_cond_mask()
360
+ model.setup_caches(max_batch_size=2, max_seq_length=model.config.block_size)
361
+
362
+ if compile:
363
+ print("Compiling...Can take up to 2 mins.")
364
+ global decode_one_token, prefill
365
+ decode_one_token = torch.compile(
366
+ decode_one_token,
367
+ mode="max-autotune",
368
+ fullgraph=True,
369
+ )
370
+
371
+ if compile_prefill:
372
+ prefill = torch.compile(
373
+ prefill,
374
+ fullgraph=True,
375
+ dynamic=True,
376
+ )
377
+
378
+ encoded = encode_tokens(tokenizer, "Hello, what's up?", device=device)
379
+ spk_emb = torch.randn((1, 256), device=device, dtype=precision)
380
+
381
+ device_sync(device=device) # MKG
382
+ t0 = time.perf_counter()
383
+ y = generate(
384
+ model,
385
+ encoded,
386
+ spk_emb,
387
+ max_new_tokens=200,
388
+ callback=lambda x: x,
389
+ temperature=torch.tensor(1.0, device=device, dtype=precision),
390
+ top_k=None,
391
+ top_p=torch.tensor(0.95, device=device, dtype=precision),
392
+ guidance_scale=torch.tensor(3.0, device=device, dtype=precision),
393
+ end_of_audio_token=9999, # don't end early for compilation stage.
394
+ )
395
+
396
+ device_sync(device=device) # MKG
397
+
398
+ print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
399
+
400
+ return model, tokenizer, smodel, model_size
401
+
402
+
403
+ def main(
404
+ *,
405
+ model,
406
+ tokenizer,
407
+ model_size,
408
+ prompt: str,
409
+ guidance_scale: torch.Tensor,
410
+ temperature: torch.Tensor,
411
+ spk_emb: torch.Tensor,
412
+ top_k: Optional[torch.Tensor] = None,
413
+ top_p: Optional[torch.Tensor] = None,
414
+ device: str = "cuda",
415
+ ) -> list:
416
+ """Generates text samples based on a pre-trained Transformer model and tokenizer."""
417
+
418
+ encoded = encode_tokens(tokenizer, prompt, device=device)
419
+ prompt_length = encoded.size(0)
420
+
421
+ aggregate_metrics: dict = {
422
+ "tokens_per_sec": [],
423
+ }
424
+
425
+ device_sync(device=device) # MKG
426
+
427
+ if True:
428
+ callback = lambda x: x
429
+ t0 = time.perf_counter()
430
+
431
+ y = generate(
432
+ model,
433
+ encoded,
434
+ spk_emb,
435
+ callback=callback,
436
+ temperature=temperature,
437
+ top_k=top_k,
438
+ top_p=top_p,
439
+ guidance_scale=guidance_scale,
440
+ )
441
+
442
+ device_sync(device=device) # MKG
443
+ t = time.perf_counter() - t0
444
+
445
+ tokens_generated = y.size(0) - prompt_length
446
+ tokens_sec = tokens_generated / t
447
+ aggregate_metrics["tokens_per_sec"].append(tokens_sec)
448
+ print(f"Time for 1st stage LLM inference: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec")
449
+ print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
450
+ # print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}")
451
+ print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB\n")
452
+
453
+ return y.tolist()
fam/llm/fast_model.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kotoba Technologies, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # Redistribution and use in source and binary forms, with or without modification, are permitted
5
+ # provided that the following conditions are met:
6
+ #
7
+ # 1. Redistributions of source code must retain the above copyright notice, this list of
8
+ # conditions and the following disclaimer.
9
+ #
10
+ # 2. Redistributions in binary form must reproduce the above copyright notice, this
11
+ # list of conditions and the following disclaimer in the documentation and/or other
12
+ # materials provided with the distribution.
13
+ #
14
+ # 3. Neither the name of the copyright holder nor the names of its contributors
15
+ # may be used to endorse or promote products derived from this software without
16
+ # specific prior written permission.
17
+ #
18
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR
19
+ # IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
20
+ # FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
21
+ # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
23
+ # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24
+ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
+ from dataclasses import dataclass
27
+ from functools import reduce
28
+ from math import gcd
29
+ from typing import Optional, Tuple
30
+
31
+ import torch
32
+ import torch.nn as nn
33
+ from torch import Tensor
34
+ from torch.nn import functional as F
35
+
36
+ from fam.llm.utils import get_default_dtype
37
+
38
+ import logging
39
+
40
+ # Adjust the logging level
41
+ logger = logging.getLogger("torch")
42
+ logger.setLevel(logging.ERROR)
43
+
44
+
45
+ def find_multiple(n: int, *args: Tuple[int]) -> int:
46
+ k = reduce(lambda x, y: x * y // gcd(x, y), args + (1,))
47
+ if n % k == 0:
48
+ return n
49
+ return n + k - (n % k)
50
+
51
+
52
+ @dataclass
53
+ class ModelArgs:
54
+ block_size: int = 2048
55
+ vocab_size: int = 32000
56
+ n_layer: int = 32
57
+ n_head: int = 32
58
+ dim: int = 4096
59
+ speaker_emb_dim: int = 256
60
+ intermediate_size: int = None
61
+ n_local_heads: int = -1
62
+ head_dim: int = 64
63
+ norm_eps: float = 1e-5
64
+ dtype: torch.dtype = torch.bfloat16
65
+
66
+ def __post_init__(self):
67
+ if self.n_local_heads == -1:
68
+ self.n_local_heads = self.n_head
69
+ if self.intermediate_size is None:
70
+ hidden_dim = 4 * self.dim
71
+ n_hidden = int(2 * hidden_dim / 3)
72
+ self.intermediate_size = find_multiple(n_hidden, 256)
73
+ self.head_dim = self.dim // self.n_head
74
+
75
+ self.dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[get_default_dtype()]
76
+
77
+ @classmethod
78
+ def from_name(cls, name: str):
79
+ if name in transformer_configs:
80
+ return cls(**transformer_configs[name])
81
+ # fuzzy search
82
+ config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)]
83
+ assert len(config) == 1, name
84
+ return cls(**transformer_configs[config[0]])
85
+
86
+
87
+ transformer_configs = {
88
+ "kotoba-speech-v0.1": dict(
89
+ n_layer=24,
90
+ n_head=16,
91
+ dim=2048,
92
+ vocab_size=2562,
93
+ ),
94
+ }
95
+
96
+
97
+ class KVCache(nn.Module):
98
+ def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype):
99
+ super().__init__()
100
+ cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
101
+ self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
102
+ self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
103
+
104
+ def update(self, input_pos, k_val, v_val):
105
+ # input_pos: [S], k_val: [B, H, S, D]
106
+ assert input_pos.shape[0] == k_val.shape[2]
107
+
108
+ k_out = self.k_cache
109
+ v_out = self.v_cache
110
+ k_out[:, :, input_pos] = k_val
111
+ v_out[:, :, input_pos] = v_val
112
+
113
+ return k_out, v_out
114
+
115
+
116
+ class Transformer(nn.Module):
117
+ def __init__(self, config: ModelArgs) -> None:
118
+ super().__init__()
119
+ self.config = config
120
+
121
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
122
+ self.pos_embeddings = nn.Embedding(config.block_size, config.dim)
123
+ self.speaker_cond_pos = nn.Linear(config.speaker_emb_dim, config.dim, bias=False)
124
+ self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
125
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
126
+ self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
127
+
128
+ self.mask_cache: Optional[Tensor] = None
129
+ self.max_batch_size = -1
130
+ self.max_seq_length = -1
131
+
132
+ def setup_spk_cond_mask(self):
133
+ self.spk_cond_mask = torch.zeros((2, 1, self.config.dim), dtype=torch.bool)
134
+ self.spk_cond_mask[0] = 1
135
+
136
+ def setup_caches(self, max_batch_size, max_seq_length):
137
+ if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
138
+ return
139
+ head_dim = self.config.dim // self.config.n_head
140
+ max_seq_length = find_multiple(max_seq_length, 8)
141
+ self.max_seq_length = max_seq_length
142
+ self.max_batch_size = max_batch_size
143
+ for b in self.layers:
144
+ b.attention.kv_cache = KVCache(
145
+ max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype=self.config.dtype
146
+ )
147
+
148
+ self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
149
+
150
+ def forward(self, idx: Tensor, spk_emb: Tensor, input_pos: Tensor) -> Tensor:
151
+ mask = self.causal_mask[None, None, input_pos]
152
+ x = (
153
+ self.tok_embeddings(idx)
154
+ + self.pos_embeddings(input_pos)
155
+ # masking for speaker condition free guidance
156
+ + self.speaker_cond_pos(spk_emb) * self.spk_cond_mask
157
+ )
158
+
159
+ for i, layer in enumerate(self.layers):
160
+ x = layer(x, input_pos, mask)
161
+ x = self.norm(x)
162
+ logits = self.output(x)
163
+ return logits
164
+
165
+ @classmethod
166
+ def from_name(cls, name: str):
167
+ return cls(ModelArgs.from_name(name))
168
+
169
+
170
+ class TransformerBlock(nn.Module):
171
+ def __init__(self, config: ModelArgs) -> None:
172
+ super().__init__()
173
+ self.attention = Attention(config)
174
+ self.feed_forward = FeedForward(config)
175
+ self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
176
+ self.attention_norm = RMSNorm(config.dim, config.norm_eps)
177
+
178
+ def forward(self, x: Tensor, input_pos: Tensor, mask: Tensor) -> Tensor:
179
+ h = x + self.attention(self.attention_norm(x), mask, input_pos)
180
+ out = h + self.feed_forward(self.ffn_norm(h))
181
+ return out
182
+
183
+
184
+ class Attention(nn.Module):
185
+ def __init__(self, config: ModelArgs):
186
+ super().__init__()
187
+ assert config.dim % config.n_head == 0
188
+
189
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
190
+ # key, query, value projections for all heads, but in a batch
191
+ self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
192
+ self.wo = nn.Linear(config.dim, config.dim, bias=False)
193
+ self.kv_cache = None
194
+
195
+ self.n_head = config.n_head
196
+ self.head_dim = config.head_dim
197
+ self.n_local_heads = config.n_local_heads
198
+ self.dim = config.dim
199
+
200
+ def forward(
201
+ self,
202
+ x: Tensor,
203
+ mask: Tensor,
204
+ input_pos: Optional[Tensor] = None,
205
+ ) -> Tensor:
206
+ bsz, seqlen, _ = x.shape
207
+
208
+ kv_size = self.n_local_heads * self.head_dim
209
+ q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
210
+
211
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
212
+ k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
213
+ v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
214
+
215
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
216
+
217
+ if self.kv_cache is not None:
218
+ k, v = self.kv_cache.update(input_pos, k, v)
219
+
220
+ k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
221
+ v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
222
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
223
+
224
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
225
+
226
+ y = self.wo(y)
227
+ return y
228
+
229
+
230
+ class SwiGLU(nn.Module):
231
+ def __init__(self, config: ModelArgs) -> None:
232
+ super().__init__()
233
+ self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
234
+ self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
235
+
236
+ def forward(self, x: Tensor) -> Tensor:
237
+ return F.silu(self.w1(x)) * self.w3(x)
238
+
239
+
240
+ class FeedForward(nn.Module):
241
+ def __init__(self, config: ModelArgs) -> None:
242
+ super().__init__()
243
+ self.swiglu = SwiGLU(config)
244
+ self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
245
+
246
+ def forward(self, x: Tensor) -> Tensor:
247
+ return self.w2(self.swiglu(x))
248
+
249
+
250
+ class RMSNorm(nn.Module):
251
+ def __init__(self, dim: int, eps: float = 1e-5):
252
+ super().__init__()
253
+ self.eps = eps
254
+ self.weight = nn.Parameter(torch.ones(dim))
255
+
256
+ def _norm(self, x):
257
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
258
+
259
+ def forward(self, x: Tensor) -> Tensor:
260
+ output = self._norm(x.float()).type_as(x)
261
+ return output * self.weight
fam/llm/inference.py ADDED
@@ -0,0 +1,710 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import hashlib
3
+ import json
4
+ import os
5
+ import pathlib
6
+ import shutil
7
+ import subprocess
8
+ import tempfile
9
+ import time
10
+ from contextlib import nullcontext
11
+ from dataclasses import dataclass
12
+ from typing import List, Literal, Optional, Tuple, Type, Union
13
+
14
+ import torch
15
+ import tqdm
16
+ import tqdm.contrib.concurrent
17
+ import tyro
18
+ from huggingface_hub import snapshot_download
19
+
20
+ from fam.llm.adapters import FlattenedInterleavedEncodec2Codebook, TiltedEncodec
21
+ from fam.llm.decoders import Decoder, EncodecDecoder
22
+ from fam.llm.enhancers import BaseEnhancer, get_enhancer
23
+ from fam.llm.model import GPT, GPTConfig
24
+ from fam.llm.utils import check_audio_file, get_default_dtype, normalize_text
25
+ from fam.quantiser.audio.speaker_encoder.model import SpeakerEncoder
26
+ from fam.quantiser.text.tokenise import TrainedBPETokeniser
27
+
28
+
29
+ @dataclass
30
+ class InferenceConfig:
31
+ ckpt_path: str # path to checkpoint
32
+ output_dir: str
33
+ num_samples: int = 10 # number of samples to draw
34
+ seed: int = 1337 # random seed
35
+ device: str = "cuda"
36
+ dtype: str = "bfloat16"
37
+ compile: bool = False
38
+ init_from: str = "resume" # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
39
+
40
+ def __str__(self):
41
+ field_strs = []
42
+ for field in dataclasses.fields(self):
43
+ value = getattr(self, field.name)
44
+ field_strs.append(f" {field.name}: {value}")
45
+
46
+ return "InferenceConfig:\n" + "\n".join(field_strs)
47
+
48
+
49
+ class Model:
50
+ def __init__(
51
+ self,
52
+ config: InferenceConfig,
53
+ tokenizer_cls: Type[TrainedBPETokeniser],
54
+ decoder_cls: Type[Decoder],
55
+ data_adapter_fn,
56
+ use_kv_cache: Optional[Literal["vanilla"]] = None,
57
+ ):
58
+ # TODO: disentangle the encodec stuff and numbers etc with rest of this code (esp at encoder-only / second stage model inference)
59
+ # TODO: remove magic number
60
+ self._encodec_codes_pad_token = 1024
61
+ self._num_encodec_codebooks = 8
62
+ self.config = config
63
+ self.use_kv_cache = use_kv_cache
64
+
65
+ torch.manual_seed(config.seed)
66
+ torch.cuda.manual_seed(config.seed)
67
+ torch.backends.cuda.matmul.allow_tf32 = True if config.dtype != "float32" else False # allow tf32 on matmul
68
+ torch.backends.cudnn.allow_tf32 = True if config.dtype != "float32" else False # allow tf32 on cudnn
69
+ device_type = "cuda" if "cuda" in config.device else "cpu" # for later use in torch.autocast
70
+ self.ptdtype = {
71
+ "float32": torch.float32,
72
+ "tfloat32": torch.float32,
73
+ "bfloat16": torch.bfloat16,
74
+ "float16": torch.float16,
75
+ }[config.dtype]
76
+ self._ctx = (
77
+ nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=self.ptdtype)
78
+ )
79
+
80
+ self.use_bpe_tokenizer = False
81
+ self.load_meta = None
82
+ self.speaker_cond = None
83
+ self.meta = None
84
+ self.model = None
85
+ self.checkpoint_config = None
86
+ self.vocab_sizes = None
87
+ self.smodel = None
88
+
89
+ self._init_model()
90
+
91
+ self.tokenizer = tokenizer_cls(**self.meta["tokenizer"])
92
+ self.decoder = decoder_cls(
93
+ tokeniser_decode_fn=self.tokenizer.decode,
94
+ output_dir=self.config.output_dir,
95
+ data_adapter_fn=data_adapter_fn,
96
+ )
97
+
98
+ def _init_model(self):
99
+ if self.config.init_from == "resume":
100
+ # init from a model saved in a specific directory
101
+ checkpoint = torch.load(self.config.ckpt_path, map_location=self.config.device)
102
+ self.vocab_sizes = checkpoint["model_args"]["vocab_sizes"]
103
+
104
+ self.load_meta = False
105
+ self.speaker_cond = False
106
+
107
+ if "config" in checkpoint:
108
+ self.checkpoint_config = checkpoint["config"]
109
+
110
+ self.meta = checkpoint["meta"]
111
+ load_meta = True
112
+
113
+ if load_meta:
114
+ self.use_bpe_tokenizer = "stoi" not in self.meta or "itos" not in self.meta
115
+ self.speaker_cond = self.meta.get("speaker_cond")
116
+
117
+ if self.speaker_cond:
118
+ speaker_emb_size = self.meta["speaker_emb_size"]
119
+
120
+ model_args = checkpoint["model_args"]
121
+ if "causal" in self.checkpoint_config and self.checkpoint_config["causal"] is False:
122
+ self._encodec_ctx_window = model_args["block_size"]
123
+
124
+ gptconf = GPTConfig(**model_args)
125
+
126
+ # TODO: rename `speaker_emb_dim` to `speaker_emb_size`.
127
+ self.model = GPT(gptconf, speaker_emb_dim=speaker_emb_size if self.speaker_cond else None)
128
+ state_dict = checkpoint["model"]
129
+ unwanted_prefix = "_orig_mod."
130
+ for k, v in list(state_dict.items()):
131
+ if k.startswith(unwanted_prefix):
132
+ state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
133
+ self.model.load_state_dict(state_dict)
134
+
135
+ # model
136
+ self.model.eval()
137
+ self.model.to(self.config.device)
138
+
139
+ if self.config.compile:
140
+ from einops._torch_specific import allow_ops_in_compiled_graph
141
+
142
+ allow_ops_in_compiled_graph()
143
+ self.model = torch.compile(self.model) # type: ignore
144
+
145
+ if self.use_kv_cache is not None:
146
+ if "causal" in self.checkpoint_config and self.checkpoint_config["causal"] is False:
147
+ raise Exception("kv_cache not supported for non-causal models!")
148
+
149
+ if self.use_kv_cache == "vanilla":
150
+ self.model.enable_kv_cache()
151
+ else:
152
+ raise NotImplementedError(f"kv_cache type {self.use_kv_cache} not implemented!")
153
+
154
+ def causal_sample(
155
+ self,
156
+ *,
157
+ texts: list[str],
158
+ batch_size: int,
159
+ max_new_tokens: int,
160
+ temperature: Optional[float],
161
+ top_k: Optional[int],
162
+ top_p: Optional[float],
163
+ speaker_embs: Optional[torch.Tensor] = None,
164
+ guidance_scale: Optional[float] = None,
165
+ ) -> list[torch.Tensor]:
166
+ """
167
+ Returns list of torch.Tensors of tokens. Each tensor is of shape (1, c, t) where c is the number of codebooks.
168
+ Any flattening / inteleaving / tilting gets reversed before the output is returned.
169
+ """
170
+ if speaker_embs is not None:
171
+ assert len(texts) == len(speaker_embs)
172
+
173
+ encoded_texts = [self.tokenizer.encode(text) for text in texts]
174
+
175
+ ## create multiple hierarchies and get seq_lens
176
+ seq_lens = []
177
+ xs = []
178
+ for i, encoded_text in enumerate(encoded_texts):
179
+ encoded_text = torch.tensor([encoded_text], dtype=torch.long, device=self.config.device)
180
+ # TODO: remove magic number
181
+ xs.append(
182
+ torch.cat(
183
+ # [1st hierarchy of text, *remaining hierarchies of padded tokens]
184
+ # TODO: self.vocab_sizes should be from the model config?
185
+ [encoded_text, *[torch.ones_like(encoded_text) * 1024] * (len(self.vocab_sizes) - 1)],
186
+ dim=0,
187
+ ).unsqueeze(0)
188
+ ) # b x [(b=1, c, t)]
189
+ seq_lens.append(xs[-1].shape[-1])
190
+ max_len = max(seq_lens)
191
+ assert len(xs) == len(seq_lens)
192
+
193
+ ## equalise the shapes in the batch. we can use torch.zeros as tokens > seq_lens will be masked out.
194
+ x = torch.zeros((len(encoded_texts), xs[0].shape[1], max_len), dtype=torch.long, device=self.config.device)
195
+ for i, _xs in enumerate(xs):
196
+ assert _xs.shape[-1] == seq_lens[i]
197
+ x[i, :, : seq_lens[i]] = _xs
198
+
199
+ ## check that the input is correct
200
+ for i in range(x.shape[0]):
201
+ assert x[i, 0, : seq_lens[i]].tolist() == encoded_texts[i]
202
+
203
+ # TODO: remove magic number
204
+ if x.shape[1] > 1:
205
+ assert set(x[i, 1, : seq_lens[i]].tolist()) == set([1024])
206
+
207
+ assert x.shape[0] == speaker_embs.shape[0] if speaker_embs is not None else True
208
+
209
+ if self.speaker_cond is False:
210
+ speaker_embs = None
211
+
212
+ # run sampling loop
213
+ with torch.no_grad():
214
+ with self._ctx: # type: ignore
215
+ to_return = []
216
+ for k in range(self.config.num_samples):
217
+ assert seq_lens is not None
218
+ assert batch_size is not None
219
+
220
+ if max(seq_lens) + max_new_tokens >= self.model.config.block_size:
221
+ raise Exception(
222
+ f"max_new_tokens {max_new_tokens} too large! Choose {self.model.config.block_size - max(seq_lens) - 1} instead."
223
+ )
224
+
225
+ y = self.model.generate(
226
+ x,
227
+ max_new_tokens,
228
+ seq_lens=seq_lens,
229
+ temperature=temperature,
230
+ top_k=top_k,
231
+ top_p=top_p,
232
+ speaker_embs=speaker_embs,
233
+ batch_size=batch_size,
234
+ guidance_scale=guidance_scale,
235
+ dtype=self.ptdtype,
236
+ end_of_audio_token=self.tokenizer.offset - 1,
237
+ end_of_text_token=self.tokenizer.eot_token,
238
+ )
239
+ for i in range(len(y)):
240
+ to_return.append(self.decoder.decode(tokens=y[i].tolist(), causal=True))
241
+
242
+ return to_return
243
+
244
+ def non_causal_sample(
245
+ self,
246
+ *,
247
+ texts: list[str],
248
+ encodec_tokens: list[torch.Tensor],
249
+ batch_size: int,
250
+ top_k: Optional[int],
251
+ temperature: Optional[float],
252
+ speaker_embs: Optional[torch.Tensor] = None,
253
+ ) -> list[str]:
254
+ """
255
+ Returns paths to saved audio files.
256
+ """
257
+ if speaker_embs is not None:
258
+ assert len(texts) == len(speaker_embs)
259
+
260
+ encoded_texts = [self.tokenizer.encode(text) for text in texts]
261
+
262
+ # setup input
263
+ # TODO: same code is used during data prep. refactor
264
+ padded_hierarchies_inputs = []
265
+ for encoded_text, encodec_token in zip(encoded_texts, encodec_tokens):
266
+ x = torch.tensor(encoded_text, dtype=torch.long, device=self.config.device)[
267
+ None, None, ...
268
+ ] # (b=1, c=1, t)
269
+
270
+ # TODO: should only happen if decoder is encodecdeocder?
271
+ assert encodec_token.shape[0] == 1
272
+ encodec_token = encodec_token[0].tolist() # (b=1, c, t) -> (c, t)
273
+ assert len(encodec_token) >= 1 and len(encodec_token) <= self._num_encodec_codebooks
274
+
275
+ ## setup hierarchies of tokens
276
+ # TODO: refactor and merge with code in processing.py
277
+ text_tokens = encoded_text # (t,)
278
+
279
+ hierarchies_in = []
280
+ hierarchies_in.append(text_tokens + encodec_token[0] + [self._encodec_codes_pad_token])
281
+ hierarchies_in.append(
282
+ [self._encodec_codes_pad_token] * len(text_tokens) + encodec_token[1] + [self._encodec_codes_pad_token]
283
+ )
284
+
285
+ ## adding padding / cutting to the right size as needed
286
+ # TODO: refactor and merge with code in processing.py
287
+ padded_hierarchies_input = []
288
+ for _, t_hierarchy in enumerate(hierarchies_in):
289
+ assert len(t_hierarchy) == len(hierarchies_in[0])
290
+ if len(t_hierarchy) < self._encodec_ctx_window:
291
+ padded_hierarchies_input.append(
292
+ t_hierarchy + [self._encodec_codes_pad_token] * (self._encodec_ctx_window - len(t_hierarchy))
293
+ )
294
+ elif len(t_hierarchy) > self._encodec_ctx_window:
295
+ padded_hierarchies_input.append(t_hierarchy[: self._encodec_ctx_window])
296
+ else:
297
+ padded_hierarchies_input.append(t_hierarchy)
298
+
299
+ padded_hierarchies_inputs.append(padded_hierarchies_input)
300
+
301
+ ## check that the input is correct
302
+ in_x = torch.tensor(padded_hierarchies_inputs, dtype=torch.long, device=self.config.device)
303
+ assert in_x.shape[0] == speaker_embs.shape[0] if speaker_embs is not None else True
304
+
305
+ if self.speaker_cond is False:
306
+ speaker_embs = None
307
+
308
+ # run sampling loop
309
+ with torch.no_grad():
310
+ with self._ctx: # type: ignore
311
+ to_return = []
312
+ for k in range(self.config.num_samples):
313
+ y = self.model.generate(
314
+ in_x,
315
+ None,
316
+ temperature=temperature,
317
+ top_k=top_k,
318
+ # TODO: handle separate top_p for this model explicitly
319
+ top_p=None,
320
+ speaker_embs=speaker_embs,
321
+ batch_size=batch_size,
322
+ guidance_scale=None,
323
+ )
324
+
325
+ b_tokens = torch.cat([in_x, y], dim=1)
326
+ for tokens in b_tokens:
327
+ try:
328
+ to_return.append(self.decoder.decode(tokens=tokens.tolist(), causal=False))
329
+ except Exception as e:
330
+ print("failed to run MBD.")
331
+ print(f"reason: {str(e)}")
332
+ to_return.append(None)
333
+
334
+ return to_return
335
+
336
+ def __call__(
337
+ self,
338
+ *,
339
+ texts: list[str],
340
+ batch_size: int,
341
+ max_new_tokens: Optional[int],
342
+ top_k: Optional[int],
343
+ top_p: Optional[float],
344
+ temperature: Optional[float],
345
+ encodec_tokens: Optional[list[torch.Tensor]] = None,
346
+ speaker_embs: Optional[torch.Tensor] = None,
347
+ guidance_scale: Optional[float] = None,
348
+ ):
349
+ if self.checkpoint_config.get("causal", True):
350
+ return self.causal_sample(
351
+ texts=texts,
352
+ batch_size=batch_size,
353
+ speaker_embs=speaker_embs,
354
+ guidance_scale=guidance_scale,
355
+ max_new_tokens=max_new_tokens,
356
+ top_k=top_k,
357
+ top_p=top_p,
358
+ temperature=temperature,
359
+ )
360
+ else:
361
+ assert encodec_tokens is not None
362
+ assert guidance_scale is None
363
+ assert max_new_tokens is None
364
+ assert top_p is None
365
+
366
+ return self.non_causal_sample(
367
+ texts=texts,
368
+ encodec_tokens=encodec_tokens,
369
+ batch_size=batch_size,
370
+ speaker_embs=speaker_embs,
371
+ top_k=top_k,
372
+ temperature=temperature,
373
+ )
374
+
375
+
376
+ def save_result_metadata(wav_path, ref_path, text, first_stage_ckpt_path, second_stage_ckpt_path):
377
+ if first_stage_ckpt_path is None or second_stage_ckpt_path is None:
378
+ return
379
+ json.dump(
380
+ {
381
+ "speaker": ref_path,
382
+ "text": text,
383
+ },
384
+ pathlib.Path(str(wav_path) + ".json").open("w"),
385
+ )
386
+
387
+
388
+ def get_cached_file(file_or_uri: str):
389
+ """
390
+ If it's an s3 file, download it to a local temporary file and return that path.
391
+ Otherwise return the path as is.
392
+ """
393
+ is_uri = file_or_uri.startswith("http")
394
+
395
+ cache_path = None
396
+ if is_uri:
397
+ ext = pathlib.Path(file_or_uri).suffix
398
+ # hash the file path to get the cache name
399
+ _cache_name = "audio_" + hashlib.md5(file_or_uri.encode("utf-8")).hexdigest() + ext
400
+
401
+ os.makedirs(os.path.expanduser("~/.cache/fam/"), exist_ok=True)
402
+ cache_path = os.path.expanduser(f"~/.cache/fam/{_cache_name}")
403
+
404
+ if not os.path.exists(cache_path):
405
+ command = f"curl -o {cache_path} {file_or_uri}"
406
+ subprocess.run(command, shell=True, check=True)
407
+ else:
408
+ if os.path.exists(file_or_uri):
409
+ cache_path = file_or_uri
410
+ else:
411
+ raise FileNotFoundError(f"File {file_or_uri} not found!")
412
+ return cache_path
413
+
414
+
415
+ def get_cached_embedding(local_file_path: str, spkemb_model):
416
+ if not os.path.exists(local_file_path):
417
+ raise FileNotFoundError(f"File {local_file_path} not found!")
418
+
419
+ # hash the file path to get the cache name
420
+ _cache_name = "embedding_" + hashlib.md5(local_file_path.encode("utf-8")).hexdigest() + ".pt"
421
+
422
+ os.makedirs(os.path.expanduser("~/.cache/fam/"), exist_ok=True)
423
+ cache_path = os.path.expanduser(f"~/.cache/fam/{_cache_name}")
424
+
425
+ if not os.path.exists(cache_path):
426
+ spk_emb = spkemb_model.embed_utterance_from_file(local_file_path, numpy=False).unsqueeze(0) # (b=1, c)
427
+ torch.save(spk_emb, cache_path)
428
+ else:
429
+ spk_emb = torch.load(cache_path)
430
+
431
+ return spk_emb
432
+
433
+
434
+ def _sample_utterance_batch(
435
+ texts: list[str],
436
+ spk_cond_paths: list[Optional[str]],
437
+ spkemb_model,
438
+ first_stage_model,
439
+ second_stage_model,
440
+ enhancer: Optional[Union[Literal["df"], BaseEnhancer]],
441
+ first_stage_ckpt_path: str,
442
+ second_stage_ckpt_path: str,
443
+ guidance_scale: Optional[Tuple[float, float]],
444
+ max_new_tokens: int,
445
+ top_k: Optional[int],
446
+ top_p: Optional[float],
447
+ temperature: Optional[float],
448
+ batch_size: int = 128,
449
+ ) -> List[str]:
450
+
451
+ speaker_embs = []
452
+ refs = spk_cond_paths.copy()
453
+
454
+ # multithreaded loop to cache all the files
455
+ spk_cond_paths = tqdm.contrib.concurrent.thread_map(
456
+ get_cached_file, spk_cond_paths, desc="getting cached speaker ref files"
457
+ )
458
+
459
+ for i, (text, spk_cond_path) in tqdm.tqdm(
460
+ enumerate(zip(texts, spk_cond_paths)), total=len(texts), desc="calculating speaker embeddings"
461
+ ):
462
+ texts[i] = normalize_text(text)
463
+ speaker_embs.append(get_cached_embedding(spk_cond_path, spkemb_model) if spk_cond_path else None)
464
+
465
+ b_speaker_embs = torch.cat(speaker_embs, dim=0)
466
+
467
+ start = time.time()
468
+ b_tokens = first_stage_model(
469
+ texts=texts,
470
+ speaker_embs=b_speaker_embs,
471
+ batch_size=batch_size,
472
+ guidance_scale=guidance_scale,
473
+ top_p=top_p,
474
+ top_k=top_k,
475
+ temperature=temperature,
476
+ max_new_tokens=max_new_tokens,
477
+ )
478
+
479
+ # TODO: set batch size for second stage model!
480
+ wav_files = second_stage_model(
481
+ texts=texts,
482
+ encodec_tokens=b_tokens,
483
+ speaker_embs=b_speaker_embs,
484
+ batch_size=batch_size,
485
+ guidance_scale=None,
486
+ top_p=None,
487
+ top_k=top_k,
488
+ temperature=temperature,
489
+ max_new_tokens=None,
490
+ )
491
+
492
+ for text, tokens, speaker_embs, ref_name, wav_file in zip(texts, b_tokens, b_speaker_embs, refs, wav_files):
493
+ if wav_file is None:
494
+ continue
495
+
496
+ with tempfile.NamedTemporaryFile(suffix=".wav") as enhanced_tmp:
497
+ if enhancer is not None:
498
+ enhancer = get_enhancer(enhancer) if isinstance(enhancer, str) else enhancer
499
+ enhancer(str(wav_file) + ".wav", enhanced_tmp.name)
500
+ # copy enhanced_tmp.name back to wav_file
501
+ print(f"copying enhanced file from {enhanced_tmp.name} to {str(wav_file) + '.wav'}.")
502
+ shutil.copy2(enhanced_tmp.name, str(wav_file) + ".wav")
503
+
504
+ save_result_metadata(
505
+ wav_file,
506
+ ref_name,
507
+ text,
508
+ first_stage_ckpt_path,
509
+ second_stage_ckpt_path,
510
+ )
511
+
512
+ print(f"time_to_synth_s: {time.time() - start}")
513
+ return [str(w) + ".wav" if not str(w).endswith(".wav") else str(w) for w in wav_files]
514
+
515
+
516
+ def sample_utterance(
517
+ text: str,
518
+ spk_cond_path: Optional[str],
519
+ spkemb_model,
520
+ first_stage_model,
521
+ second_stage_model,
522
+ enhancer: Optional[Union[Literal["df"], BaseEnhancer]],
523
+ first_stage_ckpt_path: str,
524
+ second_stage_ckpt_path: str,
525
+ guidance_scale: Optional[Tuple[float, float]],
526
+ max_new_tokens: int,
527
+ top_k: Optional[int],
528
+ top_p: Optional[float],
529
+ temperature: Optional[float],
530
+ ) -> str:
531
+ # NOTE: supports max. 220 characters atm.
532
+ # Long form synthesis coming soon...
533
+ MAX_CHARS = 220
534
+ if len(text) > MAX_CHARS:
535
+ print(
536
+ f"\n***WARNING: Max {MAX_CHARS} characters supported. Provided: {len(text)}. Truncating and generating speech...Can lead to unpredictable speech at the end.***"
537
+ )
538
+
539
+ return _sample_utterance_batch(
540
+ texts=[text],
541
+ spk_cond_paths=[spk_cond_path],
542
+ spkemb_model=spkemb_model,
543
+ first_stage_model=first_stage_model,
544
+ second_stage_model=second_stage_model,
545
+ enhancer=enhancer,
546
+ first_stage_ckpt_path=first_stage_ckpt_path,
547
+ second_stage_ckpt_path=second_stage_ckpt_path,
548
+ batch_size=1,
549
+ guidance_scale=guidance_scale,
550
+ max_new_tokens=max_new_tokens,
551
+ top_k=top_k,
552
+ top_p=top_p,
553
+ temperature=temperature,
554
+ )[0]
555
+
556
+
557
+ def build_models(config_first_stage, config_second_stage, model_dir, device, use_kv_cache):
558
+ smodel = SpeakerEncoder(
559
+ weights_fpath=os.path.join(model_dir, "speaker_encoder.pt"), device=device, eval=True, verbose=False
560
+ )
561
+ data_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=1024)
562
+ llm_first_stage = Model(
563
+ config_first_stage,
564
+ TrainedBPETokeniser,
565
+ EncodecDecoder,
566
+ data_adapter_fn=data_adapter.decode,
567
+ use_kv_cache=use_kv_cache,
568
+ )
569
+ data_adapter_second_stage = TiltedEncodec(end_of_audio_token=1024)
570
+ llm_second_stage = Model(
571
+ config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode
572
+ )
573
+ return smodel, llm_first_stage, llm_second_stage
574
+
575
+
576
+ def get_first_stage_path(model_dir: str):
577
+ """Absolute path to checkpoint for the first stage model."""
578
+ return os.path.join(os.path.expanduser(model_dir), "first_stage.pt")
579
+
580
+
581
+ def get_second_stage_path(model_dir: str):
582
+ """Absolute path to checkpoint for the second stage model."""
583
+ return os.path.join(os.path.expanduser(model_dir), "second_stage.pt")
584
+
585
+
586
+ @dataclass
587
+ class SamplingControllerConfig:
588
+ """
589
+ Sample from a trained model.
590
+ """
591
+
592
+ spk_cond_path: str
593
+ """Path to speaker reference file. Min. 30s of audio required. Supports both local paths & public URIs. Audio formats: wav, flac & mp3"""
594
+
595
+ huggingface_repo_id: str = "kotoba-tech/kotoba-speech-v0.1"
596
+ """Absolute path to the model directory."""
597
+
598
+ text: str = (
599
+ "This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model by MetaVoice."
600
+ )
601
+ """Text to synthesise."""
602
+
603
+ num_samples: int = 1
604
+ """Number of samples to generate from each model."""
605
+
606
+ max_new_tokens: int = 864
607
+ """Maximum number of new tokens to generate from the first stage model."""
608
+
609
+ temperature: float = 1.0
610
+ """Temperature for sampling applied to both models."""
611
+
612
+ top_k: Optional[int] = None
613
+ """Top k for sampling applied to both models."""
614
+
615
+ top_p: Optional[float] = 0.95
616
+ """Top p for sampling applied to first-stage model."""
617
+
618
+ seed: int = 1337
619
+ """Random seed for sampling."""
620
+
621
+ device: Literal["cuda", "cpu"] = "cuda"
622
+ """Device to use for sampling."""
623
+
624
+ dtype: Literal["bfloat16", "float16", "float32", "tfloat32"] = get_default_dtype()
625
+ """Data type to use for sampling."""
626
+
627
+ compile: bool = False
628
+ """Whether to compile the model using PyTorch 2.0."""
629
+
630
+ enhancer: Optional[Literal["df"]] = "df"
631
+ """Enhancer to use for post-processing."""
632
+
633
+ init_from: str = "resume"
634
+ """Either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')."""
635
+
636
+ use_kv_cache: Optional[Literal["vanilla"]] = "vanilla"
637
+ """Type of kv caching to use for inference: 1) [none] no kv caching, 2) [vanilla] use torch attention with hand implemented kv-cache."""
638
+
639
+ output_dir: str = "samples/"
640
+ """Relative path to output directory"""
641
+
642
+ guidance_scale: Optional[Tuple[float, float]] = (3.0, 1.0)
643
+ """Guidance scale for sampling: (speaker conditioning guidance_scale, prompt conditioning guidance scale)."""
644
+
645
+ batch_size: int = 128
646
+ """Batch size to use for sampling. Note that the batch size gets doubled when guidance is used. For H100, and 1B model,
647
+ 1 w/ guidance and 1 w/o guidance work well (without kv-caching). With kv-caching, 128 (w/o guidance) and
648
+ 64 (w/ guidance) works well."""
649
+
650
+
651
+ if __name__ == "__main__":
652
+ # TODO: add support for batch sampling via CLI. Function has been implemented above.
653
+ sampling_config = tyro.cli(SamplingControllerConfig, use_underscores=True)
654
+
655
+ check_audio_file(sampling_config.spk_cond_path)
656
+
657
+ model_dir = snapshot_download(repo_id=sampling_config.huggingface_repo_id)
658
+ first_stage_ckpt_path = get_first_stage_path(model_dir)
659
+ second_stage_ckpt_path = get_second_stage_path(model_dir)
660
+
661
+ config_first_stage = InferenceConfig(
662
+ ckpt_path=first_stage_ckpt_path,
663
+ num_samples=sampling_config.num_samples,
664
+ seed=sampling_config.seed,
665
+ device=sampling_config.device,
666
+ dtype=sampling_config.dtype,
667
+ compile=sampling_config.compile,
668
+ init_from=sampling_config.init_from,
669
+ output_dir=sampling_config.output_dir,
670
+ )
671
+
672
+ config_second_stage = InferenceConfig(
673
+ ckpt_path=second_stage_ckpt_path,
674
+ num_samples=sampling_config.num_samples,
675
+ seed=sampling_config.seed,
676
+ device=sampling_config.device,
677
+ dtype=sampling_config.dtype,
678
+ compile=sampling_config.compile,
679
+ init_from=sampling_config.init_from,
680
+ output_dir=sampling_config.output_dir,
681
+ )
682
+
683
+ sampling_config.max_new_tokens *= (
684
+ 2 # deal with max_new_tokens for flattened interleaving! (should scale with num_codebooks?)
685
+ )
686
+
687
+ # define models
688
+ smodel, llm_first_stage, llm_second_stage = build_models(
689
+ config_first_stage,
690
+ config_second_stage,
691
+ model_dir=model_dir,
692
+ device=sampling_config.device,
693
+ use_kv_cache=sampling_config.use_kv_cache,
694
+ )
695
+
696
+ sample_utterance(
697
+ sampling_config.text,
698
+ os.path.expanduser(sampling_config.spk_cond_path),
699
+ smodel,
700
+ llm_first_stage,
701
+ llm_second_stage,
702
+ sampling_config.enhancer,
703
+ first_stage_ckpt_path,
704
+ second_stage_ckpt_path,
705
+ sampling_config.guidance_scale,
706
+ max_new_tokens=sampling_config.max_new_tokens,
707
+ top_k=sampling_config.top_k,
708
+ top_p=sampling_config.top_p,
709
+ temperature=sampling_config.temperature,
710
+ )
fam/llm/layers/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from fam.llm.layers.attn import SelfAttention
2
+ from fam.llm.layers.combined import Block
3
+ from fam.llm.layers.layers import MLP, LayerNorm, RMSNorm, SwiGLU
fam/llm/layers/attn.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+
7
+
8
+ class SelfAttention(nn.Module):
9
+ def __init__(self, config):
10
+ """
11
+ Initializes the SelfAttention module.
12
+
13
+ Args:
14
+ config: An object containing the configuration parameters for the SelfAttention module.
15
+ """
16
+ super().__init__()
17
+ self._validate_config(config)
18
+ self._initialize_parameters(config)
19
+
20
+ def empty_kv_cache(self, batch_size: int, kv_cache_maxlen: int, dtype: torch.dtype):
21
+ """
22
+ Empties the key-value cache.
23
+
24
+ Args:
25
+ batch_size: The batch size.
26
+ kv_cache_maxlen: The maximum length of the key-value cache.
27
+ dtype: The data type of the cache.
28
+
29
+ Raises:
30
+ Exception: If trying to empty the KV cache when it is disabled.
31
+ """
32
+ if self.kv_cache_enabled is False:
33
+ raise Exception("Trying to empty KV cache when it is disabled")
34
+
35
+ # register so that the cache moves devices along with the module
36
+ # TODO: get rid of re-allocation.
37
+ self.register_buffer(
38
+ "kv_cache",
39
+ torch.zeros(
40
+ 2,
41
+ batch_size,
42
+ kv_cache_maxlen,
43
+ self.n_head,
44
+ self.n_embd // self.n_head,
45
+ dtype=dtype,
46
+ device=self.c_attn.weight.device,
47
+ ),
48
+ persistent=False,
49
+ )
50
+
51
+ self.kv_cache_first_empty_index = 0
52
+
53
+ def _initialize_parameters(self, config):
54
+ """
55
+ Initializes the parameters of the SelfAttention module.
56
+
57
+ Args:
58
+ config: An object containing the configuration parameters for the SelfAttention module.
59
+ """
60
+ # key, query, value projections for all heads, but in a batch
61
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
62
+
63
+ # output projection
64
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
65
+
66
+ # regularization
67
+ self.resid_dropout = nn.Dropout(config.dropout)
68
+ self.n_head = config.n_head
69
+ self.n_embd = config.n_embd
70
+ self.dropout = config.dropout
71
+ self.causal = config.causal
72
+ self.attn_kernel_type = config.attn_kernel_type
73
+ self.attn_dropout = nn.Dropout(config.dropout)
74
+
75
+ self.kv_cache_enabled = False
76
+
77
+ def _validate_config(self, config):
78
+ """
79
+ Validates the configuration parameters.
80
+
81
+ Args:
82
+ config: An object containing the configuration parameters for the SelfAttention module.
83
+
84
+ Raises:
85
+ AssertionError: If the embedding dimension is not divisible by the number of heads.
86
+ """
87
+ assert config.n_embd % config.n_head == 0, "Embedding dimension must be divisible by number of heads"
88
+
89
+ def _update_kv_cache(self, q, k, v):
90
+ """
91
+ Updates the key-value cache.
92
+
93
+ Args:
94
+ q: The query tensor.
95
+ k: The key tensor.
96
+ v: The value tensor.
97
+
98
+ Returns:
99
+ The updated key and value tensors.
100
+
101
+ Raises:
102
+ AssertionError: If the dimensions of the query, key, and value tensors are not compatible.
103
+ """
104
+ q_time, k_time, v_time = q.shape[1], k.shape[1], v.shape[1]
105
+
106
+ if self.kv_cache_first_empty_index == 0:
107
+ assert q_time == k_time and q_time == v_time
108
+ else:
109
+ assert (
110
+ q_time == 1
111
+ ), f"Only one query at a time is supported, but got q_time={q_time} for kv_cache_first_empty_index={self.kv_cache_first_empty_index}"
112
+
113
+ self.kv_cache[0, :, self.kv_cache_first_empty_index : self.kv_cache_first_empty_index + q_time] = k
114
+ self.kv_cache[1, :, self.kv_cache_first_empty_index : self.kv_cache_first_empty_index + q_time] = v
115
+ self.kv_cache_first_empty_index += q_time
116
+
117
+ k = self.kv_cache[0, :, : self.kv_cache_first_empty_index]
118
+ v = self.kv_cache[1, :, : self.kv_cache_first_empty_index]
119
+
120
+ return k, v
121
+
122
+ def _torch_attn(self, c_x: torch.Tensor) -> torch.Tensor:
123
+ """
124
+ Performs attention using the torch.nn.functional.scaled_dot_product_attention function.
125
+
126
+ Args:
127
+ c_x: The input tensor.
128
+
129
+ Returns:
130
+ The output tensor.
131
+ """
132
+ q, k, v = c_x.split(1, dim=2) # q, k, v of shape (B, T, 1, nh, hs)
133
+ q = q.squeeze(2) # (B, T, nh, hs)
134
+ k = k.squeeze(2) # (B, T, nh, hs)
135
+ v = v.squeeze(2) # (B, T, nh, hs)
136
+
137
+ # if kv-caching and causal, for the "prefill" stage, we need to use a causal mask, and
138
+ # use no mask for the "one time step" parts.
139
+ # calculate this before updating kv_caching so we have the right value for kv_cache_first_empty_index
140
+ is_causal_attn_mask = self.causal and (not self.kv_cache_enabled or self.kv_cache_first_empty_index == 0)
141
+
142
+ if self.kv_cache_enabled:
143
+ k, v = self._update_kv_cache(q, k, v)
144
+
145
+ q = q.transpose(1, 2) # (B, nh, T, hs)
146
+ k = k.transpose(1, 2) # (B, nh, T, hs)
147
+ v = v.transpose(1, 2) # (B, nh, T, hs)
148
+ y = torch.nn.functional.scaled_dot_product_attention(
149
+ q,
150
+ k,
151
+ v,
152
+ attn_mask=None,
153
+ dropout_p=self.dropout if self.training else 0,
154
+ is_causal=is_causal_attn_mask,
155
+ ).transpose(
156
+ 1, 2
157
+ ) # (B, nh, T, hs) -> (B, T, nh, hs)
158
+
159
+ return y
160
+
161
+ def forward(self, x):
162
+ """
163
+ Performs the forward pass of the SelfAttention module.
164
+
165
+ Args:
166
+ x: The input tensor.
167
+
168
+ Returns:
169
+ The output tensor.
170
+ """
171
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
172
+
173
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
174
+ c_x = self.c_attn(x).view(B, T, 3, self.n_head, C // self.n_head) # (B, T, 3, nh, hs)
175
+
176
+ # causal self-attention;
177
+ if self.attn_kernel_type == "torch_attn":
178
+ y = self._torch_attn(c_x)
179
+ else:
180
+ raise Exception(f"Unknown attention kernel type: {self.attn_kernel_type}")
181
+
182
+ y = y.contiguous().view(B, T, C) # re-assemble all head outputs side by side: (B, T, nh, hs) -> (B, T, hs * nh)
183
+ # output projection
184
+ y = self.resid_dropout(self.c_proj(y))
185
+ return y
fam/llm/layers/combined.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from fam.llm.layers.attn import SelfAttention
4
+ from fam.llm.layers.layers import MLP, LayerNorm, RMSNorm
5
+
6
+
7
+ class Block(nn.Module):
8
+ """
9
+ Block class represents a single block in the model.
10
+
11
+ Args:
12
+ config (object): Configuration object containing parameters for the block.
13
+
14
+ Attributes:
15
+ ln_1 (object): Layer normalization for the attention layer.
16
+ ln_2 (object): Layer normalization for the feed-forward layer.
17
+ attn (object): Self-attention layer.
18
+ mlp (object): Multi-layer perceptron layer.
19
+
20
+ Methods:
21
+ forward(x): Performs forward pass through the block.
22
+ """
23
+
24
+ def __init__(self, config):
25
+ super().__init__()
26
+ if config.norm_type == "rmsnorm":
27
+ if config.rmsnorm_eps is None:
28
+ raise Exception("RMSNorm requires rmsnorm_eps to be set")
29
+ self.ln_1 = RMSNorm(config.n_embd, eps=config.rmsnorm_eps) # attn norm
30
+ self.ln_2 = RMSNorm(config.n_embd, eps=config.rmsnorm_eps) # ffn norm
31
+ elif config.norm_type == "layernorm":
32
+ self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) # attn norm
33
+ self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) # ffn norm
34
+ else:
35
+ raise Exception(f"Unknown norm type: {config.norm_type}")
36
+ self.attn = SelfAttention(config)
37
+
38
+ self.mlp = MLP(config)
39
+
40
+ def forward(self, x):
41
+ """
42
+ Performs forward pass through the block.
43
+
44
+ Args:
45
+ x (tensor): Input tensor.
46
+
47
+ Returns:
48
+ tensor: Output tensor after passing through the block.
49
+ """
50
+ x = x + self.attn(self.ln_1(x))
51
+ x = x + self.mlp(self.ln_2(x))
52
+ return x
fam/llm/layers/layers.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+
7
+
8
+ class LayerNorm(nn.Module):
9
+ """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""
10
+
11
+ def __init__(self, ndim, bias):
12
+ super().__init__()
13
+ self.weight = nn.Parameter(torch.ones(ndim))
14
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
15
+
16
+ def forward(self, input):
17
+ return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
18
+
19
+
20
+ class RMSNorm(torch.nn.Module):
21
+ def __init__(self, ndim: int, eps: float):
22
+ super().__init__()
23
+ self.eps = eps
24
+ self.weight = nn.Parameter(torch.ones(ndim))
25
+
26
+ def _norm(self, x):
27
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
28
+
29
+ def forward(self, x):
30
+ return self._norm(x) * self.weight
31
+
32
+
33
+ class SwiGLU(nn.Module):
34
+ def __init__(self, in_dim, out_dim, bias) -> None:
35
+ super().__init__()
36
+ self.w1 = nn.Linear(in_dim, out_dim, bias=bias)
37
+ self.w3 = nn.Linear(in_dim, out_dim, bias=bias)
38
+
39
+ def forward(self, x):
40
+ return F.silu(self.w1(x)) * self.w3(x)
41
+
42
+
43
+ class MLP(nn.Module):
44
+ def __init__(self, config):
45
+ super().__init__()
46
+ self.non_linearity = config.nonlinearity_type
47
+ hidden_dim = 4 * config.n_embd
48
+ if config.nonlinearity_type == "gelu":
49
+ self.c_fc = nn.Linear(config.n_embd, hidden_dim, bias=config.bias)
50
+ self.gelu = nn.GELU()
51
+ self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=config.bias)
52
+ elif config.nonlinearity_type == "swiglu":
53
+ if config.swiglu_multiple_of is None:
54
+ raise Exception("SwiGLU requires swiglu_multiple_of to be set")
55
+ hidden_dim = int(2 * hidden_dim / 3)
56
+ hidden_dim = config.swiglu_multiple_of * math.ceil(hidden_dim / config.swiglu_multiple_of)
57
+ # set name to `c_proj` so that the right initialisation gets applied to it in GPT.__init__()
58
+ self.swiglu = SwiGLU(config.n_embd, hidden_dim, bias=config.bias)
59
+ self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=config.bias)
60
+ else:
61
+ raise Exception(f"Unknown nonlinearity type: {config.nonlinearity_type}")
62
+ self.dropout = nn.Dropout(config.dropout)
63
+
64
+ def forward(self, x):
65
+ if self.non_linearity == "gelu":
66
+ x = self.c_fc(x)
67
+ x = self.gelu(x)
68
+ elif self.non_linearity == "swiglu":
69
+ x = self.swiglu(x)
70
+ x = self.c_proj(x)
71
+ x = self.dropout(x)
72
+ return x
fam/llm/mixins/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from fam.llm.mixins.causal import CausalInferenceMixin
2
+ from fam.llm.mixins.non_causal import NonCausalInferenceMixin
fam/llm/mixins/causal.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ import tqdm
6
+ from torch.nn import functional as F
7
+ from IPython import embed
8
+
9
+ def top_p_sample(prob_dist: torch.Tensor, top_p: float):
10
+ sorted_probs, sorted_indices = torch.sort(prob_dist, descending=True, dim=-1)
11
+ cum_sum_probs = torch.cumsum(sorted_probs, dim=-1) # (b, vocab_size)
12
+
13
+ sorted_indices_to_remove = cum_sum_probs > top_p
14
+
15
+ # Shift the indices to the right to keep also the first token above the threshold
16
+ sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
17
+ sorted_indices_to_remove[:, 0] = 0
18
+ sorted_indices_to_remove = sorted_indices_to_remove.bool()
19
+
20
+ # replace probs to be removed with 0 in the sorted_probs
21
+ sorted_probs[sorted_indices_to_remove] = 0
22
+
23
+ # reverse the sorting process
24
+ reversed_indices = torch.argsort(sorted_indices)
25
+ prob_dist = torch.gather(sorted_probs, -1, reversed_indices)
26
+
27
+ # normalize
28
+ prob_dist = prob_dist / prob_dist.sum(dim=-1, keepdim=True)
29
+
30
+ return prob_dist
31
+
32
+
33
+ class CausalInferenceMixin:
34
+ """
35
+ Mixin class for performing inference in a causal language model.
36
+
37
+ This mixin provides methods for predicting the next token in a sequence, sampling from the model,
38
+ and applying token prediction masks.
39
+
40
+ Attributes:
41
+ None
42
+
43
+ Methods:
44
+ _sample_next_token: Predicts the next token in the sequence.
45
+ _create_token_pred_mask: Creates a token prediction mask based on sequence lengths.
46
+ _apply_token_pred_mask: Applies a token prediction mask to the next token predictions.
47
+ _sample_batch: Samples a batch of tokens from the model.
48
+ _sort_for_batching: Sorts the input sequences for efficient batching.
49
+ _causal_sample: Generates a sequence of tokens using causal sampling.
50
+
51
+ """
52
+
53
+ @torch.no_grad()
54
+ def _sample_next_token(
55
+ self,
56
+ *,
57
+ idx: torch.Tensor,
58
+ speaker_embs: Optional[torch.Tensor],
59
+ temperature: float,
60
+ top_k: Optional[int],
61
+ top_p: Optional[float],
62
+ guidance_scale: Optional[float],
63
+ ) -> torch.Tensor:
64
+ """
65
+ Predict the next token in the sequence.
66
+
67
+ Args:
68
+ idx (torch.Tensor): Initial sequence indices of shape (batch, num_hierarchies, time).
69
+ speaker_embs (Optional[torch.Tensor]): Speaker embeddings. Set to `None` if using an unconditional model.
70
+ temperature (float): Sampling temperature.
71
+ top_k (Optional[int]): Top-k filtering threshold. Set to `None` to disable top-k filtering.
72
+ top_p (Optional[float]): Nucleus sampling threshold. Set to `None` to disable it.
73
+ guidance_scale (Optional[float]): Scale factor for the guidance loss. Set to `None` to disable guidance.
74
+
75
+ Returns:
76
+ torch.Tensor: Next index in the sequence after sampling. Shape: (batch, num_hierarchies).
77
+ """
78
+ if top_k is not None and top_p is not None:
79
+ raise ValueError("Only one of top_k and top_p can be set")
80
+
81
+ # if the sequence context is growing too long we must crop it at block_size
82
+ idx_cond = idx if idx.size(-1) <= self.config.block_size else idx[:, :, -self.config.block_size :]
83
+
84
+ # forward the model to get the logits for the index in the sequence
85
+ list_logits, _ = self(
86
+ idx_cond, speaker_embs=speaker_embs
87
+ ) # list with len num_hierarchies of (b,1,vocab_size) tensors
88
+ # print(f'{list_logits[0].shape=}, {len(list_logits)=}')
89
+ # print(f'{list_logits[0][:,:,:10]}')
90
+
91
+ if guidance_scale is not None:
92
+ assert idx_cond.shape[0] % 2 == 0
93
+ assert list_logits[0].shape[0] % 2 == 0
94
+
95
+ for i, logits in enumerate(list_logits):
96
+ logits_cond, logits_uncond = logits.split(logits.shape[0] // 2, dim=0)
97
+ list_logits[i] = (guidance_scale) * logits_cond + (1 - guidance_scale) * logits_uncond
98
+
99
+ assert list_logits[0].shape[0] == idx_cond.shape[0] // 2
100
+
101
+ # pluck the logits at the final step and scale by desired temperature
102
+ list_logits = [
103
+ logits[:, -1, :] / temperature for logits in list_logits
104
+ ] # list with len num_hierarchies of (b,vocab_size) tensors
105
+
106
+ # optionally crop the logits to only the top k options
107
+ if top_k is not None:
108
+ for i in range(len(list_logits)):
109
+ logits = list_logits[i]
110
+ v, _ = torch.topk(
111
+ logits, min(top_k, logits.size(-1))
112
+ ) # returns a descending sorted list of values and indices of top_k values
113
+ logits[logits < v[:, [-1]]] = -float("Inf") # set all logits below the smallest top_k value to -Inf
114
+ list_logits[i] = logits
115
+
116
+ # apply softmax to convert logits to (normalized) probabilities
117
+ # embed()
118
+ probs = [
119
+ F.softmax(logits, dim=-1) for logits in list_logits
120
+ ] # list of len num_hierarchies of (b,vocab_size) tensors
121
+ # print(f'{probs[0].shape=}')
122
+ # print(f'{probs[0][:,:,:10]}')
123
+ if top_p is not None:
124
+ for i in range(len(probs)):
125
+ probs[i] = top_p_sample(probs[i], top_p)
126
+
127
+ # sample from the distribution
128
+ idx_next = [
129
+ torch.multinomial(prob, num_samples=1) for prob in probs
130
+ ] # list of len num_hierarchies of (b,1) tensors
131
+ idx_next = torch.cat(idx_next, dim=-1) # (b, num_hierarchies) tensor
132
+
133
+ return idx_next # (b, num_hierarchies) tensor
134
+
135
+ @torch.no_grad()
136
+ def _create_token_pred_mask(self, idx: torch.Tensor, seq_lens: list[int]) -> torch.Tensor:
137
+ """
138
+ Creates a token prediction mask based on sequence lengths.
139
+
140
+ Args:
141
+ idx (torch.Tensor): Initial sequence indices of shape (batch, num_hierarchies, time).
142
+ seq_lens (list[int]): List of sequence lengths for each sequence in idx.
143
+
144
+ Returns:
145
+ torch.Tensor: Token prediction mask of shape (batch, time).
146
+ """
147
+ token_pred_mask = torch.zeros((idx.shape[0], idx.shape[-1]), dtype=torch.bool, device=idx.device)
148
+ for i in range(len(seq_lens)):
149
+ token_pred_mask[i, : seq_lens[i]] = True
150
+
151
+ assert (token_pred_mask[:, : min(seq_lens)] == 1).all()
152
+
153
+ return token_pred_mask
154
+
155
+ @torch.no_grad()
156
+ def _apply_token_pred_mask(
157
+ self, *, idx_next: torch.Tensor, orig_input_at_t: torch.Tensor, token_pred_mask_at_t: torch.Tensor
158
+ ) -> torch.Tensor:
159
+ """
160
+ Applies a token prediction mask to the next token predictions.
161
+
162
+ Args:
163
+ idx_next (torch.Tensor): Next token predictions of shape (batch, num_hierarchies).
164
+ orig_input_at_t (torch.Tensor): Original input at time step t of shape (batch, num_hierarchies).
165
+ token_pred_mask_at_t (torch.Tensor): Token prediction mask at time step t of shape (batch, 1).
166
+
167
+ Returns:
168
+ torch.Tensor: Updated next token predictions after applying the token prediction mask.
169
+ """
170
+ idx_next = idx_next * (~token_pred_mask_at_t) + orig_input_at_t * token_pred_mask_at_t
171
+
172
+ return idx_next
173
+
174
+ @torch.no_grad()
175
+ def _sample_batch(
176
+ self,
177
+ *,
178
+ idx: torch.Tensor,
179
+ max_new_tokens: int,
180
+ seq_lens: list[int],
181
+ temperature: float,
182
+ top_k: Optional[int],
183
+ top_p: Optional[float],
184
+ speaker_embs: Optional[torch.Tensor],
185
+ guidance_scale: Optional[float],
186
+ ):
187
+ """
188
+ Samples a batch of tokens from the model.
189
+
190
+ Args:
191
+ idx (torch.Tensor): Initial sequence indices of shape (batch, num_hierarchies, time).
192
+ max_new_tokens (int): Maximum number of NEW tokens to generate (in addition to largest sequence in idx).
193
+ seq_lens (list[int]): List of sequence lengths for each sequence in idx.
194
+ temperature (float): Sampling temperature.
195
+ top_k (Optional[int]): Top-k filtering threshold. Set to `None` to disable top-k filtering.
196
+ top_p (Optional[float]): Nucleus sampling threshold. Set to `None` to disable it.
197
+ speaker_embs (Optional[torch.Tensor]): Speaker embeddings. Set to `None` if using an unconditional model.
198
+ guidance_scale (Optional[float]): Scale factor for the guidance loss. Set to `None` to disable guidance.
199
+
200
+ Returns:
201
+ torch.Tensor: Generated sequence indices of shape (batch, num_hierarchies, time).
202
+ """
203
+ assert max(seq_lens) <= idx.shape[-1]
204
+ token_pred_mask = self._create_token_pred_mask(idx, seq_lens)
205
+ input = torch.clone(idx)
206
+
207
+ min_seq_lens = min(seq_lens)
208
+ idx = idx[:, :, :min_seq_lens]
209
+
210
+ if guidance_scale is not None:
211
+ if speaker_embs is None:
212
+ raise Exception("Guidance is only supported for conditional models")
213
+
214
+ # create speaker embeddings equivalent to the batch size, filling with None
215
+ # for second half to do unconditional generation.
216
+ speaker_embs = list(speaker_embs) + [None] * (speaker_embs.shape[0])
217
+
218
+ for timestep in tqdm.tqdm(range(min_seq_lens, min_seq_lens + max_new_tokens), desc="tokens: "):
219
+ if (self.kv_cache_enabled is True) and (timestep > min_seq_lens):
220
+ idx_input = idx[:, :, -1:]
221
+ else:
222
+ idx_input = idx
223
+
224
+ if guidance_scale is not None:
225
+ # TODO: fix: will cause a problem with kv-caching as it's not expecting larger batch-size.
226
+ if timestep == min_seq_lens:
227
+ print("[hack!!!!] Guidance is on, so we're doubling batch size!")
228
+
229
+ # replicate idx in the batch dimension
230
+ idx_input = (
231
+ idx_input.unsqueeze(0).repeat(2, 1, 1, 1).reshape(-1, idx_input.shape[1], idx_input.shape[2])
232
+ )
233
+
234
+ # sanity checks
235
+ assert idx_input.shape[0] % 2 == 0
236
+
237
+ idx_next = self._sample_next_token(
238
+ idx=idx_input,
239
+ speaker_embs=speaker_embs,
240
+ temperature=temperature,
241
+ top_k=top_k,
242
+ top_p=top_p,
243
+ guidance_scale=guidance_scale,
244
+ ) # (b, num_hierarchies)
245
+
246
+ assert idx_next.shape[0] == idx.shape[0]
247
+
248
+ if timestep < token_pred_mask.shape[-1]:
249
+ idx_next = self._apply_token_pred_mask(
250
+ idx_next=idx_next,
251
+ orig_input_at_t=input[:, :, timestep],
252
+ token_pred_mask_at_t=token_pred_mask[:, [timestep]],
253
+ )
254
+
255
+ idx_next = idx_next.unsqueeze(-1) # (b, num_hierarchies, T=1) tensor
256
+ # append sampled index to the running sequence and continue
257
+ idx = torch.cat((idx, idx_next), dim=2)
258
+
259
+ return idx
260
+
261
+ @torch.no_grad()
262
+ def _sort_for_batching(
263
+ self,
264
+ *,
265
+ idx: torch.Tensor,
266
+ seq_lens: list[int],
267
+ speaker_embs: Optional[torch.Tensor],
268
+ batch_size: int,
269
+ max_new_tokens: int,
270
+ ) -> Tuple[list[int], list[int], torch.Tensor, list[int], Optional[torch.Tensor], int]:
271
+ """
272
+ Sorts the input sequences for efficient batching.
273
+
274
+ Args:
275
+ idx (torch.Tensor): Initial sequence indices of shape (batch, num_hierarchies, time).
276
+ seq_lens (list[int]): List of sequence lengths for each sequence in idx.
277
+ speaker_embs (Optional[torch.Tensor]): Speaker embeddings. Set to `None` if using an unconditional model.
278
+ batch_size (int): Batch size for sampling. idx is split into batches of this size for sampling.
279
+ max_new_tokens (int): Maximum number of NEW tokens to generate (in addition to largest sequence in idx).
280
+
281
+ Returns:
282
+ Tuple[list[int], list[int], torch.Tensor, list[int], Optional[torch.Tensor], int]:
283
+ - sorted_indices (list[int]): List of indices of the input sequences that transform it into sorted order.
284
+ - invert_sorted_indices (list[int]): List of indices to invert the sorted sequences back to the original order.
285
+ - idx (torch.Tensor): Input sequence indices in sorted order.
286
+ - seq_lens (list[int]): Sequence lengths in sorted order.
287
+ - speaker_embs (Optional[torch.Tensor]): speaker embeddings in sorted order.
288
+ - max_token_len (int): Effective maximum number of tokens to generate.
289
+ """
290
+ assert len(seq_lens) == idx.shape[0]
291
+ assert max(seq_lens) <= idx.shape[-1]
292
+
293
+ sorted_indices = np.argsort(seq_lens)
294
+ inverted_sorted_indices = np.zeros(len(seq_lens), dtype=np.int32)
295
+ inverted_sorted_indices[sorted_indices] = np.arange(len(seq_lens), dtype=np.int32)
296
+
297
+ idx = idx[sorted_indices]
298
+ seq_lens = [seq_lens[i] for i in sorted_indices]
299
+ speaker_embs = speaker_embs[sorted_indices] if speaker_embs is not None else None
300
+ max_token_len = 0
301
+
302
+ # figure out effective max_tokens to generate
303
+ for start_index in range(0, len(seq_lens), batch_size):
304
+ end_index = min(start_index + batch_size, len(seq_lens))
305
+ batch_seq_lens = seq_lens[start_index:end_index]
306
+ # random heuristic...
307
+ # # TODO: fix!
308
+ max_token_len = max(max_token_len, min(batch_seq_lens) + max_new_tokens)
309
+
310
+ return sorted_indices, inverted_sorted_indices, idx, seq_lens, speaker_embs, max_token_len
311
+
312
+ @torch.no_grad()
313
+ def _causal_sample(
314
+ self,
315
+ *,
316
+ idx: torch.Tensor,
317
+ max_new_tokens: int,
318
+ seq_lens: list[int],
319
+ temperature: float,
320
+ top_k: Optional[int],
321
+ top_p: Optional[float],
322
+ speaker_embs: Optional[torch.Tensor],
323
+ batch_size: int,
324
+ guidance_scale: Optional[float] = None,
325
+ ) -> torch.Tensor:
326
+ """
327
+ Generates a sequence of tokens using causal sampling.
328
+
329
+ Args:
330
+ idx (torch.Tensor): Initial sequence indices of shape (batch, num_hierarchies, time).
331
+ max_new_tokens (int): Maximum number of NEW tokens to generate (in addition to largest sequence in idx).
332
+ seq_lens (list[int]): List of sequence lengths for each sequence in idx.
333
+ temperature (float): Sampling temperature.
334
+ top_k (Optional[int]): Top-k filtering threshold. Set to `None` to disable top-k filtering.
335
+ top_p (Optional[float]): Nucleus sampling threshold. Set to `None` to disable it.
336
+ speaker_embs (Optional[torch.Tensor]): Speaker embeddings. Set to `None` if using an unconditional model.
337
+ batch_size (int): Batch size for sampling. idx is split into batches of this size for sampling.
338
+ guidance_scale (Optional[float]): Scale factor for the guidance loss. Set to `None` to disable guidance.
339
+
340
+ Returns:
341
+ torch.Tensor: Generated sequence indices of shape (batch, num_hierarchies, time).
342
+ """
343
+ (
344
+ _,
345
+ invert_sorted_indices,
346
+ idx,
347
+ seq_lens,
348
+ speaker_embs,
349
+ max_token_len,
350
+ ) = self._sort_for_batching(
351
+ idx=idx, seq_lens=seq_lens, speaker_embs=speaker_embs, batch_size=batch_size, max_new_tokens=max_new_tokens
352
+ )
353
+
354
+ return_idx = torch.zeros((len(seq_lens), idx.size(1), max_token_len), dtype=torch.long, device=idx.device)
355
+
356
+ for start_index in tqdm.tqdm(range(0, len(seq_lens), batch_size), desc="batch: "):
357
+ end_index = min(start_index + batch_size, len(seq_lens))
358
+
359
+ kv_batch_size = end_index - start_index
360
+ if guidance_scale is not None:
361
+ kv_batch_size = 2 * kv_batch_size
362
+
363
+ if self.kv_cache_enabled:
364
+ print("!!!! USING KV-CACHING ASSUMED TORCH.BFLOAT16")
365
+ self.empty_kv_cache(
366
+ batch_size=kv_batch_size,
367
+ kv_cache_maxlen=self.config.block_size,
368
+ dtype=torch.bfloat16,
369
+ )
370
+
371
+ batch_seq_lens = seq_lens[start_index:end_index]
372
+ batch_max_new_tokens = max_token_len - min(batch_seq_lens)
373
+
374
+ batch_idx = idx[start_index:end_index]
375
+ batch_speaker_embs = speaker_embs[start_index:end_index] if speaker_embs is not None else None
376
+
377
+ batch_idx = self._sample_batch(
378
+ idx=batch_idx,
379
+ max_new_tokens=batch_max_new_tokens,
380
+ seq_lens=batch_seq_lens,
381
+ temperature=temperature,
382
+ top_k=top_k,
383
+ top_p=top_p,
384
+ speaker_embs=batch_speaker_embs,
385
+ guidance_scale=guidance_scale,
386
+ )
387
+ return_idx[start_index:end_index] = batch_idx
388
+
389
+ return return_idx[invert_sorted_indices]
390
+
391
+ def empty_kv_cache(self, *, batch_size: int, kv_cache_maxlen: int, dtype: torch.dtype):
392
+ """
393
+ Empties key-value (KV) cache for causal attention.
394
+
395
+ Args:
396
+ batch_size (int): The batch size.
397
+ kv_cache_maxlen (int): The maximum length of the KV cache.
398
+ dtype (torch.dtype): The data type of the KV cache.
399
+
400
+ Raises:
401
+ Exception: If KV cache is enabled for non-causal attention.
402
+
403
+ """
404
+ if self.kv_cache_enabled is False:
405
+ raise Exception("KV cache is not enabled")
406
+ if self.config.causal is False:
407
+ raise Exception("KV cache is not supported for non-causal attention")
408
+
409
+ self.kv_pos = 0
410
+ for block in self.transformer.h:
411
+ block.attn.empty_kv_cache(batch_size=batch_size, kv_cache_maxlen=kv_cache_maxlen, dtype=dtype)
412
+
413
+ def enable_kv_cache(self):
414
+ """
415
+ Enables key-value (KV) cache for causal attention.
416
+
417
+ Raises:
418
+ Exception: If KV cache is enabled for non-causal attention.
419
+
420
+ """
421
+ if self.config.causal is False:
422
+ raise Exception("KV cache is not supported for non-causal attention")
423
+
424
+ self.kv_cache_enabled = True
425
+ for block in self.transformer.h:
426
+ block.attn.kv_cache_enabled = True
427
+
428
+ def disable_kv_cache(self):
429
+ """
430
+ Disables the key-value cache for the transformer and all its blocks.
431
+ """
432
+ self.kv_cache_enabled = False
433
+ for block in self.transformer.h:
434
+ block.attn.kv_cache_enabled = False
435
+ block.attn.kv_cache = None
436
+ block.attn.kv_cache_first_empty_index = 0
437
+
438
+ @torch.no_grad()
439
+ def _slow_causal_sampling_loop(
440
+ self,
441
+ idx: torch.Tensor,
442
+ max_new_tokens: int,
443
+ temperature: float = 1.0,
444
+ top_k: Optional[int] = None,
445
+ top_p: Optional[float] = None,
446
+ speaker_embs: Optional[torch.Tensor] = None,
447
+ guidance_scale: Optional[float] = None,
448
+ ):
449
+ """
450
+ Old non-batched version of causal sampling. Kept for testing / reference.
451
+
452
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,n_head,t)) and complete
453
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
454
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
455
+ """
456
+ assert idx.dim() == 3, "idx must be a batch of sequences of hierarchical tokens"
457
+ assert idx.size(0) == 1, "can only do one sequence at a time for now"
458
+ assert top_p is None, "nucleus sampling not supported yet with _slow_causal_sampling_loop"
459
+
460
+ if self.config.causal is not True:
461
+ raise Exception("Causal sampling is only supported for causal models")
462
+
463
+ if self.kv_cache_enabled:
464
+ print("!!!! USING KV-CACHING ASSUMED TORCH.BFLOAT16")
465
+ self.empty_kv_cache(
466
+ batch_size=1,
467
+ kv_cache_maxlen=self.config.block_size,
468
+ dtype=torch.bfloat16,
469
+ )
470
+
471
+ for i in range(max_new_tokens):
472
+ # if the sequence context is growing too long we must crop it at block_size
473
+ idx_cond = idx if idx.size(-1) <= self.config.block_size else idx[:, -self.config.block_size :]
474
+
475
+ if self.kv_cache_enabled:
476
+ if i > 0:
477
+ idx_cond = idx_cond[:, :, -1:]
478
+
479
+ # forward the model to get the logits for the index in the sequence
480
+ list_logits, _ = self(idx_cond, speaker_embs=speaker_embs)
481
+
482
+ if guidance_scale is not None:
483
+ # we've already checked that kv-caching is not switched on
484
+ # so this should be ok.
485
+ list_logits_uncond, _ = self(idx_cond, speaker_embs=None)
486
+ list_logits = [
487
+ (guidance_scale) * logits + (1 - guidance_scale) * logits_uncond
488
+ for logits, logits_uncond in zip(list_logits, list_logits_uncond)
489
+ ]
490
+
491
+ # pluck the logits at the final step and scale by desired temperature
492
+ list_logits = [logits[:, -1, :] / temperature for logits in list_logits]
493
+
494
+ # optionally crop the logits to only the top k options
495
+ if top_k is not None:
496
+ for i in range(len(list_logits)):
497
+ logits = list_logits[i]
498
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
499
+ logits[logits < v[:, [-1]]] = -float("Inf")
500
+ list_logits[i] = logits
501
+
502
+ # apply softmax to convert logits to (normalized) probabilities
503
+ probs = [F.softmax(logits, dim=-1) for logits in list_logits]
504
+ # sample from the distribution
505
+ idx_next = torch.tensor(
506
+ [torch.multinomial(prob, num_samples=1) for prob in probs], device=idx.device
507
+ ) # (c, 1)
508
+ # append sampled index to the running sequence and continue
509
+ idx = torch.cat((idx, idx_next.unsqueeze(0).unsqueeze(-1)), dim=2)
510
+
511
+ return idx
fam/llm/mixins/non_causal.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch.nn import functional as F
5
+
6
+
7
+ class NonCausalInferenceMixin:
8
+ """
9
+ Mixin class for non-causal inference in a language model.
10
+
11
+ This class provides methods for performing non-causal sampling using a language model.
12
+ """
13
+
14
+ @torch.no_grad()
15
+ def _non_causal_sample(
16
+ self, *, idx: torch.Tensor, speaker_embs: Optional[torch.Tensor], temperature: float, top_k: int
17
+ ):
18
+ """
19
+ Perform non-causal sampling.
20
+
21
+ Args:
22
+ idx (torch.Tensor): Input tensor of shape (batch_size, num_in_hierarchies, sequence_length).
23
+ speaker_embs (Optional[torch.Tensor]): Speaker embeddings tensor of shape (batch_size, embedding_size).
24
+ temperature (float): Temperature parameter for scaling the logits.
25
+ top_k (int): Number of top options to consider.
26
+
27
+ Returns:
28
+ torch.Tensor: Sampled output tensor of shape (batch_size, num_out_hierarchies, sequence_length).
29
+ """
30
+ b, c, t = idx.size()
31
+ assert t == self.config.block_size, f"input size {t} != config.block_size {self.config.block_size}"
32
+ # forward the model to get the logits for the index in the sequence
33
+ list_logits, _ = self(idx, speaker_embs=speaker_embs) # c x (b, t, vocab_size)
34
+
35
+ # scale by desired temperature
36
+ list_logits = [logits / temperature for logits in list_logits] # c x (b, t, vocab_size)
37
+
38
+ # optionally crop the logits to only the top k options
39
+ if top_k is not None:
40
+ for i in range(len(list_logits)):
41
+ logits = list_logits[i] # (b, t, vocab_size)
42
+
43
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1))) # (b, t, top_k)
44
+ logits[logits < v[:, :, [-1]]] = -float("Inf")
45
+ list_logits[i] = logits # (b, t, vocab_size)
46
+ assert logits.shape[0] == b and logits.shape[1] == t
47
+
48
+ # apply softmax to convert logits to (normalized) probabilities
49
+ # TODO: check shapes here!
50
+ probs = [F.softmax(logits, dim=-1) for logits in list_logits] # c x (b, t, top_k)
51
+ assert probs[0].shape[0] == b and probs[0].shape[1] == t
52
+
53
+ # TODO: output shape is as expected
54
+ outs = []
55
+ for b_prob in probs: # c x (b, t, top_k) -> (b, t, top_k)
56
+ out = [
57
+ torch.multinomial(prob, num_samples=1).transpose(0, 1).unsqueeze(0) for prob in b_prob
58
+ ] # b x (t, top_k) -> b x (t, 1) -> b x (1, t) -> b x (1, 1, t)
59
+ assert len(out) == b and out[0].shape[0] == 1 and out[0].shape[1] == 1 and out[0].shape[2] == t
60
+ out = torch.cat(out, dim=0) # (b, 1, t)
61
+ assert out.shape[0] == b and out.shape[1] == 1 and out.shape[2] == t
62
+ outs.append(out)
63
+
64
+ out = torch.cat(outs, dim=1) # (b, c, t)
65
+ assert out.shape[0] == b and out.shape[2] == t
66
+
67
+ return out
fam/llm/model.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from dataclasses import dataclass, field
4
+ from typing import Literal, Optional, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import tqdm
9
+ from einops import rearrange
10
+ from torch.nn import functional as F
11
+
12
+ from fam.llm.layers import Block, LayerNorm, RMSNorm
13
+ from fam.llm.mixins import CausalInferenceMixin, NonCausalInferenceMixin
14
+
15
+ from IPython import embed
16
+ END_OF_TEXT_TOKEN = 1537
17
+
18
+
19
+ def _select_spkemb(spkemb, mask):
20
+ _, examples, _ = spkemb.shape
21
+ mask = torch.nn.functional.one_hot(mask.long(), num_classes=examples).to(spkemb) # shape: (batch, time, examples)
22
+ spkemb = spkemb.transpose(1, 2) # b ex c -> b c ex
23
+ mask = mask.transpose(1, 2) # b t ex -> b ex t
24
+ return torch.bmm(spkemb, mask).transpose(1, 2) # b c t -> b t c
25
+
26
+
27
+ @dataclass
28
+ class GPTConfig:
29
+ block_size: int = 1024
30
+ vocab_sizes: list = field(default_factory=list)
31
+ target_vocab_sizes: Optional[list] = None
32
+ n_layer: int = 12
33
+ n_head: int = 12
34
+ n_embd: int = 768
35
+ dropout: float = 0.0
36
+ spkemb_dropout: float = 0.0
37
+ bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
38
+ causal: bool = (
39
+ True # auto-regressive or not, i.e. whether to have attention mask that prevents attending to future tokens
40
+ )
41
+ spk_emb_on_text: bool = True # whether to add speaker embedding conditioning to text tokens or not
42
+ norm_type: str = "layernorm" # "rmsnorm" or "layernorm
43
+ rmsnorm_eps: Optional[float] = None # only used for rmsnorm
44
+ nonlinearity_type: str = "gelu" # "gelu" or "swiglu"
45
+ swiglu_multiple_of: Optional[int] = None # MLP hidden layer (using SwiGLU) will be multiple of this
46
+ attn_kernel_type: Literal["torch_attn"] = "torch_attn"
47
+ #Literal["fa2", "torch_attn", "hand"] = "fa2"
48
+ kv_cache_enabled: bool = False # whether to use key-value cache for attention
49
+
50
+
51
+ def _check_speaker_emb_dims(
52
+ speaker_embs: Union[list, torch.Tensor], expected_speaker_emb_dim: int, expected_batch_size: int
53
+ ) -> Union[torch.Tensor, list]:
54
+ """
55
+ Checks that the speaker embedding dimensions are correct, and reshapes them if necessary.
56
+ """
57
+ if type(speaker_embs) == list:
58
+ b_se = len(speaker_embs)
59
+ for i, s in enumerate(speaker_embs):
60
+ if s is not None:
61
+ emb_dim = s.shape[-1]
62
+ if s.ndim == 1:
63
+ speaker_embs[i] = speaker_embs[i].unsqueeze(0)
64
+ else:
65
+ if speaker_embs.ndim == 2:
66
+ # if we have a single speaker embedding for the whole sequence,
67
+ # add a dummy dimension for backwards compatibility
68
+ speaker_embs = speaker_embs[:, None, :]
69
+
70
+ # num_examples is the number of utterances packed into this sequence
71
+ b_se, num_examples, emb_dim = speaker_embs.size()
72
+
73
+ assert b_se == expected_batch_size, f"Batch size mismatch: {b_se} != {expected_batch_size}"
74
+ assert (
75
+ emb_dim == expected_speaker_emb_dim
76
+ ), f"Speaker embedding dimension mismatch: {emb_dim} != {expected_speaker_emb_dim}"
77
+
78
+ return speaker_embs
79
+
80
+
81
+ class GPT(nn.Module, NonCausalInferenceMixin, CausalInferenceMixin):
82
+ def __init__(self, config: GPTConfig, speaker_emb_dim: Optional[int] = None):
83
+ """
84
+ Initialize the GPT model.
85
+
86
+ Args:
87
+ config (GPTConfig): Configuration object for the model.
88
+ speaker_emb_dim (Optional[int]): Dimension of the speaker embedding. Default is None.
89
+ """
90
+ super().__init__()
91
+ assert config.vocab_sizes is not None
92
+ assert config.block_size is not None
93
+ self.config = config
94
+
95
+ self.kv_cache_enabled = False # disabled by default
96
+ self.kv_pos = 0
97
+
98
+ self.speaker_emb_dim = speaker_emb_dim
99
+ self.spk_emb_on_text = config.spk_emb_on_text
100
+ if self.config.causal is True and self.spk_emb_on_text is False:
101
+ print("!!!!!!!!!!!!!!!!!!")
102
+ print(
103
+ f"!!!!!!!! Using DEFAULT of {END_OF_TEXT_TOKEN} as end of text token to find speaker cond masking!! You likely need to change this."
104
+ )
105
+ print("!!!!!!!!!!!!!!!!!!")
106
+ if self.config.causal is False and self.spk_emb_on_text is False:
107
+ raise Exception(
108
+ "Cannot use speaker embedding masking with non-causal model. This is unexpected. Check for relevant changes required in code before proceeding."
109
+ )
110
+
111
+ if config.norm_type == "rmsnorm":
112
+ if config.rmsnorm_eps is None:
113
+ raise Exception("RMSNorm requires rmsnorm_eps to be set")
114
+ ln_f = RMSNorm(config.n_embd, eps=config.rmsnorm_eps)
115
+ elif config.norm_type == "layernorm":
116
+ ln_f = LayerNorm(config.n_embd, bias=config.bias)
117
+ else:
118
+ raise Exception(f"Unknown norm type: {config.norm_type}")
119
+
120
+ self.transformer = nn.ModuleDict(
121
+ dict(
122
+ wtes=nn.ModuleList([nn.Embedding(vsize, config.n_embd,) for vsize in config.vocab_sizes]),
123
+ wpe=nn.Embedding(config.block_size, config.n_embd),
124
+ drop=nn.Dropout(config.dropout),
125
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
126
+ ln_f=ln_f,
127
+ )
128
+ )
129
+ if speaker_emb_dim is not None:
130
+ self.speaker_cond_pos = nn.Linear(speaker_emb_dim, config.n_embd, bias=False) # ここで256->2048
131
+
132
+ self.lm_heads = nn.ModuleList()
133
+ if config.target_vocab_sizes is not None:
134
+ assert config.causal is False
135
+ else:
136
+ assert config.causal is True
137
+
138
+ for vsize in config.vocab_sizes if config.target_vocab_sizes is None else config.target_vocab_sizes:
139
+ self.lm_heads.append(nn.Linear(config.n_embd, vsize, bias=False))
140
+
141
+ if config.target_vocab_sizes is None:
142
+ for i in range(len(config.vocab_sizes)):
143
+ # TODO: do we not need to take the transpose here?
144
+ # https://paperswithcode.com/method/weight-tying
145
+ self.lm_heads[i].weight = self.transformer.wtes[i].weight # type: ignore
146
+ assert len(self.lm_heads) == len(
147
+ self.transformer.wtes # type: ignore
148
+ ), f"Number of heads ({len(self.lm_heads)}) must match number of one-hot embedding matrics ({len(self.transformer.wtes)})." # type: ignore
149
+ # - causal
150
+ # GPT(
151
+ # (transformer): ModuleDict(
152
+ # (wtes): ModuleList(
153
+ # (0): Embedding(2562, 2048)
154
+ # )
155
+ # (wpe): Embedding(2048, 2048)
156
+ # (drop): Dropout(p=0.0, inplace=False)
157
+ # (h): ModuleList(
158
+ # (0-23): 24 x Block(
159
+ # (ln_1): RMSNorm()
160
+ # (ln_2): RMSNorm()
161
+ # (attn): SelfAttention(
162
+ # (c_attn): Linear(in_features=2048, out_features=6144, bias=False)
163
+ # (c_proj): Linear(in_features=2048, out_features=2048, bias=False)
164
+ # (resid_dropout): Dropout(p=0.0, inplace=False)
165
+ # )
166
+ # (mlp): MLP(
167
+ # (swiglu): SwiGLU(
168
+ # (w1): Linear(in_features=2048, out_features=5632, bias=False)
169
+ # (w3): Linear(in_features=2048, out_features=5632, bias=False)
170
+ # )
171
+ # (c_proj): Linear(in_features=5632, out_features=2048, bias=False)
172
+ # (dropout): Dropout(p=0.0, inplace=False)
173
+ # )
174
+ # )
175
+ # )
176
+ # (ln_f): RMSNorm()
177
+ # )
178
+ # (speaker_cond_pos): Linear(in_features=256, out_features=2048, bias=False)
179
+ # (lm_heads): ModuleList(
180
+ # (0): Linear(in_features=2048, out_features=2562, bias=False)
181
+ # )
182
+ # )
183
+ # GPTConfig(block_size=2048, vocab_sizes=[2562], target_vocab_sizes=None, n_layer=24, n_head=16, n_embd=2048, dropout=0.0, spkemb_dropout=0.1, bias=False, causal=True, spk_emb_on_text=True, norm_type='rmsnorm', rmsnorm_eps=1e-05, nonlinearity_type='swiglu', swiglu_multiple_of=256, attn_kernel_type='torch_attn', kv_cache_enabled=False)
184
+ #
185
+ # - non causal
186
+ # GPT(
187
+ # (transformer): ModuleDict(
188
+ # (wtes): ModuleList(
189
+ # (0): Embedding(1538, 384)
190
+ # (1): Embedding(1025, 384)
191
+ # )
192
+ # (wpe): Embedding(1024, 384)
193
+ # (drop): Dropout(p=0.0, inplace=False)
194
+ # (h): ModuleList(
195
+ # (0-5): 6 x Block(
196
+ # (ln_1): LayerNorm()
197
+ # (ln_2): LayerNorm()
198
+ # (attn): SelfAttention(
199
+ # (c_attn): Linear(in_features=384, out_features=1152, bias=False)
200
+ # (c_proj): Linear(in_features=384, out_features=384, bias=False)
201
+ # (resid_dropout): Dropout(p=0.0, inplace=False)
202
+ # )
203
+ # (mlp): MLP(
204
+ # (c_fc): Linear(in_features=384, out_features=1536, bias=False)
205
+ # (gelu): GELU(approximate='none')
206
+ # (c_proj): Linear(in_features=1536, out_features=384, bias=False)
207
+ # (dropout): Dropout(p=0.0, inplace=False)
208
+ # )
209
+ # )
210
+ # )
211
+ # (ln_f): LayerNorm()
212
+ # )
213
+ # (speaker_cond_pos): Linear(in_features=256, out_features=384, bias=False)
214
+ # (lm_heads): ModuleList(
215
+ # (0-5): 6 x Linear(in_features=384, out_features=1025, bias=False)
216
+ # )
217
+ # )
218
+ # GPTConfig(block_size=1024, vocab_sizes=[1538, 1025], target_vocab_sizes=[1025, 1025, 1025, 1025, 1025, 1025], n_layer=6, n_head=6, n_embd=384, dropout=0.0, spkemb_dropout=0.0, bias=False, causal=False, spk_emb_on_text=True, norm_type='layernorm', rmsnorm_eps=None, nonlinearity_type='gelu', swiglu_multiple_of=None, attn_kernel_type='fa2', kv_cache_enabled=False)
219
+ # if config.causal is False:
220
+ # embed()
221
+ # init all weights
222
+ self.apply(self._init_weights)
223
+ # apply special scaled init to the residual projections, per GPT-2 paper
224
+ for pn, p in self.named_parameters():
225
+ if pn.endswith("c_proj.weight"):
226
+ torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
227
+
228
+ # report number of parameters
229
+ print("number of parameters: %.2fM" % (self.get_num_params() / 1e6,))
230
+
231
+ def get_num_params(self, non_embedding=True):
232
+ """
233
+ Return the number of parameters in the model.
234
+ For non-embedding count (default), the position embeddings get subtracted.
235
+ The token embeddings would too, except due to the parameter sharing these
236
+ params are actually used as weights in the final layer, so we include them.
237
+ """
238
+ n_params = sum(p.numel() for p in self.parameters())
239
+ if non_embedding:
240
+ n_params -= self.transformer.wpe.weight.numel()
241
+ return n_params
242
+
243
+ def _init_weights(self, module):
244
+ if isinstance(module, nn.Linear):
245
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
246
+ if module.bias is not None:
247
+ torch.nn.init.zeros_(module.bias)
248
+ elif isinstance(module, nn.Embedding):
249
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
250
+
251
+ def _mask_spk_emb_on_text(self, idx: torch.Tensor, spk_emb: torch.Tensor) -> torch.Tensor:
252
+ """
253
+ This is in a separate function so we can test it easily.
254
+ """
255
+ # find index of end of text token in each sequence, then generate a binary mask
256
+ # of shape (b, 1, t) to mask out the speaker embedding for all tokens before the end of text token.
257
+ # Note: this does NOT mask the <end_of_text_token> token. This is important so that the first audio token predicted
258
+ # has speaker information to use.
259
+
260
+ # Check in channel dimension 0 as this is usually the first hierarchy where we put the text tokens.
261
+ is_end_of_text = idx[:, 0, :] == END_OF_TEXT_TOKEN
262
+ # use > 0, in case end_of_text_token is repeated for any reason.
263
+ mask = (torch.cumsum(is_end_of_text, dim=-1) > 0).float()
264
+ spk_emb = spk_emb * mask[:, :, None]
265
+
266
+ return spk_emb
267
+
268
+ def forward(
269
+ self,
270
+ idx,
271
+ targets=None,
272
+ speaker_embs=None,
273
+ embedding=None,
274
+ speaker_emb_mask=None,
275
+ loss_reduce: Literal["mean", "none"] = "mean",
276
+ ):
277
+ # print(f'{idx.shape}')
278
+ device = idx.device
279
+ b, num_hierarchies, t = idx.size()
280
+
281
+ if speaker_embs is not None:
282
+ speaker_embs = _check_speaker_emb_dims(
283
+ speaker_embs=speaker_embs, expected_speaker_emb_dim=self.speaker_emb_dim, expected_batch_size=b
284
+ )
285
+
286
+ assert (
287
+ t <= self.config.block_size
288
+ ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
289
+
290
+ if self.kv_cache_enabled:
291
+ if self.kv_pos == 0:
292
+ pos = torch.arange(0, t, dtype=torch.long, device=device)
293
+ self.kv_pos += t
294
+ else:
295
+ assert t == 1, "KV cache is only supported for single token inputs"
296
+ pos = torch.tensor([self.kv_pos], dtype=torch.long, device=device) # shape (1)
297
+ self.kv_pos += 1
298
+ else:
299
+ pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
300
+
301
+ # print("here1")
302
+ # forward the GPT model itself
303
+ # assert num_hierarchies == len(
304
+ # self.transformer.wtes
305
+ # ), f"Input tensor has {num_hierarchies} hierarchies, but model has {len(self.transformer.wtes)} set of input embeddings."
306
+
307
+ # embed the tokens, positional encoding, and speaker embedding
308
+ tok_emb = torch.zeros((b, t, self.config.n_embd), device=device)
309
+ # ends up swapping (B, num_hierarchies, t) tokens -> (B, t, c) embeddings.
310
+ # print(f'{len(self.transformer.wtes)=}')
311
+ # print(f'{self.transformer.wtes[0]=}')
312
+ wte = self.transformer.wtes[0]
313
+ #for i in range(num_hierarchies):
314
+ for i, wte in enumerate(self.transformer.wtes):
315
+ # print(f'{idx[:,i,:].shape=}')
316
+ # print(f'{idx[:,i,:]=}')
317
+ # print(f'{torch.max(idx[:,i,:])}')
318
+ # print(f'{torch.min(idx[:,i,:])}')
319
+ # print(f'{wte(idx[:,i,:]).shape=}')
320
+ # print(f'{tok_emb.shape=}')
321
+ mask_pad = idx[:, i, :] == -1 # 要素が-1であるindexを取得
322
+ masked_idx = idx[:, i, :].clone()
323
+ masked_idx[mask_pad] = 0
324
+ # embed(header='a')
325
+ embedded_idx = wte(masked_idx)
326
+ # embed(header='b')
327
+ # embedding_dim = embedded_idx.shape[-1]
328
+ # mask_expanded = mask_pad.unsqueeze(-1).expand(-1, -1, embedding_dim)
329
+ embedded_idx[mask_pad] = 0
330
+ # embedded_idx = wte(idx[:, i, :])
331
+ # print(embedded_idx[:,:,:10])
332
+ # embed(header='c')
333
+ # embed()
334
+ # masked_embedded_idx =
335
+ tok_emb += embedded_idx
336
+ # tok_emb += wte(idx[:, i, :])
337
+ # embed()
338
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
339
+
340
+ spk_emb = 0.0
341
+ if speaker_embs is not None:
342
+ if type(speaker_embs) == list:
343
+ assert speaker_emb_mask is None
344
+ assert self.training is False
345
+ assert self.spk_emb_on_text is True
346
+ # print(f'{self.config.n_embd=}')
347
+ spk_emb = []
348
+ for speaker_emb_row in speaker_embs:
349
+ if speaker_emb_row is not None:
350
+ spk_emb.append(self.speaker_cond_pos(speaker_emb_row.unsqueeze(0)))
351
+ assert spk_emb[-1].shape == (1, 1, self.config.n_embd), f"spk_emb[-1].shape={spk_emb[-1].shape}"
352
+ else:
353
+ spk_emb.append(torch.zeros((1, 1, self.config.n_embd), device=device, dtype=pos_emb.dtype))
354
+ # print(f'{len(spk_emb)}, {[v.shape for v in spk_emb]=}')
355
+ spk_emb = torch.cat(spk_emb, dim=0)
356
+
357
+ assert (
358
+ spk_emb.ndim == 3 and spk_emb.shape[1] == 1 and spk_emb.shape[0] == b
359
+ ), f"spk_emb.ndim={spk_emb.ndim}, spk_emb.shape={spk_emb.shape}, len(speaker_embs)={len(speaker_embs)}"
360
+ else:
361
+ speakers_embedded = self.speaker_cond_pos(speaker_embs) # shape (b, num_examples, c)
362
+
363
+ if speaker_emb_mask is not None:
364
+ spk_emb = _select_spkemb(speakers_embedded, speaker_emb_mask)
365
+ assert spk_emb.shape == (b, t, self.config.n_embd)
366
+ else:
367
+ spk_emb = speakers_embedded
368
+ # if we don't have a mask, we assume that the speaker embedding is the same for all tokens
369
+ # then num_examples dimension just becomes the time dimension
370
+ assert spk_emb.ndim == 3 and spk_emb.shape[1] == 1
371
+
372
+ if self.training and self.config.spkemb_dropout > 0.0:
373
+ # Remove speaker conditioning at random.
374
+ dropout = torch.ones_like(speakers_embedded) * (
375
+ torch.rand(speakers_embedded.shape[0], 1, 1, device=device) >= self.config.spkemb_dropout
376
+ )
377
+ spk_emb = torch.where(dropout == 0, torch.zeros_like(speakers_embedded), speakers_embedded)
378
+
379
+ if self.spk_emb_on_text is False:
380
+ assert speaker_emb_mask is None, "Not implemented for spk_emb_on_text=False"
381
+ spk_emb = self._mask_spk_emb_on_text(idx, spk_emb)
382
+ elif embedding is not None:
383
+ # spk_emb = embedding
384
+ # spk_emb = torch.zeros((b, t, self.config.n_embd), device=device)
385
+ # for i, wte in enumerate(self.transformer.wtes):
386
+ # print(f'{embedding[:, i, :].shape=}, {embedding.shape=}')
387
+ # print(f'{wte(embedding[:, i, :]).shape=}')
388
+ # spk_emb += wte(embedding[:, i, :])
389
+ spk_emb = self.speaker_cond_pos(embedding)
390
+ # TODO: implement causal attnetion mask here
391
+ # memo:
392
+ # b, t, d=2048のとき, tok_emb=(b,t,d), pos_emb=(t,d), spk_emp=(b,1,d)
393
+ # train: tok_emb.shape=torch.Size([128, 187, 2048]), pos_emb.shape=torch.Size([187, 2048]), spk_emb.shape=torch.Size([128, 1, 1, 187])<- spk_embは(b,1,2048)になってほしい?
394
+ # sample: tok_emb.shape=torch.Size([2, 369, 2048]), pos_emb.shape=torch.Size([369, 2048]), spk_emb.shape=torch.Size([2, 1, 2048])
395
+ # print(f'{tok_emb.shape=}, {pos_emb.shape=}, {spk_emb.shape=}')
396
+ x = self.transformer.drop(tok_emb + pos_emb + spk_emb)
397
+ for block in self.transformer.h:
398
+ x = block(x)
399
+ x = self.transformer.ln_f(x)
400
+
401
+ if targets is not None:
402
+ # if we are given some desired targets also calculate the loss
403
+ list_logits = [lm_head(x) for lm_head in self.lm_heads]
404
+ # print(f'{len(list_logits)=}, {list_logits[0].shape=}')
405
+ # embed(header='cc')
406
+
407
+ losses = [
408
+ F.cross_entropy(
409
+ logits.view(-1, logits.size(-1)),
410
+ targets[:, i, :].contiguous().view(-1),
411
+ ignore_index=-1,
412
+ reduction=loss_reduce,
413
+ )
414
+ for i, logits in enumerate(list_logits)
415
+ ]
416
+ # TODO: should we do this better without stack somehow?
417
+ # embed(header='bb')
418
+ losses = torch.stack(losses)
419
+ if loss_reduce == "mean":
420
+ # embed(header='aa')
421
+ losses = losses.mean()
422
+ else:
423
+ losses = rearrange(losses, "h (b t) -> b h t", h=len(self.lm_heads), b=b, t=t)
424
+ else:
425
+ # inference-time mini-optimization: only forward the lm_head on the very last position
426
+ if self.config.causal:
427
+ list_logits = [
428
+ lm_head(x[:, [-1], :]) for lm_head in self.lm_heads
429
+ ] # note: using list [-1] to preserve the time dim
430
+ # print(f'{len(list_logits)=}, {list_logits[0].shape=}')
431
+ else:
432
+ list_logits = [lm_head(x) for lm_head in self.lm_heads]
433
+ losses = None
434
+
435
+ return list_logits, losses
436
+
437
+ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
438
+ # start with all of the candidate parameters
439
+ param_dict = {pn: p for pn, p in self.named_parameters()}
440
+ # filter out those that do not require grad
441
+ param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
442
+ # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
443
+ # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
444
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
445
+ nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
446
+ optim_groups = [
447
+ {"params": decay_params, "weight_decay": weight_decay},
448
+ {"params": nodecay_params, "weight_decay": 0.0},
449
+ ]
450
+ num_decay_params = sum(p.numel() for p in decay_params)
451
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
452
+ print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
453
+ print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
454
+ # Create AdamW optimizer and use the fused version if it is available
455
+ fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters
456
+ use_fused = fused_available and device_type == "cuda"
457
+ extra_args = dict(fused=True) if use_fused else dict()
458
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
459
+ print(f"using fused AdamW: {use_fused}")
460
+
461
+ return optimizer
462
+
463
+ @torch.no_grad()
464
+ def generate(
465
+ self,
466
+ idx: torch.Tensor,
467
+ max_new_tokens: int,
468
+ seq_lens: Optional[list] = None,
469
+ temperature: float = 1.0,
470
+ top_k: Optional[int] = None,
471
+ top_p: Optional[float] = None,
472
+ speaker_embs: Optional[torch.Tensor] = None,
473
+ batch_size: Optional[int] = None,
474
+ guidance_scale: Optional[float] = None,
475
+ ):
476
+ """
477
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,num_hierarchies,t)) and complete
478
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
479
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
480
+ """
481
+ assert idx.dim() == 3, "idx must be a batch of sequences of hierarchical tokens"
482
+
483
+ if self.config.causal:
484
+ if seq_lens is None or batch_size is None:
485
+ raise Exception("seq_lens and batch_size must be provided for causal sampling")
486
+
487
+ return self._causal_sample(
488
+ idx=idx,
489
+ max_new_tokens=max_new_tokens,
490
+ seq_lens=seq_lens,
491
+ temperature=temperature,
492
+ top_k=top_k,
493
+ top_p=top_p,
494
+ speaker_embs=speaker_embs,
495
+ batch_size=batch_size,
496
+ guidance_scale=guidance_scale,
497
+ )
498
+
499
+ else:
500
+ if seq_lens is not None:
501
+ raise Exception("seq_lens is not supported yet for non-causal sampling")
502
+
503
+ if batch_size is None:
504
+ raise Exception("batch_size must be provided for non-causal sampling")
505
+
506
+ if guidance_scale is not None:
507
+ raise Exception("guidance_scale is not supported for non-causal sampling")
508
+
509
+ if top_p is not None:
510
+ raise Exception("top_p is not supported for non-causal sampling")
511
+
512
+ out = []
513
+ for start_index in tqdm.tqdm(range(0, idx.shape[0], batch_size), desc="non-causal batching"):
514
+ end_index = min(start_index + batch_size, idx.shape[0])
515
+ out.append(
516
+ self._non_causal_sample(
517
+ idx=idx[start_index:end_index],
518
+ speaker_embs=speaker_embs[start_index:end_index] if speaker_embs is not None else None,
519
+ temperature=temperature,
520
+ top_k=top_k,
521
+ )
522
+ )
523
+ return torch.cat(out, dim=0)
524
+ return torch.cat(out, dim=0)
fam/llm/sample.py ADDED
@@ -0,0 +1,731 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import hashlib
3
+ import json
4
+ import os
5
+ import pathlib
6
+ import shutil
7
+ import subprocess
8
+ import tempfile
9
+ from contextlib import nullcontext
10
+ from dataclasses import dataclass
11
+ from typing import List, Literal, Optional, Type, Union
12
+
13
+ import librosa
14
+ import torch
15
+ import tqdm
16
+ import tqdm.contrib.concurrent
17
+ import tyro
18
+ from huggingface_hub import snapshot_download
19
+
20
+ from fam.llm.adapters import FlattenedInterleavedEncodec2Codebook, TiltedEncodec
21
+ from fam.llm.decoders import Decoder, EncodecDecoder
22
+ from fam.llm.enhancers import BaseEnhancer, get_enhancer
23
+ from fam.llm.model import GPT, GPTConfig
24
+ from fam.llm.utils import normalize_text
25
+ from fam.quantiser.audio.speaker_encoder.model import SpeakerEncoder
26
+ from fam.quantiser.text.tokenise import TrainedBPETokeniser
27
+
28
+ @dataclass
29
+ class InferenceConfig:
30
+ ckpt_path: str # path to checkpoint
31
+ output_dir: str
32
+ num_samples: int = 10 # number of samples to draw
33
+ seed: int = 1337 # random seed
34
+ device: str = "cuda"
35
+ dtype: str = "bfloat16"
36
+ compile: bool = False
37
+ init_from: str = "resume" # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
38
+
39
+ def __str__(self):
40
+ field_strs = []
41
+ for field in dataclasses.fields(self):
42
+ value = getattr(self, field.name)
43
+ field_strs.append(f" {field.name}: {value}")
44
+
45
+ return "InferenceConfig:\n" + "\n".join(field_strs)
46
+
47
+
48
+ class Model:
49
+ """
50
+ Class to sample from a trained model.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ config: InferenceConfig,
56
+ tokenizer_cls: Type[TrainedBPETokeniser],
57
+ decoder_cls: Type[Decoder],
58
+ data_adapter_fn,
59
+ use_kv_cache: Optional[Literal["none", "flash_decoding", "vanilla"]] = None,
60
+ first_model_path = None
61
+ ):
62
+ # TODO: disentangle the encodec stuff and numbers etc with rest of this code (esp at encoder-only / second stage model inference)
63
+ # TODO: remove magic number
64
+ self._encodec_codes_pad_token = 1024
65
+ self._num_encodec_codebooks = 8
66
+ self.config = config
67
+ self.use_kv_cache = use_kv_cache
68
+
69
+ torch.manual_seed(config.seed)
70
+ torch.cuda.manual_seed(config.seed)
71
+ torch.backends.cuda.matmul.allow_tf32 = True if config.dtype != "float32" else False # allow tf32 on matmul
72
+ torch.backends.cudnn.allow_tf32 = True if config.dtype != "float32" else False # allow tf32 on cudnn
73
+ device_type = "cuda" if "cuda" in config.device else "cpu" # for later use in torch.autocast
74
+ ptdtype = {
75
+ "float32": torch.float32,
76
+ "tfloat32": torch.float32,
77
+ "bfloat16": torch.bfloat16,
78
+ "float16": torch.float16,
79
+ }[config.dtype]
80
+ self._ctx = (
81
+ nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
82
+ )
83
+
84
+ self.use_bpe_tokenizer = False
85
+ self.load_meta = None
86
+ self.speaker_cond = None
87
+ self.meta = None
88
+ self.model = None
89
+ self.checkpoint_config = None
90
+ self.vocab_sizes = None
91
+ self.smodel = None
92
+ self.first_model_path = first_model_path
93
+
94
+ self._init_model()
95
+
96
+ self.tokenizer = tokenizer_cls(**self.meta["tokenizer"])
97
+ self.decoder = decoder_cls(
98
+ tokeniser_decode_fn=self.tokenizer.decode,
99
+ output_dir=self.config.output_dir,
100
+ data_adapter_fn=data_adapter_fn,
101
+ )
102
+
103
+ def _init_model(self):
104
+ if self.config.init_from == "resume":
105
+ # init from a model saved in a specific directory
106
+ checkpoint = torch.load(self.config.ckpt_path, map_location=self.config.device)
107
+ self.vocab_sizes = checkpoint["model_args"]["vocab_sizes"]
108
+
109
+ self.load_meta = False
110
+ self.speaker_cond = False
111
+
112
+ if "config" in checkpoint:
113
+ self.checkpoint_config = checkpoint["config"]
114
+
115
+ self.meta = checkpoint["meta"]
116
+ load_meta = True
117
+
118
+ if load_meta:
119
+ self.use_bpe_tokenizer = "stoi" not in self.meta or "itos" not in self.meta
120
+ self.speaker_cond = self.meta.get("speaker_cond")
121
+
122
+ if self.speaker_cond:
123
+ speaker_emb_size = self.meta["speaker_emb_size"]
124
+
125
+ model_args = checkpoint["model_args"]
126
+ if "causal" in self.checkpoint_config and self.checkpoint_config["causal"] is False:
127
+ self._encodec_ctx_window = model_args["block_size"]
128
+
129
+ gptconf = GPTConfig(**model_args)
130
+
131
+ # TODO: rename `speaker_emb_dim` to `speaker_emb_size`.
132
+ self.model = GPT(gptconf, speaker_emb_dim=speaker_emb_size if self.speaker_cond else None)
133
+ if not getattr(self.config, 'train_from_scratch', False):
134
+ state_dict = checkpoint["model"]
135
+ unwanted_prefix = "_orig_mod."
136
+ for k, v in list(state_dict.items()):
137
+ if k.startswith(unwanted_prefix):
138
+ state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
139
+ self.model.load_state_dict(state_dict)
140
+ else:
141
+ print("training a model from scratch!!")
142
+
143
+ if self.first_model_path is not None:
144
+ if self.checkpoint_config.get("causal", True):
145
+ new_dict = {}
146
+ state_dict = torch.load(self.first_model_path)['state_dict']
147
+ for key, val in state_dict.items():
148
+ if 'first_stage_model_transformer' in key:
149
+ new_dict[key.replace('first_stage_model_transformer.', '')] = val
150
+ self.model.load_state_dict(new_dict)
151
+
152
+ # model
153
+ self.model.eval()
154
+ self.model.to(self.config.device)
155
+
156
+ if self.config.compile:
157
+ from einops._torch_specific import allow_ops_in_compiled_graph
158
+
159
+ allow_ops_in_compiled_graph()
160
+ self.model = torch.compile(self.model) # type: ignore
161
+
162
+ if self.use_kv_cache is not None:
163
+ if "causal" in self.checkpoint_config and self.checkpoint_config["causal"] is False:
164
+ raise Exception("kv_cache not supported for non-causal models!")
165
+
166
+ if self.use_kv_cache == "flash_decoding":
167
+ self.model.enable_kv_cache()
168
+ for block in self.model.transformer.h:
169
+ block.attn.attn_kernel_type = "fd"
170
+ elif self.use_kv_cache == "vanilla":
171
+ for block in self.model.transformer.h:
172
+ if block.attn.attn_kernel_type != "fa2":
173
+ raise Exception(
174
+ f"kv_cache only supported for flash attention 2 but found {block.attn.attn_kernel_type} inside model!"
175
+ )
176
+ self.model.enable_kv_cache()
177
+ else:
178
+ raise NotImplementedError(f"kv_cache type {self.use_kv_cache} not implemented!")
179
+
180
+ def causal_sample(
181
+ self,
182
+ *,
183
+ texts: list[str],
184
+ batch_size: int,
185
+ max_new_tokens: int,
186
+ temperature: Optional[float],
187
+ top_k: Optional[int],
188
+ top_p: Optional[float],
189
+ speaker_embs: Optional[torch.Tensor] = None,
190
+ guidance_scale: Optional[float] = None,
191
+ ) -> list[torch.Tensor]:
192
+ """
193
+ Returns list of torch.Tensors of tokens. Each tensor is of shape (1, c, t) where c is the number of codebooks.
194
+ Any flattening / inteleaving / tilting gets reversed before the output is returned.
195
+ """
196
+ if speaker_embs is not None:
197
+ assert len(texts) == len(speaker_embs)
198
+
199
+ encoded_texts = [self.tokenizer.encode(text) for text in texts]
200
+
201
+ ## create multiple hierarchies and get seq_lens
202
+ seq_lens = []
203
+ xs = []
204
+ for i, encoded_text in enumerate(encoded_texts):
205
+ encoded_text = torch.tensor([encoded_text], dtype=torch.long, device=self.config.device)
206
+ # TODO: remove magic number
207
+ xs.append(
208
+ torch.cat(
209
+ # [1st hierarchy of text, *remaining hierarchies of padded tokens]
210
+ # TODO: self.vocab_sizes should be from the model config?
211
+ [encoded_text, *[torch.ones_like(encoded_text) * 1024] * (len(self.vocab_sizes) - 1)],
212
+ dim=0,
213
+ ).unsqueeze(0)
214
+ ) # b x [(b=1, c, t)]
215
+ seq_lens.append(xs[-1].shape[-1])
216
+ max_len = max(seq_lens)
217
+ assert len(xs) == len(seq_lens)
218
+
219
+ ## equalise the shapes in the batch. we can use torch.zeros as tokens > seq_lens will be masked out.
220
+ x = torch.zeros((len(encoded_texts), xs[0].shape[1], max_len), dtype=torch.long, device=self.config.device)
221
+ for i, _xs in enumerate(xs):
222
+ assert _xs.shape[-1] == seq_lens[i]
223
+ x[i, :, : seq_lens[i]] = _xs
224
+
225
+ ## check that the input is correct
226
+ for i in range(x.shape[0]):
227
+ assert x[i, 0, : seq_lens[i]].tolist() == encoded_texts[i]
228
+
229
+ # TODO: remove magic number
230
+ if x.shape[1] > 1:
231
+ assert set(x[i, 1, : seq_lens[i]].tolist()) == set([1024])
232
+
233
+ assert x.shape[0] == speaker_embs.shape[0] if speaker_embs is not None else True
234
+
235
+ if self.speaker_cond is False:
236
+ speaker_embs = None
237
+
238
+ # run sampling loop
239
+ with torch.no_grad():
240
+ with self._ctx: # type: ignore
241
+ to_return = []
242
+ for k in range(self.config.num_samples):
243
+ assert seq_lens is not None
244
+ assert batch_size is not None
245
+
246
+ if max(seq_lens) + max_new_tokens >= self.model.config.block_size:
247
+ raise Exception(
248
+ f"max_new_tokens {max_new_tokens} too large! Choose {self.model.config.block_size - max(seq_lens) - 1} instead."
249
+ )
250
+
251
+ y = self.model.generate(
252
+ x,
253
+ max_new_tokens,
254
+ seq_lens=seq_lens,
255
+ temperature=temperature,
256
+ top_k=top_k,
257
+ top_p=top_p,
258
+ speaker_embs=speaker_embs,
259
+ batch_size=batch_size,
260
+ guidance_scale=guidance_scale,
261
+ )
262
+ for i in range(len(y)):
263
+ to_return.append(self.decoder.decode(tokens=y[i].tolist(), causal=True))
264
+
265
+ return to_return
266
+
267
+ def non_causal_sample(
268
+ self,
269
+ *,
270
+ texts: list[str],
271
+ encodec_tokens: list[torch.Tensor],
272
+ batch_size: int,
273
+ top_k: Optional[int],
274
+ temperature: Optional[float],
275
+ speaker_embs: Optional[torch.Tensor] = None,
276
+ ) -> list[str]:
277
+ """
278
+ Returns paths to saved audio files.
279
+ """
280
+ if speaker_embs is not None:
281
+ assert len(texts) == len(speaker_embs)
282
+
283
+ encoded_texts = [self.tokenizer.encode(text) for text in texts]
284
+
285
+ # setup input
286
+ # TODO: same code is used during data prep. refactor
287
+ padded_hierarchies_inputs = []
288
+ for encoded_text, encodec_token in zip(encoded_texts, encodec_tokens):
289
+ x = torch.tensor(encoded_text, dtype=torch.long, device=self.config.device)[
290
+ None, None, ...
291
+ ] # (b=1, c=1, t)
292
+
293
+ # TODO: should only happen if decoder is encodecdeocder?
294
+ assert encodec_token.shape[0] == 1
295
+ encodec_token = encodec_token[0].tolist() # (b=1, c, t) -> (c, t)
296
+ assert len(encodec_token) >= 1 and len(encodec_token) <= self._num_encodec_codebooks
297
+
298
+ ## setup hierarchies of tokens
299
+ # TODO: refactor and merge with code in processing.py
300
+ text_tokens = encoded_text # (t,)
301
+
302
+ # print(f'{encodec_tokens[0].shape=}, {len(encodec_tokens)}')
303
+ hierarchies_in = []
304
+ hierarchies_in.append(text_tokens + encodec_token[0] + [self._encodec_codes_pad_token])
305
+ hierarchies_in.append(
306
+ [self._encodec_codes_pad_token] * len(text_tokens) + encodec_token[1] + [self._encodec_codes_pad_token]
307
+ )
308
+
309
+ ## adding padding / cutting to the right size as needed
310
+ # TODO: refactor and merge with code in processing.py
311
+ padded_hierarchies_input = []
312
+ for _, t_hierarchy in enumerate(hierarchies_in):
313
+ assert len(t_hierarchy) == len(hierarchies_in[0])
314
+ if len(t_hierarchy) < self._encodec_ctx_window:
315
+ padded_hierarchies_input.append(
316
+ t_hierarchy + [self._encodec_codes_pad_token] * (self._encodec_ctx_window - len(t_hierarchy))
317
+ )
318
+ elif len(t_hierarchy) > self._encodec_ctx_window:
319
+ padded_hierarchies_input.append(t_hierarchy[: self._encodec_ctx_window])
320
+ else:
321
+ padded_hierarchies_input.append(t_hierarchy)
322
+
323
+ padded_hierarchies_inputs.append(padded_hierarchies_input)
324
+
325
+ ## check that the input is correct
326
+ in_x = torch.tensor(padded_hierarchies_inputs, dtype=torch.long, device=self.config.device)
327
+ assert in_x.shape[0] == speaker_embs.shape[0] if speaker_embs is not None else True
328
+
329
+ if self.speaker_cond is False:
330
+ speaker_embs = None
331
+
332
+ # run sampling loop
333
+ with torch.no_grad():
334
+ with self._ctx: # type: ignore
335
+ to_return = []
336
+ for k in range(self.config.num_samples):
337
+ y = self.model.generate(
338
+ in_x,
339
+ None,
340
+ temperature=temperature,
341
+ top_k=top_k,
342
+ # TODO: handle separate top_p for this model explicitly
343
+ top_p=None,
344
+ speaker_embs=speaker_embs,
345
+ batch_size=batch_size,
346
+ guidance_scale=None,
347
+ )
348
+
349
+ b_tokens = torch.cat([in_x, y], dim=1)
350
+ for tokens in b_tokens:
351
+ try:
352
+ to_return.append(self.decoder.decode(tokens=tokens.tolist(), causal=False))
353
+ except Exception as e:
354
+ print("failed to run MBD.")
355
+ print(f"reason: {str(e)}")
356
+ to_return.append(None)
357
+
358
+ return to_return
359
+
360
+ def __call__(
361
+ self,
362
+ *,
363
+ texts: list[str],
364
+ batch_size: int,
365
+ max_new_tokens: Optional[int],
366
+ top_k: Optional[int],
367
+ top_p: Optional[float],
368
+ temperature: Optional[float],
369
+ encodec_tokens: Optional[list[torch.Tensor]] = None,
370
+ speaker_embs: Optional[torch.Tensor] = None,
371
+ guidance_scale: Optional[float] = None,
372
+ ):
373
+ if self.checkpoint_config.get("causal", True):
374
+ return self.causal_sample(
375
+ texts=texts,
376
+ batch_size=batch_size,
377
+ speaker_embs=speaker_embs,
378
+ guidance_scale=guidance_scale,
379
+ max_new_tokens=max_new_tokens,
380
+ top_k=top_k,
381
+ top_p=top_p,
382
+ temperature=temperature,
383
+ )
384
+ else:
385
+ assert encodec_tokens is not None
386
+ assert guidance_scale is None
387
+ assert max_new_tokens is None
388
+ assert top_p is None
389
+
390
+ return self.non_causal_sample(
391
+ texts=texts,
392
+ encodec_tokens=encodec_tokens,
393
+ batch_size=batch_size,
394
+ speaker_embs=speaker_embs,
395
+ top_k=top_k,
396
+ temperature=temperature,
397
+ )
398
+
399
+
400
+ def save_result_metadata(wav_path, ref_path, text, first_stage_ckpt_path, second_stage_ckpt_path):
401
+ if first_stage_ckpt_path is None or second_stage_ckpt_path is None:
402
+ return
403
+ json.dump(
404
+ {
405
+ "speaker": ref_path,
406
+ "text": text,
407
+ },
408
+ pathlib.Path(str(wav_path) + ".json").open("w"),
409
+ )
410
+
411
+
412
+ def get_cached_file(file_or_uri: str):
413
+ """
414
+ If it's an s3 file, download it to a local temporary file and return that path.
415
+ Otherwise return the path as is.
416
+ """
417
+ is_uri = file_or_uri.startswith("http")
418
+
419
+ cache_path = None
420
+ if is_uri:
421
+ ext = pathlib.Path(file_or_uri).suffix
422
+ # hash the file path to get the cache name
423
+ _cache_name = "audio_" + hashlib.md5(file_or_uri.encode("utf-8")).hexdigest() + ext
424
+
425
+ os.makedirs(os.path.expanduser("~/.cache/fam/"), exist_ok=True)
426
+ cache_path = os.path.expanduser(f"~/.cache/fam/{_cache_name}")
427
+
428
+ if not os.path.exists(cache_path):
429
+ command = f"curl -o {cache_path} {file_or_uri}"
430
+ subprocess.run(command, shell=True, check=True)
431
+ else:
432
+ if os.path.exists(file_or_uri):
433
+ cache_path = file_or_uri
434
+ else:
435
+ raise FileNotFoundError(f"File {file_or_uri} not found!")
436
+
437
+ # check audio file is at min. 30s in length
438
+ audio, sr = librosa.load(cache_path)
439
+ #assert librosa.get_duration(y=audio, sr=sr) >= 30, "Speaker reference audio file needs to be >= 30s in duration."
440
+
441
+ return cache_path
442
+
443
+
444
+ def get_cached_embedding(local_file_path: str, spkemb_model):
445
+ if not os.path.exists(local_file_path):
446
+ raise FileNotFoundError(f"File {local_file_path} not found!")
447
+
448
+ # hash the file path to get the cache name
449
+ _cache_name = "embedding_" + hashlib.md5(local_file_path.encode("utf-8")).hexdigest() + ".pt"
450
+
451
+ os.makedirs(os.path.expanduser("~/.cache/fam/"), exist_ok=True)
452
+ cache_path = os.path.expanduser(f"~/.cache/fam/{_cache_name}")
453
+
454
+ if not os.path.exists(cache_path):
455
+ spk_emb = spkemb_model.embed_utterance_from_file(local_file_path, numpy=False).unsqueeze(0) # (b=1, c)
456
+ torch.save(spk_emb, cache_path)
457
+ else:
458
+ spk_emb = torch.load(cache_path)
459
+
460
+ return spk_emb
461
+
462
+
463
+ def _sample_utterance_batch(
464
+ texts: list[str],
465
+ spk_cond_paths: list[Optional[str]],
466
+ spkemb_model,
467
+ first_stage_model,
468
+ second_stage_model,
469
+ enhancer: Optional[Union[Literal["df"], BaseEnhancer]],
470
+ first_stage_ckpt_path: str,
471
+ second_stage_ckpt_path: str,
472
+ guidance_scale: Optional[float],
473
+ max_new_tokens: int,
474
+ top_k: Optional[int],
475
+ top_p: Optional[float],
476
+ temperature: Optional[float],
477
+ batch_size: int = 128,
478
+ ) -> List[str]:
479
+
480
+ speaker_embs = []
481
+ refs = spk_cond_paths.copy()
482
+
483
+ # multithreaded loop to cache all the files
484
+ spk_cond_paths = tqdm.contrib.concurrent.thread_map(
485
+ get_cached_file, spk_cond_paths, desc="getting cached speaker ref files"
486
+ )
487
+
488
+ for i, (text, spk_cond_path) in tqdm.tqdm(
489
+ enumerate(zip(texts, spk_cond_paths)), total=len(texts), desc="calculating speaker embeddings"
490
+ ):
491
+ texts[i] = normalize_text(text)
492
+ speaker_embs.append(get_cached_embedding(spk_cond_path, spkemb_model) if spk_cond_path else None)
493
+
494
+ b_speaker_embs = torch.cat(speaker_embs, dim=0)
495
+ b_tokens = first_stage_model(
496
+ texts=texts,
497
+ speaker_embs=b_speaker_embs,
498
+ batch_size=batch_size,
499
+ guidance_scale=guidance_scale,
500
+ top_p=top_p,
501
+ top_k=top_k,
502
+ temperature=temperature,
503
+ max_new_tokens=max_new_tokens,
504
+ )
505
+
506
+ # TODO: set batch size for second stage model!
507
+ wav_files = second_stage_model(
508
+ texts=texts,
509
+ encodec_tokens=b_tokens,
510
+ speaker_embs=b_speaker_embs,
511
+ batch_size=batch_size,
512
+ guidance_scale=None,
513
+ top_p=None,
514
+ top_k=top_k,
515
+ temperature=temperature,
516
+ max_new_tokens=None,
517
+ )
518
+
519
+ for text, tokens, speaker_embs, ref_name, wav_file in zip(texts, b_tokens, b_speaker_embs, refs, wav_files):
520
+ if wav_file is None:
521
+ continue
522
+
523
+ with tempfile.NamedTemporaryFile(suffix=".wav") as enhanced_tmp:
524
+ if enhancer is not None:
525
+ enhancer = get_enhancer(enhancer) if isinstance(enhancer, str) else enhancer
526
+ enhancer(str(wav_file) + ".wav", enhanced_tmp.name)
527
+ # copy enhanced_tmp.name back to wav_file
528
+ print(f"copying enhanced file from {enhanced_tmp.name} to {str(wav_file) + '.wav'}.")
529
+ shutil.copy2(enhanced_tmp.name, str(wav_file) + ".wav")
530
+
531
+ save_result_metadata(
532
+ wav_file,
533
+ ref_name,
534
+ text,
535
+ first_stage_ckpt_path,
536
+ second_stage_ckpt_path,
537
+ )
538
+ return [str(w) + ".wav" if not str(w).endswith(".wav") else str(w) for w in wav_files]
539
+
540
+
541
+ def sample_utterance(
542
+ text: str,
543
+ spk_cond_path: Optional[str],
544
+ spkemb_model,
545
+ first_stage_model,
546
+ second_stage_model,
547
+ enhancer: Optional[Union[Literal["df"], BaseEnhancer]],
548
+ first_stage_ckpt_path: str,
549
+ second_stage_ckpt_path: str,
550
+ guidance_scale: Optional[float],
551
+ max_new_tokens: int,
552
+ top_k: Optional[int],
553
+ top_p: Optional[float],
554
+ temperature: Optional[float],
555
+ ) -> str:
556
+ # NOTE: supports max. 220 characters atm.
557
+ # Long form synthesis coming soon...
558
+ MAX_CHARS = 220
559
+ if len(text) > MAX_CHARS:
560
+ print(
561
+ f"\n***WARNING: Max {MAX_CHARS} characters supported. Provided: {len(text)}. Truncating and generating speech...Can lead to unpredictable speech at the end.***"
562
+ )
563
+
564
+ return _sample_utterance_batch(
565
+ texts=[text],
566
+ spk_cond_paths=[spk_cond_path],
567
+ spkemb_model=spkemb_model,
568
+ first_stage_model=first_stage_model,
569
+ second_stage_model=second_stage_model,
570
+ enhancer=enhancer,
571
+ first_stage_ckpt_path=first_stage_ckpt_path,
572
+ second_stage_ckpt_path=second_stage_ckpt_path,
573
+ batch_size=1,
574
+ guidance_scale=guidance_scale,
575
+ max_new_tokens=max_new_tokens,
576
+ top_k=top_k,
577
+ top_p=top_p,
578
+ temperature=temperature,
579
+ )[0]
580
+
581
+
582
+ def build_models(config_first_stage, config_second_stage, device, use_kv_cache, first_model_path=None):
583
+ smodel = SpeakerEncoder(device=device, eval=True, verbose=False)
584
+ data_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=1024)
585
+ llm_first_stage = Model(
586
+ config_first_stage,
587
+ TrainedBPETokeniser,
588
+ EncodecDecoder,
589
+ data_adapter_fn=data_adapter.decode,
590
+ use_kv_cache=use_kv_cache,
591
+ first_model_path=first_model_path,
592
+ )
593
+ data_adapter_second_stage = TiltedEncodec(end_of_audio_token=1024)
594
+ llm_second_stage = Model(
595
+ config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode
596
+ )
597
+ return smodel, llm_first_stage, llm_second_stage
598
+
599
+
600
+ def get_first_stage_path(model_dir: str):
601
+ """Absolute path to checkpoint for the first stage model."""
602
+ return os.path.join(os.path.expanduser(model_dir), "first_stage.pt")
603
+
604
+
605
+ def get_second_stage_path(model_dir: str):
606
+ """Absolute path to checkpoint for the second stage model."""
607
+ return os.path.join(os.path.expanduser(model_dir), "second_stage.pt")
608
+
609
+
610
+ @dataclass
611
+ class SamplingControllerConfig:
612
+
613
+ huggingface_repo_id: str
614
+ """Absolute path to the model directory."""
615
+
616
+ spk_cond_path: str
617
+ """Path to speaker reference file. Min. 30s of audio required. Supports both local paths & public URIs. Audio formats: wav, flac & mp3"""
618
+
619
+ text: str = "コトバテクノロジーズのミッションは、音声基盤モデルを作ることです。"
620
+ """Text to synthesise."""
621
+
622
+ num_samples: int = 1
623
+ """Number of samples to generate from each model."""
624
+
625
+ max_new_tokens: int = 864
626
+ """Maximum number of new tokens to generate from the first stage model."""
627
+
628
+ temperature: float = 1.0
629
+ """Temperature for sampling applied to both models."""
630
+
631
+ top_k: Optional[int] = None
632
+ """Top k for sampling applied to both models."""
633
+
634
+ top_p: Optional[float] = 0.95
635
+ """Top p for sampling applied to first-stage model."""
636
+
637
+ seed: int = 1337
638
+ """Random seed for sampling."""
639
+
640
+ device: Literal["cuda", "cpu"] = "cuda"
641
+ """Device to use for sampling."""
642
+
643
+ dtype: Literal["bfloat16", "float16", "float32", "tfloat32"] = "bfloat16"
644
+ """Data type to use for sampling."""
645
+
646
+ compile: bool = False
647
+ """Whether to compile the model using PyTorch 2.0."""
648
+
649
+ enhancer: Optional[Literal["df"]] = "df"
650
+ """Enhancer to use for post-processing."""
651
+
652
+ init_from: str = "resume"
653
+ """Either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')."""
654
+
655
+ use_kv_cache: Optional[Literal["flash_decoding", "vanilla"]] = None
656
+ """Type of kv caching to use for inference: 1) [none] no kv caching, 2) [flash_decoding] use the
657
+ flash decoding kernel, 3) [vanilla] use flash attention 2 with hand implemented kv-cache."""
658
+
659
+ output_dir: str = "samples/"
660
+ """Relative path to output directory"""
661
+
662
+ guidance_scale: Optional[float] = 3.0
663
+ """Guidance scale for sampling."""
664
+
665
+ batch_size: int = 128
666
+ """Batch size to use for sampling. Note that the batch size gets doubled when guidance is used. For H100, and 1B model,
667
+ 1 w/ guidance and 1 w/o guidance work well (without kv-caching). With kv-caching, 128 (w/o guidance) and
668
+ 64 (w/ guidance) works well."""
669
+ """
670
+ Sample from a trained model.
671
+ """
672
+ first_model_path: str = None
673
+ """first model path"""
674
+
675
+
676
+
677
+ if __name__ == "__main__":
678
+ # TODO: add support for batch sampling via CLI. Function has been implemented above.
679
+ sampling_config = tyro.cli(SamplingControllerConfig, use_underscores=True)
680
+
681
+ model_dir = snapshot_download(repo_id=sampling_config.huggingface_repo_id)
682
+ first_stage_ckpt_path = get_first_stage_path(model_dir)
683
+ second_stage_ckpt_path = get_second_stage_path(model_dir)
684
+
685
+ config_first_stage = InferenceConfig(
686
+ ckpt_path=first_stage_ckpt_path,
687
+ num_samples=sampling_config.num_samples,
688
+ seed=sampling_config.seed,
689
+ device=sampling_config.device,
690
+ dtype=sampling_config.dtype,
691
+ compile=sampling_config.compile,
692
+ init_from=sampling_config.init_from,
693
+ output_dir=sampling_config.output_dir,
694
+ )
695
+
696
+ config_second_stage = InferenceConfig(
697
+ ckpt_path=second_stage_ckpt_path,
698
+ num_samples=sampling_config.num_samples,
699
+ seed=sampling_config.seed,
700
+ device=sampling_config.device,
701
+ dtype=sampling_config.dtype,
702
+ compile=sampling_config.compile,
703
+ init_from=sampling_config.init_from,
704
+ output_dir=sampling_config.output_dir,
705
+ )
706
+
707
+ sampling_config.max_new_tokens *= (
708
+ 2 # deal with max_new_tokens for flattened interleaving! (should scale with num_codebooks?)
709
+ )
710
+
711
+ # define models
712
+ smodel, llm_first_stage, llm_second_stage = build_models(
713
+ config_first_stage, config_second_stage, sampling_config.device, sampling_config.use_kv_cache, sampling_config.first_model_path,
714
+ )
715
+
716
+ print(f"Synthesising utterance...")
717
+ sample_utterance(
718
+ sampling_config.text,
719
+ os.path.expanduser(sampling_config.spk_cond_path),
720
+ smodel,
721
+ llm_first_stage,
722
+ llm_second_stage,
723
+ sampling_config.enhancer,
724
+ first_stage_ckpt_path,
725
+ second_stage_ckpt_path,
726
+ sampling_config.guidance_scale,
727
+ max_new_tokens=sampling_config.max_new_tokens,
728
+ top_k=sampling_config.top_k,
729
+ top_p=sampling_config.top_p,
730
+ temperature=sampling_config.temperature,
731
+ )
fam/llm/serving.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import shlex
5
+ import subprocess
6
+ import tempfile
7
+ from pathlib import Path
8
+ from typing import Literal, Optional
9
+
10
+ import fastapi
11
+ import fastapi.middleware.cors
12
+ import torch
13
+ import tyro
14
+ import uvicorn
15
+ from attr import dataclass
16
+ from fastapi import Request
17
+ from fastapi.responses import Response
18
+ from huggingface_hub import snapshot_download
19
+
20
+ from fam.llm.sample import (
21
+ InferenceConfig,
22
+ Model,
23
+ build_models,
24
+ get_first_stage_path,
25
+ get_second_stage_path,
26
+ # sample_utterance,
27
+ )
28
+ from fam.llm.fast_inference import TTS
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ ## Setup FastAPI server.
34
+ app = fastapi.FastAPI()
35
+
36
+
37
+ @dataclass
38
+ class ServingConfig:
39
+ huggingface_repo_id: str
40
+ """Absolute path to the model directory."""
41
+
42
+ max_new_tokens: int = 864 * 2
43
+ """Maximum number of new tokens to generate from the first stage model."""
44
+
45
+ temperature: float = 1.0
46
+ """Temperature for sampling applied to both models."""
47
+
48
+ top_k: int = 200
49
+ """Top k for sampling applied to both models."""
50
+
51
+ seed: int = 1337
52
+ """Random seed for sampling."""
53
+
54
+ dtype: Literal["bfloat16", "float16", "float32", "tfloat32"] = "bfloat16"
55
+ """Data type to use for sampling."""
56
+
57
+ enhancer: Optional[Literal["df"]] = "df"
58
+ """Enhancer to use for post-processing."""
59
+
60
+ port: int = 58003
61
+
62
+
63
+ # Singleton
64
+ class _GlobalState:
65
+ config: ServingConfig
66
+ tts: TTS
67
+
68
+
69
+ GlobalState = _GlobalState()
70
+
71
+ @dataclass(frozen=True)
72
+ class TTSRequest:
73
+ text: str
74
+ guidance: Optional[float] = 3.0
75
+ top_p: Optional[float] = 0.95
76
+ speaker_ref_path: Optional[str] = None
77
+ top_k: Optional[int] = None
78
+
79
+
80
+ def sample_utterance(
81
+ text: str,
82
+ spk_cond_path: str | None,
83
+ guidance_scale,
84
+ max_new_tokens,
85
+ top_k,
86
+ top_p,
87
+ temperature,
88
+ ) -> str:
89
+ return GlobalState.tts.synthesise(
90
+ text,
91
+ spk_cond_path,
92
+ top_p=top_p,
93
+ guidance_scale=guidance_scale,
94
+ temperature=temperature,
95
+ )
96
+
97
+
98
+ @app.post("/tts", response_class=Response)
99
+ async def text_to_speech(req: Request):
100
+ audiodata = await req.body()
101
+ payload = None
102
+ wav_out_path = None
103
+
104
+ try:
105
+ headers = req.headers
106
+ payload = headers["X-Payload"]
107
+ payload = json.loads(payload)
108
+ tts_req = TTSRequest(**payload)
109
+ with tempfile.NamedTemporaryFile(suffix=".wav") as wav_tmp:
110
+ if tts_req.speaker_ref_path is None:
111
+ wav_path = _convert_audiodata_to_wav_path(audiodata, wav_tmp)
112
+ else:
113
+ wav_path = tts_req.speaker_ref_path
114
+ wav_out_path = sample_utterance(
115
+ tts_req.text,
116
+ wav_path,
117
+ guidance_scale=tts_req.guidance,
118
+ max_new_tokens=GlobalState.config.max_new_tokens,
119
+ temperature=GlobalState.config.temperature,
120
+ top_k=tts_req.top_k,
121
+ top_p=tts_req.top_p,
122
+ )
123
+ with open(wav_out_path, "rb") as f:
124
+ return Response(content=f.read(), media_type="audio/wav")
125
+ except Exception as e:
126
+ # traceback_str = "".join(traceback.format_tb(e.__traceback__))
127
+ logger.exception(f"Error processing request {payload}")
128
+ return Response(
129
+ content="Something went wrong. Please try again in a few mins or contact us on Discord",
130
+ status_code=500,
131
+ )
132
+ finally:
133
+ if wav_out_path is not None:
134
+ Path(wav_out_path).unlink(missing_ok=True)
135
+
136
+
137
+ def _convert_audiodata_to_wav_path(audiodata, wav_tmp):
138
+ with tempfile.NamedTemporaryFile() as unknown_format_tmp:
139
+ assert unknown_format_tmp.write(audiodata) > 0
140
+ unknown_format_tmp.flush()
141
+
142
+ subprocess.check_output(
143
+ # arbitrary 2 minute cutoff
144
+ shlex.split(f"ffmpeg -t 120 -y -i {unknown_format_tmp.name} -f wav {wav_tmp.name}")
145
+ )
146
+
147
+ return wav_tmp.name
148
+
149
+
150
+ if __name__ == "__main__":
151
+ # This has to be here to avoid some weird audiocraft shenaningans messing up matplotlib
152
+ from fam.llm.enhancers import get_enhancer
153
+
154
+ for name in logging.root.manager.loggerDict:
155
+ logger = logging.getLogger(name)
156
+ logger.setLevel(logging.INFO)
157
+ logging.root.setLevel(logging.INFO)
158
+
159
+ GlobalState.config = tyro.cli(ServingConfig)
160
+ app.add_middleware(
161
+ fastapi.middleware.cors.CORSMiddleware,
162
+ allow_origins=["*", f"http://localhost:{GlobalState.config.port}", "http://localhost:3000"],
163
+ allow_credentials=True,
164
+ allow_methods=["*"],
165
+ allow_headers=["*"],
166
+ )
167
+
168
+ device = "cuda" if torch.cuda.is_available() else "cpu"
169
+ common_config = dict(
170
+ num_samples=1,
171
+ seed=1337,
172
+ device=device,
173
+ dtype=GlobalState.config.dtype,
174
+ compile=False,
175
+ init_from="resume",
176
+ output_dir=tempfile.mkdtemp(),
177
+ )
178
+ model_dir = snapshot_download(repo_id=GlobalState.config.huggingface_repo_id)
179
+ config1 = InferenceConfig(
180
+ ckpt_path=get_first_stage_path(model_dir),
181
+ **common_config,
182
+ )
183
+
184
+ config2 = InferenceConfig(
185
+ ckpt_path=get_second_stage_path(model_dir),
186
+ **common_config,
187
+ )
188
+
189
+ GlobalState.tts = TTS()
190
+
191
+ # start server
192
+ uvicorn.run(
193
+ app,
194
+ host="127.0.0.1",
195
+ port=GlobalState.config.port,
196
+ log_level="info",
197
+ )
fam/llm/utils.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import subprocess
4
+ import tempfile
5
+
6
+ import librosa
7
+ import torch
8
+
9
+
10
+ def normalize_text(text: str) -> str:
11
+ unicode_conversion = {
12
+ 8175: "'",
13
+ 8189: "'",
14
+ 8190: "'",
15
+ 8208: "-",
16
+ 8209: "-",
17
+ 8210: "-",
18
+ 8211: "-",
19
+ 8212: "-",
20
+ 8213: "-",
21
+ 8214: "||",
22
+ 8216: "'",
23
+ 8217: "'",
24
+ 8218: ",",
25
+ 8219: "`",
26
+ 8220: '"',
27
+ 8221: '"',
28
+ 8222: ",,",
29
+ 8223: '"',
30
+ 8228: ".",
31
+ 8229: "..",
32
+ 8230: "...",
33
+ 8242: "'",
34
+ 8243: '"',
35
+ 8245: "'",
36
+ 8246: '"',
37
+ 180: "'",
38
+ 2122: "TM", # Trademark
39
+ }
40
+
41
+ text = text.translate(unicode_conversion)
42
+
43
+ non_bpe_chars = set([c for c in list(text) if ord(c) >= 256])
44
+ #if len(non_bpe_chars) > 0:
45
+ # non_bpe_points = [(c, ord(c)) for c in non_bpe_chars]
46
+ # raise ValueError(f"Non-BPE single token characters found: {non_bpe_points}")
47
+
48
+ text = text.replace("\t", " ")
49
+ text = text.replace("\n", " ")
50
+ text = text.replace("*", " ")
51
+ text = text.strip()
52
+ text = re.sub("\s\s+", " ", text) # remove multiple spaces
53
+ return text
54
+
55
+ def check_audio_file(path_or_uri, threshold_s=10): # default 30
56
+ if "http" in path_or_uri:
57
+ temp_fd, filepath = tempfile.mkstemp()
58
+ os.close(temp_fd) # Close the file descriptor, curl will create a new connection
59
+ curl_command = ["curl", "-L", path_or_uri, "-o", filepath]
60
+ subprocess.run(curl_command, check=True)
61
+
62
+ else:
63
+ filepath = path_or_uri
64
+
65
+ audio, sr = librosa.load(filepath)
66
+ duration_s = librosa.get_duration(y=audio, sr=sr)
67
+ if duration_s < threshold_s:
68
+ raise Exception(
69
+ f"The audio file is too short. Please provide an audio file that is at least {threshold_s} seconds long to proceed."
70
+ )
71
+
72
+ # Clean up the temporary file if it was created
73
+ if "http" in path_or_uri:
74
+ os.remove(filepath)
75
+
76
+
77
+ def get_default_dtype() -> str:
78
+ """Compute default 'dtype' based on GPU architecture"""
79
+ if torch.cuda.is_available():
80
+ for i in range(torch.cuda.device_count()):
81
+ device_properties = torch.cuda.get_device_properties(i)
82
+ dtype = "float16" if device_properties.major <= 7 else "bfloat16" # tesla and turing architectures
83
+ else:
84
+ dtype = "float16"
85
+
86
+ print(f"using dtype={dtype}")
87
+ return dtype
88
+
89
+
90
+ def get_device() -> str:
91
+ return "cuda" if torch.cuda.is_available() else "cpu"
fam/py.typed ADDED
File without changes
fam/quantiser/__init__.py ADDED
File without changes
fam/quantiser/audio/__init__.py ADDED
File without changes
fam/quantiser/audio/speaker_encoder/__init__.py ADDED
File without changes
fam/quantiser/audio/speaker_encoder/audio.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import numpy as np
3
+
4
+ mel_window_length = 25
5
+ mel_window_step = 10
6
+ mel_n_channels = 40
7
+ sampling_rate = 16000
8
+
9
+
10
+ def wav_to_mel_spectrogram(wav):
11
+ """
12
+ Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
13
+ Note: this not a log-mel spectrogram.
14
+ """
15
+ frames = librosa.feature.melspectrogram(
16
+ y=wav,
17
+ sr=sampling_rate,
18
+ n_fft=int(sampling_rate * mel_window_length / 1000),
19
+ hop_length=int(sampling_rate * mel_window_step / 1000),
20
+ n_mels=mel_n_channels,
21
+ )
22
+ return frames.astype(np.float32).T
fam/quantiser/audio/speaker_encoder/ckpt/.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pt filter=lfs diff=lfs merge=lfs -text
fam/quantiser/audio/speaker_encoder/ckpt/ckpt.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc7ff82ef75becd495aab2ede3a8220da393a717f178ae9534df355a6173bbca
3
+ size 17090379
fam/quantiser/audio/speaker_encoder/model.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from time import perf_counter as timer
3
+ from typing import List, Optional, Union
4
+
5
+ import librosa
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+
10
+ from fam.quantiser.audio.speaker_encoder import audio
11
+
12
+ DEFAULT_SPKENC_CKPT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ckpt/ckpt.pt")
13
+
14
+ mel_window_step = 10
15
+ mel_n_channels = 40
16
+ sampling_rate = 16000
17
+ partials_n_frames = 160
18
+ model_hidden_size = 256
19
+ model_embedding_size = 256
20
+ model_num_layers = 3
21
+
22
+
23
+ class SpeakerEncoder(nn.Module):
24
+ def __init__(
25
+ self,
26
+ weights_fpath: Optional[str] = None,
27
+ device: Optional[Union[str, torch.device]] = None,
28
+ verbose: bool = True,
29
+ eval: bool = False,
30
+ ):
31
+ super().__init__()
32
+
33
+ # Define the network
34
+ self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
35
+ self.linear = nn.Linear(model_hidden_size, model_embedding_size)
36
+ self.relu = nn.ReLU()
37
+
38
+ # Get the target device
39
+ if device is None:
40
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
+ elif isinstance(device, str):
42
+ device = torch.device(device)
43
+ self.device = device
44
+
45
+ start = timer()
46
+ if eval and weights_fpath is None:
47
+ weights_fpath = DEFAULT_SPKENC_CKPT_PATH
48
+
49
+ if weights_fpath is not None:
50
+ checkpoint = torch.load(weights_fpath, map_location="cpu")
51
+
52
+ self.load_state_dict(checkpoint["model_state"], strict=False)
53
+ self.to(device)
54
+
55
+ if eval:
56
+ self.eval()
57
+
58
+ if verbose:
59
+ print("Loaded the speaker embedding model on %s in %.2f seconds." % (device.type, timer() - start))
60
+
61
+ def forward(self, mels: torch.FloatTensor):
62
+ _, (hidden, _) = self.lstm(mels)
63
+ embeds_raw = self.relu(self.linear(hidden[-1]))
64
+ return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
65
+
66
+ @staticmethod
67
+ def compute_partial_slices(n_samples: int, rate, min_coverage):
68
+ # Compute how many frames separate two partial utterances
69
+ samples_per_frame = int((sampling_rate * mel_window_step / 1000))
70
+ n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
71
+ frame_step = int(np.round((sampling_rate / rate) / samples_per_frame))
72
+
73
+ # Compute the slices
74
+ wav_slices, mel_slices = [], []
75
+ steps = max(1, n_frames - partials_n_frames + frame_step + 1)
76
+ for i in range(0, steps, frame_step):
77
+ mel_range = np.array([i, i + partials_n_frames])
78
+ wav_range = mel_range * samples_per_frame
79
+ mel_slices.append(slice(*mel_range))
80
+ wav_slices.append(slice(*wav_range))
81
+
82
+ # Evaluate whether extra padding is warranted or not
83
+ last_wav_range = wav_slices[-1]
84
+ coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
85
+ if coverage < min_coverage and len(mel_slices) > 1:
86
+ mel_slices = mel_slices[:-1]
87
+ wav_slices = wav_slices[:-1]
88
+
89
+ return wav_slices, mel_slices
90
+
91
+ def embed_utterance(self, wav: np.ndarray, return_partials=False, rate=1.3, min_coverage=0.75, numpy: bool = True):
92
+ wav_slices, mel_slices = self.compute_partial_slices(len(wav), rate, min_coverage)
93
+ max_wave_length = wav_slices[-1].stop
94
+ if max_wave_length >= len(wav):
95
+ wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
96
+
97
+ mel = audio.wav_to_mel_spectrogram(wav)
98
+ mels = np.array([mel[s] for s in mel_slices])
99
+ with torch.no_grad():
100
+ mels = torch.from_numpy(mels).to(self.device) # type: ignore
101
+ partial_embeds = self(mels)
102
+
103
+ if numpy:
104
+ partial_embeds = partial_embeds.cpu().numpy()
105
+ raw_embed = np.mean(partial_embeds, axis=0)
106
+ embed = raw_embed / np.linalg.norm(raw_embed, 2)
107
+ else:
108
+ raw_embed = partial_embeds.mean(dim=0)
109
+ embed = raw_embed / torch.linalg.norm(raw_embed, 2)
110
+
111
+ if return_partials:
112
+ return embed, partial_embeds, wav_slices
113
+ return embed
114
+
115
+ def embed_speaker(self, wavs: List[np.ndarray], **kwargs):
116
+ raw_embed = np.mean([self.embed_utterance(wav, return_partials=False, **kwargs) for wav in wavs], axis=0)
117
+ return raw_embed / np.linalg.norm(raw_embed, 2)
118
+
119
+ def embed_utterance_from_file(self, fpath: str, numpy: bool) -> torch.Tensor:
120
+ wav_tgt, _ = librosa.load(fpath, sr=16000)
121
+ wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20)
122
+ embedding = self.embed_utterance(wav_tgt, numpy=numpy)
123
+ return embedding
fam/quantiser/text/tokenise.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tiktoken
2
+
3
+
4
+ class TrainedBPETokeniser:
5
+ def __init__(self, name, pat_str, mergeable_ranks, special_tokens, offset=None) -> None:
6
+ self.tokenizer = tiktoken.Encoding(
7
+ name=name,
8
+ pat_str=pat_str,
9
+ mergeable_ranks=mergeable_ranks,
10
+ special_tokens=special_tokens,
11
+ )
12
+ self.offset = offset
13
+
14
+ def encode(self, text: str) -> list[int]:
15
+ # note: we add a end of text token!
16
+ tokens = self.tokenizer.encode(text) + [self.tokenizer.eot_token]
17
+ if self.offset is not None:
18
+ tokens = [x + self.offset for x in tokens]
19
+
20
+ return tokens
21
+
22
+ def decode(self, tokens: list[int]):
23
+ if self.offset is not None:
24
+ tokens = [x - self.offset for x in tokens]
25
+ return self.tokenizer.decode(tokens)
26
+
27
+ @property
28
+ def eot_token(self):
29
+ if self.offset is not None:
30
+ return self.tokenizer.eot_token + self.offset
31
+ else:
32
+ return self.tokenizer.eot_token
fam/ui/app.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import json
3
+ import os
4
+
5
+ import gradio as gr
6
+ import requests
7
+ import soundfile as sf
8
+
9
+ API_SERVER_URL = "http://127.0.0.1:58003/tts"
10
+ RADIO_CHOICES = ["Preset voices", "Upload target voice", "Record your voice"]
11
+ MAX_CHARS = 220
12
+ PRESET_VOICES = {
13
+ # female
14
+ "Ava": "https://cdn.themetavoice.xyz/speakers/ava.flac",
15
+ "Bria": "https://cdn.themetavoice.xyz/speakers/bria.mp3",
16
+ # male
17
+ "Alex": "https://cdn.themetavoice.xyz/speakers/alex.mp3",
18
+ "Jacob": "https://cdn.themetavoice.xyz/speakers/jacob.wav",
19
+ }
20
+
21
+
22
+ def denormalise_top_p(top_p):
23
+ # returns top_p in the range [0.9, 1.0]
24
+ return round(0.9 + top_p / 100, 2)
25
+
26
+
27
+ def denormalise_guidance(guidance):
28
+ # returns guidance in the range [1.0, 3.0]
29
+ return 1 + ((guidance - 1) * (3 - 1)) / (5 - 1)
30
+
31
+
32
+ def _handle_edge_cases(to_say, upload_target):
33
+ if not to_say:
34
+ raise gr.Error("Please provide text to synthesise")
35
+
36
+ def _check_file_size(path):
37
+ if not path:
38
+ return
39
+ filesize = os.path.getsize(path)
40
+ filesize_mb = filesize / 1024 / 1024
41
+ if filesize_mb >= 50:
42
+ raise gr.Error(
43
+ f"Please upload a sample less than 20MB for voice cloning. Provided: {round(filesize_mb)} MB"
44
+ )
45
+
46
+ _check_file_size(upload_target)
47
+
48
+
49
+ def tts(to_say, top_p, guidance, toggle, preset_dropdown, upload_target, record_target):
50
+ d_top_p = denormalise_top_p(top_p)
51
+ d_guidance = denormalise_guidance(guidance)
52
+
53
+ _handle_edge_cases(to_say, upload_target)
54
+
55
+ to_say = to_say if len(to_say) < MAX_CHARS else to_say[:MAX_CHARS]
56
+
57
+ custom_target_path = None
58
+ if toggle == RADIO_CHOICES[1]:
59
+ custom_target_path = upload_target
60
+ elif toggle == RADIO_CHOICES[2]:
61
+ custom_target_path = record_target
62
+
63
+ config = {
64
+ "text": to_say,
65
+ "guidance": d_guidance,
66
+ "top_p": d_top_p,
67
+ "speaker_ref_path": PRESET_VOICES[preset_dropdown] if toggle == RADIO_CHOICES[0] else None,
68
+ }
69
+ headers = {"Content-Type": "audio/wav", "X-Payload": json.dumps(config)}
70
+ if not custom_target_path:
71
+ response = requests.post(API_SERVER_URL, headers=headers, data=None)
72
+ else:
73
+ with open(custom_target_path, "rb") as f:
74
+ data = f.read()
75
+ response = requests.post(API_SERVER_URL, headers=headers, data=data)
76
+
77
+ wav, sr = None, None
78
+ if response.status_code == 200:
79
+ audio_buffer = io.BytesIO(response.content)
80
+ audio_buffer.seek(0)
81
+ wav, sr = sf.read(audio_buffer, dtype="float32")
82
+ else:
83
+ print(f"Something went wrong. response status code: {response.status_code}")
84
+
85
+ return sr, wav
86
+
87
+
88
+ def change_voice_selection_layout(choice):
89
+ index = RADIO_CHOICES.index(choice)
90
+ return [
91
+ gr.update(visible=True)
92
+ if i == index else gr.update(visible=False)
93
+ for i in range(len(RADIO_CHOICES))
94
+ ]
95
+
96
+
97
+ title = "# TTS by Kotoba-Speech"
98
+
99
+ description = """
100
+ <strong>Kotoba-Speech v0.1</strong>は、1.2Bのトランスフォーマーに基づく音声生成モデルです。
101
+ 以下の機能をサポートしています:
102
+ \n
103
+ * 日本語における滑らかなテキスト読み上げ生成
104
+ * スピーチプロンプトを通じたOne-shot音声クローニング
105
+
106
+ Kotoba Technologiesは、公開されたモデルを商用可能なApache 2.0ライセンスで公開します。
107
+ 推論およびモデルコードは、Meta-Voiceをベースに作られており、学習コードは弊社のGitHubで近日中に公開する予定です。
108
+ Kotoba Technologiesは、音声基盤モデルの開発に取り組んでおり、今後もモデルの公開を行なっていきます。是非、[Discord Community](https://discord.gg/qPVFqhGN7Z)に参加してご意見ください!
109
+
110
+ <strong>Kotoba-Speech v0.1</strong> is a 1.2B Transformer-based speech generative model. It supports the following properties:
111
+ \n
112
+ * Fluent text-to-speech generation in Japanese
113
+ * One-shot voice cloning through speech prompt
114
+
115
+ We are releasing our model under the Apache 2.0 license. Our inference and model code is adapted from Meta-Voice, and we will our training code on our GitHub repository shortly.
116
+ Kotoba Technologies is committing on developing speech foundation models, and we’ll continue releasing our models. Please join [our discord](https://discord.gg/qPVFqhGN7Z) to contribute to out community.
117
+ """
118
+
119
+ with gr.Blocks(title="TTS by Kotoba-Speech") as demo:
120
+ gr.Markdown(title)
121
+
122
+ with gr.Row():
123
+ gr.Markdown(description)
124
+
125
+ with gr.Row():
126
+ with gr.Column():
127
+ to_say = gr.TextArea(
128
+ label="What should I say!?",
129
+ lines=4,
130
+ value="コトバテクノロジーズのミッションは、音声基盤モデルを作ることです。",
131
+ )
132
+
133
+ with gr.Row(), gr.Column():
134
+ # voice settings
135
+ top_p = gr.Slider(
136
+ value=5.0,
137
+ minimum=0.0,
138
+ maximum=10.0,
139
+ step=1.0,
140
+ label="Speech Stability - improves text following for a challenging speaker",
141
+ )
142
+ guidance = gr.Slider(
143
+ value=5.0,
144
+ minimum=1.0,
145
+ maximum=5.0,
146
+ step=1.0,
147
+ label="Speaker similarity - How closely to match speaker identity and speech style.",
148
+ )
149
+
150
+ # voice select
151
+ toggle = gr.Radio(choices=RADIO_CHOICES, label="Choose voice", value=RADIO_CHOICES[0])
152
+
153
+ with gr.Row(visible=True) as row_1:
154
+ preset_dropdown = gr.Dropdown(
155
+ PRESET_VOICES.keys(), label="Preset voices", value=list(PRESET_VOICES.keys())[0]
156
+ )
157
+ with gr.Accordion("Preview: Preset voices", open=False):
158
+ for label, path in PRESET_VOICES.items():
159
+ gr.Audio(value=path, label=label)
160
+
161
+ with gr.Row(visible=False) as row_2:
162
+ upload_target = gr.Audio(
163
+ sources=["upload"],
164
+ type="filepath",
165
+ label="Upload a clean sample to clone. Sample should contain 1 speaker, be between 10-90 seconds and not contain background noise.",
166
+ min_length=10,
167
+ max_length=90,
168
+ )
169
+
170
+ with gr.Row(visible=False) as row_3:
171
+ record_target = gr.Audio(
172
+ sources=["microphone"],
173
+ type="filepath",
174
+ label="Record your voice with a microphone to clone. Sample should contain 1 speaker, be between 10-90 seconds and not contain background noise.",
175
+ min_length=10,
176
+ max_length=90,
177
+ )
178
+
179
+ toggle.change(
180
+ change_voice_selection_layout,
181
+ inputs=toggle,
182
+ outputs=[row_1, row_2, row_3],
183
+ )
184
+
185
+ with gr.Column():
186
+ speech = gr.Audio(
187
+ type="numpy",
188
+ label="Kotoba-Speech says...",
189
+ )
190
+
191
+ submit = gr.Button("Generate Speech")
192
+ submit.click(
193
+ fn=tts,
194
+ inputs=[to_say, top_p, guidance, toggle, preset_dropdown, upload_target, record_target],
195
+ outputs=speech,
196
+ )
197
+
198
+
199
+ demo.queue(default_concurrency_limit=2)
200
+ # demo.launch()
201
+ demo.launch(server_name="0.0.0.0", server_port=3000, share=True)
requirements.txt ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ anyio==4.0.0
2
+ argon2-cffi==23.1.0
3
+ argon2-cffi-bindings==21.2.0
4
+ arrow==1.3.0
5
+ asttokens==2.4.1
6
+ async-lru==2.0.4
7
+ attrs==23.1.0
8
+ Babel==2.13.1
9
+ beautifulsoup4==4.12.2
10
+ bleach==6.1.0
11
+ blinker==1.4
12
+ certifi==2022.12.7
13
+ cffi==1.16.0
14
+ charset-normalizer==2.1.1
15
+ comm==0.2.0
16
+ cryptography==3.4.8
17
+ dbus-python==1.2.18
18
+ debugpy==1.8.0
19
+ decorator==5.1.1
20
+ deepfilternet
21
+ defusedxml==0.7.1
22
+ distro==1.7.0
23
+ entrypoints==0.4
24
+ exceptiongroup==1.1.3
25
+ executing==2.0.1
26
+ fastjsonschema==2.18.1
27
+ filelock==3.9.0
28
+ fqdn==1.5.1
29
+ fsspec
30
+ gradio==4.21.0
31
+ httplib2==0.20.2
32
+ idna==3.4
33
+ importlib-metadata==4.6.4
34
+ ipykernel==6.26.0
35
+ ipython==8.17.2
36
+ ipython-genutils==0.2.0
37
+ ipywidgets==8.1.1
38
+ isoduration==20.11.0
39
+ jedi==0.19.1
40
+ jeepney==0.7.1
41
+ Jinja2==3.1.2
42
+ json5==0.9.14
43
+ jsonpointer==2.4
44
+ jsonschema==4.19.2
45
+ jsonschema-specifications==2023.7.1
46
+ jupyter-archive==3.4.0
47
+ jupyter-contrib-core==0.4.2
48
+ jupyter-contrib-nbextensions==0.7.0
49
+ jupyter-events==0.9.0
50
+ jupyter-highlight-selected-word==0.2.0
51
+ jupyter-lsp==2.2.0
52
+ jupyter-nbextensions-configurator==0.6.3
53
+ jupyter_client==7.4.9
54
+ jupyter_core==5.5.0
55
+ jupyter_server==2.10.0
56
+ jupyter_server_terminals==0.4.4
57
+ jupyterlab==4.0.8
58
+ jupyterlab-pygments==0.2.2
59
+ jupyterlab-widgets==3.0.9
60
+ jupyterlab_server==2.25.0
61
+ keyring==23.5.0
62
+ launchpadlib==1.10.16
63
+ lazr.restfulclient==0.14.4
64
+ lazr.uri==1.0.6
65
+ librosa
66
+ lxml==4.9.3
67
+ Mako==1.1.3
68
+ Markdown==3.3.6
69
+ MarkupSafe==2.1.2
70
+ matplotlib-inline==0.1.6
71
+ mistune==3.0.2
72
+ more-itertools==8.10.0
73
+ mpmath==1.3.0
74
+ nbclassic==1.0.0
75
+ nbclient==0.9.0
76
+ nbconvert==7.11.0
77
+ nbformat==5.9.2
78
+ nest-asyncio==1.5.8
79
+ networkx==3.0
80
+ notebook==6.5.5
81
+ notebook_shim==0.2.3
82
+ numpy==1.24.1
83
+ oauthlib==3.2.0
84
+ overrides==7.4.0
85
+ packaging==23.2
86
+ pandocfilters==1.5.0
87
+ parso==0.8.3
88
+ pexpect==4.8.0
89
+ Pillow==9.3.0
90
+ platformdirs==3.11.0
91
+ prometheus-client==0.18.0
92
+ prompt-toolkit==3.0.39
93
+ psutil==5.9.6
94
+ ptyprocess==0.7.0
95
+ pure-eval==0.2.2
96
+ pycparser==2.21
97
+ Pygments==2.16.1
98
+ PyGObject==3.42.1
99
+ PyJWT==2.3.0
100
+ pyparsing==2.4.7
101
+ # python-apt==2.4.0
102
+ python-dateutil==2.8.2
103
+ python-json-logger==2.0.7
104
+ PyYAML==6.0.1
105
+ pyzmq==24.0.1
106
+ referencing==0.30.2
107
+ requests==2.31.0
108
+ rfc3339-validator==0.1.4
109
+ rfc3986-validator==0.1.1
110
+ rpds-py==0.12.0
111
+ SecretStorage==3.3.1
112
+ Send2Trash==1.8.2
113
+ six==1.16.0
114
+ sniffio==1.3.0
115
+ soundfile==0.12.1
116
+ soupsieve==2.5
117
+ stack-data==0.6.3
118
+ sympy==1.12
119
+ terminado==0.17.1
120
+ tinycss2==1.2.1
121
+ tomli==2.0.1
122
+ torch
123
+ torchaudio
124
+ torchvision
125
+ tornado==6.3.3
126
+ traitlets==5.13.0
127
+ triton==2.1.0
128
+ types-python-dateutil==2.8.19.14
129
+ typing_extensions
130
+ uri-template==1.3.0
131
+ urllib3==1.26.13
132
+ wadllib==1.3.6
133
+ wcwidth==0.2.9
134
+ webcolors==1.13
135
+ webencodings==0.5.1
136
+ websocket-client==1.6.4
137
+ widgetsnbextension==4.0.9
138
+ zipp==1.0.0
setup.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from setuptools import find_packages, setup # type: ignore
2
+
3
+ setup(
4
+ name="fam",
5
+ packages=find_packages(".", exclude=["tests"]),
6
+ )