Yurii Paniv commited on
Commit
1ce668d
1 Parent(s): b75a2aa

Load model once, remove gradients from accentor

Browse files
Files changed (2) hide show
  1. app.py +17 -28
  2. stress.py +2 -1
app.py CHANGED
@@ -2,7 +2,6 @@ import tempfile
2
 
3
  import gradio as gr
4
 
5
- from TTS.utils.manage import ModelManager
6
  from TTS.utils.synthesizer import Synthesizer
7
  import requests
8
  from os.path import exists
@@ -11,19 +10,11 @@ from datetime import datetime
11
  from stress import sentence_to_stress
12
  from enum import Enum
13
  import torch
14
- import gc
15
 
16
  class StressOption(Enum):
17
  ManualStress = "Наголоси вручну"
18
  AutomaticStress = "Автоматичні наголоси (Beta)"
19
 
20
- MODEL_NAMES = [
21
- "uk/mykyta/vits-tts"
22
- ]
23
- MODELS = {}
24
-
25
- manager = ModelManager()
26
-
27
 
28
  def download(url, file_name):
29
  if not exists(file_name):
@@ -35,39 +26,36 @@ def download(url, file_name):
35
  print(f"Found {file_name}. Skipping download...")
36
 
37
 
38
- for MODEL_NAME in MODEL_NAMES:
39
- print(f"downloading {MODEL_NAME}")
40
- release_number = "v2.0.0-beta"
41
- model_link = f"https://github.com/robinhad/ukrainian-tts/releases/download/{release_number}/model-inference.pth"
42
- config_link = f"https://github.com/robinhad/ukrainian-tts/releases/download/{release_number}/config.json"
43
 
44
- model_path = "model.pth"
45
- config_path = "config.json"
46
 
47
- download(model_link, model_path)
48
- download(config_link, config_path)
49
 
50
-
51
- #MODELS[MODEL_NAME] = synthesizer
52
 
 
 
 
 
 
 
53
 
54
  def tts(text: str, stress: str):
55
  text = preprocess_text(text)
56
  text_limit = 1200
57
  text = text if len(text) < text_limit else text[0:text_limit] # mitigate crashes on hf space
58
  text = sentence_to_stress(text) if stress == StressOption.AutomaticStress.value else text
59
- print(text, datetime.utcnow())
60
 
61
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
62
  with torch.no_grad():
63
- synthesizer = Synthesizer(
64
- model_path, config_path, None, None, None,
65
- )
66
- if synthesizer is None:
67
- raise NameError("model not found")
68
  wavs = synthesizer.tts(text)
69
  synthesizer.save_wav(wavs, fp)
70
- gc.collect()
71
  return fp.name
72
 
73
 
@@ -91,7 +79,8 @@ iface = gr.Interface(
91
  "Github: [https://github.com/robinhad/ukrainian-tts](https://github.com/robinhad/ukrainian-tts)",
92
  examples=[
93
  ["Введ+іть, б+удь л+аска, сво+є р+ечення.", StressOption.ManualStress.value],
94
- ["Привіт, як тебе звати?", StressOption.AutomaticStress.value]
 
95
  ]
96
  )
97
  iface.launch(enable_queue=True, prevent_thread_lock=True)
2
 
3
  import gradio as gr
4
 
 
5
  from TTS.utils.synthesizer import Synthesizer
6
  import requests
7
  from os.path import exists
10
  from stress import sentence_to_stress
11
  from enum import Enum
12
  import torch
 
13
 
14
  class StressOption(Enum):
15
  ManualStress = "Наголоси вручну"
16
  AutomaticStress = "Автоматичні наголоси (Beta)"
17
 
 
 
 
 
 
 
 
18
 
19
  def download(url, file_name):
20
  if not exists(file_name):
26
  print(f"Found {file_name}. Skipping download...")
27
 
28
 
29
+ print("downloading uk/mykyta/vits-tts")
30
+ release_number = "v2.0.0-beta"
31
+ model_link = f"https://github.com/robinhad/ukrainian-tts/releases/download/{release_number}/model-inference.pth"
32
+ config_link = f"https://github.com/robinhad/ukrainian-tts/releases/download/{release_number}/config.json"
 
33
 
34
+ model_path = "model.pth"
35
+ config_path = "config.json"
36
 
37
+ download(model_link, model_path)
38
+ download(config_link, config_path)
39
 
 
 
40
 
41
+ synthesizer = Synthesizer(
42
+ model_path, config_path, None, None, None,
43
+ )
44
+
45
+ if synthesizer is None:
46
+ raise NameError("model not found")
47
 
48
  def tts(text: str, stress: str):
49
  text = preprocess_text(text)
50
  text_limit = 1200
51
  text = text if len(text) < text_limit else text[0:text_limit] # mitigate crashes on hf space
52
  text = sentence_to_stress(text) if stress == StressOption.AutomaticStress.value else text
53
+ print(text, stress, datetime.utcnow())
54
 
55
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
56
  with torch.no_grad():
 
 
 
 
 
57
  wavs = synthesizer.tts(text)
58
  synthesizer.save_wav(wavs, fp)
 
59
  return fp.name
60
 
61
 
79
  "Github: [https://github.com/robinhad/ukrainian-tts](https://github.com/robinhad/ukrainian-tts)",
80
  examples=[
81
  ["Введ+іть, б+удь л+аска, сво+є р+ечення.", StressOption.ManualStress.value],
82
+ ["Введіть, будь ласка, своє речення.", StressOption.ManualStress.value],
83
+ ["Привіт, як тебе звати?", StressOption.AutomaticStress.value],
84
  ]
85
  )
86
  iface.launch(enable_queue=True, prevent_thread_lock=True)
stress.py CHANGED
@@ -14,7 +14,8 @@ replace_accents = importer.load_pickle("uk-accentor", "replace_accents")
14
  alphabet = "абгґдеєжзиіїйклмнопрстуфхцчшщьюя"
15
 
16
  def accent_word(word):
17
- stressed_words = accentor.predict([word], mode='stress')
 
18
  plused_words = [replace_accents(x) for x in stressed_words]
19
  return plused_words[0]
20
 
14
  alphabet = "абгґдеєжзиіїйклмнопрстуфхцчшщьюя"
15
 
16
  def accent_word(word):
17
+ with torch.no_grad():
18
+ stressed_words = accentor.predict([word], mode='stress')
19
  plused_words = [replace_accents(x) for x in stressed_words]
20
  return plused_words[0]
21