omkar56 commited on
Commit
9a6d49e
1 Parent(s): 5d94527

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -19
app.py CHANGED
@@ -1,6 +1,5 @@
1
  # OCR Translate v0.2
2
- # 创建人:曾逸夫
3
- # 创建时间:2022-07-19
4
 
5
  import os
6
 
@@ -12,32 +11,91 @@ import pyclip
12
  import pytesseract
13
  from nltk.tokenize import sent_tokenize
14
  from transformers import MarianMTModel, MarianTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  nltk.download('punkt')
17
 
18
  OCR_TR_DESCRIPTION = '''# OCR Translate v0.2
19
  <div id="content_align">OCR translation system based on Tesseract</div>'''
20
 
21
- # 图片路径
22
  img_dir = "./data"
23
 
24
- # 获取tesseract语言列表
25
  choices = os.popen('tesseract --list-langs').read().split('\n')[1:-1]
26
 
27
 
28
- # 翻译模型选择
29
  def model_choice(src="en", trg="zh"):
30
  # https://huggingface.co/Helsinki-NLP/opus-mt-zh-en
31
  # https://huggingface.co/Helsinki-NLP/opus-mt-en-zh
32
- model_name = f"Helsinki-NLP/opus-mt-{src}-{trg}" # 模型名称
33
 
34
- tokenizer = MarianTokenizer.from_pretrained(model_name) # 分词器
35
- model = MarianMTModel.from_pretrained(model_name) # 模型
36
 
37
  return tokenizer, model
38
 
39
 
40
- # tesseract语言列表转pytesseract语言
41
  def ocr_lang(lang_list):
42
  lang_str = ""
43
  lang_len = len(lang_list)
@@ -57,12 +115,12 @@ def ocr_tesseract(img, languages):
57
  return ocr_str
58
 
59
 
60
- # 清除
61
  def clear_content():
62
  return None
63
 
64
 
65
- # 复制到剪贴板
66
  def cp_text(input_text):
67
  # sudo apt-get install xclip
68
  try:
@@ -72,18 +130,18 @@ def cp_text(input_text):
72
  print(e)
73
 
74
 
75
- # 清除剪贴板
76
  def cp_clear():
77
  pyclip.clear()
78
 
79
 
80
- # 翻译
81
  def translate(input_text, inputs_transStyle):
82
- # 参考:https://huggingface.co/docs/transformers/model_doc/marian
83
  if input_text is None or input_text == "":
84
  return "System prompt: There is no content to translate!"
85
 
86
- # 选择翻译模型
87
  trans_src, trans_trg = inputs_transStyle.split("-")[0], inputs_transStyle.split("-")[1]
88
  tokenizer, model = model_choice(trans_src, trans_trg)
89
 
@@ -110,7 +168,7 @@ def main():
110
  with gr.Blocks(css='style.css') as ocr_tr:
111
  gr.Markdown(OCR_TR_DESCRIPTION)
112
 
113
- # -------------- OCR 文字提取 --------------
114
  with gr.Box():
115
 
116
  with gr.Row():
@@ -147,7 +205,7 @@ def main():
147
  ["./data/test03.png", ["chi_sim"]]]
148
  gr.Examples(example_list, [inputs_img, inputs_lang], outputs_text, ocr_tesseract, cache_examples=False)
149
 
150
- # -------------- 翻译 --------------
151
  with gr.Box():
152
 
153
  with gr.Row():
@@ -165,11 +223,11 @@ def main():
165
  outputs_text,])
166
  clear_img_btn.click(fn=clear_content, inputs=[], outputs=[inputs_img])
167
 
168
- # ---------------------- 翻译 ----------------------
169
  translate_btn.click(fn=translate, inputs=[outputs_text, inputs_transStyle], outputs=[outputs_tr_text])
170
  clear_text_btn.click(fn=clear_content, inputs=[], outputs=[outputs_text])
171
 
172
- # ---------------------- 复制到剪贴板 ----------------------
173
  cp_btn.click(fn=cp_text, inputs=[outputs_tr_text], outputs=[])
174
  cp_clear_btn.click(fn=cp_clear, inputs=[], outputs=[])
175
 
 
1
  # OCR Translate v0.2
2
+
 
3
 
4
  import os
5
 
 
11
  import pytesseract
12
  from nltk.tokenize import sent_tokenize
13
  from transformers import MarianMTModel, MarianTokenizer
