liuyizhang commited on
Commit
a6c7ac9
1 Parent(s): 6075ff9

add text_to_image

Browse files
Files changed (3) hide show
  1. app.py +121 -24
  2. baidu_translate/module.py +104 -0
  3. requirements.txt +3 -1
app.py CHANGED
@@ -1,9 +1,18 @@
1
  from pyChatGPT import ChatGPT
2
  import gradio as gr
3
- import os, json
4
  from loguru import logger
 
5
  import random
6
 
 
 
 
 
 
 
 
 
7
  session_token = os.environ.get('SessionToken')
8
  # logger.info(f"session_token_: {session_token}")
9
 
@@ -19,15 +28,42 @@ def get_response_from_chatbot(text):
19
  response = "Sorry, I'm busy. Try again later."
20
  return response
21
 
22
- def chat(message, chat_history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  out_chat = []
24
  if chat_history != '':
25
  out_chat = json.loads(chat_history)
26
- response = get_response_from_chatbot(message)
27
- out_chat.append((message, response))
28
- chat_history = json.dumps(out_chat)
29
- logger.info(f"out_chat_: {len(out_chat)}")
30
- return out_chat, chat_history
 
 
 
 
 
31
 
32
  start_work = """async() => {
33
  function isMobile() {
@@ -84,9 +120,9 @@ start_work = """async() => {
84
  if (isMobile()) {
85
  window['gradioEl'].querySelectorAll('#component-1')[0].style.display = "none";
86
  window['gradioEl'].querySelectorAll('#component-2')[0].style.display = "none";
87
- new_height = (clientHeight - 200) + 'px';
88
  } else {
89
- new_height = (clientHeight - 300) + 'px';
90
  }
91
  chat_row.style.height = new_height;
92
  window['chat_bot'].style.height = new_height;
@@ -95,23 +131,80 @@ start_work = """async() => {
95
  window['chat_bot1'].children[2].style.height = new_height;
96
  prompt_row.children[0].style.flex = 'auto';
97
  prompt_row.children[0].style.width = '100%';
 
 
98
  prompt_row.children[0].setAttribute('style','flex-direction: inherit; flex: 1 1 auto; width: 100%;border-color: green;border-width: 1px !important;')
99
-
 
 
 
100
  window['checkChange'] = function checkChange() {
101
  try {
102
- if (window['chat_bot'].children[2].children[0].children.length > window['div_count']) {
103
- new_len = window['chat_bot'].children[2].children[0].children.length - window['div_count'];
104
- for (var i = 0; i < new_len; i++) {
105
- new_div = window['chat_bot'].children[2].children[0].children[window['div_count'] + i].cloneNode(true);
106
- window['chat_bot1'].children[2].children[0].appendChild(new_div);
 
 
 
 
 
 
 
 
 
107
  }
108
- window['div_count'] = chat_bot.children[2].children[0].children.length;
109
- window['chat_bot1'].children[2].scrollTop = window['chat_bot1'].children[2].scrollHeight;
110
- }
111
- if (window['chat_bot'].children[0].children.length > 1) {
112
- window['chat_bot1'].children[1].textContent = window['chat_bot'].children[0].children[1].textContent;
113
  } else {
114
- window['chat_bot1'].children[1].textContent = '';
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  }
116
 
117
  } catch(e) {
@@ -138,16 +231,20 @@ with gr.Blocks(title='Talk to chatGPT') as demo:
138
  chatbot = gr.Chatbot(elem_id="chat_bot", visible=False).style(color_map=("green", "blue"))
139
  chatbot1 = gr.Chatbot(elem_id="chat_bot1").style(color_map=("green", "blue"))
140
  with gr.Row(elem_id="prompt_row"):
141
- prompt_input = gr.Textbox(lines=2, label="prompt",show_label=False)
 
142
  chat_history = gr.Textbox(lines=4, label="prompt", visible=False)
 
143
  submit_btn = gr.Button(value = "submit",elem_id="submit-btn").style(
144
  margin=True,
145
  rounded=(True, True, True, True),
146
  width=100
147
  )
148
  submit_btn.click(fn=chat,
149
- inputs=[prompt_input, chat_history],
150
- outputs=[chatbot, chat_history],
151
  )
 
 
152
 
153
  demo.launch(debug = True)
 
1
  from pyChatGPT import ChatGPT
2
  import gradio as gr
3
+ import os, sys, json
4
  from loguru import logger
5
+ import paddlehub as hub
6
  import random
7
 
8
+ language_translation_model = hub.Module(directory=f'./baidu_translate')
9
+ def getTextTrans(text, source='zh', target='en'):
10
+ try:
11
+ text_translation = language_translation_model.translate(text, source, target)
12
+ return text_translation
13
+ except Exception as e:
14
+ return text
15
+
16
  session_token = os.environ.get('SessionToken')
17
  # logger.info(f"session_token_: {session_token}")
18
 
 
28
  response = "Sorry, I'm busy. Try again later."
29
  return response
30
 
31
+ model_ids = {
32
+ # "models/stabilityai/stable-diffusion-2-1":"sd-v2-1",
33
+ # "models/stabilityai/stable-diffusion-2":"sd-v2-0",
34
+ # "models/runwayml/stable-diffusion-v1-5":"sd-v1-5",
35
+ # "models/CompVis/stable-diffusion-v1-4":"sd-v1-4",
36
+ "models/prompthero/openjourney":"openjourney",
37
+ # "models/ShadoWxShinigamI/Midjourney-Rangoli":"midjourney",
38
+ # "models/hakurei/waifu-diffusion":"waifu-diffusion",
39
+ # "models/Linaqruf/anything-v3.0":"anything-v3.0",
40
+ }
41
+
42
+ tab_actions = []
43
+ tab_titles = []
44
+ for model_id in model_ids.keys():
45
+ print(model_id, model_ids[model_id])
46
+ try:
47
+ tab = gr.Interface.load(model_id)
48
+ tab_actions.append(tab)
49
+ tab_titles.append(model_ids[model_id])
50
+ except:
51
+ logger.info(f"load_fail__{model_id}_")
52
+
53
+ def chat(input0, input1, chat_radio, chat_history):
54
  out_chat = []
55
  if chat_history != '':
56
  out_chat = json.loads(chat_history)
57
+ if chat_radio == "Talk to chatGPT":
58
+ response = get_response_from_chatbot(input0)
59
+ out_chat.append((input0, response))
60
+ chat_history = json.dumps(out_chat)
61
+ logger.info(f"out_chat_: {len(out_chat)} / {chat_radio}")
62
+ return out_chat, input1, chat_history
63
+ else:
64
+ prompt_en = getTextTrans(input0, source='zh', target='en') + f',{random.randint(0,sys.maxsize)}'
65
+ return out_chat, prompt_en, chat_history
66
+
67
 
68
  start_work = """async() => {
69
  function isMobile() {
 
120
  if (isMobile()) {
121
  window['gradioEl'].querySelectorAll('#component-1')[0].style.display = "none";
122
  window['gradioEl'].querySelectorAll('#component-2')[0].style.display = "none";
123
+ new_height = (clientHeight - 250) + 'px';
124
  } else {
125
+ new_height = (clientHeight - 350) + 'px';
126
  }
127
  chat_row.style.height = new_height;
128
  window['chat_bot'].style.height = new_height;
 
131
  window['chat_bot1'].children[2].style.height = new_height;
132
  prompt_row.children[0].style.flex = 'auto';
133
  prompt_row.children[0].style.width = '100%';
134
+ window['gradioEl'].querySelectorAll('#chat_radio')[0].style.flex = 'auto';
135
+ window['gradioEl'].querySelectorAll('#chat_radio')[0].style.width = '100%';
136
  prompt_row.children[0].setAttribute('style','flex-direction: inherit; flex: 1 1 auto; width: 100%;border-color: green;border-width: 1px !important;')
137
+
138
+ window['prevPrompt'] = '';
139
+ window['doCheckPrompt'] = 0;
140
+ window['prevImgSrc'] = '';
141
  window['checkChange'] = function checkChange() {
142
  try {
143
+ if (window['gradioEl'].querySelectorAll('.gr-radio')[0].checked) {
144
+ if (window['chat_bot'].children[2].children[0].children.length > window['div_count']) {
145
+ new_len = window['chat_bot'].children[2].children[0].children.length - window['div_count'];
146
+ for (var i = 0; i < new_len; i++) {
147
+ new_div = window['chat_bot'].children[2].children[0].children[window['div_count'] + i].cloneNode(true);
148
+ window['chat_bot1'].children[2].children[0].appendChild(new_div);
149
+ }
150
+ window['div_count'] = chat_bot.children[2].children[0].children.length;
151
+ window['chat_bot1'].children[2].scrollTop = window['chat_bot1'].children[2].scrollHeight;
152
+ }
153
+ if (window['chat_bot'].children[0].children.length > 1) {
154
+ window['chat_bot1'].children[1].textContent = window['chat_bot'].children[0].children[1].textContent;
155
+ } else {
156
+ window['chat_bot1'].children[1].textContent = '';
157
  }
 
 
 
 
 
158
  } else {
159
+ texts = window['gradioEl'].querySelectorAll('textarea');
160
+ text0 = texts[0];
161
+ text1 = texts[1];
162
+ img_index = 0;
163
+ if (window['doCheckPrompt'] === 0 && window['prevPrompt'] !== text1.value) {
164
+ console.log('_____new prompt___[' + text1.value + ']_');
165
+ window['doCheckPrompt'] = 1;
166
+ window['prevPrompt'] = text1.value;
167
+ for (var i = 3; i < texts.length; i++) {
168
+ setNativeValue(texts[i], text1.value);
169
+ texts[i].dispatchEvent(new Event('input', { bubbles: true }));
170
+ }
171
+ setTimeout(function() {
172
+ img_submit_btns = window['gradioEl'].querySelectorAll('#tab_img')[0].querySelectorAll("button");
173
+ for (var i = 0; i < img_submit_btns.length; i++) {
174
+ if (img_submit_btns[i].innerText == 'Submit') {
175
+ img_submit_btns[i].click();
176
+ }
177
+ }
178
+ window['doCheckPrompt'] = 0;
179
+ }, 10);
180
+ }
181
+ tabitems = window['gradioEl'].querySelectorAll('.tabitem');
182
+ imgs = tabitems[img_index].children[0].children[1].children[1].children[0].querySelectorAll("img");
183
+ if (imgs.length > 0) {
184
+ if (window['prevImgSrc'] !== imgs[0].src) {
185
+ var user_div = document.createElement("div");
186
+ user_div.className = "px-3 py-2 rounded-[22px] rounded-br-none text-white text-sm chat-message svelte-rct66g";
187
+ user_div.style.backgroundColor = "#16a34a";
188
+ user_div.innerHTML = "<p>" + text0.value + "</p>";
189
+ window['chat_bot1'].children[2].children[0].appendChild(user_div);
190
+
191
+ var bot_div = document.createElement("div");
192
+ bot_div.className = "px-3 py-2 rounded-[22px] rounded-bl-none place-self-start text-white text-sm chat-message svelte-rct66g";
193
+ bot_div.style.backgroundColor = "#2563eb";
194
+ bot_div.style.width = "50%";
195
+ bot_div.style.padding = "0.2rem";
196
+ bot_div.appendChild(imgs[0].cloneNode(true));
197
+ window['chat_bot1'].children[2].children[0].appendChild(bot_div);
198
+
199
+ window['chat_bot1'].children[2].scrollTop = window['chat_bot1'].children[2].scrollHeight;
200
+ window['prevImgSrc'] = imgs[0].src;
201
+ }
202
+ }
203
+ if (tabitems[img_index].children[0].children[1].children[1].children[0].children[0].children.length > 1) {
204
+ window['chat_bot1'].children[1].textContent = tabitems[img_index].children[0].children[1].children[1].children[0].children[0].children[1].textContent;
205
+ } else {
206
+ window['chat_bot1'].children[1].textContent = '';
207
+ }
208
  }
209
 
210
  } catch(e) {
 
231
  chatbot = gr.Chatbot(elem_id="chat_bot", visible=False).style(color_map=("green", "blue"))
232
  chatbot1 = gr.Chatbot(elem_id="chat_bot1").style(color_map=("green", "blue"))
233
  with gr.Row(elem_id="prompt_row"):
234
+ prompt_input0 = gr.Textbox(lines=2, label="prompt",show_label=False)
235
+ prompt_input1 = gr.Textbox(lines=4, label="prompt", visible=False)
236
  chat_history = gr.Textbox(lines=4, label="prompt", visible=False)
237
+ chat_radio = gr.Radio(["Talk to chatGPT", "Text to Image"], elem_id="chat_radio",value="Talk to chatGPT", show_label=False)
238
  submit_btn = gr.Button(value = "submit",elem_id="submit-btn").style(
239
  margin=True,
240
  rounded=(True, True, True, True),
241
  width=100
242
  )
243
  submit_btn.click(fn=chat,
244
+ inputs=[prompt_input0, prompt_input1, chat_radio, chat_history],
245
+ outputs=[chatbot, prompt_input1, chat_history],
246
  )
247
+ with gr.Row(elem_id='tab_img', visible=False).style(height=5):
248
+ tab_img = gr.TabbedInterface(tab_actions, tab_titles)
249
 
250
  demo.launch(debug = True)
baidu_translate/module.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import random
3
+ from hashlib import md5
4
+ from typing import Optional
5
+
6
+ import requests
7
+
8
+ import paddlehub as hub
9
+ from paddlehub.module.module import moduleinfo
10
+ from paddlehub.module.module import runnable
11
+ from paddlehub.module.module import serving
12
+
13
+
14
+ def make_md5(s, encoding='utf-8'):
15
+ return md5(s.encode(encoding)).hexdigest()
16
+
17
+
18
+ @moduleinfo(name="baidu_translate",
19
+ version="1.0.0",
20
+ type="text/machine_translation",
21
+ summary="",
22
+ author="baidu-nlp",
23
+ author_email="paddle-dev@baidu.com")
24
+ class BaiduTranslate:
25
+
26
+ def __init__(self, appid=None, appkey=None):
27
+ """
28
+ :param appid: appid for requesting Baidu translation service.
29
+ :param appkey: appkey for requesting Baidu translation service.
30
+ """
31
+ # Set your own appid/appkey.
32
+ if appid == None:
33
+ self.appid = '20201015000580007'
34
+ else:
35
+ self.appid = appid
36
+ if appkey is None:
37
+ self.appkey = 'IFJB6jBORFuMmVGDRud1'
38
+ else:
39
+ self.appkey = appkey
40
+ self.url = 'http://api.fanyi.baidu.com/api/trans/vip/translate'
41
+
42
+ def translate(self, query: str, from_lang: Optional[str] = "en", to_lang: Optional[int] = "zh"):
43
+ """
44
+ Create image by text prompts using ErnieVilG model.
45
+
46
+ :param query: Text to be translated.
47
+ :param from_lang: Source language.
48
+ :param to_lang: Dst language.
49
+
50
+ Return translated string.
51
+ """
52
+ # Generate salt and sign
53
+ salt = random.randint(32768, 65536)
54
+ sign = make_md5(self.appid + query + str(salt) + self.appkey)
55
+
56
+ # Build request
57
+ headers = {'Content-Type': 'application/x-www-form-urlencoded'}
58
+ payload = {'appid': self.appid, 'q': query, 'from': from_lang, 'to': to_lang, 'salt': salt, 'sign': sign}
59
+
60
+ # Send request
61
+ try:
62
+ r = requests.post(self.url, params=payload, headers=headers)
63
+ result = r.json()
64
+ except Exception as e:
65
+ error_msg = str(e)
66
+ raise RuntimeError(error_msg)
67
+ if 'error_code' in result:
68
+ raise RuntimeError(result['error_msg'])
69
+ return result['trans_result'][0]['dst']
70
+
71
+ @runnable
72
+ def run_cmd(self, argvs):
73
+ """
74
+ Run as a command.
75
+ """
76
+ self.parser = argparse.ArgumentParser(description="Run the {} module.".format(self.name),
77
+ prog='hub run {}'.format(self.name),
78
+ usage='%(prog)s',
79
+ add_help=True)
80
+ self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
81
+ self.add_module_input_arg()
82
+ args = self.parser.parse_args(argvs)
83
+ if args.appid is not None and args.appkey is not None:
84
+ self.appid = args.appid
85
+ self.appkey = args.appkey
86
+ result = self.translate(args.query, args.from_lang, args.to_lang)
87
+ return result
88
+
89
+ @serving
90
+ def serving_method(self, query, from_lang, to_lang):
91
+ """
92
+ Run as a service.
93
+ """
94
+ return self.translate(query, from_lang, to_lang)
95
+
96
+ def add_module_input_arg(self):
97
+ """
98
+ Add the command input options.
99
+ """
100
+ self.arg_input_group.add_argument('--query', type=str)
101
+ self.arg_input_group.add_argument('--from_lang', type=str, default='en', help="源语言")
102
+ self.arg_input_group.add_argument('--to_lang', type=str, default='zh', help="目标语言")
103
+ self.arg_input_group.add_argument('--appid', type=str, default=None, help="注册得到的个人appid")
104
+ self.arg_input_group.add_argument('--appkey', type=str, default=None, help="注册得到的个人appkey")
requirements.txt CHANGED
@@ -1,2 +1,4 @@
1
  pyChatGPT
2
- loguru
 
 
 
1
  pyChatGPT
2
+ loguru
3
+ paddlepaddle==2.3.2
4
+ paddlehub