ChatAnything / DATA /test_landmark.py
ermu2001's picture
Upload DATA/test_landmark.py with huggingface_hub
d8d3000
raw
history blame
5.52 kB
# import os
# import cv2
# import time
# import glob
# import argparse
# import scipy
# import numpy as np
# from PIL import Image
# import torch
# from tqdm import tqdm
# from itertools import cycle
# from extract_kp_videos_safe import KeypointExtractor
# import numpy as np
# from PIL import Image
# with torch.no_grad():
# img_np =cv2.imread('Strawberry Monster.png')
# predictor = KeypointExtractor('cuda')
# dets = predictor.det_net.detect_faces(img_np, 0.97)
# if len(dets) == 0:
# detect = False
# else:
# print("success")
# import os
# import cv2
# import torch
# from tqdm import tqdm
# from extract_kp_videos_safe import KeypointExtractor
# # 创建 KeypointExtractor 实例
# # 设置文件夹路径
# folder_path = 'control_inversion'
# landmark_detect_false=0
# landmark_detect_success=0
# # 遍历文件夹中的图像文件
# for filename in tqdm(os.listdir(path)):
# if filename.endswith('.png') or filename.endswith('.jpg'):
# # 读取图像
# image_path = os.path.join(folder_path, filename)
# img_np = cv2.imread(image_path)
# # 进行人脸检测和关键点提取
# with torch.no_grad():
# predictor = KeypointExtractor('cuda')
# dets = predictor.det_net.detect_faces(img_np, 0.97)
# if len(dets) == 0:
# landmark_detect_false += 1
# else:
# landmark_detect_success += 1
# detect_rate = landmark_detect_success/(landmark_detect_success+landmark_detect_false)
# print(detect_rate)
# import os
# import cv2
# import torch
# from tqdm import tqdm
# from extract_kp_videos_safe import KeypointExtractor
# # 设置文件夹路径
# folder_path = 'prompts'
# # 初始化成功和失败的计数
# total_landmark_detect_success = 0
# total_landmark_detect_false = 0
# # 遍历文件夹中的 txt 文件
# for txt_filename in os.listdir(folder_path):
# if txt_filename.endswith('.txt'):
# txt_file_path = os.path.join(folder_path, txt_filename)
# # 读取 txt 文件中的图片列表
# with open(txt_file_path, 'r') as file:
# image_list = file.read().splitlines()
# landmark_detect_success = 0
# landmark_detect_false = 0
# # 遍历 txt 文件中的图片列表
# for image_filename in tqdm(image_list, desc=f'Processing {txt_filename}'):
# image_path = os.path.join('control_inversion', image_filename+'.png')
# if image_path.endswith('.png') or image_path.endswith('.jpg'):
# img_np = cv2.imread(image_path)
# # 进行人脸检测和关键点提取
# with torch.no_grad():
# predictor = KeypointExtractor('cuda')
# dets = predictor.det_net.detect_faces(img_np, 0.97)
# if len(dets) == 0:
# landmark_detect_false += 1
# else:
# landmark_detect_success += 1
# # 计算检测率
# detect_rate = landmark_detect_success / (landmark_detect_success + landmark_detect_false)
# print(f'{txt_filename}: Detect Rate = {detect_rate}')
# # 更新总的计数
# total_landmark_detect_success += landmark_detect_success
# total_landmark_detect_false += landmark_detect_false
# # 计算总的检测率
# total_detect_rate = total_landmark_detect_success / (total_landmark_detect_success + total_landmark_detect_false)
# print(f'Total Detect Rate = {total_detect_rate}')
import os
import sys
import cv2
import torch
from tqdm import tqdm
from chat_anything.sad_talker.face3d.extract_kp_videos_safe import KeypointExtractor
# 设置文件夹路径
folder_path = sys.argv[1]
# 初始化成功和失败的计数
total_landmark_detect_success = 0
total_landmark_detect_false = 0
# 遍历文件夹中的 txt 文件
for txt_filename in os.listdir(folder_path):
if txt_filename.endswith('.txt'):
txt_file_path = os.path.join(folder_path, txt_filename)
# # 读取 txt 文件中的图片列表
# with open(txt_file_path, 'r') as file:
# image_list = file.read().splitlines()
image_list = os.listdir(txt_file_path)
landmark_detect_success = 0
landmark_detect_false = 0
# 遍历 txt 文件中的图片列表
for image_filename in tqdm(image_list, desc=f'Processing {txt_filename}'):
image_path = os.path.join(txt_file_path, image_filename)
if image_path.endswith('.png') or image_path.endswith('.jpg'):
img_np = cv2.imread(image_path)
# 进行人脸检测和关键点提取
with torch.no_grad():
predictor = KeypointExtractor('cuda')
dets = predictor.det_net.detect_faces(img_np, 0.97)
if len(dets) == 0:
landmark_detect_false += 1
else:
landmark_detect_success += 1
# 计算检测率
detect_rate = landmark_detect_success / (landmark_detect_success + landmark_detect_false)
print(f'{txt_filename}: Detect Rate = {detect_rate}')
# 更新总的计数
total_landmark_detect_success += landmark_detect_success
total_landmark_detect_false += landmark_detect_false
# 计算总的检测率
total_detect_rate = total_landmark_detect_success / (total_landmark_detect_success + total_landmark_detect_false)
print(f'Total Detect Rate = {total_detect_rate}')