from __future__ import annotations import datetime import os import pathlib import shlex import shutil import subprocess import sys import gradio as gr import slugify import torch from huggingface_hub import HfApi from omegaconf import OmegaConf from app_upload import ModelUploader from utils import save_model_card sys.path.append('Tune-A-Video') URL_TO_JOIN_MODEL_LIBRARY_ORG = 'https://huggingface.co/organizations/Tune-A-Video-library/share/YjTcaNJmKyeHFpMBioHhzBcTzCYddVErEk' class Trainer: def __init__(self, hf_token: str | None = None): self.hf_token = hf_token self.api = HfApi(token=hf_token) self.model_uploader = ModelUploader(hf_token) self.checkpoint_dir = pathlib.Path('checkpoints') self.checkpoint_dir.mkdir(exist_ok=True) def download_base_model(self, base_model_id: str) -> str: model_dir = self.checkpoint_dir / base_model_id if not model_dir.exists(): org_name = base_model_id.split('/')[0] org_dir = self.checkpoint_dir / org_name org_dir.mkdir(exist_ok=True) subprocess.run(shlex.split( f'git clone https://huggingface.co/{base_model_id}'), cwd=org_dir) return model_dir.as_posix() def join_model_library_org(self) -> None: subprocess.run( shlex.split( f'curl -X POST -H "Authorization: Bearer {self.hf_token}" -H "Content-Type: application/json" {URL_TO_JOIN_MODEL_LIBRARY_ORG}' )) def run( self, training_video: str, training_prompt: str, output_model_name: str, overwrite_existing_model: bool, validation_prompt: str, base_model: str, resolution_s: str, n_steps: int, learning_rate: float, gradient_accumulation: int, seed: int, fp16: bool, use_8bit_adam: bool, checkpointing_steps: int, validation_epochs: int, upload_to_hub: bool, use_private_repo: bool, delete_existing_repo: bool, upload_to: str, remove_gpu_after_training: bool, ) -> str: if not torch.cuda.is_available(): raise gr.Error('CUDA is not available.') if training_video is None: raise gr.Error('You need to upload a video.') if not training_prompt: raise gr.Error('The training prompt is missing.') if not validation_prompt: raise gr.Error('The validation prompt is missing.') resolution = int(resolution_s) if not output_model_name: timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') output_model_name = f'tune-a-video-{timestamp}' output_model_name = slugify.slugify(output_model_name) repo_dir = pathlib.Path(__file__).parent output_dir = repo_dir / 'experiments' / output_model_name if overwrite_existing_model or upload_to_hub: shutil.rmtree(output_dir, ignore_errors=True) output_dir.mkdir(parents=True) if upload_to_hub: self.join_model_library_org() config = OmegaConf.load('Tune-A-Video/configs/man-surfing.yaml') config.pretrained_model_path = self.download_base_model(base_model) config.output_dir = output_dir.as_posix() config.train_data.video_path = training_video.name # type: ignore config.train_data.prompt = training_prompt config.train_data.n_sample_frames = 8 config.train_data.width = resolution config.train_data.height = resolution config.train_data.sample_start_idx = 0 config.train_data.sample_frame_rate = 1 config.validation_data.prompts = [validation_prompt] config.validation_data.video_length = 8 config.validation_data.width = resolution config.validation_data.height = resolution config.validation_data.num_inference_steps = 50 config.validation_data.guidance_scale = 7.5 config.learning_rate = learning_rate config.gradient_accumulation_steps = gradient_accumulation config.train_batch_size = 1 config.max_train_steps = n_steps config.checkpointing_steps = checkpointing_steps config.validation_steps = validation_epochs config.seed = seed config.mixed_precision = 'fp16' if fp16 else '' config.use_8bit_adam = use_8bit_adam config_path = output_dir / 'config.yaml' with open(config_path, 'w') as f: OmegaConf.save(config, f) command = f'accelerate launch Tune-A-Video/train_tuneavideo.py --config {config_path}' subprocess.run(shlex.split(command)) save_model_card(save_dir=output_dir, base_model=base_model, training_prompt=training_prompt, test_prompt=validation_prompt, test_image_dir='samples') message = 'Training completed!' print(message) if upload_to_hub: upload_message = self.model_uploader.upload_model( folder_path=output_dir.as_posix(), repo_name=output_model_name, upload_to=upload_to, private=use_private_repo, delete_existing_repo=delete_existing_repo) print(upload_message) message = message + '\n' + upload_message if remove_gpu_after_training: space_id = os.getenv('SPACE_ID') if space_id: self.api.request_space_hardware(repo_id=space_id, hardware='cpu-basic') return message