Spaces:
Paused
Paused
File size: 6,253 Bytes
1c72248 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import gc
import os
from collections import OrderedDict
from typing import ForwardRef, List, Optional, Union
import torch
from safetensors.torch import save_file, load_file
from jobs.process.BaseProcess import BaseProcess
from toolkit.config_modules import ModelConfig, GenerateImageConfig
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_model_hash_to_meta, \
add_base_model_info_to_meta
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.train_tools import get_torch_dtype
import random
class GenerateConfig:
def __init__(self, **kwargs):
self.prompts: List[str]
self.sampler = kwargs.get('sampler', 'ddpm')
self.width = kwargs.get('width', 512)
self.height = kwargs.get('height', 512)
self.size_list: Union[List[int], None] = kwargs.get('size_list', None)
self.neg = kwargs.get('neg', '')
self.seed = kwargs.get('seed', -1)
self.guidance_scale = kwargs.get('guidance_scale', 7)
self.sample_steps = kwargs.get('sample_steps', 20)
self.prompt_2 = kwargs.get('prompt_2', None)
self.neg_2 = kwargs.get('neg_2', None)
self.prompts = kwargs.get('prompts', None)
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
self.compile = kwargs.get('compile', False)
self.ext = kwargs.get('ext', 'png')
self.prompt_file = kwargs.get('prompt_file', False)
self.num_repeats = kwargs.get('num_repeats', 1)
self.prompts_in_file = self.prompts
if self.prompts is None:
raise ValueError("Prompts must be set")
if isinstance(self.prompts, str):
if os.path.exists(self.prompts):
with open(self.prompts, 'r', encoding='utf-8') as f:
self.prompts_in_file = f.read().splitlines()
self.prompts_in_file = [p.strip() for p in self.prompts_in_file if len(p.strip()) > 0]
else:
raise ValueError("Prompts file does not exist, put in list if you want to use a list of prompts")
self.random_prompts = kwargs.get('random_prompts', False)
self.max_random_per_prompt = kwargs.get('max_random_per_prompt', 1)
self.max_images = kwargs.get('max_images', 10000)
if self.random_prompts:
self.prompts = []
for i in range(self.max_images):
num_prompts = random.randint(1, self.max_random_per_prompt)
prompt_list = [random.choice(self.prompts_in_file) for _ in range(num_prompts)]
self.prompts.append(", ".join(prompt_list))
else:
self.prompts = self.prompts_in_file
if kwargs.get('shuffle', False):
# shuffle the prompts
random.shuffle(self.prompts)
class GenerateProcess(BaseProcess):
process_id: int
config: OrderedDict
progress_bar: ForwardRef('tqdm') = None
sd: StableDiffusion
def __init__(
self,
process_id: int,
job,
config: OrderedDict
):
super().__init__(process_id, job, config)
self.output_folder = self.get_conf('output_folder', required=True)
self.model_config = ModelConfig(**self.get_conf('model', required=True))
self.device = self.get_conf('device', self.job.device)
self.generate_config = GenerateConfig(**self.get_conf('generate', required=True))
self.torch_dtype = get_torch_dtype(self.get_conf('dtype', 'float16'))
self.progress_bar = None
self.sd = StableDiffusion(
device=self.device,
model_config=self.model_config,
dtype=self.model_config.dtype,
)
print(f"Using device {self.device}")
def clean_prompt(self, prompt: str):
# remove any non alpha numeric characters or ,'" from prompt
return ''.join(e for e in prompt if e.isalnum() or e in ", '\"")
def run(self):
with torch.no_grad():
super().run()
print("Loading model...")
self.sd.load_model()
self.sd.pipeline.to(self.device, self.torch_dtype)
print("Compiling model...")
# self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead", fullgraph=True)
if self.generate_config.compile:
self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead")
print(f"Generating {len(self.generate_config.prompts)} images")
# build prompt image configs
prompt_image_configs = []
for _ in range(self.generate_config.num_repeats):
for prompt in self.generate_config.prompts:
width = self.generate_config.width
height = self.generate_config.height
# prompt = self.clean_prompt(prompt)
if self.generate_config.size_list is not None:
# randomly select a size
width, height = random.choice(self.generate_config.size_list)
prompt_image_configs.append(GenerateImageConfig(
prompt=prompt,
prompt_2=self.generate_config.prompt_2,
width=width,
height=height,
num_inference_steps=self.generate_config.sample_steps,
guidance_scale=self.generate_config.guidance_scale,
negative_prompt=self.generate_config.neg,
negative_prompt_2=self.generate_config.neg_2,
seed=self.generate_config.seed,
guidance_rescale=self.generate_config.guidance_rescale,
output_ext=self.generate_config.ext,
output_folder=self.output_folder,
add_prompt_file=self.generate_config.prompt_file
))
# generate images
self.sd.generate_images(prompt_image_configs, sampler=self.generate_config.sampler)
print("Done generating images")
# cleanup
del self.sd
gc.collect()
torch.cuda.empty_cache()
|