Spaces:
Running
Running
File size: 3,716 Bytes
fb59cb8 e9c5f95 01b28b7 e9c5f95 fb59cb8 e9c5f95 01b28b7 e9c5f95 01b28b7 8cdd8e4 fb59cb8 bbe49e5 8cdd8e4 bbe49e5 fb59cb8 bbe49e5 8cdd8e4 bbe49e5 8cdd8e4 01b28b7 8cdd8e4 01b28b7 8cdd8e4 01b28b7 bbe49e5 8cdd8e4 bbe49e5 8cdd8e4 01b28b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
#!/usr/bin/env python
from __future__ import annotations
import deepdanbooru as dd
import gradio as gr
import huggingface_hub
import numpy as np
import PIL.Image
import tensorflow as tf
def load_model() -> tf.keras.Model:
path = huggingface_hub.hf_hub_download('public-data/DeepDanbooru',
'model-resnet_custom_v3.h5')
model = tf.keras.models.load_model(path)
return model
def load_labels() -> list[str]:
path = huggingface_hub.hf_hub_download('public-data/DeepDanbooru',
'tags.txt')
with open(path) as f:
labels = [line.strip() for line in f.readlines()]
return labels
model = load_model()
labels = load_labels()
def predict(image: PIL.Image.Image, score_threshold: float):
_, height, width, _ = model.input_shape
image = np.asarray(image)
image = tf.image.resize(image,
size=(height, width),
method=tf.image.ResizeMethod.AREA,
preserve_aspect_ratio=True)
image = image.numpy()
image = dd.image.transform_and_pad_image(image, width, height)
image = image / 255.
probs = model.predict(image[None, ...])[0]
probs = probs.astype(float)
indices = np.argsort(probs)[::-1]
result_all = dict()
result_threshold = dict()
result_html = ''
for index in indices:
label = labels[index]
prob = probs[index]
result_all[label] = prob
if prob < score_threshold:
break
result_html = result_html + '<p class="m5dd_list use"><span>' + str(label) + '</span><span>' + str(round(prob, 3)) + '</span></p>'
result_threshold[label] = prob
result_text = ', '.join(result_threshold.keys())
result_text = '<div id="m5dd_result">' + str(result_text) + '</div>'
result_html = '<div>' + str(result_html) + '</div>'
return result_html, result_text
js = """
async () => {
document.addEventListener('click', function(event) {
let tagItem = event.target.closest('.m5dd_list')
let resultArea = event.target.closest('#m5dd_result')
if (tagItem){
if (tagItem.classList.contains('use')){
tagItem.classList.remove('use')
}else{
tagItem.classList.add('use')
}
document.getElementById('m5dd_result').innerText =
Array.from(document.querySelectorAll('.m5dd_list.use>span:nth-child(1)'))
.map(v=>v.innerText)
.join(', ')
}else if (resultArea){
const selection = window.getSelection()
selection.removeAllRanges()
const range = document.createRange()
range.selectNodeContents(resultArea)
selection.addRange(range)
}else{
return
}
})
}
"""
with gr.Blocks(css="style.css") as demo:
with gr.Row():
with gr.Column(scale=1):
image = gr.Image(label='Input', type='pil')
score_threshold = gr.Slider(label='Score threshold',
minimum=0,
maximum=1,
step=0.05,
value=0.5)
run_button = gr.Button('Run')
result_text = gr.HTML(value="<div></div>")
with gr.Column(scale=3):
result_html = gr.HTML(value="<div></div>")
run_button.click(
fn=predict,
inputs=[image, score_threshold],
outputs=[result_html, result_text],
api_name='predict',
)
demo.load(None,None,None,_js=js)
demo.queue().launch()
|