File size: 19,084 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
#!/usr/bin/env python
# pylint: disable=no-member
"""generate batches of images from prompts and upscale them

params: run with `--help`

default workflow runs infinite loop and prints stats when interrupted:
1. choose random scheduler lookup all available and pick one
2. generate dynamic prompt based on styles, embeddings, places, artists, suffixes
3. beautify prompt
4. generate 3x3 images
5. create image grid
6. upscale images with face restoration
"""

import argparse
import asyncio
import base64
import io
import json
import logging
import math
import os
import pathlib
import secrets
import time
import sys
import importlib

from random import randrange
from PIL import Image
from PIL.ExifTags import TAGS
from PIL.TiffImagePlugin import ImageFileDirectory_v2

from sdapi import close, get, interrupt, post, session
from util import Map, log, safestring


sd = {}
random = {}
stats = Map({ 'images': 0, 'wall': 0, 'generate': 0, 'upscale': 0 })
avg = {}


def grid(data):
    if len(data.image) > 1:
        w, h = data.image[0].size
        rows = round(math.sqrt(len(data.image)))
        cols = math.ceil(len(data.image) / rows)
        image = Image.new('RGB', size = (cols * w, rows * h), color = 'black')
        for i, img in enumerate(data.image):
            image.paste(img, box=(i % cols * w, i // cols * h))
        short = data.info.prompt[:min(len(data.info.prompt), 96)] # limit prompt part of filename to 96 chars
        name = '{seed:0>9} {short}'.format(short = short, seed = data.info.all_seeds[0]) # pylint: disable=consider-using-f-string
        name = safestring(name) + '.jpg'
        f = os.path.join(sd.paths.root, sd.paths.grid, name)
        log.info({ 'grid': { 'name': f, 'size': image.size, 'images': len(data.image) } })
        image.save(f, 'JPEG', exif = exif(data.info, None, 'grid'), optimize = True, quality = 70)
        return image
    return data.image


def exif(info, i = None, op = 'generate'):
    seed = [info.all_seeds[i]] if len(info.all_seeds) > 0 and i is not None else info.all_seeds # always returns list
    seed = ', '.join([str(x) for x in seed]) # int list to str list to single str
    template = '{prompt} | negative {negative_prompt} | seed {s} | steps {steps} | cfgscale {cfg_scale} | sampler {sampler_name} | batch {batch_size} | timestamp {job_timestamp} | model {model} | vae {vae}'.format(s = seed, model = sd.options['sd_model_checkpoint'], vae = sd.options['sd_vae'], **info) # pylint: disable=consider-using-f-string
    if op == 'upscale':
        template += ' | faces gfpgan' if sd.upscale.gfpgan_visibility > 0 else ''
        template += ' | faces codeformer' if sd.upscale.codeformer_visibility > 0 else ''
        template += ' | upscale {resize}x {upscaler}'.format(resize = sd.upscale.upscaling_resize, upscaler = sd.upscale.upscaler_1) if sd.upscale.upscaler_1 != 'None' else '' # pylint: disable=consider-using-f-string
        template += ' | upscale {resize}x {upscaler}'.format(resize = sd.upscale.upscaling_resize, upscaler = sd.upscale.upscaler_2) if sd.upscale.upscaler_2 != 'None' else '' # pylint: disable=consider-using-f-string
    if op == 'grid':
        template += ' | grid {num}'.format(num = sd.generate.batch_size * sd.generate.n_iter) # pylint: disable=consider-using-f-string
    ifd = ImageFileDirectory_v2()
    exif_stream = io.BytesIO()
    _TAGS = {v: k for k, v in TAGS.items()} # enumerate possible exif tags
    ifd[_TAGS['ImageDescription']] = template
    ifd.save(exif_stream)
    val = b'Exif\x00\x00' + exif_stream.getvalue()
    return val


def randomize(lst):
    if len(lst) > 0:
        return secrets.choice(lst)
    else:
        return ''


def prompt(params): # generate dynamic prompt or use one if provided
    sd.generate.prompt = params.prompt if params.prompt != 'dynamic' else randomize(random.prompts)
    sd.generate.negative_prompt = params.negative if params.negative != 'dynamic' else randomize(random.negative)
    embedding = params.embedding if params.embedding != 'random' else randomize(random.embeddings)
    sd.generate.prompt = sd.generate.prompt.replace('<embedding>', embedding)
    artist = params.artist if params.artist != 'random' else randomize(random.artists)
    sd.generate.prompt = sd.generate.prompt.replace('<artist>', artist)
    style = params.style if params.style != 'random' else randomize(random.styles)
    sd.generate.prompt = sd.generate.prompt.replace('<style>', style)
    suffix = params.suffix if params.suffix != 'random' else randomize(random.suffixes)
    sd.generate.prompt = sd.generate.prompt.replace('<suffix>', suffix)
    place = params.suffix if params.suffix != 'random' else randomize(random.places)
    sd.generate.prompt = sd.generate.prompt.replace('<place>', place)
    if params.prompts or params.debug:
        log.info({ 'random initializers': random })
    if params.prompt == 'dynamic':
        log.info({ 'dynamic prompt': sd.generate.prompt })
    return sd.generate.prompt


def sampler(params, options): # find sampler
    if params.sampler == 'random':
        sd.generate.sampler_name = randomize(options.samplers)
        log.info({ 'random sampler': sd.generate.sampler_name })
    else:
        found = [i for i in options.samplers if i.startswith(params.sampler)]
        if len(found) == 0:
            log.error({ 'sampler error': sd.generate.sampler_name, 'available': options.samplers})
            exit()
        sd.generate.sampler_name = found[0]
    return sd.generate.sampler_name


async def generate(prompt = None, options = None, quiet = False): # pylint: disable=redefined-outer-name
    global sd # pylint: disable=global-statement
    if options:
        sd = Map(options)
    if prompt is not None:
        sd.generate.prompt = prompt
    if not quiet:
        log.info({ 'generate': sd.generate })
    if sd.get('options', None) is None:
        sd['options'] = await get('/sdapi/v1/options')
    names = []
    b64s = []
    images = []
    info = Map({})
    data = await post('/sdapi/v1/txt2img', sd.generate)
    if 'error' in data:
        log.error({ 'generate': data['error'], 'reason': data['reason'] })
        return Map({})
    info = Map(json.loads(data['info']))
    log.debug({ 'info': info })
    images = data['images']
    short = info.prompt[:min(len(info.prompt), 96)] # limit prompt part of filename to 64 chars
    for i in range(len(images)):
        b64s.append(images[i])
        images[i] = Image.open(io.BytesIO(base64.b64decode(images[i].split(',',1)[0])))
        name = '{seed:0>9} {short}'.format(short = short, seed = info.all_seeds[i]) # pylint: disable=consider-using-f-string
        name = safestring(name) + '.jpg'
        f = os.path.join(sd.paths.root, sd.paths.generate, name)
        names.append(f)
        if not quiet:
            log.info({ 'image': { 'name': f, 'size': images[i].size } })
        images[i].save(f, 'JPEG', exif = exif(info, i), optimize = True, quality = 70)
    return Map({ 'name': names, 'image': images, 'b64': b64s, 'info': info })


async def upscale(data):
    data.upscaled = []
    if sd.upscale.upscaling_resize <=1:
        return data
    sd.upscale.image = ''
    log.info({ 'upscale': sd.upscale })
    for i in range(len(data.image)):
        f = data.name[i].replace(sd.paths.generate, sd.paths.upscale)
        sd.upscale.image = data.b64[i]
        res = await post('/sdapi/v1/extra-single-image', sd.upscale)
        image = Image.open(io.BytesIO(base64.b64decode(res['image'].split(',',1)[0])))
        data.upscaled.append(image)
        log.info({ 'image': { 'name': f, 'size': image.size } })
        image.save(f, 'JPEG', exif = exif(data.info, i, 'upscale'), optimize = True, quality = 70)
    return data


async def init():
    '''
    import torch
    log.info({ 'torch': torch.__version__, 'available': torch.cuda.is_available() })
    current_device = torch.cuda.current_device()
    mem_free, mem_total = torch.cuda.mem_get_info()
    log.info({ 'cuda': torch.version.cuda, 'available': torch.cuda.is_available(), 'arch': torch.cuda.get_arch_list(), 'device': torch.cuda.get_device_name(current_device), 'memory': { 'free': round(mem_free / 1024 / 1024), 'total': (mem_total / 1024 / 1024) } })
    '''
    options = Map({})
    options.flags = await get('/sdapi/v1/cmd-flags')
    log.debug({ 'flags': options.flags })
    data = await get('/sdapi/v1/sd-models')
    options.models = [obj['title'] for obj in data]
    log.debug({ 'registered models': options.models })
    found = sd.options.sd_model_checkpoint if sd.options.sd_model_checkpoint in options.models else None
    if found is None:
        found = [i for i in options.models if i.startswith(sd.options.sd_model_checkpoint)]
    if len(found) == 0:
        log.error({ 'model error': sd.generate.sd_model_checkpoint, 'available': options.models})
        exit()
    sd.options.sd_model_checkpoint = found[0]
    data = await get('/sdapi/v1/samplers')
    options.samplers = [obj['name'] for obj in data]
    log.debug({ 'registered samplers': options.samplers })
    data = await get('/sdapi/v1/upscalers')
    options.upscalers = [obj['name'] for obj in data]
    log.debug({ 'registered upscalers': options.upscalers })
    data = await get('/sdapi/v1/face-restorers')
    options.restorers = [obj['name'] for obj in data]
    log.debug({ 'registered face restorers': options.restorers })
    await interrupt()
    await post('/sdapi/v1/options', sd.options)
    options.options = await get('/sdapi/v1/options')
    log.info({ 'target models': { 'diffuser': options.options['sd_model_checkpoint'], 'vae': options.options['sd_vae'] } })
    log.info({ 'paths': sd.paths })
    options.queue = await get('/queue/status')
    log.info({ 'queue': options.queue })
    pathlib.Path(sd.paths.root).mkdir(parents = True, exist_ok = True)
    pathlib.Path(os.path.join(sd.paths.root, sd.paths.generate)).mkdir(parents = True, exist_ok = True)
    pathlib.Path(os.path.join(sd.paths.root, sd.paths.upscale)).mkdir(parents = True, exist_ok = True)
    pathlib.Path(os.path.join(sd.paths.root, sd.paths.grid)).mkdir(parents = True, exist_ok = True)
    return options


def args(): # parse cmd arguments
    global sd # pylint: disable=global-statement
    global random # pylint: disable=global-statement
    parser = argparse.ArgumentParser(description = 'sd pipeline')
    parser.add_argument('--config', type = str, default = 'generate.json', required = False, help = 'configuration file')
    parser.add_argument('--random', type = str, default = 'random.json', required = False, help = 'prompt file with randomized sections')
    parser.add_argument('--max', type = int, default = 1, required = False, help = 'maximum number of generated images')
    parser.add_argument('--prompt', type = str, default = 'dynamic', required = False, help = 'prompt')
    parser.add_argument('--negative', type = str, default = 'dynamic', required = False, help = 'negative prompt')
    parser.add_argument('--artist', type = str, default = 'random', required = False, help = 'artist style, used to guide dynamic prompt when prompt is not provided')
    parser.add_argument('--embedding', type = str, default = 'random', required = False, help = 'use embedding, used to guide dynamic prompt when prompt is not provided')
    parser.add_argument('--style', type = str, default = 'random', required = False, help = 'image style, used to guide dynamic prompt when prompt is not provided')
    parser.add_argument('--suffix', type = str, default = 'random', required = False, help = 'style suffix, used to guide dynamic prompt when prompt is not provided')
    parser.add_argument('--place', type = str, default = 'random', required = False, help = 'place locator, used to guide dynamic prompt when prompt is not provided')
    parser.add_argument('--faces', default = False, action='store_true', help = 'restore faces during upscaling')
    parser.add_argument('--steps', type = int, default = 0, required = False, help = 'number of steps')
    parser.add_argument('--batch', type = int, default = 0, required = False, help = 'batch size, limited by gpu vram')
    parser.add_argument('--n', type = int, default = 0, required = False, help = 'number of iterations')
    parser.add_argument('--cfg', type = int, default = 0, required = False, help = 'classifier free guidance scale')
    parser.add_argument('--sampler', type = str, default = 'random', required = False, help = 'sampler')
    parser.add_argument('--seed', type = int, default = 0, required = False, help = 'seed, default is random')
    parser.add_argument('--upscale', type = int, default = 0, required = False, help = 'upscale factor, disabled if 0')
    parser.add_argument('--model', type = str, default = '', required = False, help = 'diffusion model')
    parser.add_argument('--vae', type = str, default = '', required = False, help = 'vae model')
    parser.add_argument('--path', type = str, default = '', required = False, help = 'output path')
    parser.add_argument('--width', type = int, default = 0, required = False, help = 'width')
    parser.add_argument('--height', type = int, default = 0, required = False, help = 'height')
    parser.add_argument('--beautify', default = False, action='store_true', help = 'beautify prompt')
    parser.add_argument('--prompts', default = False, action='store_true', help = 'print dynamic prompt templates')
    parser.add_argument('--debug', default = False, action='store_true', help = 'print extra debug information')
    params = parser.parse_args()
    if params.debug:
        log.setLevel(logging.DEBUG)
        log.debug({ 'debug': True })
    log.debug({ 'args': params.__dict__ })
    home = pathlib.Path(sys.argv[0]).parent
    if os.path.isfile(params.config):
        try:
            with open(params.config, 'r', encoding='utf-8') as f:
                data = json.load(f)
                sd = Map(data)
                log.debug({ 'config': sd })
        except Exception as e:
            log.error({ 'config error': params.config, 'exception': e })
            exit()
    elif os.path.isfile(os.path.join(home, params.config)):
        try:
            with open(os.path.join(home, params.config), 'r', encoding='utf-8') as f:
                data = json.load(f)
                sd = Map(data)
                log.debug({ 'config': sd })
        except Exception as e:
            log.error({ 'config error': params.config, 'exception': e })
            exit()
    else:
        log.error({ 'config file not found': params.config})
        exit()
    if params.prompt == 'dynamic':
        log.info({ 'prompt template': params.random })
        if os.path.isfile(params.random):
            try:
                with open(params.random, 'r', encoding='utf-8') as f:
                    data = json.load(f)
                    random = Map(data)
                    log.debug({ 'random template': sd })
            except Exception:
                log.error({ 'random template error': params.random})
                exit()
        elif os.path.isfile(os.path.join(home, params.random)):
            try:
                with open(os.path.join(home, params.random), 'r', encoding='utf-8') as f:
                    data = json.load(f)
                    random = Map(data)
                    log.debug({ 'random template': sd })
            except Exception:
                log.error({ 'random template error': params.random})
                exit()
        else:
            log.error({ 'random template file not found': params.random})
            exit()
        _dynamic = prompt(params)

    sd.paths.root = params.path if params.path != '' else sd.paths.root
    sd.generate.restore_faces = params.faces if params.faces is not None else sd.generate.restore_faces
    sd.generate.seed = params.seed if params.seed > 0 else sd.generate.seed
    sd.generate.sampler_name = params.sampler if params.sampler != 'random' else sd.generate.sampler_name
    sd.generate.batch_size = params.batch if params.batch > 0 else sd.generate.batch_size
    sd.generate.cfg_scale = params.cfg if params.cfg > 0 else sd.generate.cfg_scale
    sd.generate.n_iter = params.n if params.n > 0 else sd.generate.n_iter
    sd.generate.width = params.width if params.width > 0 else sd.generate.width
    sd.generate.height = params.height if params.height > 0 else sd.generate.height
    sd.generate.steps = params.steps if params.steps > 0 else sd.generate.steps
    sd.upscale.upscaling_resize = params.upscale if params.upscale > 0 else sd.upscale.upscaling_resize
    sd.upscale.codeformer_visibility = 1 if params.faces else sd.upscale.codeformer_visibility
    sd.options.sd_vae = params.vae if params.vae != '' else sd.options.sd_vae
    sd.options.sd_model_checkpoint = params.model if params.model != '' else sd.options.sd_model_checkpoint
    sd.upscale.upscaler_1 = 'SwinIR_4x' if params.upscale > 1 else sd.upscale.upscaler_1
    if sd.generate.cfg_scale == 0:
        sd.generate.cfg_scale = randrange(5, 10)
    return params


async def main():
    params = args()
    sess = await session()
    if sess is None:
        await close()
        exit()
    options = await init()
    iteration = 0
    while True:
        iteration += 1
        log.info('')
        log.info({ 'iteration': iteration, 'batch': sd.generate.batch_size, 'n': sd.generate.n_iter, 'total': sd.generate.n_iter * sd.generate.batch_size })
        dynamic = prompt(params)
        if params.beautify:
            try:
                promptist = importlib.import_module('modules.promptist')
                sd.generate.prompt = promptist.beautify(dynamic)
            except Exception as e:
                log.error({ 'beautify': e })
        scheduler = sampler(params, options)
        t0 = time.perf_counter()
        data = await generate() # generate returns list of images
        if 'image' not in data:
            break
        stats.images += len(data.image)
        t1 = time.perf_counter()
        if len(data.image) > 0:
            avg[scheduler] = (t1 - t0) / len(data.image)
        stats.generate += t1 - t0
        _image = grid(data)
        data = await upscale(data)
        t2 = time.perf_counter()
        stats.upscale += t2 - t1
        stats.wall += t2 - t0
        its = sd.generate.steps / ((t1 - t0) / len(data.image)) if len(data.image) > 0 else 0
        avg_time = round((t1 - t0) / len(data.image)) if len(data.image) > 0 else 0
        log.info({ 'time' : { 'wall': round(t1 - t0), 'average': avg_time, 'upscale': round(t2 - t1), 'its': round(its, 2) } })
        log.info({ 'generated': stats.images, 'max': params.max, 'progress': round(100 * stats.images / params.max, 1) })
        if params.max != 0 and stats.images >= params.max:
            break


if __name__ == '__main__':
    try:
        asyncio.run(main())
    except KeyboardInterrupt:
        asyncio.run(interrupt())
        asyncio.run(close())
        log.info({ 'interrupt': True })
    finally:
        log.info({ 'sampler performance': avg })
        log.info({ 'stats' : stats })
        asyncio.run(close())