File size: 5,100 Bytes
d5e1b1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# Copyright (C) 2023 Deforum LLC
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

# Contact the authors: https://deforum.github.io/

import requests
import os
from PIL import Image
import socket
import torchvision.transforms.functional as TF
from .general_utils import clean_gradio_path_strings

def load_img(path : str, image_box :Image.Image, shape=None, use_alpha_as_mask=False):
    # use_alpha_as_mask: Read the alpha channel of the image as the mask image
    image = load_image(path, image_box)
    image = image.convert('RGBA') if use_alpha_as_mask else image.convert('RGB')
    image = image.resize(shape, resample=Image.LANCZOS) if shape is not None else image

    mask_image = None
    if use_alpha_as_mask:
        # Split alpha channel into a mask_image
        red, green, blue, alpha = Image.Image.split(image) # not interested in R G or B, just in the alpha channel
        mask_image = alpha.convert('L')
        image = image.convert('RGB')
        
        # check using init image alpha as mask if mask is not blank
        extrema = mask_image.getextrema()
        if (extrema == (0,0)) or extrema == (255,255):
            print("use_alpha_as_mask==True: Using the alpha channel from the init image as a mask, but the alpha channel is blank.")
            print("ignoring alpha as mask.")
            mask_image = None

    return image, mask_image

def load_image(image_path :str, image_box :Image.Image):
    # If init_image_box was used then no need to fetch the image via URL, just return the Image object directly.
    if isinstance(image_box, Image.Image):
        return image_box

    image_path = clean_gradio_path_strings(image_path)
    image = None
    if image_path.startswith('http://') or image_path.startswith('https://'):
        try:
            host = socket.gethostbyname("www.google.com")
            s = socket.create_connection((host, 80), 2)
            s.close()
        except:
            raise ConnectionError("There is no active internet connection available (couldn't connect to google.com as a network test) - please use *local* masks and init files only.")
        try:
            response = requests.get(image_path, stream=True)
        except requests.exceptions.RequestException as e:
            raise ConnectionError(f"Failed to download image {image_path} due to no internet connection. Error: {e}")
        if response.status_code == 404 or response.status_code != 200:
            raise ConnectionError(f"Init image url or mask image url is not valid: {image_path}")
        image = Image.open(response.raw).convert('RGB')
    else:
        image_path = os.path.realpath(image_path)
        if not os.path.exists(image_path):
            raise RuntimeError(f"Init image path or mask image path is not valid: {image_path}")
        image = Image.open(image_path).convert('RGB')
        
    return image

def prepare_mask(mask_input, mask_shape, mask_brightness_adjust=1.0, mask_contrast_adjust=1.0):
    """
    prepares mask for use in webui
    """
    # Aparently 'mask_input' can be both path and Image object.
    mask = load_image(mask_input, mask_input)
    mask = mask.resize(mask_shape, resample=Image.LANCZOS)
    if mask_brightness_adjust != 1:
        mask = TF.adjust_brightness(mask, mask_brightness_adjust)
    if mask_contrast_adjust != 1:
        mask = TF.adjust_contrast(mask, mask_contrast_adjust)
    mask = mask.convert('L')
    return mask

# "check_mask_for_errors" may have prevented errors in composable masks,
# but it CAUSES errors on any frame where it's all black.
# Bypassing the check below until we can fix it even better.
# This may break composable masks, but it makes ACTUAL masks usable.
def check_mask_for_errors(mask_input, invert_mask=False):
    extrema = mask_input.getextrema()
    if (invert_mask):
        if extrema == (255,255): 
            print("after inverting mask will be blank. ignoring mask")  
            return None
    elif extrema == (0,0): 
        print("mask is blank. ignoring mask")  
        return None
    else:
        return mask_input    
 
def get_mask(args):
    return prepare_mask(args.mask_file, (args.W, args.H), args.mask_contrast_adjust, args.mask_brightness_adjust)

def get_mask_from_file(mask_file, args):
    return prepare_mask(mask_file, (args.W, args.H), args.mask_contrast_adjust, args.mask_brightness_adjust)

def blank_if_none(mask, w, h, mode):
    return Image.new(mode, (w, h), (0)) if mask is None else mask

def none_if_blank(mask):
    return None if mask.getextrema() == (0,0) else mask