diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..a1c6d672057887110965281855b391f309911df1
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,33 @@
+FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 as base
+
+# Install system dependencies in a single RUN command to reduce layers
+# Combine apt-get update, upgrade, and installation of packages. Clean up in the same layer to reduce image size.
+RUN apt-get update && \
+ apt-get upgrade -y && \
+ apt-get install -y python3.10 python3-pip git wget curl build-essential && \
+ apt-get autoremove -y && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+# install ffmpeg
+RUN wget https://johnvansickle.com/ffmpeg/builds/ffmpeg-git-amd64-static.tar.xz &&\
+ wget https://johnvansickle.com/ffmpeg/builds/ffmpeg-git-amd64-static.tar.xz.md5 &&\
+ md5sum -c ffmpeg-git-amd64-static.tar.xz.md5 &&\
+ tar xvf ffmpeg-git-amd64-static.tar.xz &&\
+ mv ffmpeg-git-*-static/ffprobe ffmpeg-git-*-static/ffmpeg /usr/local/bin/ &&\
+ rm -rf ffmpeg-git-*
+
+WORKDIR /app
+
+COPY requirements.txt requirements.txt
+
+RUN pip install --no-cache-dir packaging wheel torch
+RUN pip install --no-cache-dir audiocraft # HACK: installation fails within the requirements.txt
+RUN pip install --no-cache-dir -r requirements.txt
+RUN pip install --no-cache-dir --upgrade torch torchaudio
+
+COPY . .
+
+RUN pip install --no-cache-dir -e .
+
+ENTRYPOINT ["python3.10", "serving.py"]
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/all_emo_dirs.pkl b/all_emo_dirs.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..6964ac97dc1b88f65d1100e4c3713180a6efde37
--- /dev/null
+++ b/all_emo_dirs.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:beadd1f3c7eada0fa99dbdecc5c370036c1c044955a02f019f879bdc6f5fefcb
+size 20343
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..e612fb0abab323e69cea4497277e3f5603be2559
--- /dev/null
+++ b/app.py
@@ -0,0 +1,364 @@
+import gradio as gr
+import os
+
+
+is_prod = True
+if os.environ.get('PROD_MODE') == 'local':
+ is_prod = False
+
+import pickle
+
+if not is_prod:
+ import os
+ os.environ['HF_HOME'] = '/proj/afosr/metavoice/cache'
+ os.environ['TRANSFORMERS_CACHE'] = '/proj/afosr/metavoice/cache'
+ os.environ['HF_DATASETS_CACHE'] = '/proj/afosr/metavoice/cache'
+ os.environ['HF_METRICS_CACHE'] = '/proj/afosr/metavoice/cache'
+ os.environ['HF_MODULES_CACHE'] = '/proj/afosr/metavoice/cache'
+ ffmpeg_path = '/home/hc3295/ffmpeg_build/bin'
+ os.environ['PATH'] += os.pathsep + ffmpeg_path
+
+
+import shutil
+import tempfile
+import time
+from pathlib import Path
+
+import librosa
+import torch
+from huggingface_hub import snapshot_download
+
+from fam.llm.adapters import FlattenedInterleavedEncodec2Codebook
+from fam.llm.decoders import EncodecDecoder
+from fam.llm.fast_inference_utils import build_model, main
+from fam.llm.inference import (
+ EncodecDecoder,
+ InferenceConfig,
+ Model,
+ TiltedEncodec,
+ TrainedBPETokeniser,
+ get_cached_embedding,
+ get_cached_file,
+ get_enhancer,
+)
+from fam.llm.utils import (
+ check_audio_file,
+ get_default_dtype,
+ get_device,
+ normalize_text,
+)
+
+debug = False
+if not debug:
+ model_name = "metavoiceio/metavoice-1B-v0.1"
+ seed = 1337
+ output_dir = "outputs"
+ _dtype = get_default_dtype()
+ _device = 'cuda:0'
+ _model_dir = snapshot_download(repo_id=model_name)
+ first_stage_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=1024)
+ output_dir = output_dir
+ os.makedirs(output_dir, exist_ok=True)
+
+ second_stage_ckpt_path = f"{_model_dir}/second_stage.pt"
+ config_second_stage = InferenceConfig(
+ ckpt_path=second_stage_ckpt_path,
+ num_samples=1,
+ seed=seed,
+ device=_device,
+ dtype=_dtype,
+ compile=False,
+ init_from="resume",
+ output_dir=output_dir,
+ )
+ data_adapter_second_stage = TiltedEncodec(end_of_audio_token=1024)
+ llm_second_stage = Model(
+ config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode
+ )
+ enhancer = get_enhancer("df")
+
+ precision = {"float16": torch.float16, "bfloat16": torch.bfloat16}[_dtype]
+ model, tokenizer, smodel, model_size = build_model(
+ precision=precision,
+ checkpoint_path=Path(f"{_model_dir}/first_stage.pt"),
+ spk_emb_ckpt_path=Path(f"{_model_dir}/speaker_encoder.pt"),
+ device=_device,
+ compile=True,
+ compile_prefill=True,
+ )
+
+
+def generate_sample(text, emo_dir = None, source_path = None, emo_path = None, neutral_path = None, strength = 0.1, top_p = 0.95, guidance_scale = 3.0, preset_dropdown = None, toggle = None):
+
+ print('text', text)
+ print('emo_dir', emo_dir)
+ print('source_path', source_path)
+ print('emo_path', emo_path)
+ print('neutral_path', neutral_path)
+ print('strength', strength)
+ print('top_p', top_p)
+ print('guidance_scale', guidance_scale)
+
+ if toggle == RADIO_CHOICES[0]:
+ source_path = PRESET_VOICES[preset_dropdown]
+ source_path = get_cached_file(source_path)
+ check_audio_file(source_path)
+ source_emb = get_cached_embedding(source_path, smodel).to(device=_device, dtype=precision)
+
+ if emo_dir == EMO_NAMES[0]:
+ emo_path = get_cached_file(emo_path)
+ check_audio_file(emo_path)
+ emo_emb = get_cached_embedding(emo_path, smodel).to(device=_device, dtype=precision)
+
+ neutral_path = get_cached_file(neutral_path)
+ check_audio_file(neutral_path)
+ neutral_emb = get_cached_embedding(neutral_path, smodel).to(device=_device, dtype=precision)
+
+ emo_dir = emo_emb - neutral_emb
+ emo_dir = emo_dir / torch.norm(emo_dir, p=2)
+ else:
+ emo_dir = torch.tensor(ALL_EMO_DIRS[emo_dir], device=_device, dtype=precision)
+
+
+ edited_emb = source_emb + strength * emo_dir
+ edited_emb = edited_emb.to(device=_device, dtype=precision)
+
+ temperature=1.0
+ text = normalize_text(text)
+
+ start = time.time()
+ # first stage LLM
+ tokens = main(
+ model=model,
+ tokenizer=tokenizer,
+ model_size=model_size,
+ prompt=text,
+ spk_emb=edited_emb,
+ top_p=torch.tensor(top_p, device=_device, dtype=precision),
+ guidance_scale=torch.tensor(guidance_scale, device=_device, dtype=precision),
+ temperature=torch.tensor(temperature, device=_device, dtype=precision),
+ )
+ text_ids, extracted_audio_ids = first_stage_adapter.decode([tokens])
+
+ b_speaker_embs = edited_emb.unsqueeze(0)
+
+ # second stage LLM + multi-band diffusion model
+ wav_files = llm_second_stage(
+ texts=[text],
+ encodec_tokens=[torch.tensor(extracted_audio_ids, dtype=torch.int32, device=_device).unsqueeze(0)],
+ speaker_embs=b_speaker_embs,
+ batch_size=1,
+ guidance_scale=None,
+ top_p=None,
+ top_k=200,
+ temperature=1.0,
+ max_new_tokens=None,
+ )
+
+ wav_file = wav_files[0]
+ with tempfile.NamedTemporaryFile(suffix=".wav") as enhanced_tmp:
+ enhancer(str(wav_file) + ".wav", enhanced_tmp.name)
+ shutil.copy2(enhanced_tmp.name, str(wav_file) + ".wav")
+ print(f"\nSaved audio to {wav_file}.wav")
+
+ output_path = str(wav_file) + ".wav"
+ return output_path
+
+
+ALL_EMO_DIRS = pickle.load(open('all_emo_dirs.pkl', 'rb'))
+EMO_NAMES = ['Upload your own sample'] + list(ALL_EMO_DIRS.keys())
+
+RADIO_CHOICES = ["Preset voices", "Upload your voice"]
+MAX_CHARS = 220
+PRESET_VOICES = {
+ # female
+ "Bria": "https://cdn.themetavoice.xyz/speakers%2Fbria.mp3",
+ # male
+ "Alex": "https://cdn.themetavoice.xyz/speakers/alex.mp3",
+ "Jacob": "https://cdn.themetavoice.xyz/speakers/jacob.wav",
+}
+
+
+def denormalise_top_p(top_p):
+ # returns top_p in the range [0.9, 1.0]
+ return round(0.9 + top_p / 100, 2)
+
+
+def denormalise_guidance(guidance):
+ # returns guidance in the range [1.0, 3.0]
+ return 1 + ((guidance - 1) * (3 - 1)) / (5 - 1)
+
+
+def _check_file_size(path):
+ if not path:
+ return
+ filesize = os.path.getsize(path)
+ filesize_mb = filesize / 1024 / 1024
+ if filesize_mb >= 50:
+ raise gr.Error(f"Please upload a sample less than 20MB for voice cloning. Provided: {round(filesize_mb)} MB")
+
+
+def _handle_edge_cases(to_say, upload_target):
+ if not to_say:
+ raise gr.Error("Please provide text to synthesise")
+
+ if len(to_say) > MAX_CHARS:
+ gr.Warning(
+ f"Max {MAX_CHARS} characters allowed. Provided: {len(to_say)} characters. Truncating and generating speech...Result at the end can be unstable as a result."
+ )
+
+ if not upload_target:
+ return
+
+ check_audio_file(upload_target) # check file duration to be atleast 30s
+ _check_file_size(upload_target)
+
+
+def tts(to_say, top_p, guidance, toggle, preset_dropdown, upload_target):
+ try:
+ d_top_p = denormalise_top_p(top_p)
+ d_guidance = denormalise_guidance(guidance)
+
+ _handle_edge_cases(to_say, upload_target)
+
+ to_say = to_say if len(to_say) < MAX_CHARS else to_say[:MAX_CHARS]
+
+ return TTS_MODEL.synthesise(
+ text=to_say,
+ spk_ref_path=PRESET_VOICES[preset_dropdown] if toggle == RADIO_CHOICES[0] else upload_target,
+ top_p=d_top_p,
+ guidance_scale=d_guidance,
+ )
+ except Exception as e:
+ raise gr.Error(f"Something went wrong. Reason: {str(e)}")
+
+
+def change_voice_selection_layout(choice):
+ if choice == RADIO_CHOICES[0]:
+ return [gr.update(visible=True), gr.update(visible=False)]
+
+ return [gr.update(visible=False), gr.update(visible=True)]
+
+def change_emotion_selection_layout(choice):
+ if choice == EMO_NAMES[0]:
+ return [gr.update(visible=True)]
+
+ return [gr.update(visible=False)]
+
+title = """
+
+
Demo for 🎛️ EmoKnob
+"""
+
+description = """
+- While existing TTS services do not allow fine-grained control over emotions, EmoKnob allows users to control emotion in speech with few-shot samples.
+- In this demo, you can select from a few preset voices and upload your own emotional samples to clone.
+- You can then use preset emotion or upload your own emotional-neutral sample pair to control emotions.
+- You can adjust the strength of the emotion by using the slider.
+
+
+EmoKnob is uses [MetaVoice](https://github.com/metavoiceio/metavoice-src) as voice cloning backbone.
+"""
+
+with gr.Blocks(title="EmoKnob Demo") as demo:
+ gr.Markdown(title)
+ gr.Image("emo-knob-teaser-1.svg", show_label=False, container=False)
+
+ with gr.Row():
+ gr.Markdown(description)
+
+ with gr.Row():
+ with gr.Column():
+ to_say = gr.TextArea(
+ label=f"What should I say!? (max {MAX_CHARS} characters).",
+ lines=4,
+ value="To be or not to be, that is the question.",
+ )
+
+
+
+ with gr.Row(), gr.Column():
+ # voice settings
+ top_p = gr.Slider(
+ value=0.95,
+ minimum=0.0,
+ maximum=10.0,
+ step=1.0,
+ label="Speech Stability - improves text following for a challenging speaker",
+ )
+ guidance = gr.Slider(
+ value=3.0,
+ minimum=1.0,
+ maximum=5.0,
+ step=1.0,
+ label="Speaker similarity - How closely to match speaker identity and speech style.",
+ )
+
+ strength = gr.Slider(
+ value=0.1,
+ minimum=0.0,
+ maximum=5.0,
+ step=0.01,
+ label="Strength - how strong the emotion is. Setting it to too large a value may result in unstable output.",
+ )
+
+
+
+ # voice select
+ toggle = gr.Radio(choices=RADIO_CHOICES, label="Choose voice", value=RADIO_CHOICES[0])
+
+ with gr.Row(visible=True) as row_1:
+ preset_dropdown = gr.Dropdown(
+ PRESET_VOICES.keys(), label="Preset voices", value=list(PRESET_VOICES.keys())[0]
+ )
+ with gr.Accordion("Preview: Preset voices", open=False):
+ for label, path in PRESET_VOICES.items():
+ gr.Audio(value=path, label=label)
+
+ with gr.Row(visible=False) as row_2:
+ upload_target = gr.Audio(
+ sources=["upload"],
+ type="filepath",
+ label="Upload a clean sample to clone.",
+ )
+ with gr.Row():
+ emotion_name = gr.Radio(choices=EMO_NAMES, label="Emotion", value=EMO_NAMES[0])
+ with gr.Row(visible=True) as row_3:
+ upload_neutral = gr.Audio(
+ sources=["upload"],
+ type="filepath",
+ label="Upload a neutral sample to compute the emotion direction. Should be same speaker as the emotional sample.",
+ )
+
+ upload_emo = gr.Audio(
+ sources=["upload"],
+ type="filepath",
+ label="Upload an emotional sample to compute the emotion direction. Should be same speaker as the neutral sample.",
+ )
+
+ toggle.change(
+ change_voice_selection_layout,
+ inputs=toggle,
+ outputs=[row_1, row_2],
+ )
+
+ # emotion_name.change(
+ # change_emotion_selection_layout,
+ # inputs=emotion_name,
+ # outputs=[row_3],
+ # )
+
+ with gr.Column():
+ speech = gr.Audio(
+ type="filepath",
+ label="Model says...",
+ )
+
+ submit = gr.Button("Generate Speech")
+ submit.click(
+ fn=generate_sample,
+ inputs=[to_say, emotion_name, upload_target, upload_emo, upload_neutral, strength, top_p, guidance, preset_dropdown, toggle],
+ outputs=speech,
+ )
+
+demo.launch()
\ No newline at end of file
diff --git a/assets/favicon.ico b/assets/favicon.ico
new file mode 100644
index 0000000000000000000000000000000000000000..490e482591afd1626e1f01754917f9a27c8597e0
Binary files /dev/null and b/assets/favicon.ico differ
diff --git a/assets/logo.png b/assets/logo.png
new file mode 100644
index 0000000000000000000000000000000000000000..1d5a10eb1e55801bd7dd22b8750cc416acae29c0
Binary files /dev/null and b/assets/logo.png differ
diff --git a/cache/.locks/models--Salesforce--SFR-Embedding-Mistral/42dcdfcaf9e42a488d4be06500dd771d7aa11e83.lock b/cache/.locks/models--Salesforce--SFR-Embedding-Mistral/42dcdfcaf9e42a488d4be06500dd771d7aa11e83.lock
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/cache/.locks/models--Salesforce--SFR-Embedding-Mistral/afbfcebcf9df8c0af538cd5b6f616bd1d7a9739eba4b81d871545b1b562d6b0a.lock b/cache/.locks/models--Salesforce--SFR-Embedding-Mistral/afbfcebcf9df8c0af538cd5b6f616bd1d7a9739eba4b81d871545b1b562d6b0a.lock
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/cache/.locks/models--Salesforce--SFR-Embedding-Mistral/c19160bba3c1267f959caf6d13fb07f9ea232e04.lock b/cache/.locks/models--Salesforce--SFR-Embedding-Mistral/c19160bba3c1267f959caf6d13fb07f9ea232e04.lock
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/cache/.locks/models--Salesforce--SFR-Embedding-Mistral/ef62bf21fb2396937098b86ae80c68813b229c18.lock b/cache/.locks/models--Salesforce--SFR-Embedding-Mistral/ef62bf21fb2396937098b86ae80c68813b229c18.lock
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/cache/.locks/models--Salesforce--SFR-Embedding-Mistral/f7640f94e81bb7f4f04daf1668850b38763a13d9.lock b/cache/.locks/models--Salesforce--SFR-Embedding-Mistral/f7640f94e81bb7f4f04daf1668850b38763a13d9.lock
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/cache/.locks/models--Salesforce--SFR-Embedding-Mistral/f8194e4e9432d287bf257d4a7d4a0f2446c32da8.lock b/cache/.locks/models--Salesforce--SFR-Embedding-Mistral/f8194e4e9432d287bf257d4a7d4a0f2446c32da8.lock
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/cache/.locks/models--Salesforce--SFR-Embedding-Mistral/feb95adc7e79e878999ba5a1d3ddfe9f16eff0f1.lock b/cache/.locks/models--Salesforce--SFR-Embedding-Mistral/feb95adc7e79e878999ba5a1d3ddfe9f16eff0f1.lock
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/cache/models--Salesforce--SFR-Embedding-Mistral/.no_exist/938c560d1c236aa563b2dbdf084f28ab28bccb11/model.safetensors b/cache/models--Salesforce--SFR-Embedding-Mistral/.no_exist/938c560d1c236aa563b2dbdf084f28ab28bccb11/model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/cache/models--Salesforce--SFR-Embedding-Mistral/blobs/42dcdfcaf9e42a488d4be06500dd771d7aa11e83 b/cache/models--Salesforce--SFR-Embedding-Mistral/blobs/42dcdfcaf9e42a488d4be06500dd771d7aa11e83
new file mode 100644
index 0000000000000000000000000000000000000000..42dcdfcaf9e42a488d4be06500dd771d7aa11e83
--- /dev/null
+++ b/cache/models--Salesforce--SFR-Embedding-Mistral/blobs/42dcdfcaf9e42a488d4be06500dd771d7aa11e83
@@ -0,0 +1,4 @@
+{
+ "max_seq_length": 4096,
+ "do_lower_case": false
+}
\ No newline at end of file
diff --git a/cache/models--Salesforce--SFR-Embedding-Mistral/blobs/c19160bba3c1267f959caf6d13fb07f9ea232e04 b/cache/models--Salesforce--SFR-Embedding-Mistral/blobs/c19160bba3c1267f959caf6d13fb07f9ea232e04
new file mode 100644
index 0000000000000000000000000000000000000000..c19160bba3c1267f959caf6d13fb07f9ea232e04
--- /dev/null
+++ b/cache/models--Salesforce--SFR-Embedding-Mistral/blobs/c19160bba3c1267f959caf6d13fb07f9ea232e04
@@ -0,0 +1,27 @@
+{
+ "_name_or_path": "intfloat/e5-mistral-7b-instruct",
+ "architectures": [
+ "MistralModel"
+ ],
+ "attention_dropout": 0.0,
+ "bos_token_id": 1,
+ "eos_token_id": 2,
+ "hidden_act": "silu",
+ "hidden_size": 4096,
+ "initializer_range": 0.02,
+ "intermediate_size": 14336,
+ "max_position_embeddings": 32768,
+ "model_type": "mistral",
+ "num_attention_heads": 32,
+ "num_hidden_layers": 32,
+ "num_key_value_heads": 8,
+ "pad_token_id": 2,
+ "rms_norm_eps": 1e-05,
+ "rope_theta": 10000.0,
+ "sliding_window": 4096,
+ "tie_word_embeddings": false,
+ "torch_dtype": "float16",
+ "transformers_version": "4.37.0",
+ "use_cache": false,
+ "vocab_size": 32000
+}
diff --git a/cache/models--Salesforce--SFR-Embedding-Mistral/blobs/ef62bf21fb2396937098b86ae80c68813b229c18 b/cache/models--Salesforce--SFR-Embedding-Mistral/blobs/ef62bf21fb2396937098b86ae80c68813b229c18
new file mode 100644
index 0000000000000000000000000000000000000000..ef62bf21fb2396937098b86ae80c68813b229c18
--- /dev/null
+++ b/cache/models--Salesforce--SFR-Embedding-Mistral/blobs/ef62bf21fb2396937098b86ae80c68813b229c18
@@ -0,0 +1,7 @@
+{
+ "__version__": {
+ "sentence_transformers": "2.2.2",
+ "transformers": "4.37.2",
+ "pytorch": "2.1.0+cu121"
+ }
+}
\ No newline at end of file
diff --git a/cache/models--Salesforce--SFR-Embedding-Mistral/blobs/f7640f94e81bb7f4f04daf1668850b38763a13d9 b/cache/models--Salesforce--SFR-Embedding-Mistral/blobs/f7640f94e81bb7f4f04daf1668850b38763a13d9
new file mode 100644
index 0000000000000000000000000000000000000000..f7640f94e81bb7f4f04daf1668850b38763a13d9
--- /dev/null
+++ b/cache/models--Salesforce--SFR-Embedding-Mistral/blobs/f7640f94e81bb7f4f04daf1668850b38763a13d9
@@ -0,0 +1,14 @@
+[
+ {
+ "idx": 0,
+ "name": "0",
+ "path": "",
+ "type": "sentence_transformers.models.Transformer"
+ },
+ {
+ "idx": 1,
+ "name": "1",
+ "path": "1_Pooling",
+ "type": "sentence_transformers.models.Pooling"
+ }
+]
\ No newline at end of file
diff --git a/cache/models--Salesforce--SFR-Embedding-Mistral/blobs/f8194e4e9432d287bf257d4a7d4a0f2446c32da8 b/cache/models--Salesforce--SFR-Embedding-Mistral/blobs/f8194e4e9432d287bf257d4a7d4a0f2446c32da8
new file mode 100644
index 0000000000000000000000000000000000000000..f8194e4e9432d287bf257d4a7d4a0f2446c32da8
--- /dev/null
+++ b/cache/models--Salesforce--SFR-Embedding-Mistral/blobs/f8194e4e9432d287bf257d4a7d4a0f2446c32da8
@@ -0,0 +1,297 @@
+{
+ "metadata": {
+ "total_size": 14221320192
+ },
+ "weight_map": {
+ "embed_tokens.weight": "model-00001-of-00003.safetensors",
+ "layers.0.input_layernorm.weight": "model-00001-of-00003.safetensors",
+ "layers.0.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.0.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.0.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.0.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
+ "layers.0.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.0.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.0.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.0.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.1.input_layernorm.weight": "model-00001-of-00003.safetensors",
+ "layers.1.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.1.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.1.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.1.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
+ "layers.1.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.1.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.1.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.1.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.10.input_layernorm.weight": "model-00002-of-00003.safetensors",
+ "layers.10.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.10.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.10.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.10.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
+ "layers.10.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.10.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.10.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.10.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.11.input_layernorm.weight": "model-00002-of-00003.safetensors",
+ "layers.11.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.11.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.11.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.11.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
+ "layers.11.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.11.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.11.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.11.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.12.input_layernorm.weight": "model-00002-of-00003.safetensors",
+ "layers.12.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.12.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.12.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.12.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
+ "layers.12.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.12.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.12.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.12.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.13.input_layernorm.weight": "model-00002-of-00003.safetensors",
+ "layers.13.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.13.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.13.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.13.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
+ "layers.13.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.13.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.13.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.13.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.14.input_layernorm.weight": "model-00002-of-00003.safetensors",
+ "layers.14.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.14.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.14.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.14.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
+ "layers.14.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.14.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.14.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.14.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.15.input_layernorm.weight": "model-00002-of-00003.safetensors",
+ "layers.15.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.15.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.15.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.15.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
+ "layers.15.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.15.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.15.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.15.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.16.input_layernorm.weight": "model-00002-of-00003.safetensors",
+ "layers.16.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.16.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.16.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.16.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
+ "layers.16.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.16.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.16.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.16.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.17.input_layernorm.weight": "model-00002-of-00003.safetensors",
+ "layers.17.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.17.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.17.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.17.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
+ "layers.17.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.17.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.17.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.17.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.18.input_layernorm.weight": "model-00002-of-00003.safetensors",
+ "layers.18.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.18.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.18.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.18.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
+ "layers.18.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.18.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.18.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.18.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.19.input_layernorm.weight": "model-00002-of-00003.safetensors",
+ "layers.19.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.19.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.19.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.19.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
+ "layers.19.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.19.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.19.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.19.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.2.input_layernorm.weight": "model-00001-of-00003.safetensors",
+ "layers.2.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.2.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.2.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.2.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
+ "layers.2.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.2.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.2.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.2.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.20.input_layernorm.weight": "model-00002-of-00003.safetensors",
+ "layers.20.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.20.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.20.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.20.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
+ "layers.20.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.20.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.20.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.20.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.21.input_layernorm.weight": "model-00002-of-00003.safetensors",
+ "layers.21.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.21.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.21.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.21.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
+ "layers.21.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.21.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.21.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.21.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.22.input_layernorm.weight": "model-00003-of-00003.safetensors",
+ "layers.22.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.22.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.22.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.22.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
+ "layers.22.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.22.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.22.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.22.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
+ "layers.23.input_layernorm.weight": "model-00003-of-00003.safetensors",
+ "layers.23.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.23.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.23.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.23.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
+ "layers.23.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.23.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.23.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.23.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.24.input_layernorm.weight": "model-00003-of-00003.safetensors",
+ "layers.24.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.24.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.24.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.24.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
+ "layers.24.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.24.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.24.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.24.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.25.input_layernorm.weight": "model-00003-of-00003.safetensors",
+ "layers.25.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.25.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.25.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.25.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
+ "layers.25.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.25.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.25.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.25.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.26.input_layernorm.weight": "model-00003-of-00003.safetensors",
+ "layers.26.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.26.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.26.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.26.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
+ "layers.26.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.26.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.26.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.26.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.27.input_layernorm.weight": "model-00003-of-00003.safetensors",
+ "layers.27.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.27.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.27.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.27.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
+ "layers.27.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.27.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.27.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.27.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.28.input_layernorm.weight": "model-00003-of-00003.safetensors",
+ "layers.28.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.28.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.28.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.28.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
+ "layers.28.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.28.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.28.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.28.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.29.input_layernorm.weight": "model-00003-of-00003.safetensors",
+ "layers.29.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.29.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.29.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.29.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
+ "layers.29.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.29.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.29.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.29.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.3.input_layernorm.weight": "model-00001-of-00003.safetensors",
+ "layers.3.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.3.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.3.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.3.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
+ "layers.3.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.3.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.3.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.3.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.30.input_layernorm.weight": "model-00003-of-00003.safetensors",
+ "layers.30.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.30.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.30.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.30.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
+ "layers.30.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.30.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.30.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.30.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.31.input_layernorm.weight": "model-00003-of-00003.safetensors",
+ "layers.31.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.31.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.31.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.31.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
+ "layers.31.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.31.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.31.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.31.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
+ "layers.4.input_layernorm.weight": "model-00001-of-00003.safetensors",
+ "layers.4.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.4.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.4.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.4.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
+ "layers.4.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.4.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.4.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.4.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.5.input_layernorm.weight": "model-00001-of-00003.safetensors",
+ "layers.5.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.5.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.5.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.5.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
+ "layers.5.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.5.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.5.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.5.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.6.input_layernorm.weight": "model-00001-of-00003.safetensors",
+ "layers.6.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.6.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.6.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.6.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
+ "layers.6.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.6.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.6.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.6.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.7.input_layernorm.weight": "model-00001-of-00003.safetensors",
+ "layers.7.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.7.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.7.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.7.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
+ "layers.7.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.7.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.7.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.7.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.8.input_layernorm.weight": "model-00001-of-00003.safetensors",
+ "layers.8.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.8.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.8.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.8.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
+ "layers.8.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.8.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.8.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.8.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.9.input_layernorm.weight": "model-00001-of-00003.safetensors",
+ "layers.9.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.9.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.9.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.9.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
+ "layers.9.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.9.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.9.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
+ "layers.9.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
+ "norm.weight": "model-00003-of-00003.safetensors"
+ }
+}
diff --git a/cache/models--Salesforce--SFR-Embedding-Mistral/blobs/feb95adc7e79e878999ba5a1d3ddfe9f16eff0f1 b/cache/models--Salesforce--SFR-Embedding-Mistral/blobs/feb95adc7e79e878999ba5a1d3ddfe9f16eff0f1
new file mode 100644
index 0000000000000000000000000000000000000000..feb95adc7e79e878999ba5a1d3ddfe9f16eff0f1
--- /dev/null
+++ b/cache/models--Salesforce--SFR-Embedding-Mistral/blobs/feb95adc7e79e878999ba5a1d3ddfe9f16eff0f1
@@ -0,0 +1,3398 @@
+---
+tags:
+- mteb
+- sentence-transformers
+- transformers
+model-index:
+- name: SFR-Embedding-Mistral
+ results:
+ - task:
+ type: Classification
+ dataset:
+ type: mteb/amazon_counterfactual
+ name: MTEB AmazonCounterfactualClassification (en)
+ config: en
+ split: test
+ revision: e8379541af4e31359cca9fbcf4b00f2671dba205
+ metrics:
+ - type: accuracy
+ value: 77.92537313432834
+ - type: ap
+ value: 40.86767661556651
+ - type: f1
+ value: 71.65758897929837
+ - task:
+ type: Classification
+ dataset:
+ type: mteb/amazon_polarity
+ name: MTEB AmazonPolarityClassification
+ config: default
+ split: test
+ revision: e2d317d38cd51312af73b3d32a06d1a08b442046
+ metrics:
+ - type: accuracy
+ value: 95.967
+ - type: ap
+ value: 94.46300829592593
+ - type: f1
+ value: 95.96507173189292
+ - task:
+ type: Classification
+ dataset:
+ type: mteb/amazon_reviews_multi
+ name: MTEB AmazonReviewsClassification (en)
+ config: en
+ split: test
+ revision: 1399c76144fd37290681b995c656ef9b2e06e26d
+ metrics:
+ - type: accuracy
+ value: 54.352000000000004
+ - type: f1
+ value: 53.636682615380174
+ - task:
+ type: Retrieval
+ dataset:
+ type: arguana
+ name: MTEB ArguAna
+ config: default
+ split: test
+ revision: None
+ metrics:
+ - type: ndcg_at_1
+ value: 43.314
+ - type: ndcg_at_2
+ value: 54.757
+ - type: ndcg_at_3
+ value: 58.84700000000001
+ - type: ndcg_at_5
+ value: 63.634
+ - type: ndcg_at_7
+ value: 65.741
+ - type: ndcg_at_10
+ value: 67.171
+ - type: ndcg_at_20
+ value: 68.585
+ - type: ndcg_at_30
+ value: 68.81
+ - type: ndcg_at_50
+ value: 68.932
+ - type: ndcg_at_70
+ value: 68.992
+ - type: ndcg_at_100
+ value: 69.014
+ - type: ndcg_at_200
+ value: 69.014
+ - type: ndcg_at_300
+ value: 69.014
+ - type: ndcg_at_500
+ value: 69.014
+ - type: ndcg_at_700
+ value: 69.014
+ - type: ndcg_at_1000
+ value: 69.014
+ - type: map_at_1
+ value: 43.314
+ - type: map_at_2
+ value: 52.383
+ - type: map_at_3
+ value: 55.108999999999995
+ - type: map_at_5
+ value: 57.772999999999996
+ - type: map_at_7
+ value: 58.718
+ - type: map_at_10
+ value: 59.256
+ - type: map_at_20
+ value: 59.668
+ - type: map_at_30
+ value: 59.709999999999994
+ - type: map_at_50
+ value: 59.727
+ - type: map_at_70
+ value: 59.733999999999995
+ - type: map_at_100
+ value: 59.73500000000001
+ - type: map_at_200
+ value: 59.73500000000001
+ - type: map_at_300
+ value: 59.73500000000001
+ - type: map_at_500
+ value: 59.73500000000001
+ - type: map_at_700
+ value: 59.73500000000001
+ - type: map_at_1000
+ value: 59.73500000000001
+ - type: recall_at_1
+ value: 43.314
+ - type: recall_at_2
+ value: 61.451
+ - type: recall_at_3
+ value: 69.63000000000001
+ - type: recall_at_5
+ value: 81.223
+ - type: recall_at_7
+ value: 87.33999999999999
+ - type: recall_at_10
+ value: 92.034
+ - type: recall_at_20
+ value: 97.44
+ - type: recall_at_30
+ value: 98.506
+ - type: recall_at_50
+ value: 99.14699999999999
+ - type: recall_at_70
+ value: 99.502
+ - type: recall_at_100
+ value: 99.644
+ - type: recall_at_200
+ value: 99.644
+ - type: recall_at_300
+ value: 99.644
+ - type: recall_at_500
+ value: 99.644
+ - type: recall_at_700
+ value: 99.644
+ - type: recall_at_1000
+ value: 99.644
+ - type: precision_at_1
+ value: 43.314
+ - type: precision_at_2
+ value: 30.725
+ - type: precision_at_3
+ value: 23.21
+ - type: precision_at_5
+ value: 16.245
+ - type: precision_at_7
+ value: 12.477
+ - type: precision_at_10
+ value: 9.203
+ - type: precision_at_20
+ value: 4.872
+ - type: precision_at_30
+ value: 3.2840000000000003
+ - type: precision_at_50
+ value: 1.983
+ - type: precision_at_70
+ value: 1.421
+ - type: precision_at_100
+ value: 0.996
+ - type: precision_at_200
+ value: 0.498
+ - type: precision_at_300
+ value: 0.332
+ - type: precision_at_500
+ value: 0.199
+ - type: precision_at_700
+ value: 0.14200000000000002
+ - type: precision_at_1000
+ value: 0.1
+ - type: mrr_at_1
+ value: 44.666
+ - type: mrr_at_2
+ value: 52.418
+ - type: mrr_at_3
+ value: 55.595000000000006
+ - type: mrr_at_5
+ value: 58.205
+ - type: mrr_at_7
+ value: 59.202999999999996
+ - type: mrr_at_10
+ value: 59.727
+ - type: mrr_at_20
+ value: 60.133
+ - type: mrr_at_30
+ value: 60.178
+ - type: mrr_at_50
+ value: 60.192
+ - type: mrr_at_70
+ value: 60.19799999999999
+ - type: mrr_at_100
+ value: 60.199999999999996
+ - type: mrr_at_200
+ value: 60.199999999999996
+ - type: mrr_at_300
+ value: 60.199999999999996
+ - type: mrr_at_500
+ value: 60.199999999999996
+ - type: mrr_at_700
+ value: 60.199999999999996
+ - type: mrr_at_1000
+ value: 60.199999999999996
+ - task:
+ type: Clustering
+ dataset:
+ type: mteb/arxiv-clustering-p2p
+ name: MTEB ArxivClusteringP2P
+ config: default
+ split: test
+ revision: a122ad7f3f0291bf49cc6f4d32aa80929df69d5d
+ metrics:
+ - type: v_measure
+ value: 52.07508593014336
+ - task:
+ type: Clustering
+ dataset:
+ type: mteb/arxiv-clustering-s2s
+ name: MTEB ArxivClusteringS2S
+ config: default
+ split: test
+ revision: f910caf1a6075f7329cdf8c1a6135696f37dbd53
+ metrics:
+ - type: v_measure
+ value: 47.381339333240675
+ - task:
+ type: Reranking
+ dataset:
+ type: mteb/askubuntudupquestions-reranking
+ name: MTEB AskUbuntuDupQuestions
+ config: default
+ split: test
+ revision: 2000358ca161889fa9c082cb41daa8dcfb161a54
+ metrics:
+ - type: map
+ value: 67.58376647859171
+ - type: mrr
+ value: 80.56885635140483
+ - task:
+ type: STS
+ dataset:
+ type: mteb/biosses-sts
+ name: MTEB BIOSSES
+ config: default
+ split: test
+ revision: d3fb88f8f02e40887cd149695127462bbcf29b4a
+ metrics:
+ - type: cos_sim_pearson
+ value: 88.40107280274783
+ - type: cos_sim_spearman
+ value: 86.07003345325681
+ - type: euclidean_pearson
+ value: 87.1726034325395
+ - type: euclidean_spearman
+ value: 86.07003345325681
+ - type: manhattan_pearson
+ value: 87.25660625029772
+ - type: manhattan_spearman
+ value: 86.3808839096893
+ - task:
+ type: Classification
+ dataset:
+ type: mteb/banking77
+ name: MTEB Banking77Classification
+ config: default
+ split: test
+ revision: 0fd18e25b25c072e09e0d92ab615fda904d66300
+ metrics:
+ - type: accuracy
+ value: 88.81168831168831
+ - type: f1
+ value: 88.76514496560141
+ - task:
+ type: Clustering
+ dataset:
+ type: mteb/biorxiv-clustering-p2p
+ name: MTEB BiorxivClusteringP2P
+ config: default
+ split: test
+ revision: 65b79d1d13f80053f67aca9498d9402c2d9f1f40
+ metrics:
+ - type: v_measure
+ value: 43.9382520874344
+ - task:
+ type: Clustering
+ dataset:
+ type: mteb/biorxiv-clustering-s2s
+ name: MTEB BiorxivClusteringS2S
+ config: default
+ split: test
+ revision: 258694dd0231531bc1fd9de6ceb52a0853c6d908
+ metrics:
+ - type: v_measure
+ value: 41.14351847240913
+ - task:
+ type: Retrieval
+ dataset:
+ type: BeIR/cqadupstack
+ name: MTEB CQADupstackRetrieval
+ config: default
+ split: test
+ revision: None
+ metrics:
+ - type: ndcg_at_1
+ value: 34.51166666666667
+ - type: ndcg_at_2
+ value: 38.51591666666667
+ - type: ndcg_at_3
+ value: 40.95083333333333
+ - type: ndcg_at_5
+ value: 43.580666666666666
+ - type: ndcg_at_7
+ value: 45.0625
+ - type: ndcg_at_10
+ value: 46.49083333333333
+ - type: ndcg_at_20
+ value: 48.731333333333325
+ - type: ndcg_at_30
+ value: 49.78666666666667
+ - type: ndcg_at_50
+ value: 50.84049999999999
+ - type: ndcg_at_70
+ value: 51.393750000000004
+ - type: ndcg_at_100
+ value: 51.883333333333326
+ - type: ndcg_at_200
+ value: 52.65225
+ - type: ndcg_at_300
+ value: 52.98241666666669
+ - type: ndcg_at_500
+ value: 53.28541666666668
+ - type: ndcg_at_700
+ value: 53.49241666666668
+ - type: ndcg_at_1000
+ value: 53.63758333333334
+ - type: map_at_1
+ value: 29.10075
+ - type: map_at_2
+ value: 34.636500000000005
+ - type: map_at_3
+ value: 36.92033333333333
+ - type: map_at_5
+ value: 38.81641666666666
+ - type: map_at_7
+ value: 39.635416666666664
+ - type: map_at_10
+ value: 40.294583333333335
+ - type: map_at_20
+ value: 41.07574999999999
+ - type: map_at_30
+ value: 41.333
+ - type: map_at_50
+ value: 41.529333333333334
+ - type: map_at_70
+ value: 41.606833333333334
+ - type: map_at_100
+ value: 41.66224999999999
+ - type: map_at_200
+ value: 41.72691666666666
+ - type: map_at_300
+ value: 41.746583333333334
+ - type: map_at_500
+ value: 41.75983333333333
+ - type: map_at_700
+ value: 41.76558333333333
+ - type: map_at_1000
+ value: 41.769000000000005
+ - type: recall_at_1
+ value: 29.10075
+ - type: recall_at_2
+ value: 39.07658333333333
+ - type: recall_at_3
+ value: 44.93591666666667
+ - type: recall_at_5
+ value: 51.66883333333333
+ - type: recall_at_7
+ value: 55.881000000000014
+ - type: recall_at_10
+ value: 60.34691666666667
+ - type: recall_at_20
+ value: 68.44016666666667
+ - type: recall_at_30
+ value: 72.90766666666667
+ - type: recall_at_50
+ value: 77.843
+ - type: recall_at_70
+ value: 80.70366666666668
+ - type: recall_at_100
+ value: 83.42866666666667
+ - type: recall_at_200
+ value: 88.06816666666668
+ - type: recall_at_300
+ value: 90.249
+ - type: recall_at_500
+ value: 92.37616666666668
+ - type: recall_at_700
+ value: 93.978
+ - type: recall_at_1000
+ value: 95.12791666666666
+ - type: precision_at_1
+ value: 34.51166666666667
+ - type: precision_at_2
+ value: 24.326333333333327
+ - type: precision_at_3
+ value: 19.099249999999998
+ - type: precision_at_5
+ value: 13.672666666666666
+ - type: precision_at_7
+ value: 10.772
+ - type: precision_at_10
+ value: 8.302166666666668
+ - type: precision_at_20
+ value: 4.8960833333333325
+ - type: precision_at_30
+ value: 3.551083333333333
+ - type: precision_at_50
+ value: 2.3386666666666662
+ - type: precision_at_70
+ value: 1.7605833333333334
+ - type: precision_at_100
+ value: 1.2965
+ - type: precision_at_200
+ value: 0.7106666666666668
+ - type: precision_at_300
+ value: 0.4955
+ - type: precision_at_500
+ value: 0.3106666666666667
+ - type: precision_at_700
+ value: 0.22791666666666668
+ - type: precision_at_1000
+ value: 0.1635833333333333
+ - type: mrr_at_1
+ value: 34.51166666666667
+ - type: mrr_at_2
+ value: 39.954249999999995
+ - type: mrr_at_3
+ value: 41.93741666666668
+ - type: mrr_at_5
+ value: 43.487166666666674
+ - type: mrr_at_7
+ value: 44.14983333333333
+ - type: mrr_at_10
+ value: 44.62766666666666
+ - type: mrr_at_20
+ value: 45.15291666666668
+ - type: mrr_at_30
+ value: 45.317
+ - type: mrr_at_50
+ value: 45.42875
+ - type: mrr_at_70
+ value: 45.46966666666667
+ - type: mrr_at_100
+ value: 45.49716666666667
+ - type: mrr_at_200
+ value: 45.525166666666664
+ - type: mrr_at_300
+ value: 45.53233333333335
+ - type: mrr_at_500
+ value: 45.5365
+ - type: mrr_at_700
+ value: 45.538583333333335
+ - type: mrr_at_1000
+ value: 45.539583333333326
+ - task:
+ type: Retrieval
+ dataset:
+ type: climate-fever
+ name: MTEB ClimateFEVER
+ config: default
+ split: test
+ revision: None
+ metrics:
+ - type: ndcg_at_1
+ value: 35.179
+ - type: ndcg_at_2
+ value: 31.243
+ - type: ndcg_at_3
+ value: 30.562
+ - type: ndcg_at_5
+ value: 32.409
+ - type: ndcg_at_7
+ value: 34.525
+ - type: ndcg_at_10
+ value: 36.415
+ - type: ndcg_at_20
+ value: 39.443
+ - type: ndcg_at_30
+ value: 40.796
+ - type: ndcg_at_50
+ value: 42.16
+ - type: ndcg_at_70
+ value: 42.971
+ - type: ndcg_at_100
+ value: 43.691
+ - type: ndcg_at_200
+ value: 45.004
+ - type: ndcg_at_300
+ value: 45.527
+ - type: ndcg_at_500
+ value: 46.072
+ - type: ndcg_at_700
+ value: 46.387
+ - type: ndcg_at_1000
+ value: 46.663
+ - type: map_at_1
+ value: 15.692
+ - type: map_at_2
+ value: 20.116
+ - type: map_at_3
+ value: 22.6
+ - type: map_at_5
+ value: 24.701
+ - type: map_at_7
+ value: 25.934
+ - type: map_at_10
+ value: 26.843
+ - type: map_at_20
+ value: 27.975
+ - type: map_at_30
+ value: 28.372000000000003
+ - type: map_at_50
+ value: 28.671000000000003
+ - type: map_at_70
+ value: 28.803
+ - type: map_at_100
+ value: 28.895
+ - type: map_at_200
+ value: 29.011
+ - type: map_at_300
+ value: 29.042
+ - type: map_at_500
+ value: 29.065
+ - type: map_at_700
+ value: 29.075
+ - type: map_at_1000
+ value: 29.081000000000003
+ - type: recall_at_1
+ value: 15.692
+ - type: recall_at_2
+ value: 22.602
+ - type: recall_at_3
+ value: 27.814
+ - type: recall_at_5
+ value: 33.756
+ - type: recall_at_7
+ value: 38.073
+ - type: recall_at_10
+ value: 42.553000000000004
+ - type: recall_at_20
+ value: 51.121
+ - type: recall_at_30
+ value: 55.523999999999994
+ - type: recall_at_50
+ value: 60.586
+ - type: recall_at_70
+ value: 63.94
+ - type: recall_at_100
+ value: 67.134
+ - type: recall_at_200
+ value: 73.543
+ - type: recall_at_300
+ value: 76.372
+ - type: recall_at_500
+ value: 79.60199999999999
+ - type: recall_at_700
+ value: 81.536
+ - type: recall_at_1000
+ value: 83.37400000000001
+ - type: precision_at_1
+ value: 35.179
+ - type: precision_at_2
+ value: 27.199
+ - type: precision_at_3
+ value: 22.953000000000003
+ - type: precision_at_5
+ value: 17.224999999999998
+ - type: precision_at_7
+ value: 14.238999999999999
+ - type: precision_at_10
+ value: 11.303
+ - type: precision_at_20
+ value: 6.954000000000001
+ - type: precision_at_30
+ value: 5.116
+ - type: precision_at_50
+ value: 3.395
+ - type: precision_at_70
+ value: 2.579
+ - type: precision_at_100
+ value: 1.9109999999999998
+ - type: precision_at_200
+ value: 1.065
+ - type: precision_at_300
+ value: 0.743
+ - type: precision_at_500
+ value: 0.46699999999999997
+ - type: precision_at_700
+ value: 0.344
+ - type: precision_at_1000
+ value: 0.247
+ - type: mrr_at_1
+ value: 35.179
+ - type: mrr_at_2
+ value: 41.792
+ - type: mrr_at_3
+ value: 44.484
+ - type: mrr_at_5
+ value: 46.39
+ - type: mrr_at_7
+ value: 47.125
+ - type: mrr_at_10
+ value: 47.711999999999996
+ - type: mrr_at_20
+ value: 48.214
+ - type: mrr_at_30
+ value: 48.325
+ - type: mrr_at_50
+ value: 48.392
+ - type: mrr_at_70
+ value: 48.418
+ - type: mrr_at_100
+ value: 48.44
+ - type: mrr_at_200
+ value: 48.46
+ - type: mrr_at_300
+ value: 48.461999999999996
+ - type: mrr_at_500
+ value: 48.466
+ - type: mrr_at_700
+ value: 48.466
+ - type: mrr_at_1000
+ value: 48.467
+ - task:
+ type: Retrieval
+ dataset:
+ type: dbpedia-entity
+ name: MTEB DBPedia
+ config: default
+ split: test
+ revision: None
+ metrics:
+ - type: ndcg_at_1
+ value: 62.375
+ - type: ndcg_at_2
+ value: 56.286
+ - type: ndcg_at_3
+ value: 53.665
+ - type: ndcg_at_5
+ value: 51.139
+ - type: ndcg_at_7
+ value: 49.873
+ - type: ndcg_at_10
+ value: 49.056
+ - type: ndcg_at_20
+ value: 48.783
+ - type: ndcg_at_30
+ value: 49.166
+ - type: ndcg_at_50
+ value: 51.141999999999996
+ - type: ndcg_at_70
+ value: 52.774
+ - type: ndcg_at_100
+ value: 54.403
+ - type: ndcg_at_200
+ value: 57.419
+ - type: ndcg_at_300
+ value: 58.778
+ - type: ndcg_at_500
+ value: 60.228
+ - type: ndcg_at_700
+ value: 61.07599999999999
+ - type: ndcg_at_1000
+ value: 61.846000000000004
+ - type: map_at_1
+ value: 10.359
+ - type: map_at_2
+ value: 14.446
+ - type: map_at_3
+ value: 16.689
+ - type: map_at_5
+ value: 20.096
+ - type: map_at_7
+ value: 22.247
+ - type: map_at_10
+ value: 24.468999999999998
+ - type: map_at_20
+ value: 28.938000000000002
+ - type: map_at_30
+ value: 31.134
+ - type: map_at_50
+ value: 33.403
+ - type: map_at_70
+ value: 34.486
+ - type: map_at_100
+ value: 35.337
+ - type: map_at_200
+ value: 36.364999999999995
+ - type: map_at_300
+ value: 36.735
+ - type: map_at_500
+ value: 37.057
+ - type: map_at_700
+ value: 37.225
+ - type: map_at_1000
+ value: 37.379
+ - type: recall_at_1
+ value: 10.359
+ - type: recall_at_2
+ value: 14.945
+ - type: recall_at_3
+ value: 17.694
+ - type: recall_at_5
+ value: 22.677
+ - type: recall_at_7
+ value: 26.131
+ - type: recall_at_10
+ value: 30.053
+ - type: recall_at_20
+ value: 39.518
+ - type: recall_at_30
+ value: 44.925
+ - type: recall_at_50
+ value: 52.154
+ - type: recall_at_70
+ value: 56.729
+ - type: recall_at_100
+ value: 61.18900000000001
+ - type: recall_at_200
+ value: 70.407
+ - type: recall_at_300
+ value: 74.412
+ - type: recall_at_500
+ value: 78.891
+ - type: recall_at_700
+ value: 81.74
+ - type: recall_at_1000
+ value: 84.253
+ - type: precision_at_1
+ value: 75
+ - type: precision_at_2
+ value: 64.125
+ - type: precision_at_3
+ value: 57.833
+ - type: precision_at_5
+ value: 50.24999999999999
+ - type: precision_at_7
+ value: 44.75
+ - type: precision_at_10
+ value: 39.75
+ - type: precision_at_20
+ value: 30.412
+ - type: precision_at_30
+ value: 25.141999999999996
+ - type: precision_at_50
+ value: 19.2
+ - type: precision_at_70
+ value: 15.729000000000001
+ - type: precision_at_100
+ value: 12.552
+ - type: precision_at_200
+ value: 7.866
+ - type: precision_at_300
+ value: 5.9270000000000005
+ - type: precision_at_500
+ value: 4.1129999999999995
+ - type: precision_at_700
+ value: 3.2460000000000004
+ - type: precision_at_1000
+ value: 2.5260000000000002
+ - type: mrr_at_1
+ value: 75
+ - type: mrr_at_2
+ value: 78.625
+ - type: mrr_at_3
+ value: 79.708
+ - type: mrr_at_5
+ value: 80.446
+ - type: mrr_at_7
+ value: 80.862
+ - type: mrr_at_10
+ value: 81.161
+ - type: mrr_at_20
+ value: 81.3
+ - type: mrr_at_30
+ value: 81.348
+ - type: mrr_at_50
+ value: 81.361
+ - type: mrr_at_70
+ value: 81.361
+ - type: mrr_at_100
+ value: 81.361
+ - type: mrr_at_200
+ value: 81.367
+ - type: mrr_at_300
+ value: 81.367
+ - type: mrr_at_500
+ value: 81.368
+ - type: mrr_at_700
+ value: 81.368
+ - type: mrr_at_1000
+ value: 81.368
+ - task:
+ type: Classification
+ dataset:
+ type: mteb/emotion
+ name: MTEB EmotionClassification
+ config: default
+ split: test
+ revision: 4f58c6b202a23cf9a4da393831edf4f9183cad37
+ metrics:
+ - type: accuracy
+ value: 50.239999999999995
+ - type: f1
+ value: 46.42361822342044
+ - task:
+ type: Retrieval
+ dataset:
+ type: fever
+ name: MTEB FEVER
+ config: default
+ split: test
+ revision: None
+ metrics:
+ - type: ndcg_at_1
+ value: 83.723
+ - type: ndcg_at_2
+ value: 86.777
+ - type: ndcg_at_3
+ value: 87.997
+ - type: ndcg_at_5
+ value: 88.864
+ - type: ndcg_at_7
+ value: 89.143
+ - type: ndcg_at_10
+ value: 89.349
+ - type: ndcg_at_20
+ value: 89.709
+ - type: ndcg_at_30
+ value: 89.82900000000001
+ - type: ndcg_at_50
+ value: 89.923
+ - type: ndcg_at_70
+ value: 89.982
+ - type: ndcg_at_100
+ value: 90.026
+ - type: ndcg_at_200
+ value: 90.10000000000001
+ - type: ndcg_at_300
+ value: 90.12599999999999
+ - type: ndcg_at_500
+ value: 90.17399999999999
+ - type: ndcg_at_700
+ value: 90.19
+ - type: ndcg_at_1000
+ value: 90.208
+ - type: map_at_1
+ value: 77.64999999999999
+ - type: map_at_2
+ value: 83.769
+ - type: map_at_3
+ value: 85.041
+ - type: map_at_5
+ value: 85.736
+ - type: map_at_7
+ value: 85.924
+ - type: map_at_10
+ value: 86.032
+ - type: map_at_20
+ value: 86.177
+ - type: map_at_30
+ value: 86.213
+ - type: map_at_50
+ value: 86.233
+ - type: map_at_70
+ value: 86.24300000000001
+ - type: map_at_100
+ value: 86.249
+ - type: map_at_200
+ value: 86.256
+ - type: map_at_300
+ value: 86.258
+ - type: map_at_500
+ value: 86.26
+ - type: map_at_700
+ value: 86.26
+ - type: map_at_1000
+ value: 86.261
+ - type: recall_at_1
+ value: 77.64999999999999
+ - type: recall_at_2
+ value: 88.53999999999999
+ - type: recall_at_3
+ value: 91.696
+ - type: recall_at_5
+ value: 93.916
+ - type: recall_at_7
+ value: 94.731
+ - type: recall_at_10
+ value: 95.318
+ - type: recall_at_20
+ value: 96.507
+ - type: recall_at_30
+ value: 96.956
+ - type: recall_at_50
+ value: 97.34899999999999
+ - type: recall_at_70
+ value: 97.61
+ - type: recall_at_100
+ value: 97.83
+ - type: recall_at_200
+ value: 98.223
+ - type: recall_at_300
+ value: 98.374
+ - type: recall_at_500
+ value: 98.67899999999999
+ - type: recall_at_700
+ value: 98.787
+ - type: recall_at_1000
+ value: 98.919
+ - type: precision_at_1
+ value: 83.723
+ - type: precision_at_2
+ value: 48.425000000000004
+ - type: precision_at_3
+ value: 33.638
+ - type: precision_at_5
+ value: 20.843
+ - type: precision_at_7
+ value: 15.079
+ - type: precision_at_10
+ value: 10.674999999999999
+ - type: precision_at_20
+ value: 5.457999999999999
+ - type: precision_at_30
+ value: 3.6740000000000004
+ - type: precision_at_50
+ value: 2.2239999999999998
+ - type: precision_at_70
+ value: 1.599
+ - type: precision_at_100
+ value: 1.125
+ - type: precision_at_200
+ value: 0.5680000000000001
+ - type: precision_at_300
+ value: 0.38
+ - type: precision_at_500
+ value: 0.22999999999999998
+ - type: precision_at_700
+ value: 0.165
+ - type: precision_at_1000
+ value: 0.116
+ - type: mrr_at_1
+ value: 83.723
+ - type: mrr_at_2
+ value: 88.794
+ - type: mrr_at_3
+ value: 89.679
+ - type: mrr_at_5
+ value: 90.049
+ - type: mrr_at_7
+ value: 90.129
+ - type: mrr_at_10
+ value: 90.167
+ - type: mrr_at_20
+ value: 90.208
+ - type: mrr_at_30
+ value: 90.214
+ - type: mrr_at_50
+ value: 90.217
+ - type: mrr_at_70
+ value: 90.218
+ - type: mrr_at_100
+ value: 90.21900000000001
+ - type: mrr_at_200
+ value: 90.21900000000001
+ - type: mrr_at_300
+ value: 90.21900000000001
+ - type: mrr_at_500
+ value: 90.21900000000001
+ - type: mrr_at_700
+ value: 90.21900000000001
+ - type: mrr_at_1000
+ value: 90.21900000000001
+ - task:
+ type: Retrieval
+ dataset:
+ type: fiqa
+ name: MTEB FiQA2018
+ config: default
+ split: test
+ revision: None
+ metrics:
+ - type: ndcg_at_1
+ value: 59.721999999999994
+ - type: ndcg_at_2
+ value: 56.85
+ - type: ndcg_at_3
+ value: 56.462999999999994
+ - type: ndcg_at_5
+ value: 57.75599999999999
+ - type: ndcg_at_7
+ value: 59.109
+ - type: ndcg_at_10
+ value: 60.402
+ - type: ndcg_at_20
+ value: 63.071999999999996
+ - type: ndcg_at_30
+ value: 64.302
+ - type: ndcg_at_50
+ value: 65.619
+ - type: ndcg_at_70
+ value: 66.161
+ - type: ndcg_at_100
+ value: 66.645
+ - type: ndcg_at_200
+ value: 67.353
+ - type: ndcg_at_300
+ value: 67.646
+ - type: ndcg_at_500
+ value: 67.852
+ - type: ndcg_at_700
+ value: 67.974
+ - type: ndcg_at_1000
+ value: 68.084
+ - type: map_at_1
+ value: 31.56
+ - type: map_at_2
+ value: 42.093
+ - type: map_at_3
+ value: 46.177
+ - type: map_at_5
+ value: 49.78
+ - type: map_at_7
+ value: 51.410999999999994
+ - type: map_at_10
+ value: 52.524
+ - type: map_at_20
+ value: 53.815000000000005
+ - type: map_at_30
+ value: 54.201
+ - type: map_at_50
+ value: 54.531
+ - type: map_at_70
+ value: 54.625
+ - type: map_at_100
+ value: 54.686
+ - type: map_at_200
+ value: 54.757999999999996
+ - type: map_at_300
+ value: 54.776
+ - type: map_at_500
+ value: 54.786
+ - type: map_at_700
+ value: 54.790000000000006
+ - type: map_at_1000
+ value: 54.793000000000006
+ - type: recall_at_1
+ value: 31.56
+ - type: recall_at_2
+ value: 44.858
+ - type: recall_at_3
+ value: 51.11
+ - type: recall_at_5
+ value: 58.394
+ - type: recall_at_7
+ value: 63.001
+ - type: recall_at_10
+ value: 66.81200000000001
+ - type: recall_at_20
+ value: 74.901
+ - type: recall_at_30
+ value: 79.218
+ - type: recall_at_50
+ value: 84.49
+ - type: recall_at_70
+ value: 87.003
+ - type: recall_at_100
+ value: 89.345
+ - type: recall_at_200
+ value: 93.173
+ - type: recall_at_300
+ value: 94.906
+ - type: recall_at_500
+ value: 96.223
+ - type: recall_at_700
+ value: 97.043
+ - type: recall_at_1000
+ value: 97.785
+ - type: precision_at_1
+ value: 59.721999999999994
+ - type: precision_at_2
+ value: 46.682
+ - type: precision_at_3
+ value: 37.602999999999994
+ - type: precision_at_5
+ value: 27.500000000000004
+ - type: precision_at_7
+ value: 21.847
+ - type: precision_at_10
+ value: 16.667
+ - type: precision_at_20
+ value: 9.545
+ - type: precision_at_30
+ value: 6.795
+ - type: precision_at_50
+ value: 4.38
+ - type: precision_at_70
+ value: 3.221
+ - type: precision_at_100
+ value: 2.319
+ - type: precision_at_200
+ value: 1.2149999999999999
+ - type: precision_at_300
+ value: 0.827
+ - type: precision_at_500
+ value: 0.504
+ - type: precision_at_700
+ value: 0.364
+ - type: precision_at_1000
+ value: 0.257
+ - type: mrr_at_1
+ value: 59.721999999999994
+ - type: mrr_at_2
+ value: 64.506
+ - type: mrr_at_3
+ value: 65.792
+ - type: mrr_at_5
+ value: 66.965
+ - type: mrr_at_7
+ value: 67.34700000000001
+ - type: mrr_at_10
+ value: 67.57
+ - type: mrr_at_20
+ value: 67.896
+ - type: mrr_at_30
+ value: 68.008
+ - type: mrr_at_50
+ value: 68.083
+ - type: mrr_at_70
+ value: 68.105
+ - type: mrr_at_100
+ value: 68.116
+ - type: mrr_at_200
+ value: 68.12700000000001
+ - type: mrr_at_300
+ value: 68.13
+ - type: mrr_at_500
+ value: 68.132
+ - type: mrr_at_700
+ value: 68.133
+ - type: mrr_at_1000
+ value: 68.133
+ - task:
+ type: Retrieval
+ dataset:
+ type: hotpotqa
+ name: MTEB HotpotQA
+ config: default
+ split: test
+ revision: None
+ metrics:
+ - type: ndcg_at_1
+ value: 81.796
+ - type: ndcg_at_2
+ value: 67.999
+ - type: ndcg_at_3
+ value: 72.15599999999999
+ - type: ndcg_at_5
+ value: 74.99900000000001
+ - type: ndcg_at_7
+ value: 76.179
+ - type: ndcg_at_10
+ value: 77.022
+ - type: ndcg_at_20
+ value: 78.173
+ - type: ndcg_at_30
+ value: 78.648
+ - type: ndcg_at_50
+ value: 79.104
+ - type: ndcg_at_70
+ value: 79.335
+ - type: ndcg_at_100
+ value: 79.56
+ - type: ndcg_at_200
+ value: 79.911
+ - type: ndcg_at_300
+ value: 80.045
+ - type: ndcg_at_500
+ value: 80.19500000000001
+ - type: ndcg_at_700
+ value: 80.281
+ - type: ndcg_at_1000
+ value: 80.35
+ - type: map_at_1
+ value: 40.898
+ - type: map_at_2
+ value: 62.016000000000005
+ - type: map_at_3
+ value: 66.121
+ - type: map_at_5
+ value: 68.471
+ - type: map_at_7
+ value: 69.261
+ - type: map_at_10
+ value: 69.738
+ - type: map_at_20
+ value: 70.208
+ - type: map_at_30
+ value: 70.343
+ - type: map_at_50
+ value: 70.43700000000001
+ - type: map_at_70
+ value: 70.47099999999999
+ - type: map_at_100
+ value: 70.498
+ - type: map_at_200
+ value: 70.526
+ - type: map_at_300
+ value: 70.533
+ - type: map_at_500
+ value: 70.538
+ - type: map_at_700
+ value: 70.541
+ - type: map_at_1000
+ value: 70.542
+ - type: recall_at_1
+ value: 40.898
+ - type: recall_at_2
+ value: 63.964
+ - type: recall_at_3
+ value: 70.743
+ - type: recall_at_5
+ value: 76.36699999999999
+ - type: recall_at_7
+ value: 79.142
+ - type: recall_at_10
+ value: 81.404
+ - type: recall_at_20
+ value: 85.111
+ - type: recall_at_30
+ value: 86.92800000000001
+ - type: recall_at_50
+ value: 88.899
+ - type: recall_at_70
+ value: 90.01400000000001
+ - type: recall_at_100
+ value: 91.19500000000001
+ - type: recall_at_200
+ value: 93.234
+ - type: recall_at_300
+ value: 94.105
+ - type: recall_at_500
+ value: 95.159
+ - type: recall_at_700
+ value: 95.8
+ - type: recall_at_1000
+ value: 96.34700000000001
+ - type: precision_at_1
+ value: 81.796
+ - type: precision_at_2
+ value: 63.964
+ - type: precision_at_3
+ value: 47.162
+ - type: precision_at_5
+ value: 30.547
+ - type: precision_at_7
+ value: 22.612
+ - type: precision_at_10
+ value: 16.281000000000002
+ - type: precision_at_20
+ value: 8.511000000000001
+ - type: precision_at_30
+ value: 5.795
+ - type: precision_at_50
+ value: 3.556
+ - type: precision_at_70
+ value: 2.572
+ - type: precision_at_100
+ value: 1.8239999999999998
+ - type: precision_at_200
+ value: 0.932
+ - type: precision_at_300
+ value: 0.627
+ - type: precision_at_500
+ value: 0.381
+ - type: precision_at_700
+ value: 0.27399999999999997
+ - type: precision_at_1000
+ value: 0.193
+ - type: mrr_at_1
+ value: 81.796
+ - type: mrr_at_2
+ value: 85.69200000000001
+ - type: mrr_at_3
+ value: 86.52
+ - type: mrr_at_5
+ value: 86.973
+ - type: mrr_at_7
+ value: 87.13300000000001
+ - type: mrr_at_10
+ value: 87.208
+ - type: mrr_at_20
+ value: 87.303
+ - type: mrr_at_30
+ value: 87.32799999999999
+ - type: mrr_at_50
+ value: 87.347
+ - type: mrr_at_70
+ value: 87.35199999999999
+ - type: mrr_at_100
+ value: 87.355
+ - type: mrr_at_200
+ value: 87.357
+ - type: mrr_at_300
+ value: 87.357
+ - type: mrr_at_500
+ value: 87.358
+ - type: mrr_at_700
+ value: 87.358
+ - type: mrr_at_1000
+ value: 87.358
+ - task:
+ type: Classification
+ dataset:
+ type: mteb/imdb
+ name: MTEB ImdbClassification
+ config: default
+ split: test
+ revision: 3d86128a09e091d6018b6d26cad27f2739fc2db7
+ metrics:
+ - type: accuracy
+ value: 94.79200000000002
+ - type: ap
+ value: 92.54484356773553
+ - type: f1
+ value: 94.78965313682525
+ - task:
+ type: Retrieval
+ dataset:
+ type: msmarco
+ name: MTEB MSMARCO
+ config: default
+ split: dev
+ revision: None
+ metrics:
+ - type: ndcg_at_1
+ value: 24.398
+ - type: ndcg_at_2
+ value: 31.336000000000002
+ - type: ndcg_at_3
+ value: 35.266999999999996
+ - type: ndcg_at_5
+ value: 39.356
+ - type: ndcg_at_7
+ value: 41.562
+ - type: ndcg_at_10
+ value: 43.408
+ - type: ndcg_at_20
+ value: 46.107
+ - type: ndcg_at_30
+ value: 47.164
+ - type: ndcg_at_50
+ value: 48.126000000000005
+ - type: ndcg_at_70
+ value: 48.626999999999995
+ - type: ndcg_at_100
+ value: 49.043
+ - type: ndcg_at_200
+ value: 49.575
+ - type: ndcg_at_300
+ value: 49.794
+ - type: ndcg_at_500
+ value: 49.942
+ - type: ndcg_at_700
+ value: 50.014
+ - type: ndcg_at_1000
+ value: 50.077000000000005
+ - type: map_at_1
+ value: 23.723
+ - type: map_at_2
+ value: 29.593000000000004
+ - type: map_at_3
+ value: 32.273
+ - type: map_at_5
+ value: 34.587
+ - type: map_at_7
+ value: 35.589999999999996
+ - type: map_at_10
+ value: 36.296
+ - type: map_at_20
+ value: 37.059999999999995
+ - type: map_at_30
+ value: 37.265
+ - type: map_at_50
+ value: 37.402
+ - type: map_at_70
+ value: 37.454
+ - type: map_at_100
+ value: 37.486999999999995
+ - type: map_at_200
+ value: 37.516
+ - type: map_at_300
+ value: 37.524
+ - type: map_at_500
+ value: 37.528
+ - type: map_at_700
+ value: 37.529
+ - type: map_at_1000
+ value: 37.53
+ - type: recall_at_1
+ value: 23.723
+ - type: recall_at_2
+ value: 35.355
+ - type: recall_at_3
+ value: 43.22
+ - type: recall_at_5
+ value: 53.025
+ - type: recall_at_7
+ value: 59.327
+ - type: recall_at_10
+ value: 65.302
+ - type: recall_at_20
+ value: 75.765
+ - type: recall_at_30
+ value: 80.632
+ - type: recall_at_50
+ value: 85.63499999999999
+ - type: recall_at_70
+ value: 88.554
+ - type: recall_at_100
+ value: 91.16300000000001
+ - type: recall_at_200
+ value: 94.85
+ - type: recall_at_300
+ value: 96.532
+ - type: recall_at_500
+ value: 97.751
+ - type: recall_at_700
+ value: 98.383
+ - type: recall_at_1000
+ value: 98.97
+ - type: precision_at_1
+ value: 24.398
+ - type: precision_at_2
+ value: 18.274
+ - type: precision_at_3
+ value: 14.951999999999998
+ - type: precision_at_5
+ value: 11.052
+ - type: precision_at_7
+ value: 8.84
+ - type: precision_at_10
+ value: 6.8309999999999995
+ - type: precision_at_20
+ value: 3.978
+ - type: precision_at_30
+ value: 2.827
+ - type: precision_at_50
+ value: 1.807
+ - type: precision_at_70
+ value: 1.336
+ - type: precision_at_100
+ value: 0.964
+ - type: precision_at_200
+ value: 0.502
+ - type: precision_at_300
+ value: 0.34099999999999997
+ - type: precision_at_500
+ value: 0.208
+ - type: precision_at_700
+ value: 0.15
+ - type: precision_at_1000
+ value: 0.105
+ - type: mrr_at_1
+ value: 24.398
+ - type: mrr_at_2
+ value: 30.351
+ - type: mrr_at_3
+ value: 33.001000000000005
+ - type: mrr_at_5
+ value: 35.228
+ - type: mrr_at_7
+ value: 36.223
+ - type: mrr_at_10
+ value: 36.903999999999996
+ - type: mrr_at_20
+ value: 37.631
+ - type: mrr_at_30
+ value: 37.830000000000005
+ - type: mrr_at_50
+ value: 37.955
+ - type: mrr_at_70
+ value: 38.003
+ - type: mrr_at_100
+ value: 38.033
+ - type: mrr_at_200
+ value: 38.059
+ - type: mrr_at_300
+ value: 38.066
+ - type: mrr_at_500
+ value: 38.068999999999996
+ - type: mrr_at_700
+ value: 38.07
+ - type: mrr_at_1000
+ value: 38.07
+ - task:
+ type: Classification
+ dataset:
+ type: mteb/mtop_domain
+ name: MTEB MTOPDomainClassification (en)
+ config: en
+ split: test
+ revision: d80d48c1eb48d3562165c59d59d0034df9fff0bf
+ metrics:
+ - type: accuracy
+ value: 96.35658914728683
+ - type: f1
+ value: 96.15039630903114
+ - task:
+ type: Classification
+ dataset:
+ type: mteb/mtop_intent
+ name: MTEB MTOPIntentClassification (en)
+ config: en
+ split: test
+ revision: ae001d0e6b1228650b7bd1c2c65fb50ad11a8aba
+ metrics:
+ - type: accuracy
+ value: 86.29730962152303
+ - type: f1
+ value: 71.12166316567485
+ - task:
+ type: Classification
+ dataset:
+ type: mteb/amazon_massive_intent
+ name: MTEB MassiveIntentClassification (en)
+ config: en
+ split: test
+ revision: 31efe3c427b0bae9c22cbb560b8f15491cc6bed7
+ metrics:
+ - type: accuracy
+ value: 79.98991257565568
+ - type: f1
+ value: 77.41680115095276
+ - task:
+ type: Classification
+ dataset:
+ type: mteb/amazon_massive_scenario
+ name: MTEB MassiveScenarioClassification (en)
+ config: en
+ split: test
+ revision: 7d571f92784cd94a019292a1f45445077d0ef634
+ metrics:
+ - type: accuracy
+ value: 82.1990585070612
+ - type: f1
+ value: 82.23719179179362
+ - task:
+ type: Clustering
+ dataset:
+ type: mteb/medrxiv-clustering-p2p
+ name: MTEB MedrxivClusteringP2P
+ config: default
+ split: test
+ revision: e7a26af6f3ae46b30dde8737f02c07b1505bcc73
+ metrics:
+ - type: v_measure
+ value: 40.03019554933584
+ - task:
+ type: Clustering
+ dataset:
+ type: mteb/medrxiv-clustering-s2s
+ name: MTEB MedrxivClusteringS2S
+ config: default
+ split: test
+ revision: 35191c8c0dca72d8ff3efcd72aa802307d469663
+ metrics:
+ - type: v_measure
+ value: 38.999760551497815
+ - task:
+ type: Reranking
+ dataset:
+ type: mteb/mind_small
+ name: MTEB MindSmallReranking
+ config: default
+ split: test
+ revision: 3bdac13927fdc888b903db93b2ffdbd90b295a69
+ metrics:
+ - type: map
+ value: 32.72383151953079
+ - type: mrr
+ value: 33.93989699030721
+ - task:
+ type: Retrieval
+ dataset:
+ type: nfcorpus
+ name: MTEB NFCorpus
+ config: default
+ split: test
+ revision: None
+ metrics:
+ - type: ndcg_at_1
+ value: 51.858000000000004
+ - type: ndcg_at_2
+ value: 49.675999999999995
+ - type: ndcg_at_3
+ value: 47.519
+ - type: ndcg_at_5
+ value: 45.198
+ - type: ndcg_at_7
+ value: 43.504
+ - type: ndcg_at_10
+ value: 41.88
+ - type: ndcg_at_20
+ value: 39.122
+ - type: ndcg_at_30
+ value: 37.95
+ - type: ndcg_at_50
+ value: 37.602999999999994
+ - type: ndcg_at_70
+ value: 37.836
+ - type: ndcg_at_100
+ value: 38.493
+ - type: ndcg_at_200
+ value: 40.187
+ - type: ndcg_at_300
+ value: 41.524
+ - type: ndcg_at_500
+ value: 43.657000000000004
+ - type: ndcg_at_700
+ value: 45.234
+ - type: ndcg_at_1000
+ value: 47.047
+ - type: map_at_1
+ value: 6.392
+ - type: map_at_2
+ value: 10.113
+ - type: map_at_3
+ value: 11.543000000000001
+ - type: map_at_5
+ value: 13.729
+ - type: map_at_7
+ value: 14.985000000000001
+ - type: map_at_10
+ value: 16.217000000000002
+ - type: map_at_20
+ value: 18.106
+ - type: map_at_30
+ value: 18.878
+ - type: map_at_50
+ value: 19.822
+ - type: map_at_70
+ value: 20.352999999999998
+ - type: map_at_100
+ value: 20.827
+ - type: map_at_200
+ value: 21.512
+ - type: map_at_300
+ value: 21.826
+ - type: map_at_500
+ value: 22.155
+ - type: map_at_700
+ value: 22.349
+ - type: map_at_1000
+ value: 22.531000000000002
+ - type: recall_at_1
+ value: 6.392
+ - type: recall_at_2
+ value: 11.215
+ - type: recall_at_3
+ value: 13.231000000000002
+ - type: recall_at_5
+ value: 16.66
+ - type: recall_at_7
+ value: 18.802
+ - type: recall_at_10
+ value: 21.185000000000002
+ - type: recall_at_20
+ value: 25.35
+ - type: recall_at_30
+ value: 27.91
+ - type: recall_at_50
+ value: 32.845
+ - type: recall_at_70
+ value: 35.789
+ - type: recall_at_100
+ value: 39.247
+ - type: recall_at_200
+ value: 46.655
+ - type: recall_at_300
+ value: 51.43299999999999
+ - type: recall_at_500
+ value: 59.472
+ - type: recall_at_700
+ value: 64.742
+ - type: recall_at_1000
+ value: 70.97099999999999
+ - type: precision_at_1
+ value: 53.559999999999995
+ - type: precision_at_2
+ value: 48.762
+ - type: precision_at_3
+ value: 44.169000000000004
+ - type: precision_at_5
+ value: 39.071
+ - type: precision_at_7
+ value: 35.161
+ - type: precision_at_10
+ value: 31.238
+ - type: precision_at_20
+ value: 23.064999999999998
+ - type: precision_at_30
+ value: 18.844
+ - type: precision_at_50
+ value: 14.601
+ - type: precision_at_70
+ value: 12.088000000000001
+ - type: precision_at_100
+ value: 9.844999999999999
+ - type: precision_at_200
+ value: 6.358
+ - type: precision_at_300
+ value: 4.915
+ - type: precision_at_500
+ value: 3.531
+ - type: precision_at_700
+ value: 2.8649999999999998
+ - type: precision_at_1000
+ value: 2.289
+ - type: mrr_at_1
+ value: 54.17999999999999
+ - type: mrr_at_2
+ value: 59.288
+ - type: mrr_at_3
+ value: 60.836
+ - type: mrr_at_5
+ value: 62.275999999999996
+ - type: mrr_at_7
+ value: 62.688
+ - type: mrr_at_10
+ value: 62.865
+ - type: mrr_at_20
+ value: 63.11
+ - type: mrr_at_30
+ value: 63.193999999999996
+ - type: mrr_at_50
+ value: 63.258
+ - type: mrr_at_70
+ value: 63.278
+ - type: mrr_at_100
+ value: 63.297000000000004
+ - type: mrr_at_200
+ value: 63.315999999999995
+ - type: mrr_at_300
+ value: 63.318
+ - type: mrr_at_500
+ value: 63.32299999999999
+ - type: mrr_at_700
+ value: 63.324000000000005
+ - type: mrr_at_1000
+ value: 63.324999999999996
+ - task:
+ type: Retrieval
+ dataset:
+ type: nq
+ name: MTEB NQ
+ config: default
+ split: test
+ revision: None
+ metrics:
+ - type: ndcg_at_1
+ value: 50.897999999999996
+ - type: ndcg_at_2
+ value: 59.126
+ - type: ndcg_at_3
+ value: 63.093999999999994
+ - type: ndcg_at_5
+ value: 67.197
+ - type: ndcg_at_7
+ value: 68.719
+ - type: ndcg_at_10
+ value: 69.915
+ - type: ndcg_at_20
+ value: 71.229
+ - type: ndcg_at_30
+ value: 71.667
+ - type: ndcg_at_50
+ value: 71.98
+ - type: ndcg_at_70
+ value: 72.127
+ - type: ndcg_at_100
+ value: 72.217
+ - type: ndcg_at_200
+ value: 72.319
+ - type: ndcg_at_300
+ value: 72.347
+ - type: ndcg_at_500
+ value: 72.37
+ - type: ndcg_at_700
+ value: 72.379
+ - type: ndcg_at_1000
+ value: 72.381
+ - type: map_at_1
+ value: 45.297
+ - type: map_at_2
+ value: 55.596000000000004
+ - type: map_at_3
+ value: 58.724
+ - type: map_at_5
+ value: 61.387
+ - type: map_at_7
+ value: 62.173
+ - type: map_at_10
+ value: 62.69
+ - type: map_at_20
+ value: 63.125
+ - type: map_at_30
+ value: 63.223
+ - type: map_at_50
+ value: 63.27700000000001
+ - type: map_at_70
+ value: 63.295
+ - type: map_at_100
+ value: 63.303
+ - type: map_at_200
+ value: 63.31
+ - type: map_at_300
+ value: 63.31099999999999
+ - type: map_at_500
+ value: 63.312000000000005
+ - type: map_at_700
+ value: 63.312000000000005
+ - type: map_at_1000
+ value: 63.312000000000005
+ - type: recall_at_1
+ value: 45.297
+ - type: recall_at_2
+ value: 63.866
+ - type: recall_at_3
+ value: 71.898
+ - type: recall_at_5
+ value: 81.16600000000001
+ - type: recall_at_7
+ value: 85.301
+ - type: recall_at_10
+ value: 88.94800000000001
+ - type: recall_at_20
+ value: 93.719
+ - type: recall_at_30
+ value: 95.628
+ - type: recall_at_50
+ value: 97.14699999999999
+ - type: recall_at_70
+ value: 97.955
+ - type: recall_at_100
+ value: 98.48599999999999
+ - type: recall_at_200
+ value: 99.157
+ - type: recall_at_300
+ value: 99.355
+ - type: recall_at_500
+ value: 99.53699999999999
+ - type: recall_at_700
+ value: 99.62299999999999
+ - type: recall_at_1000
+ value: 99.638
+ - type: precision_at_1
+ value: 50.897999999999996
+ - type: precision_at_2
+ value: 36.703
+ - type: precision_at_3
+ value: 27.926000000000002
+ - type: precision_at_5
+ value: 19.276
+ - type: precision_at_7
+ value: 14.533999999999999
+ - type: precision_at_10
+ value: 10.678
+ - type: precision_at_20
+ value: 5.663
+ - type: precision_at_30
+ value: 3.8600000000000003
+ - type: precision_at_50
+ value: 2.358
+ - type: precision_at_70
+ value: 1.7000000000000002
+ - type: precision_at_100
+ value: 1.198
+ - type: precision_at_200
+ value: 0.603
+ - type: precision_at_300
+ value: 0.40299999999999997
+ - type: precision_at_500
+ value: 0.242
+ - type: precision_at_700
+ value: 0.173
+ - type: precision_at_1000
+ value: 0.121
+ - type: mrr_at_1
+ value: 50.897999999999996
+ - type: mrr_at_2
+ value: 59.994
+ - type: mrr_at_3
+ value: 62.553000000000004
+ - type: mrr_at_5
+ value: 64.307
+ - type: mrr_at_7
+ value: 64.864
+ - type: mrr_at_10
+ value: 65.22200000000001
+ - type: mrr_at_20
+ value: 65.499
+ - type: mrr_at_30
+ value: 65.561
+ - type: mrr_at_50
+ value: 65.592
+ - type: mrr_at_70
+ value: 65.602
+ - type: mrr_at_100
+ value: 65.607
+ - type: mrr_at_200
+ value: 65.61099999999999
+ - type: mrr_at_300
+ value: 65.61200000000001
+ - type: mrr_at_500
+ value: 65.61200000000001
+ - type: mrr_at_700
+ value: 65.61200000000001
+ - type: mrr_at_1000
+ value: 65.61200000000001
+ - task:
+ type: Retrieval
+ dataset:
+ type: quora
+ name: MTEB QuoraRetrieval
+ config: default
+ split: test
+ revision: None
+ metrics:
+ - type: ndcg_at_1
+ value: 82.96
+ - type: ndcg_at_2
+ value: 85.614
+ - type: ndcg_at_3
+ value: 87.19
+ - type: ndcg_at_5
+ value: 88.654
+ - type: ndcg_at_7
+ value: 89.287
+ - type: ndcg_at_10
+ value: 89.785
+ - type: ndcg_at_20
+ value: 90.384
+ - type: ndcg_at_30
+ value: 90.589
+ - type: ndcg_at_50
+ value: 90.738
+ - type: ndcg_at_70
+ value: 90.789
+ - type: ndcg_at_100
+ value: 90.824
+ - type: ndcg_at_200
+ value: 90.869
+ - type: ndcg_at_300
+ value: 90.881
+ - type: ndcg_at_500
+ value: 90.886
+ - type: ndcg_at_700
+ value: 90.889
+ - type: ndcg_at_1000
+ value: 90.889
+ - type: map_at_1
+ value: 72.152
+ - type: map_at_2
+ value: 80.818
+ - type: map_at_3
+ value: 83.462
+ - type: map_at_5
+ value: 85.286
+ - type: map_at_7
+ value: 85.921
+ - type: map_at_10
+ value: 86.334
+ - type: map_at_20
+ value: 86.737
+ - type: map_at_30
+ value: 86.847
+ - type: map_at_50
+ value: 86.911
+ - type: map_at_70
+ value: 86.932
+ - type: map_at_100
+ value: 86.943
+ - type: map_at_200
+ value: 86.953
+ - type: map_at_300
+ value: 86.955
+ - type: map_at_500
+ value: 86.956
+ - type: map_at_700
+ value: 86.956
+ - type: map_at_1000
+ value: 86.956
+ - type: recall_at_1
+ value: 72.152
+ - type: recall_at_2
+ value: 84.129
+ - type: recall_at_3
+ value: 88.87
+ - type: recall_at_5
+ value: 93.067
+ - type: recall_at_7
+ value: 94.882
+ - type: recall_at_10
+ value: 96.353
+ - type: recall_at_20
+ value: 98.26700000000001
+ - type: recall_at_30
+ value: 98.92999999999999
+ - type: recall_at_50
+ value: 99.441
+ - type: recall_at_70
+ value: 99.619
+ - type: recall_at_100
+ value: 99.748
+ - type: recall_at_200
+ value: 99.911
+ - type: recall_at_300
+ value: 99.956
+ - type: recall_at_500
+ value: 99.98
+ - type: recall_at_700
+ value: 99.991
+ - type: recall_at_1000
+ value: 99.996
+ - type: precision_at_1
+ value: 82.96
+ - type: precision_at_2
+ value: 52.175000000000004
+ - type: precision_at_3
+ value: 38.223
+ - type: precision_at_5
+ value: 25.056
+ - type: precision_at_7
+ value: 18.717
+ - type: precision_at_10
+ value: 13.614999999999998
+ - type: precision_at_20
+ value: 7.208
+ - type: precision_at_30
+ value: 4.928
+ - type: precision_at_50
+ value: 3.024
+ - type: precision_at_70
+ value: 2.183
+ - type: precision_at_100
+ value: 1.54
+ - type: precision_at_200
+ value: 0.779
+ - type: precision_at_300
+ value: 0.521
+ - type: precision_at_500
+ value: 0.313
+ - type: precision_at_700
+ value: 0.22399999999999998
+ - type: precision_at_1000
+ value: 0.157
+ - type: mrr_at_1
+ value: 82.96
+ - type: mrr_at_2
+ value: 87.005
+ - type: mrr_at_3
+ value: 88.07199999999999
+ - type: mrr_at_5
+ value: 88.634
+ - type: mrr_at_7
+ value: 88.793
+ - type: mrr_at_10
+ value: 88.87899999999999
+ - type: mrr_at_20
+ value: 88.94999999999999
+ - type: mrr_at_30
+ value: 88.96
+ - type: mrr_at_50
+ value: 88.965
+ - type: mrr_at_70
+ value: 88.966
+ - type: mrr_at_100
+ value: 88.967
+ - type: mrr_at_200
+ value: 88.967
+ - type: mrr_at_300
+ value: 88.967
+ - type: mrr_at_500
+ value: 88.967
+ - type: mrr_at_700
+ value: 88.967
+ - type: mrr_at_1000
+ value: 88.967
+ - task:
+ type: Clustering
+ dataset:
+ type: mteb/reddit-clustering
+ name: MTEB RedditClustering
+ config: default
+ split: test
+ revision: 24640382cdbf8abc73003fb0fa6d111a705499eb
+ metrics:
+ - type: v_measure
+ value: 59.90388554491155
+ - task:
+ type: Clustering
+ dataset:
+ type: mteb/reddit-clustering-p2p
+ name: MTEB RedditClusteringP2P
+ config: default
+ split: test
+ revision: 282350215ef01743dc01b456c7f5241fa8937f16
+ metrics:
+ - type: v_measure
+ value: 67.64232539036783
+ - task:
+ type: Retrieval
+ dataset:
+ type: scidocs
+ name: MTEB SCIDOCS
+ config: default
+ split: test
+ revision: None
+ metrics:
+ - type: ndcg_at_1
+ value: 22.6
+ - type: ndcg_at_2
+ value: 20.355999999999998
+ - type: ndcg_at_3
+ value: 18.536
+ - type: ndcg_at_5
+ value: 16.523
+ - type: ndcg_at_7
+ value: 17.979
+ - type: ndcg_at_10
+ value: 19.908
+ - type: ndcg_at_20
+ value: 22.887
+ - type: ndcg_at_30
+ value: 24.43
+ - type: ndcg_at_50
+ value: 25.959
+ - type: ndcg_at_70
+ value: 26.989
+ - type: ndcg_at_100
+ value: 27.977
+ - type: ndcg_at_200
+ value: 29.831000000000003
+ - type: ndcg_at_300
+ value: 30.787
+ - type: ndcg_at_500
+ value: 31.974999999999998
+ - type: ndcg_at_700
+ value: 32.554
+ - type: ndcg_at_1000
+ value: 33.277
+ - type: map_at_1
+ value: 4.593
+ - type: map_at_2
+ value: 6.923
+ - type: map_at_3
+ value: 8.3
+ - type: map_at_5
+ value: 10.072000000000001
+ - type: map_at_7
+ value: 10.782
+ - type: map_at_10
+ value: 11.72
+ - type: map_at_20
+ value: 12.838
+ - type: map_at_30
+ value: 13.257
+ - type: map_at_50
+ value: 13.569
+ - type: map_at_70
+ value: 13.733
+ - type: map_at_100
+ value: 13.858999999999998
+ - type: map_at_200
+ value: 14.018
+ - type: map_at_300
+ value: 14.072999999999999
+ - type: map_at_500
+ value: 14.126
+ - type: map_at_700
+ value: 14.145
+ - type: map_at_1000
+ value: 14.161999999999999
+ - type: recall_at_1
+ value: 4.593
+ - type: recall_at_2
+ value: 7.997999999999999
+ - type: recall_at_3
+ value: 10.563
+ - type: recall_at_5
+ value: 14.907
+ - type: recall_at_7
+ value: 17.4
+ - type: recall_at_10
+ value: 21.18
+ - type: recall_at_20
+ value: 28.144999999999996
+ - type: recall_at_30
+ value: 32.462
+ - type: recall_at_50
+ value: 37.267
+ - type: recall_at_70
+ value: 40.875
+ - type: recall_at_100
+ value: 44.641999999999996
+ - type: recall_at_200
+ value: 52.573
+ - type: recall_at_300
+ value: 57.089999999999996
+ - type: recall_at_500
+ value: 63.14300000000001
+ - type: recall_at_700
+ value: 66.313
+ - type: recall_at_1000
+ value: 70.458
+ - type: precision_at_1
+ value: 22.6
+ - type: precision_at_2
+ value: 19.7
+ - type: precision_at_3
+ value: 17.333000000000002
+ - type: precision_at_5
+ value: 14.680000000000001
+ - type: precision_at_7
+ value: 12.243
+ - type: precision_at_10
+ value: 10.440000000000001
+ - type: precision_at_20
+ value: 6.944999999999999
+ - type: precision_at_30
+ value: 5.333
+ - type: precision_at_50
+ value: 3.678
+ - type: precision_at_70
+ value: 2.881
+ - type: precision_at_100
+ value: 2.2030000000000003
+ - type: precision_at_200
+ value: 1.295
+ - type: precision_at_300
+ value: 0.9369999999999999
+ - type: precision_at_500
+ value: 0.622
+ - type: precision_at_700
+ value: 0.466
+ - type: precision_at_1000
+ value: 0.347
+ - type: mrr_at_1
+ value: 22.6
+ - type: mrr_at_2
+ value: 27.900000000000002
+ - type: mrr_at_3
+ value: 30.067
+ - type: mrr_at_5
+ value: 32.207
+ - type: mrr_at_7
+ value: 33.004
+ - type: mrr_at_10
+ value: 33.596
+ - type: mrr_at_20
+ value: 34.268
+ - type: mrr_at_30
+ value: 34.492
+ - type: mrr_at_50
+ value: 34.628
+ - type: mrr_at_70
+ value: 34.681
+ - type: mrr_at_100
+ value: 34.717
+ - type: mrr_at_200
+ value: 34.757
+ - type: mrr_at_300
+ value: 34.768
+ - type: mrr_at_500
+ value: 34.772
+ - type: mrr_at_700
+ value: 34.774
+ - type: mrr_at_1000
+ value: 34.775
+ - task:
+ type: STS
+ dataset:
+ type: mteb/sickr-sts
+ name: MTEB SICK-R
+ config: default
+ split: test
+ revision: a6ea5a8cab320b040a23452cc28066d9beae2cee
+ metrics:
+ - type: cos_sim_pearson
+ value: 86.90122745229677
+ - type: cos_sim_spearman
+ value: 82.92294737327579
+ - type: euclidean_pearson
+ value: 84.08979655773187
+ - type: euclidean_spearman
+ value: 82.92294657285412
+ - type: manhattan_pearson
+ value: 84.09347480531832
+ - type: manhattan_spearman
+ value: 82.91564613948087
+ - task:
+ type: STS
+ dataset:
+ type: mteb/sts12-sts
+ name: MTEB STS12
+ config: default
+ split: test
+ revision: a0d554a64d88156834ff5ae9920b964011b16384
+ metrics:
+ - type: cos_sim_pearson
+ value: 87.01218713698583
+ - type: cos_sim_spearman
+ value: 79.46865215168464
+ - type: euclidean_pearson
+ value: 83.22621889891909
+ - type: euclidean_spearman
+ value: 79.46853821709514
+ - type: manhattan_pearson
+ value: 83.69962580788805
+ - type: manhattan_spearman
+ value: 79.9561593356932
+ - task:
+ type: STS
+ dataset:
+ type: mteb/sts13-sts
+ name: MTEB STS13
+ config: default
+ split: test
+ revision: 7e90230a92c190f1bf69ae9002b8cea547a64cca
+ metrics:
+ - type: cos_sim_pearson
+ value: 88.98438696342964
+ - type: cos_sim_spearman
+ value: 89.15419511870839
+ - type: euclidean_pearson
+ value: 88.49646141802894
+ - type: euclidean_spearman
+ value: 89.15419503946019
+ - type: manhattan_pearson
+ value: 88.6420585616327
+ - type: manhattan_spearman
+ value: 89.42648950757743
+ - task:
+ type: STS
+ dataset:
+ type: mteb/sts14-sts
+ name: MTEB STS14
+ config: default
+ split: test
+ revision: 6031580fec1f6af667f0bd2da0a551cf4f0b2375
+ metrics:
+ - type: cos_sim_pearson
+ value: 87.30772547759544
+ - type: cos_sim_spearman
+ value: 84.93199878424691
+ - type: euclidean_pearson
+ value: 86.16266630395455
+ - type: euclidean_spearman
+ value: 84.93198798543634
+ - type: manhattan_pearson
+ value: 86.14285723189803
+ - type: manhattan_spearman
+ value: 85.0361672522687
+ - task:
+ type: STS
+ dataset:
+ type: mteb/sts15-sts
+ name: MTEB STS15
+ config: default
+ split: test
+ revision: ae752c7c21bf194d8b67fd573edf7ae58183cbe3
+ metrics:
+ - type: cos_sim_pearson
+ value: 90.21342071197127
+ - type: cos_sim_spearman
+ value: 90.7407512744838
+ - type: euclidean_pearson
+ value: 90.1517933113061
+ - type: euclidean_spearman
+ value: 90.74075125431919
+ - type: manhattan_pearson
+ value: 90.17963034676193
+ - type: manhattan_spearman
+ value: 90.88999275865135
+ - task:
+ type: STS
+ dataset:
+ type: mteb/sts16-sts
+ name: MTEB STS16
+ config: default
+ split: test
+ revision: 4d8694f8f0e0100860b497b999b3dbed754a0513
+ metrics:
+ - type: cos_sim_pearson
+ value: 86.82518054100498
+ - type: cos_sim_spearman
+ value: 87.81570533154735
+ - type: euclidean_pearson
+ value: 86.91684561573618
+ - type: euclidean_spearman
+ value: 87.81570533154735
+ - type: manhattan_pearson
+ value: 86.98311935744032
+ - type: manhattan_spearman
+ value: 87.9594667151966
+ - task:
+ type: STS
+ dataset:
+ type: mteb/sts17-crosslingual-sts
+ name: MTEB STS17 (en-en)
+ config: en-en
+ split: test
+ revision: af5e6fb845001ecf41f4c1e033ce921939a2a68d
+ metrics:
+ - type: cos_sim_pearson
+ value: 92.09578436612053
+ - type: cos_sim_spearman
+ value: 92.01519349090438
+ - type: euclidean_pearson
+ value: 92.07113635890894
+ - type: euclidean_spearman
+ value: 92.01519349090438
+ - type: manhattan_pearson
+ value: 91.89343820765625
+ - type: manhattan_spearman
+ value: 91.7443476810177
+ - task:
+ type: STS
+ dataset:
+ type: mteb/sts22-crosslingual-sts
+ name: MTEB STS22 (en)
+ config: en
+ split: test
+ revision: 6d1ba47164174a496b7fa5d3569dae26a6813b80
+ metrics:
+ - type: cos_sim_pearson
+ value: 69.29997751464549
+ - type: cos_sim_spearman
+ value: 68.36425436812782
+ - type: euclidean_pearson
+ value: 69.81381677661783
+ - type: euclidean_spearman
+ value: 68.36425436812782
+ - type: manhattan_pearson
+ value: 69.92823397008026
+ - type: manhattan_spearman
+ value: 68.35770640039254
+ - task:
+ type: STS
+ dataset:
+ type: mteb/stsbenchmark-sts
+ name: MTEB STSBenchmark
+ config: default
+ split: test
+ revision: b0fddb56ed78048fa8b90373c8a3cfc37b684831
+ metrics:
+ - type: cos_sim_pearson
+ value: 88.39126315452359
+ - type: cos_sim_spearman
+ value: 88.99708463265337
+ - type: euclidean_pearson
+ value: 88.60793820038607
+ - type: euclidean_spearman
+ value: 88.99708463265337
+ - type: manhattan_pearson
+ value: 88.69860633571047
+ - type: manhattan_spearman
+ value: 89.20094593888012
+ - task:
+ type: Reranking
+ dataset:
+ type: mteb/scidocs-reranking
+ name: MTEB SciDocsRR
+ config: default
+ split: test
+ revision: d3c5e1fc0b855ab6097bf1cda04dd73947d7caab
+ metrics:
+ - type: map
+ value: 86.58028062818582
+ - type: mrr
+ value: 96.53586790841693
+ - task:
+ type: Retrieval
+ dataset:
+ type: scifact
+ name: MTEB SciFact
+ config: default
+ split: test
+ revision: None
+ metrics:
+ - type: ndcg_at_1
+ value: 66.333
+ - type: ndcg_at_2
+ value: 70.655
+ - type: ndcg_at_3
+ value: 72.801
+ - type: ndcg_at_5
+ value: 75.793
+ - type: ndcg_at_7
+ value: 76.946
+ - type: ndcg_at_10
+ value: 77.66199999999999
+ - type: ndcg_at_20
+ value: 78.786
+ - type: ndcg_at_30
+ value: 79.066
+ - type: ndcg_at_50
+ value: 79.255
+ - type: ndcg_at_70
+ value: 79.423
+ - type: ndcg_at_100
+ value: 79.476
+ - type: ndcg_at_200
+ value: 79.65299999999999
+ - type: ndcg_at_300
+ value: 79.696
+ - type: ndcg_at_500
+ value: 79.73599999999999
+ - type: ndcg_at_700
+ value: 79.77199999999999
+ - type: ndcg_at_1000
+ value: 79.77199999999999
+ - type: map_at_1
+ value: 63.383
+ - type: map_at_2
+ value: 68.144
+ - type: map_at_3
+ value: 70.19800000000001
+ - type: map_at_5
+ value: 72.38
+ - type: map_at_7
+ value: 72.955
+ - type: map_at_10
+ value: 73.312
+ - type: map_at_20
+ value: 73.678
+ - type: map_at_30
+ value: 73.72800000000001
+ - type: map_at_50
+ value: 73.75500000000001
+ - type: map_at_70
+ value: 73.771
+ - type: map_at_100
+ value: 73.776
+ - type: map_at_200
+ value: 73.783
+ - type: map_at_300
+ value: 73.784
+ - type: map_at_500
+ value: 73.785
+ - type: map_at_700
+ value: 73.786
+ - type: map_at_1000
+ value: 73.786
+ - type: recall_at_1
+ value: 63.383
+ - type: recall_at_2
+ value: 72.283
+ - type: recall_at_3
+ value: 77.183
+ - type: recall_at_5
+ value: 84.56099999999999
+ - type: recall_at_7
+ value: 87.67200000000001
+ - type: recall_at_10
+ value: 89.822
+ - type: recall_at_20
+ value: 94
+ - type: recall_at_30
+ value: 95.333
+ - type: recall_at_50
+ value: 96.333
+ - type: recall_at_70
+ value: 97.333
+ - type: recall_at_100
+ value: 97.667
+ - type: recall_at_200
+ value: 99
+ - type: recall_at_300
+ value: 99.333
+ - type: recall_at_500
+ value: 99.667
+ - type: recall_at_700
+ value: 100
+ - type: recall_at_1000
+ value: 100
+ - type: precision_at_1
+ value: 66.333
+ - type: precision_at_2
+ value: 38.667
+ - type: precision_at_3
+ value: 28.111000000000004
+ - type: precision_at_5
+ value: 18.933
+ - type: precision_at_7
+ value: 14.094999999999999
+ - type: precision_at_10
+ value: 10.167
+ - type: precision_at_20
+ value: 5.35
+ - type: precision_at_30
+ value: 3.611
+ - type: precision_at_50
+ value: 2.1870000000000003
+ - type: precision_at_70
+ value: 1.576
+ - type: precision_at_100
+ value: 1.107
+ - type: precision_at_200
+ value: 0.5599999999999999
+ - type: precision_at_300
+ value: 0.374
+ - type: precision_at_500
+ value: 0.22499999999999998
+ - type: precision_at_700
+ value: 0.161
+ - type: precision_at_1000
+ value: 0.11299999999999999
+ - type: mrr_at_1
+ value: 66.333
+ - type: mrr_at_2
+ value: 70.833
+ - type: mrr_at_3
+ value: 72.167
+ - type: mrr_at_5
+ value: 73.6
+ - type: mrr_at_7
+ value: 74.084
+ - type: mrr_at_10
+ value: 74.283
+ - type: mrr_at_20
+ value: 74.54499999999999
+ - type: mrr_at_30
+ value: 74.59599999999999
+ - type: mrr_at_50
+ value: 74.622
+ - type: mrr_at_70
+ value: 74.639
+ - type: mrr_at_100
+ value: 74.643
+ - type: mrr_at_200
+ value: 74.65
+ - type: mrr_at_300
+ value: 74.652
+ - type: mrr_at_500
+ value: 74.653
+ - type: mrr_at_700
+ value: 74.653
+ - type: mrr_at_1000
+ value: 74.653
+ - task:
+ type: PairClassification
+ dataset:
+ type: mteb/sprintduplicatequestions-pairclassification
+ name: MTEB SprintDuplicateQuestions
+ config: default
+ split: test
+ revision: d66bd1f72af766a5cc4b0ca5e00c162f89e8cc46
+ metrics:
+ - type: cos_sim_accuracy
+ value: 99.84554455445544
+ - type: cos_sim_ap
+ value: 96.31178339136798
+ - type: cos_sim_f1
+ value: 92.1921921921922
+ - type: cos_sim_precision
+ value: 92.28456913827655
+ - type: cos_sim_recall
+ value: 92.10000000000001
+ - type: dot_accuracy
+ value: 99.84554455445544
+ - type: dot_ap
+ value: 96.31178339136797
+ - type: dot_f1
+ value: 92.1921921921922
+ - type: dot_precision
+ value: 92.28456913827655
+ - type: dot_recall
+ value: 92.10000000000001
+ - type: euclidean_accuracy
+ value: 99.84554455445544
+ - type: euclidean_ap
+ value: 96.31178339136798
+ - type: euclidean_f1
+ value: 92.1921921921922
+ - type: euclidean_precision
+ value: 92.28456913827655
+ - type: euclidean_recall
+ value: 92.10000000000001
+ - type: manhattan_accuracy
+ value: 99.84752475247525
+ - type: manhattan_ap
+ value: 96.4591954606088
+ - type: manhattan_f1
+ value: 92.25352112676056
+ - type: manhattan_precision
+ value: 92.81376518218623
+ - type: manhattan_recall
+ value: 91.7
+ - type: max_accuracy
+ value: 99.84752475247525
+ - type: max_ap
+ value: 96.4591954606088
+ - type: max_f1
+ value: 92.25352112676056
+ - task:
+ type: Clustering
+ dataset:
+ type: mteb/stackexchange-clustering
+ name: MTEB StackExchangeClustering
+ config: default
+ split: test
+ revision: 6cbc1f7b2bc0622f2e39d2c77fa502909748c259
+ metrics:
+ - type: v_measure
+ value: 74.24659759283294
+ - task:
+ type: Clustering
+ dataset:
+ type: mteb/stackexchange-clustering-p2p
+ name: MTEB StackExchangeClusteringP2P
+ config: default
+ split: test
+ revision: 815ca46b2622cec33ccafc3735d572c266efdb44
+ metrics:
+ - type: v_measure
+ value: 46.77690051260451
+ - task:
+ type: Reranking
+ dataset:
+ type: mteb/stackoverflowdupquestions-reranking
+ name: MTEB StackOverflowDupQuestions
+ config: default
+ split: test
+ revision: e185fbe320c72810689fc5848eb6114e1ef5ec69
+ metrics:
+ - type: map
+ value: 55.68436757803185
+ - type: mrr
+ value: 56.82157711569475
+ - task:
+ type: Summarization
+ dataset:
+ type: mteb/summeval
+ name: MTEB SummEval
+ config: default
+ split: test
+ revision: cda12ad7615edc362dbf25a00fdd61d3b1eaf93c
+ metrics:
+ - type: cos_sim_pearson
+ value: 31.652482405629843
+ - type: cos_sim_spearman
+ value: 31.16341822347735
+ - type: dot_pearson
+ value: 31.652479892699837
+ - type: dot_spearman
+ value: 31.16341822347735
+ - task:
+ type: Retrieval
+ dataset:
+ type: trec-covid
+ name: MTEB TRECCOVID
+ config: default
+ split: test
+ revision: None
+ metrics:
+ - type: ndcg_at_1
+ value: 92
+ - type: ndcg_at_2
+ value: 90.839
+ - type: ndcg_at_3
+ value: 90.642
+ - type: ndcg_at_5
+ value: 90.348
+ - type: ndcg_at_7
+ value: 89.015
+ - type: ndcg_at_10
+ value: 87.599
+ - type: ndcg_at_20
+ value: 84.434
+ - type: ndcg_at_30
+ value: 81.655
+ - type: ndcg_at_50
+ value: 77.278
+ - type: ndcg_at_70
+ value: 73.957
+ - type: ndcg_at_100
+ value: 69.56
+ - type: ndcg_at_200
+ value: 60.724000000000004
+ - type: ndcg_at_300
+ value: 57.245000000000005
+ - type: ndcg_at_500
+ value: 56.316
+ - type: ndcg_at_700
+ value: 58.399
+ - type: ndcg_at_1000
+ value: 62.21600000000001
+ - type: map_at_1
+ value: 0.247
+ - type: map_at_2
+ value: 0.488
+ - type: map_at_3
+ value: 0.7230000000000001
+ - type: map_at_5
+ value: 1.204
+ - type: map_at_7
+ value: 1.6500000000000001
+ - type: map_at_10
+ value: 2.292
+ - type: map_at_20
+ value: 4.274
+ - type: map_at_30
+ value: 6.027
+ - type: map_at_50
+ value: 9.083
+ - type: map_at_70
+ value: 11.751000000000001
+ - type: map_at_100
+ value: 14.912
+ - type: map_at_200
+ value: 22.213
+ - type: map_at_300
+ value: 26.667999999999996
+ - type: map_at_500
+ value: 31.556
+ - type: map_at_700
+ value: 34.221000000000004
+ - type: map_at_1000
+ value: 36.443999999999996
+ - type: recall_at_1
+ value: 0.247
+ - type: recall_at_2
+ value: 0.49899999999999994
+ - type: recall_at_3
+ value: 0.742
+ - type: recall_at_5
+ value: 1.247
+ - type: recall_at_7
+ value: 1.722
+ - type: recall_at_10
+ value: 2.405
+ - type: recall_at_20
+ value: 4.583
+ - type: recall_at_30
+ value: 6.587999999999999
+ - type: recall_at_50
+ value: 10.188
+ - type: recall_at_70
+ value: 13.496
+ - type: recall_at_100
+ value: 17.578
+ - type: recall_at_200
+ value: 28.158
+ - type: recall_at_300
+ value: 35.532000000000004
+ - type: recall_at_500
+ value: 45.31
+ - type: recall_at_700
+ value: 51.822
+ - type: recall_at_1000
+ value: 58.53
+ - type: precision_at_1
+ value: 96
+ - type: precision_at_2
+ value: 96
+ - type: precision_at_3
+ value: 95.333
+ - type: precision_at_5
+ value: 94.8
+ - type: precision_at_7
+ value: 93.429
+ - type: precision_at_10
+ value: 91.4
+ - type: precision_at_20
+ value: 87.7
+ - type: precision_at_30
+ value: 84.867
+ - type: precision_at_50
+ value: 80.24
+ - type: precision_at_70
+ value: 76.371
+ - type: precision_at_100
+ value: 71.08
+ - type: precision_at_200
+ value: 59.4
+ - type: precision_at_300
+ value: 51.459999999999994
+ - type: precision_at_500
+ value: 40.644000000000005
+ - type: precision_at_700
+ value: 33.889
+ - type: precision_at_1000
+ value: 27.250000000000004
+ - type: mrr_at_1
+ value: 96
+ - type: mrr_at_2
+ value: 98
+ - type: mrr_at_3
+ value: 98
+ - type: mrr_at_5
+ value: 98
+ - type: mrr_at_7
+ value: 98
+ - type: mrr_at_10
+ value: 98
+ - type: mrr_at_20
+ value: 98
+ - type: mrr_at_30
+ value: 98
+ - type: mrr_at_50
+ value: 98
+ - type: mrr_at_70
+ value: 98
+ - type: mrr_at_100
+ value: 98
+ - type: mrr_at_200
+ value: 98
+ - type: mrr_at_300
+ value: 98
+ - type: mrr_at_500
+ value: 98
+ - type: mrr_at_700
+ value: 98
+ - type: mrr_at_1000
+ value: 98
+ - task:
+ type: Retrieval
+ dataset:
+ type: webis-touche2020
+ name: MTEB Touche2020
+ config: default
+ split: test
+ revision: None
+ metrics:
+ - type: ndcg_at_1
+ value: 43.878
+ - type: ndcg_at_2
+ value: 37.956
+ - type: ndcg_at_3
+ value: 35.053
+ - type: ndcg_at_5
+ value: 32.59
+ - type: ndcg_at_7
+ value: 30.226
+ - type: ndcg_at_10
+ value: 29.005
+ - type: ndcg_at_20
+ value: 30.11
+ - type: ndcg_at_30
+ value: 32.019999999999996
+ - type: ndcg_at_50
+ value: 34.354
+ - type: ndcg_at_70
+ value: 36.665
+ - type: ndcg_at_100
+ value: 38.888
+ - type: ndcg_at_200
+ value: 43.435
+ - type: ndcg_at_300
+ value: 45.795
+ - type: ndcg_at_500
+ value: 48.699999999999996
+ - type: ndcg_at_700
+ value: 50.242
+ - type: ndcg_at_1000
+ value: 51.529
+ - type: map_at_1
+ value: 3.521
+ - type: map_at_2
+ value: 5.309
+ - type: map_at_3
+ value: 6.576
+ - type: map_at_5
+ value: 8.97
+ - type: map_at_7
+ value: 10.194
+ - type: map_at_10
+ value: 11.949
+ - type: map_at_20
+ value: 14.686
+ - type: map_at_30
+ value: 15.8
+ - type: map_at_50
+ value: 16.59
+ - type: map_at_70
+ value: 17.2
+ - type: map_at_100
+ value: 17.765
+ - type: map_at_200
+ value: 18.636
+ - type: map_at_300
+ value: 18.972
+ - type: map_at_500
+ value: 19.301
+ - type: map_at_700
+ value: 19.445
+ - type: map_at_1000
+ value: 19.546
+ - type: recall_at_1
+ value: 3.521
+ - type: recall_at_2
+ value: 5.848
+ - type: recall_at_3
+ value: 7.657
+ - type: recall_at_5
+ value: 11.368
+ - type: recall_at_7
+ value: 13.748
+ - type: recall_at_10
+ value: 18.061
+ - type: recall_at_20
+ value: 26.844
+ - type: recall_at_30
+ value: 31.186000000000003
+ - type: recall_at_50
+ value: 35.951
+ - type: recall_at_70
+ value: 40.961999999999996
+ - type: recall_at_100
+ value: 46.743
+ - type: recall_at_200
+ value: 58.483
+ - type: recall_at_300
+ value: 65.973
+ - type: recall_at_500
+ value: 75.233
+ - type: recall_at_700
+ value: 80.472
+ - type: recall_at_1000
+ value: 85.02
+ - type: precision_at_1
+ value: 46.939
+ - type: precision_at_2
+ value: 38.775999999999996
+ - type: precision_at_3
+ value: 34.694
+ - type: precision_at_5
+ value: 31.429000000000002
+ - type: precision_at_7
+ value: 27.697
+ - type: precision_at_10
+ value: 24.490000000000002
+ - type: precision_at_20
+ value: 18.776
+ - type: precision_at_30
+ value: 15.034
+ - type: precision_at_50
+ value: 10.857
+ - type: precision_at_70
+ value: 9.096
+ - type: precision_at_100
+ value: 7.51
+ - type: precision_at_200
+ value: 4.929
+ - type: precision_at_300
+ value: 3.7760000000000002
+ - type: precision_at_500
+ value: 2.6780000000000004
+ - type: precision_at_700
+ value: 2.085
+ - type: precision_at_1000
+ value: 1.5709999999999997
+ - type: mrr_at_1
+ value: 46.939
+ - type: mrr_at_2
+ value: 55.102
+ - type: mrr_at_3
+ value: 57.823
+ - type: mrr_at_5
+ value: 60.68
+ - type: mrr_at_7
+ value: 60.972
+ - type: mrr_at_10
+ value: 61.199000000000005
+ - type: mrr_at_20
+ value: 61.831
+ - type: mrr_at_30
+ value: 61.831
+ - type: mrr_at_50
+ value: 61.873
+ - type: mrr_at_70
+ value: 61.873
+ - type: mrr_at_100
+ value: 61.873
+ - type: mrr_at_200
+ value: 61.873
+ - type: mrr_at_300
+ value: 61.873
+ - type: mrr_at_500
+ value: 61.873
+ - type: mrr_at_700
+ value: 61.873
+ - type: mrr_at_1000
+ value: 61.873
+ - task:
+ type: Classification
+ dataset:
+ type: mteb/toxic_conversations_50k
+ name: MTEB ToxicConversationsClassification
+ config: default
+ split: test
+ revision: d7c0de2777da35d6aae2200a62c6e0e5af397c4c
+ metrics:
+ - type: accuracy
+ value: 69.3294
+ - type: ap
+ value: 14.561333393364736
+ - type: f1
+ value: 53.992309820496466
+ - task:
+ type: Classification
+ dataset:
+ type: mteb/tweet_sentiment_extraction
+ name: MTEB TweetSentimentExtractionClassification
+ config: default
+ split: test
+ revision: d604517c81ca91fe16a244d1248fc021f9ecee7a
+ metrics:
+ - type: accuracy
+ value: 63.63893604980192
+ - type: f1
+ value: 63.92959380489434
+ - task:
+ type: Clustering
+ dataset:
+ type: mteb/twentynewsgroups-clustering
+ name: MTEB TwentyNewsgroupsClustering
+ config: default
+ split: test
+ revision: 6125ec4e24fa026cec8a478383ee943acfbd5449
+ metrics:
+ - type: v_measure
+ value: 56.270879258659775
+ - task:
+ type: PairClassification
+ dataset:
+ type: mteb/twittersemeval2015-pairclassification
+ name: MTEB TwitterSemEval2015
+ config: default
+ split: test
+ revision: 70970daeab8776df92f5ea462b6173c0b46fd2d1
+ metrics:
+ - type: cos_sim_accuracy
+ value: 88.71073493473207
+ - type: cos_sim_ap
+ value: 81.52392540284202
+ - type: cos_sim_f1
+ value: 74.71162377994676
+ - type: cos_sim_precision
+ value: 71.89558428885094
+ - type: cos_sim_recall
+ value: 77.75725593667546
+ - type: dot_accuracy
+ value: 88.71073493473207
+ - type: dot_ap
+ value: 81.52394754041109
+ - type: dot_f1
+ value: 74.71162377994676
+ - type: dot_precision
+ value: 71.89558428885094
+ - type: dot_recall
+ value: 77.75725593667546
+ - type: euclidean_accuracy
+ value: 88.71073493473207
+ - type: euclidean_ap
+ value: 81.52392035435321
+ - type: euclidean_f1
+ value: 74.71162377994676
+ - type: euclidean_precision
+ value: 71.89558428885094
+ - type: euclidean_recall
+ value: 77.75725593667546
+ - type: manhattan_accuracy
+ value: 88.47231328604637
+ - type: manhattan_ap
+ value: 81.22907439267321
+ - type: manhattan_f1
+ value: 74.3351571446749
+ - type: manhattan_precision
+ value: 71.78667977390022
+ - type: manhattan_recall
+ value: 77.0712401055409
+ - type: max_accuracy
+ value: 88.71073493473207
+ - type: max_ap
+ value: 81.52394754041109
+ - type: max_f1
+ value: 74.71162377994676
+ - task:
+ type: PairClassification
+ dataset:
+ type: mteb/twitterurlcorpus-pairclassification
+ name: MTEB TwitterURLCorpus
+ config: default
+ split: test
+ revision: 8b6510b0b1fa4e4c4f879467980e9be563ec1cdf
+ metrics:
+ - type: cos_sim_accuracy
+ value: 89.85136026700819
+ - type: cos_sim_ap
+ value: 87.7768002924216
+ - type: cos_sim_f1
+ value: 80.358908624794
+ - type: cos_sim_precision
+ value: 76.62918209122023
+ - type: cos_sim_recall
+ value: 84.47028025870034
+ - type: dot_accuracy
+ value: 89.85136026700819
+ - type: dot_ap
+ value: 87.77680027889778
+ - type: dot_f1
+ value: 80.358908624794
+ - type: dot_precision
+ value: 76.62918209122023
+ - type: dot_recall
+ value: 84.47028025870034
+ - type: euclidean_accuracy
+ value: 89.85136026700819
+ - type: euclidean_ap
+ value: 87.77680174697751
+ - type: euclidean_f1
+ value: 80.358908624794
+ - type: euclidean_precision
+ value: 76.62918209122023
+ - type: euclidean_recall
+ value: 84.47028025870034
+ - type: manhattan_accuracy
+ value: 89.86300306593705
+ - type: manhattan_ap
+ value: 87.78613271895861
+ - type: manhattan_f1
+ value: 80.31831016905645
+ - type: manhattan_precision
+ value: 76.68230516070304
+ - type: manhattan_recall
+ value: 84.3162919618109
+ - type: max_accuracy
+ value: 89.86300306593705
+ - type: max_ap
+ value: 87.78613271895861
+ - type: max_f1
+ value: 80.358908624794
+language:
+- en
+license: cc-by-nc-4.0
+---
+
+Salesforce/SFR-Embedding-Mistral
+
+**SFR-Embedding by Salesforce Research.**
+
+The model is trained on top of [E5-mistral-7b-instruct](https://huggingface.co/intfloat/e5-mistral-7b-instruct) and [Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1).
+
+This project is for research purposes only. Third-party datasets may be subject to additional terms and conditions under their associated licenses. Please refer to specific papers for more details:
+- [MTEB benchmark](https://arxiv.org/abs/2210.07316)
+- [Mistral](https://arxiv.org/abs/2310.06825)
+- [E5-mistral-7b-instruct](https://arxiv.org/pdf/2401.00368.pdf)
+
+More technical details will be updated later.
+
+## How to run
+
+### Transformers
+The models can be used as follows:
+```python
+import torch
+import torch.nn.functional as F
+from torch import Tensor
+from transformers import AutoTokenizer, AutoModel
+
+def last_token_pool(last_hidden_states: Tensor,
+ attention_mask: Tensor) -> Tensor:
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
+ if left_padding:
+ return last_hidden_states[:, -1]
+ else:
+ sequence_lengths = attention_mask.sum(dim=1) - 1
+ batch_size = last_hidden_states.shape[0]
+ return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
+
+def get_detailed_instruct(task_description: str, query: str) -> str:
+ return f'Instruct: {task_description}\nQuery: {query}'
+
+# Each query must come with a one-sentence instruction that describes the task
+task = 'Given a web search query, retrieve relevant passages that answer the query'
+queries = [
+ get_detailed_instruct(task, 'How to bake a chocolate cake'),
+ get_detailed_instruct(task, 'Symptoms of the flu')
+]
+# No need to add instruction for retrieval documents
+passages = [
+ "To bake a delicious chocolate cake, you'll need the following ingredients: all-purpose flour, sugar, cocoa powder, baking powder, baking soda, salt, eggs, milk, vegetable oil, and vanilla extract. Start by preheating your oven to 350°F (175°C). In a mixing bowl, combine the dry ingredients (flour, sugar, cocoa powder, baking powder, baking soda, and salt). In a separate bowl, whisk together the wet ingredients (eggs, milk, vegetable oil, and vanilla extract). Gradually add the wet mixture to the dry ingredients, stirring until well combined. Pour the batter into a greased cake pan and bake for 30-35 minutes. Let it cool before frosting with your favorite chocolate frosting. Enjoy your homemade chocolate cake!",
+ "The flu, or influenza, is an illness caused by influenza viruses. Common symptoms of the flu include a high fever, chills, cough, sore throat, runny or stuffy nose, body aches, headache, fatigue, and sometimes nausea and vomiting. These symptoms can come on suddenly and are usually more severe than the common cold. It's important to get plenty of rest, stay hydrated, and consult a healthcare professional if you suspect you have the flu. In some cases, antiviral medications can help alleviate symptoms and reduce the duration of the illness."
+]
+
+# load model and tokenizer
+tokenizer = AutoTokenizer.from_pretrained('Salesforce/SFR-Embedding-Mistral')
+model = AutoModel.from_pretrained('Salesforce/SFR-Embedding-Mistral')
+
+# get the embeddings
+max_length = 4096
+input_texts = queries + passages
+batch_dict = tokenizer(input_texts, max_length=max_length, padding=True, truncation=True, return_tensors="pt")
+outputs = model(**batch_dict)
+embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
+
+# normalize embeddings
+embeddings = F.normalize(embeddings, p=2, dim=1)
+scores = (embeddings[:2] @ embeddings[2:].T) * 100
+print(scores.tolist())
+# [[86.7153549194336, 36.64569091796875], [35.00493621826172, 82.0738525390625]]
+```
+
+### Sentence Transformers
+```python
+
+from sentence_transformers import SentenceTransformer, util
+
+model = SentenceTransformer("Salesforce/SFR-Embedding-Mistral")
+
+def get_detailed_instruct(task_description: str, query: str) -> str:
+ return f'Instruct: {task_description}\nQuery: {query}'
+
+# Each query must come with a one-sentence instruction that describes the task
+task = 'Given a web search query, retrieve relevant passages that answer the query'
+queries = [
+ get_detailed_instruct(task, 'How to bake a chocolate cake'),
+ get_detailed_instruct(task, 'Symptoms of the flu')
+]
+# No need to add instruction for retrieval documents
+passages = [
+ "To bake a delicious chocolate cake, you'll need the following ingredients: all-purpose flour, sugar, cocoa powder, baking powder, baking soda, salt, eggs, milk, vegetable oil, and vanilla extract. Start by preheating your oven to 350°F (175°C). In a mixing bowl, combine the dry ingredients (flour, sugar, cocoa powder, baking powder, baking soda, and salt). In a separate bowl, whisk together the wet ingredients (eggs, milk, vegetable oil, and vanilla extract). Gradually add the wet mixture to the dry ingredients, stirring until well combined. Pour the batter into a greased cake pan and bake for 30-35 minutes. Let it cool before frosting with your favorite chocolate frosting. Enjoy your homemade chocolate cake!",
+ "The flu, or influenza, is an illness caused by influenza viruses. Common symptoms of the flu include a high fever, chills, cough, sore throat, runny or stuffy nose, body aches, headache, fatigue, and sometimes nausea and vomiting. These symptoms can come on suddenly and are usually more severe than the common cold. It's important to get plenty of rest, stay hydrated, and consult a healthcare professional if you suspect you have the flu. In some cases, antiviral medications can help alleviate symptoms and reduce the duration of the illness."
+]
+
+embeddings = model.encode(queries + passages)
+scores = util.cos_sim(embeddings[:2], embeddings[2:]) * 100
+print(scores.tolist())
+# [[86.71537780761719, 36.645721435546875], [35.00497055053711, 82.07388305664062]]
+```
+
+### MTEB Benchmark Evaluation
+Check out [unilm/e5](https://github.com/microsoft/unilm/tree/master/e5) to reproduce evaluation results on the [BEIR](https://arxiv.org/abs/2104.08663) and [MTEB](https://arxiv.org/abs/2210.07316) benchmark.
+
+
+SFR-Embedding Team (∗indicates lead contributors).
+* Rui Meng*
+* Ye Liu*
+* Shafiq Rayhan Joty
+* Caiming Xiong
+* Yingbo Zhou
+* Semih Yavuz
+
+### Citation
+```bibtex
+@misc{SFRAIResearch2024,
+ title={SFR-Embedding-Mistral:Enhance Text Retrieval with Transfer Learning},
+ author={Rui Meng, Ye Liu, Shafiq Rayhan Joty, Caiming Xiong, Yingbo Zhou, Semih Yavuz},
+ howpublished={Salesforce AI Research Blog},
+ year={2024},
+ url={https://blog.salesforceairesearch.com/sfr-embedded-mistral/}
+}
+```
+
+
+
+
+
diff --git a/cache/models--Salesforce--SFR-Embedding-Mistral/refs/main b/cache/models--Salesforce--SFR-Embedding-Mistral/refs/main
new file mode 100644
index 0000000000000000000000000000000000000000..52fcbebdecce31962b9fe3baa327336ef71fc69b
--- /dev/null
+++ b/cache/models--Salesforce--SFR-Embedding-Mistral/refs/main
@@ -0,0 +1 @@
+938c560d1c236aa563b2dbdf084f28ab28bccb11
\ No newline at end of file
diff --git a/cache/models--Salesforce--SFR-Embedding-Mistral/snapshots/938c560d1c236aa563b2dbdf084f28ab28bccb11/README.md b/cache/models--Salesforce--SFR-Embedding-Mistral/snapshots/938c560d1c236aa563b2dbdf084f28ab28bccb11/README.md
new file mode 120000
index 0000000000000000000000000000000000000000..fdb2d58dce333110e81424de5633918cf0801afb
--- /dev/null
+++ b/cache/models--Salesforce--SFR-Embedding-Mistral/snapshots/938c560d1c236aa563b2dbdf084f28ab28bccb11/README.md
@@ -0,0 +1 @@
+../../blobs/feb95adc7e79e878999ba5a1d3ddfe9f16eff0f1
\ No newline at end of file
diff --git a/cache/models--Salesforce--SFR-Embedding-Mistral/snapshots/938c560d1c236aa563b2dbdf084f28ab28bccb11/config.json b/cache/models--Salesforce--SFR-Embedding-Mistral/snapshots/938c560d1c236aa563b2dbdf084f28ab28bccb11/config.json
new file mode 120000
index 0000000000000000000000000000000000000000..6f4d4ba5e01d216fb8fd9958867ac1c3f1a8b3b9
--- /dev/null
+++ b/cache/models--Salesforce--SFR-Embedding-Mistral/snapshots/938c560d1c236aa563b2dbdf084f28ab28bccb11/config.json
@@ -0,0 +1 @@
+../../blobs/c19160bba3c1267f959caf6d13fb07f9ea232e04
\ No newline at end of file
diff --git a/cache/models--Salesforce--SFR-Embedding-Mistral/snapshots/938c560d1c236aa563b2dbdf084f28ab28bccb11/config_sentence_transformers.json b/cache/models--Salesforce--SFR-Embedding-Mistral/snapshots/938c560d1c236aa563b2dbdf084f28ab28bccb11/config_sentence_transformers.json
new file mode 120000
index 0000000000000000000000000000000000000000..91ea6252de64e9edbd9258e62332899b10eba704
--- /dev/null
+++ b/cache/models--Salesforce--SFR-Embedding-Mistral/snapshots/938c560d1c236aa563b2dbdf084f28ab28bccb11/config_sentence_transformers.json
@@ -0,0 +1 @@
+../../blobs/ef62bf21fb2396937098b86ae80c68813b229c18
\ No newline at end of file
diff --git a/cache/models--Salesforce--SFR-Embedding-Mistral/snapshots/938c560d1c236aa563b2dbdf084f28ab28bccb11/model.safetensors.index.json b/cache/models--Salesforce--SFR-Embedding-Mistral/snapshots/938c560d1c236aa563b2dbdf084f28ab28bccb11/model.safetensors.index.json
new file mode 120000
index 0000000000000000000000000000000000000000..def9cfea579ec9eadefdbbe0ef6f506769ec5e3f
--- /dev/null
+++ b/cache/models--Salesforce--SFR-Embedding-Mistral/snapshots/938c560d1c236aa563b2dbdf084f28ab28bccb11/model.safetensors.index.json
@@ -0,0 +1 @@
+../../blobs/f8194e4e9432d287bf257d4a7d4a0f2446c32da8
\ No newline at end of file
diff --git a/cache/models--Salesforce--SFR-Embedding-Mistral/snapshots/938c560d1c236aa563b2dbdf084f28ab28bccb11/modules.json b/cache/models--Salesforce--SFR-Embedding-Mistral/snapshots/938c560d1c236aa563b2dbdf084f28ab28bccb11/modules.json
new file mode 120000
index 0000000000000000000000000000000000000000..140f6daae8a93f042bd8ed47d9117211abc67e04
--- /dev/null
+++ b/cache/models--Salesforce--SFR-Embedding-Mistral/snapshots/938c560d1c236aa563b2dbdf084f28ab28bccb11/modules.json
@@ -0,0 +1 @@
+../../blobs/f7640f94e81bb7f4f04daf1668850b38763a13d9
\ No newline at end of file
diff --git a/cache/models--Salesforce--SFR-Embedding-Mistral/snapshots/938c560d1c236aa563b2dbdf084f28ab28bccb11/sentence_bert_config.json b/cache/models--Salesforce--SFR-Embedding-Mistral/snapshots/938c560d1c236aa563b2dbdf084f28ab28bccb11/sentence_bert_config.json
new file mode 120000
index 0000000000000000000000000000000000000000..df3e74d8e228ef9d5556845776fa74cf1f08febf
--- /dev/null
+++ b/cache/models--Salesforce--SFR-Embedding-Mistral/snapshots/938c560d1c236aa563b2dbdf084f28ab28bccb11/sentence_bert_config.json
@@ -0,0 +1 @@
+../../blobs/42dcdfcaf9e42a488d4be06500dd771d7aa11e83
\ No newline at end of file
diff --git a/docker-compose.yml b/docker-compose.yml
new file mode 100644
index 0000000000000000000000000000000000000000..c2ba553dc0960a21fb41e66e565d041aad506b04
--- /dev/null
+++ b/docker-compose.yml
@@ -0,0 +1,61 @@
+version: "3.5"
+
+networks:
+ metavoice-net:
+ driver: bridge
+
+volumes:
+ hf-cache:
+ driver: local
+
+x-common-settings: &common-settings
+ volumes:
+ - hf-cache:/.hf-cache
+ - ./assets:/app/assets
+ deploy:
+ replicas: 1
+ resources:
+ reservations:
+ devices:
+ - driver: nvidia
+ count: 1
+ capabilities: [ gpu ]
+ runtime: nvidia
+ ipc: host
+ tty: true # enable colorized logs
+ build:
+ context: .
+ image: metavoice-server:latest
+ networks:
+ - metavoice-net
+ environment:
+ - NVIDIA_VISIBLE_DEVICES=all
+ - HF_HOME=/.hf-cache
+ logging:
+ options:
+ max-size: "100m"
+ max-file: "10"
+
+services:
+ server:
+ <<: *common-settings
+ container_name: metavoice-server
+ command: [ "--port=58004" ]
+ ports:
+ - 58004:58004
+ healthcheck:
+ test: [ "CMD", "curl", "http://metavoice-server:58004/health" ]
+ interval: 1m
+ timeout: 10s
+ retries: 20
+ ui:
+ <<: *common-settings
+ container_name: metavoice-ui
+ entrypoint: [ "python3.10", "app.py" ]
+ ports:
+ - 7861:7861
+ healthcheck:
+ test: [ "CMD", "curl", "http://localhost:7861" ]
+ interval: 1m
+ timeout: 10s
+ retries: 1
diff --git a/emo-knob-teaser-1.svg b/emo-knob-teaser-1.svg
new file mode 100644
index 0000000000000000000000000000000000000000..8407055221c536b84942af026485a62db1debb76
--- /dev/null
+++ b/emo-knob-teaser-1.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/fam/__init__.py b/fam/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/fam/__pycache__/__init__.cpython-310.pyc b/fam/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..47b56c305d53c44bca4b2a9ae5082661b6151fdf
Binary files /dev/null and b/fam/__pycache__/__init__.cpython-310.pyc differ
diff --git a/fam/__pycache__/__init__.cpython-39.pyc b/fam/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..804fd8405bbba4163ac11d41631e4c7a5dc2f617
Binary files /dev/null and b/fam/__pycache__/__init__.cpython-39.pyc differ
diff --git a/fam/llm/__init__.py b/fam/llm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/fam/llm/__pycache__/__init__.cpython-310.pyc b/fam/llm/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b3cb397dee1d46e23417fb47cb588abba00f8d92
Binary files /dev/null and b/fam/llm/__pycache__/__init__.cpython-310.pyc differ
diff --git a/fam/llm/__pycache__/__init__.cpython-39.pyc b/fam/llm/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8c09c93ef3e66dbf1de1afe77082981d1b3a0b38
Binary files /dev/null and b/fam/llm/__pycache__/__init__.cpython-39.pyc differ
diff --git a/fam/llm/__pycache__/decoders.cpython-310.pyc b/fam/llm/__pycache__/decoders.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a8143ce36981abf61f4a55e9be4cb7de6300be2b
Binary files /dev/null and b/fam/llm/__pycache__/decoders.cpython-310.pyc differ
diff --git a/fam/llm/__pycache__/decoders.cpython-39.pyc b/fam/llm/__pycache__/decoders.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..24bcacd46e2d8ba84f238f7b02437c86ee26554f
Binary files /dev/null and b/fam/llm/__pycache__/decoders.cpython-39.pyc differ
diff --git a/fam/llm/__pycache__/enhancers.cpython-310.pyc b/fam/llm/__pycache__/enhancers.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..912b17cae5699d5c74cd09a29ca36ced1c6fcb92
Binary files /dev/null and b/fam/llm/__pycache__/enhancers.cpython-310.pyc differ
diff --git a/fam/llm/__pycache__/enhancers.cpython-39.pyc b/fam/llm/__pycache__/enhancers.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..323135822d6e4da97ffa7714cb9527392f08546d
Binary files /dev/null and b/fam/llm/__pycache__/enhancers.cpython-39.pyc differ
diff --git a/fam/llm/__pycache__/fast_inference.cpython-310.pyc b/fam/llm/__pycache__/fast_inference.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..81a5e745f3fe74dbeccce5cb4d758546d9fc50e3
Binary files /dev/null and b/fam/llm/__pycache__/fast_inference.cpython-310.pyc differ
diff --git a/fam/llm/__pycache__/fast_inference.cpython-39.pyc b/fam/llm/__pycache__/fast_inference.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a0296986d41503bb91c7e19269f4b534b3421465
Binary files /dev/null and b/fam/llm/__pycache__/fast_inference.cpython-39.pyc differ
diff --git a/fam/llm/__pycache__/fast_inference_utils.cpython-310.pyc b/fam/llm/__pycache__/fast_inference_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..be53058bc7dc252803e2d1440de0c88951725e1c
Binary files /dev/null and b/fam/llm/__pycache__/fast_inference_utils.cpython-310.pyc differ
diff --git a/fam/llm/__pycache__/fast_inference_utils.cpython-39.pyc b/fam/llm/__pycache__/fast_inference_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..64b9f983435e9d7c552361e74070100820398a25
Binary files /dev/null and b/fam/llm/__pycache__/fast_inference_utils.cpython-39.pyc differ
diff --git a/fam/llm/__pycache__/fast_model.cpython-310.pyc b/fam/llm/__pycache__/fast_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..45499ec002d70e23a114cf9c97fc9f4d60ba6af1
Binary files /dev/null and b/fam/llm/__pycache__/fast_model.cpython-310.pyc differ
diff --git a/fam/llm/__pycache__/fast_model.cpython-39.pyc b/fam/llm/__pycache__/fast_model.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..28ecd66556ad07f42e41762bdb334837d195a868
Binary files /dev/null and b/fam/llm/__pycache__/fast_model.cpython-39.pyc differ
diff --git a/fam/llm/__pycache__/inference.cpython-310.pyc b/fam/llm/__pycache__/inference.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a4445ed489d65110af90ced18c538710b6c98cfe
Binary files /dev/null and b/fam/llm/__pycache__/inference.cpython-310.pyc differ
diff --git a/fam/llm/__pycache__/inference.cpython-39.pyc b/fam/llm/__pycache__/inference.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dc9a112198caccff5021ac248c287adf8111ad0d
Binary files /dev/null and b/fam/llm/__pycache__/inference.cpython-39.pyc differ
diff --git a/fam/llm/__pycache__/model.cpython-310.pyc b/fam/llm/__pycache__/model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5f22d748b91502e816635e6bcb5f63790c60919e
Binary files /dev/null and b/fam/llm/__pycache__/model.cpython-310.pyc differ
diff --git a/fam/llm/__pycache__/model.cpython-39.pyc b/fam/llm/__pycache__/model.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..70104d33389902a9fb5e85bcd4d9ec638629eb9e
Binary files /dev/null and b/fam/llm/__pycache__/model.cpython-39.pyc differ
diff --git a/fam/llm/__pycache__/utils.cpython-310.pyc b/fam/llm/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..adf67b7f86022c3592b9746d72aae141152c325f
Binary files /dev/null and b/fam/llm/__pycache__/utils.cpython-310.pyc differ
diff --git a/fam/llm/__pycache__/utils.cpython-39.pyc b/fam/llm/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1dd7a7c4f890fbfbe162a7efd5589017d4a0ea4e
Binary files /dev/null and b/fam/llm/__pycache__/utils.cpython-39.pyc differ
diff --git a/fam/llm/adapters/__init__.py b/fam/llm/adapters/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f31d5ee984ff00300d7eab9789d563e7809d8015
--- /dev/null
+++ b/fam/llm/adapters/__init__.py
@@ -0,0 +1,2 @@
+from fam.llm.adapters.flattened_encodec import FlattenedInterleavedEncodec2Codebook
+from fam.llm.adapters.tilted_encodec import TiltedEncodec
diff --git a/fam/llm/adapters/__pycache__/__init__.cpython-310.pyc b/fam/llm/adapters/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dc0bf44f96feadda0749356c06055dc91a3dce2a
Binary files /dev/null and b/fam/llm/adapters/__pycache__/__init__.cpython-310.pyc differ
diff --git a/fam/llm/adapters/__pycache__/__init__.cpython-39.pyc b/fam/llm/adapters/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..53a980c51623c6f7758c5a5f1a24ae83cb07ded6
Binary files /dev/null and b/fam/llm/adapters/__pycache__/__init__.cpython-39.pyc differ
diff --git a/fam/llm/adapters/__pycache__/base.cpython-310.pyc b/fam/llm/adapters/__pycache__/base.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8f526bcf77e5a3bc61b28b1c190cb5fd60152d04
Binary files /dev/null and b/fam/llm/adapters/__pycache__/base.cpython-310.pyc differ
diff --git a/fam/llm/adapters/__pycache__/base.cpython-39.pyc b/fam/llm/adapters/__pycache__/base.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fdb94acc76b186d86fa9386f9865bc1990dfe8db
Binary files /dev/null and b/fam/llm/adapters/__pycache__/base.cpython-39.pyc differ
diff --git a/fam/llm/adapters/__pycache__/flattened_encodec.cpython-310.pyc b/fam/llm/adapters/__pycache__/flattened_encodec.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ef5669f54de4ff2d412a2b6cde0c9a590ffc2d70
Binary files /dev/null and b/fam/llm/adapters/__pycache__/flattened_encodec.cpython-310.pyc differ
diff --git a/fam/llm/adapters/__pycache__/flattened_encodec.cpython-39.pyc b/fam/llm/adapters/__pycache__/flattened_encodec.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c96413b57a50681db885a15f1832f4a3ab079cf9
Binary files /dev/null and b/fam/llm/adapters/__pycache__/flattened_encodec.cpython-39.pyc differ
diff --git a/fam/llm/adapters/__pycache__/tilted_encodec.cpython-310.pyc b/fam/llm/adapters/__pycache__/tilted_encodec.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..544d2b97168852a946459f2f98c1c3de75bab27e
Binary files /dev/null and b/fam/llm/adapters/__pycache__/tilted_encodec.cpython-310.pyc differ
diff --git a/fam/llm/adapters/__pycache__/tilted_encodec.cpython-39.pyc b/fam/llm/adapters/__pycache__/tilted_encodec.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..369dcaab556a21e8f2a5f3b4af7f49c4a444afb2
Binary files /dev/null and b/fam/llm/adapters/__pycache__/tilted_encodec.cpython-39.pyc differ
diff --git a/fam/llm/adapters/base.py b/fam/llm/adapters/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..e61646c6f9bae6be719dd340e408b2706b12193a
--- /dev/null
+++ b/fam/llm/adapters/base.py
@@ -0,0 +1,5 @@
+from abc import ABC
+
+
+class BaseDataAdapter(ABC):
+ pass
diff --git a/fam/llm/adapters/flattened_encodec.py b/fam/llm/adapters/flattened_encodec.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd553f0cd80c92c1935ec3161dabcb2cb10600bb
--- /dev/null
+++ b/fam/llm/adapters/flattened_encodec.py
@@ -0,0 +1,38 @@
+from fam.llm.adapters.base import BaseDataAdapter
+
+
+class FlattenedInterleavedEncodec2Codebook(BaseDataAdapter):
+ def __init__(self, end_of_audio_token):
+ self._end_of_audio_token = end_of_audio_token
+
+ def decode(self, tokens: list[list[int]]) -> tuple[list[int], list[list[int]]]:
+ assert len(tokens) == 1
+ tokens = tokens[0]
+
+ text_ids = []
+ extracted_audio_ids = [[], []]
+
+ for t in tokens:
+ if t < self._end_of_audio_token:
+ extracted_audio_ids[0].append(t)
+ elif t >= self._end_of_audio_token and t < 2 * self._end_of_audio_token:
+ extracted_audio_ids[1].append(t - self._end_of_audio_token)
+ # We ignore t = 2 * self._end_of_audio_token, as it is the end of audio token
+ elif t > 2 * self._end_of_audio_token:
+ text_ids.append(t)
+
+ if len(set([len(x) for x in extracted_audio_ids])) != 1:
+ min_len = min([len(x) for x in extracted_audio_ids])
+ max_len = max([len(x) for x in extracted_audio_ids])
+ print("WARNING: Number of tokens at each hierarchy must be of the same length!")
+ print(f"Truncating to min length of {min_len} tokens from {max_len} max.")
+ print([len(x) for x in extracted_audio_ids])
+ extracted_audio_ids = [x[:min_len] for x in extracted_audio_ids]
+
+ return text_ids[:-1], extracted_audio_ids
+
+ def encode(self, text_tokens: list[int], audio_tokens: list[list[int]]):
+ """
+ Performs the required combination and padding as needed.
+ """
+ raise NotImplementedError
diff --git a/fam/llm/adapters/tilted_encodec.py b/fam/llm/adapters/tilted_encodec.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a7e9b2c1216a489e4581dfd55684a11ebd473d9
--- /dev/null
+++ b/fam/llm/adapters/tilted_encodec.py
@@ -0,0 +1,45 @@
+from fam.llm.adapters.base import BaseDataAdapter
+
+
+class TiltedEncodec(BaseDataAdapter):
+ def __init__(self, end_of_audio_token):
+ self._end_of_audio_token = end_of_audio_token
+
+ def decode(self, tokens: list[list[int]]) -> tuple[list[int], list[list[int]]]:
+ assert len(tokens) > 1
+
+ text_ids = []
+ extracted_audio_ids = []
+
+ extracted_audio_ids.append([])
+ # Handle first hierarchy as special case as it contains text tokens as well
+ # TODO: maybe it doesn't need special case, and can be handled on it's own :)
+ for t in tokens[0]:
+ if t > self._end_of_audio_token:
+ text_ids.append(t)
+ elif t < self._end_of_audio_token:
+ extracted_audio_ids[0].append(t)
+
+ # Handle the rest of the hierarchies
+ for i in range(1, len(tokens)):
+ token_hierarchy_ids = tokens[i]
+ extracted_audio_ids.append([])
+ for t in token_hierarchy_ids:
+ if t < self._end_of_audio_token:
+ extracted_audio_ids[i].append(t)
+
+ if len(set([len(x) for x in extracted_audio_ids])) != 1:
+ min_len = min([len(x) for x in extracted_audio_ids])
+ max_len = max([len(x) for x in extracted_audio_ids])
+ print("WARNING: Number of tokens at each hierarchy must be of the same length!")
+ print(f"Truncating to min length of {min_len} tokens from {max_len} max.")
+ print([len(x) for x in extracted_audio_ids])
+ extracted_audio_ids = [x[:min_len] for x in extracted_audio_ids]
+
+ return text_ids[:-1], extracted_audio_ids
+
+ def encode(self, text_tokens: list[int], audio_tokens: list[list[int]]):
+ """
+ Performs the required combination and padding as needed.
+ """
+ raise NotImplementedError
diff --git a/fam/llm/decoders.py b/fam/llm/decoders.py
new file mode 100644
index 0000000000000000000000000000000000000000..95615821bd9c721f8ff20dca524c2051367f9b95
--- /dev/null
+++ b/fam/llm/decoders.py
@@ -0,0 +1,101 @@
+import os
+import pathlib
+import uuid
+from abc import ABC, abstractmethod
+from typing import Callable, Optional, Union
+
+import julius
+import torch
+from audiocraft.data.audio import audio_read, audio_write
+from audiocraft.models import MultiBandDiffusion # type: ignore
+
+mbd = MultiBandDiffusion.get_mbd_24khz(bw=6) # 1.5
+
+
+class Decoder(ABC):
+ @abstractmethod
+ def decode(self, tokens: list[int], ref_audio_path: Optional[str] = None, causal: Optional[bool] = None):
+ raise NotImplementedError
+
+
+class EncodecDecoder(Decoder):
+ def __init__(
+ self,
+ tokeniser_decode_fn: Callable[[list[int]], str],
+ data_adapter_fn: Callable[[list[list[int]]], tuple[list[int], list[list[int]]]],
+ output_dir: str,
+ ):
+ self._mbd_sample_rate = 24_000
+ self._end_of_audio_token = 1024
+ self._num_codebooks = 8
+ self.mbd = mbd
+
+ self.tokeniser_decode_fn = tokeniser_decode_fn
+ self._data_adapter_fn = data_adapter_fn
+
+ self.output_dir = pathlib.Path(output_dir).resolve()
+ os.makedirs(self.output_dir, exist_ok=True)
+
+ def _save_audio(self, name: str, wav: torch.Tensor):
+ audio_write(
+ name,
+ wav.squeeze(0).cpu(),
+ self._mbd_sample_rate,
+ strategy="loudness",
+ loudness_compressor=True,
+ )
+
+ def get_tokens(self, audio_path: str) -> list[list[int]]:
+ """
+ Utility method to get tokens from audio. Useful when you want to test reconstruction in some form (e.g.
+ limited codebook reconstruction or sampling from second stage model only).
+ """
+ pass
+ wav, sr = audio_read(audio_path)
+ if sr != self._mbd_sample_rate:
+ wav = julius.resample_frac(wav, sr, self._mbd_sample_rate)
+ if wav.ndim == 2:
+ wav = wav.unsqueeze(1)
+ wav = wav.to("cuda")
+ tokens = self.mbd.codec_model.encode(wav)
+ tokens = tokens[0][0]
+
+ return tokens.tolist()
+
+ def decode(
+ self, tokens: list[list[int]], causal: bool = True, ref_audio_path: Optional[str] = None
+ ) -> Union[str, torch.Tensor]:
+ # TODO: this has strange behaviour -- if causal is True, it returns tokens. if causal is False, it SAVES the audio file.
+ text_ids, extracted_audio_ids = self._data_adapter_fn(tokens)
+ text = self.tokeniser_decode_fn(text_ids)
+ # print(f"Text: {text}")
+
+ tokens = torch.tensor(extracted_audio_ids, device="cuda").unsqueeze(0)
+
+ if tokens.shape[1] < self._num_codebooks:
+ tokens = torch.cat(
+ [tokens, *[torch.ones_like(tokens[0:1, 0:1]) * 0] * (self._num_codebooks - tokens.shape[1])], dim=1
+ )
+
+ if causal:
+ return tokens
+ else:
+ with torch.amp.autocast(device_type="cuda", dtype=torch.float32):
+ wav = self.mbd.tokens_to_wav(tokens)
+ # NOTE: we couldn't just return wav here as it goes through loudness compression etc :)
+
+ if wav.shape[-1] < 9600:
+ # this causes problem for the code below, and is also odd :)
+ # first happened for tokens (1, 8, 28) -> wav (1, 1, 8960) (~320x factor in time dimension!)
+ raise Exception("wav predicted is shorter than 400ms!")
+
+ try:
+ wav_file_name = self.output_dir / f"synth_{text.replace(' ', '_')[:25]}_{uuid.uuid4()}"
+ self._save_audio(wav_file_name, wav)
+ return wav_file_name
+ except Exception as e:
+ print(f"Failed to save audio! Reason: {e}")
+
+ wav_file_name = self.output_dir / f"synth_{uuid.uuid4()}"
+ self._save_audio(wav_file_name, wav)
+ return wav_file_name
diff --git a/fam/llm/enhancers.py b/fam/llm/enhancers.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4338c7f2efa2e3fada1a4c1ad3afb1f0f9287f3
--- /dev/null
+++ b/fam/llm/enhancers.py
@@ -0,0 +1,108 @@
+import os
+from abc import ABC
+from typing import Literal, Optional
+
+from df.enhance import enhance, init_df, load_audio, save_audio
+from pydub import AudioSegment
+
+
+def convert_to_wav(input_file: str, output_file: str):
+ """Convert an audio file to WAV format
+
+ Args:
+ input_file (str): path to input audio file
+ output_file (str): path to output WAV file
+
+ """
+ # Detect the format of the input file
+ format = input_file.split(".")[-1].lower()
+
+ # Read the audio file
+ audio = AudioSegment.from_file(input_file, format=format)
+
+ # Export as WAV
+ audio.export(output_file, format="wav")
+
+
+def make_output_file_path(audio_file: str, tag: str, ext: Optional[str] = None) -> str:
+ """Generate the output file path
+
+ Args:
+ audio_file (str): path to input audio file
+ tag (str): tag to append to the output file name
+ ext (str, optional): extension of the output file. Defaults to None.
+
+ Returns:
+ str: path to output file
+ """
+
+ directory = "./enhanced"
+ # Get the name of the input file
+ filename = os.path.basename(audio_file)
+
+ # Get the name of the input file without the extension
+ filename_without_extension = os.path.splitext(filename)[0]
+
+ # Get the extension of the input file
+ extension = ext or os.path.splitext(filename)[1]
+
+ # Generate the output file path
+ output_file = os.path.join(directory, filename_without_extension + tag + extension)
+
+ return output_file
+
+
+class BaseEnhancer(ABC):
+ """Base class for audio enhancers"""
+
+ def __init__(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def __call__(self, audio_file: str, output_file: Optional[str] = None) -> str:
+ raise NotImplementedError
+
+ def get_output_file(self, audio_file: str, tag: str, ext: Optional[str] = None) -> str:
+ output_file = make_output_file_path(audio_file, tag, ext=ext)
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
+ return output_file
+
+
+class DFEnhancer(BaseEnhancer):
+ def __init__(self, *args, **kwargs):
+ self.model, self.df_state, _ = init_df()
+
+ def __call__(self, audio_file: str, output_file: Optional[str] = None) -> str:
+ output_file = output_file or self.get_output_file(audio_file, "_df")
+
+ audio, _ = load_audio(audio_file, sr=self.df_state.sr())
+
+ enhanced = enhance(self.model, self.df_state, audio)
+
+ save_audio(output_file, enhanced, self.df_state.sr())
+
+ return output_file
+
+
+def get_enhancer(enhancer_name: Literal["df"]) -> BaseEnhancer:
+ """Get an audio enhancer
+
+ Args:
+ enhancer_name (Literal["df"]): name of the audio enhancer
+
+ Raises:
+ ValueError: if the enhancer name is not recognised
+
+ Returns:
+ BaseEnhancer: audio enhancer
+ """
+
+ if enhancer_name == "df":
+ import warnings
+
+ warnings.filterwarnings(
+ "ignore",
+ message='"sinc_interpolation" resampling method name is being deprecated and replaced by "sinc_interp_hann" in the next release. The default behavior remains unchanged.',
+ )
+ return DFEnhancer()
+ else:
+ raise ValueError(f"Unknown enhancer name: {enhancer_name}")
diff --git a/fam/llm/fast_inference.py b/fam/llm/fast_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..77bf00439ffcecc3e4fd926d373f1633092e8095
--- /dev/null
+++ b/fam/llm/fast_inference.py
@@ -0,0 +1,143 @@
+import os
+import shutil
+import tempfile
+import time
+from pathlib import Path
+
+import librosa
+import torch
+from huggingface_hub import snapshot_download
+
+from fam.llm.adapters import FlattenedInterleavedEncodec2Codebook
+from fam.llm.decoders import EncodecDecoder
+from fam.llm.fast_inference_utils import build_model, main
+from fam.llm.inference import (
+ EncodecDecoder,
+ InferenceConfig,
+ Model,
+ TiltedEncodec,
+ TrainedBPETokeniser,
+ get_cached_embedding,
+ get_cached_file,
+ get_enhancer,
+)
+from fam.llm.utils import (
+ check_audio_file,
+ get_default_dtype,
+ get_device,
+ normalize_text,
+)
+
+
+class TTS:
+ END_OF_AUDIO_TOKEN = 1024
+
+ def __init__(
+ self, model_name: str = "metavoiceio/metavoice-1B-v0.1", *, seed: int = 1337, output_dir: str = "outputs"
+ ):
+ """
+ model_name (str): refers to the model identifier from the Hugging Face Model Hub (https://huggingface.co/metavoiceio)
+ """
+
+ # NOTE: this needs to come first so that we don't change global state when we want to use
+ # the torch.compiled-model.
+ self._dtype = get_default_dtype()
+ self._device = get_device()
+ self._model_dir = snapshot_download(repo_id=model_name, cache_dir = '/proj/afosr/metavoice/cache')
+ self.first_stage_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=self.END_OF_AUDIO_TOKEN)
+ self.output_dir = output_dir
+ os.makedirs(self.output_dir, exist_ok=True)
+
+ second_stage_ckpt_path = f"{self._model_dir}/second_stage.pt"
+ config_second_stage = InferenceConfig(
+ ckpt_path=second_stage_ckpt_path,
+ num_samples=1,
+ seed=seed,
+ device=self._device,
+ dtype=self._dtype,
+ compile=False,
+ init_from="resume",
+ output_dir=self.output_dir,
+ )
+ data_adapter_second_stage = TiltedEncodec(end_of_audio_token=self.END_OF_AUDIO_TOKEN)
+ self.llm_second_stage = Model(
+ config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode
+ )
+ self.enhancer = get_enhancer("df")
+
+ self.precision = {"float16": torch.float16, "bfloat16": torch.bfloat16}[self._dtype]
+ self.model, self.tokenizer, self.smodel, self.model_size = build_model(
+ precision=self.precision,
+ checkpoint_path=Path(f"{self._model_dir}/first_stage.pt"),
+ spk_emb_ckpt_path=Path(f"{self._model_dir}/speaker_encoder.pt"),
+ device=self._device,
+ compile=True,
+ compile_prefill=True,
+ )
+
+ def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3.0, temperature=1.0) -> str:
+ """
+ text: Text to speak
+ spk_ref_path: Path to speaker reference file. Min. 30s of audio required. Supports both local paths & public URIs. Audio formats: wav, flac & mp3
+ top_p: Top p for sampling applied to first-stage model. Range [0.9, 1.0] are good. This is a measure of speech stability - improves text following for a challenging speaker
+ guidance_scale: Guidance scale [1.0, 3.0] for sampling. This is a measure of speaker similarity - how closely to match speaker identity and speech style.
+ temperature: Temperature for sampling applied to both LLMs (first & second stage)
+
+ returns: path to speech .wav file
+ """
+ text = normalize_text(text)
+ spk_ref_path = get_cached_file(spk_ref_path)
+ check_audio_file(spk_ref_path)
+ spk_emb = get_cached_embedding(
+ spk_ref_path,
+ self.smodel,
+ ).to(device=self._device, dtype=self.precision)
+
+ start = time.time()
+ # first stage LLM
+ tokens = main(
+ model=self.model,
+ tokenizer=self.tokenizer,
+ model_size=self.model_size,
+ prompt=text,
+ spk_emb=spk_emb,
+ top_p=torch.tensor(top_p, device=self._device, dtype=self.precision),
+ guidance_scale=torch.tensor(guidance_scale, device=self._device, dtype=self.precision),
+ temperature=torch.tensor(temperature, device=self._device, dtype=self.precision),
+ )
+ _, extracted_audio_ids = self.first_stage_adapter.decode([tokens])
+
+ b_speaker_embs = spk_emb.unsqueeze(0)
+
+ # second stage LLM + multi-band diffusion model
+ wav_files = self.llm_second_stage(
+ texts=[text],
+ encodec_tokens=[torch.tensor(extracted_audio_ids, dtype=torch.int32, device=self._device).unsqueeze(0)],
+ speaker_embs=b_speaker_embs,
+ batch_size=1,
+ guidance_scale=None,
+ top_p=None,
+ top_k=200,
+ temperature=1.0,
+ max_new_tokens=None,
+ )
+
+ # enhance using deepfilternet
+ wav_file = wav_files[0]
+ with tempfile.NamedTemporaryFile(suffix=".wav") as enhanced_tmp:
+ self.enhancer(str(wav_file) + ".wav", enhanced_tmp.name)
+ shutil.copy2(enhanced_tmp.name, str(wav_file) + ".wav")
+ print(f"\nSaved audio to {wav_file}.wav")
+
+ # calculating real-time factor (RTF)
+ time_to_synth_s = time.time() - start
+ audio, sr = librosa.load(str(wav_file) + ".wav")
+ duration_s = librosa.get_duration(y=audio, sr=sr)
+ print(f"\nTotal time to synth (s): {time_to_synth_s}")
+ print(f"Real-time factor: {time_to_synth_s / duration_s:.2f}")
+
+ return str(wav_file) + ".wav"
+
+
+if __name__ == "__main__":
+ tts = TTS()
diff --git a/fam/llm/fast_inference_utils.py b/fam/llm/fast_inference_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ed7b9e5ce44e33896e452abb66d72da2cf40a5e
--- /dev/null
+++ b/fam/llm/fast_inference_utils.py
@@ -0,0 +1,432 @@
+# Copyright (c) MetaVoice Labs Inc., Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without modification, are permitted
+# provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this list of
+# conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice, this
+# list of conditions and the following disclaimer in the documentation and/or other
+# materials provided with the distribution.
+#
+# 3. Neither the name of the copyright holder nor the names of its contributors
+# may be used to endorse or promote products derived from this software without
+# specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR
+# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+import itertools
+import time
+from pathlib import Path
+from typing import Optional, Tuple
+
+import torch
+import torch._dynamo.config
+import torch._inductor.config
+import tqdm
+
+
+def device_sync(device):
+ if "cuda" in device:
+ torch.cuda.synchronize()
+ elif "cpu" in device:
+ pass
+ else:
+ print(f"device={device} is not yet suppported")
+
+
+torch._inductor.config.coordinate_descent_tuning = True
+torch._inductor.config.triton.unique_kernel_names = True
+torch._inductor.config.fx_graph_cache = (
+ True # Experimental feature to reduce compilation times, will be on by default in future
+)
+
+# imports need to happen after setting above flags
+from fam.llm.fast_model import Transformer
+from fam.quantiser.audio.speaker_encoder.model import SpeakerEncoder
+from fam.quantiser.text.tokenise import TrainedBPETokeniser
+
+
+def multinomial_sample_one_no_sync(
+ probs_sort,
+): # Does multinomial sampling without a cuda synchronization
+ q = torch.empty_like(probs_sort).exponential_(1)
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
+
+
+def top_p_sample(logits: torch.Tensor, top_p: torch.Tensor):
+ # ref: huggingface/transformers
+
+ sorted_logits, sorted_indices = torch.sort(logits, descending=False)
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
+
+ # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
+ sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
+ # Keep at least min_tokens_to_keep
+ sorted_indices_to_remove[-1:] = 0
+
+ # scatter sorted tensors to original indexing
+ indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
+ scores = logits.masked_fill(indices_to_remove, -float("Inf"))
+ return scores
+
+
+def logits_to_probs(
+ logits,
+ *,
+ temperature: torch.Tensor,
+ top_p: Optional[torch.Tensor] = None,
+ top_k: Optional[torch.Tensor] = None,
+):
+ logits = logits / torch.max(temperature, 1e-5 * torch.ones_like(temperature))
+
+ if top_k is not None:
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
+ pivot = v.select(-1, -1).unsqueeze(-1)
+ logits = torch.where(logits < pivot, -float("Inf"), logits)
+
+ if top_p is not None:
+ logits = top_p_sample(logits, top_p)
+
+ probs = torch.nn.functional.softmax(logits, dim=-1)
+
+ return probs
+
+
+def sample(
+ logits,
+ guidance_scale: torch.Tensor,
+ temperature: torch.Tensor,
+ top_p: Optional[torch.Tensor] = None,
+ top_k: Optional[torch.Tensor] = None,
+):
+ # (b, t, vocab_size)
+ logits = logits[:, -1]
+ logits_cond, logits_uncond_spkemb = logits.split(logits.size(0) // 2, dim=0)
+ logits = guidance_scale * logits_cond + (1 - guidance_scale) * logits_uncond_spkemb
+ probs = logits_to_probs(logits[0], temperature=temperature, top_p=top_p, top_k=top_k)
+ idx_next = multinomial_sample_one_no_sync(probs)
+ return idx_next, probs
+
+
+def prefill(
+ model: Transformer,
+ x: torch.Tensor,
+ spk_emb: torch.Tensor,
+ input_pos: torch.Tensor,
+ **sampling_kwargs,
+) -> torch.Tensor:
+ # input_pos: [B, S]
+ logits = model(x, spk_emb, input_pos)
+ return sample(logits, **sampling_kwargs)[0]
+
+
+def decode_one_token(
+ model: Transformer,
+ x: torch.Tensor,
+ spk_emb: torch.Tensor,
+ input_pos: torch.Tensor,
+ **sampling_kwargs,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ # input_pos: [B, 1]
+ assert input_pos.shape[-1] == 1
+ logits = model(x, spk_emb, input_pos)
+ return sample(logits, **sampling_kwargs)
+
+
+def decode_n_tokens(
+ model: Transformer,
+ cur_token: torch.Tensor,
+ spk_emb: torch.Tensor,
+ input_pos: torch.Tensor,
+ num_new_tokens: int,
+ callback=lambda _: _,
+ return_probs: bool = False,
+ end_of_audio_token: int = 2048,
+ **sampling_kwargs,
+):
+ new_tokens, new_probs = [], []
+ for i in tqdm.tqdm(range(num_new_tokens)):
+ if (cur_token == end_of_audio_token).any():
+ break
+ with torch.backends.cuda.sdp_kernel(
+ enable_flash=False, enable_mem_efficient=False, enable_math=True
+ ): # Actually better for Inductor to codegen attention here
+ next_token, next_prob = decode_one_token(model, cur_token, spk_emb, input_pos, **sampling_kwargs)
+ input_pos += 1
+ new_tokens.append(next_token.clone())
+ callback(new_tokens[-1])
+ if return_probs:
+ new_probs.append(next_prob.clone())
+ cur_token = next_token.view(1, -1).repeat(2, 1)
+
+ return new_tokens, new_probs
+
+
+def model_forward(model, x, spk_emb, input_pos):
+ return model(x, spk_emb, input_pos)
+
+
+@torch.no_grad()
+def generate(
+ model: Transformer,
+ prompt: torch.Tensor,
+ spk_emb: torch.Tensor,
+ *,
+ max_new_tokens: Optional[int] = None,
+ callback=lambda x: x,
+ end_of_audio_token: int = 2048,
+ **sampling_kwargs,
+) -> torch.Tensor:
+ """
+ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
+ """
+ # create an empty tensor of the expected final shape and fill in the current tokens
+ T = prompt.size(0)
+ if max_new_tokens is None:
+ max_seq_length = model.config.block_size
+ else:
+ max_seq_length = T + max_new_tokens
+ max_seq_length = min(max_seq_length, model.config.block_size)
+ max_new_tokens = max_seq_length - T
+ if max_new_tokens <= 0:
+ raise ValueError("Prompt is too long to generate more tokens")
+
+ device, dtype = prompt.device, prompt.dtype
+
+ seq = torch.clone(prompt)
+ input_pos = torch.arange(0, T, device=device)
+
+ next_token = prefill(model, prompt.view(1, -1).repeat(2, 1), spk_emb, input_pos, **sampling_kwargs)
+ seq = torch.cat([seq, next_token.view(1)])
+
+ input_pos = torch.tensor([T], device=device, dtype=torch.int)
+
+ generated_tokens, _ = decode_n_tokens(
+ model,
+ next_token.view(1, -1).repeat(2, 1),
+ spk_emb,
+ input_pos,
+ max_new_tokens - 1,
+ callback=callback,
+ end_of_audio_token=end_of_audio_token,
+ **sampling_kwargs,
+ )
+ seq = torch.cat([seq, torch.cat(generated_tokens)])
+
+ return seq
+
+
+def encode_tokens(tokenizer, string, device="cuda"):
+ tokens = tokenizer.encode(string)
+ return torch.tensor(tokens, dtype=torch.int, device=device)
+
+
+def _load_model(checkpoint_path, spk_emb_ckpt_path, device, precision):
+ ##### MODEL
+ with torch.device("meta"):
+ model = Transformer.from_name("metavoice-1B")
+
+ # TODO(quantization): enable
+ # if "int8" in str(checkpoint_path):
+ # print("Using int8 weight-only quantization!")
+ # from quantize import WeightOnlyInt8QuantHandler
+ # simple_quantizer = WeightOnlyInt8QuantHandler(model)
+ # model = simple_quantizer.convert_for_runtime()
+ # from quantize import WeightOnlyInt8QuantHandler
+
+ # if "int4" in str(checkpoint_path):
+ # print("Using int4 quantization!")
+ # path_comps = checkpoint_path.name.split(".")
+ # assert path_comps[-2].startswith("g")
+ # groupsize = int(path_comps[-2][1:])
+ # from quantize import WeightOnlyInt4QuantHandler
+ # simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
+ # model = simple_quantizer.convert_for_runtime()
+
+ checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=False)
+ state_dict = checkpoint["model"]
+ # convert MetaVoice-1B model weights naming to gptfast naming
+ unwanted_prefix = "_orig_mod."
+ for k, v in list(state_dict.items()):
+ if k.startswith(unwanted_prefix):
+ state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
+ state_dict["tok_embeddings.weight"] = state_dict.pop("transformer.wtes.0.weight")
+ state_dict["pos_embeddings.weight"] = state_dict.pop("transformer.wpe.weight")
+ state_dict["output.weight"] = state_dict.pop("lm_heads.0.weight")
+ state_dict["norm.weight"] = state_dict.pop("transformer.ln_f.weight")
+ for k, v in list(state_dict.items()):
+ if k.startswith("transformer.h."):
+ state_dict[k.replace("transformer.h.", "layers.")] = state_dict.pop(k)
+ k = k.replace("transformer.h.", "layers.")
+ if ".attn.c_attn." in k:
+ state_dict[k.replace(".attn.c_attn.", ".attention.wqkv.")] = state_dict.pop(k)
+ k = k.replace(".attn.c_attn.", ".attention.wqkv.")
+ if ".attn.c_proj." in k:
+ state_dict[k.replace(".attn.c_proj.", ".attention.wo.")] = state_dict.pop(k)
+ k = k.replace(".attn.c_proj.", ".attention.wo.")
+ if ".mlp.swiglu.w1." in k:
+ state_dict[k.replace(".mlp.swiglu.w1.", ".feed_forward.swiglu.w1.")] = state_dict.pop(k)
+ k = k.replace(".mlp.swiglu.w1.", ".feed_forward.swiglu.w1.")
+ if ".mlp.swiglu.w3." in k:
+ state_dict[k.replace(".mlp.swiglu.w3.", ".feed_forward.swiglu.w3.")] = state_dict.pop(k)
+ k = k.replace(".mlp.swiglu.w3.", ".feed_forward.swiglu.w3.")
+ if ".ln_1." in k:
+ state_dict[k.replace(".ln_1.", ".attention_norm.")] = state_dict.pop(k)
+ k = k.replace(".ln_1.", ".attention_norm.")
+ if ".ln_2." in k:
+ state_dict[k.replace(".ln_2.", ".ffn_norm.")] = state_dict.pop(k)
+ k = k.replace(".ln_2.", ".ffn_norm.")
+ if ".mlp.c_proj." in k:
+ state_dict[k.replace(".mlp.c_proj.", ".feed_forward.w2.")] = state_dict.pop(k)
+ k = k.replace(".mlp.c_proj.", ".feed_forward.w2.")
+
+ model.load_state_dict(state_dict, assign=True)
+ # simple_quantizer = WeightOnlyInt8QuantHandler(model)
+ # quantized_state_dict = simple_quantizer.create_quantized_state_dict()
+ # model = simple_quantizer.convert_for_runtime()
+ # model.load_state_dict(quantized_state_dict, assign=True)
+ model = model.to(device=device, dtype=precision)
+
+ ###### TOKENIZER
+ tokenizer_info = checkpoint.get("meta", {}).get("tokenizer", {})
+ tokenizer = TrainedBPETokeniser(**tokenizer_info)
+
+ ###### SPEAKER EMBEDDER
+ # TODO: fix!
+ smodel = SpeakerEncoder(
+ weights_fpath=spk_emb_ckpt_path,
+ device=device,
+ eval=True,
+ verbose=False,
+ )
+ return model.eval(), tokenizer, smodel
+
+
+def build_model(
+ *,
+ precision: torch.dtype,
+ checkpoint_path: Path = Path(""),
+ spk_emb_ckpt_path: Path = Path(""),
+ compile_prefill: bool = False,
+ compile: bool = True,
+ device: str = "cuda",
+):
+ assert checkpoint_path.is_file(), checkpoint_path
+
+ print(f"Using device={device}")
+
+ print("Loading model ...")
+ t0 = time.time()
+ model, tokenizer, smodel = _load_model(checkpoint_path, spk_emb_ckpt_path, device, precision)
+
+ device_sync(device=device) # MKG
+ print(f"Time to load model: {time.time() - t0:.02f} seconds")
+
+ torch.manual_seed(1234)
+ model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())])
+
+ with torch.device(device):
+ model.setup_spk_cond_mask()
+ model.setup_caches(max_batch_size=2, max_seq_length=model.config.block_size)
+
+ if compile:
+ print("Compiling...Can take up to 2 mins.")
+ global decode_one_token, prefill
+ decode_one_token = torch.compile(
+ decode_one_token,
+ mode="max-autotune",
+ fullgraph=True,
+ )
+
+ if compile_prefill:
+ prefill = torch.compile(
+ prefill,
+ fullgraph=True,
+ dynamic=True,
+ )
+
+ encoded = encode_tokens(tokenizer, "Hello, what's up?", device=device)
+ spk_emb = torch.randn((1, 256), device=device, dtype=precision)
+
+ device_sync(device=device) # MKG
+ t0 = time.perf_counter()
+ y = generate(
+ model,
+ encoded,
+ spk_emb,
+ max_new_tokens=200,
+ callback=lambda x: x,
+ temperature=torch.tensor(1.0, device=device, dtype=precision),
+ top_k=None,
+ top_p=torch.tensor(0.95, device=device, dtype=precision),
+ guidance_scale=torch.tensor(3.0, device=device, dtype=precision),
+ end_of_audio_token=9999, # don't end early for compilation stage.
+ )
+
+ device_sync(device=device) # MKG
+
+ print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
+
+ return model, tokenizer, smodel, model_size
+
+
+def main(
+ *,
+ model,
+ tokenizer,
+ model_size,
+ prompt: str,
+ guidance_scale: torch.Tensor,
+ temperature: torch.Tensor,
+ spk_emb: torch.Tensor,
+ top_k: Optional[torch.Tensor] = None,
+ top_p: Optional[torch.Tensor] = None,
+ device: str = "cuda",
+) -> list:
+ """Generates text samples based on a pre-trained Transformer model and tokenizer."""
+
+ encoded = encode_tokens(tokenizer, prompt, device=device)
+ prompt_length = encoded.size(0)
+
+ aggregate_metrics: dict = {
+ "tokens_per_sec": [],
+ }
+
+ device_sync(device=device) # MKG
+
+ if True:
+ callback = lambda x: x
+ t0 = time.perf_counter()
+
+ y = generate(
+ model,
+ encoded,
+ spk_emb,
+ callback=callback,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ guidance_scale=guidance_scale,
+ )
+
+ device_sync(device=device) # MKG
+ t = time.perf_counter() - t0
+
+ tokens_generated = y.size(0) - prompt_length
+ tokens_sec = tokens_generated / t
+ aggregate_metrics["tokens_per_sec"].append(tokens_sec)
+ print(f"Time for 1st stage LLM inference: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec")
+ print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
+ # print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}")
+ print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB\n")
+
+ return y.tolist()
diff --git a/fam/llm/fast_model.py b/fam/llm/fast_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d74bd915b8f1e0949b3ef0149225b4cae8bf346
--- /dev/null
+++ b/fam/llm/fast_model.py
@@ -0,0 +1,261 @@
+# Copyright (c) MetaVoice Labs Inc., Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without modification, are permitted
+# provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this list of
+# conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice, this
+# list of conditions and the following disclaimer in the documentation and/or other
+# materials provided with the distribution.
+#
+# 3. Neither the name of the copyright holder nor the names of its contributors
+# may be used to endorse or promote products derived from this software without
+# specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR
+# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+from dataclasses import dataclass
+from functools import reduce
+from math import gcd
+from typing import Optional, Tuple
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+from torch.nn import functional as F
+
+from fam.llm.utils import get_default_dtype
+
+import logging
+
+# Adjust the logging level
+logger = logging.getLogger("torch")
+logger.setLevel(logging.ERROR)
+
+
+def find_multiple(n: int, *args: Tuple[int]) -> int:
+ k = reduce(lambda x, y: x * y // gcd(x, y), args + (1,))
+ if n % k == 0:
+ return n
+ return n + k - (n % k)
+
+
+@dataclass
+class ModelArgs:
+ block_size: int = 2048
+ vocab_size: int = 32000
+ n_layer: int = 32
+ n_head: int = 32
+ dim: int = 4096
+ speaker_emb_dim: int = 256
+ intermediate_size: int = None
+ n_local_heads: int = -1
+ head_dim: int = 64
+ norm_eps: float = 1e-5
+ dtype: torch.dtype = torch.bfloat16
+
+ def __post_init__(self):
+ if self.n_local_heads == -1:
+ self.n_local_heads = self.n_head
+ if self.intermediate_size is None:
+ hidden_dim = 4 * self.dim
+ n_hidden = int(2 * hidden_dim / 3)
+ self.intermediate_size = find_multiple(n_hidden, 256)
+ self.head_dim = self.dim // self.n_head
+
+ self.dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[get_default_dtype()]
+
+ @classmethod
+ def from_name(cls, name: str):
+ if name in transformer_configs:
+ return cls(**transformer_configs[name])
+ # fuzzy search
+ config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)]
+ assert len(config) == 1, name
+ return cls(**transformer_configs[config[0]])
+
+
+transformer_configs = {
+ "metavoice-1B": dict(
+ n_layer=24,
+ n_head=16,
+ dim=2048,
+ vocab_size=2562,
+ ),
+}
+
+
+class KVCache(nn.Module):
+ def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype):
+ super().__init__()
+ cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
+ self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
+ self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
+
+ def update(self, input_pos, k_val, v_val):
+ # input_pos: [S], k_val: [B, H, S, D]
+ assert input_pos.shape[0] == k_val.shape[2]
+
+ k_out = self.k_cache
+ v_out = self.v_cache
+ k_out[:, :, input_pos] = k_val
+ v_out[:, :, input_pos] = v_val
+
+ return k_out, v_out
+
+
+class Transformer(nn.Module):
+ def __init__(self, config: ModelArgs) -> None:
+ super().__init__()
+ self.config = config
+
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
+ self.pos_embeddings = nn.Embedding(config.block_size, config.dim)
+ self.speaker_cond_pos = nn.Linear(config.speaker_emb_dim, config.dim, bias=False)
+ self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
+ self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
+
+ self.mask_cache: Optional[Tensor] = None
+ self.max_batch_size = -1
+ self.max_seq_length = -1
+
+ def setup_spk_cond_mask(self):
+ self.spk_cond_mask = torch.zeros((2, 1, self.config.dim), dtype=torch.bool)
+ self.spk_cond_mask[0] = 1
+
+ def setup_caches(self, max_batch_size, max_seq_length):
+ if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
+ return
+ head_dim = self.config.dim // self.config.n_head
+ max_seq_length = find_multiple(max_seq_length, 8)
+ self.max_seq_length = max_seq_length
+ self.max_batch_size = max_batch_size
+ for b in self.layers:
+ b.attention.kv_cache = KVCache(
+ max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype=self.config.dtype
+ )
+
+ self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
+
+ def forward(self, idx: Tensor, spk_emb: Tensor, input_pos: Tensor) -> Tensor:
+ mask = self.causal_mask[None, None, input_pos]
+ x = (
+ self.tok_embeddings(idx)
+ + self.pos_embeddings(input_pos)
+ # masking for speaker condition free guidance
+ + self.speaker_cond_pos(spk_emb) * self.spk_cond_mask
+ )
+
+ for i, layer in enumerate(self.layers):
+ x = layer(x, input_pos, mask)
+ x = self.norm(x)
+ logits = self.output(x)
+ return logits
+
+ @classmethod
+ def from_name(cls, name: str):
+ return cls(ModelArgs.from_name(name))
+
+
+class TransformerBlock(nn.Module):
+ def __init__(self, config: ModelArgs) -> None:
+ super().__init__()
+ self.attention = Attention(config)
+ self.feed_forward = FeedForward(config)
+ self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
+ self.attention_norm = RMSNorm(config.dim, config.norm_eps)
+
+ def forward(self, x: Tensor, input_pos: Tensor, mask: Tensor) -> Tensor:
+ h = x + self.attention(self.attention_norm(x), mask, input_pos)
+ out = h + self.feed_forward(self.ffn_norm(h))
+ return out
+
+
+class Attention(nn.Module):
+ def __init__(self, config: ModelArgs):
+ super().__init__()
+ assert config.dim % config.n_head == 0
+
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
+ # key, query, value projections for all heads, but in a batch
+ self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
+ self.wo = nn.Linear(config.dim, config.dim, bias=False)
+ self.kv_cache = None
+
+ self.n_head = config.n_head
+ self.head_dim = config.head_dim
+ self.n_local_heads = config.n_local_heads
+ self.dim = config.dim
+
+ def forward(
+ self,
+ x: Tensor,
+ mask: Tensor,
+ input_pos: Optional[Tensor] = None,
+ ) -> Tensor:
+ bsz, seqlen, _ = x.shape
+
+ kv_size = self.n_local_heads * self.head_dim
+ q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
+
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
+ k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
+ v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
+
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
+
+ if self.kv_cache is not None:
+ k, v = self.kv_cache.update(input_pos, k, v)
+
+ k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+ v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
+
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
+
+ y = self.wo(y)
+ return y
+
+
+class SwiGLU(nn.Module):
+ def __init__(self, config: ModelArgs) -> None:
+ super().__init__()
+ self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+ self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return F.silu(self.w1(x)) * self.w3(x)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, config: ModelArgs) -> None:
+ super().__init__()
+ self.swiglu = SwiGLU(config)
+ self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.w2(self.swiglu(x))
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-5):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def _norm(self, x):
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
+
+ def forward(self, x: Tensor) -> Tensor:
+ output = self._norm(x.float()).type_as(x)
+ return output * self.weight
diff --git a/fam/llm/inference.py b/fam/llm/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..13ddfac85c05840a04a4c4d42837511d647e7f51
--- /dev/null
+++ b/fam/llm/inference.py
@@ -0,0 +1,714 @@
+"""
+Command: python fam/llm/inference.py --spk_cond_path="assets/bria.mp3" --text="This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model."
+"""
+
+import dataclasses
+import hashlib
+import json
+import os
+import pathlib
+import shutil
+import subprocess
+import tempfile
+import time
+from contextlib import nullcontext
+from dataclasses import dataclass
+from typing import List, Literal, Optional, Tuple, Type, Union
+
+import torch
+import tqdm
+import tqdm.contrib.concurrent
+import tyro
+from huggingface_hub import snapshot_download
+
+from fam.llm.adapters import FlattenedInterleavedEncodec2Codebook, TiltedEncodec
+from fam.llm.decoders import Decoder, EncodecDecoder
+from fam.llm.enhancers import BaseEnhancer, get_enhancer
+from fam.llm.model import GPT, GPTConfig
+from fam.llm.utils import check_audio_file, get_default_dtype, normalize_text
+from fam.quantiser.audio.speaker_encoder.model import SpeakerEncoder
+from fam.quantiser.text.tokenise import TrainedBPETokeniser
+
+
+@dataclass
+class InferenceConfig:
+ ckpt_path: str # path to checkpoint
+ output_dir: str
+ num_samples: int = 10 # number of samples to draw
+ seed: int = 1337 # random seed
+ device: str = "cuda"
+ dtype: str = "bfloat16"
+ compile: bool = False
+ init_from: str = "resume" # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
+
+ def __str__(self):
+ field_strs = []
+ for field in dataclasses.fields(self):
+ value = getattr(self, field.name)
+ field_strs.append(f" {field.name}: {value}")
+
+ return "InferenceConfig:\n" + "\n".join(field_strs)
+
+
+class Model:
+ def __init__(
+ self,
+ config: InferenceConfig,
+ tokenizer_cls: Type[TrainedBPETokeniser],
+ decoder_cls: Type[Decoder],
+ data_adapter_fn,
+ use_kv_cache: Optional[Literal["vanilla"]] = None,
+ ):
+ # TODO: disentangle the encodec stuff and numbers etc with rest of this code (esp at encoder-only / second stage model inference)
+ # TODO: remove magic number
+ self._encodec_codes_pad_token = 1024
+ self._num_encodec_codebooks = 8
+ self.config = config
+ self.use_kv_cache = use_kv_cache
+
+ torch.manual_seed(config.seed)
+ torch.cuda.manual_seed(config.seed)
+ torch.backends.cuda.matmul.allow_tf32 = True if config.dtype != "float32" else False # allow tf32 on matmul
+ torch.backends.cudnn.allow_tf32 = True if config.dtype != "float32" else False # allow tf32 on cudnn
+ device_type = "cuda" if "cuda" in config.device else "cpu" # for later use in torch.autocast
+ self.ptdtype = {
+ "float32": torch.float32,
+ "tfloat32": torch.float32,
+ "bfloat16": torch.bfloat16,
+ "float16": torch.float16,
+ }[config.dtype]
+ self._ctx = (
+ nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=self.ptdtype)
+ )
+
+ self.use_bpe_tokenizer = False
+ self.load_meta = None
+ self.speaker_cond = None
+ self.meta = None
+ self.model = None
+ self.checkpoint_config = None
+ self.vocab_sizes = None
+ self.smodel = None
+
+ self._init_model()
+
+ self.tokenizer = tokenizer_cls(**self.meta["tokenizer"])
+ self.decoder = decoder_cls(
+ tokeniser_decode_fn=self.tokenizer.decode,
+ output_dir=self.config.output_dir,
+ data_adapter_fn=data_adapter_fn,
+ )
+
+ def _init_model(self):
+ if self.config.init_from == "resume":
+ # init from a model saved in a specific directory
+ checkpoint = torch.load(self.config.ckpt_path, map_location=self.config.device)
+ self.vocab_sizes = checkpoint["model_args"]["vocab_sizes"]
+
+ self.load_meta = False
+ self.speaker_cond = False
+
+ if "config" in checkpoint:
+ self.checkpoint_config = checkpoint["config"]
+
+ self.meta = checkpoint["meta"]
+ load_meta = True
+
+ if load_meta:
+ self.use_bpe_tokenizer = "stoi" not in self.meta or "itos" not in self.meta
+ self.speaker_cond = self.meta.get("speaker_cond")
+
+ if self.speaker_cond:
+ speaker_emb_size = self.meta["speaker_emb_size"]
+
+ model_args = checkpoint["model_args"]
+ if "causal" in self.checkpoint_config and self.checkpoint_config["causal"] is False:
+ self._encodec_ctx_window = model_args["block_size"]
+
+ gptconf = GPTConfig(**model_args)
+
+ # TODO: rename `speaker_emb_dim` to `speaker_emb_size`.
+ self.model = GPT(gptconf, speaker_emb_dim=speaker_emb_size if self.speaker_cond else None)
+ state_dict = checkpoint["model"]
+ unwanted_prefix = "_orig_mod."
+ for k, v in list(state_dict.items()):
+ if k.startswith(unwanted_prefix):
+ state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
+ self.model.load_state_dict(state_dict)
+
+ # model
+ self.model.eval()
+ self.model.to(self.config.device)
+
+ if self.config.compile:
+ from einops._torch_specific import allow_ops_in_compiled_graph
+
+ allow_ops_in_compiled_graph()
+ self.model = torch.compile(self.model) # type: ignore
+
+ if self.use_kv_cache is not None:
+ if "causal" in self.checkpoint_config and self.checkpoint_config["causal"] is False:
+ raise Exception("kv_cache not supported for non-causal models!")
+
+ if self.use_kv_cache == "vanilla":
+ self.model.enable_kv_cache()
+ else:
+ raise NotImplementedError(f"kv_cache type {self.use_kv_cache} not implemented!")
+
+ def causal_sample(
+ self,
+ *,
+ texts: list[str],
+ batch_size: int,
+ max_new_tokens: int,
+ temperature: Optional[float],
+ top_k: Optional[int],
+ top_p: Optional[float],
+ speaker_embs: Optional[torch.Tensor] = None,
+ guidance_scale: Optional[float] = None,
+ ) -> list[torch.Tensor]:
+ """
+ Returns list of torch.Tensors of tokens. Each tensor is of shape (1, c, t) where c is the number of codebooks.
+ Any flattening / inteleaving / tilting gets reversed before the output is returned.
+ """
+ if speaker_embs is not None:
+ assert len(texts) == len(speaker_embs)
+
+ encoded_texts = [self.tokenizer.encode(text) for text in texts]
+
+ ## create multiple hierarchies and get seq_lens
+ seq_lens = []
+ xs = []
+ for i, encoded_text in enumerate(encoded_texts):
+ encoded_text = torch.tensor([encoded_text], dtype=torch.long, device=self.config.device)
+ # TODO: remove magic number
+ xs.append(
+ torch.cat(
+ # [1st hierarchy of text, *remaining hierarchies of padded tokens]
+ # TODO: self.vocab_sizes should be from the model config?
+ [encoded_text, *[torch.ones_like(encoded_text) * 1024] * (len(self.vocab_sizes) - 1)],
+ dim=0,
+ ).unsqueeze(0)
+ ) # b x [(b=1, c, t)]
+ seq_lens.append(xs[-1].shape[-1])
+ max_len = max(seq_lens)
+ assert len(xs) == len(seq_lens)
+
+ ## equalise the shapes in the batch. we can use torch.zeros as tokens > seq_lens will be masked out.
+ x = torch.zeros((len(encoded_texts), xs[0].shape[1], max_len), dtype=torch.long, device=self.config.device)
+ for i, _xs in enumerate(xs):
+ assert _xs.shape[-1] == seq_lens[i]
+ x[i, :, : seq_lens[i]] = _xs
+
+ ## check that the input is correct
+ for i in range(x.shape[0]):
+ assert x[i, 0, : seq_lens[i]].tolist() == encoded_texts[i]
+
+ # TODO: remove magic number
+ if x.shape[1] > 1:
+ assert set(x[i, 1, : seq_lens[i]].tolist()) == set([1024])
+
+ assert x.shape[0] == speaker_embs.shape[0] if speaker_embs is not None else True
+
+ if self.speaker_cond is False:
+ speaker_embs = None
+
+ # run sampling loop
+ with torch.no_grad():
+ with self._ctx: # type: ignore
+ to_return = []
+ for k in range(self.config.num_samples):
+ assert seq_lens is not None
+ assert batch_size is not None
+
+ if max(seq_lens) + max_new_tokens >= self.model.config.block_size:
+ raise Exception(
+ f"max_new_tokens {max_new_tokens} too large! Choose {self.model.config.block_size - max(seq_lens) - 1} instead."
+ )
+
+ y = self.model.generate(
+ x,
+ max_new_tokens,
+ seq_lens=seq_lens,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ speaker_embs=speaker_embs,
+ batch_size=batch_size,
+ guidance_scale=guidance_scale,
+ dtype=self.ptdtype,
+ end_of_audio_token=self.tokenizer.offset - 1,
+ end_of_text_token=self.tokenizer.eot_token,
+ )
+ for i in range(len(y)):
+ to_return.append(self.decoder.decode(tokens=y[i].tolist(), causal=True))
+
+ return to_return
+
+ def non_causal_sample(
+ self,
+ *,
+ texts: list[str],
+ encodec_tokens: list[torch.Tensor],
+ batch_size: int,
+ top_k: Optional[int],
+ temperature: Optional[float],
+ speaker_embs: Optional[torch.Tensor] = None,
+ ) -> list[str]:
+ """
+ Returns paths to saved audio files.
+ """
+ if speaker_embs is not None:
+ assert len(texts) == len(speaker_embs)
+
+ encoded_texts = [self.tokenizer.encode(text) for text in texts]
+
+ # setup input
+ # TODO: same code is used during data prep. refactor
+ padded_hierarchies_inputs = []
+ for encoded_text, encodec_token in zip(encoded_texts, encodec_tokens):
+ x = torch.tensor(encoded_text, dtype=torch.long, device=self.config.device)[
+ None, None, ...
+ ] # (b=1, c=1, t)
+
+ # TODO: should only happen if decoder is encodecdeocder?
+ assert encodec_token.shape[0] == 1
+ encodec_token = encodec_token[0].tolist() # (b=1, c, t) -> (c, t)
+ assert len(encodec_token) >= 1 and len(encodec_token) <= self._num_encodec_codebooks
+
+ ## setup hierarchies of tokens
+ # TODO: refactor and merge with code in processing.py
+ text_tokens = encoded_text # (t,)
+
+ hierarchies_in = []
+ hierarchies_in.append(text_tokens + encodec_token[0] + [self._encodec_codes_pad_token])
+ hierarchies_in.append(
+ [self._encodec_codes_pad_token] * len(text_tokens) + encodec_token[1] + [self._encodec_codes_pad_token]
+ )
+
+ ## adding padding / cutting to the right size as needed
+ # TODO: refactor and merge with code in processing.py
+ padded_hierarchies_input = []
+ for _, t_hierarchy in enumerate(hierarchies_in):
+ assert len(t_hierarchy) == len(hierarchies_in[0])
+ if len(t_hierarchy) < self._encodec_ctx_window:
+ padded_hierarchies_input.append(
+ t_hierarchy + [self._encodec_codes_pad_token] * (self._encodec_ctx_window - len(t_hierarchy))
+ )
+ elif len(t_hierarchy) > self._encodec_ctx_window:
+ padded_hierarchies_input.append(t_hierarchy[: self._encodec_ctx_window])
+ else:
+ padded_hierarchies_input.append(t_hierarchy)
+
+ padded_hierarchies_inputs.append(padded_hierarchies_input)
+
+ ## check that the input is correct
+ in_x = torch.tensor(padded_hierarchies_inputs, dtype=torch.long, device=self.config.device)
+ assert in_x.shape[0] == speaker_embs.shape[0] if speaker_embs is not None else True
+
+ if self.speaker_cond is False:
+ speaker_embs = None
+
+ # run sampling loop
+ with torch.no_grad():
+ with self._ctx: # type: ignore
+ to_return = []
+ for k in range(self.config.num_samples):
+ y = self.model.generate(
+ in_x,
+ None,
+ temperature=temperature,
+ top_k=top_k,
+ # TODO: handle separate top_p for this model explicitly
+ top_p=None,
+ speaker_embs=speaker_embs,
+ batch_size=batch_size,
+ guidance_scale=None,
+ )
+
+ b_tokens = torch.cat([in_x, y], dim=1)
+ for tokens in b_tokens:
+ try:
+ to_return.append(self.decoder.decode(tokens=tokens.tolist(), causal=False))
+ except Exception as e:
+ print("failed to run MBD.")
+ print(f"reason: {str(e)}")
+ to_return.append(None)
+
+ return to_return
+
+ def __call__(
+ self,
+ *,
+ texts: list[str],
+ batch_size: int,
+ max_new_tokens: Optional[int],
+ top_k: Optional[int],
+ top_p: Optional[float],
+ temperature: Optional[float],
+ encodec_tokens: Optional[list[torch.Tensor]] = None,
+ speaker_embs: Optional[torch.Tensor] = None,
+ guidance_scale: Optional[float] = None,
+ ):
+ if self.checkpoint_config.get("causal", True):
+ return self.causal_sample(
+ texts=texts,
+ batch_size=batch_size,
+ speaker_embs=speaker_embs,
+ guidance_scale=guidance_scale,
+ max_new_tokens=max_new_tokens,
+ top_k=top_k,
+ top_p=top_p,
+ temperature=temperature,
+ )
+ else:
+ assert encodec_tokens is not None
+ assert guidance_scale is None
+ assert max_new_tokens is None
+ assert top_p is None
+
+ return self.non_causal_sample(
+ texts=texts,
+ encodec_tokens=encodec_tokens,
+ batch_size=batch_size,
+ speaker_embs=speaker_embs,
+ top_k=top_k,
+ temperature=temperature,
+ )
+
+
+def save_result_metadata(wav_path, ref_path, text, first_stage_ckpt_path, second_stage_ckpt_path):
+ if first_stage_ckpt_path is None or second_stage_ckpt_path is None:
+ return
+ json.dump(
+ {
+ "speaker": ref_path,
+ "text": text,
+ },
+ pathlib.Path(str(wav_path) + ".json").open("w"),
+ )
+
+
+def get_cached_file(file_or_uri: str):
+ """
+ If it's an s3 file, download it to a local temporary file and return that path.
+ Otherwise return the path as is.
+ """
+ is_uri = file_or_uri.startswith("http")
+
+ cache_path = None
+ if is_uri:
+ ext = pathlib.Path(file_or_uri).suffix
+ # hash the file path to get the cache name
+ _cache_name = "audio_" + hashlib.md5(file_or_uri.encode("utf-8")).hexdigest() + ext
+
+ os.makedirs(os.path.expanduser("~/.cache/fam/"), exist_ok=True)
+ cache_path = os.path.expanduser(f"~/.cache/fam/{_cache_name}")
+
+ if not os.path.exists(cache_path):
+ command = f"curl -o {cache_path} {file_or_uri}"
+ subprocess.run(command, shell=True, check=True)
+ else:
+ if os.path.exists(file_or_uri):
+ cache_path = file_or_uri
+ else:
+ raise FileNotFoundError(f"File {file_or_uri} not found!")
+ return cache_path
+
+
+def get_cached_embedding(local_file_path: str, spkemb_model):
+ if not os.path.exists(local_file_path):
+ raise FileNotFoundError(f"File {local_file_path} not found!")
+
+ # hash the file path to get the cache name
+ _cache_name = "embedding_" + hashlib.md5(local_file_path.encode("utf-8")).hexdigest() + ".pt"
+
+ os.makedirs(os.path.expanduser("~/.cache/fam/"), exist_ok=True)
+ cache_path = os.path.expanduser(f"~/.cache/fam/{_cache_name}")
+
+ if not os.path.exists(cache_path):
+ spk_emb = spkemb_model.embed_utterance_from_file(local_file_path, numpy=False).unsqueeze(0) # (b=1, c)
+ torch.save(spk_emb, cache_path)
+ else:
+ spk_emb = torch.load(cache_path)
+
+ return spk_emb
+
+
+def _sample_utterance_batch(
+ texts: list[str],
+ spk_cond_paths: list[Optional[str]],
+ spkemb_model,
+ first_stage_model,
+ second_stage_model,
+ enhancer: Optional[Union[Literal["df"], BaseEnhancer]],
+ first_stage_ckpt_path: str,
+ second_stage_ckpt_path: str,
+ guidance_scale: Optional[Tuple[float, float]],
+ max_new_tokens: int,
+ top_k: Optional[int],
+ top_p: Optional[float],
+ temperature: Optional[float],
+ batch_size: int = 128,
+) -> List[str]:
+
+ speaker_embs = []
+ refs = spk_cond_paths.copy()
+
+ # multithreaded loop to cache all the files
+ spk_cond_paths = tqdm.contrib.concurrent.thread_map(
+ get_cached_file, spk_cond_paths, desc="getting cached speaker ref files"
+ )
+
+ for i, (text, spk_cond_path) in tqdm.tqdm(
+ enumerate(zip(texts, spk_cond_paths)), total=len(texts), desc="calculating speaker embeddings"
+ ):
+ texts[i] = normalize_text(text)
+ speaker_embs.append(get_cached_embedding(spk_cond_path, spkemb_model) if spk_cond_path else None)
+
+ b_speaker_embs = torch.cat(speaker_embs, dim=0)
+
+ start = time.time()
+ b_tokens = first_stage_model(
+ texts=texts,
+ speaker_embs=b_speaker_embs,
+ batch_size=batch_size,
+ guidance_scale=guidance_scale,
+ top_p=top_p,
+ top_k=top_k,
+ temperature=temperature,
+ max_new_tokens=max_new_tokens,
+ )
+
+ # TODO: set batch size for second stage model!
+ wav_files = second_stage_model(
+ texts=texts,
+ encodec_tokens=b_tokens,
+ speaker_embs=b_speaker_embs,
+ batch_size=batch_size,
+ guidance_scale=None,
+ top_p=None,
+ top_k=top_k,
+ temperature=temperature,
+ max_new_tokens=None,
+ )
+
+ for text, tokens, speaker_embs, ref_name, wav_file in zip(texts, b_tokens, b_speaker_embs, refs, wav_files):
+ if wav_file is None:
+ continue
+
+ with tempfile.NamedTemporaryFile(suffix=".wav") as enhanced_tmp:
+ if enhancer is not None:
+ enhancer = get_enhancer(enhancer) if isinstance(enhancer, str) else enhancer
+ enhancer(str(wav_file) + ".wav", enhanced_tmp.name)
+ # copy enhanced_tmp.name back to wav_file
+ print(f"copying enhanced file from {enhanced_tmp.name} to {str(wav_file) + '.wav'}.")
+ shutil.copy2(enhanced_tmp.name, str(wav_file) + ".wav")
+
+ save_result_metadata(
+ wav_file,
+ ref_name,
+ text,
+ first_stage_ckpt_path,
+ second_stage_ckpt_path,
+ )
+
+ print(f"time_to_synth_s: {time.time() - start}")
+ return [str(w) + ".wav" if not str(w).endswith(".wav") else str(w) for w in wav_files]
+
+
+def sample_utterance(
+ text: str,
+ spk_cond_path: Optional[str],
+ spkemb_model,
+ first_stage_model,
+ second_stage_model,
+ enhancer: Optional[Union[Literal["df"], BaseEnhancer]],
+ first_stage_ckpt_path: str,
+ second_stage_ckpt_path: str,
+ guidance_scale: Optional[Tuple[float, float]],
+ max_new_tokens: int,
+ top_k: Optional[int],
+ top_p: Optional[float],
+ temperature: Optional[float],
+) -> str:
+ # NOTE: supports max. 220 characters atm.
+ # Long form synthesis coming soon...
+ MAX_CHARS = 220
+ if len(text) > MAX_CHARS:
+ print(
+ f"\n***WARNING: Max {MAX_CHARS} characters supported. Provided: {len(text)}. Truncating and generating speech...Can lead to unpredictable speech at the end.***"
+ )
+
+ return _sample_utterance_batch(
+ texts=[text],
+ spk_cond_paths=[spk_cond_path],
+ spkemb_model=spkemb_model,
+ first_stage_model=first_stage_model,
+ second_stage_model=second_stage_model,
+ enhancer=enhancer,
+ first_stage_ckpt_path=first_stage_ckpt_path,
+ second_stage_ckpt_path=second_stage_ckpt_path,
+ batch_size=1,
+ guidance_scale=guidance_scale,
+ max_new_tokens=max_new_tokens,
+ top_k=top_k,
+ top_p=top_p,
+ temperature=temperature,
+ )[0]
+
+
+def build_models(config_first_stage, config_second_stage, model_dir, device, use_kv_cache):
+ smodel = SpeakerEncoder(
+ weights_fpath=os.path.join(model_dir, "speaker_encoder.pt"), device=device, eval=True, verbose=False
+ )
+ data_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=1024)
+ llm_first_stage = Model(
+ config_first_stage,
+ TrainedBPETokeniser,
+ EncodecDecoder,
+ data_adapter_fn=data_adapter.decode,
+ use_kv_cache=use_kv_cache,
+ )
+ data_adapter_second_stage = TiltedEncodec(end_of_audio_token=1024)
+ llm_second_stage = Model(
+ config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode
+ )
+ return smodel, llm_first_stage, llm_second_stage
+
+
+def get_first_stage_path(model_dir: str):
+ """Absolute path to checkpoint for the first stage model."""
+ return os.path.join(os.path.expanduser(model_dir), "first_stage.pt")
+
+
+def get_second_stage_path(model_dir: str):
+ """Absolute path to checkpoint for the second stage model."""
+ return os.path.join(os.path.expanduser(model_dir), "second_stage.pt")
+
+
+@dataclass
+class SamplingControllerConfig:
+ """
+ Sample from a trained model.
+ """
+
+ spk_cond_path: str
+ """Path to speaker reference file. Min. 30s of audio required. Supports both local paths & public URIs. Audio formats: wav, flac & mp3"""
+
+ huggingface_repo_id: str = "metavoiceio/metavoice-1B-v0.1"
+ """Absolute path to the model directory."""
+
+ text: str = (
+ "This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model by MetaVoice."
+ )
+ """Text to synthesise."""
+
+ num_samples: int = 1
+ """Number of samples to generate from each model."""
+
+ max_new_tokens: int = 864
+ """Maximum number of new tokens to generate from the first stage model."""
+
+ temperature: float = 1.0
+ """Temperature for sampling applied to both models."""
+
+ top_k: Optional[int] = None
+ """Top k for sampling applied to both models."""
+
+ top_p: Optional[float] = 0.95
+ """Top p for sampling applied to first-stage model."""
+
+ seed: int = 1337
+ """Random seed for sampling."""
+
+ device: Literal["cuda", "cpu"] = "cuda"
+ """Device to use for sampling."""
+
+ dtype: Literal["bfloat16", "float16", "float32", "tfloat32"] = get_default_dtype()
+ """Data type to use for sampling."""
+
+ compile: bool = False
+ """Whether to compile the model using PyTorch 2.0."""
+
+ enhancer: Optional[Literal["df"]] = "df"
+ """Enhancer to use for post-processing."""
+
+ init_from: str = "resume"
+ """Either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')."""
+
+ use_kv_cache: Optional[Literal["vanilla"]] = "vanilla"
+ """Type of kv caching to use for inference: 1) [none] no kv caching, 2) [vanilla] use torch attention with hand implemented kv-cache."""
+
+ output_dir: str = "samples/"
+ """Relative path to output directory"""
+
+ guidance_scale: Optional[Tuple[float, float]] = (3.0, 1.0)
+ """Guidance scale for sampling: (speaker conditioning guidance_scale, prompt conditioning guidance scale)."""
+
+ batch_size: int = 128
+ """Batch size to use for sampling. Note that the batch size gets doubled when guidance is used. For H100, and 1B model,
+ 1 w/ guidance and 1 w/o guidance work well (without kv-caching). With kv-caching, 128 (w/o guidance) and
+ 64 (w/ guidance) works well."""
+
+
+if __name__ == "__main__":
+ # TODO: add support for batch sampling via CLI. Function has been implemented above.
+ sampling_config = tyro.cli(SamplingControllerConfig, use_underscores=True)
+
+ check_audio_file(sampling_config.spk_cond_path)
+
+ model_dir = snapshot_download(repo_id=sampling_config.huggingface_repo_id)
+ first_stage_ckpt_path = get_first_stage_path(model_dir)
+ second_stage_ckpt_path = get_second_stage_path(model_dir)
+
+ config_first_stage = InferenceConfig(
+ ckpt_path=first_stage_ckpt_path,
+ num_samples=sampling_config.num_samples,
+ seed=sampling_config.seed,
+ device=sampling_config.device,
+ dtype=sampling_config.dtype,
+ compile=sampling_config.compile,
+ init_from=sampling_config.init_from,
+ output_dir=sampling_config.output_dir,
+ )
+
+ config_second_stage = InferenceConfig(
+ ckpt_path=second_stage_ckpt_path,
+ num_samples=sampling_config.num_samples,
+ seed=sampling_config.seed,
+ device=sampling_config.device,
+ dtype=sampling_config.dtype,
+ compile=sampling_config.compile,
+ init_from=sampling_config.init_from,
+ output_dir=sampling_config.output_dir,
+ )
+
+ sampling_config.max_new_tokens *= (
+ 2 # deal with max_new_tokens for flattened interleaving! (should scale with num_codebooks?)
+ )
+
+ # define models
+ smodel, llm_first_stage, llm_second_stage = build_models(
+ config_first_stage,
+ config_second_stage,
+ model_dir=model_dir,
+ device=sampling_config.device,
+ use_kv_cache=sampling_config.use_kv_cache,
+ )
+
+ sample_utterance(
+ sampling_config.text,
+ os.path.expanduser(sampling_config.spk_cond_path),
+ smodel,
+ llm_first_stage,
+ llm_second_stage,
+ sampling_config.enhancer,
+ first_stage_ckpt_path,
+ second_stage_ckpt_path,
+ sampling_config.guidance_scale,
+ max_new_tokens=sampling_config.max_new_tokens,
+ top_k=sampling_config.top_k,
+ top_p=sampling_config.top_p,
+ temperature=sampling_config.temperature,
+ )
diff --git a/fam/llm/layers/__init__.py b/fam/llm/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..47d40436f25d16014175282f6a4ae852e2902ff4
--- /dev/null
+++ b/fam/llm/layers/__init__.py
@@ -0,0 +1,3 @@
+from fam.llm.layers.attn import SelfAttention
+from fam.llm.layers.combined import Block
+from fam.llm.layers.layers import MLP, LayerNorm, RMSNorm, SwiGLU
diff --git a/fam/llm/layers/__pycache__/__init__.cpython-310.pyc b/fam/llm/layers/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f1dbf6426e196d7595ad8af4bf2137eff639ad47
Binary files /dev/null and b/fam/llm/layers/__pycache__/__init__.cpython-310.pyc differ
diff --git a/fam/llm/layers/__pycache__/__init__.cpython-39.pyc b/fam/llm/layers/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4adf5e42f0b05c5391383d7cee37f4d1be8d4c8c
Binary files /dev/null and b/fam/llm/layers/__pycache__/__init__.cpython-39.pyc differ
diff --git a/fam/llm/layers/__pycache__/attn.cpython-310.pyc b/fam/llm/layers/__pycache__/attn.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..138e2119411afcb626387d358a73b4bd813043b1
Binary files /dev/null and b/fam/llm/layers/__pycache__/attn.cpython-310.pyc differ
diff --git a/fam/llm/layers/__pycache__/attn.cpython-39.pyc b/fam/llm/layers/__pycache__/attn.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ebb5893a7fd6a4fc63b8e9150058273adb620aae
Binary files /dev/null and b/fam/llm/layers/__pycache__/attn.cpython-39.pyc differ
diff --git a/fam/llm/layers/__pycache__/combined.cpython-310.pyc b/fam/llm/layers/__pycache__/combined.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aff69e155617c90ef47bf8aa5998eb783b2c6f81
Binary files /dev/null and b/fam/llm/layers/__pycache__/combined.cpython-310.pyc differ
diff --git a/fam/llm/layers/__pycache__/combined.cpython-39.pyc b/fam/llm/layers/__pycache__/combined.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4cace4016cf1a87dd28e44aadc5de128b05b094e
Binary files /dev/null and b/fam/llm/layers/__pycache__/combined.cpython-39.pyc differ
diff --git a/fam/llm/layers/__pycache__/layers.cpython-310.pyc b/fam/llm/layers/__pycache__/layers.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..db5bc571a56233ad460e98836bc2f6834802f292
Binary files /dev/null and b/fam/llm/layers/__pycache__/layers.cpython-310.pyc differ
diff --git a/fam/llm/layers/__pycache__/layers.cpython-39.pyc b/fam/llm/layers/__pycache__/layers.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4472598fb59a2bd7a61e0a2c48ff24565145c1eb
Binary files /dev/null and b/fam/llm/layers/__pycache__/layers.cpython-39.pyc differ
diff --git a/fam/llm/layers/attn.py b/fam/llm/layers/attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1053eecffd1c2d637c3da2affc8c5035a515d03
--- /dev/null
+++ b/fam/llm/layers/attn.py
@@ -0,0 +1,185 @@
+import warnings
+
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+
+
+class SelfAttention(nn.Module):
+ def __init__(self, config):
+ """
+ Initializes the SelfAttention module.
+
+ Args:
+ config: An object containing the configuration parameters for the SelfAttention module.
+ """
+ super().__init__()
+ self._validate_config(config)
+ self._initialize_parameters(config)
+
+ def empty_kv_cache(self, batch_size: int, kv_cache_maxlen: int, dtype: torch.dtype):
+ """
+ Empties the key-value cache.
+
+ Args:
+ batch_size: The batch size.
+ kv_cache_maxlen: The maximum length of the key-value cache.
+ dtype: The data type of the cache.
+
+ Raises:
+ Exception: If trying to empty the KV cache when it is disabled.
+ """
+ if self.kv_cache_enabled is False:
+ raise Exception("Trying to empty KV cache when it is disabled")
+
+ # register so that the cache moves devices along with the module
+ # TODO: get rid of re-allocation.
+ self.register_buffer(
+ "kv_cache",
+ torch.zeros(
+ 2,
+ batch_size,
+ kv_cache_maxlen,
+ self.n_head,
+ self.n_embd // self.n_head,
+ dtype=dtype,
+ device=self.c_attn.weight.device,
+ ),
+ persistent=False,
+ )
+
+ self.kv_cache_first_empty_index = 0
+
+ def _initialize_parameters(self, config):
+ """
+ Initializes the parameters of the SelfAttention module.
+
+ Args:
+ config: An object containing the configuration parameters for the SelfAttention module.
+ """
+ # key, query, value projections for all heads, but in a batch
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
+
+ # output projection
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
+
+ # regularization
+ self.resid_dropout = nn.Dropout(config.dropout)
+ self.n_head = config.n_head
+ self.n_embd = config.n_embd
+ self.dropout = config.dropout
+ self.causal = config.causal
+ self.attn_kernel_type = config.attn_kernel_type
+ self.attn_dropout = nn.Dropout(config.dropout)
+
+ self.kv_cache_enabled = False
+
+ def _validate_config(self, config):
+ """
+ Validates the configuration parameters.
+
+ Args:
+ config: An object containing the configuration parameters for the SelfAttention module.
+
+ Raises:
+ AssertionError: If the embedding dimension is not divisible by the number of heads.
+ """
+ assert config.n_embd % config.n_head == 0, "Embedding dimension must be divisible by number of heads"
+
+ def _update_kv_cache(self, q, k, v):
+ """
+ Updates the key-value cache.
+
+ Args:
+ q: The query tensor.
+ k: The key tensor.
+ v: The value tensor.
+
+ Returns:
+ The updated key and value tensors.
+
+ Raises:
+ AssertionError: If the dimensions of the query, key, and value tensors are not compatible.
+ """
+ q_time, k_time, v_time = q.shape[1], k.shape[1], v.shape[1]
+
+ if self.kv_cache_first_empty_index == 0:
+ assert q_time == k_time and q_time == v_time
+ else:
+ assert (
+ q_time == 1
+ ), f"Only one query at a time is supported, but got q_time={q_time} for kv_cache_first_empty_index={self.kv_cache_first_empty_index}"
+
+ self.kv_cache[0, :, self.kv_cache_first_empty_index : self.kv_cache_first_empty_index + q_time] = k
+ self.kv_cache[1, :, self.kv_cache_first_empty_index : self.kv_cache_first_empty_index + q_time] = v
+ self.kv_cache_first_empty_index += q_time
+
+ k = self.kv_cache[0, :, : self.kv_cache_first_empty_index]
+ v = self.kv_cache[1, :, : self.kv_cache_first_empty_index]
+
+ return k, v
+
+ def _torch_attn(self, c_x: torch.Tensor) -> torch.Tensor:
+ """
+ Performs attention using the torch.nn.functional.scaled_dot_product_attention function.
+
+ Args:
+ c_x: The input tensor.
+
+ Returns:
+ The output tensor.
+ """
+ q, k, v = c_x.split(1, dim=2) # q, k, v of shape (B, T, 1, nh, hs)
+ q = q.squeeze(2) # (B, T, nh, hs)
+ k = k.squeeze(2) # (B, T, nh, hs)
+ v = v.squeeze(2) # (B, T, nh, hs)
+
+ # if kv-caching and causal, for the "prefill" stage, we need to use a causal mask, and
+ # use no mask for the "one time step" parts.
+ # calculate this before updating kv_caching so we have the right value for kv_cache_first_empty_index
+ is_causal_attn_mask = self.causal and (not self.kv_cache_enabled or self.kv_cache_first_empty_index == 0)
+
+ if self.kv_cache_enabled:
+ k, v = self._update_kv_cache(q, k, v)
+
+ q = q.transpose(1, 2) # (B, nh, T, hs)
+ k = k.transpose(1, 2) # (B, nh, T, hs)
+ v = v.transpose(1, 2) # (B, nh, T, hs)
+ y = torch.nn.functional.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ attn_mask=None,
+ dropout_p=self.dropout if self.training else 0,
+ is_causal=is_causal_attn_mask,
+ ).transpose(
+ 1, 2
+ ) # (B, nh, T, hs) -> (B, T, nh, hs)
+
+ return y
+
+ def forward(self, x):
+ """
+ Performs the forward pass of the SelfAttention module.
+
+ Args:
+ x: The input tensor.
+
+ Returns:
+ The output tensor.
+ """
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
+
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
+ c_x = self.c_attn(x).view(B, T, 3, self.n_head, C // self.n_head) # (B, T, 3, nh, hs)
+
+ # causal self-attention;
+ if self.attn_kernel_type == "torch_attn":
+ y = self._torch_attn(c_x)
+ else:
+ raise Exception(f"Unknown attention kernel type: {self.attn_kernel_type}")
+
+ y = y.contiguous().view(B, T, C) # re-assemble all head outputs side by side: (B, T, nh, hs) -> (B, T, hs * nh)
+ # output projection
+ y = self.resid_dropout(self.c_proj(y))
+ return y
diff --git a/fam/llm/layers/combined.py b/fam/llm/layers/combined.py
new file mode 100644
index 0000000000000000000000000000000000000000..28991285f609e0d9b09f89847ad2822a56528a8b
--- /dev/null
+++ b/fam/llm/layers/combined.py
@@ -0,0 +1,52 @@
+import torch.nn as nn
+
+from fam.llm.layers.attn import SelfAttention
+from fam.llm.layers.layers import MLP, LayerNorm, RMSNorm
+
+
+class Block(nn.Module):
+ """
+ Block class represents a single block in the model.
+
+ Args:
+ config (object): Configuration object containing parameters for the block.
+
+ Attributes:
+ ln_1 (object): Layer normalization for the attention layer.
+ ln_2 (object): Layer normalization for the feed-forward layer.
+ attn (object): Self-attention layer.
+ mlp (object): Multi-layer perceptron layer.
+
+ Methods:
+ forward(x): Performs forward pass through the block.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ if config.norm_type == "rmsnorm":
+ if config.rmsnorm_eps is None:
+ raise Exception("RMSNorm requires rmsnorm_eps to be set")
+ self.ln_1 = RMSNorm(config.n_embd, eps=config.rmsnorm_eps) # attn norm
+ self.ln_2 = RMSNorm(config.n_embd, eps=config.rmsnorm_eps) # ffn norm
+ elif config.norm_type == "layernorm":
+ self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) # attn norm
+ self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) # ffn norm
+ else:
+ raise Exception(f"Unknown norm type: {config.norm_type}")
+ self.attn = SelfAttention(config)
+
+ self.mlp = MLP(config)
+
+ def forward(self, x):
+ """
+ Performs forward pass through the block.
+
+ Args:
+ x (tensor): Input tensor.
+
+ Returns:
+ tensor: Output tensor after passing through the block.
+ """
+ x = x + self.attn(self.ln_1(x))
+ x = x + self.mlp(self.ln_2(x))
+ return x
diff --git a/fam/llm/layers/layers.py b/fam/llm/layers/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..17063ad3d01630a41c7002f0c81623fa63e200d5
--- /dev/null
+++ b/fam/llm/layers/layers.py
@@ -0,0 +1,72 @@
+import math
+
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+
+
+class LayerNorm(nn.Module):
+ """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""
+
+ def __init__(self, ndim, bias):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(ndim))
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
+
+ def forward(self, input):
+ return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, ndim: int, eps: float):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(ndim))
+
+ def _norm(self, x):
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x):
+ return self._norm(x) * self.weight
+
+
+class SwiGLU(nn.Module):
+ def __init__(self, in_dim, out_dim, bias) -> None:
+ super().__init__()
+ self.w1 = nn.Linear(in_dim, out_dim, bias=bias)
+ self.w3 = nn.Linear(in_dim, out_dim, bias=bias)
+
+ def forward(self, x):
+ return F.silu(self.w1(x)) * self.w3(x)
+
+
+class MLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.non_linearity = config.nonlinearity_type
+ hidden_dim = 4 * config.n_embd
+ if config.nonlinearity_type == "gelu":
+ self.c_fc = nn.Linear(config.n_embd, hidden_dim, bias=config.bias)
+ self.gelu = nn.GELU()
+ self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=config.bias)
+ elif config.nonlinearity_type == "swiglu":
+ if config.swiglu_multiple_of is None:
+ raise Exception("SwiGLU requires swiglu_multiple_of to be set")
+ hidden_dim = int(2 * hidden_dim / 3)
+ hidden_dim = config.swiglu_multiple_of * math.ceil(hidden_dim / config.swiglu_multiple_of)
+ # set name to `c_proj` so that the right initialisation gets applied to it in GPT.__init__()
+ self.swiglu = SwiGLU(config.n_embd, hidden_dim, bias=config.bias)
+ self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=config.bias)
+ else:
+ raise Exception(f"Unknown nonlinearity type: {config.nonlinearity_type}")
+ self.dropout = nn.Dropout(config.dropout)
+
+ def forward(self, x):
+ if self.non_linearity == "gelu":
+ x = self.c_fc(x)
+ x = self.gelu(x)
+ elif self.non_linearity == "swiglu":
+ x = self.swiglu(x)
+ x = self.c_proj(x)
+ x = self.dropout(x)
+ return x
diff --git a/fam/llm/mixins/__init__.py b/fam/llm/mixins/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc7f3def3a82f5d50ab39cfa3581ecfd1af0d619
--- /dev/null
+++ b/fam/llm/mixins/__init__.py
@@ -0,0 +1,2 @@
+from fam.llm.mixins.causal import CausalInferenceMixin
+from fam.llm.mixins.non_causal import NonCausalInferenceMixin
diff --git a/fam/llm/mixins/__pycache__/__init__.cpython-310.pyc b/fam/llm/mixins/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ee0d43674dc7d819837250225fe344f9d5b540bc
Binary files /dev/null and b/fam/llm/mixins/__pycache__/__init__.cpython-310.pyc differ
diff --git a/fam/llm/mixins/__pycache__/__init__.cpython-39.pyc b/fam/llm/mixins/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ce02acca406d83dc0b3dc56837d771466b796f63
Binary files /dev/null and b/fam/llm/mixins/__pycache__/__init__.cpython-39.pyc differ
diff --git a/fam/llm/mixins/__pycache__/causal.cpython-310.pyc b/fam/llm/mixins/__pycache__/causal.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..07d860a27de26b6d59ae83f254703c45e3a24e99
Binary files /dev/null and b/fam/llm/mixins/__pycache__/causal.cpython-310.pyc differ
diff --git a/fam/llm/mixins/__pycache__/causal.cpython-39.pyc b/fam/llm/mixins/__pycache__/causal.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8c782846b7414ee3fe1ce6ec79da2dfe50b46d2e
Binary files /dev/null and b/fam/llm/mixins/__pycache__/causal.cpython-39.pyc differ
diff --git a/fam/llm/mixins/__pycache__/non_causal.cpython-310.pyc b/fam/llm/mixins/__pycache__/non_causal.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aef2affb6c98385cc7a30c6db2b6616fb7e3d00b
Binary files /dev/null and b/fam/llm/mixins/__pycache__/non_causal.cpython-310.pyc differ
diff --git a/fam/llm/mixins/__pycache__/non_causal.cpython-39.pyc b/fam/llm/mixins/__pycache__/non_causal.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..88d3e175b7c0c34966f1cd17eab83a7c33eecfb9
Binary files /dev/null and b/fam/llm/mixins/__pycache__/non_causal.cpython-39.pyc differ
diff --git a/fam/llm/mixins/causal.py b/fam/llm/mixins/causal.py
new file mode 100644
index 0000000000000000000000000000000000000000..62ae9d53279422a996e100dfafbe63c8caf0e63e
--- /dev/null
+++ b/fam/llm/mixins/causal.py
@@ -0,0 +1,546 @@
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import tqdm
+from torch.nn import functional as F
+
+
+def top_p_sample(prob_dist: torch.Tensor, top_p: float):
+ sorted_probs, sorted_indices = torch.sort(prob_dist, descending=True, dim=-1)
+ cum_sum_probs = torch.cumsum(sorted_probs, dim=-1) # (b, vocab_size)
+
+ sorted_indices_to_remove = cum_sum_probs > top_p
+
+ # Shift the indices to the right to keep also the first token above the threshold
+ sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
+ sorted_indices_to_remove[:, 0] = 0
+ sorted_indices_to_remove = sorted_indices_to_remove.bool()
+
+ # replace probs to be removed with 0 in the sorted_probs
+ sorted_probs[sorted_indices_to_remove] = 0
+
+ # reverse the sorting process
+ reversed_indices = torch.argsort(sorted_indices)
+ prob_dist = torch.gather(sorted_probs, -1, reversed_indices)
+
+ # normalize
+ prob_dist = prob_dist / prob_dist.sum(dim=-1, keepdim=True)
+
+ return prob_dist
+
+
+class CausalInferenceMixin:
+ """
+ Mixin class for performing inference in a causal language model.
+
+ This mixin provides methods for predicting the next token in a sequence, sampling from the model,
+ and applying token prediction masks.
+
+ Attributes:
+ None
+
+ Methods:
+ _sample_next_token: Predicts the next token in the sequence.
+ _create_token_pred_mask: Creates a token prediction mask based on sequence lengths.
+ _apply_token_pred_mask: Applies a token prediction mask to the next token predictions.
+ _sample_batch: Samples a batch of tokens from the model.
+ _sort_for_batching: Sorts the input sequences for efficient batching.
+ _causal_sample: Generates a sequence of tokens using causal sampling.
+
+ """
+
+ @torch.no_grad()
+ def _sample_next_token(
+ self,
+ *,
+ idx: torch.Tensor,
+ speaker_embs: Optional[torch.Tensor],
+ temperature: float,
+ top_k: Optional[int],
+ top_p: Optional[float],
+ guidance_scale: Optional[Tuple[float, float]],
+ ) -> torch.Tensor:
+ """
+ Predict the next token in the sequence.
+
+ Args:
+ idx (torch.Tensor): Initial sequence indices of shape (batch, num_hierarchies, time).
+ speaker_embs (Optional[torch.Tensor]): Speaker embeddings. Set to `None` if using an unconditional model.
+ temperature (float): Sampling temperature.
+ top_k (Optional[int]): Top-k filtering threshold. Set to `None` to disable top-k filtering.
+ top_p (Optional[float]): Nucleus sampling threshold. Set to `None` to disable it.
+ guidance_scale (Optional[float]): Scale factor for the guidance loss. Set to `None` to disable guidance.
+
+ Returns:
+ torch.Tensor: Next index in the sequence after sampling. Shape: (batch, num_hierarchies).
+ """
+ if top_k is not None and top_p is not None:
+ raise ValueError("Only one of top_k and top_p can be set")
+
+ # if the sequence context is growing too long we must crop it at block_size
+ idx_cond = idx if idx.size(-1) <= self.config.block_size else idx[:, :, -self.config.block_size :]
+
+ # forward the model to get the logits for the index in the sequence
+ list_logits, _ = self(
+ idx_cond, speaker_embs=speaker_embs
+ ) # list with len num_hierarchies of (b,1,vocab_size) tensors
+
+ if guidance_scale is not None:
+ spkemb_guidance_scale, prompt_guidance_scale = guidance_scale
+ assert spkemb_guidance_scale >= 1
+ assert prompt_guidance_scale >= 1
+ base_scale = spkemb_guidance_scale + prompt_guidance_scale - 1
+
+ for i, logits in enumerate(list_logits):
+ if prompt_guidance_scale > 1:
+ logits_cond, logits_uncond_spkemb, logits_uncond_prompt = logits.split(logits.shape[0] // 3, dim=0)
+ else:
+ logits_cond, logits_uncond_spkemb = logits.split(logits.shape[0] // 2, dim=0)
+ logits_uncond_prompt = 0
+ list_logits[i] = (
+ (base_scale) * logits_cond
+ + (1 - spkemb_guidance_scale) * logits_uncond_spkemb
+ + (1 - prompt_guidance_scale) * logits_uncond_prompt
+ )
+
+ # pluck the logits at the final step and scale by desired temperature
+ list_logits = [
+ logits[:, -1, :] / temperature for logits in list_logits
+ ] # list with len num_hierarchies of (b,vocab_size) tensors
+
+ # optionally crop the logits to only the top k options
+ if top_k is not None:
+ for i in range(len(list_logits)):
+ logits = list_logits[i]
+ v, _ = torch.topk(
+ logits, min(top_k, logits.size(-1))
+ ) # returns a descending sorted list of values and indices of top_k values
+ logits[logits < v[:, [-1]]] = -float("Inf") # set all logits below the smallest top_k value to -Inf
+ list_logits[i] = logits
+
+ # apply softmax to convert logits to (normalized) probabilities
+ probs = [
+ F.softmax(logits, dim=-1) for logits in list_logits
+ ] # list of len num_hierarchies of (b,vocab_size) tensors
+
+ if top_p is not None:
+ for i in range(len(probs)):
+ probs[i] = top_p_sample(probs[i], top_p)
+
+ # sample from the distribution
+ idx_next = [
+ torch.multinomial(prob, num_samples=1) for prob in probs
+ ] # list of len num_hierarchies of (b,1) tensors
+ idx_next = torch.cat(idx_next, dim=-1) # (b, num_hierarchies) tensor
+
+ return idx_next # (b, num_hierarchies) tensor
+
+ @torch.no_grad()
+ def _create_token_pred_mask(self, idx: torch.Tensor, seq_lens: list[int]) -> torch.Tensor:
+ """
+ Creates a token prediction mask based on sequence lengths.
+
+ Args:
+ idx (torch.Tensor): Initial sequence indices of shape (batch, num_hierarchies, time).
+ seq_lens (list[int]): List of sequence lengths for each sequence in idx.
+
+ Returns:
+ torch.Tensor: Token prediction mask of shape (batch, time).
+ """
+ token_pred_mask = torch.zeros((idx.shape[0], idx.shape[-1]), dtype=torch.bool, device=idx.device)
+ for i in range(len(seq_lens)):
+ token_pred_mask[i, : seq_lens[i]] = True
+
+ assert (token_pred_mask[:, : min(seq_lens)] == 1).all()
+
+ return token_pred_mask
+
+ @torch.no_grad()
+ def _apply_token_pred_mask(
+ self, *, idx_next: torch.Tensor, orig_input_at_t: torch.Tensor, token_pred_mask_at_t: torch.Tensor
+ ) -> torch.Tensor:
+ """
+ Applies a token prediction mask to the next token predictions.
+
+ Args:
+ idx_next (torch.Tensor): Next token predictions of shape (batch, num_hierarchies).
+ orig_input_at_t (torch.Tensor): Original input at time step t of shape (batch, num_hierarchies).
+ token_pred_mask_at_t (torch.Tensor): Token prediction mask at time step t of shape (batch, 1).
+
+ Returns:
+ torch.Tensor: Updated next token predictions after applying the token prediction mask.
+ """
+ idx_next = idx_next * (~token_pred_mask_at_t) + orig_input_at_t * token_pred_mask_at_t
+
+ return idx_next
+
+ @torch.no_grad()
+ def _sample_batch(
+ self,
+ *,
+ idx: torch.Tensor,
+ max_new_tokens: int,
+ seq_lens: list[int],
+ temperature: float,
+ top_k: Optional[int],
+ top_p: Optional[float],
+ speaker_embs: Optional[torch.Tensor],
+ guidance_scale: Optional[Tuple[float, float]],
+ end_of_audio_token: int,
+ end_of_text_token: int,
+ ):
+ """
+ Samples a batch of tokens from the model.
+
+ Args:
+ idx (torch.Tensor): Initial sequence indices of shape (batch, num_hierarchies, time).
+ max_new_tokens (int): Maximum number of NEW tokens to generate (in addition to largest sequence in idx).
+ seq_lens (list[int]): List of sequence lengths for each sequence in idx.
+ temperature (float): Sampling temperature.
+ top_k (Optional[int]): Top-k filtering threshold. Set to `None` to disable top-k filtering.
+ top_p (Optional[float]): Nucleus sampling threshold. Set to `None` to disable it.
+ speaker_embs (Optional[torch.Tensor]): Speaker embeddings. Set to `None` if using an unconditional model.
+ guidance_scale (Optional[float]): Scale factor for the guidance loss. Set to `None` to disable guidance.
+
+ Returns:
+ torch.Tensor: Generated sequence indices of shape (batch, num_hierarchies, time).
+ """
+ assert max(seq_lens) <= idx.shape[-1]
+ token_pred_mask = self._create_token_pred_mask(idx, seq_lens)
+ input = torch.clone(idx)
+
+ min_seq_lens = min(seq_lens)
+ idx = idx[:, :, :min_seq_lens]
+ idx_out = torch.full(
+ (idx.shape[0], idx.shape[1], idx.shape[2] + max_new_tokens),
+ end_of_audio_token,
+ dtype=idx.dtype,
+ device=idx.device,
+ )
+ idx_out[:, :, :min_seq_lens] = idx
+ terminated = idx.new_zeros(idx.shape[0], dtype=torch.bool)
+
+ if guidance_scale is not None:
+ _, prompt_guidance_scale = guidance_scale
+ if speaker_embs is None:
+ raise Exception("Guidance is only supported for conditional models")
+
+ # create speaker embeddings equivalent to the batch size, filling with None
+ # for second half to do unconditional generation.
+ speaker_embs = (
+ list(speaker_embs)
+ + [None] * (speaker_embs.shape[0])
+ + (list(speaker_embs) if prompt_guidance_scale > 1 else [])
+ )
+
+ for timestep in tqdm.tqdm(range(min_seq_lens, min_seq_lens + max_new_tokens), desc="tokens: "):
+ if terminated.all():
+ break
+ if (self.kv_cache_enabled is True) and (timestep > min_seq_lens):
+ idx_input = idx_out[:, :, [timestep - 1]]
+ else:
+ idx_input = idx_out[:, :, :timestep]
+
+ if guidance_scale is not None:
+ _, prompt_guidance_scale = guidance_scale
+ # TODO: fix: will cause a problem with kv-caching as it's not expecting larger batch-size.
+ if timestep == min_seq_lens:
+ print("[hack!!!!] Guidance is on, so we're doubling/tripling batch size!")
+
+ # replicate idx in the batch dimension
+ idx_input = (
+ idx_input.unsqueeze(0)
+ .repeat(3 if prompt_guidance_scale > 1 else 2, 1, 1, 1)
+ .reshape(-1, idx_input.shape[1], idx_input.shape[2])
+ )
+
+ if prompt_guidance_scale > 1:
+ idx_input_uncond = idx_input[idx_input.shape[0] // 3 * 2 :]
+ idx_input_uncond = idx_input_uncond.view(-1)
+ # Replace all text tokens with endoftext token
+ idx_input_uncond[idx_input_uncond > end_of_audio_token] = end_of_text_token
+
+ idx_next = self._sample_next_token(
+ idx=idx_input,
+ speaker_embs=speaker_embs,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ guidance_scale=guidance_scale,
+ ) # (b, num_hierarchies)
+
+ assert idx_next.shape[0] == idx.shape[0]
+
+ if timestep < token_pred_mask.shape[-1]:
+ idx_next = self._apply_token_pred_mask(
+ idx_next=idx_next,
+ orig_input_at_t=input[:, :, timestep],
+ token_pred_mask_at_t=token_pred_mask[:, [timestep]],
+ )
+ is_endofaudio = (idx_next == end_of_audio_token).any(dim=-1) # shape: b
+ terminated = terminated | is_endofaudio
+ idx_next[terminated] = end_of_audio_token
+ # append sampled index to the running sequence and continue
+ idx_out[:, :, timestep] = idx_next
+
+ return idx_out
+
+ @torch.no_grad()
+ def _sort_for_batching(
+ self,
+ *,
+ idx: torch.Tensor,
+ seq_lens: list[int],
+ speaker_embs: Optional[torch.Tensor],
+ batch_size: int,
+ max_new_tokens: int,
+ ) -> Tuple[list[int], list[int], torch.Tensor, list[int], Optional[torch.Tensor], int]:
+ """
+ Sorts the input sequences for efficient batching.
+
+ Args:
+ idx (torch.Tensor): Initial sequence indices of shape (batch, num_hierarchies, time).
+ seq_lens (list[int]): List of sequence lengths for each sequence in idx.
+ speaker_embs (Optional[torch.Tensor]): Speaker embeddings. Set to `None` if using an unconditional model.
+ batch_size (int): Batch size for sampling. idx is split into batches of this size for sampling.
+ max_new_tokens (int): Maximum number of NEW tokens to generate (in addition to largest sequence in idx).
+
+ Returns:
+ Tuple[list[int], list[int], torch.Tensor, list[int], Optional[torch.Tensor], int]:
+ - sorted_indices (list[int]): List of indices of the input sequences that transform it into sorted order.
+ - invert_sorted_indices (list[int]): List of indices to invert the sorted sequences back to the original order.
+ - idx (torch.Tensor): Input sequence indices in sorted order.
+ - seq_lens (list[int]): Sequence lengths in sorted order.
+ - speaker_embs (Optional[torch.Tensor]): speaker embeddings in sorted order.
+ - max_token_len (int): Effective maximum number of tokens to generate.
+ """
+ assert len(seq_lens) == idx.shape[0]
+ assert max(seq_lens) <= idx.shape[-1]
+
+ sorted_indices = np.argsort(seq_lens)
+ inverted_sorted_indices = np.zeros(len(seq_lens), dtype=np.int32)
+ inverted_sorted_indices[sorted_indices] = np.arange(len(seq_lens), dtype=np.int32)
+
+ idx = idx[sorted_indices]
+ seq_lens = [seq_lens[i] for i in sorted_indices]
+ speaker_embs = speaker_embs[sorted_indices] if speaker_embs is not None else None
+ max_token_len = 0
+
+ # figure out effective max_tokens to generate
+ for start_index in range(0, len(seq_lens), batch_size):
+ end_index = min(start_index + batch_size, len(seq_lens))
+ batch_seq_lens = seq_lens[start_index:end_index]
+ # random heuristic...
+ # # TODO: fix!
+ max_token_len = max(max_token_len, min(batch_seq_lens) + max_new_tokens)
+
+ return sorted_indices, inverted_sorted_indices, idx, seq_lens, speaker_embs, max_token_len
+
+ @torch.no_grad()
+ def _causal_sample(
+ self,
+ *,
+ idx: torch.Tensor,
+ max_new_tokens: int,
+ seq_lens: list[int],
+ temperature: float,
+ top_k: Optional[int],
+ top_p: Optional[float],
+ speaker_embs: Optional[torch.Tensor],
+ batch_size: int,
+ guidance_scale: Optional[Tuple[float, float]] = None,
+ dtype: torch.dtype = torch.bfloat16,
+ end_of_audio_token: int,
+ end_of_text_token: int,
+ ) -> torch.Tensor:
+ """
+ Generates a sequence of tokens using causal sampling.
+
+ Args:
+ idx (torch.Tensor): Initial sequence indices of shape (batch, num_hierarchies, time).
+ max_new_tokens (int): Maximum number of NEW tokens to generate (in addition to largest sequence in idx).
+ seq_lens (list[int]): List of sequence lengths for each sequence in idx.
+ temperature (float): Sampling temperature.
+ top_k (Optional[int]): Top-k filtering threshold. Set to `None` to disable top-k filtering.
+ top_p (Optional[float]): Nucleus sampling threshold. Set to `None` to disable it.
+ speaker_embs (Optional[torch.Tensor]): Speaker embeddings. Set to `None` if using an unconditional model.
+ batch_size (int): Batch size for sampling. idx is split into batches of this size for sampling.
+ guidance_scale (Optional[float]): Scale factor for the guidance loss. Set to `None` to disable guidance.
+
+ Returns:
+ torch.Tensor: Generated sequence indices of shape (batch, num_hierarchies, time).
+ """
+ (
+ _,
+ invert_sorted_indices,
+ idx,
+ seq_lens,
+ speaker_embs,
+ max_token_len,
+ ) = self._sort_for_batching(
+ idx=idx, seq_lens=seq_lens, speaker_embs=speaker_embs, batch_size=batch_size, max_new_tokens=max_new_tokens
+ )
+
+ return_idx = torch.zeros((len(seq_lens), idx.size(1), max_token_len), dtype=torch.long, device=idx.device)
+
+ for start_index in tqdm.tqdm(range(0, len(seq_lens), batch_size), desc="batch: "):
+ end_index = min(start_index + batch_size, len(seq_lens))
+
+ kv_batch_size = end_index - start_index
+ if guidance_scale is not None:
+ if guidance_scale[1] > 1:
+ kv_batch_size = 3 * kv_batch_size
+ else:
+ kv_batch_size = 2 * kv_batch_size
+
+ if self.kv_cache_enabled:
+ self.empty_kv_cache(
+ batch_size=kv_batch_size,
+ kv_cache_maxlen=self.config.block_size,
+ dtype=dtype,
+ )
+
+ batch_seq_lens = seq_lens[start_index:end_index]
+ batch_max_new_tokens = max_token_len - min(batch_seq_lens)
+
+ batch_idx = idx[start_index:end_index]
+ batch_speaker_embs = speaker_embs[start_index:end_index] if speaker_embs is not None else None
+
+ batch_idx = self._sample_batch(
+ idx=batch_idx,
+ max_new_tokens=batch_max_new_tokens,
+ seq_lens=batch_seq_lens,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ speaker_embs=batch_speaker_embs,
+ guidance_scale=guidance_scale,
+ end_of_audio_token=end_of_audio_token,
+ end_of_text_token=end_of_text_token,
+ )
+ return_idx[start_index:end_index] = batch_idx
+
+ return return_idx[invert_sorted_indices]
+
+ def empty_kv_cache(self, *, batch_size: int, kv_cache_maxlen: int, dtype: torch.dtype):
+ """
+ Empties key-value (KV) cache for causal attention.
+
+ Args:
+ batch_size (int): The batch size.
+ kv_cache_maxlen (int): The maximum length of the KV cache.
+ dtype (torch.dtype): The data type of the KV cache.
+
+ Raises:
+ Exception: If KV cache is enabled for non-causal attention.
+
+ """
+ if self.kv_cache_enabled is False:
+ raise Exception("KV cache is not enabled")
+ if self.config.causal is False:
+ raise Exception("KV cache is not supported for non-causal attention")
+
+ self.kv_pos = 0
+ for block in self.transformer.h:
+ block.attn.empty_kv_cache(batch_size=batch_size, kv_cache_maxlen=kv_cache_maxlen, dtype=dtype)
+
+ def enable_kv_cache(self):
+ """
+ Enables key-value (KV) cache for causal attention.
+
+ Raises:
+ Exception: If KV cache is enabled for non-causal attention.
+
+ """
+ if self.config.causal is False:
+ raise Exception("KV cache is not supported for non-causal attention")
+
+ self.kv_cache_enabled = True
+ for block in self.transformer.h:
+ block.attn.kv_cache_enabled = True
+
+ def disable_kv_cache(self):
+ """
+ Disables the key-value cache for the transformer and all its blocks.
+ """
+ self.kv_cache_enabled = False
+ for block in self.transformer.h:
+ block.attn.kv_cache_enabled = False
+ block.attn.kv_cache = None
+ block.attn.kv_cache_first_empty_index = 0
+
+ @torch.no_grad()
+ def _slow_causal_sampling_loop(
+ self,
+ idx: torch.Tensor,
+ max_new_tokens: int,
+ temperature: float = 1.0,
+ top_k: Optional[int] = None,
+ top_p: Optional[float] = None,
+ speaker_embs: Optional[torch.Tensor] = None,
+ guidance_scale: Optional[float] = None,
+ ):
+ """
+ Old non-batched version of causal sampling. Kept for testing / reference.
+
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,n_head,t)) and complete
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
+ """
+ assert idx.dim() == 3, "idx must be a batch of sequences of hierarchical tokens"
+ assert idx.size(0) == 1, "can only do one sequence at a time for now"
+ assert top_p is None, "nucleus sampling not supported yet with _slow_causal_sampling_loop"
+
+ if self.config.causal is not True:
+ raise Exception("Causal sampling is only supported for causal models")
+
+ if self.kv_cache_enabled:
+ print("!!!! USING KV-CACHING ASSUMED TORCH.BFLOAT16")
+ self.empty_kv_cache(
+ batch_size=1,
+ kv_cache_maxlen=self.config.block_size,
+ dtype=torch.bfloat16,
+ )
+
+ for i in range(max_new_tokens):
+ # if the sequence context is growing too long we must crop it at block_size
+ idx_cond = idx if idx.size(-1) <= self.config.block_size else idx[:, -self.config.block_size :]
+
+ if self.kv_cache_enabled:
+ if i > 0:
+ idx_cond = idx_cond[:, :, -1:]
+
+ # forward the model to get the logits for the index in the sequence
+ list_logits, _ = self(idx_cond, speaker_embs=speaker_embs)
+
+ if guidance_scale is not None:
+ # we've already checked that kv-caching is not switched on
+ # so this should be ok.
+ list_logits_uncond, _ = self(idx_cond, speaker_embs=None)
+ list_logits = [
+ (guidance_scale) * logits + (1 - guidance_scale) * logits_uncond
+ for logits, logits_uncond in zip(list_logits, list_logits_uncond)
+ ]
+
+ # pluck the logits at the final step and scale by desired temperature
+ list_logits = [logits[:, -1, :] / temperature for logits in list_logits]
+
+ # optionally crop the logits to only the top k options
+ if top_k is not None:
+ for i in range(len(list_logits)):
+ logits = list_logits[i]
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
+ logits[logits < v[:, [-1]]] = -float("Inf")
+ list_logits[i] = logits
+
+ # apply softmax to convert logits to (normalized) probabilities
+ probs = [F.softmax(logits, dim=-1) for logits in list_logits]
+ # sample from the distribution
+ idx_next = torch.tensor(
+ [torch.multinomial(prob, num_samples=1) for prob in probs], device=idx.device
+ ) # (c, 1)
+ # append sampled index to the running sequence and continue
+ idx = torch.cat((idx, idx_next.unsqueeze(0).unsqueeze(-1)), dim=2)
+
+ return idx
diff --git a/fam/llm/mixins/non_causal.py b/fam/llm/mixins/non_causal.py
new file mode 100644
index 0000000000000000000000000000000000000000..1817a56a957e4fa92c8ea74feb166f37e8cc8d06
--- /dev/null
+++ b/fam/llm/mixins/non_causal.py
@@ -0,0 +1,67 @@
+from typing import Optional
+
+import torch
+from torch.nn import functional as F
+
+
+class NonCausalInferenceMixin:
+ """
+ Mixin class for non-causal inference in a language model.
+
+ This class provides methods for performing non-causal sampling using a language model.
+ """
+
+ @torch.no_grad()
+ def _non_causal_sample(
+ self, *, idx: torch.Tensor, speaker_embs: Optional[torch.Tensor], temperature: float, top_k: int
+ ):
+ """
+ Perform non-causal sampling.
+
+ Args:
+ idx (torch.Tensor): Input tensor of shape (batch_size, num_in_hierarchies, sequence_length).
+ speaker_embs (Optional[torch.Tensor]): Speaker embeddings tensor of shape (batch_size, embedding_size).
+ temperature (float): Temperature parameter for scaling the logits.
+ top_k (int): Number of top options to consider.
+
+ Returns:
+ torch.Tensor: Sampled output tensor of shape (batch_size, num_out_hierarchies, sequence_length).
+ """
+ b, c, t = idx.size()
+ assert t == self.config.block_size, f"input size {t} != config.block_size {self.config.block_size}"
+ # forward the model to get the logits for the index in the sequence
+ list_logits, _ = self(idx, speaker_embs=speaker_embs) # c x (b, t, vocab_size)
+
+ # scale by desired temperature
+ list_logits = [logits / temperature for logits in list_logits] # c x (b, t, vocab_size)
+
+ # optionally crop the logits to only the top k options
+ if top_k is not None:
+ for i in range(len(list_logits)):
+ logits = list_logits[i] # (b, t, vocab_size)
+
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1))) # (b, t, top_k)
+ logits[logits < v[:, :, [-1]]] = -float("Inf")
+ list_logits[i] = logits # (b, t, vocab_size)
+ assert logits.shape[0] == b and logits.shape[1] == t
+
+ # apply softmax to convert logits to (normalized) probabilities
+ # TODO: check shapes here!
+ probs = [F.softmax(logits, dim=-1) for logits in list_logits] # c x (b, t, top_k)
+ assert probs[0].shape[0] == b and probs[0].shape[1] == t
+
+ # TODO: output shape is as expected
+ outs = []
+ for b_prob in probs: # c x (b, t, top_k) -> (b, t, top_k)
+ out = [
+ torch.multinomial(prob, num_samples=1).transpose(0, 1).unsqueeze(0) for prob in b_prob
+ ] # b x (t, top_k) -> b x (t, 1) -> b x (1, t) -> b x (1, 1, t)
+ assert len(out) == b and out[0].shape[0] == 1 and out[0].shape[1] == 1 and out[0].shape[2] == t
+ out = torch.cat(out, dim=0) # (b, 1, t)
+ assert out.shape[0] == b and out.shape[1] == 1 and out.shape[2] == t
+ outs.append(out)
+
+ out = torch.cat(outs, dim=1) # (b, c, t)
+ assert out.shape[0] == b and out.shape[2] == t
+
+ return out
diff --git a/fam/llm/model.py b/fam/llm/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f98bf7e4c2684dedde7d34c4a17e91835399bb8
--- /dev/null
+++ b/fam/llm/model.py
@@ -0,0 +1,408 @@
+import inspect
+import math
+from dataclasses import dataclass, field
+from typing import Literal, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import tqdm
+from einops import rearrange
+from torch.nn import functional as F
+
+from fam.llm.layers import Block, LayerNorm, RMSNorm
+from fam.llm.mixins import CausalInferenceMixin, NonCausalInferenceMixin
+
+END_OF_TEXT_TOKEN = 1537
+
+
+def _select_spkemb(spkemb, mask):
+ _, examples, _ = spkemb.shape
+ mask = torch.nn.functional.one_hot(mask.long(), num_classes=examples).to(spkemb) # shape: (batch, time, examples)
+ spkemb = spkemb.transpose(1, 2) # b ex c -> b c ex
+ mask = mask.transpose(1, 2) # b t ex -> b ex t
+ return torch.bmm(spkemb, mask).transpose(1, 2) # b c t -> b t c
+
+
+@dataclass
+class GPTConfig:
+ block_size: int = 1024
+ vocab_sizes: list = field(default_factory=list)
+ target_vocab_sizes: Optional[list] = None
+ n_layer: int = 12
+ n_head: int = 12
+ n_embd: int = 768
+ dropout: float = 0.0
+ spkemb_dropout: float = 0.0
+ bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
+ causal: bool = (
+ True # auto-regressive or not, i.e. whether to have attention mask that prevents attending to future tokens
+ )
+ spk_emb_on_text: bool = True # whether to add speaker embedding conditioning to text tokens or not
+ norm_type: str = "layernorm" # "rmsnorm" or "layernorm
+ rmsnorm_eps: Optional[float] = None # only used for rmsnorm
+ nonlinearity_type: str = "gelu" # "gelu" or "swiglu"
+ swiglu_multiple_of: Optional[int] = None # MLP hidden layer (using SwiGLU) will be multiple of this
+ attn_kernel_type: Literal["torch_attn"] = "torch_attn"
+ kv_cache_enabled: bool = False # whether to use key-value cache for attention
+
+
+def _check_speaker_emb_dims(
+ speaker_embs: Union[list, torch.Tensor], expected_speaker_emb_dim: int, expected_batch_size: int
+) -> Union[torch.Tensor, list]:
+ """
+ Checks that the speaker embedding dimensions are correct, and reshapes them if necessary.
+ """
+ if type(speaker_embs) == list:
+ b_se = len(speaker_embs)
+ for i, s in enumerate(speaker_embs):
+ if s is not None:
+ emb_dim = s.shape[-1]
+ if s.ndim == 1:
+ speaker_embs[i] = speaker_embs[i].unsqueeze(0)
+ else:
+ if speaker_embs.ndim == 2:
+ # if we have a single speaker embedding for the whole sequence,
+ # add a dummy dimension for backwards compatibility
+ speaker_embs = speaker_embs[:, None, :]
+
+ # num_examples is the number of utterances packed into this sequence
+ b_se, num_examples, emb_dim = speaker_embs.size()
+
+ assert b_se == expected_batch_size, f"Batch size mismatch: {b_se} != {expected_batch_size}"
+ assert (
+ emb_dim == expected_speaker_emb_dim
+ ), f"Speaker embedding dimension mismatch: {emb_dim} != {expected_speaker_emb_dim}"
+
+ return speaker_embs
+
+
+class GPT(nn.Module, NonCausalInferenceMixin, CausalInferenceMixin):
+ def __init__(self, config: GPTConfig, speaker_emb_dim: Optional[int] = None):
+ """
+ Initialize the GPT model.
+
+ Args:
+ config (GPTConfig): Configuration object for the model.
+ speaker_emb_dim (Optional[int]): Dimension of the speaker embedding. Default is None.
+ """
+ super().__init__()
+ assert config.vocab_sizes is not None
+ assert config.block_size is not None
+ self.config = config
+
+ self.kv_cache_enabled = False # disabled by default
+ self.kv_pos = 0
+
+ self.speaker_emb_dim = speaker_emb_dim
+ self.spk_emb_on_text = config.spk_emb_on_text
+ if self.config.causal is True and self.spk_emb_on_text is False:
+ print("!!!!!!!!!!!!!!!!!!")
+ print(
+ f"!!!!!!!! Using DEFAULT of {END_OF_TEXT_TOKEN} as end of text token to find speaker cond masking!! You likely need to change this."
+ )
+ print("!!!!!!!!!!!!!!!!!!")
+ if self.config.causal is False and self.spk_emb_on_text is False:
+ raise Exception(
+ "Cannot use speaker embedding masking with non-causal model. This is unexpected. Check for relevant changes required in code before proceeding."
+ )
+
+ if config.norm_type == "rmsnorm":
+ if config.rmsnorm_eps is None:
+ raise Exception("RMSNorm requires rmsnorm_eps to be set")
+ ln_f = RMSNorm(config.n_embd, eps=config.rmsnorm_eps)
+ elif config.norm_type == "layernorm":
+ ln_f = LayerNorm(config.n_embd, bias=config.bias)
+ else:
+ raise Exception(f"Unknown norm type: {config.norm_type}")
+
+ self.transformer = nn.ModuleDict(
+ dict(
+ wtes=nn.ModuleList([nn.Embedding(vsize, config.n_embd) for vsize in config.vocab_sizes]),
+ wpe=nn.Embedding(config.block_size, config.n_embd),
+ drop=nn.Dropout(config.dropout),
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
+ ln_f=ln_f,
+ )
+ )
+ if speaker_emb_dim is not None:
+ self.speaker_cond_pos = nn.Linear(speaker_emb_dim, config.n_embd, bias=False)
+
+ self.lm_heads = nn.ModuleList()
+ if config.target_vocab_sizes is not None:
+ assert config.causal is False
+ else:
+ assert config.causal is True
+
+ for vsize in config.vocab_sizes if config.target_vocab_sizes is None else config.target_vocab_sizes:
+ self.lm_heads.append(nn.Linear(config.n_embd, vsize, bias=False))
+
+ if config.target_vocab_sizes is None:
+ for i in range(len(config.vocab_sizes)):
+ # TODO: do we not need to take the transpose here?
+ # https://paperswithcode.com/method/weight-tying
+ self.lm_heads[i].weight = self.transformer.wtes[i].weight # type: ignore
+ assert len(self.lm_heads) == len(
+ self.transformer.wtes # type: ignore
+ ), f"Number of heads ({len(self.lm_heads)}) must match number of one-hot embedding matrics ({len(self.transformer.wtes)})." # type: ignore
+
+ # init all weights
+ self.apply(self._init_weights)
+ # apply special scaled init to the residual projections, per GPT-2 paper
+ for pn, p in self.named_parameters():
+ if pn.endswith("c_proj.weight"):
+ torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
+
+ # report number of parameters
+ print("number of parameters: %.2fM" % (self.get_num_params() / 1e6,))
+
+ def get_num_params(self, non_embedding=True):
+ """
+ Return the number of parameters in the model.
+ For non-embedding count (default), the position embeddings get subtracted.
+ The token embeddings would too, except due to the parameter sharing these
+ params are actually used as weights in the final layer, so we include them.
+ """
+ n_params = sum(p.numel() for p in self.parameters())
+ if non_embedding:
+ n_params -= self.transformer.wpe.weight.numel()
+ return n_params
+
+ def _init_weights(self, module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
+ if module.bias is not None:
+ torch.nn.init.zeros_(module.bias)
+ elif isinstance(module, nn.Embedding):
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
+
+ def _mask_spk_emb_on_text(self, idx: torch.Tensor, spk_emb: torch.Tensor) -> torch.Tensor:
+ """
+ This is in a separate function so we can test it easily.
+ """
+ # find index of end of text token in each sequence, then generate a binary mask
+ # of shape (b, 1, t) to mask out the speaker embedding for all tokens before the end of text token.
+ # Note: this does NOT mask the token. This is important so that the first audio token predicted
+ # has speaker information to use.
+
+ # Check in channel dimension 0 as this is usually the first hierarchy where we put the text tokens.
+ is_end_of_text = idx[:, 0, :] == END_OF_TEXT_TOKEN
+ # use > 0, in case end_of_text_token is repeated for any reason.
+ mask = (torch.cumsum(is_end_of_text, dim=-1) > 0).float()
+ spk_emb = spk_emb * mask[:, :, None]
+
+ return spk_emb
+
+ def forward(
+ self,
+ idx,
+ targets=None,
+ speaker_embs=None,
+ speaker_emb_mask=None,
+ loss_reduce: Literal["mean", "none"] = "mean",
+ ):
+ device = idx.device
+ b, num_hierarchies, t = idx.size()
+
+ if speaker_embs is not None:
+ speaker_embs = _check_speaker_emb_dims(
+ speaker_embs=speaker_embs, expected_speaker_emb_dim=self.speaker_emb_dim, expected_batch_size=b
+ )
+
+ assert (
+ t <= self.config.block_size
+ ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
+
+ if self.kv_cache_enabled:
+ if self.kv_pos == 0:
+ pos = torch.arange(0, t, dtype=torch.long, device=device)
+ self.kv_pos += t
+ else:
+ assert t == 1, "KV cache is only supported for single token inputs"
+ pos = torch.tensor([self.kv_pos], dtype=torch.long, device=device) # shape (1)
+ self.kv_pos += 1
+ else:
+ pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
+
+ # forward the GPT model itself
+ assert num_hierarchies == len(
+ self.transformer.wtes
+ ), f"Input tensor has {num_hierarchies} hierarchies, but model has {len(self.transformer.wtes)} set of input embeddings."
+
+ # embed the tokens, positional encoding, and speaker embedding
+ tok_emb = torch.zeros((b, t, self.config.n_embd), device=device)
+ # ends up swapping (B, num_hierarchies, t) tokens -> (B, t, c) embeddings.
+ for i, wte in enumerate(self.transformer.wtes):
+ tok_emb += wte(idx[:, i, :])
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
+
+ spk_emb = 0.0
+ if speaker_embs is not None:
+ if type(speaker_embs) == list:
+ assert speaker_emb_mask is None
+ assert self.training is False
+ assert self.spk_emb_on_text is True
+
+ spk_emb = []
+ for speaker_emb_row in speaker_embs:
+ if speaker_emb_row is not None:
+ spk_emb.append(self.speaker_cond_pos(speaker_emb_row.unsqueeze(0)))
+ assert spk_emb[-1].shape == (1, 1, self.config.n_embd), f"spk_emb[-1].shape={spk_emb[-1].shape}"
+ else:
+ spk_emb.append(torch.zeros((1, 1, self.config.n_embd), device=device, dtype=pos_emb.dtype))
+ spk_emb = torch.cat(spk_emb, dim=0)
+
+ assert (
+ spk_emb.ndim == 3 and spk_emb.shape[1] == 1 and spk_emb.shape[0] == b
+ ), f"spk_emb.ndim={spk_emb.ndim}, spk_emb.shape={spk_emb.shape}, len(speaker_embs)={len(speaker_embs)}"
+ else:
+ speakers_embedded = self.speaker_cond_pos(speaker_embs) # shape (b, num_examples, c)
+
+ if speaker_emb_mask is not None:
+ spk_emb = _select_spkemb(speakers_embedded, speaker_emb_mask)
+ assert spk_emb.shape == (b, t, self.config.n_embd)
+ else:
+ spk_emb = speakers_embedded
+ # if we don't have a mask, we assume that the speaker embedding is the same for all tokens
+ # then num_examples dimension just becomes the time dimension
+ assert spk_emb.ndim == 3 and spk_emb.shape[1] == 1
+
+ if self.training and self.config.spkemb_dropout > 0.0:
+ # Remove speaker conditioning at random.
+ dropout = torch.ones_like(speakers_embedded) * (
+ torch.rand(speakers_embedded.shape[0], 1, 1, device=device) >= self.config.spkemb_dropout
+ )
+ spk_emb = torch.where(dropout == 0, torch.zeros_like(speakers_embedded), speakers_embedded)
+
+ if self.spk_emb_on_text is False:
+ assert speaker_emb_mask is None, "Not implemented for spk_emb_on_text=False"
+ spk_emb = self._mask_spk_emb_on_text(idx, spk_emb)
+
+ x = self.transformer.drop(tok_emb + pos_emb + spk_emb)
+ for block in self.transformer.h:
+ x = block(x)
+ x = self.transformer.ln_f(x)
+
+ if targets is not None:
+ # if we are given some desired targets also calculate the loss
+ list_logits = [lm_head(x) for lm_head in self.lm_heads]
+
+ losses = [
+ F.cross_entropy(
+ logits.view(-1, logits.size(-1)),
+ targets[:, i, :].contiguous().view(-1),
+ ignore_index=-1,
+ reduction=loss_reduce,
+ )
+ for i, logits in enumerate(list_logits)
+ ]
+ # TODO: should we do this better without stack somehow?
+ losses = torch.stack(losses)
+ if loss_reduce == "mean":
+ losses = losses.mean()
+ else:
+ losses = rearrange(losses, "h (b t) -> b h t", h=len(self.lm_heads), b=b, t=t)
+ else:
+ # inference-time mini-optimization: only forward the lm_head on the very last position
+ if self.config.causal:
+ list_logits = [
+ lm_head(x[:, [-1], :]) for lm_head in self.lm_heads
+ ] # note: using list [-1] to preserve the time dim
+ else:
+ list_logits = [lm_head(x) for lm_head in self.lm_heads]
+ losses = None
+
+ return list_logits, losses
+
+ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
+ # start with all of the candidate parameters
+ param_dict = {pn: p for pn, p in self.named_parameters()}
+ # filter out those that do not require grad
+ param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
+ # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
+ # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
+ nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
+ optim_groups = [
+ {"params": decay_params, "weight_decay": weight_decay},
+ {"params": nodecay_params, "weight_decay": 0.0},
+ ]
+ num_decay_params = sum(p.numel() for p in decay_params)
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
+ print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
+ print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
+ # Create AdamW optimizer and use the fused version if it is available
+ fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters
+ use_fused = fused_available and device_type == "cuda"
+ extra_args = dict(fused=True) if use_fused else dict()
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
+ print(f"using fused AdamW: {use_fused}")
+
+ return optimizer
+
+ @torch.no_grad()
+ def generate(
+ self,
+ idx: torch.Tensor,
+ max_new_tokens: int,
+ seq_lens: Optional[list] = None,
+ temperature: float = 1.0,
+ top_k: Optional[int] = None,
+ top_p: Optional[float] = None,
+ speaker_embs: Optional[torch.Tensor] = None,
+ batch_size: Optional[int] = None,
+ guidance_scale: Optional[Tuple[float, float]] = None,
+ dtype: torch.dtype = torch.bfloat16,
+ end_of_audio_token: int = 99999, # Dummy values will disable early termination / guidance features.
+ end_of_text_token: int = 99999,
+ ):
+ """
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,num_hierarchies,t)) and complete
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
+ """
+ assert idx.dim() == 3, "idx must be a batch of sequences of hierarchical tokens"
+
+ if self.config.causal:
+ if seq_lens is None or batch_size is None:
+ raise Exception("seq_lens and batch_size must be provided for causal sampling")
+
+ return self._causal_sample(
+ idx=idx,
+ max_new_tokens=max_new_tokens,
+ seq_lens=seq_lens,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ speaker_embs=speaker_embs,
+ batch_size=batch_size,
+ guidance_scale=guidance_scale,
+ dtype=dtype,
+ end_of_audio_token=end_of_audio_token,
+ end_of_text_token=end_of_text_token,
+ )
+
+ else:
+ if seq_lens is not None:
+ raise Exception("seq_lens is not supported yet for non-causal sampling")
+
+ if batch_size is None:
+ raise Exception("batch_size must be provided for non-causal sampling")
+
+ if guidance_scale is not None:
+ raise Exception("guidance_scale is not supported for non-causal sampling")
+
+ if top_p is not None:
+ raise Exception("top_p is not supported for non-causal sampling")
+
+ out = []
+ for start_index in tqdm.tqdm(range(0, idx.shape[0], batch_size), desc="Non-causal batching"):
+ end_index = min(start_index + batch_size, idx.shape[0])
+ out.append(
+ self._non_causal_sample(
+ idx=idx[start_index:end_index],
+ speaker_embs=speaker_embs[start_index:end_index] if speaker_embs is not None else None,
+ temperature=temperature,
+ top_k=top_k,
+ )
+ )
+ return torch.cat(out, dim=0)
diff --git a/fam/llm/utils.py b/fam/llm/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbfcd40c933bb7d6dc5f2a8d1369889fe8aa2844
--- /dev/null
+++ b/fam/llm/utils.py
@@ -0,0 +1,89 @@
+import os
+import re
+import subprocess
+import tempfile
+
+import librosa
+import torch
+
+
+def normalize_text(text: str) -> str:
+ unicode_conversion = {
+ 8175: "'",
+ 8189: "'",
+ 8190: "'",
+ 8208: "-",
+ 8209: "-",
+ 8210: "-",
+ 8211: "-",
+ 8212: "-",
+ 8213: "-",
+ 8214: "||",
+ 8216: "'",
+ 8217: "'",
+ 8218: ",",
+ 8219: "`",
+ 8220: '"',
+ 8221: '"',
+ 8222: ",,",
+ 8223: '"',
+ 8228: ".",
+ 8229: "..",
+ 8230: "...",
+ 8242: "'",
+ 8243: '"',
+ 8245: "'",
+ 8246: '"',
+ 180: "'",
+ 2122: "TM", # Trademark
+ }
+
+ text = text.translate(unicode_conversion)
+
+ non_bpe_chars = set([c for c in list(text) if ord(c) >= 256])
+ if len(non_bpe_chars) > 0:
+ non_bpe_points = [(c, ord(c)) for c in non_bpe_chars]
+ raise ValueError(f"Non-supported character found: {non_bpe_points}")
+
+ text = text.replace("\t", " ").replace("\n", " ").replace("\r", " ").replace("*", " ").strip()
+ text = re.sub("\s\s+", " ", text) # remove multiple spaces
+ return text
+
+
+def check_audio_file(path_or_uri, threshold_s=30):
+ if "http" in path_or_uri:
+ temp_fd, filepath = tempfile.mkstemp()
+ os.close(temp_fd) # Close the file descriptor, curl will create a new connection
+ curl_command = ["curl", "-L", path_or_uri, "-o", filepath]
+ subprocess.run(curl_command, check=True)
+
+ else:
+ filepath = path_or_uri
+
+ audio, sr = librosa.load(filepath)
+ duration_s = librosa.get_duration(y=audio, sr=sr)
+ if duration_s < threshold_s:
+ raise Exception(
+ f"The audio file is too short. Please provide an audio file that is at least {threshold_s} seconds long to proceed."
+ )
+
+ # Clean up the temporary file if it was created
+ if "http" in path_or_uri:
+ os.remove(filepath)
+
+
+def get_default_dtype() -> str:
+ """Compute default 'dtype' based on GPU architecture"""
+ if torch.cuda.is_available():
+ for i in range(torch.cuda.device_count()):
+ device_properties = torch.cuda.get_device_properties(i)
+ dtype = "float16" if device_properties.major <= 7 else "bfloat16" # tesla and turing architectures
+ else:
+ dtype = "float16"
+
+ print(f"using dtype={dtype}")
+ return dtype
+
+
+def get_device() -> str:
+ return "cuda" if torch.cuda.is_available() else "cpu"
diff --git a/fam/py.typed b/fam/py.typed
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/fam/quantiser/__init__.py b/fam/quantiser/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/fam/quantiser/__pycache__/__init__.cpython-310.pyc b/fam/quantiser/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..19058e0bb9d7fa24c769ef471420a1b8ee68e7f6
Binary files /dev/null and b/fam/quantiser/__pycache__/__init__.cpython-310.pyc differ
diff --git a/fam/quantiser/__pycache__/__init__.cpython-39.pyc b/fam/quantiser/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6d7f2f2976af68f0e09886acd86e384a01bf7056
Binary files /dev/null and b/fam/quantiser/__pycache__/__init__.cpython-39.pyc differ
diff --git a/fam/quantiser/audio/__init__.py b/fam/quantiser/audio/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/fam/quantiser/audio/__pycache__/__init__.cpython-310.pyc b/fam/quantiser/audio/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9d41f53639362d7951e10766dc5f3fad18b21048
Binary files /dev/null and b/fam/quantiser/audio/__pycache__/__init__.cpython-310.pyc differ
diff --git a/fam/quantiser/audio/__pycache__/__init__.cpython-39.pyc b/fam/quantiser/audio/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..23c2509756c1be7ae802fc37d930b938602a972e
Binary files /dev/null and b/fam/quantiser/audio/__pycache__/__init__.cpython-39.pyc differ
diff --git a/fam/quantiser/audio/speaker_encoder/__init__.py b/fam/quantiser/audio/speaker_encoder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/fam/quantiser/audio/speaker_encoder/__pycache__/__init__.cpython-310.pyc b/fam/quantiser/audio/speaker_encoder/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8020d822c38c66bd2001316b914c3c79b2134cb4
Binary files /dev/null and b/fam/quantiser/audio/speaker_encoder/__pycache__/__init__.cpython-310.pyc differ
diff --git a/fam/quantiser/audio/speaker_encoder/__pycache__/__init__.cpython-39.pyc b/fam/quantiser/audio/speaker_encoder/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a491b86cbf60acec00b0981f4556c08bd8e461f3
Binary files /dev/null and b/fam/quantiser/audio/speaker_encoder/__pycache__/__init__.cpython-39.pyc differ
diff --git a/fam/quantiser/audio/speaker_encoder/__pycache__/audio.cpython-310.pyc b/fam/quantiser/audio/speaker_encoder/__pycache__/audio.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d2fb614b79f330ad025f6485d8a6105457642cc2
Binary files /dev/null and b/fam/quantiser/audio/speaker_encoder/__pycache__/audio.cpython-310.pyc differ
diff --git a/fam/quantiser/audio/speaker_encoder/__pycache__/audio.cpython-39.pyc b/fam/quantiser/audio/speaker_encoder/__pycache__/audio.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d0c2e652f67a046a1c6ff30c6e0537fca9f7e340
Binary files /dev/null and b/fam/quantiser/audio/speaker_encoder/__pycache__/audio.cpython-39.pyc differ
diff --git a/fam/quantiser/audio/speaker_encoder/__pycache__/model.cpython-310.pyc b/fam/quantiser/audio/speaker_encoder/__pycache__/model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..240b19adb5e837c2a179e668a0e31af7e40a26ee
Binary files /dev/null and b/fam/quantiser/audio/speaker_encoder/__pycache__/model.cpython-310.pyc differ
diff --git a/fam/quantiser/audio/speaker_encoder/__pycache__/model.cpython-39.pyc b/fam/quantiser/audio/speaker_encoder/__pycache__/model.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..731ded1e51fbfc1185cf07e89740551375996183
Binary files /dev/null and b/fam/quantiser/audio/speaker_encoder/__pycache__/model.cpython-39.pyc differ
diff --git a/fam/quantiser/audio/speaker_encoder/audio.py b/fam/quantiser/audio/speaker_encoder/audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..33729870d0adc944978f1902805e9bbd7e5af224
--- /dev/null
+++ b/fam/quantiser/audio/speaker_encoder/audio.py
@@ -0,0 +1,22 @@
+import librosa
+import numpy as np
+
+mel_window_length = 25
+mel_window_step = 10
+mel_n_channels = 40
+sampling_rate = 16000
+
+
+def wav_to_mel_spectrogram(wav):
+ """
+ Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
+ Note: this not a log-mel spectrogram.
+ """
+ frames = librosa.feature.melspectrogram(
+ y=wav,
+ sr=sampling_rate,
+ n_fft=int(sampling_rate * mel_window_length / 1000),
+ hop_length=int(sampling_rate * mel_window_step / 1000),
+ n_mels=mel_n_channels,
+ )
+ return frames.astype(np.float32).T
diff --git a/fam/quantiser/audio/speaker_encoder/model.py b/fam/quantiser/audio/speaker_encoder/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ad58593b4e1bc7cf2b56c4a94467bdec8aea95b
--- /dev/null
+++ b/fam/quantiser/audio/speaker_encoder/model.py
@@ -0,0 +1,117 @@
+import os
+from time import perf_counter as timer
+from typing import List, Optional, Union
+
+import librosa
+import numpy as np
+import torch
+from torch import nn
+
+from fam.quantiser.audio.speaker_encoder import audio
+
+mel_window_step = 10
+mel_n_channels = 40
+sampling_rate = 16000
+partials_n_frames = 160
+model_hidden_size = 256
+model_embedding_size = 256
+model_num_layers = 3
+
+
+class SpeakerEncoder(nn.Module):
+ def __init__(
+ self,
+ weights_fpath: Optional[str] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ verbose: bool = True,
+ eval: bool = False,
+ ):
+ super().__init__()
+
+ # Define the network
+ self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
+ self.linear = nn.Linear(model_hidden_size, model_embedding_size)
+ self.relu = nn.ReLU()
+
+ # Get the target device
+ if device is None:
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ elif isinstance(device, str):
+ device = torch.device(device)
+ self.device = device
+
+ start = timer()
+
+ checkpoint = torch.load(weights_fpath, map_location="cpu")
+ self.load_state_dict(checkpoint["model_state"], strict=False)
+ self.to(device)
+
+ if eval:
+ self.eval()
+
+ if verbose:
+ print("Loaded the speaker embedding model on %s in %.2f seconds." % (device.type, timer() - start))
+
+ def forward(self, mels: torch.FloatTensor):
+ _, (hidden, _) = self.lstm(mels)
+ embeds_raw = self.relu(self.linear(hidden[-1]))
+ return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
+
+ @staticmethod
+ def compute_partial_slices(n_samples: int, rate, min_coverage):
+ # Compute how many frames separate two partial utterances
+ samples_per_frame = int((sampling_rate * mel_window_step / 1000))
+ n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
+ frame_step = int(np.round((sampling_rate / rate) / samples_per_frame))
+
+ # Compute the slices
+ wav_slices, mel_slices = [], []
+ steps = max(1, n_frames - partials_n_frames + frame_step + 1)
+ for i in range(0, steps, frame_step):
+ mel_range = np.array([i, i + partials_n_frames])
+ wav_range = mel_range * samples_per_frame
+ mel_slices.append(slice(*mel_range))
+ wav_slices.append(slice(*wav_range))
+
+ # Evaluate whether extra padding is warranted or not
+ last_wav_range = wav_slices[-1]
+ coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
+ if coverage < min_coverage and len(mel_slices) > 1:
+ mel_slices = mel_slices[:-1]
+ wav_slices = wav_slices[:-1]
+
+ return wav_slices, mel_slices
+
+ def embed_utterance(self, wav: np.ndarray, return_partials=False, rate=1.3, min_coverage=0.75, numpy: bool = True):
+ wav_slices, mel_slices = self.compute_partial_slices(len(wav), rate, min_coverage)
+ max_wave_length = wav_slices[-1].stop
+ if max_wave_length >= len(wav):
+ wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
+
+ mel = audio.wav_to_mel_spectrogram(wav)
+ mels = np.array([mel[s] for s in mel_slices])
+ mels = torch.from_numpy(mels).to(self.device) # type: ignore
+ with torch.no_grad():
+ partial_embeds = self(mels)
+
+ if numpy:
+ raw_embed = np.mean(partial_embeds.cpu().numpy(), axis=0)
+ embed = raw_embed / np.linalg.norm(raw_embed, 2)
+ else:
+ raw_embed = partial_embeds.mean(dim=0)
+ embed = raw_embed / torch.linalg.norm(raw_embed, 2)
+
+ if return_partials:
+ return embed, partial_embeds, wav_slices
+ return embed
+
+ def embed_speaker(self, wavs: List[np.ndarray], **kwargs):
+ raw_embed = np.mean([self.embed_utterance(wav, return_partials=False, **kwargs) for wav in wavs], axis=0)
+ return raw_embed / np.linalg.norm(raw_embed, 2)
+
+ def embed_utterance_from_file(self, fpath: str, numpy: bool) -> torch.Tensor:
+ wav_tgt, _ = librosa.load(fpath, sr=sampling_rate)
+ wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20)
+
+ embedding = self.embed_utterance(wav_tgt, numpy=numpy)
+ return embedding
diff --git a/fam/quantiser/text/__init__.py b/fam/quantiser/text/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/fam/quantiser/text/__init__.py
@@ -0,0 +1 @@
+
diff --git a/fam/quantiser/text/__pycache__/__init__.cpython-310.pyc b/fam/quantiser/text/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a715ba2c8a176c226a7c33991f7d541482fc2193
Binary files /dev/null and b/fam/quantiser/text/__pycache__/__init__.cpython-310.pyc differ
diff --git a/fam/quantiser/text/__pycache__/__init__.cpython-39.pyc b/fam/quantiser/text/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..87c01de8f80074959de8b4da6b7e4ebbf2411089
Binary files /dev/null and b/fam/quantiser/text/__pycache__/__init__.cpython-39.pyc differ
diff --git a/fam/quantiser/text/__pycache__/tokenise.cpython-310.pyc b/fam/quantiser/text/__pycache__/tokenise.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..231a0a94ae7efc955c4ae573071af9ed7aab719a
Binary files /dev/null and b/fam/quantiser/text/__pycache__/tokenise.cpython-310.pyc differ
diff --git a/fam/quantiser/text/__pycache__/tokenise.cpython-39.pyc b/fam/quantiser/text/__pycache__/tokenise.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9e25f32f9359ac81549edd2f97f62a202975df71
Binary files /dev/null and b/fam/quantiser/text/__pycache__/tokenise.cpython-39.pyc differ
diff --git a/fam/quantiser/text/tokenise.py b/fam/quantiser/text/tokenise.py
new file mode 100644
index 0000000000000000000000000000000000000000..810c20beb3f62ec5e3b2c1299a04e9a687c446da
--- /dev/null
+++ b/fam/quantiser/text/tokenise.py
@@ -0,0 +1,32 @@
+import tiktoken
+
+
+class TrainedBPETokeniser:
+ def __init__(self, name, pat_str, mergeable_ranks, special_tokens, offset=None) -> None:
+ self.tokenizer = tiktoken.Encoding(
+ name=name,
+ pat_str=pat_str,
+ mergeable_ranks=mergeable_ranks,
+ special_tokens=special_tokens,
+ )
+ self.offset = offset
+
+ def encode(self, text: str) -> list[int]:
+ # note: we add a end of text token!
+ tokens = self.tokenizer.encode(text) + [self.tokenizer.eot_token]
+ if self.offset is not None:
+ tokens = [x + self.offset for x in tokens]
+
+ return tokens
+
+ def decode(self, tokens: list[int]):
+ if self.offset is not None:
+ tokens = [x - self.offset for x in tokens]
+ return self.tokenizer.decode(tokens)
+
+ @property
+ def eot_token(self):
+ if self.offset is not None:
+ return self.tokenizer.eot_token + self.offset
+ else:
+ return self.tokenizer.eot_token
diff --git a/metavoice.egg-info/PKG-INFO b/metavoice.egg-info/PKG-INFO
new file mode 100644
index 0000000000000000000000000000000000000000..9420c404cca5f1547a6b00f6dce772ac44388bbb
--- /dev/null
+++ b/metavoice.egg-info/PKG-INFO
@@ -0,0 +1,6 @@
+Metadata-Version: 2.1
+Name: metavoice
+Version: 0.1.0
+Summary: Foundational model for text to speech
+Requires-Python: <3.12,>=3.10
+License-File: LICENSE
diff --git a/metavoice.egg-info/SOURCES.txt b/metavoice.egg-info/SOURCES.txt
new file mode 100644
index 0000000000000000000000000000000000000000..fac4a73619504712e00984bcbd29386a8eada8a9
--- /dev/null
+++ b/metavoice.egg-info/SOURCES.txt
@@ -0,0 +1,37 @@
+LICENSE
+README.md
+pyproject.toml
+setup.py
+fam/__init__.py
+fam/py.typed
+fam/llm/__init__.py
+fam/llm/decoders.py
+fam/llm/enhancers.py
+fam/llm/fast_inference.py
+fam/llm/fast_inference_utils.py
+fam/llm/fast_model.py
+fam/llm/inference.py
+fam/llm/model.py
+fam/llm/utils.py
+fam/llm/adapters/__init__.py
+fam/llm/adapters/base.py
+fam/llm/adapters/flattened_encodec.py
+fam/llm/adapters/tilted_encodec.py
+fam/llm/layers/__init__.py
+fam/llm/layers/attn.py
+fam/llm/layers/combined.py
+fam/llm/layers/layers.py
+fam/llm/mixins/__init__.py
+fam/llm/mixins/causal.py
+fam/llm/mixins/non_causal.py
+fam/quantiser/__init__.py
+fam/quantiser/audio/__init__.py
+fam/quantiser/audio/speaker_encoder/__init__.py
+fam/quantiser/audio/speaker_encoder/audio.py
+fam/quantiser/audio/speaker_encoder/model.py
+fam/quantiser/text/__init__.py
+fam/quantiser/text/tokenise.py
+metavoice.egg-info/PKG-INFO
+metavoice.egg-info/SOURCES.txt
+metavoice.egg-info/dependency_links.txt
+metavoice.egg-info/top_level.txt
\ No newline at end of file
diff --git a/metavoice.egg-info/dependency_links.txt b/metavoice.egg-info/dependency_links.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/metavoice.egg-info/dependency_links.txt
@@ -0,0 +1 @@
+
diff --git a/metavoice.egg-info/top_level.txt b/metavoice.egg-info/top_level.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2ae0f8da9ef41d1b196ea73ab7e48bf3f5c422b6
--- /dev/null
+++ b/metavoice.egg-info/top_level.txt
@@ -0,0 +1 @@
+fam
diff --git a/outputs/synth_This_is_a_demo_of_text_to_020d77d4-f3d9-480c-9b96-39afd772e805.wav b/outputs/synth_This_is_a_demo_of_text_to_020d77d4-f3d9-480c-9b96-39afd772e805.wav
new file mode 100644
index 0000000000000000000000000000000000000000..d52c407c0f3988d3a97136df22f5a429a9a4bf51
Binary files /dev/null and b/outputs/synth_This_is_a_demo_of_text_to_020d77d4-f3d9-480c-9b96-39afd772e805.wav differ
diff --git a/outputs/synth_This_is_a_demo_of_text_to_0a3de179-bd18-49be-bcae-8d9d9e7f9850.wav b/outputs/synth_This_is_a_demo_of_text_to_0a3de179-bd18-49be-bcae-8d9d9e7f9850.wav
new file mode 100644
index 0000000000000000000000000000000000000000..8443a8841614b765b90188d51ea740377d2ff3de
Binary files /dev/null and b/outputs/synth_This_is_a_demo_of_text_to_0a3de179-bd18-49be-bcae-8d9d9e7f9850.wav differ
diff --git a/outputs/synth_This_is_a_demo_of_text_to_1c0b382e-7bb7-4c25-a5fc-74c3d2366f50.wav b/outputs/synth_This_is_a_demo_of_text_to_1c0b382e-7bb7-4c25-a5fc-74c3d2366f50.wav
new file mode 100644
index 0000000000000000000000000000000000000000..70d9d41dfca6f57a89ab2f02c83ea7c08d757901
Binary files /dev/null and b/outputs/synth_This_is_a_demo_of_text_to_1c0b382e-7bb7-4c25-a5fc-74c3d2366f50.wav differ
diff --git a/outputs/synth_This_is_a_demo_of_text_to_28bf92c7-ce82-4b6b-ba17-de3e4b48cf3c.wav b/outputs/synth_This_is_a_demo_of_text_to_28bf92c7-ce82-4b6b-ba17-de3e4b48cf3c.wav
new file mode 100644
index 0000000000000000000000000000000000000000..2d3e4caa7a096ace24f99b3fd9b12114533b5cdd
Binary files /dev/null and b/outputs/synth_This_is_a_demo_of_text_to_28bf92c7-ce82-4b6b-ba17-de3e4b48cf3c.wav differ
diff --git a/outputs/synth_This_is_a_demo_of_text_to_30cf7f84-349a-4f8f-bba3-610a299338e2.wav b/outputs/synth_This_is_a_demo_of_text_to_30cf7f84-349a-4f8f-bba3-610a299338e2.wav
new file mode 100644
index 0000000000000000000000000000000000000000..2dfc9e980e4bdb70d8f3d07c53b2cdf2329af49c
Binary files /dev/null and b/outputs/synth_This_is_a_demo_of_text_to_30cf7f84-349a-4f8f-bba3-610a299338e2.wav differ
diff --git a/outputs/synth_This_is_a_demo_of_text_to_40a44025-d15c-4392-a039-f64432dd2126.wav b/outputs/synth_This_is_a_demo_of_text_to_40a44025-d15c-4392-a039-f64432dd2126.wav
new file mode 100644
index 0000000000000000000000000000000000000000..4427055b688e4975363833b334a7031278f89e78
Binary files /dev/null and b/outputs/synth_This_is_a_demo_of_text_to_40a44025-d15c-4392-a039-f64432dd2126.wav differ
diff --git a/outputs/synth_This_is_a_demo_of_text_to_5f290957-0b43-473e-a192-508c4f40fd53.wav b/outputs/synth_This_is_a_demo_of_text_to_5f290957-0b43-473e-a192-508c4f40fd53.wav
new file mode 100644
index 0000000000000000000000000000000000000000..6227fb8d565fdbd09d5763f48477403f0de13210
Binary files /dev/null and b/outputs/synth_This_is_a_demo_of_text_to_5f290957-0b43-473e-a192-508c4f40fd53.wav differ
diff --git a/outputs/synth_This_is_a_demo_of_text_to_6f3d8b4c-fa33-4a32-82c5-3dd84251cc8f.wav b/outputs/synth_This_is_a_demo_of_text_to_6f3d8b4c-fa33-4a32-82c5-3dd84251cc8f.wav
new file mode 100644
index 0000000000000000000000000000000000000000..1f0542ab6f7154cd621dad9a9657c55fcc27f171
Binary files /dev/null and b/outputs/synth_This_is_a_demo_of_text_to_6f3d8b4c-fa33-4a32-82c5-3dd84251cc8f.wav differ
diff --git a/outputs/synth_This_is_a_demo_of_text_to_95652691-8fb0-43f5-acc5-caad33ac4016.wav b/outputs/synth_This_is_a_demo_of_text_to_95652691-8fb0-43f5-acc5-caad33ac4016.wav
new file mode 100644
index 0000000000000000000000000000000000000000..49e232f1ad26d9d8ee74bd0bd32ebdb63199080e
Binary files /dev/null and b/outputs/synth_This_is_a_demo_of_text_to_95652691-8fb0-43f5-acc5-caad33ac4016.wav differ
diff --git a/outputs/synth_This_is_a_demo_of_text_to_98436b97-3a96-4481-82cd-a33b0287bd5e.wav b/outputs/synth_This_is_a_demo_of_text_to_98436b97-3a96-4481-82cd-a33b0287bd5e.wav
new file mode 100644
index 0000000000000000000000000000000000000000..2b965e8df4a42cb66d5f31d9e5c2ef6851d6d5ac
Binary files /dev/null and b/outputs/synth_This_is_a_demo_of_text_to_98436b97-3a96-4481-82cd-a33b0287bd5e.wav differ
diff --git a/outputs/synth_This_is_a_demo_of_text_to_994ee6b9-cf1e-40ac-a250-ff5e754bd6c7.wav b/outputs/synth_This_is_a_demo_of_text_to_994ee6b9-cf1e-40ac-a250-ff5e754bd6c7.wav
new file mode 100644
index 0000000000000000000000000000000000000000..c7a72c55ad913f8ad4129c73f563d2c54bac5069
Binary files /dev/null and b/outputs/synth_This_is_a_demo_of_text_to_994ee6b9-cf1e-40ac-a250-ff5e754bd6c7.wav differ
diff --git a/outputs/synth_This_is_a_demo_of_text_to_ab7ba686-fa40-486b-b70a-4e3caaf8696e.wav b/outputs/synth_This_is_a_demo_of_text_to_ab7ba686-fa40-486b-b70a-4e3caaf8696e.wav
new file mode 100644
index 0000000000000000000000000000000000000000..ce869a6fb6e80862740340aea0a960b0ad8fa8fb
Binary files /dev/null and b/outputs/synth_This_is_a_demo_of_text_to_ab7ba686-fa40-486b-b70a-4e3caaf8696e.wav differ
diff --git a/outputs/synth_This_is_a_demo_of_text_to_bf0774d0-1716-4857-9b7c-9707b706641f.wav b/outputs/synth_This_is_a_demo_of_text_to_bf0774d0-1716-4857-9b7c-9707b706641f.wav
new file mode 100644
index 0000000000000000000000000000000000000000..f4ceb1671eebf3d646350fb97c9c753a4f987843
Binary files /dev/null and b/outputs/synth_This_is_a_demo_of_text_to_bf0774d0-1716-4857-9b7c-9707b706641f.wav differ
diff --git a/outputs/synth_This_is_a_demo_of_text_to_c9aed45b-b6b6-4b29-a68f-52b57f6c6bb5.wav b/outputs/synth_This_is_a_demo_of_text_to_c9aed45b-b6b6-4b29-a68f-52b57f6c6bb5.wav
new file mode 100644
index 0000000000000000000000000000000000000000..3ed7715d2c4f6f4fc77e9899ceb69f31aa0ffd34
Binary files /dev/null and b/outputs/synth_This_is_a_demo_of_text_to_c9aed45b-b6b6-4b29-a68f-52b57f6c6bb5.wav differ
diff --git a/outputs/synth_This_is_a_demo_of_text_to_d6bc9cbd-67f2-4842-ae0c-e339cd320cff.wav b/outputs/synth_This_is_a_demo_of_text_to_d6bc9cbd-67f2-4842-ae0c-e339cd320cff.wav
new file mode 100644
index 0000000000000000000000000000000000000000..29b23857f81f79424d53ba0d16626dfa41f45c8a
Binary files /dev/null and b/outputs/synth_This_is_a_demo_of_text_to_d6bc9cbd-67f2-4842-ae0c-e339cd320cff.wav differ
diff --git a/outputs/synth_This_is_a_demo_of_text_to_e25a753e-a0f7-425c-83e2-7f5a87d404e9.wav b/outputs/synth_This_is_a_demo_of_text_to_e25a753e-a0f7-425c-83e2-7f5a87d404e9.wav
new file mode 100644
index 0000000000000000000000000000000000000000..d2c42836ab04faa718c1a8ac679b1f95c7a0bb50
Binary files /dev/null and b/outputs/synth_This_is_a_demo_of_text_to_e25a753e-a0f7-425c-83e2-7f5a87d404e9.wav differ
diff --git a/outputs/synth_This_is_a_demo_of_text_to_e324e3f3-eec0-49be-8124-63a099fb19bb.wav b/outputs/synth_This_is_a_demo_of_text_to_e324e3f3-eec0-49be-8124-63a099fb19bb.wav
new file mode 100644
index 0000000000000000000000000000000000000000..35639f56d78128dc0469db98d9501d8ce3b19ab4
Binary files /dev/null and b/outputs/synth_This_is_a_demo_of_text_to_e324e3f3-eec0-49be-8124-63a099fb19bb.wav differ
diff --git a/outputs/synth_This_is_a_demo_of_text_to_eece1d71-71bf-4bb8-a34d-a7a720942a23.wav b/outputs/synth_This_is_a_demo_of_text_to_eece1d71-71bf-4bb8-a34d-a7a720942a23.wav
new file mode 100644
index 0000000000000000000000000000000000000000..2ce59b89510db7beb9a6ac5745594cdee4457bfa
Binary files /dev/null and b/outputs/synth_This_is_a_demo_of_text_to_eece1d71-71bf-4bb8-a34d-a7a720942a23.wav differ
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..b87e85141d611bf645134d9db4504874d15ff8c8
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,21 @@
+[project]
+name = "metavoice"
+version = "0.1.0"
+description = "Foundational model for text to speech"
+requires-python = ">=3.10,<3.12"
+
+[tool.black]
+line-length = 120
+exclude = '''
+/(
+ \.git
+ | \.mypy_cache
+ | \.tox
+ | _build
+ | build
+ | dist
+)/
+'''
+
+[tool.isort]
+profile = "black"
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d72a3f1a98aca7cbc674557b64829a87be586538
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,20 @@
+--extra-index-url https://download.pytorch.org/whl/cu118
+torch
+
+
+transformers==4.33.1
+librosa
+tqdm
+tiktoken==0.5.1
+audiocraft
+numpy<1.25
+ninja
+fastapi
+uvicorn
+tyro
+deepfilternet
+pydub
+soundfile
+huggingface_hub
+pickle
+
diff --git a/serving.py b/serving.py
new file mode 100644
index 0000000000000000000000000000000000000000..48daf269d33fca773153834b751ed3792ab2d34c
--- /dev/null
+++ b/serving.py
@@ -0,0 +1,144 @@
+import json
+import logging
+import shlex
+import subprocess
+import tempfile
+import warnings
+from pathlib import Path
+from typing import Optional
+
+import fastapi
+import fastapi.middleware.cors
+import tyro
+import uvicorn
+from attr import dataclass
+from fastapi import Request
+from fastapi.responses import Response
+
+from fam.llm.fast_inference import TTS
+from fam.llm.utils import check_audio_file
+
+logger = logging.getLogger(__name__)
+
+
+## Setup FastAPI server.
+app = fastapi.FastAPI()
+
+
+@dataclass
+class ServingConfig:
+ huggingface_repo_id: str = "metavoiceio/metavoice-1B-v0.1"
+ """Absolute path to the model directory."""
+
+ temperature: float = 1.0
+ """Temperature for sampling applied to both models."""
+
+ seed: int = 1337
+ """Random seed for sampling."""
+
+ port: int = 58003
+
+
+# Singleton
+class _GlobalState:
+ config: ServingConfig
+ tts: TTS
+
+
+GlobalState = _GlobalState()
+
+
+@dataclass(frozen=True)
+class TTSRequest:
+ text: str
+ speaker_ref_path: Optional[str] = None
+ guidance: float = 3.0
+ top_p: float = 0.95
+ top_k: Optional[int] = None
+
+
+@app.get("/health")
+async def health_check():
+ return {"status": "ok"}
+
+
+@app.post("/tts", response_class=Response)
+async def text_to_speech(req: Request):
+ audiodata = await req.body()
+ payload = None
+ wav_out_path = None
+
+ try:
+ headers = req.headers
+ payload = headers["X-Payload"]
+ payload = json.loads(payload)
+ tts_req = TTSRequest(**payload)
+ with tempfile.NamedTemporaryFile(suffix=".wav") as wav_tmp:
+ if tts_req.speaker_ref_path is None:
+ wav_path = _convert_audiodata_to_wav_path(audiodata, wav_tmp)
+ check_audio_file(wav_path)
+ else:
+ # TODO: fix
+ wav_path = tts_req.speaker_ref_path
+
+ if wav_path is None:
+ warnings.warn("Running without speaker reference")
+ assert tts_req.guidance is None
+
+ wav_out_path = GlobalState.tts.synthesise(
+ text=tts_req.text,
+ spk_ref_path=wav_path,
+ top_p=tts_req.top_p,
+ guidance_scale=tts_req.guidance,
+ )
+
+ with open(wav_out_path, "rb") as f:
+ return Response(content=f.read(), media_type="audio/wav")
+ except Exception as e:
+ # traceback_str = "".join(traceback.format_tb(e.__traceback__))
+ logger.exception(f"Error processing request {payload}")
+ return Response(
+ content="Something went wrong. Please try again in a few mins or contact us on Discord",
+ status_code=500,
+ )
+ finally:
+ if wav_out_path is not None:
+ Path(wav_out_path).unlink(missing_ok=True)
+
+
+def _convert_audiodata_to_wav_path(audiodata, wav_tmp):
+ with tempfile.NamedTemporaryFile() as unknown_format_tmp:
+ if unknown_format_tmp.write(audiodata) == 0:
+ return None
+ unknown_format_tmp.flush()
+
+ subprocess.check_output(
+ # arbitrary 2 minute cutoff
+ shlex.split(f"ffmpeg -t 120 -y -i {unknown_format_tmp.name} -f wav {wav_tmp.name}")
+ )
+
+ return wav_tmp.name
+
+
+if __name__ == "__main__":
+ for name in logging.root.manager.loggerDict:
+ logger = logging.getLogger(name)
+ logger.setLevel(logging.INFO)
+ logging.root.setLevel(logging.INFO)
+
+ GlobalState.config = tyro.cli(ServingConfig)
+ GlobalState.tts = TTS(seed=GlobalState.config.seed)
+
+ app.add_middleware(
+ fastapi.middleware.cors.CORSMiddleware,
+ allow_origins=["*", f"http://localhost:{GlobalState.config.port}", "http://localhost:3000"],
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+ )
+ uvicorn.run(
+ app,
+ host="0.0.0.0",
+ port=GlobalState.config.port,
+ log_level="info",
+ )
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9dc0608d9e2050446eb744a9ae557dd40f99264
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,6 @@
+from setuptools import find_packages, setup # type: ignore
+
+setup(
+ name="fam",
+ packages=find_packages(".", exclude=["tests"]),
+)