andy_hsih
Update app.py, Change the Exmaples and optimized.
f316013
import gradio as gr
from PIL import Image
from skimage import color, feature
from skimage.segmentation import slic, mark_boundaries
from skimage.filters import threshold_otsu
import numpy as np
import cv2
import os
# 驗證輸入的圖片是否為有效的 NumPy 陣列格式
def validate_image(image):
# 檢查圖像是否為 NumPy 陣列
if not isinstance(image, np.ndarray):
raise ValueError("Input is not a valid NumPy array.")
# 如果圖像是 RGBA 格式,轉換為 RGB 格式(去除 alpha 通道)
if image.ndim == 3 and image.shape[2] == 4: # 如果是 RGBA 圖像
image = image[:, :, :3] # 保留 RGB 三個通道,去除 alpha 通道
# 如果是灰階圖像,將其轉換為 3 通道 RGB 圖像
if image.ndim == 2: # 灰階圖像
image = np.stack([image] * 3, axis=-1) # 擴展為 3 通道的 RGB
elif image.ndim == 3 and image.shape[2] != 3:
raise ValueError("Image must have 3 channels (RGB).")
# 返回處理過的圖像,並強制將數值類型轉為 uint8 格式
return image.astype(np.uint8)
# 邊緣檢測函數
def edge_detection(image, low_threshold, high_threshold):
# 先將圖像轉為灰度圖
gray = color.rgb2gray(image)
# 使用 Canny 邊緣檢測算法,傳入低和高閾值參數
edges = feature.canny(gray, sigma=low_threshold)
# 將檢測結果轉為 0-255 範圍內的圖像格式
return (edges * 255).astype(np.uint8)
# 影像分割函數
def image_segmentation(image, num_segments):
# 使用 SLIC 算法進行圖像超像素分割
segments = slic(image, n_segments=num_segments, start_label=1)
# 標註分割結果並返回分割後的圖像
segmented_image = mark_boundaries(image, segments)
# 轉換為 uint8 類型以便於處理
return (segmented_image * 255).astype(np.uint8)
# 修復影像函數(圖像修補)
def image_inpainting(image, mask=None):
# 如果沒有提供遮罩,則自動生成遮罩
if mask is None:
# 將圖像轉為灰階圖
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# 使用 Otsu 閾值算法進行二值化,並取反(將背景設為白色,前景設為黑色)
_, mask = cv2.threshold(gray, threshold_otsu(gray), 255, cv2.THRESH_BINARY_INV)
# 擴大遮罩範圍
mask = cv2.dilate(mask, None, iterations=3)
# 如果遮罩是 RGB 圖像,則將其轉換為灰階
if mask.ndim == 3:
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
# 使用 OpenCV 進行圖像修復,將遮罩區域修復
inpainted = cv2.inpaint(image, mask, inpaintRadius=3, flags=cv2.INPAINT_TELEA)
return inpainted
# 主處理函數,用來根據選擇的任務進行處理
def process_image(task, image, mask, param1, param2):
try:
# 驗證輸入圖像是否有效
image = validate_image(image)
# 根據任務選擇進行處理
if task == "Edge Detection":
result = edge_detection(image, param1, param2) # 邊緣檢測
elif task == "Image Segmentation":
result = image_segmentation(image, int(param1)) # 影像分割
elif task == "Image Inpainting":
result = image_inpainting(image, mask) # 影像修補
else:
raise ValueError(f"Unknown task: {task}")
# 確保返回的結果是有效圖像,並強制轉換為 uint8 類型
result = np.clip(result, 0, 255).astype(np.uint8) # 確保像素值在 0 到 255 範圍內
return result
except Exception as e:
# 捕獲異常,並返回一個空白圖像作為回應
print(f"錯誤: {e}")
return np.zeros((100, 100, 3), dtype=np.uint8)
# 生成範例資料,將範例圖片與對應任務綁定
def generate_grouped_examples(task_names, examples_dir):
examples = []
# 讀取範例資料夾中的所有圖片檔案
images = sorted([f for f in os.listdir(examples_dir) if f.endswith((".png", ".jpg", ".jpeg"))])
# 遍歷每一張圖片,生成對應的範例
for file_name in images:
file_path = os.path.join(examples_dir, file_name)
try:
# 讀取圖片並轉換為 NumPy 陣列(RGB 格式)
image_array = np.array(Image.open(file_path).convert("RGB"))
# 強制將圖像轉為 uint8 格式
image_array = image_array.astype(np.uint8)
# 將圖像數據格式轉換為 PIL.Image 格式以便 Gradio 處理
image_pil = Image.fromarray(image_array)
# 為每個任務生成對應的範例(此處假設每個任務都使用相同的圖片)
for task in task_names:
examples.append([task, image_pil, None, 2, 5]) # 傳入 PIL 圖像
except Exception as e:
# 處理加載圖片過程中的錯誤
print(f"Error loading example {file_name}: {e}")
# 返回所有範例資料
return examples
# 設定支持的任務類型
tasks = ["Edge Detection", "Image Segmentation", "Image Inpainting"]
# 範例圖片存放目錄
examples_dir = "examples"
# 生成範例資料
examples = generate_grouped_examples(tasks, examples_dir)
# Gradio 介面設定
interface = gr.Interface(
fn=process_image, # 主要處理函數
inputs=[ # 輸入元件
gr.Dropdown(choices=tasks, label="Task"), # 任務選擇
gr.Image(type="numpy", label="Input Image"), # 圖片輸入
gr.Image(type="numpy", label="Mask (Optional), Note: This is for Image Inpainting function"), # 可選遮罩(僅限修補)
gr.Slider(1, 10, value=1, label="Parameter 1 (e.g., Edge Sigma, Num Segments)"), # 參數 1
gr.Slider(0, 10, value=1, label="Parameter 2 (e.g., High Threshold)") # 參數 2
],
outputs="image", # 輸出圖像
examples=examples, # 例子資料
title="Computer Vision Web App", # 網頁應用標題
description="Perform Edge Detection, Image Segmentation, or Inpainting on uploaded images." # 網頁描述
)
# 啟動 Gradio 介面
if __name__ == "__main__":
interface.launch()