import hashlib import os import sys from contextlib import redirect_stdout from pathlib import Path from typing import Type import gdown import onnxruntime as ort from .session_base import BaseSession from .session_cloth import ClothSession from .session_simple import SimpleSession def new_session(model_name: str) -> BaseSession: session_class: Type[BaseSession] if model_name == "u2netp": md5 = "8e83ca70e441ab06c318d82300c84806" url = "https://drive.google.com/uc?id=1tNuFmLv0TSNDjYIkjEdeH1IWKQdUA4HR" session_class = SimpleSession elif model_name == "u2net": md5 = "60024c5c889badc19c04ad937298a77b" url = "https://drive.google.com/uc?id=1tCU5MM1LhRgGou5OpmpjBQbSrYIUoYab" session_class = SimpleSession elif model_name == "u2net_human_seg": md5 = "c09ddc2e0104f800e3e1bb4652583d1f" url = "https://drive.google.com/uc?id=1ZfqwVxu-1XWC1xU1GHIP-FM_Knd_AX5j" session_class = SimpleSession elif model_name == "u2net_cloth_seg": md5 = "2434d1f3cb744e0e49386c906e5a08bb" url = "https://drive.google.com/uc?id=15rKbQSXQzrKCQurUjZFg8HqzZad8bcyz" session_class = ClothSession else: assert AssertionError( "Choose between u2net, u2netp, u2net_human_seg or u2net_cloth_seg" ) home = os.getenv("U2NET_HOME", os.path.join("~", ".u2net")) path = Path(home).expanduser() / f"{model_name}.onnx" path.parents[0].mkdir(parents=True, exist_ok=True) if not path.exists(): with redirect_stdout(sys.stderr): gdown.download(url, str(path), use_cookies=False) else: hashing = hashlib.new("md5", path.read_bytes(), usedforsecurity=False) if hashing.hexdigest() != md5: with redirect_stdout(sys.stderr): gdown.download(url, str(path), use_cookies=False) sess_opts = ort.SessionOptions() if "OMP_NUM_THREADS" in os.environ: sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"]) return session_class( model_name, ort.InferenceSession( str(path), providers=ort.get_available_providers(), sess_options=sess_opts ), )