File size: 4,181 Bytes
b5e825c
 
91394e0
b5e825c
 
91394e0
b5e825c
 
91394e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a23964
91394e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
780954b
f1b8d35
b5e825c
91394e0
 
 
 
 
 
 
 
 
 
780954b
91394e0
 
780954b
91394e0
 
 
b5e825c
 
 
91394e0
 
 
 
f1b8d35
91394e0
 
 
 
 
 
 
b5e825c
91394e0
b5e825c
 
 
 
91394e0
1a42cf5
91394e0
 
 
1a42cf5
91394e0
 
05779d3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from __future__ import annotations

import time
from pathlib import Path

import librosa
import soundfile as sf
import torch

from modules.asr import get_asr_model
from modules.llm import get_llm_model
from modules.svs import get_svs_model
from evaluation.svs_eval import load_evaluators, run_evaluation
from modules.melody import MelodyController
from modules.utils.text_normalize import clean_llm_output


class SingingDialoguePipeline:
    def __init__(self, config: dict):
        if "device" in config:
            self.device = config["device"]
        else:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.cache_dir = config["cache_dir"]
        self.asr = get_asr_model(
            config["asr_model"], device=self.device, cache_dir=self.cache_dir
        )
        self.llm = get_llm_model(
            config["llm_model"], device=self.device, cache_dir=self.cache_dir
        )
        self.svs = get_svs_model(
            config["svs_model"], device=self.device, cache_dir=self.cache_dir
        )
        self.melody_controller = MelodyController(
            config["melody_source"], self.cache_dir
        )
        self.max_sentences = config.get("max_sentences", 2)
        self.track_latency = config.get("track_latency", False)
        self.evaluators = load_evaluators(config.get("evaluators", {}).get("svs", []))

    def set_asr_model(self, asr_model: str):
        self.asr = get_asr_model(
            asr_model, device=self.device, cache_dir=self.cache_dir
        )

    def set_llm_model(self, llm_model: str):
        self.llm = get_llm_model(
            llm_model, device=self.device, cache_dir=self.cache_dir
        )

    def set_svs_model(self, svs_model: str):
        self.svs = get_svs_model(
            svs_model, device=self.device, cache_dir=self.cache_dir
        )

    def set_melody_controller(self, melody_source: str):
        self.melody_controller = MelodyController(melody_source, self.cache_dir)

    def run(
        self,
        audio_path,
        language,
        system_prompt,
        speaker,
        output_audio_path: Path | str = None,
    ):
        if self.track_latency:
            asr_start_time = time.time()
        audio_array, audio_sample_rate = librosa.load(audio_path, sr=16000)
        asr_result = self.asr.transcribe(
            audio_array, audio_sample_rate=audio_sample_rate, language=language
        )
        if self.track_latency:
            asr_end_time = time.time()
            asr_latency = asr_end_time - asr_start_time
        melody_prompt = self.melody_controller.get_melody_constraints(max_num_phrases=self.max_sentences)
        if self.track_latency:
            llm_start_time = time.time()
        output = self.llm.generate(asr_result, system_prompt + melody_prompt)
        if self.track_latency:
            llm_end_time = time.time()
            llm_latency = llm_end_time - llm_start_time
        llm_response = clean_llm_output(
            output, language=language, max_sentences=self.max_sentences
        )
        score = self.melody_controller.generate_score(llm_response, language)
        if self.track_latency:
            svs_start_time = time.time()
        singing_audio, sample_rate = self.svs.synthesize(
            score, language=language, speaker=speaker
        )
        if self.track_latency:
            svs_end_time = time.time()
            svs_latency = svs_end_time - svs_start_time
        results = {
            "asr_text": asr_result,
            "llm_text": llm_response,
            "svs_audio": (sample_rate, singing_audio),
        }
        if output_audio_path:
            Path(output_audio_path).parent.mkdir(parents=True, exist_ok=True)
            sf.write(output_audio_path, singing_audio, sample_rate)
            results["output_audio_path"] = output_audio_path
        if self.track_latency:
            results["metrics"] = {
                "asr_latency": asr_latency,
                "llm_latency": llm_latency,
                "svs_latency": svs_latency,
            }
        return results

    def evaluate(self, audio_path):
        return run_evaluation(audio_path, self.evaluators)