johnsu6616 commited on
Commit
5e9c7fb
1 Parent(s): eb1ff2f

修改句子輸出,看起來整齊點

Browse files
Files changed (3) hide show
  1. README.md +3 -3
  2. app.py +145 -53
  3. requirements.txt +2 -1
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
  title: SD_Helper_01
3
  emoji: 📊
4
- colorFrom: blue
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 3.24.1
8
  app_file: app.py
9
  pinned: false
10
  license: openrail
 
1
  ---
2
  title: SD_Helper_01
3
  emoji: 📊
4
+ colorFrom: gray
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 3.30.0
8
  app_file: app.py
9
  pinned: false
10
  license: openrail
app.py CHANGED
@@ -27,51 +27,44 @@ zh2en_tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-zh-en')
27
  en2zh_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-zh").eval()
28
  en2zh_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-zh")
29
 
30
- def load_prompter():
31
- prompter_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist")
32
- tokenizer = AutoTokenizer.from_pretrained("gpt2")
33
- tokenizer.pad_token = tokenizer.eos_token
34
- tokenizer.padding_side = "left"
35
- return prompter_model, tokenizer
36
 
37
- prompter_model, prompter_tokenizer = load_prompter()
 
38
 
39
- def generate_prompter(plain_text, max_new_tokens=75, num_return_sequences=3):
40
- input_ids = prompter_tokenizer(plain_text.strip() + " Rephrase:", return_tensors="pt").input_ids
41
- eos_id = prompter_tokenizer.eos_token_id
42
- outputs = prompter_model.generate(
43
- input_ids,
44
- do_sample=False,
45
- max_new_tokens=75,
46
- num_beams=6,
47
- num_return_sequences=num_return_sequences,
48
- eos_token_id=eos_id,
49
- pad_token_id=eos_id,
50
- length_penalty=-1
51
 
52
- )
53
 
54
- output_texts = prompter_tokenizer.batch_decode(outputs, skip_special_tokens=True)
55
- result = ""
56
- for output_text in output_texts:
57
- result.append(output_text.replace(plain_text + " Rephrase:", "").strip())
58
 
59
- return "\n".join(result)
60
 
61
- def translate_zh2en(text):
62
- with torch.no_grad():
63
- text = text.replace('\n', ',').replace('\r', ',')
64
- text = re.sub('^,+', ',', text)
65
  encoded = zh2en_tokenizer([text], return_tensors='pt')
66
  sequences = zh2en_model.generate(**encoded)
67
- return zh2en_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
 
 
 
 
 
68
 
69
  def translate_en2zh(text):
70
  with torch.no_grad():
 
71
  encoded = en2zh_tokenizer([text], return_tensors="pt")
72
  sequences = en2zh_model.generate(**encoded)
73
  return en2zh_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
74
 
 
 
 
 
 
 
 
 
 
75
  def text_generate(text):
76
  seed = random.randint(100, 1000000)
77
  set_seed(seed)
@@ -83,53 +76,118 @@ def text_generate(text):
83
  list = []
84
  for sequence in sequences:
85
 
 
86
  line = sequence['generated_text'].strip()
87
 
88
- if line != text_in_english and len(line) > (len(text_in_english) + 4) and line.endswith(
89
- (':', '-', '—')) is False:
90
- list.append(line)
 
 
91
 
92
- result = "\n".join(list)
93
 
94
  result = re.sub('[^ ]+\.[^ ]+', '', result)
95
 
96
- result = result.replace('<', '').replace('>', '').replace('"', '')
 
97
  if result != '':
98
  break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- return result, "\n".join(translate_en2zh(line) for line in result.split("\n") if len(line) > 0)
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  def get_prompt_from_image(input_image):
103
  image = input_image.convert('RGB')
104
  pixel_values = big_processor(images=image, return_tensors="pt").to(device).pixel_values
105
  generated_ids = big_model.to(device).generate(pixel_values=pixel_values, max_length=50)
106
  generated_caption = big_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
107
- print(generated_caption)
108
- return generated_caption
109
-
110
 
111
  with gr.Blocks() as block:
112
  with gr.Column():
113
- with gr.Tab('文生文'):
114
  with gr.Row():
115
  input_text = gr.Textbox(lines=12, label='輸入文字', placeholder='在此输入文字...')
116
-
117
  with gr.Row():
118
- txt_prompter_btn = gr.Button('執行')
119
-
120
- with gr.Tab('圖生文'):
121
  with gr.Row():
122
- input_image = gr.Image(type='pil')
123
-
124
  with gr.Row():
