John Doe commited on
Commit
04bdba9
1 Parent(s): 684ff8a

app.py, Zmaker.py, requirements.txtのアップロード

Browse files

app.py : streamlitによるGUI制御
Zmaker.py : fine-tuning済みのGPT-2で推論を行うためのコード

Files changed (2) hide show
  1. app.py +94 -0
  2. requirements.txt +67 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from Zmaker import Zmaker
3
+
4
+ if __name__ == "__main__":
5
+
6
+ #ファインチューニング済みモデルの読み込み
7
+ with st.spinner(text = "loading GPT-2..."):
8
+ if not ("AI" in st.session_state.keys()):
9
+ st.session_state["AI"] = Zmaker(
10
+ ft_path = "model/gpt2-ft/"
11
+ )
12
+
13
+ #設定用サイドバーの設定
14
+ with st.sidebar:
15
+ st.title("GPT-2のパラメータ")
16
+
17
+ #max_lenの設定用スライダ
18
+ sld_max_len = st.sidebar.slider(
19
+ "length of the sentence", min_value = 0, max_value = 256,
20
+ value = (25, 75), step = 1, key = "length"
21
+ )
22
+
23
+ #temperatureの設定用スライダ
24
+ sld_temp = st.sidebar.slider(
25
+ "temperature", min_value = 0.1, max_value = 1.5,
26
+ value = 0.1, step = 0.1, key = "temp"
27
+ )
28
+
29
+ #top_kの設定用スライダ
30
+ sld_top_k = st.sidebar.slider(
31
+ "top_k", min_value = 0, max_value = 500,
32
+ value = 40, step = 1, key = "top_k"
33
+ )
34
+
35
+ #top_pの設定用スライダ
36
+ sld_top_p = st.sidebar.slider(
37
+ "top_p", min_value = 0.01, max_value = 1.0,
38
+ value = 0.95, step = 0.01, key = "top_p"
39
+ )
40
+
41
+ #repeat_ngram_sizeの設定用スライダ
42
+ sld_top_p = st.sidebar.slider(
43
+ "repeat_ngram_size ", min_value = 1, max_value = 10,
44
+ value = 1, step = 1, key = "repeat_ngram_size"
45
+ )
46
+
47
+ #メインフォームの設定
48
+ with st.form(key = "Letter Form", clear_on_submit = False):
49
+ st.title("おてがみ 入力欄")
50
+ body = st.empty()
51
+ if ("letter_body" in st.session_state.keys()):
52
+ ret = body.text_area(
53
+ label = "お手紙を途中まで漢字+ひらがなで書いてください。続きをAIが生成します。\n"\
54
+ "本アプリで生成できるのは本文のみです。",
55
+ value = st.session_state["letter_body"]
56
+ )
57
+ else:
58
+ ret = body.text_area(
59
+ label = "お手紙を途中まで漢字+ひらがなで書いてください。\n"\
60
+ "続きをAIが生成します。",
61
+ value = "ズッポシ村へようこそ!"
62
+ )
63
+ sub = st.form_submit_button("Generate")
64
+
65
+ #注意事項
66
+ with st.expander("注意事項"):
67
+ st.text(
68
+ "※このAIは「どうぶつの森e+実況プレイ」"\
69
+ " (https://www.nicovideo.jp/mylist/45062007)において"\
70
+ " 稲葉百万鉄氏により作成された文章を学習データに用いております。\n"
71
+ " また,教師データの作成においてmintmama氏の作成した"\
72
+ " 「ズッポシむら手紙集」(https://www.nicovideo.jp/series/85494)\n"\
73
+ "を用いております。"
74
+ )
75
+
76
+
77
+ #submitボタンが押された
78
+ if sub == True:
79
+ #predictに必要な条件をGUIで設定した値に更新
80
+ st.session_state["AI"].min_len = st.session_state["length"][0]
81
+ st.session_state["AI"].max_len = st.session_state["length"][-1]
82
+ st.session_state["AI"].top_k = st.session_state["top_k"]
83
+ st.session_state["AI"].top_p = st.session_state["top_p"]
84
+ st.session_state["AI"].temp = st.session_state["temp"]
85
+ st.session_state["AI"].repeat_ngram_size = st.session_state["repeat_ngram_size"]
86
+
87
+ #AIによる予測を実行
88
+ with st.spinner(text = "generating..."):
89
+ prompt = ret
90
+ text = str(st.session_state["AI"].GenLetter("<s>"+prompt)[0])
91
+ text = text.replace('<s>', '')
92
+ text = text.replace('</s>', '')
93
+ st.session_state["letter_body"] = text
94
+ st.experimental_rerun()
requirements.txt ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.18.0
2
+ altair==4.2.2
3
+ attrs==23.1.0
4
+ blinker==1.6.2
5
+ cachetools==5.3.0
6
+ certifi==2022.12.7
7
+ charset-normalizer==3.1.0
8
+ click==8.1.3
9
+ colorama==0.4.6
10
+ decorator==5.1.1
11
+ entrypoints==0.4
12
+ filelock==3.12.0
13
+ fsspec==2023.4.0
14
+ gitdb==4.0.10
15
+ GitPython==3.1.31
16
+ huggingface-hub==0.14.1
17
+ idna==3.4
18
+ importlib-metadata==6.6.0
19
+ Jinja2==3.1.2
20
+ JsonForm==0.0.2
21
+ jsonschema==4.17.3
22
+ JsonSir==0.0.2
23
+ markdown-it-py==2.2.0
24
+ MarkupSafe==2.1.2
25
+ mdurl==0.1.2
26
+ mojimoji==0.0.12
27
+ mpmath==1.3.0
28
+ networkx==3.1
29
+ numpy==1.24.3
30
+ packaging==23.1
31
+ pandas==2.0.1
32
+ Pillow==9.5.0
33
+ protobuf==3.20.3
34
+ psutil==5.9.5
35
+ pyarrow==12.0.0
36
+ pydeck==0.8.1b0
37
+ Pygments==2.15.1
38
+ Pympler==1.0.1
39
+ pyrsistent==0.19.3
40
+ python-dateutil==2.8.2
41
+ Python-EasyConfig==0.1.7
42
+ pytz==2023.3
43
+ pytz-deprecation-shim==0.1.0.post0
44
+ PyYAML==6.0
45
+ regex==2023.3.23
46
+ requests==2.29.0
47
+ rich==13.3.5
48
+ sentencepiece==0.1.99
49
+ six==1.16.0
50
+ smmap==5.0.0
51
+ streamlit==1.22.0
52
+ sympy==1.11.1
53
+ tenacity==8.2.2
54
+ tokenizers==0.13.3
55
+ toml==0.10.2
56
+ toolz==0.12.0
57
+ torch==2.0.0
58
+ tornado==6.3.1
59
+ tqdm==4.65.0
60
+ transformers==4.28.1
61
+ typing_extensions==4.5.0
62
+ tzdata==2023.3
63
+ tzlocal==4.3
64
+ urllib3==1.26.15
65
+ validators==0.20.0
66
+ watchdog==3.0.0
67
+ zipp==3.15.0