liuyizhang commited on
Commit
dc3f3fe
·
1 Parent(s): 1c6fa47

add prompt

Browse files
app.py CHANGED
@@ -1,6 +1,5 @@
1
-
2
  import gradio as gr
3
- import sys
4
  import random
5
  import paddlehub as hub
6
  from loguru import logger
@@ -12,32 +11,59 @@ def getTextTrans(text, source='zh', target='en'):
12
  return text_translation
13
  except Exception as e:
14
  return text
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- model_ids = {
18
- "models/stabilityai/stable-diffusion-2-1":"sd-v2-1",
19
- "models/stabilityai/stable-diffusion-2":"sd-v2-0",
20
- "models/runwayml/stable-diffusion-v1-5":"sd-v1-5",
21
- # "models/CompVis/stable-diffusion-v1-4":"sd-v1-4",
22
- "models/prompthero/openjourney":"openjourney",
23
- "models/hakurei/waifu-diffusion":"waifu-diffusion",
24
- "models/Linaqruf/anything-v3.0":"anything-v3.0",
25
- }
26
  tab_actions = []
27
  tab_titles = []
28
- for model_id in model_ids.keys():
29
- print(model_id, model_ids[model_id])
 
 
 
 
30
  try:
31
- tab = gr.Interface.load(model_id)
32
  tab_actions.append(tab)
33
- tab_titles.append(model_ids[model_id])
34
- except:
35
- logger.info(f"load_fail__{model_id}_")
36
-
37
- def infer(prompt):
38
- logger.info(f"infer_1__")
39
- prompt = getTextTrans(prompt, source='zh', target='en') + f',{random.randint(0,sys.maxsize)}'
40
- return prompt
41
 
42
  start_work = """async() => {
43
  function isMobile() {
@@ -47,7 +73,6 @@ start_work = """async() => {
47
  return false;
48
  }
49
  }
50
-
51
  function getClientHeight()
