File size: 17,055 Bytes
76852be
 
 
 
 
2fdc5aa
76852be
 
 
 
 
 
294c364
76852be
6ccbaf5
76852be
 
 
 
307dca0
76852be
 
 
f0d08e9
8b064fc
76852be
 
236b291
 
 
76852be
 
 
 
5c615f1
 
 
 
 
 
76852be
5c615f1
 
e48285e
11a95be
76852be
5c615f1
 
76852be
e48285e
76852be
307dca0
 
 
 
 
76852be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c615f1
4449bf5
 
 
 
 
 
0a372fd
 
4449bf5
5c615f1
607075f
 
 
 
f0d08e9
 
 
 
 
f8b3240
f0d08e9
5c3ff40
f8b3240
 
5c3ff40
 
 
 
 
 
 
 
 
 
 
 
f8b3240
5c3ff40
 
 
 
 
 
 
 
 
f8b3240
5c3ff40
 
41e6913
 
 
5c3ff40
 
 
 
 
 
 
 
 
 
7c2da7f
5c3ff40
 
 
f8b3240
5c3ff40
 
f8b3240
5c3ff40
 
f8b3240
5c3ff40
5c615f1
d9bbbdf
 
7c2da7f
d9bbbdf
 
 
7c2da7f
 
 
 
 
d9bbbdf
7c2da7f
d9bbbdf
7c2da7f
d9bbbdf
 
7c2da7f
 
 
607075f
 
 
bd1a3cd
 
5c615f1
607075f
 
5c615f1
 
 
 
607075f
 
 
5c615f1
21632b2
5c615f1
 
 
 
 
21632b2
5c615f1
 
 
 
 
 
 
 
 
76852be
ad83a39
5c615f1
327dec3
76852be
 
236b291
9d59894
 
5c615f1
 
327dec3
5c615f1
 
 
 
76852be
 
5c615f1
 
76852be
 
ad83a39
76852be
5c615f1
76852be
 
 
 
 
 
5c615f1
76852be
 
 
 
 
5c615f1
 
 
c489323
 
76852be
 
e48285e
001f93a
76852be
 
 
 
 
 
e48285e
 
 
 
f634c57
e48285e
 
307dca0
 
 
 
 
 
6c4432e
5c615f1
76852be
 
 
 
 
a8a9b70
 
 
 
 
 
 
 
 
94c9fad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97e1982
e4f71ac
 
293e141
637dd12
 
94c9fad
e4f71ac
94c9fad
e4f71ac
94c9fad
76852be
 
 
8f024bc
0f4400e
236b291
0f4400e
8f024bc
7d98ac6
8741d34
8f024bc
 
14775e7
c4dcb9f
8f024bc
5d609d1
76852be
1c77893
db3ea2c
7d98ac6
72978f5
07eafc9
 
 
1b72445
 
5d609d1
 
8f024bc
7d98ac6
21632b2
7d98ac6
5d609d1
7d98ac6
5d609d1
0f4400e
 
76852be
5c615f1
76852be
5c615f1
 
 
6c6b3e9
5c615f1
76852be
 
ce95207
 
 
 
18b3978
ce95207
 
7c69f0a
ce95207
66b39ee
7a447d5
542b2b6
 
 
 
94c9fad
a8a9b70
97e1982
542b2b6
5c615f1
 
 
ad83a39
327dec3
c95c330
7c2da7f
 
 
 
3d05a7b
6c6b3e9
3d05a7b
 
 
5c615f1
c595c8a
 
76852be
 
94c9fad
76852be
 
 
 
 
94c9fad
76852be
c407c22
 
7c69f0a
 
 
 
f8b3240
 
76852be
c7d0bf0
 
236b291
76852be
 
c7d0bf0
76852be
3d05a7b
 
76852be
7fcb361
558a56a
294c364
 
af80463
7fcb361
cd988db
a721fed
558a56a
236b291
dfcb7a3
b8895a4
89fb56f
5c615f1
 
 
 
