TenzinGayche commited on
Commit
e05d640
1 Parent(s): 7bd8037

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +252 -77
handler.py CHANGED
@@ -1,94 +1,269 @@
1
- from typing import Dict, Any,Union
2
- import librosa
3
- import numpy as np
4
- import torch
5
- import pyewts
6
- import noisereduce as nr
7
- from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
8
- from num2tib.core import convert
9
- from num2tib.core import convert2text
10
- import base64
11
  import re
 
 
 
 
 
 
 
12
  import requests
13
- converter = pyewts.pyewts()
14
- def download_file(url, destination):
15
- response = requests.get(url)
16
- with open(destination, 'wb') as file:
17
- file.write(response.content)
18
 
19
- # Example usage:
20
- download_file('https://huggingface.co/openpecha/speecht5-tts-01/resolve/main/female_2.npy', 'female_2.npy')
21
- def replace_numbers_with_convert(sentence, wylie=True):
22
- pattern = r'\d+(\.\d+)?'
23
- def replace(match):
24
- return convert(match.group(), wylie)
25
- result = re.sub(pattern, replace, sentence)
26
-
27
- return result
28
 
29
- def cleanup_text(inputs):
30
- for src, dst in replacements:
31
- inputs = inputs.replace(src, dst)
32
- return inputs
33
 
34
- speaker_embeddings = {
35
- "Lhasa(female)": "female_2.npy",
 
 
36
 
37
- }
38
 
39
- replacements = [
40
- ('_', '_'),
41
- ('*', 'v'),
42
- ('`', ';'),
43
- ('~', ','),
44
- ('+', ','),
45
- ('\\', ';'),
46
- ('|', ';'),
47
- ('╚',''),
48
- ('╗','')
49
- ]
50
 
 
51
 
 
 
 
 
 
 
 
 
 
52
 
53
 
 
54
 
55
- class EndpointHandler():
56
- def __init__(self, path=""):
57
- # load the model
58
- self.processor = SpeechT5Processor.from_pretrained("TenzinGayche/TTS_run3_ep20_174k_b")
59
- self.model = SpeechT5ForTextToSpeech.from_pretrained("TenzinGayche/TTS_run3_ep20_174k_b")
60
- self.model.to('cuda')
61
- self.vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
62
 
 
 
63
 
64
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Union[int, str]]:
65
- """_summary_
 
66
 
67
- Args:
68
- data (Dict[str, Any]): _description_
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- Returns:
71
- bytes: _description_
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  """
73
- text = data.pop("inputs",data)
74
-
75
- # process input
76
-
77
- if len(text.strip()) == 0:
78
- return (16000, np.zeros(0).astype(np.int16))
79
- text = converter.toWylie(text)
80
- text=cleanup_text(text)
81
- text=replace_numbers_with_convert(text)
82
- inputs = self.processor(text=text, return_tensors="pt")
83
- # limit input length
84
- input_ids = inputs["input_ids"]
85
- input_ids = input_ids[..., :self.model.config.max_text_positions]
86
- speaker_embedding = np.load(speaker_embeddings['Lhasa(female)'])
87
- speaker_embedding = torch.tensor(speaker_embedding)
88
- speech = self.model.generate_speech(input_ids.to('cuda'), speaker_embedding.to('cuda'), vocoder=self.vocoder.to('cuda'))
89
- speech = nr.reduce_noise(y=speech.to('cpu'), sr=16000)
90
- return {
91
- "sample_rate": 16000,
92
- "audio": base64.b64encode(speech.tostring()).decode("utf-8"),
93
-
94
- }
 
1
+ import subprocess
2
+ from typing import Dict, List, Any
3
+ import os
4
+ import json
5
+ import logging
6
+ import sys
7
+ import tempfile
8
+ import time
9
+ from pathlib import Path
 
10
  import re
11
+ import shutil
12
+ import stat
13
+ import subprocess
14
+
15
+ import uuid
16
+ from contextlib import contextmanager
17
+
18
  import requests
19
+ logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
20
+ # Git clone command
21
+ git_clone_command = "git clone https://github.com/OpenPecha/tibetan-aligner"
 
 
22
 
23
+ # Run the command using subprocess
24
+ try:
25
+ subprocess.run(git_clone_command, shell=True, check=True)
26
+ print("Git clone successful!")
27
+ except subprocess.CalledProcessError as e:
28
+ print(f"Error while running Git clone command: {e}")
 
 
 
29
 
 
 
 
 
30
 
31
+ ALIGNER_SCRIPT_DIR = Path("./tibetan-aligner").resolve()
32
+ ALIGNER_SCRIPT_NAME = "align_tib_en.sh"
33
+ ALIGNER_SCRIPT_PATH = ALIGNER_SCRIPT_DIR / ALIGNER_SCRIPT_NAME
34
+ assert ALIGNER_SCRIPT_PATH.is_file()
35
 
36
+ import requests
37
 
