Spaces:
Build error
Build error
File size: 2,978 Bytes
96960b9 4fe6ecb fa32203 6e67c16 fa32203 e2ebd5a fa32203 080c0ac 6c850b1 f419a09 f74e97b fa32203 6c850b1 5626570 6c850b1 0695699 6c850b1 0695699 6c850b1 fa32203 6c850b1 fa32203 6c850b1 5626570 fa32203 5626570 6c850b1 fa32203 6c850b1 fa32203 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import torch
from diffusers import AutoPipelineForInpainting,DiffusionPipeline
from diffusers.utils import load_image
from scripts.api_utils import accelerator, ImageAugmentation
import hydra
from omegaconf import DictConfig
from PIL import Image
def load_pipeline(model_name: str, device, enable_compile: bool = True):
pipeline = AutoPipelineForInpainting.from_pretrained(model_name, torch_dtype=torch.float16)
if enable_compile:
pipeline.unet.to(memory_format=torch.channels_last)
pipeline.unet = torch.compile(pipeline.unet, mode='reduce-overhead',fullgraph=True)
pipeline.to(device)
return pipeline
class AutoPaintingPipeline:
def __init__(self, pipeline, image: Image, mask_image: Image, target_width: int, target_height: int):
self.pipeline = pipeline
self.image = image
self.mask_image = mask_image
self.target_width = target_width
self.target_height = target_height
def run_inference(self, prompt: str, negative_prompt: str, num_inference_steps: int, strength: float, guidance_scale: float,num_images):
output = self.pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
image=self.image,
mask_image=self.mask_image,
num_inference_steps=num_inference_steps,
strength=strength,
guidance_scale=guidance_scale,
num_images_per_prompt = num_images,
height=self.target_height,
width=self.target_width
).images[0]
return output
@hydra.main(version_base=None, config_path="../configs", config_name="inpainting")
def inference(cfg: DictConfig):
# Load the pipeline once and cache it
pipeline = load_pipeline(cfg.model, accelerator(), True)
# Image augmentation and preparation
augmenter = ImageAugmentation(target_width=cfg.target_width, target_height=cfg.target_height)
image_path = "../sample_data/example3.jpg"
image = Image.open(image_path)
extended_image = augmenter.extend_image(image)
mask_image = augmenter.generate_mask_from_bbox(extended_image, cfg.segmentation_model, cfg.detection_model)
mask_image = augmenter.invert_mask(mask_image)
# Create AutoPaintingPipeline instance with cached pipeline
painting_pipeline = AutoPaintingPipeline(
pipeline=pipeline,
image=extended_image,
mask_image=mask_image,
target_height=cfg.target_height,
target_width=cfg.target_width
)
# Run inference
output = painting_pipeline.run_inference(
prompt=cfg.prompt,
negative_prompt=cfg.negative_prompt,
num_inference_steps=cfg.num_inference_steps,
strength=cfg.strength,
guidance_scale=cfg.guidance_scale
)
# Save output and mask images
output.save(f'{cfg.output_path}/output.jpg')
mask_image.save(f'{cfg.output_path}/mask.jpg')
if __name__ == "__main__":
inference()
|