Spaces:
Runtime error
Runtime error
File size: 3,786 Bytes
ffaa8aa f1bf71a ffaa8aa f1bf71a ffaa8aa f1bf71a ffaa8aa f1bf71a ffaa8aa f1bf71a ffaa8aa f1bf71a ffaa8aa f1bf71a ffaa8aa f1bf71a ffaa8aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
import torch
from PIL import Image
import lightning.pytorch as pl
from scripts.api_utils import accelerator
from typing import Optional
pl.seed_everything(42)
class ImageGenerator:
"""
A class to generate images using ControlNet and Stable Diffusion XL pipelines.
Attributes:
controlnet (ControlNetModel): The ControlNet model.
pipeline (StableDiffusionXLControlNetPipeline): The Stable Diffusion XL pipeline with ControlNet.
"""
def __init__(self, controlnet_model_name, sd_pipeline_model_name):
"""
Initializes the ImageGenerator with the specified models.
Args:
controlnet_model_name (str): The name of the ControlNet model.
sd_pipeline_model_name (str): The name of the Stable Diffusion XL pipeline model.
image (str): The path to the image to be used.
"""
self.controlnet = ControlNetModel.from_pretrained(
controlnet_model_name, torch_dtype=torch.float16, variant="fp16"
)
self.pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
sd_pipeline_model_name,
torch_dtype=torch.float16,
variant="fp16",
controlnet=self.controlnet,
).to(accelerator())
def inference(self, prompt, negative_prompt, height, width, guidance_scale, num_images_per_prompt, num_inference_steps, image_path, controlnet_conditioning_scale, control_guidance_end,output_path:Optional[str]):
"""
Generates images based on the provided parameters.
Args:
prompt (str): The prompt for image generation.
negative_prompt (str): The negative prompt for image generation.
height (int): The height of the generated images.
width (int): The width of the generated images.
guidance_scale (float): The guidance scale for image generation.
num_images_per_prompt (int): The number of images to generate per prompt.
num_inference_steps (int): The number of inference steps.
image_path (str): The path to the image to be used.
controlnet_conditioning_scale (float): The conditioning scale for ControlNet.
control_guidance_end (float): The end guidance for ControlNet.
Returns:
list: A list of generated images.
"""
images_list = self.pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
num_inference_steps=num_inference_steps,
image=Image.open(image_path),
controlnet_conditioning_scale=controlnet_conditioning_scale,
control_guidance_end=control_guidance_end,
).images
if output_path:
for i,image in enumerate(images_list):
image.save(f'{output_path}/output_{i}.png')
else:
return images_list
if __name__ == "__main__":
generator = ImageGenerator(
controlnet_model_name="destitech/controlnet-inpaint-dreamer-sdxl",
sd_pipeline_model_name="RunDiffusion/Juggernaut-XL-v9"
)
generator.inference(
prompt='Park',
negative_prompt='low Resolution , Bad Resolution',
height=1080,
width=1920,
guidance_scale=7.5,
num_images_per_prompt=4,
num_inference_steps=100,
image_path='/home/PicPilot/sample_data/example1.jpg',
controlnet_conditioning_scale=0.9,
control_guidance_end=0.9,
output_path='/home/PicPilot/output'
)
|