陶泓
添加提示
28048bc
import streamlit as st
import torch
import numpy as np
from PIL import Image
import requests
from io import BytesIO
from transformers import AutoProcessor, CLIPSegForImageSegmentation
from scipy.ndimage import label, find_objects
import time
# Streamlit 应用标题
st.title("使用图像分割模型分割证件照")
# 输入图像 URL
url = st.text_input("输入图像地址:", "https://i.ibb.co/GRCGQ3n/464.jpg")
# 输入要识别的物体文本
texts_input = st.text_input("输入要检测的对象(以逗号分隔):", "a card")
texts = [text.strip() for text in texts_input.split(',')]
# 选择面积阈值
area_threshold = st.slider("忽略小区域的面积阈值", 0, 10000, 5000)
# 添加 GPU/CPU 选择按钮
device_option = st.radio("选择设备", ("GPU", "CPU"))
# 提交按钮
if st.button('提交'):
# 在按钮点击后确定设备
device = torch.device('cuda' if device_option == 'GPU' and torch.cuda.is_available() else 'cpu')
st.write(f"设备: {device}")
start_time = time.time() # 开始计时
# 加载模型和处理器到选定设备
processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device)
# 下载并处理图像
response = requests.get(url)
image = Image.open(BytesIO(response.content))
# 显示原始图像
st.image(image, caption="原图", use_column_width=True)
# 处理图像和文本
inputs = processor(text=texts, images=[image] * len(texts), padding=True, return_tensors="pt").to(device)
outputs = model(**inputs)
# 将 logits 转换为概率值并生成掩码
probabilities = torch.sigmoid(outputs.logits)
masks = probabilities > 0.5
masks = masks.detach().cpu().numpy() # 将数据移回 CPU 以进行后续处理
# 获取原始图像的 NumPy 数组
image_np = np.array(image)
# 全局计数器初始化
global_counter = 1
# 对每个物体生成分割图像
for i, mask in enumerate(masks):
# 将掩码调整为与原始图像相同的尺寸
mask_resized = Image.fromarray(mask).resize((image_np.shape[1], image_np.shape[0]), resample=Image.LANCZOS)
mask_resized = np.array(mask_resized) > 0.5
# 标记连通区域
labeled_mask, num_features = label(mask_resized)
object_slices = find_objects(labeled_mask)
for j in range(1, num_features + 1):
# 获取当前连通区域的边界框
object_slice = object_slices[j-1]
area = (object_slice[0].stop - object_slice[0].start) * (object_slice[1].stop - object_slice[1].start)
# 忽略面积小于阈值的区域
if area < area_threshold:
continue
# 创建一个矩形掩码
single_object_mask = np.zeros_like(mask_resized)
single_object_mask[object_slice] = 1
# 仅保留原始图像中与当前矩形掩码匹配的区域
single_segmented_image = np.zeros_like(image_np)
single_segmented_image[single_object_mask.astype(bool)] = image_np[single_object_mask.astype(bool)]
# 去除黑色背景(即裁剪图像到非黑色区域)
non_black_area = np.any(single_segmented_image > 0, axis=-1)
if non_black_area.any():
rows = np.any(non_black_area, axis=1)
cols = np.any(non_black_area, axis=0)
rmin, rmax = np.where(rows)[0][[0, -1]]
cmin, cmax = np.where(cols)[0][[0, -1]]
# 裁剪图像
cropped_image = single_segmented_image[rmin:rmax+1, cmin:cmax+1]
# 显示裁剪后的图像
st.image(cropped_image, caption=f"{texts[i]} - 图 {global_counter}", use_column_width=True)
# 增加全局计数器
global_counter += 1
end_time = time.time() # 结束计时
elapsed_time = end_time - start_time # 计算运行时间
# 显示程序运行时间
st.write(f"程序运行时间: {elapsed_time:.2f} 秒")