TryOn / utils /model.py
Warlord-K's picture
Initial Commit
ce9d0da
"""
This Module contains funstions for loading the segmentation model and inpainting models, and editing top using a example image or text prompt.
"""
# Imports
from diffusers import DiffusionPipeline
from diffusers import StableDiffusionInpaintPipeline
from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation
from torchvision.transforms.functional import to_pil_image
from PIL import Image
import torch
import numpy as np
import urllib.request
# Functions
def load_seg(model_card: str = "mattmdjaga/segformer_b2_clothes"):
"""
Load The Segmentation Extractor and Model.
Parameters:
model_card: HuggingFace Model Card. Default: mattmdjaga/segformer_b2_clothes
Returns:
extractor: Feature Extractor
model: Segformer Model For Segmentation
"""
extractor = AutoFeatureExtractor.from_pretrained(model_card)
model = SegformerForSemanticSegmentation.from_pretrained(model_card)
return extractor, model
def load_inpainting(using_prompt: bool = False, fast: bool = False):
"""
Load Inpaining Model.
Parameters:
using_prompt: If using a prompt based inpainting model or image based inpainting model. Default: False
Returns:
pipe: Diffusion Pipeline mounted onto the device
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
if using_prompt:
if fast:
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
revision="fp16",
torch_dtype=torch.float16,
)
else:
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
torch_dtype=torch.float32,
)
else:
if fast:
pipe = DiffusionPipeline.from_pretrained(
"Fantasy-Studio/Paint-by-Example",
torch_dtype=torch.float16,
)
else:
pipe = DiffusionPipeline.from_pretrained(
"Fantasy-Studio/Paint-by-Example",
torch_dtype=torch.float32,
)
pipe = pipe.to(device)
return pipe
def generate_mask(image_name: str, extractor, model):
"""
Generate mask using Image Path and Segmentation Model.
Parameters:
image_name: Path to Input Image
extractor: Feature Extractor
model: Segmentation Model
Returns:
image: PIL Image of Input Image
mask: PIL Image of Generated Mask
"""
try:
image = Image.open(image_name)
except Exception as e:
image = Image.open(urllib.request.urlopen(image_name))
inputs = extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits.cpu()
upsampled_logits = torch.nn.functional.interpolate(
logits,
size=image.size[::-1],
mode="bilinear",
align_corners=False,
)
pred_seg = upsampled_logits.argmax(dim=1)[0]
pred_seg[pred_seg != 4] = 0
pred_seg[pred_seg == 4] = 1
pred_seg = pred_seg.to(dtype=torch.float32)
# pred_seg = pred_seg.unsqueeze(dim = 0)
mask = to_pil_image(pred_seg)
return image, mask
def get_cloth(cloth_name, extractor, model):
cloth_image, cloth_mask = generate_mask(cloth_name, extractor, model)
cloth = np.array(cloth_image)
cloth[np.array(cloth_mask) == 0] = 255
return to_pil_image(cloth)
def generate_image(image, mask, pipe, example_name=None, prompt=None):
"""
Generate Edited Image. Uses Example Image or Prompt.
Parameters:
image: PIL Image of The Image to Edit.
mask: PIL Image of the Mask.
pipe: DiffusionPipeline
example_name: Path to Image of the cloth.
prompt: Editing Prompt, if not using Example Image.
Returns:
image: PIL Image of Input Image
mask: PIL Image of Generated Mask
gen: PIL Image of Generated Preview
"""
if example_name:
try:
example = Image.open(example_name)
except Exception as e:
example = Image.open(urllib.request.urlopen(example_name))
gen = pipe(
image=image.resize((512, 512)),
mask_image=mask.resize((512, 512)),
example_image=example.resize((512, 512)),
).images[0]
elif prompt:
gen = pipe(prompt=prompt, image=image, mask_image=mask).images[0]
else:
gen = None
print("Neither Example Image nor Prompt provided.")
return image, mask, gen
def generate_image_with_mask(image, mask, pipe, extractor, model, example_name=None, prompt=None):
"""
Generate Edited Image. Uses Example Image or Prompt. Extracts the Cloth from the cloth image.
Parameters:
image: PIL Image of The Image to Edit.
mask: PIL Image of the Mask.
pipe: DiffusionPipeline
example_name: Path to Image of the cloth.
prompt: Editing Prompt, if not using Example Image.
Returns:
image: PIL Image of Input Image
mask: PIL Image of Generated Mask
gen: PIL Image of Generated Preview
"""
if example_name:
cloth = get_cloth(example_name, extractor, model)
gen = pipe(
image=image.resize((512, 512)),
mask_image=mask.resize((512, 512)),
example_image=cloth.resize((512, 512)),
).images[0]
elif prompt:
gen = pipe(prompt=prompt, image=image, mask_image=mask).images[0]
else:
gen = None
print("Neither Example Image nor Prompt provided.")
return image, mask, gen
def load(using_prompt=False):
"""
Loads Segmentation and Inpainting Model.
Parameters:
using_prompt: If using a prompt based inpainting model or image based inpainting model. Default: False
Returns:
extractor: Feature Extractor
model: Segformer Model For Segmentation
pipe: Diffusion Pipeline loaded onto the device
"""
extractor, model = load_seg()
pipe = load_inpainting(using_prompt)
return extractor, model, pipe
def generate(image_name, extractor, model, pipe, example_name=None, prompt=None):
"""
Generate Preview.
Parameters:
image_name: Path to Input Image
extractor: Feature Extractor
model: Segmentation Model
pipe: DiffusionPipeline
example_name: Path to Image of the cloth.
prompt: Editing Prompt, if not using Example Image.
Returns:
gen: PIL Image of Generated Preview
"""
image, mask = generate_mask(image_name, extractor, model)
res = int(mask.size[1] * 512 / mask.size[0])
image, mask, gen = generate_image(image, mask, pipe, example_name, prompt)
return gen.resize((512, res))
def generate_with_mask(image_name, extractor, model, pipe, example_name=None, prompt=None):
"""
Generate Preview.
Parameters:
image_name: Path to Input Image
extractor: Feature Extractor
model: Segmentation Model
pipe: DiffusionPipeline
example_name: Path to Image of the cloth.
prompt: Editing Prompt, if not using Example Image.
Returns:
gen: PIL Image of Generated Preview
"""
image, mask = generate_mask(image_name, extractor, model)
res = int(mask.size[1] * 512 / mask.size[0])
image, mask, gen = generate_image_with_mask(image, mask, pipe, extractor, model, example_name, prompt)
return gen.resize((512, res))