HivisionIDPhotos / idPhotoCreateUtils.py
mikesolar's picture
Upload folder using huggingface_hub
d86aa1d
"""
@author: cuny
@file: idPhotoCreateUtils.py
@time: 2022/4/4 14:37
@description:
证件照制作服务类,新增了人脸矫正函数
"""
from _service import *
from hivisionai.hycv.utils import CV2Bytes
from _lib import AliyunUser, HY_HUMAN_MATTING_WEIGHTS_PATH
from face_judgement_align import IDphotos_create
from error import IDError
import onnxruntime
import time
import cv2
class IdPhotoCreateService(Service, CV2Bytes):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# 设置预加载模型参数,dlib、抠头、抠脖子等等
print("证件照制作对象初始化...")
start = time.time()
self.__human_sess = None
self.fd68 = None # 为本地人脸检测预留接口
self.user = AliyunUser()
print(f"初始化完毕,总耗时{round(time.time() - start, 2)}秒")
@property
def human_sess(self):
if self.__human_sess is None:
print("加载模型...")
self.__human_sess = onnxruntime.InferenceSession(HY_HUMAN_MATTING_WEIGHTS_PATH)
return self.__human_sess
def createMsg(self, status, msg, *args, **kwargs):
"""
本方法用于创建一个用于发送到WebSocket客户端的数据
输入的信息部分,需要有如下几个参数:
1. id,固定为"return-result"
2. status,如果输入为1则status=true, 如果输入为-1则status=false
3. obj_key, 图片的云端路径, 这是输入的msg本身自带的
"""
msg['status'] = True if status >= 1 else False # 最好还是用bool
msg['id'] = "async-back-msg"
msg['type'] = "certificatePhoto"
msg["format"] = "png"
return msg
def process(self,
image_pre,
oss_image_name,
w=295,
h=413,
beauty=False,
upload_path_hd=None,
upload_path_common=None,
if_upload: bool = True):
"""
处理函数
Args:
image_pre: 输入的原图
oss_image_name: 上传阿里云api的尺寸图像
w: 证件照尺寸-宽
h: 证件照尺寸-高
beauty: 是否美颜
upload_path_hd: 高清图上传cos路径
upload_path_common: 标清图上传cos路径
if_upload: 是否上传,不同选择返回的参数不同
Returns:
1. if if_upload is True:
函数会将图像上传,不返回图像仅返回参数
2. if if_upload is False:
函数不会将图像上传,返回图像和一些参数
"""
print("oss_name:", oss_image_name)
result_image_HD, result_image, _, \
typography_arr, typography_rotate, \
relative_x, relative_y, w, h, id_temp_info = IDphotos_create(image_pre,
size=(h, w),
head_height_ratio=0.45,
head_measure_ratio=0.2,
align=True,
beauty=beauty,
fd68=self.fd68,
human_sess=self.load_sess_generator("human_sess"),
oss_image_name=oss_image_name,
user=self.user)
if if_upload:
# 上传图像,云端模式
print("[图像尺寸]: ", result_image_HD.shape)
result_image_HD_byte = self.cv2_byte(result_image_HD, imageType=".png")
self.uploadFile_COS(buffer=result_image_HD_byte, key=upload_path_hd)
result_image_byte = self.cv2_byte(result_image, imageType=".png")
self.uploadFile_COS(buffer=result_image_byte, key=upload_path_common)
print("[image send success]")
return typography_arr, typography_rotate, relative_x, relative_y, w, h, id_temp_info
else:
# 不上传图像,返回处理结果
return result_image_HD, result_image, typography_arr, typography_rotate, relative_x, relative_y, w, h, id_temp_info
def checkKey(self, msg):
print("GET", msg)
try:
uid, send_msg = msg["uid"], msg["send_msg"]
connectionID = None
except KeyError:
connectionID, send_msg = msg["connectionID"], msg["send_msg"]
uid = send_msg["uid"]
download_path: str = send_msg["obj_key"] # 获得cos下载路径
# platform = send_msg["platform"] if "platform" in send_msg else "undefined" # 换装次数
# 获取需要被制作的证件照尺寸
template_info = send_msg["template_info"]
w, h, name = int(template_info["width"]), int(template_info["height"]), template_info["name"]
# 获得cos回传传路径
img_format = send_msg['obj_key'][send_msg['obj_key'].rfind('.') + 1:]
tr = send_msg['obj_key'].replace(img_format, 'png')
upload_path_hd: str = tr.replace("old-image", "new-image/hd")
upload_path_common: str = tr.replace("old-image", "new-image/common")
image_name = f"{uid}_{upload_path_common.split('/')[-1]}"
send_msg["hd_key"] = upload_path_hd # 回传云端结果图片路径(高清照)
send_msg["common_key"] = upload_path_common # 回传云端结果图片路径(高清照)
return (w, h, name), (download_path, upload_path_hd, upload_path_common), image_name, send_msg, (
uid, connectionID)
def __call__(self, msg, *args, **kwargs):
"""
证件照制作算法服务函数
"""
# --------------初始化一些数据-------------- #
print(msg)
backMsg, uid = None, ""
status_id = "0000"
funcDiary = FuncDiary("certificatePhoto")
# noinspection PyBroadException
try:
(w, h, name), (download_path, upload_path_hd, upload_path_common), image_name, backMsg, uid = self.checkKey(
msg)
# ----------------数据获取完毕-------------- #
# 开始处理
print("start...")
# start = time.time()
resp = self.downloadFile_COS(download_path, if_read=False) # 下载图片
image_byte = resp['Body'].get_raw_stream().read() # 读取二进制图片
# 将二进制图片转为cv2格式, 无损格式
image_pre = self.byte_cv2(image_byte, flags=cv2.IMREAD_COLOR)
# cv2.imwrite(f"test_image/cloud_img.{img_format}", image_pre)
# np_arr = np.frombuffer(image_byte, np.uint8)
# image = cv2.imdecode(np_arr, -1)
# 数据图片下载完毕,开始功能处理
print("processing...")
# 证件照制作
# 返回的w和h与输入的w和h不是一回事
backMsg["typography_arr"], backMsg["typography_rotate"], \
backMsg["relative_x"], backMsg["relative_y"], \
backMsg["w_create"], backMsg["h_create"], \
backMsg["id_temp_info"] = self.process(image_pre=image_pre,
oss_image_name=image_name,
w=w,
h=h,
upload_path_hd=upload_path_hd,
upload_path_common=upload_path_common)
except IDError as e:
# ------------处理失败, 错误类型有两种--------------- #
# 一是人像错误,这时候用户上传了一张无人像(太糊)或者两个以上人像的照片
# 此时face_num = 0或者2, back_msg["status"] is True
# 此外为未知错误,此时face_num 不存在于back_msg
# back_msg["status"] is False
# ----------------------------------------------- #
# print(type(e), e.err)
status_id = e.status_id
if e.face_num != -1:
backMsg["face_num"] = e.face_num
backMsg = self.createMsg(status=1, msg=backMsg) # back_msg["status"] is True
else:
# 抠图失败
backMsg = self.createMsg(status=-1, msg=backMsg)
print("fail!")
except cv2.error:
status_id = "1103"
backMsg = self.createMsg(status=-1, msg=backMsg)
print("fail!")
except Exception as e:
status_id = "1500"
print("[ERROR] ", e)
backMsg["problem"] = str(e)
backMsg = self.createMsg(status=-1, msg=backMsg)
print("fail!")
else:
# 无错误
backMsg = self.createMsg(status=1, msg=backMsg)
# 处理成功,在回传消息中添加成功对应消息
backMsg["face_num"] = 1 # 人脸个数,处理成功的话必然是1
print("success!")
finally:
# print(back_msg) # 打印回传数据,方便调试
self.sendMsg(backMsg, uid)
# ------------------投递日志------------------- #
funcDiary.content = backMsg
funcDiary.uploadDiary_COS(status_id=status_id, uid=uid[0])
# ------------------投递结束------------------- #
assert status_id == "0000", f"函数出现异常: {status_id}"
def load_sess(idPhotoCreateService: IdPhotoCreateService):
while True:
yield idPhotoCreateService.human_sess