File size: 17,655 Bytes
4a3c603
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
# coding=utf-8
# author: xusong <xusong28@jd.com>
# time: 2022/8/23 12:58

"""
## TODO:

1. 下拉框,选择类目。   gr.Radio(['服饰','箱包', '鞋靴']
2. 支持输入特效
  - 示例:https://huggingface.co/uer/gpt2-chinese-lyric
  - 参考 https://github.com/huggingface/hub-docs/blob/main/js/src/lib/components/InferenceWidget/shared/WidgetTextarea/WidgetTextarea.svelte
3. 待开放参数:No Repeat Ngram Size、Length Penalty、Number of Beams。topk-sampling, topp-sampling,


num_beam_groups = return_sequences数吗?

## badcase:

1. 结尾容易出多个句号。为啥?
2. 重复

## 解码demo (能够调整解码参数的demo)

- https://huggingface.co/spaces/THUDM/GLM-130B

## 解码参数示例

**greedy策略**

**sample策略**
- moss: do_sample=True, temperature=0.7, top_p=0.8, top_k=40, repetition_penalty=1.02
- chatglm:do_sample=True, temperature=0.95, top_p=0.7, max_length=2048
- chatglm2:do_sample=True, top_p=0.8, temperature=0.8  https://huggingface.co/THUDM/chatglm2-6b/blob/main/modeling_chatglm.py#L1023
- glm130b:
    - temperature=1, top_p=0.7, top_k=0, no_repeat_ngram_size=3, length_penalty=1, num_beams=2
- vicuna: do_sample=True, temperature=0.7, top_p=1, top_k=-1, repetition_penalty=1
- chatgpt
- baichuan-chat: do_sample=True, temperature=0.3, top_p=0.85, top_k=5, repetition_penalty=1.05  https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/main/generation_config.json
- internlm-chat: do_sample=True, temperature=0.8, top_p=0.8  https://huggingface.co/internlm/internlm-chat-7b-v1_1/blob/main/modeling_internlm.py#L783
    - 解决重复问题,需要添加 repetition_penalty=1.05  https://github.com/InternLM/InternLM/issues/28
- llama2-chat: top_p=0.6, temperature=0.9
- qwen: top_p=0.8, top_k= 0, repetition_penalty=1.1   https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/generation_config.json
- gpt4:
    - temperature=1, top_p=1,
    https://platform.openai.com/docs/api-reference/chat/object
- claude:
-

**beam_search策略**
- tensor2tensor:
- opennmt:
- transformers:
    - asr: num_beams=5, max_length=200   https://github.com/huggingface/transformers/blob/main/examples/pytorch/speech-recognition/README.md
    - wmt:
      - "num_beams=5:10:15 length_penalty=0.6:0.7:0.8:0.9:1.0:1.1"  https://github.com/huggingface/transformers/blob/main/scripts/fsmt/eval-allenai-wmt16.sh
      - num_beams=5 length_penalty=0.8:1.2 early_stopping=true:false  https://github.com/huggingface/transformers/tree/main/examples/legacy/seq2seq
    - tensor2tensor: beam_size=4, alpha=0.6     https://github.com/tensorflow/tensor2tensor/tree/master#walkthrough



## 解码参数

- generate官方文档:
    - https://huggingface.co/blog/how-to-generate
    - https://github.com/huggingface/transformers/blob/main/docs/source/en/generation_strategies.md
    - https://github.com/huggingface/transformers/blob/main/src/transformers/generation/configuration_utils.py
- generate 解码策略介绍:

-


- 去重
    - no_repeat_ngram_size
        - 源码: [NoRepeatNGramLogitsProcessor](https://github.com/huggingface/transformers/blob/v4.29.2/src/transformers/generation/logits_process.py#L490)
        - 逻辑:
        - 取值: 默认0, If set to int > 0, all ngrams of that size can only occur once
               no_repeat_ngram_size=6 即代表: 6-gram不出现2次
        - 兼容:与greedy、sampling、beam_search 兼容
        - 缺陷:
            - 这个可能把GPT的输入都算进去了。比如商品文案写作场景,输入"雅诗兰黛小棕瓶",加入no_repeat_ngram_size参数可能就不能输出"雅诗兰黛小棕瓶"了
            - "only occur once", 需要一个参数 调整成最大允许次数
    - encoder_no_repeat_ngram_size
        - 源码:[EncoderNoRepeatNGramLogitsProcessor](https://github.com/huggingface/transformers/blob/v4.29.2/src/transformers/generation/logits_process.py#L525)
        - 逻辑:
        - 兼容:与greedy、sampling、beam_search 兼容
    - repetition_penalty:
        - 源码:[RepetitionPenaltyLogitsProcessor](https://github.com/huggingface/transformers/blob/v4.29.2/src/transformers/generation/logits_process.py#L206)
        - 逻辑:对input_ids 做去重逻辑,其中 input_ids 是随着解码动态变化的。对于 logits>0 会 logits/=penalty,才叫惩罚。
               类似 coverage mechanism
        - 取值:取值范围(0, inf),>1 才叫惩罚,<1 就叫奖励了,=1 就是 no penalty。论文里说 1.2 能够balance truthful generation and lack of repetition.
        - 公式:
             - 默认                         p=softmax(logits)
             - 加 temperature后             p=softmax(logits/T)
             - 加 repetition_penalty Θ 后   p=softmax(logits/(T* (Θ if i∈g else 1) )  ,其中 i∈g 表示已经生成过的 token
        - 缺陷:未考虑重复次数,也就是 重复2次和重复100次的惩罚是一样的。
        - 兼容:与greedy、sampling、beam_search 兼容
    - encoder_repetition_penalty
        - 源码:[EncoderRepetitionPenaltyLogitsProcessor](https://github.com/huggingface/transformers/blob/v4.29.2/src/transformers/generation/logits_process.py#L228)
        - 逻辑:只对 self.encoder_input_ids 做去重逻辑,self.encoder_input_ids 是静态的
        - 冲突:
- 多样性:
    - do_sample
    - temperature:
        - 取值范围(0, inf),大于1 则会平均化(inf则相当于均匀采样,更多样化),小于1则会集中化(逼近0则相当于greedy)
        - 理解:温度越高,系统越混乱,熵越大(概率越平均化,不确定性越大,生成文本的自由创作空间越大)。温度越低,生成的文本越偏保守。
        - 公式: p=softmax(logits) , 加 temperature后 p=softmax(logits/T)
        - Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
    - diversity_penalty
    -
- 长度
    - max_length
    - max_new_tokens
    - min_length
    - min_new_tokens
    - early_stopping
    - max_tim
    - length_penalty: 长度惩罚因子
      - 取值(-inf, inf),大于0会生成更长的序列,小于0会生成更短的序列。默认值=1.0。
      - 应用场景:仅用于 beam search。(sampling策略建议也加上)
      - 公式: score = sum_logprobs / (generated_len**self.length_penalty)  即:长度越长,当前生成序列(路径)的得分越低。
      - 源码:https://github.com/huggingface/transformers/blob/v4.43.2/src/transformers/generation/beam_search.py#L965
      - 参考文档:https://cloud.tencent.com/developer/article/2295947
    - exponential_decay_length_penalty
      - ss
      - 公式
      - 源码:
- 截止符
    - eos_token_id
- 禁用词黑名单
    - 源码:[NoBadWordsLogitsProcessor](https://github.com/huggingface/transformers/blob/v4.29.2/src/transformers/generation/logits_process.py#L590)
    - bad_words_ids
    - suppress_tokens
- 强制解码词
    - force_words_ids
    - constraints
- 其他参数
    - top_p: only used in sample-based generation
        - 又称Nucleus Sampling
        - 每个时间步,按照字出现的概率由高到底排序,当概率之和大于top-p的时候,就不取后面的样本了。然后对取到的这些字的概率重新归一化后,进行采样。
        - 取值范围:0-1
        - 0表示?
    - top_k: only used in sample-based generation
        - 取值范围:
    - top-P采样方法往往与top-K采样方法结合使用,每次选取两者中最小的采样范围进行采样,可以减少预测分布过于平缓时采样到极小概率单词的几率。


## TODO:

- counted_repetition_penalty: 解决 repetition_penalty 不考虑重复次数的问题,重复越多惩罚越大
- no_repeat_ngram_size:
    - {"ngram": 3, "max_repeat": 1, "ignore_prefix": False}
                   "max_allowed_repetition":

"""

