Mayank022 commited on
Commit
2ca914e
·
verified ·
1 Parent(s): c2c0066

Update config.py

Browse files
Files changed (1) hide show
  1. config.py +64 -0
config.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import dataclasses
3
+ from typing import Optional, Tuple
4
+
5
+ @dataclasses.dataclass
6
+ class ModelConfig:
7
+ audio_model_id: str = "openai/whisper-medium"
8
+ text_model_id: str = "sarvamai/sarvam-m"
9
+ hidden_size: int = 2048
10
+ projector_act: str = "gelu"
11
+ stack_factor: int = 8
12
+
13
+ def to_dict(self):
14
+ return dataclasses.asdict(self)
15
+
16
+ @dataclasses.dataclass
17
+ class TrainConfig:
18
+ # --- Batch & GPU (tuned for A100 80GB) ---
19
+ batch_size: int = 8 # per-device; try 64 if no OOM
20
+ accum_steps: int = 2 # effective batch = 32*2=64; reduce if OOM
21
+ use_bf16: bool = True # A100 native bf16: faster + less VRAM
22
+ gradient_checkpointing: bool = False # set True if OOM to trade compute for memory
23
+ dataloader_num_workers: int = 8
24
+ dataloader_pin_memory: bool = True
25
+
26
+ learning_rate: float = 1e-4
27
+ lr_scheduler_type: str = "cosine"
28
+ num_epochs: int = 1
29
+ max_steps: int = 10000 # Use either epochs or steps
30
+
31
+ # Paths
32
+ output_dir: str = "./checkpoints"
33
+ # data_path: str = "./data/train.jsonl" # REMOVED
34
+ dataset_name: str = "fixie-ai/common_voice_17_0"
35
+ dataset_subset: str = "hi" # Hindi
36
+ dataset_split: str = "train"
37
+ val_dataset_split: str = "validation"
38
+
39
+ # LoRA
40
+ use_lora: bool = True
41
+ lora_r: int = 16
42
+ lora_alpha: int = 32
43
+ lora_dropout: float = 0.05
44
+
45
+ # Hub
46
+ push_to_hub: bool = False
47
+ hub_model_id: Optional[str] = os.getenv("HUB_MODEL_ID", None) # e.g. "username/model-name"
48
+ hub_token: Optional[str] = os.getenv("HUB_TOKEN", None)
49
+ hub_private_repo: bool = True
50
+
51
+ # WandB
52
+ wandb_project: str = os.getenv("WANDB_PROJECT", "audio-language-model")
53
+ wandb_entity: Optional[str] = os.getenv("WANDB_ENTITY", None)
54
+ wandb_run_name: Optional[str] = None
55
+ wandb_watch: str = "false" # "gradients", "all", "false"
56
+ wandb_log_model: str = "false" # "true", "false"
57
+
58
+ # Misc
59
+ seed: int = 42
60
+ log_steps: int = 10
61
+ eval_steps: int = 250
62
+ save_steps: int = 500
63
+ save_total_limit: int = 1
64
+ sample_pred_every_steps: int = 250 # print ground-truth vs predicted transcript every N steps