gallaries / app.py
OzoneAsai's picture
Update app.py
d83ba87
raw
history blame
10.8 kB
print("[1/7] Importing streamlit...", end=" ", flush=True)
import streamlit as st
print("done")
print("[2/7] Importing something...", end=" ", flush=True)
import os
import pandas as pd
from datetime import datetime
from PIL import Image
from PIL import ImageFilter
from io import BytesIO
import zipfile
import base64
print("done")
print("[3/7] Importing deepdanbooru...", end=" ", flush=True)
import deepdanbooru as dd
print("done")
print("[4/7] Importing huggingface_hub...", end=" ", flush=True)
import huggingface_hub
print("done")
print("[5/7] Importing tensorflow...", end=" ", flush=True)
import tensorflow as tf
print("done")
print("[6/7] Importing numpy...", end=" ", flush=True)
import numpy as np
print("done")
print("[6/7] Importing transformers...", end=" ", flush=True)
from transformers import pipeline
print("done")
def extract_zip_to_temp(zip_file_path, temp_folder):
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
zip_ref.extractall(temp_folder)
zip_file_path = 'wolf_2m.zip'
temp_folder = 'temp'
extract_zip_to_temp(zip_file_path, temp_folder)
# ページごとの表示数
PAGE_SIZE = 20
# ファイルの保存先フォルダー
photos_folder = "photos"
# インデックスファイルのパス
index_file_path = "index.csv"
#blur_enabled = st.sidebar.checkbox("NSFW画像にブラーをかける", value=True, key=blur_toggle_key)
# ブラーのトグル用のキー
import random
import string
blur_toggle_key = "blurOrNot"
# st.session_stateにblur_enabledがない場合に定義
if "blur_enabled" not in st.session_state:
st.session_state.blur_enabled = st.sidebar.checkbox("NSFW画像にブラーをかける", value=True, key=blur_toggle_key)
# blur_enabledの値を取得
blur_enabled = st.session_state.blur_enabled
# タグ検索用のテキスト入力
tag_search_input = st.sidebar.text_input("タグ検索 (カンマでセパレート)", "", key="tag_search")
# インデックスを初期化
photo_df = pd.DataFrame(columns=["File Name", "Timestamp", "Tags"])
# ページングのための変数
page_num = st.sidebar.number_input("ページ番号", value=1, min_value=1, max_value=(len(photo_df) // PAGE_SIZE) + 1, key="page_num")
# 並び替えのための選択肢
sort_options = ["TimeStamp 昇順", "TimeStamp 降順", "名前 昇順", "名前 降順"]
selected_sort = st.sidebar.selectbox("写真の並び替え", sort_options, key="selected_sort")
# タグ検索があればフィルタリング
if tag_search_input:
tags_to_search = [tag.strip() for tag in tag_search_input.split(',')]
current_page = photo_df[photo_df["Tags"].apply(lambda x: all(tag in x.split(', ') for tag in tags_to_search))]
else:
current_page = photo_df.copy()
# 並び替え
if "TimeStamp 昇順" in selected_sort:
current_page = current_page.sort_values(by="Timestamp", ascending=True)
elif "TimeStamp 降順" in selected_sort:
current_page = current_page.sort_values(by="Timestamp", ascending=False)
elif "名前 昇順" in selected_sort:
current_page = current_page.sort_values(by="File Name", ascending=True)
elif "名前 降順" in selected_sort:
current_page = current_page.sort_values(by="File Name", ascending=False)
if tag_search_input:
st.sidebar.subheader("タグ検索結果")
st.sidebar.write(f"検索したタグ: {', '.join(tags_to_search)}")
st.sidebar.write(f"結果数: {len(current_page)}")
# タグ付け関数
def predict_tags(image: Image.Image, score_threshold: float) -> tuple[dict[str, float], dict[str, float], str]:
_, height, width, _ = model.input_shape
image = np.asarray(image)
image = tf.image.resize(image, size=(height, width), method=tf.image.ResizeMethod.AREA, preserve_aspect_ratio=True)
image = image.numpy()
image = dd.image.transform_and_pad_image(image, width, height)
image = image / 255.
probs = model.predict(image[None, ...])[0]
probs = probs.astype(float)
indices = np.argsort(probs)[::-1]
result_all = dict()
result_threshold = dict()
for index in indices:
label = labels[index]
prob = probs[index]
result_all[label] = prob
if prob < score_threshold:
break
result_threshold[label] = prob
result_text = ', '.join(result_all.keys())
return result_threshold, result_all, result_text
# NSFW 判定関数
def predict_nsfw(image: Image.Image) -> dict[str, float]:
classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
result = classifier(image)[0]
return {result['label']: result['score']}
# ブラーをかける関数
def blur_image(image: Image.Image, blur_enabled: bool, nsfw_score: float) -> Image.Image:
if blur_enabled:
nsfw_result = predict_nsfw(image)
nsfw_score = nsfw_result.get('nsfw', 0.0)
if nsfw_score >= 0.75:
blur_radius=20*nsfw_score
image = image.filter(ImageFilter.GaussianBlur(radius=blur_radius))
return image
# ページングと並び替えのためのデータを取得する関数
def load_data():
if os.path.exists(index_file_path):
return pd.read_csv(index_file_path)
else:
return pd.DataFrame(columns=["File Name", "Timestamp", "Tags"])
# ブラーのトグル用のキー
#blur_enabled = st.sidebar.checkbox("NSFW画像にブラーをかける", value=True, key=blur_toggle_key)
if "blur_enabled" not in st.session_state:
st.session_state.blur_enabled = st.sidebar.checkbox("NSFW画像にブラーをかける", value=True, key=blur_toggle_key)
# アップロードされた写真を保存する関数
def save_uploaded_photo(uploaded_photo, file_name):
if not os.path.exists(photos_folder):
os.makedirs(photos_folder)
# アップロードされた写真を保存
image = Image.open(uploaded_photo)
# ブラーの有効/無効をトグルで制御
#blur_enabled = st.sidebar.checkbox("NSFW画像にブラーをかける", value=True, key=blur_toggle_key)
# ブラーの適用
image = blur_image(image, blur_enabled, 0.75)
image.save(os.path.join(photos_folder, file_name), "PNG")
# アップロードされた写真を表示する関数
def display_photos(photos):
for photo_info in photos.iterrows():
row = photo_info[1]
photo_path = os.path.join(photos_folder, row["File Name"])
image = Image.open(photo_path)
# タグを予測して表示
result_threshold, result_all, result_text = predict_tags(image, 0.7)
st.image(image, caption=row["File Name"], use_column_width=True)
# タグを表示
st.write("タグ:", ", ".join(result_all.keys()))
# フォルダーの中身を zip ファイルとしてダウンロード
def download_photos_as_zip(file_paths):
# Zip ファイル作成
with BytesIO() as zip_buffer:
with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED) as zip_file:
for file_path in file_paths:
zip_file.write(file_path, os.path.basename(file_path))
# ダウンロードリンク表示
st.markdown(
f"**[ダウンロード ZIPファイル](data:application/zip;base64,{base64.b64encode(zip_buffer.getvalue()).decode()})**",
unsafe_allow_html=True
)
# モデルとラベルをダウンロードする関数
def load_model():
path = huggingface_hub.hf_hub_download('public-data/DeepDanbooru', 'model-resnet_custom_v3.h5')
model = tf.keras.models.load_model(path)
return model
def load_labels():
path = huggingface_hub.hf_hub_download('public-data/DeepDanbooru', 'tags.txt')
with open(path) as f:
labels = [line.strip() for line in f.readlines()]
return labels
# Streamlit アプリケーションのメイン部分
def main():
st.sidebar.title("アップロードオプション")
uploaded_photos = st.sidebar.file_uploader("写真をアップロードしてください", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
# モデルとラベルをダウンロードする
global model, labels
model = load_model()
labels = load_labels()
# インデックスを初期化
global photo_df
photo_df = load_data()
if uploaded_photos:
for uploaded_photo in uploaded_photos:
file_name = f"{datetime.now().strftime('%Y%m%d%H%M%S%f')}.png"
save_uploaded_photo(uploaded_photo, file_name)
# インデックスに追加
photo_df = pd.concat([photo_df, pd.DataFrame([[os.path.basename(file_name), datetime.now(), ""]], columns=["File Name", "Timestamp", "Tags"])], ignore_index=True)
# インデックスを更新
photo_df.sort_values(by="Timestamp", inplace=True, ascending=False)
photo_df.reset_index(drop=True, inplace=True)
photo_df.to_csv(index_file_path, index=False)
st.subheader("アップロードされた写真")
# ページングのための変数
page_num = st.sidebar.number_input("ページ番号", value=1, min_value=1, max_value=(len(photo_df) // PAGE_SIZE) + 1)
start_idx = (page_num - 1) * PAGE_SIZE
end_idx = min(start_idx + PAGE_SIZE, len(photo_df))
current_page = photo_df.iloc[start_idx:end_idx]
# 並び替えのための選択肢
#selected_sort = st.sidebar.selectbox("写真の並び替え", ["TimeStamp 昇順", "TimeStamp 降順", "名前 昇順", "名前 降順"])
# 並び替え
if selected_sort == "TimeStamp 昇順":
current_page = current_page.sort_values(by="Timestamp", ascending=True)
elif selected_sort == "TimeStamp 降順":
current_page = current_page.sort_values(by="Timestamp", ascending=False)
elif selected_sort == "名前 昇順":
current_page = current_page.sort_values(by="File Name", ascending=True)
elif selected_sort == "名前 降順":
current_page = current_page.sort_values(by="File Name", ascending=False)
# タグ検索があれば結果を表示
if tag_search_input:
st.sidebar.subheader("タグ検索結果")
st.sidebar.write(f"検索したタグ: {', '.join(tags_to_search)}")
st.sidebar.write(f"結果数: {len(current_page)}")
# ページに写真を表示
display_photos(current_page)
# Next ボタン
if st.button("Next"):
page_num += 1
st.experimental_rerun()
# Previous ボタン
if st.button("Previous"):
page_num -= 1
st.experimental_rerun()
# Zip ダウンロードボタン
if st.button("写真をダウンロード (ZIP)"):
file_paths = [os.path.join(photos_folder, file) for file in os.listdir(photos_folder)]
download_photos_as_zip(file_paths)
if __name__ == "__main__":
main()