picpilot-server / scripts /inpainting_pipeline.py
VikramSingh178's picture
chore: Update inpainting pipeline configuration and parameters
5c75d65
import torch
from diffusers import AutoPipelineForInpainting,DiffusionPipeline
from diffusers.utils import load_image
from api_utils import accelerator, ImageAugmentation
import hydra
from omegaconf import DictConfig
from PIL import Image
def load_pipeline(model_name: str, device, enable_compile: bool = True):
pipeline = AutoPipelineForInpainting.from_pretrained(model_name, torch_dtype=torch.float16)
if enable_compile:
pipeline.unet.to(memory_format=torch.channels_last)
pipeline.unet = torch.compile(pipeline.unet, mode='reduce-overhead',fullgraph=True)
pipeline.to(device)
return pipeline
class AutoPaintingPipeline:
def __init__(self, pipeline, image: Image, mask_image: Image, target_width: int, target_height: int):
self.pipeline = pipeline
self.image = image
self.mask_image = mask_image
self.target_width = target_width
self.target_height = target_height
def run_inference(self, prompt: str, negative_prompt: str, num_inference_steps: int, strength: float, guidance_scale: float,num_images):
output = self.pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
image=self.image,
mask_image=self.mask_image,
num_inference_steps=num_inference_steps,
strength=strength,
guidance_scale=guidance_scale,
num_images_per_prompt = num_images,
height=self.target_height,
width=self.target_width
).images[0]
return output
@hydra.main(version_base=None, config_path="../configs", config_name="inpainting")
def inference(cfg: DictConfig):
# Load the pipeline once and cache it
pipeline = load_pipeline(cfg.model, accelerator(), True)
# Image augmentation and preparation
augmenter = ImageAugmentation(target_width=cfg.target_width, target_height=cfg.target_height)
image_path = "../sample_data/example3.jpg"
image = Image.open(image_path)
extended_image = augmenter.extend_image(image)
mask_image = augmenter.generate_mask_from_bbox(extended_image, cfg.segmentation_model, cfg.detection_model)
mask_image = augmenter.invert_mask(mask_image)
# Create AutoPaintingPipeline instance with cached pipeline
painting_pipeline = AutoPaintingPipeline(
pipeline=pipeline,
image=extended_image,
mask_image=mask_image,
target_height=cfg.target_height,
target_width=cfg.target_width
)
# Run inference
output = painting_pipeline.run_inference(
prompt=cfg.prompt,
negative_prompt=cfg.negative_prompt,
num_inference_steps=cfg.num_inference_steps,
strength=cfg.strength,
guidance_scale=cfg.guidance_scale
)
# Save output and mask images
output.save(f'{cfg.output_path}/output.jpg')
mask_image.save(f'{cfg.output_path}/mask.jpg')
if __name__ == "__main__":
inference()