truong-xuan-linh commited on
Commit
79309e0
1 Parent(s): d37e44c
Files changed (2) hide show
  1. app.py +5 -3
  2. src/model.py +22 -17
app.py CHANGED
@@ -14,6 +14,7 @@ if "model_name" not in st.session_state:
14
  st.session_state.model_name = None
15
  st.session_state.audio = None
16
  st.session_state.wav_file = None
 
17
 
18
  with st.sidebar.form("my_form"):
19
 
@@ -33,15 +34,16 @@ with st.sidebar.form("my_form"):
33
  speaker_id = st.selectbox("source voice", options= list(dataset_dict.keys()))
34
  speaker_url = st.text_input("speaker url", value="")
35
  # speaker_id = st.selectbox("source voice", options= glob.glob("voices/*.wav"))
36
- if st.session_state.model_name != model_name :
37
  st.session_state.model_name = model_name
38
- st.session_state.model = Model(model_name=model_name)
39
  st.session_state.speaker_id = speaker_id
 
40
 
41
  # Every form must have a submit button.
42
  submitted = st.form_submit_button("Submit")
43
  if submitted:
44
- st.session_state.audio = st.session_state.model.inference(text=text, speaker_id=speaker_id, speaker_url=speaker_url)
45
 
46
  audio_holder = st.empty()
47
  audio_holder.audio(st.session_state.audio, sample_rate=16000)
 
14
  st.session_state.model_name = None
15
  st.session_state.audio = None
16
  st.session_state.wav_file = None
17
+ st.session_state.speaker_url = ""
18
 
19
  with st.sidebar.form("my_form"):
20
 
 
34
  speaker_id = st.selectbox("source voice", options= list(dataset_dict.keys()))
35
  speaker_url = st.text_input("speaker url", value="")
36
  # speaker_id = st.selectbox("source voice", options= glob.glob("voices/*.wav"))
37
+ if st.session_state.model_name != model_name or speaker_url != st.session_state.speaker_url :
38
  st.session_state.model_name = model_name
39
+ st.session_state.model = Model(model_name=model_name, speaker_url=speaker_url)
40
  st.session_state.speaker_id = speaker_id
41
+ st.session_state.speaker_url = speaker_url
42
 
43
  # Every form must have a submit button.
44
  submitted = st.form_submit_button("Submit")
45
  if submitted:
46
+ st.session_state.audio = st.session_state.model.inference(text=text, speaker_id=speaker_id)
47
 
48
  audio_holder = st.empty()
49
  audio_holder.audio(st.session_state.audio, sample_rate=16000)
src/model.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  import requests
4
  import torchaudio
5
  import numpy as np
6
- from src.reduce_noise import smooth_and_reduce_noise, model_remove_noise, model, df_state
7
  import io
8
  from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
9
  from pydub import AudioSegment
@@ -60,36 +60,41 @@ def uroman_normalization(string):
60
 
61
  class Model():
62
 
63
- def __init__(self, model_name):
64
  self.model_name = model_name
65
  self.processor = SpeechT5Processor.from_pretrained(model_name)
66
  self.model = SpeechT5ForTextToSpeech.from_pretrained(model_name)
67
  # self.model.generate = partial(self.model.generate, use_cache=True)
68
 
69
  self.model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  if model_name == "truong-xuan-linh/speecht5-vietnamese-commonvoice" or model_name == "truong-xuan-linh/speecht5-irmvivoice":
71
  self.speaker_embeddings = torch.zeros((1, 512)) # or load xvectors from a file
72
- else:
73
- self.speaker_embeddings = torch.ones((1, 512)) # or load xvectors from a file
74
 
75
- def inference(self, text, speaker_id=None, speaker_url=""):
76
  # if self.model_name == "truong-xuan-linh/speecht5-vietnamese-voiceclone-v2":
77
  # # self.speaker_embeddings = torch.tensor(dataset_dict_v2[speaker_id])
78
  # wavform, _ = torchaudio.load(speaker_id)
79
  # self.speaker_embeddings = create_speaker_embedding(wavform)[0]
80
 
81
  if "voiceclone" in self.model_name:
82
- if not speaker_url:
83
  self.speaker_embeddings = torch.tensor(dataset_dict[speaker_id])
