File size: 9,531 Bytes
2bb0b26
 
 
 
 
dcf7e4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bb0b26
 
 
 
5b47e63
c10350f
 
2bb0b26
01920f9
 
2bb0b26
dcf7e4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bb0b26
dcf7e4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bb0b26
 
dcf7e4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bb0b26
 
 
 
dcf7e4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a39e93b
 
dcf7e4b
2bb0b26
a39e93b
5b47e63
 
2bb0b26
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# coding=utf-8
# author: xusong <xusong28@jd.com>
# time: 2022/8/23 12:58

"""
## TODO:

1. 下拉框,选择类目。   gr.Radio(['服饰','箱包', '鞋靴']
2. 支持输入特效
  - 示例:https://huggingface.co/uer/gpt2-chinese-lyric
  - 参考 https://github.com/huggingface/hub-docs/blob/main/js/src/lib/components/InferenceWidget/shared/WidgetTextarea/WidgetTextarea.svelte
3. 待开放参数:No Repeat Ngram Size、Length Penalty、Number of Beams。topk-sampling, topp-sampling,


num_beam_groups = return_sequences数吗?

## badcase:

1. 结尾容易出多个句号。为啥?

## 参考

- generate官方文档:https://huggingface.co/blog/how-to-generate
- generate参数介绍:https://github.com/huggingface/transformers/blob/7f1cdf18958efef6339040ba91edb32ae7377720/src/transformers/generation/utils.py#L470
- https://huggingface.co/spaces/THUDM/GLM-130B

"""

import torch
import gradio as gr
from info import article
from kplug import modeling_kplug_s2s_patch
from transformers import BertTokenizer, BartForConditionalGeneration

model = BartForConditionalGeneration.from_pretrained("eson/kplug-base-cepsum-jiadian")  # cnn指的是cnn daily mail
tokenizer = BertTokenizer.from_pretrained("eson/kplug-base-cepsum-jiadian")

gen_mode_params = {
    "greedy": {
        "num_beams": 1,
        "do_sample": False,
    },
    # 核心:next_tokens = torch.multinomial(next_token_probs, num_samples=1)
    "sampling": {
        "num_beams": 1,
        "do_sample": True,
    },
    "beam search": {
        "num_beams": 10,
        "do_sample": False,
    },
    "contrastive search": {
        "top_k": 4,
        "penalty_alpha": 0.2,
    },
    "diverse beam search": {
        "num_beams": 5,
        "num_beam_groups": 5,
        "num_return_sequences": 5,
        "diversity_penalty": 1.0,
    }
}

all_decoding_strategys = list(gen_mode_params.keys())


def summarize(text, prefix_text, constrained_text, decoding_strategys):
    """
    prefix_text: 能叫 prompt吗?
    constrained_text: 受限解码效果怎么这么差.
    gen_modes: Search Strategy、Decoding strategy、
    """
    # bad_words_ids  num_return_sequences=1, no_repeat_ngram_size=1, remove_invalid_values=True,
    common_params = {"min_length": 20, "max_length": 100}
    inputs = tokenizer([text], max_length=512, return_tensors="pt")

    # prompt_text = GPT2里的参数. 这里是 decoder_input_ids。 shape=(batch_size, n)
    if prefix_text:
        decoder_input_ids = tokenizer([prefix_text], max_length=30, return_tensors="pt")
        # decoder_input_ids = tokenizer(["采用优质的"], max_length=30, return_tensors="pt")
        decoder_input_ids = decoder_input_ids.input_ids[:, :-1]
        decoder_input_ids[:, 0] = model.config.decoder_start_token_id
        common_params["decoder_input_ids"] = decoder_input_ids

    #
    if constrained_text:
        common_params["force_words_ids"] = tokenizer(
            [constrained_text], add_special_tokens=False, max_length=30).input_ids

    result = {}
    print(decoding_strategys)
    for strategy in decoding_strategys:
        if constrained_text and strategy in ["greedy", "sampling", "diverse beam search"]:
            # `num_beams` needs to be greater than 1 for constrained generation.
            # `num_beam_groups` not supported yet for constrained generation.
            result[strategy] = "不支持受限解码"
            continue

        summary_ids = model.generate(inputs["input_ids"][:, 1:], **common_params, **gen_mode_params[strategy])
        summary = tokenizer.batch_decode(summary_ids, skip_special_tokens=True,
                                         clean_up_tokenization_spaces=False)
        print(strategy, summary)
        result[strategy] = summary[0]

    return result
    # return pd.DataFrame([result])


