import streamlit as st from Zmaker import Zmaker if __name__ == "__main__": #ファインチューニング済みモデルの読み込み with st.spinner(text = "loading GPT-2..."): if not ("AI" in st.session_state.keys()): st.session_state["AI"] = Zmaker( ft_path = "model/gpt2-ft/" ) #設定用サイドバーの設定 with st.sidebar: st.title("GPT-2のパラメータ") #max_lenの設定用スライダ sld_max_len = st.sidebar.slider( "length of the sentence", min_value = 0, max_value = 256, value = (25, 75), step = 1, key = "length" ) #temperatureの設定用スライダ sld_temp = st.sidebar.slider( "temperature", min_value = 0.1, max_value = 1.5, value = 0.1, step = 0.1, key = "temp" ) #top_kの設定用スライダ sld_top_k = st.sidebar.slider( "top_k", min_value = 0, max_value = 500, value = 40, step = 1, key = "top_k" ) #top_pの設定用スライダ sld_top_p = st.sidebar.slider( "top_p", min_value = 0.01, max_value = 1.0, value = 0.95, step = 0.01, key = "top_p" ) #repeat_ngram_sizeの設定用スライダ sld_top_p = st.sidebar.slider( "repeat_ngram_size ", min_value = 1, max_value = 10, value = 1, step = 1, key = "repeat_ngram_size" ) #メインフォームの設定 with st.form(key = "Letter Form", clear_on_submit = False): st.title("おてがみ 入力欄") body = st.empty() if ("letter_body" in st.session_state.keys()): ret = body.text_area( label = "お手紙を途中まで漢字+ひらがなで書いてください。続きをAIが生成します。\n"\ "本アプリで生成できるのは本文のみです。", value = st.session_state["letter_body"] ) else: ret = body.text_area( label = "お手紙を途中まで漢字+ひらがなで書いてください。\n"\ "続きをAIが生成します。", value = "ズッポシ村へようこそ!" ) sub = st.form_submit_button("Generate") #注意事項 with st.expander("注意事項"): st.text( "※このAIは「どうぶつの森e+実況プレイ」"\ " (https://www.nicovideo.jp/mylist/45062007)において"\ " 稲葉百万鉄氏により作成された文章を学習データに用いております。\n" " また,教師データの作成においてmintmama氏の作成した"\ " 「ズッポシむら手紙集」(https://www.nicovideo.jp/series/85494)\n"\ "を用いております。" ) #submitボタンが押された if sub == True: #predictに必要な条件をGUIで設定した値に更新 st.session_state["AI"].min_len = st.session_state["length"][0] st.session_state["AI"].max_len = st.session_state["length"][-1] st.session_state["AI"].top_k = st.session_state["top_k"] st.session_state["AI"].top_p = st.session_state["top_p"] st.session_state["AI"].temp = st.session_state["temp"] st.session_state["AI"].repeat_ngram_size = st.session_state["repeat_ngram_size"] #AIによる予測を実行 with st.spinner(text = "generating..."): prompt = ret text = str(st.session_state["AI"].GenLetter(""+prompt)[0]) text = text.replace('', '') text = text.replace('', '') st.session_state["letter_body"] = text st.experimental_rerun()