File size: 17,055 Bytes
76852be 2fdc5aa 76852be 294c364 76852be 6ccbaf5 76852be 307dca0 76852be f0d08e9 8b064fc 76852be 236b291 76852be 5c615f1 76852be 5c615f1 e48285e 11a95be 76852be 5c615f1 76852be e48285e 76852be 307dca0 76852be 5c615f1 4449bf5 0a372fd 4449bf5 5c615f1 607075f f0d08e9 f8b3240 f0d08e9 5c3ff40 f8b3240 5c3ff40 f8b3240 5c3ff40 f8b3240 5c3ff40 41e6913 5c3ff40 7c2da7f 5c3ff40 f8b3240 5c3ff40 f8b3240 5c3ff40 f8b3240 5c3ff40 5c615f1 d9bbbdf 7c2da7f d9bbbdf 7c2da7f d9bbbdf 7c2da7f d9bbbdf 7c2da7f d9bbbdf 7c2da7f 607075f bd1a3cd 5c615f1 607075f 5c615f1 607075f 5c615f1 21632b2 5c615f1 21632b2 5c615f1 76852be ad83a39 5c615f1 327dec3 76852be 236b291 9d59894 5c615f1 327dec3 5c615f1 76852be 5c615f1 76852be ad83a39 76852be 5c615f1 76852be 5c615f1 76852be 5c615f1 c489323 76852be e48285e 001f93a 76852be e48285e f634c57 e48285e 307dca0 6c4432e 5c615f1 76852be a8a9b70 94c9fad 97e1982 e4f71ac 293e141 637dd12 94c9fad e4f71ac 94c9fad e4f71ac 94c9fad 76852be 8f024bc 0f4400e 236b291 0f4400e 8f024bc 7d98ac6 8741d34 8f024bc 14775e7 c4dcb9f 8f024bc 5d609d1 76852be 1c77893 db3ea2c 7d98ac6 72978f5 07eafc9 1b72445 5d609d1 8f024bc 7d98ac6 21632b2 7d98ac6 5d609d1 7d98ac6 5d609d1 0f4400e 76852be 5c615f1 76852be 5c615f1 6c6b3e9 5c615f1 76852be ce95207 18b3978 ce95207 7c69f0a ce95207 66b39ee 7a447d5 542b2b6 94c9fad a8a9b70 97e1982 542b2b6 5c615f1 ad83a39 327dec3 c95c330 7c2da7f 3d05a7b 6c6b3e9 3d05a7b 5c615f1 c595c8a 76852be 94c9fad 76852be 94c9fad 76852be c407c22 7c69f0a f8b3240 76852be c7d0bf0 236b291 76852be c7d0bf0 76852be 3d05a7b 76852be 7fcb361 558a56a 294c364 af80463 7fcb361 cd988db a721fed 558a56a 236b291 dfcb7a3 b8895a4 89fb56f 5c615f1 76852be 89fb56f 5c615f1 76852be 6c6b3e9 2c4a7c2 76852be 5c615f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 |
import runpod
import base64
import subprocess
import concurrent
import requests
from requests.adapters import HTTPAdapter, Retry
import time
import logging
import socket
import json
import os
import uuid
import copy
from io import BytesIO
from PIL import Image, ImageOps
from firebase_admin import credentials, initialize_app, storage
from logging.handlers import SysLogHandler
import sentry_sdk
import torch
import boto3
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
from apscheduler.schedulers.background import BackgroundScheduler
import safetensors_worker
from rembg import remove, new_session
progress_url = "http://127.0.0.1:3000/sdapi/v1/progress"
automatic_session = requests.Session()
retries = Retry(total=30, backoff_factor=0.1, status_forcelist=[502, 503, 504])
automatic_session.mount('http://', HTTPAdapter(max_retries=retries))
convert_quality = int(os.environ.get('WEBP_CONVERT_QUALITY', 90))
convert_method = int(os.environ.get('WEBP_CONVERT_METHOD', 5))
convert_is_lossless = os.environ.get('WEBP_CONVERT_IS_LOSSLESS') == 'True'
images_base_url = os.environ.get(
'IMAGES_BASE_URL', 'https://images-dev.infero.ai')
api_base_url = os.environ.get('API_BASE_URL')
api_key = os.environ.get('API_KEY')
storage_bucket = os.environ.get(
'STORAGE_BUCKET', 'sd-app-dev-329dd.appspot.com')
env_name = os.environ.get('ENV_NAME', 'development')
gc_service_account_filename = os.environ.get(
'GC_SERVICE_ACCOUNT_FILENAME', 'upload-only-dev.json')
is_s3_storage = os.environ.get('STORAGE') == 's3'
sync_progress_value_interval = 1000
cred = credentials.Certificate(
f'./gc-service-accounts/{gc_service_account_filename}')
initialize_app(cred, {'storageBucket': storage_bucket})
google_storage = storage.bucket()
s3 = boto3.client('s3',
aws_access_key_id=os.environ.get('AWS_ACCESS_KEY'),
aws_secret_access_key=os.environ.get('AWS_SECRET_KEY'))
bucket_name = os.environ.get('AWS_S3_BUCKET')
syslog = SysLogHandler(address=('logs.papertrailapp.com', 25104))
format = '%(asctime)s SD: %(message)s'
formatter = logging.Formatter(format, datefmt='%b %d %H:%M:%S')
syslog.setFormatter(formatter)
logger = logging.getLogger()
logger.addHandler(syslog)
logger.setLevel(logging.INFO)
sentry_sdk.init(
dsn="https://5a9f4a774b78460f9762480e9bf17a57@o4504769730576384.ingest.sentry.io/4504769740406784",
traces_sample_rate=1.0,
environment=env_name
)
class TimeMesurer:
def __init__(self):
self.start_time = time.time()
def log_time_end(self, name):
end_time = time.time()
elapsed_time = (end_time - self.start_time) * 1000
logger.info(f"Operation: {name}, Elapsed time: {elapsed_time:.2f} ms")
safety_model_id = "./stable-diffusion-safety-checker"
safety_feature_extractor = None
safety_checker = None
def check_lora(path):
try:
result = safetensors_worker.CheckLoRA({}, path)
return result == 0
except Exception as e:
logger.error(f"Lora check error: {e}")
return False
def load_lora(url):
destination_dir = '/stable-diffusion-webui/models/Lora'
try:
# Ensure the destination directory exists
os.makedirs(destination_dir, exist_ok=True)
# Extract the file name from the URL
file_name = os.path.basename(url)
# Construct the full path for the destination file
destination_file = os.path.join(destination_dir, file_name)
# Check if the file already exists, and if it does, return the filename
if os.path.exists(destination_file):
logger.info(f"File '{file_name}' already exists in '{destination_dir}'.")
return file_name
# Download the file from the URL
response = requests.get(url)
# Check if the download was successful (status code 200)
if response.status_code == 200:
with open(destination_file, 'wb') as f:
f.write(response.content)
logger.info(f"File '{file_name}' downloaded and saved to '{destination_dir}'.")
# Call check_lora to verify the downloaded file
# if not check_lora(destination_file):
# os.remove(destination_file)
# raise ValueError(f"Failed to verify file '{file_name}' as a valid LoRA file.")
return file_name
else:
raise ValueError(f"Failed to download file from '{url}'. Status code: {response.status_code}")
except Exception as e:
raise ValueError(f"An error occurred: {e}")
def refresh_loras():
try:
url = 'http://127.0.0.1:3000/sdapi/v1/refresh-loras'
response = automatic_session.post(url)
# Check if the request was successful (status code 200)
if response.status_code == 200:
logger.info("Refreshed Loras successfully.")
return True
else:
logger.error(f"Failed to refresh Loras. Status code: {response.status_code}")
return False
except Exception as e:
logger.error(f"An error occurred: {e}")
return False
xl_img2img_initialized = False
def init_xl_img2img(model):
global xl_img2img_initialized
if xl_img2img_initialized:
return
try:
url = 'http://127.0.0.1:3000/sdapi/v1/txt2img'
data = {
"steps": 1,
"cfg_scale": 1,
"override_settings": {"sd_model_checkpoint": model}
}
txt2img_time = TimeMesurer()
response = automatic_session.post(url, json=data)
txt2img_time.log_time_end("txt2img")
xl_img2img_initialized = True
except Exception as e:
logger.error(f"Failed to init XL img2img: {e}")
def init_safety_checker():
global safety_feature_extractor, safety_checker
if safety_feature_extractor is None:
safety_feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
def check_safety(x_image):
safety_checker_input = safety_feature_extractor(
x_image, return_tensors="pt")
x_checked_image, has_nsfw_concept = safety_checker(
images=safety_checker_input.pixel_values, clip_input=safety_checker_input.pixel_values)
return has_nsfw_concept[0]
def update_progress_status(job_id, progress_status, progress_value=None):
try:
params = {
"id": job_id,
"apiKey": api_key,
"progressStatus": progress_status,
"progressValue": progress_value
}
requests.request(
method='post', url=f'{api_base_url}/updateJobProgressStatus', json=params)
except Exception as e:
logger.exception(f'update_progress_status: Exception {e}')
sd_available = False
def check_sd_availability():
global sd_available
while True:
try:
response = automatic_session.get(progress_url)
if response.status_code != 200:
raise Exception(
f"API is not available. Status code: {response.status_code}")
sd_available = True
logger.info(
f'check_api_availability: Done. status_code: {response.status_code} content: {response.content}')
return
except requests.exceptions.RequestException as e:
print(
f"check_api_availability: API is not available, retrying... ({e})")
except Exception as e:
logger.exception(f'check_api_availability: Exception {e}')
time.sleep(sync_progress_value_interval/1000)
def upload_image(image_base64):
try:
image_bytes = base64.b64decode(image_base64)
pil_image = Image.open(BytesIO(image_bytes))
is_nsfw = check_safety(pil_image)
postfix = ''
if is_nsfw:
postfix = '_nsfw'
webp_image = BytesIO()
pil_image.save(webp_image, format="WebP", lossless=convert_is_lossless,
quality=convert_quality, method=convert_method)
webp_image.seek(0) # reset the BytesIO object to the beginning
copy_image = BytesIO(webp_image.getvalue())
copy_image.seek(0)
image_name = f'generated/{uuid.uuid4()}{postfix}.webp'
upload_image_s3(image_name, webp_image)
return f'{images_base_url}/{image_name}'
except Exception as e:
logger.exception(f'Upload exception: {e}')
raise
def upload_image_google_storage(image_name, webp_image):
blob = google_storage.blob(image_name)
blob.upload_from_file(webp_image, content_type='image/webp')
blob.make_public()
logger.info('GOOGLE UPLOADED')
def upload_image_s3(image_name, webp_image):
s3.upload_fileobj(
webp_image,
bucket_name,
image_name,
ExtraArgs={'ContentType': 'image/webp'}
)
logger.info('S3 UPLOADED')
def upload_images_in_parallel(images):
with concurrent.futures.ThreadPoolExecutor() as executor:
uploaded_urls = list(executor.map(upload_image, images))
return uploaded_urls
def remove_base64_prefix(base64_string):
# Check if the string contains a comma, which typically follows the MIME type in the prefix
if ',' in base64_string:
# Split the string on the first comma and take the second part
return base64_string.split(',', 1)[1]
else:
# If there's no comma, return the string as it is
return base64_string
def prepare_mask(image):
# Convert the image to RGBA if it's not already
if image.mode != 'RGBA':
image = image.convert('RGBA')
# Process each pixel
pixels = image.load()
for i in range(image.width):
for j in range(image.height):
r, g, b, a = pixels[i, j]
# Calculate the transparency based on the brightness of the pixel
# Black (0, 0, 0) becomes white with full opacity (255, 255, 255, 255)
# White (255, 255, 255) becomes fully transparent (255, 255, 255, 0)
# Shades of gray become varying levels of transparent white
new_alpha = 255 - int((r + g + b) / 3)
pixels[i, j] = (255, 255, 255, new_alpha)
return image
def pil_image_to_base64(pil_img, format="PNG"):
buffered = BytesIO()
pil_img.save(buffered, format=format)
img_str = base64.b64encode(buffered.getvalue()).decode()
return f'data:image/png;base64,{img_str}'
def get_removed_bg_mask(base64_image):
decoded_image = base64.b64decode(base64_image)
with Image.open(BytesIO(decoded_image)) as input_image:
session = new_session("silueta")
# output_mask = remove(input_image, session=session, only_mask=True, post_process_mask=True, alpha_matting=True, alpha_matting_erode_size=5)
output_mask = remove(input_image, session=session, only_mask=True)
processed_image = prepare_mask(output_mask)
return pil_image_to_base64(processed_image)
def get_image_base64_by_url(image_url):
response = requests.get(image_url)
return base64.b64encode(response.content).decode('utf-8')
def sync_progress_value(job_id, params):
try:
response = automatic_session.get(progress_url)
result = response.json()
state = result["state"]
adetailers_count = len(params.get("alwayson_scripts", {}).get("adetailer", {}).get("args", []))
images_count = params.get("batch_size", 1)
adetailer_jobs_count = adetailers_count * images_count
job_no = state.get('job_no', 0)
job_count = state.get('job_count', 0)
job_progress = result["progress"]
with_adetailer = adetailer_jobs_count > 0
if job_progress is not None and job_progress > 0.01 and job_count > 0:
logger.info(f'STATE: {result["state"]}, textinfo: {result["textinfo"]}')
if job_count > 1 and with_adetailer:
adetailer_progress = job_no / adetailer_jobs_count
if adetailer_progress < 0.1:
adetailer_progress = 0.1
if adetailer_progress > 0.9:
adetailer_progress = 0.9
final_adetailer_progress = 0.5 + adetailer_progress / 2
logger.info(f'job_no: {job_no}, job_count: {job_count}, adetailer_jobs_count: {adetailer_jobs_count}, final_adetailer_progress: {final_adetailer_progress}')
update_progress_status(
job_id=job_id, progress_status="FIXING_DETAILS", progress_value=final_adetailer_progress)
else:
generating_progress = job_progress / 2 if with_adetailer else job_progress
update_progress_status(
job_id=job_id, progress_status="GENERATING", progress_value=generating_progress)
except Exception as e:
logger.exception(f'sync_progress_value: Exception {e}')
def handler(event):
global sd_available
progress_sync_scheduler = BackgroundScheduler()
scheduler_started = False
try:
job_id = event.get("id")
input_event = event.get('input', {})
endpoint = input_event.get("endpoint", "txt2img")
params = input_event.get("params", {})
method = input_event.get("method", "post")
is_xl = input_event.get("is_xl", False)
images_field = input_event.get("images_field", "images")
image_field = input_event.get("image_field", "image")
lora_urls = params.get("lora_urls")
is_fix_outpaint = params.get('is_fix_outpaint', False)
logger.info(f'is_xl {is_xl}')
if endpoint == "img2bgmask":
init_image = params.get('image')
if init_image.startswith("https"):
init_image = get_image_base64_by_url(init_image)
mask = get_removed_bg_mask(remove_base64_prefix(init_image))
return {"mask": mask}
if not sd_available:
update_progress_status(
job_id=job_id, progress_status="SERVER_STARTING")
check_sd_availability()
logger.info('run handler')
if is_xl and endpoint == 'img2img':
model = params.get('override_settings', {}).get('sd_model_checkpoint')
init_xl_img2img(model)
if endpoint == 'txt2img' or endpoint == 'img2img':
scheduler_started = True
progress_sync_scheduler.add_job(
sync_progress_value, 'interval', seconds=sync_progress_value_interval/1000, args=(job_id, params))
progress_sync_scheduler.start()
init_safety_checker()
if endpoint == "extra-single-image":
image_url = params.get('image')
params['image'] = get_image_base64_by_url(image_url)
if endpoint == "img2img":
init_image = params.get('init_images', [])[0]
if init_image.startswith("https"):
image_base64 = get_image_base64_by_url(init_image)
params['init_images'] = [image_base64]
if lora_urls is not None:
for lora_url in lora_urls:
if lora_url is not None:
load_lora(lora_url)
refresh_loras()
host = f'http://127.0.0.1:3000/sdapi/v1/{endpoint}'
request_time = TimeMesurer()
response = automatic_session.request(method=method, url=host, json=params)
result = response.json()
request_time.log_time_end("request_time")
if endpoint == 'txt2img' or endpoint == 'img2img':
progress_sync_scheduler.shutdown()
if is_fix_outpaint:
first_result = copy.deepcopy(result)
update_progress_status(
job_id=job_id, progress_status="FIXING_OUTPAINT")
new_params = copy.deepcopy(params)
new_params['denoising_strength'] = params['outpaint_fix_noise_strength']
new_params['mask_blur'] = 30
new_params['inpainting_fill'] = 1
new_params['init_images'] = first_result[images_field]
response = automatic_session.request(method=method, url=host, json=new_params)
result = response.json()
if result is not None and images_field in result:
update_progress_status(
job_id=job_id, progress_status="UPLOADING_IMAGES")
result[images_field] = upload_images_in_parallel(
result[images_field])
if result is not None and image_field in result:
update_progress_status(
job_id=job_id, progress_status="UPLOADING_IMAGES")
result[image_field] = upload_image(result[image_field])
return result
except Exception as e:
logger.exception(f'An exception occurred: {e}')
if scheduler_started:
progress_sync_scheduler.shutdown()
raise
runpod.serverless.start({"handler": handler})
|