14
+ # Newly added below
15
+ from fastapi import FastAPI, File, UploadFile, Body, Security
16
+ from fastapi.security.api_key import APIKeyHeader
17
+ from fastapi.encoders import jsonable_encoder
18
+
19
+ API_KEY = os.environ.get("API_KEY")
20
+
21
+ app = FastAPI()
22
+ api_key_header = APIKeyHeader(name="api_key", auto_error=False)
23
+
24
+ def get_api_key(api_key: Optional[str] = Depends(security)):
25
+ if api_key is None or api_key != API_KEY:
26
+ raise HTTPException(status_code=401, detail="Unauthorized access")
27
+ return api_key
28
+
29
+ @app.post("/ocr", response_model=dict)
30
+ async def ocr(
31
+ api_key: str = Depends(get_api_key),
32
+ image: UploadFile = File(...),
33
+ languages: list = Body(["eng"])
34
+ ):
35
+ # if api_key != API_KEY:
36
+ # return {"error": "Invalid API key"}, 401
37
+
38
+ try:
39
+ text = image_to_string(await image.read(), lang="+".join(languages))
40
+ except Exception as e:
41
+ return {"error": str(e)}, 500
42
+
43
+ return jsonable_encoder({"text": text})
44
+
45
+
46
+ @app.post("/translate", response_model=dict)
47
+ async def translate(
48
+ api_key: str = Depends(get_api_key),
49
+ text: str = Body(...),
50
+ src: str = "en",
51
+ trg: str = "zh",
52
+ ):
53
+ # if api_key != API_KEY:
54
+ # return {"error": "Invalid API key"}, 401
55
+
56
+ tokenizer, model = get_model(src, trg)
57
+
58
+ translated_text = ""
59
+ for sentence in sent_tokenize(text):
60
+ translated_sub = model.generate(**tokenizer(sentence, return_tensors="pt"))[0]
61
+ translated_text += tokenizer.decode(translated_sub, skip_special_tokens=True) + "\n"
62
+
63
+ return jsonable_encoder({"translated_text": translated_text})
64
+
65
+
66
+ def get_model(src: str, trg: str):
67
+ model_name = f"Helsinki-NLP/opus-mt-{src}-{trg}"
68
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
69
+ model = MarianMTModel.from_pretrained(model_name)
70
+ return tokenizer, model
71
+
72
+ # ===============================================
73
 
74
  nltk.download('punkt')
75
 
76
  OCR_TR_DESCRIPTION = '''# OCR Translate v0.2
77
  <div id="content_align">OCR translation system based on Tesseract</div>'''
78
 
79
+ # Image path
80
  img_dir = "./data"
81
 
82
+ # Get tesseract language list
83
  choices = os.popen('tesseract --list-langs').read().split('\n')[1:-1]
84
 
85
 
86
+ # Translation model selection
87
  def model_choice(src="en", trg="zh"):
88
  # https://huggingface.co/Helsinki-NLP/opus-mt-zh-en
89
  # https://huggingface.co/Helsinki-NLP/opus-mt-en-zh
90
+ model_name = f"Helsinki-NLP/opus-mt-{src}-{trg}" # Model name
91
 
92
+ tokenizer = MarianTokenizer.from_pretrained(model_name) # tokenizer
93
+ model = MarianMTModel.from_pretrained(model_name) # Model
94
 
95
  return tokenizer, model
96
 
97
 
98
+ # Convert tesseract language list to pytesseract language
99
  def ocr_lang(lang_list):
100
  lang_str = ""
101
  lang_len = len(lang_list)
 
115
  return ocr_str
116
 
117
 
118
+ # Clear
119
  def clear_content():
120
  return None
121
 
122
 
123
+ # copy to clipboard
124
  def cp_text(input_text):
125
  # sudo apt-get install xclip
126
  try:
 
130
  print(e)
131
 
132
 
133
+ # clear clipboard
134
  def cp_clear():
135
  pyclip.clear()
136
 
137
 
138
+ # translate
139
  def translate(input_text, inputs_transStyle):
140
+ # reference:https://huggingface.co/docs/transformers/model_doc/marian
141
  if input_text is None or input_text == "":
142
  return "System prompt: There is no content to translate!"
143
 
144
+ # Select translation model
145
  trans_src, trans_trg = inputs_transStyle.split("-")[0], inputs_transStyle.split("-")[1]
146
  tokenizer, model = model_choice(trans_src, trans_trg)
147
 
 
168
  with gr.Blocks(css='style.css') as ocr_tr:
169
  gr.Markdown(OCR_TR_DESCRIPTION)
170
 
171
+ # -------------- OCR text extraction --------------
172
  with gr.Box():
173
 
174
  with gr.Row():
 
205
  ["./data/test03.png", ["chi_sim"]]]
206
  gr.Examples(example_list, [inputs_img, inputs_lang], outputs_text, ocr_tesseract, cache_examples=False)
207
 
208
+ # -------------- translate --------------
209
  with gr.Box():
210
 
211
  with gr.Row():
 
223
  outputs_text,])
224
  clear_img_btn.click(fn=clear_content, inputs=[], outputs=[inputs_img])
225
 
226
+ # ---------------------- translate ----------------------
227
  translate_btn.click(fn=translate, inputs=[outputs_text, inputs_transStyle], outputs=[outputs_tr_text])
228
  clear_text_btn.click(fn=clear_content, inputs=[], outputs=[outputs_text])
229
 
230
+ # ---------------------- copy to clipboard ----------------------
231
  cp_btn.click(fn=cp_text, inputs=[outputs_tr_text], outputs=[])
232
  cp_clear_btn.click(fn=cp_clear, inputs=[], outputs=[])
233