Bingsu's picture
Upload files: v0.2.1
cd267d9
raw
history blame contribute delete
No virus
1.57 kB
from __future__ import annotations
from functools import cached_property
from diffusers import (
StableDiffusionControlNetInpaintPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
)
from asdff.base import AdPipelineBase
class AdPipeline(AdPipelineBase, StableDiffusionPipeline):
@cached_property
def inpaint_pipeline(self):
return StableDiffusionInpaintPipeline(
vae=self.vae,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
unet=self.unet,
scheduler=self.scheduler,
safety_checker=self.safety_checker,
feature_extractor=self.feature_extractor,
requires_safety_checker=self.config.requires_safety_checker,
)
@property
def txt2img_class(self):
return StableDiffusionPipeline
class AdCnPipeline(AdPipelineBase, StableDiffusionControlNetPipeline):
@cached_property
def inpaint_pipeline(self):
return StableDiffusionControlNetInpaintPipeline(
vae=self.vae,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
unet=self.unet,
controlnet=self.controlnet,
scheduler=self.scheduler,
safety_checker=self.safety_checker,
feature_extractor=self.feature_extractor,
requires_safety_checker=self.config.requires_safety_checker,
)
@property
def txt2img_class(self):
return StableDiffusionControlNetPipeline