JunhuiJi's picture
Update app.py
865690f
raw
history blame
3.13 kB
import streamlit as st
from model import GPT2LMHeadModel
from transformers import BertTokenizer
import argparse
import os
import torch
import time
from generate_title import predict_one_sample
st.set_page_config(page_title="Demo", initial_sidebar_state="auto", layout="wide")
# @st.cache_data(allow_output_mutation=True)
def get_model(device, vocab_path, model_path):
tokenizer = BertTokenizer.from_pretrained(vocab_path, do_lower_case=True)
model = GPT2LMHeadModel.from_pretrained(model_path)
model.to(device)
model.eval()
return tokenizer, model
device_ids = 0
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICE"] = str(device_ids)
device = torch.device("cuda:1" if torch.cuda.is_available() and int(device_ids) >= 0 else "cpu")
tokenizer, model = get_model(device, "vocab.txt", "checkpoint-55922")
def writer():
st.markdown(
"""
## Text Summary DEMO
"""
)
st.sidebar.subheader("配置参数")
batch_size = st.sidebar.slider("batch_size", min_value=0, max_value=10, value=3)
generate_max_len = st.sidebar.number_input("generate_max_len", min_value=0, max_value=64, value=32, step=1)
repetition_penalty = st.sidebar.number_input("repetition_penalty", min_value=0.0, max_value=10.0, value=1.2,
step=0.1)
top_k = st.sidebar.slider("top_k", min_value=0, max_value=10, value=3, step=1)
top_p = st.sidebar.number_input("top_p", min_value=0.0, max_value=1.0, value=0.95, step=0.01)
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default=batch_size, type=int, help='生成标题的个数')
parser.add_argument('--generate_max_len', default=generate_max_len, type=int, help='生成标题的最大长度')
parser.add_argument('--repetition_penalty', default=repetition_penalty, type=float, help='重复处罚率')
parser.add_argument('--top_k', default=top_k, type=float, help='解码时保留概率最高的多少个标记')
parser.add_argument('--top_p', default=top_p, type=float, help='解码时保留概率累加大于多少的标记')
parser.add_argument('--max_len', type=int, default=512, help='输入模型的最大长度,要比config中n_ctx小')
args = parser.parse_args()
content = st.text_area("输入正文", value="近期美元指数大幅攀升,一度升至二十年高位,对此,市场人士认为,主要投资者对全球经济形势悲观预期升温,加上近日美联储决议鹰派基调,导致美元避险需求抬头。", max_chars=512)
if st.button("一键生成摘要"):
start_message = st.empty()
start_message.write("正在抽取,请等待...")
start_time = time.time()
titles = predict_one_sample(model, tokenizer, device, args, content)
end_time = time.time()
start_message.write("抽取完成,耗时{}s".format(end_time - start_time))
for i, title in enumerate(titles):
st.text_input("第{}个结果".format(i + 1), title)
else:
st.stop()
if __name__ == '__main__':
writer()