Spaces:
Running
Running
TomatoCocotree
commited on
Commit
•
6a62ffb
1
Parent(s):
59656d8
上传
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Dockerfile +21 -0
- LICENSE +24 -0
- api_key.txt +1 -0
- constants.py +49 -0
- data/models/coqui/.placeholder +2 -0
- data/models/rvc/.placeholder +3 -0
- data/tmp/.placeholder +2 -0
- docker/Dockerfile +35 -0
- docker/docker-compose.yml +23 -0
- docker/readme.md +10 -0
- modules/classify/classify_module.py +41 -0
- modules/speech_recognition/streaming_module.py +121 -0
- modules/speech_recognition/vosk_module.py +77 -0
- modules/speech_recognition/whisper_module.py +56 -0
- modules/text_to_speech/coqui/coqui_module.py +333 -0
- modules/utils.py +15 -0
- modules/voice_conversion/fairseq/LICENSE +21 -0
- modules/voice_conversion/fairseq/__init__.py +45 -0
- modules/voice_conversion/fairseq/binarizer.py +381 -0
- modules/voice_conversion/fairseq/checkpoint_utils.py +905 -0
- modules/voice_conversion/fairseq/data/__init__.py +130 -0
- modules/voice_conversion/fairseq/data/add_target_dataset.py +83 -0
- modules/voice_conversion/fairseq/data/append_token_dataset.py +41 -0
- modules/voice_conversion/fairseq/data/audio/__init__.py +93 -0
- modules/voice_conversion/fairseq/data/audio/audio_utils.py +389 -0
- modules/voice_conversion/fairseq/data/audio/data_cfg.py +387 -0
- modules/voice_conversion/fairseq/data/audio/dataset_transforms/__init__.py +53 -0
- modules/voice_conversion/fairseq/data/audio/dataset_transforms/concataugment.py +61 -0
- modules/voice_conversion/fairseq/data/audio/dataset_transforms/noisyoverlapaugment.py +105 -0
- modules/voice_conversion/fairseq/data/audio/feature_transforms/__init__.py +43 -0
- modules/voice_conversion/fairseq/data/audio/feature_transforms/delta_deltas.py +37 -0
- modules/voice_conversion/fairseq/data/audio/feature_transforms/global_cmvn.py +29 -0
- modules/voice_conversion/fairseq/data/audio/feature_transforms/specaugment.py +131 -0
- modules/voice_conversion/fairseq/data/audio/feature_transforms/utterance_cmvn.py +41 -0
- modules/voice_conversion/fairseq/data/audio/frm_text_to_speech_dataset.py +205 -0
- modules/voice_conversion/fairseq/data/audio/hubert_dataset.py +356 -0
- modules/voice_conversion/fairseq/data/audio/multi_modality_dataset.py +284 -0
- modules/voice_conversion/fairseq/data/audio/raw_audio_dataset.py +393 -0
- modules/voice_conversion/fairseq/data/audio/speech_to_speech_dataset.py +379 -0
- modules/voice_conversion/fairseq/data/audio/speech_to_text_dataset.py +733 -0
- modules/voice_conversion/fairseq/data/audio/speech_to_text_joint_dataset.py +359 -0
- modules/voice_conversion/fairseq/data/audio/text_to_speech_dataset.py +250 -0
- modules/voice_conversion/fairseq/data/audio/waveform_transforms/__init__.py +48 -0
- modules/voice_conversion/fairseq/data/audio/waveform_transforms/noiseaugment.py +201 -0
- modules/voice_conversion/fairseq/data/backtranslation_dataset.py +165 -0
- modules/voice_conversion/fairseq/data/base_wrapper_dataset.py +78 -0
- modules/voice_conversion/fairseq/data/bucket_pad_length_dataset.py +78 -0
- modules/voice_conversion/fairseq/data/codedataset.py +576 -0
- modules/voice_conversion/fairseq/data/colorize_dataset.py +25 -0
- modules/voice_conversion/fairseq/data/concat_dataset.py +124 -0
Dockerfile
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.11
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
COPY requirements-complete.txt .
|
6 |
+
RUN pip install -r requirements-complete.txt
|
7 |
+
|
8 |
+
RUN mkdir /.cache && chmod -R 777 /.cache
|
9 |
+
RUN mkdir .chroma && chmod -R 777 .chroma
|
10 |
+
|
11 |
+
COPY . .
|
12 |
+
|
13 |
+
|
14 |
+
RUN chmod -R 777 /app
|
15 |
+
|
16 |
+
RUN --mount=type=secret,id=password,mode=0444,required=true \
|
17 |
+
cat /run/secrets/password > /test
|
18 |
+
|
19 |
+
EXPOSE 7860
|
20 |
+
|
21 |
+
CMD ["python", "server.py", "--cpu", "--enable-modules=caption,summarize,classify,silero-tts,edge-tts,chromadb"]
|
LICENSE
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
This is free and unencumbered software released into the public domain.
|
2 |
+
|
3 |
+
Anyone is free to copy, modify, publish, use, compile, sell, or
|
4 |
+
distribute this software, either in source code form or as a compiled
|
5 |
+
binary, for any purpose, commercial or non-commercial, and by any
|
6 |
+
means.
|
7 |
+
|
8 |
+
In jurisdictions that recognize copyright laws, the author or authors
|
9 |
+
of this software dedicate any and all copyright interest in the
|
10 |
+
software to the public domain. We make this dedication for the benefit
|
11 |
+
of the public at large and to the detriment of our heirs and
|
12 |
+
successors. We intend this dedication to be an overt act of
|
13 |
+
relinquishment in perpetuity of all present and future rights to this
|
14 |
+
software under copyright law.
|
15 |
+
|
16 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
17 |
+
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
18 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
19 |
+
IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
|
20 |
+
OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
|
21 |
+
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
22 |
+
OTHER DEALINGS IN THE SOFTWARE.
|
23 |
+
|
24 |
+
For more information, please refer to <https://unlicense.org>
|
api_key.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
CHANGEME
|
constants.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Constants
|
2 |
+
DEFAULT_CUDA_DEVICE = "cuda:0"
|
3 |
+
# Also try: 'Qiliang/bart-large-cnn-samsum-ElectrifAi_v10'
|
4 |
+
DEFAULT_SUMMARIZATION_MODEL = "Qiliang/bart-large-cnn-samsum-ChatGPT_v3"
|
5 |
+
# Also try: 'joeddav/distilbert-base-uncased-go-emotions-student'
|
6 |
+
DEFAULT_CLASSIFICATION_MODEL = "nateraw/bert-base-uncased-emotion"
|
7 |
+
# Also try: 'Salesforce/blip-image-captioning-base'
|
8 |
+
DEFAULT_CAPTIONING_MODEL = "Salesforce/blip-image-captioning-large"
|
9 |
+
DEFAULT_SD_MODEL = "ckpt/anything-v4.5-vae-swapped"
|
10 |
+
DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
|
11 |
+
DEFAULT_REMOTE_SD_HOST = "127.0.0.1"
|
12 |
+
DEFAULT_REMOTE_SD_PORT = 7860
|
13 |
+
DEFAULT_CHROMA_PORT = 8000
|
14 |
+
SILERO_SAMPLES_PATH = "tts_samples"
|
15 |
+
SILERO_SAMPLE_TEXT = "The quick brown fox jumps over the lazy dog"
|
16 |
+
DEFAULT_SUMMARIZE_PARAMS = {
|
17 |
+
"temperature": 1.0,
|
18 |
+
"repetition_penalty": 1.0,
|
19 |
+
"max_length": 500,
|
20 |
+
"min_length": 200,
|
21 |
+
"length_penalty": 1.5,
|
22 |
+
"bad_words": [
|
23 |
+
"\n",
|
24 |
+
'"',
|
25 |
+
"*",
|
26 |
+
"[",
|
27 |
+
"]",
|
28 |
+
"{",
|
29 |
+
"}",
|
30 |
+
":",
|
31 |
+
"(",
|
32 |
+
")",
|
33 |
+
"<",
|
34 |
+
">",
|
35 |
+
"Â",
|
36 |
+
"The text ends",
|
37 |
+
"The story ends",
|
38 |
+
"The text is",
|
39 |
+
"The story is",
|
40 |
+
],
|
41 |
+
}
|
42 |
+
|
43 |
+
PROMPT_PREFIX = "best quality, absurdres, "
|
44 |
+
NEGATIVE_PROMPT = """lowres, bad anatomy, error body, error hair, error arm,
|
45 |
+
error hands, bad hands, error fingers, bad fingers, missing fingers
|
46 |
+
error legs, bad legs, multiple legs, missing legs, error lighting,
|
47 |
+
error shadow, error reflection, text, error, extra digit, fewer digits,
|
48 |
+
cropped, worst quality, low quality, normal quality, jpeg artifacts,
|
49 |
+
signature, watermark, username, blurry"""
|
data/models/coqui/.placeholder
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
Put Coqui models folders here.
|
2 |
+
Must contains both a "model.pth" and "config.json" file.
|
data/models/rvc/.placeholder
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
Put RVC models folder here.
|
2 |
+
Must have ".pth" file in it
|
3 |
+
.index file is optional but could help improve the processing time/quality.
|
data/tmp/.placeholder
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
This is a temporary file folder.
|
2 |
+
May contain RVC input/output file for research purpose.
|
docker/Dockerfile
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:11.7.1-cudnn8-runtime-ubuntu20.04
|
2 |
+
|
3 |
+
EXPOSE 5100
|
4 |
+
|
5 |
+
ENV PATH="/root/miniconda3/bin:${PATH}"
|
6 |
+
ARG PATH="/root/miniconda3/bin:${PATH}"
|
7 |
+
|
8 |
+
ENV DEBIAN_FRONTEND noninteractive
|
9 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
10 |
+
python3 python3-venv wget build-essential
|
11 |
+
|
12 |
+
RUN wget \
|
13 |
+
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
14 |
+
&& mkdir /root/.conda \
|
15 |
+
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
16 |
+
&& rm -f Miniconda3-latest-Linux-x86_64.sh
|
17 |
+
|
18 |
+
RUN conda --version
|
19 |
+
|
20 |
+
RUN conda init
|
21 |
+
|
22 |
+
RUN conda create -n extras
|
23 |
+
|
24 |
+
RUN /bin/bash -c "source activate extras"
|
25 |
+
|
26 |
+
RUN conda install pytorch torchvision torchaudio pytorch-cuda=11.7 git -c pytorch -c nvidia -c conda-forge
|
27 |
+
|
28 |
+
WORKDIR /sillytavern-extras/
|
29 |
+
COPY . .
|
30 |
+
|
31 |
+
ARG REQUIREMENTS
|
32 |
+
RUN pip install -r $REQUIREMENTS
|
33 |
+
|
34 |
+
ARG MODULES
|
35 |
+
CMD ["python","server.py","--enable-modules=$MODULES"]
|
docker/docker-compose.yml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version: "3"
|
2 |
+
services:
|
3 |
+
sillytavern-extras:
|
4 |
+
runtime: nvidia
|
5 |
+
image: cohee1207/sillytavern-extras
|
6 |
+
build:
|
7 |
+
context: ../
|
8 |
+
dockerfile: docker/Dockerfile
|
9 |
+
args:
|
10 |
+
REQUIREMENTS: requirements.txt
|
11 |
+
MODULES: caption,summarize,classify
|
12 |
+
# REQUIREMENTS: requirements-complete.txt
|
13 |
+
# MODULES: caption,summarize,classify,sd,silero-tts,edge-tts,chromadb
|
14 |
+
volumes:
|
15 |
+
#- "./chromadb:/chromadb"
|
16 |
+
- "./cache:/root/.cache"
|
17 |
+
- "./api_key.txt:/sillytavern-extras/api_key.txt:rw"
|
18 |
+
ports:
|
19 |
+
- "5100:5100"
|
20 |
+
environment:
|
21 |
+
- NVIDIA_VISIBLE_DEVICES=all
|
22 |
+
command: python server.py --enable-modules=caption,summarize,classify
|
23 |
+
# command: python server.py --enable-modules=caption,summarize,classify,sd,silero-tts,edge-tts,chromadb
|
docker/readme.md
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Docker Usage
|
2 |
+
|
3 |
+
## Building the image
|
4 |
+
|
5 |
+
*This is assuming you have docker and docker compose installed and running.*
|
6 |
+
|
7 |
+
1. Open a terminal and set your current directory to the "docker" directory in your clone of this repo.
|
8 |
+
2. Adjust the "docker-compose.yml" file to match your needs. The default selection and the selection with all modules are provided as examples.
|
9 |
+
3. Once ready, run the command "docker compose build" to build the "cohee1207/sillytavern-extras" docker image.
|
10 |
+
|
modules/classify/classify_module.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Classify module for SillyTavern Extras
|
3 |
+
|
4 |
+
Authors:
|
5 |
+
- Tony Ribeiro (https://github.com/Tony-sama)
|
6 |
+
- Cohee (https://github.com/Cohee1207)
|
7 |
+
|
8 |
+
Provides classification features for text
|
9 |
+
|
10 |
+
References:
|
11 |
+
- https://huggingface.co/tasks/text-classification
|
12 |
+
"""
|
13 |
+
|
14 |
+
from transformers import pipeline
|
15 |
+
|
16 |
+
DEBUG_PREFIX = "<Classify module>"
|
17 |
+
|
18 |
+
# Models init
|
19 |
+
|
20 |
+
text_emotion_pipe = None
|
21 |
+
|
22 |
+
def init_text_emotion_classifier(model_name: str, device: str, torch_dtype: str) -> None:
|
23 |
+
global text_emotion_pipe
|
24 |
+
|
25 |
+
print(DEBUG_PREFIX,"Initializing text classification pipeline with model",model_name)
|
26 |
+
text_emotion_pipe = pipeline(
|
27 |
+
"text-classification",
|
28 |
+
model=model_name,
|
29 |
+
top_k=None,
|
30 |
+
device=device,
|
31 |
+
torch_dtype=torch_dtype,
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
def classify_text_emotion(text: str) -> list:
|
36 |
+
output = text_emotion_pipe(
|
37 |
+
text,
|
38 |
+
truncation=True,
|
39 |
+
max_length=text_emotion_pipe.model.config.max_position_embeddings,
|
40 |
+
)[0]
|
41 |
+
return sorted(output, key=lambda x: x["score"], reverse=True)
|
modules/speech_recognition/streaming_module.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Speech-to-text module based on Vosk and Whisper for SillyTavern Extras
|
3 |
+
- Vosk website: https://alphacephei.com/vosk/
|
4 |
+
- Vosk api: https://github.com/alphacep/vosk-api
|
5 |
+
- Whisper github: https://github.com/openai/whisper
|
6 |
+
|
7 |
+
Authors:
|
8 |
+
- Tony Ribeiro (https://github.com/Tony-sama)
|
9 |
+
|
10 |
+
Models are saved into user cache folder, example: C:/Users/toto/.cache/whisper and C:/Users/toto/.cache/vosk
|
11 |
+
|
12 |
+
References:
|
13 |
+
- Code adapted from:
|
14 |
+
- whisper github: https://github.com/openai/whisper
|
15 |
+
- oobabooga text-generation-webui github: https://github.com/oobabooga/text-generation-webui
|
16 |
+
- vosk github: https://github.com/alphacep/vosk-api/blob/master/python/example/test_microphone.py
|
17 |
+
"""
|
18 |
+
from flask import jsonify, abort
|
19 |
+
|
20 |
+
import queue
|
21 |
+
import sys
|
22 |
+
import sounddevice as sd
|
23 |
+
import soundfile as sf
|
24 |
+
import io
|
25 |
+
import numpy as np
|
26 |
+
from scipy.io.wavfile import write
|
27 |
+
|
28 |
+
import vosk
|
29 |
+
import whisper
|
30 |
+
|
31 |
+
DEBUG_PREFIX = "<stt streaming module>"
|
32 |
+
RECORDING_FILE_PATH = "stt_test.wav"
|
33 |
+
|
34 |
+
whisper_model = None
|
35 |
+
vosk_model = None
|
36 |
+
device = None
|
37 |
+
|
38 |
+
def load_model(file_path=None):
|
39 |
+
"""
|
40 |
+
Load given vosk model from file or default to en-us model.
|
41 |
+
Download model to user cache folder, example: C:/Users/toto/.cache/vosk
|
42 |
+
"""
|
43 |
+
|
44 |
+
if file_path is None:
|
45 |
+
return (whisper.load_model("base.en"), vosk.Model(lang="en-us"))
|
46 |
+
else:
|
47 |
+
return (whisper.load_model(file_path), vosk.Model(lang="en-us"))
|
48 |
+
|
49 |
+
def convert_bytearray_to_wav_ndarray(input_bytearray: bytes, sampling_rate=16000):
|
50 |
+
"""
|
51 |
+
Convert a bytearray to wav format to output in a file for quality check debuging
|
52 |
+
"""
|
53 |
+
bytes_wav = bytes()
|
54 |
+
byte_io = io.BytesIO(bytes_wav)
|
55 |
+
write(byte_io, sampling_rate, np.frombuffer(input_bytearray, dtype=np.int16))
|
56 |
+
output_wav = byte_io.read()
|
57 |
+
output, _ = sf.read(io.BytesIO(output_wav))
|
58 |
+
return output
|
59 |
+
|
60 |
+
def record_and_transcript():
|
61 |
+
"""
|
62 |
+
Continuously record from mic and transcript voice.
|
63 |
+
Return the transcript once no more voice is detected.
|
64 |
+
"""
|
65 |
+
if whisper_model is None:
|
66 |
+
print(DEBUG_PREFIX,"Whisper model not initialized yet.")
|
67 |
+
return ""
|
68 |
+
|
69 |
+
q = queue.Queue()
|
70 |
+
stream_errors = list()
|
71 |
+
|
72 |
+
def callback(indata, frames, time, status):
|
73 |
+
"""This is called (from a separate thread) for each audio block."""
|
74 |
+
if status:
|
75 |
+
print(status, file=sys.stderr)
|
76 |
+
stream_errors.append(status)
|
77 |
+
q.put(bytes(indata))
|
78 |
+
|
79 |
+
try:
|
80 |
+
device_info = sd.query_devices(device, "input")
|
81 |
+
# soundfile expects an int, sounddevice provides a float:
|
82 |
+
samplerate = int(device_info["default_samplerate"])
|
83 |
+
|
84 |
+
print(DEBUG_PREFIX, "Start recording from:", device_info["name"], "with samplerate", samplerate)
|
85 |
+
|
86 |
+
with sd.RawInputStream(samplerate=samplerate, blocksize = 8000, device=device, dtype="int16", channels=1, callback=callback):
|
87 |
+
|
88 |
+
rec = vosk.KaldiRecognizer(vosk_model, samplerate)
|
89 |
+
full_recording = bytearray()
|
90 |
+
while True:
|
91 |
+
data = q.get()
|
92 |
+
if len(stream_errors) > 0:
|
93 |
+
raise Exception(DEBUG_PREFIX+" Stream errors: "+str(stream_errors))
|
94 |
+
|
95 |
+
full_recording.extend(data)
|
96 |
+
|
97 |
+
if rec.AcceptWaveform(data):
|
98 |
+
# Extract transcript string
|
99 |
+
transcript = rec.Result()[14:-3]
|
100 |
+
print(DEBUG_PREFIX, "Transcripted from microphone stream (vosk):", transcript)
|
101 |
+
|
102 |
+
# ----------------------------------
|
103 |
+
# DEBUG: save recording to wav file
|
104 |
+
# ----------------------------------
|
105 |
+
output_file = convert_bytearray_to_wav_ndarray(input_bytearray=full_recording, sampling_rate=samplerate)
|
106 |
+
sf.write(file=RECORDING_FILE_PATH, data=output_file, samplerate=samplerate)
|
107 |
+
print(DEBUG_PREFIX, "Recorded message saved to", RECORDING_FILE_PATH)
|
108 |
+
|
109 |
+
# Whisper HACK
|
110 |
+
result = whisper_model.transcribe(RECORDING_FILE_PATH)
|
111 |
+
transcript = result["text"]
|
112 |
+
print(DEBUG_PREFIX, "Transcripted from audio file (whisper):", transcript)
|
113 |
+
# ----------------------------------
|
114 |
+
|
115 |
+
return jsonify({"transcript": transcript})
|
116 |
+
#else:
|
117 |
+
# print(rec.PartialResult())
|
118 |
+
|
119 |
+
except Exception as e: # No exception observed during test but we never know
|
120 |
+
print(e)
|
121 |
+
abort(500, DEBUG_PREFIX+" Exception occurs while recording")
|
modules/speech_recognition/vosk_module.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Speech-to-text module based on Vosk for SillyTavern Extras
|
3 |
+
- Vosk website: https://alphacephei.com/vosk/
|
4 |
+
- Vosk api: https://github.com/alphacep/vosk-api
|
5 |
+
|
6 |
+
Authors:
|
7 |
+
- Tony Ribeiro (https://github.com/Tony-sama)
|
8 |
+
|
9 |
+
Models are saved into user cache folder, example: C:/Users/toto/.cache/vosk
|
10 |
+
|
11 |
+
References:
|
12 |
+
- Code adapted from: https://github.com/alphacep/vosk-api/blob/master/python/example/test_simple.py
|
13 |
+
"""
|
14 |
+
from flask import jsonify, abort, request
|
15 |
+
|
16 |
+
import wave
|
17 |
+
from vosk import Model, KaldiRecognizer, SetLogLevel
|
18 |
+
import soundfile
|
19 |
+
|
20 |
+
DEBUG_PREFIX = "<stt vosk module>"
|
21 |
+
RECORDING_FILE_PATH = "stt_test.wav"
|
22 |
+
|
23 |
+
model = None
|
24 |
+
|
25 |
+
SetLogLevel(-1)
|
26 |
+
|
27 |
+
def load_model(file_path=None):
|
28 |
+
"""
|
29 |
+
Load given vosk model from file or default to en-us model.
|
30 |
+
Download model to user cache folder, example: C:/Users/toto/.cache/vosk
|
31 |
+
"""
|
32 |
+
|
33 |
+
if file_path is None:
|
34 |
+
return Model(lang="en-us")
|
35 |
+
else:
|
36 |
+
return Model(file_path)
|
37 |
+
|
38 |
+
def process_audio():
|
39 |
+
"""
|
40 |
+
Transcript request audio file to text using Whisper
|
41 |
+
"""
|
42 |
+
|
43 |
+
if model is None:
|
44 |
+
print(DEBUG_PREFIX,"Vosk model not initialized yet.")
|
45 |
+
return ""
|
46 |
+
|
47 |
+
try:
|
48 |
+
file = request.files.get('AudioFile')
|
49 |
+
file.save(RECORDING_FILE_PATH)
|
50 |
+
|
51 |
+
# Read and rewrite the file with soundfile
|
52 |
+
data, samplerate = soundfile.read(RECORDING_FILE_PATH)
|
53 |
+
soundfile.write(RECORDING_FILE_PATH, data, samplerate)
|
54 |
+
|
55 |
+
wf = wave.open(RECORDING_FILE_PATH, "rb")
|
56 |
+
if wf.getnchannels() != 1 or wf.getsampwidth() != 2 or wf.getcomptype() != "NONE":
|
57 |
+
print("Audio file must be WAV format mono PCM.")
|
58 |
+
abort(500, DEBUG_PREFIX+" Audio file must be WAV format mono PCM.")
|
59 |
+
|
60 |
+
rec = KaldiRecognizer(model, wf.getframerate())
|
61 |
+
#rec.SetWords(True)
|
62 |
+
#rec.SetPartialWords(True)
|
63 |
+
|
64 |
+
while True:
|
65 |
+
data = wf.readframes(4000)
|
66 |
+
if len(data) == 0:
|
67 |
+
break
|
68 |
+
if rec.AcceptWaveform(data):
|
69 |
+
break
|
70 |
+
|
71 |
+
transcript = rec.Result()[14:-3]
|
72 |
+
print(DEBUG_PREFIX, "Transcripted from request audio file:", transcript)
|
73 |
+
return jsonify({"transcript": transcript})
|
74 |
+
|
75 |
+
except Exception as e: # No exception observed during test but we never know
|
76 |
+
print(e)
|
77 |
+
abort(500, DEBUG_PREFIX+" Exception occurs while processing audio")
|
modules/speech_recognition/whisper_module.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Speech-to-text module based on Whisper for SillyTavern Extras
|
3 |
+
- Whisper github: https://github.com/openai/whisper
|
4 |
+
|
5 |
+
Authors:
|
6 |
+
- Tony Ribeiro (https://github.com/Tony-sama)
|
7 |
+
|
8 |
+
Models are saved into user cache folder, example: C:/Users/toto/.cache/whisper
|
9 |
+
|
10 |
+
References:
|
11 |
+
- Code adapted from:
|
12 |
+
- whisper github: https://github.com/openai/whisper
|
13 |
+
- oobabooga text-generation-webui github: https://github.com/oobabooga/text-generation-webui
|
14 |
+
"""
|
15 |
+
from flask import jsonify, abort, request
|
16 |
+
|
17 |
+
import whisper
|
18 |
+
|
19 |
+
DEBUG_PREFIX = "<stt whisper module>"
|
20 |
+
RECORDING_FILE_PATH = "stt_test.wav"
|
21 |
+
|
22 |
+
model = None
|
23 |
+
|
24 |
+
def load_model(file_path=None):
|
25 |
+
"""
|
26 |
+
Load given vosk model from file or default to en-us model.
|
27 |
+
Download model to user cache folder, example: C:/Users/toto/.cache/vosk
|
28 |
+
"""
|
29 |
+
|
30 |
+
if file_path is None:
|
31 |
+
return whisper.load_model("base.en")
|
32 |
+
else:
|
33 |
+
return whisper.load_model(file_path)
|
34 |
+
|
35 |
+
def process_audio():
|
36 |
+
"""
|
37 |
+
Transcript request audio file to text using Whisper
|
38 |
+
"""
|
39 |
+
|
40 |
+
if model is None:
|
41 |
+
print(DEBUG_PREFIX,"Whisper model not initialized yet.")
|
42 |
+
return ""
|
43 |
+
|
44 |
+
try:
|
45 |
+
file = request.files.get('AudioFile')
|
46 |
+
file.save(RECORDING_FILE_PATH)
|
47 |
+
|
48 |
+
result = model.transcribe(RECORDING_FILE_PATH)
|
49 |
+
transcript = result["text"]
|
50 |
+
print(DEBUG_PREFIX, "Transcripted from audio file (whisper):", transcript)
|
51 |
+
|
52 |
+
return jsonify({"transcript": transcript})
|
53 |
+
|
54 |
+
except Exception as e: # No exception observed during test but we never know
|
55 |
+
print(e)
|
56 |
+
abort(500, DEBUG_PREFIX+" Exception occurs while processing audio")
|
modules/text_to_speech/coqui/coqui_module.py
ADDED
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Coqui module for SillyTavern Extras
|
3 |
+
|
4 |
+
Authors:
|
5 |
+
- Pyrater (https://github.com/pyrater)
|
6 |
+
- Tony Ribeiro (https://github.com/Tony-sama)
|
7 |
+
|
8 |
+
Models are saved into user cache folder: "C:/Users/<username>/AppData/Local/tts"
|
9 |
+
|
10 |
+
References:
|
11 |
+
- Code adapted from:
|
12 |
+
- Coqui TTS https://tts.readthedocs.io/en/latest/
|
13 |
+
- Audio-webui: https://github.com/gitmylo/audio-webui
|
14 |
+
"""
|
15 |
+
import json
|
16 |
+
import os
|
17 |
+
import io
|
18 |
+
import shutil
|
19 |
+
|
20 |
+
from flask import abort, request, send_file, jsonify
|
21 |
+
|
22 |
+
from TTS.api import TTS
|
23 |
+
from TTS.utils.manage import ModelManager
|
24 |
+
|
25 |
+
from modules.utils import silence_log
|
26 |
+
|
27 |
+
DEBUG_PREFIX = "<Coqui-TTS module>"
|
28 |
+
COQUI_MODELS_PATH = "data/models/coqui/"
|
29 |
+
IGNORED_FILES = [".placeholder"]
|
30 |
+
COQUI_LOCAL_MODEL_FILE_NAME = "model.pth"
|
31 |
+
COQUI_LOCAL_CONFIG_FILE_NAME = "config.json"
|
32 |
+
|
33 |
+
gpu_mode = False
|
34 |
+
is_downloading = False
|
35 |
+
|
36 |
+
def install_model(model_id):
|
37 |
+
global gpu_mode
|
38 |
+
audio_buffer = io.BytesIO()
|
39 |
+
speaker_id = None
|
40 |
+
language_id = None
|
41 |
+
|
42 |
+
print(DEBUG_PREFIX,"Loading model",model_id)
|
43 |
+
try:
|
44 |
+
tts = TTS(model_name=model_id, progress_bar=True, gpu=gpu_mode)
|
45 |
+
|
46 |
+
if tts.is_multi_lingual:
|
47 |
+
language_id = tts.languages[0]
|
48 |
+
|
49 |
+
if tts.is_multi_speaker:
|
50 |
+
speaker_id =tts.speakers[0]
|
51 |
+
|
52 |
+
tts.tts_to_file(text="this is a test message", file_path=audio_buffer, speaker=speaker_id, language=language_id)
|
53 |
+
except Exception as e:
|
54 |
+
print(DEBUG_PREFIX,"ERROR:", e)
|
55 |
+
print("Model", model_id, "cannot be loaded, maybe wrong model name? Must be one of")
|
56 |
+
for i in TTS.list_models():
|
57 |
+
print(i)
|
58 |
+
return False
|
59 |
+
|
60 |
+
print(DEBUG_PREFIX,"Success")
|
61 |
+
return True
|
62 |
+
|
63 |
+
def coqui_check_model_state():
|
64 |
+
"""
|
65 |
+
Check if the requested model is installed on the server machine
|
66 |
+
"""
|
67 |
+
try:
|
68 |
+
model_state = "absent"
|
69 |
+
request_json = request.get_json()
|
70 |
+
model_id = request_json["model_id"]
|
71 |
+
|
72 |
+
print(DEBUG_PREFIX,"Search for model", model_id)
|
73 |
+
|
74 |
+
coqui_models_folder = ModelManager().output_prefix # models location
|
75 |
+
|
76 |
+
# Check if tts folder exist
|
77 |
+
if os.path.isdir(coqui_models_folder):
|
78 |
+
|
79 |
+
installed_models = os.listdir(coqui_models_folder)
|
80 |
+
|
81 |
+
model_folder_exists = False
|
82 |
+
model_folder = None
|
83 |
+
|
84 |
+
for i in installed_models:
|
85 |
+
if model_id == i.replace("--","/",3): # Error with model wrong name
|
86 |
+
model_folder_exists = True
|
87 |
+
model_folder = i
|
88 |
+
print(DEBUG_PREFIX,"Folder found:",model_folder)
|
89 |
+
|
90 |
+
# Check failed download
|
91 |
+
if model_folder_exists:
|
92 |
+
content = os.listdir(os.path.join(coqui_models_folder,model_folder))
|
93 |
+
print(DEBUG_PREFIX,"Checking content:",content)
|
94 |
+
for i in content:
|
95 |
+
if i == model_folder+".zip":
|
96 |
+
print("Corrupt installed found, model download must have failed previously")
|
97 |
+
model_state = "corrupted"
|
98 |
+
break
|
99 |
+
|
100 |
+
if model_state != "corrupted":
|
101 |
+
model_state = "installed"
|
102 |
+
|
103 |
+
response = json.dumps({"model_state":model_state})
|
104 |
+
return response
|
105 |
+
|
106 |
+
except Exception as e:
|
107 |
+
print(e)
|
108 |
+
abort(500, DEBUG_PREFIX + " Exception occurs while trying to search for installed model")
|
109 |
+
|
110 |
+
def coqui_install_model():
|
111 |
+
"""
|
112 |
+
Install requested model is installed on the server machine
|
113 |
+
"""
|
114 |
+
global gpu_mode
|
115 |
+
global is_downloading
|
116 |
+
|
117 |
+
try:
|
118 |
+
model_installed = False
|
119 |
+
request_json = request.get_json()
|
120 |
+
model_id = request_json["model_id"]
|
121 |
+
action = request_json["action"]
|
122 |
+
|
123 |
+
print(DEBUG_PREFIX,"Received request",action,"for model",model_id)
|
124 |
+
|
125 |
+
if (is_downloading):
|
126 |
+
print(DEBUG_PREFIX,"Rejected, already downloading a model")
|
127 |
+
return json.dumps({"status":"downloading"})
|
128 |
+
|
129 |
+
coqui_models_folder = ModelManager().output_prefix # models location
|
130 |
+
|
131 |
+
# Check if tts folder exist
|
132 |
+
if os.path.isdir(coqui_models_folder):
|
133 |
+
installed_models = os.listdir(coqui_models_folder)
|
134 |
+
model_path = None
|
135 |
+
|
136 |
+
print(DEBUG_PREFIX,"Found",len(installed_models),"models in",coqui_models_folder)
|
137 |
+
|
138 |
+
for i in installed_models:
|
139 |
+
if model_id == i.replace("--","/"):
|
140 |
+
model_installed = True
|
141 |
+
model_path = os.path.join(coqui_models_folder,i)
|
142 |
+
|
143 |
+
if model_installed:
|
144 |
+
print(DEBUG_PREFIX,"model found:", model_id)
|
145 |
+
else:
|
146 |
+
print(DEBUG_PREFIX,"model not found")
|
147 |
+
|
148 |
+
if action == "download":
|
149 |
+
if model_installed:
|
150 |
+
abort(500, DEBUG_PREFIX + "Bad request, model already installed.")
|
151 |
+
|
152 |
+
is_downloading = True
|
153 |
+
TTS(model_name=model_id, progress_bar=True, gpu=gpu_mode)
|
154 |
+
is_downloading = False
|
155 |
+
|
156 |
+
if action == "repare":
|
157 |
+
if not model_installed:
|
158 |
+
abort(500, DEBUG_PREFIX + " bad request: requesting repare of model not installed")
|
159 |
+
|
160 |
+
|
161 |
+
print(DEBUG_PREFIX,"Deleting corrupted model folder:",model_path)
|
162 |
+
shutil.rmtree(model_path, ignore_errors=True)
|
163 |
+
|
164 |
+
is_downloading = True
|
165 |
+
TTS(model_name=model_id, progress_bar=True, gpu=gpu_mode)
|
166 |
+
is_downloading = False
|
167 |
+
|
168 |
+
response = json.dumps({"status":"done"})
|
169 |
+
return response
|
170 |
+
|
171 |
+
except Exception as e:
|
172 |
+
is_downloading = False
|
173 |
+
print(e)
|
174 |
+
abort(500, DEBUG_PREFIX + " Exception occurs while trying to search for installed model")
|
175 |
+
|
176 |
+
def coqui_get_local_models():
|
177 |
+
"""
|
178 |
+
Return user local models list in the following format: [language][dataset][name] = TTS_string_id
|
179 |
+
"""
|
180 |
+
try:
|
181 |
+
print(DEBUG_PREFIX, "Received request for list of RVC models")
|
182 |
+
|
183 |
+
folder_names = os.listdir(COQUI_MODELS_PATH)
|
184 |
+
|
185 |
+
print(DEBUG_PREFIX,"Searching model in",COQUI_MODELS_PATH)
|
186 |
+
|
187 |
+
model_list = []
|
188 |
+
for folder_name in folder_names:
|
189 |
+
folder_path = COQUI_MODELS_PATH+folder_name
|
190 |
+
|
191 |
+
if folder_name in IGNORED_FILES:
|
192 |
+
continue
|
193 |
+
|
194 |
+
# Must be a folder
|
195 |
+
if not os.path.isdir(folder_path):
|
196 |
+
print("> WARNING:",folder_name,"is not a folder, it should not be there, ignored")
|
197 |
+
continue
|
198 |
+
|
199 |
+
print("> Found model folder",folder_name)
|
200 |
+
|
201 |
+
# Check pth
|
202 |
+
valid_folder = False
|
203 |
+
for file_name in os.listdir(folder_path):
|
204 |
+
if file_name.endswith(".pth"):
|
205 |
+
print(" > pth:",file_name)
|
206 |
+
valid_folder = True
|
207 |
+
if file_name.endswith(".config"):
|
208 |
+
print(" > config:",file_name)
|
209 |
+
|
210 |
+
if valid_folder:
|
211 |
+
print(" > Valid folder added to list")
|
212 |
+
model_list.append(folder_name)
|
213 |
+
else:
|
214 |
+
print(" > WARNING: Missing pth or config file, ignored folder")
|
215 |
+
|
216 |
+
# Return the list of valid folders
|
217 |
+
response = json.dumps({"models_list":model_list})
|
218 |
+
return response
|
219 |
+
|
220 |
+
except Exception as e:
|
221 |
+
print(e)
|
222 |
+
abort(500, DEBUG_PREFIX + " Exception occurs while searching for Coqui models.")
|
223 |
+
|
224 |
+
|
225 |
+
|
226 |
+
def coqui_generate_tts():
|
227 |
+
"""
|
228 |
+
Process request text with the loaded RVC model
|
229 |
+
- expected request: {
|
230 |
+
"text": text,
|
231 |
+
"model_id": voiceId,
|
232 |
+
"language_id": language,
|
233 |
+
"speaker_id": speaker
|
234 |
+
}
|
235 |
+
|
236 |
+
- model_id formats:
|
237 |
+
- model_type/language/dataset/model_name
|
238 |
+
- model_type/language/dataset/model_name[spearker_id]
|
239 |
+
- model_type/language/dataset/model_name[spearker_id][language_id]
|
240 |
+
- examples:
|
241 |
+
- tts_models/ja/kokoro/tacotron2-DDC
|
242 |
+
- tts_models/en/vctk/vits[0]
|
243 |
+
- tts_models/multilingual/multi-dataset/your_tts[2][1]
|
244 |
+
"""
|
245 |
+
global gpu_mode
|
246 |
+
global is_downloading
|
247 |
+
audio_buffer = io.BytesIO()
|
248 |
+
|
249 |
+
try:
|
250 |
+
request_json = request.get_json()
|
251 |
+
#print(request_json)
|
252 |
+
|
253 |
+
print(DEBUG_PREFIX,"Received TTS request for ", request_json)
|
254 |
+
|
255 |
+
if (is_downloading):
|
256 |
+
print(DEBUG_PREFIX,"Rejected, currently downloading a model, cannot perform TTS")
|
257 |
+
abort(500, DEBUG_PREFIX + " Requested TTS while downloading a model")
|
258 |
+
|
259 |
+
text = request_json["text"]
|
260 |
+
model_name = request_json["model_id"]
|
261 |
+
language_id = None
|
262 |
+
speaker_id = None
|
263 |
+
|
264 |
+
# Local model
|
265 |
+
model_type = model_name.split("/")[0]
|
266 |
+
if model_type == "local":
|
267 |
+
return generate_tts_local(model_name.split("/")[1], text)
|
268 |
+
|
269 |
+
|
270 |
+
if request_json["language_id"] != "none":
|
271 |
+
language_id = request_json["language_id"]
|
272 |
+
|
273 |
+
if request_json["speaker_id"] != "none":
|
274 |
+
speaker_id = request_json["speaker_id"]
|
275 |
+
|
276 |
+
print(DEBUG_PREFIX,"Loading tts \n- model", model_name, "\n - speaker_id: ",speaker_id,"\n - language_id: ",language_id, "\n - using",("GPU" if gpu_mode else "CPU"))
|
277 |
+
|
278 |
+
is_downloading = True
|
279 |
+
tts = TTS(model_name=model_name, progress_bar=True, gpu=gpu_mode)
|
280 |
+
is_downloading = False
|
281 |
+
|
282 |
+
if tts.is_multi_lingual:
|
283 |
+
if language_id is None:
|
284 |
+
abort(400, DEBUG_PREFIX + " Requested model "+model_name+" is multi-lingual but no language id provided")
|
285 |
+
language_id = tts.languages[int(language_id)]
|
286 |
+
|
287 |
+
if tts.is_multi_speaker:
|
288 |
+
if speaker_id is None:
|
289 |
+
abort(400, DEBUG_PREFIX + " Requested model "+model_name+" is multi-speaker but no speaker id provided")
|
290 |
+
speaker_id =tts.speakers[int(speaker_id)]
|
291 |
+
|
292 |
+
tts.tts_to_file(text=text, file_path=audio_buffer, speaker=speaker_id, language=language_id)
|
293 |
+
|
294 |
+
print(DEBUG_PREFIX, "Success, saved to",audio_buffer)
|
295 |
+
|
296 |
+
# Return the output_audio_path object as a response
|
297 |
+
response = send_file(audio_buffer, mimetype="audio/x-wav")
|
298 |
+
audio_buffer = io.BytesIO()
|
299 |
+
|
300 |
+
return response
|
301 |
+
|
302 |
+
except Exception as e:
|
303 |
+
print(e)
|
304 |
+
abort(500, DEBUG_PREFIX + " Exception occurs while trying to process request "+str(request_json))
|
305 |
+
|
306 |
+
def generate_tts_local(model_folder, text):
|
307 |
+
"""
|
308 |
+
Generate tts using local coqui model
|
309 |
+
"""
|
310 |
+
audio_buffer = io.BytesIO()
|
311 |
+
|
312 |
+
print(DEBUG_PREFIX,"Request for tts from local coqui model",model_folder)
|
313 |
+
|
314 |
+
model_path = os.path.join(COQUI_MODELS_PATH,model_folder,COQUI_LOCAL_MODEL_FILE_NAME)
|
315 |
+
config_path = os.path.join(COQUI_MODELS_PATH,model_folder,COQUI_LOCAL_CONFIG_FILE_NAME)
|
316 |
+
|
317 |
+
if not os.path.exists(model_path):
|
318 |
+
raise ValueError("File does not exists:",model_path)
|
319 |
+
|
320 |
+
if not os.path.exists(config_path):
|
321 |
+
raise ValueError("File does not exists:",config_path)
|
322 |
+
|
323 |
+
print(DEBUG_PREFIX,"Loading local tts model", model_path,"using",("GPU" if gpu_mode else "CPU"))
|
324 |
+
tts = TTS(model_path=model_path, config_path=config_path, progress_bar=True, gpu=gpu_mode)
|
325 |
+
tts.tts_to_file(text=text, file_path=audio_buffer)
|
326 |
+
|
327 |
+
print(DEBUG_PREFIX, "Success, saved to",audio_buffer)
|
328 |
+
|
329 |
+
# Return the output_audio_path object as a response
|
330 |
+
response = send_file(audio_buffer, mimetype="audio/x-wav")
|
331 |
+
audio_buffer = io.BytesIO()
|
332 |
+
|
333 |
+
return response
|
modules/utils.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from contextlib import contextmanager
|
3 |
+
import sys
|
4 |
+
|
5 |
+
@contextmanager
|
6 |
+
def silence_log():
|
7 |
+
old_stdout = sys.stdout
|
8 |
+
old_stderr = sys.stderr
|
9 |
+
try:
|
10 |
+
with open(os.devnull, "w") as new_target:
|
11 |
+
sys.stdout = new_target
|
12 |
+
yield new_target
|
13 |
+
finally:
|
14 |
+
sys.stdout = old_stdout
|
15 |
+
sys.stderr = old_stderr
|
modules/voice_conversion/fairseq/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) Facebook, Inc. and its affiliates.
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
modules/voice_conversion/fairseq/__init__.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
"""isort:skip_file"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
|
10 |
+
try:
|
11 |
+
from .version import __version__ # noqa
|
12 |
+
except ImportError:
|
13 |
+
version_txt = os.path.join(os.path.dirname(__file__), "version.txt")
|
14 |
+
with open(version_txt) as f:
|
15 |
+
__version__ = f.read().strip()
|
16 |
+
|
17 |
+
__all__ = ["pdb"]
|
18 |
+
|
19 |
+
# backwards compatibility to support `from fairseq.X import Y`
|
20 |
+
from fairseq.distributed import utils as distributed_utils
|
21 |
+
from fairseq.logging import meters, metrics, progress_bar # noqa
|
22 |
+
|
23 |
+
sys.modules["fairseq.distributed_utils"] = distributed_utils
|
24 |
+
sys.modules["fairseq.meters"] = meters
|
25 |
+
sys.modules["fairseq.metrics"] = metrics
|
26 |
+
sys.modules["fairseq.progress_bar"] = progress_bar
|
27 |
+
|
28 |
+
# initialize hydra
|
29 |
+
#from fairseq.dataclass.initialize import hydra_init
|
30 |
+
|
31 |
+
#hydra_init()
|
32 |
+
|
33 |
+
#import fairseq.criterions # noqa
|
34 |
+
#import fairseq.distributed # noqa
|
35 |
+
#import fairseq.models # noqa
|
36 |
+
#import fairseq.modules # noqa
|
37 |
+
#import fairseq.optim # noqa
|
38 |
+
#import fairseq.optim.lr_scheduler # noqa
|
39 |
+
#import fairseq.pdb # noqa
|
40 |
+
#import fairseq.scoring # noqa
|
41 |
+
#import fairseq.tasks # noqa
|
42 |
+
#import fairseq.token_generation_constraints # noqa
|
43 |
+
|
44 |
+
#import fairseq.benchmark # noqa
|
45 |
+
#import fairseq.model_parallel # noqa
|
modules/voice_conversion/fairseq/binarizer.py
ADDED
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
import os
|
8 |
+
import typing as tp
|
9 |
+
from abc import ABC, abstractmethod
|
10 |
+
from collections import Counter
|
11 |
+
from dataclasses import dataclass
|
12 |
+
from multiprocessing import Pool
|
13 |
+
|
14 |
+
import torch
|
15 |
+
|
16 |
+
from fairseq.data import Dictionary, indexed_dataset
|
17 |
+
from fairseq.file_chunker_utils import Chunker, find_offsets
|
18 |
+
from fairseq.file_io import PathManager
|
19 |
+
from fairseq.tokenizer import tokenize_line
|
20 |
+
|
21 |
+
logger = logging.getLogger("binarizer")
|
22 |
+
|
23 |
+
|
24 |
+
@dataclass
|
25 |
+
class BinarizeSummary:
|
26 |
+
"""
|
27 |
+
Keep track of what's going on in the binarizer
|
28 |
+
"""
|
29 |
+
|
30 |
+
num_seq: int = 0
|
31 |
+
replaced: tp.Optional[Counter] = None
|
32 |
+
num_tok: int = 0
|
33 |
+
|
34 |
+
@property
|
35 |
+
def num_replaced(self) -> int:
|
36 |
+
if self.replaced is None:
|
37 |
+
return 0
|
38 |
+
return sum(self.replaced.values())
|
39 |
+
|
40 |
+
@property
|
41 |
+
def replaced_percent(self) -> float:
|
42 |
+
return 100 * self.num_replaced / self.num_tok
|
43 |
+
|
44 |
+
def __str__(self) -> str:
|
45 |
+
base = f"{self.num_seq} sents, {self.num_tok} tokens"
|
46 |
+
if self.replaced is None:
|
47 |
+
return base
|
48 |
+
|
49 |
+
return f"{base}, {self.replaced_percent:.3}% replaced"
|
50 |
+
|
51 |
+
def merge(self, other: "BinarizeSummary"):
|
52 |
+
replaced = None
|
53 |
+
if self.replaced is not None:
|
54 |
+
replaced = self.replaced
|
55 |
+
if other.replaced is not None:
|
56 |
+
if replaced is None:
|
57 |
+
replaced = other.replaced
|
58 |
+
else:
|
59 |
+
replaced += other.replaced
|
60 |
+
self.replaced = replaced
|
61 |
+
self.num_seq += other.num_seq
|
62 |
+
self.num_tok += other.num_tok
|
63 |
+
|
64 |
+
|
65 |
+
class Binarizer(ABC):
|
66 |
+
"""
|
67 |
+
a binarizer describes how to take a string and build a tensor out of it
|
68 |
+
"""
|
69 |
+
|
70 |
+
@abstractmethod
|
71 |
+
def binarize_line(
|
72 |
+
self,
|
73 |
+
line: str,
|
74 |
+
summary: BinarizeSummary,
|
75 |
+
) -> torch.IntTensor:
|
76 |
+
...
|
77 |
+
|
78 |
+
|
79 |
+
def _worker_prefix(output_prefix: str, worker_id: int):
|
80 |
+
return f"{output_prefix}.pt{worker_id}"
|
81 |
+
|
82 |
+
|
83 |
+
class FileBinarizer:
|
84 |
+
"""
|
85 |
+
An file binarizer can take a file, tokenize it, and binarize each line to a tensor
|
86 |
+
"""
|
87 |
+
|
88 |
+
@classmethod
|
89 |
+
def multiprocess_dataset(
|
90 |
+
cls,
|
91 |
+
input_file: str,
|
92 |
+
dataset_impl: str,
|
93 |
+
binarizer: Binarizer,
|
94 |
+
output_prefix: str,
|
95 |
+
vocab_size=None,
|
96 |
+
num_workers=1,
|
97 |
+
) -> BinarizeSummary:
|
98 |
+
final_summary = BinarizeSummary()
|
99 |
+
|
100 |
+
offsets = find_offsets(input_file, num_workers)
|
101 |
+
# find_offsets returns a list of position [pos1, pos2, pos3, pos4] but we would want pairs:
|
102 |
+
# [(pos1, pos2), (pos2, pos3), (pos3, pos4)] to process the chunks with start/end info
|
103 |
+
# we zip the list with itself shifted by one to get all the pairs.
|
104 |
+
(first_chunk, *more_chunks) = zip(offsets, offsets[1:])
|
105 |
+
pool = None
|
106 |
+
if num_workers > 1:
|
107 |
+
pool = Pool(processes=num_workers - 1)
|
108 |
+
worker_results = [
|
109 |
+
pool.apply_async(
|
110 |
+
cls._binarize_chunk_and_finalize,
|
111 |
+
args=(
|
112 |
+
binarizer,
|
113 |
+
input_file,
|
114 |
+
start_offset,
|
115 |
+
end_offset,
|
116 |
+
_worker_prefix(
|
117 |
+
output_prefix,
|
118 |
+
worker_id,
|
119 |
+
),
|
120 |
+
dataset_impl,
|
121 |
+
),
|
122 |
+
kwds={
|
123 |
+
"vocab_size": vocab_size,
|
124 |
+
}
|
125 |
+
if vocab_size is not None
|
126 |
+
else {},
|
127 |
+
)
|
128 |
+
for worker_id, (start_offset, end_offset) in enumerate(
|
129 |
+
more_chunks, start=1
|
130 |
+
)
|
131 |
+
]
|
132 |
+
|
133 |
+
pool.close()
|
134 |
+
pool.join()
|
135 |
+
for r in worker_results:
|
136 |
+
summ = r.get()
|
137 |
+
final_summary.merge(summ)
|
138 |
+
|
139 |
+
# do not close the bin file as we need to merge the worker results in
|
140 |
+
final_ds, summ = cls._binarize_file_chunk(
|
141 |
+
binarizer,
|
142 |
+
input_file,
|
143 |
+
offset_start=first_chunk[0],
|
144 |
+
offset_end=first_chunk[1],
|
145 |
+
output_prefix=output_prefix,
|
146 |
+
dataset_impl=dataset_impl,
|
147 |
+
vocab_size=vocab_size if vocab_size is not None else None,
|
148 |
+
)
|
149 |
+
final_summary.merge(summ)
|
150 |
+
|
151 |
+
if num_workers > 1:
|
152 |
+
for worker_id in range(1, num_workers):
|
153 |
+
# merge the worker outputs
|
154 |
+
worker_output_prefix = _worker_prefix(
|
155 |
+
output_prefix,
|
156 |
+
worker_id,
|
157 |
+
)
|
158 |
+
final_ds.merge_file_(worker_output_prefix)
|
159 |
+
try:
|
160 |
+
os.remove(indexed_dataset.data_file_path(worker_output_prefix))
|
161 |
+
os.remove(indexed_dataset.index_file_path(worker_output_prefix))
|
162 |
+
except Exception as e:
|
163 |
+
logger.error(
|
164 |
+
f"couldn't remove {worker_output_prefix}.*", exc_info=e
|
165 |
+
)
|
166 |
+
|
167 |
+
# now we can close the file
|
168 |
+
idx_file = indexed_dataset.index_file_path(output_prefix)
|
169 |
+
final_ds.finalize(idx_file)
|
170 |
+
return final_summary
|
171 |
+
|
172 |
+
@staticmethod
|
173 |
+
def _binarize_file_chunk(
|
174 |
+
binarizer: Binarizer,
|
175 |
+
filename: str,
|
176 |
+
offset_start: int,
|
177 |
+
offset_end: int,
|
178 |
+
output_prefix: str,
|
179 |
+
dataset_impl: str,
|
180 |
+
vocab_size=None,
|
181 |
+
) -> tp.Tuple[tp.Any, BinarizeSummary]: # (dataset builder, BinarizeSummary)
|
182 |
+
"""
|
183 |
+
creates a dataset builder and append binarized items to it. This function does not
|
184 |
+
finalize the builder, this is useful if you want to do other things with your bin file
|
185 |
+
like appending/merging other files
|
186 |
+
"""
|
187 |
+
bin_file = indexed_dataset.data_file_path(output_prefix)
|
188 |
+
ds = indexed_dataset.make_builder(
|
189 |
+
bin_file,
|
190 |
+
impl=dataset_impl,
|
191 |
+
vocab_size=vocab_size,
|
192 |
+
)
|
193 |
+
summary = BinarizeSummary()
|
194 |
+
|
195 |
+
with Chunker(
|
196 |
+
PathManager.get_local_path(filename), offset_start, offset_end
|
197 |
+
) as line_iterator:
|
198 |
+
for line in line_iterator:
|
199 |
+
ds.add_item(binarizer.binarize_line(line, summary))
|
200 |
+
|
201 |
+
return ds, summary
|
202 |
+
|
203 |
+
@classmethod
|
204 |
+
def _binarize_chunk_and_finalize(
|
205 |
+
cls,
|
206 |
+
binarizer: Binarizer,
|
207 |
+
filename: str,
|
208 |
+
offset_start: int,
|
209 |
+
offset_end: int,
|
210 |
+
output_prefix: str,
|
211 |
+
dataset_impl: str,
|
212 |
+
vocab_size=None,
|
213 |
+
):
|
214 |
+
"""
|
215 |
+
same as above, but also finalizes the builder
|
216 |
+
"""
|
217 |
+
ds, summ = cls._binarize_file_chunk(
|
218 |
+
binarizer,
|
219 |
+
filename,
|
220 |
+
offset_start,
|
221 |
+
offset_end,
|
222 |
+
output_prefix,
|
223 |
+
dataset_impl,
|
224 |
+
vocab_size=vocab_size,
|
225 |
+
)
|
226 |
+
|
227 |
+
idx_file = indexed_dataset.index_file_path(output_prefix)
|
228 |
+
ds.finalize(idx_file)
|
229 |
+
|
230 |
+
return summ
|
231 |
+
|
232 |
+
|
233 |
+
class VocabularyDatasetBinarizer(Binarizer):
|
234 |
+
"""
|
235 |
+
Takes a Dictionary/Vocabulary, assign ids to each
|
236 |
+
token using the dictionary encode_line function.
|
237 |
+
"""
|
238 |
+
|
239 |
+
def __init__(
|
240 |
+
self,
|
241 |
+
dict: Dictionary,
|
242 |
+
tokenize: tp.Callable[[str], tp.List[str]] = tokenize_line,
|
243 |
+
append_eos: bool = True,
|
244 |
+
reverse_order: bool = False,
|
245 |
+
already_numberized: bool = False,
|
246 |
+
) -> None:
|
247 |
+
self.dict = dict
|
248 |
+
self.tokenize = tokenize
|
249 |
+
self.append_eos = append_eos
|
250 |
+
self.reverse_order = reverse_order
|
251 |
+
self.already_numberized = already_numberized
|
252 |
+
super().__init__()
|
253 |
+
|
254 |
+
def binarize_line(
|
255 |
+
self,
|
256 |
+
line: str,
|
257 |
+
summary: BinarizeSummary,
|
258 |
+
):
|
259 |
+
if summary.replaced is None:
|
260 |
+
summary.replaced = Counter()
|
261 |
+
|
262 |
+
def replaced_consumer(word, idx):
|
263 |
+
if idx == self.dict.unk_index and word != self.dict.unk_word:
|
264 |
+
summary.replaced.update([word])
|
265 |
+
|
266 |
+
if self.already_numberized:
|
267 |
+
id_strings = line.strip().split()
|
268 |
+
id_list = [int(id_string) for id_string in id_strings]
|
269 |
+
if self.reverse_order:
|
270 |
+
id_list.reverse()
|
271 |
+
if self.append_eos:
|
272 |
+
id_list.append(self.dict.eos())
|
273 |
+
ids = torch.IntTensor(id_list)
|
274 |
+
else:
|
275 |
+
ids = self.dict.encode_line(
|
276 |
+
line=line,
|
277 |
+
line_tokenizer=self.tokenize,
|
278 |
+
add_if_not_exist=False,
|
279 |
+
consumer=replaced_consumer,
|
280 |
+
append_eos=self.append_eos,
|
281 |
+
reverse_order=self.reverse_order,
|
282 |
+
)
|
283 |
+
|
284 |
+
summary.num_seq += 1
|
285 |
+
summary.num_tok += len(ids)
|
286 |
+
return ids
|
287 |
+
|
288 |
+
|
289 |
+
class AlignmentDatasetBinarizer(Binarizer):
|
290 |
+
"""
|
291 |
+
binarize by parsing a set of alignments and packing
|
292 |
+
them in a tensor (see utils.parse_alignment)
|
293 |
+
"""
|
294 |
+
|
295 |
+
def __init__(
|
296 |
+
self,
|
297 |
+
alignment_parser: tp.Callable[[str], torch.IntTensor],
|
298 |
+
) -> None:
|
299 |
+
super().__init__()
|
300 |
+
self.alignment_parser = alignment_parser
|
301 |
+
|
302 |
+
def binarize_line(
|
303 |
+
self,
|
304 |
+
line: str,
|
305 |
+
summary: BinarizeSummary,
|
306 |
+
):
|
307 |
+
ids = self.alignment_parser(line)
|
308 |
+
summary.num_seq += 1
|
309 |
+
summary.num_tok += len(ids)
|
310 |
+
return ids
|
311 |
+
|
312 |
+
|
313 |
+
class LegacyBinarizer:
|
314 |
+
@classmethod
|
315 |
+
def binarize(
|
316 |
+
cls,
|
317 |
+
filename: str,
|
318 |
+
dico: Dictionary,
|
319 |
+
consumer: tp.Callable[[torch.IntTensor], None],
|
320 |
+
tokenize: tp.Callable[[str], tp.List[str]] = tokenize_line,
|
321 |
+
append_eos: bool = True,
|
322 |
+
reverse_order: bool = False,
|
323 |
+
offset: int = 0,
|
324 |
+
end: int = -1,
|
325 |
+
already_numberized: bool = False,
|
326 |
+
) -> tp.Dict[str, int]:
|
327 |
+
binarizer = VocabularyDatasetBinarizer(
|
328 |
+
dict=dico,
|
329 |
+
tokenize=tokenize,
|
330 |
+
append_eos=append_eos,
|
331 |
+
reverse_order=reverse_order,
|
332 |
+
already_numberized=already_numberized,
|
333 |
+
)
|
334 |
+
return cls._consume_file(
|
335 |
+
filename,
|
336 |
+
binarizer,
|
337 |
+
consumer,
|
338 |
+
offset_start=offset,
|
339 |
+
offset_end=end,
|
340 |
+
)
|
341 |
+
|
342 |
+
@classmethod
|
343 |
+
def binarize_alignments(
|
344 |
+
cls,
|
345 |
+
filename: str,
|
346 |
+
alignment_parser: tp.Callable[[str], torch.IntTensor],
|
347 |
+
consumer: tp.Callable[[torch.IntTensor], None],
|
348 |
+
offset: int = 0,
|
349 |
+
end: int = -1,
|
350 |
+
) -> tp.Dict[str, int]:
|
351 |
+
binarizer = AlignmentDatasetBinarizer(alignment_parser)
|
352 |
+
return cls._consume_file(
|
353 |
+
filename,
|
354 |
+
binarizer,
|
355 |
+
consumer,
|
356 |
+
offset_start=offset,
|
357 |
+
offset_end=end,
|
358 |
+
)
|
359 |
+
|
360 |
+
@staticmethod
|
361 |
+
def _consume_file(
|
362 |
+
filename: str,
|
363 |
+
binarizer: Binarizer,
|
364 |
+
consumer: tp.Callable[[torch.IntTensor], None],
|
365 |
+
offset_start: int,
|
366 |
+
offset_end: int,
|
367 |
+
) -> tp.Dict[str, int]:
|
368 |
+
summary = BinarizeSummary()
|
369 |
+
|
370 |
+
with Chunker(
|
371 |
+
PathManager.get_local_path(filename), offset_start, offset_end
|
372 |
+
) as line_iterator:
|
373 |
+
for line in line_iterator:
|
374 |
+
consumer(binarizer.binarize_line(line, summary))
|
375 |
+
|
376 |
+
return {
|
377 |
+
"nseq": summary.num_seq,
|
378 |
+
"nunk": summary.num_replaced,
|
379 |
+
"ntok": summary.num_tok,
|
380 |
+
"replaced": summary.replaced,
|
381 |
+
}
|
modules/voice_conversion/fairseq/checkpoint_utils.py
ADDED
@@ -0,0 +1,905 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import ast
|
7 |
+
import collections
|
8 |
+
import contextlib
|
9 |
+
import inspect
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
import re
|
13 |
+
import time
|
14 |
+
import traceback
|
15 |
+
from collections import OrderedDict
|
16 |
+
from pathlib import Path
|
17 |
+
from typing import Any, Dict, Optional, Union
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
from fairseq.data import data_utils
|
22 |
+
from fairseq.dataclass.configs import CheckpointConfig
|
23 |
+
from fairseq.dataclass.utils import (
|
24 |
+
convert_namespace_to_omegaconf,
|
25 |
+
overwrite_args_by_name,
|
26 |
+
)
|
27 |
+
from fairseq.distributed.fully_sharded_data_parallel import FSDP, has_FSDP
|
28 |
+
from fairseq.file_io import PathManager
|
29 |
+
from fairseq.models import FairseqDecoder, FairseqEncoder
|
30 |
+
from omegaconf import DictConfig, OmegaConf, open_dict
|
31 |
+
|
32 |
+
logger = logging.getLogger(__name__)
|
33 |
+
|
34 |
+
|
35 |
+
def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
|
36 |
+
from fairseq import meters
|
37 |
+
|
38 |
+
# only one worker should attempt to create the required dir
|
39 |
+
if trainer.data_parallel_rank == 0:
|
40 |
+
os.makedirs(cfg.save_dir, exist_ok=True)
|
41 |
+
|
42 |
+
prev_best = getattr(save_checkpoint, "best", val_loss)
|
43 |
+
if val_loss is not None:
|
44 |
+
best_function = max if cfg.maximize_best_checkpoint_metric else min
|
45 |
+
save_checkpoint.best = best_function(val_loss, prev_best)
|
46 |
+
|
47 |
+
if cfg.no_save:
|
48 |
+
return
|
49 |
+
|
50 |
+
trainer.consolidate_optimizer() # TODO(SS): do we need this if no_save_optimizer_state
|
51 |
+
|
52 |
+
if not trainer.should_save_checkpoint_on_current_rank:
|
53 |
+
if trainer.always_call_state_dict_during_save_checkpoint:
|
54 |
+
trainer.state_dict()
|
55 |
+
return
|
56 |
+
|
57 |
+
write_timer = meters.StopwatchMeter()
|
58 |
+
write_timer.start()
|
59 |
+
|
60 |
+
epoch = epoch_itr.epoch
|
61 |
+
end_of_epoch = epoch_itr.end_of_epoch()
|
62 |
+
updates = trainer.get_num_updates()
|
63 |
+
|
64 |
+
logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates")
|
65 |
+
|
66 |
+
def is_better(a, b):
|
67 |
+
return a >= b if cfg.maximize_best_checkpoint_metric else a <= b
|
68 |
+
|
69 |
+
suffix = trainer.checkpoint_suffix
|
70 |
+
checkpoint_conds = collections.OrderedDict()
|
71 |
+
checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = (
|
72 |
+
end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0
|
73 |
+
)
|
74 |
+
checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = (
|
75 |
+
not end_of_epoch
|
76 |
+
and cfg.save_interval_updates > 0
|
77 |
+
and updates % cfg.save_interval_updates == 0
|
78 |
+
)
|
79 |
+
checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and (
|
80 |
+
not hasattr(save_checkpoint, "best")
|
81 |
+
or is_better(val_loss, save_checkpoint.best)
|
82 |
+
)
|
83 |
+
if val_loss is not None and cfg.keep_best_checkpoints > 0:
|
84 |
+
worst_best = getattr(save_checkpoint, "best", None)
|
85 |
+
chkpts = checkpoint_paths(
|
86 |
+
cfg.save_dir,
|
87 |
+
pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
|
88 |
+
cfg.best_checkpoint_metric, suffix
|
89 |
+
),
|
90 |
+
)
|
91 |
+
if len(chkpts) > 0:
|
92 |
+
p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0]
|
93 |
+
worst_best = float(p.rsplit("_")[-1].replace("{}.pt".format(suffix), ""))
|
94 |
+
# add random digits to resolve ties
|
95 |
+
with data_utils.numpy_seed(epoch, updates, val_loss):
|
96 |
+
rand_sfx = np.random.randint(0, cfg.keep_best_checkpoints)
|
97 |
+
|
98 |
+
checkpoint_conds[
|
99 |
+
"checkpoint.best_{}_{:.3f}{}{}.pt".format(
|
100 |
+
cfg.best_checkpoint_metric, val_loss, rand_sfx, suffix
|
101 |
+
)
|
102 |
+
] = worst_best is None or is_better(val_loss, worst_best)
|
103 |
+
checkpoint_conds[
|
104 |
+
"checkpoint_last{}.pt".format(suffix)
|
105 |
+
] = not cfg.no_last_checkpoints
|
106 |
+
|
107 |
+
extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
|
108 |
+
if hasattr(save_checkpoint, "best"):
|
109 |
+
extra_state.update({"best": save_checkpoint.best})
|
110 |
+
|
111 |
+
checkpoints = [
|
112 |
+
os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
|
113 |
+
]
|
114 |
+
if len(checkpoints) > 0 and trainer.should_save_checkpoint_on_current_rank:
|
115 |
+
trainer.save_checkpoint(checkpoints[0], extra_state)
|
116 |
+
for cp in checkpoints[1:]:
|
117 |
+
if cfg.write_checkpoints_asynchronously:
|
118 |
+
# TODO[ioPath]: Need to implement a delayed asynchronous
|
119 |
+
# file copying/moving feature.
|
120 |
+
logger.warning(
|
121 |
+
f"ioPath is not copying {checkpoints[0]} to {cp} "
|
122 |
+
"since async write mode is on."
|
123 |
+
)
|
124 |
+
else:
|
125 |
+
assert PathManager.copy(
|
126 |
+
checkpoints[0], cp, overwrite=True
|
127 |
+
), f"Failed to copy {checkpoints[0]} to {cp}"
|
128 |
+
|
129 |
+
write_timer.stop()
|
130 |
+
logger.info(
|
131 |
+
"Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format(
|
132 |
+
checkpoints[0], epoch, updates, val_loss, write_timer.sum
|
133 |
+
)
|
134 |
+
)
|
135 |
+
|
136 |
+
if not end_of_epoch and cfg.keep_interval_updates > 0:
|
137 |
+
# remove old checkpoints; checkpoints are sorted in descending order
|
138 |
+
if cfg.keep_interval_updates_pattern == -1:
|
139 |
+
checkpoints = checkpoint_paths(
|
140 |
+
cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix)
|
141 |
+
)
|
142 |
+
else:
|
143 |
+
checkpoints = checkpoint_paths(
|
144 |
+
cfg.save_dir,
|
145 |
+
pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix),
|
146 |
+
keep_match=True,
|
147 |
+
)
|
148 |
+
checkpoints = [
|
149 |
+
x[0]
|
150 |
+
for x in checkpoints
|
151 |
+
if x[1] % cfg.keep_interval_updates_pattern != 0
|
152 |
+
]
|
153 |
+
|
154 |
+
for old_chk in checkpoints[cfg.keep_interval_updates :]:
|
155 |
+
if os.path.lexists(old_chk):
|
156 |
+
os.remove(old_chk)
|
157 |
+
elif PathManager.exists(old_chk):
|
158 |
+
PathManager.rm(old_chk)
|
159 |
+
|
160 |
+
if cfg.keep_last_epochs > 0:
|
161 |
+
# remove old epoch checkpoints; checkpoints are sorted in descending order
|
162 |
+
checkpoints = checkpoint_paths(
|
163 |
+
cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix)
|
164 |
+
)
|
165 |
+
for old_chk in checkpoints[cfg.keep_last_epochs :]:
|
166 |
+
if os.path.lexists(old_chk):
|
167 |
+
os.remove(old_chk)
|
168 |
+
elif PathManager.exists(old_chk):
|
169 |
+
PathManager.rm(old_chk)
|
170 |
+
|
171 |
+
if cfg.keep_best_checkpoints > 0:
|
172 |
+
# only keep the best N checkpoints according to validation metric
|
173 |
+
checkpoints = checkpoint_paths(
|
174 |
+
cfg.save_dir,
|
175 |
+
pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
|
176 |
+
cfg.best_checkpoint_metric, suffix
|
177 |
+
),
|
178 |
+
)
|
179 |
+
if not cfg.maximize_best_checkpoint_metric:
|
180 |
+
checkpoints = checkpoints[::-1]
|
181 |
+
for old_chk in checkpoints[cfg.keep_best_checkpoints :]:
|
182 |
+
if os.path.lexists(old_chk):
|
183 |
+
os.remove(old_chk)
|
184 |
+
elif PathManager.exists(old_chk):
|
185 |
+
PathManager.rm(old_chk)
|
186 |
+
|
187 |
+
|
188 |
+
def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
|
189 |
+
"""
|
190 |
+
Load a checkpoint and restore the training iterator.
|
191 |
+
|
192 |
+
*passthrough_args* will be passed through to
|
193 |
+
``trainer.get_train_iterator``.
|
194 |
+
"""
|
195 |
+
|
196 |
+
reset_optimizer = cfg.reset_optimizer
|
197 |
+
reset_lr_scheduler = cfg.reset_lr_scheduler
|
198 |
+
optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides)
|
199 |
+
reset_meters = cfg.reset_meters
|
200 |
+
reset_dataloader = cfg.reset_dataloader
|
201 |
+
|
202 |
+
if cfg.finetune_from_model is not None and (
|
203 |
+
reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader
|
204 |
+
):
|
205 |
+
raise ValueError(
|
206 |
+
"--finetune-from-model can not be set together with either --reset-optimizer"
|
207 |
+
" or reset_lr_scheduler or reset_meters or reset_dataloader"
|
208 |
+
)
|
209 |
+
|
210 |
+
suffix = trainer.checkpoint_suffix
|
211 |
+
if (
|
212 |
+
cfg.restore_file == "checkpoint_last.pt"
|
213 |
+
): # default value of restore_file is 'checkpoint_last.pt'
|
214 |
+
checkpoint_path = os.path.join(
|
215 |
+
cfg.save_dir, "checkpoint_last{}.pt".format(suffix)
|
216 |
+
)
|
217 |
+
first_launch = not PathManager.exists(checkpoint_path)
|
218 |
+
if first_launch and getattr(cfg, "continue_once", None) is not None:
|
219 |
+
checkpoint_path = cfg.continue_once
|
220 |
+
elif cfg.finetune_from_model is not None and first_launch:
|
221 |
+
# if there is no last checkpoint to restore, start the finetune from pretrained model
|
222 |
+
# else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
|
223 |
+
if PathManager.exists(cfg.finetune_from_model):
|
224 |
+
checkpoint_path = cfg.finetune_from_model
|
225 |
+
reset_optimizer = True
|
226 |
+
reset_lr_scheduler = True
|
227 |
+
reset_meters = True
|
228 |
+
reset_dataloader = True
|
229 |
+
logger.info(
|
230 |
+
f"loading pretrained model from {checkpoint_path}: "
|
231 |
+
"optimizer, lr scheduler, meters, dataloader will be reset"
|
232 |
+
)
|
233 |
+
else:
|
234 |
+
raise ValueError(
|
235 |
+
f"--finetune-from-model {cfg.finetune_from_model} does not exist"
|
236 |
+
)
|
237 |
+
elif suffix is not None:
|
238 |
+
checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt")
|
239 |
+
else:
|
240 |
+
checkpoint_path = cfg.restore_file
|
241 |
+
|
242 |
+
if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model:
|
243 |
+
raise ValueError(
|
244 |
+
"--finetune-from-model and --restore-file (non-default value) "
|
245 |
+
"can not be specified together: " + str(cfg)
|
246 |
+
)
|
247 |
+
|
248 |
+
extra_state = trainer.load_checkpoint(
|
249 |
+
checkpoint_path,
|
250 |
+
reset_optimizer,
|
251 |
+
reset_lr_scheduler,
|
252 |
+
optimizer_overrides,
|
253 |
+
reset_meters=reset_meters,
|
254 |
+
)
|
255 |
+
|
256 |
+
if (
|
257 |
+
extra_state is not None
|
258 |
+
and "best" in extra_state
|
259 |
+
and not reset_optimizer
|
260 |
+
and not reset_meters
|
261 |
+
):
|
262 |
+
save_checkpoint.best = extra_state["best"]
|
263 |
+
|
264 |
+
if extra_state is not None and not reset_dataloader:
|
265 |
+
# restore iterator from checkpoint
|
266 |
+
itr_state = extra_state["train_iterator"]
|
267 |
+
epoch_itr = trainer.get_train_iterator(
|
268 |
+
epoch=itr_state["epoch"], load_dataset=True, **passthrough_args
|
269 |
+
)
|
270 |
+
epoch_itr.load_state_dict(itr_state)
|
271 |
+
else:
|
272 |
+
epoch_itr = trainer.get_train_iterator(
|
273 |
+
epoch=1, load_dataset=True, **passthrough_args
|
274 |
+
)
|
275 |
+
|
276 |
+
trainer.lr_step(epoch_itr.epoch)
|
277 |
+
|
278 |
+
return extra_state, epoch_itr
|
279 |
+
|
280 |
+
|
281 |
+
def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False):
|
282 |
+
"""Loads a checkpoint to CPU (with upgrading for backward compatibility).
|
283 |
+
|
284 |
+
If doing single-GPU training or if the checkpoint is only being loaded by at
|
285 |
+
most one process on each node (current default behavior is for only rank 0
|
286 |
+
to read the checkpoint from disk), load_on_all_ranks should be False to
|
287 |
+
avoid errors from torch.distributed not having been initialized or
|
288 |
+
torch.distributed.barrier() hanging.
|
289 |
+
|
290 |
+
If all processes on each node may be loading the checkpoint
|
291 |
+
simultaneously, load_on_all_ranks should be set to True to avoid I/O
|
292 |
+
conflicts.
|
293 |
+
|
294 |
+
There's currently no support for > 1 but < all processes loading the
|
295 |
+
checkpoint on each node.
|
296 |
+
"""
|
297 |
+
local_path = PathManager.get_local_path(path)
|
298 |
+
# The locally cached file returned by get_local_path() may be stale for
|
299 |
+
# remote files that are periodically updated/overwritten (ex:
|
300 |
+
# checkpoint_last.pt) - so we remove the local copy, sync across processes
|
301 |
+
# (if needed), and then download a fresh copy.
|
302 |
+
if local_path != path and PathManager.path_requires_pathmanager(path):
|
303 |
+
try:
|
304 |
+
os.remove(local_path)
|
305 |
+
except FileNotFoundError:
|
306 |
+
# With potentially multiple processes removing the same file, the
|
307 |
+
# file being missing is benign (missing_ok isn't available until
|
308 |
+
# Python 3.8).
|
309 |
+
pass
|
310 |
+
if load_on_all_ranks:
|
311 |
+
torch.distributed.barrier()
|
312 |
+
local_path = PathManager.get_local_path(path)
|
313 |
+
|
314 |
+
with open(local_path, "rb") as f:
|
315 |
+
state = torch.load(f, map_location=torch.device("cpu"))
|
316 |
+
|
317 |
+
if "args" in state and state["args"] is not None and arg_overrides is not None:
|
318 |
+
args = state["args"]
|
319 |
+
for arg_name, arg_val in arg_overrides.items():
|
320 |
+
setattr(args, arg_name, arg_val)
|
321 |
+
|
322 |
+
if "cfg" in state and state["cfg"] is not None:
|
323 |
+
|
324 |
+
# hack to be able to set Namespace in dict config. this should be removed when we update to newer
|
325 |
+
# omegaconf version that supports object flags, or when we migrate all existing models
|
326 |
+
from omegaconf import __version__ as oc_version
|
327 |
+
from omegaconf import _utils
|
328 |
+
|
329 |
+
if oc_version < "2.2":
|
330 |
+
old_primitive = _utils.is_primitive_type
|
331 |
+
_utils.is_primitive_type = lambda _: True
|
332 |
+
|
333 |
+
state["cfg"] = OmegaConf.create(state["cfg"])
|
334 |
+
|
335 |
+
_utils.is_primitive_type = old_primitive
|
336 |
+
OmegaConf.set_struct(state["cfg"], True)
|
337 |
+
else:
|
338 |
+
state["cfg"] = OmegaConf.create(state["cfg"], flags={"allow_objects": True})
|
339 |
+
|
340 |
+
if arg_overrides is not None:
|
341 |
+
overwrite_args_by_name(state["cfg"], arg_overrides)
|
342 |
+
|
343 |
+
state = _upgrade_state_dict(state)
|
344 |
+
return state
|
345 |
+
|
346 |
+
|
347 |
+
def load_model_ensemble(
|
348 |
+
filenames,
|
349 |
+
arg_overrides: Optional[Dict[str, Any]] = None,
|
350 |
+
task=None,
|
351 |
+
strict=True,
|
352 |
+
suffix="",
|
353 |
+
num_shards=1,
|
354 |
+
state=None,
|
355 |
+
):
|
356 |
+
"""Loads an ensemble of models.
|
357 |
+
|
358 |
+
Args:
|
359 |
+
filenames (List[str]): checkpoint files to load
|
360 |
+
arg_overrides (Dict[str,Any], optional): override model args that
|
361 |
+
were used during model training
|
362 |
+
task (fairseq.tasks.FairseqTask, optional): task to use for loading
|
363 |
+
"""
|
364 |
+
assert not (
|
365 |
+
strict and num_shards > 1
|
366 |
+
), "Cannot load state dict with strict=True and checkpoint shards > 1"
|
367 |
+
ensemble, args, _task = load_model_ensemble_and_task(
|
368 |
+
filenames,
|
369 |
+
arg_overrides,
|
370 |
+
task,
|
371 |
+
strict,
|
372 |
+
suffix,
|
373 |
+
num_shards,
|
374 |
+
state,
|
375 |
+
)
|
376 |
+
return ensemble, args
|
377 |
+
|
378 |
+
|
379 |
+
def get_maybe_sharded_checkpoint_filename(
|
380 |
+
filename: str, suffix: str, shard_idx: int, num_shards: int
|
381 |
+
) -> str:
|
382 |
+
orig_filename = filename
|
383 |
+
filename = filename.replace(".pt", suffix + ".pt")
|
384 |
+
fsdp_filename = filename[:-3] + f"-shard{shard_idx}.pt"
|
385 |
+
model_parallel_filename = orig_filename[:-3] + f"_part{shard_idx}.pt"
|
386 |
+
if PathManager.exists(fsdp_filename):
|
387 |
+
return fsdp_filename
|
388 |
+
elif num_shards > 1:
|
389 |
+
return model_parallel_filename
|
390 |
+
else:
|
391 |
+
return filename
|
392 |
+
|
393 |
+
|
394 |
+
def load_model_ensemble_and_task(
|
395 |
+
filenames,
|
396 |
+
arg_overrides: Optional[Dict[str, Any]] = None,
|
397 |
+
task=None,
|
398 |
+
strict=True,
|
399 |
+
suffix="",
|
400 |
+
num_shards=1,
|
401 |
+
state=None,
|
402 |
+
):
|
403 |
+
assert state is None or len(filenames) == 1
|
404 |
+
|
405 |
+
from fairseq import tasks
|
406 |
+
|
407 |
+
assert not (
|
408 |
+
strict and num_shards > 1
|
409 |
+
), "Cannot load state dict with strict=True and checkpoint shards > 1"
|
410 |
+
ensemble = []
|
411 |
+
cfg = None
|
412 |
+
for filename in filenames:
|
413 |
+
orig_filename = filename
|
414 |
+
model_shard_state = {"shard_weights": [], "shard_metadata": []}
|
415 |
+
assert num_shards > 0
|
416 |
+
st = time.time()
|
417 |
+
for shard_idx in range(num_shards):
|
418 |
+
filename = get_maybe_sharded_checkpoint_filename(
|
419 |
+
orig_filename, suffix, shard_idx, num_shards
|
420 |
+
)
|
421 |
+
|
422 |
+
if not PathManager.exists(filename):
|
423 |
+
raise IOError("Model file not found: {}".format(filename))
|
424 |
+
if state is None:
|
425 |
+
state = load_checkpoint_to_cpu(filename, arg_overrides)
|
426 |
+
if "args" in state and state["args"] is not None:
|
427 |
+
cfg = convert_namespace_to_omegaconf(state["args"])
|
428 |
+
elif "cfg" in state and state["cfg"] is not None:
|
429 |
+
cfg = state["cfg"]
|
430 |
+
else:
|
431 |
+
raise RuntimeError(
|
432 |
+
f"Neither args nor cfg exist in state keys = {state.keys()}"
|
433 |
+
)
|
434 |
+
|
435 |
+
if task is None:
|
436 |
+
task = tasks.setup_task(cfg.task)
|
437 |
+
|
438 |
+
if "task_state" in state:
|
439 |
+
task.load_state_dict(state["task_state"])
|
440 |
+
|
441 |
+
if "fsdp_metadata" in state and num_shards > 1:
|
442 |
+
model_shard_state["shard_weights"].append(state["model"])
|
443 |
+
model_shard_state["shard_metadata"].append(state["fsdp_metadata"])
|
444 |
+
# check FSDP import before the code goes too far
|
445 |
+
if not has_FSDP:
|
446 |
+
raise ImportError(
|
447 |
+
"Cannot find FullyShardedDataParallel. "
|
448 |
+
"Please install fairscale with: pip install fairscale"
|
449 |
+
)
|
450 |
+
if shard_idx == num_shards - 1:
|
451 |
+
consolidated_model_state = FSDP.consolidate_shard_weights(
|
452 |
+
shard_weights=model_shard_state["shard_weights"],
|
453 |
+
shard_metadata=model_shard_state["shard_metadata"],
|
454 |
+
)
|
455 |
+
model = task.build_model(cfg.model)
|
456 |
+
if (
|
457 |
+
"optimizer_history" in state
|
458 |
+
and len(state["optimizer_history"]) > 0
|
459 |
+
and "num_updates" in state["optimizer_history"][-1]
|
460 |
+
):
|
461 |
+
model.set_num_updates(
|
462 |
+
state["optimizer_history"][-1]["num_updates"]
|
463 |
+
)
|
464 |
+
model.load_state_dict(
|
465 |
+
consolidated_model_state, strict=strict, model_cfg=cfg.model
|
466 |
+
)
|
467 |
+
else:
|
468 |
+
# model parallel checkpoint or unsharded checkpoint
|
469 |
+
# support old external tasks
|
470 |
+
|
471 |
+
argspec = inspect.getfullargspec(task.build_model)
|
472 |
+
if "from_checkpoint" in argspec.args:
|
473 |
+
model = task.build_model(cfg.model, from_checkpoint=True)
|
474 |
+
else:
|
475 |
+
model = task.build_model(cfg.model)
|
476 |
+
if (
|
477 |
+
"optimizer_history" in state
|
478 |
+
and len(state["optimizer_history"]) > 0
|
479 |
+
and "num_updates" in state["optimizer_history"][-1]
|
480 |
+
):
|
481 |
+
model.set_num_updates(state["optimizer_history"][-1]["num_updates"])
|
482 |
+
model.load_state_dict(
|
483 |
+
state["model"], strict=strict, model_cfg=cfg.model
|
484 |
+
)
|
485 |
+
|
486 |
+
# reset state so it gets loaded for the next model in ensemble
|
487 |
+
state = None
|
488 |
+
if shard_idx % 10 == 0 and shard_idx > 0:
|
489 |
+
elapsed = time.time() - st
|
490 |
+
logger.info(
|
491 |
+
f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard"
|
492 |
+
)
|
493 |
+
|
494 |
+
# build model for ensemble
|
495 |
+
ensemble.append(model)
|
496 |
+
return ensemble, cfg, task
|
497 |
+
|
498 |
+
|
499 |
+
def load_model_ensemble_and_task_from_hf_hub(
|
500 |
+
model_id,
|
501 |
+
cache_dir: Optional[str] = None,
|
502 |
+
arg_overrides: Optional[Dict[str, Any]] = None,
|
503 |
+
**kwargs: Any,
|
504 |
+
):
|
505 |
+
try:
|
506 |
+
from huggingface_hub import snapshot_download
|
507 |
+
except ImportError:
|
508 |
+
raise ImportError(
|
509 |
+
"You need to install huggingface_hub to use `load_from_hf_hub`. "
|
510 |
+
"See https://pypi.org/project/huggingface-hub/ for installation."
|
511 |
+
)
|
512 |
+
|
513 |
+
library_name = "fairseq"
|
514 |
+
cache_dir = cache_dir or (Path.home() / ".cache" / library_name).as_posix()
|
515 |
+
cache_dir = snapshot_download(
|
516 |
+
model_id, cache_dir=cache_dir, library_name=library_name, **kwargs
|
517 |
+
)
|
518 |
+
|
519 |
+
_arg_overrides = arg_overrides or {}
|
520 |
+
_arg_overrides["data"] = cache_dir
|
521 |
+
return load_model_ensemble_and_task(
|
522 |
+
[p.as_posix() for p in Path(cache_dir).glob("*.pt")],
|
523 |
+
arg_overrides=_arg_overrides,
|
524 |
+
)
|
525 |
+
|
526 |
+
|
527 |
+
def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt", keep_match=False):
|
528 |
+
"""Retrieves all checkpoints found in `path` directory.
|
529 |
+
|
530 |
+
Checkpoints are identified by matching filename to the specified pattern. If
|
531 |
+
the pattern contains groups, the result will be sorted by the first group in
|
532 |
+
descending order.
|
533 |
+
"""
|
534 |
+
pt_regexp = re.compile(pattern)
|
535 |
+
files = PathManager.ls(path)
|
536 |
+
|
537 |
+
entries = []
|
538 |
+
for i, f in enumerate(files):
|
539 |
+
m = pt_regexp.fullmatch(f)
|
540 |
+
if m is not None:
|
541 |
+
idx = float(m.group(1)) if len(m.groups()) > 0 else i
|
542 |
+
entries.append((idx, m.group(0)))
|
543 |
+
if keep_match:
|
544 |
+
return [(os.path.join(path, x[1]), x[0]) for x in sorted(entries, reverse=True)]
|
545 |
+
else:
|
546 |
+
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
|
547 |
+
|
548 |
+
|
549 |
+
def torch_persistent_save(obj, filename, async_write: bool = False):
|
550 |
+
if async_write:
|
551 |
+
with PathManager.opena(filename, "wb") as f:
|
552 |
+
_torch_persistent_save(obj, f)
|
553 |
+
else:
|
554 |
+
if PathManager.supports_rename(filename):
|
555 |
+
# do atomic save
|
556 |
+
with PathManager.open(filename + ".tmp", "wb") as f:
|
557 |
+
_torch_persistent_save(obj, f)
|
558 |
+
PathManager.rename(filename + ".tmp", filename)
|
559 |
+
else:
|
560 |
+
# fallback to non-atomic save
|
561 |
+
with PathManager.open(filename, "wb") as f:
|
562 |
+
_torch_persistent_save(obj, f)
|
563 |
+
|
564 |
+
|
565 |
+
def _torch_persistent_save(obj, f):
|
566 |
+
if isinstance(f, str):
|
567 |
+
with PathManager.open(f, "wb") as h:
|
568 |
+
torch_persistent_save(obj, h)
|
569 |
+
return
|
570 |
+
for i in range(3):
|
571 |
+
try:
|
572 |
+
return torch.save(obj, f)
|
573 |
+
except Exception:
|
574 |
+
if i == 2:
|
575 |
+
logger.error(traceback.format_exc())
|
576 |
+
raise
|
577 |
+
|
578 |
+
|
579 |
+
def _upgrade_state_dict(state):
|
580 |
+
"""Helper for upgrading old model checkpoints."""
|
581 |
+
|
582 |
+
# add optimizer_history
|
583 |
+
if "optimizer_history" not in state:
|
584 |
+
state["optimizer_history"] = [
|
585 |
+
{"criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"]}
|
586 |
+
]
|
587 |
+
state["last_optimizer_state"] = state["optimizer"]
|
588 |
+
del state["optimizer"]
|
589 |
+
del state["best_loss"]
|
590 |
+
# move extra_state into sub-dictionary
|
591 |
+
if "epoch" in state and "extra_state" not in state:
|
592 |
+
state["extra_state"] = {
|
593 |
+
"epoch": state["epoch"],
|
594 |
+
"batch_offset": state["batch_offset"],
|
595 |
+
"val_loss": state["val_loss"],
|
596 |
+
}
|
597 |
+
del state["epoch"]
|
598 |
+
del state["batch_offset"]
|
599 |
+
del state["val_loss"]
|
600 |
+
# reduce optimizer history's memory usage (only keep the last state)
|
601 |
+
if "optimizer" in state["optimizer_history"][-1]:
|
602 |
+
state["last_optimizer_state"] = state["optimizer_history"][-1]["optimizer"]
|
603 |
+
for optim_hist in state["optimizer_history"]:
|
604 |
+
del optim_hist["optimizer"]
|
605 |
+
# record the optimizer class name
|
606 |
+
if "optimizer_name" not in state["optimizer_history"][-1]:
|
607 |
+
state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG"
|
608 |
+
# move best_loss into lr_scheduler_state
|
609 |
+
if "lr_scheduler_state" not in state["optimizer_history"][-1]:
|
610 |
+
state["optimizer_history"][-1]["lr_scheduler_state"] = {
|
611 |
+
"best": state["optimizer_history"][-1]["best_loss"]
|
612 |
+
}
|
613 |
+
del state["optimizer_history"][-1]["best_loss"]
|
614 |
+
# keep track of number of updates
|
615 |
+
if "num_updates" not in state["optimizer_history"][-1]:
|
616 |
+
state["optimizer_history"][-1]["num_updates"] = 0
|
617 |
+
# use stateful training data iterator
|
618 |
+
if "train_iterator" not in state["extra_state"]:
|
619 |
+
state["extra_state"]["train_iterator"] = {
|
620 |
+
"epoch": state["extra_state"].get("epoch", 0),
|
621 |
+
"iterations_in_epoch": state["extra_state"].get("batch_offset", 0),
|
622 |
+
}
|
623 |
+
|
624 |
+
# backward compatibility, cfg updates
|
625 |
+
if "args" in state and state["args"] is not None:
|
626 |
+
# old model checkpoints may not have separate source/target positions
|
627 |
+
if hasattr(state["args"], "max_positions") and not hasattr(
|
628 |
+
state["args"], "max_source_positions"
|
629 |
+
):
|
630 |
+
state["args"].max_source_positions = state["args"].max_positions
|
631 |
+
state["args"].max_target_positions = state["args"].max_positions
|
632 |
+
# default to translation task
|
633 |
+
if not hasattr(state["args"], "task"):
|
634 |
+
state["args"].task = "translation"
|
635 |
+
# --raw-text and --lazy-load are deprecated
|
636 |
+
if getattr(state["args"], "raw_text", False):
|
637 |
+
state["args"].dataset_impl = "raw"
|
638 |
+
elif getattr(state["args"], "lazy_load", False):
|
639 |
+
state["args"].dataset_impl = "lazy"
|
640 |
+
# epochs start at 1
|
641 |
+
if state["extra_state"]["train_iterator"] is not None:
|
642 |
+
state["extra_state"]["train_iterator"]["epoch"] = max(
|
643 |
+
state["extra_state"]["train_iterator"].get("epoch", 1), 1
|
644 |
+
)
|
645 |
+
# --remove-bpe ==> --postprocess
|
646 |
+
if hasattr(state["args"], "remove_bpe"):
|
647 |
+
state["args"].post_process = state["args"].remove_bpe
|
648 |
+
# --min-lr ==> --stop-min-lr
|
649 |
+
if hasattr(state["args"], "min_lr"):
|
650 |
+
state["args"].stop_min_lr = state["args"].min_lr
|
651 |
+
del state["args"].min_lr
|
652 |
+
# binary_cross_entropy / kd_binary_cross_entropy => wav2vec criterion
|
653 |
+
if hasattr(state["args"], "criterion") and state["args"].criterion in [
|
654 |
+
"binary_cross_entropy",
|
655 |
+
"kd_binary_cross_entropy",
|
656 |
+
]:
|
657 |
+
state["args"].criterion = "wav2vec"
|
658 |
+
# remove log_keys if it's None (criteria will supply a default value of [])
|
659 |
+
if hasattr(state["args"], "log_keys") and state["args"].log_keys is None:
|
660 |
+
delattr(state["args"], "log_keys")
|
661 |
+
# speech_pretraining => audio pretraining
|
662 |
+
if (
|
663 |
+
hasattr(state["args"], "task")
|
664 |
+
and state["args"].task == "speech_pretraining"
|
665 |
+
):
|
666 |
+
state["args"].task = "audio_pretraining"
|
667 |
+
# audio_cpc => wav2vec
|
668 |
+
if hasattr(state["args"], "arch") and state["args"].arch == "audio_cpc":
|
669 |
+
state["args"].arch = "wav2vec"
|
670 |
+
# convert legacy float learning rate to List[float]
|
671 |
+
if hasattr(state["args"], "lr") and isinstance(state["args"].lr, float):
|
672 |
+
state["args"].lr = [state["args"].lr]
|
673 |
+
# convert task data arg to a string instead of List[string]
|
674 |
+
if (
|
675 |
+
hasattr(state["args"], "data")
|
676 |
+
and isinstance(state["args"].data, list)
|
677 |
+
and len(state["args"].data) > 0
|
678 |
+
):
|
679 |
+
state["args"].data = state["args"].data[0]
|
680 |
+
|
681 |
+
state["cfg"] = convert_namespace_to_omegaconf(state["args"])
|
682 |
+
|
683 |
+
if "cfg" in state and state["cfg"] is not None:
|
684 |
+
cfg = state["cfg"]
|
685 |
+
with open_dict(cfg):
|
686 |
+
# any upgrades for Hydra-based configs
|
687 |
+
if (
|
688 |
+
"task" in cfg
|
689 |
+
and "eval_wer_config" in cfg.task
|
690 |
+
and isinstance(cfg.task.eval_wer_config.print_alignment, bool)
|
691 |
+
):
|
692 |
+
cfg.task.eval_wer_config.print_alignment = "hard"
|
693 |
+
if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool):
|
694 |
+
cfg.generation.print_alignment = (
|
695 |
+
"hard" if cfg.generation.print_alignment else None
|
696 |
+
)
|
697 |
+
if (
|
698 |
+
"model" in cfg
|
699 |
+
and "w2v_args" in cfg.model
|
700 |
+
and cfg.model.w2v_args is not None
|
701 |
+
and (
|
702 |
+
hasattr(cfg.model.w2v_args, "task") or "task" in cfg.model.w2v_args
|
703 |
+
)
|
704 |
+
and hasattr(cfg.model.w2v_args.task, "eval_wer_config")
|
705 |
+
and cfg.model.w2v_args.task.eval_wer_config is not None
|
706 |
+
and isinstance(
|
707 |
+
cfg.model.w2v_args.task.eval_wer_config.print_alignment, bool
|
708 |
+
)
|
709 |
+
):
|
710 |
+
cfg.model.w2v_args.task.eval_wer_config.print_alignment = "hard"
|
711 |
+
|
712 |
+
return state
|
713 |
+
|
714 |
+
|
715 |
+
def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]):
|
716 |
+
"""Prune the given state_dict if desired for LayerDrop
|
717 |
+
(https://arxiv.org/abs/1909.11556).
|
718 |
+
|
719 |
+
Training with LayerDrop allows models to be robust to pruning at inference
|
720 |
+
time. This function prunes state_dict to allow smaller models to be loaded
|
721 |
+
from a larger model and re-maps the existing state_dict for this to occur.
|
722 |
+
|
723 |
+
It's called by functions that load models from checkpoints and does not
|
724 |
+
need to be called directly.
|
725 |
+
"""
|
726 |
+
arch = None
|
727 |
+
if model_cfg is not None:
|
728 |
+
arch = (
|
729 |
+
model_cfg._name
|
730 |
+
if isinstance(model_cfg, DictConfig)
|
731 |
+
else getattr(model_cfg, "arch", None)
|
732 |
+
)
|
733 |
+
|
734 |
+
if not model_cfg or arch is None or arch == "ptt_transformer":
|
735 |
+
# args should not be none, but don't crash if it is.
|
736 |
+
return state_dict
|
737 |
+
|
738 |
+
encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None)
|
739 |
+
decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None)
|
740 |
+
|
741 |
+
if not encoder_layers_to_keep and not decoder_layers_to_keep:
|
742 |
+
return state_dict
|
743 |
+
|
744 |
+
# apply pruning
|
745 |
+
logger.info(
|
746 |
+
"Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop"
|
747 |
+
)
|
748 |
+
|
749 |
+
def create_pruning_pass(layers_to_keep, layer_name):
|
750 |
+
keep_layers = sorted(
|
751 |
+
int(layer_string) for layer_string in layers_to_keep.split(",")
|
752 |
+
)
|
753 |
+
mapping_dict = {}
|
754 |
+
for i in range(len(keep_layers)):
|
755 |
+
mapping_dict[str(keep_layers[i])] = str(i)
|
756 |
+
|
757 |
+
regex = re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name))
|
758 |
+
return {"substitution_regex": regex, "mapping_dict": mapping_dict}
|
759 |
+
|
760 |
+
pruning_passes = []
|
761 |
+
if encoder_layers_to_keep:
|
762 |
+
pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder"))
|
763 |
+
if decoder_layers_to_keep:
|
764 |
+
pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder"))
|
765 |
+
|
766 |
+
new_state_dict = {}
|
767 |
+
for layer_name in state_dict.keys():
|
768 |
+
match = re.search(r"\.layers\.(\d+)\.", layer_name)
|
769 |
+
# if layer has no number in it, it is a supporting layer, such as an
|
770 |
+
# embedding
|
771 |
+
if not match:
|
772 |
+
new_state_dict[layer_name] = state_dict[layer_name]
|
773 |
+
continue
|
774 |
+
|
775 |
+
# otherwise, layer should be pruned.
|
776 |
+
original_layer_number = match.group(1)
|
777 |
+
# figure out which mapping dict to replace from
|
778 |
+
for pruning_pass in pruning_passes:
|
779 |
+
if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass[
|
780 |
+
"substitution_regex"
|
781 |
+
].search(layer_name):
|
782 |
+
new_layer_number = pruning_pass["mapping_dict"][original_layer_number]
|
783 |
+
substitution_match = pruning_pass["substitution_regex"].search(
|
784 |
+
layer_name
|
785 |
+
)
|
786 |
+
new_state_key = (
|
787 |
+
layer_name[: substitution_match.start(1)]
|
788 |
+
+ new_layer_number
|
789 |
+
+ layer_name[substitution_match.end(1) :]
|
790 |
+
)
|
791 |
+
new_state_dict[new_state_key] = state_dict[layer_name]
|
792 |
+
|
793 |
+
# Since layers are now pruned, *_layers_to_keep are no longer needed.
|
794 |
+
# This is more of "It would make it work fix" rather than a proper fix.
|
795 |
+
if isinstance(model_cfg, DictConfig):
|
796 |
+
context = open_dict(model_cfg)
|
797 |
+
else:
|
798 |
+
context = contextlib.ExitStack()
|
799 |
+
with context:
|
800 |
+
if hasattr(model_cfg, "encoder_layers_to_keep"):
|
801 |
+
model_cfg.encoder_layers_to_keep = None
|
802 |
+
if hasattr(model_cfg, "decoder_layers_to_keep"):
|
803 |
+
model_cfg.decoder_layers_to_keep = None
|
804 |
+
|
805 |
+
return new_state_dict
|
806 |
+
|
807 |
+
|
808 |
+
def load_pretrained_component_from_model(
|
809 |
+
component: Union[FairseqEncoder, FairseqDecoder],
|
810 |
+
checkpoint: str,
|
811 |
+
strict: bool = True,
|
812 |
+
):
|
813 |
+
"""
|
814 |
+
Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the
|
815 |
+
provided `component` object. If state_dict fails to load, there may be a
|
816 |
+
mismatch in the architecture of the corresponding `component` found in the
|
817 |
+
`checkpoint` file.
|
818 |
+
"""
|
819 |
+
if not PathManager.exists(checkpoint):
|
820 |
+
raise IOError("Model file not found: {}".format(checkpoint))
|
821 |
+
state = load_checkpoint_to_cpu(checkpoint)
|
822 |
+
if isinstance(component, FairseqEncoder):
|
823 |
+
component_type = "encoder"
|
824 |
+
elif isinstance(component, FairseqDecoder):
|
825 |
+
component_type = "decoder"
|
826 |
+
else:
|
827 |
+
raise ValueError(
|
828 |
+
"component to load must be either a FairseqEncoder or "
|
829 |
+
"FairseqDecoder. Loading other component types are not supported."
|
830 |
+
)
|
831 |
+
component_state_dict = OrderedDict()
|
832 |
+
for key in state["model"].keys():
|
833 |
+
if key.startswith(component_type):
|
834 |
+
# encoder.input_layers.0.0.weight --> input_layers.0.0.weight
|
835 |
+
component_subkey = key[len(component_type) + 1 :]
|
836 |
+
component_state_dict[component_subkey] = state["model"][key]
|
837 |
+
component.load_state_dict(component_state_dict, strict=strict)
|
838 |
+
return component
|
839 |
+
|
840 |
+
|
841 |
+
def verify_checkpoint_directory(save_dir: str) -> None:
|
842 |
+
if not os.path.exists(save_dir):
|
843 |
+
os.makedirs(save_dir, exist_ok=True)
|
844 |
+
temp_file_path = os.path.join(save_dir, "dummy")
|
845 |
+
try:
|
846 |
+
with open(temp_file_path, "w"):
|
847 |
+
pass
|
848 |
+
except OSError as e:
|
849 |
+
logger.warning(
|
850 |
+
"Unable to access checkpoint save directory: {}".format(save_dir)
|
851 |
+
)
|
852 |
+
raise e
|
853 |
+
else:
|
854 |
+
os.remove(temp_file_path)
|
855 |
+
|
856 |
+
|
857 |
+
def save_ema_as_checkpoint(src_path, dst_path):
|
858 |
+
state = load_ema_from_checkpoint(src_path)
|
859 |
+
torch_persistent_save(state, dst_path)
|
860 |
+
|
861 |
+
|
862 |
+
def load_ema_from_checkpoint(fpath):
|
863 |
+
"""Loads exponential moving averaged (EMA) checkpoint from input and
|
864 |
+
returns a model with ema weights.
|
865 |
+
|
866 |
+
Args:
|
867 |
+
fpath: A string path of checkpoint to load from.
|
868 |
+
|
869 |
+
Returns:
|
870 |
+
A dict of string keys mapping to various values. The 'model' key
|
871 |
+
from the returned dict should correspond to an OrderedDict mapping
|
872 |
+
string parameter names to torch Tensors.
|
873 |
+
"""
|
874 |
+
params_dict = collections.OrderedDict()
|
875 |
+
new_state = None
|
876 |
+
|
877 |
+
with PathManager.open(fpath, "rb") as f:
|
878 |
+
new_state = torch.load(
|
879 |
+
f,
|
880 |
+
map_location=(
|
881 |
+
lambda s, _: torch.serialization.default_restore_location(s, "cpu")
|
882 |
+
),
|
883 |
+
)
|
884 |
+
|
885 |
+
# EMA model is stored in a separate "extra state"
|
886 |
+
model_params = new_state["extra_state"]["ema"]
|
887 |
+
|
888 |
+
for key in list(model_params.keys()):
|
889 |
+
p = model_params[key]
|
890 |
+
if isinstance(p, torch.HalfTensor):
|
891 |
+
p = p.float()
|
892 |
+
if key not in params_dict:
|
893 |
+
params_dict[key] = p.clone()
|
894 |
+
# NOTE: clone() is needed in case of p is a shared parameter
|
895 |
+
else:
|
896 |
+
raise ValueError("Key {} is repeated in EMA model params.".format(key))
|
897 |
+
|
898 |
+
if len(params_dict) == 0:
|
899 |
+
raise ValueError(
|
900 |
+
f"Input checkpoint path '{fpath}' does not contain "
|
901 |
+
"ema model weights, is this model trained with EMA?"
|
902 |
+
)
|
903 |
+
|
904 |
+
new_state["model"] = params_dict
|
905 |
+
return new_state
|
modules/voice_conversion/fairseq/data/__init__.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
"""isort:skip_file"""
|
6 |
+
|
7 |
+
from .dictionary import Dictionary, TruncatedDictionary
|
8 |
+
|
9 |
+
from .fairseq_dataset import FairseqDataset, FairseqIterableDataset
|
10 |
+
|
11 |
+
from .base_wrapper_dataset import BaseWrapperDataset
|
12 |
+
|
13 |
+
from .add_target_dataset import AddTargetDataset
|
14 |
+
from .append_token_dataset import AppendTokenDataset
|
15 |
+
from .audio.raw_audio_dataset import BinarizedAudioDataset, FileAudioDataset
|
16 |
+
from .audio.hubert_dataset import HubertDataset
|
17 |
+
from .backtranslation_dataset import BacktranslationDataset
|
18 |
+
from .bucket_pad_length_dataset import BucketPadLengthDataset
|
19 |
+
from .colorize_dataset import ColorizeDataset
|
20 |
+
from .concat_dataset import ConcatDataset
|
21 |
+
from .concat_sentences_dataset import ConcatSentencesDataset
|
22 |
+
from .denoising_dataset import DenoisingDataset
|
23 |
+
from .id_dataset import IdDataset
|
24 |
+
from .indexed_dataset import (
|
25 |
+
IndexedCachedDataset,
|
26 |
+
IndexedDataset,
|
27 |
+
IndexedRawTextDataset,
|
28 |
+
MMapIndexedDataset,
|
29 |
+
)
|
30 |
+
from .language_pair_dataset import LanguagePairDataset
|
31 |
+
from .list_dataset import ListDataset
|
32 |
+
from .lm_context_window_dataset import LMContextWindowDataset
|
33 |
+
from .lru_cache_dataset import LRUCacheDataset
|
34 |
+
from .mask_tokens_dataset import MaskTokensDataset
|
35 |
+
from .monolingual_dataset import MonolingualDataset
|
36 |
+
from .multi_corpus_sampled_dataset import MultiCorpusSampledDataset
|
37 |
+
from .nested_dictionary_dataset import NestedDictionaryDataset
|
38 |
+
from .noising import NoisingDataset
|
39 |
+
from .numel_dataset import NumelDataset
|
40 |
+
from .num_samples_dataset import NumSamplesDataset
|
41 |
+
from .offset_tokens_dataset import OffsetTokensDataset
|
42 |
+
from .pad_dataset import LeftPadDataset, PadDataset, RightPadDataset
|
43 |
+
from .prepend_dataset import PrependDataset
|
44 |
+
from .prepend_token_dataset import PrependTokenDataset
|
45 |
+
from .raw_label_dataset import RawLabelDataset
|
46 |
+
from .replace_dataset import ReplaceDataset
|
47 |
+
from .resampling_dataset import ResamplingDataset
|
48 |
+
from .roll_dataset import RollDataset
|
49 |
+
from .round_robin_zip_datasets import RoundRobinZipDatasets
|
50 |
+
from .sort_dataset import SortDataset
|
51 |
+
from .strip_token_dataset import StripTokenDataset
|
52 |
+
from .subsample_dataset import SubsampleDataset
|
53 |
+
from .token_block_dataset import TokenBlockDataset
|
54 |
+
from .transform_eos_dataset import TransformEosDataset
|
55 |
+
from .transform_eos_lang_pair_dataset import TransformEosLangPairDataset
|
56 |
+
from .shorten_dataset import TruncateDataset, RandomCropDataset
|
57 |
+
from .multilingual.sampled_multi_dataset import SampledMultiDataset
|
58 |
+
from .multilingual.sampled_multi_epoch_dataset import SampledMultiEpochDataset
|
59 |
+
from .fasta_dataset import FastaDataset, EncodedFastaDataset
|
60 |
+
from .transform_eos_concat_langpair_dataset import TransformEosConcatLangPairDataset
|
61 |
+
|
62 |
+
from .iterators import (
|
63 |
+
CountingIterator,
|
64 |
+
EpochBatchIterator,
|
65 |
+
GroupedIterator,
|
66 |
+
ShardedIterator,
|
67 |
+
)
|
68 |
+
|
69 |
+
__all__ = [
|
70 |
+
"AddTargetDataset",
|
71 |
+
"AppendTokenDataset",
|
72 |
+
"BacktranslationDataset",
|
73 |
+
"BaseWrapperDataset",
|
74 |
+
"BinarizedAudioDataset",
|
75 |
+
"BucketPadLengthDataset",
|
76 |
+
"ColorizeDataset",
|
77 |
+
"ConcatDataset",
|
78 |
+
"ConcatSentencesDataset",
|
79 |
+
"CountingIterator",
|
80 |
+
"DenoisingDataset",
|
81 |
+
"Dictionary",
|
82 |
+
"EncodedFastaDataset",
|
83 |
+
"EpochBatchIterator",
|
84 |
+
"FairseqDataset",
|
85 |
+
"FairseqIterableDataset",
|
86 |
+
"FastaDataset",
|
87 |
+
"FileAudioDataset",
|
88 |
+
"GroupedIterator",
|
89 |
+
"HubertDataset",
|
90 |
+
"IdDataset",
|
91 |
+
"IndexedCachedDataset",
|
92 |
+
"IndexedDataset",
|
93 |
+
"IndexedRawTextDataset",
|
94 |
+
"LanguagePairDataset",
|
95 |
+
"LeftPadDataset",
|
96 |
+
"ListDataset",
|
97 |
+
"LMContextWindowDataset",
|
98 |
+
"LRUCacheDataset",
|
99 |
+
"MaskTokensDataset",
|
100 |
+
"MMapIndexedDataset",
|
101 |
+
"MonolingualDataset",
|
102 |
+
"MultiCorpusSampledDataset",
|
103 |
+
"NestedDictionaryDataset",
|
104 |
+
"NoisingDataset",
|
105 |
+
"NumelDataset",
|
106 |
+
"NumSamplesDataset",
|
107 |
+
"OffsetTokensDataset",
|
108 |
+
"PadDataset",
|
109 |
+
"PrependDataset",
|
110 |
+
"PrependTokenDataset",
|
111 |
+
"RandomCropDataset",
|
112 |
+
"RawLabelDataset",
|
113 |
+
"ResamplingDataset",
|
114 |
+
"ReplaceDataset",
|
115 |
+
"RightPadDataset",
|
116 |
+
"RollDataset",
|
117 |
+
"RoundRobinZipDatasets",
|
118 |
+
"SampledMultiDataset",
|
119 |
+
"SampledMultiEpochDataset",
|
120 |
+
"ShardedIterator",
|
121 |
+
"SortDataset",
|
122 |
+
"StripTokenDataset",
|
123 |
+
"SubsampleDataset",
|
124 |
+
"TokenBlockDataset",
|
125 |
+
"TransformEosDataset",
|
126 |
+
"TransformEosLangPairDataset",
|
127 |
+
"TransformEosConcatLangPairDataset",
|
128 |
+
"TruncateDataset",
|
129 |
+
"TruncatedDictionary",
|
130 |
+
]
|
modules/voice_conversion/fairseq/data/add_target_dataset.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from . import BaseWrapperDataset, data_utils
|
9 |
+
from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel
|
10 |
+
|
11 |
+
|
12 |
+
class AddTargetDataset(BaseWrapperDataset):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
dataset,
|
16 |
+
labels,
|
17 |
+
pad,
|
18 |
+
eos,
|
19 |
+
batch_targets,
|
20 |
+
process_label=None,
|
21 |
+
label_len_fn=None,
|
22 |
+
add_to_input=False,
|
23 |
+
text_compression_level=TextCompressionLevel.none,
|
24 |
+
):
|
25 |
+
super().__init__(dataset)
|
26 |
+
self.labels = labels
|
27 |
+
self.batch_targets = batch_targets
|
28 |
+
self.pad = pad
|
29 |
+
self.eos = eos
|
30 |
+
self.process_label = process_label
|
31 |
+
self.label_len_fn = label_len_fn
|
32 |
+
self.add_to_input = add_to_input
|
33 |
+
self.text_compressor = TextCompressor(level=text_compression_level)
|
34 |
+
|
35 |
+
def get_label(self, index, process_fn=None):
|
36 |
+
lbl = self.labels[index]
|
37 |
+
lbl = self.text_compressor.decompress(lbl)
|
38 |
+
return lbl if process_fn is None else process_fn(lbl)
|
39 |
+
|
40 |
+
def __getitem__(self, index):
|
41 |
+
item = self.dataset[index]
|
42 |
+
item["label"] = self.get_label(index, process_fn=self.process_label)
|
43 |
+
return item
|
44 |
+
|
45 |
+
def size(self, index):
|
46 |
+
sz = self.dataset.size(index)
|
47 |
+
own_sz = self.label_len_fn(self.get_label(index))
|
48 |
+
return sz, own_sz
|
49 |
+
|
50 |
+
def collater(self, samples):
|
51 |
+
collated = self.dataset.collater(samples)
|
52 |
+
if len(collated) == 0:
|
53 |
+
return collated
|
54 |
+
indices = set(collated["id"].tolist())
|
55 |
+
target = [s["label"] for s in samples if s["id"] in indices]
|
56 |
+
|
57 |
+
if self.add_to_input:
|
58 |
+
eos = torch.LongTensor([self.eos])
|
59 |
+
prev_output_tokens = [torch.cat([eos, t], axis=-1) for t in target]
|
60 |
+
target = [torch.cat([t, eos], axis=-1) for t in target]
|
61 |
+
collated["net_input"]["prev_output_tokens"] = prev_output_tokens
|
62 |
+
|
63 |
+
if self.batch_targets:
|
64 |
+
collated["target_lengths"] = torch.LongTensor([len(t) for t in target])
|
65 |
+
target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False)
|
66 |
+
collated["ntokens"] = collated["target_lengths"].sum().item()
|
67 |
+
if getattr(collated["net_input"], "prev_output_tokens", None):
|
68 |
+
collated["net_input"]["prev_output_tokens"] = data_utils.collate_tokens(
|
69 |
+
collated["net_input"]["prev_output_tokens"],
|
70 |
+
pad_idx=self.pad,
|
71 |
+
left_pad=False,
|
72 |
+
)
|
73 |
+
else:
|
74 |
+
collated["ntokens"] = sum([len(t) for t in target])
|
75 |
+
|
76 |
+
collated["target"] = target
|
77 |
+
return collated
|
78 |
+
|
79 |
+
def filter_indices_by_size(self, indices, max_sizes):
|
80 |
+
indices, ignored = data_utils._filter_by_size_dynamic(
|
81 |
+
indices, self.size, max_sizes
|
82 |
+
)
|
83 |
+
return indices, ignored
|
modules/voice_conversion/fairseq/data/append_token_dataset.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from . import BaseWrapperDataset
|
10 |
+
|
11 |
+
|
12 |
+
class AppendTokenDataset(BaseWrapperDataset):
|
13 |
+
def __init__(self, dataset, token=None):
|
14 |
+
super().__init__(dataset)
|
15 |
+
self.token = token
|
16 |
+
if token is not None:
|
17 |
+
self._sizes = np.array(dataset.sizes) + 1
|
18 |
+
else:
|
19 |
+
self._sizes = dataset.sizes
|
20 |
+
|
21 |
+
def __getitem__(self, idx):
|
22 |
+
item = self.dataset[idx]
|
23 |
+
if self.token is not None:
|
24 |
+
item = torch.cat([item, item.new([self.token])])
|
25 |
+
return item
|
26 |
+
|
27 |
+
@property
|
28 |
+
def sizes(self):
|
29 |
+
return self._sizes
|
30 |
+
|
31 |
+
def num_tokens(self, index):
|
32 |
+
n = self.dataset.num_tokens(index)
|
33 |
+
if self.token is not None:
|
34 |
+
n += 1
|
35 |
+
return n
|
36 |
+
|
37 |
+
def size(self, index):
|
38 |
+
n = self.dataset.size(index)
|
39 |
+
if self.token is not None:
|
40 |
+
n += 1
|
41 |
+
return n
|
modules/voice_conversion/fairseq/data/audio/__init__.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import Dict, Optional
|
3 |
+
import importlib
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
class AudioTransform(ABC):
|
9 |
+
@classmethod
|
10 |
+
@abstractmethod
|
11 |
+
def from_config_dict(cls, config: Optional[Dict] = None):
|
12 |
+
pass
|
13 |
+
|
14 |
+
|
15 |
+
class CompositeAudioTransform(AudioTransform):
|
16 |
+
def _from_config_dict(
|
17 |
+
cls,
|
18 |
+
transform_type,
|
19 |
+
get_audio_transform,
|
20 |
+
composite_cls,
|
21 |
+
config=None,
|
22 |
+
return_empty=False,
|
23 |
+
):
|
24 |
+
_config = {} if config is None else config
|
25 |
+
_transforms = _config.get(f"{transform_type}_transforms")
|
26 |
+
|
27 |
+
if _transforms is None:
|
28 |
+
if return_empty:
|
29 |
+
_transforms = []
|
30 |
+
else:
|
31 |
+
return None
|
32 |
+
|
33 |
+
transforms = [
|
34 |
+
get_audio_transform(_t).from_config_dict(_config.get(_t))
|
35 |
+
for _t in _transforms
|
36 |
+
]
|
37 |
+
return composite_cls(transforms)
|
38 |
+
|
39 |
+
def __init__(self, transforms):
|
40 |
+
self.transforms = [t for t in transforms if t is not None]
|
41 |
+
|
42 |
+
def __call__(self, x):
|
43 |
+
for t in self.transforms:
|
44 |
+
x = t(x)
|
45 |
+
return x
|
46 |
+
|
47 |
+
def __repr__(self):
|
48 |
+
format_string = (
|
49 |
+
[self.__class__.__name__ + "("]
|
50 |
+
+ [f" {t.__repr__()}" for t in self.transforms]
|
51 |
+
+ [")"]
|
52 |
+
)
|
53 |
+
return "\n".join(format_string)
|
54 |
+
|
55 |
+
|
56 |
+
def register_audio_transform(name, cls_type, registry, class_names):
|
57 |
+
def register_audio_transform_cls(cls):
|
58 |
+
if name in registry:
|
59 |
+
raise ValueError(f"Cannot register duplicate transform ({name})")
|
60 |
+
if not issubclass(cls, cls_type):
|
61 |
+
raise ValueError(
|
62 |
+
f"Transform ({name}: {cls.__name__}) must extend "
|
63 |
+
f"{cls_type.__name__}"
|
64 |
+
)
|
65 |
+
if cls.__name__ in class_names:
|
66 |
+
raise ValueError(
|
67 |
+
f"Cannot register audio transform with duplicate "
|
68 |
+
f"class name ({cls.__name__})"
|
69 |
+
)
|
70 |
+
registry[name] = cls
|
71 |
+
class_names.add(cls.__name__)
|
72 |
+
return cls
|
73 |
+
|
74 |
+
return register_audio_transform_cls
|
75 |
+
|
76 |
+
|
77 |
+
def import_transforms(transforms_dir, transform_type):
|
78 |
+
for file in os.listdir(transforms_dir):
|
79 |
+
path = os.path.join(transforms_dir, file)
|
80 |
+
if (
|
81 |
+
not file.startswith("_")
|
82 |
+
and not file.startswith(".")
|
83 |
+
and (file.endswith(".py") or os.path.isdir(path))
|
84 |
+
):
|
85 |
+
name = file[: file.find(".py")] if file.endswith(".py") else file
|
86 |
+
importlib.import_module(
|
87 |
+
f"fairseq.data.audio.{transform_type}_transforms." + name
|
88 |
+
)
|
89 |
+
|
90 |
+
|
91 |
+
# Utility fn for uniform numbers in transforms
|
92 |
+
def rand_uniform(a, b):
|
93 |
+
return np.random.uniform() * (b - a) + a
|
modules/voice_conversion/fairseq/data/audio/audio_utils.py
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
|
7 |
+
import mmap
|
8 |
+
from pathlib import Path
|
9 |
+
import io
|
10 |
+
from typing import BinaryIO, List, Optional, Tuple, Union
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
from fairseq.data.audio.waveform_transforms import CompositeAudioWaveformTransform
|
17 |
+
|
18 |
+
SF_AUDIO_FILE_EXTENSIONS = {".wav", ".flac", ".ogg"}
|
19 |
+
FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS = {".npy", ".wav", ".flac", ".ogg"}
|
20 |
+
|
21 |
+
|
22 |
+
def convert_waveform(
|
23 |
+
waveform: Union[np.ndarray, torch.Tensor],
|
24 |
+
sample_rate: int,
|
25 |
+
normalize_volume: bool = False,
|
26 |
+
to_mono: bool = False,
|
27 |
+
to_sample_rate: Optional[int] = None,
|
28 |
+
) -> Tuple[Union[np.ndarray, torch.Tensor], int]:
|
29 |
+
"""convert a waveform:
|
30 |
+
- to a target sample rate
|
31 |
+
- from multi-channel to mono channel
|
32 |
+
- volume normalization
|
33 |
+
|
34 |
+
Args:
|
35 |
+
waveform (numpy.ndarray or torch.Tensor): 2D original waveform
|
36 |
+
(channels x length)
|
37 |
+
sample_rate (int): original sample rate
|
38 |
+
normalize_volume (bool): perform volume normalization
|
39 |
+
to_mono (bool): convert to mono channel if having multiple channels
|
40 |
+
to_sample_rate (Optional[int]): target sample rate
|
41 |
+
Returns:
|
42 |
+
waveform (numpy.ndarray): converted 2D waveform (channels x length)
|
43 |
+
sample_rate (float): target sample rate
|
44 |
+
"""
|
45 |
+
try:
|
46 |
+
import torchaudio.sox_effects as ta_sox
|
47 |
+
except ImportError:
|
48 |
+
raise ImportError("Please install torchaudio: pip install torchaudio")
|
49 |
+
|
50 |
+
effects = []
|
51 |
+
if normalize_volume:
|
52 |
+
effects.append(["gain", "-n"])
|
53 |
+
if to_sample_rate is not None and to_sample_rate != sample_rate:
|
54 |
+
effects.append(["rate", f"{to_sample_rate}"])
|
55 |
+
if to_mono and waveform.shape[0] > 1:
|
56 |
+
effects.append(["channels", "1"])
|
57 |
+
if len(effects) > 0:
|
58 |
+
is_np_input = isinstance(waveform, np.ndarray)
|
59 |
+
_waveform = torch.from_numpy(waveform) if is_np_input else waveform
|
60 |
+
converted, converted_sample_rate = ta_sox.apply_effects_tensor(
|
61 |
+
_waveform, sample_rate, effects
|
62 |
+
)
|
63 |
+
if is_np_input:
|
64 |
+
converted = converted.numpy()
|
65 |
+
return converted, converted_sample_rate
|
66 |
+
return waveform, sample_rate
|
67 |
+
|
68 |
+
|
69 |
+
def get_waveform(
|
70 |
+
path_or_fp: Union[str, BinaryIO],
|
71 |
+
normalization: bool = True,
|
72 |
+
mono: bool = True,
|
73 |
+
frames: int = -1,
|
74 |
+
start: int = 0,
|
75 |
+
always_2d: bool = True,
|
76 |
+
output_sample_rate: Optional[int] = None,
|
77 |
+
normalize_volume: bool = False,
|
78 |
+
waveform_transforms: Optional[CompositeAudioWaveformTransform] = None,
|
79 |
+
) -> Tuple[np.ndarray, int]:
|
80 |
+
"""Get the waveform and sample rate of a 16-bit WAV/FLAC/OGG Vorbis audio.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
path_or_fp (str or BinaryIO): the path or file-like object
|
84 |
+
normalization (bool): normalize values to [-1, 1] (Default: True)
|
85 |
+
mono (bool): convert multi-channel audio to mono-channel one
|
86 |
+
frames (int): the number of frames to read. (-1 for reading all)
|
87 |
+
start (int): Where to start reading. A negative value counts from the end.
|
88 |
+
always_2d (bool): always return 2D array even for mono-channel audios
|
89 |
+
output_sample_rate (Optional[int]): output sample rate
|
90 |
+
normalize_volume (bool): normalize volume
|
91 |
+
Returns:
|
92 |
+
waveform (numpy.ndarray): 1D or 2D waveform (channels x length)
|
93 |
+
sample_rate (float): sample rate
|
94 |
+
"""
|
95 |
+
if isinstance(path_or_fp, str):
|
96 |
+
ext = Path(path_or_fp).suffix
|
97 |
+
if ext not in SF_AUDIO_FILE_EXTENSIONS:
|
98 |
+
raise ValueError(f"Unsupported audio format: {ext}")
|
99 |
+
|
100 |
+
try:
|
101 |
+
import soundfile as sf
|
102 |
+
except ImportError:
|
103 |
+
raise ImportError("Please install soundfile: pip install soundfile")
|
104 |
+
|
105 |
+
waveform, sample_rate = sf.read(
|
106 |
+
path_or_fp, dtype="float32", always_2d=True, frames=frames, start=start
|
107 |
+
)
|
108 |
+
waveform = waveform.T # T x C -> C x T
|
109 |
+
waveform, sample_rate = convert_waveform(
|
110 |
+
waveform,
|
111 |
+
sample_rate,
|
112 |
+
normalize_volume=normalize_volume,
|
113 |
+
to_mono=mono,
|
114 |
+
to_sample_rate=output_sample_rate,
|
115 |
+
)
|
116 |
+
|
117 |
+
if not normalization:
|
118 |
+
waveform *= 2**15 # denormalized to 16-bit signed integers
|
119 |
+
|
120 |
+
if waveform_transforms is not None:
|
121 |
+
waveform, sample_rate = waveform_transforms(waveform, sample_rate)
|
122 |
+
|
123 |
+
if not always_2d:
|
124 |
+
waveform = waveform.squeeze(axis=0)
|
125 |
+
|
126 |
+
return waveform, sample_rate
|
127 |
+
|
128 |
+
|
129 |
+
def get_features_from_npy_or_audio(path, waveform_transforms=None):
|
130 |
+
ext = Path(path).suffix
|
131 |
+
if ext not in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS:
|
132 |
+
raise ValueError(f'Unsupported file format for "{path}"')
|
133 |
+
return (
|
134 |
+
np.load(path)
|
135 |
+
if ext == ".npy"
|
136 |
+
else get_fbank(path, waveform_transforms=waveform_transforms)
|
137 |
+
)
|
138 |
+
|
139 |
+
|
140 |
+
def get_features_or_waveform_from_stored_zip(
|
141 |
+
path,
|
142 |
+
byte_offset,
|
143 |
+
byte_size,
|
144 |
+
need_waveform=False,
|
145 |
+
use_sample_rate=None,
|
146 |
+
waveform_transforms=None,
|
147 |
+
):
|
148 |
+
assert path.endswith(".zip")
|
149 |
+
data = read_from_stored_zip(path, byte_offset, byte_size)
|
150 |
+
f = io.BytesIO(data)
|
151 |
+
if is_npy_data(data):
|
152 |
+
features_or_waveform = np.load(f)
|
153 |
+
elif is_sf_audio_data(data):
|
154 |
+
features_or_waveform = (
|
155 |
+
get_waveform(
|
156 |
+
f,
|
157 |
+
always_2d=False,
|
158 |
+
output_sample_rate=use_sample_rate,
|
159 |
+
waveform_transforms=waveform_transforms,
|
160 |
+
)[0]
|
161 |
+
if need_waveform
|
162 |
+
else get_fbank(f, waveform_transforms=waveform_transforms)
|
163 |
+
)
|
164 |
+
else:
|
165 |
+
raise ValueError(f'Unknown file format for "{path}"')
|
166 |
+
return features_or_waveform
|
167 |
+
|
168 |
+
|
169 |
+
def get_features_or_waveform(
|
170 |
+
path: str, need_waveform=False, use_sample_rate=None, waveform_transforms=None
|
171 |
+
):
|
172 |
+
"""Get speech features from .npy file or waveform from .wav/.flac file.
|
173 |
+
The file may be inside an uncompressed ZIP file and is accessed via byte
|
174 |
+
offset and length.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
path (str): File path in the format of "<.npy/.wav/.flac path>" or
|
178 |
+
"<zip path>:<byte offset>:<byte length>".
|
179 |
+
need_waveform (bool): return waveform instead of features.
|
180 |
+
use_sample_rate (int): change sample rate for the input wave file
|
181 |
+
|
182 |
+
Returns:
|
183 |
+
features_or_waveform (numpy.ndarray): speech features or waveform.
|
184 |
+
"""
|
185 |
+
_path, slice_ptr = parse_path(path)
|
186 |
+
if len(slice_ptr) == 0:
|
187 |
+
if need_waveform:
|
188 |
+
return get_waveform(
|
189 |
+
_path,
|
190 |
+
always_2d=False,
|
191 |
+
output_sample_rate=use_sample_rate,
|
192 |
+
waveform_transforms=waveform_transforms,
|
193 |
+
)[0]
|
194 |
+
return get_features_from_npy_or_audio(
|
195 |
+
_path, waveform_transforms=waveform_transforms
|
196 |
+
)
|
197 |
+
elif len(slice_ptr) == 2:
|
198 |
+
features_or_waveform = get_features_or_waveform_from_stored_zip(
|
199 |
+
_path,
|
200 |
+
slice_ptr[0],
|
201 |
+
slice_ptr[1],
|
202 |
+
need_waveform=need_waveform,
|
203 |
+
use_sample_rate=use_sample_rate,
|
204 |
+
waveform_transforms=waveform_transforms,
|
205 |
+
)
|
206 |
+
else:
|
207 |
+
raise ValueError(f"Invalid path: {path}")
|
208 |
+
|
209 |
+
return features_or_waveform
|
210 |
+
|
211 |
+
|
212 |
+
def _get_kaldi_fbank(
|
213 |
+
waveform: np.ndarray, sample_rate: int, n_bins=80
|
214 |
+
) -> Optional[np.ndarray]:
|
215 |
+
"""Get mel-filter bank features via PyKaldi."""
|
216 |
+
try:
|
217 |
+
from kaldi.feat.fbank import Fbank, FbankOptions
|
218 |
+
from kaldi.feat.mel import MelBanksOptions
|
219 |
+
from kaldi.feat.window import FrameExtractionOptions
|
220 |
+
from kaldi.matrix import Vector
|
221 |
+
|
222 |
+
mel_opts = MelBanksOptions()
|
223 |
+
mel_opts.num_bins = n_bins
|
224 |
+
frame_opts = FrameExtractionOptions()
|
225 |
+
frame_opts.samp_freq = sample_rate
|
226 |
+
opts = FbankOptions()
|
227 |
+
opts.mel_opts = mel_opts
|
228 |
+
opts.frame_opts = frame_opts
|
229 |
+
fbank = Fbank(opts=opts)
|
230 |
+
features = fbank.compute(Vector(waveform.squeeze()), 1.0).numpy()
|
231 |
+
return features
|
232 |
+
except ImportError:
|
233 |
+
return None
|
234 |
+
|
235 |
+
|
236 |
+
def _get_torchaudio_fbank(
|
237 |
+
waveform: np.ndarray, sample_rate, n_bins=80
|
238 |
+
) -> Optional[np.ndarray]:
|
239 |
+
"""Get mel-filter bank features via TorchAudio."""
|
240 |
+
try:
|
241 |
+
import torchaudio.compliance.kaldi as ta_kaldi
|
242 |
+
|
243 |
+
waveform = torch.from_numpy(waveform)
|
244 |
+
features = ta_kaldi.fbank(
|
245 |
+
waveform, num_mel_bins=n_bins, sample_frequency=sample_rate
|
246 |
+
)
|
247 |
+
return features.numpy()
|
248 |
+
except ImportError:
|
249 |
+
return None
|
250 |
+
|
251 |
+
|
252 |
+
def get_fbank(
|
253 |
+
path_or_fp: Union[str, BinaryIO], n_bins=80, waveform_transforms=None
|
254 |
+
) -> np.ndarray:
|
255 |
+
"""Get mel-filter bank features via PyKaldi or TorchAudio. Prefer PyKaldi
|
256 |
+
(faster CPP implementation) to TorchAudio (Python implementation). Note that
|
257 |
+
Kaldi/TorchAudio requires 16-bit signed integers as inputs and hence the
|
258 |
+
waveform should not be normalized."""
|
259 |
+
waveform, sample_rate = get_waveform(
|
260 |
+
path_or_fp, normalization=False, waveform_transforms=waveform_transforms
|
261 |
+
)
|
262 |
+
|
263 |
+
features = _get_kaldi_fbank(waveform, sample_rate, n_bins)
|
264 |
+
if features is None:
|
265 |
+
features = _get_torchaudio_fbank(waveform, sample_rate, n_bins)
|
266 |
+
if features is None:
|
267 |
+
raise ImportError(
|
268 |
+
"Please install pyKaldi or torchaudio to enable "
|
269 |
+
"online filterbank feature extraction"
|
270 |
+
)
|
271 |
+
|
272 |
+
return features
|
273 |
+
|
274 |
+
|
275 |
+
def is_npy_data(data: bytes) -> bool:
|
276 |
+
return data[0] == 147 and data[1] == 78
|
277 |
+
|
278 |
+
|
279 |
+
def is_sf_audio_data(data: bytes) -> bool:
|
280 |
+
is_wav = data[0] == 82 and data[1] == 73 and data[2] == 70
|
281 |
+
is_flac = data[0] == 102 and data[1] == 76 and data[2] == 97
|
282 |
+
is_ogg = data[0] == 79 and data[1] == 103 and data[2] == 103
|
283 |
+
return is_wav or is_flac or is_ogg
|
284 |
+
|
285 |
+
|
286 |
+
def mmap_read(path: str, offset: int, length: int) -> bytes:
|
287 |
+
with open(path, "rb") as f:
|
288 |
+
with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_o:
|
289 |
+
data = mmap_o[offset : offset + length]
|
290 |
+
return data
|
291 |
+
|
292 |
+
|
293 |
+
def read_from_stored_zip(zip_path: str, offset: int, length: int) -> bytes:
|
294 |
+
return mmap_read(zip_path, offset, length)
|
295 |
+
|
296 |
+
|
297 |
+
def parse_path(path: str) -> Tuple[str, List[int]]:
|
298 |
+
"""Parse data path which is either a path to
|
299 |
+
1. a .npy/.wav/.flac/.ogg file
|
300 |
+
2. a stored ZIP file with slicing info: "[zip_path]:[offset]:[length]"
|
301 |
+
|
302 |
+
Args:
|
303 |
+
path (str): the data path to parse
|
304 |
+
|
305 |
+
Returns:
|
306 |
+
file_path (str): the file path
|
307 |
+
slice_ptr (list of int): empty in case 1;
|
308 |
+
byte offset and length for the slice in case 2
|
309 |
+
"""
|
310 |
+
|
311 |
+
if Path(path).suffix in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS:
|
312 |
+
_path, slice_ptr = path, []
|
313 |
+
else:
|
314 |
+
_path, *slice_ptr = path.split(":")
|
315 |
+
if not Path(_path).is_file():
|
316 |
+
raise FileNotFoundError(f"File not found: {_path}")
|
317 |
+
assert len(slice_ptr) in {0, 2}, f"Invalid path: {path}"
|
318 |
+
slice_ptr = [int(i) for i in slice_ptr]
|
319 |
+
return _path, slice_ptr
|
320 |
+
|
321 |
+
|
322 |
+
def get_window(window_fn: callable, n_fft: int, win_length: int) -> torch.Tensor:
|
323 |
+
padding = n_fft - win_length
|
324 |
+
assert padding >= 0
|
325 |
+
return F.pad(window_fn(win_length), (padding // 2, padding - padding // 2))
|
326 |
+
|
327 |
+
|
328 |
+
def get_fourier_basis(n_fft: int) -> torch.Tensor:
|
329 |
+
basis = np.fft.fft(np.eye(n_fft))
|
330 |
+
basis = np.vstack(
|
331 |
+
[np.real(basis[: n_fft // 2 + 1, :]), np.imag(basis[: n_fft // 2 + 1, :])]
|
332 |
+
)
|
333 |
+
return torch.from_numpy(basis).float()
|
334 |
+
|
335 |
+
|
336 |
+
def get_mel_filters(
|
337 |
+
sample_rate: int, n_fft: int, n_mels: int, f_min: float, f_max: float
|
338 |
+
) -> torch.Tensor:
|
339 |
+
try:
|
340 |
+
import librosa
|
341 |
+
except ImportError:
|
342 |
+
raise ImportError("Please install librosa: pip install librosa")
|
343 |
+
basis = librosa.filters.mel(sample_rate, n_fft, n_mels, f_min, f_max)
|
344 |
+
return torch.from_numpy(basis).float()
|
345 |
+
|
346 |
+
|
347 |
+
class TTSSpectrogram(torch.nn.Module):
|
348 |
+
def __init__(
|
349 |
+
self,
|
350 |
+
n_fft: int,
|
351 |
+
win_length: int,
|
352 |
+
hop_length: int,
|
353 |
+
window_fn: callable = torch.hann_window,
|
354 |
+
return_phase: bool = False,
|
355 |
+
) -> None:
|
356 |
+
super(TTSSpectrogram, self).__init__()
|
357 |
+
self.n_fft = n_fft
|
358 |
+
self.hop_length = hop_length
|
359 |
+
self.return_phase = return_phase
|
360 |
+
|
361 |
+
basis = get_fourier_basis(n_fft).unsqueeze(1)
|
362 |
+
basis *= get_window(window_fn, n_fft, win_length)
|
363 |
+
self.register_buffer("basis", basis)
|
364 |
+
|
365 |
+
def forward(
|
366 |
+
self, waveform: torch.Tensor
|
367 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
368 |
+
padding = (self.n_fft // 2, self.n_fft // 2)
|
369 |
+
x = F.pad(waveform.unsqueeze(1), padding, mode="reflect")
|
370 |
+
x = F.conv1d(x, self.basis, stride=self.hop_length)
|
371 |
+
real_part = x[:, : self.n_fft // 2 + 1, :]
|
372 |
+
imag_part = x[:, self.n_fft // 2 + 1 :, :]
|
373 |
+
magnitude = torch.sqrt(real_part**2 + imag_part**2)
|
374 |
+
if self.return_phase:
|
375 |
+
phase = torch.atan2(imag_part, real_part)
|
376 |
+
return magnitude, phase
|
377 |
+
return magnitude
|
378 |
+
|
379 |
+
|
380 |
+
class TTSMelScale(torch.nn.Module):
|
381 |
+
def __init__(
|
382 |
+
self, n_mels: int, sample_rate: int, f_min: float, f_max: float, n_stft: int
|
383 |
+
) -> None:
|
384 |
+
super(TTSMelScale, self).__init__()
|
385 |
+
basis = get_mel_filters(sample_rate, (n_stft - 1) * 2, n_mels, f_min, f_max)
|
386 |
+
self.register_buffer("basis", basis)
|
387 |
+
|
388 |
+
def forward(self, specgram: torch.Tensor) -> torch.Tensor:
|
389 |
+
return torch.matmul(self.basis, specgram)
|
modules/voice_conversion/fairseq/data/audio/data_cfg.py
ADDED
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
from argparse import Namespace
|
8 |
+
from copy import deepcopy
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import Dict, Optional
|
11 |
+
|
12 |
+
from fairseq.data import Dictionary
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
def get_config_from_yaml(yaml_path: Path):
|
18 |
+
try:
|
19 |
+
import yaml
|
20 |
+
except ImportError:
|
21 |
+
print("Please install PyYAML: pip install PyYAML")
|
22 |
+
config = {}
|
23 |
+
if yaml_path.is_file():
|
24 |
+
try:
|
25 |
+
with open(yaml_path) as f:
|
26 |
+
config = yaml.load(f, Loader=yaml.FullLoader)
|
27 |
+
except Exception as e:
|
28 |
+
raise Exception(f"Failed to load config from {yaml_path.as_posix()}: {e}")
|
29 |
+
else:
|
30 |
+
raise FileNotFoundError(f"{yaml_path.as_posix()} not found")
|
31 |
+
|
32 |
+
return config
|
33 |
+
|
34 |
+
|
35 |
+
class S2TDataConfig(object):
|
36 |
+
"""Wrapper class for data config YAML"""
|
37 |
+
|
38 |
+
def __init__(self, yaml_path: Path):
|
39 |
+
self.config = get_config_from_yaml(yaml_path)
|
40 |
+
self.root = yaml_path.parent
|
41 |
+
|
42 |
+
def _auto_convert_to_abs_path(self, x):
|
43 |
+
if isinstance(x, str):
|
44 |
+
if not Path(x).exists() and (self.root / x).exists():
|
45 |
+
return (self.root / x).as_posix()
|
46 |
+
elif isinstance(x, dict):
|
47 |
+
return {k: self._auto_convert_to_abs_path(v) for k, v in x.items()}
|
48 |
+
return x
|
49 |
+
|
50 |
+
@property
|
51 |
+
def vocab_filename(self):
|
52 |
+
"""fairseq vocabulary file under data root"""
|
53 |
+
return self.config.get("vocab_filename", "dict.txt")
|
54 |
+
|
55 |
+
@property
|
56 |
+
def speaker_set_filename(self):
|
57 |
+
"""speaker set file under data root"""
|
58 |
+
return self.config.get("speaker_set_filename", None)
|
59 |
+
|
60 |
+
@property
|
61 |
+
def shuffle(self) -> bool:
|
62 |
+
"""Shuffle dataset samples before batching"""
|
63 |
+
return self.config.get("shuffle", False)
|
64 |
+
|
65 |
+
@property
|
66 |
+
def pre_tokenizer(self) -> Dict:
|
67 |
+
"""Pre-tokenizer to apply before subword tokenization. Returning
|
68 |
+
a dictionary with `tokenizer` providing the tokenizer name and
|
69 |
+
the other items providing the tokenizer-specific arguments.
|
70 |
+
Tokenizers are defined in `fairseq.data.encoders.*`"""
|
71 |
+
tokenizer = self.config.get("pre_tokenizer", {"tokenizer": None})
|
72 |
+
return self._auto_convert_to_abs_path(tokenizer)
|
73 |
+
|
74 |
+
@property
|
75 |
+
def bpe_tokenizer(self) -> Dict:
|
76 |
+
"""Subword tokenizer to apply after pre-tokenization. Returning
|
77 |
+
a dictionary with `bpe` providing the tokenizer name and
|
78 |
+
the other items providing the tokenizer-specific arguments.
|
79 |
+
Tokenizers are defined in `fairseq.data.encoders.*`"""
|
80 |
+
tokenizer = self.config.get("bpe_tokenizer", {"bpe": None})
|
81 |
+
return self._auto_convert_to_abs_path(tokenizer)
|
82 |
+
|
83 |
+
@property
|
84 |
+
def prepend_tgt_lang_tag(self) -> bool:
|
85 |
+
"""Prepend target lang ID token as the target BOS (e.g. for to-many
|
86 |
+
multilingual setting). During inference, this requires `--prefix-size 1`
|
87 |
+
to force BOS to be lang ID token."""
|
88 |
+
return self.config.get("prepend_tgt_lang_tag", False)
|
89 |
+
|
90 |
+
@property
|
91 |
+
def prepend_bos_and_append_tgt_lang_tag(self) -> bool:
|
92 |
+
"""Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining)."""
|
93 |
+
return self.config.get("prepend_bos_and_append_tgt_lang_tag", False)
|
94 |
+
|
95 |
+
@property
|
96 |
+
def input_feat_per_channel(self):
|
97 |
+
"""The dimension of input features (per audio channel)"""
|
98 |
+
return self.config.get("input_feat_per_channel", 80)
|
99 |
+
|
100 |
+
@property
|
101 |
+
def input_channels(self):
|
102 |
+
"""The number of channels in the input audio"""
|
103 |
+
return self.config.get("input_channels", 1)
|
104 |
+
|
105 |
+
@property
|
106 |
+
def sample_rate(self):
|
107 |
+
return self.config.get("sample_rate", 16_000)
|
108 |
+
|
109 |
+
@property
|
110 |
+
def sampling_alpha(self):
|
111 |
+
"""Hyper-parameter alpha = 1/T for temperature-based resampling.
|
112 |
+
(alpha = 1 for no resampling)"""
|
113 |
+
return self.config.get("sampling_alpha", 1.0)
|
114 |
+
|
115 |
+
@property
|
116 |
+
def use_audio_input(self):
|
117 |
+
"""Needed by the dataset loader to see if the model requires
|
118 |
+
raw audio as inputs."""
|
119 |
+
return self.config.get("use_audio_input", False)
|
120 |
+
|
121 |
+
def standardize_audio(self) -> bool:
|
122 |
+
return self.use_audio_input and self.config.get("standardize_audio", False)
|
123 |
+
|
124 |
+
@property
|
125 |
+
def use_sample_rate(self):
|
126 |
+
"""Needed by the dataset loader to see if the model requires
|
127 |
+
raw audio with specific sample rate as inputs."""
|
128 |
+
return self.config.get("use_sample_rate", 16000)
|
129 |
+
|
130 |
+
@property
|
131 |
+
def audio_root(self):
|
132 |
+
"""Audio paths in the manifest TSV can be relative and this provides
|
133 |
+
the root path. Set this to empty string when using absolute paths."""
|
134 |
+
return self.config.get("audio_root", "")
|
135 |
+
|
136 |
+
def get_transforms(self, transform_type, split, is_train):
|
137 |
+
"""Split-specific feature transforms. Allowing train set
|
138 |
+
wildcard `_train`, evaluation set wildcard `_eval` and general
|
139 |
+
wildcard `*` for matching."""
|
140 |
+
from copy import deepcopy
|
141 |
+
|
142 |
+
cfg = deepcopy(self.config)
|
143 |
+
_cur = cfg.get(f"{transform_type}transforms", {})
|
144 |
+
cur = _cur.get(split)
|
145 |
+
cur = _cur.get("_train") if cur is None and is_train else cur
|
146 |
+
cur = _cur.get("_eval") if cur is None and not is_train else cur
|
147 |
+
cur = _cur.get("*") if cur is None else cur
|
148 |
+
return cur
|
149 |
+
|
150 |
+
def get_feature_transforms(self, split, is_train):
|
151 |
+
cfg = deepcopy(self.config)
|
152 |
+
# TODO: deprecate transforms
|
153 |
+
cur = self.get_transforms("", split, is_train)
|
154 |
+
if cur is not None:
|
155 |
+
logger.warning(
|
156 |
+
"Auto converting transforms into feature_transforms, "
|
157 |
+
"but transforms will be deprecated in the future. Please "
|
158 |
+
"update this in the config."
|
159 |
+
)
|
160 |
+
ft_transforms = self.get_transforms("feature_", split, is_train)
|
161 |
+
if ft_transforms:
|
162 |
+
cur.extend(ft_transforms)
|
163 |
+
else:
|
164 |
+
cur = self.get_transforms("feature_", split, is_train)
|
165 |
+
cfg["feature_transforms"] = cur
|
166 |
+
return cfg
|
167 |
+
|
168 |
+
def get_waveform_transforms(self, split, is_train):
|
169 |
+
cfg = deepcopy(self.config)
|
170 |
+
cfg["waveform_transforms"] = self.get_transforms("waveform_", split, is_train)
|
171 |
+
return cfg
|
172 |
+
|
173 |
+
def get_dataset_transforms(self, split, is_train):
|
174 |
+
cfg = deepcopy(self.config)
|
175 |
+
cfg["dataset_transforms"] = self.get_transforms("dataset_", split, is_train)
|
176 |
+
return cfg
|
177 |
+
|
178 |
+
@property
|
179 |
+
def global_cmvn_stats_npz(self) -> Optional[str]:
|
180 |
+
path = self.config.get("global_cmvn", {}).get("stats_npz_path", None)
|
181 |
+
return self._auto_convert_to_abs_path(path)
|
182 |
+
|
183 |
+
@property
|
184 |
+
def vocoder(self) -> Dict[str, str]:
|
185 |
+
vocoder = self.config.get("vocoder", {"type": "griffin_lim"})
|
186 |
+
return self._auto_convert_to_abs_path(vocoder)
|
187 |
+
|
188 |
+
@property
|
189 |
+
def hub(self) -> Dict[str, str]:
|
190 |
+
return self.config.get("hub", {})
|
191 |
+
|
192 |
+
|
193 |
+
class S2SDataConfig(S2TDataConfig):
|
194 |
+
"""Wrapper class for data config YAML"""
|
195 |
+
|
196 |
+
@property
|
197 |
+
def vocab_filename(self):
|
198 |
+
"""fairseq vocabulary file under data root"""
|
199 |
+
return self.config.get("vocab_filename", None)
|
200 |
+
|
201 |
+
@property
|
202 |
+
def pre_tokenizer(self) -> Dict:
|
203 |
+
return None
|
204 |
+
|
205 |
+
@property
|
206 |
+
def bpe_tokenizer(self) -> Dict:
|
207 |
+
return None
|
208 |
+
|
209 |
+
@property
|
210 |
+
def input_transformed_channels(self):
|
211 |
+
"""The number of channels in the audio after feature transforms"""
|
212 |
+
# TODO: move this into individual transforms
|
213 |
+
# TODO: deprecate transforms
|
214 |
+
_cur = self.config.get("transforms", {})
|
215 |
+
ft_transforms = self.config.get("feature_transforms", {})
|
216 |
+
if _cur and ft_transforms:
|
217 |
+
_cur.update(ft_transforms)
|
218 |
+
else:
|
219 |
+
_cur = self.config.get("feature_transforms", {})
|
220 |
+
cur = _cur.get("_train", [])
|
221 |
+
|
222 |
+
_channels = self.input_channels
|
223 |
+
if "delta_deltas" in cur:
|
224 |
+
_channels *= 3
|
225 |
+
|
226 |
+
return _channels
|
227 |
+
|
228 |
+
@property
|
229 |
+
def output_sample_rate(self):
|
230 |
+
"""The audio sample rate of output target speech"""
|
231 |
+
return self.config.get("output_sample_rate", 22050)
|
232 |
+
|
233 |
+
@property
|
234 |
+
def target_speaker_embed(self):
|
235 |
+
"""Target speaker embedding file (one line per target audio sample)"""
|
236 |
+
return self.config.get("target_speaker_embed", None)
|
237 |
+
|
238 |
+
@property
|
239 |
+
def prepend_tgt_lang_tag_as_bos(self) -> bool:
|
240 |
+
"""Prepend target lang ID token as the target BOS."""
|
241 |
+
return self.config.get("prepend_tgt_lang_tag_as_bos", False)
|
242 |
+
|
243 |
+
|
244 |
+
class MultitaskConfig(object):
|
245 |
+
"""Wrapper class for data config YAML"""
|
246 |
+
|
247 |
+
def __init__(self, yaml_path: Path):
|
248 |
+
config = get_config_from_yaml(yaml_path)
|
249 |
+
self.config = {}
|
250 |
+
for k, v in config.items():
|
251 |
+
self.config[k] = SingleTaskConfig(k, v)
|
252 |
+
|
253 |
+
def get_all_tasks(self):
|
254 |
+
return self.config
|
255 |
+
|
256 |
+
def get_single_task(self, name):
|
257 |
+
assert name in self.config, f"multitask '{name}' does not exist!"
|
258 |
+
return self.config[name]
|
259 |
+
|
260 |
+
@property
|
261 |
+
def first_pass_decoder_task_index(self):
|
262 |
+
"""Return the task index of the first-pass text decoder.
|
263 |
+
If there are multiple 'is_first_pass_decoder: True' in the config file,
|
264 |
+
the last task is used for the first-pass decoder.
|
265 |
+
If there is no 'is_first_pass_decoder: True' in the config file,
|
266 |
+
the last task whose task_name includes 'target' and decoder_type is not ctc.
|
267 |
+
"""
|
268 |
+
idx = -1
|
269 |
+
for i, (k, v) in enumerate(self.config.items()):
|
270 |
+
if v.is_first_pass_decoder:
|
271 |
+
idx = i
|
272 |
+
if idx < 0:
|
273 |
+
for i, (k, v) in enumerate(self.config.items()):
|
274 |
+
if k.startswith("target") and v.decoder_type == "transformer":
|
275 |
+
idx = i
|
276 |
+
return idx
|
277 |
+
|
278 |
+
|
279 |
+
class SingleTaskConfig(object):
|
280 |
+
def __init__(self, name, config):
|
281 |
+
self.task_name = name
|
282 |
+
self.config = config
|
283 |
+
dict_path = config.get("dict", "")
|
284 |
+
self.tgt_dict = Dictionary.load(dict_path) if Path(dict_path).exists() else None
|
285 |
+
|
286 |
+
@property
|
287 |
+
def data(self):
|
288 |
+
return self.config.get("data", "")
|
289 |
+
|
290 |
+
@property
|
291 |
+
def decoder_type(self):
|
292 |
+
return self.config.get("decoder_type", "transformer")
|
293 |
+
|
294 |
+
@property
|
295 |
+
def decoder_args(self):
|
296 |
+
"""Decoder arch related args"""
|
297 |
+
args = self.config.get("decoder_args", {})
|
298 |
+
return Namespace(**args)
|
299 |
+
|
300 |
+
@property
|
301 |
+
def criterion_cfg(self):
|
302 |
+
"""cfg for the multitask criterion"""
|
303 |
+
if self.decoder_type == "ctc":
|
304 |
+
from fairseq.criterions.ctc import CtcCriterionConfig
|
305 |
+
|
306 |
+
cfg = CtcCriterionConfig
|
307 |
+
cfg.zero_infinity = self.config.get("zero_infinity", True)
|
308 |
+
else:
|
309 |
+
from fairseq.criterions.label_smoothed_cross_entropy import (
|
310 |
+
LabelSmoothedCrossEntropyCriterionConfig,
|
311 |
+
)
|
312 |
+
|
313 |
+
cfg = LabelSmoothedCrossEntropyCriterionConfig
|
314 |
+
cfg.label_smoothing = self.config.get("label_smoothing", 0.2)
|
315 |
+
return cfg
|
316 |
+
|
317 |
+
@property
|
318 |
+
def input_from(self):
|
319 |
+
"""Condition on encoder/decoder of the main model"""
|
320 |
+
return "decoder" if "decoder_layer" in self.config else "encoder"
|
321 |
+
|
322 |
+
@property
|
323 |
+
def input_layer(self):
|
324 |
+
if self.input_from == "decoder":
|
325 |
+
return self.config["decoder_layer"] - 1
|
326 |
+
else:
|
327 |
+
# default using the output from the last encoder layer (-1)
|
328 |
+
return self.config.get("encoder_layer", 0) - 1
|
329 |
+
|
330 |
+
@property
|
331 |
+
def loss_weight_schedule(self):
|
332 |
+
return (
|
333 |
+
"decay"
|
334 |
+
if "loss_weight_max" in self.config
|
335 |
+
and "loss_weight_decay_steps" in self.config
|
336 |
+
else "fixed"
|
337 |
+
)
|
338 |
+
|
339 |
+
def get_loss_weight(self, num_updates):
|
340 |
+
if self.loss_weight_schedule == "fixed":
|
341 |
+
weight = self.config.get("loss_weight", 1.0)
|
342 |
+
else: # "decay"
|
343 |
+
assert (
|
344 |
+
self.config.get("loss_weight_decay_steps", 0) > 0
|
345 |
+
), "loss_weight_decay_steps must be greater than 0 for a decay schedule"
|
346 |
+
loss_weight_min = self.config.get("loss_weight_min", 0.0001)
|
347 |
+
loss_weight_decay_stepsize = (
|
348 |
+
self.config["loss_weight_max"] - loss_weight_min
|
349 |
+
) / self.config["loss_weight_decay_steps"]
|
350 |
+
weight = max(
|
351 |
+
self.config["loss_weight_max"]
|
352 |
+
- loss_weight_decay_stepsize * num_updates,
|
353 |
+
loss_weight_min,
|
354 |
+
)
|
355 |
+
return weight
|
356 |
+
|
357 |
+
@property
|
358 |
+
def prepend_bos_and_append_tgt_lang_tag(self) -> bool:
|
359 |
+
"""Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining)."""
|
360 |
+
return self.config.get("prepend_bos_and_append_tgt_lang_tag", False)
|
361 |
+
|
362 |
+
@property
|
363 |
+
def eos_token(self):
|
364 |
+
"""EOS token during generation"""
|
365 |
+
return self.config.get("eos_token", "<eos>")
|
366 |
+
|
367 |
+
@property
|
368 |
+
def rdrop_alpha(self):
|
369 |
+
return self.config.get("rdrop_alpha", 0.0)
|
370 |
+
|
371 |
+
@property
|
372 |
+
def is_first_pass_decoder(self):
|
373 |
+
flag = self.config.get("is_first_pass_decoder", False)
|
374 |
+
if flag:
|
375 |
+
if self.decoder_type == "ctc":
|
376 |
+
raise ValueError(
|
377 |
+
"First-pass decoder in the multi-decoder model must not be CTC."
|
378 |
+
)
|
379 |
+
if "target" not in self.task_name:
|
380 |
+
raise Warning(
|
381 |
+
'The name of the first-pass decoder does not include "target".'
|
382 |
+
)
|
383 |
+
return flag
|
384 |
+
|
385 |
+
@property
|
386 |
+
def get_lang_tag_mapping(self):
|
387 |
+
return self.config.get("lang_tag_mapping", {})
|
modules/voice_conversion/fairseq/data/audio/dataset_transforms/__init__.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from fairseq.data.audio import (
|
3 |
+
AudioTransform,
|
4 |
+
CompositeAudioTransform,
|
5 |
+
import_transforms,
|
6 |
+
register_audio_transform,
|
7 |
+
)
|
8 |
+
|
9 |
+
|
10 |
+
class AudioDatasetTransform(AudioTransform):
|
11 |
+
pass
|
12 |
+
|
13 |
+
|
14 |
+
AUDIO_DATASET_TRANSFORM_REGISTRY = {}
|
15 |
+
AUDIO_DATASET_TRANSFORM_CLASS_NAMES = set()
|
16 |
+
|
17 |
+
|
18 |
+
def get_audio_dataset_transform(name):
|
19 |
+
return AUDIO_DATASET_TRANSFORM_REGISTRY[name]
|
20 |
+
|
21 |
+
|
22 |
+
def register_audio_dataset_transform(name):
|
23 |
+
return register_audio_transform(
|
24 |
+
name,
|
25 |
+
AudioDatasetTransform,
|
26 |
+
AUDIO_DATASET_TRANSFORM_REGISTRY,
|
27 |
+
AUDIO_DATASET_TRANSFORM_CLASS_NAMES,
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
import_transforms(os.path.dirname(__file__), "dataset")
|
32 |
+
|
33 |
+
|
34 |
+
class CompositeAudioDatasetTransform(CompositeAudioTransform):
|
35 |
+
@classmethod
|
36 |
+
def from_config_dict(cls, config=None):
|
37 |
+
return super()._from_config_dict(
|
38 |
+
cls,
|
39 |
+
"dataset",
|
40 |
+
get_audio_dataset_transform,
|
41 |
+
CompositeAudioDatasetTransform,
|
42 |
+
config,
|
43 |
+
return_empty=True,
|
44 |
+
)
|
45 |
+
|
46 |
+
def get_transform(self, cls):
|
47 |
+
for t in self.transforms:
|
48 |
+
if isinstance(t, cls):
|
49 |
+
return t
|
50 |
+
return None
|
51 |
+
|
52 |
+
def has_transform(self, cls):
|
53 |
+
return self.get_transform(cls) is not None
|
modules/voice_conversion/fairseq/data/audio/dataset_transforms/concataugment.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from fairseq.data.audio.dataset_transforms import (
|
5 |
+
AudioDatasetTransform,
|
6 |
+
register_audio_dataset_transform,
|
7 |
+
)
|
8 |
+
|
9 |
+
_DEFAULTS = {"rate": 0.25, "max_tokens": 3000, "attempts": 5}
|
10 |
+
|
11 |
+
|
12 |
+
@register_audio_dataset_transform("concataugment")
|
13 |
+
class ConcatAugment(AudioDatasetTransform):
|
14 |
+
@classmethod
|
15 |
+
def from_config_dict(cls, config=None):
|
16 |
+
_config = {} if config is None else config
|
17 |
+
return ConcatAugment(
|
18 |
+
_config.get("rate", _DEFAULTS["rate"]),
|
19 |
+
_config.get("max_tokens", _DEFAULTS["max_tokens"]),
|
20 |
+
_config.get("attempts", _DEFAULTS["attempts"]),
|
21 |
+
)
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
rate=_DEFAULTS["rate"],
|
26 |
+
max_tokens=_DEFAULTS["max_tokens"],
|
27 |
+
attempts=_DEFAULTS["attempts"],
|
28 |
+
):
|
29 |
+
self.rate, self.max_tokens, self.attempts = rate, max_tokens, attempts
|
30 |
+
|
31 |
+
def __repr__(self):
|
32 |
+
return (
|
33 |
+
self.__class__.__name__
|
34 |
+
+ "("
|
35 |
+
+ ", ".join(
|
36 |
+
[
|
37 |
+
f"rate={self.rate}",
|
38 |
+
f"max_tokens={self.max_tokens}",
|
39 |
+
f"attempts={self.attempts}",
|
40 |
+
]
|
41 |
+
)
|
42 |
+
+ ")"
|
43 |
+
)
|
44 |
+
|
45 |
+
def find_indices(self, index: int, n_frames: List[int], n_samples: int):
|
46 |
+
# skip conditions: application rate, max_tokens limit exceeded
|
47 |
+
if np.random.random() > self.rate:
|
48 |
+
return [index]
|
49 |
+
if self.max_tokens and n_frames[index] > self.max_tokens:
|
50 |
+
return [index]
|
51 |
+
|
52 |
+
# pick second sample to concatenate
|
53 |
+
for _ in range(self.attempts):
|
54 |
+
index2 = np.random.randint(0, n_samples)
|
55 |
+
if index2 != index and (
|
56 |
+
not self.max_tokens
|
57 |
+
or n_frames[index] + n_frames[index2] < self.max_tokens
|
58 |
+
):
|
59 |
+
return [index, index2]
|
60 |
+
|
61 |
+
return [index]
|
modules/voice_conversion/fairseq/data/audio/dataset_transforms/noisyoverlapaugment.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from fairseq.data.audio import rand_uniform
|
5 |
+
from fairseq.data.audio.dataset_transforms import (
|
6 |
+
AudioDatasetTransform,
|
7 |
+
register_audio_dataset_transform,
|
8 |
+
)
|
9 |
+
from fairseq.data.audio.waveform_transforms.noiseaugment import (
|
10 |
+
NoiseAugmentTransform,
|
11 |
+
)
|
12 |
+
|
13 |
+
_DEFAULTS = {
|
14 |
+
"rate": 0.25,
|
15 |
+
"mixing_noise_rate": 0.1,
|
16 |
+
"noise_path": "",
|
17 |
+
"noise_snr_min": -5,
|
18 |
+
"noise_snr_max": 5,
|
19 |
+
"utterance_snr_min": -5,
|
20 |
+
"utterance_snr_max": 5,
|
21 |
+
}
|
22 |
+
|
23 |
+
|
24 |
+
@register_audio_dataset_transform("noisyoverlapaugment")
|
25 |
+
class NoisyOverlapAugment(AudioDatasetTransform):
|
26 |
+
@classmethod
|
27 |
+
def from_config_dict(cls, config=None):
|
28 |
+
_config = {} if config is None else config
|
29 |
+
return NoisyOverlapAugment(
|
30 |
+
_config.get("rate", _DEFAULTS["rate"]),
|
31 |
+
_config.get("mixing_noise_rate", _DEFAULTS["mixing_noise_rate"]),
|
32 |
+
_config.get("noise_path", _DEFAULTS["noise_path"]),
|
33 |
+
_config.get("noise_snr_min", _DEFAULTS["noise_snr_min"]),
|
34 |
+
_config.get("noise_snr_max", _DEFAULTS["noise_snr_max"]),
|
35 |
+
_config.get("utterance_snr_min", _DEFAULTS["utterance_snr_min"]),
|
36 |
+
_config.get("utterance_snr_max", _DEFAULTS["utterance_snr_max"]),
|
37 |
+
)
|
38 |
+
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
rate=_DEFAULTS["rate"],
|
42 |
+
mixing_noise_rate=_DEFAULTS["mixing_noise_rate"],
|
43 |
+
noise_path=_DEFAULTS["noise_path"],
|
44 |
+
noise_snr_min=_DEFAULTS["noise_snr_min"],
|
45 |
+
noise_snr_max=_DEFAULTS["noise_snr_max"],
|
46 |
+
utterance_snr_min=_DEFAULTS["utterance_snr_min"],
|
47 |
+
utterance_snr_max=_DEFAULTS["utterance_snr_max"],
|
48 |
+
):
|
49 |
+
self.rate = rate
|
50 |
+
self.mixing_noise_rate = mixing_noise_rate
|
51 |
+
self.noise_shaper = NoiseAugmentTransform(noise_path)
|
52 |
+
self.noise_snr_min = noise_snr_min
|
53 |
+
self.noise_snr_max = noise_snr_max
|
54 |
+
self.utterance_snr_min = utterance_snr_min
|
55 |
+
self.utterance_snr_max = utterance_snr_max
|
56 |
+
|
57 |
+
def __repr__(self):
|
58 |
+
return (
|
59 |
+
self.__class__.__name__
|
60 |
+
+ "("
|
61 |
+
+ ", ".join(
|
62 |
+
[
|
63 |
+
f"rate={self.rate}",
|
64 |
+
f"mixing_noise_rate={self.mixing_noise_rate}",
|
65 |
+
f"noise_snr_min={self.noise_snr_min}",
|
66 |
+
f"noise_snr_max={self.noise_snr_max}",
|
67 |
+
f"utterance_snr_min={self.utterance_snr_min}",
|
68 |
+
f"utterance_snr_max={self.utterance_snr_max}",
|
69 |
+
]
|
70 |
+
)
|
71 |
+
+ ")"
|
72 |
+
)
|
73 |
+
|
74 |
+
def __call__(self, sources):
|
75 |
+
for i, source in enumerate(sources):
|
76 |
+
if np.random.random() > self.rate:
|
77 |
+
continue
|
78 |
+
|
79 |
+
pri = source.numpy()
|
80 |
+
|
81 |
+
if np.random.random() > self.mixing_noise_rate:
|
82 |
+
sec = sources[np.random.randint(0, len(sources))].numpy()
|
83 |
+
snr = rand_uniform(self.utterance_snr_min, self.utterance_snr_max)
|
84 |
+
else:
|
85 |
+
sec = self.noise_shaper.pick_sample(source.shape)
|
86 |
+
snr = rand_uniform(self.noise_snr_min, self.noise_snr_max)
|
87 |
+
|
88 |
+
L1 = pri.shape[-1]
|
89 |
+
L2 = sec.shape[-1]
|
90 |
+
l = np.random.randint(0, min(round(L1 / 2), L2)) # mix len
|
91 |
+
s_source = np.random.randint(0, L1 - l)
|
92 |
+
s_sec = np.random.randint(0, L2 - l)
|
93 |
+
|
94 |
+
get_power = lambda x: np.mean(x**2)
|
95 |
+
if get_power(sec) == 0:
|
96 |
+
continue
|
97 |
+
|
98 |
+
scl = np.sqrt(get_power(pri) / (np.power(10, snr / 10) * get_power(sec)))
|
99 |
+
|
100 |
+
pri[s_source : s_source + l] = np.add(
|
101 |
+
pri[s_source : s_source + l], np.multiply(scl, sec[s_sec : s_sec + l])
|
102 |
+
)
|
103 |
+
sources[i] = torch.from_numpy(pri).float()
|
104 |
+
|
105 |
+
return sources
|
modules/voice_conversion/fairseq/data/audio/feature_transforms/__init__.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from fairseq.data.audio import (
|
3 |
+
AudioTransform,
|
4 |
+
CompositeAudioTransform,
|
5 |
+
import_transforms,
|
6 |
+
register_audio_transform,
|
7 |
+
)
|
8 |
+
|
9 |
+
|
10 |
+
class AudioFeatureTransform(AudioTransform):
|
11 |
+
pass
|
12 |
+
|
13 |
+
|
14 |
+
AUDIO_FEATURE_TRANSFORM_REGISTRY = {}
|
15 |
+
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES = set()
|
16 |
+
|
17 |
+
|
18 |
+
def get_audio_feature_transform(name):
|
19 |
+
return AUDIO_FEATURE_TRANSFORM_REGISTRY[name]
|
20 |
+
|
21 |
+
|
22 |
+
def register_audio_feature_transform(name):
|
23 |
+
return register_audio_transform(
|
24 |
+
name,
|
25 |
+
AudioFeatureTransform,
|
26 |
+
AUDIO_FEATURE_TRANSFORM_REGISTRY,
|
27 |
+
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES,
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
import_transforms(os.path.dirname(__file__), "feature")
|
32 |
+
|
33 |
+
|
34 |
+
class CompositeAudioFeatureTransform(CompositeAudioTransform):
|
35 |
+
@classmethod
|
36 |
+
def from_config_dict(cls, config=None):
|
37 |
+
return super()._from_config_dict(
|
38 |
+
cls,
|
39 |
+
"feature",
|
40 |
+
get_audio_feature_transform,
|
41 |
+
CompositeAudioFeatureTransform,
|
42 |
+
config,
|
43 |
+
)
|
modules/voice_conversion/fairseq/data/audio/feature_transforms/delta_deltas.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from fairseq.data.audio.feature_transforms import (
|
4 |
+
AudioFeatureTransform,
|
5 |
+
register_audio_feature_transform,
|
6 |
+
)
|
7 |
+
|
8 |
+
|
9 |
+
@register_audio_feature_transform("delta_deltas")
|
10 |
+
class DeltaDeltas(AudioFeatureTransform):
|
11 |
+
"""Expand delta-deltas features from spectrum."""
|
12 |
+
|
13 |
+
@classmethod
|
14 |
+
def from_config_dict(cls, config=None):
|
15 |
+
_config = {} if config is None else config
|
16 |
+
return DeltaDeltas(_config.get("win_length", 5))
|
17 |
+
|
18 |
+
def __init__(self, win_length=5):
|
19 |
+
self.win_length = win_length
|
20 |
+
|
21 |
+
def __repr__(self):
|
22 |
+
return self.__class__.__name__
|
23 |
+
|
24 |
+
def __call__(self, spectrogram):
|
25 |
+
from torchaudio.functional import compute_deltas
|
26 |
+
|
27 |
+
assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor."
|
28 |
+
# spectrogram is T x F, while compute_deltas takes (…, F, T)
|
29 |
+
spectrogram = torch.from_numpy(spectrogram).transpose(0, 1)
|
30 |
+
delta = compute_deltas(spectrogram)
|
31 |
+
delta_delta = compute_deltas(delta)
|
32 |
+
|
33 |
+
out_feat = np.concatenate(
|
34 |
+
[spectrogram, delta.numpy(), delta_delta.numpy()], axis=0
|
35 |
+
)
|
36 |
+
out_feat = np.transpose(out_feat)
|
37 |
+
return out_feat
|
modules/voice_conversion/fairseq/data/audio/feature_transforms/global_cmvn.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from fairseq.data.audio.feature_transforms import (
|
3 |
+
AudioFeatureTransform,
|
4 |
+
register_audio_feature_transform,
|
5 |
+
)
|
6 |
+
|
7 |
+
|
8 |
+
@register_audio_feature_transform("global_cmvn")
|
9 |
+
class GlobalCMVN(AudioFeatureTransform):
|
10 |
+
"""Global CMVN (cepstral mean and variance normalization). The global mean
|
11 |
+
and variance need to be pre-computed and stored in NumPy format (.npz)."""
|
12 |
+
|
13 |
+
@classmethod
|
14 |
+
def from_config_dict(cls, config=None):
|
15 |
+
_config = {} if config is None else config
|
16 |
+
return GlobalCMVN(_config.get("stats_npz_path"))
|
17 |
+
|
18 |
+
def __init__(self, stats_npz_path):
|
19 |
+
self.stats_npz_path = stats_npz_path
|
20 |
+
stats = np.load(stats_npz_path)
|
21 |
+
self.mean, self.std = stats["mean"], stats["std"]
|
22 |
+
|
23 |
+
def __repr__(self):
|
24 |
+
return self.__class__.__name__ + f'(stats_npz_path="{self.stats_npz_path}")'
|
25 |
+
|
26 |
+
def __call__(self, x):
|
27 |
+
x = np.subtract(x, self.mean)
|
28 |
+
x = np.divide(x, self.std)
|
29 |
+
return x
|
modules/voice_conversion/fairseq/data/audio/feature_transforms/specaugment.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numbers
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from fairseq.data.audio.feature_transforms import (
|
7 |
+
AudioFeatureTransform,
|
8 |
+
register_audio_feature_transform,
|
9 |
+
)
|
10 |
+
|
11 |
+
|
12 |
+
@register_audio_feature_transform("specaugment")
|
13 |
+
class SpecAugmentTransform(AudioFeatureTransform):
|
14 |
+
"""SpecAugment (https://arxiv.org/abs/1904.08779)"""
|
15 |
+
|
16 |
+
@classmethod
|
17 |
+
def from_config_dict(cls, config=None):
|
18 |
+
_config = {} if config is None else config
|
19 |
+
return SpecAugmentTransform(
|
20 |
+
_config.get("time_warp_W", 0),
|
21 |
+
_config.get("freq_mask_N", 0),
|
22 |
+
_config.get("freq_mask_F", 0),
|
23 |
+
_config.get("time_mask_N", 0),
|
24 |
+
_config.get("time_mask_T", 0),
|
25 |
+
_config.get("time_mask_p", 0.0),
|
26 |
+
_config.get("mask_value", None),
|
27 |
+
)
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
time_warp_w: int = 0,
|
32 |
+
freq_mask_n: int = 0,
|
33 |
+
freq_mask_f: int = 0,
|
34 |
+
time_mask_n: int = 0,
|
35 |
+
time_mask_t: int = 0,
|
36 |
+
time_mask_p: float = 0.0,
|
37 |
+
mask_value: Optional[float] = 0.0,
|
38 |
+
):
|
39 |
+
# Sanity checks
|
40 |
+
assert mask_value is None or isinstance(
|
41 |
+
mask_value, numbers.Number
|
42 |
+
), f"mask_value (type: {type(mask_value)}) must be None or a number"
|
43 |
+
if freq_mask_n > 0:
|
44 |
+
assert freq_mask_f > 0, (
|
45 |
+
f"freq_mask_F ({freq_mask_f}) "
|
46 |
+
f"must be larger than 0 when doing freq masking."
|
47 |
+
)
|
48 |
+
if time_mask_n > 0:
|
49 |
+
assert time_mask_t > 0, (
|
50 |
+
f"time_mask_T ({time_mask_t}) must be larger than 0 when "
|
51 |
+
f"doing time masking."
|
52 |
+
)
|
53 |
+
|
54 |
+
self.time_warp_w = time_warp_w
|
55 |
+
self.freq_mask_n = freq_mask_n
|
56 |
+
self.freq_mask_f = freq_mask_f
|
57 |
+
self.time_mask_n = time_mask_n
|
58 |
+
self.time_mask_t = time_mask_t
|
59 |
+
self.time_mask_p = time_mask_p
|
60 |
+
self.mask_value = mask_value
|
61 |
+
|
62 |
+
def __repr__(self):
|
63 |
+
return (
|
64 |
+
self.__class__.__name__
|
65 |
+
+ "("
|
66 |
+
+ ", ".join(
|
67 |
+
[
|
68 |
+
f"time_warp_w={self.time_warp_w}",
|
69 |
+
f"freq_mask_n={self.freq_mask_n}",
|
70 |
+
f"freq_mask_f={self.freq_mask_f}",
|
71 |
+
f"time_mask_n={self.time_mask_n}",
|
72 |
+
f"time_mask_t={self.time_mask_t}",
|
73 |
+
f"time_mask_p={self.time_mask_p}",
|
74 |
+
]
|
75 |
+
)
|
76 |
+
+ ")"
|
77 |
+
)
|
78 |
+
|
79 |
+
def __call__(self, spectrogram):
|
80 |
+
assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor."
|
81 |
+
|
82 |
+
distorted = spectrogram.copy() # make a copy of input spectrogram.
|
83 |
+
num_frames = spectrogram.shape[0] # or 'tau' in the paper.
|
84 |
+
num_freqs = spectrogram.shape[1] # or 'miu' in the paper.
|
85 |
+
mask_value = self.mask_value
|
86 |
+
|
87 |
+
if mask_value is None: # if no value was specified, use local mean.
|
88 |
+
mask_value = spectrogram.mean()
|
89 |
+
|
90 |
+
if num_frames == 0:
|
91 |
+
return spectrogram
|
92 |
+
|
93 |
+
if num_freqs < self.freq_mask_f:
|
94 |
+
return spectrogram
|
95 |
+
|
96 |
+
if self.time_warp_w > 0:
|
97 |
+
if 2 * self.time_warp_w < num_frames:
|
98 |
+
import cv2
|
99 |
+
|
100 |
+
w0 = np.random.randint(self.time_warp_w, num_frames - self.time_warp_w)
|
101 |
+
w = np.random.randint(-self.time_warp_w + 1, self.time_warp_w)
|
102 |
+
upper, lower = distorted[:w0, :], distorted[w0:, :]
|
103 |
+
upper = cv2.resize(
|
104 |
+
upper, dsize=(num_freqs, w0 + w), interpolation=cv2.INTER_LINEAR
|
105 |
+
)
|
106 |
+
lower = cv2.resize(
|
107 |
+
lower,
|
108 |
+
dsize=(num_freqs, num_frames - w0 - w),
|
109 |
+
interpolation=cv2.INTER_LINEAR,
|
110 |
+
)
|
111 |
+
distorted = np.concatenate((upper, lower), axis=0)
|
112 |
+
|
113 |
+
for _i in range(self.freq_mask_n):
|
114 |
+
f = np.random.randint(0, self.freq_mask_f)
|
115 |
+
f0 = np.random.randint(0, num_freqs - f)
|
116 |
+
if f != 0:
|
117 |
+
distorted[:, f0 : f0 + f] = mask_value
|
118 |
+
|
119 |
+
max_time_mask_t = min(
|
120 |
+
self.time_mask_t, math.floor(num_frames * self.time_mask_p)
|
121 |
+
)
|
122 |
+
if max_time_mask_t < 1:
|
123 |
+
return distorted
|
124 |
+
|
125 |
+
for _i in range(self.time_mask_n):
|
126 |
+
t = np.random.randint(0, max_time_mask_t)
|
127 |
+
t0 = np.random.randint(0, num_frames - t)
|
128 |
+
if t != 0:
|
129 |
+
distorted[t0 : t0 + t, :] = mask_value
|
130 |
+
|
131 |
+
return distorted
|
modules/voice_conversion/fairseq/data/audio/feature_transforms/utterance_cmvn.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
from fairseq.data.audio.feature_transforms import (
|
4 |
+
AudioFeatureTransform,
|
5 |
+
register_audio_feature_transform,
|
6 |
+
)
|
7 |
+
|
8 |
+
|
9 |
+
@register_audio_feature_transform("utterance_cmvn")
|
10 |
+
class UtteranceCMVN(AudioFeatureTransform):
|
11 |
+
"""Utterance-level CMVN (cepstral mean and variance normalization)"""
|
12 |
+
|
13 |
+
@classmethod
|
14 |
+
def from_config_dict(cls, config=None):
|
15 |
+
_config = {} if config is None else config
|
16 |
+
return UtteranceCMVN(
|
17 |
+
_config.get("norm_means", True),
|
18 |
+
_config.get("norm_vars", True),
|
19 |
+
)
|
20 |
+
|
21 |
+
def __init__(self, norm_means=True, norm_vars=True):
|
22 |
+
self.norm_means, self.norm_vars = norm_means, norm_vars
|
23 |
+
|
24 |
+
def __repr__(self):
|
25 |
+
return (
|
26 |
+
self.__class__.__name__
|
27 |
+
+ f"(norm_means={self.norm_means}, norm_vars={self.norm_vars})"
|
28 |
+
)
|
29 |
+
|
30 |
+
def __call__(self, x):
|
31 |
+
mean = x.mean(axis=0)
|
32 |
+
square_sums = (x**2).sum(axis=0)
|
33 |
+
|
34 |
+
if self.norm_means:
|
35 |
+
x = np.subtract(x, mean)
|
36 |
+
if self.norm_vars:
|
37 |
+
var = square_sums / x.shape[0] - mean**2
|
38 |
+
std = np.sqrt(np.maximum(var, 1e-10))
|
39 |
+
x = np.divide(x, std)
|
40 |
+
|
41 |
+
return x
|
modules/voice_conversion/fairseq/data/audio/frm_text_to_speech_dataset.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the LICENSE file in
|
5 |
+
# the root directory of this source tree. An additional grant of patent rights
|
6 |
+
# can be found in the PATENTS file in the same directory.abs
|
7 |
+
|
8 |
+
import csv
|
9 |
+
import logging
|
10 |
+
import os.path as op
|
11 |
+
from typing import List, Optional
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
from fairseq.data import Dictionary
|
16 |
+
from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig
|
17 |
+
from fairseq.data.audio.text_to_speech_dataset import (
|
18 |
+
TextToSpeechDataset,
|
19 |
+
TextToSpeechDatasetCreator,
|
20 |
+
)
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
class FrmTextToSpeechDataset(TextToSpeechDataset):
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
split: str,
|
29 |
+
is_train_split: bool,
|
30 |
+
data_cfg: S2TDataConfig,
|
31 |
+
audio_paths: List[str],
|
32 |
+
n_frames: List[int],
|
33 |
+
src_texts: Optional[List[str]] = None,
|
34 |
+
tgt_texts: Optional[List[str]] = None,
|
35 |
+
speakers: Optional[List[str]] = None,
|
36 |
+
src_langs: Optional[List[str]] = None,
|
37 |
+
tgt_langs: Optional[List[str]] = None,
|
38 |
+
ids: Optional[List[str]] = None,
|
39 |
+
tgt_dict: Optional[Dictionary] = None,
|
40 |
+
pre_tokenizer=None,
|
41 |
+
bpe_tokenizer=None,
|
42 |
+
n_frames_per_step=1,
|
43 |
+
speaker_to_id=None,
|
44 |
+
do_chunk=False,
|
45 |
+
chunk_bound=-1,
|
46 |
+
chunk_init=50,
|
47 |
+
chunk_incr=5,
|
48 |
+
add_eos=True,
|
49 |
+
dedup=True,
|
50 |
+
ref_fpu=-1,
|
51 |
+
):
|
52 |
+
# It assumes texts are encoded at a fixed frame-rate
|
53 |
+
super().__init__(
|
54 |
+
split=split,
|
55 |
+
is_train_split=is_train_split,
|
56 |
+
data_cfg=data_cfg,
|
57 |
+
audio_paths=audio_paths,
|
58 |
+
n_frames=n_frames,
|
59 |
+
src_texts=src_texts,
|
60 |
+
tgt_texts=tgt_texts,
|
61 |
+
speakers=speakers,
|
62 |
+
src_langs=src_langs,
|
63 |
+
tgt_langs=tgt_langs,
|
64 |
+
ids=ids,
|
65 |
+
tgt_dict=tgt_dict,
|
66 |
+
pre_tokenizer=pre_tokenizer,
|
67 |
+
bpe_tokenizer=bpe_tokenizer,
|
68 |
+
n_frames_per_step=n_frames_per_step,
|
69 |
+
speaker_to_id=speaker_to_id,
|
70 |
+
)
|
71 |
+
|
72 |
+
self.do_chunk = do_chunk
|
73 |
+
self.chunk_bound = chunk_bound
|
74 |
+
self.chunk_init = chunk_init
|
75 |
+
self.chunk_incr = chunk_incr
|
76 |
+
self.add_eos = add_eos
|
77 |
+
self.dedup = dedup
|
78 |
+
self.ref_fpu = ref_fpu
|
79 |
+
|
80 |
+
self.chunk_size = -1
|
81 |
+
|
82 |
+
if do_chunk:
|
83 |
+
assert self.chunk_incr >= 0
|
84 |
+
assert self.pre_tokenizer is None
|
85 |
+
|
86 |
+
def __getitem__(self, index):
|
87 |
+
index, source, target, speaker_id, _, _, _ = super().__getitem__(index)
|
88 |
+
if target[-1].item() == self.tgt_dict.eos_index:
|
89 |
+
target = target[:-1]
|
90 |
+
|
91 |
+
fpu = source.size(0) / target.size(0) # frame-per-unit
|
92 |
+
fps = self.n_frames_per_step
|
93 |
+
assert (
|
94 |
+
self.ref_fpu == -1 or abs((fpu * fps - self.ref_fpu) / self.ref_fpu) < 0.1
|
95 |
+
), f"{fpu*fps} != {self.ref_fpu}"
|
96 |
+
|
97 |
+
# only chunk training split
|
98 |
+
if self.is_train_split and self.do_chunk and self.chunk_size > 0:
|
99 |
+
lang = target[: int(self.data_cfg.prepend_tgt_lang_tag)]
|
100 |
+
text = target[int(self.data_cfg.prepend_tgt_lang_tag) :]
|
101 |
+
size = len(text)
|
102 |
+
chunk_size = min(self.chunk_size, size)
|
103 |
+
chunk_start = np.random.randint(size - chunk_size + 1)
|
104 |
+
text = text[chunk_start : chunk_start + chunk_size]
|
105 |
+
target = torch.cat((lang, text), 0)
|
106 |
+
|
107 |
+
f_size = int(np.floor(chunk_size * fpu))
|
108 |
+
f_start = int(np.floor(chunk_start * fpu))
|
109 |
+
assert f_size > 0
|
110 |
+
source = source[f_start : f_start + f_size, :]
|
111 |
+
|
112 |
+
if self.dedup:
|
113 |
+
target = torch.unique_consecutive(target)
|
114 |
+
|
115 |
+
if self.add_eos:
|
116 |
+
eos_idx = self.tgt_dict.eos_index
|
117 |
+
target = torch.cat((target, torch.LongTensor([eos_idx])), 0)
|
118 |
+
|
119 |
+
return index, source, target, speaker_id
|
120 |
+
|
121 |
+
def set_epoch(self, epoch):
|
122 |
+
if self.is_train_split and self.do_chunk:
|
123 |
+
old = self.chunk_size
|
124 |
+
self.chunk_size = self.chunk_init + epoch * self.chunk_incr
|
125 |
+
if self.chunk_bound > 0:
|
126 |
+
self.chunk_size = min(self.chunk_size, self.chunk_bound)
|
127 |
+
logger.info(
|
128 |
+
(
|
129 |
+
f"{self.split}: setting chunk size "
|
130 |
+
f"from {old} to {self.chunk_size}"
|
131 |
+
)
|
132 |
+
)
|
133 |
+
|
134 |
+
|
135 |
+
class FrmTextToSpeechDatasetCreator(TextToSpeechDatasetCreator):
|
136 |
+
# inherit for key names
|
137 |
+
@classmethod
|
138 |
+
def from_tsv(
|
139 |
+
cls,
|
140 |
+
root: str,
|
141 |
+
data_cfg: S2TDataConfig,
|
142 |
+
split: str,
|
143 |
+
tgt_dict,
|
144 |
+
pre_tokenizer,
|
145 |
+
bpe_tokenizer,
|
146 |
+
is_train_split: bool,
|
147 |
+
n_frames_per_step: int,
|
148 |
+
speaker_to_id,
|
149 |
+
do_chunk: bool = False,
|
150 |
+
chunk_bound: int = -1,
|
151 |
+
chunk_init: int = 50,
|
152 |
+
chunk_incr: int = 5,
|
153 |
+
add_eos: bool = True,
|
154 |
+
dedup: bool = True,
|
155 |
+
ref_fpu: float = -1,
|
156 |
+
) -> FrmTextToSpeechDataset:
|
157 |
+
tsv_path = op.join(root, f"{split}.tsv")
|
158 |
+
if not op.isfile(tsv_path):
|
159 |
+
raise FileNotFoundError(f"Dataset not found: {tsv_path}")
|
160 |
+
with open(tsv_path) as f:
|
161 |
+
reader = csv.DictReader(
|
162 |
+
f,
|
163 |
+
delimiter="\t",
|
164 |
+
quotechar=None,
|
165 |
+
doublequote=False,
|
166 |
+
lineterminator="\n",
|
167 |
+
quoting=csv.QUOTE_NONE,
|
168 |
+
)
|
169 |
+
s = [dict(e) for e in reader]
|
170 |
+
assert len(s) > 0
|
171 |
+
|
172 |
+
ids = [ss[cls.KEY_ID] for ss in s]
|
173 |
+
audio_paths = [op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO]) for ss in s]
|
174 |
+
n_frames = [int(ss[cls.KEY_N_FRAMES]) for ss in s]
|
175 |
+
tgt_texts = [ss[cls.KEY_TGT_TEXT] for ss in s]
|
176 |
+
src_texts = [ss.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for ss in s]
|
177 |
+
speakers = [ss.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for ss in s]
|
178 |
+
src_langs = [ss.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for ss in s]
|
179 |
+
tgt_langs = [ss.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for ss in s]
|
180 |
+
|
181 |
+
return FrmTextToSpeechDataset(
|
182 |
+
split=split,
|
183 |
+
is_train_split=is_train_split,
|
184 |
+
data_cfg=data_cfg,
|
185 |
+
audio_paths=audio_paths,
|
186 |
+
n_frames=n_frames,
|
187 |
+
src_texts=src_texts,
|
188 |
+
tgt_texts=tgt_texts,
|
189 |
+
speakers=speakers,
|
190 |
+
src_langs=src_langs,
|
191 |
+
tgt_langs=tgt_langs,
|
192 |
+
ids=ids,
|
193 |
+
tgt_dict=tgt_dict,
|
194 |
+
pre_tokenizer=pre_tokenizer,
|
195 |
+
bpe_tokenizer=bpe_tokenizer,
|
196 |
+
n_frames_per_step=n_frames_per_step,
|
197 |
+
speaker_to_id=speaker_to_id,
|
198 |
+
do_chunk=do_chunk,
|
199 |
+
chunk_bound=chunk_bound,
|
200 |
+
chunk_init=chunk_init,
|
201 |
+
chunk_incr=chunk_incr,
|
202 |
+
add_eos=add_eos,
|
203 |
+
dedup=dedup,
|
204 |
+
ref_fpu=ref_fpu,
|
205 |
+
)
|
modules/voice_conversion/fairseq/data/audio/hubert_dataset.py
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import itertools
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
from typing import Any, List, Optional, Union
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from fairseq.data import data_utils
|
17 |
+
from fairseq.data.fairseq_dataset import FairseqDataset
|
18 |
+
from fairseq.data.audio.audio_utils import (
|
19 |
+
parse_path,
|
20 |
+
read_from_stored_zip,
|
21 |
+
)
|
22 |
+
import io
|
23 |
+
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
|
27 |
+
def load_audio(manifest_path, max_keep, min_keep):
|
28 |
+
n_long, n_short = 0, 0
|
29 |
+
names, inds, sizes = [], [], []
|
30 |
+
with open(manifest_path) as f:
|
31 |
+
root = f.readline().strip()
|
32 |
+
for ind, line in enumerate(f):
|
33 |
+
items = line.strip().split("\t")
|
34 |
+
assert len(items) == 2, line
|
35 |
+
sz = int(items[1])
|
36 |
+
if min_keep is not None and sz < min_keep:
|
37 |
+
n_short += 1
|
38 |
+
elif max_keep is not None and sz > max_keep:
|
39 |
+
n_long += 1
|
40 |
+
else:
|
41 |
+
names.append(items[0])
|
42 |
+
inds.append(ind)
|
43 |
+
sizes.append(sz)
|
44 |
+
tot = ind + 1
|
45 |
+
logger.info(
|
46 |
+
(
|
47 |
+
f"max_keep={max_keep}, min_keep={min_keep}, "
|
48 |
+
f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
|
49 |
+
f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
|
50 |
+
)
|
51 |
+
)
|
52 |
+
return root, names, inds, tot, sizes
|
53 |
+
|
54 |
+
|
55 |
+
def load_label(label_path, inds, tot):
|
56 |
+
with open(label_path) as f:
|
57 |
+
labels = [line.rstrip() for line in f]
|
58 |
+
assert (
|
59 |
+
len(labels) == tot
|
60 |
+
), f"number of labels does not match ({len(labels)} != {tot})"
|
61 |
+
labels = [labels[i] for i in inds]
|
62 |
+
return labels
|
63 |
+
|
64 |
+
|
65 |
+
def load_label_offset(label_path, inds, tot):
|
66 |
+
with open(label_path) as f:
|
67 |
+
code_lengths = [len(line.encode("utf-8")) for line in f]
|
68 |
+
assert (
|
69 |
+
len(code_lengths) == tot
|
70 |
+
), f"number of labels does not match ({len(code_lengths)} != {tot})"
|
71 |
+
offsets = list(itertools.accumulate([0] + code_lengths))
|
72 |
+
offsets = [(offsets[i], offsets[i + 1]) for i in inds]
|
73 |
+
return offsets
|
74 |
+
|
75 |
+
|
76 |
+
def verify_label_lengths(
|
77 |
+
audio_sizes,
|
78 |
+
audio_rate,
|
79 |
+
label_path,
|
80 |
+
label_rate,
|
81 |
+
inds,
|
82 |
+
tot,
|
83 |
+
tol=0.1, # tolerance in seconds
|
84 |
+
):
|
85 |
+
if label_rate < 0:
|
86 |
+
logger.info(f"{label_path} is sequence label. skipped")
|
87 |
+
return
|
88 |
+
|
89 |
+
with open(label_path) as f:
|
90 |
+
lengths = [len(line.rstrip().split()) for line in f]
|
91 |
+
assert len(lengths) == tot
|
92 |
+
lengths = [lengths[i] for i in inds]
|
93 |
+
num_invalid = 0
|
94 |
+
for i, ind in enumerate(inds):
|
95 |
+
dur_from_audio = audio_sizes[i] / audio_rate
|
96 |
+
dur_from_label = lengths[i] / label_rate
|
97 |
+
if abs(dur_from_audio - dur_from_label) > tol:
|
98 |
+
logger.warning(
|
99 |
+
(
|
100 |
+
f"audio and label duration differ too much "
|
101 |
+
f"(|{dur_from_audio} - {dur_from_label}| > {tol}) "
|
102 |
+
f"in line {ind+1} of {label_path}. Check if `label_rate` "
|
103 |
+
f"is correctly set (currently {label_rate}). "
|
104 |
+
f"num. of samples = {audio_sizes[i]}; "
|
105 |
+
f"label length = {lengths[i]}"
|
106 |
+
)
|
107 |
+
)
|
108 |
+
num_invalid += 1
|
109 |
+
if num_invalid > 0:
|
110 |
+
logger.warning(
|
111 |
+
f"total {num_invalid} (audio, label) pairs with mismatched lengths"
|
112 |
+
)
|
113 |
+
|
114 |
+
|
115 |
+
class HubertDataset(FairseqDataset):
|
116 |
+
def __init__(
|
117 |
+
self,
|
118 |
+
manifest_path: str,
|
119 |
+
sample_rate: float,
|
120 |
+
label_paths: List[str],
|
121 |
+
label_rates: Union[List[float], float], # -1 for sequence labels
|
122 |
+
pad_list: List[str],
|
123 |
+
eos_list: List[str],
|
124 |
+
label_processors: Optional[List[Any]] = None,
|
125 |
+
max_keep_sample_size: Optional[int] = None,
|
126 |
+
min_keep_sample_size: Optional[int] = None,
|
127 |
+
max_sample_size: Optional[int] = None,
|
128 |
+
shuffle: bool = True,
|
129 |
+
pad_audio: bool = False,
|
130 |
+
normalize: bool = False,
|
131 |
+
store_labels: bool = True,
|
132 |
+
random_crop: bool = False,
|
133 |
+
single_target: bool = False,
|
134 |
+
):
|
135 |
+
self.audio_root, self.audio_names, inds, tot, self.sizes = load_audio(
|
136 |
+
manifest_path, max_keep_sample_size, min_keep_sample_size
|
137 |
+
)
|
138 |
+
self.sample_rate = sample_rate
|
139 |
+
self.shuffle = shuffle
|
140 |
+
self.random_crop = random_crop
|
141 |
+
|
142 |
+
self.num_labels = len(label_paths)
|
143 |
+
self.pad_list = pad_list
|
144 |
+
self.eos_list = eos_list
|
145 |
+
self.label_processors = label_processors
|
146 |
+
self.single_target = single_target
|
147 |
+
self.label_rates = (
|
148 |
+
[label_rates for _ in range(len(label_paths))]
|
149 |
+
if isinstance(label_rates, float)
|
150 |
+
else label_rates
|
151 |
+
)
|
152 |
+
self.store_labels = store_labels
|
153 |
+
if store_labels:
|
154 |
+
self.label_list = [load_label(p, inds, tot) for p in label_paths]
|
155 |
+
else:
|
156 |
+
self.label_paths = label_paths
|
157 |
+
self.label_offsets_list = [
|
158 |
+
load_label_offset(p, inds, tot) for p in label_paths
|
159 |
+
]
|
160 |
+
assert label_processors is None or len(label_processors) == self.num_labels
|
161 |
+
for label_path, label_rate in zip(label_paths, self.label_rates):
|
162 |
+
verify_label_lengths(
|
163 |
+
self.sizes, sample_rate, label_path, label_rate, inds, tot
|
164 |
+
)
|
165 |
+
|
166 |
+
self.max_sample_size = (
|
167 |
+
max_sample_size if max_sample_size is not None else sys.maxsize
|
168 |
+
)
|
169 |
+
self.pad_audio = pad_audio
|
170 |
+
self.normalize = normalize
|
171 |
+
logger.info(
|
172 |
+
f"pad_audio={pad_audio}, random_crop={random_crop}, "
|
173 |
+
f"normalize={normalize}, max_sample_size={self.max_sample_size}"
|
174 |
+
)
|
175 |
+
|
176 |
+
def get_audio(self, index):
|
177 |
+
import soundfile as sf
|
178 |
+
|
179 |
+
wav_path = os.path.join(self.audio_root, self.audio_names[index])
|
180 |
+
_path, slice_ptr = parse_path(wav_path)
|
181 |
+
if len(slice_ptr) == 0:
|
182 |
+
wav, cur_sample_rate = sf.read(_path)
|
183 |
+
else:
|
184 |
+
assert _path.endswith(".zip")
|
185 |
+
data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
|
186 |
+
f = io.BytesIO(data)
|
187 |
+
wav, cur_sample_rate = sf.read(f)
|
188 |
+
wav = torch.from_numpy(wav).float()
|
189 |
+
wav = self.postprocess(wav, cur_sample_rate)
|
190 |
+
return wav
|
191 |
+
|
192 |
+
def get_label(self, index, label_idx):
|
193 |
+
if self.store_labels:
|
194 |
+
label = self.label_list[label_idx][index]
|
195 |
+
else:
|
196 |
+
with open(self.label_paths[label_idx]) as f:
|
197 |
+
offset_s, offset_e = self.label_offsets_list[label_idx][index]
|
198 |
+
f.seek(offset_s)
|
199 |
+
label = f.read(offset_e - offset_s)
|
200 |
+
|
201 |
+
if self.label_processors is not None:
|
202 |
+
label = self.label_processors[label_idx](label)
|
203 |
+
return label
|
204 |
+
|
205 |
+
def get_labels(self, index):
|
206 |
+
return [self.get_label(index, i) for i in range(self.num_labels)]
|
207 |
+
|
208 |
+
def __getitem__(self, index):
|
209 |
+
wav = self.get_audio(index)
|
210 |
+
labels = self.get_labels(index)
|
211 |
+
return {"id": index, "source": wav, "label_list": labels}
|
212 |
+
|
213 |
+
def __len__(self):
|
214 |
+
return len(self.sizes)
|
215 |
+
|
216 |
+
def crop_to_max_size(self, wav, target_size):
|
217 |
+
size = len(wav)
|
218 |
+
diff = size - target_size
|
219 |
+
if diff <= 0:
|
220 |
+
return wav, 0
|
221 |
+
|
222 |
+
start, end = 0, target_size
|
223 |
+
if self.random_crop:
|
224 |
+
start = np.random.randint(0, diff + 1)
|
225 |
+
end = size - diff + start
|
226 |
+
return wav[start:end], start
|
227 |
+
|
228 |
+
def collater(self, samples):
|
229 |
+
# target = max(sizes) -> random_crop not used
|
230 |
+
# target = max_sample_size -> random_crop used for long
|
231 |
+
samples = [s for s in samples if s["source"] is not None]
|
232 |
+
if len(samples) == 0:
|
233 |
+
return {}
|
234 |
+
|
235 |
+
audios = [s["source"] for s in samples]
|
236 |
+
audio_sizes = [len(s) for s in audios]
|
237 |
+
if self.pad_audio:
|
238 |
+
audio_size = min(max(audio_sizes), self.max_sample_size)
|
239 |
+
else:
|
240 |
+
audio_size = min(min(audio_sizes), self.max_sample_size)
|
241 |
+
collated_audios, padding_mask, audio_starts = self.collater_audio(
|
242 |
+
audios, audio_size
|
243 |
+
)
|
244 |
+
|
245 |
+
targets_by_label = [
|
246 |
+
[s["label_list"][i] for s in samples] for i in range(self.num_labels)
|
247 |
+
]
|
248 |
+
targets_list, lengths_list, ntokens_list = self.collater_label(
|
249 |
+
targets_by_label, audio_size, audio_starts
|
250 |
+
)
|
251 |
+
|
252 |
+
net_input = {"source": collated_audios, "padding_mask": padding_mask}
|
253 |
+
batch = {
|
254 |
+
"id": torch.LongTensor([s["id"] for s in samples]),
|
255 |
+
"net_input": net_input,
|
256 |
+
}
|
257 |
+
|
258 |
+
if self.single_target:
|
259 |
+
batch["target_lengths"] = lengths_list[0]
|
260 |
+
batch["ntokens"] = ntokens_list[0]
|
261 |
+
batch["target"] = targets_list[0]
|
262 |
+
else:
|
263 |
+
batch["target_lengths_list"] = lengths_list
|
264 |
+
batch["ntokens_list"] = ntokens_list
|
265 |
+
batch["target_list"] = targets_list
|
266 |
+
return batch
|
267 |
+
|
268 |
+
def collater_audio(self, audios, audio_size):
|
269 |
+
collated_audios = audios[0].new_zeros(len(audios), audio_size)
|
270 |
+
padding_mask = (
|
271 |
+
torch.BoolTensor(collated_audios.shape).fill_(False)
|
272 |
+
# if self.pad_audio else None
|
273 |
+
)
|
274 |
+
audio_starts = [0 for _ in audios]
|
275 |
+
for i, audio in enumerate(audios):
|
276 |
+
diff = len(audio) - audio_size
|
277 |
+
if diff == 0:
|
278 |
+
collated_audios[i] = audio
|
279 |
+
elif diff < 0:
|
280 |
+
assert self.pad_audio
|
281 |
+
collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
|
282 |
+
padding_mask[i, diff:] = True
|
283 |
+
else:
|
284 |
+
collated_audios[i], audio_starts[i] = self.crop_to_max_size(
|
285 |
+
audio, audio_size
|
286 |
+
)
|
287 |
+
return collated_audios, padding_mask, audio_starts
|
288 |
+
|
289 |
+
def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
|
290 |
+
assert label_rate > 0
|
291 |
+
s2f = label_rate / self.sample_rate
|
292 |
+
frm_starts = [int(round(s * s2f)) for s in audio_starts]
|
293 |
+
frm_size = int(round(audio_size * s2f))
|
294 |
+
if not self.pad_audio:
|
295 |
+
rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
|
296 |
+
frm_size = min(frm_size, *rem_size)
|
297 |
+
targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
|
298 |
+
logger.debug(f"audio_starts={audio_starts}")
|
299 |
+
logger.debug(f"frame_starts={frm_starts}")
|
300 |
+
logger.debug(f"frame_size={frm_size}")
|
301 |
+
|
302 |
+
lengths = torch.LongTensor([len(t) for t in targets])
|
303 |
+
ntokens = lengths.sum().item()
|
304 |
+
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
|
305 |
+
return targets, lengths, ntokens
|
306 |
+
|
307 |
+
def collater_seq_label(self, targets, pad):
|
308 |
+
lengths = torch.LongTensor([len(t) for t in targets])
|
309 |
+
ntokens = lengths.sum().item()
|
310 |
+
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
|
311 |
+
return targets, lengths, ntokens
|
312 |
+
|
313 |
+
def collater_label(self, targets_by_label, audio_size, audio_starts):
|
314 |
+
targets_list, lengths_list, ntokens_list = [], [], []
|
315 |
+
itr = zip(targets_by_label, self.label_rates, self.pad_list)
|
316 |
+
for targets, label_rate, pad in itr:
|
317 |
+
if label_rate == -1.0:
|
318 |
+
targets, lengths, ntokens = self.collater_seq_label(targets, pad)
|
319 |
+
else:
|
320 |
+
targets, lengths, ntokens = self.collater_frm_label(
|
321 |
+
targets, audio_size, audio_starts, label_rate, pad
|
322 |
+
)
|
323 |
+
targets_list.append(targets)
|
324 |
+
lengths_list.append(lengths)
|
325 |
+
ntokens_list.append(ntokens)
|
326 |
+
return targets_list, lengths_list, ntokens_list
|
327 |
+
|
328 |
+
def num_tokens(self, index):
|
329 |
+
return self.size(index)
|
330 |
+
|
331 |
+
def size(self, index):
|
332 |
+
if self.pad_audio:
|
333 |
+
return self.sizes[index]
|
334 |
+
return min(self.sizes[index], self.max_sample_size)
|
335 |
+
|
336 |
+
def ordered_indices(self):
|
337 |
+
if self.shuffle:
|
338 |
+
order = [np.random.permutation(len(self))]
|
339 |
+
else:
|
340 |
+
order = [np.arange(len(self))]
|
341 |
+
|
342 |
+
order.append(self.sizes)
|
343 |
+
return np.lexsort(order)[::-1]
|
344 |
+
|
345 |
+
def postprocess(self, wav, cur_sample_rate):
|
346 |
+
if wav.dim() == 2:
|
347 |
+
wav = wav.mean(-1)
|
348 |
+
assert wav.dim() == 1, wav.dim()
|
349 |
+
|
350 |
+
if cur_sample_rate != self.sample_rate:
|
351 |
+
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
|
352 |
+
|
353 |
+
if self.normalize:
|
354 |
+
with torch.no_grad():
|
355 |
+
wav = F.layer_norm(wav, wav.shape)
|
356 |
+
return wav
|
modules/voice_conversion/fairseq/data/audio/multi_modality_dataset.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021-present, Facebook, Inc.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the LICENSE file in
|
5 |
+
# the root directory of this source tree. An additional grant of patent rights
|
6 |
+
# can be found in the PATENTS file in the same directory.
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import math
|
10 |
+
from typing import List, Optional, NamedTuple
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
from fairseq.data.resampling_dataset import ResamplingDataset
|
14 |
+
import torch
|
15 |
+
from fairseq.data import (
|
16 |
+
ConcatDataset,
|
17 |
+
LanguagePairDataset,
|
18 |
+
FileAudioDataset,
|
19 |
+
data_utils,
|
20 |
+
)
|
21 |
+
from fairseq.data import FairseqDataset
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
|
26 |
+
class ModalityDatasetItem(NamedTuple):
|
27 |
+
datasetname: str
|
28 |
+
dataset: any
|
29 |
+
max_positions: List[int]
|
30 |
+
max_tokens: Optional[int] = None
|
31 |
+
max_sentences: Optional[int] = None
|
32 |
+
|
33 |
+
|
34 |
+
def resampling_dataset_present(ds):
|
35 |
+
if isinstance(ds, ResamplingDataset):
|
36 |
+
return True
|
37 |
+
if isinstance(ds, ConcatDataset):
|
38 |
+
return any(resampling_dataset_present(d) for d in ds.datasets)
|
39 |
+
if hasattr(ds, "dataset"):
|
40 |
+
return resampling_dataset_present(ds.dataset)
|
41 |
+
return False
|
42 |
+
|
43 |
+
|
44 |
+
# MultiModalityDataset: it concate multiple datasets with different modalities.
|
45 |
+
# Compared with ConcatDataset it can 1) sample data given the ratios for different datasets
|
46 |
+
# 2) it adds mode to indicate what type of the data samples come from.
|
47 |
+
# It will be used with GroupedEpochBatchIterator together to generate mini-batch with samples
|
48 |
+
# from the same type of dataset
|
49 |
+
# If only one dataset is used, it will perform like the original dataset with mode added
|
50 |
+
class MultiModalityDataset(ConcatDataset):
|
51 |
+
def __init__(self, datasets: List[ModalityDatasetItem]):
|
52 |
+
id_to_mode = []
|
53 |
+
dsets = []
|
54 |
+
max_tokens = []
|
55 |
+
max_sentences = []
|
56 |
+
max_positions = []
|
57 |
+
for dset in datasets:
|
58 |
+
id_to_mode.append(dset.datasetname)
|
59 |
+
dsets.append(dset.dataset)
|
60 |
+
max_tokens.append(dset.max_tokens)
|
61 |
+
max_positions.append(dset.max_positions)
|
62 |
+
max_sentences.append(dset.max_sentences)
|
63 |
+
weights = [1.0 for s in dsets]
|
64 |
+
super().__init__(dsets, weights)
|
65 |
+
self.max_tokens = max_tokens
|
66 |
+
self.max_positions = max_positions
|
67 |
+
self.max_sentences = max_sentences
|
68 |
+
self.id_to_mode = id_to_mode
|
69 |
+
self.raw_sub_batch_samplers = []
|
70 |
+
self._cur_epoch = 0
|
71 |
+
|
72 |
+
def set_epoch(self, epoch):
|
73 |
+
super().set_epoch(epoch)
|
74 |
+
self._cur_epoch = epoch
|
75 |
+
|
76 |
+
def __getitem__(self, idx):
|
77 |
+
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
|
78 |
+
sample = self.datasets[dataset_idx][sample_idx]
|
79 |
+
return (dataset_idx, sample)
|
80 |
+
|
81 |
+
def collater(self, samples):
|
82 |
+
if len(samples) == 0:
|
83 |
+
return {}
|
84 |
+
dataset_idx = samples[0][0]
|
85 |
+
# make sure all samples in samples are from same dataset
|
86 |
+
assert sum([0 if dataset_idx == s[0] else 1 for s in samples]) == 0
|
87 |
+
samples = self.datasets[dataset_idx].collater([x[1] for x in samples])
|
88 |
+
# add mode
|
89 |
+
samples["net_input"]["mode"] = self.id_to_mode[dataset_idx]
|
90 |
+
|
91 |
+
return samples
|
92 |
+
|
93 |
+
def size(self, index: int):
|
94 |
+
if len(self.datasets) == 1:
|
95 |
+
return self.datasets[0].size(index)
|
96 |
+
return super().size(index)
|
97 |
+
|
98 |
+
@property
|
99 |
+
def sizes(self):
|
100 |
+
if len(self.datasets) == 1:
|
101 |
+
return self.datasets[0].sizes
|
102 |
+
return super().sizes
|
103 |
+
|
104 |
+
def ordered_indices(self):
|
105 |
+
"""
|
106 |
+
Returns indices sorted by length. So less padding is needed.
|
107 |
+
"""
|
108 |
+
if len(self.datasets) == 1:
|
109 |
+
return self.datasets[0].ordered_indices()
|
110 |
+
indices_group = []
|
111 |
+
for d_idx, ds in enumerate(self.datasets):
|
112 |
+
sample_num = self.cumulative_sizes[d_idx]
|
113 |
+
if d_idx > 0:
|
114 |
+
sample_num = sample_num - self.cumulative_sizes[d_idx - 1]
|
115 |
+
assert sample_num == len(ds)
|
116 |
+
indices_group.append(ds.ordered_indices())
|
117 |
+
return indices_group
|
118 |
+
|
119 |
+
def get_raw_batch_samplers(self, required_batch_size_multiple, seed):
|
120 |
+
with data_utils.numpy_seed(seed):
|
121 |
+
indices = self.ordered_indices()
|
122 |
+
for i, ds in enumerate(self.datasets):
|
123 |
+
# If we have ResamplingDataset, the same id can correpond to a different
|
124 |
+
# sample in the next epoch, so we need to rebuild this at every epoch
|
125 |
+
if i < len(self.raw_sub_batch_samplers) and not resampling_dataset_present(
|
126 |
+
ds
|
127 |
+
):
|
128 |
+
logger.info(f"dataset {i} is valid and it is not re-sampled")
|
129 |
+
continue
|
130 |
+
indices[i] = ds.filter_indices_by_size(
|
131 |
+
indices[i],
|
132 |
+
self.max_positions[i],
|
133 |
+
)[0]
|
134 |
+
sub_batch_sampler = ds.batch_by_size(
|
135 |
+
indices[i],
|
136 |
+
max_tokens=self.max_tokens[i],
|
137 |
+
max_sentences=self.max_sentences[i],
|
138 |
+
required_batch_size_multiple=required_batch_size_multiple,
|
139 |
+
)
|
140 |
+
if i < len(self.raw_sub_batch_samplers):
|
141 |
+
self.raw_sub_batch_samplers[i] = sub_batch_sampler
|
142 |
+
else:
|
143 |
+
self.raw_sub_batch_samplers.append(sub_batch_sampler)
|
144 |
+
|
145 |
+
def get_batch_samplers(self, mult_ratios, required_batch_size_multiple, seed):
|
146 |
+
self.get_raw_batch_samplers(required_batch_size_multiple, seed)
|
147 |
+
batch_samplers = []
|
148 |
+
for i, _ in enumerate(self.datasets):
|
149 |
+
if i > 0:
|
150 |
+
sub_batch_sampler = [
|
151 |
+
[y + self.cumulative_sizes[i - 1] for y in x]
|
152 |
+
for x in self.raw_sub_batch_samplers[i]
|
153 |
+
]
|
154 |
+
else:
|
155 |
+
sub_batch_sampler = list(self.raw_sub_batch_samplers[i])
|
156 |
+
smp_r = mult_ratios[i]
|
157 |
+
if smp_r != 1:
|
158 |
+
is_increase = "increased" if smp_r > 1 else "decreased"
|
159 |
+
logger.info(
|
160 |
+
"number of batch for the dataset {} is {} from {} to {}".format(
|
161 |
+
self.id_to_mode[i],
|
162 |
+
is_increase,
|
163 |
+
len(sub_batch_sampler),
|
164 |
+
int(len(sub_batch_sampler) * smp_r),
|
165 |
+
)
|
166 |
+
)
|
167 |
+
mul_samplers = []
|
168 |
+
for _ in range(math.floor(smp_r)):
|
169 |
+
mul_samplers = mul_samplers + sub_batch_sampler
|
170 |
+
if math.floor(smp_r) != smp_r:
|
171 |
+
with data_utils.numpy_seed(seed + self._cur_epoch):
|
172 |
+
np.random.shuffle(sub_batch_sampler)
|
173 |
+
smp_num = int(
|
174 |
+
(smp_r - math.floor(smp_r)) * len(sub_batch_sampler)
|
175 |
+
)
|
176 |
+
mul_samplers = mul_samplers + sub_batch_sampler[:smp_num]
|
177 |
+
sub_batch_sampler = mul_samplers
|
178 |
+
else:
|
179 |
+
logger.info(
|
180 |
+
"dataset {} batch number is {} ".format(
|
181 |
+
self.id_to_mode[i], len(sub_batch_sampler)
|
182 |
+
)
|
183 |
+
)
|
184 |
+
batch_samplers.append(sub_batch_sampler)
|
185 |
+
|
186 |
+
return batch_samplers
|
187 |
+
|
188 |
+
|
189 |
+
class LangPairMaskDataset(FairseqDataset):
|
190 |
+
def __init__(
|
191 |
+
self,
|
192 |
+
dataset: LanguagePairDataset,
|
193 |
+
src_eos: int,
|
194 |
+
src_bos: Optional[int] = None,
|
195 |
+
noise_id: Optional[int] = -1,
|
196 |
+
mask_ratio: Optional[float] = 0,
|
197 |
+
mask_type: Optional[str] = "random",
|
198 |
+
):
|
199 |
+
self.dataset = dataset
|
200 |
+
self.src_eos = src_eos
|
201 |
+
self.src_bos = src_bos
|
202 |
+
self.noise_id = noise_id
|
203 |
+
self.mask_ratio = mask_ratio
|
204 |
+
self.mask_type = mask_type
|
205 |
+
assert mask_type in ("random", "tail")
|
206 |
+
|
207 |
+
@property
|
208 |
+
def src_sizes(self):
|
209 |
+
return self.dataset.src_sizes
|
210 |
+
|
211 |
+
@property
|
212 |
+
def tgt_sizes(self):
|
213 |
+
return self.dataset.tgt_sizes
|
214 |
+
|
215 |
+
@property
|
216 |
+
def sizes(self):
|
217 |
+
# dataset.sizes can be a dynamically computed sizes:
|
218 |
+
return self.dataset.sizes
|
219 |
+
|
220 |
+
def get_batch_shapes(self):
|
221 |
+
if hasattr(self.dataset, "get_batch_shapes"):
|
222 |
+
return self.dataset.get_batch_shapes()
|
223 |
+
return self.dataset.buckets
|
224 |
+
|
225 |
+
def num_tokens_vec(self, indices):
|
226 |
+
return self.dataset.num_tokens_vec(indices)
|
227 |
+
|
228 |
+
def __len__(self):
|
229 |
+
return len(self.dataset)
|
230 |
+
|
231 |
+
def num_tokens(self, index):
|
232 |
+
return self.dataset.num_tokens(index)
|
233 |
+
|
234 |
+
def size(self, index):
|
235 |
+
return self.dataset.size(index)
|
236 |
+
|
237 |
+
def ordered_indices(self):
|
238 |
+
return self.dataset.ordered_indices()
|
239 |
+
|
240 |
+
@property
|
241 |
+
def supports_prefetch(self):
|
242 |
+
return getattr(self.dataset, "supports_prefetch", False)
|
243 |
+
|
244 |
+
def prefetch(self, indices):
|
245 |
+
return self.dataset.prefetch(indices)
|
246 |
+
|
247 |
+
def mask_src_tokens(self, sample):
|
248 |
+
src_item = sample["source"]
|
249 |
+
mask = None
|
250 |
+
if self.mask_type == "random":
|
251 |
+
mask = torch.rand(len(src_item)).le(self.mask_ratio)
|
252 |
+
else:
|
253 |
+
mask = torch.ones(len(src_item))
|
254 |
+
mask[: int(len(src_item) * (1 - self.mask_ratio))] = 0
|
255 |
+
mask = mask.eq(1)
|
256 |
+
if src_item[0] == self.src_bos:
|
257 |
+
mask[0] = False
|
258 |
+
if src_item[-1] == self.src_eos:
|
259 |
+
mask[-1] = False
|
260 |
+
mask_src_item = src_item.masked_fill(mask, self.noise_id)
|
261 |
+
smp = {"id": sample["id"], "source": mask_src_item, "target": sample["target"]}
|
262 |
+
return smp
|
263 |
+
|
264 |
+
def __getitem__(self, index):
|
265 |
+
sample = self.dataset[index]
|
266 |
+
if self.mask_ratio > 0:
|
267 |
+
sample = self.mask_src_tokens(sample)
|
268 |
+
return sample
|
269 |
+
|
270 |
+
def collater(self, samples, pad_to_length=None):
|
271 |
+
return self.dataset.collater(samples, pad_to_length)
|
272 |
+
|
273 |
+
|
274 |
+
class FileAudioDatasetWrapper(FileAudioDataset):
|
275 |
+
def collater(self, samples):
|
276 |
+
samples = super().collater(samples)
|
277 |
+
if len(samples) == 0:
|
278 |
+
return {}
|
279 |
+
samples["net_input"]["src_tokens"] = samples["net_input"]["source"]
|
280 |
+
samples["net_input"]["prev_output_tokens"] = None
|
281 |
+
del samples["net_input"]["source"]
|
282 |
+
samples["net_input"]["src_lengths"] = None
|
283 |
+
samples["net_input"]["alignment"] = None
|
284 |
+
return samples
|
modules/voice_conversion/fairseq/data/audio/raw_audio_dataset.py
ADDED
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
import io
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
from .. import FairseqDataset
|
17 |
+
from ..data_utils import compute_mask_indices, get_buckets, get_bucketed_sizes
|
18 |
+
from fairseq.data.audio.audio_utils import (
|
19 |
+
parse_path,
|
20 |
+
read_from_stored_zip,
|
21 |
+
is_sf_audio_data,
|
22 |
+
)
|
23 |
+
from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel
|
24 |
+
|
25 |
+
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
|
29 |
+
class RawAudioDataset(FairseqDataset):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
sample_rate,
|
33 |
+
max_sample_size=None,
|
34 |
+
min_sample_size=0,
|
35 |
+
shuffle=True,
|
36 |
+
pad=False,
|
37 |
+
normalize=False,
|
38 |
+
compute_mask_indices=False,
|
39 |
+
**mask_compute_kwargs,
|
40 |
+
):
|
41 |
+
super().__init__()
|
42 |
+
|
43 |
+
self.sample_rate = sample_rate
|
44 |
+
self.sizes = []
|
45 |
+
self.max_sample_size = (
|
46 |
+
max_sample_size if max_sample_size is not None else sys.maxsize
|
47 |
+
)
|
48 |
+
self.min_sample_size = min_sample_size
|
49 |
+
self.pad = pad
|
50 |
+
self.shuffle = shuffle
|
51 |
+
self.normalize = normalize
|
52 |
+
self.compute_mask_indices = compute_mask_indices
|
53 |
+
if self.compute_mask_indices:
|
54 |
+
self.mask_compute_kwargs = mask_compute_kwargs
|
55 |
+
self._features_size_map = {}
|
56 |
+
self._C = mask_compute_kwargs["encoder_embed_dim"]
|
57 |
+
self._conv_feature_layers = eval(mask_compute_kwargs["conv_feature_layers"])
|
58 |
+
|
59 |
+
def __getitem__(self, index):
|
60 |
+
raise NotImplementedError()
|
61 |
+
|
62 |
+
def __len__(self):
|
63 |
+
return len(self.sizes)
|
64 |
+
|
65 |
+
def postprocess(self, feats, curr_sample_rate):
|
66 |
+
if feats.dim() == 2:
|
67 |
+
feats = feats.mean(-1)
|
68 |
+
|
69 |
+
if curr_sample_rate != self.sample_rate:
|
70 |
+
raise Exception(f"sample rate: {curr_sample_rate}, need {self.sample_rate}")
|
71 |
+
|
72 |
+
assert feats.dim() == 1, feats.dim()
|
73 |
+
|
74 |
+
if self.normalize:
|
75 |
+
with torch.no_grad():
|
76 |
+
feats = F.layer_norm(feats, feats.shape)
|
77 |
+
return feats
|
78 |
+
|
79 |
+
def crop_to_max_size(self, wav, target_size):
|
80 |
+
size = len(wav)
|
81 |
+
diff = size - target_size
|
82 |
+
if diff <= 0:
|
83 |
+
return wav
|
84 |
+
|
85 |
+
start = np.random.randint(0, diff + 1)
|
86 |
+
end = size - diff + start
|
87 |
+
return wav[start:end]
|
88 |
+
|
89 |
+
def _compute_mask_indices(self, dims, padding_mask):
|
90 |
+
B, T, C = dims
|
91 |
+
mask_indices, mask_channel_indices = None, None
|
92 |
+
if self.mask_compute_kwargs["mask_prob"] > 0:
|
93 |
+
mask_indices = compute_mask_indices(
|
94 |
+
(B, T),
|
95 |
+
padding_mask,
|
96 |
+
self.mask_compute_kwargs["mask_prob"],
|
97 |
+
self.mask_compute_kwargs["mask_length"],
|
98 |
+
self.mask_compute_kwargs["mask_selection"],
|
99 |
+
self.mask_compute_kwargs["mask_other"],
|
100 |
+
min_masks=2,
|
101 |
+
no_overlap=self.mask_compute_kwargs["no_mask_overlap"],
|
102 |
+
min_space=self.mask_compute_kwargs["mask_min_space"],
|
103 |
+
)
|
104 |
+
mask_indices = torch.from_numpy(mask_indices)
|
105 |
+
if self.mask_compute_kwargs["mask_channel_prob"] > 0:
|
106 |
+
mask_channel_indices = compute_mask_indices(
|
107 |
+
(B, C),
|
108 |
+
None,
|
109 |
+
self.mask_compute_kwargs["mask_channel_prob"],
|
110 |
+
self.mask_compute_kwargs["mask_channel_length"],
|
111 |
+
self.mask_compute_kwargs["mask_channel_selection"],
|
112 |
+
self.mask_compute_kwargs["mask_channel_other"],
|
113 |
+
no_overlap=self.mask_compute_kwargs["no_mask_channel_overlap"],
|
114 |
+
min_space=self.mask_compute_kwargs["mask_channel_min_space"],
|
115 |
+
)
|
116 |
+
mask_channel_indices = (
|
117 |
+
torch.from_numpy(mask_channel_indices).unsqueeze(1).expand(-1, T, -1)
|
118 |
+
)
|
119 |
+
|
120 |
+
return mask_indices, mask_channel_indices
|
121 |
+
|
122 |
+
@staticmethod
|
123 |
+
def _bucket_tensor(tensor, num_pad, value):
|
124 |
+
return F.pad(tensor, (0, num_pad), value=value)
|
125 |
+
|
126 |
+
def collater(self, samples):
|
127 |
+
samples = [s for s in samples if s["source"] is not None]
|
128 |
+
if len(samples) == 0:
|
129 |
+
return {}
|
130 |
+
|
131 |
+
sources = [s["source"] for s in samples]
|
132 |
+
sizes = [len(s) for s in sources]
|
133 |
+
|
134 |
+
if self.pad:
|
135 |
+
target_size = min(max(sizes), self.max_sample_size)
|
136 |
+
else:
|
137 |
+
target_size = min(min(sizes), self.max_sample_size)
|
138 |
+
|
139 |
+
collated_sources = sources[0].new_zeros(len(sources), target_size)
|
140 |
+
padding_mask = (
|
141 |
+
torch.BoolTensor(collated_sources.shape).fill_(False) if self.pad else None
|
142 |
+
)
|
143 |
+
for i, (source, size) in enumerate(zip(sources, sizes)):
|
144 |
+
diff = size - target_size
|
145 |
+
if diff == 0:
|
146 |
+
collated_sources[i] = source
|
147 |
+
elif diff < 0:
|
148 |
+
assert self.pad
|
149 |
+
collated_sources[i] = torch.cat(
|
150 |
+
[source, source.new_full((-diff,), 0.0)]
|
151 |
+
)
|
152 |
+
padding_mask[i, diff:] = True
|
153 |
+
else:
|
154 |
+
collated_sources[i] = self.crop_to_max_size(source, target_size)
|
155 |
+
|
156 |
+
input = {"source": collated_sources}
|
157 |
+
out = {"id": torch.LongTensor([s["id"] for s in samples])}
|
158 |
+
if self.pad:
|
159 |
+
input["padding_mask"] = padding_mask
|
160 |
+
|
161 |
+
if hasattr(self, "num_buckets") and self.num_buckets > 0:
|
162 |
+
assert self.pad, "Cannot bucket without padding first."
|
163 |
+
bucket = max(self._bucketed_sizes[s["id"]] for s in samples)
|
164 |
+
num_pad = bucket - collated_sources.size(-1)
|
165 |
+
if num_pad:
|
166 |
+
input["source"] = self._bucket_tensor(collated_sources, num_pad, 0)
|
167 |
+
input["padding_mask"] = self._bucket_tensor(padding_mask, num_pad, True)
|
168 |
+
|
169 |
+
if self.compute_mask_indices:
|
170 |
+
B = input["source"].size(0)
|
171 |
+
T = self._get_mask_indices_dims(input["source"].size(-1))
|
172 |
+
padding_mask_reshaped = input["padding_mask"].clone()
|
173 |
+
extra = padding_mask_reshaped.size(1) % T
|
174 |
+
if extra > 0:
|
175 |
+
padding_mask_reshaped = padding_mask_reshaped[:, :-extra]
|
176 |
+
padding_mask_reshaped = padding_mask_reshaped.view(
|
177 |
+
padding_mask_reshaped.size(0), T, -1
|
178 |
+
)
|
179 |
+
padding_mask_reshaped = padding_mask_reshaped.all(-1)
|
180 |
+
input["padding_count"] = padding_mask_reshaped.sum(-1).max().item()
|
181 |
+
mask_indices, mask_channel_indices = self._compute_mask_indices(
|
182 |
+
(B, T, self._C),
|
183 |
+
padding_mask_reshaped,
|
184 |
+
)
|
185 |
+
input["mask_indices"] = mask_indices
|
186 |
+
input["mask_channel_indices"] = mask_channel_indices
|
187 |
+
out["sample_size"] = mask_indices.sum().item()
|
188 |
+
|
189 |
+
out["net_input"] = input
|
190 |
+
return out
|
191 |
+
|
192 |
+
def _get_mask_indices_dims(self, size, padding=0, dilation=1):
|
193 |
+
if size not in self._features_size_map:
|
194 |
+
L_in = size
|
195 |
+
for (_, kernel_size, stride) in self._conv_feature_layers:
|
196 |
+
L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
|
197 |
+
L_out = 1 + L_out // stride
|
198 |
+
L_in = L_out
|
199 |
+
self._features_size_map[size] = L_out
|
200 |
+
return self._features_size_map[size]
|
201 |
+
|
202 |
+
def num_tokens(self, index):
|
203 |
+
return self.size(index)
|
204 |
+
|
205 |
+
def size(self, index):
|
206 |
+
"""Return an example's size as a float or tuple. This value is used when
|
207 |
+
filtering a dataset with ``--max-positions``."""
|
208 |
+
if self.pad:
|
209 |
+
return self.sizes[index]
|
210 |
+
return min(self.sizes[index], self.max_sample_size)
|
211 |
+
|
212 |
+
def ordered_indices(self):
|
213 |
+
"""Return an ordered list of indices. Batches will be constructed based
|
214 |
+
on this order."""
|
215 |
+
|
216 |
+
if self.shuffle:
|
217 |
+
order = [np.random.permutation(len(self))]
|
218 |
+
order.append(
|
219 |
+
np.minimum(
|
220 |
+
np.array(self.sizes),
|
221 |
+
self.max_sample_size,
|
222 |
+
)
|
223 |
+
)
|
224 |
+
return np.lexsort(order)[::-1]
|
225 |
+
else:
|
226 |
+
return np.arange(len(self))
|
227 |
+
|
228 |
+
def set_bucket_info(self, num_buckets):
|
229 |
+
self.num_buckets = num_buckets
|
230 |
+
if self.num_buckets > 0:
|
231 |
+
self._collated_sizes = np.minimum(
|
232 |
+
np.array(self.sizes),
|
233 |
+
self.max_sample_size,
|
234 |
+
)
|
235 |
+
self.buckets = get_buckets(
|
236 |
+
self._collated_sizes,
|
237 |
+
self.num_buckets,
|
238 |
+
)
|
239 |
+
self._bucketed_sizes = get_bucketed_sizes(
|
240 |
+
self._collated_sizes, self.buckets
|
241 |
+
)
|
242 |
+
logger.info(
|
243 |
+
f"{len(self.buckets)} bucket(s) for the audio dataset: "
|
244 |
+
f"{self.buckets}"
|
245 |
+
)
|
246 |
+
|
247 |
+
|
248 |
+
class FileAudioDataset(RawAudioDataset):
|
249 |
+
def __init__(
|
250 |
+
self,
|
251 |
+
manifest_path,
|
252 |
+
sample_rate,
|
253 |
+
max_sample_size=None,
|
254 |
+
min_sample_size=0,
|
255 |
+
shuffle=True,
|
256 |
+
pad=False,
|
257 |
+
normalize=False,
|
258 |
+
num_buckets=0,
|
259 |
+
compute_mask_indices=False,
|
260 |
+
text_compression_level=TextCompressionLevel.none,
|
261 |
+
**mask_compute_kwargs,
|
262 |
+
):
|
263 |
+
super().__init__(
|
264 |
+
sample_rate=sample_rate,
|
265 |
+
max_sample_size=max_sample_size,
|
266 |
+
min_sample_size=min_sample_size,
|
267 |
+
shuffle=shuffle,
|
268 |
+
pad=pad,
|
269 |
+
normalize=normalize,
|
270 |
+
compute_mask_indices=compute_mask_indices,
|
271 |
+
**mask_compute_kwargs,
|
272 |
+
)
|
273 |
+
|
274 |
+
self.text_compressor = TextCompressor(level=text_compression_level)
|
275 |
+
|
276 |
+
skipped = 0
|
277 |
+
self.fnames = []
|
278 |
+
sizes = []
|
279 |
+
self.skipped_indices = set()
|
280 |
+
|
281 |
+
with open(manifest_path, "r") as f:
|
282 |
+
self.root_dir = f.readline().strip()
|
283 |
+
for i, line in enumerate(f):
|
284 |
+
items = line.strip().split("\t")
|
285 |
+
assert len(items) == 2, line
|
286 |
+
sz = int(items[1])
|
287 |
+
if min_sample_size is not None and sz < min_sample_size:
|
288 |
+
skipped += 1
|
289 |
+
self.skipped_indices.add(i)
|
290 |
+
continue
|
291 |
+
self.fnames.append(self.text_compressor.compress(items[0]))
|
292 |
+
sizes.append(sz)
|
293 |
+
logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples")
|
294 |
+
|
295 |
+
self.sizes = np.array(sizes, dtype=np.int64)
|
296 |
+
|
297 |
+
try:
|
298 |
+
import pyarrow
|
299 |
+
|
300 |
+
self.fnames = pyarrow.array(self.fnames)
|
301 |
+
except:
|
302 |
+
logger.debug(
|
303 |
+
"Could not create a pyarrow array. Please install pyarrow for better performance"
|
304 |
+
)
|
305 |
+
pass
|
306 |
+
|
307 |
+
self.set_bucket_info(num_buckets)
|
308 |
+
|
309 |
+
def __getitem__(self, index):
|
310 |
+
import soundfile as sf
|
311 |
+
|
312 |
+
fn = self.fnames[index]
|
313 |
+
fn = fn if isinstance(self.fnames, list) else fn.as_py()
|
314 |
+
fn = self.text_compressor.decompress(fn)
|
315 |
+
path_or_fp = os.path.join(self.root_dir, fn)
|
316 |
+
_path, slice_ptr = parse_path(path_or_fp)
|
317 |
+
if len(slice_ptr) == 2:
|
318 |
+
byte_data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
|
319 |
+
assert is_sf_audio_data(byte_data)
|
320 |
+
path_or_fp = io.BytesIO(byte_data)
|
321 |
+
|
322 |
+
wav, curr_sample_rate = sf.read(path_or_fp, dtype="float32")
|
323 |
+
|
324 |
+
feats = torch.from_numpy(wav).float()
|
325 |
+
feats = self.postprocess(feats, curr_sample_rate)
|
326 |
+
return {"id": index, "source": feats}
|
327 |
+
|
328 |
+
|
329 |
+
class BinarizedAudioDataset(RawAudioDataset):
|
330 |
+
def __init__(
|
331 |
+
self,
|
332 |
+
data_dir,
|
333 |
+
split,
|
334 |
+
sample_rate,
|
335 |
+
max_sample_size=None,
|
336 |
+
min_sample_size=0,
|
337 |
+
shuffle=True,
|
338 |
+
pad=False,
|
339 |
+
normalize=False,
|
340 |
+
num_buckets=0,
|
341 |
+
compute_mask_indices=False,
|
342 |
+
**mask_compute_kwargs,
|
343 |
+
):
|
344 |
+
super().__init__(
|
345 |
+
sample_rate=sample_rate,
|
346 |
+
max_sample_size=max_sample_size,
|
347 |
+
min_sample_size=min_sample_size,
|
348 |
+
shuffle=shuffle,
|
349 |
+
pad=pad,
|
350 |
+
normalize=normalize,
|
351 |
+
compute_mask_indices=compute_mask_indices,
|
352 |
+
**mask_compute_kwargs,
|
353 |
+
)
|
354 |
+
|
355 |
+
from fairseq.data import data_utils, Dictionary
|
356 |
+
|
357 |
+
self.fnames_dict = Dictionary.load(os.path.join(data_dir, "dict.txt"))
|
358 |
+
|
359 |
+
root_path = os.path.join(data_dir, f"{split}.root")
|
360 |
+
if os.path.exists(root_path):
|
361 |
+
with open(root_path, "r") as f:
|
362 |
+
self.root_dir = next(f).strip()
|
363 |
+
else:
|
364 |
+
self.root_dir = None
|
365 |
+
|
366 |
+
fnames_path = os.path.join(data_dir, split)
|
367 |
+
self.fnames = data_utils.load_indexed_dataset(fnames_path, self.fnames_dict)
|
368 |
+
lengths_path = os.path.join(data_dir, f"{split}.lengths")
|
369 |
+
|
370 |
+
with open(lengths_path, "r") as f:
|
371 |
+
for line in f:
|
372 |
+
sz = int(line.rstrip())
|
373 |
+
assert (
|
374 |
+
sz >= min_sample_size
|
375 |
+
), f"Min sample size is not supported for binarized dataset, but found a sample with size {sz}"
|
376 |
+
self.sizes.append(sz)
|
377 |
+
|
378 |
+
self.sizes = np.array(self.sizes, dtype=np.int64)
|
379 |
+
|
380 |
+
self.set_bucket_info(num_buckets)
|
381 |
+
logger.info(f"loaded {len(self.fnames)} samples")
|
382 |
+
|
383 |
+
def __getitem__(self, index):
|
384 |
+
import soundfile as sf
|
385 |
+
|
386 |
+
fname = self.fnames_dict.string(self.fnames[index], separator="")
|
387 |
+
if self.root_dir:
|
388 |
+
fname = os.path.join(self.root_dir, fname)
|
389 |
+
|
390 |
+
wav, curr_sample_rate = sf.read(fname)
|
391 |
+
feats = torch.from_numpy(wav).float()
|
392 |
+
feats = self.postprocess(feats, curr_sample_rate)
|
393 |
+
return {"id": index, "source": feats}
|
modules/voice_conversion/fairseq/data/audio/speech_to_speech_dataset.py
ADDED
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import Dict, List, Optional, Tuple
|
10 |
+
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from fairseq.data import ConcatDataset, Dictionary
|
14 |
+
from fairseq.data import data_utils as fairseq_data_utils
|
15 |
+
from fairseq.data.audio.audio_utils import get_features_or_waveform
|
16 |
+
from fairseq.data.audio.data_cfg import S2SDataConfig
|
17 |
+
from fairseq.data.audio.speech_to_text_dataset import (
|
18 |
+
SpeechToTextDataset,
|
19 |
+
SpeechToTextDatasetCreator,
|
20 |
+
TextTargetMultitaskData,
|
21 |
+
_collate_frames,
|
22 |
+
)
|
23 |
+
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class SpeechToSpeechDatasetItem(object):
|
29 |
+
index: int
|
30 |
+
source: torch.Tensor
|
31 |
+
target: Optional[torch.Tensor] = None
|
32 |
+
target_speaker: Optional[torch.Tensor] = None
|
33 |
+
tgt_lang_tag: Optional[int] = None
|
34 |
+
|
35 |
+
|
36 |
+
class SpeechToSpeechDataset(SpeechToTextDataset):
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
split: str,
|
40 |
+
is_train_split: bool,
|
41 |
+
data_cfg: S2SDataConfig,
|
42 |
+
src_audio_paths: List[str],
|
43 |
+
src_n_frames: List[int],
|
44 |
+
tgt_audio_paths: List[str],
|
45 |
+
tgt_n_frames: List[int],
|
46 |
+
src_langs: Optional[List[str]] = None,
|
47 |
+
tgt_langs: Optional[List[str]] = None,
|
48 |
+
ids: Optional[List[str]] = None,
|
49 |
+
target_is_code: bool = False,
|
50 |
+
tgt_dict: Dictionary = None,
|
51 |
+
n_frames_per_step: int = 1,
|
52 |
+
):
|
53 |
+
tgt_texts = tgt_audio_paths if target_is_code else None
|
54 |
+
super().__init__(
|
55 |
+
split=split,
|
56 |
+
is_train_split=is_train_split,
|
57 |
+
cfg=data_cfg,
|
58 |
+
audio_paths=src_audio_paths,
|
59 |
+
n_frames=src_n_frames,
|
60 |
+
ids=ids,
|
61 |
+
tgt_dict=tgt_dict,
|
62 |
+
tgt_texts=tgt_texts,
|
63 |
+
src_langs=src_langs,
|
64 |
+
tgt_langs=tgt_langs,
|
65 |
+
n_frames_per_step=n_frames_per_step,
|
66 |
+
)
|
67 |
+
|
68 |
+
self.tgt_audio_paths = tgt_audio_paths
|
69 |
+
self.tgt_lens = [t // self.n_frames_per_step for t in tgt_n_frames]
|
70 |
+
|
71 |
+
assert not target_is_code or tgt_dict is not None
|
72 |
+
self.target_is_code = target_is_code
|
73 |
+
|
74 |
+
assert len(tgt_audio_paths) == self.n_samples
|
75 |
+
assert len(tgt_n_frames) == self.n_samples
|
76 |
+
|
77 |
+
self.tgt_speakers = None
|
78 |
+
if self.cfg.target_speaker_embed:
|
79 |
+
samples = SpeechToTextDatasetCreator._load_samples_from_tsv(
|
80 |
+
self.cfg.target_speaker_embed, split
|
81 |
+
)
|
82 |
+
spk_emb_dict = {s["id"]: s["speaker_embed"] for s in samples}
|
83 |
+
self.tgt_speakers = [spk_emb_dict[id] for id in self.ids]
|
84 |
+
assert len(self.tgt_speakers) == self.n_samples
|
85 |
+
|
86 |
+
logger.info(self.__repr__())
|
87 |
+
|
88 |
+
def pack_units(self, input: torch.Tensor) -> torch.Tensor:
|
89 |
+
if self.n_frames_per_step <= 1:
|
90 |
+
return input
|
91 |
+
|
92 |
+
offset = 4
|
93 |
+
vocab_size = (
|
94 |
+
len(self.tgt_dict) - offset
|
95 |
+
) # remove offset from <bos>, <pad>, <eos>, <unk>, which is specific to fairseq dictionary
|
96 |
+
|
97 |
+
assert input.dim() == 1
|
98 |
+
stacked_input = (
|
99 |
+
input[:-1].view(-1, self.n_frames_per_step) - offset
|
100 |
+
) # remove <eos>
|
101 |
+
scale = [
|
102 |
+
pow(vocab_size, self.n_frames_per_step - 1 - i)
|
103 |
+
for i in range(self.n_frames_per_step)
|
104 |
+
]
|
105 |
+
scale = torch.LongTensor(scale).squeeze(0)
|
106 |
+
res = input.new((len(input) - 1) // self.n_frames_per_step + 1).fill_(input[-1])
|
107 |
+
res[:-1] = (stacked_input * scale).sum(dim=1) + offset
|
108 |
+
|
109 |
+
return res
|
110 |
+
|
111 |
+
def __getitem__(self, index: int) -> SpeechToSpeechDatasetItem:
|
112 |
+
source = self._get_source_audio(index)
|
113 |
+
|
114 |
+
tgt_lang_tag = None
|
115 |
+
if self.cfg.prepend_tgt_lang_tag_as_bos:
|
116 |
+
# prepend_tgt_lang_tag_as_bos: put tgt_lang_tag as bos of target
|
117 |
+
tgt_lang_tag = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict)
|
118 |
+
|
119 |
+
if not self.target_is_code:
|
120 |
+
target = get_features_or_waveform(self.tgt_audio_paths[index])
|
121 |
+
target = torch.from_numpy(target).float()
|
122 |
+
target = self.pack_frames(target)
|
123 |
+
else:
|
124 |
+
target = self.tgt_dict.encode_line(
|
125 |
+
self.tgt_audio_paths[index],
|
126 |
+
add_if_not_exist=False,
|
127 |
+
append_eos=True,
|
128 |
+
).long()
|
129 |
+
if self.n_frames_per_step > 1:
|
130 |
+
n_tgt_frame = target.size(0) - 1 # exclude <eos>
|
131 |
+
keep_n_tgt_frame = n_tgt_frame - n_tgt_frame % self.n_frames_per_step
|
132 |
+
target = torch.cat(
|
133 |
+
(
|
134 |
+
target[:keep_n_tgt_frame],
|
135 |
+
target.new_full((1,), self.tgt_dict.eos()),
|
136 |
+
),
|
137 |
+
dim=0,
|
138 |
+
)
|
139 |
+
|
140 |
+
if self.tgt_speakers:
|
141 |
+
tgt_spk = get_features_or_waveform(self.tgt_speakers[index])
|
142 |
+
tgt_spk = torch.from_numpy(tgt_spk).float()
|
143 |
+
else:
|
144 |
+
tgt_spk = torch.FloatTensor([])
|
145 |
+
|
146 |
+
return SpeechToSpeechDatasetItem(
|
147 |
+
index=index,
|
148 |
+
source=source,
|
149 |
+
target=target,
|
150 |
+
target_speaker=tgt_spk,
|
151 |
+
tgt_lang_tag=tgt_lang_tag,
|
152 |
+
)
|
153 |
+
|
154 |
+
def _collate_target(self, samples: List[SpeechToSpeechDatasetItem]) -> torch.Tensor:
|
155 |
+
if self.target_is_code:
|
156 |
+
target = fairseq_data_utils.collate_tokens(
|
157 |
+
[x.target for x in samples],
|
158 |
+
self.tgt_dict.pad(),
|
159 |
+
self.tgt_dict.eos(),
|
160 |
+
left_pad=False,
|
161 |
+
move_eos_to_beginning=False,
|
162 |
+
)
|
163 |
+
# convert stacked units to a single id
|
164 |
+
pack_targets = [self.pack_units(x.target) for x in samples]
|
165 |
+
prev_output_tokens = fairseq_data_utils.collate_tokens(
|
166 |
+
pack_targets,
|
167 |
+
self.tgt_dict.pad(),
|
168 |
+
self.tgt_dict.eos(),
|
169 |
+
left_pad=False,
|
170 |
+
move_eos_to_beginning=True,
|
171 |
+
)
|
172 |
+
target_lengths = torch.tensor(
|
173 |
+
[x.size(0) for x in pack_targets], dtype=torch.long
|
174 |
+
)
|
175 |
+
else:
|
176 |
+
target = _collate_frames([x.target for x in samples], is_audio_input=False)
|
177 |
+
bsz, _, d = target.size()
|
178 |
+
prev_output_tokens = torch.cat(
|
179 |
+
(target.new_full((bsz, 1, d), 0.0), target[:, :-1, :]), dim=1
|
180 |
+
)
|
181 |
+
target_lengths = torch.tensor(
|
182 |
+
[x.target.size(0) for x in samples], dtype=torch.long
|
183 |
+
)
|
184 |
+
|
185 |
+
return target, prev_output_tokens, target_lengths
|
186 |
+
|
187 |
+
def collater(
|
188 |
+
self, samples: List[SpeechToSpeechDatasetItem], return_order: bool = False
|
189 |
+
) -> Dict:
|
190 |
+
if len(samples) == 0:
|
191 |
+
return {}
|
192 |
+
indices = torch.tensor([x.index for x in samples], dtype=torch.long)
|
193 |
+
frames = _collate_frames([x.source for x in samples], self.cfg.use_audio_input)
|
194 |
+
# sort samples by descending number of frames
|
195 |
+
n_frames = torch.tensor([x.source.size(0) for x in samples], dtype=torch.long)
|
196 |
+
n_frames, order = n_frames.sort(descending=True)
|
197 |
+
indices = indices.index_select(0, order)
|
198 |
+
frames = frames.index_select(0, order)
|
199 |
+
|
200 |
+
target, prev_output_tokens, target_lengths = self._collate_target(samples)
|
201 |
+
target = target.index_select(0, order)
|
202 |
+
target_lengths = target_lengths.index_select(0, order)
|
203 |
+
prev_output_tokens = prev_output_tokens.index_select(0, order)
|
204 |
+
ntokens = sum(x.target.size(0) for x in samples)
|
205 |
+
|
206 |
+
tgt_speakers = None
|
207 |
+
if self.cfg.target_speaker_embed:
|
208 |
+
tgt_speakers = _collate_frames(
|
209 |
+
[x.target_speaker for x in samples], is_audio_input=True
|
210 |
+
).index_select(0, order)
|
211 |
+
|
212 |
+
net_input = {
|
213 |
+
"src_tokens": frames,
|
214 |
+
"src_lengths": n_frames,
|
215 |
+
"prev_output_tokens": prev_output_tokens,
|
216 |
+
"tgt_speaker": tgt_speakers, # TODO: unify "speaker" and "tgt_speaker"
|
217 |
+
}
|
218 |
+
if self.tgt_texts is not None and samples[0].tgt_lang_tag is not None:
|
219 |
+
for i in range(len(samples)):
|
220 |
+
net_input["prev_output_tokens"][i][0] = samples[order[i]].tgt_lang_tag
|
221 |
+
out = {
|
222 |
+
"id": indices,
|
223 |
+
"net_input": net_input,
|
224 |
+
"speaker": tgt_speakers, # to support Tacotron2 loss for speech-to-spectrogram model
|
225 |
+
"target": target,
|
226 |
+
"target_lengths": target_lengths,
|
227 |
+
"ntokens": ntokens,
|
228 |
+
"nsentences": len(samples),
|
229 |
+
}
|
230 |
+
if return_order:
|
231 |
+
out["order"] = order
|
232 |
+
return out
|
233 |
+
|
234 |
+
|
235 |
+
class SpeechToSpeechMultitaskDataset(SpeechToSpeechDataset):
|
236 |
+
def __init__(self, **kwargs):
|
237 |
+
super().__init__(**kwargs)
|
238 |
+
self.multitask_data = {}
|
239 |
+
|
240 |
+
def add_multitask_dataset(self, task_name, task_data):
|
241 |
+
self.multitask_data[task_name] = task_data
|
242 |
+
|
243 |
+
def __getitem__(
|
244 |
+
self, index: int
|
245 |
+
) -> Tuple[SpeechToSpeechDatasetItem, Dict[str, torch.Tensor]]:
|
246 |
+
s2s_data = super().__getitem__(index)
|
247 |
+
|
248 |
+
multitask_target = {}
|
249 |
+
sample_id = self.ids[index]
|
250 |
+
tgt_lang = self.tgt_langs[index]
|
251 |
+
for task_name, task_dataset in self.multitask_data.items():
|
252 |
+
multitask_target[task_name] = task_dataset.get(sample_id, tgt_lang)
|
253 |
+
|
254 |
+
return s2s_data, multitask_target
|
255 |
+
|
256 |
+
def collater(
|
257 |
+
self, samples: List[Tuple[SpeechToSpeechDatasetItem, Dict[str, torch.Tensor]]]
|
258 |
+
) -> Dict:
|
259 |
+
if len(samples) == 0:
|
260 |
+
return {}
|
261 |
+
|
262 |
+
out = super().collater([s for s, _ in samples], return_order=True)
|
263 |
+
order = out["order"]
|
264 |
+
del out["order"]
|
265 |
+
|
266 |
+
for task_name, task_dataset in self.multitask_data.items():
|
267 |
+
if "multitask" not in out:
|
268 |
+
out["multitask"] = {}
|
269 |
+
d = [s[task_name] for _, s in samples]
|
270 |
+
task_target = task_dataset.collater(d)
|
271 |
+
out["multitask"][task_name] = {
|
272 |
+
"target": task_target["target"].index_select(0, order),
|
273 |
+
"target_lengths": task_target["target_lengths"].index_select(0, order),
|
274 |
+
"ntokens": task_target["ntokens"],
|
275 |
+
}
|
276 |
+
out["multitask"][task_name]["net_input"] = {
|
277 |
+
"prev_output_tokens": task_target["prev_output_tokens"].index_select(
|
278 |
+
0, order
|
279 |
+
),
|
280 |
+
}
|
281 |
+
|
282 |
+
return out
|
283 |
+
|
284 |
+
|
285 |
+
class SpeechToSpeechDatasetCreator(object):
|
286 |
+
# mandatory columns
|
287 |
+
KEY_ID, KEY_SRC_AUDIO, KEY_SRC_N_FRAMES = "id", "src_audio", "src_n_frames"
|
288 |
+
KEY_TGT_AUDIO, KEY_TGT_N_FRAMES = "tgt_audio", "tgt_n_frames"
|
289 |
+
# optional columns
|
290 |
+
KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang"
|
291 |
+
# default values
|
292 |
+
DEFAULT_LANG = ""
|
293 |
+
|
294 |
+
@classmethod
|
295 |
+
def _from_list(
|
296 |
+
cls,
|
297 |
+
split_name: str,
|
298 |
+
is_train_split,
|
299 |
+
samples: List[Dict],
|
300 |
+
data_cfg: S2SDataConfig,
|
301 |
+
target_is_code: bool = False,
|
302 |
+
tgt_dict: Dictionary = None,
|
303 |
+
n_frames_per_step: int = 1,
|
304 |
+
multitask: Optional[Dict] = None,
|
305 |
+
) -> SpeechToSpeechDataset:
|
306 |
+
audio_root = Path(data_cfg.audio_root)
|
307 |
+
ids = [s[cls.KEY_ID] for s in samples]
|
308 |
+
src_audio_paths = [
|
309 |
+
(audio_root / s[cls.KEY_SRC_AUDIO]).as_posix() for s in samples
|
310 |
+
]
|
311 |
+
tgt_audio_paths = [
|
312 |
+
s[cls.KEY_TGT_AUDIO]
|
313 |
+
if target_is_code
|
314 |
+
else (audio_root / s[cls.KEY_TGT_AUDIO]).as_posix()
|
315 |
+
for s in samples
|
316 |
+
]
|
317 |
+
src_n_frames = [int(s[cls.KEY_SRC_N_FRAMES]) for s in samples]
|
318 |
+
tgt_n_frames = [int(s[cls.KEY_TGT_N_FRAMES]) for s in samples]
|
319 |
+
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
|
320 |
+
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
|
321 |
+
|
322 |
+
has_multitask = multitask is not None and len(multitask.keys()) > 0
|
323 |
+
dataset_cls = (
|
324 |
+
SpeechToSpeechMultitaskDataset if has_multitask else SpeechToSpeechDataset
|
325 |
+
)
|
326 |
+
|
327 |
+
ds = dataset_cls(
|
328 |
+
split=split_name,
|
329 |
+
is_train_split=is_train_split,
|
330 |
+
data_cfg=data_cfg,
|
331 |
+
src_audio_paths=src_audio_paths,
|
332 |
+
src_n_frames=src_n_frames,
|
333 |
+
tgt_audio_paths=tgt_audio_paths,
|
334 |
+
tgt_n_frames=tgt_n_frames,
|
335 |
+
src_langs=src_langs,
|
336 |
+
tgt_langs=tgt_langs,
|
337 |
+
ids=ids,
|
338 |
+
target_is_code=target_is_code,
|
339 |
+
tgt_dict=tgt_dict,
|
340 |
+
n_frames_per_step=n_frames_per_step,
|
341 |
+
)
|
342 |
+
|
343 |
+
if has_multitask:
|
344 |
+
for task_name, task_obj in multitask.items():
|
345 |
+
task_data = TextTargetMultitaskData(
|
346 |
+
task_obj.args, split_name, task_obj.target_dictionary
|
347 |
+
)
|
348 |
+
ds.add_multitask_dataset(task_name, task_data)
|
349 |
+
return ds
|
350 |
+
|
351 |
+
@classmethod
|
352 |
+
def from_tsv(
|
353 |
+
cls,
|
354 |
+
root: str,
|
355 |
+
data_cfg: S2SDataConfig,
|
356 |
+
splits: str,
|
357 |
+
is_train_split: bool,
|
358 |
+
epoch: int,
|
359 |
+
seed: int,
|
360 |
+
target_is_code: bool = False,
|
361 |
+
tgt_dict: Dictionary = None,
|
362 |
+
n_frames_per_step: int = 1,
|
363 |
+
multitask: Optional[Dict] = None,
|
364 |
+
) -> SpeechToSpeechDataset:
|
365 |
+
datasets = []
|
366 |
+
for split in splits.split(","):
|
367 |
+
samples = SpeechToTextDatasetCreator._load_samples_from_tsv(root, split)
|
368 |
+
ds = cls._from_list(
|
369 |
+
split_name=split,
|
370 |
+
is_train_split=is_train_split,
|
371 |
+
samples=samples,
|
372 |
+
data_cfg=data_cfg,
|
373 |
+
target_is_code=target_is_code,
|
374 |
+
tgt_dict=tgt_dict,
|
375 |
+
n_frames_per_step=n_frames_per_step,
|
376 |
+
multitask=multitask,
|
377 |
+
)
|
378 |
+
datasets.append(ds)
|
379 |
+
return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
|
modules/voice_conversion/fairseq/data/audio/speech_to_text_dataset.py
ADDED
@@ -0,0 +1,733 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import csv
|
7 |
+
import logging
|
8 |
+
import re
|
9 |
+
from argparse import Namespace
|
10 |
+
from collections import defaultdict
|
11 |
+
from dataclasses import dataclass
|
12 |
+
from pathlib import Path
|
13 |
+
from typing import Dict, List, Optional, Tuple, Union
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
|
19 |
+
from fairseq.data import ConcatDataset, Dictionary, FairseqDataset, ResamplingDataset
|
20 |
+
from fairseq.data import data_utils as fairseq_data_utils
|
21 |
+
from fairseq.data import encoders
|
22 |
+
from fairseq.data.audio.audio_utils import get_features_or_waveform
|
23 |
+
from fairseq.data.audio.data_cfg import S2TDataConfig
|
24 |
+
from fairseq.data.audio.dataset_transforms import CompositeAudioDatasetTransform
|
25 |
+
from fairseq.data.audio.dataset_transforms.concataugment import ConcatAugment
|
26 |
+
from fairseq.data.audio.dataset_transforms.noisyoverlapaugment import (
|
27 |
+
NoisyOverlapAugment,
|
28 |
+
)
|
29 |
+
from fairseq.data.audio.feature_transforms import CompositeAudioFeatureTransform
|
30 |
+
from fairseq.data.audio.waveform_transforms import CompositeAudioWaveformTransform
|
31 |
+
|
32 |
+
logger = logging.getLogger(__name__)
|
33 |
+
|
34 |
+
|
35 |
+
def _collate_frames(
|
36 |
+
frames: List[torch.Tensor], is_audio_input: bool = False
|
37 |
+
) -> torch.Tensor:
|
38 |
+
"""
|
39 |
+
Convert a list of 2D frames into a padded 3D tensor
|
40 |
+
Args:
|
41 |
+
frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
|
42 |
+
length of i-th frame and f_dim is static dimension of features
|
43 |
+
Returns:
|
44 |
+
3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
|
45 |
+
"""
|
46 |
+
max_len = max(frame.size(0) for frame in frames)
|
47 |
+
if is_audio_input:
|
48 |
+
out = frames[0].new_zeros((len(frames), max_len))
|
49 |
+
else:
|
50 |
+
out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1)))
|
51 |
+
for i, v in enumerate(frames):
|
52 |
+
out[i, : v.size(0)] = v
|
53 |
+
return out
|
54 |
+
|
55 |
+
|
56 |
+
def _is_int_or_np_int(n):
|
57 |
+
return isinstance(n, int) or (
|
58 |
+
isinstance(n, np.generic) and isinstance(n.item(), int)
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
@dataclass
|
63 |
+
class SpeechToTextDatasetItem(object):
|
64 |
+
index: int
|
65 |
+
source: torch.Tensor
|
66 |
+
target: Optional[torch.Tensor] = None
|
67 |
+
speaker_id: Optional[int] = None
|
68 |
+
|
69 |
+
|
70 |
+
class SpeechToTextDataset(FairseqDataset):
|
71 |
+
LANG_TAG_TEMPLATE = "<lang:{}>"
|
72 |
+
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
split: str,
|
76 |
+
is_train_split: bool,
|
77 |
+
cfg: S2TDataConfig,
|
78 |
+
audio_paths: List[str],
|
79 |
+
n_frames: List[int],
|
80 |
+
src_texts: Optional[List[str]] = None,
|
81 |
+
tgt_texts: Optional[List[str]] = None,
|
82 |
+
speakers: Optional[List[str]] = None,
|
83 |
+
src_langs: Optional[List[str]] = None,
|
84 |
+
tgt_langs: Optional[List[str]] = None,
|
85 |
+
ids: Optional[List[str]] = None,
|
86 |
+
tgt_dict: Optional[Dictionary] = None,
|
87 |
+
pre_tokenizer=None,
|
88 |
+
bpe_tokenizer=None,
|
89 |
+
n_frames_per_step=1,
|
90 |
+
speaker_to_id=None,
|
91 |
+
append_eos=True,
|
92 |
+
):
|
93 |
+
self.split, self.is_train_split = split, is_train_split
|
94 |
+
self.cfg = cfg
|
95 |
+
self.audio_paths, self.n_frames = audio_paths, n_frames
|
96 |
+
self.n_samples = len(audio_paths)
|
97 |
+
assert len(n_frames) == self.n_samples > 0
|
98 |
+
assert src_texts is None or len(src_texts) == self.n_samples
|
99 |
+
assert tgt_texts is None or len(tgt_texts) == self.n_samples
|
100 |
+
assert speakers is None or len(speakers) == self.n_samples
|
101 |
+
assert src_langs is None or len(src_langs) == self.n_samples
|
102 |
+
assert tgt_langs is None or len(tgt_langs) == self.n_samples
|
103 |
+
assert ids is None or len(ids) == self.n_samples
|
104 |
+
assert (tgt_dict is None and tgt_texts is None) or (
|
105 |
+
tgt_dict is not None and tgt_texts is not None
|
106 |
+
)
|
107 |
+
self.src_texts, self.tgt_texts = src_texts, tgt_texts
|
108 |
+
self.src_langs, self.tgt_langs = src_langs, tgt_langs
|
109 |
+
self.speakers = speakers
|
110 |
+
self.tgt_dict = tgt_dict
|
111 |
+
self.check_tgt_lang_tag()
|
112 |
+
self.ids = ids
|
113 |
+
self.shuffle = cfg.shuffle if is_train_split else False
|
114 |
+
|
115 |
+
self.feature_transforms = CompositeAudioFeatureTransform.from_config_dict(
|
116 |
+
self.cfg.get_feature_transforms(split, is_train_split)
|
117 |
+
)
|
118 |
+
self.waveform_transforms = CompositeAudioWaveformTransform.from_config_dict(
|
119 |
+
self.cfg.get_waveform_transforms(split, is_train_split)
|
120 |
+
)
|
121 |
+
# TODO: add these to data_cfg.py
|
122 |
+
self.dataset_transforms = CompositeAudioDatasetTransform.from_config_dict(
|
123 |
+
self.cfg.get_dataset_transforms(split, is_train_split)
|
124 |
+
)
|
125 |
+
|
126 |
+
# check proper usage of transforms
|
127 |
+
if self.feature_transforms and self.cfg.use_audio_input:
|
128 |
+
logger.warning(
|
129 |
+
"Feature transforms will not be applied. To use feature transforms, "
|
130 |
+
"set use_audio_input as False in config."
|
131 |
+
)
|
132 |
+
|
133 |
+
self.pre_tokenizer = pre_tokenizer
|
134 |
+
self.bpe_tokenizer = bpe_tokenizer
|
135 |
+
self.n_frames_per_step = n_frames_per_step
|
136 |
+
self.speaker_to_id = speaker_to_id
|
137 |
+
|
138 |
+
self.tgt_lens = self.get_tgt_lens_and_check_oov()
|
139 |
+
self.append_eos = append_eos
|
140 |
+
|
141 |
+
logger.info(self.__repr__())
|
142 |
+
|
143 |
+
def get_tgt_lens_and_check_oov(self):
|
144 |
+
if self.tgt_texts is None:
|
145 |
+
return [0 for _ in range(self.n_samples)]
|
146 |
+
tgt_lens = []
|
147 |
+
n_tokens, n_oov_tokens = 0, 0
|
148 |
+
for i in range(self.n_samples):
|
149 |
+
tokenized = self.get_tokenized_tgt_text(i).split(" ")
|
150 |
+
oov_tokens = [
|
151 |
+
t
|
152 |
+
for t in tokenized
|
153 |
+
if self.tgt_dict.index(t) == self.tgt_dict.unk_index
|
154 |
+
]
|
155 |
+
n_tokens += len(tokenized)
|
156 |
+
n_oov_tokens += len(oov_tokens)
|
157 |
+
tgt_lens.append(len(tokenized))
|
158 |
+
logger.info(f"'{self.split}' has {n_oov_tokens / n_tokens * 100:.2f}% OOV")
|
159 |
+
return tgt_lens
|
160 |
+
|
161 |
+
def __repr__(self):
|
162 |
+
return (
|
163 |
+
self.__class__.__name__
|
164 |
+
+ f'(split="{self.split}", n_samples={self.n_samples:_}, '
|
165 |
+
f"prepend_tgt_lang_tag={self.cfg.prepend_tgt_lang_tag}, "
|
166 |
+
f"n_frames_per_step={self.n_frames_per_step}, "
|
167 |
+
f"shuffle={self.shuffle}, "
|
168 |
+
f"feature_transforms={self.feature_transforms}, "
|
169 |
+
f"waveform_transforms={self.waveform_transforms}, "
|
170 |
+
f"dataset_transforms={self.dataset_transforms})"
|
171 |
+
)
|
172 |
+
|
173 |
+
@classmethod
|
174 |
+
def is_lang_tag(cls, token):
|
175 |
+
pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)")
|
176 |
+
return re.match(pattern, token)
|
177 |
+
|
178 |
+
def check_tgt_lang_tag(self):
|
179 |
+
if self.cfg.prepend_tgt_lang_tag:
|
180 |
+
assert self.tgt_langs is not None and self.tgt_dict is not None
|
181 |
+
tgt_lang_tags = [
|
182 |
+
self.LANG_TAG_TEMPLATE.format(t) for t in set(self.tgt_langs)
|
183 |
+
]
|
184 |
+
assert all(t in self.tgt_dict for t in tgt_lang_tags)
|
185 |
+
|
186 |
+
@classmethod
|
187 |
+
def tokenize(cls, tokenizer, text: str):
|
188 |
+
return text if tokenizer is None else tokenizer.encode(text)
|
189 |
+
|
190 |
+
def get_tokenized_tgt_text(self, index: Union[int, List[int]]):
|
191 |
+
if _is_int_or_np_int(index):
|
192 |
+
text = self.tgt_texts[index]
|
193 |
+
else:
|
194 |
+
text = " ".join([self.tgt_texts[i] for i in index])
|
195 |
+
|
196 |
+
text = self.tokenize(self.pre_tokenizer, text)
|
197 |
+
text = self.tokenize(self.bpe_tokenizer, text)
|
198 |
+
return text
|
199 |
+
|
200 |
+
def pack_frames(self, feature: torch.Tensor):
|
201 |
+
if self.n_frames_per_step == 1:
|
202 |
+
return feature
|
203 |
+
n_packed_frames = feature.shape[0] // self.n_frames_per_step
|
204 |
+
feature = feature[: self.n_frames_per_step * n_packed_frames]
|
205 |
+
return feature.reshape(n_packed_frames, -1)
|
206 |
+
|
207 |
+
@classmethod
|
208 |
+
def get_lang_tag_idx(cls, lang: str, dictionary: Dictionary):
|
209 |
+
lang_tag_idx = dictionary.index(cls.LANG_TAG_TEMPLATE.format(lang))
|
210 |
+
assert lang_tag_idx != dictionary.unk()
|
211 |
+
return lang_tag_idx
|
212 |
+
|
213 |
+
def _get_source_audio(self, index: Union[int, List[int]]) -> torch.Tensor:
|
214 |
+
"""
|
215 |
+
Gives source audio for given index with any relevant transforms
|
216 |
+
applied. For ConcatAug, source audios for given indices are
|
217 |
+
concatenated in given order.
|
218 |
+
Args:
|
219 |
+
index (int or List[int]): index—or in the case of ConcatAug,
|
220 |
+
indices—to pull the source audio for
|
221 |
+
Returns:
|
222 |
+
source audios concatenated for given indices with
|
223 |
+
relevant transforms appplied
|
224 |
+
"""
|
225 |
+
if _is_int_or_np_int(index):
|
226 |
+
source = get_features_or_waveform(
|
227 |
+
self.audio_paths[index],
|
228 |
+
need_waveform=self.cfg.use_audio_input,
|
229 |
+
use_sample_rate=self.cfg.use_sample_rate,
|
230 |
+
waveform_transforms=self.waveform_transforms,
|
231 |
+
)
|
232 |
+
else:
|
233 |
+
source = np.concatenate(
|
234 |
+
[
|
235 |
+
get_features_or_waveform(
|
236 |
+
self.audio_paths[i],
|
237 |
+
need_waveform=self.cfg.use_audio_input,
|
238 |
+
use_sample_rate=self.cfg.use_sample_rate,
|
239 |
+
waveform_transforms=self.waveform_transforms,
|
240 |
+
)
|
241 |
+
for i in index
|
242 |
+
]
|
243 |
+
)
|
244 |
+
if self.cfg.use_audio_input:
|
245 |
+
source = torch.from_numpy(source).float()
|
246 |
+
if self.cfg.standardize_audio:
|
247 |
+
with torch.no_grad():
|
248 |
+
source = F.layer_norm(source, source.shape)
|
249 |
+
else:
|
250 |
+
if self.feature_transforms is not None:
|
251 |
+
source = self.feature_transforms(source)
|
252 |
+
source = torch.from_numpy(source).float()
|
253 |
+
return source
|
254 |
+
|
255 |
+
def __getitem__(self, index: int) -> SpeechToTextDatasetItem:
|
256 |
+
has_concat = self.dataset_transforms.has_transform(ConcatAugment)
|
257 |
+
if has_concat:
|
258 |
+
concat = self.dataset_transforms.get_transform(ConcatAugment)
|
259 |
+
indices = concat.find_indices(index, self.n_frames, self.n_samples)
|
260 |
+
|
261 |
+
source = self._get_source_audio(indices if has_concat else index)
|
262 |
+
source = self.pack_frames(source)
|
263 |
+
|
264 |
+
target = None
|
265 |
+
if self.tgt_texts is not None:
|
266 |
+
tokenized = self.get_tokenized_tgt_text(indices if has_concat else index)
|
267 |
+
target = self.tgt_dict.encode_line(
|
268 |
+
tokenized, add_if_not_exist=False, append_eos=self.append_eos
|
269 |
+
).long()
|
270 |
+
if self.cfg.prepend_tgt_lang_tag:
|
271 |
+
lang_tag_idx = self.get_lang_tag_idx(
|
272 |
+
self.tgt_langs[index], self.tgt_dict
|
273 |
+
)
|
274 |
+
target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0)
|
275 |
+
|
276 |
+
if self.cfg.prepend_bos_and_append_tgt_lang_tag:
|
277 |
+
bos = torch.LongTensor([self.tgt_dict.bos()])
|
278 |
+
lang_tag_idx = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict)
|
279 |
+
assert lang_tag_idx != self.tgt_dict.unk()
|
280 |
+
lang_tag_idx = torch.LongTensor([lang_tag_idx])
|
281 |
+
target = torch.cat((bos, target, lang_tag_idx), 0)
|
282 |
+
|
283 |
+
speaker_id = None
|
284 |
+
if self.speaker_to_id is not None:
|
285 |
+
speaker_id = self.speaker_to_id[self.speakers[index]]
|
286 |
+
return SpeechToTextDatasetItem(
|
287 |
+
index=index, source=source, target=target, speaker_id=speaker_id
|
288 |
+
)
|
289 |
+
|
290 |
+
def __len__(self):
|
291 |
+
return self.n_samples
|
292 |
+
|
293 |
+
def collater(
|
294 |
+
self, samples: List[SpeechToTextDatasetItem], return_order: bool = False
|
295 |
+
) -> Dict:
|
296 |
+
if len(samples) == 0:
|
297 |
+
return {}
|
298 |
+
indices = torch.tensor([x.index for x in samples], dtype=torch.long)
|
299 |
+
|
300 |
+
sources = [x.source for x in samples]
|
301 |
+
has_NOAug = self.dataset_transforms.has_transform(NoisyOverlapAugment)
|
302 |
+
if has_NOAug and self.cfg.use_audio_input:
|
303 |
+
NOAug = self.dataset_transforms.get_transform(NoisyOverlapAugment)
|
304 |
+
sources = NOAug(sources)
|
305 |
+
|
306 |
+
frames = _collate_frames(sources, self.cfg.use_audio_input)
|
307 |
+
# sort samples by descending number of frames
|
308 |
+
n_frames = torch.tensor([x.size(0) for x in sources], dtype=torch.long)
|
309 |
+
n_frames, order = n_frames.sort(descending=True)
|
310 |
+
indices = indices.index_select(0, order)
|
311 |
+
frames = frames.index_select(0, order)
|
312 |
+
|
313 |
+
target, target_lengths = None, None
|
314 |
+
prev_output_tokens = None
|
315 |
+
ntokens = None
|
316 |
+
if self.tgt_texts is not None:
|
317 |
+
target = fairseq_data_utils.collate_tokens(
|
318 |
+
[x.target for x in samples],
|
319 |
+
self.tgt_dict.pad(),
|
320 |
+
self.tgt_dict.eos(),
|
321 |
+
left_pad=False,
|
322 |
+
move_eos_to_beginning=False,
|
323 |
+
)
|
324 |
+
target = target.index_select(0, order)
|
325 |
+
target_lengths = torch.tensor(
|
326 |
+
[x.target.size(0) for x in samples], dtype=torch.long
|
327 |
+
).index_select(0, order)
|
328 |
+
prev_output_tokens = fairseq_data_utils.collate_tokens(
|
329 |
+
[x.target for x in samples],
|
330 |
+
self.tgt_dict.pad(),
|
331 |
+
eos_idx=None,
|
332 |
+
left_pad=False,
|
333 |
+
move_eos_to_beginning=True,
|
334 |
+
)
|
335 |
+
prev_output_tokens = prev_output_tokens.index_select(0, order)
|
336 |
+
ntokens = sum(x.target.size(0) for x in samples)
|
337 |
+
|
338 |
+
speaker = None
|
339 |
+
if self.speaker_to_id is not None:
|
340 |
+
speaker = (
|
341 |
+
torch.tensor([s.speaker_id for s in samples], dtype=torch.long)
|
342 |
+
.index_select(0, order)
|
343 |
+
.view(-1, 1)
|
344 |
+
)
|
345 |
+
|
346 |
+
net_input = {
|
347 |
+
"src_tokens": frames,
|
348 |
+
"src_lengths": n_frames,
|
349 |
+
"prev_output_tokens": prev_output_tokens,
|
350 |
+
}
|
351 |
+
out = {
|
352 |
+
"id": indices,
|
353 |
+
"net_input": net_input,
|
354 |
+
"speaker": speaker,
|
355 |
+
"target": target,
|
356 |
+
"target_lengths": target_lengths,
|
357 |
+
"ntokens": ntokens,
|
358 |
+
"nsentences": len(samples),
|
359 |
+
}
|
360 |
+
if return_order:
|
361 |
+
out["order"] = order
|
362 |
+
return out
|
363 |
+
|
364 |
+
def num_tokens(self, index):
|
365 |
+
return self.n_frames[index]
|
366 |
+
|
367 |
+
def size(self, index):
|
368 |
+
return self.n_frames[index], self.tgt_lens[index]
|
369 |
+
|
370 |
+
@property
|
371 |
+
def sizes(self):
|
372 |
+
return np.array(self.n_frames)
|
373 |
+
|
374 |
+
@property
|
375 |
+
def can_reuse_epoch_itr_across_epochs(self):
|
376 |
+
return True
|
377 |
+
|
378 |
+
def ordered_indices(self):
|
379 |
+
if self.shuffle:
|
380 |
+
order = [np.random.permutation(len(self))]
|
381 |
+
else:
|
382 |
+
order = [np.arange(len(self))]
|
383 |
+
# first by descending order of # of frames then by original/random order
|
384 |
+
order.append([-n for n in self.n_frames])
|
385 |
+
return np.lexsort(order)
|
386 |
+
|
387 |
+
def prefetch(self, indices):
|
388 |
+
raise False
|
389 |
+
|
390 |
+
|
391 |
+
class TextTargetMultitaskData(object):
|
392 |
+
# mandatory columns
|
393 |
+
KEY_ID, KEY_TEXT = "id", "tgt_text"
|
394 |
+
LANG_TAG_TEMPLATE = "<lang:{}>"
|
395 |
+
|
396 |
+
def __init__(self, args, split, tgt_dict):
|
397 |
+
samples = SpeechToTextDatasetCreator._load_samples_from_tsv(args.data, split)
|
398 |
+
self.data = {s[self.KEY_ID]: s[self.KEY_TEXT] for s in samples}
|
399 |
+
self.dict = tgt_dict
|
400 |
+
self.append_eos = args.decoder_type != "ctc"
|
401 |
+
self.pre_tokenizer = self.build_tokenizer(args)
|
402 |
+
self.bpe_tokenizer = self.build_bpe(args)
|
403 |
+
self.prepend_bos_and_append_tgt_lang_tag = (
|
404 |
+
args.prepend_bos_and_append_tgt_lang_tag
|
405 |
+
)
|
406 |
+
self.eos_token = args.eos_token
|
407 |
+
self.lang_tag_mapping = args.get_lang_tag_mapping
|
408 |
+
|
409 |
+
@classmethod
|
410 |
+
def is_lang_tag(cls, token):
|
411 |
+
pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)")
|
412 |
+
return re.match(pattern, token)
|
413 |
+
|
414 |
+
@classmethod
|
415 |
+
def tokenize(cls, tokenizer, text: str):
|
416 |
+
return text if tokenizer is None else tokenizer.encode(text)
|
417 |
+
|
418 |
+
def get_tokenized_tgt_text(self, index: int):
|
419 |
+
text = self.tokenize(self.pre_tokenizer, self.data[index])
|
420 |
+
text = self.tokenize(self.bpe_tokenizer, text)
|
421 |
+
return text
|
422 |
+
|
423 |
+
def get_lang_tag_idx(self, lang: str, dictionary: Dictionary):
|
424 |
+
lang_tag = self.LANG_TAG_TEMPLATE.format(lang)
|
425 |
+
lang_tag = self.lang_tag_mapping.get(lang_tag, lang_tag)
|
426 |
+
lang_tag_idx = dictionary.index(lang_tag)
|
427 |
+
assert lang_tag_idx != dictionary.unk(), (lang, lang_tag)
|
428 |
+
return lang_tag_idx
|
429 |
+
|
430 |
+
def build_tokenizer(self, args):
|
431 |
+
pre_tokenizer = args.config.get("pre_tokenizer")
|
432 |
+
if pre_tokenizer is not None:
|
433 |
+
logger.info(f"pre-tokenizer: {pre_tokenizer}")
|
434 |
+
return encoders.build_tokenizer(Namespace(**pre_tokenizer))
|
435 |
+
else:
|
436 |
+
return None
|
437 |
+
|
438 |
+
def build_bpe(self, args):
|
439 |
+
bpe_tokenizer = args.config.get("bpe_tokenizer")
|
440 |
+
if bpe_tokenizer is not None:
|
441 |
+
logger.info(f"tokenizer: {bpe_tokenizer}")
|
442 |
+
return encoders.build_bpe(Namespace(**bpe_tokenizer))
|
443 |
+
else:
|
444 |
+
return None
|
445 |
+
|
446 |
+
def get(self, sample_id, tgt_lang=None):
|
447 |
+
if sample_id in self.data:
|
448 |
+
tokenized = self.get_tokenized_tgt_text(sample_id)
|
449 |
+
target = self.dict.encode_line(
|
450 |
+
tokenized,
|
451 |
+
add_if_not_exist=False,
|
452 |
+
append_eos=self.append_eos,
|
453 |
+
)
|
454 |
+
if self.prepend_bos_and_append_tgt_lang_tag:
|
455 |
+
bos = torch.LongTensor([self.dict.bos()])
|
456 |
+
lang_tag_idx = self.get_lang_tag_idx(tgt_lang, self.dict)
|
457 |
+
assert lang_tag_idx != self.dict.unk()
|
458 |
+
lang_tag_idx = torch.LongTensor([lang_tag_idx])
|
459 |
+
target = torch.cat((bos, target, lang_tag_idx), 0)
|
460 |
+
return target
|
461 |
+
else:
|
462 |
+
logger.warning(f"no target for {sample_id}")
|
463 |
+
return torch.IntTensor([])
|
464 |
+
|
465 |
+
def collater(self, samples: List[torch.Tensor]) -> torch.Tensor:
|
466 |
+
out = fairseq_data_utils.collate_tokens(
|
467 |
+
samples,
|
468 |
+
self.dict.pad(),
|
469 |
+
eos_idx=None,
|
470 |
+
left_pad=False,
|
471 |
+
move_eos_to_beginning=False,
|
472 |
+
).long()
|
473 |
+
|
474 |
+
prev_out = fairseq_data_utils.collate_tokens(
|
475 |
+
samples,
|
476 |
+
self.dict.pad(),
|
477 |
+
eos_idx=None,
|
478 |
+
left_pad=False,
|
479 |
+
move_eos_to_beginning=True,
|
480 |
+
).long()
|
481 |
+
|
482 |
+
target_lengths = torch.tensor([t.size(0) for t in samples], dtype=torch.long)
|
483 |
+
ntokens = sum(t.size(0) for t in samples)
|
484 |
+
|
485 |
+
output = {
|
486 |
+
"prev_output_tokens": prev_out,
|
487 |
+
"target": out,
|
488 |
+
"target_lengths": target_lengths,
|
489 |
+
"ntokens": ntokens,
|
490 |
+
}
|
491 |
+
|
492 |
+
return output
|
493 |
+
|
494 |
+
|
495 |
+
class SpeechToTextMultitaskDataset(SpeechToTextDataset):
|
496 |
+
def __init__(self, **kwargs):
|
497 |
+
super().__init__(**kwargs)
|
498 |
+
self.multitask_data = {}
|
499 |
+
|
500 |
+
def add_multitask_dataset(self, task_name, task_data):
|
501 |
+
self.multitask_data[task_name] = task_data
|
502 |
+
|
503 |
+
def __getitem__(
|
504 |
+
self, index: int
|
505 |
+
) -> Tuple[SpeechToTextDatasetItem, Dict[str, torch.Tensor]]:
|
506 |
+
s2t_data = super().__getitem__(index)
|
507 |
+
|
508 |
+
multitask_target = {}
|
509 |
+
sample_id = self.ids[index]
|
510 |
+
tgt_lang = self.tgt_langs[index]
|
511 |
+
for task_name, task_dataset in self.multitask_data.items():
|
512 |
+
multitask_target[task_name] = task_dataset.get(sample_id, tgt_lang)
|
513 |
+
|
514 |
+
return s2t_data, multitask_target
|
515 |
+
|
516 |
+
def collater(
|
517 |
+
self, samples: List[Tuple[SpeechToTextDatasetItem, Dict[str, torch.Tensor]]]
|
518 |
+
) -> Dict:
|
519 |
+
if len(samples) == 0:
|
520 |
+
return {}
|
521 |
+
|
522 |
+
out = super().collater([s for s, _ in samples], return_order=True)
|
523 |
+
order = out["order"]
|
524 |
+
del out["order"]
|
525 |
+
|
526 |
+
for task_name, task_dataset in self.multitask_data.items():
|
527 |
+
if "multitask" not in out:
|
528 |
+
out["multitask"] = {}
|
529 |
+
d = [s[task_name] for _, s in samples]
|
530 |
+
task_target = task_dataset.collater(d)
|
531 |
+
out["multitask"][task_name] = {
|
532 |
+
"target": task_target["target"].index_select(0, order),
|
533 |
+
"target_lengths": task_target["target_lengths"].index_select(0, order),
|
534 |
+
"ntokens": task_target["ntokens"],
|
535 |
+
}
|
536 |
+
out["multitask"][task_name]["net_input"] = {
|
537 |
+
"prev_output_tokens": task_target["prev_output_tokens"].index_select(
|
538 |
+
0, order
|
539 |
+
),
|
540 |
+
}
|
541 |
+
|
542 |
+
return out
|
543 |
+
|
544 |
+
|
545 |
+
class SpeechToTextDatasetCreator(object):
|
546 |
+
# mandatory columns
|
547 |
+
KEY_ID, KEY_AUDIO, KEY_N_FRAMES = "id", "audio", "n_frames"
|
548 |
+
KEY_TGT_TEXT = "tgt_text"
|
549 |
+
# optional columns
|
550 |
+
KEY_SPEAKER, KEY_SRC_TEXT = "speaker", "src_text"
|
551 |
+
KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang"
|
552 |
+
# default values
|
553 |
+
DEFAULT_SPEAKER = DEFAULT_SRC_TEXT = DEFAULT_LANG = ""
|
554 |
+
|
555 |
+
@classmethod
|
556 |
+
def _from_list(
|
557 |
+
cls,
|
558 |
+
split_name: str,
|
559 |
+
is_train_split,
|
560 |
+
samples: List[Dict],
|
561 |
+
cfg: S2TDataConfig,
|
562 |
+
tgt_dict,
|
563 |
+
pre_tokenizer,
|
564 |
+
bpe_tokenizer,
|
565 |
+
n_frames_per_step,
|
566 |
+
speaker_to_id,
|
567 |
+
multitask: Optional[Dict] = None,
|
568 |
+
) -> SpeechToTextDataset:
|
569 |
+
audio_root = Path(cfg.audio_root)
|
570 |
+
ids = [s[cls.KEY_ID] for s in samples]
|
571 |
+
audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
|
572 |
+
n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
|
573 |
+
tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
|
574 |
+
src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
|
575 |
+
speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
|
576 |
+
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
|
577 |
+
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
|
578 |
+
|
579 |
+
has_multitask = multitask is not None and len(multitask.keys()) > 0
|
580 |
+
dataset_cls = (
|
581 |
+
SpeechToTextMultitaskDataset if has_multitask else SpeechToTextDataset
|
582 |
+
)
|
583 |
+
|
584 |
+
ds = dataset_cls(
|
585 |
+
split=split_name,
|
586 |
+
is_train_split=is_train_split,
|
587 |
+
cfg=cfg,
|
588 |
+
audio_paths=audio_paths,
|
589 |
+
n_frames=n_frames,
|
590 |
+
src_texts=src_texts,
|
591 |
+
tgt_texts=tgt_texts,
|
592 |
+
speakers=speakers,
|
593 |
+
src_langs=src_langs,
|
594 |
+
tgt_langs=tgt_langs,
|
595 |
+
ids=ids,
|
596 |
+
tgt_dict=tgt_dict,
|
597 |
+
pre_tokenizer=pre_tokenizer,
|
598 |
+
bpe_tokenizer=bpe_tokenizer,
|
599 |
+
n_frames_per_step=n_frames_per_step,
|
600 |
+
speaker_to_id=speaker_to_id,
|
601 |
+
)
|
602 |
+
|
603 |
+
if has_multitask:
|
604 |
+
for task_name, task_obj in multitask.items():
|
605 |
+
task_data = TextTargetMultitaskData(
|
606 |
+
task_obj.args, split_name, task_obj.target_dictionary
|
607 |
+
)
|
608 |
+
ds.add_multitask_dataset(task_name, task_data)
|
609 |
+
return ds
|
610 |
+
|
611 |
+
@classmethod
|
612 |
+
def get_size_ratios(
|
613 |
+
cls, datasets: List[SpeechToTextDataset], alpha: float = 1.0
|
614 |
+
) -> List[float]:
|
615 |
+
"""Size ratios for temperature-based sampling
|
616 |
+
(https://arxiv.org/abs/1907.05019)"""
|
617 |
+
|
618 |
+
id_to_lp, lp_to_sz = {}, defaultdict(int)
|
619 |
+
for ds in datasets:
|
620 |
+
lang_pairs = {f"{s}->{t}" for s, t in zip(ds.src_langs, ds.tgt_langs)}
|
621 |
+
assert len(lang_pairs) == 1
|
622 |
+
lang_pair = list(lang_pairs)[0]
|
623 |
+
id_to_lp[ds.split] = lang_pair
|
624 |
+
lp_to_sz[lang_pair] += sum(ds.n_frames)
|
625 |
+
|
626 |
+
sz_sum = sum(v for v in lp_to_sz.values())
|
627 |
+
lp_to_prob = {k: v / sz_sum for k, v in lp_to_sz.items()}
|
628 |
+
lp_to_tgt_prob = {k: v**alpha for k, v in lp_to_prob.items()}
|
629 |
+
prob_sum = sum(v for v in lp_to_tgt_prob.values())
|
630 |
+
lp_to_tgt_prob = {k: v / prob_sum for k, v in lp_to_tgt_prob.items()}
|
631 |
+
lp_to_sz_ratio = {
|
632 |
+
k: (lp_to_tgt_prob[k] * sz_sum) / v for k, v in lp_to_sz.items()
|
633 |
+
}
|
634 |
+
size_ratio = [lp_to_sz_ratio[id_to_lp[ds.split]] for ds in datasets]
|
635 |
+
|
636 |
+
p_formatted = {
|
637 |
+
k: f"{lp_to_prob[k]:.3f}->{lp_to_tgt_prob[k]:.3f}" for k in lp_to_sz
|
638 |
+
}
|
639 |
+
logger.info(f"sampling probability balancing: {p_formatted}")
|
640 |
+
sr_formatted = {ds.split: f"{r:.3f}" for ds, r in zip(datasets, size_ratio)}
|
641 |
+
logger.info(f"balanced sampling size ratio: {sr_formatted}")
|
642 |
+
return size_ratio
|
643 |
+
|
644 |
+
@classmethod
|
645 |
+
def _load_samples_from_tsv(cls, root: str, split: str):
|
646 |
+
tsv_path = Path(root) / f"{split}.tsv"
|
647 |
+
if not tsv_path.is_file():
|
648 |
+
raise FileNotFoundError(f"Dataset not found: {tsv_path}")
|
649 |
+
with open(tsv_path) as f:
|
650 |
+
reader = csv.DictReader(
|
651 |
+
f,
|
652 |
+
delimiter="\t",
|
653 |
+
quotechar=None,
|
654 |
+
doublequote=False,
|
655 |
+
lineterminator="\n",
|
656 |
+
quoting=csv.QUOTE_NONE,
|
657 |
+
)
|
658 |
+
samples = [dict(e) for e in reader]
|
659 |
+
if len(samples) == 0:
|
660 |
+
raise ValueError(f"Empty manifest: {tsv_path}")
|
661 |
+
return samples
|
662 |
+
|
663 |
+
@classmethod
|
664 |
+
def _from_tsv(
|
665 |
+
cls,
|
666 |
+
root: str,
|
667 |
+
cfg: S2TDataConfig,
|
668 |
+
split: str,
|
669 |
+
tgt_dict,
|
670 |
+
is_train_split: bool,
|
671 |
+
pre_tokenizer,
|
672 |
+
bpe_tokenizer,
|
673 |
+
n_frames_per_step,
|
674 |
+
speaker_to_id,
|
675 |
+
multitask: Optional[Dict] = None,
|
676 |
+
) -> SpeechToTextDataset:
|
677 |
+
samples = cls._load_samples_from_tsv(root, split)
|
678 |
+
return cls._from_list(
|
679 |
+
split,
|
680 |
+
is_train_split,
|
681 |
+
samples,
|
682 |
+
cfg,
|
683 |
+
tgt_dict,
|
684 |
+
pre_tokenizer,
|
685 |
+
bpe_tokenizer,
|
686 |
+
n_frames_per_step,
|
687 |
+
speaker_to_id,
|
688 |
+
multitask,
|
689 |
+
)
|
690 |
+
|
691 |
+
@classmethod
|
692 |
+
def from_tsv(
|
693 |
+
cls,
|
694 |
+
root: str,
|
695 |
+
cfg: S2TDataConfig,
|
696 |
+
splits: str,
|
697 |
+
tgt_dict,
|
698 |
+
pre_tokenizer,
|
699 |
+
bpe_tokenizer,
|
700 |
+
is_train_split: bool,
|
701 |
+
epoch: int,
|
702 |
+
seed: int,
|
703 |
+
n_frames_per_step: int = 1,
|
704 |
+
speaker_to_id=None,
|
705 |
+
multitask: Optional[Dict] = None,
|
706 |
+
) -> SpeechToTextDataset:
|
707 |
+
datasets = [
|
708 |
+
cls._from_tsv(
|
709 |
+
root=root,
|
710 |
+
cfg=cfg,
|
711 |
+
split=split,
|
712 |
+
tgt_dict=tgt_dict,
|
713 |
+
is_train_split=is_train_split,
|
714 |
+
pre_tokenizer=pre_tokenizer,
|
715 |
+
bpe_tokenizer=bpe_tokenizer,
|
716 |
+
n_frames_per_step=n_frames_per_step,
|
717 |
+
speaker_to_id=speaker_to_id,
|
718 |
+
multitask=multitask,
|
719 |
+
)
|
720 |
+
for split in splits.split(",")
|
721 |
+
]
|
722 |
+
|
723 |
+
if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0:
|
724 |
+
# temperature-based sampling
|
725 |
+
size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha)
|
726 |
+
datasets = [
|
727 |
+
ResamplingDataset(
|
728 |
+
d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)
|
729 |
+
)
|
730 |
+
for r, d in zip(size_ratios, datasets)
|
731 |
+
]
|
732 |
+
|
733 |
+
return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
|
modules/voice_conversion/fairseq/data/audio/speech_to_text_joint_dataset.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Dict, List, NamedTuple, Optional
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from fairseq.data import ConcatDataset, Dictionary, ResamplingDataset
|
13 |
+
from fairseq.data import data_utils as fairseq_data_utils
|
14 |
+
from fairseq.data.audio.speech_to_text_dataset import (
|
15 |
+
S2TDataConfig,
|
16 |
+
SpeechToTextDataset,
|
17 |
+
SpeechToTextDatasetCreator,
|
18 |
+
)
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
|
23 |
+
class S2TJointDataConfig(S2TDataConfig):
|
24 |
+
"""Wrapper class for data config YAML"""
|
25 |
+
|
26 |
+
@property
|
27 |
+
def src_vocab_filename(self):
|
28 |
+
"""fairseq vocabulary file under data root"""
|
29 |
+
return self.config.get("src_vocab_filename", "src_dict.txt")
|
30 |
+
|
31 |
+
@property
|
32 |
+
def src_pre_tokenizer(self) -> Dict:
|
33 |
+
"""Pre-tokenizer to apply before subword tokenization. Returning
|
34 |
+
a dictionary with `tokenizer` providing the tokenizer name and
|
35 |
+
the other items providing the tokenizer-specific arguments.
|
36 |
+
Tokenizers are defined in `fairseq.data.encoders.*`"""
|
37 |
+
return self.config.get("src_pre_tokenizer", {"tokenizer": None})
|
38 |
+
|
39 |
+
@property
|
40 |
+
def src_bpe_tokenizer(self) -> Dict:
|
41 |
+
"""Subword tokenizer to apply on source text after pre-tokenization.
|
42 |
+
Returning a dictionary with `bpe` providing the tokenizer name and
|
43 |
+
the other items providing the tokenizer-specific arguments.
|
44 |
+
Tokenizers are defined in `fairseq.data.encoders.*`"""
|
45 |
+
return self.config.get("src_bpe_tokenizer", {"bpe": None})
|
46 |
+
|
47 |
+
@property
|
48 |
+
def prepend_tgt_lang_tag_no_change(self) -> bool:
|
49 |
+
"""Prepend target lang ID token as the prev_output_tokens BOS (e.g. for
|
50 |
+
to-many multilingual setting). No change needed during inference.
|
51 |
+
This option is deprecated and replaced by prepend_tgt_lang_tag_as_bos.
|
52 |
+
"""
|
53 |
+
value = self.config.get("prepend_tgt_lang_tag_no_change", None)
|
54 |
+
if value is None:
|
55 |
+
return self.config.get("prepend_tgt_lang_tag_as_bos", False)
|
56 |
+
return value
|
57 |
+
|
58 |
+
@property
|
59 |
+
def sampling_text_alpha(self):
|
60 |
+
"""Hyper-parameter alpha = 1/T for temperature-based resampling. (text
|
61 |
+
input only) (alpha = 1 for no resampling)"""
|
62 |
+
return self.config.get("sampling_text_alpha", 1.0)
|
63 |
+
|
64 |
+
|
65 |
+
class SpeechToTextJointDatasetItem(NamedTuple):
|
66 |
+
index: int
|
67 |
+
source: torch.Tensor
|
68 |
+
target: Optional[torch.Tensor] = None
|
69 |
+
src_txt_tokens: Optional[torch.Tensor] = None
|
70 |
+
tgt_lang_tag: Optional[int] = None
|
71 |
+
src_lang_tag: Optional[int] = None
|
72 |
+
tgt_alignment: Optional[torch.Tensor] = None
|
73 |
+
|
74 |
+
|
75 |
+
# use_src_lang_id:
|
76 |
+
# 0: don't use src_lang_id
|
77 |
+
# 1: attach src_lang_id to the src_txt_tokens as eos
|
78 |
+
class SpeechToTextJointDataset(SpeechToTextDataset):
|
79 |
+
def __init__(
|
80 |
+
self,
|
81 |
+
split: str,
|
82 |
+
is_train_split: bool,
|
83 |
+
cfg: S2TJointDataConfig,
|
84 |
+
audio_paths: List[str],
|
85 |
+
n_frames: List[int],
|
86 |
+
src_texts: Optional[List[str]] = None,
|
87 |
+
tgt_texts: Optional[List[str]] = None,
|
88 |
+
speakers: Optional[List[str]] = None,
|
89 |
+
src_langs: Optional[List[str]] = None,
|
90 |
+
tgt_langs: Optional[List[str]] = None,
|
91 |
+
ids: Optional[List[str]] = None,
|
92 |
+
tgt_dict: Optional[Dictionary] = None,
|
93 |
+
src_dict: Optional[Dictionary] = None,
|
94 |
+
pre_tokenizer=None,
|
95 |
+
bpe_tokenizer=None,
|
96 |
+
src_pre_tokenizer=None,
|
97 |
+
src_bpe_tokenizer=None,
|
98 |
+
append_eos: Optional[bool] = True,
|
99 |
+
alignment: Optional[List[str]] = None,
|
100 |
+
use_src_lang_id: Optional[int] = 0,
|
101 |
+
):
|
102 |
+
super().__init__(
|
103 |
+
split,
|
104 |
+
is_train_split,
|
105 |
+
cfg,
|
106 |
+
audio_paths,
|
107 |
+
n_frames,
|
108 |
+
src_texts=src_texts,
|
109 |
+
tgt_texts=tgt_texts,
|
110 |
+
speakers=speakers,
|
111 |
+
src_langs=src_langs,
|
112 |
+
tgt_langs=tgt_langs,
|
113 |
+
ids=ids,
|
114 |
+
tgt_dict=tgt_dict,
|
115 |
+
pre_tokenizer=pre_tokenizer,
|
116 |
+
bpe_tokenizer=bpe_tokenizer,
|
117 |
+
append_eos=append_eos,
|
118 |
+
)
|
119 |
+
|
120 |
+
self.src_dict = src_dict
|
121 |
+
self.src_pre_tokenizer = src_pre_tokenizer
|
122 |
+
self.src_bpe_tokenizer = src_bpe_tokenizer
|
123 |
+
self.alignment = None
|
124 |
+
self.use_src_lang_id = use_src_lang_id
|
125 |
+
if alignment is not None:
|
126 |
+
self.alignment = [
|
127 |
+
[float(s) for s in sample.split()] for sample in alignment
|
128 |
+
]
|
129 |
+
|
130 |
+
def get_tokenized_src_text(self, index: int):
|
131 |
+
text = self.tokenize(self.src_pre_tokenizer, self.src_texts[index])
|
132 |
+
text = self.tokenize(self.src_bpe_tokenizer, text)
|
133 |
+
return text
|
134 |
+
|
135 |
+
def __getitem__(self, index: int) -> SpeechToTextJointDatasetItem:
|
136 |
+
s2t_dataset_item = super().__getitem__(index)
|
137 |
+
src_tokens = None
|
138 |
+
src_lang_tag = None
|
139 |
+
if self.src_texts is not None and self.src_dict is not None:
|
140 |
+
src_tokens = self.get_tokenized_src_text(index)
|
141 |
+
src_tokens = self.src_dict.encode_line(
|
142 |
+
src_tokens, add_if_not_exist=False, append_eos=True
|
143 |
+
).long()
|
144 |
+
if self.use_src_lang_id > 0:
|
145 |
+
src_lang_tag = self.get_lang_tag_idx(
|
146 |
+
self.src_langs[index], self.src_dict
|
147 |
+
)
|
148 |
+
tgt_lang_tag = None
|
149 |
+
if self.cfg.prepend_tgt_lang_tag_no_change:
|
150 |
+
# prepend_tgt_lang_tag_no_change: modify prev_output_tokens instead
|
151 |
+
tgt_lang_tag = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict)
|
152 |
+
ali = None
|
153 |
+
if self.alignment is not None:
|
154 |
+
ali = torch.Tensor(self.alignment[index]).float()
|
155 |
+
|
156 |
+
return SpeechToTextJointDatasetItem(
|
157 |
+
index=index,
|
158 |
+
source=s2t_dataset_item.source,
|
159 |
+
target=s2t_dataset_item.target,
|
160 |
+
src_txt_tokens=src_tokens,
|
161 |
+
tgt_lang_tag=tgt_lang_tag,
|
162 |
+
src_lang_tag=src_lang_tag,
|
163 |
+
tgt_alignment=ali,
|
164 |
+
)
|
165 |
+
|
166 |
+
def __len__(self):
|
167 |
+
return self.n_samples
|
168 |
+
|
169 |
+
def collater(self, samples: List[SpeechToTextJointDatasetItem]) -> Dict:
|
170 |
+
s2t_out = super().collater(samples, return_order=True)
|
171 |
+
if s2t_out == {}:
|
172 |
+
return s2t_out
|
173 |
+
net_input, order = s2t_out["net_input"], s2t_out["order"]
|
174 |
+
|
175 |
+
if self.src_texts is not None and self.src_dict is not None:
|
176 |
+
src_txt_tokens = fairseq_data_utils.collate_tokens(
|
177 |
+
[x.src_txt_tokens for x in samples],
|
178 |
+
self.src_dict.pad(),
|
179 |
+
self.src_dict.eos(),
|
180 |
+
left_pad=False,
|
181 |
+
move_eos_to_beginning=False,
|
182 |
+
)
|
183 |
+
src_txt_lengths = torch.tensor(
|
184 |
+
[x.src_txt_tokens.size()[0] for x in samples], dtype=torch.long
|
185 |
+
)
|
186 |
+
if self.use_src_lang_id > 0:
|
187 |
+
src_lang_idxs = torch.tensor(
|
188 |
+
[s.src_lang_tag for s in samples], dtype=src_txt_tokens.dtype
|
189 |
+
)
|
190 |
+
if self.use_src_lang_id == 1: # replace eos with lang_id
|
191 |
+
eos_idx = src_txt_lengths - 1
|
192 |
+
src_txt_tokens.scatter_(
|
193 |
+
1, eos_idx.view(-1, 1), src_lang_idxs.view(-1, 1)
|
194 |
+
)
|
195 |
+
else:
|
196 |
+
raise NotImplementedError("Implementation is required")
|
197 |
+
|
198 |
+
src_txt_tokens = src_txt_tokens.index_select(0, order)
|
199 |
+
src_txt_lengths = src_txt_lengths.index_select(0, order)
|
200 |
+
net_input["src_txt_tokens"] = src_txt_tokens
|
201 |
+
net_input["src_txt_lengths"] = src_txt_lengths
|
202 |
+
|
203 |
+
net_input["alignment"] = None
|
204 |
+
if self.alignment is not None:
|
205 |
+
max_len = max([s.tgt_alignment.size(0) for s in samples])
|
206 |
+
alignment = torch.ones(len(samples), max_len).float()
|
207 |
+
for i, s in enumerate(samples):
|
208 |
+
cur_len = s.tgt_alignment.size(0)
|
209 |
+
alignment[i][:cur_len].copy_(s.tgt_alignment)
|
210 |
+
net_input["alignment"] = alignment.index_select(0, order)
|
211 |
+
|
212 |
+
if self.tgt_texts is not None and samples[0].tgt_lang_tag is not None:
|
213 |
+
for i in range(len(samples)):
|
214 |
+
net_input["prev_output_tokens"][i][0] = samples[order[i]].tgt_lang_tag
|
215 |
+
|
216 |
+
out = {
|
217 |
+
"id": s2t_out["id"],
|
218 |
+
"net_input": net_input,
|
219 |
+
"target": s2t_out["target"],
|
220 |
+
"target_lengths": s2t_out["target_lengths"],
|
221 |
+
"ntokens": s2t_out["ntokens"],
|
222 |
+
"nsentences": len(samples),
|
223 |
+
}
|
224 |
+
return out
|
225 |
+
|
226 |
+
|
227 |
+
class SpeechToTextJointDatasetCreator(SpeechToTextDatasetCreator):
|
228 |
+
KEY_ALIGN = "align"
|
229 |
+
|
230 |
+
@classmethod
|
231 |
+
def _from_list(
|
232 |
+
cls,
|
233 |
+
split_name: str,
|
234 |
+
is_train_split,
|
235 |
+
samples: List[Dict],
|
236 |
+
cfg: S2TJointDataConfig,
|
237 |
+
tgt_dict,
|
238 |
+
src_dict,
|
239 |
+
pre_tokenizer,
|
240 |
+
bpe_tokenizer,
|
241 |
+
src_pre_tokenizer,
|
242 |
+
src_bpe_tokenizer,
|
243 |
+
append_eos,
|
244 |
+
use_src_lang_id,
|
245 |
+
) -> SpeechToTextJointDataset:
|
246 |
+
audio_root = Path(cfg.audio_root)
|
247 |
+
ids = [s[cls.KEY_ID] for s in samples]
|
248 |
+
audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
|
249 |
+
n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
|
250 |
+
tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
|
251 |
+
src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
|
252 |
+
speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
|
253 |
+
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
|
254 |
+
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
|
255 |
+
tgt_alignment = None
|
256 |
+
if cls.KEY_ALIGN in samples[0].keys():
|
257 |
+
tgt_alignment = [s[cls.KEY_ALIGN] for s in samples]
|
258 |
+
return SpeechToTextJointDataset(
|
259 |
+
split_name,
|
260 |
+
is_train_split,
|
261 |
+
cfg,
|
262 |
+
audio_paths,
|
263 |
+
n_frames,
|
264 |
+
src_texts=src_texts,
|
265 |
+
tgt_texts=tgt_texts,
|
266 |
+
speakers=speakers,
|
267 |
+
src_langs=src_langs,
|
268 |
+
tgt_langs=tgt_langs,
|
269 |
+
ids=ids,
|
270 |
+
tgt_dict=tgt_dict,
|
271 |
+
src_dict=src_dict,
|
272 |
+
pre_tokenizer=pre_tokenizer,
|
273 |
+
bpe_tokenizer=bpe_tokenizer,
|
274 |
+
src_pre_tokenizer=src_pre_tokenizer,
|
275 |
+
src_bpe_tokenizer=src_bpe_tokenizer,
|
276 |
+
append_eos=append_eos,
|
277 |
+
alignment=tgt_alignment,
|
278 |
+
use_src_lang_id=use_src_lang_id,
|
279 |
+
)
|
280 |
+
|
281 |
+
@classmethod
|
282 |
+
def _from_tsv(
|
283 |
+
cls,
|
284 |
+
root: str,
|
285 |
+
cfg: S2TJointDataConfig,
|
286 |
+
split: str,
|
287 |
+
tgt_dict,
|
288 |
+
src_dict,
|
289 |
+
is_train_split: bool,
|
290 |
+
pre_tokenizer,
|
291 |
+
bpe_tokenizer,
|
292 |
+
src_pre_tokenizer,
|
293 |
+
src_bpe_tokenizer,
|
294 |
+
append_eos: bool,
|
295 |
+
use_src_lang_id: int,
|
296 |
+
) -> SpeechToTextJointDataset:
|
297 |
+
samples = cls._load_samples_from_tsv(root, split)
|
298 |
+
return cls._from_list(
|
299 |
+
split,
|
300 |
+
is_train_split,
|
301 |
+
samples,
|
302 |
+
cfg,
|
303 |
+
tgt_dict,
|
304 |
+
src_dict,
|
305 |
+
pre_tokenizer,
|
306 |
+
bpe_tokenizer,
|
307 |
+
src_pre_tokenizer,
|
308 |
+
src_bpe_tokenizer,
|
309 |
+
append_eos,
|
310 |
+
use_src_lang_id,
|
311 |
+
)
|
312 |
+
|
313 |
+
@classmethod
|
314 |
+
def from_tsv(
|
315 |
+
cls,
|
316 |
+
root: str,
|
317 |
+
cfg: S2TJointDataConfig,
|
318 |
+
splits: str,
|
319 |
+
tgt_dict,
|
320 |
+
src_dict,
|
321 |
+
pre_tokenizer,
|
322 |
+
bpe_tokenizer,
|
323 |
+
src_pre_tokenizer,
|
324 |
+
src_bpe_tokenizer,
|
325 |
+
is_train_split: bool,
|
326 |
+
epoch: int,
|
327 |
+
seed: int,
|
328 |
+
append_eos: Optional[bool] = True,
|
329 |
+
use_src_lang_id: Optional[int] = 0,
|
330 |
+
) -> SpeechToTextJointDataset:
|
331 |
+
datasets = [
|
332 |
+
cls._from_tsv(
|
333 |
+
root,
|
334 |
+
cfg,
|
335 |
+
split,
|
336 |
+
tgt_dict,
|
337 |
+
src_dict,
|
338 |
+
is_train_split,
|
339 |
+
pre_tokenizer,
|
340 |
+
bpe_tokenizer,
|
341 |
+
src_pre_tokenizer,
|
342 |
+
src_bpe_tokenizer,
|
343 |
+
append_eos=append_eos,
|
344 |
+
use_src_lang_id=use_src_lang_id,
|
345 |
+
)
|
346 |
+
for split in splits.split(",")
|
347 |
+
]
|
348 |
+
|
349 |
+
if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0:
|
350 |
+
# temperature-based sampling
|
351 |
+
size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha)
|
352 |
+
datasets = [
|
353 |
+
ResamplingDataset(
|
354 |
+
d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)
|
355 |
+
)
|
356 |
+
for r, d in zip(size_ratios, datasets)
|
357 |
+
]
|
358 |
+
|
359 |
+
return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
|
modules/voice_conversion/fairseq/data/audio/text_to_speech_dataset.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the LICENSE file in
|
5 |
+
# the root directory of this source tree. An additional grant of patent rights
|
6 |
+
# can be found in the PATENTS file in the same directory.abs
|
7 |
+
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import Any, Dict, List, Optional
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
|
15 |
+
from fairseq.data import Dictionary
|
16 |
+
from fairseq.data import data_utils as fairseq_data_utils
|
17 |
+
from fairseq.data.audio.audio_utils import get_features_or_waveform
|
18 |
+
from fairseq.data.audio.speech_to_text_dataset import (
|
19 |
+
S2TDataConfig,
|
20 |
+
SpeechToTextDataset,
|
21 |
+
SpeechToTextDatasetCreator,
|
22 |
+
_collate_frames,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class TextToSpeechDatasetItem(object):
|
28 |
+
index: int
|
29 |
+
source: torch.Tensor
|
30 |
+
target: Optional[torch.Tensor] = None
|
31 |
+
speaker_id: Optional[int] = None
|
32 |
+
duration: Optional[torch.Tensor] = None
|
33 |
+
pitch: Optional[torch.Tensor] = None
|
34 |
+
energy: Optional[torch.Tensor] = None
|
35 |
+
|
36 |
+
|
37 |
+
class TextToSpeechDataset(SpeechToTextDataset):
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
split: str,
|
41 |
+
is_train_split: bool,
|
42 |
+
cfg: S2TDataConfig,
|
43 |
+
audio_paths: List[str],
|
44 |
+
n_frames: List[int],
|
45 |
+
src_texts: Optional[List[str]] = None,
|
46 |
+
tgt_texts: Optional[List[str]] = None,
|
47 |
+
speakers: Optional[List[str]] = None,
|
48 |
+
src_langs: Optional[List[str]] = None,
|
49 |
+
tgt_langs: Optional[List[str]] = None,
|
50 |
+
ids: Optional[List[str]] = None,
|
51 |
+
tgt_dict: Optional[Dictionary] = None,
|
52 |
+
pre_tokenizer=None,
|
53 |
+
bpe_tokenizer=None,
|
54 |
+
n_frames_per_step=1,
|
55 |
+
speaker_to_id=None,
|
56 |
+
durations: Optional[List[List[int]]] = None,
|
57 |
+
pitches: Optional[List[str]] = None,
|
58 |
+
energies: Optional[List[str]] = None,
|
59 |
+
):
|
60 |
+
super(TextToSpeechDataset, self).__init__(
|
61 |
+
split,
|
62 |
+
is_train_split,
|
63 |
+
cfg,
|
64 |
+
audio_paths,
|
65 |
+
n_frames,
|
66 |
+
src_texts=src_texts,
|
67 |
+
tgt_texts=tgt_texts,
|
68 |
+
speakers=speakers,
|
69 |
+
src_langs=src_langs,
|
70 |
+
tgt_langs=tgt_langs,
|
71 |
+
ids=ids,
|
72 |
+
tgt_dict=tgt_dict,
|
73 |
+
pre_tokenizer=pre_tokenizer,
|
74 |
+
bpe_tokenizer=bpe_tokenizer,
|
75 |
+
n_frames_per_step=n_frames_per_step,
|
76 |
+
speaker_to_id=speaker_to_id,
|
77 |
+
)
|
78 |
+
self.durations = durations
|
79 |
+
self.pitches = pitches
|
80 |
+
self.energies = energies
|
81 |
+
|
82 |
+
def __getitem__(self, index: int) -> TextToSpeechDatasetItem:
|
83 |
+
s2t_item = super().__getitem__(index)
|
84 |
+
|
85 |
+
duration, pitch, energy = None, None, None
|
86 |
+
if self.durations is not None:
|
87 |
+
duration = torch.tensor(
|
88 |
+
self.durations[index] + [0], dtype=torch.long # pad 0 for EOS
|
89 |
+
)
|
90 |
+
if self.pitches is not None:
|
91 |
+
pitch = get_features_or_waveform(self.pitches[index])
|
92 |
+
pitch = torch.from_numpy(
|
93 |
+
np.concatenate((pitch, [0])) # pad 0 for EOS
|
94 |
+
).float()
|
95 |
+
if self.energies is not None:
|
96 |
+
energy = get_features_or_waveform(self.energies[index])
|
97 |
+
energy = torch.from_numpy(
|
98 |
+
np.concatenate((energy, [0])) # pad 0 for EOS
|
99 |
+
).float()
|
100 |
+
return TextToSpeechDatasetItem(
|
101 |
+
index=index,
|
102 |
+
source=s2t_item.source,
|
103 |
+
target=s2t_item.target,
|
104 |
+
speaker_id=s2t_item.speaker_id,
|
105 |
+
duration=duration,
|
106 |
+
pitch=pitch,
|
107 |
+
energy=energy,
|
108 |
+
)
|
109 |
+
|
110 |
+
def collater(self, samples: List[TextToSpeechDatasetItem]) -> Dict[str, Any]:
|
111 |
+
if len(samples) == 0:
|
112 |
+
return {}
|
113 |
+
|
114 |
+
src_lengths, order = torch.tensor(
|
115 |
+
[s.target.shape[0] for s in samples], dtype=torch.long
|
116 |
+
).sort(descending=True)
|
117 |
+
id_ = torch.tensor([s.index for s in samples], dtype=torch.long).index_select(
|
118 |
+
0, order
|
119 |
+
)
|
120 |
+
feat = _collate_frames(
|
121 |
+
[s.source for s in samples], self.cfg.use_audio_input
|
122 |
+
).index_select(0, order)
|
123 |
+
target_lengths = torch.tensor(
|
124 |
+
[s.source.shape[0] for s in samples], dtype=torch.long
|
125 |
+
).index_select(0, order)
|
126 |
+
|
127 |
+
src_tokens = fairseq_data_utils.collate_tokens(
|
128 |
+
[s.target for s in samples],
|
129 |
+
self.tgt_dict.pad(),
|
130 |
+
self.tgt_dict.eos(),
|
131 |
+
left_pad=False,
|
132 |
+
move_eos_to_beginning=False,
|
133 |
+
).index_select(0, order)
|
134 |
+
|
135 |
+
speaker = None
|
136 |
+
if self.speaker_to_id is not None:
|
137 |
+
speaker = (
|
138 |
+
torch.tensor([s.speaker_id for s in samples], dtype=torch.long)
|
139 |
+
.index_select(0, order)
|
140 |
+
.view(-1, 1)
|
141 |
+
)
|
142 |
+
|
143 |
+
bsz, _, d = feat.size()
|
144 |
+
prev_output_tokens = torch.cat(
|
145 |
+
(feat.new_zeros((bsz, 1, d)), feat[:, :-1, :]), dim=1
|
146 |
+
)
|
147 |
+
|
148 |
+
durations, pitches, energies = None, None, None
|
149 |
+
if self.durations is not None:
|
150 |
+
durations = fairseq_data_utils.collate_tokens(
|
151 |
+
[s.duration for s in samples], 0
|
152 |
+
).index_select(0, order)
|
153 |
+
assert src_tokens.shape[1] == durations.shape[1]
|
154 |
+
if self.pitches is not None:
|
155 |
+
pitches = _collate_frames([s.pitch for s in samples], True)
|
156 |
+
pitches = pitches.index_select(0, order)
|
157 |
+
assert src_tokens.shape[1] == pitches.shape[1]
|
158 |
+
if self.energies is not None:
|
159 |
+
energies = _collate_frames([s.energy for s in samples], True)
|
160 |
+
energies = energies.index_select(0, order)
|
161 |
+
assert src_tokens.shape[1] == energies.shape[1]
|
162 |
+
src_texts = [self.tgt_dict.string(samples[i].target) for i in order]
|
163 |
+
|
164 |
+
return {
|
165 |
+
"id": id_,
|
166 |
+
"net_input": {
|
167 |
+
"src_tokens": src_tokens,
|
168 |
+
"src_lengths": src_lengths,
|
169 |
+
"prev_output_tokens": prev_output_tokens,
|
170 |
+
},
|
171 |
+
"speaker": speaker,
|
172 |
+
"target": feat,
|
173 |
+
"durations": durations,
|
174 |
+
"pitches": pitches,
|
175 |
+
"energies": energies,
|
176 |
+
"target_lengths": target_lengths,
|
177 |
+
"ntokens": sum(target_lengths).item(),
|
178 |
+
"nsentences": len(samples),
|
179 |
+
"src_texts": src_texts,
|
180 |
+
}
|
181 |
+
|
182 |
+
|
183 |
+
class TextToSpeechDatasetCreator(SpeechToTextDatasetCreator):
|
184 |
+
KEY_DURATION = "duration"
|
185 |
+
KEY_PITCH = "pitch"
|
186 |
+
KEY_ENERGY = "energy"
|
187 |
+
|
188 |
+
@classmethod
|
189 |
+
def _from_list(
|
190 |
+
cls,
|
191 |
+
split_name: str,
|
192 |
+
is_train_split,
|
193 |
+
samples: List[Dict],
|
194 |
+
cfg: S2TDataConfig,
|
195 |
+
tgt_dict,
|
196 |
+
pre_tokenizer,
|
197 |
+
bpe_tokenizer,
|
198 |
+
n_frames_per_step,
|
199 |
+
speaker_to_id,
|
200 |
+
multitask=None,
|
201 |
+
) -> TextToSpeechDataset:
|
202 |
+
audio_root = Path(cfg.audio_root)
|
203 |
+
ids = [s[cls.KEY_ID] for s in samples]
|
204 |
+
audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
|
205 |
+
n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
|
206 |
+
tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
|
207 |
+
src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
|
208 |
+
speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
|
209 |
+
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
|
210 |
+
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
|
211 |
+
|
212 |
+
durations = [s.get(cls.KEY_DURATION, None) for s in samples]
|
213 |
+
durations = [
|
214 |
+
None if dd is None else [int(d) for d in dd.split(" ")] for dd in durations
|
215 |
+
]
|
216 |
+
durations = None if any(dd is None for dd in durations) else durations
|
217 |
+
|
218 |
+
pitches = [s.get(cls.KEY_PITCH, None) for s in samples]
|
219 |
+
pitches = [
|
220 |
+
None if pp is None else (audio_root / pp).as_posix() for pp in pitches
|
221 |
+
]
|
222 |
+
pitches = None if any(pp is None for pp in pitches) else pitches
|
223 |
+
|
224 |
+
energies = [s.get(cls.KEY_ENERGY, None) for s in samples]
|
225 |
+
energies = [
|
226 |
+
None if ee is None else (audio_root / ee).as_posix() for ee in energies
|
227 |
+
]
|
228 |
+
energies = None if any(ee is None for ee in energies) else energies
|
229 |
+
|
230 |
+
return TextToSpeechDataset(
|
231 |
+
split_name,
|
232 |
+
is_train_split,
|
233 |
+
cfg,
|
234 |
+
audio_paths,
|
235 |
+
n_frames,
|
236 |
+
src_texts,
|
237 |
+
tgt_texts,
|
238 |
+
speakers,
|
239 |
+
src_langs,
|
240 |
+
tgt_langs,
|
241 |
+
ids,
|
242 |
+
tgt_dict,
|
243 |
+
pre_tokenizer,
|
244 |
+
bpe_tokenizer,
|
245 |
+
n_frames_per_step,
|
246 |
+
speaker_to_id,
|
247 |
+
durations,
|
248 |
+
pitches,
|
249 |
+
energies,
|
250 |
+
)
|
modules/voice_conversion/fairseq/data/audio/waveform_transforms/__init__.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from fairseq.data.audio import (
|
3 |
+
AudioTransform,
|
4 |
+
CompositeAudioTransform,
|
5 |
+
import_transforms,
|
6 |
+
register_audio_transform,
|
7 |
+
)
|
8 |
+
|
9 |
+
|
10 |
+
class AudioWaveformTransform(AudioTransform):
|
11 |
+
pass
|
12 |
+
|
13 |
+
|
14 |
+
AUDIO_WAVEFORM_TRANSFORM_REGISTRY = {}
|
15 |
+
AUDIO_WAVEFORM_TRANSFORM_CLASS_NAMES = set()
|
16 |
+
|
17 |
+
|
18 |
+
def get_audio_waveform_transform(name):
|
19 |
+
return AUDIO_WAVEFORM_TRANSFORM_REGISTRY[name]
|
20 |
+
|
21 |
+
|
22 |
+
def register_audio_waveform_transform(name):
|
23 |
+
return register_audio_transform(
|
24 |
+
name,
|
25 |
+
AudioWaveformTransform,
|
26 |
+
AUDIO_WAVEFORM_TRANSFORM_REGISTRY,
|
27 |
+
AUDIO_WAVEFORM_TRANSFORM_CLASS_NAMES,
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
import_transforms(os.path.dirname(__file__), "waveform")
|
32 |
+
|
33 |
+
|
34 |
+
class CompositeAudioWaveformTransform(CompositeAudioTransform):
|
35 |
+
@classmethod
|
36 |
+
def from_config_dict(cls, config=None):
|
37 |
+
return super()._from_config_dict(
|
38 |
+
cls,
|
39 |
+
"waveform",
|
40 |
+
get_audio_waveform_transform,
|
41 |
+
CompositeAudioWaveformTransform,
|
42 |
+
config,
|
43 |
+
)
|
44 |
+
|
45 |
+
def __call__(self, x, sample_rate):
|
46 |
+
for t in self.transforms:
|
47 |
+
x, sample_rate = t(x, sample_rate)
|
48 |
+
return x, sample_rate
|
modules/voice_conversion/fairseq/data/audio/waveform_transforms/noiseaugment.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import numpy as np
|
3 |
+
from math import ceil
|
4 |
+
|
5 |
+
from fairseq.data.audio import rand_uniform
|
6 |
+
from fairseq.data.audio.waveform_transforms import (
|
7 |
+
AudioWaveformTransform,
|
8 |
+
register_audio_waveform_transform,
|
9 |
+
)
|
10 |
+
|
11 |
+
SNR_MIN = 5.0
|
12 |
+
SNR_MAX = 15.0
|
13 |
+
RATE = 0.25
|
14 |
+
|
15 |
+
NOISE_RATE = 1.0
|
16 |
+
NOISE_LEN_MEAN = 0.2
|
17 |
+
NOISE_LEN_STD = 0.05
|
18 |
+
|
19 |
+
|
20 |
+
class NoiseAugmentTransform(AudioWaveformTransform):
|
21 |
+
@classmethod
|
22 |
+
def from_config_dict(cls, config=None):
|
23 |
+
_config = {} if config is None else config
|
24 |
+
return cls(
|
25 |
+
_config.get("samples_path", None),
|
26 |
+
_config.get("snr_min", SNR_MIN),
|
27 |
+
_config.get("snr_max", SNR_MAX),
|
28 |
+
_config.get("rate", RATE),
|
29 |
+
)
|
30 |
+
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
samples_path: str,
|
34 |
+
snr_min: float = SNR_MIN,
|
35 |
+
snr_max: float = SNR_MAX,
|
36 |
+
rate: float = RATE,
|
37 |
+
):
|
38 |
+
# Sanity checks
|
39 |
+
assert (
|
40 |
+
samples_path
|
41 |
+
), "need to provide path to audio samples for noise augmentation"
|
42 |
+
assert snr_max >= snr_min, f"empty signal-to-noise range ({snr_min}, {snr_max})"
|
43 |
+
assert rate >= 0 and rate <= 1, "rate should be a float between 0 to 1"
|
44 |
+
|
45 |
+
self.paths = list(Path(samples_path).glob("**/*.wav")) # load music
|
46 |
+
self.n_samples = len(self.paths)
|
47 |
+
assert self.n_samples > 0, f"no audio files found in {samples_path}"
|
48 |
+
|
49 |
+
self.snr_min = snr_min
|
50 |
+
self.snr_max = snr_max
|
51 |
+
self.rate = rate
|
52 |
+
|
53 |
+
def __repr__(self):
|
54 |
+
return (
|
55 |
+
self.__class__.__name__
|
56 |
+
+ "("
|
57 |
+
+ ", ".join(
|
58 |
+
[
|
59 |
+
f"n_samples={self.n_samples}",
|
60 |
+
f"snr={self.snr_min}-{self.snr_max}dB",
|
61 |
+
f"rate={self.rate}",
|
62 |
+
]
|
63 |
+
)
|
64 |
+
+ ")"
|
65 |
+
)
|
66 |
+
|
67 |
+
def pick_sample(self, goal_shape, always_2d=False, use_sample_rate=None):
|
68 |
+
from fairseq.data.audio.audio_utils import get_waveform
|
69 |
+
|
70 |
+
path = self.paths[np.random.randint(0, self.n_samples)]
|
71 |
+
sample = get_waveform(
|
72 |
+
path, always_2d=always_2d, output_sample_rate=use_sample_rate
|
73 |
+
)[0]
|
74 |
+
|
75 |
+
# Check dimensions match, else silently skip adding noise to sample
|
76 |
+
# NOTE: SHOULD THIS QUIT WITH AN ERROR?
|
77 |
+
is_2d = len(goal_shape) == 2
|
78 |
+
if len(goal_shape) != sample.ndim or (
|
79 |
+
is_2d and goal_shape[0] != sample.shape[0]
|
80 |
+
):
|
81 |
+
return np.zeros(goal_shape)
|
82 |
+
|
83 |
+
# Cut/repeat sample to size
|
84 |
+
len_dim = len(goal_shape) - 1
|
85 |
+
n_repeat = ceil(goal_shape[len_dim] / sample.shape[len_dim])
|
86 |
+
repeated = np.tile(sample, [1, n_repeat] if is_2d else n_repeat)
|
87 |
+
start = np.random.randint(0, repeated.shape[len_dim] - goal_shape[len_dim] + 1)
|
88 |
+
return (
|
89 |
+
repeated[:, start : start + goal_shape[len_dim]]
|
90 |
+
if is_2d
|
91 |
+
else repeated[start : start + goal_shape[len_dim]]
|
92 |
+
)
|
93 |
+
|
94 |
+
def _mix(self, source, noise, snr):
|
95 |
+
get_power = lambda x: np.mean(x**2)
|
96 |
+
if get_power(noise):
|
97 |
+
scl = np.sqrt(
|
98 |
+
get_power(source) / (np.power(10, snr / 10) * get_power(noise))
|
99 |
+
)
|
100 |
+
else:
|
101 |
+
scl = 0
|
102 |
+
return 1 * source + scl * noise
|
103 |
+
|
104 |
+
def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
|
105 |
+
return self.pick_sample(goal_shape, always_2d, use_sample_rate)
|
106 |
+
|
107 |
+
def __call__(self, source, sample_rate):
|
108 |
+
if np.random.random() > self.rate:
|
109 |
+
return source, sample_rate
|
110 |
+
|
111 |
+
noise = self._get_noise(
|
112 |
+
source.shape, always_2d=True, use_sample_rate=sample_rate
|
113 |
+
)
|
114 |
+
|
115 |
+
return (
|
116 |
+
self._mix(source, noise, rand_uniform(self.snr_min, self.snr_max)),
|
117 |
+
sample_rate,
|
118 |
+
)
|
119 |
+
|
120 |
+
|
121 |
+
@register_audio_waveform_transform("musicaugment")
|
122 |
+
class MusicAugmentTransform(NoiseAugmentTransform):
|
123 |
+
pass
|
124 |
+
|
125 |
+
|
126 |
+
@register_audio_waveform_transform("backgroundnoiseaugment")
|
127 |
+
class BackgroundNoiseAugmentTransform(NoiseAugmentTransform):
|
128 |
+
pass
|
129 |
+
|
130 |
+
|
131 |
+
@register_audio_waveform_transform("babbleaugment")
|
132 |
+
class BabbleAugmentTransform(NoiseAugmentTransform):
|
133 |
+
def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
|
134 |
+
for i in range(np.random.randint(3, 8)):
|
135 |
+
speech = self.pick_sample(goal_shape, always_2d, use_sample_rate)
|
136 |
+
if i == 0:
|
137 |
+
agg_noise = speech
|
138 |
+
else: # SNR scaled by i (how many noise signals already in agg_noise)
|
139 |
+
agg_noise = self._mix(agg_noise, speech, i)
|
140 |
+
return agg_noise
|
141 |
+
|
142 |
+
|
143 |
+
@register_audio_waveform_transform("sporadicnoiseaugment")
|
144 |
+
class SporadicNoiseAugmentTransform(NoiseAugmentTransform):
|
145 |
+
@classmethod
|
146 |
+
def from_config_dict(cls, config=None):
|
147 |
+
_config = {} if config is None else config
|
148 |
+
return cls(
|
149 |
+
_config.get("samples_path", None),
|
150 |
+
_config.get("snr_min", SNR_MIN),
|
151 |
+
_config.get("snr_max", SNR_MAX),
|
152 |
+
_config.get("rate", RATE),
|
153 |
+
_config.get("noise_rate", NOISE_RATE),
|
154 |
+
_config.get("noise_len_mean", NOISE_LEN_MEAN),
|
155 |
+
_config.get("noise_len_std", NOISE_LEN_STD),
|
156 |
+
)
|
157 |
+
|
158 |
+
def __init__(
|
159 |
+
self,
|
160 |
+
samples_path: str,
|
161 |
+
snr_min: float = SNR_MIN,
|
162 |
+
snr_max: float = SNR_MAX,
|
163 |
+
rate: float = RATE,
|
164 |
+
noise_rate: float = NOISE_RATE, # noises per second
|
165 |
+
noise_len_mean: float = NOISE_LEN_MEAN, # length of noises in seconds
|
166 |
+
noise_len_std: float = NOISE_LEN_STD,
|
167 |
+
):
|
168 |
+
super().__init__(samples_path, snr_min, snr_max, rate)
|
169 |
+
self.noise_rate = noise_rate
|
170 |
+
self.noise_len_mean = noise_len_mean
|
171 |
+
self.noise_len_std = noise_len_std
|
172 |
+
|
173 |
+
def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
|
174 |
+
agg_noise = np.zeros(goal_shape)
|
175 |
+
len_dim = len(goal_shape) - 1
|
176 |
+
is_2d = len(goal_shape) == 2
|
177 |
+
|
178 |
+
n_noises = round(self.noise_rate * goal_shape[len_dim] / use_sample_rate)
|
179 |
+
start_pointers = [
|
180 |
+
round(rand_uniform(0, goal_shape[len_dim])) for _ in range(n_noises)
|
181 |
+
]
|
182 |
+
|
183 |
+
for start_pointer in start_pointers:
|
184 |
+
noise_shape = list(goal_shape)
|
185 |
+
len_seconds = np.random.normal(self.noise_len_mean, self.noise_len_std)
|
186 |
+
noise_shape[len_dim] = round(max(0, len_seconds) * use_sample_rate)
|
187 |
+
end_pointer = start_pointer + noise_shape[len_dim]
|
188 |
+
if end_pointer >= goal_shape[len_dim]:
|
189 |
+
continue
|
190 |
+
|
191 |
+
noise = self.pick_sample(noise_shape, always_2d, use_sample_rate)
|
192 |
+
if is_2d:
|
193 |
+
agg_noise[:, start_pointer:end_pointer] = (
|
194 |
+
agg_noise[:, start_pointer:end_pointer] + noise
|
195 |
+
)
|
196 |
+
else:
|
197 |
+
agg_noise[start_pointer:end_pointer] = (
|
198 |
+
agg_noise[start_pointer:end_pointer] + noise
|
199 |
+
)
|
200 |
+
|
201 |
+
return agg_noise
|
modules/voice_conversion/fairseq/data/backtranslation_dataset.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from fairseq import utils
|
8 |
+
|
9 |
+
from . import FairseqDataset
|
10 |
+
|
11 |
+
|
12 |
+
def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True):
|
13 |
+
"""Backtranslate a list of samples.
|
14 |
+
|
15 |
+
Given an input (*samples*) of the form:
|
16 |
+
|
17 |
+
[{'id': 1, 'source': 'hallo welt'}]
|
18 |
+
|
19 |
+
this will return:
|
20 |
+
|
21 |
+
[{'id': 1, 'source': 'hello world', 'target': 'hallo welt'}]
|
22 |
+
|
23 |
+
Args:
|
24 |
+
samples (List[dict]): samples to backtranslate. Individual samples are
|
25 |
+
expected to have a 'source' key, which will become the 'target'
|
26 |
+
after backtranslation.
|
27 |
+
collate_fn (callable): function to collate samples into a mini-batch
|
28 |
+
generate_fn (callable): function to generate backtranslations
|
29 |
+
cuda (bool): use GPU for generation (default: ``True``)
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
List[dict]: an updated list of samples with a backtranslated source
|
33 |
+
"""
|
34 |
+
collated_samples = collate_fn(samples)
|
35 |
+
s = utils.move_to_cuda(collated_samples) if cuda else collated_samples
|
36 |
+
generated_sources = generate_fn(s)
|
37 |
+
|
38 |
+
id_to_src = {sample["id"]: sample["source"] for sample in samples}
|
39 |
+
|
40 |
+
# Go through each tgt sentence in batch and its corresponding best
|
41 |
+
# generated hypothesis and create a backtranslation data pair
|
42 |
+
# {id: id, source: generated backtranslation, target: original tgt}
|
43 |
+
return [
|
44 |
+
{
|
45 |
+
"id": id.item(),
|
46 |
+
"target": id_to_src[id.item()],
|
47 |
+
"source": hypos[0]["tokens"].cpu(),
|
48 |
+
}
|
49 |
+
for id, hypos in zip(collated_samples["id"], generated_sources)
|
50 |
+
]
|
51 |
+
|
52 |
+
|
53 |
+
class BacktranslationDataset(FairseqDataset):
|
54 |
+
"""
|
55 |
+
Sets up a backtranslation dataset which takes a tgt batch, generates
|
56 |
+
a src using a tgt-src backtranslation function (*backtranslation_fn*),
|
57 |
+
and returns the corresponding `{generated src, input tgt}` batch.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be
|
61 |
+
backtranslated. Only the source side of this dataset will be used.
|
62 |
+
After backtranslation, the source sentences in this dataset will be
|
63 |
+
returned as the targets.
|
64 |
+
src_dict (~fairseq.data.Dictionary): the dictionary of backtranslated
|
65 |
+
sentences.
|
66 |
+
tgt_dict (~fairseq.data.Dictionary, optional): the dictionary of
|
67 |
+
sentences to be backtranslated.
|
68 |
+
backtranslation_fn (callable, optional): function to call to generate
|
69 |
+
backtranslations. This is typically the `generate` method of a
|
70 |
+
:class:`~fairseq.sequence_generator.SequenceGenerator` object.
|
71 |
+
Pass in None when it is not available at initialization time, and
|
72 |
+
use set_backtranslation_fn function to set it when available.
|
73 |
+
output_collater (callable, optional): function to call on the
|
74 |
+
backtranslated samples to create the final batch
|
75 |
+
(default: ``tgt_dataset.collater``).
|
76 |
+
cuda: use GPU for generation
|
77 |
+
"""
|
78 |
+
|
79 |
+
def __init__(
|
80 |
+
self,
|
81 |
+
tgt_dataset,
|
82 |
+
src_dict,
|
83 |
+
tgt_dict=None,
|
84 |
+
backtranslation_fn=None,
|
85 |
+
output_collater=None,
|
86 |
+
cuda=True,
|
87 |
+
**kwargs
|
88 |
+
):
|
89 |
+
self.tgt_dataset = tgt_dataset
|
90 |
+
self.backtranslation_fn = backtranslation_fn
|
91 |
+
self.output_collater = (
|
92 |
+
output_collater if output_collater is not None else tgt_dataset.collater
|
93 |
+
)
|
94 |
+
self.cuda = cuda if torch.cuda.is_available() else False
|
95 |
+
self.src_dict = src_dict
|
96 |
+
self.tgt_dict = tgt_dict
|
97 |
+
|
98 |
+
def __getitem__(self, index):
|
99 |
+
"""
|
100 |
+
Returns a single sample from *tgt_dataset*. Note that backtranslation is
|
101 |
+
not applied in this step; use :func:`collater` instead to backtranslate
|
102 |
+
a batch of samples.
|
103 |
+
"""
|
104 |
+
return self.tgt_dataset[index]
|
105 |
+
|
106 |
+
def __len__(self):
|
107 |
+
return len(self.tgt_dataset)
|
108 |
+
|
109 |
+
def set_backtranslation_fn(self, backtranslation_fn):
|
110 |
+
self.backtranslation_fn = backtranslation_fn
|
111 |
+
|
112 |
+
def collater(self, samples):
|
113 |
+
"""Merge and backtranslate a list of samples to form a mini-batch.
|
114 |
+
|
115 |
+
Using the samples from *tgt_dataset*, load a collated target sample to
|
116 |
+
feed to the backtranslation model. Then take the backtranslation with
|
117 |
+
the best score as the source and the original input as the target.
|
118 |
+
|
119 |
+
Note: we expect *tgt_dataset* to provide a function `collater()` that
|
120 |
+
will collate samples into the format expected by *backtranslation_fn*.
|
121 |
+
After backtranslation, we will feed the new list of samples (i.e., the
|
122 |
+
`(backtranslated source, original source)` pairs) to *output_collater*
|
123 |
+
and return the result.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
samples (List[dict]): samples to backtranslate and collate
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
dict: a mini-batch with keys coming from *output_collater*
|
130 |
+
"""
|
131 |
+
if samples[0].get("is_dummy", False):
|
132 |
+
return samples
|
133 |
+
samples = backtranslate_samples(
|
134 |
+
samples=samples,
|
135 |
+
collate_fn=self.tgt_dataset.collater,
|
136 |
+
generate_fn=(lambda net_input: self.backtranslation_fn(net_input)),
|
137 |
+
cuda=self.cuda,
|
138 |
+
)
|
139 |
+
return self.output_collater(samples)
|
140 |
+
|
141 |
+
def num_tokens(self, index):
|
142 |
+
"""Just use the tgt dataset num_tokens"""
|
143 |
+
return self.tgt_dataset.num_tokens(index)
|
144 |
+
|
145 |
+
def ordered_indices(self):
|
146 |
+
"""Just use the tgt dataset ordered_indices"""
|
147 |
+
return self.tgt_dataset.ordered_indices()
|
148 |
+
|
149 |
+
def size(self, index):
|
150 |
+
"""Return an example's size as a float or tuple. This value is used
|
151 |
+
when filtering a dataset with ``--max-positions``.
|
152 |
+
|
153 |
+
Note: we use *tgt_dataset* to approximate the length of the source
|
154 |
+
sentence, since we do not know the actual length until after
|
155 |
+
backtranslation.
|
156 |
+
"""
|
157 |
+
tgt_size = self.tgt_dataset.size(index)[0]
|
158 |
+
return (tgt_size, tgt_size)
|
159 |
+
|
160 |
+
@property
|
161 |
+
def supports_prefetch(self):
|
162 |
+
return getattr(self.tgt_dataset, "supports_prefetch", False)
|
163 |
+
|
164 |
+
def prefetch(self, indices):
|
165 |
+
return self.tgt_dataset.prefetch(indices)
|
modules/voice_conversion/fairseq/data/base_wrapper_dataset.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from torch.utils.data.dataloader import default_collate
|
7 |
+
|
8 |
+
from . import FairseqDataset
|
9 |
+
|
10 |
+
|
11 |
+
class BaseWrapperDataset(FairseqDataset):
|
12 |
+
def __init__(self, dataset):
|
13 |
+
super().__init__()
|
14 |
+
self.dataset = dataset
|
15 |
+
|
16 |
+
def __getitem__(self, index):
|
17 |
+
return self.dataset[index]
|
18 |
+
|
19 |
+
def __len__(self):
|
20 |
+
return len(self.dataset)
|
21 |
+
|
22 |
+
def collater(self, samples):
|
23 |
+
if hasattr(self.dataset, "collater"):
|
24 |
+
return self.dataset.collater(samples)
|
25 |
+
else:
|
26 |
+
return default_collate(samples)
|
27 |
+
|
28 |
+
@property
|
29 |
+
def sizes(self):
|
30 |
+
return self.dataset.sizes
|
31 |
+
|
32 |
+
def num_tokens(self, index):
|
33 |
+
return self.dataset.num_tokens(index)
|
34 |
+
|
35 |
+
def size(self, index):
|
36 |
+
return self.dataset.size(index)
|
37 |
+
|
38 |
+
def ordered_indices(self):
|
39 |
+
return self.dataset.ordered_indices()
|
40 |
+
|
41 |
+
@property
|
42 |
+
def supports_prefetch(self):
|
43 |
+
return getattr(self.dataset, "supports_prefetch", False)
|
44 |
+
|
45 |
+
def attr(self, attr: str, index: int):
|
46 |
+
return self.dataset.attr(attr, index)
|
47 |
+
|
48 |
+
def prefetch(self, indices):
|
49 |
+
self.dataset.prefetch(indices)
|
50 |
+
|
51 |
+
def get_batch_shapes(self):
|
52 |
+
return self.dataset.get_batch_shapes()
|
53 |
+
|
54 |
+
def batch_by_size(
|
55 |
+
self,
|
56 |
+
indices,
|
57 |
+
max_tokens=None,
|
58 |
+
max_sentences=None,
|
59 |
+
required_batch_size_multiple=1,
|
60 |
+
):
|
61 |
+
return self.dataset.batch_by_size(
|
62 |
+
indices,
|
63 |
+
max_tokens=max_tokens,
|
64 |
+
max_sentences=max_sentences,
|
65 |
+
required_batch_size_multiple=required_batch_size_multiple,
|
66 |
+
)
|
67 |
+
|
68 |
+
def filter_indices_by_size(self, indices, max_sizes):
|
69 |
+
return self.dataset.filter_indices_by_size(indices, max_sizes)
|
70 |
+
|
71 |
+
@property
|
72 |
+
def can_reuse_epoch_itr_across_epochs(self):
|
73 |
+
return self.dataset.can_reuse_epoch_itr_across_epochs
|
74 |
+
|
75 |
+
def set_epoch(self, epoch):
|
76 |
+
super().set_epoch(epoch)
|
77 |
+
if hasattr(self.dataset, "set_epoch"):
|
78 |
+
self.dataset.set_epoch(epoch)
|
modules/voice_conversion/fairseq/data/bucket_pad_length_dataset.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from fairseq.data import BaseWrapperDataset
|
9 |
+
from fairseq.data.data_utils import get_buckets, get_bucketed_sizes
|
10 |
+
|
11 |
+
|
12 |
+
class BucketPadLengthDataset(BaseWrapperDataset):
|
13 |
+
"""
|
14 |
+
Bucket and pad item lengths to the nearest bucket size. This can be used to
|
15 |
+
reduce the number of unique batch shapes, which is important on TPUs since
|
16 |
+
each new batch shape requires a recompilation.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
dataset (FairseqDatset): dataset to bucket
|
20 |
+
sizes (List[int]): all item sizes
|
21 |
+
num_buckets (int): number of buckets to create
|
22 |
+
pad_idx (int): padding symbol
|
23 |
+
left_pad (bool): if True, pad on the left; otherwise right pad
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
dataset,
|
29 |
+
sizes,
|
30 |
+
num_buckets,
|
31 |
+
pad_idx,
|
32 |
+
left_pad,
|
33 |
+
tensor_key=None,
|
34 |
+
):
|
35 |
+
super().__init__(dataset)
|
36 |
+
self.pad_idx = pad_idx
|
37 |
+
self.left_pad = left_pad
|
38 |
+
|
39 |
+
assert num_buckets > 0
|
40 |
+
self.buckets = get_buckets(sizes, num_buckets)
|
41 |
+
self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets)
|
42 |
+
self._tensor_key = tensor_key
|
43 |
+
|
44 |
+
def _set_tensor(self, item, val):
|
45 |
+
if self._tensor_key is None:
|
46 |
+
return val
|
47 |
+
item[self._tensor_key] = val
|
48 |
+
return item
|
49 |
+
|
50 |
+
def _get_tensor(self, item):
|
51 |
+
if self._tensor_key is None:
|
52 |
+
return item
|
53 |
+
return item[self._tensor_key]
|
54 |
+
|
55 |
+
def _pad(self, tensor, bucket_size, dim=-1):
|
56 |
+
num_pad = bucket_size - tensor.size(dim)
|
57 |
+
return F.pad(
|
58 |
+
tensor,
|
59 |
+
(num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad),
|
60 |
+
value=self.pad_idx,
|
61 |
+
)
|
62 |
+
|
63 |
+
def __getitem__(self, index):
|
64 |
+
item = self.dataset[index]
|
65 |
+
bucket_size = self._bucketed_sizes[index]
|
66 |
+
tensor = self._get_tensor(item)
|
67 |
+
padded = self._pad(tensor, bucket_size)
|
68 |
+
return self._set_tensor(item, padded)
|
69 |
+
|
70 |
+
@property
|
71 |
+
def sizes(self):
|
72 |
+
return self._bucketed_sizes
|
73 |
+
|
74 |
+
def num_tokens(self, index):
|
75 |
+
return self._bucketed_sizes[index]
|
76 |
+
|
77 |
+
def size(self, index):
|
78 |
+
return self._bucketed_sizes[index]
|
modules/voice_conversion/fairseq/data/codedataset.py
ADDED
@@ -0,0 +1,576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
|
7 |
+
import json
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
import random
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import torch.utils.data
|
16 |
+
|
17 |
+
from . import data_utils
|
18 |
+
from fairseq.data.fairseq_dataset import FairseqDataset
|
19 |
+
|
20 |
+
F0_FRAME_SPACE = 0.005 # sec
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
|
26 |
+
class ExpressiveCodeDataConfig(object):
|
27 |
+
def __init__(self, json_path):
|
28 |
+
with open(json_path, "r") as f:
|
29 |
+
self.config = json.load(f)
|
30 |
+
self._manifests = self.config["manifests"]
|
31 |
+
|
32 |
+
@property
|
33 |
+
def manifests(self):
|
34 |
+
return self._manifests
|
35 |
+
|
36 |
+
@property
|
37 |
+
def n_units(self):
|
38 |
+
return self.config["n_units"]
|
39 |
+
|
40 |
+
@property
|
41 |
+
def sampling_rate(self):
|
42 |
+
return self.config["sampling_rate"]
|
43 |
+
|
44 |
+
@property
|
45 |
+
def code_hop_size(self):
|
46 |
+
return self.config["code_hop_size"]
|
47 |
+
|
48 |
+
@property
|
49 |
+
def f0_stats(self):
|
50 |
+
"""pre-computed f0 statistics path"""
|
51 |
+
return self.config.get("f0_stats", None)
|
52 |
+
|
53 |
+
@property
|
54 |
+
def f0_vq_type(self):
|
55 |
+
"""naive or precomp"""
|
56 |
+
return self.config["f0_vq_type"]
|
57 |
+
|
58 |
+
@property
|
59 |
+
def f0_vq_name(self):
|
60 |
+
return self.config["f0_vq_name"]
|
61 |
+
|
62 |
+
def get_f0_vq_naive_quantizer(self, log, norm_mean, norm_std):
|
63 |
+
key = "log" if log else "linear"
|
64 |
+
if norm_mean and norm_std:
|
65 |
+
key += "_mean_std_norm"
|
66 |
+
elif norm_mean:
|
67 |
+
key += "_mean_norm"
|
68 |
+
else:
|
69 |
+
key += "_none_norm"
|
70 |
+
return self.config["f0_vq_naive_quantizer"][key]
|
71 |
+
|
72 |
+
@property
|
73 |
+
def f0_vq_n_units(self):
|
74 |
+
return self.config["f0_vq_n_units"]
|
75 |
+
|
76 |
+
@property
|
77 |
+
def multispkr(self):
|
78 |
+
"""how to parse speaker label from audio path"""
|
79 |
+
return self.config.get("multispkr", None)
|
80 |
+
|
81 |
+
|
82 |
+
def get_f0(audio, rate=16000):
|
83 |
+
try:
|
84 |
+
import amfm_decompy.basic_tools as basic
|
85 |
+
import amfm_decompy.pYAAPT as pYAAPT
|
86 |
+
from librosa.util import normalize
|
87 |
+
except ImportError:
|
88 |
+
raise "Please install amfm_decompy (`pip install AMFM-decompy`) and librosa (`pip install librosa`)."
|
89 |
+
|
90 |
+
assert audio.ndim == 1
|
91 |
+
frame_length = 20.0 # ms
|
92 |
+
to_pad = int(frame_length / 1000 * rate) // 2
|
93 |
+
|
94 |
+
audio = normalize(audio) * 0.95
|
95 |
+
audio = np.pad(audio, (to_pad, to_pad), "constant", constant_values=0)
|
96 |
+
audio = basic.SignalObj(audio, rate)
|
97 |
+
pitch = pYAAPT.yaapt(
|
98 |
+
audio,
|
99 |
+
frame_length=frame_length,
|
100 |
+
frame_space=F0_FRAME_SPACE * 1000,
|
101 |
+
nccf_thresh1=0.25,
|
102 |
+
tda_frame_length=25.0,
|
103 |
+
)
|
104 |
+
f0 = pitch.samp_values
|
105 |
+
return f0
|
106 |
+
|
107 |
+
|
108 |
+
def interpolate_f0(f0):
|
109 |
+
try:
|
110 |
+
from scipy.interpolate import interp1d
|
111 |
+
except ImportError:
|
112 |
+
raise "Please install scipy (`pip install scipy`)"
|
113 |
+
|
114 |
+
orig_t = np.arange(f0.shape[0])
|
115 |
+
f0_interp = f0[:]
|
116 |
+
ii = f0_interp != 0
|
117 |
+
if ii.sum() > 1:
|
118 |
+
f0_interp = interp1d(
|
119 |
+
orig_t[ii], f0_interp[ii], bounds_error=False, kind="linear", fill_value=0
|
120 |
+
)(orig_t)
|
121 |
+
f0_interp = torch.Tensor(f0_interp).type_as(f0).to(f0.device)
|
122 |
+
return f0_interp
|
123 |
+
|
124 |
+
|
125 |
+
def naive_quantize(x, edges):
|
126 |
+
bin_idx = (x.view(-1, 1) > edges.view(1, -1)).long().sum(dim=1)
|
127 |
+
return bin_idx
|
128 |
+
|
129 |
+
|
130 |
+
def load_wav(full_path):
|
131 |
+
try:
|
132 |
+
import soundfile as sf
|
133 |
+
except ImportError:
|
134 |
+
raise "Please install soundfile (`pip install SoundFile`)"
|
135 |
+
data, sampling_rate = sf.read(full_path)
|
136 |
+
return data, sampling_rate
|
137 |
+
|
138 |
+
|
139 |
+
def parse_code(code_str, dictionary, append_eos):
|
140 |
+
code, duration = torch.unique_consecutive(
|
141 |
+
torch.ShortTensor(list(map(int, code_str.split()))), return_counts=True
|
142 |
+
)
|
143 |
+
code = " ".join(map(str, code.tolist()))
|
144 |
+
code = dictionary.encode_line(code, append_eos).short()
|
145 |
+
|
146 |
+
if append_eos:
|
147 |
+
duration = torch.cat((duration, duration.new_zeros((1,))), dim=0) # eos
|
148 |
+
duration = duration.short()
|
149 |
+
return code, duration
|
150 |
+
|
151 |
+
|
152 |
+
def parse_manifest(manifest, dictionary):
|
153 |
+
audio_files = []
|
154 |
+
codes = []
|
155 |
+
durations = []
|
156 |
+
speakers = []
|
157 |
+
|
158 |
+
with open(manifest) as info:
|
159 |
+
for line in info.readlines():
|
160 |
+
sample = eval(line.strip())
|
161 |
+
if "cpc_km100" in sample:
|
162 |
+
k = "cpc_km100"
|
163 |
+
elif "hubert_km100" in sample:
|
164 |
+
k = "hubert_km100"
|
165 |
+
elif "phone" in sample:
|
166 |
+
k = "phone"
|
167 |
+
else:
|
168 |
+
assert False, "unknown format"
|
169 |
+
code = sample[k]
|
170 |
+
code, duration = parse_code(code, dictionary, append_eos=True)
|
171 |
+
|
172 |
+
codes.append(code)
|
173 |
+
durations.append(duration)
|
174 |
+
audio_files.append(sample["audio"])
|
175 |
+
speakers.append(sample.get("speaker", None))
|
176 |
+
|
177 |
+
return audio_files, codes, durations, speakers
|
178 |
+
|
179 |
+
|
180 |
+
def parse_speaker(path, method):
|
181 |
+
if type(path) == str:
|
182 |
+
path = Path(path)
|
183 |
+
|
184 |
+
if method == "parent_name":
|
185 |
+
return path.parent.name
|
186 |
+
elif method == "parent_parent_name":
|
187 |
+
return path.parent.parent.name
|
188 |
+
elif method == "_":
|
189 |
+
return path.name.split("_")[0]
|
190 |
+
elif method == "single":
|
191 |
+
return "A"
|
192 |
+
elif callable(method):
|
193 |
+
return method(path)
|
194 |
+
else:
|
195 |
+
raise NotImplementedError()
|
196 |
+
|
197 |
+
|
198 |
+
def get_f0_by_filename(filename, tgt_sampling_rate):
|
199 |
+
audio, sampling_rate = load_wav(filename)
|
200 |
+
if sampling_rate != tgt_sampling_rate:
|
201 |
+
raise ValueError(
|
202 |
+
"{} SR doesn't match target {} SR".format(sampling_rate, tgt_sampling_rate)
|
203 |
+
)
|
204 |
+
|
205 |
+
# compute un-interpolated f0, and use Ann's interp in __getitem__ if set
|
206 |
+
f0 = get_f0(audio, rate=tgt_sampling_rate)
|
207 |
+
f0 = torch.from_numpy(f0.astype(np.float32))
|
208 |
+
return f0
|
209 |
+
|
210 |
+
|
211 |
+
def align_f0_to_durations(f0, durations, f0_code_ratio, tol=1):
|
212 |
+
code_len = durations.sum()
|
213 |
+
targ_len = int(f0_code_ratio * code_len)
|
214 |
+
diff = f0.size(0) - targ_len
|
215 |
+
assert abs(diff) <= tol, (
|
216 |
+
f"Cannot subsample F0: |{f0.size(0)} - {f0_code_ratio}*{code_len}|"
|
217 |
+
f" > {tol} (dur=\n{durations})"
|
218 |
+
)
|
219 |
+
if diff > 0:
|
220 |
+
f0 = f0[:targ_len]
|
221 |
+
elif diff < 0:
|
222 |
+
f0 = torch.cat((f0, f0.new_full((-diff,), f0[-1])), 0)
|
223 |
+
|
224 |
+
f0_offset = 0.0
|
225 |
+
seg_f0s = []
|
226 |
+
for dur in durations:
|
227 |
+
f0_dur = dur.item() * f0_code_ratio
|
228 |
+
seg_f0 = f0[int(f0_offset) : int(f0_offset + f0_dur)]
|
229 |
+
seg_f0 = seg_f0[seg_f0 != 0]
|
230 |
+
if len(seg_f0) == 0:
|
231 |
+
seg_f0 = torch.tensor(0).type(seg_f0.type())
|
232 |
+
else:
|
233 |
+
seg_f0 = seg_f0.mean()
|
234 |
+
seg_f0s.append(seg_f0)
|
235 |
+
f0_offset += f0_dur
|
236 |
+
|
237 |
+
assert int(f0_offset) == f0.size(0), f"{f0_offset} {f0.size()} {durations.sum()}"
|
238 |
+
return torch.tensor(seg_f0s)
|
239 |
+
|
240 |
+
|
241 |
+
class Paddings(object):
|
242 |
+
def __init__(self, code_val, dur_val=0, f0_val=-2.0):
|
243 |
+
self.code = code_val
|
244 |
+
self.dur = dur_val
|
245 |
+
self.f0 = f0_val
|
246 |
+
|
247 |
+
|
248 |
+
class Shifts(object):
|
249 |
+
def __init__(self, shifts_str, pads):
|
250 |
+
self._shifts = list(map(int, shifts_str.split(",")))
|
251 |
+
assert len(self._shifts) == 2, self._shifts
|
252 |
+
assert all(s >= 0 for s in self._shifts)
|
253 |
+
self.extra_length = max(s for s in self._shifts)
|
254 |
+
self.pads = pads
|
255 |
+
|
256 |
+
@property
|
257 |
+
def dur(self):
|
258 |
+
return self._shifts[0]
|
259 |
+
|
260 |
+
@property
|
261 |
+
def f0(self):
|
262 |
+
return self._shifts[1]
|
263 |
+
|
264 |
+
@staticmethod
|
265 |
+
def shift_one(seq, left_pad_num, right_pad_num, pad):
|
266 |
+
assert seq.ndim == 1
|
267 |
+
bos = seq.new_full((left_pad_num,), pad)
|
268 |
+
eos = seq.new_full((right_pad_num,), pad)
|
269 |
+
seq = torch.cat([bos, seq, eos])
|
270 |
+
mask = torch.ones_like(seq).bool()
|
271 |
+
mask[left_pad_num : len(seq) - right_pad_num] = 0
|
272 |
+
return seq, mask
|
273 |
+
|
274 |
+
def __call__(self, code, dur, f0):
|
275 |
+
if self.extra_length == 0:
|
276 |
+
code_mask = torch.zeros_like(code).bool()
|
277 |
+
dur_mask = torch.zeros_like(dur).bool()
|
278 |
+
f0_mask = torch.zeros_like(f0).bool()
|
279 |
+
return code, code_mask, dur, dur_mask, f0, f0_mask
|
280 |
+
|
281 |
+
code, code_mask = self.shift_one(code, 0, self.extra_length, self.pads.code)
|
282 |
+
dur, dur_mask = self.shift_one(
|
283 |
+
dur, self.dur, self.extra_length - self.dur, self.pads.dur
|
284 |
+
)
|
285 |
+
f0, f0_mask = self.shift_one(
|
286 |
+
f0, self.f0, self.extra_length - self.f0, self.pads.f0
|
287 |
+
)
|
288 |
+
return code, code_mask, dur, dur_mask, f0, f0_mask
|
289 |
+
|
290 |
+
|
291 |
+
class CodeDataset(FairseqDataset):
|
292 |
+
def __init__(
|
293 |
+
self,
|
294 |
+
manifest,
|
295 |
+
dictionary,
|
296 |
+
dur_dictionary,
|
297 |
+
f0_dictionary,
|
298 |
+
config,
|
299 |
+
discrete_dur,
|
300 |
+
discrete_f0,
|
301 |
+
log_f0,
|
302 |
+
normalize_f0_mean,
|
303 |
+
normalize_f0_std,
|
304 |
+
interpolate_f0,
|
305 |
+
return_filename=False,
|
306 |
+
strip_filename=True,
|
307 |
+
shifts="0,0",
|
308 |
+
return_continuous_f0=False,
|
309 |
+
):
|
310 |
+
random.seed(1234)
|
311 |
+
self.dictionary = dictionary
|
312 |
+
self.dur_dictionary = dur_dictionary
|
313 |
+
self.f0_dictionary = f0_dictionary
|
314 |
+
self.config = config
|
315 |
+
|
316 |
+
# duration config
|
317 |
+
self.discrete_dur = discrete_dur
|
318 |
+
|
319 |
+
# pitch config
|
320 |
+
self.discrete_f0 = discrete_f0
|
321 |
+
self.log_f0 = log_f0
|
322 |
+
self.normalize_f0_mean = normalize_f0_mean
|
323 |
+
self.normalize_f0_std = normalize_f0_std
|
324 |
+
self.interpolate_f0 = interpolate_f0
|
325 |
+
|
326 |
+
self.return_filename = return_filename
|
327 |
+
self.strip_filename = strip_filename
|
328 |
+
self.f0_code_ratio = config.code_hop_size / (
|
329 |
+
config.sampling_rate * F0_FRAME_SPACE
|
330 |
+
)
|
331 |
+
|
332 |
+
# use lazy loading to avoid sharing file handlers across workers
|
333 |
+
self.manifest = manifest
|
334 |
+
self._codes = None
|
335 |
+
self._durs = None
|
336 |
+
self._f0s = None
|
337 |
+
with open(f"{manifest}.leng.txt", "r") as f:
|
338 |
+
lengs = [int(line.rstrip()) for line in f]
|
339 |
+
edges = np.cumsum([0] + lengs)
|
340 |
+
self.starts, self.ends = edges[:-1], edges[1:]
|
341 |
+
with open(f"{manifest}.path.txt", "r") as f:
|
342 |
+
self.file_names = [line.rstrip() for line in f]
|
343 |
+
logger.info(f"num entries: {len(self.starts)}")
|
344 |
+
|
345 |
+
if os.path.exists(f"{manifest}.f0_stat.pt"):
|
346 |
+
self.f0_stats = torch.load(f"{manifest}.f0_stat.pt")
|
347 |
+
elif config.f0_stats:
|
348 |
+
self.f0_stats = torch.load(config.f0_stats)
|
349 |
+
|
350 |
+
self.multispkr = config.multispkr
|
351 |
+
if config.multispkr:
|
352 |
+
with open(f"{manifest}.speaker.txt", "r") as f:
|
353 |
+
self.spkrs = [line.rstrip() for line in f]
|
354 |
+
self.id_to_spkr = sorted(self.spkrs)
|
355 |
+
self.spkr_to_id = {k: v for v, k in enumerate(self.id_to_spkr)}
|
356 |
+
|
357 |
+
self.pads = Paddings(
|
358 |
+
dictionary.pad(),
|
359 |
+
0, # use 0 for duration padding
|
360 |
+
f0_dictionary.pad() if discrete_f0 else -5.0,
|
361 |
+
)
|
362 |
+
self.shifts = Shifts(shifts, pads=self.pads)
|
363 |
+
self.return_continuous_f0 = return_continuous_f0
|
364 |
+
|
365 |
+
def get_data_handlers(self):
|
366 |
+
logging.info(f"loading data for {self.manifest}")
|
367 |
+
self._codes = np.load(f"{self.manifest}.code.npy", mmap_mode="r")
|
368 |
+
self._durs = np.load(f"{self.manifest}.dur.npy", mmap_mode="r")
|
369 |
+
|
370 |
+
if self.discrete_f0:
|
371 |
+
if self.config.f0_vq_type == "precomp":
|
372 |
+
self._f0s = np.load(
|
373 |
+
f"{self.manifest}.{self.config.f0_vq_name}.npy", mmap_mode="r"
|
374 |
+
)
|
375 |
+
elif self.config.f0_vq_type == "naive":
|
376 |
+
self._f0s = np.load(f"{self.manifest}.f0.npy", mmap_mode="r")
|
377 |
+
quantizers_path = self.config.get_f0_vq_naive_quantizer(
|
378 |
+
self.log_f0, self.normalize_f0_mean, self.normalize_f0_std
|
379 |
+
)
|
380 |
+
quantizers = torch.load(quantizers_path)
|
381 |
+
n_units = self.config.f0_vq_n_units
|
382 |
+
self._f0_quantizer = torch.from_numpy(quantizers[n_units])
|
383 |
+
else:
|
384 |
+
raise ValueError(f"f0_vq_type {self.config.f0_vq_type} not supported")
|
385 |
+
else:
|
386 |
+
self._f0s = np.load(f"{self.manifest}.f0.npy", mmap_mode="r")
|
387 |
+
|
388 |
+
def preprocess_f0(self, f0, stats):
|
389 |
+
"""
|
390 |
+
1. interpolate
|
391 |
+
2. log transform (keep unvoiced frame 0)
|
392 |
+
"""
|
393 |
+
# TODO: change this to be dependent on config for naive quantizer
|
394 |
+
f0 = f0.clone()
|
395 |
+
if self.interpolate_f0:
|
396 |
+
f0 = interpolate_f0(f0)
|
397 |
+
|
398 |
+
mask = f0 != 0 # only process voiced frames
|
399 |
+
if self.log_f0:
|
400 |
+
f0[mask] = f0[mask].log()
|
401 |
+
if self.normalize_f0_mean:
|
402 |
+
mean = stats["logf0_mean"] if self.log_f0 else stats["f0_mean"]
|
403 |
+
f0[mask] = f0[mask] - mean
|
404 |
+
if self.normalize_f0_std:
|
405 |
+
std = stats["logf0_std"] if self.log_f0 else stats["f0_std"]
|
406 |
+
f0[mask] = f0[mask] / std
|
407 |
+
return f0
|
408 |
+
|
409 |
+
def _get_raw_item(self, index):
|
410 |
+
start, end = self.starts[index], self.ends[index]
|
411 |
+
if self._codes is None:
|
412 |
+
self.get_data_handlers()
|
413 |
+
code = torch.from_numpy(np.array(self._codes[start:end])).long()
|
414 |
+
dur = torch.from_numpy(np.array(self._durs[start:end]))
|
415 |
+
f0 = torch.from_numpy(np.array(self._f0s[start:end]))
|
416 |
+
return code, dur, f0
|
417 |
+
|
418 |
+
def __getitem__(self, index):
|
419 |
+
code, dur, f0 = self._get_raw_item(index)
|
420 |
+
code = torch.cat([code.new([self.dictionary.bos()]), code])
|
421 |
+
|
422 |
+
# use 0 for eos and bos
|
423 |
+
dur = torch.cat([dur.new([0]), dur])
|
424 |
+
if self.discrete_dur:
|
425 |
+
dur = self.dur_dictionary.encode_line(
|
426 |
+
" ".join(map(str, dur.tolist())), append_eos=False
|
427 |
+
).long()
|
428 |
+
else:
|
429 |
+
dur = dur.float()
|
430 |
+
|
431 |
+
# TODO: find a more elegant approach
|
432 |
+
raw_f0 = None
|
433 |
+
if self.discrete_f0:
|
434 |
+
if self.config.f0_vq_type == "precomp":
|
435 |
+
f0 = self.f0_dictionary.encode_line(
|
436 |
+
" ".join(map(str, f0.tolist())), append_eos=False
|
437 |
+
).long()
|
438 |
+
else:
|
439 |
+
f0 = f0.float()
|
440 |
+
f0 = self.preprocess_f0(f0, self.f0_stats[self.spkrs[index]])
|
441 |
+
if self.return_continuous_f0:
|
442 |
+
raw_f0 = f0
|
443 |
+
raw_f0 = torch.cat([raw_f0.new([self.f0_dictionary.bos()]), raw_f0])
|
444 |
+
f0 = naive_quantize(f0, self._f0_quantizer)
|
445 |
+
f0 = torch.cat([f0.new([self.f0_dictionary.bos()]), f0])
|
446 |
+
else:
|
447 |
+
f0 = f0.float()
|
448 |
+
if self.multispkr:
|
449 |
+
f0 = self.preprocess_f0(f0, self.f0_stats[self.spkrs[index]])
|
450 |
+
else:
|
451 |
+
f0 = self.preprocess_f0(f0, self.f0_stats)
|
452 |
+
f0 = torch.cat([f0.new([0]), f0])
|
453 |
+
|
454 |
+
if raw_f0 is not None:
|
455 |
+
*_, raw_f0, raw_f0_mask = self.shifts(code, dur, raw_f0)
|
456 |
+
else:
|
457 |
+
raw_f0_mask = None
|
458 |
+
|
459 |
+
code, code_mask, dur, dur_mask, f0, f0_mask = self.shifts(code, dur, f0)
|
460 |
+
if raw_f0_mask is not None:
|
461 |
+
assert (raw_f0_mask == f0_mask).all()
|
462 |
+
|
463 |
+
# is a padded frame if either input or output is padded
|
464 |
+
feats = {
|
465 |
+
"source": code[:-1],
|
466 |
+
"target": code[1:],
|
467 |
+
"mask": code_mask[1:].logical_or(code_mask[:-1]),
|
468 |
+
"dur_source": dur[:-1],
|
469 |
+
"dur_target": dur[1:],
|
470 |
+
"dur_mask": dur_mask[1:].logical_or(dur_mask[:-1]),
|
471 |
+
"f0_source": f0[:-1],
|
472 |
+
"f0_target": f0[1:],
|
473 |
+
"f0_mask": f0_mask[1:].logical_or(f0_mask[:-1]),
|
474 |
+
}
|
475 |
+
|
476 |
+
if raw_f0 is not None:
|
477 |
+
feats["raw_f0"] = raw_f0[1:]
|
478 |
+
|
479 |
+
if self.return_filename:
|
480 |
+
fname = self.file_names[index]
|
481 |
+
feats["filename"] = (
|
482 |
+
fname if not self.strip_filename else Path(fname).with_suffix("").name
|
483 |
+
)
|
484 |
+
return feats
|
485 |
+
|
486 |
+
def __len__(self):
|
487 |
+
return len(self.starts)
|
488 |
+
|
489 |
+
def size(self, index):
|
490 |
+
return self.ends[index] - self.starts[index] + self.shifts.extra_length
|
491 |
+
|
492 |
+
def num_tokens(self, index):
|
493 |
+
return self.size(index)
|
494 |
+
|
495 |
+
def collater(self, samples):
|
496 |
+
pad_idx, eos_idx = self.dictionary.pad(), self.dictionary.eos()
|
497 |
+
if len(samples) == 0:
|
498 |
+
return {}
|
499 |
+
|
500 |
+
src_tokens = data_utils.collate_tokens(
|
501 |
+
[s["source"] for s in samples], pad_idx, eos_idx, left_pad=False
|
502 |
+
)
|
503 |
+
|
504 |
+
tgt_tokens = data_utils.collate_tokens(
|
505 |
+
[s["target"] for s in samples],
|
506 |
+
pad_idx=pad_idx,
|
507 |
+
eos_idx=pad_idx, # appending padding, eos is there already
|
508 |
+
left_pad=False,
|
509 |
+
)
|
510 |
+
|
511 |
+
src_durs, tgt_durs = [
|
512 |
+
data_utils.collate_tokens(
|
513 |
+
[s[k] for s in samples],
|
514 |
+
pad_idx=self.pads.dur,
|
515 |
+
eos_idx=self.pads.dur,
|
516 |
+
left_pad=False,
|
517 |
+
)
|
518 |
+
for k in ["dur_source", "dur_target"]
|
519 |
+
]
|
520 |
+
|
521 |
+
src_f0s, tgt_f0s = [
|
522 |
+
data_utils.collate_tokens(
|
523 |
+
[s[k] for s in samples],
|
524 |
+
pad_idx=self.pads.f0,
|
525 |
+
eos_idx=self.pads.f0,
|
526 |
+
left_pad=False,
|
527 |
+
)
|
528 |
+
for k in ["f0_source", "f0_target"]
|
529 |
+
]
|
530 |
+
|
531 |
+
mask, dur_mask, f0_mask = [
|
532 |
+
data_utils.collate_tokens(
|
533 |
+
[s[k] for s in samples],
|
534 |
+
pad_idx=1,
|
535 |
+
eos_idx=1,
|
536 |
+
left_pad=False,
|
537 |
+
)
|
538 |
+
for k in ["mask", "dur_mask", "f0_mask"]
|
539 |
+
]
|
540 |
+
|
541 |
+
src_lengths = torch.LongTensor([s["source"].numel() for s in samples])
|
542 |
+
n_tokens = sum(len(s["source"]) for s in samples)
|
543 |
+
|
544 |
+
result = {
|
545 |
+
"nsentences": len(samples),
|
546 |
+
"ntokens": n_tokens,
|
547 |
+
"net_input": {
|
548 |
+
"src_tokens": src_tokens,
|
549 |
+
"src_lengths": src_lengths,
|
550 |
+
"dur_src": src_durs,
|
551 |
+
"f0_src": src_f0s,
|
552 |
+
},
|
553 |
+
"target": tgt_tokens,
|
554 |
+
"dur_target": tgt_durs,
|
555 |
+
"f0_target": tgt_f0s,
|
556 |
+
"mask": mask,
|
557 |
+
"dur_mask": dur_mask,
|
558 |
+
"f0_mask": f0_mask,
|
559 |
+
}
|
560 |
+
|
561 |
+
if "filename" in samples[0]:
|
562 |
+
result["filename"] = [s["filename"] for s in samples]
|
563 |
+
|
564 |
+
# TODO: remove this hack into the inference dataset
|
565 |
+
if "prefix" in samples[0]:
|
566 |
+
result["prefix"] = [s["prefix"] for s in samples]
|
567 |
+
|
568 |
+
if "raw_f0" in samples[0]:
|
569 |
+
raw_f0s = data_utils.collate_tokens(
|
570 |
+
[s["raw_f0"] for s in samples],
|
571 |
+
pad_idx=self.pads.f0,
|
572 |
+
eos_idx=self.pads.f0,
|
573 |
+
left_pad=False,
|
574 |
+
)
|
575 |
+
result["raw_f0"] = raw_f0s
|
576 |
+
return result
|
modules/voice_conversion/fairseq/data/colorize_dataset.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from . import BaseWrapperDataset
|
9 |
+
|
10 |
+
|
11 |
+
class ColorizeDataset(BaseWrapperDataset):
|
12 |
+
"""Adds 'colors' property to net input that is obtained from the provided color getter for use by models"""
|
13 |
+
|
14 |
+
def __init__(self, dataset, color_getter):
|
15 |
+
super().__init__(dataset)
|
16 |
+
self.color_getter = color_getter
|
17 |
+
|
18 |
+
def collater(self, samples):
|
19 |
+
base_collate = super().collater(samples)
|
20 |
+
if len(base_collate) > 0:
|
21 |
+
base_collate["net_input"]["colors"] = torch.tensor(
|
22 |
+
list(self.color_getter(self.dataset, s["id"]) for s in samples),
|
23 |
+
dtype=torch.long,
|
24 |
+
)
|
25 |
+
return base_collate
|
modules/voice_conversion/fairseq/data/concat_dataset.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import bisect
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from torch.utils.data.dataloader import default_collate
|
10 |
+
|
11 |
+
from . import FairseqDataset
|
12 |
+
|
13 |
+
|
14 |
+
class ConcatDataset(FairseqDataset):
|
15 |
+
@staticmethod
|
16 |
+
def cumsum(sequence, sample_ratios):
|
17 |
+
r, s = [], 0
|
18 |
+
for e, ratio in zip(sequence, sample_ratios):
|
19 |
+
curr_len = int(ratio * len(e))
|
20 |
+
r.append(curr_len + s)
|
21 |
+
s += curr_len
|
22 |
+
return r
|
23 |
+
|
24 |
+
def __init__(self, datasets, sample_ratios=1):
|
25 |
+
super(ConcatDataset, self).__init__()
|
26 |
+
assert len(datasets) > 0, "datasets should not be an empty iterable"
|
27 |
+
self.datasets = list(datasets)
|
28 |
+
if isinstance(sample_ratios, int):
|
29 |
+
sample_ratios = [sample_ratios] * len(self.datasets)
|
30 |
+
self.sample_ratios = sample_ratios
|
31 |
+
self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios)
|
32 |
+
self.real_sizes = [len(d) for d in self.datasets]
|
33 |
+
|
34 |
+
def __len__(self):
|
35 |
+
return self.cumulative_sizes[-1]
|
36 |
+
|
37 |
+
def __getitem__(self, idx):
|
38 |
+
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
|
39 |
+
return self.datasets[dataset_idx][sample_idx]
|
40 |
+
|
41 |
+
def _get_dataset_and_sample_index(self, idx: int):
|
42 |
+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
43 |
+
if dataset_idx == 0:
|
44 |
+
sample_idx = idx
|
45 |
+
else:
|
46 |
+
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
47 |
+
sample_idx = sample_idx % self.real_sizes[dataset_idx]
|
48 |
+
return dataset_idx, sample_idx
|
49 |
+
|
50 |
+
def collater(self, samples, **extra_args):
|
51 |
+
# For now only supports datasets with same underlying collater implementations
|
52 |
+
if hasattr(self.datasets[0], "collater"):
|
53 |
+
return self.datasets[0].collater(samples, **extra_args)
|
54 |
+
else:
|
55 |
+
return default_collate(samples, **extra_args)
|
56 |
+
|
57 |
+
def size(self, idx: int):
|
58 |
+
"""
|
59 |
+
Return an example's size as a float or tuple.
|
60 |
+
"""
|
61 |
+
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
|
62 |
+
return self.datasets[dataset_idx].size(sample_idx)
|
63 |
+
|
64 |
+
def num_tokens(self, index: int):
|
65 |
+
return np.max(self.size(index))
|
66 |
+
|
67 |
+
def attr(self, attr: str, index: int):
|
68 |
+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, index)
|
69 |
+
return getattr(self.datasets[dataset_idx], attr, None)
|
70 |
+
|
71 |
+
@property
|
72 |
+
def sizes(self):
|
73 |
+
_dataset_sizes = []
|
74 |
+
for ds, sr in zip(self.datasets, self.sample_ratios):
|
75 |
+
if isinstance(ds.sizes, np.ndarray):
|
76 |
+
_dataset_sizes.append(np.tile(ds.sizes, sr))
|
77 |
+
else:
|
78 |
+
# Only support underlying dataset with single size array.
|
79 |
+
assert isinstance(ds.sizes, list)
|
80 |
+
_dataset_sizes.append(np.tile(ds.sizes[0], sr))
|
81 |
+
return np.concatenate(_dataset_sizes)
|
82 |
+
|
83 |
+
@property
|
84 |
+
def supports_prefetch(self):
|
85 |
+
return all(d.supports_prefetch for d in self.datasets)
|
86 |
+
|
87 |
+
def ordered_indices(self):
|
88 |
+
"""
|
89 |
+
Returns indices sorted by length. So less padding is needed.
|
90 |
+
"""
|
91 |
+
if isinstance(self.sizes, np.ndarray) and len(self.sizes.shape) > 1:
|
92 |
+
# special handling for concatenating lang_pair_datasets
|
93 |
+
indices = np.arange(len(self))
|
94 |
+
sizes = self.sizes
|
95 |
+
tgt_sizes = (
|
96 |
+
sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None
|
97 |
+
)
|
98 |
+
src_sizes = (
|
99 |
+
sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
|
100 |
+
)
|
101 |
+
# sort by target length, then source length
|
102 |
+
if tgt_sizes is not None:
|
103 |
+
indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")]
|
104 |
+
return indices[np.argsort(src_sizes[indices], kind="mergesort")]
|
105 |
+
else:
|
106 |
+
return np.argsort(self.sizes)
|
107 |
+
|
108 |
+
def prefetch(self, indices):
|
109 |
+
frm = 0
|
110 |
+
for to, ds in zip(self.cumulative_sizes, self.datasets):
|
111 |
+
real_size = len(ds)
|
112 |
+
if getattr(ds, "supports_prefetch", False):
|
113 |
+
ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
|
114 |
+
frm = to
|
115 |
+
|
116 |
+
@property
|
117 |
+
def can_reuse_epoch_itr_across_epochs(self):
|
118 |
+
return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets)
|
119 |
+
|
120 |
+
def set_epoch(self, epoch):
|
121 |
+
super().set_epoch(epoch)
|
122 |
+
for ds in self.datasets:
|
123 |
+
if hasattr(ds, "set_epoch"):
|
124 |
+
ds.set_epoch(epoch)
|