DWizard commited on
Commit
5e8cc27
2 Parent(s): fbd52e8 fccaaea

Merge branch 'oop-refactor' of https://github.com/project-kxkg/project-t into oop-refactor

Browse files
configs/local_launch.yaml CHANGED
@@ -1,5 +1,4 @@
1
  # launch config for local environment
2
- model: "gpt-4"
3
  local_dump: ./local_dump
4
- output_type: srt
5
  environ: local
 
1
  # launch config for local environment
 
2
  local_dump: ./local_dump
3
+ # dictionary_path: ./domain_dict
4
  environ: local
configs/task_config.yaml CHANGED
@@ -1,7 +1,7 @@
1
  # configuration for each task
2
  source_lang: EN
3
  target_lang: ZH
4
- field: SC2
5
 
6
  # ASR config
7
  ASR:
 
1
  # configuration for each task
2
  source_lang: EN
3
  target_lang: ZH
4
+ field: General
5
 
6
  # ASR config
7
  ASR:
src/srt_util/srt.py CHANGED
@@ -52,6 +52,8 @@ punctuation_dict = {
52
  },
53
  }
54
 
 
 
55
  class SrtSegment(object):
56
  def __init__(self, src_lang, tgt_lang, *args) -> None:
57
  self.src_lang = src_lang
@@ -150,11 +152,19 @@ class SrtSegment(object):
150
 
151
 
152
  class SrtScript(object):
153
- def __init__(self, src_lang, tgt_lang, segments) -> None:
 
154
  self.src_lang = src_lang
155
  self.tgt_lang = tgt_lang
156
  self.segments = [SrtSegment(self.src_lang, self.tgt_lang, seg) for seg in segments]
157
 
 
 
 
 
 
 
 
158
  @classmethod
159
  def parse_from_srt_file(cls, src_lang, tgt_lang, path: str):
160
  with open(path, 'r', encoding="utf-8") as f:
@@ -429,6 +439,12 @@ class SrtScript(object):
429
  def correct_with_force_term(self):
430
  ## force term correction
431
  logging.info("performing force term correction")
 
 
 
 
 
 
432
  # load term dictionary
433
  with open("finetune_data/dict_enzh.csv", 'r', encoding='utf-8') as f:
434
  term_enzh_dict = {rows[0]: rows[1] for rows in reader(f)}
@@ -478,6 +494,12 @@ class SrtScript(object):
478
 
479
  def spell_check_term(self):
480
  logging.info("performing spell check")
 
 
 
 
 
 
481
  import enchant
482
  dict = enchant.Dict('en_US')
483
  term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt')
 
52
  },
53
  }
54
 
55
+ dict_path = "./domain_dict"
56
+
57
  class SrtSegment(object):
58
  def __init__(self, src_lang, tgt_lang, *args) -> None:
59
  self.src_lang = src_lang
 
152
 
153
 
154
  class SrtScript(object):
155
+ def __init__(self, src_lang, tgt_lang, segments, domain="General") -> None:
156
+ self.domain = domain
157
  self.src_lang = src_lang
158
  self.tgt_lang = tgt_lang
159
  self.segments = [SrtSegment(self.src_lang, self.tgt_lang, seg) for seg in segments]
160
 
161
+ if self.domain != "General":
162
+ if os.path.exists(f"{dict_path}/{self.domain}"):
163
+ # TODO: load dictionary
164
+ ...
165
+ else:
166
+ logging.error(f"domain {self.domain} doesn't exist")
167
+
168
  @classmethod
169
  def parse_from_srt_file(cls, src_lang, tgt_lang, path: str):
170
  with open(path, 'r', encoding="utf-8") as f:
 
439
  def correct_with_force_term(self):
440
  ## force term correction
441
  logging.info("performing force term correction")
442
+
443
+ # check domain
444
+ if self.domain == "General":
445
+ logging.info("General domain could not perform correct_with_force_term. skip this step.")
446
+ pass
447
+
448
  # load term dictionary
449
  with open("finetune_data/dict_enzh.csv", 'r', encoding='utf-8') as f:
450
  term_enzh_dict = {rows[0]: rows[1] for rows in reader(f)}
 
494
 
495
  def spell_check_term(self):
496
  logging.info("performing spell check")
497
+
498
+ # check domain
499
+ if self.domain == "General":
500
+ logging.info("General domain could not perform spell_check_term. skip this step.")
501
+ pass
502
+
503
  import enchant
504
  dict = enchant.Dict('en_US')
505
  term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt')
src/task.py CHANGED
@@ -157,7 +157,7 @@ class Task:
157
  # after get the transcript, release the gpu resource
158
  torch.cuda.empty_cache()
159
 
160
- self.SRT_Script = SrtScript(self.source_lang, self.target_lang, transcript['segments'])
161
  # save the srt script to local
162
  self.SRT_Script.write_srt_file_src(src_srt_path)
163
 
 
157
  # after get the transcript, release the gpu resource
158
  torch.cuda.empty_cache()
159
 
160
+ self.SRT_Script = SrtScript(self.source_lang, self.target_lang, transcript['segments'], self.field)
161
  # save the srt script to local
162
  self.SRT_Script.write_srt_file_src(src_srt_path)
163