from __future__ import annotations
import gradio as gr
from gradio.themes.utils import colors, fonts, sizes
from typing import Iterable
from gradio.themes.base import Base
from transformers import AutoTokenizer,AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("ThornRugal/DSDescription")
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")
class DarkSouls(Base):
def __init__(
self
):
super().__init__(
)
super().set(
body_background_fill='url("cover.jpg") no-repeat center/cover',
block_background_fill='#222',
input_background_fill='#222',
input_shadow="inset 0 0 .3em .1em #531",
body_text_color="#ccc",
border_color_primary="#222",
block_title_text_color="#ccc",
input_radius="0",
button_secondary_background_fill="transparent",
button_secondary_text_color="#ccc"
)
def generate_item_description(ipt):
ipt = tokenizer(ipt.replace("\n",""),return_tensors="pt").to(model.device)
ipt["do_sample"] = False
res = tokenizer.decode(model.generate(**ipt,max_length=256,no_repeat_ngram_size=1)[0]).replace("","")
wrap_chrs = {"……","——"}
for uchar in res:
if (uchar < u'\u4e00' or uchar > u'\u9fa5') and uchar not in ('"','"','“','”',"…","—"):
wrap_chrs.add(uchar)
for i in list(wrap_chrs):
res = res.replace(i,i+"\n")
return res
#iface = gr.Interface(fn=generate_item_description, inputs=gr.Textbox(), outputs="text",theme=Seafoam)
#iface.launch()
with gr.Blocks(theme=DarkSouls(),css="""
span,textarea,button,hl{letter-spacing: 1px;font-family: 'simfang';font-size:18px}
}
button{color:transparent}
#DS_output textarea{line-height: 27px;text-align:center;}
"""
,title="黑暗之魂物品描述生成"
) as demo:
with gr.Column():
gr.Markdown("# 黑暗之魂物品描述生成")
gr.Markdown("""此为利用hugging-face预训练模型来生成物品描述的应用
目标是能够输入物品的前半段描述,让模型输出一个相关的背景故事
可以使用以下的例子作为参考输入:
被称为“邪妖”的剑,随着斩杀敌人数量的增多而变强。
伊扎里斯咒术中最为可怖的一个。牺牲自己的生命将身躯化为火焰钻入敌人身体内部
被称为“干将”与“莫邪”的对剑,是由有名的剑匠夫妇为暴君打造
被人们冠以“正义”的剑,能够驱散黑暗并治疗自身
详细的过程可以看[CSDN链接](https://blog.csdn.net/thorn_r/article/details/137139136)
受作者技术力限制、物品描述段落过短等一系列限制,模型效果可能不理想还有可能“说胡话”,还望各位海涵""")
textbox_ipt = gr.Textbox(label="输入的描述",lines=3)
button = gr.Button(icon="favicon.ico",value="转换")
textbox_output = gr.Textbox(label="输出的描述",lines=12,elem_id="DS_output")
button.click(generate_item_description, textbox_ipt, textbox_output)
if __name__=="__main__":
demo.launch(show_api=False)