Spaces:
Running
Running
File size: 6,896 Bytes
6c85792 c8cb9bb 6c85792 2e0fb71 6c85792 a5ca5a6 a18a2d9 a5ca5a6 a18a2d9 a5ca5a6 bba2454 a5ca5a6 774d798 a5ca5a6 8d01019 6c85792 21fc719 6c85792 6d2e8db bba2454 ea1e63c 2e0fb71 6717f64 2e0fb71 020dc5b a0a768e 6717f64 a0a768e 6717f64 ea1e63c 6717f64 a0a768e 6717f64 a0a768e 38af3c8 a0a768e 6717f64 a0a768e 6717f64 2e0fb71 6717f64 ea1e63c a5ca5a6 46cde7f 31e7adc 383d7cb 6717f64 383d7cb 6717f64 0840b84 2e0fb71 6c85792 d635546 6c85792 cd73099 6c85792 ba31c3d 6c85792 5564ca3 bb4525d e2718c0 5f65d94 32d9c0d d635546 2e0fb71 9251ce3 fc4f76a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
import gradio as gr
import PIL
from PIL import Image
import numpy as np
import os
import uuid
import torch
from torch import autocast
import cv2
from io import BytesIO
from matplotlib import pyplot as plt
from torchvision import transforms
import io
import logging
import multiprocessing
import random
import time
import imghdr
from pathlib import Path
from typing import Union
from loguru import logger
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import Config
try:
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(False)
except:
pass
from lama_cleaner.helper import (
load_img,
numpy_to_bytes,
resize_max_size,
)
NUM_THREADS = str(multiprocessing.cpu_count())
# fix libomp problem on windows https://github.com/Sanster/lama-cleaner/issues/56
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
if os.environ.get("CACHE_DIR"):
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
HF_TOKEN_SD = os.environ.get('HF_TOKEN_SD')
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f'device = {device}')
def get_image_ext(img_bytes):
w = imghdr.what("", img_bytes)
if w is None:
w = "jpeg"
return w
def read_content(file_path):
"""read the content of target file
"""
with open(file_path, 'rb') as f:
content = f.read()
return content
model = None
def model_process(image, mask):
global model
if mask.shape[0] == image.shape[1] and mask.shape[1] == image.shape[0] and mask.shape[0] != mask.shape[1]:
# rotate image
image = np.transpose(image[::-1, ...][:, ::-1], axes=(1, 0, 2))[::-1, ...]
original_shape = image.shape
interpolation = cv2.INTER_CUBIC
size_limit = 1080 #1080 # "Original"
if size_limit == "Original":
size_limit = max(image.shape)
else:
size_limit = int(size_limit)
config = Config(
ldm_steps=25,
ldm_sampler='plms',
zits_wireframe=True,
hd_strategy='Original',
hd_strategy_crop_margin=196,
hd_strategy_crop_trigger_size=1280,
hd_strategy_resize_limit=2048,
prompt='',
use_croper=False,
croper_x=0,
croper_y=0,
croper_height=512,
croper_width=512,
sd_mask_blur=5,
sd_strength=0.75,
sd_steps=50,
sd_guidance_scale=7.5,
sd_sampler='ddim',
sd_seed=42,
cv2_flag='INPAINT_NS',
cv2_radius=5,
)
if config.sd_seed == -1:
config.sd_seed = random.randint(1, 999999999)
print(f"Origin image shape_0_: {original_shape} / {size_limit}")
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
print(f"Resized image shape_1_: {image.shape}")
print(f"mask image shape_0_: {mask.shape} / {type(mask)}")
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
print(f"mask image shape_1_: {mask.shape} / {type(mask)}")
if model is None:
return None
res_np_img = model(image, mask, config)
torch.cuda.empty_cache()
image = Image.open(io.BytesIO(numpy_to_bytes(res_np_img, 'png')))
return image # image
model = ModelManager(
name='lama',
device=device,
)
image_type = 'pil' # filepath'
def predict(input):
if image_type == 'filepath':
# input: {'image': '/tmp/tmp8mn9xw93.png', 'mask': '/tmp/tmpn5ars4te.png'}
origin_image_bytes = read_content(input["image"])
print(f'origin_image_bytes = ', type(origin_image_bytes), len(origin_image_bytes))
image, _ = load_img(origin_image_bytes)
mask, _ = load_img(read_content(input["mask"]), gray=True)
elif image_type == 'pil':
# input: {'image': pil, 'mask': pil}
image_pil = input['image']
mask_pil = input['mask']
image = np.array(image_pil)
mask = np.array(mask_pil.convert("L"))
output = model_process(image, mask)
return output
css = '''
.container {max-width: 98%;margin: auto;padding-top: 1.5rem}
#image_upload{min-height:768px}
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 768px}
#mask_radio .gr-form{background:transparent; border: none}
#word_mask{margin-top: .75em !important}
#word_mask textarea:disabled{opacity: 0.3}
.footer {margin-bottom: 45px;margin-top: 35px;text-align: center;border-bottom: 1px solid #e5e5e5}
.footer>p {font-size: .8rem; display: inline-block; padding: 0 10px;transform: translateY(10px);background: white}
.dark .footer {border-color: #303030}
.dark .footer>p {background: #0b0f19}
.acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%}
#image_upload .touch-none{display: flex}
@keyframes spin {
from {
transform: rotate(0deg);
}
to {
transform: rotate(360deg);
}
}
#share-btn-container {
display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
}
#share-btn {
all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;
}
#share-btn * {
all: unset;
}
#share-btn-container div:nth-child(-n+2){
width: auto !important;
min-height: 0px !important;
}
#share-btn-container .wrap {
display: none !important;
}
'''
image_blocks = gr.Blocks(css=css)
with image_blocks as demo:
with gr.Group():
with gr.Box():
with gr.Row():
with gr.Column():
image = gr.Image(source='upload', elem_id="image_upload",tool='sketch', type=f'{image_type}', label="Upload").style(mobile_collapse=False, width=768)
with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
btn_in = gr.Button("Erase").style(
margin=True,
rounded=(True, True, True, True),
full_width=True,
)
with gr.Column():
image_out = gr.Image(label="Output", elem_id="image_output", visible=True)# .style(width=768)
btn_in.click(fn=predict, inputs=[image], outputs=[image_out])
image_blocks.launch()
|