|
import runpod |
|
import base64 |
|
import subprocess |
|
import concurrent |
|
import requests |
|
from requests.adapters import HTTPAdapter, Retry |
|
import time |
|
import logging |
|
import socket |
|
import json |
|
import os |
|
import uuid |
|
import copy |
|
from io import BytesIO |
|
from PIL import Image, ImageOps |
|
from firebase_admin import credentials, initialize_app, storage |
|
from logging.handlers import SysLogHandler |
|
import sentry_sdk |
|
import torch |
|
import boto3 |
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker |
|
from transformers import AutoFeatureExtractor |
|
from apscheduler.schedulers.background import BackgroundScheduler |
|
import safetensors_worker |
|
from rembg import remove, new_session |
|
|
|
progress_url = "http://127.0.0.1:3000/sdapi/v1/progress" |
|
automatic_session = requests.Session() |
|
retries = Retry(total=30, backoff_factor=0.1, status_forcelist=[502, 503, 504]) |
|
automatic_session.mount('http://', HTTPAdapter(max_retries=retries)) |
|
|
|
convert_quality = int(os.environ.get('WEBP_CONVERT_QUALITY', 90)) |
|
convert_method = int(os.environ.get('WEBP_CONVERT_METHOD', 5)) |
|
convert_is_lossless = os.environ.get('WEBP_CONVERT_IS_LOSSLESS') == 'True' |
|
images_base_url = os.environ.get( |
|
'IMAGES_BASE_URL', 'https://images-dev.infero.ai') |
|
api_base_url = os.environ.get('API_BASE_URL') |
|
api_key = os.environ.get('API_KEY') |
|
storage_bucket = os.environ.get( |
|
'STORAGE_BUCKET', 'sd-app-dev-329dd.appspot.com') |
|
env_name = os.environ.get('ENV_NAME', 'development') |
|
gc_service_account_filename = os.environ.get( |
|
'GC_SERVICE_ACCOUNT_FILENAME', 'upload-only-dev.json') |
|
is_s3_storage = os.environ.get('STORAGE') == 's3' |
|
sync_progress_value_interval = 1000 |
|
|
|
cred = credentials.Certificate( |
|
f'./gc-service-accounts/{gc_service_account_filename}') |
|
initialize_app(cred, {'storageBucket': storage_bucket}) |
|
google_storage = storage.bucket() |
|
|
|
s3 = boto3.client('s3', |
|
aws_access_key_id=os.environ.get('AWS_ACCESS_KEY'), |
|
aws_secret_access_key=os.environ.get('AWS_SECRET_KEY')) |
|
bucket_name = os.environ.get('AWS_S3_BUCKET') |
|
|
|
syslog = SysLogHandler(address=('logs.papertrailapp.com', 25104)) |
|
format = '%(asctime)s SD: %(message)s' |
|
formatter = logging.Formatter(format, datefmt='%b %d %H:%M:%S') |
|
syslog.setFormatter(formatter) |
|
|
|
logger = logging.getLogger() |
|
logger.addHandler(syslog) |
|
logger.setLevel(logging.INFO) |
|
|
|
sentry_sdk.init( |
|
dsn="https://5a9f4a774b78460f9762480e9bf17a57@o4504769730576384.ingest.sentry.io/4504769740406784", |
|
traces_sample_rate=1.0, |
|
environment=env_name |
|
) |
|
|
|
|
|
class TimeMesurer: |
|
def __init__(self): |
|
self.start_time = time.time() |
|
|
|
def log_time_end(self, name): |
|
end_time = time.time() |
|
elapsed_time = (end_time - self.start_time) * 1000 |
|
logger.info(f"Operation: {name}, Elapsed time: {elapsed_time:.2f} ms") |
|
|
|
|
|
safety_model_id = "./stable-diffusion-safety-checker" |
|
safety_feature_extractor = None |
|
safety_checker = None |
|
|
|
def check_lora(path): |
|
try: |
|
result = safetensors_worker.CheckLoRA({}, path) |
|
return result == 0 |
|
except Exception as e: |
|
logger.error(f"Lora check error: {e}") |
|
return False |
|
|
|
def load_lora(url): |
|
destination_dir = '/stable-diffusion-webui/models/Lora' |
|
try: |
|
|
|
os.makedirs(destination_dir, exist_ok=True) |
|
|
|
|
|
file_name = os.path.basename(url) |
|
|
|
|
|
destination_file = os.path.join(destination_dir, file_name) |
|
|
|
|
|
if os.path.exists(destination_file): |
|
logger.info(f"File '{file_name}' already exists in '{destination_dir}'.") |
|
return file_name |
|
|
|
|
|
response = requests.get(url) |
|
|
|
|
|
if response.status_code == 200: |
|
with open(destination_file, 'wb') as f: |
|
f.write(response.content) |
|
logger.info(f"File '{file_name}' downloaded and saved to '{destination_dir}'.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
return file_name |
|
else: |
|
raise ValueError(f"Failed to download file from '{url}'. Status code: {response.status_code}") |
|
except Exception as e: |
|
raise ValueError(f"An error occurred: {e}") |
|
|
|
def refresh_loras(): |
|
try: |
|
url = 'http://127.0.0.1:3000/sdapi/v1/refresh-loras' |
|
response = automatic_session.post(url) |
|
|
|
|
|
if response.status_code == 200: |
|
logger.info("Refreshed Loras successfully.") |
|
return True |
|
else: |
|
logger.error(f"Failed to refresh Loras. Status code: {response.status_code}") |
|
return False |
|
except Exception as e: |
|
logger.error(f"An error occurred: {e}") |
|
return False |
|
|
|
xl_img2img_initialized = False |
|
|
|
def init_xl_img2img(model): |
|
global xl_img2img_initialized |
|
if xl_img2img_initialized: |
|
return |
|
try: |
|
url = 'http://127.0.0.1:3000/sdapi/v1/txt2img' |
|
data = { |
|
"steps": 1, |
|
"cfg_scale": 1, |
|
"override_settings": {"sd_model_checkpoint": model} |
|
} |
|
txt2img_time = TimeMesurer() |
|
response = automatic_session.post(url, json=data) |
|
txt2img_time.log_time_end("txt2img") |
|
xl_img2img_initialized = True |
|
except Exception as e: |
|
logger.error(f"Failed to init XL img2img: {e}") |
|
|
|
def init_safety_checker(): |
|
global safety_feature_extractor, safety_checker |
|
if safety_feature_extractor is None: |
|
safety_feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") |
|
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") |
|
|
|
|
|
def check_safety(x_image): |
|
safety_checker_input = safety_feature_extractor( |
|
x_image, return_tensors="pt") |
|
x_checked_image, has_nsfw_concept = safety_checker( |
|
images=safety_checker_input.pixel_values, clip_input=safety_checker_input.pixel_values) |
|
|
|
return has_nsfw_concept[0] |
|
|
|
|
|
def update_progress_status(job_id, progress_status, progress_value=None): |
|
try: |
|
params = { |
|
"id": job_id, |
|
"apiKey": api_key, |
|
"progressStatus": progress_status, |
|
"progressValue": progress_value |
|
} |
|
requests.request( |
|
method='post', url=f'{api_base_url}/updateJobProgressStatus', json=params) |
|
except Exception as e: |
|
logger.exception(f'update_progress_status: Exception {e}') |
|
|
|
|
|
sd_available = False |
|
|
|
|
|
def check_sd_availability(): |
|
global sd_available |
|
|
|
while True: |
|
try: |
|
response = automatic_session.get(progress_url) |
|
|
|
if response.status_code != 200: |
|
raise Exception( |
|
f"API is not available. Status code: {response.status_code}") |
|
|
|
sd_available = True |
|
|
|
logger.info( |
|
f'check_api_availability: Done. status_code: {response.status_code} content: {response.content}') |
|
return |
|
except requests.exceptions.RequestException as e: |
|
print( |
|
f"check_api_availability: API is not available, retrying... ({e})") |
|
except Exception as e: |
|
logger.exception(f'check_api_availability: Exception {e}') |
|
time.sleep(sync_progress_value_interval/1000) |
|
|
|
|
|
def upload_image(image_base64): |
|
try: |
|
image_bytes = base64.b64decode(image_base64) |
|
pil_image = Image.open(BytesIO(image_bytes)) |
|
|
|
is_nsfw = check_safety(pil_image) |
|
|
|
postfix = '' |
|
if is_nsfw: |
|
postfix = '_nsfw' |
|
|
|
webp_image = BytesIO() |
|
pil_image.save(webp_image, format="WebP", lossless=convert_is_lossless, |
|
quality=convert_quality, method=convert_method) |
|
webp_image.seek(0) |
|
copy_image = BytesIO(webp_image.getvalue()) |
|
copy_image.seek(0) |
|
|
|
image_name = f'generated/{uuid.uuid4()}{postfix}.webp' |
|
|
|
upload_image_s3(image_name, webp_image) |
|
|
|
return f'{images_base_url}/{image_name}' |
|
except Exception as e: |
|
logger.exception(f'Upload exception: {e}') |
|
raise |
|
|
|
def upload_image_google_storage(image_name, webp_image): |
|
blob = google_storage.blob(image_name) |
|
blob.upload_from_file(webp_image, content_type='image/webp') |
|
blob.make_public() |
|
logger.info('GOOGLE UPLOADED') |
|
|
|
def upload_image_s3(image_name, webp_image): |
|
s3.upload_fileobj( |
|
webp_image, |
|
bucket_name, |
|
image_name, |
|
ExtraArgs={'ContentType': 'image/webp'} |
|
) |
|
logger.info('S3 UPLOADED') |
|
|
|
def upload_images_in_parallel(images): |
|
with concurrent.futures.ThreadPoolExecutor() as executor: |
|
uploaded_urls = list(executor.map(upload_image, images)) |
|
return uploaded_urls |
|
|
|
def remove_base64_prefix(base64_string): |
|
|
|
if ',' in base64_string: |
|
|
|
return base64_string.split(',', 1)[1] |
|
else: |
|
|
|
return base64_string |
|
|
|
def prepare_mask(image): |
|
|
|
if image.mode != 'RGBA': |
|
image = image.convert('RGBA') |
|
|
|
|
|
pixels = image.load() |
|
for i in range(image.width): |
|
for j in range(image.height): |
|
r, g, b, a = pixels[i, j] |
|
|
|
|
|
|
|
|
|
|
|
new_alpha = 255 - int((r + g + b) / 3) |
|
pixels[i, j] = (255, 255, 255, new_alpha) |
|
|
|
return image |
|
|
|
def pil_image_to_base64(pil_img, format="PNG"): |
|
buffered = BytesIO() |
|
pil_img.save(buffered, format=format) |
|
img_str = base64.b64encode(buffered.getvalue()).decode() |
|
return f'data:image/png;base64,{img_str}' |
|
|
|
def get_removed_bg_mask(base64_image): |
|
decoded_image = base64.b64decode(base64_image) |
|
with Image.open(BytesIO(decoded_image)) as input_image: |
|
session = new_session("silueta") |
|
|
|
output_mask = remove(input_image, session=session, only_mask=True) |
|
processed_image = prepare_mask(output_mask) |
|
|
|
return pil_image_to_base64(processed_image) |
|
|
|
def get_image_base64_by_url(image_url): |
|
response = requests.get(image_url) |
|
return base64.b64encode(response.content).decode('utf-8') |
|
|
|
def sync_progress_value(job_id, params): |
|
try: |
|
response = automatic_session.get(progress_url) |
|
result = response.json() |
|
|
|
state = result["state"] |
|
adetailers_count = len(params.get("alwayson_scripts", {}).get("adetailer", {}).get("args", [])) |
|
images_count = params.get("batch_size", 1) |
|
adetailer_jobs_count = adetailers_count * images_count |
|
job_no = state.get('job_no', 0) |
|
job_count = state.get('job_count', 0) |
|
job_progress = result["progress"] |
|
with_adetailer = adetailer_jobs_count > 0 |
|
|
|
if job_progress is not None and job_progress > 0.01 and job_count > 0: |
|
logger.info(f'STATE: {result["state"]}, textinfo: {result["textinfo"]}') |
|
|
|
if job_count > 1 and with_adetailer: |
|
adetailer_progress = job_no / adetailer_jobs_count |
|
if adetailer_progress < 0.1: |
|
adetailer_progress = 0.1 |
|
if adetailer_progress > 0.9: |
|
adetailer_progress = 0.9 |
|
final_adetailer_progress = 0.5 + adetailer_progress / 2 |
|
logger.info(f'job_no: {job_no}, job_count: {job_count}, adetailer_jobs_count: {adetailer_jobs_count}, final_adetailer_progress: {final_adetailer_progress}') |
|
|
|
update_progress_status( |
|
job_id=job_id, progress_status="FIXING_DETAILS", progress_value=final_adetailer_progress) |
|
else: |
|
generating_progress = job_progress / 2 if with_adetailer else job_progress |
|
update_progress_status( |
|
job_id=job_id, progress_status="GENERATING", progress_value=generating_progress) |
|
except Exception as e: |
|
logger.exception(f'sync_progress_value: Exception {e}') |
|
|
|
|
|
def handler(event): |
|
global sd_available |
|
|
|
progress_sync_scheduler = BackgroundScheduler() |
|
scheduler_started = False |
|
|
|
try: |
|
job_id = event.get("id") |
|
input_event = event.get('input', {}) |
|
endpoint = input_event.get("endpoint", "txt2img") |
|
params = input_event.get("params", {}) |
|
method = input_event.get("method", "post") |
|
is_xl = input_event.get("is_xl", False) |
|
images_field = input_event.get("images_field", "images") |
|
image_field = input_event.get("image_field", "image") |
|
lora_urls = params.get("lora_urls") |
|
is_fix_outpaint = params.get('is_fix_outpaint', False) |
|
logger.info(f'is_xl {is_xl}') |
|
|
|
if endpoint == "img2bgmask": |
|
init_image = params.get('image') |
|
|
|
if init_image.startswith("https"): |
|
init_image = get_image_base64_by_url(init_image) |
|
mask = get_removed_bg_mask(remove_base64_prefix(init_image)) |
|
return {"mask": mask} |
|
|
|
if not sd_available: |
|
update_progress_status( |
|
job_id=job_id, progress_status="SERVER_STARTING") |
|
check_sd_availability() |
|
logger.info('run handler') |
|
|
|
if is_xl and endpoint == 'img2img': |
|
model = params.get('override_settings', {}).get('sd_model_checkpoint') |
|
init_xl_img2img(model) |
|
|
|
if endpoint == 'txt2img' or endpoint == 'img2img': |
|
scheduler_started = True |
|
progress_sync_scheduler.add_job( |
|
sync_progress_value, 'interval', seconds=sync_progress_value_interval/1000, args=(job_id, params)) |
|
progress_sync_scheduler.start() |
|
|
|
init_safety_checker() |
|
|
|
if endpoint == "extra-single-image": |
|
image_url = params.get('image') |
|
params['image'] = get_image_base64_by_url(image_url) |
|
|
|
if endpoint == "img2img": |
|
init_image = params.get('init_images', [])[0] |
|
|
|
if init_image.startswith("https"): |
|
image_base64 = get_image_base64_by_url(init_image) |
|
params['init_images'] = [image_base64] |
|
|
|
|
|
if lora_urls is not None: |
|
for lora_url in lora_urls: |
|
if lora_url is not None: |
|
load_lora(lora_url) |
|
refresh_loras() |
|
|
|
host = f'http://127.0.0.1:3000/sdapi/v1/{endpoint}' |
|
|
|
request_time = TimeMesurer() |
|
response = automatic_session.request(method=method, url=host, json=params) |
|
|
|
result = response.json() |
|
request_time.log_time_end("request_time") |
|
|
|
if endpoint == 'txt2img' or endpoint == 'img2img': |
|
progress_sync_scheduler.shutdown() |
|
|
|
if is_fix_outpaint: |
|
first_result = copy.deepcopy(result) |
|
update_progress_status( |
|
job_id=job_id, progress_status="FIXING_OUTPAINT") |
|
new_params = copy.deepcopy(params) |
|
new_params['denoising_strength'] = params['outpaint_fix_noise_strength'] |
|
new_params['mask_blur'] = 30 |
|
new_params['inpainting_fill'] = 1 |
|
new_params['init_images'] = first_result[images_field] |
|
response = automatic_session.request(method=method, url=host, json=new_params) |
|
result = response.json() |
|
|
|
if result is not None and images_field in result: |
|
update_progress_status( |
|
job_id=job_id, progress_status="UPLOADING_IMAGES") |
|
result[images_field] = upload_images_in_parallel( |
|
result[images_field]) |
|
|
|
if result is not None and image_field in result: |
|
update_progress_status( |
|
job_id=job_id, progress_status="UPLOADING_IMAGES") |
|
result[image_field] = upload_image(result[image_field]) |
|
|
|
return result |
|
|
|
except Exception as e: |
|
logger.exception(f'An exception occurred: {e}') |
|
if scheduler_started: |
|
progress_sync_scheduler.shutdown() |
|
raise |
|
|
|
|
|
runpod.serverless.start({"handler": handler}) |
|
|