gallaries / app.py
OzoneAsai's picture
Update app.py
3a66846
raw
history blame
No virus
8.01 kB
print("[1/7] Importing streamlit...", end=" ", flush=True)
import streamlit as st
print("done")
print("[2/7] Importing something...", end=" ", flush=True)
import os
import pandas as pd
from datetime import datetime
from PIL import Image
from PIL import ImageFilter
from io import BytesIO
import zipfile
import base64
print("done")
print("[3/7] Importing deepdanbooru...", end=" ", flush=True)
import deepdanbooru as dd
print("done")
print("[4/7] Importing huggingface_hub...", end=" ", flush=True)
import huggingface_hub
print("done")
print("[5/7] Importing tensorflow...", end=" ", flush=True)
import tensorflow as tf
print("done")
print("[6/7] Importing numpy...", end=" ", flush=True)
import numpy as np
print("done")
print("[6/7] Importing transformers...", end=" ", flush=True)
from transformers import pipeline
print("done")
# ページごとの表示数
PAGE_SIZE = 20
# ファイルの保存先フォルダー
photos_folder = "photos"
# インデックスファイルのパス
index_file_path = "index.csv"
# タグ付け関数
def predict_tags(image: Image.Image, score_threshold: float) -> tuple[dict[str, float], dict[str, float], str]:
_, height, width, _ = model.input_shape
image = np.asarray(image)
image = tf.image.resize(image, size=(height, width), method=tf.image.ResizeMethod.AREA, preserve_aspect_ratio=True)
image = image.numpy()
image = dd.image.transform_and_pad_image(image, width, height)
image = image / 255.
probs = model.predict(image[None, ...])[0]
probs = probs.astype(float)
indices = np.argsort(probs)[::-1]
result_all = dict()
result_threshold = dict()
for index in indices:
label = labels[index]
prob = probs[index]
result_all[label] = prob
if prob < score_threshold:
break
result_threshold[label] = prob
result_text = ', '.join(result_all.keys())
return result_threshold, result_all, result_text
# NSFW 判定関数
def predict_nsfw(image: Image.Image) -> dict[str, float]:
classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
result = classifier(image)[0]
return {result['label']: result['score']}
# ブラーをかける関数
def blur_image(image: Image.Image, blur_enabled: bool, nsfw_score: float) -> Image.Image:
if blur_enabled:
nsfw_result = predict_nsfw(image)
nsfw_score = nsfw_result.get('nsfw', 0.0)
if nsfw_score >= 0.75:
image = image.filter(ImageFilter.BLUR)
return image
# blur_toggle_keyの初期値を設定
blur_toggle_key = "blur_toggle_unique_key_for_sidebar"
# ブラーの有効/無効をトグルで制御する関数
def get_blur_enabled():
return st.sidebar.checkbox("NSFW画像にブラーをかける", value=True, key=blur_toggle_key)
# アップロードされた写真を保存する関数
def save_uploaded_photo(uploaded_photo, file_name):
if not os.path.exists(photos_folder):
os.makedirs(photos_folder)
# アップロードされた写真を保存
image = Image.open(uploaded_photo)
# ブラーの有効/無効をトグルで制御
blur_enabled = get_blur_enabled()
# ブラーの適用
image = blur_image(image, blur_enabled, 0.75)
image.save(os.path.join(photos_folder, file_name), "PNG")
# ページングと並び替えのためのデータを取得する関数
def load_data():
if os.path.exists(index_file_path):
return pd.read_csv(index_file_path)
else:
return pd.DataFrame(columns=["File Name", "Timestamp", "Tags"])
# アップロードされた写真を表示する関数
def display_photos(photos):
for photo_info in photos.iterrows():
row = photo_info[1]
photo_path = os.path.join(photos_folder, row["File Name"])
image = Image.open(photo_path)
# タグを予測して表示
result_threshold, result_all, result_text = predict_tags(image, 0.7)
st.image(image, caption=row["File Name"], use_column_width=True)
# タグを表示
st.write("タグ:", ", ".join(result_all.keys()))
# フォルダーの中身を zip ファイルとしてダウンロード
def download_photos_as_zip(file_paths):
# Zip ファイル作成
with BytesIO() as zip_buffer:
with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED) as zip_file:
for file_path in file_paths:
zip_file.write(file_path, os.path.basename(file_path))
# ダウンロードリンク表示
st.markdown(
f"**[ダウンロード ZIPファイル](data:application/zip;base64,{base64.b64encode(zip_buffer.getvalue()).decode()})**",
unsafe_allow_html=True
)
# モデルとラベルをダウンロードする関数
def load_model():
path = huggingface_hub.hf_hub_download('public-data/DeepDanbooru', 'model-resnet_custom_v3.h5')
model = tf.keras.models.load_model(path)
return model
def load_labels():
path = huggingface_hub.hf_hub_download('public-data/DeepDanbooru', 'tags.txt')
with open(path) as f:
labels = [line.strip() for line in f.readlines()]
return labels
# Streamlit アプリケーションのメイン部分
def main():
st.sidebar.title("アップロードオプション")
uploaded_photos = st.sidebar.file_uploader("写真をアップロードしてください", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
# モデルとラベルをダウンロードする
global model, labels
model = load_model()
labels = load_labels()
# インデックスを初期化
global photo_df
photo_df = load_data()
if uploaded_photos:
for uploaded_photo in uploaded_photos:
file_name = f"{datetime.now().strftime('%Y%m%d%H%M%S%f')}.png"
save_uploaded_photo(uploaded_photo, file_name)
# インデックスに追加
photo_df = pd.concat([photo_df, pd.DataFrame([[os.path.basename(file_name), datetime.now(), ""]], columns=["File Name", "Timestamp", "Tags"])], ignore_index=True)
# インデックスを更新
photo_df.sort_values(by="Timestamp", inplace=True, ascending=False)
photo_df.reset_index(drop=True, inplace=True)
photo_df.to_csv(index_file_path, index=False)
st.subheader("アップロードされた写真")
# ページングのための変数
page_num = st.sidebar.number_input("ページ番号", value=1, min_value=1, max_value=(len(photo_df) // PAGE_SIZE) + 1)
start_idx = (page_num - 1) * PAGE_SIZE
end_idx = min(start_idx + PAGE_SIZE, len(photo_df))
current_page = photo_df.iloc[start_idx:end_idx]
# 並び替えのための選択肢
selected_sort = st.sidebar.selectbox("写真の並び替え", ["TimeStamp 昇順", "TimeStamp 降順", "名前 昇順", "名前 降順"])
# 並び替え
if selected_sort == "TimeStamp 昇順":
current_page = current_page.sort_values(by="Timestamp", ascending=True)
elif selected_sort == "TimeStamp 降順":
current_page = current_page.sort_values(by="Timestamp", ascending=False)
elif selected_sort == "名前 昇順":
current_page = current_page.sort_values(by="File Name", ascending=True)
elif selected_sort == "名前 降順":
current_page = current_page.sort_values(by="File Name", ascending=False)
# ページに写真を表示
display_photos(current_page)
# Next ボタン
if st.button("Next"):
page_num += 1
st.experimental_rerun()
# Previous ボタン
if st.button("Previous"):
page_num -= 1
st.experimental_rerun()
# Zip ダウンロードボタン
if st.button("写真をダウンロード (ZIP)"):
file_paths = [os.path.join(photos_folder, file) for file in os.listdir(photos_folder)]
download_photos_as_zip(file_paths)
print("all function defined.")
if __name__ == "__main__":
main()