76852be
89fb56f
5c615f1
 
76852be
 
 
 
 
 
6c6b3e9
2c4a7c2
76852be
 
5c615f1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
import runpod
import base64
import subprocess
import concurrent
import requests
from requests.adapters import HTTPAdapter, Retry
import time
import logging
import socket
import json
import os
import uuid
import copy
from io import BytesIO
from PIL import Image, ImageOps
from firebase_admin import credentials, initialize_app, storage
from logging.handlers import SysLogHandler
import sentry_sdk
import torch
import boto3
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
from apscheduler.schedulers.background import BackgroundScheduler
import safetensors_worker
from rembg import remove, new_session

progress_url = "http://127.0.0.1:3000/sdapi/v1/progress"
automatic_session = requests.Session()
retries = Retry(total=30, backoff_factor=0.1, status_forcelist=[502, 503, 504])
automatic_session.mount('http://', HTTPAdapter(max_retries=retries))

convert_quality = int(os.environ.get('WEBP_CONVERT_QUALITY', 90))
convert_method = int(os.environ.get('WEBP_CONVERT_METHOD', 5))
convert_is_lossless = os.environ.get('WEBP_CONVERT_IS_LOSSLESS') == 'True'
images_base_url = os.environ.get(
    'IMAGES_BASE_URL', 'https://images-dev.infero.ai')
api_base_url = os.environ.get('API_BASE_URL')
api_key = os.environ.get('API_KEY')
storage_bucket = os.environ.get(
    'STORAGE_BUCKET', 'sd-app-dev-329dd.appspot.com')
env_name = os.environ.get('ENV_NAME', 'development')
gc_service_account_filename = os.environ.get(
    'GC_SERVICE_ACCOUNT_FILENAME', 'upload-only-dev.json')
is_s3_storage = os.environ.get('STORAGE') == 's3'
sync_progress_value_interval = 1000

cred = credentials.Certificate(
    f'./gc-service-accounts/{gc_service_account_filename}')
initialize_app(cred, {'storageBucket': storage_bucket})
google_storage = storage.bucket()

s3 = boto3.client('s3',
                  aws_access_key_id=os.environ.get('AWS_ACCESS_KEY'),
                  aws_secret_access_key=os.environ.get('AWS_SECRET_KEY'))
bucket_name = os.environ.get('AWS_S3_BUCKET')

syslog = SysLogHandler(address=('logs.papertrailapp.com', 25104))
format = '%(asctime)s SD: %(message)s'
formatter = logging.Formatter(format, datefmt='%b %d %H:%M:%S')
syslog.setFormatter(formatter)

logger = logging.getLogger()
logger.addHandler(syslog)
logger.setLevel(logging.INFO)

sentry_sdk.init(
    dsn="https://5a9f4a774b78460f9762480e9bf17a57@o4504769730576384.ingest.sentry.io/4504769740406784",
    traces_sample_rate=1.0,
    environment=env_name
)


class TimeMesurer:
    def __init__(self):
        self.start_time = time.time()

    def log_time_end(self, name):
        end_time = time.time()
        elapsed_time = (end_time - self.start_time) * 1000
        logger.info(f"Operation: {name}, Elapsed time: {elapsed_time:.2f} ms")


safety_model_id = "./stable-diffusion-safety-checker"
safety_feature_extractor = None
safety_checker = None

def check_lora(path):
    try:
        result = safetensors_worker.CheckLoRA({}, path)
        return result == 0
    except Exception as e:
        logger.error(f"Lora check error: {e}")
        return False

