johnsu6616 commited on
Commit
9daa8f3
1 Parent(s): 856e316

修改字詞的過濾

Browse files
Files changed (1) hide show
  1. app.py +27 -20
app.py CHANGED
@@ -21,9 +21,9 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
21
  big_processor = AutoProcessor.from_pretrained("microsoft/git-base-coco")
22
  big_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")
23
 
24
- pipeline_01 = pipeline('text-generation', model='succinctly/text2image-prompt-generator')
25
- pipeline_02 = pipeline('text-generation', model='Gustavosta/MagicPrompt-Stable-Diffusion', tokenizer='gpt2')
26
- pipeline_03 = pipeline('text-generation', model='johnsu6616/ModelExport')
27
 
28
  zh2en_model = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-zh-en').eval()
29
  zh2en_tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-zh-en')
@@ -33,12 +33,15 @@ en2zh_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-zh")
33
 
34
  def translate_zh2en(text):
35
  with torch.no_grad():
36
-
37
  text = re.sub(r"[:\-–.!;?_#]", '', text)
 
38
  text = re.sub(r'([^\u4e00-\u9fa5])([\u4e00-\u9fa5])', r'\1\n\2', text)
39
  text = re.sub(r'([\u4e00-\u9fa5])([^\u4e00-\u9fa5])', r'\1\n\2', text)
 
40
  text = text.replace('\n', ',')
 
41
  text =re.sub(r'(?<![a-zA-Z])\s+|\s+(?![a-zA-Z])', '', text)
 
42
  text = re.sub(r',+', ',', text)
43
 
44
  encoded = zh2en_tokenizer([text], return_tensors='pt')
@@ -50,8 +53,12 @@ def translate_zh2en(text):
50
  if result == "No,no," :
51
  result = text
52
 
 
 
 
53
  return result
54
 
 
55
  def translate_en2zh(text):
56
  with torch.no_grad():
57
 
@@ -59,7 +66,7 @@ def translate_en2zh(text):
59
  sequences = en2zh_model.generate(**encoded)
60
  result = en2zh_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
61
 
62
- result = re.sub(r'(\b\w+\b)(?:\W+\1\b)+', r'\1', result)
63
  return result
64
 
65
  def load_prompter():
@@ -71,11 +78,12 @@ def load_prompter():
71
 
72
  prompter_model, prompter_tokenizer = load_prompter()
73
 
 
74
  def generate_prompter_pipeline_01(text):
75
  seed = random.randint(100, 1000000)
76
  set_seed(seed)
77
  text_in_english = translate_zh2en(text)
78
- response = pipeline_01(text_in_english, max_new_tokens=80, num_return_sequences=3)
79
  response_list = []
80
  for x in response:
81
  resp = x['generated_text'].strip()
@@ -87,27 +95,27 @@ def generate_prompter_pipeline_01(text):
87
  response_list.append("\n")
88
 
89
  result = "".join(response_list)
90
- result = re.sub('[^ ]+\.[^ ]+', '', result)
91
- result = result.replace('<', '').replace('>', '')
92
 
93
- if result != '':
94
  return result
95
 
 
96
  def generate_prompter_tokenizer_01(text):
97
 
98
  text_in_english = translate_zh2en(text)
99
 
100
  input_ids = prompter_tokenizer(text_in_english.strip()+" Rephrase:", return_tensors="pt").input_ids
101
-
102
- eos_id = 50256
103
  outputs = prompter_model.generate(
104
  input_ids,
105
  do_sample=False,
106
- max_new_tokens=80,
107
  num_beams=3,
108
  num_return_sequences=3,
109
- pad_token_id=eos_id,
110
- eos_token_id=eos_id,
111
  length_penalty=-1.0
112
  )
113
  output_texts = prompter_tokenizer.batch_decode(outputs, skip_special_tokens=True)
@@ -123,12 +131,11 @@ def generate_prompter_tokenizer_01(text):
123
  result.append("\n")
124
  return "".join(result)
125
 
126
-
127
  def generate_prompter_pipeline_02(text):
128
  seed = random.randint(100, 1000000)
129
  set_seed(seed)
130
  text_in_english = translate_zh2en(text)
131
- response = pipeline_02(text_in_english, max_new_tokens=80, num_return_sequences=3)
132
  response_list = []
133
  for x in response:
134
  resp = x['generated_text'].strip()
@@ -149,13 +156,12 @@ def generate_prompter_pipeline_03(text):
149
  seed = random.randint(100, 1000000)
150
  set_seed(seed)
151
  text_in_english = translate_zh2en(text)
152
- response = pipeline_03(text_in_english, max_new_tokens=80, num_return_sequences=3)
153
  response_list = []
154
  for x in response:
155
  resp = x['generated_text'].strip()
156
  if resp != text_in_english and len(resp) > (len(text_in_english) + 4):
157
 
158
-
159
  response_list.append(translate_en2zh(resp)+"\n")
160
  response_list.append(resp+"\n")
161
  response_list.append("\n")
@@ -184,7 +190,7 @@ def generate_render(text,choice):
184
  def get_prompt_from_image(input_image,choice):
185
  image = input_image.convert('RGB')
186
  pixel_values = big_processor(images=image, return_tensors="pt").to(device).pixel_values
187
- generated_ids = big_model.to(device).generate(pixel_values=pixel_values, max_new_tokens=80)
188
  generated_caption = big_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
189
  text = re.sub(r"[:\-–.!;?_#]", '', generated_caption)
190
 
@@ -255,7 +261,7 @@ with gr.Blocks() as block:
255
  inputs=[input_text,radio_btn],
256
  outputs=[Textbox_1,Textbox_2]
257
  )
