|
import random |
|
import re |
|
|
|
import gradio as gr |
|
import torch |
|
|
|
from transformers import AutoModelForCausalLM |
|
from transformers import AutoModelForSeq2SeqLM |
|
from transformers import AutoTokenizer |
|
|
|
from transformers import AutoProcessor |
|
|
|
from transformers import pipeline |
|
|
|
from transformers import set_seed |
|
|
|
global ButtonIndex |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
big_processor = AutoProcessor.from_pretrained("microsoft/git-base-coco") |
|
big_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco") |
|
|
|
pipeline_01 = pipeline('text-generation', model='succinctly/text2image-prompt-generator', max_new_tokens=256) |
|
pipeline_02 = pipeline('text-generation', model='Gustavosta/MagicPrompt-Stable-Diffusion', max_new_tokens=256) |
|
pipeline_03 = pipeline('text-generation', model='johnsu6616/ModelExport', max_new_tokens=256) |
|
|
|
zh2en_model = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-zh-en').eval() |
|
zh2en_tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-zh-en') |
|
|
|
en2zh_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-zh").eval() |
|
en2zh_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-zh") |
|
|
|
def translate_zh2en(text): |
|
with torch.no_grad(): |
|
text = re.sub(r"[:\-–.!;?_#]", '', text) |
|
|
|
text = re.sub(r'([^\u4e00-\u9fa5])([\u4e00-\u9fa5])', r'\1\n\2', text) |
|
text = re.sub(r'([\u4e00-\u9fa5])([^\u4e00-\u9fa5])', r'\1\n\2', text) |
|
|
|
text = text.replace('\n', ',') |
|
|
|
text =re.sub(r'(?<![a-zA-Z])\s+|\s+(?![a-zA-Z])', '', text) |
|
|
|
text = re.sub(r',+', ',', text) |
|
|
|
encoded = zh2en_tokenizer([text], return_tensors='pt') |
|
sequences = zh2en_model.generate(**encoded) |
|
result = zh2en_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0] |
|
|
|
result = result.strip() |
|
|
|
if result == "No,no," : |
|
result = text |
|
|
|
result = re.sub(r'<.*?>', '', result) |
|
|
|
result = re.sub(r'\b(\w+)\b(?:\W+\1\b)+', r'\1', result, flags=re.IGNORECASE) |
|
return result |
|
|
|
|
|
def translate_en2zh(text): |
|
with torch.no_grad(): |
|
|
|
encoded = en2zh_tokenizer([text], return_tensors="pt") |
|
sequences = en2zh_model.generate(**encoded) |
|
result = en2zh_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0] |
|
|
|
result = re.sub(r'\b(\w+)\b(?:\W+\1\b)+', r'\1', result, flags=re.IGNORECASE) |
|
return result |
|
|
|
def load_prompter(): |
|
prompter_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist") |
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
tokenizer.pad_token = tokenizer.eos_token |
|
tokenizer.padding_side = "left" |
|
return prompter_model, tokenizer |
|
|
|
prompter_model, prompter_tokenizer = load_prompter() |
|
|
|
|
|
def generate_prompter_pipeline_01(text): |
|
seed = random.randint(100, 1000000) |
|
set_seed(seed) |
|
text_in_english = translate_zh2en(text) |
|
response = pipeline_01(text_in_english, num_return_sequences=3) |
|
response_list = [] |
|
for x in response: |
|
resp = x['generated_text'].strip() |
|
|
|
if resp != text_in_english and len(resp) > (len(text_in_english) + 4): |
|
|
|
response_list.append(translate_en2zh(resp)+"\n") |
|
response_list.append(resp+"\n") |
|
response_list.append("\n") |
|
|
|
result = "".join(response_list) |
|
result = re.sub('[^ ]+\.[^ ]+','', result) |
|
result = result.replace("<", "").replace(">", "") |
|
|
|
if result != "": |
|
return result |
|
|
|
|
|
def generate_prompter_tokenizer_01(text): |
|
|
|
text_in_english = translate_zh2en(text) |
|
|
|
input_ids = prompter_tokenizer(text_in_english.strip()+" Rephrase:", return_tensors="pt").input_ids |
|
|
|
outputs = prompter_model.generate( |
|
input_ids, |
|
do_sample=False, |
|
|
|
num_beams=3, |
|
num_return_sequences=3, |
|
pad_token_id= 50256, |
|
eos_token_id = 50256, |
|
length_penalty=-1.0 |
|
) |
|
output_texts = prompter_tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
|
|
result = [] |
|
for output_text in output_texts: |
|
|
|
output_text = output_text.replace('<', '').replace('>', '') |
|
output_text = output_text.split("Rephrase:", 1)[-1].strip() |
|
|
|
result.append(translate_en2zh(output_text)+"\n") |
|
result.append(output_text+"\n") |
|
result.append("\n") |
|
return "".join(result) |
|
|
|
def generate_prompter_pipeline_02(text): |
|
seed = random.randint(100, 1000000) |
|
set_seed(seed) |
|
text_in_english = translate_zh2en(text) |
|
response = pipeline_02(text_in_english, num_return_sequences=3) |
|
response_list = [] |
|
for x in response: |
|
resp = x['generated_text'].strip() |
|
if resp != text_in_english and len(resp) > (len(text_in_english) + 4): |
|
|
|
response_list.append(translate_en2zh(resp)+"\n") |
|
response_list.append(resp+"\n") |
|
response_list.append("\n") |
|
|
|
result = "".join(response_list) |
|
result = re.sub('[^ ]+\.[^ ]+','', result) |
|
result = result.replace("<", "").replace(">", "") |
|
|
|
if result != "": |
|
return result |
|
|
|
def generate_prompter_pipeline_03(text): |
|
seed = random.randint(100, 1000000) |
|
set_seed(seed) |
|
text_in_english = translate_zh2en(text) |
|
response = pipeline_03(text_in_english, num_return_sequences=3) |
|
response_list = [] |
|
for x in response: |
|
resp = x['generated_text'].strip() |
|
if resp != text_in_english and len(resp) > (len(text_in_english) + 4): |
|
|
|
response_list.append(translate_en2zh(resp)+"\n") |
|
response_list.append(resp+"\n") |
|
response_list.append("\n") |
|
|
|
result = "".join(response_list) |
|
result = re.sub('[^ ]+\.[^ ]+','', result) |
|
result = result.replace("<", "").replace(">", "") |
|
|
|
if result != "": |
|
return result |
|
|
|
def generate_render(text,choice): |
|
if choice == '★pipeline模式(succinctly)': |
|
outputs = generate_prompter_pipeline_01(text) |
|
return outputs,choice |
|
elif choice == '★★tokenizer模式': |
|
outputs = generate_prompter_tokenizer_01(text) |
|
return outputs,choice |
|
elif choice == '★★★pipeline模型(Gustavosta)': |
|
outputs = generate_prompter_pipeline_02(text) |
|
return outputs,choice |
|
elif choice == 'pipeline模型(John)_自訓測試,資料不穩定': |
|
outputs = generate_prompter_pipeline_03(text) |
|
return outputs,choice |
|
|
|
def get_prompt_from_image(input_image,choice): |
|
image = input_image.convert('RGB') |
|
pixel_values = big_processor(images=image, return_tensors="pt").to(device).pixel_values |
|
generated_ids = big_model.to(device).generate(pixel_values=pixel_values) |
|
generated_caption = big_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
text = re.sub(r"[:\-–.!;?_#]", '', generated_caption) |
|
|
|
if choice == '★pipeline模式(succinctly)': |
|
outputs = generate_prompter_pipeline_01(text) |
|
return outputs |
|
elif choice == '★★tokenizer模式': |
|
outputs = generate_prompter_tokenizer_01(text) |
|
return outputs |
|
elif choice == '★★★pipeline模型(Gustavosta)': |
|
outputs = generate_prompter_pipeline_02(text) |
|
return outputs |
|
elif choice == 'pipeline模型(John)_自訓測試,資料不穩定': |
|
outputs = generate_prompter_pipeline_03(text) |
|
return outputs |
|
|
|
with gr.Blocks() as block: |
|
with gr.Column(): |
|
with gr.Tab('工作區'): |
|
with gr.Row(): |
|
input_text = gr.Textbox(lines=12, label='輸入文字', placeholder='在此输入文字...') |
|
input_image = gr.Image(type='pil', label="選擇圖片(辨識度不佳)") |
|
with gr.Row(): |
|
txt_prompter_btn = gr.Button('文生文') |
|
pic_prompter_btn = gr.Button('圖生文') |
|
with gr.Row(): |
|
radio_btn = gr.Radio( |
|
label="請選擇產出方式", |
|
choices=['★pipeline模式(succinctly)', '★★tokenizer模式', '★★★pipeline模型(Gustavosta)', |
|
'pipeline模型(John)_自訓測試,資料不穩定'], |
|
|
|
value='★pipeline模式(succinctly)' |
|
) |
|
|
|
with gr.Row(): |
|
Textbox_1 = gr.Textbox(lines=6, label='提示詞生成') |
|
with gr.Row(): |
|
Textbox_2 = gr.Textbox(lines=6, label='測試資訊') |
|
|
|
with gr.Tab('測試區'): |
|
with gr.Row(): |
|
input_test01 = gr.Textbox(lines=2, label='中英翻譯', placeholder='在此输入文字...') |
|
test01_btn = gr.Button('執行') |
|
Textbox_test01 = gr.Textbox(lines=2, label='輸出結果') |
|
with gr.Row(): |
|
input_test02 = gr.Textbox(lines=2, label='英中翻譯(不精準)', placeholder='在此输入文字...') |
|
test02_btn = gr.Button('執行') |
|
Textbox_test02 = gr.Textbox(lines=2, label='輸出結果') |
|
with gr.Row(): |
|
input_test03 = gr.Textbox(lines=2, label='★pipeline模式(succinctly)', placeholder='在此输入文字...') |
|
test03_btn = gr.Button('執行') |
|
Textbox_test03 = gr.Textbox(lines=2, label='輸出結果') |
|
with gr.Row(): |
|
input_test04 = gr.Textbox(lines=2, label='★★tokenizer模式', placeholder='在此输入文字...') |
|
test04_btn = gr.Button('執行') |
|
Textbox_test04 = gr.Textbox(lines=2, label='輸出結果') |
|
with gr.Row(): |
|
input_test05 = gr.Textbox(lines=2, label='★★★pipeline模型(Gustavosta)', placeholder='在此输入文字...') |
|
test05_btn = gr.Button('執行') |
|
Textbox_test05 = gr.Textbox(lines=2, label='輸出結果') |
|
with gr.Row(): |
|
input_test06 = gr.Textbox(lines=2, label='pipeline模型(John)_自訓測試,資料不穩定', placeholder='在此输入文字...') |
|
test06_btn = gr.Button('執行') |
|
Textbox_test06 = gr.Textbox(lines=2, label='輸出結果') |
|
|
|
txt_prompter_btn.click ( |
|
fn=generate_render, |
|
inputs=[input_text,radio_btn], |
|
outputs=[Textbox_1,Textbox_2] |
|
) |
|
|
|
pic_prompter_btn.click( |
|
fn=get_prompt_from_image, |
|
inputs=[input_image,radio_btn], |
|
outputs=Textbox_1 |
|
) |
|
|
|
test01_btn.click( |
|
fn=translate_zh2en, |
|
inputs=input_test01, |
|
outputs=Textbox_test01 |
|
) |
|
|
|
test02_btn.click( |
|
fn=translate_en2zh, |
|
inputs=input_test02, |
|
outputs=Textbox_test02 |
|
) |
|
|
|
test03_btn.click( |
|
fn= generate_prompter_pipeline_01, |
|
inputs=input_test03, |
|
outputs=Textbox_test03 |
|
) |
|
|
|
test04_btn.click( |
|
fn= generate_prompter_tokenizer_01, |
|
inputs=input_test04, |
|
outputs=Textbox_test04 |
|
) |
|
|
|
test05_btn.click( |
|
fn= generate_prompter_pipeline_02, |
|
inputs=input_test05, |
|
outputs=Textbox_test05 |
|
) |
|
|
|
|
|
test06_btn.click( |
|
fn= generate_prompter_pipeline_03, |
|
inputs= input_test06, |
|
outputs= Textbox_test06 |
|
) |
|
|
|
block.launch(show_api=False, debug=True, share=False, server_name='0.0.0.0') |
|
|
|
|