def load_lora(url):
    destination_dir = '/stable-diffusion-webui/models/Lora'
    try:
        # Ensure the destination directory exists
        os.makedirs(destination_dir, exist_ok=True)

        # Extract the file name from the URL
        file_name = os.path.basename(url)

        # Construct the full path for the destination file
        destination_file = os.path.join(destination_dir, file_name)

        # Check if the file already exists, and if it does, return the filename
        if os.path.exists(destination_file):
            logger.info(f"File '{file_name}' already exists in '{destination_dir}'.")
            return file_name

        # Download the file from the URL
        response = requests.get(url)

        # Check if the download was successful (status code 200)
        if response.status_code == 200:
            with open(destination_file, 'wb') as f:
                f.write(response.content)
            logger.info(f"File '{file_name}' downloaded and saved to '{destination_dir}'.")

            # Call check_lora to verify the downloaded file
            # if not check_lora(destination_file):
            #     os.remove(destination_file)
            #     raise ValueError(f"Failed to verify file '{file_name}' as a valid LoRA file.")

            return file_name
        else:
            raise ValueError(f"Failed to download file from '{url}'. Status code: {response.status_code}")
    except Exception as e:
        raise ValueError(f"An error occurred: {e}")

def refresh_loras():
    try:
        url = 'http://127.0.0.1:3000/sdapi/v1/refresh-loras'
        response = automatic_session.post(url)

        # Check if the request was successful (status code 200)
        if response.status_code == 200:
            logger.info("Refreshed Loras successfully.")
            return True
        else:
            logger.error(f"Failed to refresh Loras. Status code: {response.status_code}")
            return False
    except Exception as e:
        logger.error(f"An error occurred: {e}")
        return False

xl_img2img_initialized = False

def init_xl_img2img(model):
    global xl_img2img_initialized
    if xl_img2img_initialized:
        return
    try:
        url = 'http://127.0.0.1:3000/sdapi/v1/txt2img'
        data = {
            "steps": 1,
            "cfg_scale": 1,
            "override_settings": {"sd_model_checkpoint": model}
        }
        txt2img_time = TimeMesurer()
        response = automatic_session.post(url, json=data)
        txt2img_time.log_time_end("txt2img")
        xl_img2img_initialized = True
    except Exception as e:
        logger.error(f"Failed to init XL img2img: {e}")

def init_safety_checker():
    global safety_feature_extractor, safety_checker
    if safety_feature_extractor is None:
        safety_feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
        safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")


def check_safety(x_image):
    safety_checker_input = safety_feature_extractor(
        x_image, return_tensors="pt")
    x_checked_image, has_nsfw_concept = safety_checker(
        images=safety_checker_input.pixel_values, clip_input=safety_checker_input.pixel_values)

    return has_nsfw_concept[0]


def update_progress_status(job_id, progress_status, progress_value=None):
    try:
        params = {
            "id": job_id,
            "apiKey": api_key,
            "progressStatus": progress_status,
            "progressValue": progress_value
        }
        requests.request(
            method='post', url=f'{api_base_url}/updateJobProgressStatus', json=params)
    except Exception as e:
        logger.exception(f'update_progress_status: Exception {e}')


sd_available = False


def check_sd_availability():
    global sd_available

    while True:
        try:
            response = automatic_session.get(progress_url)

            if response.status_code != 200:
                raise Exception(
                    f"API is not available. Status code: {response.status_code}")

            sd_available = True

            logger.info(
                f'check_api_availability: Done. status_code: {response.status_code} content: {response.content}')
            return
        except requests.exceptions.RequestException as e:
            print(
                f"check_api_availability: API is not available, retrying... ({e})")
        except Exception as e:
            logger.exception(f'check_api_availability: Exception {e}')
        time.sleep(sync_progress_value_interval/1000)


def upload_image(image_base64):
    try:
        image_bytes = base64.b64decode(image_base64)
        pil_image = Image.open(BytesIO(image_bytes))

        is_nsfw = check_safety(pil_image)

        postfix = ''
        if is_nsfw:
            postfix = '_nsfw'

        webp_image = BytesIO()
        pil_image.save(webp_image, format="WebP", lossless=convert_is_lossless,
                       quality=convert_quality, method=convert_method)
        webp_image.seek(0)  # reset the BytesIO object to the beginning
        copy_image = BytesIO(webp_image.getvalue())
        copy_image.seek(0)

        image_name = f'generated/{uuid.uuid4()}{postfix}.webp'
        
        upload_image_s3(image_name, webp_image)

        return f'{images_base_url}/{image_name}'
    except Exception as e:
        logger.exception(f'Upload exception: {e}')
        raise

