TomatoCocotree commited on
Commit
6a62ffb
1 Parent(s): 59656d8
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +21 -0
  2. LICENSE +24 -0
  3. api_key.txt +1 -0
  4. constants.py +49 -0
  5. data/models/coqui/.placeholder +2 -0
  6. data/models/rvc/.placeholder +3 -0
  7. data/tmp/.placeholder +2 -0
  8. docker/Dockerfile +35 -0
  9. docker/docker-compose.yml +23 -0
  10. docker/readme.md +10 -0
  11. modules/classify/classify_module.py +41 -0
  12. modules/speech_recognition/streaming_module.py +121 -0
  13. modules/speech_recognition/vosk_module.py +77 -0
  14. modules/speech_recognition/whisper_module.py +56 -0
  15. modules/text_to_speech/coqui/coqui_module.py +333 -0
  16. modules/utils.py +15 -0
  17. modules/voice_conversion/fairseq/LICENSE +21 -0
  18. modules/voice_conversion/fairseq/__init__.py +45 -0
  19. modules/voice_conversion/fairseq/binarizer.py +381 -0
  20. modules/voice_conversion/fairseq/checkpoint_utils.py +905 -0
  21. modules/voice_conversion/fairseq/data/__init__.py +130 -0
  22. modules/voice_conversion/fairseq/data/add_target_dataset.py +83 -0
  23. modules/voice_conversion/fairseq/data/append_token_dataset.py +41 -0
  24. modules/voice_conversion/fairseq/data/audio/__init__.py +93 -0
  25. modules/voice_conversion/fairseq/data/audio/audio_utils.py +389 -0
  26. modules/voice_conversion/fairseq/data/audio/data_cfg.py +387 -0
  27. modules/voice_conversion/fairseq/data/audio/dataset_transforms/__init__.py +53 -0
  28. modules/voice_conversion/fairseq/data/audio/dataset_transforms/concataugment.py +61 -0
  29. modules/voice_conversion/fairseq/data/audio/dataset_transforms/noisyoverlapaugment.py +105 -0
  30. modules/voice_conversion/fairseq/data/audio/feature_transforms/__init__.py +43 -0
  31. modules/voice_conversion/fairseq/data/audio/feature_transforms/delta_deltas.py +37 -0
  32. modules/voice_conversion/fairseq/data/audio/feature_transforms/global_cmvn.py +29 -0
  33. modules/voice_conversion/fairseq/data/audio/feature_transforms/specaugment.py +131 -0
  34. modules/voice_conversion/fairseq/data/audio/feature_transforms/utterance_cmvn.py +41 -0
  35. modules/voice_conversion/fairseq/data/audio/frm_text_to_speech_dataset.py +205 -0
  36. modules/voice_conversion/fairseq/data/audio/hubert_dataset.py +356 -0
  37. modules/voice_conversion/fairseq/data/audio/multi_modality_dataset.py +284 -0
  38. modules/voice_conversion/fairseq/data/audio/raw_audio_dataset.py +393 -0
  39. modules/voice_conversion/fairseq/data/audio/speech_to_speech_dataset.py +379 -0
  40. modules/voice_conversion/fairseq/data/audio/speech_to_text_dataset.py +733 -0
  41. modules/voice_conversion/fairseq/data/audio/speech_to_text_joint_dataset.py +359 -0
  42. modules/voice_conversion/fairseq/data/audio/text_to_speech_dataset.py +250 -0
  43. modules/voice_conversion/fairseq/data/audio/waveform_transforms/__init__.py +48 -0
  44. modules/voice_conversion/fairseq/data/audio/waveform_transforms/noiseaugment.py +201 -0
  45. modules/voice_conversion/fairseq/data/backtranslation_dataset.py +165 -0
  46. modules/voice_conversion/fairseq/data/base_wrapper_dataset.py +78 -0
  47. modules/voice_conversion/fairseq/data/bucket_pad_length_dataset.py +78 -0
  48. modules/voice_conversion/fairseq/data/codedataset.py +576 -0
  49. modules/voice_conversion/fairseq/data/colorize_dataset.py +25 -0
  50. modules/voice_conversion/fairseq/data/concat_dataset.py +124 -0
Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements-complete.txt .
6
+ RUN pip install -r requirements-complete.txt
7
+
8
+ RUN mkdir /.cache && chmod -R 777 /.cache
9
+ RUN mkdir .chroma && chmod -R 777 .chroma
10
+
11
+ COPY . .
12
+
13
+
14
+ RUN chmod -R 777 /app
15
+
16
+ RUN --mount=type=secret,id=password,mode=0444,required=true \
17
+ cat /run/secrets/password > /test
18
+
19
+ EXPOSE 7860
20
+
21
+ CMD ["python", "server.py", "--cpu", "--enable-modules=caption,summarize,classify,silero-tts,edge-tts,chromadb"]
LICENSE ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ This is free and unencumbered software released into the public domain.
2
+
3
+ Anyone is free to copy, modify, publish, use, compile, sell, or
4
+ distribute this software, either in source code form or as a compiled
5
+ binary, for any purpose, commercial or non-commercial, and by any
6
+ means.
7
+
8
+ In jurisdictions that recognize copyright laws, the author or authors
9
+ of this software dedicate any and all copyright interest in the
10
+ software to the public domain. We make this dedication for the benefit
11
+ of the public at large and to the detriment of our heirs and
12
+ successors. We intend this dedication to be an overt act of
13
+ relinquishment in perpetuity of all present and future rights to this
14
+ software under copyright law.
15
+
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
19
+ IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
20
+ OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
21
+ ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
22
+ OTHER DEALINGS IN THE SOFTWARE.
23
+
24
+ For more information, please refer to <https://unlicense.org>
api_key.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ CHANGEME
constants.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Constants
2
+ DEFAULT_CUDA_DEVICE = "cuda:0"
3
+ # Also try: 'Qiliang/bart-large-cnn-samsum-ElectrifAi_v10'
4
+ DEFAULT_SUMMARIZATION_MODEL = "Qiliang/bart-large-cnn-samsum-ChatGPT_v3"
5
+ # Also try: 'joeddav/distilbert-base-uncased-go-emotions-student'
6
+ DEFAULT_CLASSIFICATION_MODEL = "nateraw/bert-base-uncased-emotion"
7
+ # Also try: 'Salesforce/blip-image-captioning-base'
8
+ DEFAULT_CAPTIONING_MODEL = "Salesforce/blip-image-captioning-large"
9
+ DEFAULT_SD_MODEL = "ckpt/anything-v4.5-vae-swapped"
10
+ DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
11
+ DEFAULT_REMOTE_SD_HOST = "127.0.0.1"
12
+ DEFAULT_REMOTE_SD_PORT = 7860
13
+ DEFAULT_CHROMA_PORT = 8000
14
+ SILERO_SAMPLES_PATH = "tts_samples"
15
+ SILERO_SAMPLE_TEXT = "The quick brown fox jumps over the lazy dog"
16
+ DEFAULT_SUMMARIZE_PARAMS = {
17
+ "temperature": 1.0,
18
+ "repetition_penalty": 1.0,
19
+ "max_length": 500,
20
+ "min_length": 200,
21
+ "length_penalty": 1.5,
22
+ "bad_words": [
23
+ "\n",
24
+ '"',
25
+ "*",
26
+ "[",
27
+ "]",
28
+ "{",
29
+ "}",
30
+ ":",
31
+ "(",
32
+ ")",
33
+ "<",
34
+ ">",
35
+ "Â",
36
+ "The text ends",
37
+ "The story ends",
38
+ "The text is",
39
+ "The story is",
40
+ ],
41
+ }
42
+
43
+ PROMPT_PREFIX = "best quality, absurdres, "
44
+ NEGATIVE_PROMPT = """lowres, bad anatomy, error body, error hair, error arm,
45
+ error hands, bad hands, error fingers, bad fingers, missing fingers
46
+ error legs, bad legs, multiple legs, missing legs, error lighting,
47
+ error shadow, error reflection, text, error, extra digit, fewer digits,
48
+ cropped, worst quality, low quality, normal quality, jpeg artifacts,
49
+ signature, watermark, username, blurry"""
data/models/coqui/.placeholder ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Put Coqui models folders here.
2
+ Must contains both a "model.pth" and "config.json" file.
data/models/rvc/.placeholder ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ Put RVC models folder here.
2
+ Must have ".pth" file in it
3
+ .index file is optional but could help improve the processing time/quality.
data/tmp/.placeholder ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ This is a temporary file folder.
2
+ May contain RVC input/output file for research purpose.
docker/Dockerfile ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.7.1-cudnn8-runtime-ubuntu20.04
2
+
3
+ EXPOSE 5100
4
+
5
+ ENV PATH="/root/miniconda3/bin:${PATH}"
6
+ ARG PATH="/root/miniconda3/bin:${PATH}"
7
+
8
+ ENV DEBIAN_FRONTEND noninteractive
9
+ RUN apt-get update && apt-get install -y --no-install-recommends \
10
+ python3 python3-venv wget build-essential
11
+
12
+ RUN wget \
13
+ https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
14
+ && mkdir /root/.conda \
15
+ && bash Miniconda3-latest-Linux-x86_64.sh -b \
16
+ && rm -f Miniconda3-latest-Linux-x86_64.sh
17
+
18
+ RUN conda --version
19
+
20
+ RUN conda init
21
+
22
+ RUN conda create -n extras
23
+
24
+ RUN /bin/bash -c "source activate extras"
25
+
26
+ RUN conda install pytorch torchvision torchaudio pytorch-cuda=11.7 git -c pytorch -c nvidia -c conda-forge
27
+
28
+ WORKDIR /sillytavern-extras/
29
+ COPY . .
30
+
31
+ ARG REQUIREMENTS
32
+ RUN pip install -r $REQUIREMENTS
33
+
34
+ ARG MODULES
35
+ CMD ["python","server.py","--enable-modules=$MODULES"]
docker/docker-compose.yml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: "3"
2
+ services:
3
+ sillytavern-extras:
4
+ runtime: nvidia
5
+ image: cohee1207/sillytavern-extras
6
+ build:
7
+ context: ../
8
+ dockerfile: docker/Dockerfile
9
+ args:
10
+ REQUIREMENTS: requirements.txt
11
+ MODULES: caption,summarize,classify
12
+ # REQUIREMENTS: requirements-complete.txt
13
+ # MODULES: caption,summarize,classify,sd,silero-tts,edge-tts,chromadb
14
+ volumes:
15
+ #- "./chromadb:/chromadb"
16
+ - "./cache:/root/.cache"
17
+ - "./api_key.txt:/sillytavern-extras/api_key.txt:rw"
18
+ ports:
19
+ - "5100:5100"
20
+ environment:
21
+ - NVIDIA_VISIBLE_DEVICES=all
22
+ command: python server.py --enable-modules=caption,summarize,classify
23
+ # command: python server.py --enable-modules=caption,summarize,classify,sd,silero-tts,edge-tts,chromadb
docker/readme.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Docker Usage
2
+
3
+ ## Building the image
4
+
5
+ *This is assuming you have docker and docker compose installed and running.*
6
+
7
+ 1. Open a terminal and set your current directory to the "docker" directory in your clone of this repo.
8
+ 2. Adjust the "docker-compose.yml" file to match your needs. The default selection and the selection with all modules are provided as examples.
9
+ 3. Once ready, run the command "docker compose build" to build the "cohee1207/sillytavern-extras" docker image.
10
+
modules/classify/classify_module.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Classify module for SillyTavern Extras
3
+
4
+ Authors:
5
+ - Tony Ribeiro (https://github.com/Tony-sama)
6
+ - Cohee (https://github.com/Cohee1207)
7
+
8
+ Provides classification features for text
9
+
10
+ References:
11
+ - https://huggingface.co/tasks/text-classification
12
+ """
13
+
14
+ from transformers import pipeline
15
+
16
+ DEBUG_PREFIX = "<Classify module>"
17
+
18
+ # Models init
19
+
20
+ text_emotion_pipe = None
21
+
22
+ def init_text_emotion_classifier(model_name: str, device: str, torch_dtype: str) -> None:
23
+ global text_emotion_pipe
24
+
25
+ print(DEBUG_PREFIX,"Initializing text classification pipeline with model",model_name)
26
+ text_emotion_pipe = pipeline(
27
+ "text-classification",
28
+ model=model_name,
29
+ top_k=None,
30
+ device=device,
31
+ torch_dtype=torch_dtype,
32
+ )
33
+
34
+
35
+ def classify_text_emotion(text: str) -> list:
36
+ output = text_emotion_pipe(
37
+ text,
38
+ truncation=True,
39
+ max_length=text_emotion_pipe.model.config.max_position_embeddings,
40
+ )[0]
41
+ return sorted(output, key=lambda x: x["score"], reverse=True)
modules/speech_recognition/streaming_module.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Speech-to-text module based on Vosk and Whisper for SillyTavern Extras
3
+ - Vosk website: https://alphacephei.com/vosk/
4
+ - Vosk api: https://github.com/alphacep/vosk-api
5
+ - Whisper github: https://github.com/openai/whisper
6
+
7
+ Authors:
8
+ - Tony Ribeiro (https://github.com/Tony-sama)
9
+
10
+ Models are saved into user cache folder, example: C:/Users/toto/.cache/whisper and C:/Users/toto/.cache/vosk
11
+
12
+ References:
13
+ - Code adapted from:
14
+ - whisper github: https://github.com/openai/whisper
15
+ - oobabooga text-generation-webui github: https://github.com/oobabooga/text-generation-webui
16
+ - vosk github: https://github.com/alphacep/vosk-api/blob/master/python/example/test_microphone.py
17
+ """
18
+ from flask import jsonify, abort
19
+
20
+ import queue
21
+ import sys
22
+ import sounddevice as sd
23
+ import soundfile as sf
24
+ import io
25
+ import numpy as np
26
+ from scipy.io.wavfile import write
27
+
28
+ import vosk
29
+ import whisper
30
+
31
+ DEBUG_PREFIX = "<stt streaming module>"
32
+ RECORDING_FILE_PATH = "stt_test.wav"
33
+
34
+ whisper_model = None
35
+ vosk_model = None
36
+ device = None
37
+
38
+ def load_model(file_path=None):
39
+ """
40
+ Load given vosk model from file or default to en-us model.
41
+ Download model to user cache folder, example: C:/Users/toto/.cache/vosk
42
+ """
43
+
44
+ if file_path is None:
45
+ return (whisper.load_model("base.en"), vosk.Model(lang="en-us"))
46
+ else:
47
+ return (whisper.load_model(file_path), vosk.Model(lang="en-us"))
48
+
49
+ def convert_bytearray_to_wav_ndarray(input_bytearray: bytes, sampling_rate=16000):
50
+ """
51
+ Convert a bytearray to wav format to output in a file for quality check debuging
52
+ """
53
+ bytes_wav = bytes()
54
+ byte_io = io.BytesIO(bytes_wav)
55
+ write(byte_io, sampling_rate, np.frombuffer(input_bytearray, dtype=np.int16))
56
+ output_wav = byte_io.read()
57
+ output, _ = sf.read(io.BytesIO(output_wav))
58
+ return output
59
+
60
+ def record_and_transcript():
61
+ """
62
+ Continuously record from mic and transcript voice.
63
+ Return the transcript once no more voice is detected.
64
+ """
65
+ if whisper_model is None:
66
+ print(DEBUG_PREFIX,"Whisper model not initialized yet.")
67
+ return ""
68
+
69
+ q = queue.Queue()
70
+ stream_errors = list()
71
+
72
+ def callback(indata, frames, time, status):
73
+ """This is called (from a separate thread) for each audio block."""
74
+ if status:
75
+ print(status, file=sys.stderr)
76
+ stream_errors.append(status)
77
+ q.put(bytes(indata))
78
+
79
+ try:
80
+ device_info = sd.query_devices(device, "input")
81
+ # soundfile expects an int, sounddevice provides a float:
82
+ samplerate = int(device_info["default_samplerate"])
83
+
84
+ print(DEBUG_PREFIX, "Start recording from:", device_info["name"], "with samplerate", samplerate)
85
+
86
+ with sd.RawInputStream(samplerate=samplerate, blocksize = 8000, device=device, dtype="int16", channels=1, callback=callback):
87
+
88
+ rec = vosk.KaldiRecognizer(vosk_model, samplerate)
89
+ full_recording = bytearray()
90
+ while True:
91
+ data = q.get()
92
+ if len(stream_errors) > 0:
93
+ raise Exception(DEBUG_PREFIX+" Stream errors: "+str(stream_errors))
94
+
95
+ full_recording.extend(data)
96
+
97
+ if rec.AcceptWaveform(data):
98
+ # Extract transcript string
99
+ transcript = rec.Result()[14:-3]
100
+ print(DEBUG_PREFIX, "Transcripted from microphone stream (vosk):", transcript)
101
+
102
+ # ----------------------------------
103
+ # DEBUG: save recording to wav file
104
+ # ----------------------------------
105
+ output_file = convert_bytearray_to_wav_ndarray(input_bytearray=full_recording, sampling_rate=samplerate)
106
+ sf.write(file=RECORDING_FILE_PATH, data=output_file, samplerate=samplerate)
107
+ print(DEBUG_PREFIX, "Recorded message saved to", RECORDING_FILE_PATH)
108
+
109
+ # Whisper HACK
110
+ result = whisper_model.transcribe(RECORDING_FILE_PATH)
111
+ transcript = result["text"]
112
+ print(DEBUG_PREFIX, "Transcripted from audio file (whisper):", transcript)
113
+ # ----------------------------------
114
+
115
+ return jsonify({"transcript": transcript})
116
+ #else:
117
+ # print(rec.PartialResult())
118
+
119
+ except Exception as e: # No exception observed during test but we never know
120
+ print(e)
121
+ abort(500, DEBUG_PREFIX+" Exception occurs while recording")
modules/speech_recognition/vosk_module.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Speech-to-text module based on Vosk for SillyTavern Extras
3
+ - Vosk website: https://alphacephei.com/vosk/
4
+ - Vosk api: https://github.com/alphacep/vosk-api
5
+
6
+ Authors:
7
+ - Tony Ribeiro (https://github.com/Tony-sama)
8
+
9
+ Models are saved into user cache folder, example: C:/Users/toto/.cache/vosk
10
+
11
+ References:
12
+ - Code adapted from: https://github.com/alphacep/vosk-api/blob/master/python/example/test_simple.py
13
+ """
14
+ from flask import jsonify, abort, request
15
+
16
+ import wave
17
+ from vosk import Model, KaldiRecognizer, SetLogLevel
18
+ import soundfile
19
+
20
+ DEBUG_PREFIX = "<stt vosk module>"
21
+ RECORDING_FILE_PATH = "stt_test.wav"
22
+
23
+ model = None
24
+
25
+ SetLogLevel(-1)
26
+
27
+ def load_model(file_path=None):
28
+ """
29
+ Load given vosk model from file or default to en-us model.
30
+ Download model to user cache folder, example: C:/Users/toto/.cache/vosk
31
+ """
32
+
33
+ if file_path is None:
34
+ return Model(lang="en-us")
35
+ else:
36
+ return Model(file_path)
37
+
38
+ def process_audio():
39
+ """
40
+ Transcript request audio file to text using Whisper
41
+ """
42
+
43
+ if model is None:
44
+ print(DEBUG_PREFIX,"Vosk model not initialized yet.")
45
+ return ""
46
+
47
+ try:
48
+ file = request.files.get('AudioFile')
49
+ file.save(RECORDING_FILE_PATH)
50
+
51
+ # Read and rewrite the file with soundfile
52
+ data, samplerate = soundfile.read(RECORDING_FILE_PATH)
53
+ soundfile.write(RECORDING_FILE_PATH, data, samplerate)
54
+
55
+ wf = wave.open(RECORDING_FILE_PATH, "rb")
56
+ if wf.getnchannels() != 1 or wf.getsampwidth() != 2 or wf.getcomptype() != "NONE":
57
+ print("Audio file must be WAV format mono PCM.")
58
+ abort(500, DEBUG_PREFIX+" Audio file must be WAV format mono PCM.")
59
+
60
+ rec = KaldiRecognizer(model, wf.getframerate())
61
+ #rec.SetWords(True)
62
+ #rec.SetPartialWords(True)
63
+
64
+ while True:
65
+ data = wf.readframes(4000)
66
+ if len(data) == 0:
67
+ break
68
+ if rec.AcceptWaveform(data):
69
+ break
70
+
71
+ transcript = rec.Result()[14:-3]
72
+ print(DEBUG_PREFIX, "Transcripted from request audio file:", transcript)
73
+ return jsonify({"transcript": transcript})
74
+
75
+ except Exception as e: # No exception observed during test but we never know
76
+ print(e)
77
+ abort(500, DEBUG_PREFIX+" Exception occurs while processing audio")
modules/speech_recognition/whisper_module.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Speech-to-text module based on Whisper for SillyTavern Extras
3
+ - Whisper github: https://github.com/openai/whisper
4
+
5
+ Authors:
6
+ - Tony Ribeiro (https://github.com/Tony-sama)
7
+
8
+ Models are saved into user cache folder, example: C:/Users/toto/.cache/whisper
9
+
10
+ References:
11
+ - Code adapted from:
12
+ - whisper github: https://github.com/openai/whisper
13
+ - oobabooga text-generation-webui github: https://github.com/oobabooga/text-generation-webui
14
+ """
15
+ from flask import jsonify, abort, request
16
+
17
+ import whisper
18
+
19
+ DEBUG_PREFIX = "<stt whisper module>"
20
+ RECORDING_FILE_PATH = "stt_test.wav"
21
+
22
+ model = None
23
+
24
+ def load_model(file_path=None):
25
+ """
26
+ Load given vosk model from file or default to en-us model.
27
+ Download model to user cache folder, example: C:/Users/toto/.cache/vosk
28
+ """
29
+
30
+ if file_path is None:
31
+ return whisper.load_model("base.en")
32
+ else:
33
+ return whisper.load_model(file_path)
34
+
35
+ def process_audio():
36
+ """
37
+ Transcript request audio file to text using Whisper
38
+ """
39
+
40
+ if model is None:
41
+ print(DEBUG_PREFIX,"Whisper model not initialized yet.")
42
+ return ""
43
+
44
+ try:
45
+ file = request.files.get('AudioFile')
46
+ file.save(RECORDING_FILE_PATH)
47
+
48
+ result = model.transcribe(RECORDING_FILE_PATH)
49
+ transcript = result["text"]
50
+ print(DEBUG_PREFIX, "Transcripted from audio file (whisper):", transcript)
51
+
52
+ return jsonify({"transcript": transcript})
53
+
54
+ except Exception as e: # No exception observed during test but we never know
55
+ print(e)
56
+ abort(500, DEBUG_PREFIX+" Exception occurs while processing audio")
modules/text_to_speech/coqui/coqui_module.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Coqui module for SillyTavern Extras
3
+
4
+ Authors:
5
+ - Pyrater (https://github.com/pyrater)
6
+ - Tony Ribeiro (https://github.com/Tony-sama)
7
+
8
+ Models are saved into user cache folder: "C:/Users/<username>/AppData/Local/tts"
9
+
10
+ References:
11
+ - Code adapted from:
12
+ - Coqui TTS https://tts.readthedocs.io/en/latest/
13
+ - Audio-webui: https://github.com/gitmylo/audio-webui
14
+ """
15
+ import json
16
+ import os
17
+ import io
18
+ import shutil
19
+
20
+ from flask import abort, request, send_file, jsonify
21
+
22
+ from TTS.api import TTS
23
+ from TTS.utils.manage import ModelManager
24
+
25
+ from modules.utils import silence_log
26
+
27
+ DEBUG_PREFIX = "<Coqui-TTS module>"
28
+ COQUI_MODELS_PATH = "data/models/coqui/"
29
+ IGNORED_FILES = [".placeholder"]
30
+ COQUI_LOCAL_MODEL_FILE_NAME = "model.pth"
31
+ COQUI_LOCAL_CONFIG_FILE_NAME = "config.json"
32
+
33
+ gpu_mode = False
34
+ is_downloading = False
35
+
36
+ def install_model(model_id):
37
+ global gpu_mode
38
+ audio_buffer = io.BytesIO()
39
+ speaker_id = None
40
+ language_id = None
41
+
42
+ print(DEBUG_PREFIX,"Loading model",model_id)
43
+ try:
44
+ tts = TTS(model_name=model_id, progress_bar=True, gpu=gpu_mode)
45
+
46
+ if tts.is_multi_lingual:
47
+ language_id = tts.languages[0]
48
+
49
+ if tts.is_multi_speaker:
50
+ speaker_id =tts.speakers[0]
51
+
52
+ tts.tts_to_file(text="this is a test message", file_path=audio_buffer, speaker=speaker_id, language=language_id)
53
+ except Exception as e:
54
+ print(DEBUG_PREFIX,"ERROR:", e)
55
+ print("Model", model_id, "cannot be loaded, maybe wrong model name? Must be one of")
56
+ for i in TTS.list_models():
57
+ print(i)
58
+ return False
59
+
60
+ print(DEBUG_PREFIX,"Success")
61
+ return True
62
+
63
+ def coqui_check_model_state():
64
+ """
65
+ Check if the requested model is installed on the server machine
66
+ """
67
+ try:
68
+ model_state = "absent"
69
+ request_json = request.get_json()
70
+ model_id = request_json["model_id"]
71
+
72
+ print(DEBUG_PREFIX,"Search for model", model_id)
73
+
74
+ coqui_models_folder = ModelManager().output_prefix # models location
75
+
76
+ # Check if tts folder exist
77
+ if os.path.isdir(coqui_models_folder):
78
+
79
+ installed_models = os.listdir(coqui_models_folder)
80
+
81
+ model_folder_exists = False
82
+ model_folder = None
83
+
84
+ for i in installed_models:
85
+ if model_id == i.replace("--","/",3): # Error with model wrong name
86
+ model_folder_exists = True
87
+ model_folder = i
88
+ print(DEBUG_PREFIX,"Folder found:",model_folder)
89
+
90
+ # Check failed download
91
+ if model_folder_exists:
92
+ content = os.listdir(os.path.join(coqui_models_folder,model_folder))
93
+ print(DEBUG_PREFIX,"Checking content:",content)
94
+ for i in content:
95
+ if i == model_folder+".zip":
96
+ print("Corrupt installed found, model download must have failed previously")
97
+ model_state = "corrupted"
98
+ break
99
+
100
+ if model_state != "corrupted":
101
+ model_state = "installed"
102
+
103
+ response = json.dumps({"model_state":model_state})
104
+ return response
105
+
106
+ except Exception as e:
107
+ print(e)
108
+ abort(500, DEBUG_PREFIX + " Exception occurs while trying to search for installed model")
109
+
110
+ def coqui_install_model():
111
+ """
112
+ Install requested model is installed on the server machine
113
+ """
114
+ global gpu_mode
115
+ global is_downloading
116
+
117
+ try:
118
+ model_installed = False
119
+ request_json = request.get_json()
120
+ model_id = request_json["model_id"]
121
+ action = request_json["action"]
122
+
123
+ print(DEBUG_PREFIX,"Received request",action,"for model",model_id)
124
+
125
+ if (is_downloading):
126
+ print(DEBUG_PREFIX,"Rejected, already downloading a model")
127
+ return json.dumps({"status":"downloading"})
128
+
129
+ coqui_models_folder = ModelManager().output_prefix # models location
130
+
131
+ # Check if tts folder exist
132
+ if os.path.isdir(coqui_models_folder):
133
+ installed_models = os.listdir(coqui_models_folder)
134
+ model_path = None
135
+
136
+ print(DEBUG_PREFIX,"Found",len(installed_models),"models in",coqui_models_folder)
137
+
138
+ for i in installed_models:
139
+ if model_id == i.replace("--","/"):
140
+ model_installed = True
141
+ model_path = os.path.join(coqui_models_folder,i)
142
+
143
+ if model_installed:
144
+ print(DEBUG_PREFIX,"model found:", model_id)
145
+ else:
146
+ print(DEBUG_PREFIX,"model not found")
147
+
148
+ if action == "download":
149
+ if model_installed:
150
+ abort(500, DEBUG_PREFIX + "Bad request, model already installed.")
151
+
152
+ is_downloading = True
153
+ TTS(model_name=model_id, progress_bar=True, gpu=gpu_mode)
154
+ is_downloading = False
155
+
156
+ if action == "repare":
157
+ if not model_installed:
158
+ abort(500, DEBUG_PREFIX + " bad request: requesting repare of model not installed")
159
+
160
+
161
+ print(DEBUG_PREFIX,"Deleting corrupted model folder:",model_path)
162
+ shutil.rmtree(model_path, ignore_errors=True)
163
+
164
+ is_downloading = True
165
+ TTS(model_name=model_id, progress_bar=True, gpu=gpu_mode)
166
+ is_downloading = False
167
+
168
+ response = json.dumps({"status":"done"})
169
+ return response
170
+
171
+ except Exception as e:
172
+ is_downloading = False
173
+ print(e)
174
+ abort(500, DEBUG_PREFIX + " Exception occurs while trying to search for installed model")
175
+
176
+ def coqui_get_local_models():
177
+ """
178
+ Return user local models list in the following format: [language][dataset][name] = TTS_string_id
179
+ """
180
+ try:
181
+ print(DEBUG_PREFIX, "Received request for list of RVC models")
182
+
183
+ folder_names = os.listdir(COQUI_MODELS_PATH)
184
+
185
+ print(DEBUG_PREFIX,"Searching model in",COQUI_MODELS_PATH)
186
+
187
+ model_list = []
188
+ for folder_name in folder_names:
189
+ folder_path = COQUI_MODELS_PATH+folder_name
190
+
191
+ if folder_name in IGNORED_FILES:
192
+ continue
193
+
194
+ # Must be a folder
195
+ if not os.path.isdir(folder_path):
196
+ print("> WARNING:",folder_name,"is not a folder, it should not be there, ignored")
197
+ continue
198
+
199
+ print("> Found model folder",folder_name)
200
+
201
+ # Check pth
202
+ valid_folder = False
203
+ for file_name in os.listdir(folder_path):
204
+ if file_name.endswith(".pth"):
205
+ print(" > pth:",file_name)
206
+ valid_folder = True
207
+ if file_name.endswith(".config"):
208
+ print(" > config:",file_name)
209
+
210
+ if valid_folder:
211
+ print(" > Valid folder added to list")
212
+ model_list.append(folder_name)
213
+ else:
214
+ print(" > WARNING: Missing pth or config file, ignored folder")
215
+
216
+ # Return the list of valid folders
217
+ response = json.dumps({"models_list":model_list})
218
+ return response
219
+
220
+ except Exception as e:
221
+ print(e)
222
+ abort(500, DEBUG_PREFIX + " Exception occurs while searching for Coqui models.")
223
+
224
+
225
+
226
+ def coqui_generate_tts():
227
+ """
228
+ Process request text with the loaded RVC model
229
+ - expected request: {
230
+ "text": text,
231
+ "model_id": voiceId,
232
+ "language_id": language,
233
+ "speaker_id": speaker
234
+ }
235
+
236
+ - model_id formats:
237
+ - model_type/language/dataset/model_name
238
+ - model_type/language/dataset/model_name[spearker_id]
239
+ - model_type/language/dataset/model_name[spearker_id][language_id]
240
+ - examples:
241
+ - tts_models/ja/kokoro/tacotron2-DDC
242
+ - tts_models/en/vctk/vits[0]
243
+ - tts_models/multilingual/multi-dataset/your_tts[2][1]
244
+ """
245
+ global gpu_mode
246
+ global is_downloading
247
+ audio_buffer = io.BytesIO()
248
+
249
+ try:
250
+ request_json = request.get_json()
251
+ #print(request_json)
252
+
253
+ print(DEBUG_PREFIX,"Received TTS request for ", request_json)
254
+
255
+ if (is_downloading):
256
+ print(DEBUG_PREFIX,"Rejected, currently downloading a model, cannot perform TTS")
257
+ abort(500, DEBUG_PREFIX + " Requested TTS while downloading a model")
258
+
259
+ text = request_json["text"]
260
+ model_name = request_json["model_id"]
261
+ language_id = None
262
+ speaker_id = None
263
+
264
+ # Local model
265
+ model_type = model_name.split("/")[0]
266
+ if model_type == "local":
267
+ return generate_tts_local(model_name.split("/")[1], text)
268
+
269
+
270
+ if request_json["language_id"] != "none":
271
+ language_id = request_json["language_id"]
272
+
273
+ if request_json["speaker_id"] != "none":
274
+ speaker_id = request_json["speaker_id"]
275
+
276
+ print(DEBUG_PREFIX,"Loading tts \n- model", model_name, "\n - speaker_id: ",speaker_id,"\n - language_id: ",language_id, "\n - using",("GPU" if gpu_mode else "CPU"))
277
+
278
+ is_downloading = True
279
+ tts = TTS(model_name=model_name, progress_bar=True, gpu=gpu_mode)
280
+ is_downloading = False
281
+
282
+ if tts.is_multi_lingual:
283
+ if language_id is None:
284
+ abort(400, DEBUG_PREFIX + " Requested model "+model_name+" is multi-lingual but no language id provided")
285
+ language_id = tts.languages[int(language_id)]
286
+
287
+ if tts.is_multi_speaker:
288
+ if speaker_id is None:
289
+ abort(400, DEBUG_PREFIX + " Requested model "+model_name+" is multi-speaker but no speaker id provided")
290
+ speaker_id =tts.speakers[int(speaker_id)]
291
+
292
+ tts.tts_to_file(text=text, file_path=audio_buffer, speaker=speaker_id, language=language_id)
293
+
294
+ print(DEBUG_PREFIX, "Success, saved to",audio_buffer)
295
+
296
+ # Return the output_audio_path object as a response
297
+ response = send_file(audio_buffer, mimetype="audio/x-wav")
298
+ audio_buffer = io.BytesIO()
299
+
300
+ return response
301
+
302
+ except Exception as e:
303
+ print(e)
304
+ abort(500, DEBUG_PREFIX + " Exception occurs while trying to process request "+str(request_json))
305
+
306
+ def generate_tts_local(model_folder, text):
307
+ """
308
+ Generate tts using local coqui model
309
+ """
310
+ audio_buffer = io.BytesIO()
311
+
312
+ print(DEBUG_PREFIX,"Request for tts from local coqui model",model_folder)
313
+
314
+ model_path = os.path.join(COQUI_MODELS_PATH,model_folder,COQUI_LOCAL_MODEL_FILE_NAME)
315
+ config_path = os.path.join(COQUI_MODELS_PATH,model_folder,COQUI_LOCAL_CONFIG_FILE_NAME)
316
+
317
+ if not os.path.exists(model_path):
318
+ raise ValueError("File does not exists:",model_path)
319
+
320
+ if not os.path.exists(config_path):
321
+ raise ValueError("File does not exists:",config_path)
322
+
323
+ print(DEBUG_PREFIX,"Loading local tts model", model_path,"using",("GPU" if gpu_mode else "CPU"))
324
+ tts = TTS(model_path=model_path, config_path=config_path, progress_bar=True, gpu=gpu_mode)
325
+ tts.tts_to_file(text=text, file_path=audio_buffer)
326
+
327
+ print(DEBUG_PREFIX, "Success, saved to",audio_buffer)
328
+
329
+ # Return the output_audio_path object as a response
330
+ response = send_file(audio_buffer, mimetype="audio/x-wav")
331
+ audio_buffer = io.BytesIO()
332
+
333
+ return response
modules/utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from contextlib import contextmanager
3
+ import sys
4
+
5
+ @contextmanager
6
+ def silence_log():
7
+ old_stdout = sys.stdout
8
+ old_stderr = sys.stderr
9
+ try:
10
+ with open(os.devnull, "w") as new_target:
11
+ sys.stdout = new_target
12
+ yield new_target
13
+ finally:
14
+ sys.stdout = old_stdout
15
+ sys.stderr = old_stderr
modules/voice_conversion/fairseq/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Facebook, Inc. and its affiliates.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
modules/voice_conversion/fairseq/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """isort:skip_file"""
6
+
7
+ import os
8
+ import sys
9
+
10
+ try:
11
+ from .version import __version__ # noqa
12
+ except ImportError:
13
+ version_txt = os.path.join(os.path.dirname(__file__), "version.txt")
14
+ with open(version_txt) as f:
15
+ __version__ = f.read().strip()
16
+
17
+ __all__ = ["pdb"]
18
+
19
+ # backwards compatibility to support `from fairseq.X import Y`
20
+ from fairseq.distributed import utils as distributed_utils
21
+ from fairseq.logging import meters, metrics, progress_bar # noqa
22
+
23
+ sys.modules["fairseq.distributed_utils"] = distributed_utils
24
+ sys.modules["fairseq.meters"] = meters
25
+ sys.modules["fairseq.metrics"] = metrics
26
+ sys.modules["fairseq.progress_bar"] = progress_bar
27
+
28
+ # initialize hydra
29
+ #from fairseq.dataclass.initialize import hydra_init
30
+
31
+ #hydra_init()
32
+
33
+ #import fairseq.criterions # noqa
34
+ #import fairseq.distributed # noqa
35
+ #import fairseq.models # noqa
36
+ #import fairseq.modules # noqa
37
+ #import fairseq.optim # noqa
38
+ #import fairseq.optim.lr_scheduler # noqa
39
+ #import fairseq.pdb # noqa
40
+ #import fairseq.scoring # noqa
41
+ #import fairseq.tasks # noqa
42
+ #import fairseq.token_generation_constraints # noqa
43
+
44
+ #import fairseq.benchmark # noqa
45
+ #import fairseq.model_parallel # noqa
modules/voice_conversion/fairseq/binarizer.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import os
8
+ import typing as tp
9
+ from abc import ABC, abstractmethod
10
+ from collections import Counter
11
+ from dataclasses import dataclass
12
+ from multiprocessing import Pool
13
+
14
+ import torch
15
+
16
+ from fairseq.data import Dictionary, indexed_dataset
17
+ from fairseq.file_chunker_utils import Chunker, find_offsets
18
+ from fairseq.file_io import PathManager
19
+ from fairseq.tokenizer import tokenize_line
20
+
21
+ logger = logging.getLogger("binarizer")
22
+
23
+
24
+ @dataclass
25
+ class BinarizeSummary:
26
+ """
27
+ Keep track of what's going on in the binarizer
28
+ """
29
+
30
+ num_seq: int = 0
31
+ replaced: tp.Optional[Counter] = None
32
+ num_tok: int = 0
33
+
34
+ @property
35
+ def num_replaced(self) -> int:
36
+ if self.replaced is None:
37
+ return 0
38
+ return sum(self.replaced.values())
39
+
40
+ @property
41
+ def replaced_percent(self) -> float:
42
+ return 100 * self.num_replaced / self.num_tok
43
+
44
+ def __str__(self) -> str:
45
+ base = f"{self.num_seq} sents, {self.num_tok} tokens"
46
+ if self.replaced is None:
47
+ return base
48
+
49
+ return f"{base}, {self.replaced_percent:.3}% replaced"
50
+
51
+ def merge(self, other: "BinarizeSummary"):
52
+ replaced = None
53
+ if self.replaced is not None:
54
+ replaced = self.replaced
55
+ if other.replaced is not None:
56
+ if replaced is None:
57
+ replaced = other.replaced
58
+ else:
59
+ replaced += other.replaced
60
+ self.replaced = replaced
61
+ self.num_seq += other.num_seq
62
+ self.num_tok += other.num_tok
63
+
64
+
65
+ class Binarizer(ABC):
66
+ """
67
+ a binarizer describes how to take a string and build a tensor out of it
68
+ """
69
+
70
+ @abstractmethod
71
+ def binarize_line(
72
+ self,
73
+ line: str,
74
+ summary: BinarizeSummary,
75
+ ) -> torch.IntTensor:
76
+ ...
77
+
78
+
79
+ def _worker_prefix(output_prefix: str, worker_id: int):
80
+ return f"{output_prefix}.pt{worker_id}"
81
+
82
+
83
+ class FileBinarizer:
84
+ """
85
+ An file binarizer can take a file, tokenize it, and binarize each line to a tensor
86
+ """
87
+
88
+ @classmethod
89
+ def multiprocess_dataset(
90
+ cls,
91
+ input_file: str,
92
+ dataset_impl: str,
93
+ binarizer: Binarizer,
94
+ output_prefix: str,
95
+ vocab_size=None,
96
+ num_workers=1,
97
+ ) -> BinarizeSummary:
98
+ final_summary = BinarizeSummary()
99
+
100
+ offsets = find_offsets(input_file, num_workers)
101
+ # find_offsets returns a list of position [pos1, pos2, pos3, pos4] but we would want pairs:
102
+ # [(pos1, pos2), (pos2, pos3), (pos3, pos4)] to process the chunks with start/end info
103
+ # we zip the list with itself shifted by one to get all the pairs.
104
+ (first_chunk, *more_chunks) = zip(offsets, offsets[1:])
105
+ pool = None
106
+ if num_workers > 1:
107
+ pool = Pool(processes=num_workers - 1)
108
+ worker_results = [
109
+ pool.apply_async(
110
+ cls._binarize_chunk_and_finalize,
111
+ args=(
112
+ binarizer,
113
+ input_file,
114
+ start_offset,
115
+ end_offset,
116
+ _worker_prefix(
117
+ output_prefix,
118
+ worker_id,
119
+ ),
120
+ dataset_impl,
121
+ ),
122
+ kwds={
123
+ "vocab_size": vocab_size,
124
+ }
125
+ if vocab_size is not None
126
+ else {},
127
+ )
128
+ for worker_id, (start_offset, end_offset) in enumerate(
129
+ more_chunks, start=1
130
+ )
131
+ ]
132
+
133
+ pool.close()
134
+ pool.join()
135
+ for r in worker_results:
136
+ summ = r.get()
137
+ final_summary.merge(summ)
138
+
139
+ # do not close the bin file as we need to merge the worker results in
140
+ final_ds, summ = cls._binarize_file_chunk(
141
+ binarizer,
142
+ input_file,
143
+ offset_start=first_chunk[0],
144
+ offset_end=first_chunk[1],
145
+ output_prefix=output_prefix,
146
+ dataset_impl=dataset_impl,
147
+ vocab_size=vocab_size if vocab_size is not None else None,
148
+ )
149
+ final_summary.merge(summ)
150
+
151
+ if num_workers > 1:
152
+ for worker_id in range(1, num_workers):
153
+ # merge the worker outputs
154
+ worker_output_prefix = _worker_prefix(
155
+ output_prefix,
156
+ worker_id,
157
+ )
158
+ final_ds.merge_file_(worker_output_prefix)
159
+ try:
160
+ os.remove(indexed_dataset.data_file_path(worker_output_prefix))
161
+ os.remove(indexed_dataset.index_file_path(worker_output_prefix))
162
+ except Exception as e:
163
+ logger.error(
164
+ f"couldn't remove {worker_output_prefix}.*", exc_info=e
165
+ )
166
+
167
+ # now we can close the file
168
+ idx_file = indexed_dataset.index_file_path(output_prefix)
169
+ final_ds.finalize(idx_file)
170
+ return final_summary
171
+
172
+ @staticmethod
173
+ def _binarize_file_chunk(
174
+ binarizer: Binarizer,
175
+ filename: str,
176
+ offset_start: int,
177
+ offset_end: int,
178
+ output_prefix: str,
179
+ dataset_impl: str,
180
+ vocab_size=None,
181
+ ) -> tp.Tuple[tp.Any, BinarizeSummary]: # (dataset builder, BinarizeSummary)
182
+ """
183
+ creates a dataset builder and append binarized items to it. This function does not
184
+ finalize the builder, this is useful if you want to do other things with your bin file
185
+ like appending/merging other files
186
+ """
187
+ bin_file = indexed_dataset.data_file_path(output_prefix)
188
+ ds = indexed_dataset.make_builder(
189
+ bin_file,
190
+ impl=dataset_impl,
191
+ vocab_size=vocab_size,
192
+ )
193
+ summary = BinarizeSummary()
194
+
195
+ with Chunker(
196
+ PathManager.get_local_path(filename), offset_start, offset_end
197
+ ) as line_iterator:
198
+ for line in line_iterator:
199
+ ds.add_item(binarizer.binarize_line(line, summary))
200
+
201
+ return ds, summary
202
+
203
+ @classmethod
204
+ def _binarize_chunk_and_finalize(
205
+ cls,
206
+ binarizer: Binarizer,
207
+ filename: str,
208
+ offset_start: int,
209
+ offset_end: int,
210
+ output_prefix: str,
211
+ dataset_impl: str,
212
+ vocab_size=None,
213
+ ):
214
+ """
215
+ same as above, but also finalizes the builder
216
+ """
217
+ ds, summ = cls._binarize_file_chunk(
218
+ binarizer,
219
+ filename,
220
+ offset_start,
221
+ offset_end,
222
+ output_prefix,
223
+ dataset_impl,
224
+ vocab_size=vocab_size,
225
+ )
226
+
227
+ idx_file = indexed_dataset.index_file_path(output_prefix)
228
+ ds.finalize(idx_file)
229
+
230
+ return summ
231
+
232
+
233
+ class VocabularyDatasetBinarizer(Binarizer):
234
+ """
235
+ Takes a Dictionary/Vocabulary, assign ids to each
236
+ token using the dictionary encode_line function.
237
+ """
238
+
239
+ def __init__(
240
+ self,
241
+ dict: Dictionary,
242
+ tokenize: tp.Callable[[str], tp.List[str]] = tokenize_line,
243
+ append_eos: bool = True,
244
+ reverse_order: bool = False,
245
+ already_numberized: bool = False,
246
+ ) -> None:
247
+ self.dict = dict
248
+ self.tokenize = tokenize
249
+ self.append_eos = append_eos
250
+ self.reverse_order = reverse_order
251
+ self.already_numberized = already_numberized
252
+ super().__init__()
253
+
254
+ def binarize_line(
255
+ self,
256
+ line: str,
257
+ summary: BinarizeSummary,
258
+ ):
259
+ if summary.replaced is None:
260
+ summary.replaced = Counter()
261
+
262
+ def replaced_consumer(word, idx):
263
+ if idx == self.dict.unk_index and word != self.dict.unk_word:
264
+ summary.replaced.update([word])
265
+
266
+ if self.already_numberized:
267
+ id_strings = line.strip().split()
268
+ id_list = [int(id_string) for id_string in id_strings]
269
+ if self.reverse_order:
270
+ id_list.reverse()
271
+ if self.append_eos:
272
+ id_list.append(self.dict.eos())
273
+ ids = torch.IntTensor(id_list)
274
+ else:
275
+ ids = self.dict.encode_line(
276
+ line=line,
277
+ line_tokenizer=self.tokenize,
278
+ add_if_not_exist=False,
279
+ consumer=replaced_consumer,
280
+ append_eos=self.append_eos,
281
+ reverse_order=self.reverse_order,
282
+ )
283
+
284
+ summary.num_seq += 1
285
+ summary.num_tok += len(ids)
286
+ return ids
287
+
288
+
289
+ class AlignmentDatasetBinarizer(Binarizer):
290
+ """
291
+ binarize by parsing a set of alignments and packing
292
+ them in a tensor (see utils.parse_alignment)
293
+ """
294
+
295
+ def __init__(
296
+ self,
297
+ alignment_parser: tp.Callable[[str], torch.IntTensor],
298
+ ) -> None:
299
+ super().__init__()
300
+ self.alignment_parser = alignment_parser
301
+
302
+ def binarize_line(
303
+ self,
304
+ line: str,
305
+ summary: BinarizeSummary,
306
+ ):
307
+ ids = self.alignment_parser(line)
308
+ summary.num_seq += 1
309
+ summary.num_tok += len(ids)
310
+ return ids
311
+
312
+
313
+ class LegacyBinarizer:
314
+ @classmethod
315
+ def binarize(
316
+ cls,
317
+ filename: str,
318
+ dico: Dictionary,
319
+ consumer: tp.Callable[[torch.IntTensor], None],
320
+ tokenize: tp.Callable[[str], tp.List[str]] = tokenize_line,
321
+ append_eos: bool = True,
322
+ reverse_order: bool = False,
323
+ offset: int = 0,
324
+ end: int = -1,
325
+ already_numberized: bool = False,
326
+ ) -> tp.Dict[str, int]:
327
+ binarizer = VocabularyDatasetBinarizer(
328
+ dict=dico,
329
+ tokenize=tokenize,
330
+ append_eos=append_eos,
331
+ reverse_order=reverse_order,
332
+ already_numberized=already_numberized,
333
+ )
334
+ return cls._consume_file(
335
+ filename,
336
+ binarizer,
337
+ consumer,
338
+ offset_start=offset,
339
+ offset_end=end,
340
+ )
341
+
342
+ @classmethod
343
+ def binarize_alignments(
344
+ cls,
345
+ filename: str,
346
+ alignment_parser: tp.Callable[[str], torch.IntTensor],
347
+ consumer: tp.Callable[[torch.IntTensor], None],
348
+ offset: int = 0,
349
+ end: int = -1,
350
+ ) -> tp.Dict[str, int]:
351
+ binarizer = AlignmentDatasetBinarizer(alignment_parser)
352
+ return cls._consume_file(
353
+ filename,
354
+ binarizer,
355
+ consumer,
356
+ offset_start=offset,
357
+ offset_end=end,
358
+ )
359
+
360
+ @staticmethod
361
+ def _consume_file(
362
+ filename: str,
363
+ binarizer: Binarizer,
364
+ consumer: tp.Callable[[torch.IntTensor], None],
365
+ offset_start: int,
366
+ offset_end: int,
367
+ ) -> tp.Dict[str, int]:
368
+ summary = BinarizeSummary()
369
+
370
+ with Chunker(
371
+ PathManager.get_local_path(filename), offset_start, offset_end
372
+ ) as line_iterator:
373
+ for line in line_iterator:
374
+ consumer(binarizer.binarize_line(line, summary))
375
+
376
+ return {
377
+ "nseq": summary.num_seq,
378
+ "nunk": summary.num_replaced,
379
+ "ntok": summary.num_tok,
380
+ "replaced": summary.replaced,
381
+ }
modules/voice_conversion/fairseq/checkpoint_utils.py ADDED
@@ -0,0 +1,905 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import ast
7
+ import collections
8
+ import contextlib
9
+ import inspect
10
+ import logging
11
+ import os
12
+ import re
13
+ import time
14
+ import traceback
15
+ from collections import OrderedDict
16
+ from pathlib import Path
17
+ from typing import Any, Dict, Optional, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+ from fairseq.data import data_utils
22
+ from fairseq.dataclass.configs import CheckpointConfig
23
+ from fairseq.dataclass.utils import (
24
+ convert_namespace_to_omegaconf,
25
+ overwrite_args_by_name,
26
+ )
27
+ from fairseq.distributed.fully_sharded_data_parallel import FSDP, has_FSDP
28
+ from fairseq.file_io import PathManager
29
+ from fairseq.models import FairseqDecoder, FairseqEncoder
30
+ from omegaconf import DictConfig, OmegaConf, open_dict
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
36
+ from fairseq import meters
37
+
38
+ # only one worker should attempt to create the required dir
39
+ if trainer.data_parallel_rank == 0:
40
+ os.makedirs(cfg.save_dir, exist_ok=True)
41
+
42
+ prev_best = getattr(save_checkpoint, "best", val_loss)
43
+ if val_loss is not None:
44
+ best_function = max if cfg.maximize_best_checkpoint_metric else min
45
+ save_checkpoint.best = best_function(val_loss, prev_best)
46
+
47
+ if cfg.no_save:
48
+ return
49
+
50
+ trainer.consolidate_optimizer() # TODO(SS): do we need this if no_save_optimizer_state
51
+
52
+ if not trainer.should_save_checkpoint_on_current_rank:
53
+ if trainer.always_call_state_dict_during_save_checkpoint:
54
+ trainer.state_dict()
55
+ return
56
+
57
+ write_timer = meters.StopwatchMeter()
58
+ write_timer.start()
59
+
60
+ epoch = epoch_itr.epoch
61
+ end_of_epoch = epoch_itr.end_of_epoch()
62
+ updates = trainer.get_num_updates()
63
+
64
+ logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates")
65
+
66
+ def is_better(a, b):
67
+ return a >= b if cfg.maximize_best_checkpoint_metric else a <= b
68
+
69
+ suffix = trainer.checkpoint_suffix
70
+ checkpoint_conds = collections.OrderedDict()
71
+ checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = (
72
+ end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0
73
+ )
74
+ checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = (
75
+ not end_of_epoch
76
+ and cfg.save_interval_updates > 0
77
+ and updates % cfg.save_interval_updates == 0
78
+ )
79
+ checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and (
80
+ not hasattr(save_checkpoint, "best")
81
+ or is_better(val_loss, save_checkpoint.best)
82
+ )
83
+ if val_loss is not None and cfg.keep_best_checkpoints > 0:
84
+ worst_best = getattr(save_checkpoint, "best", None)
85
+ chkpts = checkpoint_paths(
86
+ cfg.save_dir,
87
+ pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
88
+ cfg.best_checkpoint_metric, suffix
89
+ ),
90
+ )
91
+ if len(chkpts) > 0:
92
+ p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0]
93
+ worst_best = float(p.rsplit("_")[-1].replace("{}.pt".format(suffix), ""))
94
+ # add random digits to resolve ties
95
+ with data_utils.numpy_seed(epoch, updates, val_loss):
96
+ rand_sfx = np.random.randint(0, cfg.keep_best_checkpoints)
97
+
98
+ checkpoint_conds[
99
+ "checkpoint.best_{}_{:.3f}{}{}.pt".format(
100
+ cfg.best_checkpoint_metric, val_loss, rand_sfx, suffix
101
+ )
102
+ ] = worst_best is None or is_better(val_loss, worst_best)
103
+ checkpoint_conds[
104
+ "checkpoint_last{}.pt".format(suffix)
105
+ ] = not cfg.no_last_checkpoints
106
+
107
+ extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
108
+ if hasattr(save_checkpoint, "best"):
109
+ extra_state.update({"best": save_checkpoint.best})
110
+
111
+ checkpoints = [
112
+ os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
113
+ ]
114
+ if len(checkpoints) > 0 and trainer.should_save_checkpoint_on_current_rank:
115
+ trainer.save_checkpoint(checkpoints[0], extra_state)
116
+ for cp in checkpoints[1:]:
117
+ if cfg.write_checkpoints_asynchronously:
118
+ # TODO[ioPath]: Need to implement a delayed asynchronous
119
+ # file copying/moving feature.
120
+ logger.warning(
121
+ f"ioPath is not copying {checkpoints[0]} to {cp} "
122
+ "since async write mode is on."
123
+ )
124
+ else:
125
+ assert PathManager.copy(
126
+ checkpoints[0], cp, overwrite=True
127
+ ), f"Failed to copy {checkpoints[0]} to {cp}"
128
+
129
+ write_timer.stop()
130
+ logger.info(
131
+ "Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format(
132
+ checkpoints[0], epoch, updates, val_loss, write_timer.sum
133
+ )
134
+ )
135
+
136
+ if not end_of_epoch and cfg.keep_interval_updates > 0:
137
+ # remove old checkpoints; checkpoints are sorted in descending order
138
+ if cfg.keep_interval_updates_pattern == -1:
139
+ checkpoints = checkpoint_paths(
140
+ cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix)
141
+ )
142
+ else:
143
+ checkpoints = checkpoint_paths(
144
+ cfg.save_dir,
145
+ pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix),
146
+ keep_match=True,
147
+ )
148
+ checkpoints = [
149
+ x[0]
150
+ for x in checkpoints
151
+ if x[1] % cfg.keep_interval_updates_pattern != 0
152
+ ]
153
+
154
+ for old_chk in checkpoints[cfg.keep_interval_updates :]:
155
+ if os.path.lexists(old_chk):
156
+ os.remove(old_chk)
157
+ elif PathManager.exists(old_chk):
158
+ PathManager.rm(old_chk)
159
+
160
+ if cfg.keep_last_epochs > 0:
161
+ # remove old epoch checkpoints; checkpoints are sorted in descending order
162
+ checkpoints = checkpoint_paths(
163
+ cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix)
164
+ )
165
+ for old_chk in checkpoints[cfg.keep_last_epochs :]:
166
+ if os.path.lexists(old_chk):
167
+ os.remove(old_chk)
168
+ elif PathManager.exists(old_chk):
169
+ PathManager.rm(old_chk)
170
+
171
+ if cfg.keep_best_checkpoints > 0:
172
+ # only keep the best N checkpoints according to validation metric
173
+ checkpoints = checkpoint_paths(
174
+ cfg.save_dir,
175
+ pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
176
+ cfg.best_checkpoint_metric, suffix
177
+ ),
178
+ )
179
+ if not cfg.maximize_best_checkpoint_metric:
180
+ checkpoints = checkpoints[::-1]
181
+ for old_chk in checkpoints[cfg.keep_best_checkpoints :]:
182
+ if os.path.lexists(old_chk):
183
+ os.remove(old_chk)
184
+ elif PathManager.exists(old_chk):
185
+ PathManager.rm(old_chk)
186
+
187
+
188
+ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
189
+ """
190
+ Load a checkpoint and restore the training iterator.
191
+
192
+ *passthrough_args* will be passed through to
193
+ ``trainer.get_train_iterator``.
194
+ """
195
+
196
+ reset_optimizer = cfg.reset_optimizer
197
+ reset_lr_scheduler = cfg.reset_lr_scheduler
198
+ optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides)
199
+ reset_meters = cfg.reset_meters
200
+ reset_dataloader = cfg.reset_dataloader
201
+
202
+ if cfg.finetune_from_model is not None and (
203
+ reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader
204
+ ):
205
+ raise ValueError(
206
+ "--finetune-from-model can not be set together with either --reset-optimizer"
207
+ " or reset_lr_scheduler or reset_meters or reset_dataloader"
208
+ )
209
+
210
+ suffix = trainer.checkpoint_suffix
211
+ if (
212
+ cfg.restore_file == "checkpoint_last.pt"
213
+ ): # default value of restore_file is 'checkpoint_last.pt'
214
+ checkpoint_path = os.path.join(
215
+ cfg.save_dir, "checkpoint_last{}.pt".format(suffix)
216
+ )
217
+ first_launch = not PathManager.exists(checkpoint_path)
218
+ if first_launch and getattr(cfg, "continue_once", None) is not None:
219
+ checkpoint_path = cfg.continue_once
220
+ elif cfg.finetune_from_model is not None and first_launch:
221
+ # if there is no last checkpoint to restore, start the finetune from pretrained model
222
+ # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
223
+ if PathManager.exists(cfg.finetune_from_model):
224
+ checkpoint_path = cfg.finetune_from_model
225
+ reset_optimizer = True
226
+ reset_lr_scheduler = True
227
+ reset_meters = True
228
+ reset_dataloader = True
229
+ logger.info(
230
+ f"loading pretrained model from {checkpoint_path}: "
231
+ "optimizer, lr scheduler, meters, dataloader will be reset"
232
+ )
233
+ else:
234
+ raise ValueError(
235
+ f"--finetune-from-model {cfg.finetune_from_model} does not exist"
236
+ )
237
+ elif suffix is not None:
238
+ checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt")
239
+ else:
240
+ checkpoint_path = cfg.restore_file
241
+
242
+ if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model:
243
+ raise ValueError(
244
+ "--finetune-from-model and --restore-file (non-default value) "
245
+ "can not be specified together: " + str(cfg)
246
+ )
247
+
248
+ extra_state = trainer.load_checkpoint(
249
+ checkpoint_path,
250
+ reset_optimizer,
251
+ reset_lr_scheduler,
252
+ optimizer_overrides,
253
+ reset_meters=reset_meters,
254
+ )
255
+
256
+ if (
257
+ extra_state is not None
258
+ and "best" in extra_state
259
+ and not reset_optimizer
260
+ and not reset_meters
261
+ ):
262
+ save_checkpoint.best = extra_state["best"]
263
+
264
+ if extra_state is not None and not reset_dataloader:
265
+ # restore iterator from checkpoint
266
+ itr_state = extra_state["train_iterator"]
267
+ epoch_itr = trainer.get_train_iterator(
268
+ epoch=itr_state["epoch"], load_dataset=True, **passthrough_args
269
+ )
270
+ epoch_itr.load_state_dict(itr_state)
271
+ else:
272
+ epoch_itr = trainer.get_train_iterator(
273
+ epoch=1, load_dataset=True, **passthrough_args
274
+ )
275
+
276
+ trainer.lr_step(epoch_itr.epoch)
277
+
278
+ return extra_state, epoch_itr
279
+
280
+
281
+ def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False):
282
+ """Loads a checkpoint to CPU (with upgrading for backward compatibility).
283
+
284
+ If doing single-GPU training or if the checkpoint is only being loaded by at
285
+ most one process on each node (current default behavior is for only rank 0
286
+ to read the checkpoint from disk), load_on_all_ranks should be False to
287
+ avoid errors from torch.distributed not having been initialized or
288
+ torch.distributed.barrier() hanging.
289
+
290
+ If all processes on each node may be loading the checkpoint
291
+ simultaneously, load_on_all_ranks should be set to True to avoid I/O
292
+ conflicts.
293
+
294
+ There's currently no support for > 1 but < all processes loading the
295
+ checkpoint on each node.
296
+ """
297
+ local_path = PathManager.get_local_path(path)
298
+ # The locally cached file returned by get_local_path() may be stale for
299
+ # remote files that are periodically updated/overwritten (ex:
300
+ # checkpoint_last.pt) - so we remove the local copy, sync across processes
301
+ # (if needed), and then download a fresh copy.
302
+ if local_path != path and PathManager.path_requires_pathmanager(path):
303
+ try:
304
+ os.remove(local_path)
305
+ except FileNotFoundError:
306
+ # With potentially multiple processes removing the same file, the
307
+ # file being missing is benign (missing_ok isn't available until
308
+ # Python 3.8).
309
+ pass
310
+ if load_on_all_ranks:
311
+ torch.distributed.barrier()
312
+ local_path = PathManager.get_local_path(path)
313
+
314
+ with open(local_path, "rb") as f:
315
+ state = torch.load(f, map_location=torch.device("cpu"))
316
+
317
+ if "args" in state and state["args"] is not None and arg_overrides is not None:
318
+ args = state["args"]
319
+ for arg_name, arg_val in arg_overrides.items():
320
+ setattr(args, arg_name, arg_val)
321
+
322
+ if "cfg" in state and state["cfg"] is not None:
323
+
324
+ # hack to be able to set Namespace in dict config. this should be removed when we update to newer
325
+ # omegaconf version that supports object flags, or when we migrate all existing models
326
+ from omegaconf import __version__ as oc_version
327
+ from omegaconf import _utils
328
+
329
+ if oc_version < "2.2":
330
+ old_primitive = _utils.is_primitive_type
331
+ _utils.is_primitive_type = lambda _: True
332
+
333
+ state["cfg"] = OmegaConf.create(state["cfg"])
334
+
335
+ _utils.is_primitive_type = old_primitive
336
+ OmegaConf.set_struct(state["cfg"], True)
337
+ else:
338
+ state["cfg"] = OmegaConf.create(state["cfg"], flags={"allow_objects": True})
339
+
340
+ if arg_overrides is not None:
341
+ overwrite_args_by_name(state["cfg"], arg_overrides)
342
+
343
+ state = _upgrade_state_dict(state)
344
+ return state
345
+
346
+
347
+ def load_model_ensemble(
348
+ filenames,
349
+ arg_overrides: Optional[Dict[str, Any]] = None,
350
+ task=None,
351
+ strict=True,
352
+ suffix="",
353
+ num_shards=1,
354
+ state=None,
355
+ ):
356
+ """Loads an ensemble of models.
357
+
358
+ Args:
359
+ filenames (List[str]): checkpoint files to load
360
+ arg_overrides (Dict[str,Any], optional): override model args that
361
+ were used during model training
362
+ task (fairseq.tasks.FairseqTask, optional): task to use for loading
363
+ """
364
+ assert not (
365
+ strict and num_shards > 1
366
+ ), "Cannot load state dict with strict=True and checkpoint shards > 1"
367
+ ensemble, args, _task = load_model_ensemble_and_task(
368
+ filenames,
369
+ arg_overrides,
370
+ task,
371
+ strict,
372
+ suffix,
373
+ num_shards,
374
+ state,
375
+ )
376
+ return ensemble, args
377
+
378
+
379
+ def get_maybe_sharded_checkpoint_filename(
380
+ filename: str, suffix: str, shard_idx: int, num_shards: int
381
+ ) -> str:
382
+ orig_filename = filename
383
+ filename = filename.replace(".pt", suffix + ".pt")
384
+ fsdp_filename = filename[:-3] + f"-shard{shard_idx}.pt"
385
+ model_parallel_filename = orig_filename[:-3] + f"_part{shard_idx}.pt"
386
+ if PathManager.exists(fsdp_filename):
387
+ return fsdp_filename
388
+ elif num_shards > 1:
389
+ return model_parallel_filename
390
+ else:
391
+ return filename
392
+
393
+
394
+ def load_model_ensemble_and_task(
395
+ filenames,
396
+ arg_overrides: Optional[Dict[str, Any]] = None,
397
+ task=None,
398
+ strict=True,
399
+ suffix="",
400
+ num_shards=1,
401
+ state=None,
402
+ ):
403
+ assert state is None or len(filenames) == 1
404
+
405
+ from fairseq import tasks
406
+
407
+ assert not (
408
+ strict and num_shards > 1
409
+ ), "Cannot load state dict with strict=True and checkpoint shards > 1"
410
+ ensemble = []
411
+ cfg = None
412
+ for filename in filenames:
413
+ orig_filename = filename
414
+ model_shard_state = {"shard_weights": [], "shard_metadata": []}
415
+ assert num_shards > 0
416
+ st = time.time()
417
+ for shard_idx in range(num_shards):
418
+ filename = get_maybe_sharded_checkpoint_filename(
419
+ orig_filename, suffix, shard_idx, num_shards
420
+ )
421
+
422
+ if not PathManager.exists(filename):
423
+ raise IOError("Model file not found: {}".format(filename))
424
+ if state is None:
425
+ state = load_checkpoint_to_cpu(filename, arg_overrides)
426
+ if "args" in state and state["args"] is not None:
427
+ cfg = convert_namespace_to_omegaconf(state["args"])
428
+ elif "cfg" in state and state["cfg"] is not None:
429
+ cfg = state["cfg"]
430
+ else:
431
+ raise RuntimeError(
432
+ f"Neither args nor cfg exist in state keys = {state.keys()}"
433
+ )
434
+
435
+ if task is None:
436
+ task = tasks.setup_task(cfg.task)
437
+
438
+ if "task_state" in state:
439
+ task.load_state_dict(state["task_state"])
440
+
441
+ if "fsdp_metadata" in state and num_shards > 1:
442
+ model_shard_state["shard_weights"].append(state["model"])
443
+ model_shard_state["shard_metadata"].append(state["fsdp_metadata"])
444
+ # check FSDP import before the code goes too far
445
+ if not has_FSDP:
446
+ raise ImportError(
447
+ "Cannot find FullyShardedDataParallel. "
448
+ "Please install fairscale with: pip install fairscale"
449
+ )
450
+ if shard_idx == num_shards - 1:
451
+ consolidated_model_state = FSDP.consolidate_shard_weights(
452
+ shard_weights=model_shard_state["shard_weights"],
453
+ shard_metadata=model_shard_state["shard_metadata"],
454
+ )
455
+ model = task.build_model(cfg.model)
456
+ if (
457
+ "optimizer_history" in state
458
+ and len(state["optimizer_history"]) > 0
459
+ and "num_updates" in state["optimizer_history"][-1]
460
+ ):
461
+ model.set_num_updates(
462
+ state["optimizer_history"][-1]["num_updates"]
463
+ )
464
+ model.load_state_dict(
465
+ consolidated_model_state, strict=strict, model_cfg=cfg.model
466
+ )
467
+ else:
468
+ # model parallel checkpoint or unsharded checkpoint
469
+ # support old external tasks
470
+
471
+ argspec = inspect.getfullargspec(task.build_model)
472
+ if "from_checkpoint" in argspec.args:
473
+ model = task.build_model(cfg.model, from_checkpoint=True)
474
+ else:
475
+ model = task.build_model(cfg.model)
476
+ if (
477
+ "optimizer_history" in state
478
+ and len(state["optimizer_history"]) > 0
479
+ and "num_updates" in state["optimizer_history"][-1]
480
+ ):
481
+ model.set_num_updates(state["optimizer_history"][-1]["num_updates"])
482
+ model.load_state_dict(
483
+ state["model"], strict=strict, model_cfg=cfg.model
484
+ )
485
+
486
+ # reset state so it gets loaded for the next model in ensemble
487
+ state = None
488
+ if shard_idx % 10 == 0 and shard_idx > 0:
489
+ elapsed = time.time() - st
490
+ logger.info(
491
+ f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard"
492
+ )
493
+
494
+ # build model for ensemble
495
+ ensemble.append(model)
496
+ return ensemble, cfg, task
497
+
498
+
499
+ def load_model_ensemble_and_task_from_hf_hub(
500
+ model_id,
501
+ cache_dir: Optional[str] = None,
502
+ arg_overrides: Optional[Dict[str, Any]] = None,
503
+ **kwargs: Any,
504
+ ):
505
+ try:
506
+ from huggingface_hub import snapshot_download
507
+ except ImportError:
508
+ raise ImportError(
509
+ "You need to install huggingface_hub to use `load_from_hf_hub`. "
510
+ "See https://pypi.org/project/huggingface-hub/ for installation."
511
+ )
512
+
513
+ library_name = "fairseq"
514
+ cache_dir = cache_dir or (Path.home() / ".cache" / library_name).as_posix()
515
+ cache_dir = snapshot_download(
516
+ model_id, cache_dir=cache_dir, library_name=library_name, **kwargs
517
+ )
518
+
519
+ _arg_overrides = arg_overrides or {}
520
+ _arg_overrides["data"] = cache_dir
521
+ return load_model_ensemble_and_task(
522
+ [p.as_posix() for p in Path(cache_dir).glob("*.pt")],
523
+ arg_overrides=_arg_overrides,
524
+ )
525
+
526
+
527
+ def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt", keep_match=False):
528
+ """Retrieves all checkpoints found in `path` directory.
529
+
530
+ Checkpoints are identified by matching filename to the specified pattern. If
531
+ the pattern contains groups, the result will be sorted by the first group in
532
+ descending order.
533
+ """
534
+ pt_regexp = re.compile(pattern)
535
+ files = PathManager.ls(path)
536
+
537
+ entries = []
538
+ for i, f in enumerate(files):
539
+ m = pt_regexp.fullmatch(f)
540
+ if m is not None:
541
+ idx = float(m.group(1)) if len(m.groups()) > 0 else i
542
+ entries.append((idx, m.group(0)))
543
+ if keep_match:
544
+ return [(os.path.join(path, x[1]), x[0]) for x in sorted(entries, reverse=True)]
545
+ else:
546
+ return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
547
+
548
+
549
+ def torch_persistent_save(obj, filename, async_write: bool = False):
550
+ if async_write:
551
+ with PathManager.opena(filename, "wb") as f:
552
+ _torch_persistent_save(obj, f)
553
+ else:
554
+ if PathManager.supports_rename(filename):
555
+ # do atomic save
556
+ with PathManager.open(filename + ".tmp", "wb") as f:
557
+ _torch_persistent_save(obj, f)
558
+ PathManager.rename(filename + ".tmp", filename)
559
+ else:
560
+ # fallback to non-atomic save
561
+ with PathManager.open(filename, "wb") as f:
562
+ _torch_persistent_save(obj, f)
563
+
564
+
565
+ def _torch_persistent_save(obj, f):
566
+ if isinstance(f, str):
567
+ with PathManager.open(f, "wb") as h:
568
+ torch_persistent_save(obj, h)
569
+ return
570
+ for i in range(3):
571
+ try:
572
+ return torch.save(obj, f)
573
+ except Exception:
574
+ if i == 2:
575
+ logger.error(traceback.format_exc())
576
+ raise
577
+
578
+
579
+ def _upgrade_state_dict(state):
580
+ """Helper for upgrading old model checkpoints."""
581
+
582
+ # add optimizer_history
583
+ if "optimizer_history" not in state:
584
+ state["optimizer_history"] = [
585
+ {"criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"]}
586
+ ]
587
+ state["last_optimizer_state"] = state["optimizer"]
588
+ del state["optimizer"]
589
+ del state["best_loss"]
590
+ # move extra_state into sub-dictionary
591
+ if "epoch" in state and "extra_state" not in state:
592
+ state["extra_state"] = {
593
+ "epoch": state["epoch"],
594
+ "batch_offset": state["batch_offset"],
595
+ "val_loss": state["val_loss"],
596
+ }
597
+ del state["epoch"]
598
+ del state["batch_offset"]
599
+ del state["val_loss"]
600
+ # reduce optimizer history's memory usage (only keep the last state)
601
+ if "optimizer" in state["optimizer_history"][-1]:
602
+ state["last_optimizer_state"] = state["optimizer_history"][-1]["optimizer"]
603
+ for optim_hist in state["optimizer_history"]:
604
+ del optim_hist["optimizer"]
605
+ # record the optimizer class name
606
+ if "optimizer_name" not in state["optimizer_history"][-1]:
607
+ state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG"
608
+ # move best_loss into lr_scheduler_state
609
+ if "lr_scheduler_state" not in state["optimizer_history"][-1]:
610
+ state["optimizer_history"][-1]["lr_scheduler_state"] = {
611
+ "best": state["optimizer_history"][-1]["best_loss"]
612
+ }
613
+ del state["optimizer_history"][-1]["best_loss"]
614
+ # keep track of number of updates
615
+ if "num_updates" not in state["optimizer_history"][-1]:
616
+ state["optimizer_history"][-1]["num_updates"] = 0
617
+ # use stateful training data iterator
618
+ if "train_iterator" not in state["extra_state"]:
619
+ state["extra_state"]["train_iterator"] = {
620
+ "epoch": state["extra_state"].get("epoch", 0),
621
+ "iterations_in_epoch": state["extra_state"].get("batch_offset", 0),
622
+ }
623
+
624
+ # backward compatibility, cfg updates
625
+ if "args" in state and state["args"] is not None:
626
+ # old model checkpoints may not have separate source/target positions
627
+ if hasattr(state["args"], "max_positions") and not hasattr(
628
+ state["args"], "max_source_positions"
629
+ ):
630
+ state["args"].max_source_positions = state["args"].max_positions
631
+ state["args"].max_target_positions = state["args"].max_positions
632
+ # default to translation task
633
+ if not hasattr(state["args"], "task"):
634
+ state["args"].task = "translation"
635
+ # --raw-text and --lazy-load are deprecated
636
+ if getattr(state["args"], "raw_text", False):
637
+ state["args"].dataset_impl = "raw"
638
+ elif getattr(state["args"], "lazy_load", False):
639
+ state["args"].dataset_impl = "lazy"
640
+ # epochs start at 1
641
+ if state["extra_state"]["train_iterator"] is not None:
642
+ state["extra_state"]["train_iterator"]["epoch"] = max(
643
+ state["extra_state"]["train_iterator"].get("epoch", 1), 1
644
+ )
645
+ # --remove-bpe ==> --postprocess
646
+ if hasattr(state["args"], "remove_bpe"):
647
+ state["args"].post_process = state["args"].remove_bpe
648
+ # --min-lr ==> --stop-min-lr
649
+ if hasattr(state["args"], "min_lr"):
650
+ state["args"].stop_min_lr = state["args"].min_lr
651
+ del state["args"].min_lr
652
+ # binary_cross_entropy / kd_binary_cross_entropy => wav2vec criterion
653
+ if hasattr(state["args"], "criterion") and state["args"].criterion in [
654
+ "binary_cross_entropy",
655
+ "kd_binary_cross_entropy",
656
+ ]:
657
+ state["args"].criterion = "wav2vec"
658
+ # remove log_keys if it's None (criteria will supply a default value of [])
659
+ if hasattr(state["args"], "log_keys") and state["args"].log_keys is None:
660
+ delattr(state["args"], "log_keys")
661
+ # speech_pretraining => audio pretraining
662
+ if (
663
+ hasattr(state["args"], "task")
664
+ and state["args"].task == "speech_pretraining"
665
+ ):
666
+ state["args"].task = "audio_pretraining"
667
+ # audio_cpc => wav2vec
668
+ if hasattr(state["args"], "arch") and state["args"].arch == "audio_cpc":
669
+ state["args"].arch = "wav2vec"
670
+ # convert legacy float learning rate to List[float]
671
+ if hasattr(state["args"], "lr") and isinstance(state["args"].lr, float):
672
+ state["args"].lr = [state["args"].lr]
673
+ # convert task data arg to a string instead of List[string]
674
+ if (
675
+ hasattr(state["args"], "data")
676
+ and isinstance(state["args"].data, list)
677
+ and len(state["args"].data) > 0
678
+ ):
679
+ state["args"].data = state["args"].data[0]
680
+
681
+ state["cfg"] = convert_namespace_to_omegaconf(state["args"])
682
+
683
+ if "cfg" in state and state["cfg"] is not None:
684
+ cfg = state["cfg"]
685
+ with open_dict(cfg):
686
+ # any upgrades for Hydra-based configs
687
+ if (
688
+ "task" in cfg
689
+ and "eval_wer_config" in cfg.task
690
+ and isinstance(cfg.task.eval_wer_config.print_alignment, bool)
691
+ ):
692
+ cfg.task.eval_wer_config.print_alignment = "hard"
693
+ if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool):
694
+ cfg.generation.print_alignment = (
695
+ "hard" if cfg.generation.print_alignment else None
696
+ )
697
+ if (
698
+ "model" in cfg
699
+ and "w2v_args" in cfg.model
700
+ and cfg.model.w2v_args is not None
701
+ and (
702
+ hasattr(cfg.model.w2v_args, "task") or "task" in cfg.model.w2v_args
703
+ )
704
+ and hasattr(cfg.model.w2v_args.task, "eval_wer_config")
705
+ and cfg.model.w2v_args.task.eval_wer_config is not None
706
+ and isinstance(
707
+ cfg.model.w2v_args.task.eval_wer_config.print_alignment, bool
708
+ )
709
+ ):
710
+ cfg.model.w2v_args.task.eval_wer_config.print_alignment = "hard"
711
+
712
+ return state
713
+
714
+
715
+ def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]):
716
+ """Prune the given state_dict if desired for LayerDrop
717
+ (https://arxiv.org/abs/1909.11556).
718
+
719
+ Training with LayerDrop allows models to be robust to pruning at inference
720
+ time. This function prunes state_dict to allow smaller models to be loaded
721
+ from a larger model and re-maps the existing state_dict for this to occur.
722
+
723
+ It's called by functions that load models from checkpoints and does not
724
+ need to be called directly.
725
+ """
726
+ arch = None
727
+ if model_cfg is not None:
728
+ arch = (
729
+ model_cfg._name
730
+ if isinstance(model_cfg, DictConfig)
731
+ else getattr(model_cfg, "arch", None)
732
+ )
733
+
734
+ if not model_cfg or arch is None or arch == "ptt_transformer":
735
+ # args should not be none, but don't crash if it is.
736
+ return state_dict
737
+
738
+ encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None)
739
+ decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None)
740
+
741
+ if not encoder_layers_to_keep and not decoder_layers_to_keep:
742
+ return state_dict
743
+
744
+ # apply pruning
745
+ logger.info(
746
+ "Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop"
747
+ )
748
+
749
+ def create_pruning_pass(layers_to_keep, layer_name):
750
+ keep_layers = sorted(
751
+ int(layer_string) for layer_string in layers_to_keep.split(",")
752
+ )
753
+ mapping_dict = {}
754
+ for i in range(len(keep_layers)):
755
+ mapping_dict[str(keep_layers[i])] = str(i)
756
+
757
+ regex = re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name))
758
+ return {"substitution_regex": regex, "mapping_dict": mapping_dict}
759
+
760
+ pruning_passes = []
761
+ if encoder_layers_to_keep:
762
+ pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder"))
763
+ if decoder_layers_to_keep:
764
+ pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder"))
765
+
766
+ new_state_dict = {}
767
+ for layer_name in state_dict.keys():
768
+ match = re.search(r"\.layers\.(\d+)\.", layer_name)
769
+ # if layer has no number in it, it is a supporting layer, such as an
770
+ # embedding
771
+ if not match:
772
+ new_state_dict[layer_name] = state_dict[layer_name]
773
+ continue
774
+
775
+ # otherwise, layer should be pruned.
776
+ original_layer_number = match.group(1)
777
+ # figure out which mapping dict to replace from
778
+ for pruning_pass in pruning_passes:
779
+ if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass[
780
+ "substitution_regex"
781
+ ].search(layer_name):
782
+ new_layer_number = pruning_pass["mapping_dict"][original_layer_number]
783
+ substitution_match = pruning_pass["substitution_regex"].search(
784
+ layer_name
785
+ )
786
+ new_state_key = (
787
+ layer_name[: substitution_match.start(1)]
788
+ + new_layer_number
789
+ + layer_name[substitution_match.end(1) :]
790
+ )
791
+ new_state_dict[new_state_key] = state_dict[layer_name]
792
+
793
+ # Since layers are now pruned, *_layers_to_keep are no longer needed.
794
+ # This is more of "It would make it work fix" rather than a proper fix.
795
+ if isinstance(model_cfg, DictConfig):
796
+ context = open_dict(model_cfg)
797
+ else:
798
+ context = contextlib.ExitStack()
799
+ with context:
800
+ if hasattr(model_cfg, "encoder_layers_to_keep"):
801
+ model_cfg.encoder_layers_to_keep = None
802
+ if hasattr(model_cfg, "decoder_layers_to_keep"):
803
+ model_cfg.decoder_layers_to_keep = None
804
+
805
+ return new_state_dict
806
+
807
+
808
+ def load_pretrained_component_from_model(
809
+ component: Union[FairseqEncoder, FairseqDecoder],
810
+ checkpoint: str,
811
+ strict: bool = True,
812
+ ):
813
+ """
814
+ Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the
815
+ provided `component` object. If state_dict fails to load, there may be a
816
+ mismatch in the architecture of the corresponding `component` found in the
817
+ `checkpoint` file.
818
+ """
819
+ if not PathManager.exists(checkpoint):
820
+ raise IOError("Model file not found: {}".format(checkpoint))
821
+ state = load_checkpoint_to_cpu(checkpoint)
822
+ if isinstance(component, FairseqEncoder):
823
+ component_type = "encoder"
824
+ elif isinstance(component, FairseqDecoder):
825
+ component_type = "decoder"
826
+ else:
827
+ raise ValueError(
828
+ "component to load must be either a FairseqEncoder or "
829
+ "FairseqDecoder. Loading other component types are not supported."
830
+ )
831
+ component_state_dict = OrderedDict()
832
+ for key in state["model"].keys():
833
+ if key.startswith(component_type):
834
+ # encoder.input_layers.0.0.weight --> input_layers.0.0.weight
835
+ component_subkey = key[len(component_type) + 1 :]
836
+ component_state_dict[component_subkey] = state["model"][key]
837
+ component.load_state_dict(component_state_dict, strict=strict)
838
+ return component
839
+
840
+
841
+ def verify_checkpoint_directory(save_dir: str) -> None:
842
+ if not os.path.exists(save_dir):
843
+ os.makedirs(save_dir, exist_ok=True)
844
+ temp_file_path = os.path.join(save_dir, "dummy")
845
+ try:
846
+ with open(temp_file_path, "w"):
847
+ pass
848
+ except OSError as e:
849
+ logger.warning(
850
+ "Unable to access checkpoint save directory: {}".format(save_dir)
851
+ )
852
+ raise e
853
+ else:
854
+ os.remove(temp_file_path)
855
+
856
+
857
+ def save_ema_as_checkpoint(src_path, dst_path):
858
+ state = load_ema_from_checkpoint(src_path)
859
+ torch_persistent_save(state, dst_path)
860
+
861
+
862
+ def load_ema_from_checkpoint(fpath):
863
+ """Loads exponential moving averaged (EMA) checkpoint from input and
864
+ returns a model with ema weights.
865
+
866
+ Args:
867
+ fpath: A string path of checkpoint to load from.
868
+
869
+ Returns:
870
+ A dict of string keys mapping to various values. The 'model' key
871
+ from the returned dict should correspond to an OrderedDict mapping
872
+ string parameter names to torch Tensors.
873
+ """
874
+ params_dict = collections.OrderedDict()
875
+ new_state = None
876
+
877
+ with PathManager.open(fpath, "rb") as f:
878
+ new_state = torch.load(
879
+ f,
880
+ map_location=(
881
+ lambda s, _: torch.serialization.default_restore_location(s, "cpu")
882
+ ),
883
+ )
884
+
885
+ # EMA model is stored in a separate "extra state"
886
+ model_params = new_state["extra_state"]["ema"]
887
+
888
+ for key in list(model_params.keys()):
889
+ p = model_params[key]
890
+ if isinstance(p, torch.HalfTensor):
891
+ p = p.float()
892
+ if key not in params_dict:
893
+ params_dict[key] = p.clone()
894
+ # NOTE: clone() is needed in case of p is a shared parameter
895
+ else:
896
+ raise ValueError("Key {} is repeated in EMA model params.".format(key))
897
+
898
+ if len(params_dict) == 0:
899
+ raise ValueError(
900
+ f"Input checkpoint path '{fpath}' does not contain "
901
+ "ema model weights, is this model trained with EMA?"
902
+ )
903
+
904
+ new_state["model"] = params_dict
905
+ return new_state
modules/voice_conversion/fairseq/data/__init__.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """isort:skip_file"""
6
+
7
+ from .dictionary import Dictionary, TruncatedDictionary
8
+
9
+ from .fairseq_dataset import FairseqDataset, FairseqIterableDataset
10
+
11
+ from .base_wrapper_dataset import BaseWrapperDataset
12
+
13
+ from .add_target_dataset import AddTargetDataset
14
+ from .append_token_dataset import AppendTokenDataset
15
+ from .audio.raw_audio_dataset import BinarizedAudioDataset, FileAudioDataset
16
+ from .audio.hubert_dataset import HubertDataset
17
+ from .backtranslation_dataset import BacktranslationDataset
18
+ from .bucket_pad_length_dataset import BucketPadLengthDataset
19
+ from .colorize_dataset import ColorizeDataset
20
+ from .concat_dataset import ConcatDataset
21
+ from .concat_sentences_dataset import ConcatSentencesDataset
22
+ from .denoising_dataset import DenoisingDataset
23
+ from .id_dataset import IdDataset
24
+ from .indexed_dataset import (
25
+ IndexedCachedDataset,
26
+ IndexedDataset,
27
+ IndexedRawTextDataset,
28
+ MMapIndexedDataset,
29
+ )
30
+ from .language_pair_dataset import LanguagePairDataset
31
+ from .list_dataset import ListDataset
32
+ from .lm_context_window_dataset import LMContextWindowDataset
33
+ from .lru_cache_dataset import LRUCacheDataset
34
+ from .mask_tokens_dataset import MaskTokensDataset
35
+ from .monolingual_dataset import MonolingualDataset
36
+ from .multi_corpus_sampled_dataset import MultiCorpusSampledDataset
37
+ from .nested_dictionary_dataset import NestedDictionaryDataset
38
+ from .noising import NoisingDataset
39
+ from .numel_dataset import NumelDataset
40
+ from .num_samples_dataset import NumSamplesDataset
41
+ from .offset_tokens_dataset import OffsetTokensDataset
42
+ from .pad_dataset import LeftPadDataset, PadDataset, RightPadDataset
43
+ from .prepend_dataset import PrependDataset
44
+ from .prepend_token_dataset import PrependTokenDataset
45
+ from .raw_label_dataset import RawLabelDataset
46
+ from .replace_dataset import ReplaceDataset
47
+ from .resampling_dataset import ResamplingDataset
48
+ from .roll_dataset import RollDataset
49
+ from .round_robin_zip_datasets import RoundRobinZipDatasets
50
+ from .sort_dataset import SortDataset
51
+ from .strip_token_dataset import StripTokenDataset
52
+ from .subsample_dataset import SubsampleDataset
53
+ from .token_block_dataset import TokenBlockDataset
54
+ from .transform_eos_dataset import TransformEosDataset
55
+ from .transform_eos_lang_pair_dataset import TransformEosLangPairDataset
56
+ from .shorten_dataset import TruncateDataset, RandomCropDataset
57
+ from .multilingual.sampled_multi_dataset import SampledMultiDataset
58
+ from .multilingual.sampled_multi_epoch_dataset import SampledMultiEpochDataset
59
+ from .fasta_dataset import FastaDataset, EncodedFastaDataset
60
+ from .transform_eos_concat_langpair_dataset import TransformEosConcatLangPairDataset
61
+
62
+ from .iterators import (
63
+ CountingIterator,
64
+ EpochBatchIterator,
65
+ GroupedIterator,
66
+ ShardedIterator,
67
+ )
68
+
69
+ __all__ = [
70
+ "AddTargetDataset",
71
+ "AppendTokenDataset",
72
+ "BacktranslationDataset",
73
+ "BaseWrapperDataset",
74
+ "BinarizedAudioDataset",
75
+ "BucketPadLengthDataset",
76
+ "ColorizeDataset",
77
+ "ConcatDataset",
78
+ "ConcatSentencesDataset",
79
+ "CountingIterator",
80
+ "DenoisingDataset",
81
+ "Dictionary",
82
+ "EncodedFastaDataset",
83
+ "EpochBatchIterator",
84
+ "FairseqDataset",
85
+ "FairseqIterableDataset",
86
+ "FastaDataset",
87
+ "FileAudioDataset",
88
+ "GroupedIterator",
89
+ "HubertDataset",
90
+ "IdDataset",
91
+ "IndexedCachedDataset",
92
+ "IndexedDataset",
93
+ "IndexedRawTextDataset",
94
+ "LanguagePairDataset",
95
+ "LeftPadDataset",
96
+ "ListDataset",
97
+ "LMContextWindowDataset",
98
+ "LRUCacheDataset",
99
+ "MaskTokensDataset",
100
+ "MMapIndexedDataset",
101
+ "MonolingualDataset",
102
+ "MultiCorpusSampledDataset",
103
+ "NestedDictionaryDataset",
104
+ "NoisingDataset",
105
+ "NumelDataset",
106
+ "NumSamplesDataset",
107
+ "OffsetTokensDataset",
108
+ "PadDataset",
109
+ "PrependDataset",
110
+ "PrependTokenDataset",
111
+ "RandomCropDataset",
112
+ "RawLabelDataset",
113
+ "ResamplingDataset",
114
+ "ReplaceDataset",
115
+ "RightPadDataset",
116
+ "RollDataset",
117
+ "RoundRobinZipDatasets",
118
+ "SampledMultiDataset",
119
+ "SampledMultiEpochDataset",
120
+ "ShardedIterator",
121
+ "SortDataset",
122
+ "StripTokenDataset",
123
+ "SubsampleDataset",
124
+ "TokenBlockDataset",
125
+ "TransformEosDataset",
126
+ "TransformEosLangPairDataset",
127
+ "TransformEosConcatLangPairDataset",
128
+ "TruncateDataset",
129
+ "TruncatedDictionary",
130
+ ]
modules/voice_conversion/fairseq/data/add_target_dataset.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+ from . import BaseWrapperDataset, data_utils
9
+ from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel
10
+
11
+
12
+ class AddTargetDataset(BaseWrapperDataset):
13
+ def __init__(
14
+ self,
15
+ dataset,
16
+ labels,
17
+ pad,
18
+ eos,
19
+ batch_targets,
20
+ process_label=None,
21
+ label_len_fn=None,
22
+ add_to_input=False,
23
+ text_compression_level=TextCompressionLevel.none,
24
+ ):
25
+ super().__init__(dataset)
26
+ self.labels = labels
27
+ self.batch_targets = batch_targets
28
+ self.pad = pad
29
+ self.eos = eos
30
+ self.process_label = process_label
31
+ self.label_len_fn = label_len_fn
32
+ self.add_to_input = add_to_input
33
+ self.text_compressor = TextCompressor(level=text_compression_level)
34
+
35
+ def get_label(self, index, process_fn=None):
36
+ lbl = self.labels[index]
37
+ lbl = self.text_compressor.decompress(lbl)
38
+ return lbl if process_fn is None else process_fn(lbl)
39
+
40
+ def __getitem__(self, index):
41
+ item = self.dataset[index]
42
+ item["label"] = self.get_label(index, process_fn=self.process_label)
43
+ return item
44
+
45
+ def size(self, index):
46
+ sz = self.dataset.size(index)
47
+ own_sz = self.label_len_fn(self.get_label(index))
48
+ return sz, own_sz
49
+
50
+ def collater(self, samples):
51
+ collated = self.dataset.collater(samples)
52
+ if len(collated) == 0:
53
+ return collated
54
+ indices = set(collated["id"].tolist())
55
+ target = [s["label"] for s in samples if s["id"] in indices]
56
+
57
+ if self.add_to_input:
58
+ eos = torch.LongTensor([self.eos])
59
+ prev_output_tokens = [torch.cat([eos, t], axis=-1) for t in target]
60
+ target = [torch.cat([t, eos], axis=-1) for t in target]
61
+ collated["net_input"]["prev_output_tokens"] = prev_output_tokens
62
+
63
+ if self.batch_targets:
64
+ collated["target_lengths"] = torch.LongTensor([len(t) for t in target])
65
+ target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False)
66
+ collated["ntokens"] = collated["target_lengths"].sum().item()
67
+ if getattr(collated["net_input"], "prev_output_tokens", None):
68
+ collated["net_input"]["prev_output_tokens"] = data_utils.collate_tokens(
69
+ collated["net_input"]["prev_output_tokens"],
70
+ pad_idx=self.pad,
71
+ left_pad=False,
72
+ )
73
+ else:
74
+ collated["ntokens"] = sum([len(t) for t in target])
75
+
76
+ collated["target"] = target
77
+ return collated
78
+
79
+ def filter_indices_by_size(self, indices, max_sizes):
80
+ indices, ignored = data_utils._filter_by_size_dynamic(
81
+ indices, self.size, max_sizes
82
+ )
83
+ return indices, ignored
modules/voice_conversion/fairseq/data/append_token_dataset.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from . import BaseWrapperDataset
10
+
11
+
12
+ class AppendTokenDataset(BaseWrapperDataset):
13
+ def __init__(self, dataset, token=None):
14
+ super().__init__(dataset)
15
+ self.token = token
16
+ if token is not None:
17
+ self._sizes = np.array(dataset.sizes) + 1
18
+ else:
19
+ self._sizes = dataset.sizes
20
+
21
+ def __getitem__(self, idx):
22
+ item = self.dataset[idx]
23
+ if self.token is not None:
24
+ item = torch.cat([item, item.new([self.token])])
25
+ return item
26
+
27
+ @property
28
+ def sizes(self):
29
+ return self._sizes
30
+
31
+ def num_tokens(self, index):
32
+ n = self.dataset.num_tokens(index)
33
+ if self.token is not None:
34
+ n += 1
35
+ return n
36
+
37
+ def size(self, index):
38
+ n = self.dataset.size(index)
39
+ if self.token is not None:
40
+ n += 1
41
+ return n
modules/voice_conversion/fairseq/data/audio/__init__.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Optional
3
+ import importlib
4
+ import os
5
+ import numpy as np
6
+
7
+
8
+ class AudioTransform(ABC):
9
+ @classmethod
10
+ @abstractmethod
11
+ def from_config_dict(cls, config: Optional[Dict] = None):
12
+ pass
13
+
14
+
15
+ class CompositeAudioTransform(AudioTransform):
16
+ def _from_config_dict(
17
+ cls,
18
+ transform_type,
19
+ get_audio_transform,
20
+ composite_cls,
21
+ config=None,
22
+ return_empty=False,
23
+ ):
24
+ _config = {} if config is None else config
25
+ _transforms = _config.get(f"{transform_type}_transforms")
26
+
27
+ if _transforms is None:
28
+ if return_empty:
29
+ _transforms = []
30
+ else:
31
+ return None
32
+
33
+ transforms = [
34
+ get_audio_transform(_t).from_config_dict(_config.get(_t))
35
+ for _t in _transforms
36
+ ]
37
+ return composite_cls(transforms)
38
+
39
+ def __init__(self, transforms):
40
+ self.transforms = [t for t in transforms if t is not None]
41
+
42
+ def __call__(self, x):
43
+ for t in self.transforms:
44
+ x = t(x)
45
+ return x
46
+
47
+ def __repr__(self):
48
+ format_string = (
49
+ [self.__class__.__name__ + "("]
50
+ + [f" {t.__repr__()}" for t in self.transforms]
51
+ + [")"]
52
+ )
53
+ return "\n".join(format_string)
54
+
55
+
56
+ def register_audio_transform(name, cls_type, registry, class_names):
57
+ def register_audio_transform_cls(cls):
58
+ if name in registry:
59
+ raise ValueError(f"Cannot register duplicate transform ({name})")
60
+ if not issubclass(cls, cls_type):
61
+ raise ValueError(
62
+ f"Transform ({name}: {cls.__name__}) must extend "
63
+ f"{cls_type.__name__}"
64
+ )
65
+ if cls.__name__ in class_names:
66
+ raise ValueError(
67
+ f"Cannot register audio transform with duplicate "
68
+ f"class name ({cls.__name__})"
69
+ )
70
+ registry[name] = cls
71
+ class_names.add(cls.__name__)
72
+ return cls
73
+
74
+ return register_audio_transform_cls
75
+
76
+
77
+ def import_transforms(transforms_dir, transform_type):
78
+ for file in os.listdir(transforms_dir):
79
+ path = os.path.join(transforms_dir, file)
80
+ if (
81
+ not file.startswith("_")
82
+ and not file.startswith(".")
83
+ and (file.endswith(".py") or os.path.isdir(path))
84
+ ):
85
+ name = file[: file.find(".py")] if file.endswith(".py") else file
86
+ importlib.import_module(
87
+ f"fairseq.data.audio.{transform_type}_transforms." + name
88
+ )
89
+
90
+
91
+ # Utility fn for uniform numbers in transforms
92
+ def rand_uniform(a, b):
93
+ return np.random.uniform() * (b - a) + a
modules/voice_conversion/fairseq/data/audio/audio_utils.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import mmap
8
+ from pathlib import Path
9
+ import io
10
+ from typing import BinaryIO, List, Optional, Tuple, Union
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+ from fairseq.data.audio.waveform_transforms import CompositeAudioWaveformTransform
17
+
18
+ SF_AUDIO_FILE_EXTENSIONS = {".wav", ".flac", ".ogg"}
19
+ FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS = {".npy", ".wav", ".flac", ".ogg"}
20
+
21
+
22
+ def convert_waveform(
23
+ waveform: Union[np.ndarray, torch.Tensor],
24
+ sample_rate: int,
25
+ normalize_volume: bool = False,
26
+ to_mono: bool = False,
27
+ to_sample_rate: Optional[int] = None,
28
+ ) -> Tuple[Union[np.ndarray, torch.Tensor], int]:
29
+ """convert a waveform:
30
+ - to a target sample rate
31
+ - from multi-channel to mono channel
32
+ - volume normalization
33
+
34
+ Args:
35
+ waveform (numpy.ndarray or torch.Tensor): 2D original waveform
36
+ (channels x length)
37
+ sample_rate (int): original sample rate
38
+ normalize_volume (bool): perform volume normalization
39
+ to_mono (bool): convert to mono channel if having multiple channels
40
+ to_sample_rate (Optional[int]): target sample rate
41
+ Returns:
42
+ waveform (numpy.ndarray): converted 2D waveform (channels x length)
43
+ sample_rate (float): target sample rate
44
+ """
45
+ try:
46
+ import torchaudio.sox_effects as ta_sox
47
+ except ImportError:
48
+ raise ImportError("Please install torchaudio: pip install torchaudio")
49
+
50
+ effects = []
51
+ if normalize_volume:
52
+ effects.append(["gain", "-n"])
53
+ if to_sample_rate is not None and to_sample_rate != sample_rate:
54
+ effects.append(["rate", f"{to_sample_rate}"])
55
+ if to_mono and waveform.shape[0] > 1:
56
+ effects.append(["channels", "1"])
57
+ if len(effects) > 0:
58
+ is_np_input = isinstance(waveform, np.ndarray)
59
+ _waveform = torch.from_numpy(waveform) if is_np_input else waveform
60
+ converted, converted_sample_rate = ta_sox.apply_effects_tensor(
61
+ _waveform, sample_rate, effects
62
+ )
63
+ if is_np_input:
64
+ converted = converted.numpy()
65
+ return converted, converted_sample_rate
66
+ return waveform, sample_rate
67
+
68
+
69
+ def get_waveform(
70
+ path_or_fp: Union[str, BinaryIO],
71
+ normalization: bool = True,
72
+ mono: bool = True,
73
+ frames: int = -1,
74
+ start: int = 0,
75
+ always_2d: bool = True,
76
+ output_sample_rate: Optional[int] = None,
77
+ normalize_volume: bool = False,
78
+ waveform_transforms: Optional[CompositeAudioWaveformTransform] = None,
79
+ ) -> Tuple[np.ndarray, int]:
80
+ """Get the waveform and sample rate of a 16-bit WAV/FLAC/OGG Vorbis audio.
81
+
82
+ Args:
83
+ path_or_fp (str or BinaryIO): the path or file-like object
84
+ normalization (bool): normalize values to [-1, 1] (Default: True)
85
+ mono (bool): convert multi-channel audio to mono-channel one
86
+ frames (int): the number of frames to read. (-1 for reading all)
87
+ start (int): Where to start reading. A negative value counts from the end.
88
+ always_2d (bool): always return 2D array even for mono-channel audios
89
+ output_sample_rate (Optional[int]): output sample rate
90
+ normalize_volume (bool): normalize volume
91
+ Returns:
92
+ waveform (numpy.ndarray): 1D or 2D waveform (channels x length)
93
+ sample_rate (float): sample rate
94
+ """
95
+ if isinstance(path_or_fp, str):
96
+ ext = Path(path_or_fp).suffix
97
+ if ext not in SF_AUDIO_FILE_EXTENSIONS:
98
+ raise ValueError(f"Unsupported audio format: {ext}")
99
+
100
+ try:
101
+ import soundfile as sf
102
+ except ImportError:
103
+ raise ImportError("Please install soundfile: pip install soundfile")
104
+
105
+ waveform, sample_rate = sf.read(
106
+ path_or_fp, dtype="float32", always_2d=True, frames=frames, start=start
107
+ )
108
+ waveform = waveform.T # T x C -> C x T
109
+ waveform, sample_rate = convert_waveform(
110
+ waveform,
111
+ sample_rate,
112
+ normalize_volume=normalize_volume,
113
+ to_mono=mono,
114
+ to_sample_rate=output_sample_rate,
115
+ )
116
+
117
+ if not normalization:
118
+ waveform *= 2**15 # denormalized to 16-bit signed integers
119
+
120
+ if waveform_transforms is not None:
121
+ waveform, sample_rate = waveform_transforms(waveform, sample_rate)
122
+
123
+ if not always_2d:
124
+ waveform = waveform.squeeze(axis=0)
125
+
126
+ return waveform, sample_rate
127
+
128
+
129
+ def get_features_from_npy_or_audio(path, waveform_transforms=None):
130
+ ext = Path(path).suffix
131
+ if ext not in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS:
132
+ raise ValueError(f'Unsupported file format for "{path}"')
133
+ return (
134
+ np.load(path)
135
+ if ext == ".npy"
136
+ else get_fbank(path, waveform_transforms=waveform_transforms)
137
+ )
138
+
139
+
140
+ def get_features_or_waveform_from_stored_zip(
141
+ path,
142
+ byte_offset,
143
+ byte_size,
144
+ need_waveform=False,
145
+ use_sample_rate=None,
146
+ waveform_transforms=None,
147
+ ):
148
+ assert path.endswith(".zip")
149
+ data = read_from_stored_zip(path, byte_offset, byte_size)
150
+ f = io.BytesIO(data)
151
+ if is_npy_data(data):
152
+ features_or_waveform = np.load(f)
153
+ elif is_sf_audio_data(data):
154
+ features_or_waveform = (
155
+ get_waveform(
156
+ f,
157
+ always_2d=False,
158
+ output_sample_rate=use_sample_rate,
159
+ waveform_transforms=waveform_transforms,
160
+ )[0]
161
+ if need_waveform
162
+ else get_fbank(f, waveform_transforms=waveform_transforms)
163
+ )
164
+ else:
165
+ raise ValueError(f'Unknown file format for "{path}"')
166
+ return features_or_waveform
167
+
168
+
169
+ def get_features_or_waveform(
170
+ path: str, need_waveform=False, use_sample_rate=None, waveform_transforms=None
171
+ ):
172
+ """Get speech features from .npy file or waveform from .wav/.flac file.
173
+ The file may be inside an uncompressed ZIP file and is accessed via byte
174
+ offset and length.
175
+
176
+ Args:
177
+ path (str): File path in the format of "<.npy/.wav/.flac path>" or
178
+ "<zip path>:<byte offset>:<byte length>".
179
+ need_waveform (bool): return waveform instead of features.
180
+ use_sample_rate (int): change sample rate for the input wave file
181
+
182
+ Returns:
183
+ features_or_waveform (numpy.ndarray): speech features or waveform.
184
+ """
185
+ _path, slice_ptr = parse_path(path)
186
+ if len(slice_ptr) == 0:
187
+ if need_waveform:
188
+ return get_waveform(
189
+ _path,
190
+ always_2d=False,
191
+ output_sample_rate=use_sample_rate,
192
+ waveform_transforms=waveform_transforms,
193
+ )[0]
194
+ return get_features_from_npy_or_audio(
195
+ _path, waveform_transforms=waveform_transforms
196
+ )
197
+ elif len(slice_ptr) == 2:
198
+ features_or_waveform = get_features_or_waveform_from_stored_zip(
199
+ _path,
200
+ slice_ptr[0],
201
+ slice_ptr[1],
202
+ need_waveform=need_waveform,
203
+ use_sample_rate=use_sample_rate,
204
+ waveform_transforms=waveform_transforms,
205
+ )
206
+ else:
207
+ raise ValueError(f"Invalid path: {path}")
208
+
209
+ return features_or_waveform
210
+
211
+
212
+ def _get_kaldi_fbank(
213
+ waveform: np.ndarray, sample_rate: int, n_bins=80
214
+ ) -> Optional[np.ndarray]:
215
+ """Get mel-filter bank features via PyKaldi."""
216
+ try:
217
+ from kaldi.feat.fbank import Fbank, FbankOptions
218
+ from kaldi.feat.mel import MelBanksOptions
219
+ from kaldi.feat.window import FrameExtractionOptions
220
+ from kaldi.matrix import Vector
221
+
222
+ mel_opts = MelBanksOptions()
223
+ mel_opts.num_bins = n_bins
224
+ frame_opts = FrameExtractionOptions()
225
+ frame_opts.samp_freq = sample_rate
226
+ opts = FbankOptions()
227
+ opts.mel_opts = mel_opts
228
+ opts.frame_opts = frame_opts
229
+ fbank = Fbank(opts=opts)
230
+ features = fbank.compute(Vector(waveform.squeeze()), 1.0).numpy()
231
+ return features
232
+ except ImportError:
233
+ return None
234
+
235
+
236
+ def _get_torchaudio_fbank(
237
+ waveform: np.ndarray, sample_rate, n_bins=80
238
+ ) -> Optional[np.ndarray]:
239
+ """Get mel-filter bank features via TorchAudio."""
240
+ try:
241
+ import torchaudio.compliance.kaldi as ta_kaldi
242
+
243
+ waveform = torch.from_numpy(waveform)
244
+ features = ta_kaldi.fbank(
245
+ waveform, num_mel_bins=n_bins, sample_frequency=sample_rate
246
+ )
247
+ return features.numpy()
248
+ except ImportError:
249
+ return None
250
+
251
+
252
+ def get_fbank(
253
+ path_or_fp: Union[str, BinaryIO], n_bins=80, waveform_transforms=None
254
+ ) -> np.ndarray:
255
+ """Get mel-filter bank features via PyKaldi or TorchAudio. Prefer PyKaldi
256
+ (faster CPP implementation) to TorchAudio (Python implementation). Note that
257
+ Kaldi/TorchAudio requires 16-bit signed integers as inputs and hence the
258
+ waveform should not be normalized."""
259
+ waveform, sample_rate = get_waveform(
260
+ path_or_fp, normalization=False, waveform_transforms=waveform_transforms
261
+ )
262
+
263
+ features = _get_kaldi_fbank(waveform, sample_rate, n_bins)
264
+ if features is None:
265
+ features = _get_torchaudio_fbank(waveform, sample_rate, n_bins)
266
+ if features is None:
267
+ raise ImportError(
268
+ "Please install pyKaldi or torchaudio to enable "
269
+ "online filterbank feature extraction"
270
+ )
271
+
272
+ return features
273
+
274
+
275
+ def is_npy_data(data: bytes) -> bool:
276
+ return data[0] == 147 and data[1] == 78
277
+
278
+
279
+ def is_sf_audio_data(data: bytes) -> bool:
280
+ is_wav = data[0] == 82 and data[1] == 73 and data[2] == 70
281
+ is_flac = data[0] == 102 and data[1] == 76 and data[2] == 97
282
+ is_ogg = data[0] == 79 and data[1] == 103 and data[2] == 103
283
+ return is_wav or is_flac or is_ogg
284
+
285
+
286
+ def mmap_read(path: str, offset: int, length: int) -> bytes:
287
+ with open(path, "rb") as f:
288
+ with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_o:
289
+ data = mmap_o[offset : offset + length]
290
+ return data
291
+
292
+
293
+ def read_from_stored_zip(zip_path: str, offset: int, length: int) -> bytes:
294
+ return mmap_read(zip_path, offset, length)
295
+
296
+
297
+ def parse_path(path: str) -> Tuple[str, List[int]]:
298
+ """Parse data path which is either a path to
299
+ 1. a .npy/.wav/.flac/.ogg file
300
+ 2. a stored ZIP file with slicing info: "[zip_path]:[offset]:[length]"
301
+
302
+ Args:
303
+ path (str): the data path to parse
304
+
305
+ Returns:
306
+ file_path (str): the file path
307
+ slice_ptr (list of int): empty in case 1;
308
+ byte offset and length for the slice in case 2
309
+ """
310
+
311
+ if Path(path).suffix in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS:
312
+ _path, slice_ptr = path, []
313
+ else:
314
+ _path, *slice_ptr = path.split(":")
315
+ if not Path(_path).is_file():
316
+ raise FileNotFoundError(f"File not found: {_path}")
317
+ assert len(slice_ptr) in {0, 2}, f"Invalid path: {path}"
318
+ slice_ptr = [int(i) for i in slice_ptr]
319
+ return _path, slice_ptr
320
+
321
+
322
+ def get_window(window_fn: callable, n_fft: int, win_length: int) -> torch.Tensor:
323
+ padding = n_fft - win_length
324
+ assert padding >= 0
325
+ return F.pad(window_fn(win_length), (padding // 2, padding - padding // 2))
326
+
327
+
328
+ def get_fourier_basis(n_fft: int) -> torch.Tensor:
329
+ basis = np.fft.fft(np.eye(n_fft))
330
+ basis = np.vstack(
331
+ [np.real(basis[: n_fft // 2 + 1, :]), np.imag(basis[: n_fft // 2 + 1, :])]
332
+ )
333
+ return torch.from_numpy(basis).float()
334
+
335
+
336
+ def get_mel_filters(
337
+ sample_rate: int, n_fft: int, n_mels: int, f_min: float, f_max: float
338
+ ) -> torch.Tensor:
339
+ try:
340
+ import librosa
341
+ except ImportError:
342
+ raise ImportError("Please install librosa: pip install librosa")
343
+ basis = librosa.filters.mel(sample_rate, n_fft, n_mels, f_min, f_max)
344
+ return torch.from_numpy(basis).float()
345
+
346
+
347
+ class TTSSpectrogram(torch.nn.Module):
348
+ def __init__(
349
+ self,
350
+ n_fft: int,
351
+ win_length: int,
352
+ hop_length: int,
353
+ window_fn: callable = torch.hann_window,
354
+ return_phase: bool = False,
355
+ ) -> None:
356
+ super(TTSSpectrogram, self).__init__()
357
+ self.n_fft = n_fft
358
+ self.hop_length = hop_length
359
+ self.return_phase = return_phase
360
+
361
+ basis = get_fourier_basis(n_fft).unsqueeze(1)
362
+ basis *= get_window(window_fn, n_fft, win_length)
363
+ self.register_buffer("basis", basis)
364
+
365
+ def forward(
366
+ self, waveform: torch.Tensor
367
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
368
+ padding = (self.n_fft // 2, self.n_fft // 2)
369
+ x = F.pad(waveform.unsqueeze(1), padding, mode="reflect")
370
+ x = F.conv1d(x, self.basis, stride=self.hop_length)
371
+ real_part = x[:, : self.n_fft // 2 + 1, :]
372
+ imag_part = x[:, self.n_fft // 2 + 1 :, :]
373
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
374
+ if self.return_phase:
375
+ phase = torch.atan2(imag_part, real_part)
376
+ return magnitude, phase
377
+ return magnitude
378
+
379
+
380
+ class TTSMelScale(torch.nn.Module):
381
+ def __init__(
382
+ self, n_mels: int, sample_rate: int, f_min: float, f_max: float, n_stft: int
383
+ ) -> None:
384
+ super(TTSMelScale, self).__init__()
385
+ basis = get_mel_filters(sample_rate, (n_stft - 1) * 2, n_mels, f_min, f_max)
386
+ self.register_buffer("basis", basis)
387
+
388
+ def forward(self, specgram: torch.Tensor) -> torch.Tensor:
389
+ return torch.matmul(self.basis, specgram)
modules/voice_conversion/fairseq/data/audio/data_cfg.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from argparse import Namespace
8
+ from copy import deepcopy
9
+ from pathlib import Path
10
+ from typing import Dict, Optional
11
+
12
+ from fairseq.data import Dictionary
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def get_config_from_yaml(yaml_path: Path):
18
+ try:
19
+ import yaml
20
+ except ImportError:
21
+ print("Please install PyYAML: pip install PyYAML")
22
+ config = {}
23
+ if yaml_path.is_file():
24
+ try:
25
+ with open(yaml_path) as f:
26
+ config = yaml.load(f, Loader=yaml.FullLoader)
27
+ except Exception as e:
28
+ raise Exception(f"Failed to load config from {yaml_path.as_posix()}: {e}")
29
+ else:
30
+ raise FileNotFoundError(f"{yaml_path.as_posix()} not found")
31
+
32
+ return config
33
+
34
+
35
+ class S2TDataConfig(object):
36
+ """Wrapper class for data config YAML"""
37
+
38
+ def __init__(self, yaml_path: Path):
39
+ self.config = get_config_from_yaml(yaml_path)
40
+ self.root = yaml_path.parent
41
+
42
+ def _auto_convert_to_abs_path(self, x):
43
+ if isinstance(x, str):
44
+ if not Path(x).exists() and (self.root / x).exists():
45
+ return (self.root / x).as_posix()
46
+ elif isinstance(x, dict):
47
+ return {k: self._auto_convert_to_abs_path(v) for k, v in x.items()}
48
+ return x
49
+
50
+ @property
51
+ def vocab_filename(self):
52
+ """fairseq vocabulary file under data root"""
53
+ return self.config.get("vocab_filename", "dict.txt")
54
+
55
+ @property
56
+ def speaker_set_filename(self):
57
+ """speaker set file under data root"""
58
+ return self.config.get("speaker_set_filename", None)
59
+
60
+ @property
61
+ def shuffle(self) -> bool:
62
+ """Shuffle dataset samples before batching"""
63
+ return self.config.get("shuffle", False)
64
+
65
+ @property
66
+ def pre_tokenizer(self) -> Dict:
67
+ """Pre-tokenizer to apply before subword tokenization. Returning
68
+ a dictionary with `tokenizer` providing the tokenizer name and
69
+ the other items providing the tokenizer-specific arguments.
70
+ Tokenizers are defined in `fairseq.data.encoders.*`"""
71
+ tokenizer = self.config.get("pre_tokenizer", {"tokenizer": None})
72
+ return self._auto_convert_to_abs_path(tokenizer)
73
+
74
+ @property
75
+ def bpe_tokenizer(self) -> Dict:
76
+ """Subword tokenizer to apply after pre-tokenization. Returning
77
+ a dictionary with `bpe` providing the tokenizer name and
78
+ the other items providing the tokenizer-specific arguments.
79
+ Tokenizers are defined in `fairseq.data.encoders.*`"""
80
+ tokenizer = self.config.get("bpe_tokenizer", {"bpe": None})
81
+ return self._auto_convert_to_abs_path(tokenizer)
82
+
83
+ @property
84
+ def prepend_tgt_lang_tag(self) -> bool:
85
+ """Prepend target lang ID token as the target BOS (e.g. for to-many
86
+ multilingual setting). During inference, this requires `--prefix-size 1`
87
+ to force BOS to be lang ID token."""
88
+ return self.config.get("prepend_tgt_lang_tag", False)
89
+
90
+ @property
91
+ def prepend_bos_and_append_tgt_lang_tag(self) -> bool:
92
+ """Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining)."""
93
+ return self.config.get("prepend_bos_and_append_tgt_lang_tag", False)
94
+
95
+ @property
96
+ def input_feat_per_channel(self):
97
+ """The dimension of input features (per audio channel)"""
98
+ return self.config.get("input_feat_per_channel", 80)
99
+
100
+ @property
101
+ def input_channels(self):
102
+ """The number of channels in the input audio"""
103
+ return self.config.get("input_channels", 1)
104
+
105
+ @property
106
+ def sample_rate(self):
107
+ return self.config.get("sample_rate", 16_000)
108
+
109
+ @property
110
+ def sampling_alpha(self):
111
+ """Hyper-parameter alpha = 1/T for temperature-based resampling.
112
+ (alpha = 1 for no resampling)"""
113
+ return self.config.get("sampling_alpha", 1.0)
114
+
115
+ @property
116
+ def use_audio_input(self):
117
+ """Needed by the dataset loader to see if the model requires
118
+ raw audio as inputs."""
119
+ return self.config.get("use_audio_input", False)
120
+
121
+ def standardize_audio(self) -> bool:
122
+ return self.use_audio_input and self.config.get("standardize_audio", False)
123
+
124
+ @property
125
+ def use_sample_rate(self):
126
+ """Needed by the dataset loader to see if the model requires
127
+ raw audio with specific sample rate as inputs."""
128
+ return self.config.get("use_sample_rate", 16000)
129
+
130
+ @property
131
+ def audio_root(self):
132
+ """Audio paths in the manifest TSV can be relative and this provides
133
+ the root path. Set this to empty string when using absolute paths."""
134
+ return self.config.get("audio_root", "")
135
+
136
+ def get_transforms(self, transform_type, split, is_train):
137
+ """Split-specific feature transforms. Allowing train set
138
+ wildcard `_train`, evaluation set wildcard `_eval` and general
139
+ wildcard `*` for matching."""
140
+ from copy import deepcopy
141
+
142
+ cfg = deepcopy(self.config)
143
+ _cur = cfg.get(f"{transform_type}transforms", {})
144
+ cur = _cur.get(split)
145
+ cur = _cur.get("_train") if cur is None and is_train else cur
146
+ cur = _cur.get("_eval") if cur is None and not is_train else cur
147
+ cur = _cur.get("*") if cur is None else cur
148
+ return cur
149
+
150
+ def get_feature_transforms(self, split, is_train):
151
+ cfg = deepcopy(self.config)
152
+ # TODO: deprecate transforms
153
+ cur = self.get_transforms("", split, is_train)
154
+ if cur is not None:
155
+ logger.warning(
156
+ "Auto converting transforms into feature_transforms, "
157
+ "but transforms will be deprecated in the future. Please "
158
+ "update this in the config."
159
+ )
160
+ ft_transforms = self.get_transforms("feature_", split, is_train)
161
+ if ft_transforms:
162
+ cur.extend(ft_transforms)
163
+ else:
164
+ cur = self.get_transforms("feature_", split, is_train)
165
+ cfg["feature_transforms"] = cur
166
+ return cfg
167
+
168
+ def get_waveform_transforms(self, split, is_train):
169
+ cfg = deepcopy(self.config)
170
+ cfg["waveform_transforms"] = self.get_transforms("waveform_", split, is_train)
171
+ return cfg
172
+
173
+ def get_dataset_transforms(self, split, is_train):
174
+ cfg = deepcopy(self.config)
175
+ cfg["dataset_transforms"] = self.get_transforms("dataset_", split, is_train)
176
+ return cfg
177
+
178
+ @property
179
+ def global_cmvn_stats_npz(self) -> Optional[str]:
180
+ path = self.config.get("global_cmvn", {}).get("stats_npz_path", None)
181
+ return self._auto_convert_to_abs_path(path)
182
+
183
+ @property
184
+ def vocoder(self) -> Dict[str, str]:
185
+ vocoder = self.config.get("vocoder", {"type": "griffin_lim"})
186
+ return self._auto_convert_to_abs_path(vocoder)
187
+
188
+ @property
189
+ def hub(self) -> Dict[str, str]:
190
+ return self.config.get("hub", {})
191
+
192
+
193
+ class S2SDataConfig(S2TDataConfig):
194
+ """Wrapper class for data config YAML"""
195
+
196
+ @property
197
+ def vocab_filename(self):
198
+ """fairseq vocabulary file under data root"""
199
+ return self.config.get("vocab_filename", None)
200
+
201
+ @property
202
+ def pre_tokenizer(self) -> Dict:
203
+ return None
204
+
205
+ @property
206
+ def bpe_tokenizer(self) -> Dict:
207
+ return None
208
+
209
+ @property
210
+ def input_transformed_channels(self):
211
+ """The number of channels in the audio after feature transforms"""
212
+ # TODO: move this into individual transforms
213
+ # TODO: deprecate transforms
214
+ _cur = self.config.get("transforms", {})
215
+ ft_transforms = self.config.get("feature_transforms", {})
216
+ if _cur and ft_transforms:
217
+ _cur.update(ft_transforms)
218
+ else:
219
+ _cur = self.config.get("feature_transforms", {})
220
+ cur = _cur.get("_train", [])
221
+
222
+ _channels = self.input_channels
223
+ if "delta_deltas" in cur:
224
+ _channels *= 3
225
+
226
+ return _channels
227
+
228
+ @property
229
+ def output_sample_rate(self):
230
+ """The audio sample rate of output target speech"""
231
+ return self.config.get("output_sample_rate", 22050)
232
+
233
+ @property
234
+ def target_speaker_embed(self):
235
+ """Target speaker embedding file (one line per target audio sample)"""
236
+ return self.config.get("target_speaker_embed", None)
237
+
238
+ @property
239
+ def prepend_tgt_lang_tag_as_bos(self) -> bool:
240
+ """Prepend target lang ID token as the target BOS."""
241
+ return self.config.get("prepend_tgt_lang_tag_as_bos", False)
242
+
243
+
244
+ class MultitaskConfig(object):
245
+ """Wrapper class for data config YAML"""
246
+
247
+ def __init__(self, yaml_path: Path):
248
+ config = get_config_from_yaml(yaml_path)
249
+ self.config = {}
250
+ for k, v in config.items():
251
+ self.config[k] = SingleTaskConfig(k, v)
252
+
253
+ def get_all_tasks(self):
254
+ return self.config
255
+
256
+ def get_single_task(self, name):
257
+ assert name in self.config, f"multitask '{name}' does not exist!"
258
+ return self.config[name]
259
+
260
+ @property
261
+ def first_pass_decoder_task_index(self):
262
+ """Return the task index of the first-pass text decoder.
263
+ If there are multiple 'is_first_pass_decoder: True' in the config file,
264
+ the last task is used for the first-pass decoder.
265
+ If there is no 'is_first_pass_decoder: True' in the config file,
266
+ the last task whose task_name includes 'target' and decoder_type is not ctc.
267
+ """
268
+ idx = -1
269
+ for i, (k, v) in enumerate(self.config.items()):
270
+ if v.is_first_pass_decoder:
271
+ idx = i
272
+ if idx < 0:
273
+ for i, (k, v) in enumerate(self.config.items()):
274
+ if k.startswith("target") and v.decoder_type == "transformer":
275
+ idx = i
276
+ return idx
277
+
278
+
279
+ class SingleTaskConfig(object):
280
+ def __init__(self, name, config):
281
+ self.task_name = name
282
+ self.config = config
283
+ dict_path = config.get("dict", "")
284
+ self.tgt_dict = Dictionary.load(dict_path) if Path(dict_path).exists() else None
285
+
286
+ @property
287
+ def data(self):
288
+ return self.config.get("data", "")
289
+
290
+ @property
291
+ def decoder_type(self):
292
+ return self.config.get("decoder_type", "transformer")
293
+
294
+ @property
295
+ def decoder_args(self):
296
+ """Decoder arch related args"""
297
+ args = self.config.get("decoder_args", {})
298
+ return Namespace(**args)
299
+
300
+ @property
301
+ def criterion_cfg(self):
302
+ """cfg for the multitask criterion"""
303
+ if self.decoder_type == "ctc":
304
+ from fairseq.criterions.ctc import CtcCriterionConfig
305
+
306
+ cfg = CtcCriterionConfig
307
+ cfg.zero_infinity = self.config.get("zero_infinity", True)
308
+ else:
309
+ from fairseq.criterions.label_smoothed_cross_entropy import (
310
+ LabelSmoothedCrossEntropyCriterionConfig,
311
+ )
312
+
313
+ cfg = LabelSmoothedCrossEntropyCriterionConfig
314
+ cfg.label_smoothing = self.config.get("label_smoothing", 0.2)
315
+ return cfg
316
+
317
+ @property
318
+ def input_from(self):
319
+ """Condition on encoder/decoder of the main model"""
320
+ return "decoder" if "decoder_layer" in self.config else "encoder"
321
+
322
+ @property
323
+ def input_layer(self):
324
+ if self.input_from == "decoder":
325
+ return self.config["decoder_layer"] - 1
326
+ else:
327
+ # default using the output from the last encoder layer (-1)
328
+ return self.config.get("encoder_layer", 0) - 1
329
+
330
+ @property
331
+ def loss_weight_schedule(self):
332
+ return (
333
+ "decay"
334
+ if "loss_weight_max" in self.config
335
+ and "loss_weight_decay_steps" in self.config
336
+ else "fixed"
337
+ )
338
+
339
+ def get_loss_weight(self, num_updates):
340
+ if self.loss_weight_schedule == "fixed":
341
+ weight = self.config.get("loss_weight", 1.0)
342
+ else: # "decay"
343
+ assert (
344
+ self.config.get("loss_weight_decay_steps", 0) > 0
345
+ ), "loss_weight_decay_steps must be greater than 0 for a decay schedule"
346
+ loss_weight_min = self.config.get("loss_weight_min", 0.0001)
347
+ loss_weight_decay_stepsize = (
348
+ self.config["loss_weight_max"] - loss_weight_min
349
+ ) / self.config["loss_weight_decay_steps"]
350
+ weight = max(
351
+ self.config["loss_weight_max"]
352
+ - loss_weight_decay_stepsize * num_updates,
353
+ loss_weight_min,
354
+ )
355
+ return weight
356
+
357
+ @property
358
+ def prepend_bos_and_append_tgt_lang_tag(self) -> bool:
359
+ """Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining)."""
360
+ return self.config.get("prepend_bos_and_append_tgt_lang_tag", False)
361
+
362
+ @property
363
+ def eos_token(self):
364
+ """EOS token during generation"""
365
+ return self.config.get("eos_token", "<eos>")
366
+
367
+ @property
368
+ def rdrop_alpha(self):
369
+ return self.config.get("rdrop_alpha", 0.0)
370
+
371
+ @property
372
+ def is_first_pass_decoder(self):
373
+ flag = self.config.get("is_first_pass_decoder", False)
374
+ if flag:
375
+ if self.decoder_type == "ctc":
376
+ raise ValueError(
377
+ "First-pass decoder in the multi-decoder model must not be CTC."
378
+ )
379
+ if "target" not in self.task_name:
380
+ raise Warning(
381
+ 'The name of the first-pass decoder does not include "target".'
382
+ )
383
+ return flag
384
+
385
+ @property
386
+ def get_lang_tag_mapping(self):
387
+ return self.config.get("lang_tag_mapping", {})
modules/voice_conversion/fairseq/data/audio/dataset_transforms/__init__.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from fairseq.data.audio import (
3
+ AudioTransform,
4
+ CompositeAudioTransform,
5
+ import_transforms,
6
+ register_audio_transform,
7
+ )
8
+
9
+
10
+ class AudioDatasetTransform(AudioTransform):
11
+ pass
12
+
13
+
14
+ AUDIO_DATASET_TRANSFORM_REGISTRY = {}
15
+ AUDIO_DATASET_TRANSFORM_CLASS_NAMES = set()
16
+
17
+
18
+ def get_audio_dataset_transform(name):
19
+ return AUDIO_DATASET_TRANSFORM_REGISTRY[name]
20
+
21
+
22
+ def register_audio_dataset_transform(name):
23
+ return register_audio_transform(
24
+ name,
25
+ AudioDatasetTransform,
26
+ AUDIO_DATASET_TRANSFORM_REGISTRY,
27
+ AUDIO_DATASET_TRANSFORM_CLASS_NAMES,
28
+ )
29
+
30
+
31
+ import_transforms(os.path.dirname(__file__), "dataset")
32
+
33
+
34
+ class CompositeAudioDatasetTransform(CompositeAudioTransform):
35
+ @classmethod
36
+ def from_config_dict(cls, config=None):
37
+ return super()._from_config_dict(
38
+ cls,
39
+ "dataset",
40
+ get_audio_dataset_transform,
41
+ CompositeAudioDatasetTransform,
42
+ config,
43
+ return_empty=True,
44
+ )
45
+
46
+ def get_transform(self, cls):
47
+ for t in self.transforms:
48
+ if isinstance(t, cls):
49
+ return t
50
+ return None
51
+
52
+ def has_transform(self, cls):
53
+ return self.get_transform(cls) is not None
modules/voice_conversion/fairseq/data/audio/dataset_transforms/concataugment.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import numpy as np
3
+
4
+ from fairseq.data.audio.dataset_transforms import (
5
+ AudioDatasetTransform,
6
+ register_audio_dataset_transform,
7
+ )
8
+
9
+ _DEFAULTS = {"rate": 0.25, "max_tokens": 3000, "attempts": 5}
10
+
11
+
12
+ @register_audio_dataset_transform("concataugment")
13
+ class ConcatAugment(AudioDatasetTransform):
14
+ @classmethod
15
+ def from_config_dict(cls, config=None):
16
+ _config = {} if config is None else config
17
+ return ConcatAugment(
18
+ _config.get("rate", _DEFAULTS["rate"]),
19
+ _config.get("max_tokens", _DEFAULTS["max_tokens"]),
20
+ _config.get("attempts", _DEFAULTS["attempts"]),
21
+ )
22
+
23
+ def __init__(
24
+ self,
25
+ rate=_DEFAULTS["rate"],
26
+ max_tokens=_DEFAULTS["max_tokens"],
27
+ attempts=_DEFAULTS["attempts"],
28
+ ):
29
+ self.rate, self.max_tokens, self.attempts = rate, max_tokens, attempts
30
+
31
+ def __repr__(self):
32
+ return (
33
+ self.__class__.__name__
34
+ + "("
35
+ + ", ".join(
36
+ [
37
+ f"rate={self.rate}",
38
+ f"max_tokens={self.max_tokens}",
39
+ f"attempts={self.attempts}",
40
+ ]
41
+ )
42
+ + ")"
43
+ )
44
+
45
+ def find_indices(self, index: int, n_frames: List[int], n_samples: int):
46
+ # skip conditions: application rate, max_tokens limit exceeded
47
+ if np.random.random() > self.rate:
48
+ return [index]
49
+ if self.max_tokens and n_frames[index] > self.max_tokens:
50
+ return [index]
51
+
52
+ # pick second sample to concatenate
53
+ for _ in range(self.attempts):
54
+ index2 = np.random.randint(0, n_samples)
55
+ if index2 != index and (
56
+ not self.max_tokens
57
+ or n_frames[index] + n_frames[index2] < self.max_tokens
58
+ ):
59
+ return [index, index2]
60
+
61
+ return [index]
modules/voice_conversion/fairseq/data/audio/dataset_transforms/noisyoverlapaugment.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ from fairseq.data.audio import rand_uniform
5
+ from fairseq.data.audio.dataset_transforms import (
6
+ AudioDatasetTransform,
7
+ register_audio_dataset_transform,
8
+ )
9
+ from fairseq.data.audio.waveform_transforms.noiseaugment import (
10
+ NoiseAugmentTransform,
11
+ )
12
+
13
+ _DEFAULTS = {
14
+ "rate": 0.25,
15
+ "mixing_noise_rate": 0.1,
16
+ "noise_path": "",
17
+ "noise_snr_min": -5,
18
+ "noise_snr_max": 5,
19
+ "utterance_snr_min": -5,
20
+ "utterance_snr_max": 5,
21
+ }
22
+
23
+
24
+ @register_audio_dataset_transform("noisyoverlapaugment")
25
+ class NoisyOverlapAugment(AudioDatasetTransform):
26
+ @classmethod
27
+ def from_config_dict(cls, config=None):
28
+ _config = {} if config is None else config
29
+ return NoisyOverlapAugment(
30
+ _config.get("rate", _DEFAULTS["rate"]),
31
+ _config.get("mixing_noise_rate", _DEFAULTS["mixing_noise_rate"]),
32
+ _config.get("noise_path", _DEFAULTS["noise_path"]),
33
+ _config.get("noise_snr_min", _DEFAULTS["noise_snr_min"]),
34
+ _config.get("noise_snr_max", _DEFAULTS["noise_snr_max"]),
35
+ _config.get("utterance_snr_min", _DEFAULTS["utterance_snr_min"]),
36
+ _config.get("utterance_snr_max", _DEFAULTS["utterance_snr_max"]),
37
+ )
38
+
39
+ def __init__(
40
+ self,
41
+ rate=_DEFAULTS["rate"],
42
+ mixing_noise_rate=_DEFAULTS["mixing_noise_rate"],
43
+ noise_path=_DEFAULTS["noise_path"],
44
+ noise_snr_min=_DEFAULTS["noise_snr_min"],
45
+ noise_snr_max=_DEFAULTS["noise_snr_max"],
46
+ utterance_snr_min=_DEFAULTS["utterance_snr_min"],
47
+ utterance_snr_max=_DEFAULTS["utterance_snr_max"],
48
+ ):
49
+ self.rate = rate
50
+ self.mixing_noise_rate = mixing_noise_rate
51
+ self.noise_shaper = NoiseAugmentTransform(noise_path)
52
+ self.noise_snr_min = noise_snr_min
53
+ self.noise_snr_max = noise_snr_max
54
+ self.utterance_snr_min = utterance_snr_min
55
+ self.utterance_snr_max = utterance_snr_max
56
+
57
+ def __repr__(self):
58
+ return (
59
+ self.__class__.__name__
60
+ + "("
61
+ + ", ".join(
62
+ [
63
+ f"rate={self.rate}",
64
+ f"mixing_noise_rate={self.mixing_noise_rate}",
65
+ f"noise_snr_min={self.noise_snr_min}",
66
+ f"noise_snr_max={self.noise_snr_max}",
67
+ f"utterance_snr_min={self.utterance_snr_min}",
68
+ f"utterance_snr_max={self.utterance_snr_max}",
69
+ ]
70
+ )
71
+ + ")"
72
+ )
73
+
74
+ def __call__(self, sources):
75
+ for i, source in enumerate(sources):
76
+ if np.random.random() > self.rate:
77
+ continue
78
+
79
+ pri = source.numpy()
80
+
81
+ if np.random.random() > self.mixing_noise_rate:
82
+ sec = sources[np.random.randint(0, len(sources))].numpy()
83
+ snr = rand_uniform(self.utterance_snr_min, self.utterance_snr_max)
84
+ else:
85
+ sec = self.noise_shaper.pick_sample(source.shape)
86
+ snr = rand_uniform(self.noise_snr_min, self.noise_snr_max)
87
+
88
+ L1 = pri.shape[-1]
89
+ L2 = sec.shape[-1]
90
+ l = np.random.randint(0, min(round(L1 / 2), L2)) # mix len
91
+ s_source = np.random.randint(0, L1 - l)
92
+ s_sec = np.random.randint(0, L2 - l)
93
+
94
+ get_power = lambda x: np.mean(x**2)
95
+ if get_power(sec) == 0:
96
+ continue
97
+
98
+ scl = np.sqrt(get_power(pri) / (np.power(10, snr / 10) * get_power(sec)))
99
+
100
+ pri[s_source : s_source + l] = np.add(
101
+ pri[s_source : s_source + l], np.multiply(scl, sec[s_sec : s_sec + l])
102
+ )
103
+ sources[i] = torch.from_numpy(pri).float()
104
+
105
+ return sources
modules/voice_conversion/fairseq/data/audio/feature_transforms/__init__.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from fairseq.data.audio import (
3
+ AudioTransform,
4
+ CompositeAudioTransform,
5
+ import_transforms,
6
+ register_audio_transform,
7
+ )
8
+
9
+
10
+ class AudioFeatureTransform(AudioTransform):
11
+ pass
12
+
13
+
14
+ AUDIO_FEATURE_TRANSFORM_REGISTRY = {}
15
+ AUDIO_FEATURE_TRANSFORM_CLASS_NAMES = set()
16
+
17
+
18
+ def get_audio_feature_transform(name):
19
+ return AUDIO_FEATURE_TRANSFORM_REGISTRY[name]
20
+
21
+
22
+ def register_audio_feature_transform(name):
23
+ return register_audio_transform(
24
+ name,
25
+ AudioFeatureTransform,
26
+ AUDIO_FEATURE_TRANSFORM_REGISTRY,
27
+ AUDIO_FEATURE_TRANSFORM_CLASS_NAMES,
28
+ )
29
+
30
+
31
+ import_transforms(os.path.dirname(__file__), "feature")
32
+
33
+
34
+ class CompositeAudioFeatureTransform(CompositeAudioTransform):
35
+ @classmethod
36
+ def from_config_dict(cls, config=None):
37
+ return super()._from_config_dict(
38
+ cls,
39
+ "feature",
40
+ get_audio_feature_transform,
41
+ CompositeAudioFeatureTransform,
42
+ config,
43
+ )
modules/voice_conversion/fairseq/data/audio/feature_transforms/delta_deltas.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from fairseq.data.audio.feature_transforms import (
4
+ AudioFeatureTransform,
5
+ register_audio_feature_transform,
6
+ )
7
+
8
+
9
+ @register_audio_feature_transform("delta_deltas")
10
+ class DeltaDeltas(AudioFeatureTransform):
11
+ """Expand delta-deltas features from spectrum."""
12
+
13
+ @classmethod
14
+ def from_config_dict(cls, config=None):
15
+ _config = {} if config is None else config
16
+ return DeltaDeltas(_config.get("win_length", 5))
17
+
18
+ def __init__(self, win_length=5):
19
+ self.win_length = win_length
20
+
21
+ def __repr__(self):
22
+ return self.__class__.__name__
23
+
24
+ def __call__(self, spectrogram):
25
+ from torchaudio.functional import compute_deltas
26
+
27
+ assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor."
28
+ # spectrogram is T x F, while compute_deltas takes (…, F, T)
29
+ spectrogram = torch.from_numpy(spectrogram).transpose(0, 1)
30
+ delta = compute_deltas(spectrogram)
31
+ delta_delta = compute_deltas(delta)
32
+
33
+ out_feat = np.concatenate(
34
+ [spectrogram, delta.numpy(), delta_delta.numpy()], axis=0
35
+ )
36
+ out_feat = np.transpose(out_feat)
37
+ return out_feat
modules/voice_conversion/fairseq/data/audio/feature_transforms/global_cmvn.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from fairseq.data.audio.feature_transforms import (
3
+ AudioFeatureTransform,
4
+ register_audio_feature_transform,
5
+ )
6
+
7
+
8
+ @register_audio_feature_transform("global_cmvn")
9
+ class GlobalCMVN(AudioFeatureTransform):
10
+ """Global CMVN (cepstral mean and variance normalization). The global mean
11
+ and variance need to be pre-computed and stored in NumPy format (.npz)."""
12
+
13
+ @classmethod
14
+ def from_config_dict(cls, config=None):
15
+ _config = {} if config is None else config
16
+ return GlobalCMVN(_config.get("stats_npz_path"))
17
+
18
+ def __init__(self, stats_npz_path):
19
+ self.stats_npz_path = stats_npz_path
20
+ stats = np.load(stats_npz_path)
21
+ self.mean, self.std = stats["mean"], stats["std"]
22
+
23
+ def __repr__(self):
24
+ return self.__class__.__name__ + f'(stats_npz_path="{self.stats_npz_path}")'
25
+
26
+ def __call__(self, x):
27
+ x = np.subtract(x, self.mean)
28
+ x = np.divide(x, self.std)
29
+ return x
modules/voice_conversion/fairseq/data/audio/feature_transforms/specaugment.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numbers
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ from fairseq.data.audio.feature_transforms import (
7
+ AudioFeatureTransform,
8
+ register_audio_feature_transform,
9
+ )
10
+
11
+
12
+ @register_audio_feature_transform("specaugment")
13
+ class SpecAugmentTransform(AudioFeatureTransform):
14
+ """SpecAugment (https://arxiv.org/abs/1904.08779)"""
15
+
16
+ @classmethod
17
+ def from_config_dict(cls, config=None):
18
+ _config = {} if config is None else config
19
+ return SpecAugmentTransform(
20
+ _config.get("time_warp_W", 0),
21
+ _config.get("freq_mask_N", 0),
22
+ _config.get("freq_mask_F", 0),
23
+ _config.get("time_mask_N", 0),
24
+ _config.get("time_mask_T", 0),
25
+ _config.get("time_mask_p", 0.0),
26
+ _config.get("mask_value", None),
27
+ )
28
+
29
+ def __init__(
30
+ self,
31
+ time_warp_w: int = 0,
32
+ freq_mask_n: int = 0,
33
+ freq_mask_f: int = 0,
34
+ time_mask_n: int = 0,
35
+ time_mask_t: int = 0,
36
+ time_mask_p: float = 0.0,
37
+ mask_value: Optional[float] = 0.0,
38
+ ):
39
+ # Sanity checks
40
+ assert mask_value is None or isinstance(
41
+ mask_value, numbers.Number
42
+ ), f"mask_value (type: {type(mask_value)}) must be None or a number"
43
+ if freq_mask_n > 0:
44
+ assert freq_mask_f > 0, (
45
+ f"freq_mask_F ({freq_mask_f}) "
46
+ f"must be larger than 0 when doing freq masking."
47
+ )
48
+ if time_mask_n > 0:
49
+ assert time_mask_t > 0, (
50
+ f"time_mask_T ({time_mask_t}) must be larger than 0 when "
51
+ f"doing time masking."
52
+ )
53
+
54
+ self.time_warp_w = time_warp_w
55
+ self.freq_mask_n = freq_mask_n
56
+ self.freq_mask_f = freq_mask_f
57
+ self.time_mask_n = time_mask_n
58
+ self.time_mask_t = time_mask_t
59
+ self.time_mask_p = time_mask_p
60
+ self.mask_value = mask_value
61
+
62
+ def __repr__(self):
63
+ return (
64
+ self.__class__.__name__
65
+ + "("
66
+ + ", ".join(
67
+ [
68
+ f"time_warp_w={self.time_warp_w}",
69
+ f"freq_mask_n={self.freq_mask_n}",
70
+ f"freq_mask_f={self.freq_mask_f}",
71
+ f"time_mask_n={self.time_mask_n}",
72
+ f"time_mask_t={self.time_mask_t}",
73
+ f"time_mask_p={self.time_mask_p}",
74
+ ]
75
+ )
76
+ + ")"
77
+ )
78
+
79
+ def __call__(self, spectrogram):
80
+ assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor."
81
+
82
+ distorted = spectrogram.copy() # make a copy of input spectrogram.
83
+ num_frames = spectrogram.shape[0] # or 'tau' in the paper.
84
+ num_freqs = spectrogram.shape[1] # or 'miu' in the paper.
85
+ mask_value = self.mask_value
86
+
87
+ if mask_value is None: # if no value was specified, use local mean.
88
+ mask_value = spectrogram.mean()
89
+
90
+ if num_frames == 0:
91
+ return spectrogram
92
+
93
+ if num_freqs < self.freq_mask_f:
94
+ return spectrogram
95
+
96
+ if self.time_warp_w > 0:
97
+ if 2 * self.time_warp_w < num_frames:
98
+ import cv2
99
+
100
+ w0 = np.random.randint(self.time_warp_w, num_frames - self.time_warp_w)
101
+ w = np.random.randint(-self.time_warp_w + 1, self.time_warp_w)
102
+ upper, lower = distorted[:w0, :], distorted[w0:, :]
103
+ upper = cv2.resize(
104
+ upper, dsize=(num_freqs, w0 + w), interpolation=cv2.INTER_LINEAR
105
+ )
106
+ lower = cv2.resize(
107
+ lower,
108
+ dsize=(num_freqs, num_frames - w0 - w),
109
+ interpolation=cv2.INTER_LINEAR,
110
+ )
111
+ distorted = np.concatenate((upper, lower), axis=0)
112
+
113
+ for _i in range(self.freq_mask_n):
114
+ f = np.random.randint(0, self.freq_mask_f)
115
+ f0 = np.random.randint(0, num_freqs - f)
116
+ if f != 0:
117
+ distorted[:, f0 : f0 + f] = mask_value
118
+
119
+ max_time_mask_t = min(
120
+ self.time_mask_t, math.floor(num_frames * self.time_mask_p)
121
+ )
122
+ if max_time_mask_t < 1:
123
+ return distorted
124
+
125
+ for _i in range(self.time_mask_n):
126
+ t = np.random.randint(0, max_time_mask_t)
127
+ t0 = np.random.randint(0, num_frames - t)
128
+ if t != 0:
129
+ distorted[t0 : t0 + t, :] = mask_value
130
+
131
+ return distorted
modules/voice_conversion/fairseq/data/audio/feature_transforms/utterance_cmvn.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from fairseq.data.audio.feature_transforms import (
4
+ AudioFeatureTransform,
5
+ register_audio_feature_transform,
6
+ )
7
+
8
+
9
+ @register_audio_feature_transform("utterance_cmvn")
10
+ class UtteranceCMVN(AudioFeatureTransform):
11
+ """Utterance-level CMVN (cepstral mean and variance normalization)"""
12
+
13
+ @classmethod
14
+ def from_config_dict(cls, config=None):
15
+ _config = {} if config is None else config
16
+ return UtteranceCMVN(
17
+ _config.get("norm_means", True),
18
+ _config.get("norm_vars", True),
19
+ )
20
+
21
+ def __init__(self, norm_means=True, norm_vars=True):
22
+ self.norm_means, self.norm_vars = norm_means, norm_vars
23
+
24
+ def __repr__(self):
25
+ return (
26
+ self.__class__.__name__
27
+ + f"(norm_means={self.norm_means}, norm_vars={self.norm_vars})"
28
+ )
29
+
30
+ def __call__(self, x):
31
+ mean = x.mean(axis=0)
32
+ square_sums = (x**2).sum(axis=0)
33
+
34
+ if self.norm_means:
35
+ x = np.subtract(x, mean)
36
+ if self.norm_vars:
37
+ var = square_sums / x.shape[0] - mean**2
38
+ std = np.sqrt(np.maximum(var, 1e-10))
39
+ x = np.divide(x, std)
40
+
41
+ return x
modules/voice_conversion/fairseq/data/audio/frm_text_to_speech_dataset.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2017-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the LICENSE file in
5
+ # the root directory of this source tree. An additional grant of patent rights
6
+ # can be found in the PATENTS file in the same directory.abs
7
+
8
+ import csv
9
+ import logging
10
+ import os.path as op
11
+ from typing import List, Optional
12
+
13
+ import numpy as np
14
+ import torch
15
+ from fairseq.data import Dictionary
16
+ from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig
17
+ from fairseq.data.audio.text_to_speech_dataset import (
18
+ TextToSpeechDataset,
19
+ TextToSpeechDatasetCreator,
20
+ )
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class FrmTextToSpeechDataset(TextToSpeechDataset):
26
+ def __init__(
27
+ self,
28
+ split: str,
29
+ is_train_split: bool,
30
+ data_cfg: S2TDataConfig,
31
+ audio_paths: List[str],
32
+ n_frames: List[int],
33
+ src_texts: Optional[List[str]] = None,
34
+ tgt_texts: Optional[List[str]] = None,
35
+ speakers: Optional[List[str]] = None,
36
+ src_langs: Optional[List[str]] = None,
37
+ tgt_langs: Optional[List[str]] = None,
38
+ ids: Optional[List[str]] = None,
39
+ tgt_dict: Optional[Dictionary] = None,
40
+ pre_tokenizer=None,
41
+ bpe_tokenizer=None,
42
+ n_frames_per_step=1,
43
+ speaker_to_id=None,
44
+ do_chunk=False,
45
+ chunk_bound=-1,
46
+ chunk_init=50,
47
+ chunk_incr=5,
48
+ add_eos=True,
49
+ dedup=True,
50
+ ref_fpu=-1,
51
+ ):
52
+ # It assumes texts are encoded at a fixed frame-rate
53
+ super().__init__(
54
+ split=split,
55
+ is_train_split=is_train_split,
56
+ data_cfg=data_cfg,
57
+ audio_paths=audio_paths,
58
+ n_frames=n_frames,
59
+ src_texts=src_texts,
60
+ tgt_texts=tgt_texts,
61
+ speakers=speakers,
62
+ src_langs=src_langs,
63
+ tgt_langs=tgt_langs,
64
+ ids=ids,
65
+ tgt_dict=tgt_dict,
66
+ pre_tokenizer=pre_tokenizer,
67
+ bpe_tokenizer=bpe_tokenizer,
68
+ n_frames_per_step=n_frames_per_step,
69
+ speaker_to_id=speaker_to_id,
70
+ )
71
+
72
+ self.do_chunk = do_chunk
73
+ self.chunk_bound = chunk_bound
74
+ self.chunk_init = chunk_init
75
+ self.chunk_incr = chunk_incr
76
+ self.add_eos = add_eos
77
+ self.dedup = dedup
78
+ self.ref_fpu = ref_fpu
79
+
80
+ self.chunk_size = -1
81
+
82
+ if do_chunk:
83
+ assert self.chunk_incr >= 0
84
+ assert self.pre_tokenizer is None
85
+
86
+ def __getitem__(self, index):
87
+ index, source, target, speaker_id, _, _, _ = super().__getitem__(index)
88
+ if target[-1].item() == self.tgt_dict.eos_index:
89
+ target = target[:-1]
90
+
91
+ fpu = source.size(0) / target.size(0) # frame-per-unit
92
+ fps = self.n_frames_per_step
93
+ assert (
94
+ self.ref_fpu == -1 or abs((fpu * fps - self.ref_fpu) / self.ref_fpu) < 0.1
95
+ ), f"{fpu*fps} != {self.ref_fpu}"
96
+
97
+ # only chunk training split
98
+ if self.is_train_split and self.do_chunk and self.chunk_size > 0:
99
+ lang = target[: int(self.data_cfg.prepend_tgt_lang_tag)]
100
+ text = target[int(self.data_cfg.prepend_tgt_lang_tag) :]
101
+ size = len(text)
102
+ chunk_size = min(self.chunk_size, size)
103
+ chunk_start = np.random.randint(size - chunk_size + 1)
104
+ text = text[chunk_start : chunk_start + chunk_size]
105
+ target = torch.cat((lang, text), 0)
106
+
107
+ f_size = int(np.floor(chunk_size * fpu))
108
+ f_start = int(np.floor(chunk_start * fpu))
109
+ assert f_size > 0
110
+ source = source[f_start : f_start + f_size, :]
111
+
112
+ if self.dedup:
113
+ target = torch.unique_consecutive(target)
114
+
115
+ if self.add_eos:
116
+ eos_idx = self.tgt_dict.eos_index
117
+ target = torch.cat((target, torch.LongTensor([eos_idx])), 0)
118
+
119
+ return index, source, target, speaker_id
120
+
121
+ def set_epoch(self, epoch):
122
+ if self.is_train_split and self.do_chunk:
123
+ old = self.chunk_size
124
+ self.chunk_size = self.chunk_init + epoch * self.chunk_incr
125
+ if self.chunk_bound > 0:
126
+ self.chunk_size = min(self.chunk_size, self.chunk_bound)
127
+ logger.info(
128
+ (
129
+ f"{self.split}: setting chunk size "
130
+ f"from {old} to {self.chunk_size}"
131
+ )
132
+ )
133
+
134
+
135
+ class FrmTextToSpeechDatasetCreator(TextToSpeechDatasetCreator):
136
+ # inherit for key names
137
+ @classmethod
138
+ def from_tsv(
139
+ cls,
140
+ root: str,
141
+ data_cfg: S2TDataConfig,
142
+ split: str,
143
+ tgt_dict,
144
+ pre_tokenizer,
145
+ bpe_tokenizer,
146
+ is_train_split: bool,
147
+ n_frames_per_step: int,
148
+ speaker_to_id,
149
+ do_chunk: bool = False,
150
+ chunk_bound: int = -1,
151
+ chunk_init: int = 50,
152
+ chunk_incr: int = 5,
153
+ add_eos: bool = True,
154
+ dedup: bool = True,
155
+ ref_fpu: float = -1,
156
+ ) -> FrmTextToSpeechDataset:
157
+ tsv_path = op.join(root, f"{split}.tsv")
158
+ if not op.isfile(tsv_path):
159
+ raise FileNotFoundError(f"Dataset not found: {tsv_path}")
160
+ with open(tsv_path) as f:
161
+ reader = csv.DictReader(
162
+ f,
163
+ delimiter="\t",
164
+ quotechar=None,
165
+ doublequote=False,
166
+ lineterminator="\n",
167
+ quoting=csv.QUOTE_NONE,
168
+ )
169
+ s = [dict(e) for e in reader]
170
+ assert len(s) > 0
171
+
172
+ ids = [ss[cls.KEY_ID] for ss in s]
173
+ audio_paths = [op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO]) for ss in s]
174
+ n_frames = [int(ss[cls.KEY_N_FRAMES]) for ss in s]
175
+ tgt_texts = [ss[cls.KEY_TGT_TEXT] for ss in s]
176
+ src_texts = [ss.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for ss in s]
177
+ speakers = [ss.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for ss in s]
178
+ src_langs = [ss.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for ss in s]
179
+ tgt_langs = [ss.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for ss in s]
180
+
181
+ return FrmTextToSpeechDataset(
182
+ split=split,
183
+ is_train_split=is_train_split,
184
+ data_cfg=data_cfg,
185
+ audio_paths=audio_paths,
186
+ n_frames=n_frames,
187
+ src_texts=src_texts,
188
+ tgt_texts=tgt_texts,
189
+ speakers=speakers,
190
+ src_langs=src_langs,
191
+ tgt_langs=tgt_langs,
192
+ ids=ids,
193
+ tgt_dict=tgt_dict,
194
+ pre_tokenizer=pre_tokenizer,
195
+ bpe_tokenizer=bpe_tokenizer,
196
+ n_frames_per_step=n_frames_per_step,
197
+ speaker_to_id=speaker_to_id,
198
+ do_chunk=do_chunk,
199
+ chunk_bound=chunk_bound,
200
+ chunk_init=chunk_init,
201
+ chunk_incr=chunk_incr,
202
+ add_eos=add_eos,
203
+ dedup=dedup,
204
+ ref_fpu=ref_fpu,
205
+ )
modules/voice_conversion/fairseq/data/audio/hubert_dataset.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import itertools
7
+ import logging
8
+ import os
9
+ import sys
10
+ from typing import Any, List, Optional, Union
11
+
12
+ import numpy as np
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from fairseq.data import data_utils
17
+ from fairseq.data.fairseq_dataset import FairseqDataset
18
+ from fairseq.data.audio.audio_utils import (
19
+ parse_path,
20
+ read_from_stored_zip,
21
+ )
22
+ import io
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ def load_audio(manifest_path, max_keep, min_keep):
28
+ n_long, n_short = 0, 0
29
+ names, inds, sizes = [], [], []
30
+ with open(manifest_path) as f:
31
+ root = f.readline().strip()
32
+ for ind, line in enumerate(f):
33
+ items = line.strip().split("\t")
34
+ assert len(items) == 2, line
35
+ sz = int(items[1])
36
+ if min_keep is not None and sz < min_keep:
37
+ n_short += 1
38
+ elif max_keep is not None and sz > max_keep:
39
+ n_long += 1
40
+ else:
41
+ names.append(items[0])
42
+ inds.append(ind)
43
+ sizes.append(sz)
44
+ tot = ind + 1
45
+ logger.info(
46
+ (
47
+ f"max_keep={max_keep}, min_keep={min_keep}, "
48
+ f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
49
+ f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
50
+ )
51
+ )
52
+ return root, names, inds, tot, sizes
53
+
54
+
55
+ def load_label(label_path, inds, tot):
56
+ with open(label_path) as f:
57
+ labels = [line.rstrip() for line in f]
58
+ assert (
59
+ len(labels) == tot
60
+ ), f"number of labels does not match ({len(labels)} != {tot})"
61
+ labels = [labels[i] for i in inds]
62
+ return labels
63
+
64
+
65
+ def load_label_offset(label_path, inds, tot):
66
+ with open(label_path) as f:
67
+ code_lengths = [len(line.encode("utf-8")) for line in f]
68
+ assert (
69
+ len(code_lengths) == tot
70
+ ), f"number of labels does not match ({len(code_lengths)} != {tot})"
71
+ offsets = list(itertools.accumulate([0] + code_lengths))
72
+ offsets = [(offsets[i], offsets[i + 1]) for i in inds]
73
+ return offsets
74
+
75
+
76
+ def verify_label_lengths(
77
+ audio_sizes,
78
+ audio_rate,
79
+ label_path,
80
+ label_rate,
81
+ inds,
82
+ tot,
83
+ tol=0.1, # tolerance in seconds
84
+ ):
85
+ if label_rate < 0:
86
+ logger.info(f"{label_path} is sequence label. skipped")
87
+ return
88
+
89
+ with open(label_path) as f:
90
+ lengths = [len(line.rstrip().split()) for line in f]
91
+ assert len(lengths) == tot
92
+ lengths = [lengths[i] for i in inds]
93
+ num_invalid = 0
94
+ for i, ind in enumerate(inds):
95
+ dur_from_audio = audio_sizes[i] / audio_rate
96
+ dur_from_label = lengths[i] / label_rate
97
+ if abs(dur_from_audio - dur_from_label) > tol:
98
+ logger.warning(
99
+ (
100
+ f"audio and label duration differ too much "
101
+ f"(|{dur_from_audio} - {dur_from_label}| > {tol}) "
102
+ f"in line {ind+1} of {label_path}. Check if `label_rate` "
103
+ f"is correctly set (currently {label_rate}). "
104
+ f"num. of samples = {audio_sizes[i]}; "
105
+ f"label length = {lengths[i]}"
106
+ )
107
+ )
108
+ num_invalid += 1
109
+ if num_invalid > 0:
110
+ logger.warning(
111
+ f"total {num_invalid} (audio, label) pairs with mismatched lengths"
112
+ )
113
+
114
+
115
+ class HubertDataset(FairseqDataset):
116
+ def __init__(
117
+ self,
118
+ manifest_path: str,
119
+ sample_rate: float,
120
+ label_paths: List[str],
121
+ label_rates: Union[List[float], float], # -1 for sequence labels
122
+ pad_list: List[str],
123
+ eos_list: List[str],
124
+ label_processors: Optional[List[Any]] = None,
125
+ max_keep_sample_size: Optional[int] = None,
126
+ min_keep_sample_size: Optional[int] = None,
127
+ max_sample_size: Optional[int] = None,
128
+ shuffle: bool = True,
129
+ pad_audio: bool = False,
130
+ normalize: bool = False,
131
+ store_labels: bool = True,
132
+ random_crop: bool = False,
133
+ single_target: bool = False,
134
+ ):
135
+ self.audio_root, self.audio_names, inds, tot, self.sizes = load_audio(
136
+ manifest_path, max_keep_sample_size, min_keep_sample_size
137
+ )
138
+ self.sample_rate = sample_rate
139
+ self.shuffle = shuffle
140
+ self.random_crop = random_crop
141
+
142
+ self.num_labels = len(label_paths)
143
+ self.pad_list = pad_list
144
+ self.eos_list = eos_list
145
+ self.label_processors = label_processors
146
+ self.single_target = single_target
147
+ self.label_rates = (
148
+ [label_rates for _ in range(len(label_paths))]
149
+ if isinstance(label_rates, float)
150
+ else label_rates
151
+ )
152
+ self.store_labels = store_labels
153
+ if store_labels:
154
+ self.label_list = [load_label(p, inds, tot) for p in label_paths]
155
+ else:
156
+ self.label_paths = label_paths
157
+ self.label_offsets_list = [
158
+ load_label_offset(p, inds, tot) for p in label_paths
159
+ ]
160
+ assert label_processors is None or len(label_processors) == self.num_labels
161
+ for label_path, label_rate in zip(label_paths, self.label_rates):
162
+ verify_label_lengths(
163
+ self.sizes, sample_rate, label_path, label_rate, inds, tot
164
+ )
165
+
166
+ self.max_sample_size = (
167
+ max_sample_size if max_sample_size is not None else sys.maxsize
168
+ )
169
+ self.pad_audio = pad_audio
170
+ self.normalize = normalize
171
+ logger.info(
172
+ f"pad_audio={pad_audio}, random_crop={random_crop}, "
173
+ f"normalize={normalize}, max_sample_size={self.max_sample_size}"
174
+ )
175
+
176
+ def get_audio(self, index):
177
+ import soundfile as sf
178
+
179
+ wav_path = os.path.join(self.audio_root, self.audio_names[index])
180
+ _path, slice_ptr = parse_path(wav_path)
181
+ if len(slice_ptr) == 0:
182
+ wav, cur_sample_rate = sf.read(_path)
183
+ else:
184
+ assert _path.endswith(".zip")
185
+ data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
186
+ f = io.BytesIO(data)
187
+ wav, cur_sample_rate = sf.read(f)
188
+ wav = torch.from_numpy(wav).float()
189
+ wav = self.postprocess(wav, cur_sample_rate)
190
+ return wav
191
+
192
+ def get_label(self, index, label_idx):
193
+ if self.store_labels:
194
+ label = self.label_list[label_idx][index]
195
+ else:
196
+ with open(self.label_paths[label_idx]) as f:
197
+ offset_s, offset_e = self.label_offsets_list[label_idx][index]
198
+ f.seek(offset_s)
199
+ label = f.read(offset_e - offset_s)
200
+
201
+ if self.label_processors is not None:
202
+ label = self.label_processors[label_idx](label)
203
+ return label
204
+
205
+ def get_labels(self, index):
206
+ return [self.get_label(index, i) for i in range(self.num_labels)]
207
+
208
+ def __getitem__(self, index):
209
+ wav = self.get_audio(index)
210
+ labels = self.get_labels(index)
211
+ return {"id": index, "source": wav, "label_list": labels}
212
+
213
+ def __len__(self):
214
+ return len(self.sizes)
215
+
216
+ def crop_to_max_size(self, wav, target_size):
217
+ size = len(wav)
218
+ diff = size - target_size
219
+ if diff <= 0:
220
+ return wav, 0
221
+
222
+ start, end = 0, target_size
223
+ if self.random_crop:
224
+ start = np.random.randint(0, diff + 1)
225
+ end = size - diff + start
226
+ return wav[start:end], start
227
+
228
+ def collater(self, samples):
229
+ # target = max(sizes) -> random_crop not used
230
+ # target = max_sample_size -> random_crop used for long
231
+ samples = [s for s in samples if s["source"] is not None]
232
+ if len(samples) == 0:
233
+ return {}
234
+
235
+ audios = [s["source"] for s in samples]
236
+ audio_sizes = [len(s) for s in audios]
237
+ if self.pad_audio:
238
+ audio_size = min(max(audio_sizes), self.max_sample_size)
239
+ else:
240
+ audio_size = min(min(audio_sizes), self.max_sample_size)
241
+ collated_audios, padding_mask, audio_starts = self.collater_audio(
242
+ audios, audio_size
243
+ )
244
+
245
+ targets_by_label = [
246
+ [s["label_list"][i] for s in samples] for i in range(self.num_labels)
247
+ ]
248
+ targets_list, lengths_list, ntokens_list = self.collater_label(
249
+ targets_by_label, audio_size, audio_starts
250
+ )
251
+
252
+ net_input = {"source": collated_audios, "padding_mask": padding_mask}
253
+ batch = {
254
+ "id": torch.LongTensor([s["id"] for s in samples]),
255
+ "net_input": net_input,
256
+ }
257
+
258
+ if self.single_target:
259
+ batch["target_lengths"] = lengths_list[0]
260
+ batch["ntokens"] = ntokens_list[0]
261
+ batch["target"] = targets_list[0]
262
+ else:
263
+ batch["target_lengths_list"] = lengths_list
264
+ batch["ntokens_list"] = ntokens_list
265
+ batch["target_list"] = targets_list
266
+ return batch
267
+
268
+ def collater_audio(self, audios, audio_size):
269
+ collated_audios = audios[0].new_zeros(len(audios), audio_size)
270
+ padding_mask = (
271
+ torch.BoolTensor(collated_audios.shape).fill_(False)
272
+ # if self.pad_audio else None
273
+ )
274
+ audio_starts = [0 for _ in audios]
275
+ for i, audio in enumerate(audios):
276
+ diff = len(audio) - audio_size
277
+ if diff == 0:
278
+ collated_audios[i] = audio
279
+ elif diff < 0:
280
+ assert self.pad_audio
281
+ collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
282
+ padding_mask[i, diff:] = True
283
+ else:
284
+ collated_audios[i], audio_starts[i] = self.crop_to_max_size(
285
+ audio, audio_size
286
+ )
287
+ return collated_audios, padding_mask, audio_starts
288
+
289
+ def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
290
+ assert label_rate > 0
291
+ s2f = label_rate / self.sample_rate
292
+ frm_starts = [int(round(s * s2f)) for s in audio_starts]
293
+ frm_size = int(round(audio_size * s2f))
294
+ if not self.pad_audio:
295
+ rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
296
+ frm_size = min(frm_size, *rem_size)
297
+ targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
298
+ logger.debug(f"audio_starts={audio_starts}")
299
+ logger.debug(f"frame_starts={frm_starts}")
300
+ logger.debug(f"frame_size={frm_size}")
301
+
302
+ lengths = torch.LongTensor([len(t) for t in targets])
303
+ ntokens = lengths.sum().item()
304
+ targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
305
+ return targets, lengths, ntokens
306
+
307
+ def collater_seq_label(self, targets, pad):
308
+ lengths = torch.LongTensor([len(t) for t in targets])
309
+ ntokens = lengths.sum().item()
310
+ targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
311
+ return targets, lengths, ntokens
312
+
313
+ def collater_label(self, targets_by_label, audio_size, audio_starts):
314
+ targets_list, lengths_list, ntokens_list = [], [], []
315
+ itr = zip(targets_by_label, self.label_rates, self.pad_list)
316
+ for targets, label_rate, pad in itr:
317
+ if label_rate == -1.0:
318
+ targets, lengths, ntokens = self.collater_seq_label(targets, pad)
319
+ else:
320
+ targets, lengths, ntokens = self.collater_frm_label(
321
+ targets, audio_size, audio_starts, label_rate, pad
322
+ )
323
+ targets_list.append(targets)
324
+ lengths_list.append(lengths)
325
+ ntokens_list.append(ntokens)
326
+ return targets_list, lengths_list, ntokens_list
327
+
328
+ def num_tokens(self, index):
329
+ return self.size(index)
330
+
331
+ def size(self, index):
332
+ if self.pad_audio:
333
+ return self.sizes[index]
334
+ return min(self.sizes[index], self.max_sample_size)
335
+
336
+ def ordered_indices(self):
337
+ if self.shuffle:
338
+ order = [np.random.permutation(len(self))]
339
+ else:
340
+ order = [np.arange(len(self))]
341
+
342
+ order.append(self.sizes)
343
+ return np.lexsort(order)[::-1]
344
+
345
+ def postprocess(self, wav, cur_sample_rate):
346
+ if wav.dim() == 2:
347
+ wav = wav.mean(-1)
348
+ assert wav.dim() == 1, wav.dim()
349
+
350
+ if cur_sample_rate != self.sample_rate:
351
+ raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
352
+
353
+ if self.normalize:
354
+ with torch.no_grad():
355
+ wav = F.layer_norm(wav, wav.shape)
356
+ return wav
modules/voice_conversion/fairseq/data/audio/multi_modality_dataset.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the LICENSE file in
5
+ # the root directory of this source tree. An additional grant of patent rights
6
+ # can be found in the PATENTS file in the same directory.
7
+
8
+ import logging
9
+ import math
10
+ from typing import List, Optional, NamedTuple
11
+
12
+ import numpy as np
13
+ from fairseq.data.resampling_dataset import ResamplingDataset
14
+ import torch
15
+ from fairseq.data import (
16
+ ConcatDataset,
17
+ LanguagePairDataset,
18
+ FileAudioDataset,
19
+ data_utils,
20
+ )
21
+ from fairseq.data import FairseqDataset
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class ModalityDatasetItem(NamedTuple):
27
+ datasetname: str
28
+ dataset: any
29
+ max_positions: List[int]
30
+ max_tokens: Optional[int] = None
31
+ max_sentences: Optional[int] = None
32
+
33
+
34
+ def resampling_dataset_present(ds):
35
+ if isinstance(ds, ResamplingDataset):
36
+ return True
37
+ if isinstance(ds, ConcatDataset):
38
+ return any(resampling_dataset_present(d) for d in ds.datasets)
39
+ if hasattr(ds, "dataset"):
40
+ return resampling_dataset_present(ds.dataset)
41
+ return False
42
+
43
+
44
+ # MultiModalityDataset: it concate multiple datasets with different modalities.
45
+ # Compared with ConcatDataset it can 1) sample data given the ratios for different datasets
46
+ # 2) it adds mode to indicate what type of the data samples come from.
47
+ # It will be used with GroupedEpochBatchIterator together to generate mini-batch with samples
48
+ # from the same type of dataset
49
+ # If only one dataset is used, it will perform like the original dataset with mode added
50
+ class MultiModalityDataset(ConcatDataset):
51
+ def __init__(self, datasets: List[ModalityDatasetItem]):
52
+ id_to_mode = []
53
+ dsets = []
54
+ max_tokens = []
55
+ max_sentences = []
56
+ max_positions = []
57
+ for dset in datasets:
58
+ id_to_mode.append(dset.datasetname)
59
+ dsets.append(dset.dataset)
60
+ max_tokens.append(dset.max_tokens)
61
+ max_positions.append(dset.max_positions)
62
+ max_sentences.append(dset.max_sentences)
63
+ weights = [1.0 for s in dsets]
64
+ super().__init__(dsets, weights)
65
+ self.max_tokens = max_tokens
66
+ self.max_positions = max_positions
67
+ self.max_sentences = max_sentences
68
+ self.id_to_mode = id_to_mode
69
+ self.raw_sub_batch_samplers = []
70
+ self._cur_epoch = 0
71
+
72
+ def set_epoch(self, epoch):
73
+ super().set_epoch(epoch)
74
+ self._cur_epoch = epoch
75
+
76
+ def __getitem__(self, idx):
77
+ dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
78
+ sample = self.datasets[dataset_idx][sample_idx]
79
+ return (dataset_idx, sample)
80
+
81
+ def collater(self, samples):
82
+ if len(samples) == 0:
83
+ return {}
84
+ dataset_idx = samples[0][0]
85
+ # make sure all samples in samples are from same dataset
86
+ assert sum([0 if dataset_idx == s[0] else 1 for s in samples]) == 0
87
+ samples = self.datasets[dataset_idx].collater([x[1] for x in samples])
88
+ # add mode
89
+ samples["net_input"]["mode"] = self.id_to_mode[dataset_idx]
90
+
91
+ return samples
92
+
93
+ def size(self, index: int):
94
+ if len(self.datasets) == 1:
95
+ return self.datasets[0].size(index)
96
+ return super().size(index)
97
+
98
+ @property
99
+ def sizes(self):
100
+ if len(self.datasets) == 1:
101
+ return self.datasets[0].sizes
102
+ return super().sizes
103
+
104
+ def ordered_indices(self):
105
+ """
106
+ Returns indices sorted by length. So less padding is needed.
107
+ """
108
+ if len(self.datasets) == 1:
109
+ return self.datasets[0].ordered_indices()
110
+ indices_group = []
111
+ for d_idx, ds in enumerate(self.datasets):
112
+ sample_num = self.cumulative_sizes[d_idx]
113
+ if d_idx > 0:
114
+ sample_num = sample_num - self.cumulative_sizes[d_idx - 1]
115
+ assert sample_num == len(ds)
116
+ indices_group.append(ds.ordered_indices())
117
+ return indices_group
118
+
119
+ def get_raw_batch_samplers(self, required_batch_size_multiple, seed):
120
+ with data_utils.numpy_seed(seed):
121
+ indices = self.ordered_indices()
122
+ for i, ds in enumerate(self.datasets):
123
+ # If we have ResamplingDataset, the same id can correpond to a different
124
+ # sample in the next epoch, so we need to rebuild this at every epoch
125
+ if i < len(self.raw_sub_batch_samplers) and not resampling_dataset_present(
126
+ ds
127
+ ):
128
+ logger.info(f"dataset {i} is valid and it is not re-sampled")
129
+ continue
130
+ indices[i] = ds.filter_indices_by_size(
131
+ indices[i],
132
+ self.max_positions[i],
133
+ )[0]
134
+ sub_batch_sampler = ds.batch_by_size(
135
+ indices[i],
136
+ max_tokens=self.max_tokens[i],
137
+ max_sentences=self.max_sentences[i],
138
+ required_batch_size_multiple=required_batch_size_multiple,
139
+ )
140
+ if i < len(self.raw_sub_batch_samplers):
141
+ self.raw_sub_batch_samplers[i] = sub_batch_sampler
142
+ else:
143
+ self.raw_sub_batch_samplers.append(sub_batch_sampler)
144
+
145
+ def get_batch_samplers(self, mult_ratios, required_batch_size_multiple, seed):
146
+ self.get_raw_batch_samplers(required_batch_size_multiple, seed)
147
+ batch_samplers = []
148
+ for i, _ in enumerate(self.datasets):
149
+ if i > 0:
150
+ sub_batch_sampler = [
151
+ [y + self.cumulative_sizes[i - 1] for y in x]
152
+ for x in self.raw_sub_batch_samplers[i]
153
+ ]
154
+ else:
155
+ sub_batch_sampler = list(self.raw_sub_batch_samplers[i])
156
+ smp_r = mult_ratios[i]
157
+ if smp_r != 1:
158
+ is_increase = "increased" if smp_r > 1 else "decreased"
159
+ logger.info(
160
+ "number of batch for the dataset {} is {} from {} to {}".format(
161
+ self.id_to_mode[i],
162
+ is_increase,
163
+ len(sub_batch_sampler),
164
+ int(len(sub_batch_sampler) * smp_r),
165
+ )
166
+ )
167
+ mul_samplers = []
168
+ for _ in range(math.floor(smp_r)):
169
+ mul_samplers = mul_samplers + sub_batch_sampler
170
+ if math.floor(smp_r) != smp_r:
171
+ with data_utils.numpy_seed(seed + self._cur_epoch):
172
+ np.random.shuffle(sub_batch_sampler)
173
+ smp_num = int(
174
+ (smp_r - math.floor(smp_r)) * len(sub_batch_sampler)
175
+ )
176
+ mul_samplers = mul_samplers + sub_batch_sampler[:smp_num]
177
+ sub_batch_sampler = mul_samplers
178
+ else:
179
+ logger.info(
180
+ "dataset {} batch number is {} ".format(
181
+ self.id_to_mode[i], len(sub_batch_sampler)
182
+ )
183
+ )
184
+ batch_samplers.append(sub_batch_sampler)
185
+
186
+ return batch_samplers
187
+
188
+
189
+ class LangPairMaskDataset(FairseqDataset):
190
+ def __init__(
191
+ self,
192
+ dataset: LanguagePairDataset,
193
+ src_eos: int,
194
+ src_bos: Optional[int] = None,
195
+ noise_id: Optional[int] = -1,
196
+ mask_ratio: Optional[float] = 0,
197
+ mask_type: Optional[str] = "random",
198
+ ):
199
+ self.dataset = dataset
200
+ self.src_eos = src_eos
201
+ self.src_bos = src_bos
202
+ self.noise_id = noise_id
203
+ self.mask_ratio = mask_ratio
204
+ self.mask_type = mask_type
205
+ assert mask_type in ("random", "tail")
206
+
207
+ @property
208
+ def src_sizes(self):
209
+ return self.dataset.src_sizes
210
+
211
+ @property
212
+ def tgt_sizes(self):
213
+ return self.dataset.tgt_sizes
214
+
215
+ @property
216
+ def sizes(self):
217
+ # dataset.sizes can be a dynamically computed sizes:
218
+ return self.dataset.sizes
219
+
220
+ def get_batch_shapes(self):
221
+ if hasattr(self.dataset, "get_batch_shapes"):
222
+ return self.dataset.get_batch_shapes()
223
+ return self.dataset.buckets
224
+
225
+ def num_tokens_vec(self, indices):
226
+ return self.dataset.num_tokens_vec(indices)
227
+
228
+ def __len__(self):
229
+ return len(self.dataset)
230
+
231
+ def num_tokens(self, index):
232
+ return self.dataset.num_tokens(index)
233
+
234
+ def size(self, index):
235
+ return self.dataset.size(index)
236
+
237
+ def ordered_indices(self):
238
+ return self.dataset.ordered_indices()
239
+
240
+ @property
241
+ def supports_prefetch(self):
242
+ return getattr(self.dataset, "supports_prefetch", False)
243
+
244
+ def prefetch(self, indices):
245
+ return self.dataset.prefetch(indices)
246
+
247
+ def mask_src_tokens(self, sample):
248
+ src_item = sample["source"]
249
+ mask = None
250
+ if self.mask_type == "random":
251
+ mask = torch.rand(len(src_item)).le(self.mask_ratio)
252
+ else:
253
+ mask = torch.ones(len(src_item))
254
+ mask[: int(len(src_item) * (1 - self.mask_ratio))] = 0
255
+ mask = mask.eq(1)
256
+ if src_item[0] == self.src_bos:
257
+ mask[0] = False
258
+ if src_item[-1] == self.src_eos:
259
+ mask[-1] = False
260
+ mask_src_item = src_item.masked_fill(mask, self.noise_id)
261
+ smp = {"id": sample["id"], "source": mask_src_item, "target": sample["target"]}
262
+ return smp
263
+
264
+ def __getitem__(self, index):
265
+ sample = self.dataset[index]
266
+ if self.mask_ratio > 0:
267
+ sample = self.mask_src_tokens(sample)
268
+ return sample
269
+
270
+ def collater(self, samples, pad_to_length=None):
271
+ return self.dataset.collater(samples, pad_to_length)
272
+
273
+
274
+ class FileAudioDatasetWrapper(FileAudioDataset):
275
+ def collater(self, samples):
276
+ samples = super().collater(samples)
277
+ if len(samples) == 0:
278
+ return {}
279
+ samples["net_input"]["src_tokens"] = samples["net_input"]["source"]
280
+ samples["net_input"]["prev_output_tokens"] = None
281
+ del samples["net_input"]["source"]
282
+ samples["net_input"]["src_lengths"] = None
283
+ samples["net_input"]["alignment"] = None
284
+ return samples
modules/voice_conversion/fairseq/data/audio/raw_audio_dataset.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import logging
8
+ import os
9
+ import sys
10
+ import io
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+ from .. import FairseqDataset
17
+ from ..data_utils import compute_mask_indices, get_buckets, get_bucketed_sizes
18
+ from fairseq.data.audio.audio_utils import (
19
+ parse_path,
20
+ read_from_stored_zip,
21
+ is_sf_audio_data,
22
+ )
23
+ from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel
24
+
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class RawAudioDataset(FairseqDataset):
30
+ def __init__(
31
+ self,
32
+ sample_rate,
33
+ max_sample_size=None,
34
+ min_sample_size=0,
35
+ shuffle=True,
36
+ pad=False,
37
+ normalize=False,
38
+ compute_mask_indices=False,
39
+ **mask_compute_kwargs,
40
+ ):
41
+ super().__init__()
42
+
43
+ self.sample_rate = sample_rate
44
+ self.sizes = []
45
+ self.max_sample_size = (
46
+ max_sample_size if max_sample_size is not None else sys.maxsize
47
+ )
48
+ self.min_sample_size = min_sample_size
49
+ self.pad = pad
50
+ self.shuffle = shuffle
51
+ self.normalize = normalize
52
+ self.compute_mask_indices = compute_mask_indices
53
+ if self.compute_mask_indices:
54
+ self.mask_compute_kwargs = mask_compute_kwargs
55
+ self._features_size_map = {}
56
+ self._C = mask_compute_kwargs["encoder_embed_dim"]
57
+ self._conv_feature_layers = eval(mask_compute_kwargs["conv_feature_layers"])
58
+
59
+ def __getitem__(self, index):
60
+ raise NotImplementedError()
61
+
62
+ def __len__(self):
63
+ return len(self.sizes)
64
+
65
+ def postprocess(self, feats, curr_sample_rate):
66
+ if feats.dim() == 2:
67
+ feats = feats.mean(-1)
68
+
69
+ if curr_sample_rate != self.sample_rate:
70
+ raise Exception(f"sample rate: {curr_sample_rate}, need {self.sample_rate}")
71
+
72
+ assert feats.dim() == 1, feats.dim()
73
+
74
+ if self.normalize:
75
+ with torch.no_grad():
76
+ feats = F.layer_norm(feats, feats.shape)
77
+ return feats
78
+
79
+ def crop_to_max_size(self, wav, target_size):
80
+ size = len(wav)
81
+ diff = size - target_size
82
+ if diff <= 0:
83
+ return wav
84
+
85
+ start = np.random.randint(0, diff + 1)
86
+ end = size - diff + start
87
+ return wav[start:end]
88
+
89
+ def _compute_mask_indices(self, dims, padding_mask):
90
+ B, T, C = dims
91
+ mask_indices, mask_channel_indices = None, None
92
+ if self.mask_compute_kwargs["mask_prob"] > 0:
93
+ mask_indices = compute_mask_indices(
94
+ (B, T),
95
+ padding_mask,
96
+ self.mask_compute_kwargs["mask_prob"],
97
+ self.mask_compute_kwargs["mask_length"],
98
+ self.mask_compute_kwargs["mask_selection"],
99
+ self.mask_compute_kwargs["mask_other"],
100
+ min_masks=2,
101
+ no_overlap=self.mask_compute_kwargs["no_mask_overlap"],
102
+ min_space=self.mask_compute_kwargs["mask_min_space"],
103
+ )
104
+ mask_indices = torch.from_numpy(mask_indices)
105
+ if self.mask_compute_kwargs["mask_channel_prob"] > 0:
106
+ mask_channel_indices = compute_mask_indices(
107
+ (B, C),
108
+ None,
109
+ self.mask_compute_kwargs["mask_channel_prob"],
110
+ self.mask_compute_kwargs["mask_channel_length"],
111
+ self.mask_compute_kwargs["mask_channel_selection"],
112
+ self.mask_compute_kwargs["mask_channel_other"],
113
+ no_overlap=self.mask_compute_kwargs["no_mask_channel_overlap"],
114
+ min_space=self.mask_compute_kwargs["mask_channel_min_space"],
115
+ )
116
+ mask_channel_indices = (
117
+ torch.from_numpy(mask_channel_indices).unsqueeze(1).expand(-1, T, -1)
118
+ )
119
+
120
+ return mask_indices, mask_channel_indices
121
+
122
+ @staticmethod
123
+ def _bucket_tensor(tensor, num_pad, value):
124
+ return F.pad(tensor, (0, num_pad), value=value)
125
+
126
+ def collater(self, samples):
127
+ samples = [s for s in samples if s["source"] is not None]
128
+ if len(samples) == 0:
129
+ return {}
130
+
131
+ sources = [s["source"] for s in samples]
132
+ sizes = [len(s) for s in sources]
133
+
134
+ if self.pad:
135
+ target_size = min(max(sizes), self.max_sample_size)
136
+ else:
137
+ target_size = min(min(sizes), self.max_sample_size)
138
+
139
+ collated_sources = sources[0].new_zeros(len(sources), target_size)
140
+ padding_mask = (
141
+ torch.BoolTensor(collated_sources.shape).fill_(False) if self.pad else None
142
+ )
143
+ for i, (source, size) in enumerate(zip(sources, sizes)):
144
+ diff = size - target_size
145
+ if diff == 0:
146
+ collated_sources[i] = source
147
+ elif diff < 0:
148
+ assert self.pad
149
+ collated_sources[i] = torch.cat(
150
+ [source, source.new_full((-diff,), 0.0)]
151
+ )
152
+ padding_mask[i, diff:] = True
153
+ else:
154
+ collated_sources[i] = self.crop_to_max_size(source, target_size)
155
+
156
+ input = {"source": collated_sources}
157
+ out = {"id": torch.LongTensor([s["id"] for s in samples])}
158
+ if self.pad:
159
+ input["padding_mask"] = padding_mask
160
+
161
+ if hasattr(self, "num_buckets") and self.num_buckets > 0:
162
+ assert self.pad, "Cannot bucket without padding first."
163
+ bucket = max(self._bucketed_sizes[s["id"]] for s in samples)
164
+ num_pad = bucket - collated_sources.size(-1)
165
+ if num_pad:
166
+ input["source"] = self._bucket_tensor(collated_sources, num_pad, 0)
167
+ input["padding_mask"] = self._bucket_tensor(padding_mask, num_pad, True)
168
+
169
+ if self.compute_mask_indices:
170
+ B = input["source"].size(0)
171
+ T = self._get_mask_indices_dims(input["source"].size(-1))
172
+ padding_mask_reshaped = input["padding_mask"].clone()
173
+ extra = padding_mask_reshaped.size(1) % T
174
+ if extra > 0:
175
+ padding_mask_reshaped = padding_mask_reshaped[:, :-extra]
176
+ padding_mask_reshaped = padding_mask_reshaped.view(
177
+ padding_mask_reshaped.size(0), T, -1
178
+ )
179
+ padding_mask_reshaped = padding_mask_reshaped.all(-1)
180
+ input["padding_count"] = padding_mask_reshaped.sum(-1).max().item()
181
+ mask_indices, mask_channel_indices = self._compute_mask_indices(
182
+ (B, T, self._C),
183
+ padding_mask_reshaped,
184
+ )
185
+ input["mask_indices"] = mask_indices
186
+ input["mask_channel_indices"] = mask_channel_indices
187
+ out["sample_size"] = mask_indices.sum().item()
188
+
189
+ out["net_input"] = input
190
+ return out
191
+
192
+ def _get_mask_indices_dims(self, size, padding=0, dilation=1):
193
+ if size not in self._features_size_map:
194
+ L_in = size
195
+ for (_, kernel_size, stride) in self._conv_feature_layers:
196
+ L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
197
+ L_out = 1 + L_out // stride
198
+ L_in = L_out
199
+ self._features_size_map[size] = L_out
200
+ return self._features_size_map[size]
201
+
202
+ def num_tokens(self, index):
203
+ return self.size(index)
204
+
205
+ def size(self, index):
206
+ """Return an example's size as a float or tuple. This value is used when
207
+ filtering a dataset with ``--max-positions``."""
208
+ if self.pad:
209
+ return self.sizes[index]
210
+ return min(self.sizes[index], self.max_sample_size)
211
+
212
+ def ordered_indices(self):
213
+ """Return an ordered list of indices. Batches will be constructed based
214
+ on this order."""
215
+
216
+ if self.shuffle:
217
+ order = [np.random.permutation(len(self))]
218
+ order.append(
219
+ np.minimum(
220
+ np.array(self.sizes),
221
+ self.max_sample_size,
222
+ )
223
+ )
224
+ return np.lexsort(order)[::-1]
225
+ else:
226
+ return np.arange(len(self))
227
+
228
+ def set_bucket_info(self, num_buckets):
229
+ self.num_buckets = num_buckets
230
+ if self.num_buckets > 0:
231
+ self._collated_sizes = np.minimum(
232
+ np.array(self.sizes),
233
+ self.max_sample_size,
234
+ )
235
+ self.buckets = get_buckets(
236
+ self._collated_sizes,
237
+ self.num_buckets,
238
+ )
239
+ self._bucketed_sizes = get_bucketed_sizes(
240
+ self._collated_sizes, self.buckets
241
+ )
242
+ logger.info(
243
+ f"{len(self.buckets)} bucket(s) for the audio dataset: "
244
+ f"{self.buckets}"
245
+ )
246
+
247
+
248
+ class FileAudioDataset(RawAudioDataset):
249
+ def __init__(
250
+ self,
251
+ manifest_path,
252
+ sample_rate,
253
+ max_sample_size=None,
254
+ min_sample_size=0,
255
+ shuffle=True,
256
+ pad=False,
257
+ normalize=False,
258
+ num_buckets=0,
259
+ compute_mask_indices=False,
260
+ text_compression_level=TextCompressionLevel.none,
261
+ **mask_compute_kwargs,
262
+ ):
263
+ super().__init__(
264
+ sample_rate=sample_rate,
265
+ max_sample_size=max_sample_size,
266
+ min_sample_size=min_sample_size,
267
+ shuffle=shuffle,
268
+ pad=pad,
269
+ normalize=normalize,
270
+ compute_mask_indices=compute_mask_indices,
271
+ **mask_compute_kwargs,
272
+ )
273
+
274
+ self.text_compressor = TextCompressor(level=text_compression_level)
275
+
276
+ skipped = 0
277
+ self.fnames = []
278
+ sizes = []
279
+ self.skipped_indices = set()
280
+
281
+ with open(manifest_path, "r") as f:
282
+ self.root_dir = f.readline().strip()
283
+ for i, line in enumerate(f):
284
+ items = line.strip().split("\t")
285
+ assert len(items) == 2, line
286
+ sz = int(items[1])
287
+ if min_sample_size is not None and sz < min_sample_size:
288
+ skipped += 1
289
+ self.skipped_indices.add(i)
290
+ continue
291
+ self.fnames.append(self.text_compressor.compress(items[0]))
292
+ sizes.append(sz)
293
+ logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples")
294
+
295
+ self.sizes = np.array(sizes, dtype=np.int64)
296
+
297
+ try:
298
+ import pyarrow
299
+
300
+ self.fnames = pyarrow.array(self.fnames)
301
+ except:
302
+ logger.debug(
303
+ "Could not create a pyarrow array. Please install pyarrow for better performance"
304
+ )
305
+ pass
306
+
307
+ self.set_bucket_info(num_buckets)
308
+
309
+ def __getitem__(self, index):
310
+ import soundfile as sf
311
+
312
+ fn = self.fnames[index]
313
+ fn = fn if isinstance(self.fnames, list) else fn.as_py()
314
+ fn = self.text_compressor.decompress(fn)
315
+ path_or_fp = os.path.join(self.root_dir, fn)
316
+ _path, slice_ptr = parse_path(path_or_fp)
317
+ if len(slice_ptr) == 2:
318
+ byte_data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
319
+ assert is_sf_audio_data(byte_data)
320
+ path_or_fp = io.BytesIO(byte_data)
321
+
322
+ wav, curr_sample_rate = sf.read(path_or_fp, dtype="float32")
323
+
324
+ feats = torch.from_numpy(wav).float()
325
+ feats = self.postprocess(feats, curr_sample_rate)
326
+ return {"id": index, "source": feats}
327
+
328
+
329
+ class BinarizedAudioDataset(RawAudioDataset):
330
+ def __init__(
331
+ self,
332
+ data_dir,
333
+ split,
334
+ sample_rate,
335
+ max_sample_size=None,
336
+ min_sample_size=0,
337
+ shuffle=True,
338
+ pad=False,
339
+ normalize=False,
340
+ num_buckets=0,
341
+ compute_mask_indices=False,
342
+ **mask_compute_kwargs,
343
+ ):
344
+ super().__init__(
345
+ sample_rate=sample_rate,
346
+ max_sample_size=max_sample_size,
347
+ min_sample_size=min_sample_size,
348
+ shuffle=shuffle,
349
+ pad=pad,
350
+ normalize=normalize,
351
+ compute_mask_indices=compute_mask_indices,
352
+ **mask_compute_kwargs,
353
+ )
354
+
355
+ from fairseq.data import data_utils, Dictionary
356
+
357
+ self.fnames_dict = Dictionary.load(os.path.join(data_dir, "dict.txt"))
358
+
359
+ root_path = os.path.join(data_dir, f"{split}.root")
360
+ if os.path.exists(root_path):
361
+ with open(root_path, "r") as f:
362
+ self.root_dir = next(f).strip()
363
+ else:
364
+ self.root_dir = None
365
+
366
+ fnames_path = os.path.join(data_dir, split)
367
+ self.fnames = data_utils.load_indexed_dataset(fnames_path, self.fnames_dict)
368
+ lengths_path = os.path.join(data_dir, f"{split}.lengths")
369
+
370
+ with open(lengths_path, "r") as f:
371
+ for line in f:
372
+ sz = int(line.rstrip())
373
+ assert (
374
+ sz >= min_sample_size
375
+ ), f"Min sample size is not supported for binarized dataset, but found a sample with size {sz}"
376
+ self.sizes.append(sz)
377
+
378
+ self.sizes = np.array(self.sizes, dtype=np.int64)
379
+
380
+ self.set_bucket_info(num_buckets)
381
+ logger.info(f"loaded {len(self.fnames)} samples")
382
+
383
+ def __getitem__(self, index):
384
+ import soundfile as sf
385
+
386
+ fname = self.fnames_dict.string(self.fnames[index], separator="")
387
+ if self.root_dir:
388
+ fname = os.path.join(self.root_dir, fname)
389
+
390
+ wav, curr_sample_rate = sf.read(fname)
391
+ feats = torch.from_numpy(wav).float()
392
+ feats = self.postprocess(feats, curr_sample_rate)
393
+ return {"id": index, "source": feats}
modules/voice_conversion/fairseq/data/audio/speech_to_speech_dataset.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from typing import Dict, List, Optional, Tuple
10
+
11
+ import torch
12
+
13
+ from fairseq.data import ConcatDataset, Dictionary
14
+ from fairseq.data import data_utils as fairseq_data_utils
15
+ from fairseq.data.audio.audio_utils import get_features_or_waveform
16
+ from fairseq.data.audio.data_cfg import S2SDataConfig
17
+ from fairseq.data.audio.speech_to_text_dataset import (
18
+ SpeechToTextDataset,
19
+ SpeechToTextDatasetCreator,
20
+ TextTargetMultitaskData,
21
+ _collate_frames,
22
+ )
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ @dataclass
28
+ class SpeechToSpeechDatasetItem(object):
29
+ index: int
30
+ source: torch.Tensor
31
+ target: Optional[torch.Tensor] = None
32
+ target_speaker: Optional[torch.Tensor] = None
33
+ tgt_lang_tag: Optional[int] = None
34
+
35
+
36
+ class SpeechToSpeechDataset(SpeechToTextDataset):
37
+ def __init__(
38
+ self,
39
+ split: str,
40
+ is_train_split: bool,
41
+ data_cfg: S2SDataConfig,
42
+ src_audio_paths: List[str],
43
+ src_n_frames: List[int],
44
+ tgt_audio_paths: List[str],
45
+ tgt_n_frames: List[int],
46
+ src_langs: Optional[List[str]] = None,
47
+ tgt_langs: Optional[List[str]] = None,
48
+ ids: Optional[List[str]] = None,
49
+ target_is_code: bool = False,
50
+ tgt_dict: Dictionary = None,
51
+ n_frames_per_step: int = 1,
52
+ ):
53
+ tgt_texts = tgt_audio_paths if target_is_code else None
54
+ super().__init__(
55
+ split=split,
56
+ is_train_split=is_train_split,
57
+ cfg=data_cfg,
58
+ audio_paths=src_audio_paths,
59
+ n_frames=src_n_frames,
60
+ ids=ids,
61
+ tgt_dict=tgt_dict,
62
+ tgt_texts=tgt_texts,
63
+ src_langs=src_langs,
64
+ tgt_langs=tgt_langs,
65
+ n_frames_per_step=n_frames_per_step,
66
+ )
67
+
68
+ self.tgt_audio_paths = tgt_audio_paths
69
+ self.tgt_lens = [t // self.n_frames_per_step for t in tgt_n_frames]
70
+
71
+ assert not target_is_code or tgt_dict is not None
72
+ self.target_is_code = target_is_code
73
+
74
+ assert len(tgt_audio_paths) == self.n_samples
75
+ assert len(tgt_n_frames) == self.n_samples
76
+
77
+ self.tgt_speakers = None
78
+ if self.cfg.target_speaker_embed:
79
+ samples = SpeechToTextDatasetCreator._load_samples_from_tsv(
80
+ self.cfg.target_speaker_embed, split
81
+ )
82
+ spk_emb_dict = {s["id"]: s["speaker_embed"] for s in samples}
83
+ self.tgt_speakers = [spk_emb_dict[id] for id in self.ids]
84
+ assert len(self.tgt_speakers) == self.n_samples
85
+
86
+ logger.info(self.__repr__())
87
+
88
+ def pack_units(self, input: torch.Tensor) -> torch.Tensor:
89
+ if self.n_frames_per_step <= 1:
90
+ return input
91
+
92
+ offset = 4
93
+ vocab_size = (
94
+ len(self.tgt_dict) - offset
95
+ ) # remove offset from <bos>, <pad>, <eos>, <unk>, which is specific to fairseq dictionary
96
+
97
+ assert input.dim() == 1
98
+ stacked_input = (
99
+ input[:-1].view(-1, self.n_frames_per_step) - offset
100
+ ) # remove <eos>
101
+ scale = [
102
+ pow(vocab_size, self.n_frames_per_step - 1 - i)
103
+ for i in range(self.n_frames_per_step)
104
+ ]
105
+ scale = torch.LongTensor(scale).squeeze(0)
106
+ res = input.new((len(input) - 1) // self.n_frames_per_step + 1).fill_(input[-1])
107
+ res[:-1] = (stacked_input * scale).sum(dim=1) + offset
108
+
109
+ return res
110
+
111
+ def __getitem__(self, index: int) -> SpeechToSpeechDatasetItem:
112
+ source = self._get_source_audio(index)
113
+
114
+ tgt_lang_tag = None
115
+ if self.cfg.prepend_tgt_lang_tag_as_bos:
116
+ # prepend_tgt_lang_tag_as_bos: put tgt_lang_tag as bos of target
117
+ tgt_lang_tag = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict)
118
+
119
+ if not self.target_is_code:
120
+ target = get_features_or_waveform(self.tgt_audio_paths[index])
121
+ target = torch.from_numpy(target).float()
122
+ target = self.pack_frames(target)
123
+ else:
124
+ target = self.tgt_dict.encode_line(
125
+ self.tgt_audio_paths[index],
126
+ add_if_not_exist=False,
127
+ append_eos=True,
128
+ ).long()
129
+ if self.n_frames_per_step > 1:
130
+ n_tgt_frame = target.size(0) - 1 # exclude <eos>
131
+ keep_n_tgt_frame = n_tgt_frame - n_tgt_frame % self.n_frames_per_step
132
+ target = torch.cat(
133
+ (
134
+ target[:keep_n_tgt_frame],
135
+ target.new_full((1,), self.tgt_dict.eos()),
136
+ ),
137
+ dim=0,
138
+ )
139
+
140
+ if self.tgt_speakers:
141
+ tgt_spk = get_features_or_waveform(self.tgt_speakers[index])
142
+ tgt_spk = torch.from_numpy(tgt_spk).float()
143
+ else:
144
+ tgt_spk = torch.FloatTensor([])
145
+
146
+ return SpeechToSpeechDatasetItem(
147
+ index=index,
148
+ source=source,
149
+ target=target,
150
+ target_speaker=tgt_spk,
151
+ tgt_lang_tag=tgt_lang_tag,
152
+ )
153
+
154
+ def _collate_target(self, samples: List[SpeechToSpeechDatasetItem]) -> torch.Tensor:
155
+ if self.target_is_code:
156
+ target = fairseq_data_utils.collate_tokens(
157
+ [x.target for x in samples],
158
+ self.tgt_dict.pad(),
159
+ self.tgt_dict.eos(),
160
+ left_pad=False,
161
+ move_eos_to_beginning=False,
162
+ )
163
+ # convert stacked units to a single id
164
+ pack_targets = [self.pack_units(x.target) for x in samples]
165
+ prev_output_tokens = fairseq_data_utils.collate_tokens(
166
+ pack_targets,
167
+ self.tgt_dict.pad(),
168
+ self.tgt_dict.eos(),
169
+ left_pad=False,
170
+ move_eos_to_beginning=True,
171
+ )
172
+ target_lengths = torch.tensor(
173
+ [x.size(0) for x in pack_targets], dtype=torch.long
174
+ )
175
+ else:
176
+ target = _collate_frames([x.target for x in samples], is_audio_input=False)
177
+ bsz, _, d = target.size()
178
+ prev_output_tokens = torch.cat(
179
+ (target.new_full((bsz, 1, d), 0.0), target[:, :-1, :]), dim=1
180
+ )
181
+ target_lengths = torch.tensor(
182
+ [x.target.size(0) for x in samples], dtype=torch.long
183
+ )
184
+
185
+ return target, prev_output_tokens, target_lengths
186
+
187
+ def collater(
188
+ self, samples: List[SpeechToSpeechDatasetItem], return_order: bool = False
189
+ ) -> Dict:
190
+ if len(samples) == 0:
191
+ return {}
192
+ indices = torch.tensor([x.index for x in samples], dtype=torch.long)
193
+ frames = _collate_frames([x.source for x in samples], self.cfg.use_audio_input)
194
+ # sort samples by descending number of frames
195
+ n_frames = torch.tensor([x.source.size(0) for x in samples], dtype=torch.long)
196
+ n_frames, order = n_frames.sort(descending=True)
197
+ indices = indices.index_select(0, order)
198
+ frames = frames.index_select(0, order)
199
+
200
+ target, prev_output_tokens, target_lengths = self._collate_target(samples)
201
+ target = target.index_select(0, order)
202
+ target_lengths = target_lengths.index_select(0, order)
203
+ prev_output_tokens = prev_output_tokens.index_select(0, order)
204
+ ntokens = sum(x.target.size(0) for x in samples)
205
+
206
+ tgt_speakers = None
207
+ if self.cfg.target_speaker_embed:
208
+ tgt_speakers = _collate_frames(
209
+ [x.target_speaker for x in samples], is_audio_input=True
210
+ ).index_select(0, order)
211
+
212
+ net_input = {
213
+ "src_tokens": frames,
214
+ "src_lengths": n_frames,
215
+ "prev_output_tokens": prev_output_tokens,
216
+ "tgt_speaker": tgt_speakers, # TODO: unify "speaker" and "tgt_speaker"
217
+ }
218
+ if self.tgt_texts is not None and samples[0].tgt_lang_tag is not None:
219
+ for i in range(len(samples)):
220
+ net_input["prev_output_tokens"][i][0] = samples[order[i]].tgt_lang_tag
221
+ out = {
222
+ "id": indices,
223
+ "net_input": net_input,
224
+ "speaker": tgt_speakers, # to support Tacotron2 loss for speech-to-spectrogram model
225
+ "target": target,
226
+ "target_lengths": target_lengths,
227
+ "ntokens": ntokens,
228
+ "nsentences": len(samples),
229
+ }
230
+ if return_order:
231
+ out["order"] = order
232
+ return out
233
+
234
+
235
+ class SpeechToSpeechMultitaskDataset(SpeechToSpeechDataset):
236
+ def __init__(self, **kwargs):
237
+ super().__init__(**kwargs)
238
+ self.multitask_data = {}
239
+
240
+ def add_multitask_dataset(self, task_name, task_data):
241
+ self.multitask_data[task_name] = task_data
242
+
243
+ def __getitem__(
244
+ self, index: int
245
+ ) -> Tuple[SpeechToSpeechDatasetItem, Dict[str, torch.Tensor]]:
246
+ s2s_data = super().__getitem__(index)
247
+
248
+ multitask_target = {}
249
+ sample_id = self.ids[index]
250
+ tgt_lang = self.tgt_langs[index]
251
+ for task_name, task_dataset in self.multitask_data.items():
252
+ multitask_target[task_name] = task_dataset.get(sample_id, tgt_lang)
253
+
254
+ return s2s_data, multitask_target
255
+
256
+ def collater(
257
+ self, samples: List[Tuple[SpeechToSpeechDatasetItem, Dict[str, torch.Tensor]]]
258
+ ) -> Dict:
259
+ if len(samples) == 0:
260
+ return {}
261
+
262
+ out = super().collater([s for s, _ in samples], return_order=True)
263
+ order = out["order"]
264
+ del out["order"]
265
+
266
+ for task_name, task_dataset in self.multitask_data.items():
267
+ if "multitask" not in out:
268
+ out["multitask"] = {}
269
+ d = [s[task_name] for _, s in samples]
270
+ task_target = task_dataset.collater(d)
271
+ out["multitask"][task_name] = {
272
+ "target": task_target["target"].index_select(0, order),
273
+ "target_lengths": task_target["target_lengths"].index_select(0, order),
274
+ "ntokens": task_target["ntokens"],
275
+ }
276
+ out["multitask"][task_name]["net_input"] = {
277
+ "prev_output_tokens": task_target["prev_output_tokens"].index_select(
278
+ 0, order
279
+ ),
280
+ }
281
+
282
+ return out
283
+
284
+
285
+ class SpeechToSpeechDatasetCreator(object):
286
+ # mandatory columns
287
+ KEY_ID, KEY_SRC_AUDIO, KEY_SRC_N_FRAMES = "id", "src_audio", "src_n_frames"
288
+ KEY_TGT_AUDIO, KEY_TGT_N_FRAMES = "tgt_audio", "tgt_n_frames"
289
+ # optional columns
290
+ KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang"
291
+ # default values
292
+ DEFAULT_LANG = ""
293
+
294
+ @classmethod
295
+ def _from_list(
296
+ cls,
297
+ split_name: str,
298
+ is_train_split,
299
+ samples: List[Dict],
300
+ data_cfg: S2SDataConfig,
301
+ target_is_code: bool = False,
302
+ tgt_dict: Dictionary = None,
303
+ n_frames_per_step: int = 1,
304
+ multitask: Optional[Dict] = None,
305
+ ) -> SpeechToSpeechDataset:
306
+ audio_root = Path(data_cfg.audio_root)
307
+ ids = [s[cls.KEY_ID] for s in samples]
308
+ src_audio_paths = [
309
+ (audio_root / s[cls.KEY_SRC_AUDIO]).as_posix() for s in samples
310
+ ]
311
+ tgt_audio_paths = [
312
+ s[cls.KEY_TGT_AUDIO]
313
+ if target_is_code
314
+ else (audio_root / s[cls.KEY_TGT_AUDIO]).as_posix()
315
+ for s in samples
316
+ ]
317
+ src_n_frames = [int(s[cls.KEY_SRC_N_FRAMES]) for s in samples]
318
+ tgt_n_frames = [int(s[cls.KEY_TGT_N_FRAMES]) for s in samples]
319
+ src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
320
+ tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
321
+
322
+ has_multitask = multitask is not None and len(multitask.keys()) > 0
323
+ dataset_cls = (
324
+ SpeechToSpeechMultitaskDataset if has_multitask else SpeechToSpeechDataset
325
+ )
326
+
327
+ ds = dataset_cls(
328
+ split=split_name,
329
+ is_train_split=is_train_split,
330
+ data_cfg=data_cfg,
331
+ src_audio_paths=src_audio_paths,
332
+ src_n_frames=src_n_frames,
333
+ tgt_audio_paths=tgt_audio_paths,
334
+ tgt_n_frames=tgt_n_frames,
335
+ src_langs=src_langs,
336
+ tgt_langs=tgt_langs,
337
+ ids=ids,
338
+ target_is_code=target_is_code,
339
+ tgt_dict=tgt_dict,
340
+ n_frames_per_step=n_frames_per_step,
341
+ )
342
+
343
+ if has_multitask:
344
+ for task_name, task_obj in multitask.items():
345
+ task_data = TextTargetMultitaskData(
346
+ task_obj.args, split_name, task_obj.target_dictionary
347
+ )
348
+ ds.add_multitask_dataset(task_name, task_data)
349
+ return ds
350
+
351
+ @classmethod
352
+ def from_tsv(
353
+ cls,
354
+ root: str,
355
+ data_cfg: S2SDataConfig,
356
+ splits: str,
357
+ is_train_split: bool,
358
+ epoch: int,
359
+ seed: int,
360
+ target_is_code: bool = False,
361
+ tgt_dict: Dictionary = None,
362
+ n_frames_per_step: int = 1,
363
+ multitask: Optional[Dict] = None,
364
+ ) -> SpeechToSpeechDataset:
365
+ datasets = []
366
+ for split in splits.split(","):
367
+ samples = SpeechToTextDatasetCreator._load_samples_from_tsv(root, split)
368
+ ds = cls._from_list(
369
+ split_name=split,
370
+ is_train_split=is_train_split,
371
+ samples=samples,
372
+ data_cfg=data_cfg,
373
+ target_is_code=target_is_code,
374
+ tgt_dict=tgt_dict,
375
+ n_frames_per_step=n_frames_per_step,
376
+ multitask=multitask,
377
+ )
378
+ datasets.append(ds)
379
+ return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
modules/voice_conversion/fairseq/data/audio/speech_to_text_dataset.py ADDED
@@ -0,0 +1,733 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import csv
7
+ import logging
8
+ import re
9
+ from argparse import Namespace
10
+ from collections import defaultdict
11
+ from dataclasses import dataclass
12
+ from pathlib import Path
13
+ from typing import Dict, List, Optional, Tuple, Union
14
+
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn.functional as F
18
+
19
+ from fairseq.data import ConcatDataset, Dictionary, FairseqDataset, ResamplingDataset
20
+ from fairseq.data import data_utils as fairseq_data_utils
21
+ from fairseq.data import encoders
22
+ from fairseq.data.audio.audio_utils import get_features_or_waveform
23
+ from fairseq.data.audio.data_cfg import S2TDataConfig
24
+ from fairseq.data.audio.dataset_transforms import CompositeAudioDatasetTransform
25
+ from fairseq.data.audio.dataset_transforms.concataugment import ConcatAugment
26
+ from fairseq.data.audio.dataset_transforms.noisyoverlapaugment import (
27
+ NoisyOverlapAugment,
28
+ )
29
+ from fairseq.data.audio.feature_transforms import CompositeAudioFeatureTransform
30
+ from fairseq.data.audio.waveform_transforms import CompositeAudioWaveformTransform
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ def _collate_frames(
36
+ frames: List[torch.Tensor], is_audio_input: bool = False
37
+ ) -> torch.Tensor:
38
+ """
39
+ Convert a list of 2D frames into a padded 3D tensor
40
+ Args:
41
+ frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
42
+ length of i-th frame and f_dim is static dimension of features
43
+ Returns:
44
+ 3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
45
+ """
46
+ max_len = max(frame.size(0) for frame in frames)
47
+ if is_audio_input:
48
+ out = frames[0].new_zeros((len(frames), max_len))
49
+ else:
50
+ out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1)))
51
+ for i, v in enumerate(frames):
52
+ out[i, : v.size(0)] = v
53
+ return out
54
+
55
+
56
+ def _is_int_or_np_int(n):
57
+ return isinstance(n, int) or (
58
+ isinstance(n, np.generic) and isinstance(n.item(), int)
59
+ )
60
+
61
+
62
+ @dataclass
63
+ class SpeechToTextDatasetItem(object):
64
+ index: int
65
+ source: torch.Tensor
66
+ target: Optional[torch.Tensor] = None
67
+ speaker_id: Optional[int] = None
68
+
69
+
70
+ class SpeechToTextDataset(FairseqDataset):
71
+ LANG_TAG_TEMPLATE = "<lang:{}>"
72
+
73
+ def __init__(
74
+ self,
75
+ split: str,
76
+ is_train_split: bool,
77
+ cfg: S2TDataConfig,
78
+ audio_paths: List[str],
79
+ n_frames: List[int],
80
+ src_texts: Optional[List[str]] = None,
81
+ tgt_texts: Optional[List[str]] = None,
82
+ speakers: Optional[List[str]] = None,
83
+ src_langs: Optional[List[str]] = None,
84
+ tgt_langs: Optional[List[str]] = None,
85
+ ids: Optional[List[str]] = None,
86
+ tgt_dict: Optional[Dictionary] = None,
87
+ pre_tokenizer=None,
88
+ bpe_tokenizer=None,
89
+ n_frames_per_step=1,
90
+ speaker_to_id=None,
91
+ append_eos=True,
92
+ ):
93
+ self.split, self.is_train_split = split, is_train_split
94
+ self.cfg = cfg
95
+ self.audio_paths, self.n_frames = audio_paths, n_frames
96
+ self.n_samples = len(audio_paths)
97
+ assert len(n_frames) == self.n_samples > 0
98
+ assert src_texts is None or len(src_texts) == self.n_samples
99
+ assert tgt_texts is None or len(tgt_texts) == self.n_samples
100
+ assert speakers is None or len(speakers) == self.n_samples
101
+ assert src_langs is None or len(src_langs) == self.n_samples
102
+ assert tgt_langs is None or len(tgt_langs) == self.n_samples
103
+ assert ids is None or len(ids) == self.n_samples
104
+ assert (tgt_dict is None and tgt_texts is None) or (
105
+ tgt_dict is not None and tgt_texts is not None
106
+ )
107
+ self.src_texts, self.tgt_texts = src_texts, tgt_texts
108
+ self.src_langs, self.tgt_langs = src_langs, tgt_langs
109
+ self.speakers = speakers
110
+ self.tgt_dict = tgt_dict
111
+ self.check_tgt_lang_tag()
112
+ self.ids = ids
113
+ self.shuffle = cfg.shuffle if is_train_split else False
114
+
115
+ self.feature_transforms = CompositeAudioFeatureTransform.from_config_dict(
116
+ self.cfg.get_feature_transforms(split, is_train_split)
117
+ )
118
+ self.waveform_transforms = CompositeAudioWaveformTransform.from_config_dict(
119
+ self.cfg.get_waveform_transforms(split, is_train_split)
120
+ )
121
+ # TODO: add these to data_cfg.py
122
+ self.dataset_transforms = CompositeAudioDatasetTransform.from_config_dict(
123
+ self.cfg.get_dataset_transforms(split, is_train_split)
124
+ )
125
+
126
+ # check proper usage of transforms
127
+ if self.feature_transforms and self.cfg.use_audio_input:
128
+ logger.warning(
129
+ "Feature transforms will not be applied. To use feature transforms, "
130
+ "set use_audio_input as False in config."
131
+ )
132
+
133
+ self.pre_tokenizer = pre_tokenizer
134
+ self.bpe_tokenizer = bpe_tokenizer
135
+ self.n_frames_per_step = n_frames_per_step
136
+ self.speaker_to_id = speaker_to_id
137
+
138
+ self.tgt_lens = self.get_tgt_lens_and_check_oov()
139
+ self.append_eos = append_eos
140
+
141
+ logger.info(self.__repr__())
142
+
143
+ def get_tgt_lens_and_check_oov(self):
144
+ if self.tgt_texts is None:
145
+ return [0 for _ in range(self.n_samples)]
146
+ tgt_lens = []
147
+ n_tokens, n_oov_tokens = 0, 0
148
+ for i in range(self.n_samples):
149
+ tokenized = self.get_tokenized_tgt_text(i).split(" ")
150
+ oov_tokens = [
151
+ t
152
+ for t in tokenized
153
+ if self.tgt_dict.index(t) == self.tgt_dict.unk_index
154
+ ]
155
+ n_tokens += len(tokenized)
156
+ n_oov_tokens += len(oov_tokens)
157
+ tgt_lens.append(len(tokenized))
158
+ logger.info(f"'{self.split}' has {n_oov_tokens / n_tokens * 100:.2f}% OOV")
159
+ return tgt_lens
160
+
161
+ def __repr__(self):
162
+ return (
163
+ self.__class__.__name__
164
+ + f'(split="{self.split}", n_samples={self.n_samples:_}, '
165
+ f"prepend_tgt_lang_tag={self.cfg.prepend_tgt_lang_tag}, "
166
+ f"n_frames_per_step={self.n_frames_per_step}, "
167
+ f"shuffle={self.shuffle}, "
168
+ f"feature_transforms={self.feature_transforms}, "
169
+ f"waveform_transforms={self.waveform_transforms}, "
170
+ f"dataset_transforms={self.dataset_transforms})"
171
+ )
172
+
173
+ @classmethod
174
+ def is_lang_tag(cls, token):
175
+ pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)")
176
+ return re.match(pattern, token)
177
+
178
+ def check_tgt_lang_tag(self):
179
+ if self.cfg.prepend_tgt_lang_tag:
180
+ assert self.tgt_langs is not None and self.tgt_dict is not None
181
+ tgt_lang_tags = [
182
+ self.LANG_TAG_TEMPLATE.format(t) for t in set(self.tgt_langs)
183
+ ]
184
+ assert all(t in self.tgt_dict for t in tgt_lang_tags)
185
+
186
+ @classmethod
187
+ def tokenize(cls, tokenizer, text: str):
188
+ return text if tokenizer is None else tokenizer.encode(text)
189
+
190
+ def get_tokenized_tgt_text(self, index: Union[int, List[int]]):
191
+ if _is_int_or_np_int(index):
192
+ text = self.tgt_texts[index]
193
+ else:
194
+ text = " ".join([self.tgt_texts[i] for i in index])
195
+
196
+ text = self.tokenize(self.pre_tokenizer, text)
197
+ text = self.tokenize(self.bpe_tokenizer, text)
198
+ return text
199
+
200
+ def pack_frames(self, feature: torch.Tensor):
201
+ if self.n_frames_per_step == 1:
202
+ return feature
203
+ n_packed_frames = feature.shape[0] // self.n_frames_per_step
204
+ feature = feature[: self.n_frames_per_step * n_packed_frames]
205
+ return feature.reshape(n_packed_frames, -1)
206
+
207
+ @classmethod
208
+ def get_lang_tag_idx(cls, lang: str, dictionary: Dictionary):
209
+ lang_tag_idx = dictionary.index(cls.LANG_TAG_TEMPLATE.format(lang))
210
+ assert lang_tag_idx != dictionary.unk()
211
+ return lang_tag_idx
212
+
213
+ def _get_source_audio(self, index: Union[int, List[int]]) -> torch.Tensor:
214
+ """
215
+ Gives source audio for given index with any relevant transforms
216
+ applied. For ConcatAug, source audios for given indices are
217
+ concatenated in given order.
218
+ Args:
219
+ index (int or List[int]): index—or in the case of ConcatAug,
220
+ indices—to pull the source audio for
221
+ Returns:
222
+ source audios concatenated for given indices with
223
+ relevant transforms appplied
224
+ """
225
+ if _is_int_or_np_int(index):
226
+ source = get_features_or_waveform(
227
+ self.audio_paths[index],
228
+ need_waveform=self.cfg.use_audio_input,
229
+ use_sample_rate=self.cfg.use_sample_rate,
230
+ waveform_transforms=self.waveform_transforms,
231
+ )
232
+ else:
233
+ source = np.concatenate(
234
+ [
235
+ get_features_or_waveform(
236
+ self.audio_paths[i],
237
+ need_waveform=self.cfg.use_audio_input,
238
+ use_sample_rate=self.cfg.use_sample_rate,
239
+ waveform_transforms=self.waveform_transforms,
240
+ )
241
+ for i in index
242
+ ]
243
+ )
244
+ if self.cfg.use_audio_input:
245
+ source = torch.from_numpy(source).float()
246
+ if self.cfg.standardize_audio:
247
+ with torch.no_grad():
248
+ source = F.layer_norm(source, source.shape)
249
+ else:
250
+ if self.feature_transforms is not None:
251
+ source = self.feature_transforms(source)
252
+ source = torch.from_numpy(source).float()
253
+ return source
254
+
255
+ def __getitem__(self, index: int) -> SpeechToTextDatasetItem:
256
+ has_concat = self.dataset_transforms.has_transform(ConcatAugment)
257
+ if has_concat:
258
+ concat = self.dataset_transforms.get_transform(ConcatAugment)
259
+ indices = concat.find_indices(index, self.n_frames, self.n_samples)
260
+
261
+ source = self._get_source_audio(indices if has_concat else index)
262
+ source = self.pack_frames(source)
263
+
264
+ target = None
265
+ if self.tgt_texts is not None:
266
+ tokenized = self.get_tokenized_tgt_text(indices if has_concat else index)
267
+ target = self.tgt_dict.encode_line(
268
+ tokenized, add_if_not_exist=False, append_eos=self.append_eos
269
+ ).long()
270
+ if self.cfg.prepend_tgt_lang_tag:
271
+ lang_tag_idx = self.get_lang_tag_idx(
272
+ self.tgt_langs[index], self.tgt_dict
273
+ )
274
+ target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0)
275
+
276
+ if self.cfg.prepend_bos_and_append_tgt_lang_tag:
277
+ bos = torch.LongTensor([self.tgt_dict.bos()])
278
+ lang_tag_idx = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict)
279
+ assert lang_tag_idx != self.tgt_dict.unk()
280
+ lang_tag_idx = torch.LongTensor([lang_tag_idx])
281
+ target = torch.cat((bos, target, lang_tag_idx), 0)
282
+
283
+ speaker_id = None
284
+ if self.speaker_to_id is not None:
285
+ speaker_id = self.speaker_to_id[self.speakers[index]]
286
+ return SpeechToTextDatasetItem(
287
+ index=index, source=source, target=target, speaker_id=speaker_id
288
+ )
289
+
290
+ def __len__(self):
291
+ return self.n_samples
292
+
293
+ def collater(
294
+ self, samples: List[SpeechToTextDatasetItem], return_order: bool = False
295
+ ) -> Dict:
296
+ if len(samples) == 0:
297
+ return {}
298
+ indices = torch.tensor([x.index for x in samples], dtype=torch.long)
299
+
300
+ sources = [x.source for x in samples]
301
+ has_NOAug = self.dataset_transforms.has_transform(NoisyOverlapAugment)
302
+ if has_NOAug and self.cfg.use_audio_input:
303
+ NOAug = self.dataset_transforms.get_transform(NoisyOverlapAugment)
304
+ sources = NOAug(sources)
305
+
306
+ frames = _collate_frames(sources, self.cfg.use_audio_input)
307
+ # sort samples by descending number of frames
308
+ n_frames = torch.tensor([x.size(0) for x in sources], dtype=torch.long)
309
+ n_frames, order = n_frames.sort(descending=True)
310
+ indices = indices.index_select(0, order)
311
+ frames = frames.index_select(0, order)
312
+
313
+ target, target_lengths = None, None
314
+ prev_output_tokens = None
315
+ ntokens = None
316
+ if self.tgt_texts is not None:
317
+ target = fairseq_data_utils.collate_tokens(
318
+ [x.target for x in samples],
319
+ self.tgt_dict.pad(),
320
+ self.tgt_dict.eos(),
321
+ left_pad=False,
322
+ move_eos_to_beginning=False,
323
+ )
324
+ target = target.index_select(0, order)
325
+ target_lengths = torch.tensor(
326
+ [x.target.size(0) for x in samples], dtype=torch.long
327
+ ).index_select(0, order)
328
+ prev_output_tokens = fairseq_data_utils.collate_tokens(
329
+ [x.target for x in samples],
330
+ self.tgt_dict.pad(),
331
+ eos_idx=None,
332
+ left_pad=False,
333
+ move_eos_to_beginning=True,
334
+ )
335
+ prev_output_tokens = prev_output_tokens.index_select(0, order)
336
+ ntokens = sum(x.target.size(0) for x in samples)
337
+
338
+ speaker = None
339
+ if self.speaker_to_id is not None:
340
+ speaker = (
341
+ torch.tensor([s.speaker_id for s in samples], dtype=torch.long)
342
+ .index_select(0, order)
343
+ .view(-1, 1)
344
+ )
345
+
346
+ net_input = {
347
+ "src_tokens": frames,
348
+ "src_lengths": n_frames,
349
+ "prev_output_tokens": prev_output_tokens,
350
+ }
351
+ out = {
352
+ "id": indices,
353
+ "net_input": net_input,
354
+ "speaker": speaker,
355
+ "target": target,
356
+ "target_lengths": target_lengths,
357
+ "ntokens": ntokens,
358
+ "nsentences": len(samples),
359
+ }
360
+ if return_order:
361
+ out["order"] = order
362
+ return out
363
+
364
+ def num_tokens(self, index):
365
+ return self.n_frames[index]
366
+
367
+ def size(self, index):
368
+ return self.n_frames[index], self.tgt_lens[index]
369
+
370
+ @property
371
+ def sizes(self):
372
+ return np.array(self.n_frames)
373
+
374
+ @property
375
+ def can_reuse_epoch_itr_across_epochs(self):
376
+ return True
377
+
378
+ def ordered_indices(self):
379
+ if self.shuffle:
380
+ order = [np.random.permutation(len(self))]
381
+ else:
382
+ order = [np.arange(len(self))]
383
+ # first by descending order of # of frames then by original/random order
384
+ order.append([-n for n in self.n_frames])
385
+ return np.lexsort(order)
386
+
387
+ def prefetch(self, indices):
388
+ raise False
389
+
390
+
391
+ class TextTargetMultitaskData(object):
392
+ # mandatory columns
393
+ KEY_ID, KEY_TEXT = "id", "tgt_text"
394
+ LANG_TAG_TEMPLATE = "<lang:{}>"
395
+
396
+ def __init__(self, args, split, tgt_dict):
397
+ samples = SpeechToTextDatasetCreator._load_samples_from_tsv(args.data, split)
398
+ self.data = {s[self.KEY_ID]: s[self.KEY_TEXT] for s in samples}
399
+ self.dict = tgt_dict
400
+ self.append_eos = args.decoder_type != "ctc"
401
+ self.pre_tokenizer = self.build_tokenizer(args)
402
+ self.bpe_tokenizer = self.build_bpe(args)
403
+ self.prepend_bos_and_append_tgt_lang_tag = (
404
+ args.prepend_bos_and_append_tgt_lang_tag
405
+ )
406
+ self.eos_token = args.eos_token
407
+ self.lang_tag_mapping = args.get_lang_tag_mapping
408
+
409
+ @classmethod
410
+ def is_lang_tag(cls, token):
411
+ pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)")
412
+ return re.match(pattern, token)
413
+
414
+ @classmethod
415
+ def tokenize(cls, tokenizer, text: str):
416
+ return text if tokenizer is None else tokenizer.encode(text)
417
+
418
+ def get_tokenized_tgt_text(self, index: int):
419
+ text = self.tokenize(self.pre_tokenizer, self.data[index])
420
+ text = self.tokenize(self.bpe_tokenizer, text)
421
+ return text
422
+
423
+ def get_lang_tag_idx(self, lang: str, dictionary: Dictionary):
424
+ lang_tag = self.LANG_TAG_TEMPLATE.format(lang)
425
+ lang_tag = self.lang_tag_mapping.get(lang_tag, lang_tag)
426
+ lang_tag_idx = dictionary.index(lang_tag)
427
+ assert lang_tag_idx != dictionary.unk(), (lang, lang_tag)
428
+ return lang_tag_idx
429
+
430
+ def build_tokenizer(self, args):
431
+ pre_tokenizer = args.config.get("pre_tokenizer")
432
+ if pre_tokenizer is not None:
433
+ logger.info(f"pre-tokenizer: {pre_tokenizer}")
434
+ return encoders.build_tokenizer(Namespace(**pre_tokenizer))
435
+ else:
436
+ return None
437
+
438
+ def build_bpe(self, args):
439
+ bpe_tokenizer = args.config.get("bpe_tokenizer")
440
+ if bpe_tokenizer is not None:
441
+ logger.info(f"tokenizer: {bpe_tokenizer}")
442
+ return encoders.build_bpe(Namespace(**bpe_tokenizer))
443
+ else:
444
+ return None
445
+
446
+ def get(self, sample_id, tgt_lang=None):
447
+ if sample_id in self.data:
448
+ tokenized = self.get_tokenized_tgt_text(sample_id)
449
+ target = self.dict.encode_line(
450
+ tokenized,
451
+ add_if_not_exist=False,
452
+ append_eos=self.append_eos,
453
+ )
454
+ if self.prepend_bos_and_append_tgt_lang_tag:
455
+ bos = torch.LongTensor([self.dict.bos()])
456
+ lang_tag_idx = self.get_lang_tag_idx(tgt_lang, self.dict)
457
+ assert lang_tag_idx != self.dict.unk()
458
+ lang_tag_idx = torch.LongTensor([lang_tag_idx])
459
+ target = torch.cat((bos, target, lang_tag_idx), 0)
460
+ return target
461
+ else:
462
+ logger.warning(f"no target for {sample_id}")
463
+ return torch.IntTensor([])
464
+
465
+ def collater(self, samples: List[torch.Tensor]) -> torch.Tensor:
466
+ out = fairseq_data_utils.collate_tokens(
467
+ samples,
468
+ self.dict.pad(),
469
+ eos_idx=None,
470
+ left_pad=False,
471
+ move_eos_to_beginning=False,
472
+ ).long()
473
+
474
+ prev_out = fairseq_data_utils.collate_tokens(
475
+ samples,
476
+ self.dict.pad(),
477
+ eos_idx=None,
478
+ left_pad=False,
479
+ move_eos_to_beginning=True,
480
+ ).long()
481
+
482
+ target_lengths = torch.tensor([t.size(0) for t in samples], dtype=torch.long)
483
+ ntokens = sum(t.size(0) for t in samples)
484
+
485
+ output = {
486
+ "prev_output_tokens": prev_out,
487
+ "target": out,
488
+ "target_lengths": target_lengths,
489
+ "ntokens": ntokens,
490
+ }
491
+
492
+ return output
493
+
494
+
495
+ class SpeechToTextMultitaskDataset(SpeechToTextDataset):
496
+ def __init__(self, **kwargs):
497
+ super().__init__(**kwargs)
498
+ self.multitask_data = {}
499
+
500
+ def add_multitask_dataset(self, task_name, task_data):
501
+ self.multitask_data[task_name] = task_data
502
+
503
+ def __getitem__(
504
+ self, index: int
505
+ ) -> Tuple[SpeechToTextDatasetItem, Dict[str, torch.Tensor]]:
506
+ s2t_data = super().__getitem__(index)
507
+
508
+ multitask_target = {}
509
+ sample_id = self.ids[index]
510
+ tgt_lang = self.tgt_langs[index]
511
+ for task_name, task_dataset in self.multitask_data.items():
512
+ multitask_target[task_name] = task_dataset.get(sample_id, tgt_lang)
513
+
514
+ return s2t_data, multitask_target
515
+
516
+ def collater(
517
+ self, samples: List[Tuple[SpeechToTextDatasetItem, Dict[str, torch.Tensor]]]
518
+ ) -> Dict:
519
+ if len(samples) == 0:
520
+ return {}
521
+
522
+ out = super().collater([s for s, _ in samples], return_order=True)
523
+ order = out["order"]
524
+ del out["order"]
525
+
526
+ for task_name, task_dataset in self.multitask_data.items():
527
+ if "multitask" not in out:
528
+ out["multitask"] = {}
529
+ d = [s[task_name] for _, s in samples]
530
+ task_target = task_dataset.collater(d)
531
+ out["multitask"][task_name] = {
532
+ "target": task_target["target"].index_select(0, order),
533
+ "target_lengths": task_target["target_lengths"].index_select(0, order),
534
+ "ntokens": task_target["ntokens"],
535
+ }
536
+ out["multitask"][task_name]["net_input"] = {
537
+ "prev_output_tokens": task_target["prev_output_tokens"].index_select(
538
+ 0, order
539
+ ),
540
+ }
541
+
542
+ return out
543
+
544
+
545
+ class SpeechToTextDatasetCreator(object):
546
+ # mandatory columns
547
+ KEY_ID, KEY_AUDIO, KEY_N_FRAMES = "id", "audio", "n_frames"
548
+ KEY_TGT_TEXT = "tgt_text"
549
+ # optional columns
550
+ KEY_SPEAKER, KEY_SRC_TEXT = "speaker", "src_text"
551
+ KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang"
552
+ # default values
553
+ DEFAULT_SPEAKER = DEFAULT_SRC_TEXT = DEFAULT_LANG = ""
554
+
555
+ @classmethod
556
+ def _from_list(
557
+ cls,
558
+ split_name: str,
559
+ is_train_split,
560
+ samples: List[Dict],
561
+ cfg: S2TDataConfig,
562
+ tgt_dict,
563
+ pre_tokenizer,
564
+ bpe_tokenizer,
565
+ n_frames_per_step,
566
+ speaker_to_id,
567
+ multitask: Optional[Dict] = None,
568
+ ) -> SpeechToTextDataset:
569
+ audio_root = Path(cfg.audio_root)
570
+ ids = [s[cls.KEY_ID] for s in samples]
571
+ audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
572
+ n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
573
+ tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
574
+ src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
575
+ speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
576
+ src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
577
+ tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
578
+
579
+ has_multitask = multitask is not None and len(multitask.keys()) > 0
580
+ dataset_cls = (
581
+ SpeechToTextMultitaskDataset if has_multitask else SpeechToTextDataset
582
+ )
583
+
584
+ ds = dataset_cls(
585
+ split=split_name,
586
+ is_train_split=is_train_split,
587
+ cfg=cfg,
588
+ audio_paths=audio_paths,
589
+ n_frames=n_frames,
590
+ src_texts=src_texts,
591
+ tgt_texts=tgt_texts,
592
+ speakers=speakers,
593
+ src_langs=src_langs,
594
+ tgt_langs=tgt_langs,
595
+ ids=ids,
596
+ tgt_dict=tgt_dict,
597
+ pre_tokenizer=pre_tokenizer,
598
+ bpe_tokenizer=bpe_tokenizer,
599
+ n_frames_per_step=n_frames_per_step,
600
+ speaker_to_id=speaker_to_id,
601
+ )
602
+
603
+ if has_multitask:
604
+ for task_name, task_obj in multitask.items():
605
+ task_data = TextTargetMultitaskData(
606
+ task_obj.args, split_name, task_obj.target_dictionary
607
+ )
608
+ ds.add_multitask_dataset(task_name, task_data)
609
+ return ds
610
+
611
+ @classmethod
612
+ def get_size_ratios(
613
+ cls, datasets: List[SpeechToTextDataset], alpha: float = 1.0
614
+ ) -> List[float]:
615
+ """Size ratios for temperature-based sampling
616
+ (https://arxiv.org/abs/1907.05019)"""
617
+
618
+ id_to_lp, lp_to_sz = {}, defaultdict(int)
619
+ for ds in datasets:
620
+ lang_pairs = {f"{s}->{t}" for s, t in zip(ds.src_langs, ds.tgt_langs)}
621
+ assert len(lang_pairs) == 1
622
+ lang_pair = list(lang_pairs)[0]
623
+ id_to_lp[ds.split] = lang_pair
624
+ lp_to_sz[lang_pair] += sum(ds.n_frames)
625
+
626
+ sz_sum = sum(v for v in lp_to_sz.values())
627
+ lp_to_prob = {k: v / sz_sum for k, v in lp_to_sz.items()}
628
+ lp_to_tgt_prob = {k: v**alpha for k, v in lp_to_prob.items()}
629
+ prob_sum = sum(v for v in lp_to_tgt_prob.values())
630
+ lp_to_tgt_prob = {k: v / prob_sum for k, v in lp_to_tgt_prob.items()}
631
+ lp_to_sz_ratio = {
632
+ k: (lp_to_tgt_prob[k] * sz_sum) / v for k, v in lp_to_sz.items()
633
+ }
634
+ size_ratio = [lp_to_sz_ratio[id_to_lp[ds.split]] for ds in datasets]
635
+
636
+ p_formatted = {
637
+ k: f"{lp_to_prob[k]:.3f}->{lp_to_tgt_prob[k]:.3f}" for k in lp_to_sz
638
+ }
639
+ logger.info(f"sampling probability balancing: {p_formatted}")
640
+ sr_formatted = {ds.split: f"{r:.3f}" for ds, r in zip(datasets, size_ratio)}
641
+ logger.info(f"balanced sampling size ratio: {sr_formatted}")
642
+ return size_ratio
643
+
644
+ @classmethod
645
+ def _load_samples_from_tsv(cls, root: str, split: str):
646
+ tsv_path = Path(root) / f"{split}.tsv"
647
+ if not tsv_path.is_file():
648
+ raise FileNotFoundError(f"Dataset not found: {tsv_path}")
649
+ with open(tsv_path) as f:
650
+ reader = csv.DictReader(
651
+ f,
652
+ delimiter="\t",
653
+ quotechar=None,
654
+ doublequote=False,
655
+ lineterminator="\n",
656
+ quoting=csv.QUOTE_NONE,
657
+ )
658
+ samples = [dict(e) for e in reader]
659
+ if len(samples) == 0:
660
+ raise ValueError(f"Empty manifest: {tsv_path}")
661
+ return samples
662
+
663
+ @classmethod
664
+ def _from_tsv(
665
+ cls,
666
+ root: str,
667
+ cfg: S2TDataConfig,
668
+ split: str,
669
+ tgt_dict,
670
+ is_train_split: bool,
671
+ pre_tokenizer,
672
+ bpe_tokenizer,
673
+ n_frames_per_step,
674
+ speaker_to_id,
675
+ multitask: Optional[Dict] = None,
676
+ ) -> SpeechToTextDataset:
677
+ samples = cls._load_samples_from_tsv(root, split)
678
+ return cls._from_list(
679
+ split,
680
+ is_train_split,
681
+ samples,
682
+ cfg,
683
+ tgt_dict,
684
+ pre_tokenizer,
685
+ bpe_tokenizer,
686
+ n_frames_per_step,
687
+ speaker_to_id,
688
+ multitask,
689
+ )
690
+
691
+ @classmethod
692
+ def from_tsv(
693
+ cls,
694
+ root: str,
695
+ cfg: S2TDataConfig,
696
+ splits: str,
697
+ tgt_dict,
698
+ pre_tokenizer,
699
+ bpe_tokenizer,
700
+ is_train_split: bool,
701
+ epoch: int,
702
+ seed: int,
703
+ n_frames_per_step: int = 1,
704
+ speaker_to_id=None,
705
+ multitask: Optional[Dict] = None,
706
+ ) -> SpeechToTextDataset:
707
+ datasets = [
708
+ cls._from_tsv(
709
+ root=root,
710
+ cfg=cfg,
711
+ split=split,
712
+ tgt_dict=tgt_dict,
713
+ is_train_split=is_train_split,
714
+ pre_tokenizer=pre_tokenizer,
715
+ bpe_tokenizer=bpe_tokenizer,
716
+ n_frames_per_step=n_frames_per_step,
717
+ speaker_to_id=speaker_to_id,
718
+ multitask=multitask,
719
+ )
720
+ for split in splits.split(",")
721
+ ]
722
+
723
+ if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0:
724
+ # temperature-based sampling
725
+ size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha)
726
+ datasets = [
727
+ ResamplingDataset(
728
+ d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)
729
+ )
730
+ for r, d in zip(size_ratios, datasets)
731
+ ]
732
+
733
+ return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
modules/voice_conversion/fairseq/data/audio/speech_to_text_joint_dataset.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from pathlib import Path
8
+ from typing import Dict, List, NamedTuple, Optional
9
+
10
+ import torch
11
+
12
+ from fairseq.data import ConcatDataset, Dictionary, ResamplingDataset
13
+ from fairseq.data import data_utils as fairseq_data_utils
14
+ from fairseq.data.audio.speech_to_text_dataset import (
15
+ S2TDataConfig,
16
+ SpeechToTextDataset,
17
+ SpeechToTextDatasetCreator,
18
+ )
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class S2TJointDataConfig(S2TDataConfig):
24
+ """Wrapper class for data config YAML"""
25
+
26
+ @property
27
+ def src_vocab_filename(self):
28
+ """fairseq vocabulary file under data root"""
29
+ return self.config.get("src_vocab_filename", "src_dict.txt")
30
+
31
+ @property
32
+ def src_pre_tokenizer(self) -> Dict:
33
+ """Pre-tokenizer to apply before subword tokenization. Returning
34
+ a dictionary with `tokenizer` providing the tokenizer name and
35
+ the other items providing the tokenizer-specific arguments.
36
+ Tokenizers are defined in `fairseq.data.encoders.*`"""
37
+ return self.config.get("src_pre_tokenizer", {"tokenizer": None})
38
+
39
+ @property
40
+ def src_bpe_tokenizer(self) -> Dict:
41
+ """Subword tokenizer to apply on source text after pre-tokenization.
42
+ Returning a dictionary with `bpe` providing the tokenizer name and
43
+ the other items providing the tokenizer-specific arguments.
44
+ Tokenizers are defined in `fairseq.data.encoders.*`"""
45
+ return self.config.get("src_bpe_tokenizer", {"bpe": None})
46
+
47
+ @property
48
+ def prepend_tgt_lang_tag_no_change(self) -> bool:
49
+ """Prepend target lang ID token as the prev_output_tokens BOS (e.g. for
50
+ to-many multilingual setting). No change needed during inference.
51
+ This option is deprecated and replaced by prepend_tgt_lang_tag_as_bos.
52
+ """
53
+ value = self.config.get("prepend_tgt_lang_tag_no_change", None)
54
+ if value is None:
55
+ return self.config.get("prepend_tgt_lang_tag_as_bos", False)
56
+ return value
57
+
58
+ @property
59
+ def sampling_text_alpha(self):
60
+ """Hyper-parameter alpha = 1/T for temperature-based resampling. (text
61
+ input only) (alpha = 1 for no resampling)"""
62
+ return self.config.get("sampling_text_alpha", 1.0)
63
+
64
+
65
+ class SpeechToTextJointDatasetItem(NamedTuple):
66
+ index: int
67
+ source: torch.Tensor
68
+ target: Optional[torch.Tensor] = None
69
+ src_txt_tokens: Optional[torch.Tensor] = None
70
+ tgt_lang_tag: Optional[int] = None
71
+ src_lang_tag: Optional[int] = None
72
+ tgt_alignment: Optional[torch.Tensor] = None
73
+
74
+
75
+ # use_src_lang_id:
76
+ # 0: don't use src_lang_id
77
+ # 1: attach src_lang_id to the src_txt_tokens as eos
78
+ class SpeechToTextJointDataset(SpeechToTextDataset):
79
+ def __init__(
80
+ self,
81
+ split: str,
82
+ is_train_split: bool,
83
+ cfg: S2TJointDataConfig,
84
+ audio_paths: List[str],
85
+ n_frames: List[int],
86
+ src_texts: Optional[List[str]] = None,
87
+ tgt_texts: Optional[List[str]] = None,
88
+ speakers: Optional[List[str]] = None,
89
+ src_langs: Optional[List[str]] = None,
90
+ tgt_langs: Optional[List[str]] = None,
91
+ ids: Optional[List[str]] = None,
92
+ tgt_dict: Optional[Dictionary] = None,
93
+ src_dict: Optional[Dictionary] = None,
94
+ pre_tokenizer=None,
95
+ bpe_tokenizer=None,
96
+ src_pre_tokenizer=None,
97
+ src_bpe_tokenizer=None,
98
+ append_eos: Optional[bool] = True,
99
+ alignment: Optional[List[str]] = None,
100
+ use_src_lang_id: Optional[int] = 0,
101
+ ):
102
+ super().__init__(
103
+ split,
104
+ is_train_split,
105
+ cfg,
106
+ audio_paths,
107
+ n_frames,
108
+ src_texts=src_texts,
109
+ tgt_texts=tgt_texts,
110
+ speakers=speakers,
111
+ src_langs=src_langs,
112
+ tgt_langs=tgt_langs,
113
+ ids=ids,
114
+ tgt_dict=tgt_dict,
115
+ pre_tokenizer=pre_tokenizer,
116
+ bpe_tokenizer=bpe_tokenizer,
117
+ append_eos=append_eos,
118
+ )
119
+
120
+ self.src_dict = src_dict
121
+ self.src_pre_tokenizer = src_pre_tokenizer
122
+ self.src_bpe_tokenizer = src_bpe_tokenizer
123
+ self.alignment = None
124
+ self.use_src_lang_id = use_src_lang_id
125
+ if alignment is not None:
126
+ self.alignment = [
127
+ [float(s) for s in sample.split()] for sample in alignment
128
+ ]
129
+
130
+ def get_tokenized_src_text(self, index: int):
131
+ text = self.tokenize(self.src_pre_tokenizer, self.src_texts[index])
132
+ text = self.tokenize(self.src_bpe_tokenizer, text)
133
+ return text
134
+
135
+ def __getitem__(self, index: int) -> SpeechToTextJointDatasetItem:
136
+ s2t_dataset_item = super().__getitem__(index)
137
+ src_tokens = None
138
+ src_lang_tag = None
139
+ if self.src_texts is not None and self.src_dict is not None:
140
+ src_tokens = self.get_tokenized_src_text(index)
141
+ src_tokens = self.src_dict.encode_line(
142
+ src_tokens, add_if_not_exist=False, append_eos=True
143
+ ).long()
144
+ if self.use_src_lang_id > 0:
145
+ src_lang_tag = self.get_lang_tag_idx(
146
+ self.src_langs[index], self.src_dict
147
+ )
148
+ tgt_lang_tag = None
149
+ if self.cfg.prepend_tgt_lang_tag_no_change:
150
+ # prepend_tgt_lang_tag_no_change: modify prev_output_tokens instead
151
+ tgt_lang_tag = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict)
152
+ ali = None
153
+ if self.alignment is not None:
154
+ ali = torch.Tensor(self.alignment[index]).float()
155
+
156
+ return SpeechToTextJointDatasetItem(
157
+ index=index,
158
+ source=s2t_dataset_item.source,
159
+ target=s2t_dataset_item.target,
160
+ src_txt_tokens=src_tokens,
161
+ tgt_lang_tag=tgt_lang_tag,
162
+ src_lang_tag=src_lang_tag,
163
+ tgt_alignment=ali,
164
+ )
165
+
166
+ def __len__(self):
167
+ return self.n_samples
168
+
169
+ def collater(self, samples: List[SpeechToTextJointDatasetItem]) -> Dict:
170
+ s2t_out = super().collater(samples, return_order=True)
171
+ if s2t_out == {}:
172
+ return s2t_out
173
+ net_input, order = s2t_out["net_input"], s2t_out["order"]
174
+
175
+ if self.src_texts is not None and self.src_dict is not None:
176
+ src_txt_tokens = fairseq_data_utils.collate_tokens(
177
+ [x.src_txt_tokens for x in samples],
178
+ self.src_dict.pad(),
179
+ self.src_dict.eos(),
180
+ left_pad=False,
181
+ move_eos_to_beginning=False,
182
+ )
183
+ src_txt_lengths = torch.tensor(
184
+ [x.src_txt_tokens.size()[0] for x in samples], dtype=torch.long
185
+ )
186
+ if self.use_src_lang_id > 0:
187
+ src_lang_idxs = torch.tensor(
188
+ [s.src_lang_tag for s in samples], dtype=src_txt_tokens.dtype
189
+ )
190
+ if self.use_src_lang_id == 1: # replace eos with lang_id
191
+ eos_idx = src_txt_lengths - 1
192
+ src_txt_tokens.scatter_(
193
+ 1, eos_idx.view(-1, 1), src_lang_idxs.view(-1, 1)
194
+ )
195
+ else:
196
+ raise NotImplementedError("Implementation is required")
197
+
198
+ src_txt_tokens = src_txt_tokens.index_select(0, order)
199
+ src_txt_lengths = src_txt_lengths.index_select(0, order)
200
+ net_input["src_txt_tokens"] = src_txt_tokens
201
+ net_input["src_txt_lengths"] = src_txt_lengths
202
+
203
+ net_input["alignment"] = None
204
+ if self.alignment is not None:
205
+ max_len = max([s.tgt_alignment.size(0) for s in samples])
206
+ alignment = torch.ones(len(samples), max_len).float()
207
+ for i, s in enumerate(samples):
208
+ cur_len = s.tgt_alignment.size(0)
209
+ alignment[i][:cur_len].copy_(s.tgt_alignment)
210
+ net_input["alignment"] = alignment.index_select(0, order)
211
+
212
+ if self.tgt_texts is not None and samples[0].tgt_lang_tag is not None:
213
+ for i in range(len(samples)):
214
+ net_input["prev_output_tokens"][i][0] = samples[order[i]].tgt_lang_tag
215
+
216
+ out = {
217
+ "id": s2t_out["id"],
218
+ "net_input": net_input,
219
+ "target": s2t_out["target"],
220
+ "target_lengths": s2t_out["target_lengths"],
221
+ "ntokens": s2t_out["ntokens"],
222
+ "nsentences": len(samples),
223
+ }
224
+ return out
225
+
226
+
227
+ class SpeechToTextJointDatasetCreator(SpeechToTextDatasetCreator):
228
+ KEY_ALIGN = "align"
229
+
230
+ @classmethod
231
+ def _from_list(
232
+ cls,
233
+ split_name: str,
234
+ is_train_split,
235
+ samples: List[Dict],
236
+ cfg: S2TJointDataConfig,
237
+ tgt_dict,
238
+ src_dict,
239
+ pre_tokenizer,
240
+ bpe_tokenizer,
241
+ src_pre_tokenizer,
242
+ src_bpe_tokenizer,
243
+ append_eos,
244
+ use_src_lang_id,
245
+ ) -> SpeechToTextJointDataset:
246
+ audio_root = Path(cfg.audio_root)
247
+ ids = [s[cls.KEY_ID] for s in samples]
248
+ audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
249
+ n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
250
+ tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
251
+ src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
252
+ speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
253
+ src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
254
+ tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
255
+ tgt_alignment = None
256
+ if cls.KEY_ALIGN in samples[0].keys():
257
+ tgt_alignment = [s[cls.KEY_ALIGN] for s in samples]
258
+ return SpeechToTextJointDataset(
259
+ split_name,
260
+ is_train_split,
261
+ cfg,
262
+ audio_paths,
263
+ n_frames,
264
+ src_texts=src_texts,
265
+ tgt_texts=tgt_texts,
266
+ speakers=speakers,
267
+ src_langs=src_langs,
268
+ tgt_langs=tgt_langs,
269
+ ids=ids,
270
+ tgt_dict=tgt_dict,
271
+ src_dict=src_dict,
272
+ pre_tokenizer=pre_tokenizer,
273
+ bpe_tokenizer=bpe_tokenizer,
274
+ src_pre_tokenizer=src_pre_tokenizer,
275
+ src_bpe_tokenizer=src_bpe_tokenizer,
276
+ append_eos=append_eos,
277
+ alignment=tgt_alignment,
278
+ use_src_lang_id=use_src_lang_id,
279
+ )
280
+
281
+ @classmethod
282
+ def _from_tsv(
283
+ cls,
284
+ root: str,
285
+ cfg: S2TJointDataConfig,
286
+ split: str,
287
+ tgt_dict,
288
+ src_dict,
289
+ is_train_split: bool,
290
+ pre_tokenizer,
291
+ bpe_tokenizer,
292
+ src_pre_tokenizer,
293
+ src_bpe_tokenizer,
294
+ append_eos: bool,
295
+ use_src_lang_id: int,
296
+ ) -> SpeechToTextJointDataset:
297
+ samples = cls._load_samples_from_tsv(root, split)
298
+ return cls._from_list(
299
+ split,
300
+ is_train_split,
301
+ samples,
302
+ cfg,
303
+ tgt_dict,
304
+ src_dict,
305
+ pre_tokenizer,
306
+ bpe_tokenizer,
307
+ src_pre_tokenizer,
308
+ src_bpe_tokenizer,
309
+ append_eos,
310
+ use_src_lang_id,
311
+ )
312
+
313
+ @classmethod
314
+ def from_tsv(
315
+ cls,
316
+ root: str,
317
+ cfg: S2TJointDataConfig,
318
+ splits: str,
319
+ tgt_dict,
320
+ src_dict,
321
+ pre_tokenizer,
322
+ bpe_tokenizer,
323
+ src_pre_tokenizer,
324
+ src_bpe_tokenizer,
325
+ is_train_split: bool,
326
+ epoch: int,
327
+ seed: int,
328
+ append_eos: Optional[bool] = True,
329
+ use_src_lang_id: Optional[int] = 0,
330
+ ) -> SpeechToTextJointDataset:
331
+ datasets = [
332
+ cls._from_tsv(
333
+ root,
334
+ cfg,
335
+ split,
336
+ tgt_dict,
337
+ src_dict,
338
+ is_train_split,
339
+ pre_tokenizer,
340
+ bpe_tokenizer,
341
+ src_pre_tokenizer,
342
+ src_bpe_tokenizer,
343
+ append_eos=append_eos,
344
+ use_src_lang_id=use_src_lang_id,
345
+ )
346
+ for split in splits.split(",")
347
+ ]
348
+
349
+ if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0:
350
+ # temperature-based sampling
351
+ size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha)
352
+ datasets = [
353
+ ResamplingDataset(
354
+ d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)
355
+ )
356
+ for r, d in zip(size_ratios, datasets)
357
+ ]
358
+
359
+ return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
modules/voice_conversion/fairseq/data/audio/text_to_speech_dataset.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2017-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the LICENSE file in
5
+ # the root directory of this source tree. An additional grant of patent rights
6
+ # can be found in the PATENTS file in the same directory.abs
7
+
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import Any, Dict, List, Optional
11
+
12
+ import numpy as np
13
+ import torch
14
+
15
+ from fairseq.data import Dictionary
16
+ from fairseq.data import data_utils as fairseq_data_utils
17
+ from fairseq.data.audio.audio_utils import get_features_or_waveform
18
+ from fairseq.data.audio.speech_to_text_dataset import (
19
+ S2TDataConfig,
20
+ SpeechToTextDataset,
21
+ SpeechToTextDatasetCreator,
22
+ _collate_frames,
23
+ )
24
+
25
+
26
+ @dataclass
27
+ class TextToSpeechDatasetItem(object):
28
+ index: int
29
+ source: torch.Tensor
30
+ target: Optional[torch.Tensor] = None
31
+ speaker_id: Optional[int] = None
32
+ duration: Optional[torch.Tensor] = None
33
+ pitch: Optional[torch.Tensor] = None
34
+ energy: Optional[torch.Tensor] = None
35
+
36
+
37
+ class TextToSpeechDataset(SpeechToTextDataset):
38
+ def __init__(
39
+ self,
40
+ split: str,
41
+ is_train_split: bool,
42
+ cfg: S2TDataConfig,
43
+ audio_paths: List[str],
44
+ n_frames: List[int],
45
+ src_texts: Optional[List[str]] = None,
46
+ tgt_texts: Optional[List[str]] = None,
47
+ speakers: Optional[List[str]] = None,
48
+ src_langs: Optional[List[str]] = None,
49
+ tgt_langs: Optional[List[str]] = None,
50
+ ids: Optional[List[str]] = None,
51
+ tgt_dict: Optional[Dictionary] = None,
52
+ pre_tokenizer=None,
53
+ bpe_tokenizer=None,
54
+ n_frames_per_step=1,
55
+ speaker_to_id=None,
56
+ durations: Optional[List[List[int]]] = None,
57
+ pitches: Optional[List[str]] = None,
58
+ energies: Optional[List[str]] = None,
59
+ ):
60
+ super(TextToSpeechDataset, self).__init__(
61
+ split,
62
+ is_train_split,
63
+ cfg,
64
+ audio_paths,
65
+ n_frames,
66
+ src_texts=src_texts,
67
+ tgt_texts=tgt_texts,
68
+ speakers=speakers,
69
+ src_langs=src_langs,
70
+ tgt_langs=tgt_langs,
71
+ ids=ids,
72
+ tgt_dict=tgt_dict,
73
+ pre_tokenizer=pre_tokenizer,
74
+ bpe_tokenizer=bpe_tokenizer,
75
+ n_frames_per_step=n_frames_per_step,
76
+ speaker_to_id=speaker_to_id,
77
+ )
78
+ self.durations = durations
79
+ self.pitches = pitches
80
+ self.energies = energies
81
+
82
+ def __getitem__(self, index: int) -> TextToSpeechDatasetItem:
83
+ s2t_item = super().__getitem__(index)
84
+
85
+ duration, pitch, energy = None, None, None
86
+ if self.durations is not None:
87
+ duration = torch.tensor(
88
+ self.durations[index] + [0], dtype=torch.long # pad 0 for EOS
89
+ )
90
+ if self.pitches is not None:
91
+ pitch = get_features_or_waveform(self.pitches[index])
92
+ pitch = torch.from_numpy(
93
+ np.concatenate((pitch, [0])) # pad 0 for EOS
94
+ ).float()
95
+ if self.energies is not None:
96
+ energy = get_features_or_waveform(self.energies[index])
97
+ energy = torch.from_numpy(
98
+ np.concatenate((energy, [0])) # pad 0 for EOS
99
+ ).float()
100
+ return TextToSpeechDatasetItem(
101
+ index=index,
102
+ source=s2t_item.source,
103
+ target=s2t_item.target,
104
+ speaker_id=s2t_item.speaker_id,
105
+ duration=duration,
106
+ pitch=pitch,
107
+ energy=energy,
108
+ )
109
+
110
+ def collater(self, samples: List[TextToSpeechDatasetItem]) -> Dict[str, Any]:
111
+ if len(samples) == 0:
112
+ return {}
113
+
114
+ src_lengths, order = torch.tensor(
115
+ [s.target.shape[0] for s in samples], dtype=torch.long
116
+ ).sort(descending=True)
117
+ id_ = torch.tensor([s.index for s in samples], dtype=torch.long).index_select(
118
+ 0, order
119
+ )
120
+ feat = _collate_frames(
121
+ [s.source for s in samples], self.cfg.use_audio_input
122
+ ).index_select(0, order)
123
+ target_lengths = torch.tensor(
124
+ [s.source.shape[0] for s in samples], dtype=torch.long
125
+ ).index_select(0, order)
126
+
127
+ src_tokens = fairseq_data_utils.collate_tokens(
128
+ [s.target for s in samples],
129
+ self.tgt_dict.pad(),
130
+ self.tgt_dict.eos(),
131
+ left_pad=False,
132
+ move_eos_to_beginning=False,
133
+ ).index_select(0, order)
134
+
135
+ speaker = None
136
+ if self.speaker_to_id is not None:
137
+ speaker = (
138
+ torch.tensor([s.speaker_id for s in samples], dtype=torch.long)
139
+ .index_select(0, order)
140
+ .view(-1, 1)
141
+ )
142
+
143
+ bsz, _, d = feat.size()
144
+ prev_output_tokens = torch.cat(
145
+ (feat.new_zeros((bsz, 1, d)), feat[:, :-1, :]), dim=1
146
+ )
147
+
148
+ durations, pitches, energies = None, None, None
149
+ if self.durations is not None:
150
+ durations = fairseq_data_utils.collate_tokens(
151
+ [s.duration for s in samples], 0
152
+ ).index_select(0, order)
153
+ assert src_tokens.shape[1] == durations.shape[1]
154
+ if self.pitches is not None:
155
+ pitches = _collate_frames([s.pitch for s in samples], True)
156
+ pitches = pitches.index_select(0, order)
157
+ assert src_tokens.shape[1] == pitches.shape[1]
158
+ if self.energies is not None:
159
+ energies = _collate_frames([s.energy for s in samples], True)
160
+ energies = energies.index_select(0, order)
161
+ assert src_tokens.shape[1] == energies.shape[1]
162
+ src_texts = [self.tgt_dict.string(samples[i].target) for i in order]
163
+
164
+ return {
165
+ "id": id_,
166
+ "net_input": {
167
+ "src_tokens": src_tokens,
168
+ "src_lengths": src_lengths,
169
+ "prev_output_tokens": prev_output_tokens,
170
+ },
171
+ "speaker": speaker,
172
+ "target": feat,
173
+ "durations": durations,
174
+ "pitches": pitches,
175
+ "energies": energies,
176
+ "target_lengths": target_lengths,
177
+ "ntokens": sum(target_lengths).item(),
178
+ "nsentences": len(samples),
179
+ "src_texts": src_texts,
180
+ }
181
+
182
+
183
+ class TextToSpeechDatasetCreator(SpeechToTextDatasetCreator):
184
+ KEY_DURATION = "duration"
185
+ KEY_PITCH = "pitch"
186
+ KEY_ENERGY = "energy"
187
+
188
+ @classmethod
189
+ def _from_list(
190
+ cls,
191
+ split_name: str,
192
+ is_train_split,
193
+ samples: List[Dict],
194
+ cfg: S2TDataConfig,
195
+ tgt_dict,
196
+ pre_tokenizer,
197
+ bpe_tokenizer,
198
+ n_frames_per_step,
199
+ speaker_to_id,
200
+ multitask=None,
201
+ ) -> TextToSpeechDataset:
202
+ audio_root = Path(cfg.audio_root)
203
+ ids = [s[cls.KEY_ID] for s in samples]
204
+ audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
205
+ n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
206
+ tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
207
+ src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
208
+ speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
209
+ src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
210
+ tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
211
+
212
+ durations = [s.get(cls.KEY_DURATION, None) for s in samples]
213
+ durations = [
214
+ None if dd is None else [int(d) for d in dd.split(" ")] for dd in durations
215
+ ]
216
+ durations = None if any(dd is None for dd in durations) else durations
217
+
218
+ pitches = [s.get(cls.KEY_PITCH, None) for s in samples]
219
+ pitches = [
220
+ None if pp is None else (audio_root / pp).as_posix() for pp in pitches
221
+ ]
222
+ pitches = None if any(pp is None for pp in pitches) else pitches
223
+
224
+ energies = [s.get(cls.KEY_ENERGY, None) for s in samples]
225
+ energies = [
226
+ None if ee is None else (audio_root / ee).as_posix() for ee in energies
227
+ ]
228
+ energies = None if any(ee is None for ee in energies) else energies
229
+
230
+ return TextToSpeechDataset(
231
+ split_name,
232
+ is_train_split,
233
+ cfg,
234
+ audio_paths,
235
+ n_frames,
236
+ src_texts,
237
+ tgt_texts,
238
+ speakers,
239
+ src_langs,
240
+ tgt_langs,
241
+ ids,
242
+ tgt_dict,
243
+ pre_tokenizer,
244
+ bpe_tokenizer,
245
+ n_frames_per_step,
246
+ speaker_to_id,
247
+ durations,
248
+ pitches,
249
+ energies,
250
+ )
modules/voice_conversion/fairseq/data/audio/waveform_transforms/__init__.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from fairseq.data.audio import (
3
+ AudioTransform,
4
+ CompositeAudioTransform,
5
+ import_transforms,
6
+ register_audio_transform,
7
+ )
8
+
9
+
10
+ class AudioWaveformTransform(AudioTransform):
11
+ pass
12
+
13
+
14
+ AUDIO_WAVEFORM_TRANSFORM_REGISTRY = {}
15
+ AUDIO_WAVEFORM_TRANSFORM_CLASS_NAMES = set()
16
+
17
+
18
+ def get_audio_waveform_transform(name):
19
+ return AUDIO_WAVEFORM_TRANSFORM_REGISTRY[name]
20
+
21
+
22
+ def register_audio_waveform_transform(name):
23
+ return register_audio_transform(
24
+ name,
25
+ AudioWaveformTransform,
26
+ AUDIO_WAVEFORM_TRANSFORM_REGISTRY,
27
+ AUDIO_WAVEFORM_TRANSFORM_CLASS_NAMES,
28
+ )
29
+
30
+
31
+ import_transforms(os.path.dirname(__file__), "waveform")
32
+
33
+
34
+ class CompositeAudioWaveformTransform(CompositeAudioTransform):
35
+ @classmethod
36
+ def from_config_dict(cls, config=None):
37
+ return super()._from_config_dict(
38
+ cls,
39
+ "waveform",
40
+ get_audio_waveform_transform,
41
+ CompositeAudioWaveformTransform,
42
+ config,
43
+ )
44
+
45
+ def __call__(self, x, sample_rate):
46
+ for t in self.transforms:
47
+ x, sample_rate = t(x, sample_rate)
48
+ return x, sample_rate
modules/voice_conversion/fairseq/data/audio/waveform_transforms/noiseaugment.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import numpy as np
3
+ from math import ceil
4
+
5
+ from fairseq.data.audio import rand_uniform
6
+ from fairseq.data.audio.waveform_transforms import (
7
+ AudioWaveformTransform,
8
+ register_audio_waveform_transform,
9
+ )
10
+
11
+ SNR_MIN = 5.0
12
+ SNR_MAX = 15.0
13
+ RATE = 0.25
14
+
15
+ NOISE_RATE = 1.0
16
+ NOISE_LEN_MEAN = 0.2
17
+ NOISE_LEN_STD = 0.05
18
+
19
+
20
+ class NoiseAugmentTransform(AudioWaveformTransform):
21
+ @classmethod
22
+ def from_config_dict(cls, config=None):
23
+ _config = {} if config is None else config
24
+ return cls(
25
+ _config.get("samples_path", None),
26
+ _config.get("snr_min", SNR_MIN),
27
+ _config.get("snr_max", SNR_MAX),
28
+ _config.get("rate", RATE),
29
+ )
30
+
31
+ def __init__(
32
+ self,
33
+ samples_path: str,
34
+ snr_min: float = SNR_MIN,
35
+ snr_max: float = SNR_MAX,
36
+ rate: float = RATE,
37
+ ):
38
+ # Sanity checks
39
+ assert (
40
+ samples_path
41
+ ), "need to provide path to audio samples for noise augmentation"
42
+ assert snr_max >= snr_min, f"empty signal-to-noise range ({snr_min}, {snr_max})"
43
+ assert rate >= 0 and rate <= 1, "rate should be a float between 0 to 1"
44
+
45
+ self.paths = list(Path(samples_path).glob("**/*.wav")) # load music
46
+ self.n_samples = len(self.paths)
47
+ assert self.n_samples > 0, f"no audio files found in {samples_path}"
48
+
49
+ self.snr_min = snr_min
50
+ self.snr_max = snr_max
51
+ self.rate = rate
52
+
53
+ def __repr__(self):
54
+ return (
55
+ self.__class__.__name__
56
+ + "("
57
+ + ", ".join(
58
+ [
59
+ f"n_samples={self.n_samples}",
60
+ f"snr={self.snr_min}-{self.snr_max}dB",
61
+ f"rate={self.rate}",
62
+ ]
63
+ )
64
+ + ")"
65
+ )
66
+
67
+ def pick_sample(self, goal_shape, always_2d=False, use_sample_rate=None):
68
+ from fairseq.data.audio.audio_utils import get_waveform
69
+
70
+ path = self.paths[np.random.randint(0, self.n_samples)]
71
+ sample = get_waveform(
72
+ path, always_2d=always_2d, output_sample_rate=use_sample_rate
73
+ )[0]
74
+
75
+ # Check dimensions match, else silently skip adding noise to sample
76
+ # NOTE: SHOULD THIS QUIT WITH AN ERROR?
77
+ is_2d = len(goal_shape) == 2
78
+ if len(goal_shape) != sample.ndim or (
79
+ is_2d and goal_shape[0] != sample.shape[0]
80
+ ):
81
+ return np.zeros(goal_shape)
82
+
83
+ # Cut/repeat sample to size
84
+ len_dim = len(goal_shape) - 1
85
+ n_repeat = ceil(goal_shape[len_dim] / sample.shape[len_dim])
86
+ repeated = np.tile(sample, [1, n_repeat] if is_2d else n_repeat)
87
+ start = np.random.randint(0, repeated.shape[len_dim] - goal_shape[len_dim] + 1)
88
+ return (
89
+ repeated[:, start : start + goal_shape[len_dim]]
90
+ if is_2d
91
+ else repeated[start : start + goal_shape[len_dim]]
92
+ )
93
+
94
+ def _mix(self, source, noise, snr):
95
+ get_power = lambda x: np.mean(x**2)
96
+ if get_power(noise):
97
+ scl = np.sqrt(
98
+ get_power(source) / (np.power(10, snr / 10) * get_power(noise))
99
+ )
100
+ else:
101
+ scl = 0
102
+ return 1 * source + scl * noise
103
+
104
+ def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
105
+ return self.pick_sample(goal_shape, always_2d, use_sample_rate)
106
+
107
+ def __call__(self, source, sample_rate):
108
+ if np.random.random() > self.rate:
109
+ return source, sample_rate
110
+
111
+ noise = self._get_noise(
112
+ source.shape, always_2d=True, use_sample_rate=sample_rate
113
+ )
114
+
115
+ return (
116
+ self._mix(source, noise, rand_uniform(self.snr_min, self.snr_max)),
117
+ sample_rate,
118
+ )
119
+
120
+
121
+ @register_audio_waveform_transform("musicaugment")
122
+ class MusicAugmentTransform(NoiseAugmentTransform):
123
+ pass
124
+
125
+
126
+ @register_audio_waveform_transform("backgroundnoiseaugment")
127
+ class BackgroundNoiseAugmentTransform(NoiseAugmentTransform):
128
+ pass
129
+
130
+
131
+ @register_audio_waveform_transform("babbleaugment")
132
+ class BabbleAugmentTransform(NoiseAugmentTransform):
133
+ def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
134
+ for i in range(np.random.randint(3, 8)):
135
+ speech = self.pick_sample(goal_shape, always_2d, use_sample_rate)
136
+ if i == 0:
137
+ agg_noise = speech
138
+ else: # SNR scaled by i (how many noise signals already in agg_noise)
139
+ agg_noise = self._mix(agg_noise, speech, i)
140
+ return agg_noise
141
+
142
+
143
+ @register_audio_waveform_transform("sporadicnoiseaugment")
144
+ class SporadicNoiseAugmentTransform(NoiseAugmentTransform):
145
+ @classmethod
146
+ def from_config_dict(cls, config=None):
147
+ _config = {} if config is None else config
148
+ return cls(
149
+ _config.get("samples_path", None),
150
+ _config.get("snr_min", SNR_MIN),
151
+ _config.get("snr_max", SNR_MAX),
152
+ _config.get("rate", RATE),
153
+ _config.get("noise_rate", NOISE_RATE),
154
+ _config.get("noise_len_mean", NOISE_LEN_MEAN),
155
+ _config.get("noise_len_std", NOISE_LEN_STD),
156
+ )
157
+
158
+ def __init__(
159
+ self,
160
+ samples_path: str,
161
+ snr_min: float = SNR_MIN,
162
+ snr_max: float = SNR_MAX,
163
+ rate: float = RATE,
164
+ noise_rate: float = NOISE_RATE, # noises per second
165
+ noise_len_mean: float = NOISE_LEN_MEAN, # length of noises in seconds
166
+ noise_len_std: float = NOISE_LEN_STD,
167
+ ):
168
+ super().__init__(samples_path, snr_min, snr_max, rate)
169
+ self.noise_rate = noise_rate
170
+ self.noise_len_mean = noise_len_mean
171
+ self.noise_len_std = noise_len_std
172
+
173
+ def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
174
+ agg_noise = np.zeros(goal_shape)
175
+ len_dim = len(goal_shape) - 1
176
+ is_2d = len(goal_shape) == 2
177
+
178
+ n_noises = round(self.noise_rate * goal_shape[len_dim] / use_sample_rate)
179
+ start_pointers = [
180
+ round(rand_uniform(0, goal_shape[len_dim])) for _ in range(n_noises)
181
+ ]
182
+
183
+ for start_pointer in start_pointers:
184
+ noise_shape = list(goal_shape)
185
+ len_seconds = np.random.normal(self.noise_len_mean, self.noise_len_std)
186
+ noise_shape[len_dim] = round(max(0, len_seconds) * use_sample_rate)
187
+ end_pointer = start_pointer + noise_shape[len_dim]
188
+ if end_pointer >= goal_shape[len_dim]:
189
+ continue
190
+
191
+ noise = self.pick_sample(noise_shape, always_2d, use_sample_rate)
192
+ if is_2d:
193
+ agg_noise[:, start_pointer:end_pointer] = (
194
+ agg_noise[:, start_pointer:end_pointer] + noise
195
+ )
196
+ else:
197
+ agg_noise[start_pointer:end_pointer] = (
198
+ agg_noise[start_pointer:end_pointer] + noise
199
+ )
200
+
201
+ return agg_noise
modules/voice_conversion/fairseq/data/backtranslation_dataset.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from fairseq import utils
8
+
9
+ from . import FairseqDataset
10
+
11
+
12
+ def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True):
13
+ """Backtranslate a list of samples.
14
+
15
+ Given an input (*samples*) of the form:
16
+
17
+ [{'id': 1, 'source': 'hallo welt'}]
18
+
19
+ this will return:
20
+
21
+ [{'id': 1, 'source': 'hello world', 'target': 'hallo welt'}]
22
+
23
+ Args:
24
+ samples (List[dict]): samples to backtranslate. Individual samples are
25
+ expected to have a 'source' key, which will become the 'target'
26
+ after backtranslation.
27
+ collate_fn (callable): function to collate samples into a mini-batch
28
+ generate_fn (callable): function to generate backtranslations
29
+ cuda (bool): use GPU for generation (default: ``True``)
30
+
31
+ Returns:
32
+ List[dict]: an updated list of samples with a backtranslated source
33
+ """
34
+ collated_samples = collate_fn(samples)
35
+ s = utils.move_to_cuda(collated_samples) if cuda else collated_samples
36
+ generated_sources = generate_fn(s)
37
+
38
+ id_to_src = {sample["id"]: sample["source"] for sample in samples}
39
+
40
+ # Go through each tgt sentence in batch and its corresponding best
41
+ # generated hypothesis and create a backtranslation data pair
42
+ # {id: id, source: generated backtranslation, target: original tgt}
43
+ return [
44
+ {
45
+ "id": id.item(),
46
+ "target": id_to_src[id.item()],
47
+ "source": hypos[0]["tokens"].cpu(),
48
+ }
49
+ for id, hypos in zip(collated_samples["id"], generated_sources)
50
+ ]
51
+
52
+
53
+ class BacktranslationDataset(FairseqDataset):
54
+ """
55
+ Sets up a backtranslation dataset which takes a tgt batch, generates
56
+ a src using a tgt-src backtranslation function (*backtranslation_fn*),
57
+ and returns the corresponding `{generated src, input tgt}` batch.
58
+
59
+ Args:
60
+ tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be
61
+ backtranslated. Only the source side of this dataset will be used.
62
+ After backtranslation, the source sentences in this dataset will be
63
+ returned as the targets.
64
+ src_dict (~fairseq.data.Dictionary): the dictionary of backtranslated
65
+ sentences.
66
+ tgt_dict (~fairseq.data.Dictionary, optional): the dictionary of
67
+ sentences to be backtranslated.
68
+ backtranslation_fn (callable, optional): function to call to generate
69
+ backtranslations. This is typically the `generate` method of a
70
+ :class:`~fairseq.sequence_generator.SequenceGenerator` object.
71
+ Pass in None when it is not available at initialization time, and
72
+ use set_backtranslation_fn function to set it when available.
73
+ output_collater (callable, optional): function to call on the
74
+ backtranslated samples to create the final batch
75
+ (default: ``tgt_dataset.collater``).
76
+ cuda: use GPU for generation
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ tgt_dataset,
82
+ src_dict,
83
+ tgt_dict=None,
84
+ backtranslation_fn=None,
85
+ output_collater=None,
86
+ cuda=True,
87
+ **kwargs
88
+ ):
89
+ self.tgt_dataset = tgt_dataset
90
+ self.backtranslation_fn = backtranslation_fn
91
+ self.output_collater = (
92
+ output_collater if output_collater is not None else tgt_dataset.collater
93
+ )
94
+ self.cuda = cuda if torch.cuda.is_available() else False
95
+ self.src_dict = src_dict
96
+ self.tgt_dict = tgt_dict
97
+
98
+ def __getitem__(self, index):
99
+ """
100
+ Returns a single sample from *tgt_dataset*. Note that backtranslation is
101
+ not applied in this step; use :func:`collater` instead to backtranslate
102
+ a batch of samples.
103
+ """
104
+ return self.tgt_dataset[index]
105
+
106
+ def __len__(self):
107
+ return len(self.tgt_dataset)
108
+
109
+ def set_backtranslation_fn(self, backtranslation_fn):
110
+ self.backtranslation_fn = backtranslation_fn
111
+
112
+ def collater(self, samples):
113
+ """Merge and backtranslate a list of samples to form a mini-batch.
114
+
115
+ Using the samples from *tgt_dataset*, load a collated target sample to
116
+ feed to the backtranslation model. Then take the backtranslation with
117
+ the best score as the source and the original input as the target.
118
+
119
+ Note: we expect *tgt_dataset* to provide a function `collater()` that
120
+ will collate samples into the format expected by *backtranslation_fn*.
121
+ After backtranslation, we will feed the new list of samples (i.e., the
122
+ `(backtranslated source, original source)` pairs) to *output_collater*
123
+ and return the result.
124
+
125
+ Args:
126
+ samples (List[dict]): samples to backtranslate and collate
127
+
128
+ Returns:
129
+ dict: a mini-batch with keys coming from *output_collater*
130
+ """
131
+ if samples[0].get("is_dummy", False):
132
+ return samples
133
+ samples = backtranslate_samples(
134
+ samples=samples,
135
+ collate_fn=self.tgt_dataset.collater,
136
+ generate_fn=(lambda net_input: self.backtranslation_fn(net_input)),
137
+ cuda=self.cuda,
138
+ )
139
+ return self.output_collater(samples)
140
+
141
+ def num_tokens(self, index):
142
+ """Just use the tgt dataset num_tokens"""
143
+ return self.tgt_dataset.num_tokens(index)
144
+
145
+ def ordered_indices(self):
146
+ """Just use the tgt dataset ordered_indices"""
147
+ return self.tgt_dataset.ordered_indices()
148
+
149
+ def size(self, index):
150
+ """Return an example's size as a float or tuple. This value is used
151
+ when filtering a dataset with ``--max-positions``.
152
+
153
+ Note: we use *tgt_dataset* to approximate the length of the source
154
+ sentence, since we do not know the actual length until after
155
+ backtranslation.
156
+ """
157
+ tgt_size = self.tgt_dataset.size(index)[0]
158
+ return (tgt_size, tgt_size)
159
+
160
+ @property
161
+ def supports_prefetch(self):
162
+ return getattr(self.tgt_dataset, "supports_prefetch", False)
163
+
164
+ def prefetch(self, indices):
165
+ return self.tgt_dataset.prefetch(indices)
modules/voice_conversion/fairseq/data/base_wrapper_dataset.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from torch.utils.data.dataloader import default_collate
7
+
8
+ from . import FairseqDataset
9
+
10
+
11
+ class BaseWrapperDataset(FairseqDataset):
12
+ def __init__(self, dataset):
13
+ super().__init__()
14
+ self.dataset = dataset
15
+
16
+ def __getitem__(self, index):
17
+ return self.dataset[index]
18
+
19
+ def __len__(self):
20
+ return len(self.dataset)
21
+
22
+ def collater(self, samples):
23
+ if hasattr(self.dataset, "collater"):
24
+ return self.dataset.collater(samples)
25
+ else:
26
+ return default_collate(samples)
27
+
28
+ @property
29
+ def sizes(self):
30
+ return self.dataset.sizes
31
+
32
+ def num_tokens(self, index):
33
+ return self.dataset.num_tokens(index)
34
+
35
+ def size(self, index):
36
+ return self.dataset.size(index)
37
+
38
+ def ordered_indices(self):
39
+ return self.dataset.ordered_indices()
40
+
41
+ @property
42
+ def supports_prefetch(self):
43
+ return getattr(self.dataset, "supports_prefetch", False)
44
+
45
+ def attr(self, attr: str, index: int):
46
+ return self.dataset.attr(attr, index)
47
+
48
+ def prefetch(self, indices):
49
+ self.dataset.prefetch(indices)
50
+
51
+ def get_batch_shapes(self):
52
+ return self.dataset.get_batch_shapes()
53
+
54
+ def batch_by_size(
55
+ self,
56
+ indices,
57
+ max_tokens=None,
58
+ max_sentences=None,
59
+ required_batch_size_multiple=1,
60
+ ):
61
+ return self.dataset.batch_by_size(
62
+ indices,
63
+ max_tokens=max_tokens,
64
+ max_sentences=max_sentences,
65
+ required_batch_size_multiple=required_batch_size_multiple,
66
+ )
67
+
68
+ def filter_indices_by_size(self, indices, max_sizes):
69
+ return self.dataset.filter_indices_by_size(indices, max_sizes)
70
+
71
+ @property
72
+ def can_reuse_epoch_itr_across_epochs(self):
73
+ return self.dataset.can_reuse_epoch_itr_across_epochs
74
+
75
+ def set_epoch(self, epoch):
76
+ super().set_epoch(epoch)
77
+ if hasattr(self.dataset, "set_epoch"):
78
+ self.dataset.set_epoch(epoch)
modules/voice_conversion/fairseq/data/bucket_pad_length_dataset.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch.nn.functional as F
8
+ from fairseq.data import BaseWrapperDataset
9
+ from fairseq.data.data_utils import get_buckets, get_bucketed_sizes
10
+
11
+
12
+ class BucketPadLengthDataset(BaseWrapperDataset):
13
+ """
14
+ Bucket and pad item lengths to the nearest bucket size. This can be used to
15
+ reduce the number of unique batch shapes, which is important on TPUs since
16
+ each new batch shape requires a recompilation.
17
+
18
+ Args:
19
+ dataset (FairseqDatset): dataset to bucket
20
+ sizes (List[int]): all item sizes
21
+ num_buckets (int): number of buckets to create
22
+ pad_idx (int): padding symbol
23
+ left_pad (bool): if True, pad on the left; otherwise right pad
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ dataset,
29
+ sizes,
30
+ num_buckets,
31
+ pad_idx,
32
+ left_pad,
33
+ tensor_key=None,
34
+ ):
35
+ super().__init__(dataset)
36
+ self.pad_idx = pad_idx
37
+ self.left_pad = left_pad
38
+
39
+ assert num_buckets > 0
40
+ self.buckets = get_buckets(sizes, num_buckets)
41
+ self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets)
42
+ self._tensor_key = tensor_key
43
+
44
+ def _set_tensor(self, item, val):
45
+ if self._tensor_key is None:
46
+ return val
47
+ item[self._tensor_key] = val
48
+ return item
49
+
50
+ def _get_tensor(self, item):
51
+ if self._tensor_key is None:
52
+ return item
53
+ return item[self._tensor_key]
54
+
55
+ def _pad(self, tensor, bucket_size, dim=-1):
56
+ num_pad = bucket_size - tensor.size(dim)
57
+ return F.pad(
58
+ tensor,
59
+ (num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad),
60
+ value=self.pad_idx,
61
+ )
62
+
63
+ def __getitem__(self, index):
64
+ item = self.dataset[index]
65
+ bucket_size = self._bucketed_sizes[index]
66
+ tensor = self._get_tensor(item)
67
+ padded = self._pad(tensor, bucket_size)
68
+ return self._set_tensor(item, padded)
69
+
70
+ @property
71
+ def sizes(self):
72
+ return self._bucketed_sizes
73
+
74
+ def num_tokens(self, index):
75
+ return self._bucketed_sizes[index]
76
+
77
+ def size(self, index):
78
+ return self._bucketed_sizes[index]
modules/voice_conversion/fairseq/data/codedataset.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import json
8
+ import logging
9
+ import os
10
+ import random
11
+ from pathlib import Path
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torch.utils.data
16
+
17
+ from . import data_utils
18
+ from fairseq.data.fairseq_dataset import FairseqDataset
19
+
20
+ F0_FRAME_SPACE = 0.005 # sec
21
+
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class ExpressiveCodeDataConfig(object):
27
+ def __init__(self, json_path):
28
+ with open(json_path, "r") as f:
29
+ self.config = json.load(f)
30
+ self._manifests = self.config["manifests"]
31
+
32
+ @property
33
+ def manifests(self):
34
+ return self._manifests
35
+
36
+ @property
37
+ def n_units(self):
38
+ return self.config["n_units"]
39
+
40
+ @property
41
+ def sampling_rate(self):
42
+ return self.config["sampling_rate"]
43
+
44
+ @property
45
+ def code_hop_size(self):
46
+ return self.config["code_hop_size"]
47
+
48
+ @property
49
+ def f0_stats(self):
50
+ """pre-computed f0 statistics path"""
51
+ return self.config.get("f0_stats", None)
52
+
53
+ @property
54
+ def f0_vq_type(self):
55
+ """naive or precomp"""
56
+ return self.config["f0_vq_type"]
57
+
58
+ @property
59
+ def f0_vq_name(self):
60
+ return self.config["f0_vq_name"]
61
+
62
+ def get_f0_vq_naive_quantizer(self, log, norm_mean, norm_std):
63
+ key = "log" if log else "linear"
64
+ if norm_mean and norm_std:
65
+ key += "_mean_std_norm"
66
+ elif norm_mean:
67
+ key += "_mean_norm"
68
+ else:
69
+ key += "_none_norm"
70
+ return self.config["f0_vq_naive_quantizer"][key]
71
+
72
+ @property
73
+ def f0_vq_n_units(self):
74
+ return self.config["f0_vq_n_units"]
75
+
76
+ @property
77
+ def multispkr(self):
78
+ """how to parse speaker label from audio path"""
79
+ return self.config.get("multispkr", None)
80
+
81
+
82
+ def get_f0(audio, rate=16000):
83
+ try:
84
+ import amfm_decompy.basic_tools as basic
85
+ import amfm_decompy.pYAAPT as pYAAPT
86
+ from librosa.util import normalize
87
+ except ImportError:
88
+ raise "Please install amfm_decompy (`pip install AMFM-decompy`) and librosa (`pip install librosa`)."
89
+
90
+ assert audio.ndim == 1
91
+ frame_length = 20.0 # ms
92
+ to_pad = int(frame_length / 1000 * rate) // 2
93
+
94
+ audio = normalize(audio) * 0.95
95
+ audio = np.pad(audio, (to_pad, to_pad), "constant", constant_values=0)
96
+ audio = basic.SignalObj(audio, rate)
97
+ pitch = pYAAPT.yaapt(
98
+ audio,
99
+ frame_length=frame_length,
100
+ frame_space=F0_FRAME_SPACE * 1000,
101
+ nccf_thresh1=0.25,
102
+ tda_frame_length=25.0,
103
+ )
104
+ f0 = pitch.samp_values
105
+ return f0
106
+
107
+
108
+ def interpolate_f0(f0):
109
+ try:
110
+ from scipy.interpolate import interp1d
111
+ except ImportError:
112
+ raise "Please install scipy (`pip install scipy`)"
113
+
114
+ orig_t = np.arange(f0.shape[0])
115
+ f0_interp = f0[:]
116
+ ii = f0_interp != 0
117
+ if ii.sum() > 1:
118
+ f0_interp = interp1d(
119
+ orig_t[ii], f0_interp[ii], bounds_error=False, kind="linear", fill_value=0
120
+ )(orig_t)
121
+ f0_interp = torch.Tensor(f0_interp).type_as(f0).to(f0.device)
122
+ return f0_interp
123
+
124
+
125
+ def naive_quantize(x, edges):
126
+ bin_idx = (x.view(-1, 1) > edges.view(1, -1)).long().sum(dim=1)
127
+ return bin_idx
128
+
129
+
130
+ def load_wav(full_path):
131
+ try:
132
+ import soundfile as sf
133
+ except ImportError:
134
+ raise "Please install soundfile (`pip install SoundFile`)"
135
+ data, sampling_rate = sf.read(full_path)
136
+ return data, sampling_rate
137
+
138
+
139
+ def parse_code(code_str, dictionary, append_eos):
140
+ code, duration = torch.unique_consecutive(
141
+ torch.ShortTensor(list(map(int, code_str.split()))), return_counts=True
142
+ )
143
+ code = " ".join(map(str, code.tolist()))
144
+ code = dictionary.encode_line(code, append_eos).short()
145
+
146
+ if append_eos:
147
+ duration = torch.cat((duration, duration.new_zeros((1,))), dim=0) # eos
148
+ duration = duration.short()
149
+ return code, duration
150
+
151
+
152
+ def parse_manifest(manifest, dictionary):
153
+ audio_files = []
154
+ codes = []
155
+ durations = []
156
+ speakers = []
157
+
158
+ with open(manifest) as info:
159
+ for line in info.readlines():
160
+ sample = eval(line.strip())
161
+ if "cpc_km100" in sample:
162
+ k = "cpc_km100"
163
+ elif "hubert_km100" in sample:
164
+ k = "hubert_km100"
165
+ elif "phone" in sample:
166
+ k = "phone"
167
+ else:
168
+ assert False, "unknown format"
169
+ code = sample[k]
170
+ code, duration = parse_code(code, dictionary, append_eos=True)
171
+
172
+ codes.append(code)
173
+ durations.append(duration)
174
+ audio_files.append(sample["audio"])
175
+ speakers.append(sample.get("speaker", None))
176
+
177
+ return audio_files, codes, durations, speakers
178
+
179
+
180
+ def parse_speaker(path, method):
181
+ if type(path) == str:
182
+ path = Path(path)
183
+
184
+ if method == "parent_name":
185
+ return path.parent.name
186
+ elif method == "parent_parent_name":
187
+ return path.parent.parent.name
188
+ elif method == "_":
189
+ return path.name.split("_")[0]
190
+ elif method == "single":
191
+ return "A"
192
+ elif callable(method):
193
+ return method(path)
194
+ else:
195
+ raise NotImplementedError()
196
+
197
+
198
+ def get_f0_by_filename(filename, tgt_sampling_rate):
199
+ audio, sampling_rate = load_wav(filename)
200
+ if sampling_rate != tgt_sampling_rate:
201
+ raise ValueError(
202
+ "{} SR doesn't match target {} SR".format(sampling_rate, tgt_sampling_rate)
203
+ )
204
+
205
+ # compute un-interpolated f0, and use Ann's interp in __getitem__ if set
206
+ f0 = get_f0(audio, rate=tgt_sampling_rate)
207
+ f0 = torch.from_numpy(f0.astype(np.float32))
208
+ return f0
209
+
210
+
211
+ def align_f0_to_durations(f0, durations, f0_code_ratio, tol=1):
212
+ code_len = durations.sum()
213
+ targ_len = int(f0_code_ratio * code_len)
214
+ diff = f0.size(0) - targ_len
215
+ assert abs(diff) <= tol, (
216
+ f"Cannot subsample F0: |{f0.size(0)} - {f0_code_ratio}*{code_len}|"
217
+ f" > {tol} (dur=\n{durations})"
218
+ )
219
+ if diff > 0:
220
+ f0 = f0[:targ_len]
221
+ elif diff < 0:
222
+ f0 = torch.cat((f0, f0.new_full((-diff,), f0[-1])), 0)
223
+
224
+ f0_offset = 0.0
225
+ seg_f0s = []
226
+ for dur in durations:
227
+ f0_dur = dur.item() * f0_code_ratio
228
+ seg_f0 = f0[int(f0_offset) : int(f0_offset + f0_dur)]
229
+ seg_f0 = seg_f0[seg_f0 != 0]
230
+ if len(seg_f0) == 0:
231
+ seg_f0 = torch.tensor(0).type(seg_f0.type())
232
+ else:
233
+ seg_f0 = seg_f0.mean()
234
+ seg_f0s.append(seg_f0)
235
+ f0_offset += f0_dur
236
+
237
+ assert int(f0_offset) == f0.size(0), f"{f0_offset} {f0.size()} {durations.sum()}"
238
+ return torch.tensor(seg_f0s)
239
+
240
+
241
+ class Paddings(object):
242
+ def __init__(self, code_val, dur_val=0, f0_val=-2.0):
243
+ self.code = code_val
244
+ self.dur = dur_val
245
+ self.f0 = f0_val
246
+
247
+
248
+ class Shifts(object):
249
+ def __init__(self, shifts_str, pads):
250
+ self._shifts = list(map(int, shifts_str.split(",")))
251
+ assert len(self._shifts) == 2, self._shifts
252
+ assert all(s >= 0 for s in self._shifts)
253
+ self.extra_length = max(s for s in self._shifts)
254
+ self.pads = pads
255
+
256
+ @property
257
+ def dur(self):
258
+ return self._shifts[0]
259
+
260
+ @property
261
+ def f0(self):
262
+ return self._shifts[1]
263
+
264
+ @staticmethod
265
+ def shift_one(seq, left_pad_num, right_pad_num, pad):
266
+ assert seq.ndim == 1
267
+ bos = seq.new_full((left_pad_num,), pad)
268
+ eos = seq.new_full((right_pad_num,), pad)
269
+ seq = torch.cat([bos, seq, eos])
270
+ mask = torch.ones_like(seq).bool()
271
+ mask[left_pad_num : len(seq) - right_pad_num] = 0
272
+ return seq, mask
273
+
274
+ def __call__(self, code, dur, f0):
275
+ if self.extra_length == 0:
276
+ code_mask = torch.zeros_like(code).bool()
277
+ dur_mask = torch.zeros_like(dur).bool()
278
+ f0_mask = torch.zeros_like(f0).bool()
279
+ return code, code_mask, dur, dur_mask, f0, f0_mask
280
+
281
+ code, code_mask = self.shift_one(code, 0, self.extra_length, self.pads.code)
282
+ dur, dur_mask = self.shift_one(
283
+ dur, self.dur, self.extra_length - self.dur, self.pads.dur
284
+ )
285
+ f0, f0_mask = self.shift_one(
286
+ f0, self.f0, self.extra_length - self.f0, self.pads.f0
287
+ )
288
+ return code, code_mask, dur, dur_mask, f0, f0_mask
289
+
290
+
291
+ class CodeDataset(FairseqDataset):
292
+ def __init__(
293
+ self,
294
+ manifest,
295
+ dictionary,
296
+ dur_dictionary,
297
+ f0_dictionary,
298
+ config,
299
+ discrete_dur,
300
+ discrete_f0,
301
+ log_f0,
302
+ normalize_f0_mean,
303
+ normalize_f0_std,
304
+ interpolate_f0,
305
+ return_filename=False,
306
+ strip_filename=True,
307
+ shifts="0,0",
308
+ return_continuous_f0=False,
309
+ ):
310
+ random.seed(1234)
311
+ self.dictionary = dictionary
312
+ self.dur_dictionary = dur_dictionary
313
+ self.f0_dictionary = f0_dictionary
314
+ self.config = config
315
+
316
+ # duration config
317
+ self.discrete_dur = discrete_dur
318
+
319
+ # pitch config
320
+ self.discrete_f0 = discrete_f0
321
+ self.log_f0 = log_f0
322
+ self.normalize_f0_mean = normalize_f0_mean
323
+ self.normalize_f0_std = normalize_f0_std
324
+ self.interpolate_f0 = interpolate_f0
325
+
326
+ self.return_filename = return_filename
327
+ self.strip_filename = strip_filename
328
+ self.f0_code_ratio = config.code_hop_size / (
329
+ config.sampling_rate * F0_FRAME_SPACE
330
+ )
331
+
332
+ # use lazy loading to avoid sharing file handlers across workers
333
+ self.manifest = manifest
334
+ self._codes = None
335
+ self._durs = None
336
+ self._f0s = None
337
+ with open(f"{manifest}.leng.txt", "r") as f:
338
+ lengs = [int(line.rstrip()) for line in f]
339
+ edges = np.cumsum([0] + lengs)
340
+ self.starts, self.ends = edges[:-1], edges[1:]
341
+ with open(f"{manifest}.path.txt", "r") as f:
342
+ self.file_names = [line.rstrip() for line in f]
343
+ logger.info(f"num entries: {len(self.starts)}")
344
+
345
+ if os.path.exists(f"{manifest}.f0_stat.pt"):
346
+ self.f0_stats = torch.load(f"{manifest}.f0_stat.pt")
347
+ elif config.f0_stats:
348
+ self.f0_stats = torch.load(config.f0_stats)
349
+
350
+ self.multispkr = config.multispkr
351
+ if config.multispkr:
352
+ with open(f"{manifest}.speaker.txt", "r") as f:
353
+ self.spkrs = [line.rstrip() for line in f]
354
+ self.id_to_spkr = sorted(self.spkrs)
355
+ self.spkr_to_id = {k: v for v, k in enumerate(self.id_to_spkr)}
356
+
357
+ self.pads = Paddings(
358
+ dictionary.pad(),
359
+ 0, # use 0 for duration padding
360
+ f0_dictionary.pad() if discrete_f0 else -5.0,
361
+ )
362
+ self.shifts = Shifts(shifts, pads=self.pads)
363
+ self.return_continuous_f0 = return_continuous_f0
364
+
365
+ def get_data_handlers(self):
366
+ logging.info(f"loading data for {self.manifest}")
367
+ self._codes = np.load(f"{self.manifest}.code.npy", mmap_mode="r")
368
+ self._durs = np.load(f"{self.manifest}.dur.npy", mmap_mode="r")
369
+
370
+ if self.discrete_f0:
371
+ if self.config.f0_vq_type == "precomp":
372
+ self._f0s = np.load(
373
+ f"{self.manifest}.{self.config.f0_vq_name}.npy", mmap_mode="r"
374
+ )
375
+ elif self.config.f0_vq_type == "naive":
376
+ self._f0s = np.load(f"{self.manifest}.f0.npy", mmap_mode="r")
377
+ quantizers_path = self.config.get_f0_vq_naive_quantizer(
378
+ self.log_f0, self.normalize_f0_mean, self.normalize_f0_std
379
+ )
380
+ quantizers = torch.load(quantizers_path)
381
+ n_units = self.config.f0_vq_n_units
382
+ self._f0_quantizer = torch.from_numpy(quantizers[n_units])
383
+ else:
384
+ raise ValueError(f"f0_vq_type {self.config.f0_vq_type} not supported")
385
+ else:
386
+ self._f0s = np.load(f"{self.manifest}.f0.npy", mmap_mode="r")
387
+
388
+ def preprocess_f0(self, f0, stats):
389
+ """
390
+ 1. interpolate
391
+ 2. log transform (keep unvoiced frame 0)
392
+ """
393
+ # TODO: change this to be dependent on config for naive quantizer
394
+ f0 = f0.clone()
395
+ if self.interpolate_f0:
396
+ f0 = interpolate_f0(f0)
397
+
398
+ mask = f0 != 0 # only process voiced frames
399
+ if self.log_f0:
400
+ f0[mask] = f0[mask].log()
401
+ if self.normalize_f0_mean:
402
+ mean = stats["logf0_mean"] if self.log_f0 else stats["f0_mean"]
403
+ f0[mask] = f0[mask] - mean
404
+ if self.normalize_f0_std:
405
+ std = stats["logf0_std"] if self.log_f0 else stats["f0_std"]
406
+ f0[mask] = f0[mask] / std
407
+ return f0
408
+
409
+ def _get_raw_item(self, index):
410
+ start, end = self.starts[index], self.ends[index]
411
+ if self._codes is None:
412
+ self.get_data_handlers()
413
+ code = torch.from_numpy(np.array(self._codes[start:end])).long()
414
+ dur = torch.from_numpy(np.array(self._durs[start:end]))
415
+ f0 = torch.from_numpy(np.array(self._f0s[start:end]))
416
+ return code, dur, f0
417
+
418
+ def __getitem__(self, index):
419
+ code, dur, f0 = self._get_raw_item(index)
420
+ code = torch.cat([code.new([self.dictionary.bos()]), code])
421
+
422
+ # use 0 for eos and bos
423
+ dur = torch.cat([dur.new([0]), dur])
424
+ if self.discrete_dur:
425
+ dur = self.dur_dictionary.encode_line(
426
+ " ".join(map(str, dur.tolist())), append_eos=False
427
+ ).long()
428
+ else:
429
+ dur = dur.float()
430
+
431
+ # TODO: find a more elegant approach
432
+ raw_f0 = None
433
+ if self.discrete_f0:
434
+ if self.config.f0_vq_type == "precomp":
435
+ f0 = self.f0_dictionary.encode_line(
436
+ " ".join(map(str, f0.tolist())), append_eos=False
437
+ ).long()
438
+ else:
439
+ f0 = f0.float()
440
+ f0 = self.preprocess_f0(f0, self.f0_stats[self.spkrs[index]])
441
+ if self.return_continuous_f0:
442
+ raw_f0 = f0
443
+ raw_f0 = torch.cat([raw_f0.new([self.f0_dictionary.bos()]), raw_f0])
444
+ f0 = naive_quantize(f0, self._f0_quantizer)
445
+ f0 = torch.cat([f0.new([self.f0_dictionary.bos()]), f0])
446
+ else:
447
+ f0 = f0.float()
448
+ if self.multispkr:
449
+ f0 = self.preprocess_f0(f0, self.f0_stats[self.spkrs[index]])
450
+ else:
451
+ f0 = self.preprocess_f0(f0, self.f0_stats)
452
+ f0 = torch.cat([f0.new([0]), f0])
453
+
454
+ if raw_f0 is not None:
455
+ *_, raw_f0, raw_f0_mask = self.shifts(code, dur, raw_f0)
456
+ else:
457
+ raw_f0_mask = None
458
+
459
+ code, code_mask, dur, dur_mask, f0, f0_mask = self.shifts(code, dur, f0)
460
+ if raw_f0_mask is not None:
461
+ assert (raw_f0_mask == f0_mask).all()
462
+
463
+ # is a padded frame if either input or output is padded
464
+ feats = {
465
+ "source": code[:-1],
466
+ "target": code[1:],
467
+ "mask": code_mask[1:].logical_or(code_mask[:-1]),
468
+ "dur_source": dur[:-1],
469
+ "dur_target": dur[1:],
470
+ "dur_mask": dur_mask[1:].logical_or(dur_mask[:-1]),
471
+ "f0_source": f0[:-1],
472
+ "f0_target": f0[1:],
473
+ "f0_mask": f0_mask[1:].logical_or(f0_mask[:-1]),
474
+ }
475
+
476
+ if raw_f0 is not None:
477
+ feats["raw_f0"] = raw_f0[1:]
478
+
479
+ if self.return_filename:
480
+ fname = self.file_names[index]
481
+ feats["filename"] = (
482
+ fname if not self.strip_filename else Path(fname).with_suffix("").name
483
+ )
484
+ return feats
485
+
486
+ def __len__(self):
487
+ return len(self.starts)
488
+
489
+ def size(self, index):
490
+ return self.ends[index] - self.starts[index] + self.shifts.extra_length
491
+
492
+ def num_tokens(self, index):
493
+ return self.size(index)
494
+
495
+ def collater(self, samples):
496
+ pad_idx, eos_idx = self.dictionary.pad(), self.dictionary.eos()
497
+ if len(samples) == 0:
498
+ return {}
499
+
500
+ src_tokens = data_utils.collate_tokens(
501
+ [s["source"] for s in samples], pad_idx, eos_idx, left_pad=False
502
+ )
503
+
504
+ tgt_tokens = data_utils.collate_tokens(
505
+ [s["target"] for s in samples],
506
+ pad_idx=pad_idx,
507
+ eos_idx=pad_idx, # appending padding, eos is there already
508
+ left_pad=False,
509
+ )
510
+
511
+ src_durs, tgt_durs = [
512
+ data_utils.collate_tokens(
513
+ [s[k] for s in samples],
514
+ pad_idx=self.pads.dur,
515
+ eos_idx=self.pads.dur,
516
+ left_pad=False,
517
+ )
518
+ for k in ["dur_source", "dur_target"]
519
+ ]
520
+
521
+ src_f0s, tgt_f0s = [
522
+ data_utils.collate_tokens(
523
+ [s[k] for s in samples],
524
+ pad_idx=self.pads.f0,
525
+ eos_idx=self.pads.f0,
526
+ left_pad=False,
527
+ )
528
+ for k in ["f0_source", "f0_target"]
529
+ ]
530
+
531
+ mask, dur_mask, f0_mask = [
532
+ data_utils.collate_tokens(
533
+ [s[k] for s in samples],
534
+ pad_idx=1,
535
+ eos_idx=1,
536
+ left_pad=False,
537
+ )
538
+ for k in ["mask", "dur_mask", "f0_mask"]
539
+ ]
540
+
541
+ src_lengths = torch.LongTensor([s["source"].numel() for s in samples])
542
+ n_tokens = sum(len(s["source"]) for s in samples)
543
+
544
+ result = {
545
+ "nsentences": len(samples),
546
+ "ntokens": n_tokens,
547
+ "net_input": {
548
+ "src_tokens": src_tokens,
549
+ "src_lengths": src_lengths,
550
+ "dur_src": src_durs,
551
+ "f0_src": src_f0s,
552
+ },
553
+ "target": tgt_tokens,
554
+ "dur_target": tgt_durs,
555
+ "f0_target": tgt_f0s,
556
+ "mask": mask,
557
+ "dur_mask": dur_mask,
558
+ "f0_mask": f0_mask,
559
+ }
560
+
561
+ if "filename" in samples[0]:
562
+ result["filename"] = [s["filename"] for s in samples]
563
+
564
+ # TODO: remove this hack into the inference dataset
565
+ if "prefix" in samples[0]:
566
+ result["prefix"] = [s["prefix"] for s in samples]
567
+
568
+ if "raw_f0" in samples[0]:
569
+ raw_f0s = data_utils.collate_tokens(
570
+ [s["raw_f0"] for s in samples],
571
+ pad_idx=self.pads.f0,
572
+ eos_idx=self.pads.f0,
573
+ left_pad=False,
574
+ )
575
+ result["raw_f0"] = raw_f0s
576
+ return result
modules/voice_conversion/fairseq/data/colorize_dataset.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+ from . import BaseWrapperDataset
9
+
10
+
11
+ class ColorizeDataset(BaseWrapperDataset):
12
+ """Adds 'colors' property to net input that is obtained from the provided color getter for use by models"""
13
+
14
+ def __init__(self, dataset, color_getter):
15
+ super().__init__(dataset)
16
+ self.color_getter = color_getter
17
+
18
+ def collater(self, samples):
19
+ base_collate = super().collater(samples)
20
+ if len(base_collate) > 0:
21
+ base_collate["net_input"]["colors"] = torch.tensor(
22
+ list(self.color_getter(self.dataset, s["id"]) for s in samples),
23
+ dtype=torch.long,
24
+ )
25
+ return base_collate
modules/voice_conversion/fairseq/data/concat_dataset.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import bisect
7
+
8
+ import numpy as np
9
+ from torch.utils.data.dataloader import default_collate
10
+
11
+ from . import FairseqDataset
12
+
13
+
14
+ class ConcatDataset(FairseqDataset):
15
+ @staticmethod
16
+ def cumsum(sequence, sample_ratios):
17
+ r, s = [], 0
18
+ for e, ratio in zip(sequence, sample_ratios):
19
+ curr_len = int(ratio * len(e))
20
+ r.append(curr_len + s)
21
+ s += curr_len
22
+ return r
23
+
24
+ def __init__(self, datasets, sample_ratios=1):
25
+ super(ConcatDataset, self).__init__()
26
+ assert len(datasets) > 0, "datasets should not be an empty iterable"
27
+ self.datasets = list(datasets)
28
+ if isinstance(sample_ratios, int):
29
+ sample_ratios = [sample_ratios] * len(self.datasets)
30
+ self.sample_ratios = sample_ratios
31
+ self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios)
32
+ self.real_sizes = [len(d) for d in self.datasets]
33
+
34
+ def __len__(self):
35
+ return self.cumulative_sizes[-1]
36
+
37
+ def __getitem__(self, idx):
38
+ dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
39
+ return self.datasets[dataset_idx][sample_idx]
40
+
41
+ def _get_dataset_and_sample_index(self, idx: int):
42
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
43
+ if dataset_idx == 0:
44
+ sample_idx = idx
45
+ else:
46
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
47
+ sample_idx = sample_idx % self.real_sizes[dataset_idx]
48
+ return dataset_idx, sample_idx
49
+
50
+ def collater(self, samples, **extra_args):
51
+ # For now only supports datasets with same underlying collater implementations
52
+ if hasattr(self.datasets[0], "collater"):
53
+ return self.datasets[0].collater(samples, **extra_args)
54
+ else:
55
+ return default_collate(samples, **extra_args)
56
+
57
+ def size(self, idx: int):
58
+ """
59
+ Return an example's size as a float or tuple.
60
+ """
61
+ dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
62
+ return self.datasets[dataset_idx].size(sample_idx)
63
+
64
+ def num_tokens(self, index: int):
65
+ return np.max(self.size(index))
66
+
67
+ def attr(self, attr: str, index: int):
68
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, index)
69
+ return getattr(self.datasets[dataset_idx], attr, None)
70
+
71
+ @property
72
+ def sizes(self):
73
+ _dataset_sizes = []
74
+ for ds, sr in zip(self.datasets, self.sample_ratios):
75
+ if isinstance(ds.sizes, np.ndarray):
76
+ _dataset_sizes.append(np.tile(ds.sizes, sr))
77
+ else:
78
+ # Only support underlying dataset with single size array.
79
+ assert isinstance(ds.sizes, list)
80
+ _dataset_sizes.append(np.tile(ds.sizes[0], sr))
81
+ return np.concatenate(_dataset_sizes)
82
+
83
+ @property
84
+ def supports_prefetch(self):
85
+ return all(d.supports_prefetch for d in self.datasets)
86
+
87
+ def ordered_indices(self):
88
+ """
89
+ Returns indices sorted by length. So less padding is needed.
90
+ """
91
+ if isinstance(self.sizes, np.ndarray) and len(self.sizes.shape) > 1:
92
+ # special handling for concatenating lang_pair_datasets
93
+ indices = np.arange(len(self))
94
+ sizes = self.sizes
95
+ tgt_sizes = (
96
+ sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None
97
+ )
98
+ src_sizes = (
99
+ sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
100
+ )
101
+ # sort by target length, then source length
102
+ if tgt_sizes is not None:
103
+ indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")]
104
+ return indices[np.argsort(src_sizes[indices], kind="mergesort")]
105
+ else:
106
+ return np.argsort(self.sizes)
107
+
108
+ def prefetch(self, indices):
109
+ frm = 0
110
+ for to, ds in zip(self.cumulative_sizes, self.datasets):
111
+ real_size = len(ds)
112
+ if getattr(ds, "supports_prefetch", False):
113
+ ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
114
+ frm = to
115
+
116
+ @property
117
+ def can_reuse_epoch_itr_across_epochs(self):
118
+ return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets)
119
+
120
+ def set_epoch(self, epoch):
121
+ super().set_epoch(epoch)
122
+ for ds in self.datasets:
123
+ if hasattr(ds, "set_epoch"):
124
+ ds.set_epoch(epoch)