File size: 9,531 Bytes
2bb0b26 dcf7e4b 2bb0b26 5b47e63 c10350f 2bb0b26 01920f9 2bb0b26 dcf7e4b 2bb0b26 dcf7e4b 2bb0b26 dcf7e4b 2bb0b26 dcf7e4b a39e93b dcf7e4b 2bb0b26 a39e93b 5b47e63 2bb0b26 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
# coding=utf-8
# author: xusong <xusong28@jd.com>
# time: 2022/8/23 12:58
"""
## TODO:
1. 下拉框,选择类目。 gr.Radio(['服饰','箱包', '鞋靴']
2. 支持输入特效
- 示例:https://huggingface.co/uer/gpt2-chinese-lyric
- 参考 https://github.com/huggingface/hub-docs/blob/main/js/src/lib/components/InferenceWidget/shared/WidgetTextarea/WidgetTextarea.svelte
3. 待开放参数:No Repeat Ngram Size、Length Penalty、Number of Beams。topk-sampling, topp-sampling,
num_beam_groups = return_sequences数吗?
## badcase:
1. 结尾容易出多个句号。为啥?
## 参考
- generate官方文档:https://huggingface.co/blog/how-to-generate
- generate参数介绍:https://github.com/huggingface/transformers/blob/7f1cdf18958efef6339040ba91edb32ae7377720/src/transformers/generation/utils.py#L470
- https://huggingface.co/spaces/THUDM/GLM-130B
"""
import torch
import gradio as gr
from info import article
from kplug import modeling_kplug_s2s_patch
from transformers import BertTokenizer, BartForConditionalGeneration
model = BartForConditionalGeneration.from_pretrained("eson/kplug-base-cepsum-jiadian") # cnn指的是cnn daily mail
tokenizer = BertTokenizer.from_pretrained("eson/kplug-base-cepsum-jiadian")
gen_mode_params = {
"greedy": {
"num_beams": 1,
"do_sample": False,
},
# 核心:next_tokens = torch.multinomial(next_token_probs, num_samples=1)
"sampling": {
"num_beams": 1,
"do_sample": True,
},
"beam search": {
"num_beams": 10,
"do_sample": False,
},
"contrastive search": {
"top_k": 4,
"penalty_alpha": 0.2,
},
"diverse beam search": {
"num_beams": 5,
"num_beam_groups": 5,
"num_return_sequences": 5,
"diversity_penalty": 1.0,
}
}
all_decoding_strategys = list(gen_mode_params.keys())
def summarize(text, prefix_text, constrained_text, decoding_strategys):
"""
prefix_text: 能叫 prompt吗?
constrained_text: 受限解码效果怎么这么差.
gen_modes: Search Strategy、Decoding strategy、
"""
# bad_words_ids num_return_sequences=1, no_repeat_ngram_size=1, remove_invalid_values=True,
common_params = {"min_length": 20, "max_length": 100}
inputs = tokenizer([text], max_length=512, return_tensors="pt")
# prompt_text = GPT2里的参数. 这里是 decoder_input_ids。 shape=(batch_size, n)
if prefix_text:
decoder_input_ids = tokenizer([prefix_text], max_length=30, return_tensors="pt")
# decoder_input_ids = tokenizer(["采用优质的"], max_length=30, return_tensors="pt")
decoder_input_ids = decoder_input_ids.input_ids[:, :-1]
decoder_input_ids[:, 0] = model.config.decoder_start_token_id
common_params["decoder_input_ids"] = decoder_input_ids
#
if constrained_text:
common_params["force_words_ids"] = tokenizer(
[constrained_text], add_special_tokens=False, max_length=30).input_ids
result = {}
print(decoding_strategys)
for strategy in decoding_strategys:
if constrained_text and strategy in ["greedy", "sampling", "diverse beam search"]:
# `num_beams` needs to be greater than 1 for constrained generation.
# `num_beam_groups` not supported yet for constrained generation.
result[strategy] = "不支持受限解码"
continue
summary_ids = model.generate(inputs["input_ids"][:, 1:], **common_params, **gen_mode_params[strategy])
summary = tokenizer.batch_decode(summary_ids, skip_special_tokens=True,
clean_up_tokenization_spaces=False)
print(strategy, summary)
result[strategy] = summary[0]
return result
# return pd.DataFrame([result])
sum_examples = [
[
"美的对开门风冷无霜家用智能电冰箱波光金纤薄机身高颜值助力保鲜,美的家居风,尺寸说明:M以上的距离尤其是左右两侧距离必须保证。关于尺寸的更多问题可,LED冷光源,纤薄机身,风冷无霜,智能操控,远程调温,节能静音,照亮你的视野,535L大容量,系统散热和使用的便利性,建议左右两侧、顶部和背部需要预留10C,电源线和调平脚等。冰箱放置时为保证,菜谱推荐,半开门俯视图,全开门俯视图,预留参考图",
"", "", all_decoding_strategys],
[
"美的对开门风冷无霜家用智能电冰箱波光金纤薄机身高颜值助力保鲜,美的家居风,尺寸说明:M以上的距离尤其是左右两侧距离必须保证。关于尺寸的更多问题可,LED冷光源,纤薄机身,风冷无霜,智能操控,远程调温,节能静音,照亮你的视野,535L大容量,系统散热和使用的便利性,建议左右两侧、顶部和背部需要预留10C,电源线和调平脚等。冰箱放置时为保证,菜谱推荐,半开门俯视图,全开门俯视图,预留参考图",
"智能", "", all_decoding_strategys],
[
"美的对开门风冷无霜家用智能电冰箱波光金纤薄机身高颜值助力保鲜,美的家居风,尺寸说明:M以上的距离尤其是左右两侧距离必须保证。关于尺寸的更多问题可,LED冷光源,纤薄机身,风冷无霜,智能操控,远程调温,节能静音,照亮你的视野,535L大容量,系统散热和使用的便利性,建议左右两侧、顶部和背部需要预留10C,电源线和调平脚等。冰箱放置时为保证,菜谱推荐,半开门俯视图,全开门俯视图,预留参考图",
"", "风冷无霜", all_decoding_strategys],
[
"爱家乐新加坡电风扇静音无叶风扇健康空气循环扇儿童球形风扇落地扇外观,宁静节能,产品结构,现代科技的结晶,品质,气家,未来风新时代,动里,空让,健康,低至13分贝/DC直流马达/低耗24,亲密玩伴,24W功率,/低耗,别加坡国民品牌,气流通道,增强室内空气运动,过尘栅网,1-12档风力调速,涡轮风扇,吸气口,大于6米随心掌控,电源适配暑,装箱明细,摆头角度,手动摇摆轨道,操作方式,与空调同时使用不仅可以让室温快速均衡作,电源插口,适用环境,还可以在短时间内,导引出风口,产品类型,快件重量,电机,暖空气向上冷空气向下,线长,使房间温度均衡,省电环保,定时,功率,将凉风或热风送给到附近的房间,轻松享受生活,左右自动(上下手动)摇摆9度,进风口,能够很快中和空气温度差",
"", "", all_decoding_strategys],
[
"海尔8公斤节能静音高温消毒烫烫净全自动滚筒洗衣机靠实力说话,一掌控时间掌控自由,i-time智能时间洗,8公斤容量全家衣物一次清洗,细节绝不含糊,真正实力派,自动添加洗衣盒,洗羽绒服,就要专属程序,羊毛,牛仔,习绒,海尔洗衣机蓝晶系列滚筒,个性范儿,按照程序需求自动冲入洗衣机内,灵活旋钮,创新下排水洁净不残留,强力筋内筒,AMT防霉窗垫,LED大屏显示,洗衣液,消毒剂分别置放在洗衣盒中,从根本上解决污水残留问题避免,全新LD面板显示,更宽阔更大气操作信息一目了然,宽阔大气操作信息一目了然,右槽:消毒剂,简化洗衣程序,弹力筋中间的凹槽内分布,无残留排水模块,海尔洗衣机具有专业级羽绒洗护程序,为羽绒服营造洗护,一体化环境彻底告别手洗或者机洗,左槽:洗涤剂,我的智慧生活,中槽:柔顺剂,满足各种洗涤需求,告别昂贵洗衣店,自家",
"", "", all_decoding_strategys],
]
sum_iface = gr.Interface(
fn=summarize,
inputs=[
gr.Textbox(
label="商品信息(Product Info)",
value="美的对开门风冷无霜家用智能电冰箱波光金纤薄机身高颜值助力保鲜,美的家居风,尺寸说明:"
"M以上的距离尤其是左右两侧距离必须保证。关于尺寸的更多问题可,LED冷光源,纤薄机身,风冷"
"无霜,智能操控,远程调温,节能静音,照亮你的视野,535L大容量,系统散热和使用的便利性,"
"建议左右两侧、顶部和背部需要预留10C,电源线和调平脚等。冰箱放置时为保证,菜谱推荐,半开"
"门俯视图,全开门俯视图,预留参考图"),
gr.Textbox(
"",
label="prefix text"
),
gr.Textbox(
"",
label="constrained text"
),
gr.Checkboxgroup(
all_decoding_strategys, value=all_decoding_strategys[0:1],
label="decoding strategy"
),
],
# outputs=gr.Textbox(
# label="文本摘要(Summarization)",
# lines=4,
# ),
# outputs=gr.DataFrame(
# label="文本摘要(Summarization)",
# ),
outputs=gr.JSON(
label="文本摘要(Summarization)",
),
examples=sum_examples,
title="生成式摘要(Abstractive Summarization)",
description='生成式摘要,用于电商领域的商品营销文案写作。输入商品信息,输出商品的营销文案。',
article=article
)
if __name__ == "__main__":
sum_iface.launch()
|