Yurii Paniv commited on
Commit
64fcafd
1 Parent(s): e883b68

#8 Add docs

Browse files
Files changed (3) hide show
  1. app.py +2 -0
  2. requirements.txt +2 -3
  3. ukrainian_tts/tts.py +29 -6
app.py CHANGED
@@ -3,6 +3,7 @@ import gradio as gr
3
  from datetime import datetime
4
  from enum import Enum
5
  from ukrainian_tts.tts import TTS
 
6
 
7
  class StressOption(Enum):
8
  AutomaticStress = "Автоматичні наголоси (за словником) 📖"
@@ -16,6 +17,7 @@ class VoiceOption(Enum):
16
  Dmytro = "Дмитро (чоловічий) 👨"
17
  Olga = "Ольга (жіночий) 👩"
18
 
 
19
 
20
  badge = (
21
  "https://visitor-badge-reloaded.herokuapp.com/badge?page_id=robinhad.ukrainian-tts"
3
  from datetime import datetime
4
  from enum import Enum
5
  from ukrainian_tts.tts import TTS
6
+ from torch.cuda import is_available
7
 
8
  class StressOption(Enum):
9
  AutomaticStress = "Автоматичні наголоси (за словником) 📖"
17
  Dmytro = "Дмитро (чоловічий) 👨"
18
  Olga = "Ольга (жіночий) 👩"
19
 
20
+ print(f"CUDA available? {is_available}")
21
 
22
  badge = (
23
  "https://visitor-badge-reloaded.herokuapp.com/badge?page_id=robinhad.ukrainian-tts"
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
- TTS==0.8.0
 
2
  torch==1.12.1
3
  --extra-index-url https://download.pytorch.org/whl/cu113
4
- ukrainian-word-stress==1.0.1
5
- git+https://github.com/egorsmkv/ukrainian-accentor.git@5b7971c4e135e3ff3283336962e63fc0b1c80f4c
1
+ # requirements for HuggingFace demo. Installs local package.
2
+ .
3
  torch==1.12.1
4
  --extra-index-url https://download.pytorch.org/whl/cu113
 
 
ukrainian_tts/tts.py CHANGED
@@ -1,12 +1,13 @@
1
  from io import BytesIO
2
  import requests
3
- from os.path import exists
4
  from TTS.utils.synthesizer import Synthesizer
5
  from enum import Enum
6
  from .formatter import preprocess_text
7
  from torch import no_grad
8
 
9
  class Voices(Enum):
 
10
  Olena = "olena"
11
  Mykyta = "mykyta"
12
  Lada = "lada"
@@ -15,22 +16,39 @@ class Voices(Enum):
15
 
16
 
17
  class StressOption(Enum):
 
 
 
18
  Dictionary = "dictionary"
19
  Model = "model"
20
 
21
 
22
  class TTS:
 
 
 
23
  def __init__(self, cache_folder=None) -> None:
 
 
 
 
24
  self.__setup_cache(cache_folder)
25
 
26
 
27
  def tts(self, text: str, voice: str, stress: str, output_fp=BytesIO()):
 
 
 
 
 
 
 
28
  autostress_with_model = (
29
  True if stress == StressOption.Model.value else False
30
  )
31
 
32
  if voice not in [option.value for option in Voices]:
33
- raise ValueError("Invalid value for voice selected! Please use one of the following values: {', '.join([option.value for option in Voices])}.")
34
 
35
  text = preprocess_text(text, autostress_with_model)
36
 
@@ -44,15 +62,19 @@ class TTS:
44
 
45
 
46
  def __setup_cache(self, cache_folder=None):
 
47
  print("downloading uk/mykyta/vits-tts")
48
  release_number = "v3.0.0"
49
  model_link = f"https://github.com/robinhad/ukrainian-tts/releases/download/{release_number}/model-inference.pth"
50
  config_link = f"https://github.com/robinhad/ukrainian-tts/releases/download/{release_number}/config.json"
51
  speakers_link = f"https://github.com/robinhad/ukrainian-tts/releases/download/{release_number}/speakers.pth"
52
 
53
- model_path = "model.pth"
54
- config_path = "config.json"
55
- speakers_path = "speakers.pth"
 
 
 
56
 
57
  self.__download(model_link, model_path)
58
  self.__download(config_link, config_path)
@@ -67,10 +89,11 @@ class TTS:
67
  )
68
 
69
  if self.synthesizer is None:
70
- raise NameError("model not found")
71
 
72
 
73
  def __download(self, url, file_name):
 
74
  if not exists(file_name):
75
  print(f"Downloading {file_name}")
76
  r = requests.get(url, allow_redirects=True)
1
  from io import BytesIO
2
  import requests
3
+ from os.path import exists, join
4
  from TTS.utils.synthesizer import Synthesizer
5
  from enum import Enum
6
  from .formatter import preprocess_text
7
  from torch import no_grad
8
 
9
  class Voices(Enum):
10
+ """List of available voices for the model."""
11
  Olena = "olena"
12
  Mykyta = "mykyta"
13
  Lada = "lada"
16
 
17
 
18
  class StressOption(Enum):
19
+ """Options how to stress sentence.
20
+ - `dictionary` - performs lookup in dictionary, taking into account grammatical case of a word and its' neighbors
21
+ - `model` - stress using transformer model"""
22
  Dictionary = "dictionary"
23
  Model = "model"
24
 
25
 
26
  class TTS:
27
+ """
28
+
29
+ """
30
  def __init__(self, cache_folder=None) -> None:
31
+ """
32
+ Class to setup a text-to-speech engine, from download to model creation. \n
33
+ Downloads or uses files from `cache_folder` directory. \n
34
+ By default stores in current directory."""
35
  self.__setup_cache(cache_folder)
36
 
37
 
38
  def tts(self, text: str, voice: str, stress: str, output_fp=BytesIO()):
39
+ """
40
+ Run a Text-to-Speech engine and output to `output_fp` BytesIO-like object.
41
+ - `text` - your model input text.
42
+ - `voice` - one of predefined voices from `Voices` enum.
43
+ - `stress` - stress method options, predefined in `StressOption` enum.
44
+ - `output_fp` - file-like object output. Stores in RAM by default.
45
+ """
46
  autostress_with_model = (
47
  True if stress == StressOption.Model.value else False
48
  )
49
 
50
  if voice not in [option.value for option in Voices]:
51
+ raise ValueError(f"Invalid value for voice selected! Please use one of the following values: {', '.join([option.value for option in Voices])}.")
52
 
53
  text = preprocess_text(text, autostress_with_model)
54
 
62
 
63
 
64
  def __setup_cache(self, cache_folder=None):
65
+ """Downloads models and stores them into `cache_folder`. By default stores in current directory."""
66
  print("downloading uk/mykyta/vits-tts")
67
  release_number = "v3.0.0"
68
  model_link = f"https://github.com/robinhad/ukrainian-tts/releases/download/{release_number}/model-inference.pth"
69
  config_link = f"https://github.com/robinhad/ukrainian-tts/releases/download/{release_number}/config.json"
70
  speakers_link = f"https://github.com/robinhad/ukrainian-tts/releases/download/{release_number}/speakers.pth"
71
 
72
+ if cache_folder is None:
73
+ cache_folder = "."
74
+
75
+ model_path = join(cache_folder, "model.pth")
76
+ config_path = join(cache_folder, "config.json")
77
+ speakers_path = join(cache_folder, "speakers.pth")
78
 
79
  self.__download(model_link, model_path)
80
  self.__download(config_link, config_path)
89
  )
90
 
91
  if self.synthesizer is None:
92
+ raise NameError("Model not found")
93
 
94
 
95
  def __download(self, url, file_name):
96
+ """Downloads file from `url` into local `file_name` file."""
97
  if not exists(file_name):
98
  print(f"Downloading {file_name}")
99
  r = requests.get(url, allow_redirects=True)