import torch
import gradio as gr
from info import article
from kplug import modeling_kplug_s2s_patch
from transformers import BertTokenizer, BartForConditionalGeneration

model = BartForConditionalGeneration.from_pretrained("eson/kplug-base-cepsum-jiadian")  # cnn指的是cnn daily mail
tokenizer = BertTokenizer.from_pretrained("eson/kplug-base-cepsum-jiadian")


"""
解码策略
  https://zhuanlan.zhihu.com/p/267471193
  https://github.com/huggingface/transformers/blob/v4.29.2/src/transformers/generation/utils.py#L473
"""
gen_mode_params = {
    "greedy": {
        "num_beams": 1,
        "do_sample": False,
    },

    # 核心:next_tokens = torch.multinomial(next_token_probs, num_samples=1)
    "sampling": {
        "num_beams": 1,
        "do_sample": True,
        "repetition_penalty": 1.2
        # temperature # 大于1 则会平均化(inf则相当于均匀采样,更多样化),小于1则会集中化(0则相当于greedy)
        # top_p
        # top_k
    },

    # TODO:
    #  typical sampling:
    #    https://github.com/huggingface/transformers/blob/v4.29.2/src/transformers/generation/logits_process.py#L332


    # TODO:
    # Truncation Sampling: EtaLogitsWarper、EpsilonLogitsWarper
    #   https://github.com/huggingface/transformers/blob/v4.29.2/src/transformers/generation/logits_process.py#L387



    "beam search": {
        "num_beams": 10,
        "do_sample": False,
    },

    # 算法? 复杂度?
    "contrastive search": {
        "top_k": 4,
        "penalty_alpha": 0.2,
    },

    # 算法? 复杂度?
    # 网格波束搜索(Hokamp和Liu,2017)和约束波束搜索(Anderson等,2017)  https://blog.csdn.net/qq_36533552/article/details/106317720
    "diverse beam search": {
        "num_beams": 5,
        "num_beam_groups": 5,
        "num_return_sequences": 5,
        "diversity_penalty": 1.0,
    }
}

