Virtual_Try-On / app.py
Gainward777's picture
Update app.py
aef179f verified
raw
history blame
3.68 kB
import gradio as gr
import numpy as np
#import random
import spaces #[uncomment to use ZeroGPU]
#from diffusers import DiffusionPipeline
import torch
from diffusers import AutoPipelineForInpainting
from diffusers.utils import load_image
from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline
#import cv2
#import matplotlib.pyplot as plt
from PIL import Image
import os
import gc
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
GDINO_MODEL_NAME="IDEA-Research/grounding-dino-tiny"
SAM_MODEL_NAME="facebook/sam-vit-base"
GDINO=pipeline(model=GDINO_MODEL_NAME, task="zero-shot-object-detection", device=DEVICE)
SAM=AutoModelForMaskGeneration.from_pretrained(SAM_MODEL_NAME).to(DEVICE)
SAM_PROCESSOR=AutoProcessor.from_pretrained(SAM_MODEL_NAME)
SD_MODEL="diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
SD_PIPLINE = AutoPipelineForInpainting.from_pretrained(SD_MODEL, torch_dtype=torch.float16).to(DEVICE)
IP_ADAPTER="h94/IP-Adapter"
SUB_FOLDER="sdxl_models"
IP_WEIGHT_NAME="ip-adapter_sdxl.bin"
SD_PIPLINE.load_ip_adapter(IP_ADAPTER, subfolder=SUB_FOLDER, weight_name=IP_WEIGHT_NAME)
IP_SCALE=0.6
SD_PIPLINE.set_ip_adapter_scale(IP_SCALE)
GEN_STEPS=100
def refine_masks(masks: torch.BoolTensor)->np.array:
masks = masks.permute(0, 2, 3, 1)
masks = masks.float().mean(axis=-1)
return masks.cpu().numpy()
def get_boxes(detections:list)-> list:
boxes = []
for det in detections:
boxes.append([det['box']['xmin'], det['box']['ymin'],
det['box']['xmax'], det['box']['ymax']])
return [boxes]
def get_mask(img:Image, prompt:str, d_model:pipeline, s_model:AutoModelForMaskGeneration,
s_processor:AutoProcessor, device:str, threshold:float = 0.3)-> np.array:
labels = [label if label.endswith(".") else label+"." for label in ['face', prompt]]
dets=d_model(img, candidate_labels=labels, threshold=threshold)
boxes = get_boxes(dets)
inputs=s_processor(images=img, input_boxes=boxes, return_tensors="pt").to(DEVICE)
outputs = s_model(**inputs)
masks = s_processor.post_process_masks(
masks=outputs.pred_masks,
original_sizes=inputs.original_sizes,
reshaped_input_sizes=inputs.reshaped_input_sizes
)[0]
return refine_masks(masks)
def generate_result(model_img:str, cloth_img:str,
masks: np.array, prompt:str, sd_pipline:AutoPipelineForInpainting, n_steps:int=100)->Image:
width, height = model_img.size
cloth_mask=masks[1] #np.array(masks[1],dtype=np.float32)
generator = torch.Generator(device="cpu").manual_seed(4)
images = sd_pipline(
prompt=prompt,
image=model_img,
mask_image=cloth_mask,
ip_adapter_image=cloth_img,
generator=generator,
num_inference_steps=n_steps,
).images
return images[0].resize((width, height))
@spaces.GPU
def run(model_img:Image, cloth_img:Image, cloth_class:str, close_description:str)->Image:
masks = get_mask(model_img, cloth_class, GDINO, SAM, SAM_PROCESSOR, DEVICE) #GSAM2)
result = generate_result(model_img, cloth_img, masks, close_description, SD_PIPLINE, GEN_STEPS)
gc.collect()
torch.cuda.empty_cache()
return result
gr.Interface(
run,
title = 'Virtual Try-On',
inputs=[
gr.Image(sources = 'upload', label='Model image', type = 'pil'),
gr.Image(sources = 'upload', label='Cloth image', type = 'pil'),
gr.Textbox(label = 'Cloth class'),
gr.Textbox(label = 'Close description')
],
outputs = [
gr.Image()
]
).launch(debug=True,share=True)