9fo912 / model.py
pengdaqian
fix more
65c3ba9
raw
history blame
1.49 kB
import torch
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler, DPMSolverMultistepScheduler, \
OnnxStableDiffusionPipeline
import pipeline_openvino_stable_diffusion
from optimum.intel.openvino import OVStableDiffusionPipeline
def get_sd_21():
model_id = "stabilityai/stable-diffusion-2-1-base"
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
if torch.cuda.is_available():
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
scheduler=scheduler,
# safety_checker=None,
revision="fp16",
torch_dtype=torch.float16)
pipe = pipe.to('cuda')
else:
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
scheduler=scheduler,
# safety_checker=None,
revision="fp16",
torch_dtype=torch.float16)
return pipe
def get_sd_small():
model_id = 'OFA-Sys/small-stable-diffusion-v0'
scheduler = DPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = OVStableDiffusionPipeline.from_pretrained("OFA-Sys/small-stable-diffusion-v0", compile=False)
pipe.compile()
return pipe
def get_sd_tiny():
pipe = OVStableDiffusionPipeline.from_pretrained("OpenVINO/stable-diffusion-2-1-quantized", compile=False)
pipe.reshape(batch_size=1, height=512, width=512, num_images_per_prompt=1)
pipe.compile()
return pipe