Spaces:
Runtime error
Runtime error
from gradio_client import Client, handle_file | |
from datetime import datetime | |
import os | |
import shutil | |
import logging | |
import time | |
from typing import Tuple, Optional | |
class TalkingHeadAPIClient: | |
"""DittoTalkingHead API クライアント""" | |
def __init__(self, space_name: str = "O-ken5481/talkingAvater_bgk", max_retries: int = 3, retry_delay: int = 5): | |
""" | |
Args: | |
space_name: Hugging Face SpaceのID(デフォルト: O-ken5481/talkingAvater_bgk) | |
max_retries: 最大リトライ回数 | |
retry_delay: リトライ間隔(秒) | |
""" | |
self.space_name = space_name | |
self.max_retries = max_retries | |
self.retry_delay = retry_delay | |
self.logger = self._setup_logger() | |
self.client = None | |
self._connect() | |
def _setup_logger(self) -> logging.Logger: | |
"""ロガーの設定""" | |
logger = logging.getLogger('TalkingHeadAPIClient') | |
logger.setLevel(logging.INFO) | |
if not logger.handlers: | |
handler = logging.StreamHandler() | |
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s', | |
datefmt='%Y-%m-%d %H:%M:%S') | |
handler.setFormatter(formatter) | |
logger.addHandler(handler) | |
return logger | |
def _connect(self) -> None: | |
"""APIへの接続""" | |
for attempt in range(self.max_retries): | |
try: | |
self.logger.info(f"接続開始: {self.space_name} (試行 {attempt + 1}/{self.max_retries})") | |
self.client = Client(self.space_name) | |
self.logger.info("接続成功") | |
return | |
except Exception as e: | |
self.logger.error(f"接続失敗: {e}") | |
if attempt < self.max_retries - 1: | |
self.logger.info(f"{self.retry_delay}秒後にリトライします...") | |
time.sleep(self.retry_delay) | |
else: | |
raise ConnectionError(f"APIへの接続に失敗しました: {e}") | |
def generate_video(self, audio_path: str, image_path: str) -> Tuple[Optional[dict], str]: | |
""" | |
API経由で動画生成 | |
Args: | |
audio_path: 音声ファイルのパス | |
image_path: 画像ファイルのパス | |
Returns: | |
tuple: (video_data, status_message) | |
""" | |
# ファイルの存在確認 | |
if not os.path.exists(audio_path): | |
error_msg = f"音声ファイルが見つかりません: {audio_path}" | |
self.logger.error(error_msg) | |
return None, error_msg | |
if not os.path.exists(image_path): | |
error_msg = f"画像ファイルが見つかりません: {image_path}" | |
self.logger.error(error_msg) | |
return None, error_msg | |
# API呼び出し | |
for attempt in range(self.max_retries): | |
try: | |
self.logger.info(f"ファイルアップロード: {audio_path}, {image_path}") | |
self.logger.info("処理開始...") | |
result = self.client.predict( | |
audio_file=handle_file(audio_path), | |
source_image=handle_file(image_path), | |
api_name="/process_talking_head" | |
) | |
self.logger.info("動画生成完了") | |
return result | |
except Exception as e: | |
self.logger.error(f"処理エラー (試行 {attempt + 1}/{self.max_retries}): {e}") | |
if attempt < self.max_retries - 1: | |
self.logger.info(f"{self.retry_delay}秒後にリトライします...") | |
time.sleep(self.retry_delay) | |
else: | |
error_msg = f"動画生成に失敗しました: {e}" | |
return None, error_msg | |
def save_with_timestamp(self, video_path: str, output_dir: str = "example") -> Optional[str]: | |
""" | |
動画をタイムスタンプ付きで保存 | |
Args: | |
video_path: 生成された動画のパス | |
output_dir: 保存先ディレクトリ | |
Returns: | |
str: 保存されたファイルパス(エラー時はNone) | |
""" | |
try: | |
# 動画パスの確認 | |
if not video_path or not os.path.exists(video_path): | |
self.logger.error(f"動画ファイルが見つかりません: {video_path}") | |
return None | |
# 出力ディレクトリの作成 | |
os.makedirs(output_dir, exist_ok=True) | |
# YYYY-MM-DD_HH-MM-SS.mp4 形式で保存 | |
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
output_path = os.path.join(output_dir, f"{timestamp}.mp4") | |
# ファイルをコピー | |
shutil.copy2(video_path, output_path) | |
# ファイルサイズの確認 | |
file_size = os.path.getsize(output_path) | |
self.logger.info(f"保存完了: {output_path} (サイズ: {file_size:,} bytes)") | |
return output_path | |
except Exception as e: | |
self.logger.error(f"保存エラー: {e}") | |
return None | |
def process_with_save(self, audio_path: str, image_path: str, output_dir: str = "example") -> Tuple[Optional[str], str]: | |
""" | |
動画生成と保存を一括実行 | |
Args: | |
audio_path: 音声ファイルのパス | |
image_path: 画像ファイルのパス | |
output_dir: 保存先ディレクトリ | |
Returns: | |
tuple: (saved_path, status_message) | |
""" | |
# 動画生成 | |
result = self.generate_video(audio_path, image_path) | |
if result[0] is None: | |
return None, result[1] | |
video_data, status = result | |
# 動画の保存 | |
if isinstance(video_data, dict) and 'video' in video_data: | |
saved_path = self.save_with_timestamp(video_data['video'], output_dir) | |
if saved_path: | |
return saved_path, f"{status}\n保存先: {saved_path}" | |
else: | |
return None, f"{status}\n保存に失敗しました" | |
else: | |
return None, f"予期しないレスポンス形式: {video_data}" | |
def main(): | |
"""テストスクリプトのメイン関数""" | |
# ロギング設定 | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(message)s', | |
datefmt='%Y-%m-%d %H:%M:%S' | |
) | |
# クライアント初期化 | |
try: | |
client = TalkingHeadAPIClient() | |
except Exception as e: | |
logging.error(f"クライアント初期化失敗: {e}") | |
return | |
# サンプルファイルを使用 | |
audio_path = "example/audio.wav" | |
image_path = "example/image.png" | |
# ファイルの存在確認 | |
if not os.path.exists(audio_path): | |
logging.error(f"音声ファイルが見つかりません: {audio_path}") | |
return | |
if not os.path.exists(image_path): | |
logging.error(f"画像ファイルが見つかりません: {image_path}") | |
return | |
try: | |
# 動画生成と保存 | |
saved_path, status = client.process_with_save(audio_path, image_path) | |
if saved_path: | |
print(f"\n✅ 成功!") | |
print(f"ステータス: {status}") | |
print(f"動画を確認してください: {saved_path}") | |
else: | |
print(f"\n❌ 失敗") | |
print(f"ステータス: {status}") | |
except KeyboardInterrupt: | |
logging.info("処理を中断しました") | |
except Exception as e: | |
logging.error(f"予期しないエラー: {e}") | |
import traceback | |
traceback.print_exc() | |
if __name__ == "__main__": | |
main() |