Diff-Refine / src /config.py
2ira's picture
Add files using upload-large-folder tool
77d636f verified
from dataclasses import dataclass
import torch
@dataclass
class ModelConfig:
# sequence latent space config
encoder_name: str = "../jina-embeddings-v2-base-code" #"jinaai/jina-embeddings-v2-base-code","microsoft/codebert-base" # or roberta-base
input_dim: int = 768 # Jina Base is 768
latent_dim: int = 768 # 保留最大语义
decoder_layers: int = 4 # simple NAR decoder
# VAE Adapter config
max_seq_len: int = 2048 # set according to task
patch_size: int = 4 # patching compress rate
# DiT setting
dit_layers: int = 12
dit_heads: int = 8
dit_hidden: int = 768 # hidden width, less than latent_dim*patch_size to cut oom
mlp_ratio: float = 4.0
# @property
# def dit_hidden(self):
# return self.latent_dim
@dataclass
class TrainConfig:
device: str = "cuda" if torch.cuda.is_available() else "cpu"
lr_ae: float = 1e-4
lr_flow: float = 5e-4
batch_size: int = 8
grad_accum_steps: int = 4 # 梯度积累,等效于Batch_size = 32
num_epochs_ae: int = 20 # 先训练AE 再训练Flow
num_epochs_flow: int = 50 # flow 需要训练的论数要多一些
grad_clip: float = 1.0
use_amp: bool = False # 混合精度训练,Jina+AMP 容易报错
save_dir: str = "./checkpoints"
def __post_init__(self):
import os
os.makedirs(self.save_dir, exist_ok=True)