all_decoding_strategys = list(gen_mode_params.keys())


def summarize(text, prefix_text, constrained_text, decoding_strategys):
    """
    prefix_text: 能叫 prompt吗?
    constrained_text: 受限解码效果怎么这么差.
    gen_modes: Search Strategy、Decoding strategy、
    """
    # bad_words_ids  num_return_sequences=1, no_repeat_ngram_size=1, remove_invalid_values=True,
    common_params = {"min_length": 20, "max_length": 100}
    inputs = tokenizer([text], max_length=512, return_tensors="pt")

    # prompt_text = GPT2里的参数. 这里是 decoder_input_ids。 shape=(batch_size, n)
    if prefix_text:
        decoder_input_ids = tokenizer([prefix_text], max_length=30, return_tensors="pt")
        # decoder_input_ids = tokenizer(["采用优质的"], max_length=30, return_tensors="pt")
        decoder_input_ids = decoder_input_ids.input_ids[:, :-1]
        decoder_input_ids[:, 0] = model.config.decoder_start_token_id
        common_params["decoder_input_ids"] = decoder_input_ids

    #
    if constrained_text:
        common_params["force_words_ids"] = tokenizer(
            [constrained_text], add_special_tokens=False, max_length=30).input_ids

    result = {}
    print(decoding_strategys)
    for strategy in decoding_strategys:
        if constrained_text and strategy in ["greedy", "sampling", "diverse beam search"]:
            # `num_beams` needs to be greater than 1 for constrained generation.
            # `num_beam_groups` not supported yet for constrained generation.
            result[strategy] = "不支持 constrained text"
            continue

        summary_ids = model.generate(inputs["input_ids"][:, 1:], **common_params, **gen_mode_params[strategy])
        summary = tokenizer.batch_decode(summary_ids, skip_special_tokens=True,
                                         clean_up_tokenization_spaces=False)
        print(strategy, summary)
        result[strategy] = summary

    return result
    # return pd.DataFrame([result])


