pictero / helpers.py
nickyreinert-vml
fixing minor bug
a494f64
raw
history blame contribute delete
No virus
3.55 kB
import torch
from diffusers import (
DDPMScheduler,
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
DPMSolverMultistepScheduler,
)
import base64
def get_variant(str_variant):
if str(str_variant).lower() == 'none':
return None
else:
return str_variant
def get_bool(str_bool):
if str(str_bool).lower() == 'false':
return False
else:
return True
def get_data_type(str_data_type):
if str_data_type == "bfloat16":
return torch.bfloat16 # BFloat16 is not supported on MPS as of 01/2024
if str_data_type == "float32":
return torch.float32 # BFloat16 is not supported on MPS as of 01/2024
else:
return torch.float16 # Half-precision weights, as of https://huggingface.co/docs/diffusers/main/en/optimization/fp16 will save GPU memory
def get_tensorfloat32(allow_tensorfloat32):
return True if str(allow_tensorfloat32).lower() == 'true' else False
def get_pipeline(config):
if "pipeline" in config and config["pipeline"] == "StableDiffusion3Pipeline":
from diffusers import StableDiffusion3Pipeline
pipeline = StableDiffusion3Pipeline.from_pretrained(
config["model"],
use_safetensors = get_bool(config["use_safetensors"]),
torch_dtype = get_data_type(config["data_type"]),
variant = get_variant(config["variant"])).to(config["device"])
else:
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
config["model"],
use_safetensors = get_bool(config["use_safetensors"]),
torch_dtype = get_data_type(config["data_type"]),
variant = get_variant(config["variant"])
).to(config["device"])
return pipeline
def get_scheduler(scheduler, pipeline_config):
if scheduler == "DDPMScheduler":
return DDPMScheduler.from_config(pipeline_config)
elif scheduler == "DDIMScheduler":
return DDIMScheduler.from_config(pipeline_config)
elif scheduler == "PNDMScheduler":
return PNDMScheduler.from_config(pipeline_config)
elif scheduler == "LMSDiscreteScheduler":
return LMSDiscreteScheduler.from_config(pipeline_config)
elif scheduler == "EulerAncestralDiscreteScheduler":
return EulerAncestralDiscreteScheduler.from_config(pipeline_config)
elif scheduler == "EulerDiscreteScheduler":
return EulerDiscreteScheduler.from_config(pipeline_config)
elif scheduler == "DPMSolverMultistepScheduler":
return DPMSolverMultistepScheduler.from_config(pipeline_config)
else:
return DPMSolverMultistepScheduler.from_config(pipeline_config)
def dict_list_to_markdown_table(config_history):
if not config_history:
return ""
headers = list(config_history[0].keys())
markdown_table = "| share | " + " | ".join(headers) + " |\n"
markdown_table += "| --- | " + " | ".join(["---"] * len(headers)) + " |\n"
for index, config in enumerate(config_history):
encoded_config = base64.b64encode(str(config).encode()).decode()
share_link = f'<a target="_blank" href="?config={encoded_config}">πŸ“Ž</a>'
markdown_table += f"| {share_link} | " + " | ".join(str(config.get(key, "")) for key in headers) + " |\n"
markdown_table = '<div style="overflow-x: auto;">\n\n' + markdown_table + '</div>'
return markdown_table