def upload_image_google_storage(image_name, webp_image): 
    blob = google_storage.blob(image_name)
    blob.upload_from_file(webp_image, content_type='image/webp')
    blob.make_public()
    logger.info('GOOGLE UPLOADED')

def upload_image_s3(image_name, webp_image): 
    s3.upload_fileobj(
        webp_image,
        bucket_name,
        image_name,
        ExtraArgs={'ContentType': 'image/webp'}
    )
    logger.info('S3 UPLOADED')

def upload_images_in_parallel(images):
    with concurrent.futures.ThreadPoolExecutor() as executor:
        uploaded_urls = list(executor.map(upload_image, images))
    return uploaded_urls

def remove_base64_prefix(base64_string):
    # Check if the string contains a comma, which typically follows the MIME type in the prefix
    if ',' in base64_string:
        # Split the string on the first comma and take the second part
        return base64_string.split(',', 1)[1]
    else:
        # If there's no comma, return the string as it is
        return base64_string

def prepare_mask(image):
    # Convert the image to RGBA if it's not already
    if image.mode != 'RGBA':
        image = image.convert('RGBA')

    # Process each pixel
    pixels = image.load()
    for i in range(image.width):
        for j in range(image.height):
            r, g, b, a = pixels[i, j]

            # Calculate the transparency based on the brightness of the pixel
            # Black (0, 0, 0) becomes white with full opacity (255, 255, 255, 255)
            # White (255, 255, 255) becomes fully transparent (255, 255, 255, 0)
            # Shades of gray become varying levels of transparent white
            new_alpha = 255 - int((r + g + b) / 3)
            pixels[i, j] = (255, 255, 255, new_alpha)

    return image

def pil_image_to_base64(pil_img, format="PNG"):
    buffered = BytesIO()
    pil_img.save(buffered, format=format)
    img_str = base64.b64encode(buffered.getvalue()).decode()
    return f'data:image/png;base64,{img_str}'

def get_removed_bg_mask(base64_image):
    decoded_image = base64.b64decode(base64_image)
    with Image.open(BytesIO(decoded_image)) as input_image:
        session = new_session("silueta")
        # output_mask = remove(input_image, session=session, only_mask=True, post_process_mask=True, alpha_matting=True, alpha_matting_erode_size=5)
        output_mask = remove(input_image, session=session, only_mask=True)
        processed_image = prepare_mask(output_mask)

        return pil_image_to_base64(processed_image)

def get_image_base64_by_url(image_url):
    response = requests.get(image_url)
    return base64.b64encode(response.content).decode('utf-8')

def sync_progress_value(job_id, params):
    try:
        response = automatic_session.get(progress_url)
        result = response.json()
        
        state = result["state"]
        adetailers_count = len(params.get("alwayson_scripts", {}).get("adetailer", {}).get("args", []))
        images_count = params.get("batch_size", 1)
        adetailer_jobs_count = adetailers_count * images_count
        job_no = state.get('job_no', 0)
        job_count = state.get('job_count', 0)
        job_progress = result["progress"]
        with_adetailer = adetailer_jobs_count > 0

        if job_progress is not None and job_progress > 0.01 and job_count > 0:
            logger.info(f'STATE: {result["state"]}, textinfo: {result["textinfo"]}')

            if job_count > 1 and with_adetailer:
                adetailer_progress = job_no / adetailer_jobs_count
                if adetailer_progress < 0.1:
                    adetailer_progress = 0.1
                if adetailer_progress > 0.9:
                    adetailer_progress = 0.9
                final_adetailer_progress = 0.5 + adetailer_progress / 2
                logger.info(f'job_no: {job_no}, job_count: {job_count}, adetailer_jobs_count: {adetailer_jobs_count}, final_adetailer_progress: {final_adetailer_progress}')
                
                update_progress_status(
                    job_id=job_id, progress_status="FIXING_DETAILS", progress_value=final_adetailer_progress)
            else:
                generating_progress = job_progress / 2 if with_adetailer else job_progress
                update_progress_status(
                    job_id=job_id, progress_status="GENERATING", progress_value=generating_progress)
    except Exception as e:
        logger.exception(f'sync_progress_value: Exception {e}')


