hahafofo commited on
Commit
0db3431
1 Parent(s): d65da85
Files changed (2) hide show
  1. app.py +149 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import re
3
+
4
+ import gradio as gr
5
+ import torch
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+ from transformers import AutoModelForSeq2SeqLM
8
+ from transformers import AutoProcessor
9
+ from transformers import pipeline, set_seed
10
+
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ big_processor = AutoProcessor.from_pretrained("microsoft/git-base-coco")
13
+ big_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")
14
+
15
+ text_pipe = pipeline('text-generation', model='succinctly/text2image-prompt-generator')
16
+
17
+ zh2en_model = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-zh-en').eval()
18
+ zh2en_tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-zh-en')
19
+ en2zh_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-zh").eval()
20
+ en2zh_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-zh")
21
+
22
+
23
+ def load_prompter():
24
+ prompter_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist")
25
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
26
+ tokenizer.pad_token = tokenizer.eos_token
27
+ tokenizer.padding_side = "left"
28
+ return prompter_model, tokenizer
29
+
30
+
31
+ prompter_model, prompter_tokenizer = load_prompter()
32
+
33
+
34
+ def generate_prompter(plain_text, max_new_tokens=75, num_beams=8, num_return_sequences=8, length_penalty=-1.0):
35
+ input_ids = prompter_tokenizer(plain_text.strip() + " Rephrase:", return_tensors="pt").input_ids
36
+ eos_id = prompter_tokenizer.eos_token_id
37
+ outputs = prompter_model.generate(
38
+ input_ids,
39
+ do_sample=False,
40
+ max_new_tokens=max_new_tokens,
41
+ num_beams=num_beams,
42
+ num_return_sequences=num_return_sequences,
43
+ eos_token_id=eos_id,
44
+ pad_token_id=eos_id,
45
+ length_penalty=length_penalty
46
+ )
47
+ output_texts = prompter_tokenizer.batch_decode(outputs, skip_special_tokens=True)
48
+ result = []
49
+ for output_text in output_texts:
50
+ result.append(output_text.replace(plain_text + " Rephrase:", "").strip())
51
+
52
+ return "\n".join(result)
53
+
54
+
55
+ def translate_zh2en(text):
56
+ with torch.no_grad():
57
+ encoded = zh2en_tokenizer([text], return_tensors='pt')
58
+ sequences = zh2en_model.generate(**encoded)
59
+ return zh2en_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
60
+
61
+
62
+ def translate_en2zh(text):
63
+ with torch.no_grad():
64
+ encoded = en2zh_tokenizer([text], return_tensors="pt")
65
+ sequences = en2zh_model.generate(**encoded)
66
+ return en2zh_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
67
+
68
+
69
+ def text_generate(text_in_english):
70
+ seed = random.randint(100, 1000000)
71
+ set_seed(seed)
72
+
73
+ result = ""
74
+ for _ in range(6):
75
+ sequences = text_pipe(text_in_english, max_length=random.randint(60, 90), num_return_sequences=8)
76
+ list = []
77
+ for sequence in sequences:
78
+ line = sequence['generated_text'].strip()
79
+ if line != text_in_english and len(line) > (len(text_in_english) + 4) and line.endswith(
80
+ (':', '-', '—')) is False:
81
+ list.append(line)
82
+
83
+ result = "\n".join(list)
84
+ result = re.sub('[^ ]+\.[^ ]+', '', result)
85
+ result = result.replace('<', '').replace('>', '')
86
+ if result != '':
87
+ break
88
+ return result, "\n".join(translate_en2zh(line) for line in result.split("\n") if len(line) > 0)
89
+
90
+
91
+ def get_prompt_from_image(input_image):
92
+ image = input_image.convert('RGB')
93
+ pixel_values = big_processor(images=image, return_tensors="pt").to(device).pixel_values
94
+
95
+ generated_ids = big_model.to(device).generate(pixel_values=pixel_values, max_length=50)
96
+ generated_caption = big_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
97
+ print(generated_caption)
98
+ return generated_caption
99
+
100
+
101
+ with gr.Blocks() as block:
102
+ with gr.Column():
103
+ with gr.Tab('文本生成'):
104
+ with gr.Row():
105
+ input_text = gr.Textbox(lines=6, label='你的想法', placeholder='在此输入内容...')
106
+ translate_output = gr.Textbox(lines=6, label='翻译结果(Prompt输入)')
107
+
108
+ with gr.Accordion('SD优化参数设置', open=False):
109
+ max_new_tokens = gr.Slider(1, 255, 75, label='max_new_tokens', step=1)
110
+ nub_beams = gr.Slider(1, 30, 8, label='num_beams', step=1)
111
+ num_return_sequences = gr.Slider(1, 30, 8, label='num_return_sequences', step=1)
112
+ length_penalty = gr.Slider(-1.0, 1.0, -1.0, label='length_penalty')
113
+
114
+ generate_prompter_output = gr.Textbox(lines=6, label='SD优化的 Prompt')
115
+
116
+ output = gr.Textbox(lines=6, label='瞎编的 Prompt')
117
+ output_zh = gr.Textbox(lines=6, label='瞎编的 Prompt(zh)')
118
+ with gr.Row():
119
+ translate_btn = gr.Button('翻译')
120
+ generate_prompter_btn = gr.Button('SD优化')
121
+ gpt_btn = gr.Button('瞎编')
122
+
123
+ with gr.Tab('从图片中生成'):
124
+ with gr.Row():
125
+ input_image = gr.Image(type='pil')
126
+ img_btn = gr.Button('提交')
127
+ output_image = gr.Textbox(lines=6, label='生成的 Prompt')
128
+ translate_btn.click(
129
+ fn=translate_zh2en,
130
+ inputs=input_text,
131
+ outputs=translate_output
132
+ )
133
+ generate_prompter_btn.click(
134
+ fn=generate_prompter,
135
+ inputs=[translate_output, max_new_tokens, nub_beams, num_return_sequences, length_penalty],
136
+ outputs=generate_prompter_output
137
+ )
138
+ gpt_btn.click(
139
+ fn=text_generate,
140
+ inputs=translate_output,
141
+ outputs=[output, output_zh]
142
+ )
143
+ img_btn.click(
144
+ fn=get_prompt_from_image,
145
+ inputs=input_image,
146
+ outputs=output_image
147
+ )
148
+
149
+ block.queue(max_size=64).launch(show_api=False, enable_queue=True, debug=True, share=False, server_name='0.0.0.0')
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers==4.27.4
2
+ #sentencepiece
3
+ #sacremoses