""" This Module contains funstions for loading the segmentation model and inpainting models, and editing top using a example image or text prompt. """ # Imports from diffusers import DiffusionPipeline from diffusers import StableDiffusionInpaintPipeline from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation from torchvision.transforms.functional import to_pil_image from PIL import Image import torch import numpy as np import urllib.request # Functions def load_seg(model_card: str = "mattmdjaga/segformer_b2_clothes"): """ Load The Segmentation Extractor and Model. Parameters: model_card: HuggingFace Model Card. Default: mattmdjaga/segformer_b2_clothes Returns: extractor: Feature Extractor model: Segformer Model For Segmentation """ extractor = AutoFeatureExtractor.from_pretrained(model_card) model = SegformerForSemanticSegmentation.from_pretrained(model_card) return extractor, model def load_inpainting(using_prompt: bool = False, fast: bool = False): """ Load Inpaining Model. Parameters: using_prompt: If using a prompt based inpainting model or image based inpainting model. Default: False Returns: pipe: Diffusion Pipeline mounted onto the device """ device = "cuda" if torch.cuda.is_available() else "cpu" if using_prompt: if fast: pipe = StableDiffusionInpaintPipeline.from_pretrained( "runwayml/stable-diffusion-inpainting", revision="fp16", torch_dtype=torch.float16, ) else: pipe = StableDiffusionInpaintPipeline.from_pretrained( "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float32, ) else: if fast: pipe = DiffusionPipeline.from_pretrained( "Fantasy-Studio/Paint-by-Example", torch_dtype=torch.float16, ) else: pipe = DiffusionPipeline.from_pretrained( "Fantasy-Studio/Paint-by-Example", torch_dtype=torch.float32, ) pipe = pipe.to(device) return pipe def generate_mask(image_name: str, extractor, model): """ Generate mask using Image Path and Segmentation Model. Parameters: image_name: Path to Input Image extractor: Feature Extractor model: Segmentation Model Returns: image: PIL Image of Input Image mask: PIL Image of Generated Mask """ try: image = Image.open(image_name) except Exception as e: image = Image.open(urllib.request.urlopen(image_name)) inputs = extractor(images=image, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits.cpu() upsampled_logits = torch.nn.functional.interpolate( logits, size=image.size[::-1], mode="bilinear", align_corners=False, ) pred_seg = upsampled_logits.argmax(dim=1)[0] pred_seg[pred_seg != 4] = 0 pred_seg[pred_seg == 4] = 1 pred_seg = pred_seg.to(dtype=torch.float32) # pred_seg = pred_seg.unsqueeze(dim = 0) mask = to_pil_image(pred_seg) return image, mask def get_cloth(cloth_name, extractor, model): cloth_image, cloth_mask = generate_mask(cloth_name, extractor, model) cloth = np.array(cloth_image) cloth[np.array(cloth_mask) == 0] = 255 return to_pil_image(cloth) def generate_image(image, mask, pipe, example_name=None, prompt=None): """ Generate Edited Image. Uses Example Image or Prompt. Parameters: image: PIL Image of The Image to Edit. mask: PIL Image of the Mask. pipe: DiffusionPipeline example_name: Path to Image of the cloth. prompt: Editing Prompt, if not using Example Image. Returns: image: PIL Image of Input Image mask: PIL Image of Generated Mask gen: PIL Image of Generated Preview """ if example_name: try: example = Image.open(example_name) except Exception as e: example = Image.open(urllib.request.urlopen(example_name)) gen = pipe( image=image.resize((512, 512)), mask_image=mask.resize((512, 512)), example_image=example.resize((512, 512)), ).images[0] elif prompt: gen = pipe(prompt=prompt, image=image, mask_image=mask).images[0] else: gen = None print("Neither Example Image nor Prompt provided.") return image, mask, gen def generate_image_with_mask(image, mask, pipe, extractor, model, example_name=None, prompt=None): """ Generate Edited Image. Uses Example Image or Prompt. Extracts the Cloth from the cloth image. Parameters: image: PIL Image of The Image to Edit. mask: PIL Image of the Mask. pipe: DiffusionPipeline example_name: Path to Image of the cloth. prompt: Editing Prompt, if not using Example Image. Returns: image: PIL Image of Input Image mask: PIL Image of Generated Mask gen: PIL Image of Generated Preview """ if example_name: cloth = get_cloth(example_name, extractor, model) gen = pipe( image=image.resize((512, 512)), mask_image=mask.resize((512, 512)), example_image=cloth.resize((512, 512)), ).images[0] elif prompt: gen = pipe(prompt=prompt, image=image, mask_image=mask).images[0] else: gen = None print("Neither Example Image nor Prompt provided.") return image, mask, gen def load(using_prompt=False): """ Loads Segmentation and Inpainting Model. Parameters: using_prompt: If using a prompt based inpainting model or image based inpainting model. Default: False Returns: extractor: Feature Extractor model: Segformer Model For Segmentation pipe: Diffusion Pipeline loaded onto the device """ extractor, model = load_seg() pipe = load_inpainting(using_prompt) return extractor, model, pipe def generate(image_name, extractor, model, pipe, example_name=None, prompt=None): """ Generate Preview. Parameters: image_name: Path to Input Image extractor: Feature Extractor model: Segmentation Model pipe: DiffusionPipeline example_name: Path to Image of the cloth. prompt: Editing Prompt, if not using Example Image. Returns: gen: PIL Image of Generated Preview """ image, mask = generate_mask(image_name, extractor, model) res = int(mask.size[1] * 512 / mask.size[0]) image, mask, gen = generate_image(image, mask, pipe, example_name, prompt) return gen.resize((512, res)) def generate_with_mask(image_name, extractor, model, pipe, example_name=None, prompt=None): """ Generate Preview. Parameters: image_name: Path to Input Image extractor: Feature Extractor model: Segmentation Model pipe: DiffusionPipeline example_name: Path to Image of the cloth. prompt: Editing Prompt, if not using Example Image. Returns: gen: PIL Image of Generated Preview """ image, mask = generate_mask(image_name, extractor, model) res = int(mask.size[1] * 512 / mask.size[0]) image, mask, gen = generate_image_with_mask(image, mask, pipe, extractor, model, example_name, prompt) return gen.resize((512, res))