import runpod import base64 import subprocess import concurrent import requests from requests.adapters import HTTPAdapter, Retry import time import logging import socket import json import os import uuid import copy from io import BytesIO from PIL import Image, ImageOps from firebase_admin import credentials, initialize_app, storage from logging.handlers import SysLogHandler import sentry_sdk import torch import boto3 from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from transformers import AutoFeatureExtractor from apscheduler.schedulers.background import BackgroundScheduler import safetensors_worker from rembg import remove, new_session progress_url = "http://127.0.0.1:3000/sdapi/v1/progress" automatic_session = requests.Session() retries = Retry(total=30, backoff_factor=0.1, status_forcelist=[502, 503, 504]) automatic_session.mount('http://', HTTPAdapter(max_retries=retries)) convert_quality = int(os.environ.get('WEBP_CONVERT_QUALITY', 90)) convert_method = int(os.environ.get('WEBP_CONVERT_METHOD', 5)) convert_is_lossless = os.environ.get('WEBP_CONVERT_IS_LOSSLESS') == 'True' images_base_url = os.environ.get( 'IMAGES_BASE_URL', 'https://images-dev.infero.ai') api_base_url = os.environ.get('API_BASE_URL') api_key = os.environ.get('API_KEY') storage_bucket = os.environ.get( 'STORAGE_BUCKET', 'sd-app-dev-329dd.appspot.com') env_name = os.environ.get('ENV_NAME', 'development') gc_service_account_filename = os.environ.get( 'GC_SERVICE_ACCOUNT_FILENAME', 'upload-only-dev.json') is_s3_storage = os.environ.get('STORAGE') == 's3' sync_progress_value_interval = 1000 cred = credentials.Certificate( f'./gc-service-accounts/{gc_service_account_filename}') initialize_app(cred, {'storageBucket': storage_bucket}) google_storage = storage.bucket() s3 = boto3.client('s3', aws_access_key_id=os.environ.get('AWS_ACCESS_KEY'), aws_secret_access_key=os.environ.get('AWS_SECRET_KEY')) bucket_name = os.environ.get('AWS_S3_BUCKET') syslog = SysLogHandler(address=('logs.papertrailapp.com', 25104)) format = '%(asctime)s SD: %(message)s' formatter = logging.Formatter(format, datefmt='%b %d %H:%M:%S') syslog.setFormatter(formatter) logger = logging.getLogger() logger.addHandler(syslog) logger.setLevel(logging.INFO) sentry_sdk.init( dsn="https://5a9f4a774b78460f9762480e9bf17a57@o4504769730576384.ingest.sentry.io/4504769740406784", traces_sample_rate=1.0, environment=env_name ) class TimeMesurer: def __init__(self): self.start_time = time.time() def log_time_end(self, name): end_time = time.time() elapsed_time = (end_time - self.start_time) * 1000 logger.info(f"Operation: {name}, Elapsed time: {elapsed_time:.2f} ms") safety_model_id = "./stable-diffusion-safety-checker" safety_feature_extractor = None safety_checker = None def check_lora(path): try: result = safetensors_worker.CheckLoRA({}, path) return result == 0 except Exception as e: logger.error(f"Lora check error: {e}") return False def load_lora(url): destination_dir = '/stable-diffusion-webui/models/Lora' try: # Ensure the destination directory exists os.makedirs(destination_dir, exist_ok=True) # Extract the file name from the URL file_name = os.path.basename(url) # Construct the full path for the destination file destination_file = os.path.join(destination_dir, file_name) # Check if the file already exists, and if it does, return the filename if os.path.exists(destination_file): logger.info(f"File '{file_name}' already exists in '{destination_dir}'.") return file_name # Download the file from the URL response = requests.get(url) # Check if the download was successful (status code 200) if response.status_code == 200: with open(destination_file, 'wb') as f: f.write(response.content) logger.info(f"File '{file_name}' downloaded and saved to '{destination_dir}'.") # Call check_lora to verify the downloaded file # if not check_lora(destination_file): # os.remove(destination_file) # raise ValueError(f"Failed to verify file '{file_name}' as a valid LoRA file.") return file_name else: raise ValueError(f"Failed to download file from '{url}'. Status code: {response.status_code}") except Exception as e: raise ValueError(f"An error occurred: {e}") def refresh_loras(): try: url = 'http://127.0.0.1:3000/sdapi/v1/refresh-loras' response = automatic_session.post(url) # Check if the request was successful (status code 200) if response.status_code == 200: logger.info("Refreshed Loras successfully.") return True else: logger.error(f"Failed to refresh Loras. Status code: {response.status_code}") return False except Exception as e: logger.error(f"An error occurred: {e}") return False xl_img2img_initialized = False def init_xl_img2img(model): global xl_img2img_initialized if xl_img2img_initialized: return try: url = 'http://127.0.0.1:3000/sdapi/v1/txt2img' data = { "steps": 1, "cfg_scale": 1, "override_settings": {"sd_model_checkpoint": model} } txt2img_time = TimeMesurer() response = automatic_session.post(url, json=data) txt2img_time.log_time_end("txt2img") xl_img2img_initialized = True except Exception as e: logger.error(f"Failed to init XL img2img: {e}") def init_safety_checker(): global safety_feature_extractor, safety_checker if safety_feature_extractor is None: safety_feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") def check_safety(x_image): safety_checker_input = safety_feature_extractor( x_image, return_tensors="pt") x_checked_image, has_nsfw_concept = safety_checker( images=safety_checker_input.pixel_values, clip_input=safety_checker_input.pixel_values) return has_nsfw_concept[0] def update_progress_status(job_id, progress_status, progress_value=None): try: params = { "id": job_id, "apiKey": api_key, "progressStatus": progress_status, "progressValue": progress_value } requests.request( method='post', url=f'{api_base_url}/updateJobProgressStatus', json=params) except Exception as e: logger.exception(f'update_progress_status: Exception {e}') sd_available = False def check_sd_availability(): global sd_available while True: try: response = automatic_session.get(progress_url) if response.status_code != 200: raise Exception( f"API is not available. Status code: {response.status_code}") sd_available = True logger.info( f'check_api_availability: Done. status_code: {response.status_code} content: {response.content}') return except requests.exceptions.RequestException as e: print( f"check_api_availability: API is not available, retrying... ({e})") except Exception as e: logger.exception(f'check_api_availability: Exception {e}') time.sleep(sync_progress_value_interval/1000) def upload_image(image_base64): try: image_bytes = base64.b64decode(image_base64) pil_image = Image.open(BytesIO(image_bytes)) is_nsfw = check_safety(pil_image) postfix = '' if is_nsfw: postfix = '_nsfw' webp_image = BytesIO() pil_image.save(webp_image, format="WebP", lossless=convert_is_lossless, quality=convert_quality, method=convert_method) webp_image.seek(0) # reset the BytesIO object to the beginning copy_image = BytesIO(webp_image.getvalue()) copy_image.seek(0) image_name = f'generated/{uuid.uuid4()}{postfix}.webp' upload_image_s3(image_name, webp_image) return f'{images_base_url}/{image_name}' except Exception as e: logger.exception(f'Upload exception: {e}') raise def upload_image_google_storage(image_name, webp_image): blob = google_storage.blob(image_name) blob.upload_from_file(webp_image, content_type='image/webp') blob.make_public() logger.info('GOOGLE UPLOADED') def upload_image_s3(image_name, webp_image): s3.upload_fileobj( webp_image, bucket_name, image_name, ExtraArgs={'ContentType': 'image/webp'} ) logger.info('S3 UPLOADED') def upload_images_in_parallel(images): with concurrent.futures.ThreadPoolExecutor() as executor: uploaded_urls = list(executor.map(upload_image, images)) return uploaded_urls def remove_base64_prefix(base64_string): # Check if the string contains a comma, which typically follows the MIME type in the prefix if ',' in base64_string: # Split the string on the first comma and take the second part return base64_string.split(',', 1)[1] else: # If there's no comma, return the string as it is return base64_string def prepare_mask(image): # Convert the image to RGBA if it's not already if image.mode != 'RGBA': image = image.convert('RGBA') # Process each pixel pixels = image.load() for i in range(image.width): for j in range(image.height): r, g, b, a = pixels[i, j] # Calculate the transparency based on the brightness of the pixel # Black (0, 0, 0) becomes white with full opacity (255, 255, 255, 255) # White (255, 255, 255) becomes fully transparent (255, 255, 255, 0) # Shades of gray become varying levels of transparent white new_alpha = 255 - int((r + g + b) / 3) pixels[i, j] = (255, 255, 255, new_alpha) return image def pil_image_to_base64(pil_img, format="PNG"): buffered = BytesIO() pil_img.save(buffered, format=format) img_str = base64.b64encode(buffered.getvalue()).decode() return f'data:image/png;base64,{img_str}' def get_removed_bg_mask(base64_image): decoded_image = base64.b64decode(base64_image) with Image.open(BytesIO(decoded_image)) as input_image: session = new_session("silueta") # output_mask = remove(input_image, session=session, only_mask=True, post_process_mask=True, alpha_matting=True, alpha_matting_erode_size=5) output_mask = remove(input_image, session=session, only_mask=True) processed_image = prepare_mask(output_mask) return pil_image_to_base64(processed_image) def get_image_base64_by_url(image_url): response = requests.get(image_url) return base64.b64encode(response.content).decode('utf-8') def sync_progress_value(job_id, params): try: response = automatic_session.get(progress_url) result = response.json() state = result["state"] adetailers_count = len(params.get("alwayson_scripts", {}).get("adetailer", {}).get("args", [])) images_count = params.get("batch_size", 1) adetailer_jobs_count = adetailers_count * images_count job_no = state.get('job_no', 0) job_count = state.get('job_count', 0) job_progress = result["progress"] with_adetailer = adetailer_jobs_count > 0 if job_progress is not None and job_progress > 0.01 and job_count > 0: logger.info(f'STATE: {result["state"]}, textinfo: {result["textinfo"]}') if job_count > 1 and with_adetailer: adetailer_progress = job_no / adetailer_jobs_count if adetailer_progress < 0.1: adetailer_progress = 0.1 if adetailer_progress > 0.9: adetailer_progress = 0.9 final_adetailer_progress = 0.5 + adetailer_progress / 2 logger.info(f'job_no: {job_no}, job_count: {job_count}, adetailer_jobs_count: {adetailer_jobs_count}, final_adetailer_progress: {final_adetailer_progress}') update_progress_status( job_id=job_id, progress_status="FIXING_DETAILS", progress_value=final_adetailer_progress) else: generating_progress = job_progress / 2 if with_adetailer else job_progress update_progress_status( job_id=job_id, progress_status="GENERATING", progress_value=generating_progress) except Exception as e: logger.exception(f'sync_progress_value: Exception {e}') def handler(event): global sd_available progress_sync_scheduler = BackgroundScheduler() scheduler_started = False try: job_id = event.get("id") input_event = event.get('input', {}) endpoint = input_event.get("endpoint", "txt2img") params = input_event.get("params", {}) method = input_event.get("method", "post") is_xl = input_event.get("is_xl", False) images_field = input_event.get("images_field", "images") image_field = input_event.get("image_field", "image") lora_urls = params.get("lora_urls") is_fix_outpaint = params.get('is_fix_outpaint', False) logger.info(f'is_xl {is_xl}') if endpoint == "img2bgmask": init_image = params.get('image') if init_image.startswith("https"): init_image = get_image_base64_by_url(init_image) mask = get_removed_bg_mask(remove_base64_prefix(init_image)) return {"mask": mask} if not sd_available: update_progress_status( job_id=job_id, progress_status="SERVER_STARTING") check_sd_availability() logger.info('run handler') if is_xl and endpoint == 'img2img': model = params.get('override_settings', {}).get('sd_model_checkpoint') init_xl_img2img(model) if endpoint == 'txt2img' or endpoint == 'img2img': scheduler_started = True progress_sync_scheduler.add_job( sync_progress_value, 'interval', seconds=sync_progress_value_interval/1000, args=(job_id, params)) progress_sync_scheduler.start() init_safety_checker() if endpoint == "extra-single-image": image_url = params.get('image') params['image'] = get_image_base64_by_url(image_url) if endpoint == "img2img": init_image = params.get('init_images', [])[0] if init_image.startswith("https"): image_base64 = get_image_base64_by_url(init_image) params['init_images'] = [image_base64] if lora_urls is not None: for lora_url in lora_urls: if lora_url is not None: load_lora(lora_url) refresh_loras() host = f'http://127.0.0.1:3000/sdapi/v1/{endpoint}' request_time = TimeMesurer() response = automatic_session.request(method=method, url=host, json=params) result = response.json() request_time.log_time_end("request_time") if endpoint == 'txt2img' or endpoint == 'img2img': progress_sync_scheduler.shutdown() if is_fix_outpaint: first_result = copy.deepcopy(result) update_progress_status( job_id=job_id, progress_status="FIXING_OUTPAINT") new_params = copy.deepcopy(params) new_params['denoising_strength'] = params['outpaint_fix_noise_strength'] new_params['mask_blur'] = 30 new_params['inpainting_fill'] = 1 new_params['init_images'] = first_result[images_field] response = automatic_session.request(method=method, url=host, json=new_params) result = response.json() if result is not None and images_field in result: update_progress_status( job_id=job_id, progress_status="UPLOADING_IMAGES") result[images_field] = upload_images_in_parallel( result[images_field]) if result is not None and image_field in result: update_progress_status( job_id=job_id, progress_status="UPLOADING_IMAGES") result[image_field] = upload_image(result[image_field]) return result except Exception as e: logger.exception(f'An exception occurred: {e}') if scheduler_started: progress_sync_scheduler.shutdown() raise runpod.serverless.start({"handler": handler})