OzoneAsai commited on
Commit
3a66846
1 Parent(s): 151def0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -1
app.py CHANGED
@@ -1,3 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # blur_toggle_keyの初期値を設定
2
  blur_toggle_key = "blur_toggle_unique_key_for_sidebar"
3
 
@@ -21,6 +104,54 @@ def save_uploaded_photo(uploaded_photo, file_name):
21
 
22
  image.save(os.path.join(photos_folder, file_name), "PNG")
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  # Streamlit アプリケーションのメイン部分
25
  def main():
26
  st.sidebar.title("アップロードオプション")
@@ -86,6 +217,6 @@ def main():
86
  if st.button("写真をダウンロード (ZIP)"):
87
  file_paths = [os.path.join(photos_folder, file) for file in os.listdir(photos_folder)]
88
  download_photos_as_zip(file_paths)
89
-
90
  if __name__ == "__main__":
91
  main()
 
1
+ print("[1/7] Importing streamlit...", end=" ", flush=True)
2
+
3
+ import streamlit as st
4
+ print("done")
5
+ print("[2/7] Importing something...", end=" ", flush=True)
6
+ import os
7
+ import pandas as pd
8
+ from datetime import datetime
9
+ from PIL import Image
10
+ from PIL import ImageFilter
11
+ from io import BytesIO
12
+ import zipfile
13
+ import base64
14
+ print("done")
15
+ print("[3/7] Importing deepdanbooru...", end=" ", flush=True)
16
+
17
+ import deepdanbooru as dd
18
+ print("done")
19
+ print("[4/7] Importing huggingface_hub...", end=" ", flush=True)
20
+
21
+ import huggingface_hub
22
+ print("done")
23
+ print("[5/7] Importing tensorflow...", end=" ", flush=True)
24
+
25
+ import tensorflow as tf
26
+ print("done")
27
+ print("[6/7] Importing numpy...", end=" ", flush=True)
28
+
29
+ import numpy as np
30
+ print("done")
31
+ print("[6/7] Importing transformers...", end=" ", flush=True)
32
+
33
+ from transformers import pipeline
34
+ print("done")
35
+
36
+ # ページごとの表示数
37
+ PAGE_SIZE = 20
38
+
39
+ # ファイルの保存先フォルダー
40
+ photos_folder = "photos"
41
+
42
+ # インデックスファイルのパス
43
+ index_file_path = "index.csv"
44
+
45
+ # タグ付け関数
46
+ def predict_tags(image: Image.Image, score_threshold: float) -> tuple[dict[str, float], dict[str, float], str]:
47
+ _, height, width, _ = model.input_shape
48
+ image = np.asarray(image)
49
+ image = tf.image.resize(image, size=(height, width), method=tf.image.ResizeMethod.AREA, preserve_aspect_ratio=True)
50
+ image = image.numpy()
51
+ image = dd.image.transform_and_pad_image(image, width, height)
52
+ image = image / 255.
53
+ probs = model.predict(image[None, ...])[0]
54
+ probs = probs.astype(float)
55
+
56
+ indices = np.argsort(probs)[::-1]
57
+ result_all = dict()
58
+ result_threshold = dict()
59
+ for index in indices:
60
+ label = labels[index]
61
+ prob = probs[index]
62
+ result_all[label] = prob
63
+ if prob < score_threshold:
64
+ break
65
+ result_threshold[label] = prob
66
+ result_text = ', '.join(result_all.keys())
67
+ return result_threshold, result_all, result_text
68
+
69
+ # NSFW 判定関数
70
+ def predict_nsfw(image: Image.Image) -> dict[str, float]:
71
+ classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
72
+ result = classifier(image)[0]
73
+ return {result['label']: result['score']}
74
+
75
+ # ブラーをかける関数
76
+ def blur_image(image: Image.Image, blur_enabled: bool, nsfw_score: float) -> Image.Image:
77
+ if blur_enabled:
78
+ nsfw_result = predict_nsfw(image)
79
+ nsfw_score = nsfw_result.get('nsfw', 0.0)
80
+ if nsfw_score >= 0.75:
81
+ image = image.filter(ImageFilter.BLUR)
82
+ return image
83
+
84
  # blur_toggle_keyの初期値を設定
85
  blur_toggle_key = "blur_toggle_unique_key_for_sidebar"
86
 
 
104
 
105
  image.save(os.path.join(photos_folder, file_name), "PNG")
106
 
107
+ # ページングと並び替えのためのデータを取得する関数
108
+ def load_data():
109
+ if os.path.exists(index_file_path):
110
+ return pd.read_csv(index_file_path)
111
+ else:
112
+ return pd.DataFrame(columns=["File Name", "Timestamp", "Tags"])
113
+
114
+ # アップロードされた写真を表示する関数
115
+ def display_photos(photos):
116
+ for photo_info in photos.iterrows():
117
+ row = photo_info[1]
118
+ photo_path = os.path.join(photos_folder, row["File Name"])
119
+ image = Image.open(photo_path)
120
+
121
+ # タグを予測して表示
122
+ result_threshold, result_all, result_text = predict_tags(image, 0.7)
123
+
124
+ st.image(image, caption=row["File Name"], use_column_width=True)
125
+
126
+ # タグを表示
127
+ st.write("タグ:", ", ".join(result_all.keys()))
128
+
129
+ # フォルダーの中身を zip ファイルとしてダウンロード
130
+ def download_photos_as_zip(file_paths):
131
+ # Zip ファイル作成
132
+ with BytesIO() as zip_buffer:
133
+ with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED) as zip_file:
134
+ for file_path in file_paths:
135
+ zip_file.write(file_path, os.path.basename(file_path))
136
+
137
+ # ダウンロードリンク表示
138
+ st.markdown(
139
+ f"**[ダウンロード ZIPファイル](data:application/zip;base64,{base64.b64encode(zip_buffer.getvalue()).decode()})**",
140
+ unsafe_allow_html=True
141
+ )
142
+
143
+ # モデルとラベルをダウンロードする関数
144
+ def load_model():
145
+ path = huggingface_hub.hf_hub_download('public-data/DeepDanbooru', 'model-resnet_custom_v3.h5')
146
+ model = tf.keras.models.load_model(path)
147
+ return model
148
+
149
+ def load_labels():
150
+ path = huggingface_hub.hf_hub_download('public-data/DeepDanbooru', 'tags.txt')
151
+ with open(path) as f:
152
+ labels = [line.strip() for line in f.readlines()]
153
+ return labels
154
+
155
  # Streamlit アプリケーションのメイン部分
156
  def main():
157
  st.sidebar.title("アップロードオプション")
 
217
  if st.button("写真をダウンロード (ZIP)"):
218
  file_paths = [os.path.join(photos_folder, file) for file in os.listdir(photos_folder)]
219
  download_photos_as_zip(file_paths)
220
+ print("all function defined.")
221
  if __name__ == "__main__":
222
  main()