Spaces:
Runtime error
Runtime error
import io | |
import os | |
import sys | |
from typing import List, Optional | |
from urllib.parse import urlparse | |
import cv2 | |
import numpy as np | |
import torch | |
from PIL import Image, ImageOps | |
from loguru import logger | |
from torch.hub import download_url_to_file, get_dir | |
def get_cache_path_by_url(url): | |
parts = urlparse(url) | |
hub_dir = get_dir() | |
model_dir = os.path.join(hub_dir, "checkpoints") | |
if not os.path.isdir(model_dir): | |
os.makedirs(model_dir) | |
filename = os.path.basename(parts.path) | |
cached_file = os.path.join(model_dir, filename) | |
return cached_file | |
def download_model(url): | |
cached_file = get_cache_path_by_url(url) | |
if not os.path.exists(cached_file): | |
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) | |
hash_prefix = None | |
download_url_to_file(url, cached_file, hash_prefix, progress=True) | |
return cached_file | |
def ceil_modulo(x, mod): | |
if x % mod == 0: | |
return x | |
return (x // mod + 1) * mod | |
def load_jit_model(url_or_path, device): | |
# if os.path.exists(url_or_path): | |
# model_path = url_or_path | |
# else: | |
# model_path = download_model(url_or_path) | |
model_path = os.getcwd() | |
logger.info(f"Load model from: {model_path}") | |
try: | |
model = torch.jit.load(model_path).to(device) | |
except: | |
logger.error( | |
f"Failed to load {model_path}, delete model and restart lama-cleaner" | |
) | |
exit(-1) | |
model.eval() | |
return model | |
def load_model(model: torch.nn.Module, url_or_path, device): | |
if os.path.exists(url_or_path): | |
model_path = url_or_path | |
else: | |
model_path = download_model(url_or_path) | |
try: | |
state_dict = torch.load(model_path, map_location='cpu') | |
model.load_state_dict(state_dict, strict=True) | |
model.to(device) | |
logger.info(f"Load model from: {model_path}") | |
except: | |
logger.error( | |
f"Failed to load {model_path}, delete model and restart lama-cleaner" | |
) | |
exit(-1) | |
model.eval() | |
return model | |
def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes: | |
data = cv2.imencode( | |
f".{ext}", | |
image_numpy, | |
[int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0], | |
)[1] | |
image_bytes = data.tobytes() | |
return image_bytes | |
def load_img(img_bytes, gray: bool = False): | |
alpha_channel = None | |
image = Image.open(io.BytesIO(img_bytes)) | |
try: | |
image = ImageOps.exif_transpose(image) | |
except: | |
pass | |
if gray: | |
image = image.convert('L') | |
np_img = np.array(image) | |
else: | |
if image.mode == 'RGBA': | |
np_img = np.array(image) | |
alpha_channel = np_img[:, :, -1] | |
np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB) | |
else: | |
image = image.convert('RGB') | |
np_img = np.array(image) | |
return np_img, alpha_channel | |
def norm_img(np_img): | |
if len(np_img.shape) == 2: | |
np_img = np_img[:, :, np.newaxis] | |
np_img = np.transpose(np_img, (2, 0, 1)) | |
np_img = np_img.astype("float32") / 255 | |
return np_img | |
def resize_max_size( | |
np_img, size_limit: int, interpolation=cv2.INTER_CUBIC | |
) -> np.ndarray: | |
# Resize image's longer size to size_limit if longer size larger than size_limit | |
h, w = np_img.shape[:2] | |
if max(h, w) > size_limit: | |
ratio = size_limit / max(h, w) | |
new_w = int(w * ratio + 0.5) | |
new_h = int(h * ratio + 0.5) | |
return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation) | |
else: | |
return np_img | |
def pad_img_to_modulo( | |
img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None | |
): | |
""" | |
Args: | |
img: [H, W, C] | |
mod: | |
square: 是否为正方形 | |
min_size: | |
Returns: | |
""" | |
if len(img.shape) == 2: | |
img = img[:, :, np.newaxis] | |
height, width = img.shape[:2] | |
out_height = ceil_modulo(height, mod) | |
out_width = ceil_modulo(width, mod) | |
if min_size is not None: | |
assert min_size % mod == 0 | |
out_width = max(min_size, out_width) | |
out_height = max(min_size, out_height) | |
if square: | |
max_size = max(out_height, out_width) | |
out_height = max_size | |
out_width = max_size | |
return np.pad( | |
img, | |
((0, out_height - height), (0, out_width - width), (0, 0)), | |
mode="symmetric", | |
) | |
def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]: | |
""" | |
Args: | |
mask: (h, w, 1) 0~255 | |
Returns: | |
""" | |
height, width = mask.shape[:2] | |
_, thresh = cv2.threshold(mask, 127, 255, 0) | |
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
boxes = [] | |
for cnt in contours: | |
x, y, w, h = cv2.boundingRect(cnt) | |
box = np.array([x, y, x + w, y + h]).astype(int) | |
box[::2] = np.clip(box[::2], 0, width) | |
box[1::2] = np.clip(box[1::2], 0, height) | |
boxes.append(box) | |
return boxes | |
def only_keep_largest_contour(mask: np.ndarray) -> List[np.ndarray]: | |
""" | |
Args: | |
mask: (h, w) 0~255 | |
Returns: | |
""" | |
_, thresh = cv2.threshold(mask, 127, 255, 0) | |
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
max_area = 0 | |
max_index = -1 | |
for i, cnt in enumerate(contours): | |
area = cv2.contourArea(cnt) | |
if area > max_area: | |
max_area = area | |
max_index = i | |
if max_index != -1: | |
new_mask = np.zeros_like(mask) | |
return cv2.drawContours(new_mask, contours, max_index, 255, -1) | |
else: | |
return mask | |