NSAQA / temporal_segmentation.py
laurenok24's picture
Upload temporal_segmentation.py
7b7e262 verified
raw
history blame
11.7 kB
from microprograms.temporal_segmentation.entry import entry_microprogram_one_frame
from microprograms.temporal_segmentation.somersault import somersault_microprogram_one_frame
from microprograms.temporal_segmentation.twist import twist_microprogram_one_frame
from microprograms.temporal_segmentation.start_takeoff import takeoff_microprogram_one_frame
from microprograms.errors.distance_from_springboard_micro_program import board_end
from microprograms.errors.splash_micro_program import get_splash_from_one_frame
from microprograms.errors.distance_from_springboard_micro_program import calculate_distance_from_springboard_for_one_frame
from microprograms.errors.distance_from_springboard_micro_program import calculate_distance_from_platform_for_one_frame
from microprograms.errors.angles_micro_programs import applyFeetApartError
from microprograms.errors.angles_micro_programs import applyPositionTightnessError
from models.detectron2.platform_detector_setup import get_platform_detector
from models.pose_estimator.pose_estimator_model_setup import get_pose_estimation
from models.detectron2.diver_detector_setup import get_diver_detector
from models.pose_estimator.pose_estimator_model_setup import get_pose_model
from models.detectron2.splash_detector_setup import get_splash_detector
from somersault_counter import som_counter, twist_counter
from microprograms.errors.over_rotation import over_rotation
import pickle
import os
import math
import numpy as np
import cv2
# returns None if either pose_pred or board_end_coord is None
# returns True if diver is on board, returns False if diver is off board
def detect_on_board(board_end_coord, board_side, pose_pred, handstand):
if pose_pred is None:
print("pose_pred is None")
return
if board_end_coord is None:
print("board end coord is None")
return
# side = find_which_side_board_on(outputs)
if board_side == 'left':
# if right of board end
if np.array(pose_pred)[0][2][0] > int(board_end_coord[0]):
# print("board on left side, we're saying that the diver is on the right")
# if standing_dive:
if handstand:
# print("board end to wrist dist:", math.dist(np.array(pose_pred)[0][15], board_end_coord))
# print("wrist to elbow dist:", math.dist(np.array(pose_pred)[0][14], np.array(pose_pred)[0][15]))
distance = math.dist(np.array(pose_pred)[0][15], board_end_coord) < math.dist(np.array(pose_pred)[0][14], np.array(pose_pred)[0][15]) * 1.5
else:
# print("board end to ankle dist:", math.dist(np.array(pose_pred)[0][5], board_end_coord))
# print("ankle to knee dist:", math.dist(np.array(pose_pred)[0][5], np.array(pose_pred)[0][4]))
distance = math.dist(np.array(pose_pred)[0][5], board_end_coord) < math.dist(np.array(pose_pred)[0][5], np.array(pose_pred)[0][4]) * 1.5
# print("distance", distance)
return distance
# return False
# if left of board end
else:
# print("board on left side, we're saying that the diver is on the left")
return True
else:
# if right of board end
if np.array(pose_pred)[0][2][0] > int(board_end_coord[0]):
# print("board on right side, we're saying that the diver is on the right")
return True
# if left of board end
else:
# print("board on right side, we're saying that the diver is on the left")
# if standing_dive:
if handstand:
distance = math.dist(np.array(pose_pred)[0][10], board_end_coord) < math.dist(np.array(pose_pred)[0][11], np.array(pose_pred)[0][10]) * 1.5
else:
distance = math.dist(np.array(pose_pred)[0][0], board_end_coord) < math.dist(np.array(pose_pred)[0][1], np.array(pose_pred)[0][0]) * 1.5
# print("distance", distance)
return distance
# return False
def main():
first_folder = input("what is the first folder? Ex: 01, FINAWorldChampionships2019_Women10m_final_r1, etc. ")
second_folder = input("what is the second folder? (dive within the first folder)")
takeoff = []
twist = []
som = []
entry = []
distance_from_board = []
position_tightness = []
feet_apart = []
splash = []
above_board = True
board_side = input("what side is the board on? type either 'left' or 'right'")
on_board = True
handstand = input("is the dive a handstand dive? type either 'True' or 'False' ") == 'True'
expected_twists = int(input("what are the expected number of twists? type 1 for half a twist, 2 for full twist, etc."))
expected_som = int(input("what are the expected number of somersaults? type 1 for half a somersault, 2 for full somersault, etc."))
# expected_twists= 0
# expected_som = 7
platform_detector = get_platform_detector()
splash_detector = get_splash_detector()
diver_detector = get_diver_detector()
pose_model = get_pose_model()
dive_folder_num = "{}_{}".format(first_folder, second_folder)
directory = './FineDiving/datasets/FINADiving_MTL_256s/{}/{}/'.format(first_folder, second_folder)
# dive_folder_num = '01_1'
# directory = './FineDiving/datasets/FINADiving_MTL_256s/01/1/'
# dive_folder_num = 'FINAWorldChampionships2019_Women10m_final_r1_0'
# directory = './FineDiving/datasets/FINADiving_MTL_256s/FINAWorldChampionships2019_Women10m_final_r1/0/'
file_names = os.listdir(directory)
# with open('./output/joint_plots/{}/pose_preds.pkl'.format(dive_folder_num), 'rb') as pickle_file:
# pose_preds = pickle.load(pickle_file)
j = 0
# if len(pose_preds) > len(file_names):
# print("WRONG POSE_PREDS")
prev_pred = None
som_prev_pred =None
half_som_count=0
petal_count = 0
in_petal = False
for i in range(len(file_names)):
# pose_pred = None
filepath = directory + file_names[i]
print("filepath:", filepath)
if file_names[i][-4:] != ".jpg":
continue
diver_box, pose_pred = get_pose_estimation(filepath, diver_detector=diver_detector, pose_model=pose_model)
# if j < len(pose_preds):
# print("filepath has pose_pred:", filepath)
# pose_pred = pose_preds[j]
# j += 1
calculated_half_som_count, skip = som_counter(pose_pred, som_prev_pred, half_som_count=half_som_count, handstand=handstand)
if not skip:
som_prev_pred = pose_pred
print("calculated_half_som_count:", calculated_half_som_count)
calculated_petal_count, calculated_in_petal = twist_counter(pose_pred, prev_pose_pred=prev_pred, in_petal=in_petal, petal_count=petal_count)
print("calculated_petal_count", calculated_petal_count)
print("calculated_in_petal", calculated_in_petal)
im = cv2.imread(filepath)
outputs = platform_detector(im)
board_end_coord = board_end(outputs, board_side=board_side)
# if board_end_coord is None:
# print("NO BOARD NONE CRYING")
if above_board and not on_board and board_end_coord is not None and pose_pred is not None and np.array(pose_pred)[0][2][1] > int(board_end_coord[1]):
above_board=False
if on_board and detect_on_board(board_end_coord, board_side, pose_pred, handstand) is not None and not detect_on_board(board_end_coord, board_side, pose_pred, handstand):
on_board = False
print('ON_BOARD:', on_board)
print('ABOVE_BOARD:', above_board)
calculated_takeoff = takeoff_microprogram_one_frame(filepath, above_board=above_board, on_board=on_board, pose_pred=pose_pred)
calculated_twist = twist_microprogram_one_frame(filepath, on_board=on_board, pose_pred=pose_pred, expected_twists=expected_twists, petal_count=petal_count, expected_som=expected_som, half_som_count=half_som_count)
calculated_som = somersault_microprogram_one_frame(filepath, pose_pred=pose_pred, on_board=on_board, expected_som=expected_som, half_som_count=half_som_count, expected_twists=expected_twists, petal_count=petal_count)
calculated_entry = entry_microprogram_one_frame(filepath, above_board=above_board, on_board=on_board, pose_pred=pose_pred, expected_twists=expected_twists, petal_count=petal_count, expected_som=expected_som, half_som_count=half_som_count, splash_detector=splash_detector, visualize=True, dive_folder_num=dive_folder_num)
if calculated_takeoff == 1:
# distance from board
dist = calculate_distance_from_platform_for_one_frame(filepath, visualize=False, pose_pred=pose_pred, board_end_coord=board_end_coord, platform_detector=platform_detector) # saves photo to ./output/data/distance_from_board/
distance_from_board.append(dist)
# height off board
# height_off_board.append(dist)
# proximity to the edge of board
# proximity_to_end_board.append(dist)
elif calculated_som == 1:
half_som_count = calculated_half_som_count
# petal_count = calculated_petal_count
# in_petal = calculated_in_petal
if above_board:
dist = calculate_distance_from_platform_for_one_frame(filepath, visualize=False, pose_pred=pose_pred, board_end_coord=board_end_coord, platform_detector=platform_detector) # saves photo to ./output/data/distance_from_board/
# distance from board
distance_from_board.append(dist)
# rotation speed
# position tightness
position_tightness.append(applyPositionTightnessError(filepath, pose_pred=pose_pred))
# feet flat
# feet/legs apart
feet_apart.append(applyFeetApartError(filepath, pose_pred=pose_pred))
elif calculated_twist == 1:
half_som_count = calculated_half_som_count
petal_count = calculated_petal_count
in_petal = calculated_in_petal
if above_board:
dist = calculate_distance_from_platform_for_one_frame(filepath, visualize=False, pose_pred=pose_pred, board_end_coord=board_end_coord, platform_detector=platform_detector) # saves photo to ./output/data/distance_from_board/
# distance from board
distance_from_board.append(dist)
# twisting speed
# position tightness
position_tightness.append(180 - applyPositionTightnessError(filepath, pose_pred=pose_pred))
# feet flat
# feet/legs apart
feet_apart.append(applyFeetApartError(filepath, pose_pred=pose_pred))
elif calculated_entry == 1:
# over/under twisting
# over/under rotating
print("OVER-ROTATION ERROR:", over_rotation(filepath, pose_pred=pose_pred, diver_detector=diver_detector, pose_model=pose_model))
# not straight during entry
# splash
splash.append(get_splash_from_one_frame(filepath, predictor=splash_detector, visualize=True))
else:
print('no phase of dive calculated!')
takeoff.append(calculated_takeoff)
twist.append(calculated_twist)
som.append(calculated_som)
entry.append(calculated_entry)
prev_pred = pose_pred
print("takeoff", takeoff)
print("twist", twist)
print("som", som)
print("entry", entry)
print("distance_from_board", distance_from_board)
print("position_tightness", position_tightness)
print("feet_apart", feet_apart)
print("splash", splash)
if __name__ == "__main__":
main()