roll-ai's picture
Upload 177 files
59d751c verified
from pathlib import Path
from typing import Any, Dict, List
import torch
from pydantic import BaseModel
class State(BaseModel):
model_config = {"arbitrary_types_allowed": True}
train_frames: int
train_height: int
train_width: int
transformer_config: Dict[str, Any] = None
weight_dtype: torch.dtype = torch.float32 # dtype for mixed precision training
num_trainable_parameters: int = 0
overwrote_max_train_steps: bool = False
num_update_steps_per_epoch: int = 0
total_batch_size_count: int = 0
generator: torch.Generator | None = None
validation_prompts: List[str] = []
validation_images: List[Path | None] = []
validation_videos: List[Path | None] = []
# WJ: Added..
validation_prompt_embeddings: List[Path | None] = []
validation_video_latents: List[Path | None] = []
validation_flow_latents: List[Path | None] = []
validation_valid_masks: List[Path | None] = []
validation_valid_masks_interp: List[Path | None] = []
using_deepspeed: bool = False