import hashlib import os import sys from contextlib import redirect_stdout from pathlib import Path from typing import Type import onnxruntime as ort import pooch from .session_base import BaseSession from .session_cloth import ClothSession from .session_simple import SimpleSession def new_session(model_name: str = "u2net") -> BaseSession: session_class: Type[BaseSession] md5 = "60024c5c889badc19c04ad937298a77b" url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx" session_class = SimpleSession if model_name == "u2netp": md5 = "8e83ca70e441ab06c318d82300c84806" url = ( "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx" ) session_class = SimpleSession elif model_name == "u2net_human_seg": md5 = "c09ddc2e0104f800e3e1bb4652583d1f" url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx" session_class = SimpleSession elif model_name == "u2net_cloth_seg": md5 = "2434d1f3cb744e0e49386c906e5a08bb" url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx" session_class = ClothSession elif model_name == "silueta": md5 = "55e59e0d8062d2f5d013f4725ee84782" url = ( "https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx" ) session_class = SimpleSession u2net_home = os.getenv( "U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net") ) fname = f"{model_name}.onnx" path = Path(u2net_home).expanduser() full_path = Path(u2net_home).expanduser() / fname pooch.retrieve( url, f"md5:{md5}", fname=fname, path=Path(u2net_home).expanduser(), progressbar=True, ) sess_opts = ort.SessionOptions() if "OMP_NUM_THREADS" in os.environ: sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"]) return session_class( model_name, ort.InferenceSession( str(full_path), providers=ort.get_available_providers(), sess_options=sess_opts, ), )