File size: 6,253 Bytes
1c72248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import gc
import os
from collections import OrderedDict
from typing import ForwardRef, List, Optional, Union

import torch
from safetensors.torch import save_file, load_file

from jobs.process.BaseProcess import BaseProcess
from toolkit.config_modules import ModelConfig, GenerateImageConfig
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_model_hash_to_meta, \
    add_base_model_info_to_meta
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.train_tools import get_torch_dtype
import random


class GenerateConfig:

    def __init__(self, **kwargs):
        self.prompts: List[str]
        self.sampler = kwargs.get('sampler', 'ddpm')
        self.width = kwargs.get('width', 512)
        self.height = kwargs.get('height', 512)
        self.size_list: Union[List[int], None] = kwargs.get('size_list', None)
        self.neg = kwargs.get('neg', '')
        self.seed = kwargs.get('seed', -1)
        self.guidance_scale = kwargs.get('guidance_scale', 7)
        self.sample_steps = kwargs.get('sample_steps', 20)
        self.prompt_2 = kwargs.get('prompt_2', None)
        self.neg_2 = kwargs.get('neg_2', None)
        self.prompts = kwargs.get('prompts', None)
        self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
        self.compile = kwargs.get('compile', False)
        self.ext = kwargs.get('ext', 'png')
        self.prompt_file = kwargs.get('prompt_file', False)
        self.num_repeats = kwargs.get('num_repeats', 1)
        self.prompts_in_file = self.prompts
        if self.prompts is None:
            raise ValueError("Prompts must be set")
        if isinstance(self.prompts, str):
            if os.path.exists(self.prompts):
                with open(self.prompts, 'r', encoding='utf-8') as f:
                    self.prompts_in_file = f.read().splitlines()
                    self.prompts_in_file = [p.strip() for p in self.prompts_in_file if len(p.strip()) > 0]
            else:
                raise ValueError("Prompts file does not exist, put in list if you want to use a list of prompts")

        self.random_prompts = kwargs.get('random_prompts', False)
        self.max_random_per_prompt = kwargs.get('max_random_per_prompt', 1)
        self.max_images = kwargs.get('max_images', 10000)

        if self.random_prompts:
            self.prompts = []
            for i in range(self.max_images):
                num_prompts = random.randint(1, self.max_random_per_prompt)
                prompt_list = [random.choice(self.prompts_in_file) for _ in range(num_prompts)]
                self.prompts.append(", ".join(prompt_list))
        else:
            self.prompts = self.prompts_in_file

        if kwargs.get('shuffle', False):
            # shuffle the prompts
            random.shuffle(self.prompts)


class GenerateProcess(BaseProcess):
    process_id: int
    config: OrderedDict
    progress_bar: ForwardRef('tqdm') = None
    sd: StableDiffusion

    def __init__(
            self,
            process_id: int,
            job,
            config: OrderedDict
    ):
        super().__init__(process_id, job, config)
        self.output_folder = self.get_conf('output_folder', required=True)
        self.model_config = ModelConfig(**self.get_conf('model', required=True))
        self.device = self.get_conf('device', self.job.device)
        self.generate_config = GenerateConfig(**self.get_conf('generate', required=True))
        self.torch_dtype = get_torch_dtype(self.get_conf('dtype', 'float16'))

        self.progress_bar = None
        self.sd = StableDiffusion(
            device=self.device,
            model_config=self.model_config,
            dtype=self.model_config.dtype,
        )

        print(f"Using device {self.device}")

    def clean_prompt(self, prompt: str):
        # remove any non alpha numeric characters or ,'" from prompt
        return ''.join(e for e in prompt if e.isalnum() or e in ", '\"")

    def run(self):
        with torch.no_grad():
            super().run()
            print("Loading model...")
            self.sd.load_model()
            self.sd.pipeline.to(self.device, self.torch_dtype)

            print("Compiling model...")
            # self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead", fullgraph=True)
            if self.generate_config.compile:
                self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead")

            print(f"Generating {len(self.generate_config.prompts)} images")
            # build prompt image configs
            prompt_image_configs = []
            for _ in range(self.generate_config.num_repeats):
                for prompt in self.generate_config.prompts:
                    width = self.generate_config.width
                    height = self.generate_config.height
                    # prompt = self.clean_prompt(prompt)

                    if self.generate_config.size_list is not None:
                        # randomly select a size
                        width, height = random.choice(self.generate_config.size_list)

                    prompt_image_configs.append(GenerateImageConfig(
                        prompt=prompt,
                        prompt_2=self.generate_config.prompt_2,
                        width=width,
                        height=height,
                        num_inference_steps=self.generate_config.sample_steps,
                        guidance_scale=self.generate_config.guidance_scale,
                        negative_prompt=self.generate_config.neg,
                        negative_prompt_2=self.generate_config.neg_2,
                        seed=self.generate_config.seed,
                        guidance_rescale=self.generate_config.guidance_rescale,
                        output_ext=self.generate_config.ext,
                        output_folder=self.output_folder,
                        add_prompt_file=self.generate_config.prompt_file
                    ))
            # generate images
            self.sd.generate_images(prompt_image_configs, sampler=self.generate_config.sampler)

            print("Done generating images")
            # cleanup
            del self.sd
            gc.collect()
            torch.cuda.empty_cache()