yugangee's picture
Update main.py
1834923 verified
import os
import subprocess
from utils import read_video, save_video
from trackers import Tracker
import cv2
import numpy as np
from team_assigner import TeamAssigner
from player_ball_assigner import PlayerBallAssigner
from camera_movement_estimator import CameraMovementEstimator
from view_transformer import ViewTransformer
from speed_and_distance_estimator import SpeedAndDistance_Estimator
from commentary_ai.generator.frame_analyzer import analyze_frame
from commentary_ai.generator.retriever import load_vector_store, search_similar_sentences
from commentary_ai.generator.openai_captioner import generate_caption
def save_subtitles(subtitle_texts, fps, output_path):
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
index = 1
for frame_idx, text in enumerate(subtitle_texts):
if not text:
continue
start_time = frame_idx / fps
end_time = start_time + 2
def sec_to_timestamp(sec):
h = int(sec // 3600)
m = int((sec % 3600) // 60)
s = int(sec % 60)
ms = int((sec - int(sec)) * 1000)
return f"{h:02}:{m:02}:{s:02},{ms:03}"
start_str = sec_to_timestamp(start_time)
end_str = sec_to_timestamp(end_time)
f.write(f"{index}\n{start_str} --> {end_str}\n{text}\n\n")
index += 1
print(f"SRT subtitles saved at: {output_path}")
def main(video_path=None, output_path="output_videos/output_video.mp4"):
FPS = 30
vector_store_path = os.path.abspath(
os.path.join("commentary_ai", "generator", "vector_store.pkl")
)
print("벡터 저장소 절대 경로:", vector_store_path)
load_vector_store(vector_store_path)
video_frames = read_video(video_path if video_path else 'input_videos/sample_1.mp4')
tracker = Tracker('models/best.pt')
tracks = tracker.get_object_tracks(
video_frames,
read_from_stub=True,
stub_path='stubs/track_stubs.pkl'
)
tracker.add_positions_to_tracks(tracks)
camera_movement_estimator = CameraMovementEstimator(video_frames[0])
camera_movement_per_frame = camera_movement_estimator.get_camera_movement(
video_frames,
read_from_stub=True,
stub_path='stubs/camera_movement_stub.pkl'
)
camera_movement_estimator.add_adjust_positions_to_tracks(tracks, camera_movement_per_frame)
view_transformer = ViewTransformer()
view_transformer.add_transformed_position_to_tracks(tracks)
tracks["ball"] = tracker.interpolate_ball_positions(tracks["ball"])
speed_and_distance_estimator = SpeedAndDistance_Estimator()
speed_and_distance_estimator.add_speed_and_distance_to_tracks(tracks)
team_assigner = TeamAssigner()
team_assigner.assign_team_color(video_frames[0], tracks['players'][0])
for frame_num, player_track in enumerate(tracks['players']):
for player_id, track in player_track.items():
team = team_assigner.get_player_team(video_frames[frame_num], track['bbox'], player_id)
track['team'] = team
track['team_color'] = team_assigner.team_colors[team]
player_assigner = PlayerBallAssigner()
team_ball_control = []
previous_player_with_ball = -1
previous_team_with_ball = None
subtitle_data = []
event_data = []
for frame_num, player_track in enumerate(tracks['players']):
ball_bbox = tracks['ball'][frame_num][1]['bbox']
ball_speed = tracks['ball'][frame_num].get('speed', 0)
assigned_player = player_assigner.assign_ball_to_player(player_track, ball_bbox)
if assigned_player != -1:
player_track[assigned_player]['has_ball'] = True
current_team_with_ball = player_track[assigned_player]['team']
tracker.update_ball_owner(assigned_player, current_team_with_ball)
else:
current_team_with_ball = previous_team_with_ball
team_ball_control.append(current_team_with_ball)
event_texts = []
if previous_player_with_ball != -1 and assigned_player != previous_player_with_ball:
if assigned_player != -1:
event_texts.append(f"패스 성공! 플레이어 {previous_player_with_ball} -> 플레이어 {assigned_player}")
elif assigned_player != -1:
speed = player_track[assigned_player].get('speed', 0)
if speed > 1.5:
event_texts.append(f"플레이어 {assigned_player}이 드리블 중입니다.")
if previous_team_with_ball is not None and current_team_with_ball != previous_team_with_ball:
event_texts.append("태클 성공! 상대 팀이 볼을 차단했습니다.")
if ball_speed > 8:
event_texts.append("슛! 볼이 빠른 속도로 움직입니다.")
goal_area = ((100, 50), (200, 100))
if goal_area[0][0] < ball_bbox[0] < goal_area[1][0] and goal_area[0][1] < ball_bbox[1] < goal_area[1][1]:
event_texts.append("골! 볼이 골대에 들어갔습니다!")
event_data.append("\n".join(event_texts))
if frame_num % (FPS * 1) != 0:
subtitle_data.append(subtitle_data[-1] if subtitle_data else "")
continue
speed = player_track.get(assigned_player, {}).get('speed', 0) if assigned_player != -1 else 0
frame_info = {
"frame_num": frame_num,
"assigned_player": assigned_player,
"player_speed": speed,
"ball_speed": ball_speed,
"team_with_ball": current_team_with_ball,
"ball_position": ball_bbox,
"events": event_texts
}
frame_description = analyze_frame(frame_info)
retrieved_examples = search_similar_sentences(frame_description, top_k=5)
subtitle_text = generate_caption(frame_description, retrieved_examples)
subtitle_data.append(subtitle_text)
previous_player_with_ball = assigned_player
previous_team_with_ball = current_team_with_ball
output_video_frames = tracker.draw_annotations(video_frames, tracks, team_ball_control, subtitle_data, event_data)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
save_video(output_video_frames, output_path)
# 🔧 FFmpeg 재인코딩 (미리보기 가능하게 하기 위함)
temp_output_path = output_path.replace(".mp4", "_temp.mp4")
subprocess.run([
"ffmpeg", "-y",
"-i", output_path,
"-vcodec", "libx264",
"-acodec", "aac",
temp_output_path
], check=True)
os.replace(temp_output_path, output_path)
print(f"✅ Output video saved at: {output_path}")
srt_path = os.path.splitext(output_path)[0] + ".srt"
save_subtitles(subtitle_data, FPS, srt_path)
if __name__ == '__main__':
main()