picpilot-server / scripts /controlnet_outpainting.py
VikramSingh178's picture
chore: Remove unused matplotlib import in controlnet_outpainting.py
c94cf04
raw
history blame
No virus
3.79 kB
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
import torch
from PIL import Image
import lightning.pytorch as pl
from scripts.api_utils import accelerator
from typing import Optional
pl.seed_everything(42)
class ImageGenerator:
"""
A class to generate images using ControlNet and Stable Diffusion XL pipelines.
Attributes:
controlnet (ControlNetModel): The ControlNet model.
pipeline (StableDiffusionXLControlNetPipeline): The Stable Diffusion XL pipeline with ControlNet.
"""
def __init__(self, controlnet_model_name, sd_pipeline_model_name):
"""
Initializes the ImageGenerator with the specified models.
Args:
controlnet_model_name (str): The name of the ControlNet model.
sd_pipeline_model_name (str): The name of the Stable Diffusion XL pipeline model.
image (str): The path to the image to be used.
"""
self.controlnet = ControlNetModel.from_pretrained(
controlnet_model_name, torch_dtype=torch.float16, variant="fp16"
)
self.pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
sd_pipeline_model_name,
torch_dtype=torch.float16,
variant="fp16",
controlnet=self.controlnet,
).to(accelerator())
def inference(self, prompt, negative_prompt, height, width, guidance_scale, num_images_per_prompt, num_inference_steps, image_path, controlnet_conditioning_scale, control_guidance_end,output_path:Optional[str]):
"""
Generates images based on the provided parameters.
Args:
prompt (str): The prompt for image generation.
negative_prompt (str): The negative prompt for image generation.
height (int): The height of the generated images.
width (int): The width of the generated images.
guidance_scale (float): The guidance scale for image generation.
num_images_per_prompt (int): The number of images to generate per prompt.
num_inference_steps (int): The number of inference steps.
image_path (str): The path to the image to be used.
controlnet_conditioning_scale (float): The conditioning scale for ControlNet.
control_guidance_end (float): The end guidance for ControlNet.
Returns:
list: A list of generated images.
"""
images_list = self.pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
num_inference_steps=num_inference_steps,
image=Image.open(image_path),
controlnet_conditioning_scale=controlnet_conditioning_scale,
control_guidance_end=control_guidance_end,
).images
if output_path:
for i,image in enumerate(images_list):
image.save(f'{output_path}/output_{i}.png')
else:
return images_list
if __name__ == "__main__":
generator = ImageGenerator(
controlnet_model_name="destitech/controlnet-inpaint-dreamer-sdxl",
sd_pipeline_model_name="RunDiffusion/Juggernaut-XL-v9"
)
generator.inference(
prompt='Park',
negative_prompt='low Resolution , Bad Resolution',
height=1080,
width=1920,
guidance_scale=7.5,
num_images_per_prompt=4,
num_inference_steps=100,
image_path='/home/PicPilot/sample_data/example1.jpg',
controlnet_conditioning_scale=0.9,
control_guidance_end=0.9,
output_path='/home/PicPilot/output'
)