test-model / handler.py
kollof's picture
Update handler.py
bd1a3cd verified
import runpod
import base64
import subprocess
import concurrent
import requests
from requests.adapters import HTTPAdapter, Retry
import time
import logging
import socket
import json
import os
import uuid
import copy
from io import BytesIO
from PIL import Image, ImageOps
from firebase_admin import credentials, initialize_app, storage
from logging.handlers import SysLogHandler
import sentry_sdk
import torch
import boto3
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
from apscheduler.schedulers.background import BackgroundScheduler
import safetensors_worker
from rembg import remove, new_session
progress_url = "http://127.0.0.1:3000/sdapi/v1/progress"
automatic_session = requests.Session()
retries = Retry(total=30, backoff_factor=0.1, status_forcelist=[502, 503, 504])
automatic_session.mount('http://', HTTPAdapter(max_retries=retries))
convert_quality = int(os.environ.get('WEBP_CONVERT_QUALITY', 90))
convert_method = int(os.environ.get('WEBP_CONVERT_METHOD', 5))
convert_is_lossless = os.environ.get('WEBP_CONVERT_IS_LOSSLESS') == 'True'
images_base_url = os.environ.get(
'IMAGES_BASE_URL', 'https://images-dev.infero.ai')
api_base_url = os.environ.get('API_BASE_URL')
api_key = os.environ.get('API_KEY')
storage_bucket = os.environ.get(
'STORAGE_BUCKET', 'sd-app-dev-329dd.appspot.com')
env_name = os.environ.get('ENV_NAME', 'development')
gc_service_account_filename = os.environ.get(
'GC_SERVICE_ACCOUNT_FILENAME', 'upload-only-dev.json')
is_s3_storage = os.environ.get('STORAGE') == 's3'
sync_progress_value_interval = 1000
cred = credentials.Certificate(
f'./gc-service-accounts/{gc_service_account_filename}')
initialize_app(cred, {'storageBucket': storage_bucket})
google_storage = storage.bucket()
s3 = boto3.client('s3',
aws_access_key_id=os.environ.get('AWS_ACCESS_KEY'),
aws_secret_access_key=os.environ.get('AWS_SECRET_KEY'))
bucket_name = os.environ.get('AWS_S3_BUCKET')
syslog = SysLogHandler(address=('logs.papertrailapp.com', 25104))
format = '%(asctime)s SD: %(message)s'
formatter = logging.Formatter(format, datefmt='%b %d %H:%M:%S')
syslog.setFormatter(formatter)
logger = logging.getLogger()
logger.addHandler(syslog)
logger.setLevel(logging.INFO)
sentry_sdk.init(
dsn="https://5a9f4a774b78460f9762480e9bf17a57@o4504769730576384.ingest.sentry.io/4504769740406784",
traces_sample_rate=1.0,
environment=env_name
)
class TimeMesurer:
def __init__(self):
self.start_time = time.time()
def log_time_end(self, name):
end_time = time.time()
elapsed_time = (end_time - self.start_time) * 1000
logger.info(f"Operation: {name}, Elapsed time: {elapsed_time:.2f} ms")
safety_model_id = "./stable-diffusion-safety-checker"
safety_feature_extractor = None
safety_checker = None
def check_lora(path):
try:
result = safetensors_worker.CheckLoRA({}, path)
return result == 0
except Exception as e:
logger.error(f"Lora check error: {e}")
return False
def load_lora(url):
destination_dir = '/stable-diffusion-webui/models/Lora'
try:
# Ensure the destination directory exists
os.makedirs(destination_dir, exist_ok=True)
# Extract the file name from the URL
file_name = os.path.basename(url)
# Construct the full path for the destination file
destination_file = os.path.join(destination_dir, file_name)
# Check if the file already exists, and if it does, return the filename
if os.path.exists(destination_file):
logger.info(f"File '{file_name}' already exists in '{destination_dir}'.")
return file_name
# Download the file from the URL
response = requests.get(url)
# Check if the download was successful (status code 200)
if response.status_code == 200:
with open(destination_file, 'wb') as f:
f.write(response.content)
logger.info(f"File '{file_name}' downloaded and saved to '{destination_dir}'.")
# Call check_lora to verify the downloaded file
# if not check_lora(destination_file):
# os.remove(destination_file)
# raise ValueError(f"Failed to verify file '{file_name}' as a valid LoRA file.")
return file_name
else:
raise ValueError(f"Failed to download file from '{url}'. Status code: {response.status_code}")
except Exception as e:
raise ValueError(f"An error occurred: {e}")
def refresh_loras():
try:
url = 'http://127.0.0.1:3000/sdapi/v1/refresh-loras'
response = automatic_session.post(url)
# Check if the request was successful (status code 200)
if response.status_code == 200:
logger.info("Refreshed Loras successfully.")
return True
else:
logger.error(f"Failed to refresh Loras. Status code: {response.status_code}")
return False
except Exception as e:
logger.error(f"An error occurred: {e}")
return False
xl_img2img_initialized = False
def init_xl_img2img(model):
global xl_img2img_initialized
if xl_img2img_initialized:
return
try:
url = 'http://127.0.0.1:3000/sdapi/v1/txt2img'
data = {
"steps": 1,
"cfg_scale": 1,
"override_settings": {"sd_model_checkpoint": model}
}
txt2img_time = TimeMesurer()
response = automatic_session.post(url, json=data)
txt2img_time.log_time_end("txt2img")
xl_img2img_initialized = True
except Exception as e:
logger.error(f"Failed to init XL img2img: {e}")
def init_safety_checker():
global safety_feature_extractor, safety_checker
if safety_feature_extractor is None:
safety_feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
def check_safety(x_image):
safety_checker_input = safety_feature_extractor(
x_image, return_tensors="pt")
x_checked_image, has_nsfw_concept = safety_checker(
images=safety_checker_input.pixel_values, clip_input=safety_checker_input.pixel_values)
return has_nsfw_concept[0]
def update_progress_status(job_id, progress_status, progress_value=None):
try:
params = {
"id": job_id,
"apiKey": api_key,
"progressStatus": progress_status,
"progressValue": progress_value
}
requests.request(
method='post', url=f'{api_base_url}/updateJobProgressStatus', json=params)
except Exception as e:
logger.exception(f'update_progress_status: Exception {e}')
sd_available = False
def check_sd_availability():
global sd_available
while True:
try:
response = automatic_session.get(progress_url)
if response.status_code != 200:
raise Exception(
f"API is not available. Status code: {response.status_code}")
sd_available = True
logger.info(
f'check_api_availability: Done. status_code: {response.status_code} content: {response.content}')
return
except requests.exceptions.RequestException as e:
print(
f"check_api_availability: API is not available, retrying... ({e})")
except Exception as e:
logger.exception(f'check_api_availability: Exception {e}')
time.sleep(sync_progress_value_interval/1000)
def upload_image(image_base64):
try:
image_bytes = base64.b64decode(image_base64)
pil_image = Image.open(BytesIO(image_bytes))
is_nsfw = check_safety(pil_image)
postfix = ''
if is_nsfw:
postfix = '_nsfw'
webp_image = BytesIO()
pil_image.save(webp_image, format="WebP", lossless=convert_is_lossless,
quality=convert_quality, method=convert_method)
webp_image.seek(0) # reset the BytesIO object to the beginning
copy_image = BytesIO(webp_image.getvalue())
copy_image.seek(0)
image_name = f'generated/{uuid.uuid4()}{postfix}.webp'
upload_image_s3(image_name, webp_image)
return f'{images_base_url}/{image_name}'
except Exception as e:
logger.exception(f'Upload exception: {e}')
raise
def upload_image_google_storage(image_name, webp_image):
blob = google_storage.blob(image_name)
blob.upload_from_file(webp_image, content_type='image/webp')
blob.make_public()
logger.info('GOOGLE UPLOADED')
def upload_image_s3(image_name, webp_image):
s3.upload_fileobj(
webp_image,
bucket_name,
image_name,
ExtraArgs={'ContentType': 'image/webp'}
)
logger.info('S3 UPLOADED')
def upload_images_in_parallel(images):
with concurrent.futures.ThreadPoolExecutor() as executor:
uploaded_urls = list(executor.map(upload_image, images))
return uploaded_urls
def remove_base64_prefix(base64_string):
# Check if the string contains a comma, which typically follows the MIME type in the prefix
if ',' in base64_string:
# Split the string on the first comma and take the second part
return base64_string.split(',', 1)[1]
else:
# If there's no comma, return the string as it is
return base64_string
def prepare_mask(image):
# Convert the image to RGBA if it's not already
if image.mode != 'RGBA':
image = image.convert('RGBA')
# Process each pixel
pixels = image.load()
for i in range(image.width):
for j in range(image.height):
r, g, b, a = pixels[i, j]
# Calculate the transparency based on the brightness of the pixel
# Black (0, 0, 0) becomes white with full opacity (255, 255, 255, 255)
# White (255, 255, 255) becomes fully transparent (255, 255, 255, 0)
# Shades of gray become varying levels of transparent white
new_alpha = 255 - int((r + g + b) / 3)
pixels[i, j] = (255, 255, 255, new_alpha)
return image
def pil_image_to_base64(pil_img, format="PNG"):
buffered = BytesIO()
pil_img.save(buffered, format=format)
img_str = base64.b64encode(buffered.getvalue()).decode()
return f'data:image/png;base64,{img_str}'
def get_removed_bg_mask(base64_image):
decoded_image = base64.b64decode(base64_image)
with Image.open(BytesIO(decoded_image)) as input_image:
session = new_session("silueta")
# output_mask = remove(input_image, session=session, only_mask=True, post_process_mask=True, alpha_matting=True, alpha_matting_erode_size=5)
output_mask = remove(input_image, session=session, only_mask=True)
processed_image = prepare_mask(output_mask)
return pil_image_to_base64(processed_image)
def get_image_base64_by_url(image_url):
response = requests.get(image_url)
return base64.b64encode(response.content).decode('utf-8')
def sync_progress_value(job_id, params):
try:
response = automatic_session.get(progress_url)
result = response.json()
state = result["state"]
adetailers_count = len(params.get("alwayson_scripts", {}).get("adetailer", {}).get("args", []))
images_count = params.get("batch_size", 1)
adetailer_jobs_count = adetailers_count * images_count
job_no = state.get('job_no', 0)
job_count = state.get('job_count', 0)
job_progress = result["progress"]
with_adetailer = adetailer_jobs_count > 0
if job_progress is not None and job_progress > 0.01 and job_count > 0:
logger.info(f'STATE: {result["state"]}, textinfo: {result["textinfo"]}')
if job_count > 1 and with_adetailer:
adetailer_progress = job_no / adetailer_jobs_count
if adetailer_progress < 0.1:
adetailer_progress = 0.1
if adetailer_progress > 0.9:
adetailer_progress = 0.9
final_adetailer_progress = 0.5 + adetailer_progress / 2
logger.info(f'job_no: {job_no}, job_count: {job_count}, adetailer_jobs_count: {adetailer_jobs_count}, final_adetailer_progress: {final_adetailer_progress}')
update_progress_status(
job_id=job_id, progress_status="FIXING_DETAILS", progress_value=final_adetailer_progress)
else:
generating_progress = job_progress / 2 if with_adetailer else job_progress
update_progress_status(
job_id=job_id, progress_status="GENERATING", progress_value=generating_progress)
except Exception as e:
logger.exception(f'sync_progress_value: Exception {e}')
def handler(event):
global sd_available
progress_sync_scheduler = BackgroundScheduler()
scheduler_started = False
try:
job_id = event.get("id")
input_event = event.get('input', {})
endpoint = input_event.get("endpoint", "txt2img")
params = input_event.get("params", {})
method = input_event.get("method", "post")
is_xl = input_event.get("is_xl", False)
images_field = input_event.get("images_field", "images")
image_field = input_event.get("image_field", "image")
lora_urls = params.get("lora_urls")
is_fix_outpaint = params.get('is_fix_outpaint', False)
logger.info(f'is_xl {is_xl}')
if endpoint == "img2bgmask":
init_image = params.get('image')
if init_image.startswith("https"):
init_image = get_image_base64_by_url(init_image)
mask = get_removed_bg_mask(remove_base64_prefix(init_image))
return {"mask": mask}
if not sd_available:
update_progress_status(
job_id=job_id, progress_status="SERVER_STARTING")
check_sd_availability()
logger.info('run handler')
if is_xl and endpoint == 'img2img':
model = params.get('override_settings', {}).get('sd_model_checkpoint')
init_xl_img2img(model)
if endpoint == 'txt2img' or endpoint == 'img2img':
scheduler_started = True
progress_sync_scheduler.add_job(
sync_progress_value, 'interval', seconds=sync_progress_value_interval/1000, args=(job_id, params))
progress_sync_scheduler.start()
init_safety_checker()
if endpoint == "extra-single-image":
image_url = params.get('image')
params['image'] = get_image_base64_by_url(image_url)
if endpoint == "img2img":
init_image = params.get('init_images', [])[0]
if init_image.startswith("https"):
image_base64 = get_image_base64_by_url(init_image)
params['init_images'] = [image_base64]
if lora_urls is not None:
for lora_url in lora_urls:
if lora_url is not None:
load_lora(lora_url)
refresh_loras()
host = f'http://127.0.0.1:3000/sdapi/v1/{endpoint}'
request_time = TimeMesurer()
response = automatic_session.request(method=method, url=host, json=params)
result = response.json()
request_time.log_time_end("request_time")
if endpoint == 'txt2img' or endpoint == 'img2img':
progress_sync_scheduler.shutdown()
if is_fix_outpaint:
first_result = copy.deepcopy(result)
update_progress_status(
job_id=job_id, progress_status="FIXING_OUTPAINT")
new_params = copy.deepcopy(params)
new_params['denoising_strength'] = params['outpaint_fix_noise_strength']
new_params['mask_blur'] = 30
new_params['inpainting_fill'] = 1
new_params['init_images'] = first_result[images_field]
response = automatic_session.request(method=method, url=host, json=new_params)
result = response.json()
if result is not None and images_field in result:
update_progress_status(
job_id=job_id, progress_status="UPLOADING_IMAGES")
result[images_field] = upload_images_in_parallel(
result[images_field])
if result is not None and image_field in result:
update_progress_status(
job_id=job_id, progress_status="UPLOADING_IMAGES")
result[image_field] = upload_image(result[image_field])
return result
except Exception as e:
logger.exception(f'An exception occurred: {e}')
if scheduler_started:
progress_sync_scheduler.shutdown()
raise
runpod.serverless.start({"handler": handler})