File size: 2,970 Bytes
96960b9
4fe6ecb
fa32203
d030f4c
fa32203
e2ebd5a
fa32203
080c0ac
 
6c850b1
 
 
 
 
 
f419a09
f74e97b
fa32203
6c850b1
 
 
 
5626570
 
6c850b1
0695699
6c850b1
 
 
 
 
 
 
 
0695699
6c850b1
 
 
 
fa32203
6c850b1
 
fa32203
6c850b1
 
 
 
5626570
 
fa32203
 
 
 
5626570
6c850b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa32203
6c850b1
 
 
 
fa32203
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
from diffusers import AutoPipelineForInpainting,DiffusionPipeline
from diffusers.utils import load_image
from api_utils import accelerator, ImageAugmentation
import hydra
from omegaconf import DictConfig
from PIL import Image

def load_pipeline(model_name: str, device, enable_compile: bool = True):
    pipeline = AutoPipelineForInpainting.from_pretrained(model_name, torch_dtype=torch.float16)
    if enable_compile:
        pipeline.unet.to(memory_format=torch.channels_last)
        pipeline.unet = torch.compile(pipeline.unet, mode='reduce-overhead',fullgraph=True)
    pipeline.to(device)
    return pipeline


class AutoPaintingPipeline:
    def __init__(self, pipeline, image: Image, mask_image: Image, target_width: int, target_height: int):
        self.pipeline = pipeline
        self.image = image
        self.mask_image = mask_image
        self.target_width = target_width
        self.target_height = target_height

    def run_inference(self, prompt: str, negative_prompt: str, num_inference_steps: int, strength: float, guidance_scale: float,num_images):
        output = self.pipeline(
            prompt=prompt,
            negative_prompt=negative_prompt,
            image=self.image,
            mask_image=self.mask_image,
            num_inference_steps=num_inference_steps,
            strength=strength,
            guidance_scale=guidance_scale,
            num_images_per_prompt = num_images,
            height=self.target_height,
            width=self.target_width
            
        ).images[0]
        return output

@hydra.main(version_base=None, config_path="../configs", config_name="inpainting")
def inference(cfg: DictConfig):
    # Load the pipeline once and cache it
    pipeline = load_pipeline(cfg.model, accelerator(), True)

    # Image augmentation and preparation
    augmenter = ImageAugmentation(target_width=cfg.target_width, target_height=cfg.target_height)
    image_path = "../sample_data/example3.jpg"
    image = Image.open(image_path)
    extended_image = augmenter.extend_image(image)
    mask_image = augmenter.generate_mask_from_bbox(extended_image, cfg.segmentation_model, cfg.detection_model)
    mask_image = augmenter.invert_mask(mask_image)
    
    # Create AutoPaintingPipeline instance with cached pipeline
    painting_pipeline = AutoPaintingPipeline(
        pipeline=pipeline,
        image=extended_image,
        mask_image=mask_image,
        target_height=cfg.target_height,
        target_width=cfg.target_width
    )
    
    # Run inference
    output = painting_pipeline.run_inference(
        prompt=cfg.prompt,
        negative_prompt=cfg.negative_prompt,
        num_inference_steps=cfg.num_inference_steps,
        strength=cfg.strength,
        guidance_scale=cfg.guidance_scale
    )
    
    # Save output and mask images
    output.save(f'{cfg.output_path}/output.jpg')
    mask_image.save(f'{cfg.output_path}/mask.jpg')

if __name__ == "__main__":
    inference()