|
|
import os |
|
|
import dataclasses |
|
|
from typing import Optional, Tuple |
|
|
|
|
|
@dataclasses.dataclass |
|
|
class ModelConfig: |
|
|
audio_model_id: str = "openai/whisper-medium" |
|
|
text_model_id: str = "sarvamai/sarvam-m" |
|
|
hidden_size: int = 2048 |
|
|
projector_act: str = "gelu" |
|
|
stack_factor: int = 8 |
|
|
|
|
|
def to_dict(self): |
|
|
return dataclasses.asdict(self) |
|
|
|
|
|
@dataclasses.dataclass |
|
|
class TrainConfig: |
|
|
|
|
|
batch_size: int = 8 |
|
|
accum_steps: int = 2 |
|
|
use_bf16: bool = True |
|
|
gradient_checkpointing: bool = False |
|
|
dataloader_num_workers: int = 8 |
|
|
dataloader_pin_memory: bool = True |
|
|
|
|
|
learning_rate: float = 1e-4 |
|
|
lr_scheduler_type: str = "cosine" |
|
|
num_epochs: int = 1 |
|
|
max_steps: int = 10000 |
|
|
|
|
|
|
|
|
output_dir: str = "./checkpoints" |
|
|
|
|
|
dataset_name: str = "fixie-ai/common_voice_17_0" |
|
|
dataset_subset: str = "hi" |
|
|
dataset_split: str = "train" |
|
|
val_dataset_split: str = "validation" |
|
|
|
|
|
|
|
|
use_lora: bool = True |
|
|
lora_r: int = 16 |
|
|
lora_alpha: int = 32 |
|
|
lora_dropout: float = 0.05 |
|
|
|
|
|
|
|
|
push_to_hub: bool = False |
|
|
hub_model_id: Optional[str] = os.getenv("HUB_MODEL_ID", None) |
|
|
hub_token: Optional[str] = os.getenv("HUB_TOKEN", None) |
|
|
hub_private_repo: bool = True |
|
|
|
|
|
|
|
|
wandb_project: str = os.getenv("WANDB_PROJECT", "audio-language-model") |
|
|
wandb_entity: Optional[str] = os.getenv("WANDB_ENTITY", None) |
|
|
wandb_run_name: Optional[str] = None |
|
|
wandb_watch: str = "false" |
|
|
wandb_log_model: str = "false" |
|
|
|
|
|
|
|
|
seed: int = 42 |
|
|
log_steps: int = 10 |
|
|
eval_steps: int = 250 |
|
|
save_steps: int = 500 |
|
|
save_total_limit: int = 1 |
|
|
sample_pred_every_steps: int = 250 |
|
|
|