import gradio as gr from PIL import Image from skimage import color, feature from skimage.segmentation import slic, mark_boundaries from skimage.filters import threshold_otsu import numpy as np import cv2 import os # 驗證輸入的圖片是否為有效的 NumPy 陣列格式 def validate_image(image): # 檢查圖像是否為 NumPy 陣列 if not isinstance(image, np.ndarray): raise ValueError("Input is not a valid NumPy array.") # 如果圖像是 RGBA 格式,轉換為 RGB 格式(去除 alpha 通道) if image.ndim == 3 and image.shape[2] == 4: # 如果是 RGBA 圖像 image = image[:, :, :3] # 保留 RGB 三個通道,去除 alpha 通道 # 如果是灰階圖像,將其轉換為 3 通道 RGB 圖像 if image.ndim == 2: # 灰階圖像 image = np.stack([image] * 3, axis=-1) # 擴展為 3 通道的 RGB elif image.ndim == 3 and image.shape[2] != 3: raise ValueError("Image must have 3 channels (RGB).") # 返回處理過的圖像,並強制將數值類型轉為 uint8 格式 return image.astype(np.uint8) # 邊緣檢測函數 def edge_detection(image, low_threshold, high_threshold): # 先將圖像轉為灰度圖 gray = color.rgb2gray(image) # 使用 Canny 邊緣檢測算法,傳入低和高閾值參數 edges = feature.canny(gray, sigma=low_threshold) # 將檢測結果轉為 0-255 範圍內的圖像格式 return (edges * 255).astype(np.uint8) # 影像分割函數 def image_segmentation(image, num_segments): # 使用 SLIC 算法進行圖像超像素分割 segments = slic(image, n_segments=num_segments, start_label=1) # 標註分割結果並返回分割後的圖像 segmented_image = mark_boundaries(image, segments) # 轉換為 uint8 類型以便於處理 return (segmented_image * 255).astype(np.uint8) # 修復影像函數(圖像修補) def image_inpainting(image, mask=None): # 如果沒有提供遮罩,則自動生成遮罩 if mask is None: # 將圖像轉為灰階圖 gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) # 使用 Otsu 閾值算法進行二值化,並取反(將背景設為白色,前景設為黑色) _, mask = cv2.threshold(gray, threshold_otsu(gray), 255, cv2.THRESH_BINARY_INV) # 擴大遮罩範圍 mask = cv2.dilate(mask, None, iterations=3) # 如果遮罩是 RGB 圖像,則將其轉換為灰階 if mask.ndim == 3: mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY) # 使用 OpenCV 進行圖像修復,將遮罩區域修復 inpainted = cv2.inpaint(image, mask, inpaintRadius=3, flags=cv2.INPAINT_TELEA) return inpainted # 主處理函數,用來根據選擇的任務進行處理 def process_image(task, image, mask, param1, param2): try: # 驗證輸入圖像是否有效 image = validate_image(image) # 根據任務選擇進行處理 if task == "Edge Detection": result = edge_detection(image, param1, param2) # 邊緣檢測 elif task == "Image Segmentation": result = image_segmentation(image, int(param1)) # 影像分割 elif task == "Image Inpainting": result = image_inpainting(image, mask) # 影像修補 else: raise ValueError(f"Unknown task: {task}") # 確保返回的結果是有效圖像,並強制轉換為 uint8 類型 result = np.clip(result, 0, 255).astype(np.uint8) # 確保像素值在 0 到 255 範圍內 return result except Exception as e: # 捕獲異常,並返回一個空白圖像作為回應 print(f"錯誤: {e}") return np.zeros((100, 100, 3), dtype=np.uint8) # 生成範例資料,將範例圖片與對應任務綁定 def generate_grouped_examples(task_names, examples_dir): examples = [] # 讀取範例資料夾中的所有圖片檔案 images = sorted([f for f in os.listdir(examples_dir) if f.endswith((".png", ".jpg", ".jpeg"))]) # 遍歷每一張圖片,生成對應的範例 for file_name in images: file_path = os.path.join(examples_dir, file_name) try: # 讀取圖片並轉換為 NumPy 陣列(RGB 格式) image_array = np.array(Image.open(file_path).convert("RGB")) # 強制將圖像轉為 uint8 格式 image_array = image_array.astype(np.uint8) # 將圖像數據格式轉換為 PIL.Image 格式以便 Gradio 處理 image_pil = Image.fromarray(image_array) # 為每個任務生成對應的範例(此處假設每個任務都使用相同的圖片) for task in task_names: examples.append([task, image_pil, None, 2, 5]) # 傳入 PIL 圖像 except Exception as e: # 處理加載圖片過程中的錯誤 print(f"Error loading example {file_name}: {e}") # 返回所有範例資料 return examples # 設定支持的任務類型 tasks = ["Edge Detection", "Image Segmentation", "Image Inpainting"] # 範例圖片存放目錄 examples_dir = "examples" # 生成範例資料 examples = generate_grouped_examples(tasks, examples_dir) # Gradio 介面設定 interface = gr.Interface( fn=process_image, # 主要處理函數 inputs=[ # 輸入元件 gr.Dropdown(choices=tasks, label="Task"), # 任務選擇 gr.Image(type="numpy", label="Input Image"), # 圖片輸入 gr.Image(type="numpy", label="Mask (Optional), Note: This is for Image Inpainting function"), # 可選遮罩(僅限修補) gr.Slider(1, 10, value=1, label="Parameter 1 (e.g., Edge Sigma, Num Segments)"), # 參數 1 gr.Slider(0, 10, value=1, label="Parameter 2 (e.g., High Threshold)") # 參數 2 ], outputs="image", # 輸出圖像 examples=examples, # 例子資料 title="Computer Vision Web App", # 網頁應用標題 description="Perform Edge Detection, Image Segmentation, or Inpainting on uploaded images." # 網頁描述 ) # 啟動 Gradio 介面 if __name__ == "__main__": interface.launch()