|
import os |
|
import time |
|
import shutil |
|
from pathlib import Path |
|
from typing import Optional |
|
|
|
import gradio as gr |
|
from huggingface_hub import snapshot_download |
|
from PIL import Image |
|
|
|
|
|
from handler import EndpointHandler |
|
|
|
|
|
|
|
|
|
|
|
|
|
REPO_ID = os.environ.get("ASSETS_REPO_ID", "pixai-labs/pixai-tagger-v0.9") |
|
REVISION = os.environ.get("ASSETS_REVISION") |
|
MODEL_DIR = os.environ.get("MODEL_DIR", "./assets") |
|
|
|
|
|
HF_TOKEN = ( |
|
os.environ.get("HUGGINGFACE_HUB_TOKEN") |
|
or os.environ.get("HF_TOKEN") |
|
or os.environ.get("HUGGINGFACE_TOKEN") |
|
or os.environ.get("HUGGINGFACEHUB_API_TOKEN") |
|
) |
|
|
|
REQUIRED_FILES = [ |
|
"model_v0.9.pth", |
|
"tags_v0.9_13k.json", |
|
"char_ip_map.json", |
|
] |
|
|
|
def ensure_assets(repo_id: str, revision: Optional[str], target_dir: str): |
|
""" |
|
1) snapshot_download the upstream repo (cached by HF Hub) |
|
2) copy the required files into `target_dir` with the exact filenames expected |
|
""" |
|
target = Path(target_dir) |
|
target.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
missing = [f for f in REQUIRED_FILES if not (target / f).exists()] |
|
if not missing: |
|
return |
|
|
|
|
|
snapshot_path = snapshot_download( |
|
repo_id=repo_id, |
|
revision=revision, |
|
allow_patterns=REQUIRED_FILES, |
|
token=HF_TOKEN, |
|
) |
|
|
|
|
|
for fname in REQUIRED_FILES: |
|
src = Path(snapshot_path) / fname |
|
dst = target / fname |
|
if not src.exists(): |
|
raise FileNotFoundError( |
|
f"Expected '{fname}' not found in snapshot for {repo_id} @ {revision or 'default'}" |
|
) |
|
shutil.copyfile(src, dst) |
|
|
|
|
|
|
|
ensure_assets(REPO_ID, REVISION, MODEL_DIR) |
|
|
|
|
|
|
|
|
|
|
|
|
|
handler = EndpointHandler(MODEL_DIR) |
|
DEVICE_LABEL = f"Device: {handler.device.upper()}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_inference( |
|
source_choice: str, |
|
image: Optional[Image.Image], |
|
url: str, |
|
general_threshold: float, |
|
character_threshold: float, |
|
): |
|
if source_choice == "Upload image": |
|
if image is None: |
|
raise gr.Error("Please upload an image.") |
|
inputs = image |
|
else: |
|
if not url or not url.strip(): |
|
raise gr.Error("Please provide an image URL.") |
|
inputs = {"url": url.strip()} |
|
|
|
data = { |
|
"inputs": inputs, |
|
"parameters": { |
|
"general_threshold": float(general_threshold), |
|
"character_threshold": float(character_threshold), |
|
}, |
|
} |
|
|
|
started = time.time() |
|
try: |
|
out = handler(data) |
|
except Exception as e: |
|
raise gr.Error(f"Inference error: {e}") from e |
|
latency = round(time.time() - started, 4) |
|
|
|
features = ", ".join(sorted(out.get("feature", []))) or "β" |
|
characters = ", ".join(sorted(out.get("character", []))) or "β" |
|
ips = ", ".join(out.get("ip", [])) or "β" |
|
|
|
meta = { |
|
"device": handler.device, |
|
"latency_s_total": latency, |
|
**out.get("_timings", {}), |
|
} |
|
|
|
return features, characters, ips, meta, out |
|
|
|
|
|
with gr.Blocks(title="PixAI Tagger v0.9 β Demo", fill_height=True) as demo: |
|
gr.Markdown( |
|
""" |
|
# PixAI Tagger v0.9 β Gradio Demo |
|
Downloads model assets from **pixai-labs/pixai-tagger-v0.9** on first run, |
|
then uses your imported `EndpointHandler` to predict **general**, **character**, and **IP** tags. |
|
|
|
**Expected local filenames** (kept unchanged): |
|
- `model_v0.9.pth` |
|
- `tags_v0.9_13k.json` |
|
- `char_ip_map.json` |
|
|
|
Configure via env vars: |
|
- `ASSETS_REPO_ID` (default: `pixai-labs/pixai-tagger-v0.9`) |
|
- `ASSETS_REVISION` (optional) |
|
- `MODEL_DIR` (default: `./assets`) |
|
""" |
|
) |
|
with gr.Row(): |
|
gr.Markdown(f"**{DEVICE_LABEL}**") |
|
|
|
with gr.Row(): |
|
source_choice = gr.Radio( |
|
choices=["Upload image", "From URL"], |
|
value="Upload image", |
|
label="Image source", |
|
) |
|
|
|
with gr.Row(variant="panel"): |
|
with gr.Column(scale=2): |
|
image = gr.Image(label="Upload image", type="pil", visible=True) |
|
url = gr.Textbox(label="Image URL", placeholder="https://β¦", visible=False) |
|
|
|
def toggle_inputs(choice): |
|
return ( |
|
gr.update(visible=(choice == "Upload image")), |
|
gr.update(visible=(choice == "From URL")), |
|
) |
|
|
|
source_choice.change(toggle_inputs, [source_choice], [image, url]) |
|
|
|
with gr.Column(scale=1): |
|
general_threshold = gr.Slider( |
|
minimum=0.0, maximum=1.0, step=0.01, value=0.30, label="General threshold" |
|
) |
|
character_threshold = gr.Slider( |
|
minimum=0.0, maximum=1.0, step=0.01, value=0.85, label="Character threshold" |
|
) |
|
run_btn = gr.Button("Run", variant="primary") |
|
clear_btn = gr.Button("Clear") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("### Predicted Tags") |
|
features_out = gr.Textbox(label="General tags", lines=4) |
|
characters_out = gr.Textbox(label="Character tags", lines=4) |
|
ip_out = gr.Textbox(label="IP tags", lines=2) |
|
|
|
with gr.Column(): |
|
gr.Markdown("### Metadata & Raw Output") |
|
meta_out = gr.JSON(label="Timings/Device") |
|
raw_out = gr.JSON(label="Raw JSON") |
|
|
|
examples = gr.Examples( |
|
label="Examples (URL mode)", |
|
examples=[ |
|
["From URL", None, "https://cdn.donmai.us/sample/50/b7/__komeiji_koishi_touhou_drawn_by_cui_ying__sample-50b7006f16e0144d5b5db44cadc2d22f.jpg", 0.30, 0.85], |
|
], |
|
inputs=[source_choice, image, url, general_threshold, character_threshold], |
|
cache_examples=False, |
|
) |
|
|
|
def clear(): |
|
return (None, "", 0.30, 0.85, "", "", "", {}, {}) |
|
|
|
run_btn.click( |
|
run_inference, |
|
inputs=[source_choice, image, url, general_threshold, character_threshold], |
|
outputs=[features_out, characters_out, ip_out, meta_out, raw_out], |
|
api_name="predict", |
|
) |
|
clear_btn.click( |
|
clear, |
|
inputs=None, |
|
outputs=[ |
|
image, url, general_threshold, character_threshold, |
|
features_out, characters_out, ip_out, meta_out, raw_out |
|
], |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.queue(max_size=8).launch() |
|
|