sum_examples = [
    [
        "美的对开门风冷无霜家用智能电冰箱波光金纤薄机身高颜值助力保鲜,美的家居风,尺寸说明:M以上的距离尤其是左右两侧距离必须保证。关于尺寸的更多问题可,LED冷光源,纤薄机身,风冷无霜,智能操控,远程调温,节能静音,照亮你的视野,535L大容量,系统散热和使用的便利性,建议左右两侧、顶部和背部需要预留10C,电源线和调平脚等。冰箱放置时为保证,菜谱推荐,半开门俯视图,全开门俯视图,预留参考图",
        "", "", all_decoding_strategys],
    [
        "美的对开门风冷无霜家用智能电冰箱波光金纤薄机身高颜值助力保鲜,美的家居风,尺寸说明:M以上的距离尤其是左右两侧距离必须保证。关于尺寸的更多问题可,LED冷光源,纤薄机身,风冷无霜,智能操控,远程调温,节能静音,照亮你的视野,535L大容量,系统散热和使用的便利性,建议左右两侧、顶部和背部需要预留10C,电源线和调平脚等。冰箱放置时为保证,菜谱推荐,半开门俯视图,全开门俯视图,预留参考图",
        "智能", "", all_decoding_strategys],
    [
        "美的对开门风冷无霜家用智能电冰箱波光金纤薄机身高颜值助力保鲜,美的家居风,尺寸说明:M以上的距离尤其是左右两侧距离必须保证。关于尺寸的更多问题可,LED冷光源,纤薄机身,风冷无霜,智能操控,远程调温,节能静音,照亮你的视野,535L大容量,系统散热和使用的便利性,建议左右两侧、顶部和背部需要预留10C,电源线和调平脚等。冰箱放置时为保证,菜谱推荐,半开门俯视图,全开门俯视图,预留参考图",
        "", "风冷无霜", all_decoding_strategys],

    [
        "爱家乐新加坡电风扇静音无叶风扇健康空气循环扇儿童球形风扇落地扇外观,宁静节能,产品结构,现代科技的结晶,品质,气家,未来风新时代,动里,空让,健康,低至13分贝/DC直流马达/低耗24,亲密玩伴,24W功率,/低耗,别加坡国民品牌,气流通道,增强室内空气运动,过尘栅网,1-12档风力调速,涡轮风扇,吸气口,大于6米随心掌控,电源适配暑,装箱明细,摆头角度,手动摇摆轨道,操作方式,与空调同时使用不仅可以让室温快速均衡作,电源插口,适用环境,还可以在短时间内,导引出风口,产品类型,快件重量,电机,暖空气向上冷空气向下,线长,使房间温度均衡,省电环保,定时,功率,将凉风或热风送给到附近的房间,轻松享受生活,左右自动(上下手动)摇摆9度,进风口,能够很快中和空气温度差",
        "", "", all_decoding_strategys],
    [
        "海尔8公斤节能静音高温消毒烫烫净全自动滚筒洗衣机靠实力说话,一掌控时间掌控自由,i-time智能时间洗,8公斤容量全家衣物一次清洗,细节绝不含糊,真正实力派,自动添加洗衣盒,洗羽绒服,就要专属程序,羊毛,牛仔,习绒,海尔洗衣机蓝晶系列滚筒,个性范儿,按照程序需求自动冲入洗衣机内,灵活旋钮,创新下排水洁净不残留,强力筋内筒,AMT防霉窗垫,LED大屏显示,洗衣液,消毒剂分别置放在洗衣盒中,从根本上解决污水残留问题避免,全新LD面板显示,更宽阔更大气操作信息一目了然,宽阔大气操作信息一目了然,右槽:消毒剂,简化洗衣程序,弹力筋中间的凹槽内分布,无残留排水模块,海尔洗衣机具有专业级羽绒洗护程序,为羽绒服营造洗护,一体化环境彻底告别手洗或者机洗,左槽:洗涤剂,我的智慧生活,中槽:柔顺剂,满足各种洗涤需求,告别昂贵洗衣店,自家",
        "", "", all_decoding_strategys],
]

sum_iface = gr.Interface(
    fn=summarize,
    inputs=[
        gr.Textbox(
            label="商品信息(Product Info)",
            value="美的对开门风冷无霜家用智能电冰箱波光金纤薄机身高颜值助力保鲜,美的家居风,尺寸说明:"
                  "M以上的距离尤其是左右两侧距离必须保证。关于尺寸的更多问题可,LED冷光源,纤薄机身,风冷"
                  "无霜,智能操控,远程调温,节能静音,照亮你的视野,535L大容量,系统散热和使用的便利性,"
                  "建议左右两侧、顶部和背部需要预留10C,电源线和调平脚等。冰箱放置时为保证,菜谱推荐,半开"
                  "门俯视图,全开门俯视图,预留参考图"),
        gr.Textbox(
            "",
            label="prefix text"
        ),
        gr.Textbox(
            "",
            label="constrained text"
        ),
        gr.Checkboxgroup(
            all_decoding_strategys, value=all_decoding_strategys[0:1],
            label="decoding strategy"
        ),
    ],
    # outputs=gr.Textbox(
    #     label="文本摘要(Summarization)",
    #     lines=4,
    # ),
    # outputs=gr.DataFrame(
    #     label="文本摘要(Summarization)",
    # ),
    outputs=gr.JSON(
        label="文本摘要(Summarization)",
    ),

    examples=sum_examples,
    title="生成式摘要(Abstractive Summarization)",
    description='生成式摘要,用于电商领域的商品营销文案写作。输入商品信息,输出商品的营销文案。',
    article=article
)

if __name__ == "__main__":
    sum_iface.launch()