38
+ GITHUB_USERNAME = "pechawa"
39
+ GITHUB_ACCESS_TOKEN = "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"
40
+ GITHUB_TOKEN = "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"
41
+ GITHUB_EMAIL = "openpecha-bot@openpecha.org"
42
+ GITHUB_ORG = "MonlamAI"
43
+ MAI_TM_PUBLISH_TODO_REPO = "MonlamAI_TMs_Publish_TODO"
44
+ GITHUB_API_ENDPOINT = f"https://api.github.com/orgs/{GITHUB_ORG}/repos"
 
 
 
 
45
 
46
+ DEBUG = False
47
 
48
+ quiet = "-q" if DEBUG else ""
49
+ def make_dir_executable(dir_path: Path):
50
+ for fn in dir_path.iterdir():
51
+ st = os.stat(fn)
52
+ os.chmod(fn, st.st_mode | stat.S_IEXEC)
53
+ st = os.stat(fn)
54
+ os.chmod(fn, st.st_mode | stat.S_IXGRP)
55
+ st = os.stat(fn)
56
+ os.chmod(fn, st.st_mode | stat.S_IXOTH)
57
 
58
 
59
+ make_dir_executable(ALIGNER_SCRIPT_DIR)
60
 
 
 
 
 
 
 
 
61
 
62
+ def create_github_repo(repo_path: Path, repo_name: str):
63
+ logging.info("[INFO] Creating GitHub repo...")
64
 
65
+ # configure git users
66
+ subprocess.run(f"git config --global user.name {GITHUB_USERNAME}".split())
67
+ subprocess.run(f"git config --global user.email {GITHUB_EMAIL}".split())
68
 
69
+ # Initialize a Git repository
70
+ subprocess.run(f"git init {quiet}".split(), cwd=str(repo_path))
71
+
72
+ # Commit the changes
73
+ subprocess.run("git add . ".split(), cwd=str(repo_path))
74
+ subprocess.run(
75
+ f"git commit {quiet} -m".split() + ["Initial commit"], cwd=str(repo_path)
76
+ )
77
+
78
+ # Create a new repository on GitHub
79
+ response = requests.post(
80
+ GITHUB_API_ENDPOINT,
81
+ json={
82
+ "name": repo_name,
83
+ "private": True,
84
+ },
85
+ auth=(GITHUB_USERNAME, GITHUB_ACCESS_TOKEN),
86
+ )
87
+ response.raise_for_status()
88
+
89
+ time.sleep(3)
90
+
91
+ # Add the GitHub remote to the local Git repository and push the changes
92
+ remote_url = f"https://{GITHUB_ORG}:{GITHUB_ACCESS_TOKEN}@github.com/{GITHUB_ORG}/{repo_name}.git"
93
+ subprocess.run(
94
+ f"git remote add origin {remote_url}", cwd=str(repo_path), shell=True
95
+ )
96
+ # rename default branch to main
97
+ subprocess.run("git branch -M main".split(), cwd=str(repo_path))
98
+ subprocess.run(f"git push {quiet} -u origin main".split(), cwd=str(repo_path))
99
+
100
+ return response.json()["html_url"]
101
+
102
+
103
+ def convert_raw_align_to_tm(align_fn: Path, tm_path: Path):
104
+ if DEBUG:
105
+ logging.debug("[INFO] Conerting raw alignment to TM repo...")
106
+
107
+ def load_alignment(fn: Path):
108
+ content = fn.read_text()
109
+ print("Content !!! \n\n"+content)
110
+ if not content:
111
+ return []
112
+
113
+ for seg_pair in content.splitlines():
114
+ if not seg_pair:
115
+ continue
116
+
117
+ if "\t" in seg_pair:
118
+ try:
119
+ bo_seg, en_seg = seg_pair.split("\t", 1)
120
+ except Exception as e:
121
+ logging.error(f"{e} in {fn}")
122
+ raise
123
+
124
+ else:
125
+ bo_seg = seg_pair
126
+ en_seg = "\n"
127
+ yield bo_seg, en_seg
128
+
129
+ text_bo_fn = tm_path / f"{tm_path.name}-bo.txt"
130
+ text_en_fn = tm_path / f"{tm_path.name}-en.txt"
131
+
132
+ with open(text_bo_fn, "w", encoding="utf-8") as bo_file, open(
133
+ text_en_fn, "w", encoding="utf-8"
134
+ ) as en_file:
135
+ for bo_seg, en_seg in load_alignment(align_fn):
136
+ bo_file.write(bo_seg + "\n")
137
+ en_file.write(en_seg + "\n")
138
+
139
+ return tm_path
140
+
141
+
142
+ def get_github_dev_url(raw_github_url: str) -> str:
143
+ base_url = "https://github.dev"
144
+ _, file_path = raw_github_url.split(".com")
145
+ blob_file_path = file_path.replace("main", "blob/main")
146
+ return base_url + blob_file_path
147
+
148
+
149
+ def add_input_in_readme(input_dict: Dict[str, str], path: Path) -> Path:
150
+ input_readme_fn = path / "README.md"
151
+ text_id = input_dict["text_id"]
152
+ bo_file_url = get_github_dev_url(input_dict["bo_file_url"])
153
+ en_file_url = get_github_dev_url(input_dict["en_file_url"])
154
+ input_string = "## Input\n- [BO{}]({})\n- [EN{}]({})".format(
155
+ text_id, bo_file_url, text_id, en_file_url
156
+ )
157
+
158
+ input_readme_fn.write_text(input_string)
159
 