84
- else:
85
- response = requests.get(speaker_url)
86
- audio_stream = io.BytesIO(response.content)
87
- audio_segment = AudioSegment.from_file(audio_stream, format="wav")
88
- audio_segment = audio_segment.set_channels(1)
89
- audio_segment = audio_segment.set_frame_rate(16000)
90
- audio_segment = audio_segment.set_sample_width(2)
91
- wavform, _ = torchaudio.load(audio_segment.export())
92
- self.speaker_embeddings = create_speaker_embedding(wavform)[0]
93
  # self.speaker_embeddings = create_speaker_embedding(speaker_id)[0]
94
  # wavform, _ = torchaudio.load("voices/kcbn1.wav")
95
  # self.speaker_embeddings = create_speaker_embedding(wavform)[0]
@@ -114,8 +119,8 @@ class Model():
114
  speech = self.model.generate_speech(inputs["input_ids"], threshold=0.5, speaker_embeddings=self.speaker_embeddings, vocoder=vocoder)
115
  full_speech.append(speech.numpy())
116
  # full_speech.append(butter_bandpass_filter(speech.numpy(), lowcut=10, highcut=5000, fs=16000, order=2))
117
- out_audio = model_remove_noise(model, df_state, np.concatenate(full_speech))
118
- return out_audio
119
 
120
  @staticmethod
121
  def moving_average(data, window_size):
 
3
  import requests
4
  import torchaudio
5
  import numpy as np
6
+ # from src.reduce_noise import smooth_and_reduce_noise, model_remove_noise, model, df_state
7
  import io
8
  from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
9
  from pydub import AudioSegment
 
60
 
61
  class Model():
62
 
63
+ def __init__(self, model_name, speaker_url=""):
64
  self.model_name = model_name
65
  self.processor = SpeechT5Processor.from_pretrained(model_name)
66
  self.model = SpeechT5ForTextToSpeech.from_pretrained(model_name)
67
  # self.model.generate = partial(self.model.generate, use_cache=True)
68
 
69
  self.model.eval()
70
+
71
+ self.speaker_url = speaker_url
72
+ if speaker_url:
73
+
74
+ print(f"download speaker_url")
75
+ response = requests.get(speaker_url)
76
+ audio_stream = io.BytesIO(response.content)
77
+ audio_segment = AudioSegment.from_file(audio_stream, format="wav")
78
+ audio_segment = audio_segment.set_channels(1)
79
+ audio_segment = audio_segment.set_frame_rate(16000)
80
+ audio_segment = audio_segment.set_sample_width(2)
81
+ wavform, _ = torchaudio.load(audio_segment.export())
82
+ self.speaker_embeddings = create_speaker_embedding(wavform)[0]
83
+ else:
84
+ self.speaker_embeddings = None
85
+
86
  if model_name == "truong-xuan-linh/speecht5-vietnamese-commonvoice" or model_name == "truong-xuan-linh/speecht5-irmvivoice":
87
  self.speaker_embeddings = torch.zeros((1, 512)) # or load xvectors from a file
 
 
88
 
89
+ def inference(self, text, speaker_id=None):
90
  # if self.model_name == "truong-xuan-linh/speecht5-vietnamese-voiceclone-v2":
91
  # # self.speaker_embeddings = torch.tensor(dataset_dict_v2[speaker_id])
92
  # wavform, _ = torchaudio.load(speaker_id)
93
  # self.speaker_embeddings = create_speaker_embedding(wavform)[0]
94
 
95
  if "voiceclone" in self.model_name:
96
+ if not self.speaker_url:
97
  self.speaker_embeddings = torch.tensor(dataset_dict[speaker_id])
 
 
 
 
 
 
 
 
 
98
  # self.speaker_embeddings = create_speaker_embedding(speaker_id)[0]
99
  # wavform, _ = torchaudio.load("voices/kcbn1.wav")
100
  # self.speaker_embeddings = create_speaker_embedding(wavform)[0]
 
119
  speech = self.model.generate_speech(inputs["input_ids"], threshold=0.5, speaker_embeddings=self.speaker_embeddings, vocoder=vocoder)
120
  full_speech.append(speech.numpy())
121
  # full_speech.append(butter_bandpass_filter(speech.numpy(), lowcut=10, highcut=5000, fs=16000, order=2))
122
+ # out_audio = model_remove_noise(model, df_state, np.concatenate(full_speech))
123
+ return np.concatenate(full_speech)
124
 
125
  @staticmethod
126
  def moving_average(data, window_size):