File size: 454 Bytes
08efd84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import os

import torch
import torch.distributed as dist


def score_function(engine):
    val_metric = engine.state.metrics["val_mean_dice"]
    if dist.is_initialized():
        device = torch.device("cuda:" + os.environ["LOCAL_RANK"])
        val_metric = torch.tensor([val_metric]).to(device)
        dist.all_reduce(val_metric, op=dist.ReduceOp.SUM)
        val_metric /= dist.get_world_size()
        return val_metric.item()
    return val_metric