Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# blur_toggle_keyの初期値を設定
|
2 |
blur_toggle_key = "blur_toggle_unique_key_for_sidebar"
|
3 |
|
@@ -21,6 +104,54 @@ def save_uploaded_photo(uploaded_photo, file_name):
|
|
21 |
|
22 |
image.save(os.path.join(photos_folder, file_name), "PNG")
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
# Streamlit アプリケーションのメイン部分
|
25 |
def main():
|
26 |
st.sidebar.title("アップロードオプション")
|
@@ -86,6 +217,6 @@ def main():
|
|
86 |
if st.button("写真をダウンロード (ZIP)"):
|
87 |
file_paths = [os.path.join(photos_folder, file) for file in os.listdir(photos_folder)]
|
88 |
download_photos_as_zip(file_paths)
|
89 |
-
|
90 |
if __name__ == "__main__":
|
91 |
main()
|
|
|
1 |
+
print("[1/7] Importing streamlit...", end=" ", flush=True)
|
2 |
+
|
3 |
+
import streamlit as st
|
4 |
+
print("done")
|
5 |
+
print("[2/7] Importing something...", end=" ", flush=True)
|
6 |
+
import os
|
7 |
+
import pandas as pd
|
8 |
+
from datetime import datetime
|
9 |
+
from PIL import Image
|
10 |
+
from PIL import ImageFilter
|
11 |
+
from io import BytesIO
|
12 |
+
import zipfile
|
13 |
+
import base64
|
14 |
+
print("done")
|
15 |
+
print("[3/7] Importing deepdanbooru...", end=" ", flush=True)
|
16 |
+
|
17 |
+
import deepdanbooru as dd
|
18 |
+
print("done")
|
19 |
+
print("[4/7] Importing huggingface_hub...", end=" ", flush=True)
|
20 |
+
|
21 |
+
import huggingface_hub
|
22 |
+
print("done")
|
23 |
+
print("[5/7] Importing tensorflow...", end=" ", flush=True)
|
24 |
+
|
25 |
+
import tensorflow as tf
|
26 |
+
print("done")
|
27 |
+
print("[6/7] Importing numpy...", end=" ", flush=True)
|
28 |
+
|
29 |
+
import numpy as np
|
30 |
+
print("done")
|
31 |
+
print("[6/7] Importing transformers...", end=" ", flush=True)
|
32 |
+
|
33 |
+
from transformers import pipeline
|
34 |
+
print("done")
|
35 |
+
|
36 |
+
# ページごとの表示数
|
37 |
+
PAGE_SIZE = 20
|
38 |
+
|
39 |
+
# ファイルの保存先フォルダー
|
40 |
+
photos_folder = "photos"
|
41 |
+
|
42 |
+
# インデックスファイルのパス
|
43 |
+
index_file_path = "index.csv"
|
44 |
+
|
45 |
+
# タグ付け関数
|
46 |
+
def predict_tags(image: Image.Image, score_threshold: float) -> tuple[dict[str, float], dict[str, float], str]:
|
47 |
+
_, height, width, _ = model.input_shape
|
48 |
+
image = np.asarray(image)
|
49 |
+
image = tf.image.resize(image, size=(height, width), method=tf.image.ResizeMethod.AREA, preserve_aspect_ratio=True)
|
50 |
+
image = image.numpy()
|
51 |
+
image = dd.image.transform_and_pad_image(image, width, height)
|
52 |
+
image = image / 255.
|
53 |
+
probs = model.predict(image[None, ...])[0]
|
54 |
+
probs = probs.astype(float)
|
55 |
+
|
56 |
+
indices = np.argsort(probs)[::-1]
|
57 |
+
result_all = dict()
|
58 |
+
result_threshold = dict()
|
59 |
+
for index in indices:
|
60 |
+
label = labels[index]
|
61 |
+
prob = probs[index]
|
62 |
+
result_all[label] = prob
|
63 |
+
if prob < score_threshold:
|
64 |
+
break
|
65 |
+
result_threshold[label] = prob
|
66 |
+
result_text = ', '.join(result_all.keys())
|
67 |
+
return result_threshold, result_all, result_text
|
68 |
+
|
69 |
+
# NSFW 判定関数
|
70 |
+
def predict_nsfw(image: Image.Image) -> dict[str, float]:
|
71 |
+
classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
|
72 |
+
result = classifier(image)[0]
|
73 |
+
return {result['label']: result['score']}
|
74 |
+
|
75 |
+
# ブラーをかける関数
|
76 |
+
def blur_image(image: Image.Image, blur_enabled: bool, nsfw_score: float) -> Image.Image:
|
77 |
+
if blur_enabled:
|
78 |
+
nsfw_result = predict_nsfw(image)
|
79 |
+
nsfw_score = nsfw_result.get('nsfw', 0.0)
|
80 |
+
if nsfw_score >= 0.75:
|
81 |
+
image = image.filter(ImageFilter.BLUR)
|
82 |
+
return image
|
83 |
+
|
84 |
# blur_toggle_keyの初期値を設定
|
85 |
blur_toggle_key = "blur_toggle_unique_key_for_sidebar"
|
86 |
|
|
|
104 |
|
105 |
image.save(os.path.join(photos_folder, file_name), "PNG")
|
106 |
|
107 |
+
# ページングと並び替えのためのデータを取得する関数
|
108 |
+
def load_data():
|
109 |
+
if os.path.exists(index_file_path):
|
110 |
+
return pd.read_csv(index_file_path)
|
111 |
+
else:
|
112 |
+
return pd.DataFrame(columns=["File Name", "Timestamp", "Tags"])
|
113 |
+
|
114 |
+
# アップロードされた写真を表示する関数
|
115 |
+
def display_photos(photos):
|
116 |
+
for photo_info in photos.iterrows():
|
117 |
+
row = photo_info[1]
|
118 |
+
photo_path = os.path.join(photos_folder, row["File Name"])
|
119 |
+
image = Image.open(photo_path)
|
120 |
+
|
121 |
+
# タグを予測して表示
|
122 |
+
result_threshold, result_all, result_text = predict_tags(image, 0.7)
|
123 |
+
|
124 |
+
st.image(image, caption=row["File Name"], use_column_width=True)
|
125 |
+
|
126 |
+
# タグを表示
|
127 |
+
st.write("タグ:", ", ".join(result_all.keys()))
|
128 |
+
|
129 |
+
# フォルダーの中身を zip ファイルとしてダウンロード
|
130 |
+
def download_photos_as_zip(file_paths):
|
131 |
+
# Zip ファイル作成
|
132 |
+
with BytesIO() as zip_buffer:
|
133 |
+
with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED) as zip_file:
|
134 |
+
for file_path in file_paths:
|
135 |
+
zip_file.write(file_path, os.path.basename(file_path))
|
136 |
+
|
137 |
+
# ダウンロードリンク表示
|
138 |
+
st.markdown(
|
139 |
+
f"**[ダウンロード ZIPファイル](data:application/zip;base64,{base64.b64encode(zip_buffer.getvalue()).decode()})**",
|
140 |
+
unsafe_allow_html=True
|
141 |
+
)
|
142 |
+
|
143 |
+
# モデルとラベルをダウンロードする関数
|
144 |
+
def load_model():
|
145 |
+
path = huggingface_hub.hf_hub_download('public-data/DeepDanbooru', 'model-resnet_custom_v3.h5')
|
146 |
+
model = tf.keras.models.load_model(path)
|
147 |
+
return model
|
148 |
+
|
149 |
+
def load_labels():
|
150 |
+
path = huggingface_hub.hf_hub_download('public-data/DeepDanbooru', 'tags.txt')
|
151 |
+
with open(path) as f:
|
152 |
+
labels = [line.strip() for line in f.readlines()]
|
153 |
+
return labels
|
154 |
+
|
155 |
# Streamlit アプリケーションのメイン部分
|
156 |
def main():
|
157 |
st.sidebar.title("アップロードオプション")
|
|
|
217 |
if st.button("写真をダウンロード (ZIP)"):
|
218 |
file_paths = [os.path.join(photos_folder, file) for file in os.listdir(photos_folder)]
|
219 |
download_photos_as_zip(file_paths)
|
220 |
+
print("all function defined.")
|
221 |
if __name__ == "__main__":
|
222 |
main()
|