def handler(event):
    global sd_available

    progress_sync_scheduler = BackgroundScheduler()
    scheduler_started = False

    try:
        job_id = event.get("id")
        input_event = event.get('input', {})
        endpoint = input_event.get("endpoint", "txt2img")
        params = input_event.get("params", {})
        method = input_event.get("method", "post")
        is_xl = input_event.get("is_xl", False)
        images_field = input_event.get("images_field", "images")
        image_field = input_event.get("image_field", "image")
        lora_urls = params.get("lora_urls")
        is_fix_outpaint = params.get('is_fix_outpaint', False)
        logger.info(f'is_xl {is_xl}')

        if endpoint == "img2bgmask":
            init_image = params.get('image')

            if init_image.startswith("https"):
                init_image = get_image_base64_by_url(init_image)
            mask = get_removed_bg_mask(remove_base64_prefix(init_image))
            return {"mask": mask}

        if not sd_available:
            update_progress_status(
                job_id=job_id, progress_status="SERVER_STARTING")
            check_sd_availability()
            logger.info('run handler')

        if is_xl and endpoint == 'img2img':
            model = params.get('override_settings', {}).get('sd_model_checkpoint')
            init_xl_img2img(model)

        if endpoint == 'txt2img' or endpoint == 'img2img':
            scheduler_started = True
            progress_sync_scheduler.add_job(
                sync_progress_value, 'interval', seconds=sync_progress_value_interval/1000, args=(job_id, params))
            progress_sync_scheduler.start()

        init_safety_checker()

        if endpoint == "extra-single-image":
            image_url = params.get('image')
            params['image'] = get_image_base64_by_url(image_url)

        if endpoint == "img2img":
            init_image = params.get('init_images', [])[0]

            if init_image.startswith("https"):
                image_base64 = get_image_base64_by_url(init_image)
                params['init_images'] = [image_base64]
            

        if lora_urls is not None:
            for lora_url in lora_urls:
                if lora_url is not None:
                    load_lora(lora_url)
            refresh_loras()

        host = f'http://127.0.0.1:3000/sdapi/v1/{endpoint}'

        request_time = TimeMesurer()
        response = automatic_session.request(method=method, url=host, json=params)

        result = response.json()
        request_time.log_time_end("request_time")

        if endpoint == 'txt2img' or endpoint == 'img2img':
            progress_sync_scheduler.shutdown()

        if is_fix_outpaint:
            first_result = copy.deepcopy(result)
            update_progress_status(
                job_id=job_id, progress_status="FIXING_OUTPAINT")
            new_params = copy.deepcopy(params)
            new_params['denoising_strength'] = params['outpaint_fix_noise_strength']
            new_params['mask_blur'] = 30
            new_params['inpainting_fill'] = 1
            new_params['init_images'] = first_result[images_field]
            response = automatic_session.request(method=method, url=host, json=new_params)
            result = response.json()
        
        if result is not None and images_field in result:
            update_progress_status(
                job_id=job_id, progress_status="UPLOADING_IMAGES")
            result[images_field] = upload_images_in_parallel(
                result[images_field])

        if result is not None and image_field in result:
            update_progress_status(
                job_id=job_id, progress_status="UPLOADING_IMAGES")
            result[image_field] = upload_image(result[image_field])

        return result

    except Exception as e:
        logger.exception(f'An exception occurred: {e}')
        if scheduler_started:
            progress_sync_scheduler.shutdown()
        raise


runpod.serverless.start({"handler": handler})