|
print("[1/7] Importing streamlit...", end=" ", flush=True) |
|
import streamlit as st |
|
print("done") |
|
print("[2/7] Importing something...", end=" ", flush=True) |
|
import os |
|
import pandas as pd |
|
from datetime import datetime |
|
from PIL import Image |
|
from PIL import ImageFilter |
|
from io import BytesIO |
|
import zipfile |
|
import base64 |
|
print("done") |
|
print("[3/7] Importing deepdanbooru...", end=" ", flush=True) |
|
import deepdanbooru as dd |
|
print("done") |
|
print("[4/7] Importing huggingface_hub...", end=" ", flush=True) |
|
import huggingface_hub |
|
print("done") |
|
print("[5/7] Importing tensorflow...", end=" ", flush=True) |
|
import tensorflow as tf |
|
print("done") |
|
print("[6/7] Importing numpy...", end=" ", flush=True) |
|
import numpy as np |
|
print("done") |
|
print("[6/7] Importing transformers...", end=" ", flush=True) |
|
from transformers import pipeline |
|
print("done") |
|
|
|
def extract_zip_to_temp(zip_file_path, temp_folder): |
|
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: |
|
zip_ref.extractall(temp_folder) |
|
zip_file_path = 'wolf_2m.zip' |
|
temp_folder = 'temp' |
|
extract_zip_to_temp(zip_file_path, temp_folder) |
|
|
|
|
|
|
|
|
|
|
|
PAGE_SIZE = 20 |
|
|
|
|
|
photos_folder = "photos" |
|
|
|
|
|
index_file_path = "index.csv" |
|
|
|
|
|
|
|
import random |
|
import string |
|
|
|
blur_toggle_key = "blurOrNot" |
|
|
|
|
|
if "blur_enabled" not in st.session_state: |
|
st.session_state.blur_enabled = st.sidebar.checkbox("NSFW画像にブラーをかける", value=True, key=blur_toggle_key) |
|
|
|
|
|
blur_enabled = st.session_state.blur_enabled |
|
|
|
|
|
tag_search_input = st.sidebar.text_input("タグ検索 (カンマでセパレート)", "", key="tag_search") |
|
|
|
|
|
photo_df = pd.DataFrame(columns=["File Name", "Timestamp", "Tags"]) |
|
|
|
|
|
page_num = st.sidebar.number_input("ページ番号", value=1, min_value=1, max_value=(len(photo_df) // PAGE_SIZE) + 1, key="page_num") |
|
|
|
|
|
sort_options = ["TimeStamp 昇順", "TimeStamp 降順", "名前 昇順", "名前 降順"] |
|
selected_sort = st.sidebar.selectbox("写真の並び替え", sort_options, key="selected_sort") |
|
|
|
|
|
if tag_search_input: |
|
tags_to_search = [tag.strip() for tag in tag_search_input.split(',')] |
|
current_page = photo_df[photo_df["Tags"].apply(lambda x: all(tag in x.split(', ') for tag in tags_to_search))] |
|
else: |
|
current_page = photo_df.copy() |
|
|
|
|
|
if "TimeStamp 昇順" in selected_sort: |
|
current_page = current_page.sort_values(by="Timestamp", ascending=True) |
|
elif "TimeStamp 降順" in selected_sort: |
|
current_page = current_page.sort_values(by="Timestamp", ascending=False) |
|
elif "名前 昇順" in selected_sort: |
|
current_page = current_page.sort_values(by="File Name", ascending=True) |
|
elif "名前 降順" in selected_sort: |
|
current_page = current_page.sort_values(by="File Name", ascending=False) |
|
|
|
if tag_search_input: |
|
st.sidebar.subheader("タグ検索結果") |
|
st.sidebar.write(f"検索したタグ: {', '.join(tags_to_search)}") |
|
st.sidebar.write(f"結果数: {len(current_page)}") |
|
|
|
def predict_tags(image: Image.Image, score_threshold: float) -> tuple[dict[str, float], dict[str, float], str]: |
|
_, height, width, _ = model.input_shape |
|
image = np.asarray(image) |
|
image = tf.image.resize(image, size=(height, width), method=tf.image.ResizeMethod.AREA, preserve_aspect_ratio=True) |
|
image = image.numpy() |
|
image = dd.image.transform_and_pad_image(image, width, height) |
|
image = image / 255. |
|
probs = model.predict(image[None, ...])[0] |
|
probs = probs.astype(float) |
|
|
|
indices = np.argsort(probs)[::-1] |
|
result_all = dict() |
|
result_threshold = dict() |
|
for index in indices: |
|
label = labels[index] |
|
prob = probs[index] |
|
result_all[label] = prob |
|
if prob < score_threshold: |
|
break |
|
result_threshold[label] = prob |
|
result_text = ', '.join(result_all.keys()) |
|
return result_threshold, result_all, result_text |
|
|
|
|
|
def predict_nsfw(image: Image.Image) -> dict[str, float]: |
|
classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection") |
|
result = classifier(image)[0] |
|
return {result['label']: result['score']} |
|
|
|
|
|
def blur_image(image: Image.Image, blur_enabled: bool, nsfw_score: float) -> Image.Image: |
|
if blur_enabled: |
|
nsfw_result = predict_nsfw(image) |
|
nsfw_score = nsfw_result.get('nsfw', 0.0) |
|
if nsfw_score >= 0.75: |
|
blur_radius=20*nsfw_score |
|
image = image.filter(ImageFilter.GaussianBlur(radius=blur_radius)) |
|
|
|
return image |
|
|
|
|
|
def load_data(): |
|
if os.path.exists(index_file_path): |
|
return pd.read_csv(index_file_path) |
|
else: |
|
return pd.DataFrame(columns=["File Name", "Timestamp", "Tags"]) |
|
|
|
|
|
|
|
|
|
if "blur_enabled" not in st.session_state: |
|
st.session_state.blur_enabled = st.sidebar.checkbox("NSFW画像にブラーをかける", value=True, key=blur_toggle_key) |
|
|
|
def save_uploaded_photo(uploaded_photo, file_name): |
|
if not os.path.exists(photos_folder): |
|
os.makedirs(photos_folder) |
|
|
|
|
|
image = Image.open(uploaded_photo) |
|
|
|
|
|
|
|
|
|
|
|
image = blur_image(image, blur_enabled, 0.75) |
|
|
|
image.save(os.path.join(photos_folder, file_name), "PNG") |
|
|
|
|
|
def display_photos(photos): |
|
for photo_info in photos.iterrows(): |
|
row = photo_info[1] |
|
photo_path = os.path.join(photos_folder, row["File Name"]) |
|
image = Image.open(photo_path) |
|
|
|
|
|
result_threshold, result_all, result_text = predict_tags(image, 0.7) |
|
|
|
st.image(image, caption=row["File Name"], use_column_width=True) |
|
|
|
|
|
st.write("タグ:", ", ".join(result_all.keys())) |
|
|
|
|
|
def download_photos_as_zip(file_paths): |
|
|
|
with BytesIO() as zip_buffer: |
|
with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED) as zip_file: |
|
for file_path in file_paths: |
|
zip_file.write(file_path, os.path.basename(file_path)) |
|
|
|
|
|
st.markdown( |
|
f"**[ダウンロード ZIPファイル](data:application/zip;base64,{base64.b64encode(zip_buffer.getvalue()).decode()})**", |
|
unsafe_allow_html=True |
|
) |
|
|
|
|
|
def load_model(): |
|
path = huggingface_hub.hf_hub_download('public-data/DeepDanbooru', 'model-resnet_custom_v3.h5') |
|
model = tf.keras.models.load_model(path) |
|
return model |
|
|
|
def load_labels(): |
|
path = huggingface_hub.hf_hub_download('public-data/DeepDanbooru', 'tags.txt') |
|
with open(path) as f: |
|
labels = [line.strip() for line in f.readlines()] |
|
return labels |
|
|
|
|
|
def main(): |
|
st.sidebar.title("アップロードオプション") |
|
uploaded_photos = st.sidebar.file_uploader("写真をアップロードしてください", type=["jpg", "jpeg", "png"], accept_multiple_files=True) |
|
|
|
|
|
global model, labels |
|
model = load_model() |
|
labels = load_labels() |
|
|
|
|
|
global photo_df |
|
photo_df = load_data() |
|
|
|
if uploaded_photos: |
|
for uploaded_photo in uploaded_photos: |
|
file_name = f"{datetime.now().strftime('%Y%m%d%H%M%S%f')}.png" |
|
save_uploaded_photo(uploaded_photo, file_name) |
|
|
|
|
|
photo_df = pd.concat([photo_df, pd.DataFrame([[os.path.basename(file_name), datetime.now(), ""]], columns=["File Name", "Timestamp", "Tags"])], ignore_index=True) |
|
|
|
|
|
photo_df.sort_values(by="Timestamp", inplace=True, ascending=False) |
|
photo_df.reset_index(drop=True, inplace=True) |
|
photo_df.to_csv(index_file_path, index=False) |
|
|
|
st.subheader("アップロードされた写真") |
|
|
|
|
|
page_num = st.sidebar.number_input("ページ番号", value=1, min_value=1, max_value=(len(photo_df) // PAGE_SIZE) + 1) |
|
start_idx = (page_num - 1) * PAGE_SIZE |
|
end_idx = min(start_idx + PAGE_SIZE, len(photo_df)) |
|
current_page = photo_df.iloc[start_idx:end_idx] |
|
|
|
|
|
|
|
|
|
|
|
if selected_sort == "TimeStamp 昇順": |
|
current_page = current_page.sort_values(by="Timestamp", ascending=True) |
|
elif selected_sort == "TimeStamp 降順": |
|
current_page = current_page.sort_values(by="Timestamp", ascending=False) |
|
elif selected_sort == "名前 昇順": |
|
current_page = current_page.sort_values(by="File Name", ascending=True) |
|
elif selected_sort == "名前 降順": |
|
current_page = current_page.sort_values(by="File Name", ascending=False) |
|
|
|
|
|
if tag_search_input: |
|
st.sidebar.subheader("タグ検索結果") |
|
st.sidebar.write(f"検索したタグ: {', '.join(tags_to_search)}") |
|
st.sidebar.write(f"結果数: {len(current_page)}") |
|
|
|
|
|
display_photos(current_page) |
|
|
|
|
|
if st.button("Next"): |
|
page_num += 1 |
|
st.experimental_rerun() |
|
|
|
|
|
if st.button("Previous"): |
|
page_num -= 1 |
|
st.experimental_rerun() |
|
|
|
|
|
if st.button("写真をダウンロード (ZIP)"): |
|
file_paths = [os.path.join(photos_folder, file) for file in os.listdir(photos_folder)] |
|
download_photos_as_zip(file_paths) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|