File size: 3,786 Bytes
ffaa8aa
f1bf71a
 
ffaa8aa
 
 
 
f1bf71a
ffaa8aa
 
 
f1bf71a
ffaa8aa
 
 
 
f1bf71a
ffaa8aa
 
 
f1bf71a
ffaa8aa
 
 
 
 
 
 
 
 
 
 
 
 
 
f1bf71a
ffaa8aa
 
 
f1bf71a
ffaa8aa
 
 
 
 
 
 
 
 
 
 
f1bf71a
ffaa8aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
import torch
from PIL import Image
import lightning.pytorch as pl
from scripts.api_utils import accelerator
from typing import Optional
pl.seed_everything(42)

class ImageGenerator:
    """
    A class to generate images using ControlNet and Stable Diffusion XL pipelines.

    Attributes:
        controlnet (ControlNetModel): The ControlNet model.
        pipeline (StableDiffusionXLControlNetPipeline): The Stable Diffusion XL pipeline with ControlNet.
    """

    def __init__(self, controlnet_model_name, sd_pipeline_model_name):
        """
        Initializes the ImageGenerator with the specified models.

        Args:
            controlnet_model_name (str): The name of the ControlNet model.
            sd_pipeline_model_name (str): The name of the Stable Diffusion XL pipeline model.
            image (str): The path to the image to be used.
        """
        self.controlnet = ControlNetModel.from_pretrained(
            controlnet_model_name, torch_dtype=torch.float16, variant="fp16"
        )
        self.pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
            sd_pipeline_model_name,
            torch_dtype=torch.float16,
            variant="fp16",
            controlnet=self.controlnet,
        ).to(accelerator())

    def inference(self, prompt, negative_prompt, height, width, guidance_scale, num_images_per_prompt, num_inference_steps, image_path, controlnet_conditioning_scale, control_guidance_end,output_path:Optional[str]):
        """
        Generates images based on the provided parameters.

        Args:
            prompt (str): The prompt for image generation.
            negative_prompt (str): The negative prompt for image generation.
            height (int): The height of the generated images.
            width (int): The width of the generated images.
            guidance_scale (float): The guidance scale for image generation.
            num_images_per_prompt (int): The number of images to generate per prompt.
            num_inference_steps (int): The number of inference steps.
            image_path (str): The path to the image to be used.
            controlnet_conditioning_scale (float): The conditioning scale for ControlNet.
            control_guidance_end (float): The end guidance for ControlNet.

        Returns:
            list: A list of generated images.
        """
        images_list = self.pipeline(
            prompt=prompt,
            negative_prompt=negative_prompt,
            height=height,
            width=width,
            guidance_scale=guidance_scale,
            num_images_per_prompt=num_images_per_prompt,
            num_inference_steps=num_inference_steps,
            image=Image.open(image_path),
            controlnet_conditioning_scale=controlnet_conditioning_scale,
            control_guidance_end=control_guidance_end,
        ).images
        if output_path:
            for i,image in enumerate(images_list):
                 image.save(f'{output_path}/output_{i}.png')
        else:
            return images_list
        
if __name__ == "__main__":
    generator = ImageGenerator(
        controlnet_model_name="destitech/controlnet-inpaint-dreamer-sdxl",
        sd_pipeline_model_name="RunDiffusion/Juggernaut-XL-v9"
    )
    generator.inference(
        prompt='Park',
        negative_prompt='low Resolution , Bad Resolution',
        height=1080,
        width=1920,
        guidance_scale=7.5,
        num_images_per_prompt=4,
        num_inference_steps=100,
        image_path='/home/PicPilot/sample_data/example1.jpg',
        controlnet_conditioning_scale=0.9,
        control_guidance_end=0.9,
        output_path='/home/PicPilot/output'
    )