fashionsd / sdfile.py
Abhi5ingh's picture
Update sdfile.py
575593e verified
raw
history blame contribute delete
No virus
3.09 kB
import gc
import datetime
import os
import re
from typing import Literal
import streamlit as st
import torch
from diffusers import (
StableDiffusionPipeline,
StableDiffusionControlNetPipeline,
ControlNetModel,
EulerDiscreteScheduler,
DDIMScheduler,
)
PIPELINES = Literal["txt2img", "sketch2img"]
@st.cache_resource(max_entries=1)
def get_pipelines( name:PIPELINES, enable_cpu_offload = False, ) -> StableDiffusionPipeline:
pipe = None
if name == "txt2img":
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16,cache_dir="D:/huggingface/CACHE/")
pipe.unet.load_attn_procs("./")
pipe.safety_checker = lambda images, **kwargs: (images, [False] * len(images))
elif name == "sketch2img":
controlnet = ControlNetModel.from_pretrained("Abhi5ingh/ControlnetDresscode", torch_dtype=torch.float16,cache_dir="D:/huggingface/CACHE/")
pipe = StableDiffusionControlNetPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", controlnet = controlnet, torch_dtype = torch.float16,cache_dir="D:/huggingface/CACHE/")
pipe.unet.load_attn_procs("./")
pipe.safety_checker = lambda images, **kwargs: (images, [False] * len(images))
if pipe is None:
raise Exception(f"Pipeline not Found {name}")
if enable_cpu_offload:
print("Enabling cpu offloading for the given pipeline")
pipe.enable_model_cpu_offload()
else:
pipe = pipe.to("cuda")
return pipe
def generate(
prompt,
pipeline_name: PIPELINES,
image = None,
num_inference_steps = 30,
negative_prompt = None,
width = 512,
height = 512,
guidance_scale = 7.5,
controlnet_conditioning_scale = None,
enable_cpu_offload= False):
negative_prompt = negative_prompt if negative_prompt else None
p = st.progress(0)
callback = lambda step,*_: p.progress(step/num_inference_steps)
pipe = get_pipelines(pipeline_name,enable_cpu_offload=enable_cpu_offload)
torch.cuda.empty_cache()
kwargs = dict(
prompt = prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
callback=callback,
guidance_scale=guidance_scale,
)
print("kwargs",kwargs)
if pipeline_name =="sketch2img" and image:
kwargs.update(image=image,controlnet_conditioning_scale=controlnet_conditioning_scale)
elif pipeline_name == "txt2img":
kwargs.update(width = width, height = height)
else:
raise Exception(
f"Cannot generate image for pipeline {pipeline_name} and {prompt}")
images = pipe(**kwargs).images
image = images[0]
os.makedirs("outputs", exist_ok=True)
filename = (
"outputs/"
+ re.sub(r"\s+", "_",prompt)[:30]
+ f"_{datetime.datetime.now().timestamp()}"
)
image.save(f"{filename}.png")
with open(f"{filename}.txt", "w") as f:
f.write(f"Prompt: {prompt}\n\nNegative Prompt:{negative_prompt}")
return image