Spaces:
Build error
Build error
import torch | |
from PIL import Image | |
from diffusers import ( | |
StableDiffusionControlNetPipeline, | |
UniPCMultistepScheduler, | |
ControlNetModel | |
) | |
class GeoPainting: | |
DEFAULT_CONTROLNET_MODEL = "lllyasviel/control_v11f1p_sd15_depth" | |
DEFAULT_DIFFUSER_MODEL = "geospatial_diffuser" | |
def __init__(self, controlnet_model_path=DEFAULT_CONTROLNET_MODEL, diffuser_model=DEFAULT_DIFFUSER_MODEL): | |
self.controlnet = ControlNetModel.from_pretrained(controlnet_model_path, torch_dtype=torch.float16) | |
self.generator = torch.Generator(device="cpu").manual_seed(2) | |
self.pipe = StableDiffusionControlNetPipeline.from_pretrained( | |
diffuser_model, | |
low_cpu_mem_usage=False, | |
device_map=None, | |
controlnet=self.controlnet, | |
torch_dtype=torch.float16 | |
) | |
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config) | |
if torch.cuda.is_available(): | |
self.pipe.enable_model_cpu_offload() | |
self.pipe.enable_xformers_memory_efficient_attention() | |
def generate_painting(self, input_promp, control_image): | |
image = Image.fromarray(control_image.astype('uint8')) | |
output = self.pipe( | |
input_promp, | |
image, | |
negative_prompt="ugly, disfigured, low quality, blurry, nsfw", | |
generator=self.generator, | |
num_inference_steps=20, | |
) | |
return output.images[0] | |