Spaces:
Sleeping
Sleeping
split model setup and task execution
Browse files
main.py
CHANGED
@@ -14,7 +14,8 @@ from rewards import get_reward_losses
|
|
14 |
from training import LatentNoiseTrainer, get_optimizer
|
15 |
|
16 |
|
17 |
-
def
|
|
|
18 |
seed_everything(args.seed)
|
19 |
bf.makedirs(f"{args.save_dir}/logs/{args.task}")
|
20 |
# Set up logging and name settings
|
@@ -92,6 +93,10 @@ def main(args, progress_callback=None):
|
|
92 |
)
|
93 |
enable_grad = not args.no_optim
|
94 |
|
|
|
|
|
|
|
|
|
95 |
if args.task == "single":
|
96 |
init_latents = torch.randn(shape, device=device, dtype=dtype)
|
97 |
latents = torch.nn.Parameter(init_latents, requires_grad=enable_grad)
|
@@ -269,7 +274,12 @@ def main(args, progress_callback=None):
|
|
269 |
# log total rewards
|
270 |
logging.info(f"Mean initial rewards: {total_init_rewards}")
|
271 |
logging.info(f"Mean best rewards: {total_best_rewards}")
|
272 |
-
|
273 |
-
|
274 |
args = parse_args()
|
275 |
-
|
|
|
|
|
|
|
|
|
|
|
|
14 |
from training import LatentNoiseTrainer, get_optimizer
|
15 |
|
16 |
|
17 |
+
def setup(args):
|
18 |
+
#args = parse_args()
|
19 |
seed_everything(args.seed)
|
20 |
bf.makedirs(f"{args.save_dir}/logs/{args.task}")
|
21 |
# Set up logging and name settings
|
|
|
93 |
)
|
94 |
enable_grad = not args.no_optim
|
95 |
|
96 |
+
return args, trainer, device, dtype, shape, enable_grad, settings
|
97 |
+
|
98 |
+
def execute_task(args, trainer, device, dtype, shape, enable_grad, settings, progress_callback=None):
|
99 |
+
#args = parse_args()
|
100 |
if args.task == "single":
|
101 |
init_latents = torch.randn(shape, device=device, dtype=dtype)
|
102 |
latents = torch.nn.Parameter(init_latents, requires_grad=enable_grad)
|
|
|
274 |
# log total rewards
|
275 |
logging.info(f"Mean initial rewards: {total_init_rewards}")
|
276 |
logging.info(f"Mean best rewards: {total_best_rewards}")
|
277 |
+
|
278 |
+
def main():
|
279 |
args = parse_args()
|
280 |
+
args, trainer, device, dtype, shape, enable_grad, settings = setup(args)
|
281 |
+
execute_task(args, trainer, device, dtype, shape, enable_grad, settings)
|
282 |
+
|
283 |
+
|
284 |
+
if __name__ == "__main__":
|
285 |
+
main()
|