import argparse import random import numpy as np import pytorch_lightning as pl import torch from dataset import KobeDataModule from model import KobeModel from options import add_args, add_options from typing import List import warnings import logging from transformers.models.bert.tokenization_bert import BertTokenizer import sentencepiece as spm import streamlit as st import matplotlib.image as mpimg import matplotlib.pyplot as plt @st.cache(suppress_st_warning=True) def init(): torch.manual_seed(1) parser = argparse.ArgumentParser() add_options(parser) args = parser.parse_args() args.vocab_file= "bert-base-chinese" args.cond_vocab_file = "./vocab.cond.model" add_args(args) model = KobeModel(args) #model = model.load_from_checkpoint("./epoch=19-step=66080.ckpt", args=args) model = model.load_from_checkpoint("./epoch=11-step=396384.ckpt", args=args) text_tokenizer = BertTokenizer.from_pretrained(args.vocab_file) cond_tokenizer = spm.SentencePieceProcessor() cond_tokenizer.Load(args.cond_vocab_file) # Read Images trainer = pl.Trainer() for param in model.parameters(): param.requires_grad = False return model, text_tokenizer, cond_tokenizer, args, trainer img = mpimg.imread('figure.png') # Output Images st.image(img, width=700) model, text_tokenizer, cond_tokenizer, args, trainer = init() class Example: title_token_ids: List[int] condition_token_ids: List[int] fact_token_ids: List[int] def __init__(self, title_token_ids:List[int], condition_token_ids: List[int], fact_token_ids: List[int]): self.title_token_ids = title_token_ids self.condition_token_ids = condition_token_ids self.fact_token_ids = fact_token_ids st.session_state['title'] = '' st.session_state['fact'] = '' st.session_state['aspect'] = '' st.session_state['cat'] = '' input1 = st.selectbox( 'please choose:', ("1. 天猫直发百诺碳纤维三脚架单反相机三角架c2690tb1摄影脚架", "2. dacom飞鱼p10运动型跑步蓝牙耳机入耳式头戴挂耳塞式7级防水苹果安卓手机通用可接听电话音乐篮", "3. coach蔻驰贝壳包pvc单肩手提斜挎大号贝壳包女包5828", "4. highcook韩库韩国进口蓝宝石近无烟炒锅家用不粘锅电磁炉炒菜锅", "5. 欧式复古亚麻布料沙发面料定做飘窗垫窗台垫榻榻米垫抱枕diy布艺", "6. 飞利浦电动剃须刀sp9851充电式带多功能理容配件和智能清洁系统", "7. 不锈钢牛排刀叉西餐餐具全套筷子刀叉勺三件套欧式加厚24件礼盒装", "8. 香百年汽车香水挂件车载香水香薰悬挂吊坠车用车内装饰挂饰精油", "9. 迪士尼小学生书包儿童男孩13一46年级美国队长男童减负12周岁男", "10. 半饱良味潮汕猪肉脯宅人食堂潮汕小吃特产碳烤猪肉干120g")) st.write('You selected:', input1) input1=int(input1.split(". ")[0]) title ="" fact = "" st.write('Fact:') if input1==1: st.session_state.title= "天猫直发百诺碳纤维三脚架单反相机三角架c2690tb1摄影脚架" st.session_state.fact = "太字节Terabyte,计算机存储容量单位,也常用TB来表示。百诺公司创建于1996年,早期与日本合作,后通过自身技术创新与努力,逐渐在国内外抑尘设备行业赢得一席单反就是指单镜头反光,即SLRSingleLensReflex,单反相机就是拥有这个功能的相机。技巧拍摄往往都离不开三脚架的帮助,如夜景拍摄、微距拍摄等方面。" if input1==2: st.session_state.title="dacom飞鱼p10运动型跑步蓝牙耳机入耳式头戴挂耳塞式7级防水苹果安卓手机通用可接听电话音乐篮牙" st.session_state.fact = "移动电话,或称为无线电话,通常称为手机,原本只是一种通讯工具,早期又有大哥大的俗称,是可以在较广范围运动型是德国精神病学家克雷奇默提出的身体类型之一。跑步,是指陆生动物使用足部,移动最快捷的方法。蓝牙耳机就是将蓝牙技术应用在免持耳机上,让使用者可以免除恼人电线的牵绊,自在地以各种方式轻松通话。" #st.write('Fact: \n', st.session_state.fact) if input1==3: st.session_state.title="coach蔻驰贝壳包pvc单肩手提斜挎大号贝壳包女包5828" st.session_state.fact = "聚氯乙烯,英文简称PVCPolyvinylchloride,是氯乙烯单体vinylchloridem女包,这个名词是箱包的性别分类衍生词。贝壳包beikebao女士包种类的一种,因为其外形酷似贝壳的外形而得名。蔻驰为美国经典皮件品牌COACH,一像以简洁、耐用的风格特色赢得消费者的喜爱。" #st.write('Fact: \n', st.session_state.fact) if input1==4: st.session_state.title= "highcook韩库韩国进口蓝宝石近无烟炒锅家用不粘锅电磁炉炒菜锅" st.session_state.fact = "蓝宝石,是刚玉宝石中除红色的红宝石之外,其它颜色刚玉宝石的通称,主要成分是氧化铝Al2O3。电磁炉又称为电磁灶,1957年第一台家用电磁炉诞生于德国。家用是汉语词汇,出自管子权修,解释为家庭日常使用的。不粘锅即做饭不会粘锅底的锅,是因为锅底采用了不粘涂层,常见的、不粘性能最好的有特氟龙涂层和陶瓷涂层。" #st.write('Fact: \n', st.session_state.fact) if input1==5: st.session_state.title= "欧式复古亚麻布料沙发面料定做飘窗垫窗台垫榻榻米垫抱枕diy布艺" st.session_state.fact = "沙发是个外来词,根据英语单词sofa音译而来。面料就是用来制作服装的材料。飘窗垫,就是放在飘窗的台面上的垫子。复古与怀旧,有时候很难区分。" #st.write('Fact: \n', st.session_state.fact) if input1==6: st.session_state.title= "飞利浦电动剃须刀sp9851充电式带多功能理容配件和智能清洁系统" st.session_state.fact = "能够完成一种或者几种生理功能的多个器官按照一定的次序组合在一起的结构叫做系统。配件,指装配机械的零件或部件;也指损坏后重新安装上的零件或部件。清洁是由奥利维耶阿萨亚斯执导,张曼玉、尼克诺尔蒂主演的剧情片,于2004年9月1日在法国上映。飞利浦,1891年成立于荷兰,主要生产照明、家庭电器、医疗系统方面的产品。" #st.write('Fact: \n', st.session_state.fact) if input1==7: st.session_state.title= "不锈钢牛排刀叉西餐餐具全套筷子刀叉勺三件套欧式加厚24件礼盒装" st.session_state.fact = "西餐餐具具体有大盘子、小盘子、浅碟、深碟、吃沙拉用的叉子、叉肉用的叉子、喝汤用的汤匙、吃甜点用的汤匙三件咳嗽,贫穷与爱情触不到的恋人...不锈钢指耐空气、蒸汽、水等弱腐蚀介质和酸、碱、盐等化学浸蚀性介质腐蚀的钢,又称不锈耐酸钢。2,4D丁酯,无色油状液体。" #st.write('Fact: \n', st.session_state.fact) if input1==8: st.session_state.title= "香百年汽车香水挂件车载香水香薰悬挂吊坠车用车内装饰挂饰精油" st.session_state.fact = "悬挂系统是汽车的车架与车桥或车轮之间的一切传力连接装置的总称,其作用是传递作用在车轮和车架之间的力和吊坠,一种首饰,配戴在脖子上的饰品,多为金属制,特别是不锈钢制和银制,也有矿石、水晶、玉石等制的,主汽车香水AutoPerfume是一种混合了香精油、固定剂与酒精的液体,用来让汽车车内拥有持久且悦人的精油是从植物的花、叶、茎、根或果实中,通过水蒸气蒸馏法、挤压法、冷浸法或溶剂提取法提炼萃取的挥发性芳" #st.write('Fact: \n', st.session_state.fact) if input1==9: st.session_state.title = "迪士尼小学生书包儿童男孩13一46年级美国队长男童减负12周岁男" st.session_state.fact= "美国队长是每一男孩心中的英雄人物,迪士尼美国队长款的小学生书包,按照美国队长的防护盾牌设计,泛着丝丝银光,帅气有型,而且还有很多英雄款式哦脊背处采用柔软舒适的脊椎防护设计,减轻孩子的背部压力。而且前方盾牌还能拆卸下来,当做斜挎包使用,满足淘气小男孩的英雄梦。" #st.write('Fact: \n', st.session_state.fact) if input1==10: st.session_state.title = "半饱良味潮汕猪肉脯宅人食堂潮汕小吃特产碳烤猪肉干120g" st.session_state.fact = "潮汕,不是潮州潮州一词始于隋文帝开皇十年,距今不到两千年。猪肉脯是一种用猪肉经腌制、烘烤的片状肉制品,食用方便、制作考究、美味可口、耐贮藏和便于运输的中式传统发达国家都有全国统一的急救电话号码。特产指某地特有的或特别著名的产品,有文化内涵或历史,亦指只有在某地才生产的一种产品。" #st.write('Fact: \n', st.session_state.fact) st.write(st.session_state.fact) input2 = st.selectbox( 'please choose category:', ("1. 家庭主妇", "2. 烹饪达人", "3. 买鞋控", "4. 数码达人", "5. 吃货", "6. 爱包人", "7. 高富帅", )) st.write('You selected:', input2) input2=input2.split(". ")[0] st.session_state.aspect = "<"+str(input2)+">" input3 = st.selectbox( 'please choose aspect:', ("1. appearance", "2. texture", "3. function")) st.write('You selected:', input3) input3=int(input3.split(". ")[0]) cond = "" if input3==1: st.session_state.cat="" if input3==2: st.session_state.cat="" if input3==3: st.session_state.cat="" #cond = cond+" "+aspect #print(title) #print(fact) #print(cond) if st.button('result'): tokenizer = text_tokenizer title_token_ids=tokenizer.encode(st.session_state.title, add_special_tokens=False) condition_token_ids=cond_tokenizer.EncodeAsIds(st.session_state.aspect+" "+st.session_state.cat) fact_token_ids=tokenizer.encode(st.session_state.fact, add_special_tokens=False) e = Example(title_token_ids, condition_token_ids, fact_token_ids) dm = KobeDataModule( [e], args.text_vocab_path, args.max_seq_len, 1, 1, ) for d in dm.test_dataloader(): #st.write(st.session_state.title) #st.write(st.session_state.fact) #st.write(st.session_state.aspect+" "+st.session_state.cat) st.write("result:") st.write(''.join(model.test_step(d ,1)).replace(" ",""))