Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import random | |
import torch | |
from sklearn.preprocessing import StandardScaler | |
from sklearn.cluster import KMeans | |
from matplotlib.font_manager import FontProperties | |
import pandas as pd | |
import seaborn as sns | |
from collections import Counter | |
import matplotlib.pyplot as plt | |
import numpy as np | |
# モデルとトークナイザーのロード | |
import time | |
from transformers import AutoModel | |
from matplotlib.offsetbox import OffsetImage, AnnotationBbox | |
import matplotlib.pyplot as plt | |
# モデルとトークナイザーのパス | |
model_path = 'use14/bert-base-japanese-v3/2024-0208-0323/model' | |
tokenizer_path = 'use14/bert-base-japanese-v3/2024-0208-0323/tokenizer' | |
# トークナイザーとモデルのロード | |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) | |
model = AutoModelForSequenceClassification.from_pretrained(model_path, num_labels=14) | |
font_path = 'NotoSansCJKjp-Bold.otf' # 実際のフォントのパスに置き換えてください | |
font_prop = FontProperties(fname=font_path) | |
# Streamlitアプリのタイトル | |
st.title("セリフチェッカー") | |
# セッション状態にキーが存在しない場合は、初期値を設定 | |
if 'button_clicked' not in st.session_state: | |
st.session_state.button_clicked = False | |
def on_button_click(): | |
# ボタンがクリックされた時の処理 | |
st.session_state.button_clicked = True | |
# 吹き出し風表示用のカスタムCSS | |
custom_css = """ | |
<style> | |
.bubble { | |
position: relative; | |
background: #F0F0F0; | |
border-radius: .4em; | |
padding: 10px; | |
max-width: 95%; /* 吹き出しの最大幅を90%に設定 */ | |
word-wrap: break-word; /* 長い単語でも折り返しを保証 */ | |
} | |
.bubble::after { | |
content: ''; | |
position: absolute; | |
top: 10px; | |
left: -10px; | |
width: 0; | |
height: 0; | |
border: 10px solid transparent; | |
border-right-color: #F0F0F0; | |
border-left: 0; | |
margin-top: 5px; | |
margin-left: 0; | |
} | |
</style> | |
""" | |
# CSSを使ってプログレスバーの色を変更 | |
st.markdown(""" | |
<style> | |
/* プログレスバーの色を変更 */ | |
.stProgress > div > div > div > div { | |
background-color: #008000; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# カテゴリリスト | |
category_list = [ | |
'5_紫上鏡一[教師,大人,タメ口]', '20_見嶋千里[怖め,大人,粗雑な言葉,特徴的な笑い]', '29_白鳥王子[キザ,大人,調子いい]', | |
'30_白城院素子[お嬢様,少女]', '32_水陰那月[ネガティブ,少年,敬語]', '50_緋崎平一郎[元気,少年,やんちゃ]', | |
'76_百知瑠璃[元気,少女,です!,写真]', '91_桜結衣[元気,少女,タメ口,アイドル]', '101_御伽美夜子[落ち着いている,大人,女性口調]', | |
'121_荊棘従道[執事,大人,敬語(丁寧語)]', '133_司馬萌香[姉御,少女,粗雑(ヤンキー)]', '134_菜野花[落ち着いてる,少女,敬語]', | |
'139_黒冬和馬[少しそっけない,少年,タメ口]', '142_四涼礼子[ダウナー,少女,語尾~,めんどくさがり]' | |
] | |
# カテゴリ選択用のセレクトボックス | |
selected_category = st.selectbox("1.目標キャラクターを選択", category_list) | |
# 選択されたカテゴリに対応する画像ファイル名の決定 | |
# カテゴリリストのインデックスを取得し、それに1を加えることで1から始まる画像ファイル番号を作成 | |
image_file_number = category_list.index(selected_category) + 1 | |
image_path = f"img/{image_file_number}.png" | |
image_width=300 | |
judge_text = st.text_input("2.セリフを入力 //例: 貴方達も迷ったんですか?, よし、変身完了だ。, 勿論幽霊は抜きにして…でしょうね。,[ですわ。,だろ!,ですね。]") | |
st.button("🔍 チェックする", on_click=on_button_click) | |
st.divider() | |
# 画面を2つの列に分割 | |
col1, col2 = st.columns([1, 4]) | |
# 左側の列に画像を表示 | |
with col1: | |
st.image(image_path, caption=selected_category, width=120) | |
# 右側の列にテキストボックスを配置 | |
with col2: | |
if judge_text: # ユーザーが何か入力した場合のみ表示 | |
st.markdown(custom_css, unsafe_allow_html=True) # カスタムCSSの適用 | |
st.markdown(f'<div class="bubble">{judge_text}</div>', unsafe_allow_html=True) | |
# 処理ステップ数に応じてプログレスバーを更新する関数 | |
def update_progress(step, total_steps): | |
progress = int((step / total_steps) * 100) | |
progress_bar.progress(progress) | |
total_steps = 5 # 処理を行う総ステップ数 | |
if st.session_state.button_clicked: | |
if judge_text: | |
# プログレスバーの初期化 | |
progress_bar = st.progress(0) | |
# トークナイズとテンソル化 | |
words = tokenizer.tokenize(judge_text) | |
word_ids = tokenizer.convert_tokens_to_ids(words) | |
word_tensor = torch.tensor([word_ids[:512]]) # 最大長を512に制限 | |
update_progress(1, total_steps) | |
# デバイスの自動選択 | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
word_tensor = word_tensor.to(device) | |
model = model.to(device) | |
update_progress(2, total_steps) | |
# 推論 | |
with torch.no_grad(): | |
y = model(word_tensor) | |
update_progress(3, total_steps) | |
# 最も近いカテゴリの決定 | |
pred = y.logits.argmax(-1) | |
st.write(f"最もらしい: {category_list[pred.item()]}") | |
update_progress(4, total_steps) | |
# 各クラスの確率計算 | |
probabilities = torch.softmax(y.logits, dim=-1) | |
top_prob, top_cat_indices = probabilities.topk(len(category_list)) | |
update_progress(5, total_steps) | |
# 確率とカテゴリ名の準備 | |
top_probabilities = top_prob.cpu().numpy()[0] | |
top_categories = [category_list[index] for index in top_cat_indices.cpu().numpy()[0]] | |
# 棒グラフの作成 | |
# 棒グラフの作成 | |
plt.figure(figsize=(10, 6)) | |
bars = plt.bar(range(len(top_categories)), top_probabilities, color='skyblue') | |
# 画像の読み込みと配置 | |
for i, (bar, category) in enumerate(zip(bars, top_categories)): | |
img_path = f'img/{i+1}.png' # ファイル名はcategory_listの順番+1の番号.png | |
image = plt.imread(img_path) | |
imagebox = OffsetImage(image, zoom=0.5) # zoomで画像のサイズを調整 | |
ab = AnnotationBbox(imagebox, (bar.get_x() + bar.get_width() / 2, bar.get_height()), frameon=False, box_alignment=(0.5, -0.2)) | |
plt.gca().add_artist(ab) | |
plt.xlabel('カテゴリ', fontproperties=font_prop) | |
plt.ylabel('確率', fontproperties=font_prop) | |
plt.ylim(0, 0.5) # y軸の範囲設定 | |
plt.xticks(range(len(top_categories)), top_categories, rotation=45, ha="right", fontproperties=font_prop) | |
plt.title('カテゴリ別確率', fontproperties=font_prop) | |
plt.show() | |
progress_bar.progress(100) | |
time.sleep(1) # 1秒待機 | |
progress_bar.empty() # プログレスバーを削除 | |
else: | |
st.write("入力してください。") |