|
import logging |
|
from typing import Tuple, Union |
|
|
|
import gdown |
|
import numpy as np |
|
import numpy.typing as npt |
|
import torch |
|
from torch import device as torch_device |
|
from torch.nn import Module |
|
|
|
from .module.yolov5 import YOLO_DIR |
|
|
|
|
|
DETECTOR_WEIGHT_ID = "1YHH7pLoZEdyxw2AoLz9G4lrq6uuxweYB" |
|
REMOVER_WEIGHT_ID = "1Hd79M8DhCwjFuT198R-QB7ozQbHRGcGM" |
|
|
|
|
|
def select_device(device: str = "") -> torch_device: |
|
"""Return a torch.device instance""" |
|
cpu = device.lower() == "cpu" |
|
cuda = not cpu and torch.cuda.is_available() |
|
return torch_device("cuda:1" if cuda else "cpu") |
|
|
|
|
|
def load_yolo_model(weight_path: str, device: str) -> Tuple[Module, int]: |
|
"""Load yolov5 model from specified path using torch hub""" |
|
model = torch.hub.load(str(YOLO_DIR), "custom", path=weight_path, source="local", force_reload=True, device=device) |
|
print(weight_path) |
|
|
|
|
|
return model, model.stride |
|
|
|
|
|
def download_weight(file_id: str, output: Union[str, None] = None, quiet: bool = False) -> None: |
|
"""Download model weight from Google Drive given the file ID""" |
|
url = f"https://drive.google.com/uc?id={file_id}" |
|
try: |
|
gdown.cached_download(url=url, path=output, quiet=quiet) |
|
except Exception as e: |
|
print(e) |
|
print("Something went wrong when downloading the weight") |
|
print( |
|
"Check your internet connection or manually download the weight " |
|
f"at https://drive.google.com/file/d/{file_id}/view?usp=sharing" |
|
) |
|
|
|
|
|
def check_image_shape(image: npt.NDArray) -> None: |
|
"""Check if input image is valid""" |
|
if not isinstance(image, np.ndarray): |
|
raise TypeError("Invalid Type: List value must be of type np.ndarray") |
|
else: |
|
if len(image.shape) != 3: |
|
raise ValueError("Invalid image shape") |
|
if image.shape[-1] != 3: |
|
raise ValueError("Image must be 3 dimensional") |
|
|