mimbres commited on
Commit
c6bfadd
·
1 Parent(s): a582c22
Files changed (1) hide show
  1. load_checkpoint.py +0 -33
load_checkpoint.py DELETED
@@ -1,33 +0,0 @@
1
- # @title Load Checkpoint
2
- model_name = 'YPTF.MoE+Multi (noPS)' # @param ["YMT3+", "YPTF+Single (noPS)", "YPTF+Multi (PS)", "YPTF.MoE+Multi (noPS)", "YPTF.MoE+Multi (PS)"]
3
- precision = '16' # @param ["32", "bf16-mixed", "16"]
4
- project = '2024'
5
-
6
- if model_name == "YMT3+":
7
- checkpoint = "notask_all_cross_v6_xk2_amp0811_gm_ext_plus_nops_b72@model.ckpt"
8
- args = [checkpoint, '-p', project, '-pr', precision]
9
- elif model_name == "YPTF+Single (noPS)":
10
- checkpoint = "ptf_all_cross_rebal5_mirst_xk2_edr005_attend_c_full_plus_b100@model.ckpt"
11
- args = [checkpoint, '-p', project, '-enc', 'perceiver-tf', '-ac', 'spec',
12
- '-hop', '300', '-atc', '1', '-pr', precision]
13
- elif model_name == "YPTF+Multi (PS)":
14
- checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt"
15
- args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256',
16
- '-dec', 'multi-t5', '-nl', '26', '-enc', 'perceiver-tf',
17
- '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
18
- elif model_name == "YPTF.MoE+Multi (noPS)":
19
- checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt"
20
- args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
21
- '-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
22
- '-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
23
- '-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
24
- elif model_name == "YPTF.MoE+Multi (PS)":
25
- checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt"
26
- args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
27
- '-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
28
- '-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
29
- '-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
30
- else:
31
- raise ValueError(model_name)
32
-
33
- model = load_model_checkpoint(args=args)