Spaces:
Sleeping
Sleeping
import streamlit as st | |
import cshogi | |
from IPython.display import display | |
from transformers import T5ForConditionalGeneration, T5Tokenizer | |
import pandas as pd | |
import pickle | |
from datasets import load_dataset | |
import tools | |
#モデルの読み込み | |
tokenizer = T5Tokenizer.from_pretrained("pizzagatakasugi/shogi_t5", is_fast=True) | |
model = T5ForConditionalGeneration.from_pretrained("pizzagatakasugi/shogi_t5_v1") | |
model.eval() | |
st.title("将棋解説文の自動生成") | |
df = pd.read_csv("./dataset10.csv") | |
num = st.text_input("0から9の数字を入力") | |
if num in [str(x) for x in list(range(10))]: | |
df = df.iloc[int(num)] | |
st.write(df["game_type"],df["precedence_name"],df["follower_name"]) | |
sfen = df["sfen"].split("\n") | |
bestlist = eval(df["bestlist"]) | |
best2list = eval(df["best2list"]) | |
te = [] | |
te_sf = [] | |
#文字の正規化 | |
movelist = tools.nomalize_sfen(sfen) | |
#盤面表示 | |
s = st.selectbox(label="手数を選択",options=te) | |
with st.expander("parameter"): | |
temp = st.slider("temperature",min_value=0.0,max_value=1.0,step=0.01,value=0.1,key=1) | |
top_k = st.slider("top_k",min_value=1,max_value=50,step=1,value=5,key=2) | |
top_p = st.slider("top_p",min_value=0.5,max_value=1.0,step=0.01,value=0.90,key=3) | |
beams = st.slider("num_beams",min_value=1,max_value=10,step=1,value=1,key=4) | |
tokens = st.slider("min_new_tokens",min_value=0,max_value=50,value=20,key=5) | |
reload = st.button('盤面生成',key=0) | |
if s in te and reload == True: | |
reload = False | |
idx = te.index(s) | |
board = cshogi.Board(sfen=te_sf[idx+1]) | |
st.markdown(board.to_svg(),unsafe_allow_html=True) | |
if x %2 == 1: | |
teban = "△" | |
else: | |
teban = "▲" | |
#入力文作成 | |
kifs="" | |
cnt = 0 | |
cnt1 = 0 | |
teban1 = teban | |
teban2 = teban | |
for kif in movelist: | |
if cnt > idx: | |
break | |
kif = kif.split("(")[0] | |
kifs += kif | |
cnt += 1 | |
best = "" | |
for x in bestlist[idx]: | |
if teban1 == "▲": | |
teban1 = "△" | |
else: | |
teban1 = "▲" | |
best += teban1+x | |
if cnt1 == 2: | |
break | |
else: | |
cnt1 += 1 | |
best2 = "" | |
for y in best2list[idx]: | |
if teban2 == "▲": | |
teban2 = "△" | |
else: | |
teban2 = "▲" | |
best2 += teban2+y | |
break | |
#st.write(idx,"入力",input) | |
with st.spinner("推論中です..."): | |
input = sfen[0]+sfen[1]+kifs+"最善手の予測手順は"+best+"次善手の予測手順は"+best2 | |
tokenized_inputs = tokenizer.encode( | |
input, max_length= 512, truncation=True, | |
padding="max_length", return_tensors="pt" | |
) | |
output_ids = model.generate(input_ids=tokenized_inputs, | |
max_length=512, | |
do_sample = True, | |
temperature = temp, | |
num_beams = beams, | |
top_k = top_k, | |
top_p = top_p, | |
min_new_tokens = tokens, | |
) | |
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True, | |
clean_up_tokenization_spaces=False) | |
st.write(output_text) | |
# temperature = st.slider("temperature",min_value=0.0,max_value=1.0,step=0.01,value=0.3,key=1) | |
# num_beams = st.slider("num_beams",min_value=1,max_value=5,step=1,value=1,key=2) | |
# min_new_tokens = st.slider("min_new_tokens",min_value=0,max_value=100,value=30,key=3) | |