fifa-tryon-demo / rembg /session_factory.py
hasibzunair's picture
added files
4a285f6
import hashlib
import os
import sys
from contextlib import redirect_stdout
from pathlib import Path
from typing import Type
import gdown
import onnxruntime as ort
from .session_base import BaseSession
from .session_cloth import ClothSession
from .session_simple import SimpleSession
def new_session(model_name: str) -> BaseSession:
session_class: Type[BaseSession]
if model_name == "u2netp":
md5 = "8e83ca70e441ab06c318d82300c84806"
url = "https://drive.google.com/uc?id=1tNuFmLv0TSNDjYIkjEdeH1IWKQdUA4HR"
session_class = SimpleSession
elif model_name == "u2net":
md5 = "60024c5c889badc19c04ad937298a77b"
url = "https://drive.google.com/uc?id=1tCU5MM1LhRgGou5OpmpjBQbSrYIUoYab"
session_class = SimpleSession
elif model_name == "u2net_human_seg":
md5 = "c09ddc2e0104f800e3e1bb4652583d1f"
url = "https://drive.google.com/uc?id=1ZfqwVxu-1XWC1xU1GHIP-FM_Knd_AX5j"
session_class = SimpleSession
elif model_name == "u2net_cloth_seg":
md5 = "2434d1f3cb744e0e49386c906e5a08bb"
url = "https://drive.google.com/uc?id=15rKbQSXQzrKCQurUjZFg8HqzZad8bcyz"
session_class = ClothSession
else:
assert AssertionError(
"Choose between u2net, u2netp, u2net_human_seg or u2net_cloth_seg"
)
home = os.getenv("U2NET_HOME", os.path.join("~", ".u2net"))
path = Path(home).expanduser() / f"{model_name}.onnx"
path.parents[0].mkdir(parents=True, exist_ok=True)
if not path.exists():
with redirect_stdout(sys.stderr):
gdown.download(url, str(path), use_cookies=False)
else:
hashing = hashlib.new("md5", path.read_bytes(), usedforsecurity=False)
if hashing.hexdigest() != md5:
with redirect_stdout(sys.stderr):
gdown.download(url, str(path), use_cookies=False)
sess_opts = ort.SessionOptions()
if "OMP_NUM_THREADS" in os.environ:
sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
return session_class(
model_name,
ort.InferenceSession(
str(path), providers=ort.get_available_providers(), sess_options=sess_opts
),
)