52
  {
53
  var clientHeight=0;
@@ -70,7 +95,6 @@ start_work = """async() => {
70
  valueSetter.call(element, value);
71
  }
72
  }
73
-
74
  var gradioEl = document.querySelector('body > gradio-app').shadowRoot;
75
  if (!gradioEl) {
76
  gradioEl = document.querySelector('body > gradio-app');
@@ -80,21 +104,35 @@ start_work = """async() => {
80
  window['gradioEl'] = gradioEl;
81
 
82
  tabitems = window['gradioEl'].querySelectorAll('.tabitem');
 
83
  for (var i = 0; i < tabitems.length; i++) {
84
- tabitems[i].childNodes[0].children[0].style.display='none';
85
- tabitems[i].childNodes[0].children[1].children[0].style.display='none';
86
- tabitems[i].childNodes[0].children[1].children[1].children[0].children[1].style.display="none";
87
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  tab_demo = window['gradioEl'].querySelectorAll('#tab_demo')[0];
89
  tab_demo.style.display = "block";
90
  tab_demo.setAttribute('style', 'height: 100%;');
91
-
92
  const page1 = window['gradioEl'].querySelectorAll('#page_1')[0];
93
- const page2 = window['gradioEl'].querySelectorAll('#page_2')[0];
 
 
94
 
95
  page1.style.display = "none";
96
  page2.style.display = "block";
97
-
98
  window['prevPrompt'] = '';
99
  window['doCheckPrompt'] = 0;
100
  window['checkPrompt'] = function checkPrompt() {
@@ -102,19 +140,32 @@ start_work = """async() => {
102
  texts = window['gradioEl'].querySelectorAll('textarea');
103
  text0 = texts[0];
104
  text1 = texts[1];
 
 
 
 
 
 
105
  progress_bar = window['gradioEl'].querySelectorAll('.progress-bar');
106
- if (window['doCheckPrompt'] === 0 && window['prevPrompt'] !== text1.value && progress_bar.length == 0) {
107
- console.log('_____new prompt___[' + text1.value + ']_');
108
  window['doCheckPrompt'] = 1;
109
- window['prevPrompt'] = text1.value;
110
- for (var i = 2; i < texts.length; i++) {
111
- setNativeValue(texts[i], text1.value);
112
- texts[i].dispatchEvent(new Event('input', { bubbles: true }));
113
- }
 
 
 
 
 
 
 
114
  setTimeout(function() {
115
  btns = window['gradioEl'].querySelectorAll('button');
116
  for (var i = 0; i < btns.length; i++) {
117
- if (btns[i].innerText == 'Submit') {
118
  btns[i].click();
119
  }
120
  }
@@ -130,28 +181,40 @@ start_work = """async() => {
130
  return false;
131
  }"""
132
 
133
- with gr.Blocks(title='Text to Image') as demo:
 
 
 
 
 
 
 
 
134
  with gr.Group(elem_id="page_1", visible=True) as page_1:
135
  with gr.Box():
136
  with gr.Row():
137
- start_button = gr.Button("Let's GO!", elem_id="start-btn", visible=True)
138
- start_button.click(fn=None, inputs=[], outputs=[], _js=start_work)
139
 
140
- with gr.Group(elem_id="page_2", visible=False) as page_2:
141
- with gr.Row(elem_id="prompt_row"):
142
- prompt_input0 = gr.Textbox(lines=4, label="prompt")
143
- prompt_input1 = gr.Textbox(lines=4, label="prompt", visible=False)
 
 
 
 
 
 
144
  with gr.Row():
145
- submit_btn = gr.Button(value = "submit",elem_id="erase-btn").style(
146
  margin=True,
147
  rounded=(True, True, True, True),
148
  )
149
- with gr.Row(elem_id='tab_demo', visible=True).style(height=5):
 
150
  tab_demo = gr.TabbedInterface(tab_actions, tab_titles)
 
 
151
 
152
- submit_btn.click(fn=infer, inputs=[prompt_input0], outputs=[prompt_input1])
153
-
154
- if __name__ == "__main__":
155
- demo.launch()
156
-
157
-
 
1
+ from transformers import pipeline
2
  import gradio as gr
 
3
  import random
4
  import paddlehub as hub
5
  from loguru import logger
 
11
  return text_translation
12
  except Exception as e:
13
  return text
14
+
15
+ extend_prompt_pipe = pipeline('text-generation', model='./model', max_length=77)
16
+
17
+ def extend_prompt(prompt):
18
+ prompt_en = getTextTrans(prompt, source='zh', target='en')
19
+ extend_prompt_en = extend_prompt_pipe(prompt_en+',', num_return_sequences=1)[0]["generated_text"]
20
+ if (prompt != prompt_en):
21
+ extend_prompt_zh = getTextTrans(extend_prompt_en, source='en', target='zh')
22
+ extend_prompt_out = f'{extend_prompt_zh} 【{extend_prompt_en}】'
23
+ else:
24
+ extend_prompt_out = extend_prompt_en
25
+
26
+ return prompt_en, extend_prompt_en, extend_prompt_out
27
 
28
+ examples = [
29
+ ['elon musk as thor'],
30
+ ["giant dragon flying in the sky"],
31
+ ['psychedelic liquids space'],
32
+ ["a coconut laying on the beach"],
33
+ ["peaceful village landscape"],
34
+ ]
35
+
36
+ # model_ids = {
37
+ # # "models/stabilityai/stable-diffusion-2-1":"sd-v2-1",
38
+ # "models/stabilityai/stable-diffusion-2":"sd-v2-0",
39
+ # # "models/runwayml/stable-diffusion-v1-5":"sd-v1-5",
40
+ # # "models/CompVis/stable-diffusion-v1-4":"sd-v1-4",
41
+ # "models/prompthero/openjourney":"openjourney",
42
+ # "models/hakurei/waifu-diffusion":"waifu-diffusion",
43
+ # "models/Linaqruf/anything-v3.0":"anything-v3.0",
44
+ # }
45
+
46
+ space_ids = {
47
+ "spaces/stabilityai/stable-diffusion":"Stable Diffusion 2.1",
48
+ "spaces/stabilityai/stable-diffusion-1":"Stable Diffusion 1.0",
49
+ # "spaces/hakurei/waifu-diffusion-demo":"waifu-diffusion",
50
+ }
51
 
 
 
 
 
 
 
 
 
 
52
  tab_actions = []
53
  tab_titles = []
54
+
55
+ thanks_info = "Thanks: "
56
+ thanks_info += "[<a style='display:inline-block' href='https://huggingface.co/spaces/daspartho/prompt-extend' _blank><font style='color:blue;weight:bold;'>prompt-extend</font></a>]"
57
+
58
+ for space_id in space_ids.keys():
59
+ print(space_id, space_ids[space_id])
60
  try:
61
+ tab = gr.Interface.load(space_id)
62
  tab_actions.append(tab)
63
+ tab_titles.append(space_ids[space_id])
64
+ thanks_info += f"[<a style='display:inline-block' href='https://huggingface.co/{space_id}' _blank><font style='color:blue;weight:bold;'>{space_ids[space_id]}</font></a>]"
65
+ except Exception as e:
66
+ logger.info(f"load_fail__{space_id}_{e}")
 
 
 
 
67
 
68
  start_work = """async() => {
69
  function isMobile() {
 
73
  return false;
74
  }
75
  }
 
76
  function getClientHeight()
77
  {
78
  var clientHeight=0;
 
95
  valueSetter.call(element, value);
96
  }
97
  }
 
98
  var gradioEl = document.querySelector('body > gradio-app').shadowRoot;
99
  if (!gradioEl) {
100
  gradioEl = document.querySelector('body > gradio-app');
 
104
  window['gradioEl'] = gradioEl;
105
 
106
  tabitems = window['gradioEl'].querySelectorAll('.tabitem');
107
+
108
  for (var i = 0; i < tabitems.length; i++) {
109
+ if ([0, 1].includes(i)) {
110
+ tabitems[i].childNodes[0].children[0].style.display='none';
111
+ for (var j = 0; j < tabitems[i].childNodes[0].children[1].children.length; j++) {
112
+ if (j != 1) {
113
+ tabitems[i].childNodes[0].children[1].children[j].style.display='none';
114
+ }
115
+ }
116
+ } else if (i==2) {
117
+ tabitems[i].childNodes[0].children[0].style.display='none';
118
+ tabitems[i].childNodes[0].children[1].style.display='none';
119
+ tabitems[i].childNodes[0].children[2].children[0].style.display='none';
120
+ tabitems[i].childNodes[0].children[3].style.display='none';
121
+
122
+ }
123
+
124
+ }
125
+
126
  tab_demo = window['gradioEl'].querySelectorAll('#tab_demo')[0];
127
  tab_demo.style.display = "block";
128
  tab_demo.setAttribute('style', 'height: 100%;');
 
129
  const page1 = window['gradioEl'].querySelectorAll('#page_1')[0];
130
+ const page2 = window['gradioEl'].querySelectorAll('#page_2')[0];
131
+ window['gradioEl'].querySelectorAll('.gr-radio')[0].disabled = "";
132
+ window['gradioEl'].querySelectorAll('.gr-radio')[1].disabled = "";
133
 
134
  page1.style.display = "none";
135
  page2.style.display = "block";
 
136
  window['prevPrompt'] = '';
137
  window['doCheckPrompt'] = 0;
138
  window['checkPrompt'] = function checkPrompt() {
 
140
  texts = window['gradioEl'].querySelectorAll('textarea');
141
  text0 = texts[0];
142
  text1 = texts[1];
143
+ text2 = texts[2];
144
+ if (window['gradioEl'].querySelectorAll('.gr-radio')[0].checked) {
145
+ text_value = text1.value;
146
+ } else {
147
+ text_value = text2.value;
148
+ }
149
  progress_bar = window['gradioEl'].querySelectorAll('.progress-bar');
150
+ if (window['doCheckPrompt'] === 0 && window['prevPrompt'] !== text_value && progress_bar.length == 0) {
151
+ console.log('_____new prompt___[' + text_value + ']_');
152
  window['doCheckPrompt'] = 1;
153
+ window['prevPrompt'] = text_value;
154
+ tabitems = window['gradioEl'].querySelectorAll('.tabitem');
155
+ for (var i = 0; i < tabitems.length; i++) {
156
+ if ([0, 1].includes(i)) {
157
+ inputText = tabitems[i].children[0].children[1].children[0].querySelectorAll('.gr-text-input')[0];
158
+ } else if (i==2) {
159
+ inputText = tabitems[i].childNodes[0].children[2].children[0].children[0].querySelectorAll('.gr-text-input')[0];
160
+ }
161
+ setNativeValue(inputText, text_value);
162
+ inputText.dispatchEvent(new Event('input', { bubbles: true }));
163
+ }
164
+
165
  setTimeout(function() {
166
  btns = window['gradioEl'].querySelectorAll('button');
167
  for (var i = 0; i < btns.length; i++) {
168
+ if (['Generate image','Run'].includes(btns[i].innerText)) {
169
  btns[i].click();
170
  }
171
  }
 
181
  return false;
182
  }"""
183
 
184
+ descriptions = "Thanks: "
185
+ descriptions += "[<a style='display:inline-block' href='https://huggingface.co/spaces/daspartho/prompt-extend' _blank><font style='color:blue;weight:bold;'>prompt-extend</font></a>]"
186
+ descriptions += "[<a style='display:inline-block' href='https://huggingface.co/spaces/stabilityai/stable-diffusion-1' _blank><font style='color:blue;weight:bold;'>Stable Diffusion 1.0</font></a>]"
187
+ descriptions += "[<a style='display:inline-block' href='https://huggingface.co/spaces/stabilityai/stable-diffusion-1' _blank><font style='color:blue;weight:bold;'>Stable Diffusion 1.0</font></a>]"
188
+ descriptions += "[<a style='display:inline-block' href='https://huggingface.co/spaces/hakurei/waifu-diffusion-demo' _blank><font style='color:blue;weight:bold;'>waifu-diffusion-demo</font></a>]"
189
+ descriptions = f"<p>{descriptions}</p>"
190
+
191
+ with gr.Blocks(title='prompt-extend/') as demo:
192
+ # gr.HTML(descriptions)
193
  with gr.Group(elem_id="page_1", visible=True) as page_1:
194
  with gr.Box():
195
  with gr.Row():
196
+ start_button = gr.Button("Let's GO!", elem_id="start-btn", visible=True)
197
+ start_button.click(fn=None, inputs=[], outputs=[], _js=start_work)
198
 
199
+ with gr.Group(elem_id="page_2", visible=False) as page_2:
200
+ with gr.Row(elem_id="prompt_row0"):
201
+ with gr.Column(id="input_col1"):
202
+ prompt_input0 = gr.Textbox(lines=1, label="Original prompt", visible=True)
203
+ prompt_input0_en = gr.Textbox(lines=1, label="Original prompt", visible=False)
204
+ prompt_radio = gr.Radio(["Original prompt", "Extend prompt"], elem_id="prompt_radio",value="Extend prompt", show_label=False)
205
+ # with gr.Row(elem_id="prompt_row1"):
206
+ with gr.Column(id="input_col2"):
207
+ prompt_input1 = gr.Textbox(lines=2, label="Extend prompt", visible=False)
208
+ prompt_input2 = gr.Textbox(lines=2, label="Extend prompt", visible=True)
209
  with gr.Row():
210
+ submit_btn = gr.Button(value = "submit",elem_id="submit-btn").style(
211
  margin=True,
212
  rounded=(True, True, True, True),
213
  )
214
+ submit_btn.click(fn=extend_prompt, inputs=[prompt_input0], outputs=[prompt_input0_en, prompt_input1, prompt_input2])
215
+ with gr.Row(elem_id='tab_demo', visible=True).style(height=200):
216
  tab_demo = gr.TabbedInterface(tab_actions, tab_titles)
217
+ with gr.Row():
218
+ gr.HTML(f"<p>{thanks_info}</p>")
219
 
220
+ demo.launch()
 
 
 
 
 
model/README.md ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - generated_from_trainer
5
+ model-index:
6
+ - name: prompt-extend
7
+ results: []
8
+ ---
9
+ [![Generic badge](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue.svg)](https://huggingface.co/spaces/daspartho/prompt-extend)
10
+
11
+ # Prompt Extend
12
+
13
+ Text generation model for generating suitable style cues given the main idea for a prompt.
14
+
15
+ It is a GPT-2 model trained on [dataset](https://huggingface.co/datasets/daspartho/stable-diffusion-prompts) of stable diffusion prompts.
16
+
17
+ ### Training hyperparameters
18
+
19
+ The following hyperparameters were used during training:
20
+ - learning_rate: 0.0001
21
+ - train_batch_size: 128
22
+ - eval_batch_size: 256
23
+ - seed: 42
24
+ - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
25
+ - lr_scheduler_type: cosine
26
+ - lr_scheduler_warmup_ratio: 0.1
27
+ - num_epochs: 5
28
+ - mixed_precision_training: Native AMP
29
+
30
+ ### Training results
31
+
32
+ | Training Loss | Epoch | Step | Validation Loss |
33
+ |:-------------:|:-----:|:-----:|:---------------:|
34
+ | 3.7436 | 1.0 | 12796 | 2.5429 |
35
+ | 2.3292 | 2.0 | 25592 | 2.0711 |
36
+ | 1.9439 | 3.0 | 38388 | 1.8447 |
37
+ | 1.7059 | 4.0 | 51184 | 1.7325 |
38
+ | 1.5775 | 5.0 | 63980 | 1.7110 |
39
+
40
+
41
+ ### Framework versions
42
+
43
+ - Transformers 4.24.0
44
+ - Pytorch 1.13.0+cu117
45
+ - Datasets 2.7.1
46
+ - Tokenizers 0.13.2
model/config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "gpt2",
3
+ "activation_function": "gelu_new",
4
+ "architectures": [
5
+ "GPT2LMHeadModel"
6
+ ],
7
+ "attn_pdrop": 0.1,
8
+ "bos_token_id": 0,
9
+ "embd_pdrop": 0.1,
10
+ "eos_token_id": 0,
11
+ "initializer_range": 0.02,
12
+ "layer_norm_epsilon": 1e-05,
13
+ "model_type": "gpt2",
14
+ "n_ctx": 128,
15
+ "n_embd": 768,
16
+ "n_head": 12,
17
+ "n_inner": null,
18
+ "n_layer": 12,
19
+ "n_positions": 1024,
20
+ "reorder_and_upcast_attn": false,
21
+ "resid_pdrop": 0.1,
22
+ "scale_attn_by_inverse_layer_idx": false,
23
+ "scale_attn_weights": true,
24
+ "summary_activation": null,
25
+ "summary_first_dropout": 0.1,
26
+ "summary_proj_to_labels": true,
27
+ "summary_type": "cls_index",
28
+ "summary_use_proj": true,
29
+ "task_specific_params": {
30
+ "text-generation": {
31
+ "do_sample": true,
32
+ "max_length": 50
33
+ }
34
+ },
35
+ "torch_dtype": "float32",
36
+ "transformers_version": "4.24.0",
37
+ "use_cache": true,
38
+ "vocab_size": 52000
39
+ }
model/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b2b0c3fa1c9ca9c46fcde7690c72facfdc736802ec22d9144db4f82ae9a4f9e
3
+ size 515752509
model/special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|endoftext|>",
3
+ "eos_token": "<|endoftext|>",
4
+ "pad_token": "<|endoftext|>",
5
+ "unk_token": "<|endoftext|>"
6
+ }
model/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
model/tokenizer_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": "<|endoftext|>",
4
+ "eos_token": "<|endoftext|>",
5
+ "model_max_length": 1024,
6
+ "name_or_path": "daspartho/prompt-tokenizer",
7
+ "special_tokens_map_file": null,
8
+ "tokenizer_class": "GPT2Tokenizer",
9
+ "unk_token": "<|endoftext|>"
10
+ }
model/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:feb1020a4f463c27cf23421abb9ec75003c289d820d4ebbe52b15d9af77b46c2
3
+ size 3387
model/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt CHANGED
@@ -1,8 +1,8 @@
1
- torch
2
  ftfy
3
  spacy
4
- diffusers
5
- transformers
 
6
  paddlepaddle==2.3.2
7
  paddlehub
8
  loguru
 
 
1
  ftfy
2
  spacy
3
+ transformers
4
+ torch
5
+ gradio
6
  paddlepaddle==2.3.2
7
  paddlehub
8
  loguru