- load_checkpoint.py +0 -33
load_checkpoint.py
DELETED
@@ -1,33 +0,0 @@
|
|
1 |
-
# @title Load Checkpoint
|
2 |
-
model_name = 'YPTF.MoE+Multi (noPS)' # @param ["YMT3+", "YPTF+Single (noPS)", "YPTF+Multi (PS)", "YPTF.MoE+Multi (noPS)", "YPTF.MoE+Multi (PS)"]
|
3 |
-
precision = '16' # @param ["32", "bf16-mixed", "16"]
|
4 |
-
project = '2024'
|
5 |
-
|
6 |
-
if model_name == "YMT3+":
|
7 |
-
checkpoint = "notask_all_cross_v6_xk2_amp0811_gm_ext_plus_nops_b72@model.ckpt"
|
8 |
-
args = [checkpoint, '-p', project, '-pr', precision]
|
9 |
-
elif model_name == "YPTF+Single (noPS)":
|
10 |
-
checkpoint = "ptf_all_cross_rebal5_mirst_xk2_edr005_attend_c_full_plus_b100@model.ckpt"
|
11 |
-
args = [checkpoint, '-p', project, '-enc', 'perceiver-tf', '-ac', 'spec',
|
12 |
-
'-hop', '300', '-atc', '1', '-pr', precision]
|
13 |
-
elif model_name == "YPTF+Multi (PS)":
|
14 |
-
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt"
|
15 |
-
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256',
|
16 |
-
'-dec', 'multi-t5', '-nl', '26', '-enc', 'perceiver-tf',
|
17 |
-
'-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
|
18 |
-
elif model_name == "YPTF.MoE+Multi (noPS)":
|
19 |
-
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt"
|
20 |
-
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
|
21 |
-
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
|
22 |
-
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
|
23 |
-
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
|
24 |
-
elif model_name == "YPTF.MoE+Multi (PS)":
|
25 |
-
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt"
|
26 |
-
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
|
27 |
-
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
|
28 |
-
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
|
29 |
-
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
|
30 |
-
else:
|
31 |
-
raise ValueError(model_name)
|
32 |
-
|
33 |
-
model = load_model_checkpoint(args=args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|