#!/usr/bin/env python import os from dotenv import load_dotenv from Helpers import name_formatter, weights_dir, capture_message from contextlib import ExitStack import torch from torch import autocast from diffusers import StableDiffusionPipeline, DDIMScheduler import base64 from io import BytesIO from PIL import Image import uuid device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device.type != 'cuda': raise ValueError("need to run on GPU") class Pipeline: def __init__(self, data): load_dotenv() self.output_images = os.getenv('OUTPUT_IMAGES') self.data = data def generate(self): for preset in self.data['presets']: self.run(preset) def run(self, preset): torch.cuda.empty_cache() capture_message('Pipeline: Run') prompt = preset['prompt'] % (name_formatter(self.data['name'])) n_samples = preset.get('n_samples', 1) guidance_scale = preset.get('guidance_scale', 7.5) ddim_steps = preset.get('ddim_steps', 50) ddim_eta = preset.get('ddim_eta', 0.0) n_iter = preset.get('n_iter', 1) height = preset.get('height', 512) width = preset.get('width', 512) channels = preset.get('channels', 4) scale = preset.get('scale', 7.5) seed = preset.get('seed', 7.5) preset_id = preset.get('preset_id', 1) scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) txt2img_pipe = StableDiffusionPipeline.from_pretrained( weights_dir(), scheduler=scheduler, safety_checker=None, torch_dtype=torch.float16, use_auth_token="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" ) txt2img_pipe = txt2img_pipe.to(device) g_cuda = torch.Generator(device='cuda') g_cuda.manual_seed(seed) with ExitStack() as stack: if device == "cpu": _ = stack.enter_context(autocast(device)) images = txt2img_pipe( prompt=[prompt] * n_samples, guidance_scale=guidance_scale, #n_samples=n_samples, #ddim_steps=ddim_steps, #ddim_eta=ddim_eta, #n_iter=n_iter, #H=height, #W=width, #C=channels, #scale=scale, #seed=seed, generator=g_cuda ).images for img in images: img.save(self.output_images + str(preset_id) + '----' + '%s.png' % uuid.uuid4())