#!/usr/bin/env python from __future__ import annotations import deepdanbooru as dd import gradio as gr import huggingface_hub import numpy as np import PIL.Image import tensorflow as tf def load_model() -> tf.keras.Model: path = huggingface_hub.hf_hub_download('public-data/DeepDanbooru', 'model-resnet_custom_v3.h5') model = tf.keras.models.load_model(path) return model def load_labels() -> list[str]: path = huggingface_hub.hf_hub_download('public-data/DeepDanbooru', 'tags.txt') with open(path) as f: labels = [line.strip() for line in f.readlines()] return labels model = load_model() labels = load_labels() def predict(image: PIL.Image.Image, score_threshold: float): _, height, width, _ = model.input_shape image = np.asarray(image) image = tf.image.resize(image, size=(height, width), method=tf.image.ResizeMethod.AREA, preserve_aspect_ratio=True) image = image.numpy() image = dd.image.transform_and_pad_image(image, width, height) image = image / 255. probs = model.predict(image[None, ...])[0] probs = probs.astype(float) indices = np.argsort(probs)[::-1] result_all = dict() result_threshold = dict() result_html = '' for index in indices: label = labels[index] prob = probs[index] result_all[label] = prob if prob < score_threshold: break result_html = result_html + '

' + str(label) + '' + str(round(prob, 3)) + '

' result_threshold[label] = prob result_text = ', '.join(result_threshold.keys()) result_text = '
' + str(result_text) + '
' result_html = '
' + str(result_html) + '
' return result_html, result_text js = """ async () => { document.addEventListener('click', function(event) { let tagItem = event.target.closest('.m5dd_list') let resultArea = event.target.closest('#m5dd_result') if (tagItem){ if (tagItem.classList.contains('use')){ tagItem.classList.remove('use') }else{ tagItem.classList.add('use') } document.getElementById('m5dd_result').innerText = Array.from(document.querySelectorAll('.m5dd_list.use>span:nth-child(1)')) .map(v=>v.innerText) .join(', ') }else if (resultArea){ const selection = window.getSelection() selection.removeAllRanges() const range = document.createRange() range.selectNodeContents(resultArea) selection.addRange(range) }else{ return } }) } """ with gr.Blocks(css="style.css") as demo: with gr.Row(): with gr.Column(scale=1): image = gr.Image(label='Input', type='pil') score_threshold = gr.Slider(label='Score threshold', minimum=0, maximum=1, step=0.05, value=0.5) run_button = gr.Button('Run') result_text = gr.HTML(value="
") with gr.Column(scale=3): result_html = gr.HTML(value="
") run_button.click( fn=predict, inputs=[image, score_threshold], outputs=[result_html, result_text], api_name='predict', ) demo.load(None,None,None,_js=js) demo.queue().launch()