125
- pic_prompter_btn = gr.Button('執行')
126
-
127
- Textbox_1 = gr.Textbox(lines=6, label='輸出結果')
128
- Textbox_2 = gr.Textbox(lines=6, label='中文翻譯')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  txt_prompter_btn.click(
131
-
132
- fn=text_generate,
133
  inputs=input_text,
134
  outputs=[Textbox_1,Textbox_2]
135
  )
@@ -137,7 +195,41 @@ with gr.Blocks() as block:
137
  pic_prompter_btn.click(
138
  fn=get_prompt_from_image,
139
  inputs=input_image,
140
- outputs=Textbox_1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  )
142
 
143
  block.queue(max_size=64).launch(show_api=False, enable_queue=True, debug=True, share=False, server_name='0.0.0.0')
 
27
  en2zh_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-zh").eval()
28
  en2zh_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-zh")
29
 
 
 
 
 
 
 
30
 
31
+ def translate_zh2en(text):
32
+ with torch.no_grad():
33
 
34
+ text = re.sub(r'([^\u4e00-\u9fa5])([\u4e00-\u9fa5])', r'\1\n\2', text)
35
+ text = re.sub(r'([\u4e00-\u9fa5])([^\u4e00-\u9fa5])', r'\1\n\2', text)
 
 
 
 
 
 
 
 
 
 
36
 
37
+ text = text.replace('\n', ',')
38
 
39
+ text =re.sub(r'(?<![a-zA-Z])\s+|\s+(?![a-zA-Z])', '', text)
 
 
 
40
 
41
+ text = re.sub(r',+', ',', text)
42
 
 
 
 
 
43
  encoded = zh2en_tokenizer([text], return_tensors='pt')
44
  sequences = zh2en_model.generate(**encoded)
45
+ result = zh2en_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
46
+
47
+ result = result.strip()
48
+
49
+ return result
50
+
51
 
52
  def translate_en2zh(text):
53
  with torch.no_grad():
54
+
55
  encoded = en2zh_tokenizer([text], return_tensors="pt")
56
  sequences = en2zh_model.generate(**encoded)
57
  return en2zh_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
58
 
59
+ def test05(text):
60
+
61
+ return text
62
+
63
+ def test06(text):
64
+
65
+ return text
66
+
67
+
68
  def text_generate(text):
69
  seed = random.randint(100, 1000000)
70
  set_seed(seed)
 
76
  list = []
77
  for sequence in sequences:
78
 
79
+
80
  line = sequence['generated_text'].strip()
81
 
82
+ if line != text_in_english and len(line) > (len(text_in_english) + 4):
83
+
84
+ list.append(translate_en2zh(line)+"\n")
85
+ list.append(line+"\n")
86
+ list.append("\n")
87
 
88
+ result = "".join(list)
89
 
90
  result = re.sub('[^ ]+\.[^ ]+', '', result)
91
 
92
+ result = result.replace('<', '').replace('>', '')
93
+
94
  if result != '':
95
  break
96
+ return result
97
+
98
+
99
+ def load_prompter():
100
+ prompter_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist")
101
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
102
+ tokenizer.pad_token = tokenizer.eos_token
103
+ tokenizer.padding_side = "left"
104
+ return prompter_model, tokenizer
105
+
106
+ prompter_model, prompter_tokenizer = load_prompter()
107
+
108
+ def generate_prompter(text):
109
+ text = translate_zh2en(text)
110
+
111
+ input_ids = prompter_tokenizer(text.strip()+" Rephrase:", return_tensors="pt").input_ids
112
+ eos_id = prompter_tokenizer.eos_token_id
113
+ outputs = prompter_model.generate(
114
+ input_ids,
115
+ do_sample=False,
116
+ max_new_tokens=75,
117
+ num_beams=3,
118
+ num_return_sequences=3,
119
+ eos_token_id=eos_id,
120
+ pad_token_id=eos_id,
121
+ length_penalty=-1.0
122
+ )
123
+ output_texts = prompter_tokenizer.batch_decode(outputs, skip_special_tokens=True)
124
+
125
+ result = []
126
+ for output_text in output_texts:
127
 
128
+ output_text = output_text.replace('<', '').replace('>', '')
129
+ output_text = output_text.split("Rephrase:", 1)[-1].strip()
130
+
131
+ result.append(translate_en2zh(output_text)+"\n")
132
+ result.append(output_text+"\n")
133
+ result.append("\n")
134
+ return "".join(result)
135
+
136
+ def combine_text(text):
137
+ text01 = generate_prompter(text)
138
+ text02 = text_generate(text)
139
+ return text01,text02
140
 
141
  def get_prompt_from_image(input_image):
142
  image = input_image.convert('RGB')
143
  pixel_values = big_processor(images=image, return_tensors="pt").to(device).pixel_values
