kplug / demo_sum.py
xusong28
update
d0547d2
# coding=utf-8
# author: xusong <xusong28@jd.com>
# time: 2022/8/23 12:58
"""
## TODO:
1. 下拉框,选择类目。 gr.Radio(['服饰','箱包', '鞋靴']
2. 支持输入特效
- 示例:https://huggingface.co/uer/gpt2-chinese-lyric
- 参考 https://github.com/huggingface/hub-docs/blob/main/js/src/lib/components/InferenceWidget/shared/WidgetTextarea/WidgetTextarea.svelte
3. 待开放参数:No Repeat Ngram Size、Length Penalty、Number of Beams。topk-sampling, topp-sampling,
num_beam_groups = return_sequences数吗?
## badcase:
1. 结尾容易出多个句号。为啥?
2. 重复
## 解码参数示例
- moss: do_sample=True, temperature=0.7, top_p=0.8, top_k=40, repetition_penalty=1.02
- chatglm:do_sample=True, temperature=0.95, top_p=0.7, max_length=2048
- vicuna: do_sample=True, temperature=0.7, top_p=1, top_k=-1, repetition_penalty=1
- chatgpt
## 参考
- generate官方文档:https://huggingface.co/blog/how-to-generate
- generate 解码策略介绍:
-
- https://huggingface.co/spaces/THUDM/GLM-130B
- 去重
- no_repeat_ngram_size
- 源码: [NoRepeatNGramLogitsProcessor](https://github.com/huggingface/transformers/blob/v4.29.2/src/transformers/generation/logits_process.py#L490)
- 逻辑:
- 取值:
- 兼容:与greedy、sampling、beam_search 兼容
- encoder_no_repeat_ngram_size
- 源码:[EncoderNoRepeatNGramLogitsProcessor](https://github.com/huggingface/transformers/blob/v4.29.2/src/transformers/generation/logits_process.py#L525)
- 逻辑:
- 兼容:与greedy、sampling、beam_search 兼容
- repetition_penalty:
- 源码:[RepetitionPenaltyLogitsProcessor](https://github.com/huggingface/transformers/blob/v4.29.2/src/transformers/generation/logits_process.py#L206)
- 逻辑:对input_ids 做去重逻辑,其中 input_ids 是随着解码动态变化的。对于 logits>0 会 logits/=penalty,才叫惩罚。
- 取值:>1 才叫惩罚,<1 就叫奖励了,=1 就是 no penalty。
- 兼容:与greedy、sampling、beam_search 兼容
- encoder_repetition_penalty
- 源码:[EncoderRepetitionPenaltyLogitsProcessor](https://github.com/huggingface/transformers/blob/v4.29.2/src/transformers/generation/logits_process.py#L228)
- 逻辑:只对 self.encoder_input_ids 做去重逻辑,self.encoder_input_ids 是静态的
- 冲突:
- 长度
- ss
- s
- 禁用词黑名单
- 源码:[NoBadWordsLogitsProcessor](https://github.com/huggingface/transformers/blob/v4.29.2/src/transformers/generation/logits_process.py#L590)
- ss
"""
import torch
import gradio as gr
from info import article
from kplug import modeling_kplug_s2s_patch
from transformers import BertTokenizer, BartForConditionalGeneration
model = BartForConditionalGeneration.from_pretrained("eson/kplug-base-cepsum-jiadian") # cnn指的是cnn daily mail
tokenizer = BertTokenizer.from_pretrained("eson/kplug-base-cepsum-jiadian")
"""
解码策略
https://zhuanlan.zhihu.com/p/267471193
https://github.com/huggingface/transformers/blob/v4.29.2/src/transformers/generation/utils.py#L473
"""
gen_mode_params = {
"greedy": {
"num_beams": 1,
"do_sample": False,
},
# 核心:next_tokens = torch.multinomial(next_token_probs, num_samples=1)
"sampling": {
"num_beams": 1,
"do_sample": True,
"repetition_penalty": 1.2
# temperature # 大于1 则会平均化(inf则相当于均匀采样,更多样化),小于1则会集中化(0则相当于greedy)
# top_p
# top_k
},
# TODO:
# typical sampling:
# https://github.com/huggingface/transformers/blob/v4.29.2/src/transformers/generation/logits_process.py#L332
# TODO:
# Truncation Sampling: EtaLogitsWarper、EpsilonLogitsWarper
# https://github.com/huggingface/transformers/blob/v4.29.2/src/transformers/generation/logits_process.py#L387
"beam search": {
"num_beams": 10,
"do_sample": False,
},
# 算法? 复杂度?
"contrastive search": {
"top_k": 4,
"penalty_alpha": 0.2,
},
# 算法? 复杂度?
"diverse beam search": {
"num_beams": 5,
"num_beam_groups": 5,
"num_return_sequences": 5,
"diversity_penalty": 1.0,
}
}
all_decoding_strategys = list(gen_mode_params.keys())
def summarize(text, prefix_text, constrained_text, decoding_strategys):
"""
prefix_text: 能叫 prompt吗?
constrained_text: 受限解码效果怎么这么差.
gen_modes: Search Strategy、Decoding strategy、
"""
# bad_words_ids num_return_sequences=1, no_repeat_ngram_size=1, remove_invalid_values=True,
common_params = {"min_length": 20, "max_length": 100}
inputs = tokenizer([text], max_length=512, return_tensors="pt")
# prompt_text = GPT2里的参数. 这里是 decoder_input_ids。 shape=(batch_size, n)
if prefix_text:
decoder_input_ids = tokenizer([prefix_text], max_length=30, return_tensors="pt")
# decoder_input_ids = tokenizer(["采用优质的"], max_length=30, return_tensors="pt")
decoder_input_ids = decoder_input_ids.input_ids[:, :-1]
decoder_input_ids[:, 0] = model.config.decoder_start_token_id
common_params["decoder_input_ids"] = decoder_input_ids
#
if constrained_text:
common_params["force_words_ids"] = tokenizer(
[constrained_text], add_special_tokens=False, max_length=30).input_ids
result = {}
print(decoding_strategys)
for strategy in decoding_strategys:
if constrained_text and strategy in ["greedy", "sampling", "diverse beam search"]:
# `num_beams` needs to be greater than 1 for constrained generation.
# `num_beam_groups` not supported yet for constrained generation.
result[strategy] = "不支持 constrained text"
continue
summary_ids = model.generate(inputs["input_ids"][:, 1:], **common_params, **gen_mode_params[strategy])
summary = tokenizer.batch_decode(summary_ids, skip_special_tokens=True,
clean_up_tokenization_spaces=False)
print(strategy, summary)
result[strategy] = summary
return result
# return pd.DataFrame([result])
sum_examples = [
[
"美的对开门风冷无霜家用智能电冰箱波光金纤薄机身高颜值助力保鲜,美的家居风,尺寸说明:M以上的距离尤其是左右两侧距离必须保证。关于尺寸的更多问题可,LED冷光源,纤薄机身,风冷无霜,智能操控,远程调温,节能静音,照亮你的视野,535L大容量,系统散热和使用的便利性,建议左右两侧、顶部和背部需要预留10C,电源线和调平脚等。冰箱放置时为保证,菜谱推荐,半开门俯视图,全开门俯视图,预留参考图",
"", "", all_decoding_strategys],
[
"美的对开门风冷无霜家用智能电冰箱波光金纤薄机身高颜值助力保鲜,美的家居风,尺寸说明:M以上的距离尤其是左右两侧距离必须保证。关于尺寸的更多问题可,LED冷光源,纤薄机身,风冷无霜,智能操控,远程调温,节能静音,照亮你的视野,535L大容量,系统散热和使用的便利性,建议左右两侧、顶部和背部需要预留10C,电源线和调平脚等。冰箱放置时为保证,菜谱推荐,半开门俯视图,全开门俯视图,预留参考图",
"智能", "", all_decoding_strategys],
[
"美的对开门风冷无霜家用智能电冰箱波光金纤薄机身高颜值助力保鲜,美的家居风,尺寸说明:M以上的距离尤其是左右两侧距离必须保证。关于尺寸的更多问题可,LED冷光源,纤薄机身,风冷无霜,智能操控,远程调温,节能静音,照亮你的视野,535L大容量,系统散热和使用的便利性,建议左右两侧、顶部和背部需要预留10C,电源线和调平脚等。冰箱放置时为保证,菜谱推荐,半开门俯视图,全开门俯视图,预留参考图",
"", "风冷无霜", all_decoding_strategys],
[
"爱家乐新加坡电风扇静音无叶风扇健康空气循环扇儿童球形风扇落地扇外观,宁静节能,产品结构,现代科技的结晶,品质,气家,未来风新时代,动里,空让,健康,低至13分贝/DC直流马达/低耗24,亲密玩伴,24W功率,/低耗,别加坡国民品牌,气流通道,增强室内空气运动,过尘栅网,1-12档风力调速,涡轮风扇,吸气口,大于6米随心掌控,电源适配暑,装箱明细,摆头角度,手动摇摆轨道,操作方式,与空调同时使用不仅可以让室温快速均衡作,电源插口,适用环境,还可以在短时间内,导引出风口,产品类型,快件重量,电机,暖空气向上冷空气向下,线长,使房间温度均衡,省电环保,定时,功率,将凉风或热风送给到附近的房间,轻松享受生活,左右自动(上下手动)摇摆9度,进风口,能够很快中和空气温度差",
"", "", all_decoding_strategys],
[
"海尔8公斤节能静音高温消毒烫烫净全自动滚筒洗衣机靠实力说话,一掌控时间掌控自由,i-time智能时间洗,8公斤容量全家衣物一次清洗,细节绝不含糊,真正实力派,自动添加洗衣盒,洗羽绒服,就要专属程序,羊毛,牛仔,习绒,海尔洗衣机蓝晶系列滚筒,个性范儿,按照程序需求自动冲入洗衣机内,灵活旋钮,创新下排水洁净不残留,强力筋内筒,AMT防霉窗垫,LED大屏显示,洗衣液,消毒剂分别置放在洗衣盒中,从根本上解决污水残留问题避免,全新LD面板显示,更宽阔更大气操作信息一目了然,宽阔大气操作信息一目了然,右槽:消毒剂,简化洗衣程序,弹力筋中间的凹槽内分布,无残留排水模块,海尔洗衣机具有专业级羽绒洗护程序,为羽绒服营造洗护,一体化环境彻底告别手洗或者机洗,左槽:洗涤剂,我的智慧生活,中槽:柔顺剂,满足各种洗涤需求,告别昂贵洗衣店,自家",
"", "", all_decoding_strategys],
]
sum_iface = gr.Interface(
fn=summarize,
inputs=[
gr.Textbox(
label="商品信息(Product Info)",
value="美的对开门风冷无霜家用智能电冰箱波光金纤薄机身高颜值助力保鲜,美的家居风,尺寸说明:"
"M以上的距离尤其是左右两侧距离必须保证。关于尺寸的更多问题可,LED冷光源,纤薄机身,风冷"
"无霜,智能操控,远程调温,节能静音,照亮你的视野,535L大容量,系统散热和使用的便利性,"
"建议左右两侧、顶部和背部需要预留10C,电源线和调平脚等。冰箱放置时为保证,菜谱推荐,半开"
"门俯视图,全开门俯视图,预留参考图"),
gr.Textbox(
"",
label="前缀词(Prefix Text)"
),
gr.Textbox(
"",
label="限定词(Constrained Text)"
),
gr.Checkboxgroup(
all_decoding_strategys, value=all_decoding_strategys[0:1],
label="解码策略(Decoding Strategy)"
),
],
# outputs=gr.Textbox(
# label="文本摘要(Summarization)",
# lines=4,
# ),
# outputs=gr.DataFrame(
# label="文本摘要(Summarization)",
# ),
outputs=gr.JSON( # TODO:去掉json array的数字标号
label="文本摘要(Summarization)",
),
examples=sum_examples,
title="生成式摘要(Abstractive Summarization)",
description='生成式摘要,用于电商领域的商品营销文案写作。输入商品信息,输出商品的营销文案。',
article=article
)
if __name__ == "__main__":
sum_iface.launch()