File size: 1,991 Bytes
0a3525d
 
 
 
 
69e8a46
 
0a3525d
 
 
 
 
69e8a46
0a3525d
 
 
 
 
 
 
 
69e8a46
0a3525d
 
 
69e8a46
0a3525d
 
 
69e8a46
0a3525d
 
69e8a46
0a3525d
 
69e8a46
0a3525d
 
 
69e8a46
0a3525d
 
69e8a46
0a3525d
 
69e8a46
0a3525d
 
 
 
 
 
 
 
 
69e8a46
 
 
 
 
 
 
0a3525d
 
 
 
69e8a46
0a3525d
 
 
 
 
 
 
 
69e8a46
0a3525d
69e8a46
0a3525d
 
 
 
69e8a46
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
defaults:
  - base
  - _self_

project: text2semantic_finetune_dual_ar
max_length: 4096
pretrained_ckpt_path: checkpoints/fish-speech-1.2-sft

# Lightning Trainer
trainer:
  accumulate_grad_batches: 1
  gradient_clip_val: 1.0
  gradient_clip_algorithm: "norm"
  max_steps: 1000
  precision: bf16-true
  limit_val_batches: 10
  val_check_interval: 100

# Dataset Configuration
tokenizer:
  _target_: transformers.AutoTokenizer.from_pretrained
  pretrained_model_name_or_path: ${pretrained_ckpt_path}

# Dataset Configuration
train_dataset:
  _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
  proto_files:
    - data/protos
  tokenizer: ${tokenizer}
  causal: true
  max_length: ${max_length}
  use_speaker: false
  interactive_prob: 0.7

val_dataset:
  _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
  proto_files:
    - data/protos
  tokenizer: ${tokenizer}
  causal: true
  max_length: ${max_length}
  use_speaker: false
  interactive_prob: 0.7

data:
  _target_: fish_speech.datasets.semantic.SemanticDataModule
  train_dataset: ${train_dataset}
  val_dataset: ${val_dataset}
  num_workers: 4
  batch_size: 8
  tokenizer: ${tokenizer}
  max_length: ${max_length}

# Model Configuration
model:
  _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic
  model: 
    _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained
    path: ${pretrained_ckpt_path}
    load_weights: true
    max_length: ${max_length}
    lora_config: null

  optimizer:
    _target_: torch.optim.AdamW
    _partial_: true
    lr: 1e-4
    weight_decay: 0
    betas: [0.9, 0.95]
    eps: 1e-5

  lr_scheduler:
    _target_: torch.optim.lr_scheduler.LambdaLR
    _partial_: true
    lr_lambda:
      _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda
      _partial_: true
      num_warmup_steps: 10

# Callbacks
callbacks:
  model_checkpoint:
    every_n_train_steps: ${trainer.val_check_interval}