144
  generated_ids = big_model.to(device).generate(pixel_values=pixel_values, max_length=50)
145
  generated_caption = big_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
146
+ result01 = generate_prompter(generated_caption)
147
+ result02 = text_generate(generated_caption)
148
+ return result01,result02
149
 
150
  with gr.Blocks() as block:
151
  with gr.Column():
152
+ with gr.Tab('工作區'):
153
  with gr.Row():
154
  input_text = gr.Textbox(lines=12, label='輸入文字', placeholder='在此输入文字...')
155
+ input_image = gr.Image(type='pil')
156
  with gr.Row():
157
+ txt_prompter_btn = gr.Button('文生文')
158
+ pic_prompter_btn = gr.Button('圖生文')
 
159
  with gr.Row():
160
+ Textbox_1 = gr.Textbox(lines=6, label='生成方式A')
 
161
  with gr.Row():
162
+ Textbox_2 = gr.Textbox(lines=6, label='生成方式B')
163
+ with gr.Tab('測試區'):
164
+ with gr.Row():
165
+ input_test01 = gr.Textbox(lines=2, label='中英翻譯', placeholder='在此输入文字...')
166
+ test01_btn = gr.Button('執行')
167
+ Textbox_test01 = gr.Textbox(lines=2, label='輸出結果')
168
+ with gr.Row():
169
+ input_test02 = gr.Textbox(lines=2, label='英中翻譯', placeholder='在此输入文字...')
170
+ test02_btn = gr.Button('執行')
171
+ Textbox_test02 = gr.Textbox(lines=2, label='輸出結果')
172
+ with gr.Row():
173
+ input_test03 = gr.Textbox(lines=2, label='SD模式', placeholder='在此输入文字...')
174
+ test03_btn = gr.Button('執行')
175
+ Textbox_test03 = gr.Textbox(lines=2, label='輸出結果')
176
+ with gr.Row():
177
+ input_test04 = gr.Textbox(lines=2, label='瞎掰模式', placeholder='在此输入文字...')
178
+ test04_btn = gr.Button('執行')
179
+ Textbox_test04 = gr.Textbox(lines=2, label='輸出結果')
180
+ with gr.Row():
181
+ input_test05 = gr.Textbox(lines=2, label='沒作用', placeholder='在此输入文字...')
182
+ test05_btn = gr.Button('執行')
183
+ Textbox_test05 = gr.Textbox(lines=2, label='輸出結果')
184
+ with gr.Row():
185
+ input_test06 = gr.Textbox(lines=2, label='沒作用', placeholder='在此输入文字...')
186
+ test06_btn = gr.Button('執行')
187
+ Textbox_test06 = gr.Textbox(lines=2, label='輸出結果')
188
 
189
  txt_prompter_btn.click(
190
+ fn=combine_text,
 
191
  inputs=input_text,
192
  outputs=[Textbox_1,Textbox_2]
193
  )
 
195
  pic_prompter_btn.click(
196
  fn=get_prompt_from_image,
197
  inputs=input_image,
198
+ outputs=[Textbox_1,Textbox_2]
199
+ )
200
+
201
+ test01_btn.click(
202
+ fn=translate_zh2en,
203
+ inputs=input_test01,
204
+ outputs=Textbox_test01
205
+ )
206
+
207
+ test02_btn.click(
208
+ fn=translate_en2zh,
209
+ inputs=input_test02,
210
+ outputs=Textbox_test02
211
+ )
212
+
213
+ test03_btn.click(
214
+ fn=generate_prompter,
215
+ inputs=input_test03,
216
+ outputs=Textbox_test03
217
+ )
218
+
219
+ test04_btn.click(
220
+ fn=text_generate,
221
+ inputs=input_test04,
222
+ outputs=Textbox_test04
223
+ )
224
+ test05_btn.click(
225
+ fn=test05,
226
+ inputs=input_test05,
227
+ outputs=Textbox_test05
228
+ )
229
+ test06_btn.click(
230
+ fn=test06,
231
+ inputs=input_test06,
232
+ outputs=Textbox_test06
233
  )
234
 
235
  block.queue(max_size=64).launch(show_api=False, enable_queue=True, debug=True, share=False, server_name='0.0.0.0')
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
  transformers==4.27.4
2
  torch==2.0.0
3
- gradio==3.24.1
 
4
  sentencepiece==0.1.97
5
  sacremoses==0.0.53
 
1
  transformers==4.27.4
2
  torch==2.0.0
3
+ pytorch_lightning==1.7.7
4
+ gradio==3.30.0
5
  sentencepiece==0.1.97
6
  sacremoses==0.0.53