Spaces:
Paused
Paused
from pathlib import Path | |
from typing import Any, Dict, List | |
import torch | |
from pydantic import BaseModel | |
class State(BaseModel): | |
model_config = {"arbitrary_types_allowed": True} | |
train_frames: int | |
train_height: int | |
train_width: int | |
transformer_config: Dict[str, Any] = None | |
weight_dtype: torch.dtype = torch.float32 # dtype for mixed precision training | |
num_trainable_parameters: int = 0 | |
overwrote_max_train_steps: bool = False | |
num_update_steps_per_epoch: int = 0 | |
total_batch_size_count: int = 0 | |
generator: torch.Generator | None = None | |
validation_prompts: List[str] = [] | |
validation_images: List[Path | None] = [] | |
validation_videos: List[Path | None] = [] | |
# WJ: Added.. | |
validation_prompt_embeddings: List[Path | None] = [] | |
validation_video_latents: List[Path | None] = [] | |
validation_flow_latents: List[Path | None] = [] | |
validation_valid_masks: List[Path | None] = [] | |
validation_valid_masks_interp: List[Path | None] = [] | |
using_deepspeed: bool = False | |