vessel-detection / run_local.py
DefendIntelligence's picture
Add local install bootstrap and GitHub links
ecb622f verified
from __future__ import annotations
import argparse
import os
import subprocess
import urllib.request
import venv
from pathlib import Path
ROOT = Path(__file__).resolve().parent
VENV_DIR = ROOT / ".venv"
MODEL_DIR = ROOT / "models"
MODEL_PATH = MODEL_DIR / "best.pt"
MODEL_URL = "https://huggingface.co/DefendIntelligence/vessel-detection/resolve/main/models/best.pt"
def _venv_python() -> Path:
if os.name == "nt":
return VENV_DIR / "Scripts" / "python.exe"
return VENV_DIR / "bin" / "python"
def _run(command: list[str | os.PathLike[str]], env: dict[str, str] | None = None) -> None:
printable = " ".join(str(part) for part in command)
print(f"\n$ {printable}", flush=True)
subprocess.check_call([str(part) for part in command], cwd=ROOT, env=env)
def _ensure_venv() -> Path:
python_path = _venv_python()
if not python_path.exists():
print(f"Creating virtual environment: {VENV_DIR}", flush=True)
venv.EnvBuilder(with_pip=True).create(VENV_DIR)
return python_path
def _install_dependencies(python_path: Path) -> None:
_run([python_path, "-m", "pip", "install", "--upgrade", "pip"])
_run([python_path, "-m", "pip", "install", "-r", "requirements.txt"])
def _download_model() -> None:
MODEL_DIR.mkdir(parents=True, exist_ok=True)
if MODEL_PATH.exists() and MODEL_PATH.stat().st_size > 0:
print(f"Model already present: {MODEL_PATH}", flush=True)
return
tmp_path = MODEL_PATH.with_suffix(".pt.tmp")
print(f"Downloading model from Hugging Face:\n{MODEL_URL}", flush=True)
with urllib.request.urlopen(MODEL_URL) as response, tmp_path.open("wb") as handle:
total = int(response.headers.get("Content-Length") or 0)
downloaded = 0
while True:
chunk = response.read(1024 * 1024)
if not chunk:
break
handle.write(chunk)
downloaded += len(chunk)
if total:
percent = downloaded * 100 / total
print(f"\r{downloaded / 1_000_000:.1f} MB / {total / 1_000_000:.1f} MB ({percent:.0f}%)", end="")
else:
print(f"\r{downloaded / 1_000_000:.1f} MB", end="")
print()
tmp_path.replace(MODEL_PATH)
print(f"Saved model to: {MODEL_PATH}", flush=True)
def main() -> None:
parser = argparse.ArgumentParser(description="Install and run the Vessel Detection Gradio demo locally.")
parser.add_argument("--skip-install", action="store_true", help="Do not install Python dependencies.")
parser.add_argument("--download-only", action="store_true", help="Download the model and exit.")
parser.add_argument("--host", default="127.0.0.1", help="Gradio server host.")
parser.add_argument("--port", default="7860", help="Gradio server port.")
args = parser.parse_args()
python_path = None
if not (args.download_only and args.skip_install):
python_path = _ensure_venv()
if not args.skip_install:
if python_path is None:
python_path = _ensure_venv()
_install_dependencies(python_path)
_download_model()
if args.download_only:
return
if python_path is None:
python_path = _ensure_venv()
env = os.environ.copy()
env["GRADIO_SERVER_NAME"] = args.host
env["GRADIO_SERVER_PORT"] = args.port
print(f"\nStarting Gradio at http://{args.host}:{args.port}", flush=True)
_run([python_path, "app.py"], env=env)
if __name__ == "__main__":
main()