File size: 4,183 Bytes
f6ff4fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse

import gradio as gr
import torch
from transformers import BloomTokenizerFast, BloomForCausalLM

from project_settings import project_path


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--trained_model_path',
        default=(project_path / "trained_models/bloom-1b4-sft").as_posix(),
        type=str,
    )
    parser.add_argument('--device', default='auto', type=str)

    args = parser.parse_args()
    return args


def main():
    args = get_args()

    if args.device == 'auto':
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    else:
        device = args.device

    # pretrained model
    tokenizer = BloomTokenizerFast.from_pretrained(args.trained_model_path)
    model = BloomForCausalLM.from_pretrained(args.trained_model_path)

    description = """
    FireflyBloom1b4
    
    基于 [YeungNLP/bloom-1b4-zh](https://huggingface.co/YeungNLP/bloom-1b4-zh) 预训练模型,
    基于 [YeungNLP/firefly-train-1.1M](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) 数据集, 
    训练的等同于 [YeungNLP/firefly-bloom-1b4](https://huggingface.co/YeungNLP/firefly-bloom-1b4) 的问答模型. 
    
    训练代码是自己编写的, 在 examples 里, 总共训练了 3 个 epoch. 感觉效果还可以. 

    """

    def fn(text: str,
           max_new_tokens: int = 200,
           top_p: float = 0.85,
           temperature: float = 0.35,
           repetition_penalty: float = 1.2
           ):
        print(text)
        text = '<s>{}</s></s>'.format(text)
        input_ids = tokenizer(text, return_tensors="pt").input_ids
        input_ids = input_ids.to(device)
        outputs = model.generate(input_ids,
                                 max_new_tokens=max_new_tokens,
                                 do_sample=True,
                                 top_p=top_p,
                                 temperature=temperature,
                                 repetition_penalty=repetition_penalty,
                                 eos_token_id=tokenizer.eos_token_id
                                 )
        rets = tokenizer.batch_decode(outputs)
        output = rets[0].strip().replace(text, "").replace('</s>', "")
        print(output)
        return output

    demo = gr.Interface(
        fn=fn,
        inputs=[
            gr.Text(label="text"),
            gr.Number(value=200, label="max_new_tokens"),
            gr.Slider(minimum=0, maximum=1, value=0.85, label="top_p"),
            gr.Slider(minimum=0, maximum=1, value=0.35, label="temperature"),
            gr.Number(value=1.2, label="repetition_penalty"),
        ],
        outputs=[gr.Text(label="output")],
        examples=[
            [
                "将下面句子翻译成现代文:\n石中央又生一树,高百余尺,条干偃阴为五色,翠叶如盘,花径尺余,色深碧,蕊深红,异香成烟,著物霏霏。",
                200, 0.85, 0.35, 1.2
            ],
            [
                "实体识别: 1949年10月1日,人们在北京天安门广场参加开国大典。",
                200, 0.85, 0.35, 1.2
            ],
            [
                "把这句话翻译成英文: 1949年10月1日,人们在北京天安门广场参加开国大典。",
                200, 0.85, 0.35, 1.2
            ],
            [
                "晚上睡不着该怎么办. 请给点详细的介绍.",
                200, 0.85, 0.35, 1.2
            ],
            [
                "将下面的句子翻译成文言文:结婚率下降, 离婚率暴增, 生育率下降, 人民焦虑迷茫, 到底是谁的错.",
                200, 0.85, 0.35, 1.2
            ],
            [
                "对联:厌烟沿檐烟燕眼.",
                200, 0.85, 0.35, 1.2
            ],
            [
                "写一首咏雪的古诗, 标题为 \"沁园春, 雪\".",
                200, 0.85, 0.35, 1.2
            ],
        ],
        examples_per_page=50,
        title="Firefly Bloom 1b4",
        description=description,
    )
    demo.launch()

    return


if __name__ == '__main__':
    main()