Geo-Painting / geo_painting.py
DavidTamayo's picture
General clean up
ca559d4
import torch
from PIL import Image
from diffusers import (
StableDiffusionControlNetPipeline,
UniPCMultistepScheduler,
ControlNetModel
)
class GeoPainting:
DEFAULT_CONTROLNET_MODEL = "lllyasviel/control_v11f1p_sd15_depth"
DEFAULT_DIFFUSER_MODEL = "geospatial_diffuser"
def __init__(self, controlnet_model_path=DEFAULT_CONTROLNET_MODEL, diffuser_model=DEFAULT_DIFFUSER_MODEL):
self.controlnet = ControlNetModel.from_pretrained(controlnet_model_path, torch_dtype=torch.float16)
self.generator = torch.Generator(device="cpu").manual_seed(2)
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
diffuser_model,
low_cpu_mem_usage=False,
device_map=None,
controlnet=self.controlnet,
torch_dtype=torch.float16
)
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
if torch.cuda.is_available():
self.pipe.enable_model_cpu_offload()
self.pipe.enable_xformers_memory_efficient_attention()
def generate_painting(self, input_promp, control_image):
image = Image.fromarray(control_image.astype('uint8'))
output = self.pipe(
input_promp,
image,
negative_prompt="ugly, disfigured, low quality, blurry, nsfw",
generator=self.generator,
num_inference_steps=20,
)
return output.images[0]