Spaces:
Runtime error
Runtime error
# This is a sample Python script. | |
# Press ⌃R to execute it or replace it with your code. | |
# Press Double ⇧ to search everywhere for classes, files, tool windows, actions, and settings. | |
import base64 | |
import datetime | |
import json | |
import cv2 | |
import requests | |
from PIL import Image, ImageDraw, ImageFont, ImageOps | |
import numpy as np | |
from io import BytesIO | |
import time | |
main_image_path = "/Users/aaron/Documents/temp/16pic_2415206_s.png" | |
API_TOKEN = "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" | |
headers = {"Authorization": f"Bearer {API_TOKEN}"} | |
# API_URL = "https://api-inference.huggingface.co/models/hustvl/yolos-tiny" | |
API_URL = "https://api-inference.huggingface.co/models/facebook/detr-resnet-50" | |
# API_OBJECT_URL = "https://api-inference.huggingface.co/models/microsoft/resnet-50" | |
API_SEGMENTATION_URL = "https://api-inference.huggingface.co/models/facebook/detr-resnet-50-panoptic" | |
API_SEGMENTATION_URL_2 = "https://api-inference.huggingface.co/models/nvidia/segformer-b0-finetuned-ade-512-512" | |
temp_dir = "/Users/aaron/Documents/temp/imageai/" | |
def query(filename): | |
with open(filename, "rb") as f: | |
data = f.read() | |
response = requests.request("POST", API_URL, headers=headers, data=data) | |
return json.loads(response.content.decode("utf-8")) | |
def queryObjectDetection(filename): | |
with open(filename, "rb") as f: | |
data = f.read() | |
response = requests.request("POST", API_OBJECT_URL, headers=headers, data=data, timeout=6) | |
print(response) | |
return json.loads(response.content.decode("utf-8")) | |
def getImageSegmentation(): | |
data = query(main_image_path) | |
print(data) | |
return data | |
def crop_image(box): | |
# 打开图片 | |
image = Image.open(main_image_path) | |
# 计算裁剪区域 | |
crop_area = (box['xmin'], box['ymin'], box['xmax'], box['ymax']) | |
# 裁剪图片 | |
cropped_image = image.crop(crop_area) | |
return cropped_image | |
# # 示例 | |
# image_path = "path/to/your/image.jpg" | |
# box = {'xmin': 186, 'ymin': 75, 'xmax': 252, 'ymax': 123} | |
# | |
# cropped_image = crop_image(image_path, box) | |
# cropped_image.show() # 显示裁剪后的图片 | |
# cropped_image.save("path/to/save/cropped_image.jpg") # 保存裁剪后的图片 | |
# Press the green button in the gutter to run the script. | |
# if __name__ == '__main__': | |
# data = getImageSegmentation() | |
# for item in data: | |
# box = item['box'] | |
# cropped_image = crop_image(box) | |
# temp_image_path = temp_dir + str(int(datetime.datetime.now().timestamp() * 1000000)) + ".png" | |
# print(temp_image_path) | |
# cropped_image.save(temp_image_path) | |
# object_data = queryObjectDetection(temp_image_path) | |
# print(object_data) | |
# flag = False | |
# for obj in object_data: | |
# # 检查字典中是否包含 'error' 键 | |
# if 'error' in obj and obj['error'] is not None: | |
# flag = True | |
# print("找到了一个包含 'error' 键的字典,且其值不为 None") | |
# else: | |
# print("字典不包含 'error' 键,或其值为 None") | |
# if flag: | |
# continue | |
# item['label'] = object_data[0]['label'] | |
# print(data) | |
# | |
# ###下面就是画个图,和上面住流程无关,仅仅用于测试 | |
# image = Image.open(main_image_path) | |
# draw = ImageDraw.Draw(image) | |
# | |
# # 设置边框颜色和字体 | |
# border_color = (255, 0, 0) # 红色 | |
# text_color = (255, 255, 255) # 白色 | |
# font = ImageFont.truetype("Geneva.ttf", 12) # 使用 系统Geneva 字体,大小为 8 | |
# | |
# # 遍历对象列表,画边框和标签 | |
# for obj in data: | |
# label = obj['label'] | |
# box = obj['box'] | |
# xmin, ymin, xmax, ymax = box['xmin'], box['ymin'], box['xmax'], box['ymax'] | |
# | |
# # 画边框 | |
# draw.rectangle([xmin, ymin, xmax, ymax], outline=border_color, width=2) | |
# | |
# # 画标签 | |
# text_size = draw.textsize(label, font=font) | |
# draw.rectangle([xmin, ymin, xmin + text_size[0], ymin + text_size[1]], fill=border_color) | |
# draw.text((xmin, ymin), label, font=font, fill=text_color) | |
# | |
# image.show() | |
import numpy as np | |
from PIL import Image | |
import gradio as gr | |
def send_request_to_api(img_byte_arr, max_retries=3, wait_time=60): | |
retry_count = 0 | |
while retry_count < max_retries: | |
response = requests.request("POST", API_SEGMENTATION_URL, headers=headers, data=img_byte_arr) | |
response_content = response.content.decode("utf-8") | |
# 检查响应是否包含错误 | |
if "error" in response_content: | |
print(f"Error: {response_content}") | |
retry_count += 1 | |
time.sleep(wait_time) | |
else: | |
json_obj = json.loads(response_content) | |
return json_obj | |
raise Exception("Failed to get a valid response from the API after multiple retries.") | |
def getSegmentationMaskImage(input_img, blur_kernel_size=21): | |
# 调整输入图像的大小 | |
target_width = 600 | |
aspect_ratio = float(input_img.height) / float(input_img.width) | |
target_height = int(target_width * aspect_ratio) | |
input_img.thumbnail((target_width, target_height)) | |
img_byte_arr = BytesIO() | |
input_img.save(img_byte_arr, format='PNG') | |
img_byte_arr = img_byte_arr.getvalue() | |
json_obj = send_request_to_api(img_byte_arr) | |
print(json_obj) | |
# 加载原始图像 | |
original_image = input_img.copy() | |
# 如果原始图像不是RGBA模式,则将其转换为RGBA模式 | |
if original_image.mode != 'RGBA': | |
original_image = original_image.convert('RGBA') | |
output_images = [] | |
for item in json_obj: | |
label = item['label'] | |
# 如果label以"LABEL"开头,则跳过此项 | |
if label.startswith("LABEL"): | |
continue | |
mask_data = item['mask'] | |
# 将Base64编码的mask数据解码为PNG图像 | |
mask_image = Image.open(BytesIO(base64.b64decode(mask_data))) | |
# 将原始图像转换为OpenCV格式并应用高斯模糊 | |
original_image_cv2 = cv2.cvtColor(np.array(original_image.convert('RGB')), cv2.COLOR_RGB2BGR) | |
blurred_image_cv2 = cv2.GaussianBlur(original_image_cv2, (blur_kernel_size, blur_kernel_size), 0) | |
# 将模糊图像转换回PIL格式,并将其转换回原始图像的颜色模式 | |
blurred_image = Image.fromarray(cv2.cvtColor(blurred_image_cv2, cv2.COLOR_BGR2RGB)).convert(original_image.mode) | |
# 使用mask_image作为蒙版将原始图像的非模糊部分复制到模糊图像上 | |
process_image = Image.composite(original_image, blurred_image, mask_image) | |
# 在mask位置添加红色文本和指向原始图像非模糊部分的红色线 | |
draw = ImageDraw.Draw(process_image) | |
font = ImageFont.load_default() # 您可以选择其他字体和大小 | |
text_position = (10, 30) | |
draw.text(text_position, label, font=font, fill=(255, 0, 0)) | |
# 计算mask的边界框 | |
mask_bbox = mask_image.getbbox() | |
# 计算mask边界框的顶部中心点 | |
mask_top_center_x = (mask_bbox[0] + mask_bbox[2]) // 2 | |
mask_top_center_y = mask_bbox[1] | |
# 计算文本框的底部中心点 | |
text_width, text_height = draw.textsize(label, font=font) | |
text_bottom_center_x = text_position[0] + text_width // 2 | |
text_bottom_center_y = text_position[1] + text_height | |
# 绘制一条从文本框底部中心到mask边界框顶部中心的红色线 | |
draw.line([(text_bottom_center_x, text_bottom_center_y), (mask_top_center_x, mask_top_center_y)], | |
fill=(255, 0, 0), width=2) | |
output_images.append(process_image) | |
return output_images | |
def sepia(input_img): | |
# 检查输入图像的数据类型和值范围 | |
if input_img.dtype == np.float32 and np.max(input_img) <= 1.0: | |
input_img = (input_img * 255).astype(np.uint8) | |
input_img = Image.fromarray(input_img) | |
output_images = getSegmentationMaskImage(input_img) | |
# 将所有图像堆叠在一起 | |
stacked_image = np.vstack([np.array(img) for img in output_images]) | |
return stacked_image | |
def imageDemo(): | |
demo = gr.Interface(sepia, gr.Image(shape=None), gr.outputs.Image(label="Processed Images", type="numpy"), | |
title='Image Processing Demo') | |
demo.launch() | |
if __name__ == '__main__': | |
imageDemo() | |
#######---------gif输出方式 | |
# def sepia(input_img): | |
# input_img = Image.fromarray((input_img * 255).astype(np.uint8)) | |
# | |
# output_images = getSegmentationMaskImage(input_img) | |
# | |
# # 生成GIF动画 | |
# buffered = BytesIO() | |
# output_images[0].save(buffered, format='GIF', save_all=True, append_images=output_images[1:], duration=3000, loop=0) | |
# gif_str = base64.b64encode(buffered.getvalue()).decode() | |
# return f'<img src="data:image/gif;base64,{gif_str}" width="400" />' | |
# | |
# | |
# def imageDemo(): | |
# demo = gr.Interface(sepia, gr.Image(shape=(200, 200)), gr.outputs.HTML(label="Processed Animation"), title='Sepia Filter Demo') | |
# demo.launch() | |