procrastinya commited on
Commit
1fd054e
1 Parent(s): 92d9076

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +391 -0
  2. module.py +104 -0
  3. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ import gradio as gr
3
+ import random
4
+ import string
5
+ import paddlehub as hub
6
+ import torch
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+ from loguru import logger
9
+
10
+ language_translation_model = hub.Module(directory=f'./baidu_translate')
11
+ def getTextTrans(text, source='zh', target='en'):
12
+ def is_chinese(string):
13
+ for ch in string:
14
+ if u'\u4e00' <= ch <= u'\u9fff':
15
+ return True
16
+ return False
17
+
18
+ if not is_chinese(text) and target == 'en':
19
+ return text
20
+
21
+ try:
22
+ text_translation = language_translation_model.translate(text, source, target)
23
+ return text_translation
24
+ except Exception as e:
25
+ return text
26
+
27
+ space_ids = {
28
+ "spaces/stabilityai/stable-diffusion": "SD 2.1",
29
+ "spaces/runwayml/stable-diffusion-v1-5": "SD 1.5",
30
+ "spaces/stabilityai/stable-diffusion-1": "SD 1.0",
31
+ "dalle_mini_tab": "Dalle mini",
32
+ "spaces/IDEA-CCNL/Taiyi-Stable-Diffusion-Chinese": "Taiyi(太乙)",
33
+ }
34
+
35
+ tab_actions = []
36
+ tab_titles = []
37
+
38
+ extend_prompt_1 = True
39
+ extend_prompt_2 = True
40
+ extend_prompt_3 = True
41
+
42
+ thanks_info = "Thanks: "
43
+ if extend_prompt_1:
44
+ extend_prompt_pipe = pipeline('text-generation', model='yizhangliu/prompt-extend', max_length=77, pad_token_id=0)
45
+ thanks_info += "[<a style='display:inline-block' href='https://huggingface.co/spaces/daspartho/prompt-extend' _blank><font style='color:blue;weight:bold;'>prompt-extend(1)</font></a>]"
46
+ if extend_prompt_2:
47
+ def load_prompter():
48
+ prompter_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist")
49
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
50
+ tokenizer.pad_token = tokenizer.eos_token
51
+ tokenizer.padding_side = "left"
52
+ return prompter_model, tokenizer
53
+ prompter_model, prompter_tokenizer = load_prompter()
54
+ def extend_prompt_microsoft(in_text):
55
+ input_ids = prompter_tokenizer(in_text.strip()+" Rephrase:", return_tensors="pt").input_ids
56
+ eos_id = prompter_tokenizer.eos_token_id
57
+ outputs = prompter_model.generate(input_ids, do_sample=False, max_new_tokens=75, num_beams=8, num_return_sequences=8, eos_token_id=eos_id, pad_token_id=eos_id, length_penalty=-1.0)
58
+ output_texts = prompter_tokenizer.batch_decode(outputs, skip_special_tokens=True)
59
+ res = output_texts[0].replace(in_text+" Rephrase:", "").strip()
60
+ return res
61
+ thanks_info += "[<a style='display:inline-block' href='https://huggingface.co/spaces/microsoft/Promptist' _blank><font style='color:blue;weight:bold;'>Promptist(2)</font></a>]"
62
+ if extend_prompt_3:
63
+ MagicPrompt = gr.Interface.load("spaces/Gustavosta/MagicPrompt-Stable-Diffusion")
64
+ thanks_info += "[<a style='display:inline-block' href='https://huggingface.co/spaces/Gustavosta/MagicPrompt-Stable-Diffusion' _blank><font style='color:blue;weight:bold;'>MagicPrompt(3)</font></a>]"
65
+
66
+ do_dreamlike_photoreal = False
67
+ if do_dreamlike_photoreal:
68
+ def add_random_noise(prompt, noise_level=0.1):
69
+ # Get the percentage of characters to add as noise
70
+ percentage_noise = noise_level * 5
71
+ # Get the number of characters to add as noise
72
+ num_noise_chars = int(len(prompt) * (percentage_noise/100))
73
+ # Get the indices of the characters to add noise to
74
+ noise_indices = random.sample(range(len(prompt)), num_noise_chars)
75
+ # Add noise to the selected characters
76
+ prompt_list = list(prompt)
77
+ for index in noise_indices:
78
+ prompt_list[index] = random.choice(string.ascii_letters + string.punctuation)
79
+ new_prompt = "".join(prompt_list)
80
+ return new_prompt
81
+
82
+ dreamlike_photoreal_2_0 = gr.Interface.load("models/dreamlike-art/dreamlike-photoreal-2.0")
83
+ dreamlike_image = gr.Image(label="Dreamlike Photoreal 2.0")
84
+
85
+ tab_actions.append(dreamlike_image)
86
+ tab_titles.append("Dreamlike_2.0")
87
+ thanks_info += "[<a style='display:inline-block' href='https://huggingface.co/dreamlike-art/dreamlike-photoreal-2.0' _blank><font style='color:blue;weight:bold;'>dreamlike-photoreal-2.0</font></a>]"
88
+
89
+ for space_id in space_ids.keys():
90
+ print(space_id, space_ids[space_id])
91
+ try:
92
+ tab_title = space_ids[space_id]
93
+ tab_titles.append(tab_title)
94
+ if (tab_title == 'Dalle mini'):
95
+ tab_content = gr.Blocks(elem_id='dalle_mini')
96
+ tab_actions.append(tab_content)
97
+ else:
98
+ tab_content = gr.Interface.load(space_id)
99
+ tab_actions.append(tab_content)
100
+ thanks_info += f"[<a style='display:inline-block' href='https://huggingface.co/{space_id}' _blank><font style='color:blue;weight:bold;'>{tab_title}</font></a>]"
101
+ except Exception as e:
102
+ logger.info(f"load_fail__{space_id}_{e}")
103
+
104
+ start_work = """async() => {
105
+ function isMobile() {
106
+ try {
107
+ document.createEvent("TouchEvent"); return true;
108
+ } catch(e) {
109
+ return false;
110
+ }
111
+ }
112
+ function getClientHeight()
113
+ {
114
+ var clientHeight=0;
115
+ if(document.body.clientHeight&&document.documentElement.clientHeight) {
116
+ var clientHeight = (document.body.clientHeight<document.documentElement.clientHeight)?document.body.clientHeight:document.documentElement.clientHeight;
117
+ } else {
118
+ var clientHeight = (document.body.clientHeight>document.documentElement.clientHeight)?document.body.clientHeight:document.documentElement.clientHeight;
119
+ }
120
+ return clientHeight;
121
+ }
122
+
123
+ function setNativeValue(element, value) {
124
+ const valueSetter = Object.getOwnPropertyDescriptor(element.__proto__, 'value').set;
125
+ const prototype = Object.getPrototypeOf(element);
126
+ const prototypeValueSetter = Object.getOwnPropertyDescriptor(prototype, 'value').set;
127
+
128
+ if (valueSetter && valueSetter !== prototypeValueSetter) {
129
+ prototypeValueSetter.call(element, value);
130
+ } else {
131
+ valueSetter.call(element, value);
132
+ }
133
+ }
134
+ window['tab_advanced'] = 0;
135
+
136
+ var gradioEl = document.querySelector('body > gradio-app').shadowRoot;
137
+ if (!gradioEl) {
138
+ gradioEl = document.querySelector('body > gradio-app');
139
+ }
140
+
141
+ if (typeof window['gradioEl'] === 'undefined') {
142
+ window['gradioEl'] = gradioEl;
143
+ tabitems = window['gradioEl'].querySelectorAll('.tabitem');
144
+ tabitems_title = window['gradioEl'].querySelectorAll('#tab_demo')[0].children[0].children[0].children;
145
+ window['dalle_mini_block'] = null;
146
+ window['dalle_mini_iframe'] = null;
147
+ for (var i = 0; i < tabitems.length; i++) {
148
+ if (tabitems_title[i].innerText.indexOf('SD') >= 0) {
149
+ tabitems[i].childNodes[0].children[0].style.display='none';
150
+ for (var j = 0; j < tabitems[i].childNodes[0].children[1].children.length; j++) {
151
+ if (j != 1) {
152
+ tabitems[i].childNodes[0].children[1].children[j].style.display='none';
153
+ }
154
+ }
155
+ if (tabitems_title[i].innerText.indexOf('SD 1') >= 0) {
156
+ for (var j = 0; j < 4; j++) {
157
+ tabitems[i].childNodes[0].children[1].children[3].children[1].children[j].children[2].removeAttribute("disabled");
158
+ }
159
+ } else if (tabitems_title[i].innerText.indexOf('SD 2') >= 0) {
160
+ tabitems[i].children[0].children[1].children[3].children[0].click();
161
+ }
162
+ } else if (tabitems_title[i].innerText.indexOf('Taiyi') >= 0) {
163
+ tabitems[i].children[0].children[0].children[1].style.display='none';
164
+ tabitems[i].children[0].children[0].children[0].children[0].children[1].style.display='none';
165
+ } else if (tabitems_title[i].innerText.indexOf('Dreamlike') >= 0) {
166
+ tabitems[i].childNodes[0].children[0].children[1].style.display='none';
167
+ } else if (tabitems_title[i].innerText.indexOf('Dalle mini') >= 0) {
168
+ window['dalle_mini_block']= tabitems[i];
169
+ }
170
+ }
171
+
172
+ tab_demo = window['gradioEl'].querySelectorAll('#tab_demo')[0];
173
+ tab_demo.style.display = "block";
174
+ tab_demo.setAttribute('style', 'height: 100%;');
175
+ const page1 = window['gradioEl'].querySelectorAll('#page_1')[0];
176
+ const page2 = window['gradioEl'].querySelectorAll('#page_2')[0];
177
+
178
+ btns_1 = window['gradioEl'].querySelector('#input_col1_row3').children;
179
+ btns_1_split = 100 / btns_1.length;
180
+ for (var i = 0; i < btns_1.length; i++) {
181
+ btns_1[i].setAttribute('style', 'min-width:0px;width:' + btns_1_split + '%;');
182
+ }
183
+ page1.style.display = "none";
184
+ page2.style.display = "block";
185
+ prompt_work = window['gradioEl'].querySelectorAll('#prompt_work');
186
+ for (var i = 0; i < prompt_work.length; i++) {
187
+ prompt_work[i].style.display='none';
188
+ }
189
+
190
+ window['prevPrompt'] = '';
191
+ window['doCheckPrompt'] = 0;
192
+ window['checkPrompt'] = function checkPrompt() {
193
+ try {
194
+ prompt_work = window['gradioEl'].querySelectorAll('#prompt_work');
195
+ if (prompt_work.length > 0 && prompt_work[0].children.length > 1) {
196
+ prompt_work[0].children[1].style.display='none';
197
+ prompt_work[0].style.display='block';
198
+ }
199
+ text_value = window['gradioEl'].querySelectorAll('#prompt_work')[0].querySelectorAll('textarea')[0].value;
200
+ progress_bar = window['gradioEl'].querySelectorAll('.progress-bar');
201
+ if (window['doCheckPrompt'] === 0 && window['prevPrompt'] !== text_value && progress_bar.length == 0) {
202
+ console.log('_____new prompt___[' + text_value + ']_');
203
+ window['doCheckPrompt'] = 1;
204
+ window['prevPrompt'] = text_value;
205
+ tabitems = window['gradioEl'].querySelectorAll('.tabitem');
206
+ for (var i = 0; i < tabitems.length; i++) {
207
+ if (tabitems_title[i].innerText.indexOf('Dalle mini') >= 0) {
208
+ if (window['dalle_mini_block']) {
209
+ if (window['dalle_mini_iframe'] === null) {
210
+ window['dalle_mini_iframe'] = document.createElement('iframe');
211
+ window['dalle_mini_iframe'].height = 1000;
212
+ window['dalle_mini_iframe'].width = '100%';
213
+ window['dalle_mini_iframe'].id = 'dalle_iframe';
214
+ window['dalle_mini_block'].appendChild(window['dalle_mini_iframe']);
215
+ }
216
+ window['dalle_mini_iframe'].src = 'https://yizhangliu-dalleclone.hf.space/index.html?prompt=' + encodeURI(text_value);
217
+ console.log('dalle_mini');
218
+ }
219
+ continue;
220
+ }
221
+ inputText = null;
222
+ if (tabitems_title[i].innerText.indexOf('SD') >= 0) {
223
+ text_value = window['gradioEl'].querySelectorAll('#prompt_work')[0].querySelectorAll('textarea')[0].value;
224
+ inputText = tabitems[i].children[0].children[1].children[0].querySelectorAll('.gr-text-input')[0];
225
+ } else if (tabitems_title[i].innerText.indexOf('Taiyi') >= 0) {
226
+ text_value = window['gradioEl'].querySelectorAll('#prompt_work_zh')[0].querySelectorAll('textarea')[0].value;
227
+ inputText = tabitems[i].children[0].children[0].children[1].querySelectorAll('.gr-text-input')[0];
228
+ }
229
+ if (inputText) {
230
+ setNativeValue(inputText, text_value);
231
+ inputText.dispatchEvent(new Event('input', { bubbles: true }));
232
+ }
233
+ }
234
+
235
+ setTimeout(function() {
236
+ btns = window['gradioEl'].querySelectorAll('button');
237
+ for (var i = 0; i < btns.length; i++) {
238
+ if (['Generate image','Run', '生成图像(Generate)'].includes(btns[i].innerText)) {
239
+ btns[i].click();
240
+ }
241
+ }
242
+ window['doCheckPrompt'] = 0;
243
+ }, 10);
244
+ }
245
+ } catch(e) {
246
+ }
247
+ }
248
+ window['checkPrompt_interval'] = window.setInterval("window.checkPrompt()", 100);
249
+ }
250
+
251
+ return false;
252
+ }"""
253
+
254
+ switch_tab_advanced = """async() => {
255
+ window['tab_advanced'] = 1 - window['tab_advanced'];
256
+ if (window['tab_advanced']==0) {
257
+ action = 'none';
258
+ } else {
259
+ action = 'block';
260
+ }
261
+ tabitems = window['gradioEl'].querySelectorAll('.tabitem');
262
+ tabitems_title = window['gradioEl'].querySelectorAll('#tab_demo')[0].children[0].children[0].children;
263
+ for (var i = 0; i < tabitems.length; i++) {
264
+ if (tabitems_title[i].innerText.indexOf('SD') >= 0) {
265
+ //tabitems[i].childNodes[0].children[1].children[0].style.display=action;
266
+ //tabitems[i].childNodes[0].children[1].children[4].style.display=action;
267
+ for (var j = 0; j < tabitems[i].childNodes[0].children[1].children.length; j++) {
268
+ if (j != 1) {
269
+ tabitems[i].childNodes[0].children[1].children[j].style.display=action;
270
+ }
271
+ }
272
+ } else if (tabitems_title[i].innerText.indexOf('Taiyi') >= 0) {
273
+ tabitems[i].children[0].children[0].children[1].style.display=action;
274
+ }
275
+ }
276
+ return false;
277
+ }"""
278
+
279
+ def prompt_extend(prompt, PM):
280
+ prompt_en = getTextTrans(prompt, source='zh', target='en')
281
+ if PM == 1:
282
+ extend_prompt_en = extend_prompt_pipe(prompt_en+',', num_return_sequences=1)[0]["generated_text"]
283
+ elif PM == 2:
284
+ extend_prompt_en = extend_prompt_microsoft(prompt_en)
285
+ elif PM == 3:
286
+ extend_prompt_en = MagicPrompt(prompt_en)
287
+
288
+ if (prompt != prompt_en):
289
+ logger.info(f"extend_prompt__1_PM=[{PM}]_")
290
+ extend_prompt_out = getTextTrans(extend_prompt_en, source='en', target='zh')
291
+ else:
292
+ logger.info(f"extend_prompt__2_PM=[{PM}]_")
293
+ extend_prompt_out = extend_prompt_en
294
+
295
+ return extend_prompt_out
296
+
297
+ def prompt_extend_1(prompt):
298
+ extend_prompt_out = prompt_extend(prompt, 1)
299
+ return extend_prompt_out
300
+
301
+ def prompt_extend_2(prompt):
302
+ extend_prompt_out = prompt_extend(prompt, 2)
303
+ return extend_prompt_out
304
+
305
+ def prompt_extend_3(prompt):
306
+ extend_prompt_out = prompt_extend(prompt, 3)
307
+ return extend_prompt_out
308
+
309
+ def prompt_draw_1(prompt, noise_level):
310
+ prompt_en = getTextTrans(prompt, source='zh', target='en')
311
+ if (prompt != prompt_en):
312
+ logger.info(f"draw_prompt______1__")
313
+ prompt_zh = prompt
314
+ else:
315
+ logger.info(f"draw_prompt______2__")
316
+ prompt_zh = getTextTrans(prompt, source='en', target='zh')
317
+
318
+ prompt_with_noise = add_random_noise(prompt_en, noise_level)
319
+ dreamlike_output = dreamlike_photoreal_2_0(prompt_with_noise)
320
+ return prompt_en, prompt_zh, dreamlike_output
321
+
322
+ def prompt_draw_2(prompt):
323
+ prompt_en = getTextTrans(prompt, source='zh', target='en')
324
+ if (prompt != prompt_en):
325
+ logger.info(f"draw_prompt______1__")
326
+ prompt_zh = prompt
327
+ else:
328
+ logger.info(f"draw_prompt______2__")
329
+ prompt_zh = getTextTrans(prompt, source='en', target='zh')
330
+ return prompt_en, prompt_zh
331
+
332
+ with gr.Blocks(title='Text-to-Image') as demo:
333
+ with gr.Group(elem_id="page_1", visible=True) as page_1:
334
+ with gr.Box():
335
+ with gr.Row():
336
+ start_button = gr.Button("Let's GO!", elem_id="start-btn", visible=True)
337
+ start_button.click(fn=None, inputs=[], outputs=[], _js=start_work)
338
+
339
+ with gr.Group(elem_id="page_2", visible=False) as page_2:
340
+ with gr.Row(elem_id="prompt_row0"):
341
+ with gr.Column(id="input_col1"):
342
+ with gr.Row(elem_id="input_col1_row1"):
343
+ prompt_input0 = gr.Textbox(lines=2, label="Original prompt", visible=True)
344
+ with gr.Row(elem_id="input_col1_row2"):
345
+ prompt_work = gr.Textbox(lines=1, label="prompt_work", elem_id="prompt_work", visible=True)
346
+ with gr.Row(elem_id="input_col1_row3"):
347
+ with gr.Column(elem_id="input_col1_row2_col0"):
348
+ draw_btn_0 = gr.Button(value = "Generate(original)", elem_id="draw-btn-0")
349
+ if extend_prompt_1:
350
+ with gr.Column(elem_id="input_col1_row2_col1"):
351
+ extend_btn_1 = gr.Button(value = "Extend_1",elem_id="extend-btn-1")
352
+ if extend_prompt_2:
353
+ with gr.Column(elem_id="input_col1_row2_col2"):
354
+ extend_btn_2 = gr.Button(value = "Extend_2",elem_id="extend-btn-2")
355
+ if extend_prompt_3:
356
+ with gr.Column(elem_id="input_col1_row2_col3"):
357
+ extend_btn_3 = gr.Button(value = "Extend_3",elem_id="extend-btn-3")
358
+ with gr.Column(id="input_col2"):
359
+ prompt_input1 = gr.Textbox(lines=2, label="Extend prompt", visible=True)
360
+ draw_btn_1 = gr.Button(value = "Generate(extend)", elem_id="draw-btn-1")
361
+ with gr.Row(elem_id="prompt_row1"):
362
+ with gr.Column(id="input_col3"):
363
+ with gr.Row(elem_id="input_col3_row2"):
364
+ prompt_work_zh = gr.Textbox(lines=1, label="prompt_work_zh", elem_id="prompt_work_zh", visible=False)
365
+ with gr.Row(elem_id='tab_demo', visible=True).style(height=200):
366
+ tab_demo = gr.TabbedInterface(tab_actions, tab_titles)
367
+ if do_dreamlike_photoreal:
368
+ with gr.Row():
369
+ noise_level=gr.Slider(minimum=0.1, maximum=3, step=0.1, label="Dreamlike noise Level: [Higher noise level produces more diverse outputs, while lower noise level produces similar outputs.]")
370
+ with gr.Row():
371
+ switch_tab_advanced_btn = gr.Button(value = "Switch_tab_advanced", elem_id="switch_tab_advanced_btn")
372
+ switch_tab_advanced_btn.click(fn=None, inputs=[], outputs=[], _js=switch_tab_advanced)
373
+ with gr.Row():
374
+ gr.HTML(f"<p>{thanks_info}</p>")
375
+
376
+ if extend_prompt_1:
377
+ extend_btn_1.click(fn=prompt_extend_1, inputs=[prompt_input0], outputs=[prompt_input1])
378
+ if extend_prompt_2:
379
+ extend_btn_2.click(fn=prompt_extend_2, inputs=[prompt_input0], outputs=[prompt_input1])
380
+ if extend_prompt_3:
381
+ extend_btn_3.click(fn=prompt_extend_3, inputs=[prompt_input0], outputs=[prompt_input1])
382
+
383
+ if do_dreamlike_photoreal:
384
+ draw_btn_0.click(fn=prompt_draw_1, inputs=[prompt_input0, noise_level], outputs=[prompt_work, prompt_work_zh, dreamlike_image])
385
+ draw_btn_1.click(fn=prompt_draw_1, inputs=[prompt_input1, noise_level], outputs=[prompt_work, prompt_work_zh, dreamlike_image])
386
+ else:
387
+ draw_btn_0.click(fn=prompt_draw_2, inputs=[prompt_input0], outputs=[prompt_work, prompt_work_zh])
388
+ draw_btn_1.click(fn=prompt_draw_2, inputs=[prompt_input1], outputs=[prompt_work, prompt_work_zh])
389
+
390
+ demo.queue()
391
+ demo.launch()
module.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import random
3
+ from hashlib import md5
4
+ from typing import Optional
5
+
6
+ import requests
7
+
8
+ import paddlehub as hub
9
+ from paddlehub.module.module import moduleinfo
10
+ from paddlehub.module.module import runnable
11
+ from paddlehub.module.module import serving
12
+
13
+
14
+ def make_md5(s, encoding='utf-8'):
15
+ return md5(s.encode(encoding)).hexdigest()
16
+
17
+
18
+ @moduleinfo(name="baidu_translate",
19
+ version="1.0.0",
20
+ type="text/machine_translation",
21
+ summary="",
22
+ author="baidu-nlp",
23
+ author_email="paddle-dev@baidu.com")
24
+ class BaiduTranslate:
25
+
26
+ def __init__(self, appid=None, appkey=None):
27
+ """
28
+ :param appid: appid for requesting Baidu translation service.
29
+ :param appkey: appkey for requesting Baidu translation service.
30
+ """
31
+ # Set your own appid/appkey.
32
+ if appid == None:
33
+ self.appid = '20201015000580007'
34
+ else:
35
+ self.appid = appid
36
+ if appkey is None:
37
+ self.appkey = 'IFJB6jBORFuMmVGDRud1'
38
+ else:
39
+ self.appkey = appkey
40
+ self.url = 'http://api.fanyi.baidu.com/api/trans/vip/translate'
41
+
42
+ def translate(self, query: str, from_lang: Optional[str] = "en", to_lang: Optional[int] = "zh"):
43
+ """
44
+ Create image by text prompts using ErnieVilG model.
45
+
46
+ :param query: Text to be translated.
47
+ :param from_lang: Source language.
48
+ :param to_lang: Dst language.
49
+
50
+ Return translated string.
51
+ """
52
+ # Generate salt and sign
53
+ salt = random.randint(32768, 65536)
54
+ sign = make_md5(self.appid + query + str(salt) + self.appkey)
55
+
56
+ # Build request
57
+ headers = {'Content-Type': 'application/x-www-form-urlencoded'}
58
+ payload = {'appid': self.appid, 'q': query, 'from': from_lang, 'to': to_lang, 'salt': salt, 'sign': sign}
59
+
60
+ # Send request
61
+ try:
62
+ r = requests.post(self.url, params=payload, headers=headers)
63
+ result = r.json()
64
+ except Exception as e:
65
+ error_msg = str(e)
66
+ raise RuntimeError(error_msg)
67
+ if 'error_code' in result:
68
+ raise RuntimeError(result['error_msg'])
69
+ return result['trans_result'][0]['dst']
70
+
71
+ @runnable
72
+ def run_cmd(self, argvs):
73
+ """
74
+ Run as a command.
75
+ """
76
+ self.parser = argparse.ArgumentParser(description="Run the {} module.".format(self.name),
77
+ prog='hub run {}'.format(self.name),
78
+ usage='%(prog)s',
79
+ add_help=True)
80
+ self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
81
+ self.add_module_input_arg()
82
+ args = self.parser.parse_args(argvs)
83
+ if args.appid is not None and args.appkey is not None:
84
+ self.appid = args.appid
85
+ self.appkey = args.appkey
86
+ result = self.translate(args.query, args.from_lang, args.to_lang)
87
+ return result
88
+
89
+ @serving
90
+ def serving_method(self, query, from_lang, to_lang):
91
+ """
92
+ Run as a service.
93
+ """
94
+ return self.translate(query, from_lang, to_lang)
95
+
96
+ def add_module_input_arg(self):
97
+ """
98
+ Add the command input options.
99
+ """
100
+ self.arg_input_group.add_argument('--query', type=str)
101
+ self.arg_input_group.add_argument('--from_lang', type=str, default='en', help="源语言")
102
+ self.arg_input_group.add_argument('--to_lang', type=str, default='zh', help="目标语言")
103
+ self.arg_input_group.add_argument('--appid', type=str, default=None, help="注册得到的个人appid")
104
+ self.arg_input_group.add_argument('--appkey', type=str, default=None, help="注册得到的个人appkey")
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ ftfy
2
+ spacy
3
+ transformers
4
+ torch
5
+ gradio
6
+ paddlepaddle==2.3.2
7
+ paddlehub
8
+ loguru