Spaces:
Runtime error
Runtime error
File size: 3,772 Bytes
55cc64a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import requests
import os
from PIL import Image, ImageOps
import cv2
import numpy as np
import socket
import torchvision.transforms.functional as TF
def load_img(path : str, shape=None, use_alpha_as_mask=False):
# use_alpha_as_mask: Read the alpha channel of the image as the mask image
image = load_image(path)
if use_alpha_as_mask:
image = image.convert('RGBA')
else:
image = image.convert('RGB')
if shape is not None:
image = image.resize(shape, resample=Image.LANCZOS)
mask_image = None
if use_alpha_as_mask:
# Split alpha channel into a mask_image
red, green, blue, alpha = Image.Image.split(image)
mask_image = alpha.convert('L')
image = image.convert('RGB')
# check using init image alpha as mask if mask is not blank
extrema = mask_image.getextrema()
if (extrema == (0,0)) or extrema == (255,255):
print("use_alpha_as_mask==True: Using the alpha channel from the init image as a mask, but the alpha channel is blank.")
print("ignoring alpha as mask.")
mask_image = None
return image, mask_image
def load_image(image_path :str):
image = None
if image_path.startswith('http://') or image_path.startswith('https://'):
try:
host = socket.gethostbyname("www.google.com")
s = socket.create_connection((host, 80), 2)
s.close()
except:
raise ConnectionError("There is no active internet connection available - please use local masks and init files only.")
try:
response = requests.get(image_path, stream=True)
except requests.exceptions.RequestException as e:
raise ConnectionError("Failed to download image due to no internet connection. Error: {}".format(e))
if response.status_code == 404 or response.status_code != 200:
raise ConnectionError("Init image url or mask image url is not valid")
image = Image.open(response.raw).convert('RGB')
else:
if not os.path.exists(image_path):
raise RuntimeError("Init image path or mask image path is not valid")
image = Image.open(image_path).convert('RGB')
return image
def prepare_mask(mask_input, mask_shape, mask_brightness_adjust=1.0, mask_contrast_adjust=1.0):
"""
prepares mask for use in webui
"""
if isinstance(mask_input, Image.Image):
mask = mask_input
else :
mask = load_image(mask_input)
mask = mask.resize(mask_shape, resample=Image.LANCZOS)
if mask_brightness_adjust != 1:
mask = TF.adjust_brightness(mask, mask_brightness_adjust)
if mask_contrast_adjust != 1:
mask = TF.adjust_contrast(mask, mask_contrast_adjust)
mask = mask.convert('L')
return mask
def check_mask_for_errors(mask_input, invert_mask=False):
extrema = mask_input.getextrema()
if (invert_mask):
if extrema == (255,255):
print("after inverting mask will be blank. ignoring mask")
return None
elif extrema == (0,0):
print("mask is blank. ignoring mask")
return None
else:
return mask_input
def get_mask(args):
return check_mask_for_errors(
prepare_mask(args.mask_file, (args.W, args.H), args.mask_contrast_adjust, args.mask_brightness_adjust)
)
def get_mask_from_file(mask_file, args):
return check_mask_for_errors(
prepare_mask(mask_file, (args.W, args.H), args.mask_contrast_adjust, args.mask_brightness_adjust)
)
def blank_if_none(mask, w, h, mode):
return Image.new(mode, (w, h), (0)) if mask is None else mask
def none_if_blank(mask):
return None if mask.getextrema() == (0,0) else mask
|