lpw commited on
Commit
505f510
1 Parent(s): 0b08851

Create audio_to_audio.py

Browse files
Files changed (1) hide show
  1. audio_to_audio.py +143 -0
audio_to_audio.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from pathlib import Path
4
+ from typing import List, Tuple
5
+
6
+ import numpy as np
7
+ import torch
8
+ # from app.pipelines import Pipeline
9
+ from app.pipelines.utils import ARG_OVERRIDES_MAP
10
+ from fairseq import hub_utils
11
+ from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub
12
+ from fairseq.models.speech_to_speech.hub_interface import S2SHubInterface
13
+ from fairseq.models.speech_to_text.hub_interface import S2THubInterface
14
+ from fairseq.models.text_to_speech import CodeHiFiGANVocoder
15
+ from fairseq.models.text_to_speech.hub_interface import (
16
+ TTSHubInterface,
17
+ VocoderHubInterface,
18
+ )
19
+ from huggingface_hub import snapshot_download
20
+
21
+
22
+ class SpeechToSpeechPipeline():
23
+ def __init__(self, model_id: str):
24
+ arg_overrides = ARG_OVERRIDES_MAP.get(
25
+ model_id, {}
26
+ ) # Model specific override. TODO: Update on checkpoint side in the future
27
+ arg_overrides["config_yaml"] = "config.yaml" # common override
28
+ models, cfg, task = load_model_ensemble_and_task_from_hf_hub(
29
+ model_id,
30
+ arg_overrides=arg_overrides,
31
+ cache_dir=os.getenv("HUGGINGFACE_HUB_CACHE"),
32
+ )
33
+ self.cfg = cfg
34
+ self.model = models[0].cpu()
35
+ self.model.eval()
36
+ self.task = task
37
+
38
+ self.sampling_rate = getattr(self.task, "sr", None) or 16_000
39
+
40
+ tgt_lang = self.task.data_cfg.hub.get("tgt_lang", None)
41
+ pfx = f"{tgt_lang}_" if self.task.data_cfg.prepend_tgt_lang_tag else ""
42
+
43
+ generation_args = self.task.data_cfg.hub.get(f"{pfx}generation_args", None)
44
+ if generation_args is not None:
45
+ for key in generation_args:
46
+ setattr(cfg.generation, key, generation_args[key])
47
+ self.generator = task.build_generator([self.model], cfg.generation)
48
+
49
+ tts_model_id = self.task.data_cfg.hub.get(f"{pfx}tts_model_id", None)
50
+ self.unit_vocoder = self.task.data_cfg.hub.get(f"{pfx}unit_vocoder", None)
51
+ self.tts_model, self.tts_task, self.tts_generator = None, None, None
52
+ if tts_model_id is not None:
53
+ _id = tts_model_id.split(":")[-1]
54
+ cache_dir = os.getenv("HUGGINGFACE_HUB_CACHE")
55
+ if self.unit_vocoder is not None:
56
+ library_name = "fairseq"
57
+ cache_dir = (
58
+ cache_dir or (Path.home() / ".cache" / library_name).as_posix()
59
+ )
60
+ cache_dir = snapshot_download(
61
+ f"facebook/{_id}", cache_dir=cache_dir, library_name=library_name
62
+ )
63
+
64
+ x = hub_utils.from_pretrained(
65
+ cache_dir,
66
+ "model.pt",
67
+ ".",
68
+ archive_map=CodeHiFiGANVocoder.hub_models(),
69
+ config_yaml="config.json",
70
+ fp16=False,
71
+ is_vocoder=True,
72
+ )
73
+
74
+ with open(f"{x['args']['data']}/config.json") as f:
75
+ vocoder_cfg = json.load(f)
76
+ assert (
77
+ len(x["args"]["model_path"]) == 1
78
+ ), "Too many vocoder models in the input"
79
+
80
+ vocoder = CodeHiFiGANVocoder(x["args"]["model_path"][0], vocoder_cfg)
81
+ self.tts_model = VocoderHubInterface(vocoder_cfg, vocoder)
82
+
83
+ else:
84
+ (
85
+ tts_models,
86
+ tts_cfg,
87
+ self.tts_task,
88
+ ) = load_model_ensemble_and_task_from_hf_hub(
89
+ f"facebook/{_id}",
90
+ arg_overrides={"vocoder": "griffin_lim", "fp16": False},
91
+ cache_dir=cache_dir,
92
+ )
93
+ self.tts_model = tts_models[0].cpu()
94
+ self.tts_model.eval()
95
+ tts_cfg["task"].cpu = True
96
+ TTSHubInterface.update_cfg_with_data_cfg(
97
+ tts_cfg, self.tts_task.data_cfg
98
+ )
99
+ self.tts_generator = self.tts_task.build_generator(
100
+ [self.tts_model], tts_cfg
101
+ )
102
+
103
+ def __call__(self, inputs: np.array) -> Tuple[np.array, int, List[str]]:
104
+ """
105
+ Args:
106
+ inputs (:obj:`np.array`):
107
+ The raw waveform of audio received. By default sampled at `self.sampling_rate`.
108
+ The shape of this array is `T`, where `T` is the time axis
109
+ Return:
110
+ A :obj:`tuple` containing:
111
+ - :obj:`np.array`:
112
+ The return shape of the array must be `C'`x`T'`
113
+ - a :obj:`int`: the sampling rate as an int in Hz.
114
+ - a :obj:`List[str]`: the annotation for each out channel.
115
+ This can be the name of the instruments for audio source separation
116
+ or some annotation for speech enhancement. The length must be `C'`.
117
+ """
118
+ _inputs = torch.from_numpy(inputs).unsqueeze(0)
119
+ sample, text = None, None
120
+ if self.cfg.task._name in ["speech_to_text", "speech_to_text_sharded"]:
121
+ sample = S2THubInterface.get_model_input(self.task, _inputs)
122
+ text = S2THubInterface.get_prediction(
123
+ self.task, self.model, self.generator, sample
124
+ )
125
+ elif self.cfg.task._name in ["speech_to_speech"]:
126
+ s2shubinerface = S2SHubInterface(self.cfg, self.task, self.model)
127
+ sample = s2shubinerface.get_model_input(self.task, _inputs)
128
+ text = S2SHubInterface.get_prediction(
129
+ self.task, self.model, self.generator, sample
130
+ )
131
+
132
+ wav, sr = np.zeros((0,)), self.sampling_rate
133
+ if self.unit_vocoder is not None:
134
+ tts_sample = self.tts_model.get_model_input(text)
135
+ wav, sr = self.tts_model.get_prediction(tts_sample)
136
+ text = ""
137
+ else:
138
+ tts_sample = TTSHubInterface.get_model_input(self.tts_task, text)
139
+ wav, sr = TTSHubInterface.get_prediction(
140
+ self.tts_task, self.tts_model, self.tts_generator, tts_sample
141
+ )
142
+
143
+ return wav.unsqueeze(0).numpy(), sr, [text]