import argparse import json from pathlib import Path import gradio as gr import torch from models import AudioClassifier from utils import logger ckpt_dir = Path("ckpt/") config_path = ckpt_dir / "config.json" assert config_path.exists(), f"config.json not found in {ckpt_dir}" config = json.loads((ckpt_dir / "config.json").read_text()) device = "cuda" if torch.cuda.is_available() else "cpu" model = AudioClassifier(device=device, **config["model"]).to(device) # Latest checkpoint if (ckpt_dir / "model_final.pth").exists(): ckpt = ckpt_dir / "model_final.pth" else: ckpt = sorted(ckpt_dir.glob("*.pth"))[-1] logger.info(f"Loading {ckpt}...") model.load_state_dict(torch.load(ckpt)) def classify_audio(audio_file: str): logger.info(f"Classifying {audio_file}...") output = model.infer_from_file(audio_file) logger.success(f"Predicted: {output}") return output desc = """ # NSFW音声分類器 出力は以下の3つのクラスの確率です。 - usual: 通常の音声 - aegi: 喘ぎ声 - chupa: チュパ音(フェラやキス音声) """ with gr.Interface( fn=classify_audio, inputs=gr.Audio(label="Input audio", type="filepath"), outputs=gr.Text(label="Classification"), description=desc, allow_flagging="never", ) as iface: iface.launch()