|
import torch |
|
import diffusers |
|
import tqdm as notebook_tqdm |
|
from diffusers import StableDiffusionInpaintPipeline |
|
import cv2 |
|
import math |
|
import gradio as gr |
|
import numpy as np |
|
import os |
|
import mediapipe as mp |
|
|
|
from mediapipe.tasks import python |
|
from mediapipe.tasks.python import vision |
|
from mediapipe.tasks.python.components import containers |
|
|
|
from skimage.measure import label, regionprops |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import cv2 |
|
|
|
|
|
from skimage.measure import label |
|
from skimage.measure import regionprops |
|
|
|
from PIL import Image |
|
import torch |
|
|
|
import numpy as np |
|
import cv2 |
|
from PIL import Image, ImageDraw |
|
import mediapipe as mp |
|
from transformers import pipeline |
|
from skimage.measure import label, regionprops |
|
import gradio as gr |
|
|
|
|
|
import gradio as gr |
|
import numpy as np |
|
import cv2 |
|
from PIL import Image, ImageDraw |
|
import mediapipe as mp |
|
from transformers import pipeline |
|
from skimage.measure import label, regionprops |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
def _normalized_to_pixel_coordinates( |
|
normalized_x: float, normalized_y: float, image_width: int, image_height: int): |
|
"""Converts normalized value pair to pixel coordinates.""" |
|
|
|
|
|
def is_valid_normalized_value(value: float) -> bool: |
|
return (value > 0 or math.isclose(0, value)) and (value < 1 or math.isclose(1, value)) |
|
|
|
if not (is_valid_normalized_value(normalized_x) and is_valid_normalized_value(normalized_y)): |
|
|
|
return None |
|
x_px = min(math.floor(normalized_x * image_width), image_width - 1) |
|
y_px = min(math.floor(normalized_y * image_height), image_height - 1) |
|
return x_px, y_px |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
pipe = StableDiffusionInpaintPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-2-inpainting", |
|
torch_dtype=torch.float16, |
|
).to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
BG_COLOR = (192, 192, 192) |
|
MASK_COLOR = (255, 255, 255) |
|
|
|
RegionOfInterest = vision.InteractiveSegmenterRegionOfInterest |
|
NormalizedKeypoint = containers.keypoint.NormalizedKeypoint |
|
|
|
|
|
base_options = python.BaseOptions(model_asset_path='model.tflite') |
|
options = vision.ImageSegmenterOptions(base_options=base_options, output_category_mask=True) |
|
|
|
|
|
def create_bounding_box_mask(image): |
|
image = 1 - image |
|
|
|
|
|
y_indices, x_indices = np.nonzero(image) |
|
if not y_indices.size or not x_indices.size: |
|
return None |
|
|
|
|
|
x_min, x_max = x_indices.min(), x_indices.max() |
|
y_min, y_max = y_indices.min(), y_indices.max() |
|
|
|
|
|
bounding_mask = np.zeros_like(image, dtype=np.uint8) |
|
bounding_mask[y_min:y_max+1, x_min:x_max+1] = 1 |
|
|
|
return bounding_mask |
|
|
|
|
|
|
|
def segment_2(image_np, coordinates): |
|
OVERLAY_COLOR = (255, 105, 180) |
|
|
|
|
|
with python.vision.InteractiveSegmenter.create_from_options(options) as segmenter: |
|
|
|
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image_np) |
|
|
|
|
|
coordinates = coordinates.strip("()") |
|
|
|
|
|
valeurs = coordinates.split(',') |
|
|
|
|
|
x = float(valeurs[0]) |
|
y = float(valeurs[1]) |
|
|
|
|
|
roi = RegionOfInterest(format=RegionOfInterest.Format.KEYPOINT, |
|
keypoint=NormalizedKeypoint(x, y)) |
|
segmentation_result = segmenter.segment(image, roi) |
|
category_mask = segmentation_result.category_mask |
|
|
|
|
|
mask = (category_mask.numpy_view().astype(np.uint8)*255) |
|
|
|
|
|
x, y, w, h = cv2.boundingRect(mask) |
|
|
|
|
|
image_data = cv2.cvtColor(image.numpy_view(), cv2.COLOR_BGR2RGB) |
|
|
|
|
|
overlay_image = np.zeros(image_data.shape, dtype=np.uint8) |
|
overlay_image[:] = OVERLAY_COLOR |
|
|
|
|
|
alpha = np.stack((category_mask.numpy_view(),) * 3, axis=-1) <= 0.1 |
|
|
|
|
|
alpha = alpha.astype(float) * 0.5 |
|
|
|
|
|
output_image = image_data * (1 - alpha) + overlay_image * alpha |
|
output_image = output_image.astype(np.uint8) |
|
|
|
|
|
thickness, radius = 6, -1 |
|
keypoint_px = _normalized_to_pixel_coordinates(x, y, image.width, image.height) |
|
cv2.circle(output_image, keypoint_px, thickness + 5, (0, 0, 0), radius) |
|
cv2.circle(output_image, keypoint_px, thickness, (255, 255, 255), radius) |
|
|
|
|
|
image_width, image_height = output_image.shape[:2] |
|
bounding_mask = create_bounding_box_mask(mask) |
|
bbox_mask_image = Image.fromarray((bounding_mask * 255).astype(np.uint8)) |
|
bbox_img = bbox_mask_image.convert("RGB") |
|
bbox_img.resize((image_width, image_height)) |
|
|
|
return output_image,bbox_mask_image |
|
|
|
|
|
def generate_2(image_file_path, bbox_image, prompt): |
|
|
|
|
|
img = Image.fromarray(image_file_path).convert("RGB") |
|
|
|
|
|
images = pipe(prompt=prompt, |
|
image=img, |
|
mask_image=bbox_image, |
|
generator=torch.Generator(device="cuda").manual_seed(0), |
|
num_images_per_prompt=3, |
|
plms=True).images |
|
|
|
|
|
def image_grid(imgs, rows, cols): |
|
assert len(imgs) == rows*cols |
|
|
|
w, h = imgs[0].size |
|
grid = Image.new('RGB', size=(cols*w, rows*h)) |
|
grid_w, grid_h = grid.size |
|
|
|
for i, img in enumerate(imgs): |
|
grid.paste(img, box=(i%cols*w, i//cols*h)) |
|
return grid |
|
|
|
grid_image = image_grid(images, 1, 3) |
|
return grid_image |
|
|
|
|
|
def onclick(evt: gr.SelectData, image): |
|
if evt: |
|
x, y = evt.index |
|
|
|
normalized_x = round(x / image.shape[1], 2) |
|
normalized_y = round(y / image.shape[0], 2) |
|
return normalized_x, normalized_y |
|
else: |
|
return None, None |
|
|
|
|
|
|
|
|
|
|
|
def callback(image, coordinates, prompt): |
|
|
|
|
|
segmented_image, bbox_image = segment_2(image, coordinates) |
|
|
|
|
|
grid_image = generate_2(image, bbox_image, prompt) |
|
|
|
|
|
return segmented_image, grid_image |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
image_input = gr.Image(type="numpy", label="Upload Image", interactive=True) |
|
coordinates_output = gr.Textbox(label="Coordinates") |
|
with gr.Row(): |
|
prompt_input = gr.Textbox(label="What do you want to change?") |
|
submit_button = gr.Button("Submit") |
|
with gr.Row(): |
|
segmented_image_output = gr.Image(type="numpy", label="Segmented Image") |
|
grid_image_output = gr.Image(type="pil", label="Generated Image Grid") |
|
|
|
image_input.select(onclick, inputs=[image_input], outputs=coordinates_output) |
|
submit_button.click(fn=callback, inputs=[image_input, coordinates_output, prompt_input], outputs=[segmented_image_output, grid_image_output]) |
|
|
|
demo.launch(debug=True) |
|
|