Zengyf-CVer commited on
Commit
e16eaa7
1 Parent(s): b1f485a

app update

Browse files
Files changed (1) hide show
  1. app.py +23 -12
app.py CHANGED
@@ -1,7 +1,6 @@
1
  # OCR Translate v0.1
2
  # 创建人:曾逸夫
3
  # 创建时间:2022-06-14
4
- # email: zyfiy1314@163.com
5
 
6
  import os
7
 
@@ -13,13 +12,6 @@ from transformers import MarianMTModel, MarianTokenizer
13
 
14
  nltk.download('punkt')
15
 
16
- # ----------- 翻译 -----------
17
- # https://huggingface.co/Helsinki-NLP/opus-mt-en-zh
18
- modchoice = "Helsinki-NLP/opus-mt-en-zh" # 模型名称
19
-
20
- tokenizer = MarianTokenizer.from_pretrained(modchoice) # 分词器
21
- model = MarianMTModel.from_pretrained(modchoice) # 模型
22
-
23
  OCR_TR_DESCRIPTION = '''# OCR Translate v0.1
24
  <div id="content_align">基于Tesseract的OCR翻译系统</div>'''
25
 
@@ -30,6 +22,17 @@ img_dir = "./data"
30
  choices = os.popen('tesseract --list-langs').read().split('\n')[1:-1]
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
33
  # tesseract语言列表转pytesseract语言
34
  def ocr_lang(lang_list):
35
  lang_str = ""
@@ -66,11 +69,19 @@ def translate(input_text):
66
  if input_text is None or input_text == "":
67
  return "系统提示:没有可翻译的内容!"
68
 
69
- translated = model.generate(**tokenizer(sent_tokenize(input_text), return_tensors="pt", padding=True))
70
- tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
71
- translate_text = "".join(tgt_text)
 
 
 
 
 
 
 
 
72
 
73
- return translate_text
74
 
75
 
76
  def main():
 
1
  # OCR Translate v0.1
2
  # 创建人:曾逸夫
3
  # 创建时间:2022-06-14
 
4
 
5
  import os
6
 
 
12
 
13
  nltk.download('punkt')
14
 
 
 
 
 
 
 
 
15
  OCR_TR_DESCRIPTION = '''# OCR Translate v0.1
16
  <div id="content_align">基于Tesseract的OCR翻译系统</div>'''
17
 
 
22
  choices = os.popen('tesseract --list-langs').read().split('\n')[1:-1]
23
 
24
 
25
+ # 翻译模型选择
26
+ def model_choice(src="en", trg="zh"):
27
+ # https://huggingface.co/Helsinki-NLP/opus-mt-en-zh
28
+ model_name = f"Helsinki-NLP/opus-mt-{src}-{trg}" # 模型名称
29
+
30
+ tokenizer = MarianTokenizer.from_pretrained(model_name) # 分词器
31
+ model = MarianMTModel.from_pretrained(model_name) # 模型
32
+
33
+ return tokenizer, model
34
+
35
+
36
  # tesseract语言列表转pytesseract语言
37
  def ocr_lang(lang_list):
38
  lang_str = ""
 
69
  if input_text is None or input_text == "":
70
  return "系统提示:没有可翻译的内容!"
71
 
72
+ tokenizer, model = model_choice()
73
+
74
+ translate_text = ""
75
+ input_text_list = input_text.split("\n\n")
76
+
77
+ for i in range(len(input_text_list)):
78
+ translated_sub = model.generate(
79
+ **tokenizer(sent_tokenize(input_text_list[i]), return_tensors="pt", truncation=True, padding=True))
80
+ tgt_text_sub = [tokenizer.decode(t, skip_special_tokens=True) for t in translated_sub]
81
+ translate_text_sub = "".join(tgt_text_sub)
82
+ translate_text = translate_text + "\n\n" + translate_text_sub
83
 
84
+ return translate_text[2:]
85
 
86
 
87
  def main():