sum_examples = [
    [
        "美的对开门风冷无霜家用智能电冰箱波光金纤薄机身高颜值助力保鲜,美的家居风,尺寸说明:M以上的距离尤其是左右两侧距离必须保证。关于尺寸的更多问题可,LED冷光源,纤薄机身,风冷无霜,智能操控,远程调温,节能静音,照亮你的视野,535L大容量,系统散热和使用的便利性,建议左右两侧、顶部和背部需要预留10C,电源线和调平脚等。冰箱放置时为保证,菜谱推荐,半开门俯视图,全开门俯视图,预留参考图",
        "", "", all_decoding_strategys],
    [
        "美的对开门风冷无霜家用智能电冰箱波光金纤薄机身高颜值助力保鲜,美的家居风,尺寸说明:M以上的距离尤其是左右两侧距离必须保证。关于尺寸的更多问题可,LED冷光源,纤薄机身,风冷无霜,智能操控,远程调温,节能静音,照亮你的视野,535L大容量,系统散热和使用的便利性,建议左右两侧、顶部和背部需要预留10C,电源线和调平脚等。冰箱放置时为保证,菜谱推荐,半开门俯视图,全开门俯视图,预留参考图",
        "智能", "", all_decoding_strategys],
    [
        "美的对开门风冷无霜家用智能电冰箱波光金纤薄机身高颜值助力保鲜,美的家居风,尺寸说明:M以上的距离尤其是左右两侧距离必须保证。关于尺寸的更多问题可,LED冷光源,纤薄机身,风冷无霜,智能操控,远程调温,节能静音,照亮你的视野,535L大容量,系统散热和使用的便利性,建议左右两侧、顶部和背部需要预留10C,电源线和调平脚等。冰箱放置时为保证,菜谱推荐,半开门俯视图,全开门俯视图,预留参考图",
        "", "风冷无霜", all_decoding_strategys],

    [
        "爱家乐新加坡电风扇静音无叶风扇健康空气循环扇儿童球形风扇落地扇外观,宁静节能,产品结构,现代科技的结晶,品质,气家,未来风新时代,动里,空让,健康,低至13分贝/DC直流马达/低耗24,亲密玩伴,24W功率,/低耗,别加坡国民品牌,气流通道,增强室内空气运动,过尘栅网,1-12档风力调速,涡轮风扇,吸气口,大于6米随心掌控,电源适配暑,装箱明细,摆头角度,手动摇摆轨道,操作方式,与空调同时使用不仅可以让室温快速均衡作,电源插口,适用环境,还可以在短时间内,导引出风口,产品类型,快件重量,电机,暖空气向上冷空气向下,线长,使房间温度均衡,省电环保,定时,功率,将凉风或热风送给到附近的房间,轻松享受生活,左右自动(上下手动)摇摆9度,进风口,能够很快中和空气温度差",
        "", "", all_decoding_strategys],
    [
        "海尔8公斤节能静音高温消毒烫烫净全自动滚筒洗衣机靠实力说话,一掌控时间掌控自由,i-time智能时间洗,8公斤容量全家衣物一次清洗,细节绝不含糊,真正实力派,自动添加洗衣盒,洗羽绒服,就要专属程序,羊毛,牛仔,习绒,海尔洗衣机蓝晶系列滚筒,个性范儿,按照程序需求自动冲入洗衣机内,灵活旋钮,创新下排水洁净不残留,强力筋内筒,AMT防霉窗垫,LED大屏显示,洗衣液,消毒剂分别置放在洗衣盒中,从根本上解决污水残留问题避免,全新LD面板显示,更宽阔更大气操作信息一目了然,宽阔大气操作信息一目了然,右槽:消毒剂,简化洗衣程序,弹力筋中间的凹槽内分布,无残留排水模块,海尔洗衣机具有专业级羽绒洗护程序,为羽绒服营造洗护,一体化环境彻底告别手洗或者机洗,左槽:洗涤剂,我的智慧生活,中槽:柔顺剂,满足各种洗涤需求,告别昂贵洗衣店,自家",
        "", "", all_decoding_strategys],
]

sum_iface = gr.Interface(
    fn=summarize,
    inputs=[
        gr.Textbox(
            label="商品信息(Product Info)",
            value="美的对开门风冷无霜家用智能电冰箱波光金纤薄机身高颜值助力保鲜,美的家居风,尺寸说明:"
                  "M以上的距离尤其是左右两侧距离必须保证。关于尺寸的更多问题可,LED冷光源,纤薄机身,风冷"
                  "无霜,智能操控,远程调温,节能静音,照亮你的视野,535L大容量,系统散热和使用的便利性,"
                  "建议左右两侧、顶部和背部需要预留10C,电源线和调平脚等。冰箱放置时为保证,菜谱推荐,半开"
                  "门俯视图,全开门俯视图,预留参考图"),
        gr.Textbox(
            "",
            label="前缀词(Prefix Text)"
        ),
        gr.Textbox(
            "",
            label="限定词(Constrained Text)"
        ),
        gr.Checkboxgroup(
            all_decoding_strategys, value=all_decoding_strategys[0:1],
            label="解码策略(Decoding Strategy)"
        ),
    ],
    # outputs=gr.Textbox(
    #     label="文本摘要(Summarization)",
    #     lines=4,
    # ),
    # outputs=gr.DataFrame(
    #     label="文本摘要(Summarization)",
    # ),
    outputs=gr.JSON(  # TODO:去掉json array的数字标号
        label="文本摘要(Summarization)",
    ),

    examples=sum_examples,
    title="生成式摘要(Abstractive Summarization)",
    description='生成式摘要,用于电商领域的商品营销文案写作。输入商品信息,输出商品的营销文案。',
    article=article
)

if __name__ == "__main__":
    sum_iface.launch()