258
-
259
  pic_prompter_btn.click(
260
  fn=get_prompt_from_image,
261
  inputs=[input_image,radio_btn],
@@ -292,6 +298,7 @@ with gr.Blocks() as block:
292
  outputs=Textbox_test05
293
  )
294
 
 
295
  test06_btn.click(
296
  fn= generate_prompter_pipeline_03,
297
  inputs= input_test06,
 
21
  big_processor = AutoProcessor.from_pretrained("microsoft/git-base-coco")
22
  big_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")
23
 
24
+ pipeline_01 = pipeline('text-generation', model='succinctly/text2image-prompt-generator', max_new_tokens=256)
25
+ pipeline_02 = pipeline('text-generation', model='Gustavosta/MagicPrompt-Stable-Diffusion', max_new_tokens=256)
26
+ pipeline_03 = pipeline('text-generation', model='johnsu6616/ModelExport', max_new_tokens=256)
27
 
28
  zh2en_model = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-zh-en').eval()
29
  zh2en_tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-zh-en')
 
33
 
34
  def translate_zh2en(text):
35
  with torch.no_grad():
 
36
  text = re.sub(r"[:\-–.!;?_#]", '', text)
37
+
38
  text = re.sub(r'([^\u4e00-\u9fa5])([\u4e00-\u9fa5])', r'\1\n\2', text)
39
  text = re.sub(r'([\u4e00-\u9fa5])([^\u4e00-\u9fa5])', r'\1\n\2', text)
40
+
41
  text = text.replace('\n', ',')
42
+
43
  text =re.sub(r'(?<![a-zA-Z])\s+|\s+(?![a-zA-Z])', '', text)
44
+
45
  text = re.sub(r',+', ',', text)
46
 
47
  encoded = zh2en_tokenizer([text], return_tensors='pt')
 
53
  if result == "No,no," :
54
  result = text
55
 
56
+ result = re.sub(r'<.*?>', '', result)
57
+
58
+ result = re.sub(r'\b(\w+)\b(?:\W+\1\b)+', r'\1', result, flags=re.IGNORECASE)
59
  return result
60
 
61
+
62
  def translate_en2zh(text):
63
  with torch.no_grad():
64
 
 
66
  sequences = en2zh_model.generate(**encoded)
67
  result = en2zh_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
68
 
69
+ result = re.sub(r'\b(\w+)\b(?:\W+\1\b)+', r'\1', result, flags=re.IGNORECASE)
70
  return result
71
 
72
  def load_prompter():
 
78
 
79
  prompter_model, prompter_tokenizer = load_prompter()
80
 
81
+
82
  def generate_prompter_pipeline_01(text):
83
  seed = random.randint(100, 1000000)
84
  set_seed(seed)
85
  text_in_english = translate_zh2en(text)
86
+ response = pipeline_01(text_in_english, num_return_sequences=3)
87
  response_list = []
88
  for x in response:
89
  resp = x['generated_text'].strip()
 
95
  response_list.append("\n")
96
 
97
  result = "".join(response_list)
98
+ result = re.sub('[^ ]+\.[^ ]+','', result)
99
+ result = result.replace("<", "").replace(">", "")
100
 
101
+ if result != "":
102
  return result
103
 
104
+
105
  def generate_prompter_tokenizer_01(text):
106
 
107
  text_in_english = translate_zh2en(text)
108
 
109
  input_ids = prompter_tokenizer(text_in_english.strip()+" Rephrase:", return_tensors="pt").input_ids
110
+
 
111
  outputs = prompter_model.generate(
112
  input_ids,
113
  do_sample=False,
114
+
115
  num_beams=3,
116
  num_return_sequences=3,
117
+ pad_token_id= 50256,
118
+ eos_token_id = 50256,
119
  length_penalty=-1.0
120
  )
121
  output_texts = prompter_tokenizer.batch_decode(outputs, skip_special_tokens=True)
 
131
  result.append("\n")
132
  return "".join(result)
133
 
 
134
  def generate_prompter_pipeline_02(text):
135
  seed = random.randint(100, 1000000)
136
  set_seed(seed)
137
  text_in_english = translate_zh2en(text)
138
+ response = pipeline_02(text_in_english, num_return_sequences=3)
139
  response_list = []
140
  for x in response:
141
  resp = x['generated_text'].strip()
 
156
  seed = random.randint(100, 1000000)
157
  set_seed(seed)
158
  text_in_english = translate_zh2en(text)
159
+ response = pipeline_03(text_in_english, num_return_sequences=3)
160
  response_list = []
161
  for x in response:
162
  resp = x['generated_text'].strip()
163
  if resp != text_in_english and len(resp) > (len(text_in_english) + 4):
164
 
 
165
  response_list.append(translate_en2zh(resp)+"\n")
166
  response_list.append(resp+"\n")
167
  response_list.append("\n")
 
190
  def get_prompt_from_image(input_image,choice):
191
  image = input_image.convert('RGB')
192
  pixel_values = big_processor(images=image, return_tensors="pt").to(device).pixel_values
193
+ generated_ids = big_model.to(device).generate(pixel_values=pixel_values)
194
  generated_caption = big_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
195
  text = re.sub(r"[:\-–.!;?_#]", '', generated_caption)
196
 
 
261
  inputs=[input_text,radio_btn],
262
  outputs=[Textbox_1,Textbox_2]
263
  )
264
+
265
  pic_prompter_btn.click(
266
  fn=get_prompt_from_image,
267
  inputs=[input_image,radio_btn],
 
298
  outputs=Textbox_test05
299
  )
300
 
301
+
302
  test06_btn.click(
303
  fn= generate_prompter_pipeline_03,
304
  inputs= input_test06,