smangrul commited on
Commit
48d3c97
1 Parent(s): f41be3d

updating fsdp docs as per resolution to issue https://github.com/huggingface/accelerate/issues/1054

Browse files
code_samples/training_configuration/pytorch_fsdp CHANGED
@@ -27,16 +27,27 @@ use_cpu: false
27
  <pre>
28
  from accelerate import Accelerator
29
 
30
- accelerator = Accelerator()
31
- - model, optimizer, dataloader, scheduler = accelerator.prepare(
32
- - model, optimizer, dataloader, scheduler
33
- -)
34
- +model = accelerator.prepare(model)
35
- +# Optimizer can be any PyTorch optimizer class
36
- +optimizer = torch.optim.AdamW(params=model.parameters(), lr=lr)
37
- +optimizer, dataloader, scheduler = accelerator.prepare(
38
- + optimizer, dataloader, scheduler
39
- +)
 
 
 
 
 
 
 
 
 
 
 
40
  </pre>
41
  ##
42
  If the YAML was generated through the `accelerate config` command:
 
27
  <pre>
28
  from accelerate import Accelerator
29
 
30
+ def main():
31
+ accelerator = Accelerator()
32
+ - model, optimizer, dataloader, scheduler = accelerator.prepare(
33
+ - model, optimizer, dataloader, scheduler
34
+ - )
35
+ + model = accelerator.prepare(model)
36
+ + # Optimizer can be any PyTorch optimizer class
37
+ + optimizer = torch.optim.AdamW(params=model.parameters(), lr=lr)
38
+ + optimizer, dataloader, scheduler = accelerator.prepare(
39
+ + optimizer, dataloader, scheduler
40
+ + )
41
+
42
+ ...
43
+
44
+ accelerator.unwrap_model(model).save_pretrained(
45
+ args.output_dir,
46
+ is_main_process=accelerator.is_main_process,
47
+ save_function=accelerator.save,
48
+ + state_dict=accelerator.get_state_dict(model)
49
+ )
50
+ ...
51
  </pre>
52
  ##
53
  If the YAML was generated through the `accelerate config` command: