Spaces:
Runtime error
Runtime error
""" | |
This Module contains funstions for loading the segmentation model and inpainting models, and editing top using a example image or text prompt. | |
""" | |
# Imports | |
from diffusers import DiffusionPipeline | |
from diffusers import StableDiffusionInpaintPipeline | |
from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation | |
from torchvision.transforms.functional import to_pil_image | |
from PIL import Image | |
import torch | |
import numpy as np | |
import urllib.request | |
# Functions | |
def load_seg(model_card: str = "mattmdjaga/segformer_b2_clothes"): | |
""" | |
Load The Segmentation Extractor and Model. | |
Parameters: | |
model_card: HuggingFace Model Card. Default: mattmdjaga/segformer_b2_clothes | |
Returns: | |
extractor: Feature Extractor | |
model: Segformer Model For Segmentation | |
""" | |
extractor = AutoFeatureExtractor.from_pretrained(model_card) | |
model = SegformerForSemanticSegmentation.from_pretrained(model_card) | |
return extractor, model | |
def load_inpainting(using_prompt: bool = False, fast: bool = False): | |
""" | |
Load Inpaining Model. | |
Parameters: | |
using_prompt: If using a prompt based inpainting model or image based inpainting model. Default: False | |
Returns: | |
pipe: Diffusion Pipeline mounted onto the device | |
""" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
if using_prompt: | |
if fast: | |
pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
"runwayml/stable-diffusion-inpainting", | |
revision="fp16", | |
torch_dtype=torch.float16, | |
) | |
else: | |
pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
"runwayml/stable-diffusion-inpainting", | |
torch_dtype=torch.float32, | |
) | |
else: | |
if fast: | |
pipe = DiffusionPipeline.from_pretrained( | |
"Fantasy-Studio/Paint-by-Example", | |
torch_dtype=torch.float16, | |
) | |
else: | |
pipe = DiffusionPipeline.from_pretrained( | |
"Fantasy-Studio/Paint-by-Example", | |
torch_dtype=torch.float32, | |
) | |
pipe = pipe.to(device) | |
return pipe | |
def generate_mask(image_name: str, extractor, model): | |
""" | |
Generate mask using Image Path and Segmentation Model. | |
Parameters: | |
image_name: Path to Input Image | |
extractor: Feature Extractor | |
model: Segmentation Model | |
Returns: | |
image: PIL Image of Input Image | |
mask: PIL Image of Generated Mask | |
""" | |
try: | |
image = Image.open(image_name) | |
except Exception as e: | |
image = Image.open(urllib.request.urlopen(image_name)) | |
inputs = extractor(images=image, return_tensors="pt") | |
outputs = model(**inputs) | |
logits = outputs.logits.cpu() | |
upsampled_logits = torch.nn.functional.interpolate( | |
logits, | |
size=image.size[::-1], | |
mode="bilinear", | |
align_corners=False, | |
) | |
pred_seg = upsampled_logits.argmax(dim=1)[0] | |
pred_seg[pred_seg != 4] = 0 | |
pred_seg[pred_seg == 4] = 1 | |
pred_seg = pred_seg.to(dtype=torch.float32) | |
# pred_seg = pred_seg.unsqueeze(dim = 0) | |
mask = to_pil_image(pred_seg) | |
return image, mask | |
def get_cloth(cloth_name, extractor, model): | |
cloth_image, cloth_mask = generate_mask(cloth_name, extractor, model) | |
cloth = np.array(cloth_image) | |
cloth[np.array(cloth_mask) == 0] = 255 | |
return to_pil_image(cloth) | |
def generate_image(image, mask, pipe, example_name=None, prompt=None): | |
""" | |
Generate Edited Image. Uses Example Image or Prompt. | |
Parameters: | |
image: PIL Image of The Image to Edit. | |
mask: PIL Image of the Mask. | |
pipe: DiffusionPipeline | |
example_name: Path to Image of the cloth. | |
prompt: Editing Prompt, if not using Example Image. | |
Returns: | |
image: PIL Image of Input Image | |
mask: PIL Image of Generated Mask | |
gen: PIL Image of Generated Preview | |
""" | |
if example_name: | |
try: | |
example = Image.open(example_name) | |
except Exception as e: | |
example = Image.open(urllib.request.urlopen(example_name)) | |
gen = pipe( | |
image=image.resize((512, 512)), | |
mask_image=mask.resize((512, 512)), | |
example_image=example.resize((512, 512)), | |
).images[0] | |
elif prompt: | |
gen = pipe(prompt=prompt, image=image, mask_image=mask).images[0] | |
else: | |
gen = None | |
print("Neither Example Image nor Prompt provided.") | |
return image, mask, gen | |
def generate_image_with_mask(image, mask, pipe, extractor, model, example_name=None, prompt=None): | |
""" | |
Generate Edited Image. Uses Example Image or Prompt. Extracts the Cloth from the cloth image. | |
Parameters: | |
image: PIL Image of The Image to Edit. | |
mask: PIL Image of the Mask. | |
pipe: DiffusionPipeline | |
example_name: Path to Image of the cloth. | |
prompt: Editing Prompt, if not using Example Image. | |
Returns: | |
image: PIL Image of Input Image | |
mask: PIL Image of Generated Mask | |
gen: PIL Image of Generated Preview | |
""" | |
if example_name: | |
cloth = get_cloth(example_name, extractor, model) | |
gen = pipe( | |
image=image.resize((512, 512)), | |
mask_image=mask.resize((512, 512)), | |
example_image=cloth.resize((512, 512)), | |
).images[0] | |
elif prompt: | |
gen = pipe(prompt=prompt, image=image, mask_image=mask).images[0] | |
else: | |
gen = None | |
print("Neither Example Image nor Prompt provided.") | |
return image, mask, gen | |
def load(using_prompt=False): | |
""" | |
Loads Segmentation and Inpainting Model. | |
Parameters: | |
using_prompt: If using a prompt based inpainting model or image based inpainting model. Default: False | |
Returns: | |
extractor: Feature Extractor | |
model: Segformer Model For Segmentation | |
pipe: Diffusion Pipeline loaded onto the device | |
""" | |
extractor, model = load_seg() | |
pipe = load_inpainting(using_prompt) | |
return extractor, model, pipe | |
def generate(image_name, extractor, model, pipe, example_name=None, prompt=None): | |
""" | |
Generate Preview. | |
Parameters: | |
image_name: Path to Input Image | |
extractor: Feature Extractor | |
model: Segmentation Model | |
pipe: DiffusionPipeline | |
example_name: Path to Image of the cloth. | |
prompt: Editing Prompt, if not using Example Image. | |
Returns: | |
gen: PIL Image of Generated Preview | |
""" | |
image, mask = generate_mask(image_name, extractor, model) | |
res = int(mask.size[1] * 512 / mask.size[0]) | |
image, mask, gen = generate_image(image, mask, pipe, example_name, prompt) | |
return gen.resize((512, res)) | |
def generate_with_mask(image_name, extractor, model, pipe, example_name=None, prompt=None): | |
""" | |
Generate Preview. | |
Parameters: | |
image_name: Path to Input Image | |
extractor: Feature Extractor | |
model: Segmentation Model | |
pipe: DiffusionPipeline | |
example_name: Path to Image of the cloth. | |
prompt: Editing Prompt, if not using Example Image. | |
Returns: | |
gen: PIL Image of Generated Preview | |
""" | |
image, mask = generate_mask(image_name, extractor, model) | |
res = int(mask.size[1] * 512 / mask.size[0]) | |
image, mask, gen = generate_image_with_mask(image, mask, pipe, extractor, model, example_name, prompt) | |
return gen.resize((512, res)) | |