Spaces:
Running
Running
File size: 3,138 Bytes
c570d14 aa3f835 c570d14 4f1095a c570d14 4f1095a aa3f835 c570d14 aa3f835 4f1095a aa3f835 4f1095a aa3f835 4f1095a aa3f835 4f1095a aa3f835 4f1095a aa3f835 4f1095a aa3f835 c570d14 4f1095a c570d14 7b24b7f c570d14 77bce91 4f1095a 77bce91 4f1095a 77bce91 4f1095a 7ad735d 4f1095a 7ad735d 4f1095a 7ad735d 4f1095a 7ad735d 4f1095a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
#!/usr/bin/env python
from __future__ import annotations
import functools
import json
import os
import pathlib
import tarfile
from typing import Callable
import gradio as gr
import huggingface_hub
import PIL.Image
import torch
import torchvision.transforms as T
DESCRIPTION = "# [RF5/danbooru-pretrained](https://github.com/RF5/danbooru-pretrained)"
MODEL_REPO = "public-data/danbooru-pretrained"
def load_sample_image_paths() -> list[pathlib.Path]:
image_dir = pathlib.Path("images")
if not image_dir.exists():
dataset_repo = "hysts/sample-images-TADNE"
path = huggingface_hub.hf_hub_download(dataset_repo, "images.tar.gz", repo_type="dataset")
with tarfile.open(path) as f:
f.extractall()
return sorted(image_dir.glob("*"))
def load_model(device: torch.device) -> torch.nn.Module:
path = huggingface_hub.hf_hub_download(MODEL_REPO, "resnet50-13306192.pth")
state_dict = torch.load(path)
model = torch.hub.load("RF5/danbooru-pretrained", "resnet50", pretrained=False)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model
def load_labels() -> list[str]:
path = huggingface_hub.hf_hub_download(MODEL_REPO, "class_names_6000.json")
with open(path) as f:
labels = json.load(f)
return labels
@torch.inference_mode()
def predict(
image: PIL.Image.Image,
score_threshold: float,
transform: Callable,
device: torch.device,
model: torch.nn.Module,
labels: list[str],
) -> dict[str, float]:
data = transform(image)
data = data.to(device).unsqueeze(0)
preds = model(data)[0]
preds = torch.sigmoid(preds)
preds = preds.cpu().numpy().astype(float)
res = dict()
for prob, label in zip(preds.tolist(), labels):
if prob < score_threshold:
continue
res[label] = prob
return res
image_paths = load_sample_image_paths()
examples = [[path.as_posix(), 0.4] for path in image_paths]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = load_model(device)
labels = load_labels()
transform = T.Compose(
[
T.Resize(360),
T.ToTensor(),
T.Normalize(mean=[0.7137, 0.6628, 0.6519], std=[0.2970, 0.3017, 0.2979]),
]
)
fn = functools.partial(predict, transform=transform, device=device, model=model, labels=labels)
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column():
image = gr.Image(label="Input", type="pil")
threshold = gr.Slider(label="Score Threshold", minimum=0, maximum=1, step=0.05, value=0.4)
run_button = gr.Button()
with gr.Column():
result = gr.Label(label="Output")
inputs = [image, threshold]
gr.Examples(
examples=examples,
inputs=inputs,
outputs=result,
fn=fn,
cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
)
run_button.click(
fn=fn,
inputs=inputs,
outputs=result,
api_name="predict",
)
if __name__ == "__main__":
demo.queue(max_size=15).launch()
|