File size: 3,991 Bytes
ed6ea08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
from collections import defaultdict
import os

import gradio as gr
from threading import Thread
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
from transformers.models.bert.tokenization_bert import BertTokenizer
from transformers.generation.streamers import TextIteratorStreamer
import torch

from project_settings import project_path


def get_args():
    parser = argparse.ArgumentParser()

    parser.add_argument("--max_new_tokens", default=512, type=int)
    parser.add_argument("--top_p", default=0.9, type=float)
    parser.add_argument("--temperature", default=0.35, type=float)
    parser.add_argument("--repetition_penalty", default=1.0, type=float)
    parser.add_argument('--device', default="cuda" if torch.cuda.is_available() else "cpu", type=str)

    args = parser.parse_args()
    return args


description = """
## GPT2 Chat
"""


examples = [

]


def main():
    args = get_args()

    if args.device == 'auto':
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    else:
        device = args.device

    input_text_box = gr.Text(label="text")
    output_text_box = gr.Text(lines=4, label="generated_content")

    def fn_stream(text: str,
                  max_new_tokens: int = 200,
                  top_p: float = 0.85,
                  temperature: float = 0.35,
                  repetition_penalty: float = 1.2,
                  model_name: str = "qgyd2021/lib_service_4chan",
                  is_chat: bool = True,
                  ):
        tokenizer = BertTokenizer.from_pretrained(model_name)
        model = GPT2LMHeadModel.from_pretrained(model_name)
        model = model.eval()

        text_encoded = tokenizer.__call__(text, add_special_tokens=False)
        input_ids_ = text_encoded["input_ids"]

        input_ids = [tokenizer.cls_token_id]
        input_ids.extend(input_ids_)
        if is_chat:
            input_ids.append(tokenizer.sep_token_id)

        input_ids = torch.tensor([input_ids], dtype=torch.long)
        input_ids = input_ids.to(device)

        output = ""
        streamer = TextIteratorStreamer(tokenizer=tokenizer)

        generation_kwargs = dict(
            inputs=input_ids,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            top_p=top_p,
            temperature=temperature,
            repetition_penalty=repetition_penalty,
            eos_token_id=tokenizer.sep_token_id if is_chat else None,
            pad_token_id=tokenizer.pad_token_id,
            streamer=streamer,
        )
        thread = Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()

        for output_ in streamer:
            output_ = output_.replace(" ", "")
            output_ = output_.replace("[CLS]", "")
            output_ = output_.replace("[SEP]", "\n")
            output_ = output_.replace("[UNK]", "")

            output += output_
            output_text_box.value += output
            yield output

    demo = gr.Interface(
        fn=fn_stream,
        inputs=[
            input_text_box,
            gr.Slider(minimum=0, maximum=512, value=512, step=1, label="max_new_tokens"),
            gr.Slider(minimum=0, maximum=1, value=0.85, step=0.01, label="top_p"),
            gr.Slider(minimum=0, maximum=1, value=0.35, step=0.01, label="temperature"),
            gr.Slider(minimum=0, maximum=2, value=1.2, step=0.01, label="repetition_penalty"),
            gr.Dropdown(choices=["qgyd2021/lib_service_4chan"], label="model_name"),
            gr.Checkbox(label="is_chat")
        ],
        outputs=[output_text_box],
        examples=[
            ["怎样擦屁股才能擦的干净", 512, 0.75, 0.35, 1.2, "qgyd2021/lib_service_4chan", True],
        ],
        cache_examples=False,
        examples_per_page=50,
        title="H Novel Generate",
        description=description,
    )
    demo.queue().launch()

    return


if __name__ == '__main__':
    main()