andyhsih8's picture
Update app.py, Change the Exmaples and optimized. -v2
b940194
import gradio as gr
from PIL import Image
from skimage import color, feature
from skimage.segmentation import slic, mark_boundaries
from skimage.filters import threshold_otsu
import numpy as np
import cv2
import os
# 驗證輸入的圖片是否為有效的 NumPy 陣列格式
def validate_image(image):
if not isinstance(image, np.ndarray):
raise ValueError("Input is not a valid NumPy array.")
if image.ndim == 3 and image.shape[2] == 4: # RGBA to RGB
image = image[:, :, :3]
if image.ndim == 2: # Grayscale to RGB
image = np.stack([image] * 3, axis=-1)
elif image.ndim == 3 and image.shape[2] != 3:
raise ValueError("Image must have 3 channels (RGB).")
return image.astype(np.uint8)
# 邊緣檢測函數
def edge_detection(image, low_threshold, high_threshold):
gray = color.rgb2gray(image)
edges = feature.canny(gray, sigma=low_threshold)
return (edges * 255).astype(np.uint8)
# 影像分割函數
def image_segmentation(image, num_segments):
segments = slic(image, n_segments=num_segments, start_label=1)
segmented_image = mark_boundaries(image, segments)
return (segmented_image * 255).astype(np.uint8)
# 修復影像函數(圖像修補)
def image_inpainting(image, mask=None):
if mask is None:
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
_, mask = cv2.threshold(gray, threshold_otsu(gray), 255, cv2.THRESH_BINARY_INV)
mask = cv2.dilate(mask, None, iterations=3)
if mask.ndim == 3:
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
inpainted = cv2.inpaint(image, mask, inpaintRadius=3, flags=cv2.INPAINT_TELEA)
return inpainted
# 主處理函數,用來根據選擇的任務進行處理
def process_image(task, image, mask, param1, param2):
try:
# 驗證輸入的圖像
image = validate_image(image)
if task == "Edge Detection":
result = edge_detection(image, param1, param2)
elif task == "Image Segmentation":
result = image_segmentation(image, int(param1))
elif task == "Image Inpainting":
result = image_inpainting(image, mask)
else:
raise ValueError(f"Unknown task: {task}")
result = np.clip(result, 0, 255).astype(np.uint8)
# 確保返回 PIL Image 格式
return Image.fromarray(result) # 返回 PIL Image 格式的圖像
except Exception as e:
print(f"錯誤: {e}")
return np.zeros((100, 100, 3), dtype=np.uint8)
# 生成範例資料,強化對圖像的檢查和處理
def generate_grouped_examples(task_names, examples_dir):
examples = []
images = sorted([f for f in os.listdir(examples_dir) if f.endswith((".png", ".jpg", ".jpeg"))])
# 生成一個全黑的 mask 當作 placeholder
black_mask = Image.fromarray(np.zeros((100, 100, 3), dtype=np.uint8))
# for file_name in images:
# file_path = os.path.join(examples_dir, file_name)
# try:
# image = Image.open(file_path).convert("RGB")
# image_array = np.array(image)
# # 確保範例圖像轉為 PIL.Image 格式
# image_pil = Image.fromarray(image_array)
# # for task in task_names:
# # # 確保範例圖像正確
# # examples.append([task, image_pil, None, 2, 5]) # 傳入 PIL 圖像
# for task in task_names:
# if task == "Image Inpainting":
# # 給一個 dummy mask
# examples.append([task, image_pil, black_mask, 2, 5])
# else:
# # 對非 Inpainting 任務,可直接給 mask=None 或改給 black_mask,看哪種運行正常
# # 但通常給 None 就會引發現在的問題,所以建議也用黑 mask
# examples.append([task, image_pil, black_mask, 2, 5])
# except Exception as e:
# print(f"Error loading example {file_name}: {e}")
# return examples
for i, file_name in enumerate(images):
file_path = os.path.join(examples_dir, file_name)
try:
image = Image.open(file_path).convert("RGB")
image_array = np.array(image)
image_pil = Image.fromarray(image_array)
# 根據 i 決定前三張或其他張如何設定 mask
for task in task_names:
if i < 2:
# 前三張圖片:第三個參數 (mask) = None
examples.append([task, image_pil, None, 2, 5])
else:
# 之後的圖片:如果是 Inpainting,就給 black_mask;否則也可給黑 mask
if task == "Image Inpainting":
examples.append([task, image_pil, black_mask, 2, 5])
else:
examples.append([task, image_pil, black_mask, 2, 5])
except Exception as e:
print(f"Error loading example {file_name}: {e}")
return examples
tasks = ["Edge Detection", "Image Segmentation", "Image Inpainting"]
examples_dir = "examples" # 假設範例圖像在此目錄下
examples = generate_grouped_examples(tasks, examples_dir)
# Gradio 介面設定
interface = gr.Interface(
fn=process_image,
inputs=[
gr.Dropdown(choices=tasks, label="Task"),
gr.Image(type="numpy", label="Input Image"),
gr.Image(type="numpy", label="Mask (Optional), Note: This is for Image Inpainting function"),
gr.Slider(1, 10, value=1, label="Parameter 1 (e.g., Edge Sigma, Num Segments)"),
gr.Slider(0, 10, value=1, label="Parameter 2 (e.g., High Threshold)")
],
outputs=gr.Image(type="pil"), # 返回 PIL 格式的圖像
examples=examples,
title="Computer Vision Web App",
description="Perform Edge Detection, Image Segmentation, or Inpainting on uploaded images."
)
if __name__ == "__main__":
interface.launch()