Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 1,986 Bytes
b91e31d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
##
<pre>
import evaluate
+from accelerate import Accelerator
+accelerator = Accelerator()
+train_dataloader, eval_dataloader, model, optimizer, scheduler = (
+ accelerator.prepare(
+ train_dataloader, eval_dataloader,
+ model, optimizer, scheduler
+ )
+)
metric = evaluate.load("accuracy")
for batch in train_dataloader:
optimizer.zero_grad()
inputs, targets = batch
- inputs = inputs.to(device)
- targets = targets.to(device)
outputs = model(inputs)
loss = loss_function(outputs, targets)
loss.backward()
optimizer.step()
scheduler.step()
model.eval()
for batch in eval_dataloader:
inputs, targets = batch
- inputs = inputs.to(device)
- targets = targets.to(device)
with torch.no_grad():
outputs = model(inputs)
predictions = outputs.argmax(dim=-1)
+ predictions, references = accelerator.gather_for_metrics(
+ (predictions, references)
+ )
metric.add_batch(
predictions = predictions,
references = references
)
print(metric.compute())</pre>
##
When calculating metrics on a validation set, you can use the `Accelerator.gather_for_metrics`
method to gather the predictions and references from all devices and then calculate the metric on the gathered values.
This will also *automatically* drop the padded values from the gathered tensors that were added to ensure
that all tensors have the same length. This ensures that the metric is calculated on the correct values.
##
To learn more checkout the related documentation:
- <a href="https://huggingface.co/docs/accelerate/en/quicktour#distributed-evaluation" target="_blank">Quicktour - Calculating metrics</a>
- <a href="https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.gather_for_metrics" target="_blank">API reference</a>
- <a href="https://github.com/huggingface/accelerate/blob/main/examples/by_feature/multi_process_metrics.py" target="_blank">Example script</a> |