File size: 4,894 Bytes
25faaaf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import random
import re

import gradio as gr
import torch

from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM

from transformers import AutoProcessor

from transformers import pipeline

from transformers import set_seed

device = "cuda" if torch.cuda.is_available() else "cpu"

big_processor = AutoProcessor.from_pretrained("microsoft/git-base-coco")
big_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")

text_pipe = pipeline('text-generation', model='succinctly/text2image-prompt-generator')

zh2en_model = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-zh-en').eval()
zh2en_tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-zh-en')

en2zh_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-zh").eval()
en2zh_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-zh")

def load_prompter():
    prompter_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist")
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    return prompter_model, tokenizer

prompter_model, prompter_tokenizer = load_prompter()

def generate_prompter(plain_text, max_new_tokens=75, num_return_sequences=3):
    input_ids = prompter_tokenizer(plain_text.strip() + " Rephrase:", return_tensors="pt").input_ids
    eos_id = prompter_tokenizer.eos_token_id
    outputs = prompter_model.generate(
        input_ids,
        do_sample=False,
        max_new_tokens=75,
        num_beams=6,
        num_return_sequences=num_return_sequences,
        eos_token_id=eos_id,
        pad_token_id=eos_id,
        length_penalty=-1

    )

    output_texts = prompter_tokenizer.batch_decode(outputs, skip_special_tokens=True)
    result = ""
    for output_text in output_texts:
        result.append(output_text.replace(plain_text + " Rephrase:", "").strip())

    return "\n".join(result)

def translate_zh2en(text):
    with torch.no_grad():
        text = text.replace('\n', ',').replace('\r', ',')
        text = re.sub('^,+', ',', text)
        encoded = zh2en_tokenizer([text], return_tensors='pt')
        sequences = zh2en_model.generate(**encoded)
        return zh2en_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]

def translate_en2zh(text):
    with torch.no_grad():
        encoded = en2zh_tokenizer([text], return_tensors="pt")
        sequences = en2zh_model.generate(**encoded)
        return en2zh_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]

def text_generate(text):
    seed = random.randint(100, 1000000)
    set_seed(seed)

    text_in_english = translate_zh2en(text)
    result = ""
    for _ in range(6):
        sequences = text_pipe(text_in_english, max_length=random.randint(60, 90), num_return_sequences=8)
        list = []
        for sequence in sequences:

            line = sequence['generated_text'].strip()

            if line != text_in_english and len(line) > (len(text_in_english) + 4) and line.endswith(
                    (':', '-', '—')) is False:
                list.append(line)

        result = "\n".join(list)

        result = re.sub('[^ ]+\.[^ ]+', '', result)

        result = result.replace('<', '').replace('>', '').replace('"', '')
        if result != '':
            break

    return result, "\n".join(translate_en2zh(line) for line in result.split("\n") if len(line) > 0) 

def get_prompt_from_image(input_image):
    image = input_image.convert('RGB')
    pixel_values = big_processor(images=image, return_tensors="pt").to(device).pixel_values
    generated_ids = big_model.to(device).generate(pixel_values=pixel_values, max_length=50)
    generated_caption = big_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    print(generated_caption)
    return generated_caption


with gr.Blocks() as block:
    with gr.Column():
        with gr.Tab('文生文'):
            with gr.Row():
                input_text = gr.Textbox(lines=12, label='輸入文字', placeholder='在此输入文字...')

            with gr.Row():
                txt_prompter_btn = gr.Button('執行')

        with gr.Tab('圖生文'):
            with gr.Row():
                input_image = gr.Image(type='pil')

            with gr.Row():
                pic_prompter_btn = gr.Button('執行')

    Textbox_1 = gr.Textbox(lines=6, label='輸出結果')
    Textbox_2 = gr.Textbox(lines=6, label='中文翻譯')

    txt_prompter_btn.click(

        fn=text_generate,
        inputs=input_text, 
        outputs=[Textbox_1,Textbox_2]
    )

    pic_prompter_btn.click(
        fn=get_prompt_from_image,
        inputs=input_image,
        outputs=Textbox_1
    )

block.queue(max_size=64).launch(show_api=False, enable_queue=True, debug=True, share=False, server_name='0.0.0.0')