160
+ return path
161
+
162
+ def add_to_publish_todo_repo(org, repo_name, file_path, access_token):
163
+ base_url = f"https://api.github.com/repos/{org}/{repo_name}/contents/"
164
+
165
+ headers = {
166
+ "Authorization": f"Bearer {access_token}",
167
+ "Accept": "application/vnd.github.v3+json",
168
+ }
169
+
170
+ url = base_url + file_path
171
+
172
+ response = requests.get(url, headers=headers)
173
+
174
+ if response.status_code == 200:
175
+ print(f"[INFO] '{file_path}' already added.")
176
+ return
177
+
178
+ payload = {"message": f"Add {file_path}", "content": ""}
179
+
180
+ response = requests.put(url, headers=headers, json=payload)
181
+
182
+ if response.status_code == 201:
183
+ print(f"[INFO] '{file_path}' added to publish todo")
184
+ else:
185
+ print(f"[ERROR] Failed to add '{file_path}'.")
186
+ print(f"[ERROR] Response: {response.text}")
187
+
188
+
189
+ def create_tm(align_fn: Path, text_pair: Dict[str, str]):
190
+ align_fn = Path(align_fn)
191
+ text_id = text_pair["text_id"]
192
+ with tempfile.TemporaryDirectory() as tmp_dir:
193
+ output_dir = Path(tmp_dir)
194
+ repo_name = f"TM{text_id}"
195
+ tm_path = output_dir / repo_name
196
+ tm_path.mkdir(exist_ok=True, parents=True)
197
+ repo_path = convert_raw_align_to_tm(align_fn, tm_path)
198
+ repo_path = add_input_in_readme(text_pair, tm_path)
199
+ repo_url = create_github_repo(repo_path, repo_name)
200
+ logging.info(f"TM repo created: {repo_url}")
201
+ add_to_publish_todo_repo(GITHUB_ORG, MAI_TM_PUBLISH_TODO_REPO, repo_name, GITHUB_ACCESS_TOKEN)
202
+ return repo_url
203
+
204
+ ##----------------------- MAIN -----------------------##
205
+
206
+
207
+ @contextmanager
208
+ def TemporaryDirectory():
209
+ tmpdir = Path("./output").resolve() / uuid.uuid4().hex[:8]
210
+ tmpdir.mkdir(exist_ok=True, parents=True)
211
+ try:
212
+ yield tmpdir
213
+ finally:
214
+ shutil.rmtree(str(tmpdir))
215
+
216
+
217
+ def download_file(s3_public_url: str, output_fn) -> Path:
218
+ """Download file from a public S3 bucket URL."""
219
+ with requests.get(s3_public_url, stream=True) as r:
220
+ r.raise_for_status()
221
+ with open(output_fn, "wb") as f:
222
+ for chunk in r.iter_content(chunk_size=8192):
223
+ f.write(chunk)
224
+ return output_fn
225
+
226
+
227
+ def _run_align_script(bo_fn, en_fn, output_dir):
228
+ start = time.time()
229
+ cmd = [str(ALIGNER_SCRIPT_PATH), str(bo_fn), str(en_fn), str(output_dir)]
230
+ output = subprocess.run(
231
+ cmd,
232
+ check=True,
233
+ capture_output=True,
234
+ text=True,
235
+ cwd=str(ALIGNER_SCRIPT_DIR),
236
+ )
237
+ output_fn = re.search(r"\[OUTPUT\] (.*)", output.stdout).group(1)
238
+ output_fn = "/" + output_fn.split("//")[-1]
239
+ end = time.time()
240
+ total_time = round((end - start) / 60, 2)
241
+ logging.info(f"Total time taken for Aligning: {total_time} mins")
242
+ return output_fn
243
+ def align(text_pair):
244
+ logging.info(f"Running aligner for TM{text_pair['text_id']}...")
245
+ with TemporaryDirectory() as tmpdir:
246
+ output_dir = Path(tmpdir)
247
+ bo_fn = download_file(text_pair["bo_file_url"], output_fn=output_dir / "bo.tx")
248
+ en_fn = download_file(text_pair["en_file_url"], output_fn=output_dir / "en.tx")
249
+ print("bo_fn: ", bo_fn)
250
+ print("en_fn: ", en_fn)
251
+ aligned_fn = _run_align_script(bo_fn, en_fn, output_dir)
252
+ print("aligned_fn: ", aligned_fn)
253
+ repo_url = create_tm(aligned_fn, text_pair=text_pair)
254
+ return {"tm_repo_url": repo_url}
255
+
256
+ class EndpointHandler():
257
+ def __init__(self, path=""):
258
+ self.path = path
259
+
260
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
261
  """
262
+ Args:
263
+ data (:obj:):
264
+ includes the input data and the parameters for the inference.
265
+ Return:
266
+ A :obj:`list`:. The list contains the embeddings of the inference inputs
267
+ """
268
+ return align(data)
269
+