Spaces:
Sleeping
Sleeping
Merge branch 'oop-refactor' of https://github.com/project-kxkg/project-t into oop-refactor
Browse files- configs/local_launch.yaml +1 -2
- configs/task_config.yaml +1 -1
- src/srt_util/srt.py +23 -1
- src/task.py +1 -1
configs/local_launch.yaml
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
# launch config for local environment
|
2 |
-
model: "gpt-4"
|
3 |
local_dump: ./local_dump
|
4 |
-
|
5 |
environ: local
|
|
|
1 |
# launch config for local environment
|
|
|
2 |
local_dump: ./local_dump
|
3 |
+
# dictionary_path: ./domain_dict
|
4 |
environ: local
|
configs/task_config.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
# configuration for each task
|
2 |
source_lang: EN
|
3 |
target_lang: ZH
|
4 |
-
field:
|
5 |
|
6 |
# ASR config
|
7 |
ASR:
|
|
|
1 |
# configuration for each task
|
2 |
source_lang: EN
|
3 |
target_lang: ZH
|
4 |
+
field: General
|
5 |
|
6 |
# ASR config
|
7 |
ASR:
|
src/srt_util/srt.py
CHANGED
@@ -52,6 +52,8 @@ punctuation_dict = {
|
|
52 |
},
|
53 |
}
|
54 |
|
|
|
|
|
55 |
class SrtSegment(object):
|
56 |
def __init__(self, src_lang, tgt_lang, *args) -> None:
|
57 |
self.src_lang = src_lang
|
@@ -150,11 +152,19 @@ class SrtSegment(object):
|
|
150 |
|
151 |
|
152 |
class SrtScript(object):
|
153 |
-
def __init__(self, src_lang, tgt_lang, segments) -> None:
|
|
|
154 |
self.src_lang = src_lang
|
155 |
self.tgt_lang = tgt_lang
|
156 |
self.segments = [SrtSegment(self.src_lang, self.tgt_lang, seg) for seg in segments]
|
157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
@classmethod
|
159 |
def parse_from_srt_file(cls, src_lang, tgt_lang, path: str):
|
160 |
with open(path, 'r', encoding="utf-8") as f:
|
@@ -429,6 +439,12 @@ class SrtScript(object):
|
|
429 |
def correct_with_force_term(self):
|
430 |
## force term correction
|
431 |
logging.info("performing force term correction")
|
|
|
|
|
|
|
|
|
|
|
|
|
432 |
# load term dictionary
|
433 |
with open("finetune_data/dict_enzh.csv", 'r', encoding='utf-8') as f:
|
434 |
term_enzh_dict = {rows[0]: rows[1] for rows in reader(f)}
|
@@ -478,6 +494,12 @@ class SrtScript(object):
|
|
478 |
|
479 |
def spell_check_term(self):
|
480 |
logging.info("performing spell check")
|
|
|
|
|
|
|
|
|
|
|
|
|
481 |
import enchant
|
482 |
dict = enchant.Dict('en_US')
|
483 |
term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt')
|
|
|
52 |
},
|
53 |
}
|
54 |
|
55 |
+
dict_path = "./domain_dict"
|
56 |
+
|
57 |
class SrtSegment(object):
|
58 |
def __init__(self, src_lang, tgt_lang, *args) -> None:
|
59 |
self.src_lang = src_lang
|
|
|
152 |
|
153 |
|
154 |
class SrtScript(object):
|
155 |
+
def __init__(self, src_lang, tgt_lang, segments, domain="General") -> None:
|
156 |
+
self.domain = domain
|
157 |
self.src_lang = src_lang
|
158 |
self.tgt_lang = tgt_lang
|
159 |
self.segments = [SrtSegment(self.src_lang, self.tgt_lang, seg) for seg in segments]
|
160 |
|
161 |
+
if self.domain != "General":
|
162 |
+
if os.path.exists(f"{dict_path}/{self.domain}"):
|
163 |
+
# TODO: load dictionary
|
164 |
+
...
|
165 |
+
else:
|
166 |
+
logging.error(f"domain {self.domain} doesn't exist")
|
167 |
+
|
168 |
@classmethod
|
169 |
def parse_from_srt_file(cls, src_lang, tgt_lang, path: str):
|
170 |
with open(path, 'r', encoding="utf-8") as f:
|
|
|
439 |
def correct_with_force_term(self):
|
440 |
## force term correction
|
441 |
logging.info("performing force term correction")
|
442 |
+
|
443 |
+
# check domain
|
444 |
+
if self.domain == "General":
|
445 |
+
logging.info("General domain could not perform correct_with_force_term. skip this step.")
|
446 |
+
pass
|
447 |
+
|
448 |
# load term dictionary
|
449 |
with open("finetune_data/dict_enzh.csv", 'r', encoding='utf-8') as f:
|
450 |
term_enzh_dict = {rows[0]: rows[1] for rows in reader(f)}
|
|
|
494 |
|
495 |
def spell_check_term(self):
|
496 |
logging.info("performing spell check")
|
497 |
+
|
498 |
+
# check domain
|
499 |
+
if self.domain == "General":
|
500 |
+
logging.info("General domain could not perform spell_check_term. skip this step.")
|
501 |
+
pass
|
502 |
+
|
503 |
import enchant
|
504 |
dict = enchant.Dict('en_US')
|
505 |
term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt')
|
src/task.py
CHANGED
@@ -157,7 +157,7 @@ class Task:
|
|
157 |
# after get the transcript, release the gpu resource
|
158 |
torch.cuda.empty_cache()
|
159 |
|
160 |
-
self.SRT_Script = SrtScript(self.source_lang, self.target_lang, transcript['segments'])
|
161 |
# save the srt script to local
|
162 |
self.SRT_Script.write_srt_file_src(src_srt_path)
|
163 |
|
|
|
157 |
# after get the transcript, release the gpu resource
|
158 |
torch.cuda.empty_cache()
|
159 |
|
160 |
+
self.SRT_Script = SrtScript(self.source_lang, self.target_lang, transcript['segments'], self.field)
|
161 |
# save the srt script to local
|
162 |
self.SRT_Script.write_srt_file_src(src_srt_path)
|
163 |
|