Spaces:
Runtime error
Runtime error
import os | |
import subprocess | |
from utils import read_video, save_video | |
from trackers import Tracker | |
import cv2 | |
import numpy as np | |
from team_assigner import TeamAssigner | |
from player_ball_assigner import PlayerBallAssigner | |
from camera_movement_estimator import CameraMovementEstimator | |
from view_transformer import ViewTransformer | |
from speed_and_distance_estimator import SpeedAndDistance_Estimator | |
from commentary_ai.generator.frame_analyzer import analyze_frame | |
from commentary_ai.generator.retriever import load_vector_store, search_similar_sentences | |
from commentary_ai.generator.openai_captioner import generate_caption | |
def save_subtitles(subtitle_texts, fps, output_path): | |
os.makedirs(os.path.dirname(output_path), exist_ok=True) | |
with open(output_path, "w", encoding="utf-8") as f: | |
index = 1 | |
for frame_idx, text in enumerate(subtitle_texts): | |
if not text: | |
continue | |
start_time = frame_idx / fps | |
end_time = start_time + 2 | |
def sec_to_timestamp(sec): | |
h = int(sec // 3600) | |
m = int((sec % 3600) // 60) | |
s = int(sec % 60) | |
ms = int((sec - int(sec)) * 1000) | |
return f"{h:02}:{m:02}:{s:02},{ms:03}" | |
start_str = sec_to_timestamp(start_time) | |
end_str = sec_to_timestamp(end_time) | |
f.write(f"{index}\n{start_str} --> {end_str}\n{text}\n\n") | |
index += 1 | |
print(f"SRT subtitles saved at: {output_path}") | |
def main(video_path=None, output_path="output_videos/output_video.mp4"): | |
FPS = 30 | |
vector_store_path = os.path.abspath( | |
os.path.join("commentary_ai", "generator", "vector_store.pkl") | |
) | |
print("벡터 저장소 절대 경로:", vector_store_path) | |
load_vector_store(vector_store_path) | |
video_frames = read_video(video_path if video_path else 'input_videos/sample_1.mp4') | |
tracker = Tracker('models/best.pt') | |
tracks = tracker.get_object_tracks( | |
video_frames, | |
read_from_stub=True, | |
stub_path='stubs/track_stubs.pkl' | |
) | |
tracker.add_positions_to_tracks(tracks) | |
camera_movement_estimator = CameraMovementEstimator(video_frames[0]) | |
camera_movement_per_frame = camera_movement_estimator.get_camera_movement( | |
video_frames, | |
read_from_stub=True, | |
stub_path='stubs/camera_movement_stub.pkl' | |
) | |
camera_movement_estimator.add_adjust_positions_to_tracks(tracks, camera_movement_per_frame) | |
view_transformer = ViewTransformer() | |
view_transformer.add_transformed_position_to_tracks(tracks) | |
tracks["ball"] = tracker.interpolate_ball_positions(tracks["ball"]) | |
speed_and_distance_estimator = SpeedAndDistance_Estimator() | |
speed_and_distance_estimator.add_speed_and_distance_to_tracks(tracks) | |
team_assigner = TeamAssigner() | |
team_assigner.assign_team_color(video_frames[0], tracks['players'][0]) | |
for frame_num, player_track in enumerate(tracks['players']): | |
for player_id, track in player_track.items(): | |
team = team_assigner.get_player_team(video_frames[frame_num], track['bbox'], player_id) | |
track['team'] = team | |
track['team_color'] = team_assigner.team_colors[team] | |
player_assigner = PlayerBallAssigner() | |
team_ball_control = [] | |
previous_player_with_ball = -1 | |
previous_team_with_ball = None | |
subtitle_data = [] | |
event_data = [] | |
for frame_num, player_track in enumerate(tracks['players']): | |
ball_bbox = tracks['ball'][frame_num][1]['bbox'] | |
ball_speed = tracks['ball'][frame_num].get('speed', 0) | |
assigned_player = player_assigner.assign_ball_to_player(player_track, ball_bbox) | |
if assigned_player != -1: | |
player_track[assigned_player]['has_ball'] = True | |
current_team_with_ball = player_track[assigned_player]['team'] | |
tracker.update_ball_owner(assigned_player, current_team_with_ball) | |
else: | |
current_team_with_ball = previous_team_with_ball | |
team_ball_control.append(current_team_with_ball) | |
event_texts = [] | |
if previous_player_with_ball != -1 and assigned_player != previous_player_with_ball: | |
if assigned_player != -1: | |
event_texts.append(f"패스 성공! 플레이어 {previous_player_with_ball} -> 플레이어 {assigned_player}") | |
elif assigned_player != -1: | |
speed = player_track[assigned_player].get('speed', 0) | |
if speed > 1.5: | |
event_texts.append(f"플레이어 {assigned_player}이 드리블 중입니다.") | |
if previous_team_with_ball is not None and current_team_with_ball != previous_team_with_ball: | |
event_texts.append("태클 성공! 상대 팀이 볼을 차단했습니다.") | |
if ball_speed > 8: | |
event_texts.append("슛! 볼이 빠른 속도로 움직입니다.") | |
goal_area = ((100, 50), (200, 100)) | |
if goal_area[0][0] < ball_bbox[0] < goal_area[1][0] and goal_area[0][1] < ball_bbox[1] < goal_area[1][1]: | |
event_texts.append("골! 볼이 골대에 들어갔습니다!") | |
event_data.append("\n".join(event_texts)) | |
if frame_num % (FPS * 1) != 0: | |
subtitle_data.append(subtitle_data[-1] if subtitle_data else "") | |
continue | |
speed = player_track.get(assigned_player, {}).get('speed', 0) if assigned_player != -1 else 0 | |
frame_info = { | |
"frame_num": frame_num, | |
"assigned_player": assigned_player, | |
"player_speed": speed, | |
"ball_speed": ball_speed, | |
"team_with_ball": current_team_with_ball, | |
"ball_position": ball_bbox, | |
"events": event_texts | |
} | |
frame_description = analyze_frame(frame_info) | |
retrieved_examples = search_similar_sentences(frame_description, top_k=5) | |
subtitle_text = generate_caption(frame_description, retrieved_examples) | |
subtitle_data.append(subtitle_text) | |
previous_player_with_ball = assigned_player | |
previous_team_with_ball = current_team_with_ball | |
output_video_frames = tracker.draw_annotations(video_frames, tracks, team_ball_control, subtitle_data, event_data) | |
os.makedirs(os.path.dirname(output_path), exist_ok=True) | |
save_video(output_video_frames, output_path) | |
# 🔧 FFmpeg 재인코딩 (미리보기 가능하게 하기 위함) | |
temp_output_path = output_path.replace(".mp4", "_temp.mp4") | |
subprocess.run([ | |
"ffmpeg", "-y", | |
"-i", output_path, | |
"-vcodec", "libx264", | |
"-acodec", "aac", | |
temp_output_path | |
], check=True) | |
os.replace(temp_output_path, output_path) | |
print(f"✅ Output video saved at: {output_path}") | |
srt_path = os.path.splitext(output_path)[0] + ".srt" | |
save_subtitles(subtitle_data, FPS, srt_path) | |
if __name__ == '__main__': | |
main() | |