File size: 5,511 Bytes
5c60553
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import re
import torch
import requests
import torchaudio
import numpy as np
from src.reduce_noise import smooth_and_reduce_noise, model_remove_noise, model, df_state
import io
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from pydub import AudioSegment
import re
from uroman import uroman
# from src.pynote_speaker_embedding import create_speaker_embedding
from src.speechbrain_speaker_embedding import create_speaker_embedding

from datasets import load_dataset
dataset = load_dataset("truong-xuan-linh/vi-xvector-speechbrain", 
                       download_mode="force_redownload", 
                            verification_mode="no_checks", 
                            cache_dir="temp/",
                            revision="5ea5e4345258333cbc6d1dd2544f6c658e66a634")
dataset = dataset["train"].to_list()

dataset_dict = {}

for rc in dataset:
    dataset_dict[rc["speaker_id"]] = rc["embedding"]
    
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")

def remove_special_characters(sentence):
    # Use regular expression to keep only letters, periods, and commas
    sentence_after_removal =  re.sub(r'[^a-zA-Z\s,.\u00C0-\u1EF9]', ' ', sentence)
    return sentence_after_removal

from scipy.signal import butter, lfilter

def butter_bandpass(lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return b, a

def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = lfilter(b, a, data)
    return y

def korean_splitter(string):
    pattern = re.compile('[가-힣]+')
    matches = pattern.findall(string)
    return matches

def uroman_normalization(string):
    korean_inputs = korean_splitter(string)
    for korean_input in korean_inputs:
        korean_roman = uroman(korean_input)
        string = string.replace(korean_input, korean_roman)
    return string

class Model():
    
    def __init__(self, model_name):
        self.model_name = model_name
        self.processor = SpeechT5Processor.from_pretrained(model_name)
        self.model = SpeechT5ForTextToSpeech.from_pretrained(model_name)
        # self.model.generate = partial(self.model.generate, use_cache=True)

        self.model.eval()
        if model_name == "truong-xuan-linh/speecht5-vietnamese-commonvoice" or model_name == "truong-xuan-linh/speecht5-irmvivoice":
            self.speaker_embeddings = torch.zeros((1, 512))  # or load xvectors from a file
        else:
            self.speaker_embeddings = torch.ones((1, 512))  # or load xvectors from a file
            
    def inference(self, text, speaker_id=None, speaker_url=""):
        # if self.model_name == "truong-xuan-linh/speecht5-vietnamese-voiceclone-v2":
        #     # self.speaker_embeddings = torch.tensor(dataset_dict_v2[speaker_id])
        #     wavform, _ = torchaudio.load(speaker_id)
        #     self.speaker_embeddings = create_speaker_embedding(wavform)[0]
            
        if "voiceclone" in self.model_name:
            if not speaker_url:
                self.speaker_embeddings = torch.tensor(dataset_dict[speaker_id])
            else:
                response = requests.get(speaker_url)
                audio_stream = io.BytesIO(response.content)
                audio_segment = AudioSegment.from_file(audio_stream, format="wav")
                audio_segment = audio_segment.set_channels(1)
                audio_segment = audio_segment.set_frame_rate(16000)
                audio_segment = audio_segment.set_sample_width(2)
                wavform, _ = torchaudio.load(audio_segment.export())
                self.speaker_embeddings = create_speaker_embedding(wavform)[0]
            # self.speaker_embeddings = create_speaker_embedding(speaker_id)[0]
            # wavform, _ = torchaudio.load("voices/kcbn1.wav")
            # self.speaker_embeddings = create_speaker_embedding(wavform)[0]
            # wavform, _ = torchaudio.load(wav_file)
            # self.speaker_embeddings = create_speaker_embedding(wavform)[0]
        
            
        with torch.no_grad():
            full_speech = []
            separators = r";|\.|!|\?|\n"
            text = uroman_normalization(text)
            text = text.replace(" ", "▁")
            split_texts = re.split(separators, text)
            
            for split_text in split_texts:
                
                if split_text != "▁":
                    # split_text = remove_special_characters(" ," + split_text) + " ,"
                    split_text = split_text.lower() + "▁"
                    print(split_text)
                    inputs = self.processor.tokenizer(text=split_text, return_tensors="pt")
                    speech = self.model.generate_speech(inputs["input_ids"], threshold=0.5, speaker_embeddings=self.speaker_embeddings, vocoder=vocoder)
                    full_speech.append(speech.numpy())
                    # full_speech.append(butter_bandpass_filter(speech.numpy(), lowcut=10, highcut=5000, fs=16000, order=2))
            out_audio = model_remove_noise(model, df_state, np.concatenate(full_speech))
            return out_audio
    
    @staticmethod
    def moving_average(data, window_size):
        return np.convolve(data, np.ones(window_size)/window_size, mode='same')
    
# woman: VIVOSSPK26, VIVOSSPK02, VIVOSSPK40

# man: VIVOSSPK28, VIVOSSPK36, VIVOSDEV09, VIVOSSPK33, VIVOSSPK23