Student0809's picture
Add files using upload-large-folder tool
7feac49 verified
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Type
import gradio as gr
from swift.ui.base import BaseUI
class Advanced(BaseUI):
group = 'llm_train'
locale_dict = {
'advanced_param': {
'label': {
'zh': '高级参数设置',
'en': 'Advanced settings'
},
},
'optim': {
'label': {
'zh': 'Optimizer类型',
'en': 'The Optimizer type'
},
'info': {
'zh': '设置Optimizer类型',
'en': 'Set the Optimizer type'
}
},
'weight_decay': {
'label': {
'zh': '权重衰减',
'en': 'Weight decay'
},
'info': {
'zh': '设置weight decay',
'en': 'Set the weight decay'
}
},
'logging_steps': {
'label': {
'zh': '日志打印步数',
'en': 'Logging steps'
},
'info': {
'zh': '设置日志打印的步数间隔',
'en': 'Set the logging interval'
}
},
'lr_scheduler_type': {
'label': {
'zh': 'LrScheduler类型',
'en': 'The LrScheduler type'
},
'info': {
'zh': '设置LrScheduler类型',
'en': 'Set the LrScheduler type'
}
},
'warmup_ratio': {
'label': {
'zh': '学习率warmup比例',
'en': 'Lr warmup ratio'
},
'info': {
'zh': '设置学习率warmup比例',
'en': 'Set the warmup ratio in total steps'
}
},
'more_params': {
'label': {
'zh': '其他高级参数',
'en': 'Other params'
},
'info': {
'zh': '以json格式或--xxx xxx命令行格式填入',
'en': 'Fill in with json format or --xxx xxx cmd format'
}
},
'truncation_strategy': {
'label': {
'zh': '数据集超长策略',
'en': 'Dataset truncation strategy'
},
'info': {
'zh': '如果token超长该如何处理',
'en': 'How to deal with the rows exceed the max length'
}
},
'max_steps': {
'label': {
'zh': '最大迭代步数',
'en': 'Max steps',
},
'info': {
'zh': '设置最大迭代步数,该值如果大于零则数据集迭代次数不生效',
'en': 'Set the max steps, if the value > 0 then num_train_epochs has no effects',
}
},
'per_device_eval_batch_size': {
'label': {
'zh': '验证batch size',
'en': 'Val batch size',
},
'info': {
'zh': '验证的batch size',
'en': 'Set the val batch size',
}
},
'max_grad_norm': {
'label': {
'zh': '梯度裁剪',
'en': 'Max grad norm',
},
'info': {
'zh': '设置梯度裁剪',
'en': 'Set the max grad norm',
}
},
'predict_with_generate': {
'label': {
'zh': '使用生成指标代替loss',
'en': 'Use generate metric instead of loss',
},
'info': {
'zh': '验证时使用generate/Rouge代替loss',
'en': 'Use model.generate/Rouge instead of loss',
}
},
'deepspeed': {
'label': {
'zh': 'deepspeed',
'en': 'deepspeed',
},
'info': {
'zh': '可以选择下拉列表,也支持传入路径',
'en': 'Choose from the dropbox or fill in a valid path',
}
},
}
@classmethod
def do_build_ui(cls, base_tab: Type['BaseUI']):
with gr.Accordion(elem_id='advanced_param', open=False):
with gr.Blocks():
with gr.Row():
gr.Textbox(elem_id='optim', lines=1, scale=20)
gr.Textbox(elem_id='weight_decay', lines=1, scale=20)
gr.Textbox(elem_id='logging_steps', lines=1, scale=20)
gr.Textbox(elem_id='lr_scheduler_type', lines=1, scale=20)
gr.Textbox(elem_id='max_steps', lines=1, scale=20)
gr.Slider(elem_id='warmup_ratio', minimum=0.0, maximum=1.0, step=0.05, scale=20)
with gr.Row():
gr.Dropdown(elem_id='truncation_strategy', scale=20)
gr.Slider(elem_id='per_device_eval_batch_size', minimum=1, maximum=256, step=2, scale=20)
gr.Textbox(elem_id='max_grad_norm', lines=1, scale=20)
gr.Dropdown(
elem_id='deepspeed',
scale=20,
allow_custom_value=True,
value=None,
choices=['zero0', 'zero1', 'zero2', 'zero3', 'zero2_offload', 'zero3_offload'])
with gr.Row():
gr.Textbox(elem_id='more_params', lines=4, scale=20)