imageProcess / app.py
risekid's picture
Update app.py
b315d19
# This is a sample Python script.
# Press ⌃R to execute it or replace it with your code.
# Press Double ⇧ to search everywhere for classes, files, tool windows, actions, and settings.
import base64
import datetime
import json
import cv2
import requests
from PIL import Image, ImageDraw, ImageFont, ImageOps
import numpy as np
from io import BytesIO
import time
main_image_path = "/Users/aaron/Documents/temp/16pic_2415206_s.png"
API_TOKEN = "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"
headers = {"Authorization": f"Bearer {API_TOKEN}"}
# API_URL = "https://api-inference.huggingface.co/models/hustvl/yolos-tiny"
API_URL = "https://api-inference.huggingface.co/models/facebook/detr-resnet-50"
# API_OBJECT_URL = "https://api-inference.huggingface.co/models/microsoft/resnet-50"
API_SEGMENTATION_URL = "https://api-inference.huggingface.co/models/facebook/detr-resnet-50-panoptic"
API_SEGMENTATION_URL_2 = "https://api-inference.huggingface.co/models/nvidia/segformer-b0-finetuned-ade-512-512"
temp_dir = "/Users/aaron/Documents/temp/imageai/"
def query(filename):
with open(filename, "rb") as f:
data = f.read()
response = requests.request("POST", API_URL, headers=headers, data=data)
return json.loads(response.content.decode("utf-8"))
def queryObjectDetection(filename):
with open(filename, "rb") as f:
data = f.read()
response = requests.request("POST", API_OBJECT_URL, headers=headers, data=data, timeout=6)
print(response)
return json.loads(response.content.decode("utf-8"))
def getImageSegmentation():
data = query(main_image_path)
print(data)
return data
def crop_image(box):
# 打开图片
image = Image.open(main_image_path)
# 计算裁剪区域
crop_area = (box['xmin'], box['ymin'], box['xmax'], box['ymax'])
# 裁剪图片
cropped_image = image.crop(crop_area)
return cropped_image
# # 示例
# image_path = "path/to/your/image.jpg"
# box = {'xmin': 186, 'ymin': 75, 'xmax': 252, 'ymax': 123}
#
# cropped_image = crop_image(image_path, box)
# cropped_image.show() # 显示裁剪后的图片
# cropped_image.save("path/to/save/cropped_image.jpg") # 保存裁剪后的图片
# Press the green button in the gutter to run the script.
# if __name__ == '__main__':
# data = getImageSegmentation()
# for item in data:
# box = item['box']
# cropped_image = crop_image(box)
# temp_image_path = temp_dir + str(int(datetime.datetime.now().timestamp() * 1000000)) + ".png"
# print(temp_image_path)
# cropped_image.save(temp_image_path)
# object_data = queryObjectDetection(temp_image_path)
# print(object_data)
# flag = False
# for obj in object_data:
# # 检查字典中是否包含 'error' 键
# if 'error' in obj and obj['error'] is not None:
# flag = True
# print("找到了一个包含 'error' 键的字典,且其值不为 None")
# else:
# print("字典不包含 'error' 键,或其值为 None")
# if flag:
# continue
# item['label'] = object_data[0]['label']
# print(data)
#
# ###下面就是画个图,和上面住流程无关,仅仅用于测试
# image = Image.open(main_image_path)
# draw = ImageDraw.Draw(image)
#
# # 设置边框颜色和字体
# border_color = (255, 0, 0) # 红色
# text_color = (255, 255, 255) # 白色
# font = ImageFont.truetype("Geneva.ttf", 12) # 使用 系统Geneva 字体,大小为 8
#
# # 遍历对象列表,画边框和标签
# for obj in data:
# label = obj['label']
# box = obj['box']
# xmin, ymin, xmax, ymax = box['xmin'], box['ymin'], box['xmax'], box['ymax']
#
# # 画边框
# draw.rectangle([xmin, ymin, xmax, ymax], outline=border_color, width=2)
#
# # 画标签
# text_size = draw.textsize(label, font=font)
# draw.rectangle([xmin, ymin, xmin + text_size[0], ymin + text_size[1]], fill=border_color)
# draw.text((xmin, ymin), label, font=font, fill=text_color)
#
# image.show()
import numpy as np
from PIL import Image
import gradio as gr
def send_request_to_api(img_byte_arr, max_retries=3, wait_time=60):
retry_count = 0
while retry_count < max_retries:
response = requests.request("POST", API_SEGMENTATION_URL, headers=headers, data=img_byte_arr)
response_content = response.content.decode("utf-8")
# 检查响应是否包含错误
if "error" in response_content:
print(f"Error: {response_content}")
retry_count += 1
time.sleep(wait_time)
else:
json_obj = json.loads(response_content)
return json_obj
raise Exception("Failed to get a valid response from the API after multiple retries.")
def getSegmentationMaskImage(input_img, blur_kernel_size=21):
# 调整输入图像的大小
target_width = 600
aspect_ratio = float(input_img.height) / float(input_img.width)
target_height = int(target_width * aspect_ratio)
input_img.thumbnail((target_width, target_height))
img_byte_arr = BytesIO()
input_img.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
json_obj = send_request_to_api(img_byte_arr)
print(json_obj)
# 加载原始图像
original_image = input_img.copy()
# 如果原始图像不是RGBA模式,则将其转换为RGBA模式
if original_image.mode != 'RGBA':
original_image = original_image.convert('RGBA')
output_images = []
for item in json_obj:
label = item['label']
# 如果label以"LABEL"开头,则跳过此项
if label.startswith("LABEL"):
continue
mask_data = item['mask']
# 将Base64编码的mask数据解码为PNG图像
mask_image = Image.open(BytesIO(base64.b64decode(mask_data)))
# 将原始图像转换为OpenCV格式并应用高斯模糊
original_image_cv2 = cv2.cvtColor(np.array(original_image.convert('RGB')), cv2.COLOR_RGB2BGR)
blurred_image_cv2 = cv2.GaussianBlur(original_image_cv2, (blur_kernel_size, blur_kernel_size), 0)
# 将模糊图像转换回PIL格式,并将其转换回原始图像的颜色模式
blurred_image = Image.fromarray(cv2.cvtColor(blurred_image_cv2, cv2.COLOR_BGR2RGB)).convert(original_image.mode)
# 使用mask_image作为蒙版将原始图像的非模糊部分复制到模糊图像上
process_image = Image.composite(original_image, blurred_image, mask_image)
# 在mask位置添加红色文本和指向原始图像非模糊部分的红色线
draw = ImageDraw.Draw(process_image)
font = ImageFont.load_default() # 您可以选择其他字体和大小
text_position = (10, 30)
draw.text(text_position, label, font=font, fill=(255, 0, 0))
# 计算mask的边界框
mask_bbox = mask_image.getbbox()
# 计算mask边界框的顶部中心点
mask_top_center_x = (mask_bbox[0] + mask_bbox[2]) // 2
mask_top_center_y = mask_bbox[1]
# 计算文本框的底部中心点
text_width, text_height = draw.textsize(label, font=font)
text_bottom_center_x = text_position[0] + text_width // 2
text_bottom_center_y = text_position[1] + text_height
# 绘制一条从文本框底部中心到mask边界框顶部中心的红色线
draw.line([(text_bottom_center_x, text_bottom_center_y), (mask_top_center_x, mask_top_center_y)],
fill=(255, 0, 0), width=2)
output_images.append(process_image)
return output_images
def sepia(input_img):
# 检查输入图像的数据类型和值范围
if input_img.dtype == np.float32 and np.max(input_img) <= 1.0:
input_img = (input_img * 255).astype(np.uint8)
input_img = Image.fromarray(input_img)
output_images = getSegmentationMaskImage(input_img)
# 将所有图像堆叠在一起
stacked_image = np.vstack([np.array(img) for img in output_images])
return stacked_image
def imageDemo():
demo = gr.Interface(sepia, gr.Image(shape=None), gr.outputs.Image(label="Processed Images", type="numpy"),
title='Image Processing Demo')
demo.launch()
if __name__ == '__main__':
imageDemo()
#######---------gif输出方式
# def sepia(input_img):
# input_img = Image.fromarray((input_img * 255).astype(np.uint8))
#
# output_images = getSegmentationMaskImage(input_img)
#
# # 生成GIF动画
# buffered = BytesIO()
# output_images[0].save(buffered, format='GIF', save_all=True, append_images=output_images[1:], duration=3000, loop=0)
# gif_str = base64.b64encode(buffered.getvalue()).decode()
# return f'<img src="data:image/gif;base64,{gif_str}" width="400" />'
#
#
# def imageDemo():
# demo = gr.Interface(sepia, gr.Image(shape=(200, 200)), gr.outputs.HTML(label="Processed Animation"), title='Sepia Filter Demo')
# demo.launch()