txt2img-consumer / Pipeline.py
akinv's picture
Fixed
f879398
#!/usr/bin/env python
import os
from dotenv import load_dotenv
from Helpers import name_formatter, weights_dir, capture_message
from contextlib import ExitStack
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline, DDIMScheduler
import base64
from io import BytesIO
from PIL import Image
import uuid
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type != 'cuda':
raise ValueError("need to run on GPU")
class Pipeline:
def __init__(self, data):
load_dotenv()
self.output_images = os.getenv('OUTPUT_IMAGES')
self.data = data
def generate(self):
for preset in self.data['presets']:
self.run(preset)
def run(self, preset):
torch.cuda.empty_cache()
capture_message('Pipeline: Run')
prompt = preset['prompt'] % (name_formatter(self.data['name']))
n_samples = preset.get('n_samples', 1)
guidance_scale = preset.get('guidance_scale', 7.5)
ddim_steps = preset.get('ddim_steps', 50)
ddim_eta = preset.get('ddim_eta', 0.0)
n_iter = preset.get('n_iter', 1)
height = preset.get('height', 512)
width = preset.get('width', 512)
channels = preset.get('channels', 4)
scale = preset.get('scale', 7.5)
seed = preset.get('seed', 7.5)
preset_id = preset.get('preset_id', 1)
scheduler = DDIMScheduler(beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False)
txt2img_pipe = StableDiffusionPipeline.from_pretrained(
weights_dir(),
scheduler=scheduler,
safety_checker=None,
torch_dtype=torch.float16,
use_auth_token="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"
)
txt2img_pipe = txt2img_pipe.to(device)
g_cuda = torch.Generator(device='cuda')
g_cuda.manual_seed(seed)
with ExitStack() as stack:
if device == "cpu":
_ = stack.enter_context(autocast(device))
images = txt2img_pipe(
prompt=[prompt] * n_samples,
guidance_scale=guidance_scale,
#n_samples=n_samples,
#ddim_steps=ddim_steps,
#ddim_eta=ddim_eta,
#n_iter=n_iter,
#H=height,
#W=width,
#C=channels,
#scale=scale,
#seed=seed,
generator=g_cuda
).images
for img in images:
img.save(self.output_images + str(preset_id) + '----' + '%s.png' % uuid.uuid4())