Spaces:
Running
Running
#!/usr/bin/env python | |
from __future__ import annotations | |
import functools | |
import json | |
import os | |
import pathlib | |
import tarfile | |
from typing import Callable | |
import gradio as gr | |
import huggingface_hub | |
import PIL.Image | |
import torch | |
import torchvision.transforms as T | |
DESCRIPTION = '# [RF5/danbooru-pretrained](https://github.com/RF5/danbooru-pretrained)' | |
MODEL_REPO = 'public-data/danbooru-pretrained' | |
def load_sample_image_paths() -> list[pathlib.Path]: | |
image_dir = pathlib.Path('images') | |
if not image_dir.exists(): | |
dataset_repo = 'hysts/sample-images-TADNE' | |
path = huggingface_hub.hf_hub_download(dataset_repo, | |
'images.tar.gz', | |
repo_type='dataset') | |
with tarfile.open(path) as f: | |
f.extractall() | |
return sorted(image_dir.glob('*')) | |
def load_model(device: torch.device) -> torch.nn.Module: | |
path = huggingface_hub.hf_hub_download(MODEL_REPO, 'resnet50-13306192.pth') | |
state_dict = torch.load(path) | |
model = torch.hub.load('RF5/danbooru-pretrained', | |
'resnet50', | |
pretrained=False) | |
model.load_state_dict(state_dict) | |
model.to(device) | |
model.eval() | |
return model | |
def load_labels() -> list[str]: | |
path = huggingface_hub.hf_hub_download(MODEL_REPO, 'class_names_6000.json') | |
with open(path) as f: | |
labels = json.load(f) | |
return labels | |
def predict(image: PIL.Image.Image, score_threshold: float, | |
transform: Callable, device: torch.device, model: torch.nn.Module, | |
labels: list[str]) -> dict[str, float]: | |
data = transform(image) | |
data = data.to(device).unsqueeze(0) | |
preds = model(data)[0] | |
preds = torch.sigmoid(preds) | |
preds = preds.cpu().numpy().astype(float) | |
res = dict() | |
for prob, label in zip(preds.tolist(), labels): | |
if prob < score_threshold: | |
continue | |
res[label] = prob | |
return res | |
image_paths = load_sample_image_paths() | |
examples = [[path.as_posix(), 0.4] for path in image_paths] | |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
model = load_model(device) | |
labels = load_labels() | |
transform = T.Compose([ | |
T.Resize(360), | |
T.ToTensor(), | |
T.Normalize(mean=[0.7137, 0.6628, 0.6519], std=[0.2970, 0.3017, 0.2979]), | |
]) | |
fn = functools.partial(predict, | |
transform=transform, | |
device=device, | |
model=model, | |
labels=labels) | |
with gr.Blocks(css='style.css') as demo: | |
gr.Markdown(DESCRIPTION) | |
with gr.Row(): | |
with gr.Column(): | |
image = gr.Image(label='Input', type='pil') | |
threshold = gr.Slider(label='Score Threshold', | |
minimum=0, | |
maximum=1, | |
step=0.05, | |
value=0.4) | |
run_button = gr.Button('Run') | |
with gr.Column(): | |
result = gr.Label(label='Output') | |
inputs = [image, threshold] | |
gr.Examples(examples=examples, | |
inputs=inputs, | |
outputs=result, | |
fn=fn, | |
cache_examples=os.getenv('CACHE_EXAMPLES') == '1') | |
run_button.click(fn=fn, inputs=inputs, outputs=result, api_name='predict') | |
demo.queue(max_size=15).launch() | |